From c8c1e000a308d60cca951746548cefae83ecf0c3 Mon Sep 17 00:00:00 2001 From: momo5502 Date: Sat, 11 Jan 2025 20:56:50 +0100 Subject: [PATCH 1/4] Separate udp socket implementation from generic socket --- src/common/network/socket.cpp | 83 +++++++++++++++---------------- src/common/network/socket.hpp | 14 +++--- src/common/network/udp_socket.cpp | 37 ++++++++++++++ src/common/network/udp_socket.hpp | 22 ++++++++ 4 files changed, 106 insertions(+), 50 deletions(-) create mode 100644 src/common/network/udp_socket.cpp create mode 100644 src/common/network/udp_socket.hpp diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 5b81a998..93a67d1c 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -1,4 +1,5 @@ #include "socket.hpp" +#include "address.hpp" #include @@ -6,11 +7,15 @@ using namespace std::literals; namespace network { - socket::socket(const int af) - : address_family_(af) + socket::socket(SOCKET s) + : socket_(s) + { + } + + socket::socket(const int af, const int type, const int protocol) { initialize_wsa(); - this->socket_ = ::socket(af, SOCK_DGRAM, IPPROTO_UDP); + this->socket_ = ::socket(af, type, protocol); if (af == AF_INET6) { @@ -36,16 +41,18 @@ namespace network { this->release(); this->socket_ = obj.socket_; - this->port_ = obj.port_; - this->address_family_ = obj.address_family_; obj.socket_ = INVALID_SOCKET; - obj.address_family_ = AF_UNSPEC; } return *this; } + socket::operator bool() const + { + return this->socket_ != INVALID_SOCKET; + } + void socket::release() { if (this->socket_ != INVALID_SOCKET) @@ -57,41 +64,7 @@ namespace network bool socket::bind_port(const address& target) { - const auto result = bind(this->socket_, &target.get_addr(), target.get_size()) == 0; - if (result) - { - this->port_ = target.get_port(); - } - - return result; - } - - bool socket::send(const address& target, const void* data, const size_t size) const - { - const auto res = sendto(this->socket_, static_cast(data), static_cast(size), 0, - &target.get_addr(), target.get_size()); - return static_cast(res) == size; - } - - bool socket::send(const address& target, const std::string& data) const - { - return this->send(target, data.data(), data.size()); - } - - bool socket::receive(address& source, std::string& data) const - { - char buffer[0x2000]; - auto len = source.get_max_size(); - - const auto result = - recvfrom(this->socket_, buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); - if (result == SOCKET_ERROR) - { - return false; - } - - data.assign(buffer, buffer + result); - return true; + return bind(this->socket_, &target.get_addr(), target.get_size()) == 0; } bool socket::set_blocking(const bool blocking) @@ -156,14 +129,38 @@ namespace network return this->socket_; } + std::optional
socket::get_name() const + { + address a{}; + auto len = a.get_max_size(); + if (getsockname(this->socket_, &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + uint16_t socket::get_port() const { - return this->port_; + const auto address = this->get_name(); + if (!address) + { + return 0; + } + + return address->get_port(); } int socket::get_address_family() const { - return this->address_family_; + const auto address = this->get_name(); + if (!address) + { + return AF_UNSPEC; + } + + return address->get_addr().sa_family; } bool socket::sleep_sockets(const std::span& sockets, const std::chrono::milliseconds timeout) diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index 0900e6ed..f520ad41 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -4,6 +4,7 @@ #include #include +#include #ifdef _WIN32 using send_size = int; @@ -27,7 +28,9 @@ namespace network public: socket() = default; - socket(int af); + socket(SOCKET s); + + socket(int af, int type, int protocol); ~socket(); socket(const socket& obj) = delete; @@ -36,11 +39,9 @@ namespace network socket(socket&& obj) noexcept; socket& operator=(socket&& obj) noexcept; - bool bind_port(const address& target); + operator bool() const; - [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; - [[maybe_unused]] bool send(const address& target, const std::string& data) const; - bool receive(address& source, std::string& data) const; + bool bind_port(const address& target); bool set_blocking(bool blocking); static bool set_blocking(SOCKET s, bool blocking); @@ -51,6 +52,7 @@ namespace network SOCKET get_socket() const; uint16_t get_port() const; + std::optional
get_name() const; int get_address_family() const; @@ -61,8 +63,6 @@ namespace network static bool is_socket_ready(SOCKET s, bool in_poll); private: - int address_family_{AF_UNSPEC}; - uint16_t port_ = 0; SOCKET socket_ = INVALID_SOCKET; void release(); diff --git a/src/common/network/udp_socket.cpp b/src/common/network/udp_socket.cpp new file mode 100644 index 00000000..99c6e9e6 --- /dev/null +++ b/src/common/network/udp_socket.cpp @@ -0,0 +1,37 @@ +#include "udp_socket.hpp" + +namespace network +{ + udp_socket::udp_socket(const int af) + : socket(af, SOCK_DGRAM, IPPROTO_UDP) + { + } + + bool udp_socket::send(const address& target, const void* data, const size_t size) const + { + const auto res = sendto(this->get_socket(), static_cast(data), static_cast(size), 0, + &target.get_addr(), target.get_size()); + return static_cast(res) == size; + } + + bool udp_socket::send(const address& target, const std::string_view data) const + { + return this->send(target, data.data(), data.size()); + } + + bool udp_socket::receive(address& source, std::string& data) const + { + char buffer[0x2000]; + auto len = source.get_max_size(); + + const auto result = + recvfrom(this->get_socket(), buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } +} diff --git a/src/common/network/udp_socket.hpp b/src/common/network/udp_socket.hpp new file mode 100644 index 00000000..8bc68f70 --- /dev/null +++ b/src/common/network/udp_socket.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "socket.hpp" + +#include + +namespace network +{ + struct udp_socket : socket + { + udp_socket(int af); + udp_socket() = default; + ~udp_socket() = default; + + udp_socket(udp_socket&& obj) noexcept = default; + udp_socket& operator=(udp_socket&& obj) noexcept = default; + + [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; + [[maybe_unused]] bool send(const address& target, std::string_view data) const; + bool receive(address& source, std::string& data) const; + }; +} From 21e2f6f999bac65dd1d9de6c88f90ad0eb44c0a5 Mon Sep 17 00:00:00 2001 From: momo5502 Date: Sat, 11 Jan 2025 21:29:55 +0100 Subject: [PATCH 2/4] Prepare TCP support --- src/common/network/tcp_client_socket.cpp | 40 ++++++++++++++++++++++++ src/common/network/tcp_client_socket.hpp | 33 +++++++++++++++++++ src/common/network/tcp_server_socket.cpp | 35 +++++++++++++++++++++ src/common/network/tcp_server_socket.hpp | 25 +++++++++++++++ 4 files changed, 133 insertions(+) create mode 100644 src/common/network/tcp_client_socket.cpp create mode 100644 src/common/network/tcp_client_socket.hpp create mode 100644 src/common/network/tcp_server_socket.cpp create mode 100644 src/common/network/tcp_server_socket.hpp diff --git a/src/common/network/tcp_client_socket.cpp b/src/common/network/tcp_client_socket.cpp new file mode 100644 index 00000000..a30e4d5a --- /dev/null +++ b/src/common/network/tcp_client_socket.cpp @@ -0,0 +1,40 @@ +#include "tcp_client_socket.hpp" + +namespace network +{ + tcp_client_socket::tcp_client_socket(SOCKET s, const address& target) + : socket(s), + target_(target) + { + } + + bool tcp_client_socket::send(const void* data, const size_t size) const + { + const auto res = ::send(this->get_socket(), static_cast(data), static_cast(size), 0); + return static_cast(res) == size; + } + + bool tcp_client_socket::send(const std::string_view data) const + { + return this->send(data.data(), data.size()); + } + + bool tcp_client_socket::receive(std::string& data) const + { + char buffer[0x2000]; + + const auto result = recv(this->get_socket(), buffer, static_cast(sizeof(buffer)), 0); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } + + address tcp_client_socket::get_target() const + { + return this->target_; + } +} diff --git a/src/common/network/tcp_client_socket.hpp b/src/common/network/tcp_client_socket.hpp new file mode 100644 index 00000000..f09bbf9a --- /dev/null +++ b/src/common/network/tcp_client_socket.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "socket.hpp" + +#include + +namespace network +{ + class tcp_server_socket; + + class tcp_client_socket : public socket + { + // TODO: Construct and connect client! + + tcp_client_socket() = default; + ~tcp_client_socket() = default; + + tcp_client_socket(tcp_client_socket&& obj) noexcept = default; + tcp_client_socket& operator=(tcp_client_socket&& obj) noexcept = default; + + [[maybe_unused]] bool send(const void* data, size_t size) const; + [[maybe_unused]] bool send(std::string_view data) const; + bool receive(std::string& data) const; + + address get_target() const; + + private: + address target_{}; + + friend tcp_server_socket; + tcp_client_socket(SOCKET s, const address& target); + }; +} diff --git a/src/common/network/tcp_server_socket.cpp b/src/common/network/tcp_server_socket.cpp new file mode 100644 index 00000000..da9b5437 --- /dev/null +++ b/src/common/network/tcp_server_socket.cpp @@ -0,0 +1,35 @@ +#include "tcp_server_socket.hpp" + +namespace network +{ + tcp_server_socket::tcp_server_socket(const int af) + : socket(af, SOCK_STREAM, IPPROTO_TCP) + { + } + + tcp_client_socket tcp_server_socket::accept() + { + this->listen(); + + address a{}; + auto len = a.get_max_size(); + auto s = ::accept(this->get_socket(), &a.get_addr(), &len); + + if (s == INVALID_SOCKET) + { + return {}; + } + + return tcp_client_socket{s, a}; + } + + void tcp_server_socket::listen() + { + if (this->listening_) + { + return; + } + + this->listening_ = ::listen(this->get_socket(), 32) == 0; + } +} diff --git a/src/common/network/tcp_server_socket.hpp b/src/common/network/tcp_server_socket.hpp new file mode 100644 index 00000000..78289f01 --- /dev/null +++ b/src/common/network/tcp_server_socket.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "socket.hpp" +#include "tcp_client_socket.hpp" + +namespace network +{ + class tcp_server_socket : public socket + { + tcp_server_socket(int af); + + tcp_server_socket() = default; + ~tcp_server_socket() = default; + + tcp_server_socket(tcp_server_socket&& obj) noexcept = default; + tcp_server_socket& operator=(tcp_server_socket&& obj) noexcept = default; + + tcp_client_socket accept(); + + private: + bool listening_{false}; + + void listen(); + }; +} From 8333c25f2ce1988d10e8a59191e592aab981d0ac Mon Sep 17 00:00:00 2001 From: momo5502 Date: Sun, 12 Jan 2025 08:23:47 +0100 Subject: [PATCH 3/4] Finish tcp client socket --- src/common/network/address.cpp | 6 ++-- src/common/network/address.hpp | 5 +-- src/common/network/socket.cpp | 4 +-- src/common/network/socket.hpp | 5 +-- src/common/network/tcp_client_socket.cpp | 40 +++++++++++++++++++++--- src/common/network/tcp_client_socket.hpp | 13 ++++---- src/common/network/tcp_server_socket.cpp | 9 +----- src/common/network/tcp_server_socket.hpp | 3 +- src/common/network/udp_socket.hpp | 2 +- 9 files changed, 57 insertions(+), 30 deletions(-) diff --git a/src/common/network/address.cpp b/src/common/network/address.cpp index bb116a90..2fb173ef 100644 --- a/src/common/network/address.cpp +++ b/src/common/network/address.cpp @@ -38,7 +38,7 @@ namespace network this->address_.sa_family = AF_UNSPEC; } - address::address(const std::string& addr, const std::optional& family) + address::address(const std::string_view addr, const std::optional& family) : address() { this->parse(addr, family); @@ -286,7 +286,7 @@ namespace network return is_ipv4() || is_ipv6(); } - void address::parse(std::string addr, const std::optional& family) + void address::parse(std::string_view addr, const std::optional& family) { std::optional port_value{}; @@ -298,7 +298,7 @@ namespace network addr = addr.substr(0, pos); } - this->resolve(addr, family); + this->resolve(std::string(addr), family); if (port_value) { diff --git a/src/common/network/address.hpp b/src/common/network/address.hpp index f8b6c7df..aa434850 100644 --- a/src/common/network/address.hpp +++ b/src/common/network/address.hpp @@ -26,6 +26,7 @@ #endif #include +#include #include #include @@ -42,7 +43,7 @@ namespace network { public: address(); - address(const std::string& addr, const std::optional& family = {}); + address(std::string_view addr, const std::optional& family = std::nullopt); address(const sockaddr_in& addr); address(const sockaddr_in6& addr); address(const sockaddr* addr, socklen_t length); @@ -91,7 +92,7 @@ namespace network sockaddr_storage storage_; }; - void parse(std::string addr, const std::optional& family = {}); + void parse(std::string_view addr, const std::optional& family = {}); void resolve(const std::string& hostname, const std::optional& family = {}); }; } diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 93a67d1c..88f6d2cd 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -62,9 +62,9 @@ namespace network } } - bool socket::bind_port(const address& target) + bool socket::bind(const address& target) { - return bind(this->socket_, &target.get_addr(), target.get_size()) == 0; + return ::bind(this->socket_, &target.get_addr(), target.get_size()) == 0; } bool socket::set_blocking(const bool blocking) diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index f520ad41..4696fb88 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -11,6 +11,7 @@ using send_size = int; #define GET_SOCKET_ERROR() (WSAGetLastError()) #define poll WSAPoll #define SOCK_WOULDBLOCK WSAEWOULDBLOCK +#define SHUT_RDWR SD_BOTH #else using SOCKET = int; using send_size = size_t; @@ -31,7 +32,7 @@ namespace network socket(SOCKET s); socket(int af, int type, int protocol); - ~socket(); + virtual ~socket(); socket(const socket& obj) = delete; socket& operator=(const socket& obj) = delete; @@ -41,7 +42,7 @@ namespace network operator bool() const; - bool bind_port(const address& target); + bool bind(const address& target); bool set_blocking(bool blocking); static bool set_blocking(SOCKET s, bool blocking); diff --git a/src/common/network/tcp_client_socket.cpp b/src/common/network/tcp_client_socket.cpp index a30e4d5a..83953d38 100644 --- a/src/common/network/tcp_client_socket.cpp +++ b/src/common/network/tcp_client_socket.cpp @@ -2,12 +2,24 @@ namespace network { - tcp_client_socket::tcp_client_socket(SOCKET s, const address& target) - : socket(s), - target_(target) + tcp_client_socket::tcp_client_socket(const int af) + : socket(af, SOCK_STREAM, IPPROTO_TCP) { } + tcp_client_socket::tcp_client_socket(SOCKET s) + : socket(s) + { + } + + tcp_client_socket::~tcp_client_socket() + { + if (*this && this->get_target()) + { + ::shutdown(this->get_socket(), SHUT_RDWR); + } + } + bool tcp_client_socket::send(const void* data, const size_t size) const { const auto res = ::send(this->get_socket(), static_cast(data), static_cast(size), 0); @@ -33,8 +45,26 @@ namespace network return true; } - address tcp_client_socket::get_target() const + std::optional
tcp_client_socket::get_target() const { - return this->target_; + address a{}; + auto len = a.get_max_size(); + if (getpeername(this->get_socket(), &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + + bool tcp_client_socket::connect(const address& target) + { + if (::connect(this->get_socket(), &target.get_addr(), target.get_size()) != SOCKET_ERROR) + { + return true; + } + + const auto error = GET_SOCKET_ERROR(); + return error == SOCK_WOULDBLOCK; } } diff --git a/src/common/network/tcp_client_socket.hpp b/src/common/network/tcp_client_socket.hpp index f09bbf9a..8d297215 100644 --- a/src/common/network/tcp_client_socket.hpp +++ b/src/common/network/tcp_client_socket.hpp @@ -10,10 +10,11 @@ namespace network class tcp_client_socket : public socket { - // TODO: Construct and connect client! + public: + tcp_client_socket(int af); tcp_client_socket() = default; - ~tcp_client_socket() = default; + ~tcp_client_socket() override; tcp_client_socket(tcp_client_socket&& obj) noexcept = default; tcp_client_socket& operator=(tcp_client_socket&& obj) noexcept = default; @@ -22,12 +23,12 @@ namespace network [[maybe_unused]] bool send(std::string_view data) const; bool receive(std::string& data) const; - address get_target() const; + std::optional
get_target() const; + + bool connect(const address& target); private: - address target_{}; - friend tcp_server_socket; - tcp_client_socket(SOCKET s, const address& target); + tcp_client_socket(SOCKET s); }; } diff --git a/src/common/network/tcp_server_socket.cpp b/src/common/network/tcp_server_socket.cpp index da9b5437..806bb875 100644 --- a/src/common/network/tcp_server_socket.cpp +++ b/src/common/network/tcp_server_socket.cpp @@ -13,14 +13,7 @@ namespace network address a{}; auto len = a.get_max_size(); - auto s = ::accept(this->get_socket(), &a.get_addr(), &len); - - if (s == INVALID_SOCKET) - { - return {}; - } - - return tcp_client_socket{s, a}; + return ::accept(this->get_socket(), &a.get_addr(), &len); } void tcp_server_socket::listen() diff --git a/src/common/network/tcp_server_socket.hpp b/src/common/network/tcp_server_socket.hpp index 78289f01..3d62fb1f 100644 --- a/src/common/network/tcp_server_socket.hpp +++ b/src/common/network/tcp_server_socket.hpp @@ -7,10 +7,11 @@ namespace network { class tcp_server_socket : public socket { + public: tcp_server_socket(int af); tcp_server_socket() = default; - ~tcp_server_socket() = default; + ~tcp_server_socket() override = default; tcp_server_socket(tcp_server_socket&& obj) noexcept = default; tcp_server_socket& operator=(tcp_server_socket&& obj) noexcept = default; diff --git a/src/common/network/udp_socket.hpp b/src/common/network/udp_socket.hpp index 8bc68f70..2e240728 100644 --- a/src/common/network/udp_socket.hpp +++ b/src/common/network/udp_socket.hpp @@ -10,7 +10,7 @@ namespace network { udp_socket(int af); udp_socket() = default; - ~udp_socket() = default; + ~udp_socket() override = default; udp_socket(udp_socket&& obj) noexcept = default; udp_socket& operator=(udp_socket&& obj) noexcept = default; From dd226bd45a5b151cba22733e4aabd2df048ad0ab Mon Sep 17 00:00:00 2001 From: momo5502 Date: Sun, 12 Jan 2025 08:43:34 +0100 Subject: [PATCH 4/4] Fix compilation --- src/common/network/tcp_client_socket.cpp | 6 +++++- src/common/network/tcp_client_socket.hpp | 2 +- src/common/network/tcp_server_socket.cpp | 8 +++++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/common/network/tcp_client_socket.cpp b/src/common/network/tcp_client_socket.cpp index 83953d38..9739b394 100644 --- a/src/common/network/tcp_client_socket.cpp +++ b/src/common/network/tcp_client_socket.cpp @@ -1,5 +1,7 @@ #include "tcp_client_socket.hpp" +#include + namespace network { tcp_client_socket::tcp_client_socket(const int af) @@ -7,9 +9,11 @@ namespace network { } - tcp_client_socket::tcp_client_socket(SOCKET s) + tcp_client_socket::tcp_client_socket(SOCKET s, const address& target) : socket(s) { + (void)target; + assert(this->get_target() == target); } tcp_client_socket::~tcp_client_socket() diff --git a/src/common/network/tcp_client_socket.hpp b/src/common/network/tcp_client_socket.hpp index 8d297215..e5b586a1 100644 --- a/src/common/network/tcp_client_socket.hpp +++ b/src/common/network/tcp_client_socket.hpp @@ -29,6 +29,6 @@ namespace network private: friend tcp_server_socket; - tcp_client_socket(SOCKET s); + tcp_client_socket(SOCKET s, const address& target); }; } diff --git a/src/common/network/tcp_server_socket.cpp b/src/common/network/tcp_server_socket.cpp index 806bb875..026f5952 100644 --- a/src/common/network/tcp_server_socket.cpp +++ b/src/common/network/tcp_server_socket.cpp @@ -13,7 +13,13 @@ namespace network address a{}; auto len = a.get_max_size(); - return ::accept(this->get_socket(), &a.get_addr(), &len); + const auto s = ::accept(this->get_socket(), &a.get_addr(), &len); + if (s == INVALID_SOCKET) + { + return {}; + } + + return {s, a}; } void tcp_server_socket::listen()