Add support for user callbacks

This commit is contained in:
Igor Pissolati
2026-01-03 20:26:31 -03:00
parent 7c912146fb
commit 9fdc2a4ce6
13 changed files with 413 additions and 18 deletions

View File

@@ -167,8 +167,15 @@ namespace
const auto var_ptr = get_function_argument(win_emu.emu(), index);
if (var_ptr && !is_int_resource(var_ptr))
{
const auto str = read_string<CharType>(win_emu.memory, var_ptr);
print_string(win_emu.log, str);
try
{
const auto str = read_string<CharType>(win_emu.memory, var_ptr);
print_string(win_emu.log, str);
}
catch (...)
{
print_string(win_emu.log, "[failed to read]");
}
}
}

View File

@@ -59,6 +59,12 @@ class typed_emulator : public emulator
return result;
}
void write_stack(const size_t index, const pointer_type& value)
{
const auto sp = this->read_stack_pointer();
this->write_memory(sp + (index * pointer_size), &value, sizeof(value));
}
void push_stack(const pointer_type& value)
{
const auto sp = this->read_stack_pointer() - pointer_size;

View File

@@ -509,6 +509,26 @@ namespace
return true;
}
bool validate_primary_monitor(MONITORINFOEXA& mi)
{
if (std::string_view(mi.szDevice) != R"(\\.\DISPLAY1)")
{
return false;
}
if (mi.rcMonitor.left != 0 || mi.rcMonitor.top != 0 || mi.rcMonitor.right != 1920 || mi.rcMonitor.bottom != 1080)
{
return false;
}
if (!(mi.dwFlags & MONITORINFOF_PRIMARY))
{
return false;
}
return true;
}
bool test_monitor_info()
{
const POINT pt = {0, 0};
@@ -526,22 +546,35 @@ namespace
return false;
}
if (std::string_view(mi.szDevice) != R"(\\.\DISPLAY1)")
return validate_primary_monitor(mi);
}
BOOL CALLBACK monitor_enum_proc(HMONITOR hMonitor, HDC, LPRECT, LPARAM dwData)
{
auto* valid = reinterpret_cast<bool*>(dwData);
MONITORINFOEXA mi;
mi.cbSize = sizeof(mi);
if (!GetMonitorInfoA(hMonitor, &mi))
{
return FALSE;
}
*valid = validate_primary_monitor(mi);
return *valid ? TRUE : FALSE;
}
bool test_user_callback()
{
bool valid = false;
if (!EnumDisplayMonitors(nullptr, nullptr, monitor_enum_proc, reinterpret_cast<LPARAM>(&valid)))
{
return false;
}
if (mi.rcMonitor.left != 0 || mi.rcMonitor.top != 0 || mi.rcMonitor.right != 1920 || mi.rcMonitor.bottom != 1080)
{
return false;
}
if (!(mi.dwFlags & MONITORINFOF_PRIMARY))
{
return false;
}
return true;
return valid;
}
bool test_time_zone()
@@ -957,9 +990,9 @@ int main(const int argc, const char* argv[])
RUN_TEST(test_interrupts, "Interrupts")
}
RUN_TEST(test_tls, "TLS")
RUN_TEST(test_socket, "Socket")
RUN_TEST(test_apc, "APC")
RUN_TEST(test_user_callback, "User Callback")
return valid ? 0 : 1;
}

View File

@@ -35,6 +35,48 @@ struct pending_apc
}
};
enum class callback_id : uint32_t
{
Invalid = 0,
NtUserEnumDisplayMonitors,
};
struct callback_frame
{
callback_id handler_id;
uint64_t rip;
uint64_t rsp;
uint64_t r10;
uint64_t rcx;
uint64_t rdx;
uint64_t r8;
uint64_t r9;
void serialize(utils::buffer_serializer& buffer) const
{
buffer.write(this->handler_id);
buffer.write(this->rip);
buffer.write(this->rsp);
buffer.write(this->r10);
buffer.write(this->rcx);
buffer.write(this->rdx);
buffer.write(this->r8);
buffer.write(this->r9);
}
void deserialize(utils::buffer_deserializer& buffer)
{
buffer.read(this->handler_id);
buffer.read(this->rip);
buffer.read(this->rsp);
buffer.read(this->r10);
buffer.read(this->rcx);
buffer.read(this->rdx);
buffer.read(this->r8);
buffer.read(this->r9);
}
};
class emulator_thread : public ref_counted_object
{
public:
@@ -105,6 +147,8 @@ class emulator_thread : public ref_counted_object
bool debugger_hide{false};
std::vector<callback_frame> callback_stack;
void mark_as_ready(NTSTATUS status);
bool is_await_time_over(utils::clock& clock) const
@@ -188,6 +232,8 @@ class emulator_thread : public ref_counted_object
buffer.write_vector(this->last_registers);
buffer.write(this->debugger_hide);
buffer.write_vector(this->callback_stack);
}
void deserialize_object(utils::buffer_deserializer& buffer) override
@@ -236,6 +282,8 @@ class emulator_thread : public ref_counted_object
buffer.read_vector(this->last_registers);
buffer.read(this->debugger_hide);
buffer.read_vector(this->callback_stack);
}
void leak_memory()

View File

@@ -440,3 +440,48 @@ inline uint64_t get_function_argument(x86_64_emulator& emu, const size_t index,
return emu.read_stack(index + 1);
}
}
inline void set_function_argument(x86_64_emulator& emu, const size_t index, const uint64_t value, const bool is_syscall = false)
{
bool use_32bit_stack = false;
if (!is_syscall)
{
const auto cs_selector = emu.reg<uint16_t>(x86_register::cs);
const auto bitness = segment_utils::get_segment_bitness(emu, cs_selector);
use_32bit_stack = bitness && *bitness == segment_utils::segment_bitness::bit32;
}
if (use_32bit_stack)
{
const auto esp = emu.reg<uint32_t>(x86_register::esp);
const auto address = static_cast<uint64_t>(esp) + static_cast<uint64_t>((index + 1) * sizeof(uint32_t));
emu.write_memory<uint32_t>(address, static_cast<uint32_t>(value));
return;
}
switch (index)
{
case 0:
emu.reg(is_syscall ? x86_register::r10 : x86_register::rcx, value);
break;
case 1:
emu.reg(x86_register::rdx, value);
break;
case 2:
emu.reg(x86_register::r8, value);
break;
case 3:
emu.reg(x86_register::r9, value);
break;
default:
emu.write_stack(index + 1, value);
break;
}
}
constexpr size_t aligned_stack_space(const size_t arg_count)
{
const size_t slots = (arg_count < 4) ? 4 : arg_count;
const size_t bytes = slots * sizeof(uint64_t);
return (bytes + 15) & ~15;
}

View File

@@ -449,6 +449,64 @@ void process_context::setup(x86_64_emulator& emu, memory_manager& memory, regist
});
}
void process_context::setup_callback_hook(windows_emulator& win_emu, memory_manager& memory)
{
uint64_t sentinel_addr = this->callback_sentinel_addr;
if (!sentinel_addr)
{
auto sentinel_obj = this->base_allocator.reserve_page_aligned<uint8_t>();
sentinel_addr = sentinel_obj.value();
this->callback_sentinel_addr = sentinel_addr;
const uint8_t ret_opcode = 0xC3;
win_emu.emu().write_memory(sentinel_addr, &ret_opcode, 1);
const auto sentinel_aligned_length = page_align_up(sentinel_addr + 1) - sentinel_addr;
memory.protect_memory(sentinel_addr, static_cast<size_t>(sentinel_aligned_length), memory_permission::all);
}
auto& emu = win_emu.emu();
emu.hook_memory_execution(sentinel_addr, [&](uint64_t) {
auto* t = this->active_thread;
if (!t || t->callback_stack.empty())
{
return;
}
const auto frame = t->callback_stack.back();
t->callback_stack.pop_back();
const auto callbacks_before = t->callback_stack.size();
const uint64_t guest_result = emu.reg(x86_register::rax);
emu.reg(x86_register::rip, frame.rip);
emu.reg(x86_register::rsp, frame.rsp);
emu.reg(x86_register::r10, frame.r10);
emu.reg(x86_register::rcx, frame.rcx);
emu.reg(x86_register::rdx, frame.rdx);
emu.reg(x86_register::r8, frame.r8);
emu.reg(x86_register::r9, frame.r9);
win_emu.dispatcher.dispatch_completion(win_emu, frame.handler_id, guest_result);
uint64_t target_rip = emu.reg(x86_register::rip);
emu.reg(x86_register::rip, this->callback_sentinel_addr);
const bool new_callback_dispatched = t->callback_stack.size() > callbacks_before;
if (!new_callback_dispatched)
{
// Move past the syscall instruction
target_rip += 2;
}
const uint64_t ret_stack_ptr = frame.rsp - sizeof(emulator_pointer);
emu.write_memory(ret_stack_ptr, &target_rip, sizeof(target_rip));
emu.reg(x86_register::rsp, ret_stack_ptr);
});
}
void process_context::serialize(utils::buffer_serializer& buffer) const
{
buffer.write(this->shared_section_address);
@@ -496,6 +554,8 @@ void process_context::serialize(utils::buffer_serializer& buffer) const
buffer.write(this->threads);
buffer.write(this->threads.find_handle(this->active_thread).bits);
buffer.write(this->callback_sentinel_addr);
}
void process_context::deserialize(utils::buffer_deserializer& buffer)
@@ -551,6 +611,8 @@ void process_context::deserialize(utils::buffer_deserializer& buffer)
buffer.read(this->threads);
this->active_thread = this->threads.get(buffer.read<uint64_t>());
buffer.read(this->callback_sentinel_addr);
}
generic_handle_store* process_context::get_handle_store(const handle handle)

View File

@@ -76,6 +76,8 @@ struct process_context
const mapped_module& executable, const mapped_module& ntdll, const apiset::container& apiset_container,
const mapped_module* ntdll32 = nullptr);
void setup_callback_hook(windows_emulator& win_emu, memory_manager& memory);
handle create_thread(memory_manager& memory, uint64_t start_address, uint64_t argument, uint64_t stack_size, uint32_t create_flags,
bool initial_thread = false);
@@ -147,6 +149,8 @@ struct process_context
handle_store<handle_types::thread, emulator_thread> threads{};
emulator_thread* active_thread{nullptr};
emulator_pointer callback_sentinel_addr{0};
// Extended parameters from last NtMapViewOfSectionEx call
// These can be used by other syscalls like NtAllocateVirtualMemoryEx
uint64_t last_extended_params_numa_node{0};

View File

@@ -22,6 +22,7 @@ void syscall_dispatcher::deserialize(utils::buffer_deserializer& buffer)
{
buffer.read_map(this->handlers_);
this->add_handlers();
this->add_callbacks();
}
void syscall_dispatcher::setup(const exported_symbols& ntdll_exports, const std::span<const std::byte> ntdll_data,
@@ -36,6 +37,7 @@ void syscall_dispatcher::setup(const exported_symbols& ntdll_exports, const std:
map_syscalls(this->handlers_, win32u_syscalls);
this->add_handlers();
this->add_callbacks();
}
void syscall_dispatcher::add_handlers()
@@ -140,6 +142,41 @@ void syscall_dispatcher::dispatch_callback(windows_emulator& win_emu, std::strin
}
}
void syscall_dispatcher::dispatch_completion(windows_emulator& win_emu, callback_id callback_id, uint64_t guest_result)
{
auto& emu = win_emu.emu();
const syscall_context c{
.win_emu = win_emu,
.emu = emu,
.proc = win_emu.process,
.write_status = true,
};
const auto entry = this->callbacks_.find(callback_id);
if (entry == this->callbacks_.end())
{
win_emu.log.error("Unknown callback: 0x%X\n", static_cast<uint32_t>(callback_id));
c.emu.stop();
return;
}
try
{
entry->second(c, guest_result);
}
catch (std::exception& e)
{
win_emu.log.error("Callback 0x%X threw an exception - %s\n", static_cast<int>(callback_id), e.what());
emu.stop();
}
catch (...)
{
win_emu.log.error("Callback 0x%X threw an unknown exception\n", static_cast<int>(callback_id));
emu.stop();
}
}
syscall_dispatcher::syscall_dispatcher(const exported_symbols& ntdll_exports, const std::span<const std::byte> ntdll_data,
const exported_symbols& win32u_exports, const std::span<const std::byte> win32u_data)
{

View File

@@ -4,6 +4,7 @@
struct syscall_context;
using syscall_handler = void (*)(const syscall_context& c);
using callback_completion_handler = void (*)(const syscall_context& c, uint64_t guest_result);
struct syscall_handler_entry
{
@@ -22,6 +23,7 @@ class syscall_dispatcher
void dispatch(windows_emulator& win_emu);
static void dispatch_callback(windows_emulator& win_emu, std::string& syscall_name);
void dispatch_completion(windows_emulator& win_emu, callback_id callback_id, uint64_t guest_result);
void serialize(utils::buffer_serializer& buffer) const;
void deserialize(utils::buffer_deserializer& buffer);
@@ -36,7 +38,9 @@ class syscall_dispatcher
private:
std::map<uint64_t, syscall_handler_entry> handlers_{};
std::map<callback_id, callback_completion_handler> callbacks_{};
static void add_handlers(std::map<std::string, syscall_handler>& handler_mapping);
void add_handlers();
void add_callbacks();
};

View File

@@ -12,6 +12,7 @@ struct syscall_context
process_context& proc;
mutable bool write_status{true};
mutable bool retrigger_syscall{false};
mutable bool run_callback{false};
};
inline uint64_t get_syscall_argument(x86_64_emulator& emu, const size_t index)
@@ -132,15 +133,16 @@ T resolve_indexed_argument(x86_64_emulator& emu, size_t& index)
return resolve_argument<T>(emu, index++);
}
inline void write_syscall_result(const syscall_context& c, const uint64_t result, const uint64_t initial_ip)
inline void write_syscall_result(const syscall_context& c, const uint64_t result, const uint64_t initial_ip,
const bool is_callback_completion = false)
{
if (c.write_status && !c.retrigger_syscall)
if (c.write_status && !c.retrigger_syscall && !c.run_callback)
{
c.emu.reg<uint64_t>(x86_register::rax, result);
}
const auto new_ip = c.emu.read_instruction_pointer();
if (initial_ip != new_ip || c.retrigger_syscall)
if ((initial_ip != new_ip || c.retrigger_syscall || c.run_callback) && !is_callback_completion)
{
c.emu.reg(x86_register::rip, new_ip - 2);
}
@@ -176,6 +178,36 @@ syscall_handler make_syscall_handler()
return +[](const syscall_context& c) { forward_syscall(c, Handler); };
}
template <typename Result>
void forward_callback_completion(const syscall_context& c, const uint64_t /*guest_result*/, Result (*handler)())
{
const auto ip = c.emu.read_instruction_pointer();
const auto ret = handler();
write_syscall_result(c, static_cast<uint64_t>(ret), ip, true);
}
template <typename Result, typename GuestResult, typename... Args>
void forward_callback_completion(const syscall_context& c, const uint64_t guest_result,
Result (*handler)(const syscall_context&, GuestResult, Args...))
{
const auto ip = c.emu.read_instruction_pointer();
size_t index = 0;
std::tuple<const syscall_context&, GuestResult, Args...> func_args{
c, static_cast<GuestResult>(guest_result),
resolve_indexed_argument<std::remove_cv_t<std::remove_reference_t<Args>>>(c.emu, index)...};
const auto ret = std::apply(handler, std::move(func_args));
write_syscall_result(c, ret, ip, true);
}
template <auto Handler>
callback_completion_handler make_callback_completion_handler()
{
return +[](const syscall_context& c, const uint64_t guest_result) { forward_callback_completion(c, guest_result, Handler); };
}
template <typename T, typename Traits>
void write_attribute(emulator& emu, const PS_ATTRIBUTE<Traits>& attribute, const T& value)
{

View File

@@ -3,6 +3,7 @@
#include "cpu_context.hpp"
#include "emulator_utils.hpp"
#include "syscall_utils.hpp"
#include "user_callback_dispatch.hpp"
#include <numeric>
#include <cwctype>
@@ -1007,6 +1008,47 @@ namespace syscalls
return STATUS_UNSUCCESSFUL;
}
BOOL handle_NtUserEnumDisplayMonitors(const syscall_context& c, const hdc hdc_in, const uint64_t clip_rect_ptr, const uint64_t callback,
const uint64_t param)
{
if (!callback)
{
return FALSE;
}
const auto hmon = c.win_emu.process.default_monitor_handle.bits;
const auto display_info = c.proc.user_handles.get_display_info().read();
if (clip_rect_ptr)
{
RECT clip{};
c.emu.read_memory(clip_rect_ptr, &clip, sizeof(clip));
const emulator_object<USER_MONITOR> monitor_obj(c.emu, display_info.pPrimaryMonitor);
const auto monitor = monitor_obj.read();
auto effective_rc{monitor.rcMonitor};
effective_rc.left = std::max(effective_rc.left, clip.left);
effective_rc.top = std::max(effective_rc.top, clip.top);
effective_rc.right = std::min(effective_rc.right, clip.right);
effective_rc.bottom = std::min(effective_rc.bottom, clip.bottom);
if (effective_rc.right <= effective_rc.left || effective_rc.bottom <= effective_rc.top)
{
return TRUE;
}
}
const uint64_t rect_ptr = display_info.pPrimaryMonitor + offsetof(USER_MONITOR, rcMonitor);
dispatch_user_callback(c, callback_id::NtUserEnumDisplayMonitors, callback, hmon, hdc_in, rect_ptr, param);
return {};
}
BOOL completion_NtUserEnumDisplayMonitors(const syscall_context&, BOOL guest_result, const hdc /*hdc_in*/,
const uint64_t /*clip_rect_ptr*/, const uint64_t /*callback*/, const uint64_t /*param*/)
{
return guest_result;
}
NTSTATUS handle_NtAssociateWaitCompletionPacket()
{
return STATUS_SUCCESS;
@@ -1049,6 +1091,19 @@ namespace syscalls
return TRUE;
}
emulator_pointer handle_NtUserMapDesktopObject(const syscall_context& c, handle handle)
{
const auto index = handle.value.id;
if (index == 0 || index >= user_handle_table::MAX_HANDLES)
{
return 0;
}
const auto handle_entry = c.proc.user_handles.get_handle_table().read(static_cast<size_t>(index));
return handle_entry.pHead;
}
}
void syscall_dispatcher::add_handlers(std::map<std::string, syscall_handler>& handler_mapping)
@@ -1238,6 +1293,7 @@ void syscall_dispatcher::add_handlers(std::map<std::string, syscall_handler>& ha
add_handler(NtUserGetKeyboardType);
add_handler(NtUserEnumDisplayDevices);
add_handler(NtUserEnumDisplaySettings);
add_handler(NtUserEnumDisplayMonitors);
add_handler(NtUserSetProp);
add_handler(NtUserSetProp2);
add_handler(NtUserChangeWindowMessageFilterEx);
@@ -1266,6 +1322,20 @@ void syscall_dispatcher::add_handlers(std::map<std::string, syscall_handler>& ha
add_handler(NtRemoveProcessDebug);
add_handler(NtNotifyChangeDirectoryFileEx);
add_handler(NtUserGetHDevName);
add_handler(NtUserMapDesktopObject);
#undef add_handler
}
void syscall_dispatcher::add_callbacks()
{
#define add_callback(syscall) \
do \
{ \
this->callbacks_[callback_id::syscall] = make_callback_completion_handler<syscalls::completion_##syscall>(); \
} while (0)
add_callback(NtUserEnumDisplayMonitors);
#undef add_callback
}

View File

@@ -0,0 +1,46 @@
#pragma once
#include "syscall_utils.hpp"
// TODO: Here we are calling guest functions directly, but this is not how it works in the real Windows kernel.
// In the real implementation, the kernel invokes ntdll!KiUserCallbackDispatcher and passes a callback
// index that refers to an entry in PEB->KernelCallbackTable. The dispatcher then looks up the function
// pointer in that table and invokes the corresponding user-mode callback.
template <typename... Args>
void dispatch_user_callback(const syscall_context& c, callback_id completion_id, uint64_t func_address, Args... args)
{
const uint64_t original_rsp = c.emu.read_stack_pointer();
// Save syscall argument registers BEFORE modifying anything
const callback_frame frame{
.handler_id = completion_id,
.rip = c.emu.read_instruction_pointer(),
.rsp = original_rsp,
.r10 = c.emu.reg(x86_register::r10),
.rcx = c.emu.reg(x86_register::rcx),
.rdx = c.emu.reg(x86_register::rdx),
.r8 = c.emu.reg(x86_register::r8),
.r9 = c.emu.reg(x86_register::r9),
};
uint64_t stack_ptr = align_down(original_rsp, 16);
constexpr size_t arg_count = sizeof...(Args);
const size_t allocation_size = aligned_stack_space(arg_count);
stack_ptr -= allocation_size;
// Push the return address onto the stack (Simulating CALL)
stack_ptr -= sizeof(emulator_pointer);
c.emu.write_memory(stack_ptr, &c.proc.callback_sentinel_addr, sizeof(c.proc.callback_sentinel_addr));
c.proc.active_thread->callback_stack.push_back(frame);
c.emu.reg(x86_register::rsp, stack_ptr);
size_t index = 0;
(set_function_argument(c.emu, index++, static_cast<uint64_t>(args)), ...);
c.emu.reg(x86_register::rip, func_address);
c.run_callback = true;
}

View File

@@ -554,6 +554,7 @@ void windows_emulator::start(size_t count)
{
this->should_stop = false;
this->setup_process_if_necessary();
this->process.setup_callback_hook(*this, this->memory);
const auto use_count = count > 0;
const auto start_instructions = this->executed_instructions_;