diff --git a/CMakeLists.txt b/CMakeLists.txt index 6eecca90..172d696d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,7 +134,6 @@ set(SUNSHINE_TARGET_FILES sunshine/stream.h sunshine/video.cpp sunshine/video.h - sunshine/thread_safe.h sunshine/input.cpp sunshine/input.h sunshine/audio.cpp @@ -147,6 +146,8 @@ set(SUNSHINE_TARGET_FILES sunshine/move_by_copy.h sunshine/task_pool.h sunshine/thread_pool.h + sunshine/thread_safe.h + sunshine/sync.h ${PLATFORM_TARGET_FILES}) include_directories( diff --git a/sunshine/network.cpp b/sunshine/network.cpp index 029b1c76..833bed67 100644 --- a/sunshine/network.cpp +++ b/sunshine/network.cpp @@ -95,10 +95,10 @@ std::string_view to_enum_string(net_e net) { } host_t host_create(ENetAddress &addr, std::size_t peers, std::uint16_t port) { - enet_address_set_host(&addr, "0.0.0.0"); + enet_address_set_host(&addr, "::"); enet_address_set_port(&addr, port); - return host_t { enet_host_create(PF_INET, &addr, peers, 1, 0, 0) }; + return host_t { enet_host_create(AF_INET6, &addr, peers, 1, 0, 0) }; } void free_host(ENetHost *host) { diff --git a/sunshine/platform/common.h b/sunshine/platform/common.h index 01e659e4..b60741b7 100644 --- a/sunshine/platform/common.h +++ b/sunshine/platform/common.h @@ -88,6 +88,7 @@ using input_t = util::safe_ptr; std::string get_mac_address(const std::string_view &address); std::string from_sockaddr(const sockaddr *const); +std::pair from_sockaddr_ex(const sockaddr *const); std::unique_ptr microphone(std::uint32_t sample_rate); std::unique_ptr display(); diff --git a/sunshine/platform/linux.cpp b/sunshine/platform/linux.cpp index 7190ff43..9f094288 100644 --- a/sunshine/platform/linux.cpp +++ b/sunshine/platform/linux.cpp @@ -391,6 +391,24 @@ std::string from_sockaddr(const sockaddr *const ip_addr) { return std::string { data }; } +std::pair from_sockaddr_ex(const sockaddr *const ip_addr) { + char data[INET6_ADDRSTRLEN]; + + auto family = ip_addr->sa_family; + std::uint16_t port; + if(family == AF_INET6) { + inet_ntop(AF_INET6, &((sockaddr_in6*)ip_addr)->sin6_addr, data, INET6_ADDRSTRLEN); + port = ((sockaddr_in6*)ip_addr)->sin6_port; + } + + if(family == AF_INET) { + inet_ntop(AF_INET, &((sockaddr_in*)ip_addr)->sin_addr, data, INET_ADDRSTRLEN); + port = ((sockaddr_in*)ip_addr)->sin_port; + } + + return { port, std::string { data } }; +} + std::string get_mac_address(const std::string_view &address) { auto ifaddrs = get_ifaddrs(); for(auto pos = ifaddrs.get(); pos != nullptr; pos = pos->ifa_next) { diff --git a/sunshine/platform/windows.cpp b/sunshine/platform/windows.cpp index 23eeaa55..b9fd87da 100755 --- a/sunshine/platform/windows.cpp +++ b/sunshine/platform/windows.cpp @@ -101,6 +101,24 @@ std::string from_sockaddr(const sockaddr *const socket_address) { return std::string { data }; } +std::pair from_sockaddr_ex(const sockaddr *const ip_addr) { + char data[INET6_ADDRSTRLEN]; + + auto family = ip_addr->sa_family; + std::uint16_t port; + if(family == AF_INET6) { + inet_ntop(AF_INET6, &((sockaddr_in6*)ip_addr)->sin6_addr, data, INET6_ADDRSTRLEN); + port = ((sockaddr_in6*)ip_addr)->sin6_port; + } + + if(family == AF_INET) { + inet_ntop(AF_INET, &((sockaddr_in*)ip_addr)->sin_addr, data, INET_ADDRSTRLEN); + port = ((sockaddr_in*)ip_addr)->sin_port; + } + + return { port, std::string { data } }; +} + adapteraddrs_t get_adapteraddrs() { adapteraddrs_t info { nullptr }; ULONG size = 0; diff --git a/sunshine/stream.cpp b/sunshine/stream.cpp index f52659db..a8ece2b5 100644 --- a/sunshine/stream.cpp +++ b/sunshine/stream.cpp @@ -20,6 +20,7 @@ extern "C" { #include "utility.h" #include "stream.h" #include "thread_safe.h" +#include "sync.h" #include "input.h" #include "main.h" @@ -89,14 +90,109 @@ using audio_packet_t = util::c_ptr; using message_queue_t = std::shared_ptr>>; using message_queue_queue_t = std::shared_ptr>>; -using session_queue_t = std::shared_ptr>>; + +static inline void while_starting_do_nothing(std::atomic &state) { + while(state.load(std::memory_order_acquire) == session::state_e::STARTING) { + std::this_thread::sleep_for(1ms); + } +} + +class control_server_t { +public: + control_server_t(control_server_t &&) noexcept = default; + control_server_t &operator=(control_server_t &&) noexcept = default; + + explicit control_server_t(std::uint16_t port) : _host { net::host_create(_addr, config::stream.channels, port) } {} + + void emplace_addr_to_session(const std::string &addr, session_t &session) { + auto lg = _map_addr_session.lock(); + + _map_addr_session->emplace(addr, std::make_pair(0u, &session)); + } + + void erase_session(session_t &session) { + auto lg = _map_addr_session.lock(); + + auto pos = std::find_if(std::begin(_map_addr_session.raw), std::end(_map_addr_session.raw), [session_p=&session](auto ¤t_port_and_session) { + return session_p == current_port_and_session.second.second; + }); + + _map_addr_session->erase(pos); + } + + // Get session associated with address. + // If none are found, try to find a session not yet claimed. (It will be marked by a port of value 0 + // If none of those are found, return nullptr + session_t *get_session(const ENetAddress &address) { + TUPLE_2D(port, addr_string, platf::from_sockaddr_ex((sockaddr*)&address.address)); + + auto lg = _map_addr_session.lock(); + TUPLE_2D(begin, end, _map_addr_session->equal_range(addr_string)); + + auto it = std::end(_map_addr_session.raw); + for(auto pos = begin; pos != end; ++pos) { + TUPLE_2D_REF(session_port, session_p, pos->second); + + if(port == session_port) { + return session_p; + } + else if(session_port == 0) { + it = pos; + } + } + + if(it != std::end(_map_addr_session.raw)) { + TUPLE_2D_REF(session_port, session_p, it->second); + session_port = port; + + return session_p; + } + + return nullptr; + } + + // Circular dependency: + // iterate refers to session + // session refers to broadcast_ctx_t + // broadcast_ctx_t refers to control_server_t + // Therefore, iterate is implemented further down the source file + void iterate(std::chrono::milliseconds timeout); + + template + void iterate(std::chrono::duration timeout) { + iterate(std::chrono::floor(timeout)); + } + + void map(uint16_t type, std::function cb) { + _map_type_cb.emplace(type, std::move(cb)); + } + + void send(const std::string_view &payload) { + std::for_each(_host->peers, _host->peers + _host->peerCount, [payload](auto &peer) { + auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE); + if(enet_peer_send(&peer, 0, packet)) { + enet_packet_destroy(packet); + } + }); + + enet_host_flush(_host.get()); + } + + // Callbacks + std::unordered_map> _map_type_cb; + + // Mapping ip:port to session + util::sync_t>> _map_addr_session; + + ENetAddress _addr; + net::host_t _host; +}; struct broadcast_ctx_t { video::packet_queue_t video_packets; audio::packet_queue_t audio_packets; message_queue_queue_t message_queue_queue; - session_queue_t session_queue; std::thread recv_thread; std::thread video_thread; @@ -105,8 +201,9 @@ struct broadcast_ctx_t { asio::io_service io; - udp::socket video_sock { io, udp::endpoint(udp::v4(), VIDEO_STREAM_PORT) }; - udp::socket audio_sock { io, udp::endpoint(udp::v4(), AUDIO_STREAM_PORT) }; + udp::socket video_sock { io, udp::endpoint(udp::v6(), VIDEO_STREAM_PORT) }; + udp::socket audio_sock { io, udp::endpoint(udp::v6(), AUDIO_STREAM_PORT) }; + control_server_t control_server { CONTROL_PORT }; }; struct session_t { @@ -118,6 +215,7 @@ struct session_t { std::chrono::steady_clock::time_point pingTimeout; safe::shared_t::ptr_t broadcast_ref; + udp::endpoint video_peer; udp::endpoint audio_peer; @@ -138,101 +236,57 @@ std::shared_ptr input; static auto broadcast = safe::make_shared(start_broadcast, end_broadcast); safe::signal_t broadcast_shutdown_event; -class control_server_t { -public: - control_server_t(control_server_t &&) noexcept = default; - control_server_t &operator=(control_server_t &&) noexcept = default; +void control_server_t::iterate(std::chrono::milliseconds timeout) { + ENetEvent event; + auto res = enet_host_service(_host.get(), &event, timeout.count()); - explicit control_server_t(session_queue_t session_queue, std::uint16_t port) : session_queue { session_queue }, _host { net::host_create(_addr, config::stream.channels, port) } {} + if(res > 0) { + auto session = get_session(event.peer->address); + if(!session) { + BOOST_LOG(warning) << "Rejected connection from ["sv << platf::from_sockaddr((sockaddr*)&event.peer->address.address) << "]: it's not properly set up"sv; + enet_peer_disconnect_now(event.peer, 0); - void populate_addr_to_session() { - while(session_queue->peek()) { - auto session_opt = session_queue->pop(); - if(!session_opt) { - break; - } - TUPLE_2D_REF(addr_string, session, *session_opt); - - if(session) { - _map_addr_session.try_emplace(addr_string, session).second; - } - else { - _map_addr_session.erase(addr_string); - } + return; } - } - template - void iterate(std::chrono::duration timeout) { - ENetEvent event; - auto res = enet_host_service(_host.get(), &event, std::chrono::floor(timeout).count()); + session->pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout; - populate_addr_to_session(); - if(res > 0) { - auto addr_string = platf::from_sockaddr((sockaddr*)&event.peer->address.address); + switch(event.type) { + case ENET_EVENT_TYPE_RECEIVE: + { + net::packet_t packet { event.packet }; - auto it = _map_addr_session.find(addr_string); - if(it == std::end(_map_addr_session)) { - BOOST_LOG(warning) << "Rejected connection from ["sv << addr_string << "]: it's not properly set up"sv; - enet_peer_disconnect_now(event.peer, 0); - - return; - } - - auto &session = it->second; - session->pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout; - - switch(event.type) { - case ENET_EVENT_TYPE_RECEIVE: - { - net::packet_t packet { event.packet }; - - std::uint16_t *type = (std::uint16_t *)packet->data; - std::string_view payload { (char*)packet->data + sizeof(*type), packet->dataLength - sizeof(*type) }; + std::uint16_t *type = (std::uint16_t *)packet->data; + std::string_view payload { (char*)packet->data + sizeof(*type), packet->dataLength - sizeof(*type) }; - auto cb = _map_type_cb.find(*type); - if(cb == std::end(_map_type_cb)) { - BOOST_LOG(warning) - << "type [Unknown] { "sv << util::hex(*type).to_string_view() << " }"sv << std::endl - << "---data---"sv << std::endl << util::hex_vec(payload) << std::endl << "---end data---"sv; - } - - else { - cb->second(session, payload); - } + auto cb = _map_type_cb.find(*type); + if(cb == std::end(_map_type_cb)) { + BOOST_LOG(warning) + << "type [Unknown] { "sv << util::hex(*type).to_string_view() << " }"sv << std::endl + << "---data---"sv << std::endl << util::hex_vec(payload) << std::endl << "---end data---"sv; + } + + else { + cb->second(session, payload); } - break; - case ENET_EVENT_TYPE_CONNECT: - BOOST_LOG(info) << "CLIENT CONNECTED"sv; - break; - case ENET_EVENT_TYPE_DISCONNECT: - BOOST_LOG(info) << "CLIENT DISCONNECTED"sv; - // No more clients to send video data to ^_^ - if(session->state == session::state_e::RUNNING) { - session::stop(*session); - } - break; - case ENET_EVENT_TYPE_NONE: - break; } + break; + case ENET_EVENT_TYPE_CONNECT: + BOOST_LOG(info) << "CLIENT CONNECTED"sv; + break; + case ENET_EVENT_TYPE_DISCONNECT: + BOOST_LOG(info) << "CLIENT DISCONNECTED"sv; + // No more clients to send video data to ^_^ + if(session->state == session::state_e::RUNNING) { + session::stop(*session); + } + break; + case ENET_EVENT_TYPE_NONE: + break; } } - - void map(uint16_t type, std::function cb) { - _map_type_cb.emplace(type, std::move(cb)); - } - - void send(const std::string_view &payload); - - std::unordered_map> _map_type_cb; - std::unordered_map _map_addr_session; - - session_queue_t session_queue; - - ENetAddress _addr; - net::host_t _host; -}; +} namespace fec { using rs_t = util::safe_ptr; @@ -338,29 +392,16 @@ std::vector replace(const std::string_view &original, const std::string return replaced; } -void control_server_t::send(const std::string_view & payload) { - std::for_each(_host->peers, _host->peers + _host->peerCount, [payload](auto &peer) { - auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE); - if(enet_peer_send(&peer, 0, packet)) { - enet_packet_destroy(packet); - } - }); - - enet_host_flush(_host.get()); -} - -void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t session_queue) { - control_server_t server { session_queue, CONTROL_PORT }; - - server.map(packetTypes[IDX_START_A], [&](session_t *session, const std::string_view &payload) { +void controlBroadcastThread(safe::signal_t *shutdown_event, control_server_t *server) { + server->map(packetTypes[IDX_START_A], [&](session_t *session, const std::string_view &payload) { BOOST_LOG(debug) << "type [IDX_START_A]"sv; }); - server.map(packetTypes[IDX_START_B], [&](session_t *session, const std::string_view &payload) { + server->map(packetTypes[IDX_START_B], [&](session_t *session, const std::string_view &payload) { BOOST_LOG(debug) << "type [IDX_START_B]"sv; }); - server.map(packetTypes[IDX_LOSS_STATS], [&](session_t *session, const std::string_view &payload) { + server->map(packetTypes[IDX_LOSS_STATS], [&](session_t *session, const std::string_view &payload) { int32_t *stats = (int32_t*)payload.data(); auto count = stats[0]; std::chrono::milliseconds t { stats[1] }; @@ -376,7 +417,7 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess << "---end stats---"; }); - server.map(packetTypes[IDX_INVALIDATE_REF_FRAMES], [&](session_t *session, const std::string_view &payload) { + server->map(packetTypes[IDX_INVALIDATE_REF_FRAMES], [&](session_t *session, const std::string_view &payload) { std::int64_t *frames = (std::int64_t *)payload.data(); auto firstFrame = frames[0]; auto lastFrame = frames[1]; @@ -389,7 +430,7 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess session->idr_events->raise(std::make_pair(firstFrame, lastFrame)); }); - server.map(packetTypes[IDX_INPUT_DATA], [&](session_t *session, const std::string_view &payload) { + server->map(packetTypes[IDX_INPUT_DATA], [&](session_t *session, const std::string_view &payload) { BOOST_LOG(debug) << "type [IDX_INPUT_DATA]"sv; int32_t tagged_cipher_length = util::endian::big(*(int32_t*)payload.data()); @@ -416,11 +457,16 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess }); while(!shutdown_event->peek()) { - auto now = std::chrono::steady_clock::now(); - for(auto &[addr,session] : server._map_addr_session) { - if(now > session->pingTimeout) { - BOOST_LOG(info) << addr << ": Ping Timeout"sv; - session::stop(*session); + { + auto lg = server->_map_addr_session.lock(); + + auto now = std::chrono::steady_clock::now(); + for(auto &[addr,port_session] : server->_map_addr_session.raw) { + auto session = port_session.second; + if(now > session->pingTimeout) { + BOOST_LOG(info) << addr << ": Ping Timeout"sv; + session::stop(*session); + } } } @@ -433,13 +479,13 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess payload[0] = packetTypes[IDX_TERMINATION]; payload[1] = reason; - server.send(std::string_view {(char*)payload.data(), payload.size()}); + server->send(std::string_view {(char*)payload.data(), payload.size()}); shutdown_event->raise(true); continue; } - server.iterate(500ms); + server->iterate(500ms); } } @@ -650,11 +696,10 @@ int start_broadcast(broadcast_ctx_t &ctx) { ctx.video_packets = std::make_shared(); ctx.audio_packets = std::make_shared(); ctx.message_queue_queue = std::make_shared(); - ctx.session_queue = std::make_shared(); ctx.video_thread = std::thread { videoBroadcastThread, &broadcast_shutdown_event, std::ref(ctx.video_sock), ctx.video_packets }; ctx.audio_thread = std::thread { audioBroadcastThread, &broadcast_shutdown_event, std::ref(ctx.audio_sock), ctx.audio_packets }; - ctx.control_thread = std::thread { controlBroadcastThread, &broadcast_shutdown_event, ctx.session_queue }; + ctx.control_thread = std::thread { controlBroadcastThread, &broadcast_shutdown_event, &ctx.control_server }; ctx.recv_thread = std::thread { recvThread, std::ref(ctx) }; @@ -727,12 +772,9 @@ void videoThread(session_t *session, std::string addr_str) { session::stop(*session); }); - while(session->state == session::state_e::STARTING) { - std::this_thread::sleep_for(1ms); - } + while_starting_do_nothing(session->state); auto addr = asio::ip::make_address(addr_str); - auto ref = broadcast.ref(); auto port = recv_ping(ref, socket_e::video, addr, config::stream.ping_timeout); if(port < 0) { @@ -751,9 +793,7 @@ void audioThread(session_t *session, std::string addr_str) { session::stop(*session); }); - while(session->state == session::state_e::STARTING) { - std::this_thread::sleep_for(1ms); - } + while_starting_do_nothing(session->state); auto addr = asio::ip::make_address(addr_str); @@ -776,11 +816,16 @@ state_e state(session_t &session) { } void stop(session_t &session) { - session.broadcast_ref->session_queue->raise(session.video_peer.address().to_string(), nullptr); - session.shutdown_event.raise(true); + while_starting_do_nothing(session.state); auto expected = state_e::RUNNING; - session.state.compare_exchange_strong(expected, state_e::STOPPING); + auto already_stopping = !session.state.compare_exchange_strong(expected, state_e::STOPPING); + if(already_stopping) { + return; + } + + session.broadcast_ref->control_server.erase_session(session); + session.shutdown_event.raise(true); } void join(session_t &session) { @@ -792,7 +837,7 @@ void join(session_t &session) { void start(session_t &session, const std::string &addr_string) { session.broadcast_ref = broadcast.ref(); - session.broadcast_ref->session_queue->raise(addr_string, &session); + session.broadcast_ref->control_server.emplace_addr_to_session(addr_string, session); session.pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout; diff --git a/sunshine/sync.h b/sunshine/sync.h new file mode 100644 index 00000000..c8b21f0a --- /dev/null +++ b/sunshine/sync.h @@ -0,0 +1,112 @@ +// +// Created by loki on 16-4-19. +// + +#ifndef SUNSHINE_SYNC_H +#define SUNSHINE_SYNC_H + +#include +#include +#include + +namespace util { + +template +class sync_t { +public: + static_assert(N > 0, "sync_t should have more than zero mutexes"); + using value_type = T; + + template + std::lock_guard lock() { + return std::lock_guard { std::get(_lock) }; + } + + template + sync_t(Args&&... args) : raw {std::forward(args)... } {} + + sync_t &operator=(sync_t &&other) noexcept { + for(auto &l : _lock) { + l.lock(); + } + + for(auto &l : other._lock) { + l.lock(); + } + + raw = std::move(other.raw); + + for(auto &l : _lock) { + l.unlock(); + } + + for(auto &l : other._lock) { + l.unlock(); + } + + return *this; + } + + sync_t &operator=(sync_t &other) noexcept { + for(auto &l : _lock) { + l.lock(); + } + + for(auto &l : other._lock) { + l.lock(); + } + + raw = other.raw; + + for(auto &l : _lock) { + l.unlock(); + } + + for(auto &l : other._lock) { + l.unlock(); + } + + return *this; + } + + sync_t &operator=(const value_type &val) noexcept { + for(auto &l : _lock) { + l.lock(); + } + + raw = val; + + for(auto &l : _lock) { + l.unlock(); + } + + return *this; + } + + sync_t &operator=(value_type &&val) noexcept { + for(auto &l : _lock) { + l.lock(); + } + + raw = std::move(val); + + for(auto &l : _lock) { + l.unlock(); + } + + return *this; + } + + value_type *operator->() { + return &raw; + } + + value_type raw; +private: + std::array _lock; +}; + +} + + +#endif //T_MAN_SYNC_H