Handle listen/accept/send/receive in afd_endpoint (#246)

This PR attempts to add support for `listen`, `accept`, `send`, and
`recv` to afd_endpoint.

The changes in this PR allow me to get a simple HTTP server running in
the emulator, but I'm still likely missing something as I run into a
mapping violation on `select` when performing chunked downloads from the
HTTP server.

Also, there are still some things I'm still uncertain about:

- Where the `delayed_ioctl_` event should be reset? It wasn't being
reset previously, which caused issues. In this PR, I'm resetting it with
every IO operation.
- ~~Is `AFD_LISTEN_RESPONSE_INFO` correct? The struct Windows expects
appears to be 20 bytes long, so I had to remove some fields from the
structure.~~
- ~~`pragma pack(push, 1)` seems to be necessary, but it doesn't seem to
be used in the other structures, maybe they fit neatly in the
alignment?~~
This commit is contained in:
Maurice Heumann
2025-05-21 18:12:50 +02:00
committed by GitHub
8 changed files with 465 additions and 25 deletions

View File

@@ -43,6 +43,7 @@ using NTSTATUS = std::uint32_t;
#define STATUS_FILE_IS_A_DIRECTORY ((NTSTATUS)0xC00000BAL)
#define STATUS_NOT_SUPPORTED ((NTSTATUS)0xC00000BBL)
#define STATUS_INVALID_ADDRESS ((NTSTATUS)0xC0000141L)
#define STATUS_CONNECTION_RESET ((NTSTATUS)0xC000020DL)
#define STATUS_NOT_FOUND ((NTSTATUS)0xC0000225L)
#define STATUS_CONNECTION_REFUSED ((NTSTATUS)0xC0000236L)
#define STATUS_TIMER_RESOLUTION_NOT_SET ((NTSTATUS)0xC0000245L)

View File

@@ -225,7 +225,7 @@ namespace
const io_device_context& c)
{
constexpr auto info_size = offsetof(AFD_POLL_INFO64, Handles);
if (!c.input_buffer || c.input_buffer_length < info_size)
if (!c.input_buffer || c.input_buffer_length < info_size || c.input_buffer != c.output_buffer)
{
throw std::runtime_error("Bad AFD poll data");
}
@@ -254,12 +254,7 @@ namespace
{
int16_t socket_events{};
if (poll_events & (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE))
{
socket_events |= POLLRDNORM;
}
if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
if (poll_events & (AFD_POLL_DISCONNECT | AFD_POLL_ACCEPT | AFD_POLL_RECEIVE))
{
socket_events |= POLLRDNORM;
}
@@ -269,7 +264,7 @@ namespace
socket_events |= POLLRDBAND;
}
if (poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND))
if (poll_events & (AFD_POLL_CONNECT | AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND))
{
socket_events |= POLLWRNORM;
}
@@ -277,13 +272,14 @@ namespace
return socket_events;
}
ULONG map_socket_response_events_to_afd(const int16_t socket_events)
ULONG map_socket_response_events_to_afd(const int16_t socket_events, const ULONG afd_poll_events,
const bool is_listener)
{
ULONG afd_events = 0;
if (socket_events & POLLRDNORM)
{
afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE);
afd_events |= is_listener ? AFD_POLL_ACCEPT : AFD_POLL_RECEIVE;
}
if (socket_events & POLLRDBAND)
@@ -293,19 +289,20 @@ namespace
if (socket_events & POLLWRNORM)
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND);
afd_events |= AFD_POLL_SEND;
}
if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR))
if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR) &&
afd_poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT))
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT);
}
else if (socket_events & POLLHUP)
else if (socket_events & POLLHUP && afd_poll_events & AFD_POLL_DISCONNECT)
{
afd_events |= AFD_POLL_DISCONNECT;
}
if (socket_events & POLLNVAL)
if (socket_events & POLLNVAL && afd_poll_events & AFD_POLL_LOCAL_CLOSE)
{
afd_events |= AFD_POLL_LOCAL_CLOSE;
}
@@ -349,12 +346,13 @@ namespace
continue;
}
const auto& handle = handles[i];
auto entry = handle_info_obj.read(i);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, handle.PollEvents, pfd.s->is_listening());
entry.Status = STATUS_SUCCESS;
handle_info_obj.write(entry, current_index++);
break;
}
assert(current_index == static_cast<size_t>(count));
@@ -364,11 +362,24 @@ namespace
info.NumberOfHandles = static_cast<ULONG>(current_index); //
});
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Information = info_size + sizeof(AFD_POLL_HANDLE_INFO64) * current_index;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
struct afd_endpoint : io_device
{
struct pending_connection
{
network::address remote_address;
std::unique_ptr<network::i_socket> accepted_socket;
};
std::unique_ptr<network::i_socket> s_{};
bool executing_delayed_ioctl_{};
@@ -376,6 +387,9 @@ namespace
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
std::optional<std::function<void(const io_device_context&)>> timeout_callback_{};
std::unordered_map<LONG, pending_connection> pending_connections_{};
LONG next_sequence_{0};
afd_endpoint()
{
@@ -415,15 +429,16 @@ namespace
this->s_->set_blocking(false);
}
void delay_ioctrl(const io_device_context& c,
void delay_ioctrl(const io_device_context& c, const std::optional<bool> require_poll = {},
const std::optional<std::chrono::steady_clock::time_point> timeout = {},
const std::optional<bool> require_poll = {})
const std::optional<std::function<void(const io_device_context&)>>& timeout_callback = {})
{
if (this->executing_delayed_ioctl_)
{
return;
}
this->timeout_callback_ = timeout_callback;
this->timeout_ = timeout;
this->require_poll_ = require_poll;
this->delayed_ioctl_ = c;
@@ -431,6 +446,7 @@ namespace
void clear_pending_state()
{
this->timeout_callback_ = {};
this->timeout_ = {};
this->require_poll_ = {};
this->delayed_ioctl_ = {};
@@ -464,6 +480,11 @@ namespace
}
write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT);
if (this->timeout_callback_)
{
(*this->timeout_callback_)(*this->delayed_ioctl_);
}
}
auto* e = win_emu.process.events.get(this->delayed_ioctl_->event);
@@ -497,18 +518,30 @@ namespace
{
if (_AFD_BASE(c.io_control_code) != FSCTL_AFD_BASE)
{
win_emu.log.print(color::cyan, "Bad AFD IOCTL: %X\n", c.io_control_code);
win_emu.log.print(color::cyan, "Bad AFD IOCTL: 0x%X\n", c.io_control_code);
return STATUS_NOT_SUPPORTED;
}
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X\n", c.io_control_code);
const auto request = _AFD_REQUEST(c.io_control_code);
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: 0x%X (%d)\n", c.io_control_code, request);
switch (request)
{
case AFD_BIND:
return this->ioctl_bind(win_emu, c);
case AFD_CONNECT:
return this->ioctl_connect(win_emu, c);
case AFD_START_LISTEN:
return this->ioctl_listen(win_emu, c);
case AFD_WAIT_FOR_LISTEN:
return this->ioctl_wait_for_listen(win_emu, c);
case AFD_ACCEPT:
return this->ioctl_accept(win_emu, c);
case AFD_SEND:
return this->ioctl_send(win_emu, c);
case AFD_RECEIVE:
return this->ioctl_receive(win_emu, c);
case AFD_SEND_DATAGRAM:
return this->ioctl_send_datagram(win_emu, c);
case AFD_RECEIVE_DATAGRAM:
@@ -517,13 +550,48 @@ namespace
return this->ioctl_poll(win_emu, c);
case AFD_SET_CONTEXT:
case AFD_GET_INFORMATION:
case AFD_SET_INFORMATION:
case AFD_QUERY_HANDLES:
case AFD_TRANSPORT_IOCTL:
return STATUS_SUCCESS;
default:
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code);
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: 0x%X (%d)\n", c.io_control_code, request);
return STATUS_NOT_SUPPORTED;
}
}
NTSTATUS ioctl_connect(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length);
constexpr auto address_offset = 12;
if (data.size() < address_offset)
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto addr = convert_to_host_address(win_emu, std::span(data).subspan(address_offset));
if (!this->s_->connect(addr))
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const
{
if (!this->s_)
@@ -550,6 +618,259 @@ namespace
return STATUS_SUCCESS;
}
NTSTATUS ioctl_listen(windows_emulator& win_emu, const io_device_context& c) const
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_LISTEN_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto listen_info = win_emu.emu().read_memory<AFD_LISTEN_INFO>(c.input_buffer);
if (!this->s_->listen(static_cast<int>(listen_info.MaximumConnectionQueue)))
{
return STATUS_INVALID_PARAMETER;
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_wait_for_listen(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.output_buffer_length < sizeof(AFD_LISTEN_RESPONSE_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
network::address remote_address{};
auto accepted_socket_ptr = this->s_->accept(remote_address);
if (!accepted_socket_ptr)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
}
if (!remote_address.is_ipv4())
{
throw std::runtime_error("Unsupported address family");
}
pending_connection pending{};
pending.remote_address = remote_address;
pending.accepted_socket = std::move(accepted_socket_ptr);
LONG sequence = next_sequence_++;
pending_connections_.try_emplace(sequence, std::move(pending));
AFD_LISTEN_RESPONSE_INFO response{};
response.Sequence = sequence;
auto transport_buffer = convert_to_win_address(win_emu, remote_address);
memcpy(&response.RemoteAddress, transport_buffer.data(), sizeof(win_sockaddr));
win_emu.emu().write_memory<AFD_LISTEN_RESPONSE_INFO>(c.output_buffer, response);
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Information = sizeof(AFD_LISTEN_RESPONSE_INFO);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_accept(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_ACCEPT_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto accept_info = win_emu.emu().read_memory<AFD_ACCEPT_INFO>(c.input_buffer);
const auto it = pending_connections_.find(accept_info.Sequence);
if (it == pending_connections_.end())
{
return STATUS_INVALID_PARAMETER;
}
auto& accepted_socket = it->second.accepted_socket;
auto* target_device = win_emu.process.devices.get(accept_info.AcceptHandle);
if (!target_device)
{
return STATUS_INVALID_HANDLE;
}
auto* target_endpoint = target_device->get_internal_device<afd_endpoint>();
if (!target_endpoint)
{
return STATUS_INVALID_HANDLE;
}
target_endpoint->s_ = std::move(accepted_socket);
pending_connections_.erase(it);
return STATUS_SUCCESS;
}
NTSTATUS ioctl_receive(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_RECV_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto receive_info = emu.read_memory<AFD_RECV_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!receive_info.BufferArray || receive_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (receive_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(receive_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
const auto bytes_received = this->s_->recv(host_buffer);
if (bytes_received < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
{
return STATUS_CONNECTION_RESET;
}
return STATUS_UNSUCCESSFUL;
}
emu.write_memory(wsabuf.buf, host_buffer.data(), static_cast<size_t>(bytes_received));
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Information = static_cast<uint32_t>(bytes_received);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_send(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_SEND_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto send_info = emu.read_memory<AFD_SEND_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!send_info.BufferArray || send_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (send_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(send_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
emu.read_memory(wsabuf.buf, host_buffer.data(), host_buffer.size());
const auto bytes_sent = this->s_->send(host_buffer);
if (bytes_sent < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, false);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
{
return STATUS_CONNECTION_RESET;
}
return STATUS_UNSUCCESSFUL;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Information = static_cast<uint32_t>(bytes_sent);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
static std::vector<network::i_socket*> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{
@@ -602,7 +923,12 @@ namespace
timeout = utils::convert_delay_interval_to_time_point(win_emu.clock(), info.Timeout);
}
this->delay_ioctrl(c, timeout);
this->delay_ioctrl(c, {}, timeout, [&win_emu](const io_device_context& dc) {
const emulator_object<AFD_POLL_INFO64> info_obj{win_emu.emu(), dc.input_buffer};
info_obj.access([&](AFD_POLL_INFO64& info) {
info.NumberOfHandles = 0; //
});
});
}
return STATUS_PENDING;
@@ -641,7 +967,7 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
@@ -701,7 +1027,7 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);
this->delay_ioctrl(c, false);
return STATUS_PENDING;
}

View File

@@ -6,6 +6,32 @@
typedef LONG TDI_STATUS;
struct win_sockaddr
{
USHORT sa_family;
CHAR sa_data[14];
};
struct AFD_LISTEN_INFO
{
BOOLEAN SanActive;
ULONG MaximumConnectionQueue;
BOOLEAN UseDelayedAcceptance;
};
struct AFD_LISTEN_RESPONSE_INFO
{
LONG Sequence;
win_sockaddr RemoteAddress;
};
struct AFD_ACCEPT_INFO
{
BOOLEAN SanActive;
LONG Sequence;
handle AcceptHandle;
};
template <typename Traits>
struct TDI_CONNECTION_INFORMATION
{
@@ -173,5 +199,11 @@ struct AFD_POLL_INFO64
#define AFD_NO_OPERATION 39
#define AFD_VALIDATE_GROUP 40
#define AFD_GET_UNACCEPTED_CONNECT_DATA 41
#define AFD_ROUTING_INTERFACE_QUERY 42
#define AFD_ROUTING_INTERFACE_CHANGE 43
#define AFD_ADDRESS_LIST_QUERY 44
#define AFD_ADDRESS_LIST_CHANGE 45
#define AFD_JOIN_LEAF 46
#define AFD_TRANSPORT_IOCTL 47
// NOLINTEND(modernize-use-using,cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)

View File

@@ -14,8 +14,12 @@ namespace network
virtual int get_last_error() = 0;
virtual bool is_ready(bool in_poll) = 0;
virtual bool is_listening() = 0;
virtual bool bind(const address& addr) = 0;
virtual bool connect(const address& addr) = 0;
virtual bool listen(int backlog) = 0;
virtual std::unique_ptr<i_socket> accept(address& address) = 0;
virtual sent_size send(std::span<const std::byte> data) = 0;
virtual sent_size sendto(const address& destination, std::span<const std::byte> data) = 0;

View File

@@ -3,6 +3,11 @@
namespace network
{
socket_wrapper::socket_wrapper(SOCKET s)
: socket_(s)
{
}
socket_wrapper::socket_wrapper(const int af, const int type, const int protocol)
: socket_(af, type, protocol)
{
@@ -23,11 +28,52 @@ namespace network
return this->socket_.is_ready(in_poll);
}
bool socket_wrapper::is_listening()
{
if (!this->socket_.is_valid())
{
return false;
}
int val{};
socklen_t len = sizeof(val);
const auto res =
getsockopt(this->socket_.get_socket(), SOL_SOCKET, SO_ACCEPTCONN, reinterpret_cast<char*>(&val), &len);
return res != SOCKET_ERROR && val == 1;
}
bool socket_wrapper::bind(const address& addr)
{
return this->socket_.bind(addr);
}
bool socket_wrapper::connect(const address& addr)
{
return ::connect(this->socket_.get_socket(), &addr.get_addr(), addr.get_size()) == 0;
}
bool socket_wrapper::listen(int backlog)
{
return ::listen(this->socket_.get_socket(), backlog) == 0;
}
std::unique_ptr<i_socket> socket_wrapper::accept(address& address)
{
sockaddr addr{};
socklen_t addrlen = sizeof(sockaddr);
const auto s = ::accept(this->socket_.get_socket(), &addr, &addrlen);
if (s == INVALID_SOCKET)
{
return nullptr;
}
address.set_address(&addr, addrlen);
return std::make_unique<socket_wrapper>(s);
}
sent_size socket_wrapper::send(const std::span<const std::byte> data)
{
return ::send(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),

View File

@@ -7,6 +7,7 @@ namespace network
class socket_wrapper : public i_socket
{
public:
socket_wrapper(SOCKET s);
socket_wrapper(int af, int type, int protocol);
~socket_wrapper() override = default;
@@ -15,8 +16,12 @@ namespace network
int get_last_error() override;
bool is_ready(bool in_poll) override;
bool is_listening() override;
bool bind(const address& addr) override;
bool connect(const address& addr) override;
bool listen(int backlog) override;
std::unique_ptr<i_socket> accept(address& address) override;
sent_size send(std::span<const std::byte> data) override;
sent_size sendto(const address& destination, std::span<const std::byte> data) override;

View File

@@ -65,12 +65,33 @@ namespace network
return true;
}
bool is_listening() override
{
return false;
}
bool bind(const address& addr) override
{
this->a = addr;
return true;
}
bool connect(const address& addr) override
{
this->a = addr;
return true;
}
bool listen(int) override
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<i_socket> accept(address&) override
{
throw std::runtime_error("Not implemented");
}
sent_size send(std::span<const std::byte>) override
{
throw std::runtime_error("Not implemented");

View File

@@ -441,6 +441,11 @@ namespace syscalls
return STATUS_INVALID_HANDLE;
}
if (auto* e = c.win_emu.process.events.get(event))
{
e->signaled = false;
}
io_device_context context{c.emu};
context.event = event;
context.apc_routine = apc_routine;