Handle listen/accept/send/receive in afd_endpoint

This commit is contained in:
Igor Pissolati
2025-05-20 00:56:22 -03:00
parent 0f4cc3365c
commit f5ed0752e3
9 changed files with 412 additions and 11 deletions

View File

@@ -77,6 +77,33 @@ namespace network
return ::bind(this->socket_, &target.get_addr(), target.get_size()) == 0;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
bool socket::listen(int backlog)
{
int result = ::listen(this->socket_, backlog);
if (result == 0)
{
listening_ = true;
return true;
}
return false;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
SOCKET socket::accept(address& address)
{
sockaddr addr{};
int addrlen = sizeof(sockaddr);
const auto s = ::accept(this->socket_, &addr, &addrlen);
if (s != INVALID_SOCKET)
{
address.set_address(&addr, addrlen);
}
return s;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
bool socket::set_blocking(const bool blocking)
{
@@ -158,6 +185,11 @@ namespace network
return this->is_valid() && is_socket_ready(this->socket_, in_poll);
}
bool socket::is_listening() const
{
return this->is_valid() && listening_;
}
bool socket::sleep_sockets(const std::span<const socket*>& sockets, const std::chrono::milliseconds timeout,
const bool in_poll)
{

View File

@@ -47,6 +47,8 @@ namespace network
bool is_valid() const;
bool bind(const address& target);
bool listen(int backlog);
SOCKET accept(address& address);
bool set_blocking(bool blocking);
static bool set_blocking(SOCKET s, bool blocking);
@@ -62,6 +64,7 @@ namespace network
int get_address_family() const;
bool is_ready(bool in_poll) const;
bool is_listening() const;
static bool sleep_sockets(const std::span<const socket*>& sockets, std::chrono::milliseconds timeout,
bool in_poll);
@@ -74,5 +77,6 @@ namespace network
private:
SOCKET socket_ = INVALID_SOCKET;
bool listening_{};
};
}

View File

@@ -43,6 +43,7 @@ using NTSTATUS = std::uint32_t;
#define STATUS_FILE_IS_A_DIRECTORY ((NTSTATUS)0xC00000BAL)
#define STATUS_NOT_SUPPORTED ((NTSTATUS)0xC00000BBL)
#define STATUS_INVALID_ADDRESS ((NTSTATUS)0xC0000141L)
#define STATUS_CONNECTION_RESET ((NTSTATUS)0xC000020DL)
#define STATUS_NOT_FOUND ((NTSTATUS)0xC0000225L)
#define STATUS_CONNECTION_REFUSED ((NTSTATUS)0xC0000236L)
#define STATUS_TIMER_RESOLUTION_NOT_SET ((NTSTATUS)0xC0000245L)

View File

@@ -259,11 +259,6 @@ namespace
socket_events |= POLLRDNORM;
}
if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDNORM;
}
if (poll_events & AFD_POLL_RECEIVE_EXPEDITED)
{
socket_events |= POLLRDBAND;
@@ -277,13 +272,13 @@ namespace
return socket_events;
}
ULONG map_socket_response_events_to_afd(const int16_t socket_events)
ULONG map_socket_response_events_to_afd(const int16_t socket_events, const bool is_listener)
{
ULONG afd_events = 0;
if (socket_events & POLLRDNORM)
{
afd_events |= (AFD_POLL_ACCEPT | AFD_POLL_RECEIVE);
afd_events |= is_listener ? AFD_POLL_ACCEPT : AFD_POLL_RECEIVE;
}
if (socket_events & POLLRDBAND)
@@ -350,11 +345,10 @@ namespace
}
auto entry = handle_info_obj.read(i);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, pfd.s->is_listening());
entry.Status = STATUS_SUCCESS;
handle_info_obj.write(entry, current_index++);
break;
}
assert(current_index == static_cast<size_t>(count));
@@ -369,6 +363,12 @@ namespace
struct afd_endpoint : io_device
{
struct pending_connection
{
network::address remote_address;
std::unique_ptr<network::i_socket> accepted_socket;
};
std::unique_ptr<network::i_socket> s_{};
bool executing_delayed_ioctl_{};
@@ -376,6 +376,8 @@ namespace
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
std::unordered_map<LONG, pending_connection> pending_connections_{};
LONG next_sequence_{0};
afd_endpoint()
{
@@ -501,14 +503,32 @@ namespace
return STATUS_NOT_SUPPORTED;
}
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X\n", c.io_control_code);
if (this->delayed_ioctl_)
{
if (auto* e = win_emu.process.events.get(c.event))
{
e->signaled = false;
}
}
const auto request = _AFD_REQUEST(c.io_control_code);
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X (%X)\n", c.io_control_code, request);
switch (request)
{
case AFD_BIND:
return this->ioctl_bind(win_emu, c);
case AFD_START_LISTEN:
return this->ioctl_listen(win_emu, c);
case AFD_WAIT_FOR_LISTEN:
return this->ioctl_wait_for_listen(win_emu, c);
case AFD_ACCEPT:
return this->ioctl_accept(win_emu, c);
case AFD_SEND:
return this->ioctl_send(win_emu, c);
case AFD_RECEIVE:
return this->ioctl_receive(win_emu, c);
case AFD_SEND_DATAGRAM:
return this->ioctl_send_datagram(win_emu, c);
case AFD_RECEIVE_DATAGRAM:
@@ -517,9 +537,11 @@ namespace
return this->ioctl_poll(win_emu, c);
case AFD_SET_CONTEXT:
case AFD_GET_INFORMATION:
case AFD_SET_INFORMATION:
case AFD_QUERY_HANDLES:
return STATUS_SUCCESS;
default:
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X\n", c.io_control_code);
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X (%X)\n", c.io_control_code, request);
return STATUS_NOT_SUPPORTED;
}
}
@@ -550,6 +572,274 @@ namespace
return STATUS_SUCCESS;
}
NTSTATUS ioctl_listen(windows_emulator& win_emu, const io_device_context& c) const
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_LISTEN_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto listen_info = win_emu.emu().read_memory<AFD_LISTEN_INFO>(c.input_buffer);
if (!this->s_->listen(static_cast<int>(listen_info.MaximumConnectionQueue)))
{
return STATUS_INVALID_PARAMETER;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_wait_for_listen(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.output_buffer_length < sizeof(AFD_LISTEN_RESPONSE_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
network::address remote_address{};
auto accepted_socket_ptr = this->s_->accept(remote_address);
if (!accepted_socket_ptr)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
}
if (!remote_address.is_ipv4())
{
throw std::runtime_error("Unsupported address family");
}
pending_connection pending{};
pending.remote_address = remote_address;
pending.accepted_socket = std::move(accepted_socket_ptr);
LONG sequence = next_sequence_++;
pending_connections_.try_emplace(sequence, std::move(pending));
AFD_LISTEN_RESPONSE_INFO response{};
response.Sequence = sequence;
auto transport_buffer = convert_to_win_address(win_emu, remote_address);
memcpy(&response.RemoteAddress, transport_buffer.data(), sizeof(win_sockaddr));
win_emu.emu().write_memory<AFD_LISTEN_RESPONSE_INFO>(c.output_buffer, response);
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = sizeof(AFD_LISTEN_RESPONSE_INFO);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_accept(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
if (c.input_buffer_length < sizeof(AFD_ACCEPT_INFO))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto accept_info = win_emu.emu().read_memory<AFD_ACCEPT_INFO>(c.input_buffer);
const auto it = pending_connections_.find(accept_info.Sequence);
if (it == pending_connections_.end())
{
return STATUS_INVALID_PARAMETER;
}
auto& accepted_socket = it->second.accepted_socket;
auto* target_device = win_emu.process.devices.get(accept_info.AcceptHandle);
if (!target_device)
{
return STATUS_INVALID_HANDLE;
}
auto* target_endpoint = target_device->get_internal_device<afd_endpoint>();
if (!target_endpoint)
{
return STATUS_INVALID_HANDLE;
}
target_endpoint->s_ = std::move(accepted_socket);
pending_connections_.erase(it);
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_receive(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_RECV_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto receive_info = emu.read_memory<AFD_RECV_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!receive_info.BufferArray || receive_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (receive_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(receive_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
const auto bytes_received = this->s_->recv(host_buffer);
if (bytes_received < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
return STATUS_CONNECTION_RESET;
return STATUS_UNSUCCESSFUL;
}
emu.write_memory(wsabuf.buf, host_buffer.data(), static_cast<size_t>(bytes_received));
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_received);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_send(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto& emu = win_emu.emu();
if (c.input_buffer_length < sizeof(AFD_SEND_INFO<EmulatorTraits<Emu64>>))
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto send_info = emu.read_memory<AFD_SEND_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
if (!send_info.BufferArray || send_info.BufferCount == 0)
{
return STATUS_INVALID_PARAMETER;
}
if (send_info.BufferCount > 1)
{
// TODO: Scatter/Gather
return STATUS_NOT_SUPPORTED;
}
const auto wsabuf = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(send_info.BufferArray);
if (!wsabuf.buf || wsabuf.len == 0)
{
return STATUS_INVALID_PARAMETER;
}
std::vector<std::byte> host_buffer;
host_buffer.resize(wsabuf.len);
emu.read_memory(wsabuf.buf, host_buffer.data(), host_buffer.size());
const auto bytes_sent = this->s_->send(host_buffer);
if (bytes_sent < 0)
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
return STATUS_CONNECTION_RESET;
return STATUS_UNSUCCESSFUL;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_sent);
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
static std::vector<network::i_socket*> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{

View File

@@ -6,6 +6,32 @@
typedef LONG TDI_STATUS;
struct win_sockaddr
{
USHORT sa_family;
CHAR sa_data[14];
};
struct AFD_LISTEN_INFO
{
BOOLEAN SanActive;
ULONG MaximumConnectionQueue;
BOOLEAN UseDelayedAcceptance;
};
struct AFD_LISTEN_RESPONSE_INFO
{
LONG Sequence;
win_sockaddr RemoteAddress;
};
struct AFD_ACCEPT_INFO
{
BOOLEAN SanActive;
LONG Sequence;
handle AcceptHandle;
};
template <typename Traits>
struct TDI_CONNECTION_INFORMATION
{

View File

@@ -14,8 +14,11 @@ namespace network
virtual int get_last_error() = 0;
virtual bool is_ready(bool in_poll) = 0;
virtual bool is_listening() = 0;
virtual bool bind(const address& addr) = 0;
virtual bool listen(int backlog) = 0;
virtual std::unique_ptr<i_socket> accept(address& address) = 0;
virtual sent_size send(std::span<const std::byte> data) = 0;
virtual sent_size sendto(const address& destination, std::span<const std::byte> data) = 0;

View File

@@ -3,6 +3,11 @@
namespace network
{
socket_wrapper::socket_wrapper(SOCKET s)
: socket_(s)
{
}
socket_wrapper::socket_wrapper(const int af, const int type, const int protocol)
: socket_(af, type, protocol)
{
@@ -23,11 +28,32 @@ namespace network
return this->socket_.is_ready(in_poll);
}
bool socket_wrapper::is_listening()
{
return this->socket_.is_listening();
}
bool socket_wrapper::bind(const address& addr)
{
return this->socket_.bind(addr);
}
bool socket_wrapper::listen(int backlog)
{
return this->socket_.listen(backlog);
}
std::unique_ptr<i_socket> socket_wrapper::accept(address& address)
{
const auto s = this->socket_.accept(address);
if (s == INVALID_SOCKET)
{
return nullptr;
}
return std::make_unique<socket_wrapper>(s);
}
sent_size socket_wrapper::send(const std::span<const std::byte> data)
{
return ::send(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),

View File

@@ -7,6 +7,7 @@ namespace network
class socket_wrapper : public i_socket
{
public:
socket_wrapper(SOCKET s);
socket_wrapper(int af, int type, int protocol);
~socket_wrapper() override = default;
@@ -15,8 +16,11 @@ namespace network
int get_last_error() override;
bool is_ready(bool in_poll) override;
bool is_listening() override;
bool bind(const address& addr) override;
bool listen(int backlog) override;
std::unique_ptr<i_socket> accept(address& address) override;
sent_size send(std::span<const std::byte> data) override;
sent_size sendto(const address& destination, std::span<const std::byte> data) override;

View File

@@ -65,12 +65,27 @@ namespace network
return true;
}
bool is_listening() override
{
return false;
}
bool bind(const address& addr) override
{
this->a = addr;
return true;
}
bool listen(int) override
{
throw std::runtime_error("Not implemented");
}
std::unique_ptr<i_socket> accept(address&) override
{
throw std::runtime_error("Not implemented");
}
sent_size send(std::span<const std::byte>) override
{
throw std::runtime_error("Not implemented");