diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 05f15e9c..0068561b 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -49,8 +49,8 @@ namespace static_assert(sizeof(win_sockaddr_in) == 16); static_assert(sizeof(win_sockaddr_in6) == 28); - static_assert(sizeof(win_sockaddr_in::sin_addr) == sizeof(sockaddr_in::sin_addr)); - static_assert(sizeof(win_sockaddr_in6::sin6_addr) == sizeof(sockaddr_in6::sin6_addr)); + static_assert(sizeof(win_sockaddr_in::sin_addr) == 4); + static_assert(sizeof(win_sockaddr_in6::sin6_addr) == 16); static_assert(sizeof(win_sockaddr_in6::sin6_flowinfo) == sizeof(sockaddr_in6::sin6_flowinfo)); static_assert(sizeof(win_sockaddr_in6::sin6_scope_id) == sizeof(sockaddr_in6::sin6_scope_id)); @@ -75,20 +75,20 @@ namespace {255, IPPROTO_RAW}, // }; - int translate_host_to_win_address_family(const int host_af) + int16_t translate_host_to_win_address_family(const int host_af) { for (auto& entry : address_family_map) { if (entry.second == host_af) { - return entry.first; + return static_cast(entry.first); } } throw std::runtime_error("Unknown host address family: " + std::to_string(host_af)); } - int translate_address_family(const int win_af) + int translate_win_to_host_address_family(const int win_af) { const auto entry = address_family_map.find(win_af); if (entry != address_family_map.end()) @@ -99,7 +99,7 @@ namespace throw std::runtime_error("Unknown address family: " + std::to_string(win_af)); } - int translate_type(const int win_type) + int translate_win_to_host_type(const int win_type) { const auto entry = socket_type_map.find(win_type); if (entry != socket_type_map.end()) @@ -110,7 +110,7 @@ namespace throw std::runtime_error("Unknown socket type: " + std::to_string(win_type)); } - int translate_protocol(const int win_protocol) + int translate_win_to_host_protocol(const int win_protocol) { const auto entry = socket_protocol_map.find(win_protocol); if (entry != socket_protocol_map.end()) @@ -121,6 +121,90 @@ namespace throw std::runtime_error("Unknown socket protocol: " + std::to_string(win_protocol)); } + std::vector convert_to_win_address(const network::address& a) + { + if (a.is_ipv4()) + { + win_sockaddr_in win_addr{}; + win_addr.sin_family = translate_host_to_win_address_family(a.get_family()); + win_addr.sin_port = htons(a.get_port()); + memcpy(&win_addr.sin_addr, &a.get_in_addr().sin_addr, sizeof(win_addr.sin_addr)); + + const auto ptr = reinterpret_cast(&win_addr); + return {ptr, ptr + sizeof(win_addr)}; + } + + if (a.is_ipv6()) + { + win_sockaddr_in6 win_addr{}; + win_addr.sin6_family = translate_host_to_win_address_family(a.get_family()); + win_addr.sin6_port = htons(a.get_port()); + + auto& addr = a.get_in6_addr(); + memcpy(&win_addr.sin6_addr, &addr.sin6_addr, sizeof(win_addr.sin6_addr)); + win_addr.sin6_flowinfo = addr.sin6_flowinfo; + win_addr.sin6_scope_id = addr.sin6_scope_id; + + const auto ptr = reinterpret_cast(&win_addr); + return {ptr, ptr + sizeof(win_addr)}; + } + + throw std::runtime_error("Unsupported host address family for conversion: " + std::to_string(a.get_family())); + } + + network::address convert_to_host_address(const std::span data) + { + if (data.size() < sizeof(win_sockaddr)) + { + throw std::runtime_error("Bad address size"); + } + + win_sockaddr win_addr{}; + memcpy(&win_addr, data.data(), sizeof(win_addr)); + + const auto family = translate_win_to_host_address_family(win_addr.sa_family); + + network::address a{}; + + if (family == AF_INET) + { + if (data.size() < sizeof(win_sockaddr_in)) + { + throw std::runtime_error("Bad IPv4 address size"); + } + + win_sockaddr_in win_addr4{}; + memcpy(&win_addr4, data.data(), sizeof(win_addr4)); + + a.set_ipv4(win_addr4.sin_addr); + a.set_port(ntohs(win_addr4.sin_port)); + + return a; + } + + if (family == AF_INET6) + { + if (data.size() < sizeof(win_sockaddr_in6)) + { + throw std::runtime_error("Bad IPv6 address size"); + } + + win_sockaddr_in6 win_addr6{}; + memcpy(&win_addr6, data.data(), sizeof(win_addr6)); + + a.set_ipv6(win_addr6.sin6_addr); + a.set_port(ntohs(win_addr6.sin6_port)); + + auto& addr = a.get_in6_addr(); + addr.sin6_flowinfo = win_addr6.sin6_flowinfo; + addr.sin6_scope_id = win_addr6.sin6_scope_id; + + return a; + } + + throw std::runtime_error("Unsupported win address family for conversion: " + std::to_string(family)); + } + afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data) { if (!data.buffer || data.length < sizeof(afd_creation_data)) @@ -315,9 +399,9 @@ namespace const auto& data = *this->creation_data; - const auto af = translate_address_family(data.address_family); - const auto type = translate_type(data.type); - const auto protocol = translate_protocol(data.protocol); + const auto af = translate_win_to_host_address_family(data.address_family); + const auto type = translate_win_to_host_type(data.type); + const auto protocol = translate_win_to_host_protocol(data.protocol); const auto sock = socket(af, type, protocol); if (sock == INVALID_SOCKET) @@ -450,13 +534,7 @@ namespace return STATUS_BUFFER_TOO_SMALL; } - auto* address = reinterpret_cast(data.data() + address_offset); - const auto address_size = static_cast(data.size() - address_offset); - - address->sa_family = - static_castsa_family)>(translate_address_family(address->sa_family)); - - const network::address addr(address, address_size); + const auto addr = convert_to_host_address(std::span(data).subspan(address_offset)); if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) { @@ -536,30 +614,19 @@ namespace const auto receive_info = emu.read_memory>>(c.input_buffer); const auto buffer = emu.read_memory>>(receive_info.BufferArray); - std::vector address{}; - - unsigned long address_length = 0x1000; - if (receive_info.AddressLength) - { - address_length = emu.read_memory(receive_info.AddressLength); - } - - address.resize(std::clamp(address_length, 1UL, 0x1000UL)); - if (!buffer.len || buffer.len > 0x10000 || !buffer.buf) { return STATUS_INVALID_PARAMETER; } - auto fromlength = static_cast(address.size()); + network::address from{}; + auto from_length = from.get_max_size(); std::vector data{}; data.resize(buffer.len); - auto* sender = reinterpret_cast(address.data()); - - const auto recevied_data = - recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, sender, &fromlength); + const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, + &from.get_addr(), &from_length); if (recevied_data < 0) { @@ -573,17 +640,20 @@ namespace return STATUS_UNSUCCESSFUL; } - // TODO: Translate rest of address struct? - sender->sa_family = - static_castsa_family)>(translate_host_to_win_address_family(sender->sa_family)); + assert(from.get_size() == from_length); const auto data_size = std::min(data.size(), static_cast(recevied_data)); emu.write_memory(buffer.buf, data.data(), data_size); - if (receive_info.Address && address_length) + const auto win_from = convert_to_win_address(from); + + if (receive_info.Address && receive_info.AddressLength) { - const auto address_size = std::min(address.size(), static_cast(address_length)); - emu.write_memory(receive_info.Address, address.data(), address_size); + const emulator_object address_length{emu, receive_info.AddressLength}; + const auto address_size = std::min(win_from.size(), static_cast(address_length.read())); + + emu.write_memory(receive_info.Address, win_from.data(), address_size); + address_length.write(static_cast(address_size)); } if (c.io_status_block) @@ -610,17 +680,13 @@ namespace auto address_buffer = emu.read_memory(send_info.TdiConnInfo.RemoteAddress, static_cast(send_info.TdiConnInfo.RemoteAddressLength)); - auto* address = reinterpret_cast(address_buffer.data()); - address->sa_family = - static_castsa_family)>(translate_address_family(address->sa_family)); - - const network::address target(address, static_cast(address_buffer.size())); + const auto target = convert_to_host_address(address_buffer); const auto data = emu.read_memory(buffer.buf, buffer.len); const auto sent_data = sendto(*this->s_, reinterpret_cast(data.data()), static_cast(data.size()), - 0 /* ? */, &target.get_addr(), target.get_size()); + 0 /* TODO */, &target.get_addr(), target.get_size()); if (sent_data < 0) {