diff --git a/src/common/network/address.cpp b/src/common/network/address.cpp index bb116a90..2fb173ef 100644 --- a/src/common/network/address.cpp +++ b/src/common/network/address.cpp @@ -38,7 +38,7 @@ namespace network this->address_.sa_family = AF_UNSPEC; } - address::address(const std::string& addr, const std::optional& family) + address::address(const std::string_view addr, const std::optional& family) : address() { this->parse(addr, family); @@ -286,7 +286,7 @@ namespace network return is_ipv4() || is_ipv6(); } - void address::parse(std::string addr, const std::optional& family) + void address::parse(std::string_view addr, const std::optional& family) { std::optional port_value{}; @@ -298,7 +298,7 @@ namespace network addr = addr.substr(0, pos); } - this->resolve(addr, family); + this->resolve(std::string(addr), family); if (port_value) { diff --git a/src/common/network/address.hpp b/src/common/network/address.hpp index f8b6c7df..aa434850 100644 --- a/src/common/network/address.hpp +++ b/src/common/network/address.hpp @@ -26,6 +26,7 @@ #endif #include +#include #include #include @@ -42,7 +43,7 @@ namespace network { public: address(); - address(const std::string& addr, const std::optional& family = {}); + address(std::string_view addr, const std::optional& family = std::nullopt); address(const sockaddr_in& addr); address(const sockaddr_in6& addr); address(const sockaddr* addr, socklen_t length); @@ -91,7 +92,7 @@ namespace network sockaddr_storage storage_; }; - void parse(std::string addr, const std::optional& family = {}); + void parse(std::string_view addr, const std::optional& family = {}); void resolve(const std::string& hostname, const std::optional& family = {}); }; } diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 5b81a998..88f6d2cd 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -1,4 +1,5 @@ #include "socket.hpp" +#include "address.hpp" #include @@ -6,11 +7,15 @@ using namespace std::literals; namespace network { - socket::socket(const int af) - : address_family_(af) + socket::socket(SOCKET s) + : socket_(s) + { + } + + socket::socket(const int af, const int type, const int protocol) { initialize_wsa(); - this->socket_ = ::socket(af, SOCK_DGRAM, IPPROTO_UDP); + this->socket_ = ::socket(af, type, protocol); if (af == AF_INET6) { @@ -36,16 +41,18 @@ namespace network { this->release(); this->socket_ = obj.socket_; - this->port_ = obj.port_; - this->address_family_ = obj.address_family_; obj.socket_ = INVALID_SOCKET; - obj.address_family_ = AF_UNSPEC; } return *this; } + socket::operator bool() const + { + return this->socket_ != INVALID_SOCKET; + } + void socket::release() { if (this->socket_ != INVALID_SOCKET) @@ -55,43 +62,9 @@ namespace network } } - bool socket::bind_port(const address& target) + bool socket::bind(const address& target) { - const auto result = bind(this->socket_, &target.get_addr(), target.get_size()) == 0; - if (result) - { - this->port_ = target.get_port(); - } - - return result; - } - - bool socket::send(const address& target, const void* data, const size_t size) const - { - const auto res = sendto(this->socket_, static_cast(data), static_cast(size), 0, - &target.get_addr(), target.get_size()); - return static_cast(res) == size; - } - - bool socket::send(const address& target, const std::string& data) const - { - return this->send(target, data.data(), data.size()); - } - - bool socket::receive(address& source, std::string& data) const - { - char buffer[0x2000]; - auto len = source.get_max_size(); - - const auto result = - recvfrom(this->socket_, buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); - if (result == SOCKET_ERROR) - { - return false; - } - - data.assign(buffer, buffer + result); - return true; + return ::bind(this->socket_, &target.get_addr(), target.get_size()) == 0; } bool socket::set_blocking(const bool blocking) @@ -156,14 +129,38 @@ namespace network return this->socket_; } + std::optional
socket::get_name() const + { + address a{}; + auto len = a.get_max_size(); + if (getsockname(this->socket_, &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + uint16_t socket::get_port() const { - return this->port_; + const auto address = this->get_name(); + if (!address) + { + return 0; + } + + return address->get_port(); } int socket::get_address_family() const { - return this->address_family_; + const auto address = this->get_name(); + if (!address) + { + return AF_UNSPEC; + } + + return address->get_addr().sa_family; } bool socket::sleep_sockets(const std::span& sockets, const std::chrono::milliseconds timeout) diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index 0900e6ed..4696fb88 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -4,12 +4,14 @@ #include #include +#include #ifdef _WIN32 using send_size = int; #define GET_SOCKET_ERROR() (WSAGetLastError()) #define poll WSAPoll #define SOCK_WOULDBLOCK WSAEWOULDBLOCK +#define SHUT_RDWR SD_BOTH #else using SOCKET = int; using send_size = size_t; @@ -27,8 +29,10 @@ namespace network public: socket() = default; - socket(int af); - ~socket(); + socket(SOCKET s); + + socket(int af, int type, int protocol); + virtual ~socket(); socket(const socket& obj) = delete; socket& operator=(const socket& obj) = delete; @@ -36,11 +40,9 @@ namespace network socket(socket&& obj) noexcept; socket& operator=(socket&& obj) noexcept; - bool bind_port(const address& target); + operator bool() const; - [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; - [[maybe_unused]] bool send(const address& target, const std::string& data) const; - bool receive(address& source, std::string& data) const; + bool bind(const address& target); bool set_blocking(bool blocking); static bool set_blocking(SOCKET s, bool blocking); @@ -51,6 +53,7 @@ namespace network SOCKET get_socket() const; uint16_t get_port() const; + std::optional
get_name() const; int get_address_family() const; @@ -61,8 +64,6 @@ namespace network static bool is_socket_ready(SOCKET s, bool in_poll); private: - int address_family_{AF_UNSPEC}; - uint16_t port_ = 0; SOCKET socket_ = INVALID_SOCKET; void release(); diff --git a/src/common/network/tcp_client_socket.cpp b/src/common/network/tcp_client_socket.cpp new file mode 100644 index 00000000..9739b394 --- /dev/null +++ b/src/common/network/tcp_client_socket.cpp @@ -0,0 +1,74 @@ +#include "tcp_client_socket.hpp" + +#include + +namespace network +{ + tcp_client_socket::tcp_client_socket(const int af) + : socket(af, SOCK_STREAM, IPPROTO_TCP) + { + } + + tcp_client_socket::tcp_client_socket(SOCKET s, const address& target) + : socket(s) + { + (void)target; + assert(this->get_target() == target); + } + + tcp_client_socket::~tcp_client_socket() + { + if (*this && this->get_target()) + { + ::shutdown(this->get_socket(), SHUT_RDWR); + } + } + + bool tcp_client_socket::send(const void* data, const size_t size) const + { + const auto res = ::send(this->get_socket(), static_cast(data), static_cast(size), 0); + return static_cast(res) == size; + } + + bool tcp_client_socket::send(const std::string_view data) const + { + return this->send(data.data(), data.size()); + } + + bool tcp_client_socket::receive(std::string& data) const + { + char buffer[0x2000]; + + const auto result = recv(this->get_socket(), buffer, static_cast(sizeof(buffer)), 0); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } + + std::optional
tcp_client_socket::get_target() const + { + address a{}; + auto len = a.get_max_size(); + if (getpeername(this->get_socket(), &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + + bool tcp_client_socket::connect(const address& target) + { + if (::connect(this->get_socket(), &target.get_addr(), target.get_size()) != SOCKET_ERROR) + { + return true; + } + + const auto error = GET_SOCKET_ERROR(); + return error == SOCK_WOULDBLOCK; + } +} diff --git a/src/common/network/tcp_client_socket.hpp b/src/common/network/tcp_client_socket.hpp new file mode 100644 index 00000000..e5b586a1 --- /dev/null +++ b/src/common/network/tcp_client_socket.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "socket.hpp" + +#include + +namespace network +{ + class tcp_server_socket; + + class tcp_client_socket : public socket + { + public: + tcp_client_socket(int af); + + tcp_client_socket() = default; + ~tcp_client_socket() override; + + tcp_client_socket(tcp_client_socket&& obj) noexcept = default; + tcp_client_socket& operator=(tcp_client_socket&& obj) noexcept = default; + + [[maybe_unused]] bool send(const void* data, size_t size) const; + [[maybe_unused]] bool send(std::string_view data) const; + bool receive(std::string& data) const; + + std::optional
get_target() const; + + bool connect(const address& target); + + private: + friend tcp_server_socket; + tcp_client_socket(SOCKET s, const address& target); + }; +} diff --git a/src/common/network/tcp_server_socket.cpp b/src/common/network/tcp_server_socket.cpp new file mode 100644 index 00000000..026f5952 --- /dev/null +++ b/src/common/network/tcp_server_socket.cpp @@ -0,0 +1,34 @@ +#include "tcp_server_socket.hpp" + +namespace network +{ + tcp_server_socket::tcp_server_socket(const int af) + : socket(af, SOCK_STREAM, IPPROTO_TCP) + { + } + + tcp_client_socket tcp_server_socket::accept() + { + this->listen(); + + address a{}; + auto len = a.get_max_size(); + const auto s = ::accept(this->get_socket(), &a.get_addr(), &len); + if (s == INVALID_SOCKET) + { + return {}; + } + + return {s, a}; + } + + void tcp_server_socket::listen() + { + if (this->listening_) + { + return; + } + + this->listening_ = ::listen(this->get_socket(), 32) == 0; + } +} diff --git a/src/common/network/tcp_server_socket.hpp b/src/common/network/tcp_server_socket.hpp new file mode 100644 index 00000000..3d62fb1f --- /dev/null +++ b/src/common/network/tcp_server_socket.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "socket.hpp" +#include "tcp_client_socket.hpp" + +namespace network +{ + class tcp_server_socket : public socket + { + public: + tcp_server_socket(int af); + + tcp_server_socket() = default; + ~tcp_server_socket() override = default; + + tcp_server_socket(tcp_server_socket&& obj) noexcept = default; + tcp_server_socket& operator=(tcp_server_socket&& obj) noexcept = default; + + tcp_client_socket accept(); + + private: + bool listening_{false}; + + void listen(); + }; +} diff --git a/src/common/network/udp_socket.cpp b/src/common/network/udp_socket.cpp new file mode 100644 index 00000000..99c6e9e6 --- /dev/null +++ b/src/common/network/udp_socket.cpp @@ -0,0 +1,37 @@ +#include "udp_socket.hpp" + +namespace network +{ + udp_socket::udp_socket(const int af) + : socket(af, SOCK_DGRAM, IPPROTO_UDP) + { + } + + bool udp_socket::send(const address& target, const void* data, const size_t size) const + { + const auto res = sendto(this->get_socket(), static_cast(data), static_cast(size), 0, + &target.get_addr(), target.get_size()); + return static_cast(res) == size; + } + + bool udp_socket::send(const address& target, const std::string_view data) const + { + return this->send(target, data.data(), data.size()); + } + + bool udp_socket::receive(address& source, std::string& data) const + { + char buffer[0x2000]; + auto len = source.get_max_size(); + + const auto result = + recvfrom(this->get_socket(), buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } +} diff --git a/src/common/network/udp_socket.hpp b/src/common/network/udp_socket.hpp new file mode 100644 index 00000000..2e240728 --- /dev/null +++ b/src/common/network/udp_socket.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "socket.hpp" + +#include + +namespace network +{ + struct udp_socket : socket + { + udp_socket(int af); + udp_socket() = default; + ~udp_socket() override = default; + + udp_socket(udp_socket&& obj) noexcept = default; + udp_socket& operator=(udp_socket&& obj) noexcept = default; + + [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; + [[maybe_unused]] bool send(const address& target, std::string_view data) const; + bool receive(address& source, std::string& data) const; + }; +}