Convert socket addresses

This commit is contained in:
momo5502
2025-01-26 16:34:07 +01:00
parent b3d4d32fbd
commit 5387c45da2

View File

@@ -49,8 +49,8 @@ namespace
static_assert(sizeof(win_sockaddr_in) == 16);
static_assert(sizeof(win_sockaddr_in6) == 28);
static_assert(sizeof(win_sockaddr_in::sin_addr) == sizeof(sockaddr_in::sin_addr));
static_assert(sizeof(win_sockaddr_in6::sin6_addr) == sizeof(sockaddr_in6::sin6_addr));
static_assert(sizeof(win_sockaddr_in::sin_addr) == 4);
static_assert(sizeof(win_sockaddr_in6::sin6_addr) == 16);
static_assert(sizeof(win_sockaddr_in6::sin6_flowinfo) == sizeof(sockaddr_in6::sin6_flowinfo));
static_assert(sizeof(win_sockaddr_in6::sin6_scope_id) == sizeof(sockaddr_in6::sin6_scope_id));
@@ -75,20 +75,20 @@ namespace
{255, IPPROTO_RAW}, //
};
int translate_host_to_win_address_family(const int host_af)
int16_t translate_host_to_win_address_family(const int host_af)
{
for (auto& entry : address_family_map)
{
if (entry.second == host_af)
{
return entry.first;
return static_cast<int16_t>(entry.first);
}
}
throw std::runtime_error("Unknown host address family: " + std::to_string(host_af));
}
int translate_address_family(const int win_af)
int translate_win_to_host_address_family(const int win_af)
{
const auto entry = address_family_map.find(win_af);
if (entry != address_family_map.end())
@@ -99,7 +99,7 @@ namespace
throw std::runtime_error("Unknown address family: " + std::to_string(win_af));
}
int translate_type(const int win_type)
int translate_win_to_host_type(const int win_type)
{
const auto entry = socket_type_map.find(win_type);
if (entry != socket_type_map.end())
@@ -110,7 +110,7 @@ namespace
throw std::runtime_error("Unknown socket type: " + std::to_string(win_type));
}
int translate_protocol(const int win_protocol)
int translate_win_to_host_protocol(const int win_protocol)
{
const auto entry = socket_protocol_map.find(win_protocol);
if (entry != socket_protocol_map.end())
@@ -121,6 +121,90 @@ namespace
throw std::runtime_error("Unknown socket protocol: " + std::to_string(win_protocol));
}
std::vector<std::byte> convert_to_win_address(const network::address& a)
{
if (a.is_ipv4())
{
win_sockaddr_in win_addr{};
win_addr.sin_family = translate_host_to_win_address_family(a.get_family());
win_addr.sin_port = htons(a.get_port());
memcpy(&win_addr.sin_addr, &a.get_in_addr().sin_addr, sizeof(win_addr.sin_addr));
const auto ptr = reinterpret_cast<std::byte*>(&win_addr);
return {ptr, ptr + sizeof(win_addr)};
}
if (a.is_ipv6())
{
win_sockaddr_in6 win_addr{};
win_addr.sin6_family = translate_host_to_win_address_family(a.get_family());
win_addr.sin6_port = htons(a.get_port());
auto& addr = a.get_in6_addr();
memcpy(&win_addr.sin6_addr, &addr.sin6_addr, sizeof(win_addr.sin6_addr));
win_addr.sin6_flowinfo = addr.sin6_flowinfo;
win_addr.sin6_scope_id = addr.sin6_scope_id;
const auto ptr = reinterpret_cast<std::byte*>(&win_addr);
return {ptr, ptr + sizeof(win_addr)};
}
throw std::runtime_error("Unsupported host address family for conversion: " + std::to_string(a.get_family()));
}
network::address convert_to_host_address(const std::span<const std::byte> data)
{
if (data.size() < sizeof(win_sockaddr))
{
throw std::runtime_error("Bad address size");
}
win_sockaddr win_addr{};
memcpy(&win_addr, data.data(), sizeof(win_addr));
const auto family = translate_win_to_host_address_family(win_addr.sa_family);
network::address a{};
if (family == AF_INET)
{
if (data.size() < sizeof(win_sockaddr_in))
{
throw std::runtime_error("Bad IPv4 address size");
}
win_sockaddr_in win_addr4{};
memcpy(&win_addr4, data.data(), sizeof(win_addr4));
a.set_ipv4(win_addr4.sin_addr);
a.set_port(ntohs(win_addr4.sin_port));
return a;
}
if (family == AF_INET6)
{
if (data.size() < sizeof(win_sockaddr_in6))
{
throw std::runtime_error("Bad IPv6 address size");
}
win_sockaddr_in6 win_addr6{};
memcpy(&win_addr6, data.data(), sizeof(win_addr6));
a.set_ipv6(win_addr6.sin6_addr);
a.set_port(ntohs(win_addr6.sin6_port));
auto& addr = a.get_in6_addr();
addr.sin6_flowinfo = win_addr6.sin6_flowinfo;
addr.sin6_scope_id = win_addr6.sin6_scope_id;
return a;
}
throw std::runtime_error("Unsupported win address family for conversion: " + std::to_string(family));
}
afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data)
{
if (!data.buffer || data.length < sizeof(afd_creation_data))
@@ -315,9 +399,9 @@ namespace
const auto& data = *this->creation_data;
const auto af = translate_address_family(data.address_family);
const auto type = translate_type(data.type);
const auto protocol = translate_protocol(data.protocol);
const auto af = translate_win_to_host_address_family(data.address_family);
const auto type = translate_win_to_host_type(data.type);
const auto protocol = translate_win_to_host_protocol(data.protocol);
const auto sock = socket(af, type, protocol);
if (sock == INVALID_SOCKET)
@@ -450,13 +534,7 @@ namespace
return STATUS_BUFFER_TOO_SMALL;
}
auto* address = reinterpret_cast<sockaddr*>(data.data() + address_offset);
const auto address_size = static_cast<socklen_t>(data.size() - address_offset);
address->sa_family =
static_cast<decltype(address->sa_family)>(translate_address_family(address->sa_family));
const network::address addr(address, address_size);
const auto addr = convert_to_host_address(std::span(data).subspan(address_offset));
if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
{
@@ -536,30 +614,19 @@ namespace
const auto receive_info = emu.read_memory<AFD_RECV_DATAGRAM_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
const auto buffer = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(receive_info.BufferArray);
std::vector<std::byte> address{};
unsigned long address_length = 0x1000;
if (receive_info.AddressLength)
{
address_length = emu.read_memory<ULONG>(receive_info.AddressLength);
}
address.resize(std::clamp(address_length, 1UL, 0x1000UL));
if (!buffer.len || buffer.len > 0x10000 || !buffer.buf)
{
return STATUS_INVALID_PARAMETER;
}
auto fromlength = static_cast<socklen_t>(address.size());
network::address from{};
auto from_length = from.get_max_size();
std::vector<char> data{};
data.resize(buffer.len);
auto* sender = reinterpret_cast<sockaddr*>(address.data());
const auto recevied_data =
recvfrom(*this->s_, data.data(), static_cast<send_size>(data.size()), 0, sender, &fromlength);
const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast<send_size>(data.size()), 0,
&from.get_addr(), &from_length);
if (recevied_data < 0)
{
@@ -573,17 +640,20 @@ namespace
return STATUS_UNSUCCESSFUL;
}
// TODO: Translate rest of address struct?
sender->sa_family =
static_cast<decltype(sender->sa_family)>(translate_host_to_win_address_family(sender->sa_family));
assert(from.get_size() == from_length);
const auto data_size = std::min(data.size(), static_cast<size_t>(recevied_data));
emu.write_memory(buffer.buf, data.data(), data_size);
if (receive_info.Address && address_length)
const auto win_from = convert_to_win_address(from);
if (receive_info.Address && receive_info.AddressLength)
{
const auto address_size = std::min(address.size(), static_cast<size_t>(address_length));
emu.write_memory(receive_info.Address, address.data(), address_size);
const emulator_object<ULONG> address_length{emu, receive_info.AddressLength};
const auto address_size = std::min(win_from.size(), static_cast<size_t>(address_length.read()));
emu.write_memory(receive_info.Address, win_from.data(), address_size);
address_length.write(static_cast<ULONG>(address_size));
}
if (c.io_status_block)
@@ -610,17 +680,13 @@ namespace
auto address_buffer = emu.read_memory(send_info.TdiConnInfo.RemoteAddress,
static_cast<size_t>(send_info.TdiConnInfo.RemoteAddressLength));
auto* address = reinterpret_cast<sockaddr*>(address_buffer.data());
address->sa_family =
static_cast<decltype(address->sa_family)>(translate_address_family(address->sa_family));
const network::address target(address, static_cast<socklen_t>(address_buffer.size()));
const auto target = convert_to_host_address(address_buffer);
const auto data = emu.read_memory(buffer.buf, buffer.len);
const auto sent_data =
sendto(*this->s_, reinterpret_cast<const char*>(data.data()), static_cast<send_size>(data.size()),
0 /* ? */, &target.get_addr(), target.get_size());
0 /* TODO */, &target.get_addr(), target.get_size());
if (sent_data < 0)
{