diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index fd3dd8e3..bc1ab997 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -273,23 +273,37 @@ namespace } ULONG map_socket_response_events_to_afd(const int16_t socket_events, const ULONG afd_poll_events, - const bool is_listener) + const bool is_listening, const bool is_connecting) { ULONG afd_events = 0; if (socket_events & POLLRDNORM) { - afd_events |= is_listener ? AFD_POLL_ACCEPT : AFD_POLL_RECEIVE; + if (!is_listening && afd_poll_events & AFD_POLL_RECEIVE) + { + afd_events |= AFD_POLL_RECEIVE; + } + else if (is_listening && afd_poll_events & AFD_POLL_ACCEPT) + { + afd_events |= AFD_POLL_ACCEPT; + } } - if (socket_events & POLLRDBAND) + if (socket_events & POLLRDBAND && afd_poll_events & AFD_POLL_RECEIVE_EXPEDITED) { afd_events |= AFD_POLL_RECEIVE_EXPEDITED; } if (socket_events & POLLWRNORM) { - afd_events |= AFD_POLL_SEND; + if (!is_connecting && afd_poll_events & AFD_POLL_SEND) + { + afd_events |= AFD_POLL_SEND; + } + else if (is_connecting && afd_poll_events & AFD_POLL_CONNECT) + { + afd_events |= AFD_POLL_CONNECT; + } } if ((socket_events & (POLLHUP | POLLERR)) == (POLLHUP | POLLERR) && @@ -310,68 +324,6 @@ namespace return afd_events; } - NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c, - const std::span endpoints, - const std::span handles) - { - std::vector poll_data{}; - poll_data.resize(endpoints.size()); - - for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i) - { - auto& pfd = poll_data.at(i); - const auto& handle = handles[i]; - - pfd.s = endpoints[i]; - pfd.events = map_afd_request_events_to_socket(handle.PollEvents); - pfd.revents = pfd.events; - } - - const auto count = win_emu.socket_factory().poll_sockets(poll_data); - if (count <= 0) - { - return STATUS_PENDING; - } - - constexpr auto info_size = offsetof(AFD_POLL_INFO64, Handles); - const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; - - size_t current_index = 0; - - for (size_t i = 0; i < endpoints.size(); ++i) - { - const auto& pfd = poll_data.at(i); - if (pfd.revents == 0) - { - continue; - } - - const auto& handle = handles[i]; - - auto entry = handle_info_obj.read(i); - entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, handle.PollEvents, pfd.s->is_listening()); - entry.Status = STATUS_SUCCESS; - - handle_info_obj.write(entry, current_index++); - } - - assert(current_index == static_cast(count)); - - const emulator_object info_obj{win_emu.emu(), c.input_buffer}; - info_obj.access([&](AFD_POLL_INFO64& info) { - info.NumberOfHandles = static_cast(current_index); // - }); - - if (c.io_status_block) - { - IO_STATUS_BLOCK> block{}; - block.Information = info_size + sizeof(AFD_POLL_HANDLE_INFO64) * current_index; - c.io_status_block.write(block); - } - - return STATUS_SUCCESS; - } - struct afd_endpoint : io_device { struct pending_connection @@ -387,10 +339,15 @@ namespace std::optional require_poll_{}; std::optional delayed_ioctl_{}; std::optional timeout_{}; - std::optional> timeout_callback_{}; + std::optional> timeout_callback_{}; + std::unordered_map pending_connections_{}; LONG next_sequence_{0}; + std::optional event_select_event_{}; + ULONG event_select_mask_{0}; + ULONG triggered_events_{0}; + afd_endpoint() { network::initialize_wsa(); @@ -431,7 +388,8 @@ namespace void delay_ioctrl(const io_device_context& c, const std::optional require_poll = {}, const std::optional timeout = {}, - const std::optional>& timeout_callback = {}) + const std::optional>& + timeout_callback = {}) { if (this->executing_delayed_ioctl_) { @@ -454,46 +412,89 @@ namespace void work(windows_emulator& win_emu) override { - if (!this->delayed_ioctl_ || !this->s_) + if (!this->s_ || (!this->delayed_ioctl_ && !this->event_select_mask_)) { return; } - this->executing_delayed_ioctl_ = true; - const auto _ = utils::finally([&] { this->executing_delayed_ioctl_ = false; }); + network::poll_entry pfd{}; + pfd.s = this->s_.get(); - if (this->require_poll_.has_value()) + if (this->delayed_ioctl_ && this->require_poll_.has_value()) { - const auto is_ready = this->s_->is_ready(*this->require_poll_); - if (!is_ready) + pfd.events |= *this->require_poll_ ? POLLIN : POLLOUT; + } + if (this->event_select_mask_) + { + pfd.events = + static_cast(pfd.events | map_afd_request_events_to_socket(this->event_select_mask_)); + } + pfd.revents = pfd.events; + + if (pfd.events != 0) + { + win_emu.socket_factory().poll_sockets(std::span{&pfd, 1}); + } + + const auto socket_events = pfd.revents; + + if (socket_events && this->event_select_mask_) + { + const bool is_connecting = + this->delayed_ioctl_ && _AFD_REQUEST(this->delayed_ioctl_->io_control_code) == AFD_CONNECT; + ULONG current_events = map_socket_response_events_to_afd(socket_events, this->event_select_mask_, + pfd.s->is_listening(), is_connecting); + + if ((current_events & ~this->triggered_events_) != 0) { - return; + this->triggered_events_ |= current_events; + + if (auto* event = win_emu.process.events.get(*this->event_select_event_)) + { + event->signaled = true; + } } } - const auto status = this->execute_ioctl(win_emu, *this->delayed_ioctl_); - if (status == STATUS_PENDING) + if (this->delayed_ioctl_) { - if (!this->timeout_ || this->timeout_ > win_emu.clock().steady_now()) + this->executing_delayed_ioctl_ = true; + const auto _ = utils::finally([&] { this->executing_delayed_ioctl_ = false; }); + + if (this->require_poll_.has_value()) { - return; + const auto is_ready = + socket_events & ((*this->require_poll_ ? POLLIN : POLLOUT) | POLLHUP | POLLERR); + if (!is_ready) + { + return; + } } - write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT); - - if (this->timeout_callback_) + const auto status = this->execute_ioctl(win_emu, *this->delayed_ioctl_); + if (status == STATUS_PENDING) { - (*this->timeout_callback_)(*this->delayed_ioctl_); + if (!this->timeout_ || this->timeout_ > win_emu.clock().steady_now()) + { + return; + } + + write_io_status(this->delayed_ioctl_->io_status_block, STATUS_TIMEOUT); + + if (this->timeout_callback_) + { + (*this->timeout_callback_)(win_emu, *this->delayed_ioctl_); + } } - } - auto* e = win_emu.process.events.get(this->delayed_ioctl_->event); - if (e) - { - e->signaled = true; - } + auto* e = win_emu.process.events.get(this->delayed_ioctl_->event); + if (e) + { + e->signaled = true; + } - this->clear_pending_state(); + this->clear_pending_state(); + } } void deserialize_object(utils::buffer_deserializer& buffer) override @@ -548,6 +549,12 @@ namespace return this->ioctl_receive_datagram(win_emu, c); case AFD_POLL: return this->ioctl_poll(win_emu, c); + case AFD_GET_ADDRESS: + return this->ioctl_get_address(win_emu, c); + case AFD_EVENT_SELECT: + return this->ioctl_event_select(win_emu, c); + case AFD_ENUM_NETWORK_EVENTS: + return this->ioctl_enum_network_events(win_emu, c); case AFD_SET_CONTEXT: case AFD_GET_INFORMATION: case AFD_SET_INFORMATION: @@ -878,12 +885,12 @@ namespace return STATUS_SUCCESS; } - static std::vector resolve_endpoints(windows_emulator& win_emu, - const std::span handles) + static std::vector resolve_endpoints(windows_emulator& win_emu, + const std::span handles) { auto& proc = win_emu.process; - std::vector endpoints{}; + std::vector endpoints{}; endpoints.reserve(handles.size()); for (const auto& handle : handles) @@ -900,12 +907,79 @@ namespace throw std::runtime_error("Invalid AFD endpoint!"); } - endpoints.push_back(endpoint->s_.get()); + endpoints.push_back(endpoint); } return endpoints; } + static NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c, + const std::span endpoints, + const std::span handles) + { + std::vector poll_data{}; + poll_data.resize(endpoints.size()); + + for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i) + { + auto& pfd = poll_data.at(i); + const auto& handle = handles[i]; + + pfd.s = endpoints[i]->s_.get(); + pfd.events = map_afd_request_events_to_socket(handle.PollEvents); + pfd.revents = pfd.events; + } + + const auto count = win_emu.socket_factory().poll_sockets(poll_data); + if (count <= 0) + { + return STATUS_PENDING; + } + + constexpr auto info_size = offsetof(AFD_POLL_INFO64, Handles); + const emulator_object handle_info_obj{win_emu.emu(), c.input_buffer + info_size}; + + size_t current_index = 0; + + for (size_t i = 0; i < endpoints.size(); ++i) + { + const auto& pfd = poll_data.at(i); + if (pfd.revents == 0) + { + continue; + } + + const auto& handle = handles[i]; + const auto& endpoint = endpoints[i]; + + const bool is_connecting = + endpoint->delayed_ioctl_ && _AFD_REQUEST(endpoint->delayed_ioctl_->io_control_code) == AFD_CONNECT; + + auto entry = handle_info_obj.read(i); + entry.PollEvents = map_socket_response_events_to_afd(pfd.revents, handle.PollEvents, + pfd.s->is_listening(), is_connecting); + entry.Status = STATUS_SUCCESS; + + handle_info_obj.write(entry, current_index++); + } + + assert(current_index == static_cast(count)); + + const emulator_object info_obj{win_emu.emu(), c.input_buffer}; + info_obj.access([&](AFD_POLL_INFO64& info) { + info.NumberOfHandles = static_cast(current_index); // + }); + + if (c.io_status_block) + { + IO_STATUS_BLOCK> block{}; + block.Information = info_size + sizeof(AFD_POLL_HANDLE_INFO64) * current_index; + c.io_status_block.write(block); + } + + return STATUS_SUCCESS; + } + NTSTATUS ioctl_poll(windows_emulator& win_emu, const io_device_context& c) { const auto [info, handles] = get_poll_info(win_emu, c); @@ -919,9 +993,21 @@ namespace if (!this->executing_delayed_ioctl_) { + const auto timeout_callback = [](windows_emulator& win_emu, const io_device_context& c) { + const emulator_object info_obj{win_emu.emu(), c.input_buffer}; + info_obj.access([&](AFD_POLL_INFO64& info) { + info.NumberOfHandles = 0; // + }); + }; + if (!info.Timeout.QuadPart) { - return status; + if (status == STATUS_PENDING) + { + timeout_callback(win_emu, c); + return STATUS_TIMEOUT; + } + return STATUS_SUCCESS; } std::optional timeout{}; @@ -930,12 +1016,7 @@ namespace timeout = utils::convert_delay_interval_to_time_point(win_emu.clock(), info.Timeout); } - this->delay_ioctrl(c, {}, timeout, [&win_emu](const io_device_context& dc) { - const emulator_object info_obj{win_emu.emu(), dc.input_buffer}; - info_obj.access([&](AFD_POLL_INFO64& info) { - info.NumberOfHandles = 0; // - }); - }); + this->delay_ioctrl(c, {}, timeout, timeout_callback); } return STATUS_PENDING; @@ -1050,6 +1131,107 @@ namespace return STATUS_SUCCESS; } + + NTSTATUS ioctl_get_address(windows_emulator& win_emu, const io_device_context& c) const + { + if (!this->s_) + { + throw std::runtime_error("Invalid AFD endpoint socket!"); + } + + const auto local_address = this->s_->get_local_address(); + if (!local_address) + { + return STATUS_INVALID_PARAMETER; + } + + std::vector win_addr_bytes = convert_to_win_address(win_emu, *local_address); + + if (c.output_buffer_length < win_addr_bytes.size()) + { + return STATUS_BUFFER_TOO_SMALL; + } + + win_emu.emu().write_memory(c.output_buffer, win_addr_bytes.data(), win_addr_bytes.size()); + + if (c.io_status_block) + { + IO_STATUS_BLOCK> block{}; + block.Information = static_cast(win_addr_bytes.size()); + c.io_status_block.write(block); + } + + return STATUS_SUCCESS; + } + + NTSTATUS ioctl_event_select(windows_emulator& win_emu, const io_device_context& c) + { + if (!this->s_) + { + throw std::runtime_error("Invalid AFD endpoint socket!"); + } + + if (c.input_buffer_length < sizeof(AFD_EVENT_SELECT_INFO)) + { + return STATUS_BUFFER_TOO_SMALL; + } + + const auto select_info = win_emu.emu().read_memory(c.input_buffer); + + this->event_select_event_ = select_info.Event; + this->event_select_mask_ = select_info.PollEvents; + this->triggered_events_ = 0; + + if (auto* event = win_emu.process.events.get(select_info.Event)) + { + event->signaled = false; + } + + return STATUS_SUCCESS; + } + + NTSTATUS ioctl_enum_network_events(windows_emulator& win_emu, const io_device_context& c) + { + if (!this->s_) + { + throw std::runtime_error("Invalid AFD endpoint socket!"); + } + + if (c.output_buffer_length < 56) + { + return STATUS_BUFFER_TOO_SMALL; + } + + if (c.input_buffer) + { + if (c.input_buffer_length == 0) + { + handle h{}; + h.bits = c.input_buffer; + + if (auto* event = win_emu.process.events.get(h)) + { + event->signaled = false; + } + } + else + { + return STATUS_NOT_SUPPORTED; + } + } + + win_emu.emu().write_memory(c.output_buffer, this->triggered_events_); + this->triggered_events_ = 0; + + if (c.io_status_block) + { + IO_STATUS_BLOCK> block{}; + block.Information = 56; + c.io_status_block.write(block); + } + + return STATUS_SUCCESS; + } }; struct afd_async_connect_hlp : stateless_device diff --git a/src/windows-emulator/devices/afd_types.hpp b/src/windows-emulator/devices/afd_types.hpp index 5c9818a6..a225282b 100644 --- a/src/windows-emulator/devices/afd_types.hpp +++ b/src/windows-emulator/devices/afd_types.hpp @@ -119,31 +119,37 @@ struct AFD_POLL_INFO64 AFD_POLL_HANDLE_INFO64 Handles[1]; }; -#define AFD_POLL_RECEIVE_BIT 0 -#define AFD_POLL_RECEIVE (1 << AFD_POLL_RECEIVE_BIT) -#define AFD_POLL_RECEIVE_EXPEDITED_BIT 1 -#define AFD_POLL_RECEIVE_EXPEDITED (1 << AFD_POLL_RECEIVE_EXPEDITED_BIT) -#define AFD_POLL_SEND_BIT 2 -#define AFD_POLL_SEND (1 << AFD_POLL_SEND_BIT) -#define AFD_POLL_DISCONNECT_BIT 3 -#define AFD_POLL_DISCONNECT (1 << AFD_POLL_DISCONNECT_BIT) -#define AFD_POLL_ABORT_BIT 4 -#define AFD_POLL_ABORT (1 << AFD_POLL_ABORT_BIT) -#define AFD_POLL_LOCAL_CLOSE_BIT 5 -#define AFD_POLL_LOCAL_CLOSE (1 << AFD_POLL_LOCAL_CLOSE_BIT) -#define AFD_POLL_CONNECT_BIT 6 -#define AFD_POLL_CONNECT (1 << AFD_POLL_CONNECT_BIT) -#define AFD_POLL_ACCEPT_BIT 7 -#define AFD_POLL_ACCEPT (1 << AFD_POLL_ACCEPT_BIT) -#define AFD_POLL_CONNECT_FAIL_BIT 8 -#define AFD_POLL_CONNECT_FAIL (1 << AFD_POLL_CONNECT_FAIL_BIT) -#define AFD_POLL_QOS_BIT 9 -#define AFD_POLL_QOS (1 << AFD_POLL_QOS_BIT) -#define AFD_POLL_GROUP_QOS_BIT 10 -#define AFD_POLL_GROUP_QOS (1 << AFD_POLL_GROUP_QOS_BIT) +#define AFD_POLL_RECEIVE_BIT 0 +#define AFD_POLL_RECEIVE (1 << AFD_POLL_RECEIVE_BIT) +#define AFD_POLL_RECEIVE_EXPEDITED_BIT 1 +#define AFD_POLL_RECEIVE_EXPEDITED (1 << AFD_POLL_RECEIVE_EXPEDITED_BIT) +#define AFD_POLL_SEND_BIT 2 +#define AFD_POLL_SEND (1 << AFD_POLL_SEND_BIT) +#define AFD_POLL_DISCONNECT_BIT 3 +#define AFD_POLL_DISCONNECT (1 << AFD_POLL_DISCONNECT_BIT) +#define AFD_POLL_ABORT_BIT 4 +#define AFD_POLL_ABORT (1 << AFD_POLL_ABORT_BIT) +#define AFD_POLL_LOCAL_CLOSE_BIT 5 +#define AFD_POLL_LOCAL_CLOSE (1 << AFD_POLL_LOCAL_CLOSE_BIT) +#define AFD_POLL_CONNECT_BIT 6 +#define AFD_POLL_CONNECT (1 << AFD_POLL_CONNECT_BIT) +#define AFD_POLL_ACCEPT_BIT 7 +#define AFD_POLL_ACCEPT (1 << AFD_POLL_ACCEPT_BIT) +#define AFD_POLL_CONNECT_FAIL_BIT 8 +#define AFD_POLL_CONNECT_FAIL (1 << AFD_POLL_CONNECT_FAIL_BIT) +#define AFD_POLL_QOS_BIT 9 +#define AFD_POLL_QOS (1 << AFD_POLL_QOS_BIT) +#define AFD_POLL_GROUP_QOS_BIT 10 +#define AFD_POLL_GROUP_QOS (1 << AFD_POLL_GROUP_QOS_BIT) -#define AFD_NUM_POLL_EVENTS 11 -#define AFD_POLL_ALL ((1 << AFD_NUM_POLL_EVENTS) - 1) +#define AFD_NUM_POLL_EVENTS 11 +#define AFD_POLL_ALL ((1 << AFD_NUM_POLL_EVENTS) - 1) + +struct AFD_EVENT_SELECT_INFO +{ + handle Event; + ULONG PollEvents; +}; #define _AFD_REQUEST(ioctl) ((((ULONG)(ioctl)) >> 2) & 0x03FF) #define _AFD_BASE(ioctl) ((((ULONG)(ioctl)) >> 12) & 0xFFFFF) diff --git a/src/windows-emulator/network/i_socket.hpp b/src/windows-emulator/network/i_socket.hpp index 946790ce..dbd84a85 100644 --- a/src/windows-emulator/network/i_socket.hpp +++ b/src/windows-emulator/network/i_socket.hpp @@ -16,6 +16,8 @@ namespace network virtual bool is_ready(bool in_poll) = 0; virtual bool is_listening() = 0; + virtual std::optional
get_local_address() = 0; + virtual bool bind(const address& addr) = 0; virtual bool connect(const address& addr) = 0; virtual bool listen(int backlog) = 0; diff --git a/src/windows-emulator/network/socket_wrapper.cpp b/src/windows-emulator/network/socket_wrapper.cpp index d1a0bd81..a869b55b 100644 --- a/src/windows-emulator/network/socket_wrapper.cpp +++ b/src/windows-emulator/network/socket_wrapper.cpp @@ -43,6 +43,22 @@ namespace network return res != SOCKET_ERROR && val == 1; } + std::optional
socket_wrapper::get_local_address() + { + sockaddr addr{}; + socklen_t addrlen = sizeof(sockaddr); + const auto res = ::getsockname(this->socket_.get_socket(), &addr, &addrlen); + + if (res != 0) + { + return {}; + } + + address address{}; + address.set_address(&addr, addrlen); + return address; + } + bool socket_wrapper::bind(const address& addr) { return this->socket_.bind(addr); diff --git a/src/windows-emulator/network/socket_wrapper.hpp b/src/windows-emulator/network/socket_wrapper.hpp index d3401d99..f9a64f40 100644 --- a/src/windows-emulator/network/socket_wrapper.hpp +++ b/src/windows-emulator/network/socket_wrapper.hpp @@ -18,6 +18,8 @@ namespace network bool is_ready(bool in_poll) override; bool is_listening() override; + std::optional
get_local_address() override; + bool bind(const address& addr) override; bool connect(const address& addr) override; bool listen(int backlog) override; diff --git a/src/windows-emulator/network/static_socket_factory.cpp b/src/windows-emulator/network/static_socket_factory.cpp index 5f40d49a..bc6be302 100644 --- a/src/windows-emulator/network/static_socket_factory.cpp +++ b/src/windows-emulator/network/static_socket_factory.cpp @@ -70,6 +70,11 @@ namespace network return false; } + std::optional
get_local_address() override + { + return this->a; + } + bool bind(const address& addr) override { this->a = addr;