Add socket abstraction

This commit is contained in:
Maurice Heumann
2025-03-20 15:17:43 +01:00
parent 2cb14a3555
commit 4da6642123
15 changed files with 437 additions and 71 deletions

View File

@@ -0,0 +1,26 @@
#pragma once
#include <span>
#include <network/socket.hpp>
namespace network
{
struct i_socket
{
virtual ~i_socket() = default;
virtual void set_blocking(bool blocking) = 0;
virtual int get_last_error() = 0;
virtual bool is_ready(bool in_poll) = 0;
virtual bool bind(const address& addr) = 0;
virtual sent_size send(std::span<const std::byte> data) = 0;
virtual sent_size sendto(const address& destination, std::span<const std::byte> data) = 0;
virtual sent_size recv(std::span<std::byte> data) = 0;
virtual sent_size recvfrom(address& source, std::span<std::byte> data) = 0;
};
}

View File

@@ -0,0 +1,55 @@
#include "socket_factory.hpp"
#include "socket_wrapper.hpp"
namespace network
{
socket_factory::socket_factory()
{
initialize_wsa();
}
std::unique_ptr<i_socket> socket_factory::create_socket(int af, int type, int protocol)
{
return std::make_unique<socket_wrapper>(af, type, protocol);
}
int socket_factory::poll_sockets(const std::span<poll_entry> entries)
{
std::vector<pollfd> poll_data{};
poll_data.reserve(entries.size());
for (const auto& entry : entries)
{
if (!entry.s)
{
throw std::runtime_error("Bad socket given!");
}
const auto* wrapper = dynamic_cast<socket_wrapper*>(entry.s);
if (!wrapper)
{
throw std::runtime_error("Socket was not created using the given factory");
}
pollfd fd{};
fd.fd = wrapper->get().get_socket();
fd.events = entry.events;
fd.revents = entry.revents;
poll_data.push_back(fd);
}
const auto res = poll(poll_data.data(), static_cast<uint32_t>(poll_data.size()), 0);
for (size_t i = 0; i < poll_data.size() && i < entries.size(); ++i)
{
auto& entry = entries[i];
const auto& fd = poll_data[i];
entry.events = fd.events;
entry.revents = fd.revents;
}
return res;
}
}

View File

@@ -0,0 +1,24 @@
#pragma once
#include "i_socket.hpp"
#include <memory>
namespace network
{
struct poll_entry
{
i_socket* s{};
int16_t events{};
int16_t revents{};
};
struct socket_factory
{
socket_factory();
virtual ~socket_factory() = default;
virtual std::unique_ptr<i_socket> create_socket(int af, int type, int protocol);
virtual int poll_sockets(std::span<poll_entry> entries);
};
}

View File

@@ -0,0 +1,59 @@
#include "socket_wrapper.hpp"
#include <cassert>
namespace network
{
socket_wrapper::socket_wrapper(const int af, const int type, const int protocol)
: socket_(af, type, protocol)
{
}
void socket_wrapper::set_blocking(const bool blocking)
{
this->socket_.set_blocking(blocking);
}
int socket_wrapper::get_last_error()
{
return GET_SOCKET_ERROR();
}
bool socket_wrapper::is_ready(const bool in_poll)
{
return this->socket_.is_ready(in_poll);
}
bool socket_wrapper::bind(const address& addr)
{
return this->socket_.bind(addr);
}
sent_size socket_wrapper::send(const std::span<const std::byte> data)
{
return ::send(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),
static_cast<send_size>(data.size()), 0);
}
sent_size socket_wrapper::sendto(const address& destination, const std::span<const std::byte> data)
{
return ::sendto(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),
static_cast<send_size>(data.size()), 0, &destination.get_addr(), destination.get_size());
}
sent_size socket_wrapper::recv(std::span<std::byte> data)
{
return ::recv(this->socket_.get_socket(), reinterpret_cast<char*>(data.data()),
static_cast<send_size>(data.size()), 0);
}
sent_size socket_wrapper::recvfrom(address& source, std::span<std::byte> data)
{
auto source_length = source.get_max_size();
const auto res = ::recvfrom(this->socket_.get_socket(), reinterpret_cast<char*>(data.data()),
static_cast<send_size>(data.size()), 0, &source.get_addr(), &source_length);
assert(source.get_size() == source_length);
return res;
}
}

View File

@@ -0,0 +1,35 @@
#pragma once
#include "i_socket.hpp"
namespace network
{
class socket_wrapper : public i_socket
{
public:
socket_wrapper(int af, int type, int protocol);
~socket_wrapper() override = default;
void set_blocking(bool blocking) override;
int get_last_error() override;
bool is_ready(bool in_poll) override;
bool bind(const address& addr) override;
sent_size send(std::span<const std::byte> data) override;
sent_size sendto(const address& destination, std::span<const std::byte> data) override;
sent_size recv(std::span<std::byte> data) override;
sent_size recvfrom(address& source, std::span<std::byte> data) override;
const socket& get() const
{
return this->socket_;
}
private:
socket socket_{};
};
}