diff --git a/src/windows-emulator/context_frame.cpp b/src/windows-emulator/context_frame.cpp index 01bd06d0..e7adf26b 100644 --- a/src/windows-emulator/context_frame.cpp +++ b/src/windows-emulator/context_frame.cpp @@ -5,7 +5,7 @@ namespace context_frame { void restore(x64_emulator& emu, const CONTEXT64& context) { - if (context.ContextFlags & CONTEXT_DEBUG_REGISTERS_64) + if ((context.ContextFlags & CONTEXT_DEBUG_REGISTERS_64) == CONTEXT_DEBUG_REGISTERS_64) { emu.reg(x64_register::dr0, context.Dr0); emu.reg(x64_register::dr1, context.Dr1); @@ -15,7 +15,7 @@ namespace context_frame emu.reg(x64_register::dr7, context.Dr7); } - if (context.ContextFlags & CONTEXT_CONTROL_64) + if ((context.ContextFlags & CONTEXT_CONTROL_64) == CONTEXT_CONTROL_64) { emu.reg(x64_register::ss, context.SegSs); emu.reg(x64_register::cs, context.SegCs); @@ -26,7 +26,7 @@ namespace context_frame emu.reg(x64_register::eflags, context.EFlags); } - if (context.ContextFlags & CONTEXT_INTEGER_64) + if ((context.ContextFlags & CONTEXT_INTEGER_64) == CONTEXT_INTEGER_64) { emu.reg(x64_register::rax, context.Rax); emu.reg(x64_register::rbx, context.Rbx); @@ -45,7 +45,7 @@ namespace context_frame emu.reg(x64_register::r15, context.R15); } - /*if (context.ContextFlags & CONTEXT_SEGMENTS) + /*if ((context.ContextFlags & CONTEXT_SEGMENTS) == CONTEXT_SEGMENTS) { emu.reg(x64_register::ds, context.SegDs); emu.reg(x64_register::es, context.SegEs); @@ -53,7 +53,7 @@ namespace context_frame emu.reg(x64_register::gs, context.SegGs); }*/ - if (context.ContextFlags & CONTEXT_FLOATING_POINT_64) + if ((context.ContextFlags & CONTEXT_FLOATING_POINT_64) == CONTEXT_FLOATING_POINT_64) { emu.reg(x64_register::fpcw, context.FltSave.ControlWord); emu.reg(x64_register::fpsw, context.FltSave.StatusWord); @@ -66,7 +66,7 @@ namespace context_frame } } - if (context.ContextFlags & CONTEXT_XSTATE_64) + if ((context.ContextFlags & CONTEXT_XSTATE_64) == CONTEXT_XSTATE_64) { emu.reg(x64_register::mxcsr, context.MxCsr); @@ -80,7 +80,7 @@ namespace context_frame void save(x64_emulator& emu, CONTEXT64& context) { - if (context.ContextFlags & CONTEXT_DEBUG_REGISTERS_64) + if ((context.ContextFlags & CONTEXT_DEBUG_REGISTERS_64) == CONTEXT_DEBUG_REGISTERS_64) { context.Dr0 = emu.reg(x64_register::dr0); context.Dr1 = emu.reg(x64_register::dr1); @@ -90,7 +90,7 @@ namespace context_frame context.Dr7 = emu.reg(x64_register::dr7); } - if (context.ContextFlags & CONTEXT_CONTROL_64) + if ((context.ContextFlags & CONTEXT_CONTROL_64) == CONTEXT_CONTROL_64) { context.SegSs = emu.reg(x64_register::ss); context.SegCs = emu.reg(x64_register::cs); @@ -99,7 +99,7 @@ namespace context_frame context.EFlags = emu.reg(x64_register::eflags); } - if (context.ContextFlags & CONTEXT_INTEGER_64) + if ((context.ContextFlags & CONTEXT_INTEGER_64) == CONTEXT_INTEGER_64) { context.Rax = emu.reg(x64_register::rax); context.Rbx = emu.reg(x64_register::rbx); @@ -118,7 +118,7 @@ namespace context_frame context.R15 = emu.reg(x64_register::r15); } - if (context.ContextFlags & CONTEXT_SEGMENTS_64) + if ((context.ContextFlags & CONTEXT_SEGMENTS_64) == CONTEXT_SEGMENTS_64) { context.SegDs = emu.reg(x64_register::ds); context.SegEs = emu.reg(x64_register::es); @@ -126,7 +126,7 @@ namespace context_frame context.SegGs = emu.reg(x64_register::gs); } - if (context.ContextFlags & CONTEXT_FLOATING_POINT_64) + if ((context.ContextFlags & CONTEXT_FLOATING_POINT_64) == CONTEXT_FLOATING_POINT_64) { context.FltSave.ControlWord = emu.reg(x64_register::fpcw); context.FltSave.StatusWord = emu.reg(x64_register::fpsw); @@ -138,7 +138,7 @@ namespace context_frame } } - if (context.ContextFlags & CONTEXT_XSTATE_64) + if ((context.ContextFlags & CONTEXT_XSTATE_64) == CONTEXT_INTEGER_64) { context.MxCsr = emu.reg(x64_register::mxcsr); for (int i = 0; i < 16; i++) diff --git a/src/windows-emulator/syscalls.cpp b/src/windows-emulator/syscalls.cpp index 3551638d..01d3b070 100644 --- a/src/windows-emulator/syscalls.cpp +++ b/src/windows-emulator/syscalls.cpp @@ -3543,7 +3543,9 @@ namespace } c.proc.active_thread->save(c.emu); - const auto _ = utils::finally([&] { c.proc.active_thread->restore(c.emu); }); + const auto _ = utils::finally([&] { + c.proc.active_thread->restore(c.emu); // + }); thread->restore(c.emu); @@ -3559,6 +3561,36 @@ namespace return STATUS_SUCCESS; } + NTSTATUS handle_NtSetContextThread(const syscall_context& c, const handle thread_handle, + const emulator_object thread_context) + { + const auto* thread = thread_handle == CURRENT_THREAD ? c.proc.active_thread : c.proc.threads.get(thread_handle); + + if (!thread) + { + return STATUS_INVALID_HANDLE; + } + + const auto needs_swich = thread != c.proc.active_thread; + + if (needs_swich) + { + c.proc.active_thread->save(c.emu); + thread->restore(c.emu); + } + + const auto _ = utils::finally([&] { + if (needs_swich) + { + c.proc.active_thread->restore(c.emu); // + } + }); + + const auto context = thread_context.read(); + context_frame::restore(c.emu, context); + return STATUS_SUCCESS; + } + NTSTATUS handle_NtYieldExecution(const syscall_context& c) { c.win_emu.yield_thread(); @@ -3689,6 +3721,7 @@ void syscall_dispatcher::add_handlers(std::map& ha add_handler(NtUserGetCursorPos); add_handler(NtUserReleaseDC); add_handler(NtUserFindExistingCursorIcon); + add_handler(NtSetContextThread); #undef add_handler }