Handle connect & Fix polling

This commit is contained in:
Igor Pissolati
2025-05-20 18:03:39 -03:00
parent f5ed0752e3
commit 4b83b20e19
8 changed files with 139 additions and 31 deletions

View File

@@ -77,23 +77,23 @@ namespace network
return ::bind(this->socket_, &target.get_addr(), target.get_size()) == 0;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
bool socket::connect(const address& target)
{
return ::connect(this->socket_, &target.get_addr(), target.get_size()) == 0;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
bool socket::listen(int backlog)
{
int result = ::listen(this->socket_, backlog);
if (result == 0)
{
listening_ = true;
return true;
}
return false;
return ::listen(this->socket_, backlog) == 0;
}
// NOLINTNEXTLINE(readability-make-member-function-const)
SOCKET socket::accept(address& address)
{
sockaddr addr{};
int addrlen = sizeof(sockaddr);
socklen_t addrlen = sizeof(sockaddr);
const auto s = ::accept(this->socket_, &addr, &addrlen);
if (s != INVALID_SOCKET)
@@ -187,7 +187,7 @@ namespace network
bool socket::is_listening() const
{
return this->is_valid() && listening_;
return this->is_valid() && is_socket_listening(this->socket_);
}
bool socket::sleep_sockets(const std::span<const socket*>& sockets, const std::chrono::milliseconds timeout,
@@ -246,6 +246,14 @@ namespace network
return !socket_is_ready;
}
bool socket::is_socket_listening(SOCKET s)
{
int val{};
socklen_t len = sizeof(val);
return getsockopt(s, SOL_SOCKET, SO_ACCEPTCONN, reinterpret_cast<char*>(&val), &len) != SOCKET_ERROR &&
val == 1;
}
bool socket::sleep_sockets_until(const std::span<const socket*>& sockets,
const std::chrono::high_resolution_clock::time_point time_point,
const bool in_poll)

View File

@@ -47,6 +47,7 @@ namespace network
bool is_valid() const;
bool bind(const address& target);
bool connect(const address& target);
bool listen(int backlog);
SOCKET accept(address& address);
@@ -72,11 +73,11 @@ namespace network
std::chrono::high_resolution_clock::time_point time_point, bool in_poll);
static bool is_socket_ready(SOCKET s, bool in_poll);
static bool is_socket_listening(SOCKET s);
void close();
private:
SOCKET socket_ = INVALID_SOCKET;
bool listening_{};
};
}

View File

@@ -225,7 +225,7 @@ namespace
const io_device_context& c)
{
constexpr auto info_size = offsetof(AFD_POLL_INFO64, Handles);
if (!c.input_buffer || c.input_buffer_length < info_size)
if (!c.input_buffer || c.input_buffer_length < info_size || c.input_buffer != c.output_buffer)
{
throw std::runtime_error("Bad AFD poll data");
}
@@ -264,7 +264,7 @@ namespace
socket_events |= POLLRDBAND;
}
if (poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND))
if (poll_events & (AFD_POLL_CONNECT | AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND))
{
socket_events |= POLLWRNORM;
}
@@ -272,7 +272,8 @@ namespace
return socket_events;
}
ULONG map_socket_response_events_to_afd(const int16_t socket_events, const bool is_listener)
ULONG map_socket_response_events_to_afd(const int16_t socket_events, const ULONG afd_poll_events,
const bool is_listener)
{
ULONG afd_events = 0;
@@ -288,19 +289,20 @@ namespace
if (socket_events & POLLWRNORM)
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_SEND);
afd_events |= AFD_POLL_SEND;
}
if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR))
if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR) &&
afd_poll_events & (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT))
{
afd_events |= (AFD_POLL_CONNECT_FAIL | AFD_POLL_ABORT);
}
else if (socket_events & POLLHUP)
else if (socket_events & POLLHUP && afd_poll_events & AFD_POLL_DISCONNECT)
{
afd_events |= AFD_POLL_DISCONNECT;
}
if (socket_events & POLLNVAL)
if (socket_events & POLLNVAL && afd_poll_events & AFD_POLL_LOCAL_CLOSE)
{
afd_events |= AFD_POLL_LOCAL_CLOSE;
}
@@ -344,8 +346,10 @@ namespace
continue;
}
const auto& handle = handles[i];
auto entry = handle_info_obj.read(i);
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, pfd.s->is_listening());
entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, handle.PollEvents, pfd.s->is_listening());
entry.Status = STATUS_SUCCESS;
handle_info_obj.write(entry, current_index++);
@@ -358,6 +362,14 @@ namespace
info.NumberOfHandles = static_cast<ULONG>(current_index); //
});
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = info_size + sizeof(AFD_POLL_HANDLE_INFO64) * current_index;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
@@ -376,6 +388,7 @@ namespace
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
std::optional<std::function<void(const io_device_context&)>> timeout_callback_{};
std::unordered_map<LONG, pending_connection> pending_connections_{};
LONG next_sequence_{0};
@@ -417,15 +430,16 @@ namespace
this->s_->set_blocking(false);
}
void delay_ioctrl(const io_device_context& c,
void delay_ioctrl(const io_device_context& c, const std::optional<bool> require_poll = {},
const std::optional<std::chrono::steady_clock::time_point> timeout = {},
const std::optional<bool> require_poll = {})
const std::optional<std::function<void(const io_device_context&)>>& timeout_callback = {})
{
if (this->executing_delayed_ioctl_)
{
return;
}
this->timeout_callback_ = timeout_callback;
this->timeout_ = timeout;
this->require_poll_ = require_poll;
this->delayed_ioctl_ = c;
@@ -433,6 +447,7 @@ namespace
void clear_pending_state()
{
this->timeout_callback_ = {};
this->timeout_ = {};
this->require_poll_ = {};
this->delayed_ioctl_ = {};
@@ -466,6 +481,11 @@ namespace
}
write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT);
if (this->timeout_callback_)
{
(*this->timeout_callback_)(*this->delayed_ioctl_);
}
}
auto* e = win_emu.process.events.get(this->delayed_ioctl_->event);
@@ -499,7 +519,7 @@ namespace
{
if (_AFD_BASE(c.io_control_code) != FSCTL_AFD_BASE)
{
win_emu.log.print(color::cyan, "Bad AFD IOCTL: %X\n", c.io_control_code);
win_emu.log.print(color::cyan, "Bad AFD IOCTL: 0x%X\n", c.io_control_code);
return STATUS_NOT_SUPPORTED;
}
@@ -513,12 +533,14 @@ namespace
const auto request = _AFD_REQUEST(c.io_control_code);
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: %X (%X)\n", c.io_control_code, request);
win_emu.log.print(color::dark_gray, "--> AFD IOCTL: 0x%X (%d)\n", c.io_control_code, request);
switch (request)
{
case AFD_BIND:
return this->ioctl_bind(win_emu, c);
case AFD_CONNECT:
return this->ioctl_connect(win_emu, c);
case AFD_START_LISTEN:
return this->ioctl_listen(win_emu, c);
case AFD_WAIT_FOR_LISTEN:
@@ -539,13 +561,54 @@ namespace
case AFD_GET_INFORMATION:
case AFD_SET_INFORMATION:
case AFD_QUERY_HANDLES:
case AFD_TRANSPORT_IOCTL:
return STATUS_SUCCESS;
default:
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: %X (%X)\n", c.io_control_code, request);
win_emu.log.print(color::gray, "Unsupported AFD IOCTL: 0x%X (%d)\n", c.io_control_code, request);
return STATUS_NOT_SUPPORTED;
}
}
NTSTATUS ioctl_connect(windows_emulator& win_emu, const io_device_context& c)
{
if (!this->s_)
{
throw std::runtime_error("Invalid AFD endpoint socket!");
}
auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length);
constexpr auto address_offset = 12;
if (data.size() < address_offset)
{
return STATUS_BUFFER_TOO_SMALL;
}
const auto addr = convert_to_host_address(win_emu, std::span(data).subspan(address_offset));
if (!this->s_->connect(addr))
{
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const
{
if (!this->s_)
@@ -569,6 +632,14 @@ namespace
return STATUS_ADDRESS_ALREADY_ASSOCIATED;
}
if (c.io_status_block)
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = 0;
c.io_status_block.write(block);
}
return STATUS_SUCCESS;
}
@@ -622,7 +693,7 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
return STATUS_UNSUCCESSFUL;
@@ -751,12 +822,14 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
{
return STATUS_CONNECTION_RESET;
}
return STATUS_UNSUCCESSFUL;
}
@@ -767,7 +840,7 @@ namespace
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_received);
block.Information = static_cast<uint32_t>(bytes_received);
c.io_status_block.write(block);
}
@@ -819,12 +892,14 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);
this->delay_ioctrl(c, false);
return STATUS_PENDING;
}
if (error == SERR(ECONNRESET))
{
return STATUS_CONNECTION_RESET;
}
return STATUS_UNSUCCESSFUL;
}
@@ -833,7 +908,7 @@ namespace
{
IO_STATUS_BLOCK<EmulatorTraits<Emu64>> block{};
block.Status = STATUS_SUCCESS;
block.Information = static_cast<ULONG_PTR>(bytes_sent);
block.Information = static_cast<uint32_t>(bytes_sent);
c.io_status_block.write(block);
}
@@ -892,7 +967,12 @@ namespace
timeout = utils::convert_delay_interval_to_time_point(win_emu.clock(), info.Timeout);
}
this->delay_ioctrl(c, timeout);
this->delay_ioctrl(c, {}, timeout, [&](const io_device_context& dc) {
const emulator_object<AFD_POLL_INFO64> info_obj{win_emu.emu(), dc.input_buffer};
info_obj.access([&](AFD_POLL_INFO64& info) {
info.NumberOfHandles = 0; //
});
});
}
return STATUS_PENDING;
@@ -931,7 +1011,7 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
this->delay_ioctrl(c, true);
return STATUS_PENDING;
}
@@ -991,7 +1071,7 @@ namespace
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);
this->delay_ioctrl(c, false);
return STATUS_PENDING;
}

View File

@@ -199,5 +199,11 @@ struct AFD_POLL_INFO64
#define AFD_NO_OPERATION 39
#define AFD_VALIDATE_GROUP 40
#define AFD_GET_UNACCEPTED_CONNECT_DATA 41
#define AFD_ROUTING_INTERFACE_QUERY 42
#define AFD_ROUTING_INTERFACE_CHANGE 43
#define AFD_ADDRESS_LIST_QUERY 44
#define AFD_ADDRESS_LIST_CHANGE 45
#define AFD_JOIN_LEAF 46
#define AFD_TRANSPORT_IOCTL 47
// NOLINTEND(modernize-use-using,cppcoreguidelines-avoid-c-arrays,hicpp-avoid-c-arrays,modernize-avoid-c-arrays)

View File

@@ -17,6 +17,7 @@ namespace network
virtual bool is_listening() = 0;
virtual bool bind(const address& addr) = 0;
virtual bool connect(const address& addr) = 0;
virtual bool listen(int backlog) = 0;
virtual std::unique_ptr<i_socket> accept(address& address) = 0;

View File

@@ -38,6 +38,11 @@ namespace network
return this->socket_.bind(addr);
}
bool socket_wrapper::connect(const address& addr)
{
return this->socket_.connect(addr);
}
bool socket_wrapper::listen(int backlog)
{
return this->socket_.listen(backlog);

View File

@@ -19,6 +19,7 @@ namespace network
bool is_listening() override;
bool bind(const address& addr) override;
bool connect(const address& addr) override;
bool listen(int backlog) override;
std::unique_ptr<i_socket> accept(address& address) override;

View File

@@ -76,6 +76,12 @@ namespace network
return true;
}
bool connect(const address& addr) override
{
this->a = addr;
return true;
}
bool listen(int) override
{
throw std::runtime_error("Not implemented");