diff --git a/src/common/network/address.cpp b/src/common/network/address.cpp index bb116a90..2fb173ef 100644 --- a/src/common/network/address.cpp +++ b/src/common/network/address.cpp @@ -38,7 +38,7 @@ namespace network this->address_.sa_family = AF_UNSPEC; } - address::address(const std::string& addr, const std::optional& family) + address::address(const std::string_view addr, const std::optional& family) : address() { this->parse(addr, family); @@ -286,7 +286,7 @@ namespace network return is_ipv4() || is_ipv6(); } - void address::parse(std::string addr, const std::optional& family) + void address::parse(std::string_view addr, const std::optional& family) { std::optional port_value{}; @@ -298,7 +298,7 @@ namespace network addr = addr.substr(0, pos); } - this->resolve(addr, family); + this->resolve(std::string(addr), family); if (port_value) { diff --git a/src/common/network/address.hpp b/src/common/network/address.hpp index f8b6c7df..aa434850 100644 --- a/src/common/network/address.hpp +++ b/src/common/network/address.hpp @@ -26,6 +26,7 @@ #endif #include +#include #include #include @@ -42,7 +43,7 @@ namespace network { public: address(); - address(const std::string& addr, const std::optional& family = {}); + address(std::string_view addr, const std::optional& family = std::nullopt); address(const sockaddr_in& addr); address(const sockaddr_in6& addr); address(const sockaddr* addr, socklen_t length); @@ -91,7 +92,7 @@ namespace network sockaddr_storage storage_; }; - void parse(std::string addr, const std::optional& family = {}); + void parse(std::string_view addr, const std::optional& family = {}); void resolve(const std::string& hostname, const std::optional& family = {}); }; } diff --git a/src/common/network/socket.cpp b/src/common/network/socket.cpp index 93a67d1c..88f6d2cd 100644 --- a/src/common/network/socket.cpp +++ b/src/common/network/socket.cpp @@ -62,9 +62,9 @@ namespace network } } - bool socket::bind_port(const address& target) + bool socket::bind(const address& target) { - return bind(this->socket_, &target.get_addr(), target.get_size()) == 0; + return ::bind(this->socket_, &target.get_addr(), target.get_size()) == 0; } bool socket::set_blocking(const bool blocking) diff --git a/src/common/network/socket.hpp b/src/common/network/socket.hpp index f520ad41..4696fb88 100644 --- a/src/common/network/socket.hpp +++ b/src/common/network/socket.hpp @@ -11,6 +11,7 @@ using send_size = int; #define GET_SOCKET_ERROR() (WSAGetLastError()) #define poll WSAPoll #define SOCK_WOULDBLOCK WSAEWOULDBLOCK +#define SHUT_RDWR SD_BOTH #else using SOCKET = int; using send_size = size_t; @@ -31,7 +32,7 @@ namespace network socket(SOCKET s); socket(int af, int type, int protocol); - ~socket(); + virtual ~socket(); socket(const socket& obj) = delete; socket& operator=(const socket& obj) = delete; @@ -41,7 +42,7 @@ namespace network operator bool() const; - bool bind_port(const address& target); + bool bind(const address& target); bool set_blocking(bool blocking); static bool set_blocking(SOCKET s, bool blocking); diff --git a/src/common/network/tcp_client_socket.cpp b/src/common/network/tcp_client_socket.cpp index a30e4d5a..83953d38 100644 --- a/src/common/network/tcp_client_socket.cpp +++ b/src/common/network/tcp_client_socket.cpp @@ -2,12 +2,24 @@ namespace network { - tcp_client_socket::tcp_client_socket(SOCKET s, const address& target) - : socket(s), - target_(target) + tcp_client_socket::tcp_client_socket(const int af) + : socket(af, SOCK_STREAM, IPPROTO_TCP) { } + tcp_client_socket::tcp_client_socket(SOCKET s) + : socket(s) + { + } + + tcp_client_socket::~tcp_client_socket() + { + if (*this && this->get_target()) + { + ::shutdown(this->get_socket(), SHUT_RDWR); + } + } + bool tcp_client_socket::send(const void* data, const size_t size) const { const auto res = ::send(this->get_socket(), static_cast(data), static_cast(size), 0); @@ -33,8 +45,26 @@ namespace network return true; } - address tcp_client_socket::get_target() const + std::optional
tcp_client_socket::get_target() const { - return this->target_; + address a{}; + auto len = a.get_max_size(); + if (getpeername(this->get_socket(), &a.get_addr(), &len) == SOCKET_ERROR) + { + return std::nullopt; + } + + return a; + } + + bool tcp_client_socket::connect(const address& target) + { + if (::connect(this->get_socket(), &target.get_addr(), target.get_size()) != SOCKET_ERROR) + { + return true; + } + + const auto error = GET_SOCKET_ERROR(); + return error == SOCK_WOULDBLOCK; } } diff --git a/src/common/network/tcp_client_socket.hpp b/src/common/network/tcp_client_socket.hpp index f09bbf9a..8d297215 100644 --- a/src/common/network/tcp_client_socket.hpp +++ b/src/common/network/tcp_client_socket.hpp @@ -10,10 +10,11 @@ namespace network class tcp_client_socket : public socket { - // TODO: Construct and connect client! + public: + tcp_client_socket(int af); tcp_client_socket() = default; - ~tcp_client_socket() = default; + ~tcp_client_socket() override; tcp_client_socket(tcp_client_socket&& obj) noexcept = default; tcp_client_socket& operator=(tcp_client_socket&& obj) noexcept = default; @@ -22,12 +23,12 @@ namespace network [[maybe_unused]] bool send(std::string_view data) const; bool receive(std::string& data) const; - address get_target() const; + std::optional
get_target() const; + + bool connect(const address& target); private: - address target_{}; - friend tcp_server_socket; - tcp_client_socket(SOCKET s, const address& target); + tcp_client_socket(SOCKET s); }; } diff --git a/src/common/network/tcp_server_socket.cpp b/src/common/network/tcp_server_socket.cpp index da9b5437..806bb875 100644 --- a/src/common/network/tcp_server_socket.cpp +++ b/src/common/network/tcp_server_socket.cpp @@ -13,14 +13,7 @@ namespace network address a{}; auto len = a.get_max_size(); - auto s = ::accept(this->get_socket(), &a.get_addr(), &len); - - if (s == INVALID_SOCKET) - { - return {}; - } - - return tcp_client_socket{s, a}; + return ::accept(this->get_socket(), &a.get_addr(), &len); } void tcp_server_socket::listen() diff --git a/src/common/network/tcp_server_socket.hpp b/src/common/network/tcp_server_socket.hpp index 78289f01..3d62fb1f 100644 --- a/src/common/network/tcp_server_socket.hpp +++ b/src/common/network/tcp_server_socket.hpp @@ -7,10 +7,11 @@ namespace network { class tcp_server_socket : public socket { + public: tcp_server_socket(int af); tcp_server_socket() = default; - ~tcp_server_socket() = default; + ~tcp_server_socket() override = default; tcp_server_socket(tcp_server_socket&& obj) noexcept = default; tcp_server_socket& operator=(tcp_server_socket&& obj) noexcept = default; diff --git a/src/common/network/udp_socket.hpp b/src/common/network/udp_socket.hpp index 8bc68f70..2e240728 100644 --- a/src/common/network/udp_socket.hpp +++ b/src/common/network/udp_socket.hpp @@ -10,7 +10,7 @@ namespace network { udp_socket(int af); udp_socket() = default; - ~udp_socket() = default; + ~udp_socket() override = default; udp_socket(udp_socket&& obj) noexcept = default; udp_socket& operator=(udp_socket&& obj) noexcept = default;