Add socket abstraction

This commit is contained in:
Maurice Heumann
2025-03-20 15:17:43 +01:00
parent 2cb14a3555
commit 4da6642123
15 changed files with 437 additions and 71 deletions

View File

@@ -8,6 +8,7 @@
#ifdef _WIN32
using send_size = int;
using sent_size = int;
#define GET_SOCKET_ERROR() (WSAGetLastError())
#define poll WSAPoll
#define SERR(x) (WSA##x)
@@ -15,6 +16,7 @@ using send_size = int;
#else
using SOCKET = int;
using send_size = size_t;
using sent_size = ssize_t;
#define INVALID_SOCKET (SOCKET)(~0)
#define SOCKET_ERROR (-1)
#define GET_SOCKET_ERROR() (errno)

View File

@@ -400,16 +400,10 @@ void print_time()
int main(const int argc, const char* argv[])
{
bool reproducible = false;
if (argc == 2)
if (argc == 2 && argv[1] == "-time"sv)
{
if (argv[1] == "-time"sv)
{
print_time();
return 0;
}
reproducible = argv[1] == "-reproducible"sv;
print_time();
return 0;
}
bool valid = true;
@@ -423,13 +417,9 @@ int main(const int argc, const char* argv[])
RUN_TEST(test_exceptions, "Exceptions")
RUN_TEST(test_native_exceptions, "Native Exceptions")
RUN_TEST(test_tls, "TLS")
RUN_TEST(test_socket, "Socket")
Sleep(1);
if (!reproducible)
{
RUN_TEST(test_socket, "Socket")
}
return valid ? 0 : 1;
}

View File

@@ -4,6 +4,8 @@
#include <gtest/gtest.h>
#include <windows_emulator.hpp>
#include "static_socket_factory.hpp"
#define ASSERT_NOT_TERMINATED(win_emu) \
do \
{ \
@@ -40,7 +42,6 @@ namespace test
struct sample_configuration
{
bool reproducible{false};
bool print_time{false};
};
@@ -53,11 +54,6 @@ namespace test
settings.arguments.emplace_back(u"-time");
}
if (config.reproducible)
{
settings.arguments.emplace_back(u"-reproducible");
}
return settings;
}
@@ -74,7 +70,6 @@ namespace test
settings.emulation_root = get_emulator_root();
settings.port_mappings[28970] = static_cast<uint16_t>(getpid());
settings.path_mappings["C:\\a.txt"] =
std::filesystem::temp_directory_path() / ("emulator-test-file-" + std::to_string(getpid()) + ".txt");
@@ -82,6 +77,9 @@ namespace test
get_sample_app_settings(config),
settings,
std::move(callbacks),
emulator_interfaces{
.socket_factory = network::create_static_socket_factory(),
},
};
}
@@ -97,7 +95,7 @@ namespace test
inline windows_emulator create_reproducible_sample_emulator()
{
return create_sample_emulator({.reproducible = true});
return create_sample_emulator();
}
inline void bisect_emulation(windows_emulator& emu)

View File

@@ -84,7 +84,13 @@ namespace test
utils::buffer_deserializer deserializer{serializer.get_buffer()};
windows_emulator new_emu{{.emulation_root = get_emulator_root(), .use_relative_time = true}};
windows_emulator new_emu{
{.emulation_root = get_emulator_root(), .use_relative_time = true},
{
.socket_factory = network::create_static_socket_factory(),
},
};
new_emu.log.disable_output(true);
new_emu.deserialize(deserializer);

View File

@@ -0,0 +1,129 @@
#include "static_socket_factory.hpp"
#include <queue>
#include <stdexcept>
#include <unordered_map>
#include <network/socket.hpp>
namespace network
{
namespace
{
struct static_socket_factory : socket_factory
{
using packet_data = std::vector<std::byte>;
using packet = std::pair<address, packet_data>;
using packet_queue = std::queue<packet>;
using packet_mapping = std::unordered_map<address, packet_queue>;
std::shared_ptr<packet_mapping> packets = std::make_shared<packet_mapping>();
uint16_t port{0};
struct static_socket : i_socket
{
int error{0};
address a{};
std::shared_ptr<packet_mapping> packets{};
static_socket(static_socket_factory& f, const int af)
: packets(f.packets)
{
if (af == AF_INET)
{
a.set_ipv4(0);
}
else if (af == AF_INET6)
{
a.set_ipv6({});
}
else
{
throw std::runtime_error("Invalid address family");
}
a.set_port(++f.port);
}
~static_socket() override = default;
void set_blocking(const bool blocking) override
{
if (blocking)
{
throw std::runtime_error("Blocking sockets not supported yet!");
}
}
int get_last_error() override
{
return this->error;
}
bool is_ready(const bool) override
{
return true;
}
bool bind(const address& addr) override
{
this->a = addr;
return true;
}
sent_size send(std::span<const std::byte>) override
{
throw std::runtime_error("Not implemented");
}
sent_size sendto(const address& destination, std::span<const std::byte> data) override
{
this->error = 0;
(*this->packets)[destination].emplace(this->a, packet_data{data.begin(), data.end()});
return static_cast<int>(data.size());
}
sent_size recv(std::span<std::byte>) override
{
throw std::runtime_error("Not implemented");
}
sent_size recvfrom(address& source, std::span<std::byte> data) override
{
this->error = 0;
auto& q = (*this->packets)[this->a];
if (q.empty())
{
this->error = SERR(EWOULDBLOCK);
return -1;
}
const auto p = std::move(q.front());
q.pop();
memcpy(data.data(), p.second.data(), std::min(data.size(), p.second.size()));
source = p.first;
return static_cast<int>(p.second.size());
}
};
std::unique_ptr<i_socket> create_socket(const int af, const int, const int) override
{
return std::make_unique<static_socket>(*this, af);
}
int poll_sockets(std::span<poll_entry>) override
{
throw std::runtime_error("Not implemented");
}
};
}
std::unique_ptr<socket_factory> create_static_socket_factory()
{
return std::make_unique<static_socket_factory>();
}
}

View File

@@ -0,0 +1,8 @@
#pragma once
#include <network/socket_factory.hpp>
namespace network
{
std::unique_ptr<socket_factory> create_static_socket_factory();
}

View File

@@ -3,6 +3,7 @@
#include "afd_types.hpp"
#include "../windows_emulator.hpp"
#include "../network/socket_factory.hpp"
#include <network/address.hpp>
#include <network/socket.hpp>
@@ -313,10 +314,10 @@ namespace
}
NTSTATUS perform_poll(windows_emulator& win_emu, const io_device_context& c,
const std::span<const SOCKET> endpoints,
const std::span<network::i_socket* const> endpoints,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{
std::vector<pollfd> poll_data{};
std::vector<network::poll_entry> poll_data{};
poll_data.resize(endpoints.size());
for (size_t i = 0; i < endpoints.size() && i < handles.size(); ++i)
@@ -324,12 +325,12 @@ namespace
auto& pfd = poll_data.at(i);
const auto& handle = handles[i];
pfd.fd = endpoints[i];
pfd.s = endpoints[i];
pfd.events = map_afd_request_events_to_socket(handle.PollEvents);
pfd.revents = pfd.events;
}
const auto count = poll(poll_data.data(), static_cast<uint32_t>(poll_data.size()), 0);
const auto count = win_emu.socket_factory().poll_sockets(poll_data);
if (count <= 0)
{
return STATUS_PENDING;
@@ -358,17 +359,20 @@ namespace
assert(current_index == static_cast<size_t>(count));
emulator_object<AFD_POLL_INFO64>{win_emu.emu(), c.input_buffer}.access(
[&](AFD_POLL_INFO64& info) { info.NumberOfHandles = static_cast<ULONG>(current_index); });
const emulator_object<AFD_POLL_INFO64> info_obj{win_emu.emu(), c.input_buffer};
info_obj.access([&](AFD_POLL_INFO64& info) {
info.NumberOfHandles = static_cast<ULONG>(current_index); //
});
return STATUS_SUCCESS;
}
struct afd_endpoint : io_device
{
std::unique_ptr<network::i_socket> s_{};
bool executing_delayed_ioctl_{};
std::optional<afd_creation_data> creation_data{};
std::optional<SOCKET> s_{};
std::optional<bool> require_poll_{};
std::optional<io_device_context> delayed_ioctl_{};
std::optional<std::chrono::steady_clock::time_point> timeout_{};
@@ -381,21 +385,15 @@ namespace
afd_endpoint(afd_endpoint&&) = delete;
afd_endpoint& operator=(afd_endpoint&&) = delete;
~afd_endpoint() override
{
if (this->s_)
{
closesocket(*this->s_);
}
}
~afd_endpoint() override = default;
void create(windows_emulator& win_emu, const io_device_creation_data& data) override
{
this->creation_data = get_creation_data(win_emu, data);
this->setup();
this->setup(win_emu.socket_factory());
}
void setup()
void setup(network::socket_factory& factory)
{
if (!this->creation_data)
{
@@ -408,15 +406,13 @@ namespace
const auto type = translate_win_to_host_type(data.type);
const auto protocol = translate_win_to_host_protocol(data.protocol);
const auto sock = socket(af, type, protocol);
if (sock == INVALID_SOCKET)
this->s_ = factory.create_socket(af, type, protocol);
if (!this->s_)
{
throw std::runtime_error("Failed to create socket!");
}
network::socket::set_blocking(sock, false);
this->s_ = sock;
this->s_->set_blocking(false);
}
void delay_ioctrl(const io_device_context& c,
@@ -452,7 +448,7 @@ namespace
if (this->require_poll_.has_value())
{
const auto is_ready = network::socket::is_socket_ready(*this->s_, *this->require_poll_);
const auto is_ready = this->s_->is_ready(*this->require_poll_);
if (!is_ready)
{
return;
@@ -482,7 +478,7 @@ namespace
void deserialize_object(utils::buffer_deserializer& buffer) override
{
buffer.read_optional(this->creation_data);
this->setup();
this->setup(buffer.read<socket_factory_wrapper>());
buffer.read_optional(this->require_poll_);
buffer.read_optional(this->delayed_ioctl_);
@@ -546,7 +542,7 @@ namespace
const auto addr = convert_to_host_address(win_emu, std::span(data).subspan(address_offset));
if (bind(*this->s_, &addr.get_addr(), addr.get_size()) == SOCKET_ERROR)
if (!this->s_->bind(addr))
{
return STATUS_ADDRESS_ALREADY_ASSOCIATED;
}
@@ -554,12 +550,12 @@ namespace
return STATUS_SUCCESS;
}
static std::vector<SOCKET> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
static std::vector<network::i_socket*> resolve_endpoints(windows_emulator& win_emu,
const std::span<const AFD_POLL_HANDLE_INFO64> handles)
{
auto& proc = win_emu.process;
std::vector<SOCKET> endpoints{};
std::vector<network::i_socket*> endpoints{};
endpoints.reserve(handles.size());
for (const auto& handle : handles)
@@ -576,7 +572,7 @@ namespace
throw std::runtime_error("Invalid AFD endpoint!");
}
endpoints.push_back(*endpoint->s_);
endpoints.push_back(endpoint->s_.get());
}
return endpoints;
@@ -635,17 +631,14 @@ namespace
}
network::address from{};
auto from_length = from.get_max_size();
std::vector<char> data{};
std::vector<std::byte> data{};
data.resize(buffer.len);
const auto recevied_data = recvfrom(*this->s_, data.data(), static_cast<send_size>(data.size()), 0,
&from.get_addr(), &from_length);
const auto recevied_data = this->s_->recvfrom(from, data);
if (recevied_data < 0)
{
const auto error = GET_SOCKET_ERROR();
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, true);
@@ -655,8 +648,6 @@ namespace
return STATUS_UNSUCCESSFUL;
}
assert(from.get_size() == from_length);
const auto data_size = std::min(data.size(), static_cast<size_t>(recevied_data));
emu.write_memory(buffer.buf, data.data(), data_size);
@@ -704,13 +695,10 @@ namespace
const auto target = convert_to_host_address(win_emu, address_buffer);
const auto data = emu.read_memory(buffer.buf, buffer.len);
const auto sent_data =
sendto(*this->s_, reinterpret_cast<const char*>(data.data()), static_cast<send_size>(data.size()),
0 /* TODO */, &target.get_addr(), target.get_size());
const auto sent_data = this->s_->sendto(target, data);
if (sent_data < 0)
{
const auto error = GET_SOCKET_ERROR();
const auto error = this->s_->get_last_error();
if (error == SERR(EWOULDBLOCK))
{
this->delay_ioctrl(c, {}, false);

View File

@@ -8,6 +8,11 @@
#include <utils/time.hpp>
namespace network
{
struct socket_factory;
}
// TODO: Replace with pointer handling structure for future 32 bit support
using emulator_pointer = uint64_t;
@@ -51,6 +56,7 @@ using memory_manager_wrapper = object_wrapper<memory_manager>;
using module_manager_wrapper = object_wrapper<module_manager>;
using process_context_wrapper = object_wrapper<process_context>;
using windows_emulator_wrapper = object_wrapper<windows_emulator>;
using socket_factory_wrapper = object_wrapper<network::socket_factory>;
template <typename T>
class emulator_object

View File

@@ -0,0 +1,26 @@
#pragma once
#include <span>
#include <network/socket.hpp>
namespace network
{
struct i_socket
{
virtual ~i_socket() = default;
virtual void set_blocking(bool blocking) = 0;
virtual int get_last_error() = 0;
virtual bool is_ready(bool in_poll) = 0;
virtual bool bind(const address& addr) = 0;
virtual sent_size send(std::span<const std::byte> data) = 0;
virtual sent_size sendto(const address& destination, std::span<const std::byte> data) = 0;
virtual sent_size recv(std::span<std::byte> data) = 0;
virtual sent_size recvfrom(address& source, std::span<std::byte> data) = 0;
};
}

View File

@@ -0,0 +1,55 @@
#include "socket_factory.hpp"
#include "socket_wrapper.hpp"
namespace network
{
socket_factory::socket_factory()
{
initialize_wsa();
}
std::unique_ptr<i_socket> socket_factory::create_socket(int af, int type, int protocol)
{
return std::make_unique<socket_wrapper>(af, type, protocol);
}
int socket_factory::poll_sockets(const std::span<poll_entry> entries)
{
std::vector<pollfd> poll_data{};
poll_data.reserve(entries.size());
for (const auto& entry : entries)
{
if (!entry.s)
{
throw std::runtime_error("Bad socket given!");
}
const auto* wrapper = dynamic_cast<socket_wrapper*>(entry.s);
if (!wrapper)
{
throw std::runtime_error("Socket was not created using the given factory");
}
pollfd fd{};
fd.fd = wrapper->get().get_socket();
fd.events = entry.events;
fd.revents = entry.revents;
poll_data.push_back(fd);
}
const auto res = poll(poll_data.data(), static_cast<uint32_t>(poll_data.size()), 0);
for (size_t i = 0; i < poll_data.size() && i < entries.size(); ++i)
{
auto& entry = entries[i];
const auto& fd = poll_data[i];
entry.events = fd.events;
entry.revents = fd.revents;
}
return res;
}
}

View File

@@ -0,0 +1,24 @@
#pragma once
#include "i_socket.hpp"
#include <memory>
namespace network
{
struct poll_entry
{
i_socket* s{};
int16_t events{};
int16_t revents{};
};
struct socket_factory
{
socket_factory();
virtual ~socket_factory() = default;
virtual std::unique_ptr<i_socket> create_socket(int af, int type, int protocol);
virtual int poll_sockets(std::span<poll_entry> entries);
};
}

View File

@@ -0,0 +1,59 @@
#include "socket_wrapper.hpp"
#include <cassert>
namespace network
{
socket_wrapper::socket_wrapper(const int af, const int type, const int protocol)
: socket_(af, type, protocol)
{
}
void socket_wrapper::set_blocking(const bool blocking)
{
this->socket_.set_blocking(blocking);
}
int socket_wrapper::get_last_error()
{
return GET_SOCKET_ERROR();
}
bool socket_wrapper::is_ready(const bool in_poll)
{
return this->socket_.is_ready(in_poll);
}
bool socket_wrapper::bind(const address& addr)
{
return this->socket_.bind(addr);
}
sent_size socket_wrapper::send(const std::span<const std::byte> data)
{
return ::send(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),
static_cast<send_size>(data.size()), 0);
}
sent_size socket_wrapper::sendto(const address& destination, const std::span<const std::byte> data)
{
return ::sendto(this->socket_.get_socket(), reinterpret_cast<const char*>(data.data()),
static_cast<send_size>(data.size()), 0, &destination.get_addr(), destination.get_size());
}
sent_size socket_wrapper::recv(std::span<std::byte> data)
{
return ::recv(this->socket_.get_socket(), reinterpret_cast<char*>(data.data()),
static_cast<send_size>(data.size()), 0);
}
sent_size socket_wrapper::recvfrom(address& source, std::span<std::byte> data)
{
auto source_length = source.get_max_size();
const auto res = ::recvfrom(this->socket_.get_socket(), reinterpret_cast<char*>(data.data()),
static_cast<send_size>(data.size()), 0, &source.get_addr(), &source_length);
assert(source.get_size() == source_length);
return res;
}
}

View File

@@ -0,0 +1,35 @@
#pragma once
#include "i_socket.hpp"
namespace network
{
class socket_wrapper : public i_socket
{
public:
socket_wrapper(int af, int type, int protocol);
~socket_wrapper() override = default;
void set_blocking(bool blocking) override;
int get_last_error() override;
bool is_ready(bool in_poll) override;
bool bind(const address& addr) override;
sent_size send(std::span<const std::byte> data) override;
sent_size sendto(const address& destination, std::span<const std::byte> data) override;
sent_size recv(std::span<std::byte> data) override;
sent_size recvfrom(address& source, std::span<std::byte> data) override;
const socket& get() const
{
return this->socket_;
}
private:
socket socket_{};
};
}

View File

@@ -179,8 +179,14 @@ namespace
}
};
std::unique_ptr<utils::clock> get_clock(const uint64_t& instructions, const bool use_relative_time)
std::unique_ptr<utils::clock> get_clock(emulator_interfaces& interfaces, const uint64_t& instructions,
const bool use_relative_time)
{
if (interfaces.clock)
{
return std::move(interfaces.clock);
}
if (use_relative_time)
{
return std::make_unique<instruction_tick_clock>(instructions);
@@ -188,6 +194,15 @@ namespace
return std::make_unique<utils::clock>();
}
std::unique_ptr<network::socket_factory> get_socket_factory(emulator_interfaces& interfaces)
{
if (interfaces.socket_factory)
{
return std::move(interfaces.socket_factory);
}
return std::make_unique<network::socket_factory>();
}
}
std::unique_ptr<x64_emulator> create_default_x64_emulator()
@@ -196,8 +211,9 @@ std::unique_ptr<x64_emulator> create_default_x64_emulator()
}
windows_emulator::windows_emulator(application_settings app_settings, const emulator_settings& settings,
emulator_callbacks callbacks, std::unique_ptr<x64_emulator> emu)
: windows_emulator(settings, std::move(emu))
emulator_callbacks callbacks, emulator_interfaces interfaces,
std::unique_ptr<x64_emulator> emu)
: windows_emulator(settings, std::move(interfaces), std::move(emu))
{
this->callbacks = std::move(callbacks);
@@ -205,9 +221,11 @@ windows_emulator::windows_emulator(application_settings app_settings, const emul
this->setup_process(app_settings);
}
windows_emulator::windows_emulator(const emulator_settings& settings, std::unique_ptr<x64_emulator> emu)
windows_emulator::windows_emulator(const emulator_settings& settings, emulator_interfaces interfaces,
std::unique_ptr<x64_emulator> emu)
: emu_(std::move(emu)),
clock_(get_clock(this->executed_instructions_, settings.use_relative_time)),
clock_(get_clock(interfaces, this->executed_instructions_, settings.use_relative_time)),
socket_factory_(get_socket_factory(interfaces)),
emulation_root{settings.emulation_root.empty() ? settings.emulation_root : absolute(settings.emulation_root)},
file_sys(emulation_root.empty() ? emulation_root : emulation_root / "filesys"),
memory(*this->emu_),
@@ -556,6 +574,10 @@ void windows_emulator::deserialize(utils::buffer_deserializer& buffer)
return clock_wrapper{this->clock()}; //
});
buffer.register_factory<socket_factory_wrapper>([this] {
return socket_factory_wrapper{this->socket_factory()}; //
});
buffer.read(this->executed_instructions_);
buffer.read(this->switch_thread_);

View File

@@ -11,6 +11,7 @@
#include "file_system.hpp"
#include "memory_manager.hpp"
#include "module/module_manager.hpp"
#include "network/socket_factory.hpp"
std::unique_ptr<x64_emulator> create_default_x64_emulator();
@@ -45,12 +46,19 @@ struct emulator_settings
std::set<std::string, std::less<>> modules{};
};
struct emulator_interfaces
{
std::unique_ptr<utils::clock> clock{};
std::unique_ptr<network::socket_factory> socket_factory{};
};
class windows_emulator
{
uint64_t executed_instructions_{0};
std::unique_ptr<x64_emulator> emu_{};
std::unique_ptr<utils::clock> clock_{};
std::unique_ptr<network::socket_factory> socket_factory_{};
public:
std::filesystem::path emulation_root{};
@@ -63,10 +71,10 @@ class windows_emulator
process_context process;
syscall_dispatcher dispatcher;
windows_emulator(const emulator_settings& settings = {},
windows_emulator(const emulator_settings& settings = {}, emulator_interfaces interfaces = {},
std::unique_ptr<x64_emulator> emu = create_default_x64_emulator());
windows_emulator(application_settings app_settings, const emulator_settings& settings = {},
emulator_callbacks callbacks = {},
emulator_callbacks callbacks = {}, emulator_interfaces interfaces = {},
std::unique_ptr<x64_emulator> emu = create_default_x64_emulator());
windows_emulator(windows_emulator&&) = delete;
@@ -96,6 +104,16 @@ class windows_emulator
return *this->clock_;
}
network::socket_factory& socket_factory()
{
return *this->socket_factory_;
}
const network::socket_factory& socket_factory() const
{
return *this->socket_factory_;
}
emulator_thread& current_thread() const
{
if (!this->process.active_thread)