diff --git a/src/windows-emulator/devices/afd_endpoint.cpp b/src/windows-emulator/devices/afd_endpoint.cpp index 48d43c08..558e714a 100644 --- a/src/windows-emulator/devices/afd_endpoint.cpp +++ b/src/windows-emulator/devices/afd_endpoint.cpp @@ -22,6 +22,73 @@ namespace // ... }; + const std::map address_family_map{ + {0, AF_UNSPEC}, // + {2, AF_INET}, // + {23, AF_INET6}, // + }; + + const std::map socket_type_map{ + {0, 0}, // + {1, SOCK_STREAM}, // + {2, SOCK_DGRAM}, // + {3, SOCK_RAW}, // + {4, SOCK_RDM}, // + }; + + const std::map socket_protocol_map{ + {0, 0}, // + {6, IPPROTO_TCP}, // + {17, IPPROTO_UDP}, // + {255, IPPROTO_RAW}, // + }; + + int translate_host_to_win_address_family(const int host_af) + { + for (auto& entry : address_family_map) + { + if (entry.second == host_af) + { + return entry.first; + } + } + + throw std::runtime_error("Unknown host address family: " + std::to_string(host_af)); + } + + int translate_address_family(const int win_af) + { + const auto entry = address_family_map.find(win_af); + if (entry != address_family_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown address family: " + std::to_string(win_af)); + } + + int translate_type(const int win_type) + { + const auto entry = socket_type_map.find(win_type); + if (entry != socket_type_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown socket type: " + std::to_string(win_type)); + } + + int translate_protocol(const int win_protocol) + { + const auto entry = socket_protocol_map.find(win_protocol); + if (entry != socket_protocol_map.end()) + { + return entry->second; + } + + throw std::runtime_error("Unknown socket protocol: " + std::to_string(win_protocol)); + } + afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data) { if (!data.buffer || data.length < sizeof(afd_creation_data)) @@ -216,8 +283,11 @@ namespace const auto& data = *this->creation_data; - // TODO: values map to windows values; might not be the case for other platforms - const auto sock = socket(data.address_family, data.type, data.protocol); + const auto af = translate_address_family(data.address_family); + const auto type = translate_type(data.type); + const auto protocol = translate_protocol(data.protocol); + + const auto sock = socket(af, type, protocol); if (sock == INVALID_SOCKET) { throw std::runtime_error("Failed to create socket!"); @@ -339,7 +409,7 @@ namespace NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const { - const auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length); + auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length); constexpr auto address_offset = 4; @@ -348,9 +418,12 @@ namespace return STATUS_BUFFER_TOO_SMALL; } - const auto* address = reinterpret_cast(data.data() + address_offset); + auto* address = reinterpret_cast(data.data() + address_offset); const auto address_size = static_cast(data.size() - address_offset); + address->sa_family = + static_castsa_family)>(translate_address_family(address->sa_family)); + const network::address addr(address, address_size); if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR) @@ -451,8 +524,10 @@ namespace std::vector data{}; data.resize(buffer.len); - const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, - reinterpret_cast(address.data()), &fromlength); + auto* sender = reinterpret_cast(address.data()); + + const auto recevied_data = + recvfrom(*this->s_, data.data(), static_cast(data.size()), 0, sender, &fromlength); if (recevied_data < 0) { @@ -466,6 +541,10 @@ namespace return STATUS_UNSUCCESSFUL; } + // TODO: Translate rest of address struct? + sender->sa_family = + static_castsa_family)>(translate_host_to_win_address_family(sender->sa_family)); + const auto data_size = std::min(data.size(), static_cast(recevied_data)); emu.write_memory(buffer.buf, data.data(), data_size); @@ -497,11 +576,13 @@ namespace const auto send_info = emu.read_memory>>(c.input_buffer); const auto buffer = emu.read_memory>>(send_info.BufferArray); - const auto address = emu.read_memory(send_info.TdiConnInfo.RemoteAddress, - static_cast(send_info.TdiConnInfo.RemoteAddressLength)); + auto address_buffer = emu.read_memory(send_info.TdiConnInfo.RemoteAddress, + static_cast(send_info.TdiConnInfo.RemoteAddressLength)); + auto* address = reinterpret_cast(address_buffer.data()); + address->sa_family = + static_castsa_family)>(translate_address_family(address->sa_family)); - const network::address target(reinterpret_cast(address.data()), - static_cast(address.size())); + const network::address target(address, static_cast(address_buffer.size())); const auto data = emu.read_memory(buffer.buf, buffer.len);