Handle listen/accept/send/receive in afd_endpoint

This commit is contained in:
Igor Pissolati
2025-05-20 00:56:22 -03:00
parent 0f4cc3365c
commit f5ed0752e3
9 changed files with 412 additions and 11 deletions

View File

@@ -259,11 +259,6 @@ namespace
socket_events |= POLLRDNORM;
}
if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDNORM;
}
if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDBAND;
@@ -277,13 +272,13 @@ namespace
return socket_events;
}
ULONG map_socket_response_events_to_afd(const int16_t socket_events)
ULONG map_socket_response_events_to_afd(const int16_t socket_events, const bool is_listener)
{
ULONG afd_events = 0;
if (socket_events & POLLRDNORM)
{
afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE);
afd_events |= is_listener ? AFD_POLL_ACCEPT : AFD_POLL_RECEIVE;
}
if (socket_events & POLLRDBAND)
@@ -350,11 +345,10 @@ namespace
}
auto entry = handle_info_obj.read(i);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, pfd.s->is_listening());
entry.Status = STATUS_SUCCESS;
handle_info_obj.write(entry, current_index++);
break;
}
assert(current_index == static_cast<size_t>(count));
@@ -369,6 +363,12 @@ namespace
struct afd_endpoint : io_device
{
struct pending_connection
{
network::address remote_address;
std::unique_ptr<network::i_socket> accepted_socket;
};
std::unique_ptr<network::i_socket> s_{};
bool executing_delayed_ioctl_{};
@@ -376,6 +376,8 @@ namespace
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
std::unordered_map<LONG, pending_connection> pending_connections_{};
LONG next_sequence_{0};
afd_endpoint()
{
@@ -501,14 +503,32 @@ namespace
return STATUS_NOT_SUPPORTED;
}
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X\n", c.io_control_code);
if (this->delayed_ioctl_)
{
if (auto* e = win_emu.process.events.get(c.event))
{
e->signaled = false;
}
}
const auto request = _AFD_REQUEST(c.io_control_code);
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X (%X)\n", c.io_control_code, request);
switch (request)
{
case AFD_BIND:
return this->ioctl_bind(win_emu, c);
case AFD_START_LISTEN:
return this->ioctl_listen(win_emu, c);
case AFD_WAIT_FOR_LISTEN:
return this->ioctl_wait_for_listen(win_emu, c);
case AFD_ACCEPT:
return this->ioctl_accept(win_emu, c);
case AFD_SEND:
return this->ioctl_send(win_emu, c);
case AFD_RECEIVE:
return this->ioctl_receive(win_emu, c);
case AFD_SEND_DATAGRAM:
return this->ioctl_send_datagram(win_emu, c);
case AFD_RECEIVE_DATAGRAM:
@@ -517,9 +537,11 @@ namespace
return this->ioctl_poll(win_emu, c);
case AFD_SET_CONTEXT:
case AFD_GET_INFORMATION:
case AFD_SET_INFORMATION:
case AFD_QUERY_HANDLES:
return STATUS_SUCCESS;
default:
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code);
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X (%X)\n", c.io_control_code, request);
return STATUS_NOT_SUPPORTED;
}
}
@@ -550,6 +572,274 @@ namespace
return STATUS_SUCCESS;
}
NTSTATUS ioctl_listen(windows_emulator& win_emu, const io_device_context& c) const
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_LISTEN_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto listen_info = win_emu.emu().read_memory<AFD_LISTEN_INFO>(c.input_buffer);
if (!this->s_->listen(static_cast<int>(listen_info.MaximumConnectionQueue)))
{
return STATUS_INVALID_PARAMETER;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_wait_for_listen(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.output_buffer_length < sizeof(AFD_LISTEN_RESPONSE_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
network::address remote_address{};
auto accepted_socket_ptr = this->s_->accept(remote_address);
if (!accepted_socket_ptr)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
}
if (!remote_address.is_ipv4())
{
throw std::runtime_error("Unsupported address family");
}
pending_connection pending{};
pending.remote_address = remote_address;
pending.accepted_socket = std::move(accepted_socket_ptr);
LONG sequence = next_sequence_++;
pending_connections_.try_emplace(sequence, std::move(pending));
AFD_LISTEN_RESPONSE_INFO response{};
response.Sequence = sequence;
auto transport_buffer = convert_to_win_address(win_emu, remote_address);
memcpy(&response.RemoteAddress, transport_buffer.data(), sizeof(win_sockaddr));
win_emu.emu().write_memory<AFD_LISTEN_RESPONSE_INFO>(c.output_buffer, response);
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = sizeof(AFD_LISTEN_RESPONSE_INFO);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_accept(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_ACCEPT_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto accept_info = win_emu.emu().read_memory<AFD_ACCEPT_INFO>(c.input_buffer);
const auto it = pending_connections_.find(accept_info.Sequence);
if (it == pending_connections_.end())
{
return STATUS_INVALID_PARAMETER;
}
auto& accepted_socket = it->second.accepted_socket;
auto* target_device = win_emu.process.devices.get(accept_info.AcceptHandle);
if (!target_device)
{
return STATUS_INVALID_HANDLE;
}
auto* target_endpoint = target_device->get_internal_device<afd_endpoint>();
if (!target_endpoint)
{
return STATUS_INVALID_HANDLE;
}
target_endpoint->s_ = std::move(accepted_socket);
pending_connections_.erase(it);
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_receive(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_RECV_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto receive_info = emu.read_memory<AFD_RECV_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!receive_info.BufferArray || receive_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (receive_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(receive_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
const auto bytes_received = this->s_->recv(host_buffer);
if (bytes_received < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
return STATUS_CONNECTION_RESET;
return STATUS_UNSUCCESSFUL;
}
emu.write_memory(wsabuf.buf, host_buffer.data(), static_cast<size_t>(bytes_received));
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_received);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_send(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_SEND_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto send_info = emu.read_memory<AFD_SEND_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!send_info.BufferArray || send_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (send_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(send_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
emu.read_memory(wsabuf.buf, host_buffer.data(), host_buffer.size());
const auto bytes_sent = this->s_->send(host_buffer);
if (bytes_sent < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
return STATUS_CONNECTION_RESET;
return STATUS_UNSUCCESSFUL;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_sent);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
static std::vector<network::i_socket*> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{