Prepare support for serialization of non-default-constructible objects

This commit is contained in:
momo5502
2024-10-25 18:56:16 +02:00
parent d29e4a811f
commit d782c80f3f
9 changed files with 112 additions and 46 deletions

View File

@@ -7,6 +7,7 @@
#include <cstring>
#include <optional>
#include <functional>
#include <typeindex>
namespace utils
{
@@ -145,7 +146,7 @@ namespace utils
template <typename T>
T read()
{
T object{};
auto object = this->construct_object<T>();
this->read(object);
return object;
}
@@ -165,11 +166,11 @@ namespace utils
template <typename T, typename F>
requires(std::is_invocable_r_v<T, F>)
void read_optional(std::optional<T>& val, const F& constructor)
void read_optional(std::optional<T>& val, const F& factory)
{
if (this->read<bool>())
{
val.emplace(constructor());
val.emplace(factory());
this->read<T>(*val);
}
else
@@ -214,7 +215,7 @@ namespace utils
auto key = this->read<key_type>();
auto value = this->read<value_type>();
map[std::move(key)] = std::move(value);
map.emplace(std::move(key), std::move(value));
}
}
@@ -263,9 +264,41 @@ namespace utils
return this->offset_;
}
template <typename T, typename F>
requires(std::is_invocable_r_v<T, F>)
void register_factory(F factory)
{
this->factories_[std::type_index(typeid(T))] = [f = std::move(factory)]() -> T* {
return new T(f());
};
}
private:
size_t offset_{0};
std::span<const std::byte> buffer_{};
std::unordered_map<std::type_index, std::function<void*()>> factories_{};
template <typename T>
T construct_object()
{
if constexpr (std::is_default_constructible_v<T>)
{
return {};
}
const auto factory = this->factories_.find(std::type_index(typeid(T)));
if (factory == this->factories_.end())
{
throw std::runtime_error(
"Object construction failed. Missing factory for type: " + std::string(typeid(T).name()));
}
auto* object = static_cast<T*>(factory->second());
auto obj = std::move(*object);
delete object;
return obj;
}
};
class buffer_serializer

View File

@@ -8,7 +8,7 @@ namespace test
emu.logger.disable_output(true);
emu.start();
assert_terminated_successfully(emu);
ASSER_TERMINATED_SUCCESSFULLY(emu);
}
TEST(EmulationTest, CountedEmulationWorks)

View File

@@ -3,16 +3,11 @@
#include <gtest/gtest.h>
#include <windows_emulator.hpp>
namespace test
{
inline void assert_terminated_with_status(const windows_emulator& win_emu, const NTSTATUS status)
{
ASSERT_TRUE(win_emu.process().exit_status.has_value());
ASSERT_EQ(*win_emu.process().exit_status, status);
}
#define ASSER_TERMINATED_WITH_STATUS(win_emu, status) \
do { \
ASSERT_TRUE(win_emu.process().exit_status.has_value()); \
ASSERT_EQ(*win_emu.process().exit_status, status); \
} while(false)
inline void assert_terminated_successfully(const windows_emulator& win_emu)
{
assert_terminated_with_status(win_emu, STATUS_SUCCESS);
}
}
#define ASSER_TERMINATED_SUCCESSFULLY(win_emu) \
ASSER_TERMINATED_WITH_STATUS(win_emu, STATUS_SUCCESS)

View File

@@ -2,13 +2,13 @@
namespace test
{
TEST(SerializationTest, DISABLED_SerializedDataIsReproducible)
TEST(SerializationTest, SerializedDataIsReproducible)
{
windows_emulator emu1{ "./test-sample.exe" };
emu1.logger.disable_output(true);
emu1.start();
assert_terminated_successfully(emu1);
ASSER_TERMINATED_SUCCESSFULLY(emu1);
utils::buffer_serializer serializer1{};
emu1.serialize(serializer1);
@@ -30,7 +30,7 @@ namespace test
emu1.logger.disable_output(true);
emu1.start();
assert_terminated_successfully(emu1);
ASSER_TERMINATED_SUCCESSFULLY(emu1);
utils::buffer_serializer serializer1{};
emu1.serialize(serializer1);
@@ -39,7 +39,7 @@ namespace test
emu2.logger.disable_output(true);
emu2.start();
assert_terminated_successfully(emu2);
ASSER_TERMINATED_SUCCESSFULLY(emu2);
utils::buffer_serializer serializer2{};
emu2.serialize(serializer2);
@@ -63,10 +63,10 @@ namespace test
new_emu.deserialize(deserializer);
new_emu.start();
assert_terminated_successfully(new_emu);
ASSER_TERMINATED_SUCCESSFULLY(new_emu);
emu.start();
assert_terminated_successfully(emu);
ASSER_TERMINATED_SUCCESSFULLY(emu);
utils::buffer_serializer serializer1{};
utils::buffer_serializer serializer2{};

View File

@@ -7,7 +7,7 @@ class emulator_object
public:
using value_type = T;
emulator_object() = default;
//emulator_object() = default;
emulator_object(emulator& emu, const uint64_t address = 0)
: emu_(&emu)
@@ -25,7 +25,7 @@ public:
return this->address_;
}
uint64_t size() const
constexpr uint64_t size() const
{
return sizeof(T);
}
@@ -174,6 +174,11 @@ public:
return this->active_address_;
}
emulator& get_emulator() const
{
return *this->emu_;
}
void serialize(utils::buffer_serializer& buffer) const
{
buffer.write(this->address_);

View File

@@ -105,7 +105,7 @@ public:
handle store(T value)
{
auto index = this->find_free_index();
this->store_[index] = std::move(value);
this->store_.emplace(index, std::move(value));
return make_handle(index);
}

View File

@@ -192,7 +192,10 @@ private:
class emulator_thread : ref_counted_object
{
public:
emulator_thread() = default;
emulator_thread(x64_emulator& emu)
: emu_ptr(&emu)
{
}
emulator_thread(x64_emulator& emu, const process_context& context, uint64_t start_address, uint64_t argument,
uint64_t stack_size, uint32_t id);
@@ -205,20 +208,7 @@ public:
~emulator_thread()
{
if (marker.was_moved())
{
return;
}
if (this->stack_base)
{
this->emu_ptr->release_memory(this->stack_base, this->stack_size);
}
if (this->gs_segment)
{
this->gs_segment->release();
}
this->release();
}
moved_marker marker{};
@@ -285,6 +275,11 @@ public:
void serialize(utils::buffer_serializer& buffer) const
{
if (this->marker.was_moved())
{
throw std::runtime_error("Object was moved!");
}
buffer.write(this->stack_base);
buffer.write(this->stack_size);
buffer.write(this->start_address);
@@ -310,6 +305,13 @@ public:
void deserialize(utils::buffer_deserializer& buffer)
{
if (this->marker.was_moved())
{
throw std::runtime_error("Object was moved!");
}
this->release();
buffer.read(this->stack_base);
buffer.read(this->stack_size);
buffer.read(this->start_address);
@@ -335,6 +337,31 @@ public:
private:
void setup_registers(x64_emulator& emu, const process_context& context) const;
void release()
{
if (this->marker.was_moved())
{
return;
}
if (this->stack_base)
{
if (!this->emu_ptr)
{
throw std::runtime_error("Emulator was never assigned!");
}
this->emu_ptr->release_memory(this->stack_base, this->stack_size);
this->stack_base = 0;
}
if (this->gs_segment)
{
this->gs_segment->release();
this->gs_segment = {};
}
}
};
struct process_context
@@ -454,6 +481,7 @@ struct process_context
buffer.read_vector(this->default_register_set);
buffer.read(this->current_thread_id);
buffer.read(this->threads);
this->active_thread = this->threads.get(buffer.read<uint64_t>());

View File

@@ -1848,7 +1848,7 @@ namespace
const ULONG share_access,
const ULONG open_options)
{
return handle_NtCreateFile(c, file_handle, desired_access, object_attributes, io_status_block, {}, 0,
return handle_NtCreateFile(c, file_handle, desired_access, object_attributes, io_status_block, {c.emu}, 0,
share_access, FILE_OPEN, open_options, 0, 0);
}

View File

@@ -279,9 +279,9 @@ namespace
if (record.ExceptionRecord)
{
record_mapping[&record] = record_obj;
record_mapping.emplace(&record, record_obj);
emulator_object<EXCEPTION_RECORD> nested_record_obj{};
emulator_object<EXCEPTION_RECORD> nested_record_obj{allocator.get_emulator()};
const auto nested_record = record_mapping.find(record.ExceptionRecord);
if (nested_record != record_mapping.end())
@@ -584,7 +584,7 @@ void emulator_thread::mark_as_ready(const NTSTATUS status)
bool emulator_thread::is_thread_ready(process_context& context)
{
if(this->exit_status.has_value())
if (this->exit_status.has_value())
{
return false;
}
@@ -903,7 +903,7 @@ void windows_emulator::start(std::chrono::nanoseconds timeout, size_t count)
const auto current_instructions = this->process().executed_instructions;
const auto diff = current_instructions - start_instructions;
if(diff >= count)
if (diff >= count)
{
break;
}
@@ -923,6 +923,11 @@ void windows_emulator::serialize(utils::buffer_serializer& buffer) const
void windows_emulator::deserialize(utils::buffer_deserializer& buffer)
{
buffer.register_factory<emulator_thread>([this]
{
return emulator_thread(this->emu());
});
this->emu().deserialize(buffer);
this->process_.deserialize(buffer);
this->dispatcher_.deserialize(buffer);