diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 5b81a998..93a67d1c 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -1,4 +1,5 @@ #include "socket.hpp" +#include "address.hpp" #include @@ -6,11 +7,15 @@ using namespace std::literals; namespace network { - socket::socket(const int af) - : address_family_(af) + socket::socket(SOCKET s) + : socket_(s) + { + } + + socket::socket(const int af, const int type, const int protocol) { initialize_wsa(); - this->socket_ = ::socket(af, SOCK_DGRAM, IPPROTO_UDP); + this->socket_ = ::socket(af, type, protocol); if (af == AF_INET6) { @@ -36,16 +41,18 @@ namespace network { this->release(); this->socket_ = obj.socket_; - this->port_ = obj.port_; - this->address_family_ = obj.address_family_; obj.socket_ = INVALID_SOCKET; - obj.address_family_ = AF_UNSPEC; } return *this; } + socket::operator bool() const + { + return this->socket_ != INVALID_SOCKET; + } + void socket::release() { if (this->socket_ != INVALID_SOCKET) @@ -57,41 +64,7 @@ namespace network bool socket::bind_port(const address& target) { - const auto result = bind(this->socket_, &target.get_addr(), target.get_size()) == 0; - if (result) - { - this->port_ = target.get_port(); - } - - return result; - } - - bool socket::send(const address& target, const void* data, const size_t size) const - { - const auto res = sendto(this->socket_, static_cast(data), static_cast(size), 0, - &target.get_addr(), target.get_size()); - return static_cast(res) == size; - } - - bool socket::send(const address& target, const std::string& data) const - { - return this->send(target, data.data(), data.size()); - } - - bool socket::receive(address& source, std::string& data) const - { - char buffer[0x2000]; - auto len = source.get_max_size(); - - const auto result = - recvfrom(this->socket_, buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); - if (result == SOCKET_ERROR) - { - return false; - } - - data.assign(buffer, buffer + result); - return true; + return bind(this->socket_, &target.get_addr(), target.get_size()) == 0; } bool socket::set_blocking(const bool blocking) @@ -156,14 +129,38 @@ namespace network return this->socket_; } + std::optional
socket::get_name() const + { + address a{}; + auto len = a.get_max_size(); + if (getsockname(this->socket_, &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + uint16_t socket::get_port() const { - return this->port_; + const auto address = this->get_name(); + if (!address) + { + return 0; + } + + return address->get_port(); } int socket::get_address_family() const { - return this->address_family_; + const auto address = this->get_name(); + if (!address) + { + return AF_UNSPEC; + } + + return address->get_addr().sa_family; } bool socket::sleep_sockets(const std::span& sockets, const std::chrono::milliseconds timeout) diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index 0900e6ed..f520ad41 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -4,6 +4,7 @@ #include #include +#include #ifdef _WIN32 using send_size = int; @@ -27,7 +28,9 @@ namespace network public: socket() = default; - socket(int af); + socket(SOCKET s); + + socket(int af, int type, int protocol); ~socket(); socket(const socket& obj) = delete; @@ -36,11 +39,9 @@ namespace network socket(socket&& obj) noexcept; socket& operator=(socket&& obj) noexcept; - bool bind_port(const address& target); + operator bool() const; - [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; - [[maybe_unused]] bool send(const address& target, const std::string& data) const; - bool receive(address& source, std::string& data) const; + bool bind_port(const address& target); bool set_blocking(bool blocking); static bool set_blocking(SOCKET s, bool blocking); @@ -51,6 +52,7 @@ namespace network SOCKET get_socket() const; uint16_t get_port() const; + std::optional
get_name() const; int get_address_family() const; @@ -61,8 +63,6 @@ namespace network static bool is_socket_ready(SOCKET s, bool in_poll); private: - int address_family_{AF_UNSPEC}; - uint16_t port_ = 0; SOCKET socket_ = INVALID_SOCKET; void release(); diff --git a/src/common/network/udp_socket.cpp b/src/common/network/udp_socket.cpp new file mode 100644 index 00000000..99c6e9e6 --- /dev/null +++ b/src/common/network/udp_socket.cpp @@ -0,0 +1,37 @@ +#include "udp_socket.hpp" + +namespace network +{ + udp_socket::udp_socket(const int af) + : socket(af, SOCK_DGRAM, IPPROTO_UDP) + { + } + + bool udp_socket::send(const address& target, const void* data, const size_t size) const + { + const auto res = sendto(this->get_socket(), static_cast(data), static_cast(size), 0, + &target.get_addr(), target.get_size()); + return static_cast(res) == size; + } + + bool udp_socket::send(const address& target, const std::string_view data) const + { + return this->send(target, data.data(), data.size()); + } + + bool udp_socket::receive(address& source, std::string& data) const + { + char buffer[0x2000]; + auto len = source.get_max_size(); + + const auto result = + recvfrom(this->get_socket(), buffer, static_cast(sizeof(buffer)), 0, &source.get_addr(), &len); + if (result == SOCKET_ERROR) + { + return false; + } + + data.assign(buffer, buffer + result); + return true; + } +} diff --git a/src/common/network/udp_socket.hpp b/src/common/network/udp_socket.hpp new file mode 100644 index 00000000..8bc68f70 --- /dev/null +++ b/src/common/network/udp_socket.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "socket.hpp" + +#include + +namespace network +{ + struct udp_socket : socket + { + udp_socket(int af); + udp_socket() = default; + ~udp_socket() = default; + + udp_socket(udp_socket&& obj) noexcept = default; + udp_socket& operator=(udp_socket&& obj) noexcept = default; + + [[maybe_unused]] bool send(const address& target, const void* data, size_t size) const; + [[maybe_unused]] bool send(const address& target, std::string_view data) const; + bool receive(address& source, std::string& data) const; + }; +}