Fix basic registry support and add test

This commit is contained in:
momo5502
2024-11-04 18:36:18 +01:00
parent 6937827e59
commit 808dca6455
7 changed files with 74 additions and 12 deletions

View File

@@ -4,6 +4,7 @@
#include <fstream>
#include <thread>
#include <atomic>
#include <optional>
#include <vector>
#include <Windows.h>
@@ -92,6 +93,49 @@ bool test_io()
return text == buffer;
}
std::optional<std::string> read_registry_string(const HKEY root, const char* path, const char* value)
{
HKEY key{};
if (RegOpenKeyExA(root, path, 0, KEY_READ, &key) !=
ERROR_SUCCESS)
{
return std::nullopt;
}
char data[MAX_PATH]{};
DWORD length = sizeof(data);
const auto res = RegQueryValueExA(key, value, nullptr, nullptr, reinterpret_cast<uint8_t*>(data), &length);
if (RegCloseKey(key) != ERROR_SUCCESS)
{
return std::nullopt;
}
if (res != ERROR_SUCCESS)
{
return std::nullopt;
}
if (length == 0)
{
return "";
}
return {std::string(data, min(length - 1, sizeof(data)))};
}
bool test_registry()
{
const auto val = read_registry_string(HKEY_LOCAL_MACHINE, R"(SOFTWARE\Microsoft\Windows\CurrentVersion)",
"ProgramFilesDir");
if (!val)
{
return false;
}
return *val == "C:\\Program Files";
}
void throw_exception()
{
if (do_the_task)
@@ -128,7 +172,8 @@ bool test_native_exceptions()
throw_native_exception();
return false;
}
__except(EXCEPTION_EXECUTE_HANDLER) {
__except (EXCEPTION_EXECUTE_HANDLER)
{
return true;
}
}
@@ -146,6 +191,7 @@ int main(int /*argc*/, const char* /*argv*/[])
bool valid = true;
RUN_TEST(test_io, "I/O")
RUN_TEST(test_registry, "Registry")
RUN_TEST(test_threads, "Threads")
RUN_TEST(test_env, "Environment")
RUN_TEST(test_exceptions, "Exceptions")

View File

@@ -95,7 +95,7 @@ namespace handle_detail
};
}
template <handle_types::type Type, typename T>
template <handle_types::type Type, typename T, uint32_t IndexShift = 0>
requires(utils::Serializable<T>)
class handle_store
{
@@ -117,7 +117,7 @@ public:
h.bits = 0;
h.value.is_pseudo = false;
h.value.type = Type;
h.value.id = index;
h.value.id = index << IndexShift;
return h;
}
@@ -287,7 +287,7 @@ private:
return this->store_.end();
}
return this->store_.find(h.id);
return this->store_.find(static_cast<uint32_t>(h.id) >> IndexShift);
}
uint32_t find_free_index()

View File

@@ -394,7 +394,7 @@ struct process_context
handle_store<handle_types::file, file> files{};
handle_store<handle_types::semaphore, semaphore> semaphores{};
handle_store<handle_types::port, port> ports{};
handle_store<handle_types::registry, registry_key> registry_keys{};
handle_store<handle_types::registry, registry_key, 2> registry_keys{};
std::map<uint16_t, std::wstring> atoms{};
std::vector<std::byte> default_register_set{};

View File

@@ -149,7 +149,7 @@ const hive_value* hive_key::get_value(std::ifstream& file, const std::string_vie
auto& value = entry->second;
if (value.parsed)
if (!value.parsed)
{
value.data = read_file_data(file, MAIN_ROOT_OFFSET + value.data_offset, value.data_length);
value.parsed = true;

View File

@@ -7,6 +7,14 @@
namespace
{
void string_to_lower(std::string& str)
{
std::ranges::transform(str, str.begin(), [](const char val)
{
return static_cast<char>(std::tolower(static_cast<unsigned char>(val)));
});
}
std::filesystem::path canonicalize_path(const std::filesystem::path& key)
{
auto path = key.lexically_normal().wstring();
@@ -107,7 +115,7 @@ std::optional<registry_key> registry_manager::get_key(const std::filesystem::pat
{
registry_key reg_key{};
reg_key.hive = normal_key;
return { std::move(reg_key) };
return {std::move(reg_key)};
}
const auto iterator = this->find_hive(normal_key);
@@ -134,8 +142,10 @@ std::optional<registry_key> registry_manager::get_key(const std::filesystem::pat
return {std::move(reg_key)};
}
std::optional<registry_value> registry_manager::get_value(const registry_key& key, const std::string_view name)
std::optional<registry_value> registry_manager::get_value(const registry_key& key, std::string name)
{
string_to_lower(name);
const auto iterator = this->hives_.find(key.hive);
if (iterator == this->hives_.end())
{

View File

@@ -51,7 +51,7 @@ public:
void deserialize(utils::buffer_deserializer& buffer);
std::optional<registry_key> get_key(const std::filesystem::path& key);
std::optional<registry_value> get_value(const registry_key& key, const std::string_view name);
std::optional<registry_value> get_value(const registry_key& key, std::string name);
private:
std::filesystem::path hive_path_{};

View File

@@ -183,7 +183,7 @@ namespace
if (key_value_information_class == KeyValueBasicInformation)
{
const auto required_size = sizeof(KEY_VALUE_BASIC_INFORMATION) + (original_name.size() * 2) - 1;
const auto required_size = offsetof(KEY_VALUE_BASIC_INFORMATION, Name) + (original_name.size() * 2) - 1;
result_length.write(static_cast<ULONG>(required_size));
if (required_size > length)
@@ -208,7 +208,7 @@ namespace
if (key_value_information_class == KeyValuePartialInformation)
{
const auto required_size = sizeof(KEY_VALUE_PARTIAL_INFORMATION) + value->data.size() - 1;
const auto required_size = offsetof(KEY_VALUE_PARTIAL_INFORMATION, Data) + value->data.size();
result_length.write(static_cast<ULONG>(required_size));
if (required_size > length)
@@ -235,7 +235,7 @@ namespace
{
const auto name_size = original_name.size() * 2;
const auto value_size = value->data.size();
const auto required_size = sizeof(KEY_VALUE_FULL_INFORMATION) + name_size + value_size + -1;
const auto required_size = offsetof(KEY_VALUE_FULL_INFORMATION, Name) + name_size + value_size + -1;
result_length.write(static_cast<ULONG>(required_size));
if (required_size > length)
@@ -268,6 +268,11 @@ namespace
return STATUS_NOT_SUPPORTED;
}
NTSTATUS handle_NtCreateKey()
{
return STATUS_NOT_SUPPORTED;
}
NTSTATUS handle_NtSetInformationThread(const syscall_context& c, const uint64_t thread_handle,
const THREADINFOCLASS info_class,
const uint64_t thread_information,
@@ -2537,6 +2542,7 @@ void syscall_dispatcher::add_handlers(std::map<std::string, syscall_handler>& ha
add_handler(NtQueryKey);
add_handler(NtGetNlsSectionPtr);
add_handler(NtAccessCheck);
add_handler(NtCreateKey);
#undef add_handler
}