Fix sockets and add test (#114)

This commit is contained in:
Maurice Heumann
2025-01-26 16:51:41 +01:00
committed by GitHub
8 changed files with 401 additions and 53 deletions

View File

@@ -22,6 +22,16 @@ namespace utils
{ a.deserialize(deserializer) } -> std::same_as<void>;
};
template <typename T>
struct is_optional : std::false_type
{
};
template <typename T>
struct is_optional<std::optional<T>> : std::true_type
{
};
namespace detail
{
template <typename, typename = void>
@@ -349,6 +359,12 @@ namespace utils
const uint64_t old_size = this->buffer_.size();
#endif
if (this->break_offset_ && this->buffer_.size() <= *this->break_offset_ &&
this->buffer_.size() + length > *this->break_offset_)
{
throw std::runtime_error("Break offset reached!");
}
const auto* byte_buffer = static_cast<const std::byte*>(buffer);
this->buffer_.insert(this->buffer_.end(), byte_buffer, byte_buffer + length);
@@ -365,6 +381,7 @@ namespace utils
}
template <typename T>
requires(!is_optional<T>::value)
void write(const T& object)
{
constexpr auto is_trivially_copyable = std::is_trivially_copyable_v<T>;
@@ -475,8 +492,47 @@ namespace utils
return std::move(this->buffer_);
}
void set_break_offset(const size_t break_offset)
{
this->break_offset_ = break_offset;
}
std::optional<size_t> get_diff(const buffer_serializer& other) const
{
auto& b1 = this->get_buffer();
auto& b2 = other.get_buffer();
const auto s1 = b1.size();
const auto s2 = b2.size();
for (size_t i = 0; i < s1 && i < s2; ++i)
{
if (b1.at(i) != b2.at(i))
{
return i;
}
}
if (s1 != s2)
{
return std::min(s1, s2);
}
return std::nullopt;
}
void print_diff(const buffer_serializer& other) const
{
const auto diff = this->get_diff(other);
if (diff)
{
printf("Diff at %zd\n", *diff);
}
}
private:
std::vector<std::byte> buffer_{};
std::optional<size_t> break_offset_{};
};
template <>

View File

@@ -9,3 +9,7 @@ list(SORT SRC_FILES)
add_executable(test-sample ${SRC_FILES})
momo_assign_source_group(${SRC_FILES})
target_link_libraries(test-sample PRIVATE
emulator-common
)

View File

@@ -9,7 +9,7 @@
#include <filesystem>
#include <string_view>
#include <Windows.h>
#include <network/udp_socket.hpp>
using namespace std::literals;
@@ -195,7 +195,7 @@ std::optional<std::string> read_registry_string(const HKEY root, const char* pat
return "";
}
return {std::string(data, min(length - 1, sizeof(data)))};
return {std::string(data, std::min(static_cast<size_t>(length - 1), sizeof(data)))};
}
bool test_registry()
@@ -231,6 +231,36 @@ bool test_exceptions()
}
}
bool test_socket()
{
network::udp_socket receiver{AF_INET};
const network::udp_socket sender{AF_INET};
const network::address destination{"127.0.0.1:28970", AF_INET};
constexpr std::string_view send_data = "Hello World";
if (!receiver.bind(destination))
{
puts("Failed to bind socket!");
return false;
}
if (!sender.send(destination, send_data))
{
puts("Failed to send data!");
return false;
}
const auto response = receiver.receive();
if (!response)
{
puts("Failed to recieve data!");
return false;
}
return send_data == response->second;
}
void throw_access_violation()
{
if (do_the_task)
@@ -256,7 +286,7 @@ bool test_ud2_exception(void* address)
{
__try
{
static_cast<void (*)()>(address)();
reinterpret_cast<void (*)()>(address)();
return false;
}
__except (EXCEPTION_EXECUTE_HANDLER)
@@ -301,12 +331,18 @@ void print_time()
puts(res ? "Success" : "Fail"); \
}
int main(int argc, const char* argv[])
int main(const int argc, const char* argv[])
{
if (argc == 2 && argv[1] == "-time"sv)
bool reproducible = false;
if (argc == 2)
{
print_time();
return 0;
if (argv[1] == "-time"sv)
{
print_time();
return 0;
}
reproducible = argv[1] == "-reproducible"sv;
}
bool valid = true;
@@ -320,5 +356,10 @@ int main(int argc, const char* argv[])
RUN_TEST(test_native_exceptions, "Native Exceptions")
RUN_TEST(test_tls, "TLS")
if (!reproducible)
{
RUN_TEST(test_socket, "Socket")
}
return valid ? 0 : 1;
}

View File

@@ -110,6 +110,8 @@ CALL :collect_dll mscms.dll
CALL :collect_dll ktmw32.dll
CALL :collect_dll shcore.dll
CALL :collect_dll diagnosticdatasettings.dll
CALL :collect_dll mswsock.dll
CALL :collect_dll umpdc.dll
CALL :collect_dll locale.nls

View File

@@ -38,14 +38,20 @@ namespace test
return env;
}
inline windows_emulator create_sample_emulator(emulator_settings settings, emulator_callbacks callbacks = {})
inline windows_emulator create_sample_emulator(emulator_settings settings, const bool reproducible = false,
emulator_callbacks callbacks = {})
{
const auto is_verbose = enable_verbose_logging();
if (is_verbose)
{
settings.disable_logging = false;
settings.verbose_calls = true;
// settings.verbose_calls = true;
}
if (reproducible)
{
settings.arguments = {u"-reproducible"};
}
settings.application = "c:/test-sample.exe";
@@ -53,13 +59,73 @@ namespace test
return windows_emulator{std::move(settings), std::move(callbacks)};
}
inline windows_emulator create_sample_emulator()
inline windows_emulator create_sample_emulator(const bool reproducible = false)
{
emulator_settings settings{
.disable_logging = true,
.use_relative_time = true,
};
return create_sample_emulator(std::move(settings));
return create_sample_emulator(std::move(settings), reproducible);
}
inline void bisect_emulation(windows_emulator& emu)
{
utils::buffer_serializer start_state{};
emu.serialize(start_state);
emu.start();
const auto limit = emu.process().executed_instructions;
const auto reset_emulator = [&] {
utils::buffer_deserializer deserializer{start_state.get_buffer()};
emu.deserialize(deserializer);
};
const auto get_state_for_count = [&](const size_t count) {
reset_emulator();
emu.start({}, count);
utils::buffer_serializer state{};
emu.serialize(state);
return state;
};
const auto has_diff_after_count = [&](const size_t count) {
const auto s1 = get_state_for_count(count);
const auto s2 = get_state_for_count(count);
return s1.get_diff(s2).has_value();
};
if (!has_diff_after_count(limit))
{
puts("Emulation has no diff");
}
auto upper_bound = limit;
decltype(upper_bound) lower_bound = 0;
printf("Bounds: %" PRIx64 " - %" PRIx64 "\n", lower_bound, upper_bound);
while (lower_bound + 1 < upper_bound)
{
const auto diff = (upper_bound - lower_bound);
const auto pivot = lower_bound + (diff / 2);
const auto has_diff = has_diff_after_count(pivot);
auto* bound = has_diff ? &upper_bound : &lower_bound;
*bound = pivot;
printf("Bounds: %" PRIx64 " - %" PRIx64 "\n", lower_bound, upper_bound);
}
(void)get_state_for_count(lower_bound);
const auto rip = emu.emu().read_instruction_pointer();
printf("Diff detected after 0x%" PRIx64 " instructions at 0x%" PRIx64 " (%s)\n", lower_bound, rip,
emu.process().mod_manager.find_name(rip));
}
}

View File

@@ -4,7 +4,7 @@ namespace test
{
TEST(SerializationTest, ResettingEmulatorWorks)
{
auto emu = create_sample_emulator();
auto emu = create_sample_emulator(true);
utils::buffer_serializer start_state{};
emu.serialize(start_state);
@@ -31,7 +31,7 @@ namespace test
TEST(SerializationTest, SerializedDataIsReproducible)
{
auto emu1 = create_sample_emulator();
auto emu1 = create_sample_emulator(true);
emu1.start();
ASSERT_TERMINATED_SUCCESSFULLY(emu1);
@@ -55,7 +55,7 @@ namespace test
TEST(SerializationTest, EmulationIsReproducible)
{
auto emu1 = create_sample_emulator();
auto emu1 = create_sample_emulator(true);
emu1.start();
ASSERT_TERMINATED_SUCCESSFULLY(emu1);
@@ -63,7 +63,7 @@ namespace test
utils::buffer_serializer serializer1{};
emu1.serialize(serializer1);
auto emu2 = create_sample_emulator();
auto emu2 = create_sample_emulator(true);
emu2.start();
ASSERT_TERMINATED_SUCCESSFULLY(emu2);
@@ -76,7 +76,7 @@ namespace test
TEST(SerializationTest, DeserializedEmulatorBehavesLikeSource)
{
auto emu = create_sample_emulator();
auto emu = create_sample_emulator(true);
emu.start({}, 100);
utils::buffer_serializer serializer{};

View File

@@ -16,7 +16,7 @@ namespace test
.use_relative_time = false,
};
auto emu = create_sample_emulator(settings, callbacks);
auto emu = create_sample_emulator(settings, false, callbacks);
emu.start();
constexpr auto prefix = "Time: "sv;

View File

@@ -22,6 +22,189 @@ namespace
// ...
};
struct win_sockaddr
{
int16_t sa_family;
uint8_t sa_data[14];
};
struct win_sockaddr_in
{
int16_t sin_family;
uint16_t sin_port;
in_addr sin_addr;
uint8_t sin_zero[8];
};
struct win_sockaddr_in6
{
int16_t sin6_family;
uint16_t sin6_port;
uint32_t sin6_flowinfo;
in6_addr sin6_addr;
uint32_t sin6_scope_id;
};
static_assert(sizeof(win_sockaddr) == 16);
static_assert(sizeof(win_sockaddr_in) == 16);
static_assert(sizeof(win_sockaddr_in6) == 28);
static_assert(sizeof(win_sockaddr_in::sin_addr) == 4);
static_assert(sizeof(win_sockaddr_in6::sin6_addr) == 16);
static_assert(sizeof(win_sockaddr_in6::sin6_flowinfo) == sizeof(sockaddr_in6::sin6_flowinfo));
static_assert(sizeof(win_sockaddr_in6::sin6_scope_id) == sizeof(sockaddr_in6::sin6_scope_id));
const std::map<int, int> address_family_map{
{0, AF_UNSPEC}, //
{2, AF_INET}, //
{23, AF_INET6}, //
};
const std::map<int, int> socket_type_map{
{0, 0}, //
{1, SOCK_STREAM}, //
{2, SOCK_DGRAM}, //
{3, SOCK_RAW}, //
{4, SOCK_RDM}, //
};
const std::map<int, int> socket_protocol_map{
{0, 0}, //
{6, IPPROTO_TCP}, //
{17, IPPROTO_UDP}, //
{255, IPPROTO_RAW}, //
};
int16_t translate_host_to_win_address_family(const int host_af)
{
for (auto& entry : address_family_map)
{
if (entry.second == host_af)
{
return static_cast<int16_t>(entry.first);
}
}
throw std::runtime_error("Unknown host address family: " + std::to_string(host_af));
}
int translate_win_to_host_address_family(const int win_af)
{
const auto entry = address_family_map.find(win_af);
if (entry != address_family_map.end())
{
return entry->second;
}
throw std::runtime_error("Unknown address family: " + std::to_string(win_af));
}
int translate_win_to_host_type(const int win_type)
{
const auto entry = socket_type_map.find(win_type);
if (entry != socket_type_map.end())
{
return entry->second;
}
throw std::runtime_error("Unknown socket type: " + std::to_string(win_type));
}
int translate_win_to_host_protocol(const int win_protocol)
{
const auto entry = socket_protocol_map.find(win_protocol);
if (entry != socket_protocol_map.end())
{
return entry->second;
}
throw std::runtime_error("Unknown socket protocol: " + std::to_string(win_protocol));
}
std::vector<std::byte> convert_to_win_address(const network::address& a)
{
if (a.is_ipv4())
{
win_sockaddr_in win_addr{};
win_addr.sin_family = translate_host_to_win_address_family(a.get_family());
win_addr.sin_port = htons(a.get_port());
memcpy(&win_addr.sin_addr, &a.get_in_addr().sin_addr, sizeof(win_addr.sin_addr));
const auto ptr = reinterpret_cast<std::byte*>(&win_addr);
return {ptr, ptr + sizeof(win_addr)};
}
if (a.is_ipv6())
{
win_sockaddr_in6 win_addr{};
win_addr.sin6_family = translate_host_to_win_address_family(a.get_family());
win_addr.sin6_port = htons(a.get_port());
auto& addr = a.get_in6_addr();
memcpy(&win_addr.sin6_addr, &addr.sin6_addr, sizeof(win_addr.sin6_addr));
win_addr.sin6_flowinfo = addr.sin6_flowinfo;
win_addr.sin6_scope_id = addr.sin6_scope_id;
const auto ptr = reinterpret_cast<std::byte*>(&win_addr);
return {ptr, ptr + sizeof(win_addr)};
}
throw std::runtime_error("Unsupported host address family for conversion: " + std::to_string(a.get_family()));
}
network::address convert_to_host_address(const std::span<const std::byte> data)
{
if (data.size() < sizeof(win_sockaddr))
{
throw std::runtime_error("Bad address size");
}
win_sockaddr win_addr{};
memcpy(&win_addr, data.data(), sizeof(win_addr));
const auto family = translate_win_to_host_address_family(win_addr.sa_family);
network::address a{};
if (family == AF_INET)
{
if (data.size() < sizeof(win_sockaddr_in))
{
throw std::runtime_error("Bad IPv4 address size");
}
win_sockaddr_in win_addr4{};
memcpy(&win_addr4, data.data(), sizeof(win_addr4));
a.set_ipv4(win_addr4.sin_addr);
a.set_port(ntohs(win_addr4.sin_port));
return a;
}
if (family == AF_INET6)
{
if (data.size() < sizeof(win_sockaddr_in6))
{
throw std::runtime_error("Bad IPv6 address size");
}
win_sockaddr_in6 win_addr6{};
memcpy(&win_addr6, data.data(), sizeof(win_addr6));
a.set_ipv6(win_addr6.sin6_addr);
a.set_port(ntohs(win_addr6.sin6_port));
auto& addr = a.get_in6_addr();
addr.sin6_flowinfo = win_addr6.sin6_flowinfo;
addr.sin6_scope_id = win_addr6.sin6_scope_id;
return a;
}
throw std::runtime_error("Unsupported win address family for conversion: " + std::to_string(family));
}
afd_creation_data get_creation_data(windows_emulator& win_emu, const io_device_creation_data& data)
{
if (!data.buffer || data.length < sizeof(afd_creation_data))
@@ -216,8 +399,11 @@ namespace
const auto& data = *this->creation_data;
// TODO: values map to windows values; might not be the case for other platforms
const auto sock = socket(data.address_family, data.type, data.protocol);
const auto af = translate_win_to_host_address_family(data.address_family);
const auto type = translate_win_to_host_type(data.type);
const auto protocol = translate_win_to_host_protocol(data.protocol);
const auto sock = socket(af, type, protocol);
if (sock == INVALID_SOCKET)
{
throw std::runtime_error("Failed to create socket!");
@@ -290,20 +476,20 @@ namespace
void deserialize(utils::buffer_deserializer& buffer) override
{
buffer.read(this->creation_data);
buffer.read_optional(this->creation_data);
this->setup();
buffer.read(this->require_poll_);
buffer.read(this->delayed_ioctl_);
buffer.read(this->timeout_);
buffer.read_optional(this->require_poll_);
buffer.read_optional(this->delayed_ioctl_);
buffer.read_optional(this->timeout_);
}
void serialize(utils::buffer_serializer& buffer) const override
{
buffer.write(this->creation_data);
buffer.write(this->require_poll_);
buffer.write(this->delayed_ioctl_);
buffer.write(this->timeout_);
buffer.write_optional(this->creation_data);
buffer.write_optional(this->require_poll_);
buffer.write_optional(this->delayed_ioctl_);
buffer.write_optional(this->timeout_);
}
NTSTATUS io_control(windows_emulator& win_emu, const io_device_context& c) override
@@ -339,7 +525,7 @@ namespace
NTSTATUS ioctl_bind(windows_emulator& win_emu, const io_device_context& c) const
{
const auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length);
auto data = win_emu.emu().read_memory(c.input_buffer, c.input_buffer_length);
constexpr auto address_offset = 4;
@@ -348,10 +534,7 @@ namespace
return STATUS_BUFFER_TOO_SMALL;
}
const auto* address = reinterpret_cast<const sockaddr*>(data.data() + address_offset);
const auto address_size = static_cast<socklen_t>(data.size() - address_offset);
const network::address addr(address, address_size);
const auto addr = convert_to_host_address(std::span(data).subspan(address_offset));
if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
{
@@ -431,28 +614,19 @@ namespace
const auto receive_info = emu.read_memory<AFD_RECV_DATAGRAM_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
const auto buffer = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(receive_info.BufferArray);
std::vector<std::byte> address{};
unsigned long address_length = 0x1000;
if (receive_info.AddressLength)
{
address_length = emu.read_memory<ULONG>(receive_info.AddressLength);
}
address.resize(std::clamp(address_length, 1UL, 0x1000UL));
if (!buffer.len || buffer.len > 0x10000 || !buffer.buf)
{
return STATUS_INVALID_PARAMETER;
}
auto fromlength = static_cast<socklen_t>(address.size());
network::address from{};
auto from_length = from.get_max_size();
std::vector<char> data{};
data.resize(buffer.len);
const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast<send_size>(data.size()), 0,
reinterpret_cast<sockaddr*>(address.data()), &fromlength);
&from.get_addr(), &from_length);
if (recevied_data < 0)
{
@@ -466,13 +640,20 @@ namespace
return STATUS_UNSUCCESSFUL;
}
assert(from.get_size() == from_length);
const auto data_size = std::min(data.size(), static_cast<size_t>(recevied_data));
emu.write_memory(buffer.buf, data.data(), data_size);
if (receive_info.Address && address_length)
const auto win_from = convert_to_win_address(from);
if (receive_info.Address && receive_info.AddressLength)
{
const auto address_size = std::min(address.size(), static_cast<size_t>(address_length));
emu.write_memory(receive_info.Address, address.data(), address_size);
const emulator_object<ULONG> address_length{emu, receive_info.AddressLength};
const auto address_size = std::min(win_from.size(), static_cast<size_t>(address_length.read()));
emu.write_memory(receive_info.Address, win_from.data(), address_size);
address_length.write(static_cast<ULONG>(address_size));
}
if (c.io_status_block)
@@ -497,17 +678,15 @@ namespace
const auto send_info = emu.read_memory<AFD_SEND_DATAGRAM_INFO<EmulatorTraits<Emu64>>>(c.input_buffer);
const auto buffer = emu.read_memory<EMU_WSABUF<EmulatorTraits<Emu64>>>(send_info.BufferArray);
const auto address = emu.read_memory(send_info.TdiConnInfo.RemoteAddress,
static_cast<size_t>(send_info.TdiConnInfo.RemoteAddressLength));
const network::address target(reinterpret_cast<const sockaddr*>(address.data()),
static_cast<socklen_t>(address.size()));
auto address_buffer = emu.read_memory(send_info.TdiConnInfo.RemoteAddress,
static_cast<size_t>(send_info.TdiConnInfo.RemoteAddressLength));
const auto target = convert_to_host_address(address_buffer);
const auto data = emu.read_memory(buffer.buf, buffer.len);
const auto sent_data =
sendto(*this->s_, reinterpret_cast<const char*>(data.data()), static_cast<send_size>(data.size()),
0 /* ? */, &target.get_addr(), target.get_size());
0 /* TODO */, &target.get_addr(), target.get_size());
if (sent_data < 0)
{