Add socket abstraction

This commit is contained in:
Maurice Heumann
2025-03-20 15:17:43 +01:00
parent 2cb14a3555
commit 4da6642123
15 changed files with 437 additions and 71 deletions

View File

@@ -3,6 +3,7 @@
#include "afd_types.hpp"
#include "../windows_emulator.hpp"
#include "../network/socket_factory.hpp"
#include <network/address.hpp>
#include <network/socket.hpp>
@@ -313,10 +314,10 @@ namespace
}
NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c,
const std::span<const SOCKET> endpoints,
const std::span<network::i_socket* const> endpoints,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{
std::vector<pollfd> poll_data{};
std::vector<network::poll_entry> poll_data{};
poll_data.resize(endpoints.size());
for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i)
@@ -324,12 +325,12 @@ namespace
auto& pfd = poll_data.at(i);
const auto& handle = handles[i];
pfd.fd = endpoints[i];
pfd.s = endpoints[i];
pfd.events = map_afd_request_events_to_socket(handle.PollEvents);
pfd.revents = pfd.events;
}
const auto count = poll(poll_data.data(), static_cast<uint32_t>(poll_data.size()), 0);
const auto count = win_emu.socket_factory().poll_sockets(poll_data);
if (count <= 0)
{
return STATUS_PENDING;
@@ -358,17 +359,20 @@ namespace
assert(current_index == static_cast<size_t>(count));
emulator_object<AFD_POLL_INFO64>{win_emu.emu(), c.input_buffer}.access(
[&](AFD_POLL_INFO64& info) { info.NumberOfHandles = static_cast<ULONG>(current_index); });
const emulator_object<AFD_POLL_INFO64> info_obj{win_emu.emu(), c.input_buffer};
info_obj.access([&](AFD_POLL_INFO64& info) {
info.NumberOfHandles = static_cast<ULONG>(current_index); //
});
return STATUS_SUCCESS;
}
struct afd_endpoint : io_device
{
std::unique_ptr<network::i_socket> s_{};
bool executing_delayed_ioctl_{};
std::optional<afd_creation_data> creation_data{};
std::optional<SOCKET> s_{};
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
@@ -381,21 +385,15 @@ namespace
afd_endpoint(afd_endpoint&&) = delete;
afd_endpoint& operator=(afd_endpoint&&) = delete;
~afd_endpoint() override
{
if (this->s_)
{
closesocket(*this->s_);
}
}
~afd_endpoint() override = default;
void create(windows_emulator& win_emu, const io_device_creation_data& data) override
{
this->creation_data = get_creation_data(win_emu, data);
this->setup();
this->setup(win_emu.socket_factory());
}
void setup()
void setup(network::socket_factory& factory)
{
if (!this->creation_data)
{
@@ -408,15 +406,13 @@ namespace
const auto type = translate_win_to_host_type(data.type);
const auto protocol = translate_win_to_host_protocol(data.protocol);
const auto sock = socket(af, type, protocol);
if (sock == INVALID_SOCKET)
this->s_ = factory.create_socket(af, type, protocol);
if (!this->s_)
{
throw std::runtime_error("Failed to create socket!");
}
network::socket::set_blocking(sock, false);
this->s_ = sock;
this->s_->set_blocking(false);
}
void delay_ioctrl(const io_device_context& c,
@@ -452,7 +448,7 @@ namespace
if (this->require_poll_.has_value())
{
const auto is_ready = network::socket::is_socket_ready(*this->s_, *this->require_poll_);
const auto is_ready = this->s_->is_ready(*this->require_poll_);
if (!is_ready)
{
return;
@@ -482,7 +478,7 @@ namespace
void deserialize_object(utils::buffer_deserializer& buffer) override
{
buffer.read_optional(this->creation_data);
this->setup();
this->setup(buffer.read<socket_factory_wrapper>());
buffer.read_optional(this->require_poll_);
buffer.read_optional(this->delayed_ioctl_);
@@ -546,7 +542,7 @@ namespace
const auto addr = convert_to_host_address(win_emu, std::span(data).subspan(address_offset));
if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
if (!this->s_->bind(addr))
{
return STATUS_ADDRESS_ALREADY_ASSOCIATED;
}
@@ -554,12 +550,12 @@ namespace
return STATUS_SUCCESS;
}
static std::vector<SOCKET> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
static std::vector<network::i_socket*> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{
auto& proc = win_emu.process;
std::vector<SOCKET> endpoints{};
std::vector<network::i_socket*> endpoints{};
endpoints.reserve(handles.size());
for (const auto& handle : handles)
@@ -576,7 +572,7 @@ namespace
throw std::runtime_error("Invalid AFD endpoint!");
}
endpoints.push_back(*endpoint->s_);
endpoints.push_back(endpoint->s_.get());
}
return endpoints;
@@ -635,17 +631,14 @@ namespace
}
network::address from{};
auto from_length = from.get_max_size();
std::vector<char> data{};
std::vector<std::byte> data{};
data.resize(buffer.len);
const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast<send_size>(data.size()), 0,
&from.get_addr(), &from_length);
const auto recevied_data = this->s_->recvfrom(from, data);
if (recevied_data < 0)
{
const auto error = GET_SOCKET_ERROR();
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
@@ -655,8 +648,6 @@ namespace
return STATUS_UNSUCCESSFUL;
}
assert(from.get_size() == from_length);
const auto data_size = std::min(data.size(), static_cast<size_t>(recevied_data));
emu.write_memory(buffer.buf, data.data(), data_size);
@@ -704,13 +695,10 @@ namespace
const auto target = convert_to_host_address(win_emu, address_buffer);
const auto data = emu.read_memory(buffer.buf, buffer.len);
const auto sent_data =
sendto(*this->s_, reinterpret_cast<const char*>(data.data()), static_cast<send_size>(data.size()),
0 /* TODO */, &target.get_addr(), target.get_size());
const auto sent_data = this->s_->sendto(target, data);
if (sent_data < 0)
{
const auto error = GET_SOCKET_ERROR();
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);