diff --git a/.claude/settings.local.json b/.claude/settings.local.json index d832524..5d4c767 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -9,7 +9,11 @@ "Bash(openspec instructions:*)", "Bash(./build/bin/runTests.exe --gtest_filter=\"KernelPerf*\")", "Bash(powershell -Command \"Start-Process cmake -ArgumentList ''--build'',''build'',''--target'',''install_WinKernelLite'',''--config'',''Debug'' -Verb RunAs -Wait\")", - "Bash(openspec list:*)" + "Bash(openspec list:*)", + "Bash(grep -r \"#include\" /c/projects/winkernellite/include/*.h)", + "Bash(grep -r \"^#ifndef\\\\|^#define\" /c/projects/winkernellite/include/*.h)", + "Bash(grep:*)", + "Bash(./build/bin/runTests.exe)" ] } } diff --git a/include/File.h b/include/File.h index ccb2e28..e0b8cbf 100644 --- a/include/File.h +++ b/include/File.h @@ -145,84 +145,4 @@ NTSTATUS ZwCreateFile( } #endif -// Implementation of user mode file functions - -inline NTSTATUS ZwCreateFile( - OUT PHANDLE FileHandle, - IN ACCESS_MASK DesiredAccess, - IN POBJECT_ATTRIBUTES ObjectAttributes, - OUT PIO_STATUS_BLOCK IoStatusBlock, - IN PLARGE_INTEGER AllocationSize OPTIONAL, - IN ULONG FileAttributes, - IN ULONG ShareAccess, - IN ULONG CreateDisposition, - IN ULONG CreateOptions, - IN PVOID EaBuffer OPTIONAL, - IN ULONG EaLength -) -{ - // Validate parameters - if (!FileHandle || !ObjectAttributes || !ObjectAttributes->ObjectName || !IoStatusBlock) { - return STATUS_INVALID_PARAMETER; - } - - // These parameters are not used in this implementation - UNREFERENCED_PARAMETER(AllocationSize); - UNREFERENCED_PARAMETER(EaBuffer); - UNREFERENCED_PARAMETER(EaLength); - - // Convert NT CreateDisposition to Win32 CreateDisposition - DWORD dwCreationDisposition; - switch (CreateDisposition) { - case FILE_SUPERSEDE: - case FILE_OVERWRITE_IF: - dwCreationDisposition = CREATE_ALWAYS; - break; - case FILE_CREATE: - dwCreationDisposition = CREATE_NEW; - break; - case FILE_OPEN: - dwCreationDisposition = OPEN_EXISTING; - break; - case FILE_OPEN_IF: - dwCreationDisposition = OPEN_ALWAYS; - break; - case FILE_OVERWRITE: - dwCreationDisposition = TRUNCATE_EXISTING; - break; - default: - return STATUS_INVALID_PARAMETER; - } - - // Create file with CreateFileW - *FileHandle = CreateFileW( - ObjectAttributes->ObjectName->Buffer, - DesiredAccess, - ShareAccess, - NULL, // Security attributes not supported - dwCreationDisposition, - FileAttributes | (CreateOptions & 0x00FFFFFF), // Convert relevant options - NULL // Template file not supported - ); - - if (*FileHandle == INVALID_HANDLE_VALUE) { - DWORD error = GetLastError(); - IoStatusBlock->Status = STATUS_UNSUCCESSFUL; - - if (error == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else if (error == ERROR_ACCESS_DENIED) { - return STATUS_ACCESS_DENIED; - } else { - return STATUS_UNSUCCESSFUL; - } - } - - IoStatusBlock->Status = STATUS_SUCCESS; - IoStatusBlock->Information = FILE_OPENED; // Simplified - real implementation would have more accurate information - return STATUS_SUCCESS; -} - - - #endif // WINKERNEL_FILE_H diff --git a/include/KernelHeap.h b/include/KernelHeap.h index e688244..cee9510 100644 --- a/include/KernelHeap.h +++ b/include/KernelHeap.h @@ -79,33 +79,26 @@ typedef struct _GLOBAL_STATE { BOOL TrackFreedMemory; /* Control whether to track freed memory for double-free detection */ } GLOBAL_STATE; -/* Function declarations */ -__forceinline GLOBAL_STATE* GetGlobalState(void); -__forceinline BOOL InitHeap(void); -__forceinline void CleanupHeap(void); -__forceinline void TrackAllocation(PVOID Address, SIZE_T Size, const char* FileName, int LineNumber); -__forceinline BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumber); -__forceinline void TrackFreedMemory(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber); -__forceinline void TrackFreedMemoryLocked(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId); -__forceinline BOOL CheckForDoubleFree(PVOID Address, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId); -__forceinline void CleanupOldFreedEntries(void); -__forceinline PVOID ExAllocatePoolWithTracking(POOL_TYPE PoolType, SIZE_T NumberOfBytes, const char* FileName, int LineNumber); -__forceinline PVOID ExAllocatePool(POOL_TYPE PoolType, SIZE_T NumberOfBytes); -__forceinline PVOID ExAllocatePoolWithTag(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag); -__forceinline PVOID ExAllocatePoolZeroWithTag(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag); -__forceinline PVOID _ExAllocatePoolZeroWithTagTracked(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag, const char* FileName, int LineNumber); -inline void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber); -__forceinline void ExFreePool(PVOID pointer); -__forceinline void ExFreePoolWithTag(PVOID pointer, ULONG Tag); -__forceinline void PrintMemoryLeaks(void); -__forceinline void PrintDoubleFreeReport(void); -__forceinline void SetErrorSuppression(BOOL suppress); -__forceinline BOOL GetErrorSuppression(void); -__forceinline void SetFreedMemoryTracking(BOOL enable); -__forceinline BOOL GetFreedMemoryTracking(void); -__forceinline void SetMaxFreedEntries(SIZE_T maxEntries); -__forceinline SIZE_T GetMaxFreedEntries(void); -__forceinline BOOL IsValidHeapPointer(PVOID pointer); +/* Function declarations - implementations in KernelHeap.c */ +GLOBAL_STATE* GetGlobalState(void); +BOOL InitHeap(void); +void CleanupHeap(void); +void TrackAllocation(PVOID Address, SIZE_T Size, const char* FileName, int LineNumber); +BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumber); +void TrackFreedMemoryLocked(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId); +BOOL CheckForDoubleFree(PVOID Address, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId); +void CleanupOldFreedEntries(void); +PVOID ExAllocatePoolWithTracking(POOL_TYPE PoolType, SIZE_T NumberOfBytes, const char* FileName, int LineNumber); +void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber); +BOOL IsValidHeapPointer(PVOID pointer); +void PrintMemoryLeaks(void); +void PrintDoubleFreeReport(void); +void SetErrorSuppression(BOOL suppress); +BOOL GetErrorSuppression(void); +void SetFreedMemoryTracking(BOOL enable); +BOOL GetFreedMemoryTracking(void); +void SetMaxFreedEntries(SIZE_T maxEntries); +SIZE_T GetMaxFreedEntries(void); #ifdef __cplusplus } @@ -119,350 +112,19 @@ extern "C" { /* External declaration of global state variable - defined in KernelHeap.c */ extern GLOBAL_STATE* g_WinKernelLite_GlobalState; -/* Global state accessor - C-Compatible version */ -__forceinline GLOBAL_STATE* GetGlobalState(void) { - if (g_WinKernelLite_GlobalState == NULL) { - HEAP_TRACE("GetGlobalState: Creating new global state"); - GLOBAL_STATE* temp = (GLOBAL_STATE*)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(GLOBAL_STATE)); - if (temp != NULL) { - HEAP_VERBOSE("GetGlobalState: Allocated global state at %p (size: %zu)", temp, sizeof(GLOBAL_STATE)); - /* CriticalSection initialization removed for performance */ - temp->HeapHandle = GetProcessHeap(); - HEAP_VERBOSE("GetGlobalState: Using heap handle %p", temp->HeapHandle); - /* Initialize the linked lists immediately when creating global state */ - InitializeListHead(&temp->MemoryAllocations); - InitializeListHead(&temp->FreedMemoryList); - /* Initialize other fields to safe defaults */ - temp->AllocationCount = 0; - temp->TotalBytesAllocated = 0; - temp->CurrentBytesAllocated = 0; - temp->PeakBytesAllocated = 0; - temp->DoubleFreeCount = 0; - temp->FreedEntryCount = 0; - temp->MaxFreedEntries = 1000; /* Default: keep track of last 1000 freed allocations */ - temp->NextAllocationId = 1; /* Start allocation IDs at 1 */ - temp->SuppressErrors = FALSE; - temp->TrackFreedMemory = TRUE; /* Enable double-free tracking by default */ - g_WinKernelLite_GlobalState = temp; - HEAP_INFO("GetGlobalState: Global state initialized successfully"); - } else { - HEAP_ERROR("GetGlobalState: Failed to allocate global state"); - } - } - return g_WinKernelLite_GlobalState; -} - -__forceinline BOOL InitHeap(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (!state || !state->HeapHandle) { - return FALSE; - } - - // Global state is already properly initialized by GetGlobalState() - // This function is now essentially a no-op but kept for compatibility - // Reset statistics if needed (this is safe to call multiple times) - // CriticalSection usage removed for performance - function is now non-thread-safe but faster - - // Only reset if we have no active allocations - if (IsListEmpty(&state->MemoryAllocations)) { - state->AllocationCount = 0; - state->TotalBytesAllocated = 0; - state->CurrentBytesAllocated = 0; - state->PeakBytesAllocated = 0; - state->DoubleFreeCount = 0; - // Don't reset freed memory tracking - it's useful to keep across test runs - state->SuppressErrors = FALSE; - state->TrackFreedMemory = TRUE; - if (state->MaxFreedEntries == 0) { - state->MaxFreedEntries = 1000; - } - } - - return TRUE; -} - -__forceinline void CleanupHeap(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - // Clean up any remaining tracking entries - PLIST_ENTRY current, next; - PMEMORY_TRACKING_ENTRY entry; - PFREED_MEMORY_ENTRY freedEntry; - - // CriticalSection usage removed for performance - function is now non-thread-safe but faster - - // Clean up allocation tracking entries - current = state->MemoryAllocations.Flink; - while (current != &state->MemoryAllocations) { - next = current->Flink; // Save next before we free current - entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); - - // Free the tracking entry (the actual memory leak will remain) - HeapFree(GetProcessHeap(), 0, entry); - - current = next; - } - - // Clean up freed memory tracking entries - current = state->FreedMemoryList.Flink; - while (current != &state->FreedMemoryList) { - next = current->Flink; // Save next before we free current - freedEntry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); - - // Free the freed memory tracking entry - HeapFree(GetProcessHeap(), 0, freedEntry); - - current = next; - } - - // CriticalSection cleanup removed for performance - HeapFree(GetProcessHeap(), 0, state); - g_WinKernelLite_GlobalState = NULL; - } -} - -__forceinline void TrackAllocation(PVOID Address, SIZE_T Size, const char* FileName, int LineNumber) { - GLOBAL_STATE* state; - PMEMORY_TRACKING_ENTRY entry; - - state = GetGlobalState(); - if (!state) { - return; - } - - // Allocate a new tracking entry using the system heap (not our tracked heap) - entry = (PMEMORY_TRACKING_ENTRY)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(MEMORY_TRACKING_ENTRY)); - if (!entry) { - if (!state->SuppressErrors) { - HEAP_ERROR("Failed to allocate memory tracking entry"); - } - return; - } - - // Initialize the tracking entry - entry->Address = Address; - entry->Size = Size; - entry->FileName = FileName; - entry->LineNumber = LineNumber; - entry->AllocationId = state->NextAllocationId++; // Assign unique ID - - // Fast path - no critical section overhead for performance - // This function is now non-thread-safe but much faster - - // Insert at the head of the list for O(1) insertion - InsertHeadList(&state->MemoryAllocations, &entry->ListEntry); - - state->AllocationCount++; - state->TotalBytesAllocated += Size; - state->CurrentBytesAllocated += Size; - if (state->CurrentBytesAllocated > state->PeakBytesAllocated) - state->PeakBytesAllocated = state->CurrentBytesAllocated; -} - -__forceinline void TrackFreedMemoryLocked(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId) { - GLOBAL_STATE* state = GetGlobalState(); - if (!state || !state->TrackFreedMemory) { - return; - } - - // Fast path - no critical section overhead for performance - // This function is now non-thread-safe but much faster - - // Allocate a new freed memory tracking entry - PFREED_MEMORY_ENTRY entry = (PFREED_MEMORY_ENTRY)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(FREED_MEMORY_ENTRY)); - - if (!entry) { - if (!state->SuppressErrors) { - HEAP_ERROR("Failed to allocate freed memory tracking entry"); - } - return; - } - - // Initialize the freed memory tracking entry - entry->Address = Address; - entry->Size = Size; - entry->AllocFileName = AllocFileName; - entry->AllocLineNumber = AllocLineNumber; - entry->FreeFileName = FreeFileName; - entry->FreeLineNumber = FreeLineNumber; - entry->ThreadId = GetCurrentThreadId(); - entry->AllocationId = AllocationId; // Store the allocation ID - GetSystemTimeAsFileTime(&entry->FreeTime); - - // Insert at the head of the freed memory list (non-thread-safe) - InsertHeadList(&state->FreedMemoryList, &entry->ListEntry); - state->FreedEntryCount++; - - // Clean up old entries if we've exceeded the maximum (non-thread-safe) - if (state->FreedEntryCount > state->MaxFreedEntries) { - CleanupOldFreedEntries(); - } -} - -__forceinline BOOL CheckForDoubleFree(PVOID Address, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId) { - GLOBAL_STATE* state; - PLIST_ENTRY current; - PFREED_MEMORY_ENTRY entry; - BOOL found = FALSE; - - if (!Address) return FALSE; // Address is required - - state = GetGlobalState(); - if (!state || !state->TrackFreedMemory) return FALSE; - - // Fast path - no critical section overhead for performance - // This function is now non-thread-safe but much faster - - // Walk the freed memory list looking for the exact address AND allocation ID match - current = state->FreedMemoryList.Flink; - while (current != &state->FreedMemoryList) { - entry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); - - // Check both address AND allocation ID for exact match - if (entry->Address == Address && entry->AllocationId == AllocationId) { - found = TRUE; - state->DoubleFreeCount++; - - if (!state->SuppressErrors) { - HEAP_ERROR("=== DOUBLE-FREE DETECTED ==="); - HEAP_ERROR("Address: %p (Size: %zu bytes, Allocation ID: %llu)", Address, entry->Size, AllocationId); - HEAP_ERROR("Originally allocated at: %s:%d", entry->AllocFileName, entry->AllocLineNumber); - HEAP_ERROR("First freed at: %s:%d (Thread: %lu)", - entry->FreeFileName, entry->FreeLineNumber, entry->ThreadId); - HEAP_ERROR("Attempted second free at: %s:%d (Thread: %lu)", - FreeFileName, FreeLineNumber, GetCurrentThreadId()); - HEAP_ERROR("============================"); - } - break; - } - - current = current->Flink; - } - - return found; -} - -__forceinline void CleanupOldFreedEntries(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (!state) return; - - // This function should be called while holding the MemoryTrackingLock - // Remove old entries from the tail of the list until we're under the limit - while (state->FreedEntryCount > state->MaxFreedEntries && !IsListEmpty(&state->FreedMemoryList)) { - PLIST_ENTRY lastEntry = state->FreedMemoryList.Blink; - PFREED_MEMORY_ENTRY entry = CONTAINING_RECORD(lastEntry, FREED_MEMORY_ENTRY, ListEntry); - - RemoveEntryList(lastEntry); - HeapFree(GetProcessHeap(), 0, entry); - state->FreedEntryCount--; - } -} - -__forceinline BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumber) { - GLOBAL_STATE* state; - PLIST_ENTRY current; - PMEMORY_TRACKING_ENTRY entry; - BOOL found = FALSE; - const char* allocFileName = "Unknown"; - int allocLineNumber = 0; - SIZE_T allocSize = 0; - ULONGLONG allocationId = 0; - - if (!Address) { - return FALSE; - } - - state = GetGlobalState(); - if (!state) { - return FALSE; - } - - // Fast path - no critical section overhead for performance - // This function is now non-thread-safe but much faster - - // Walk the linked list looking for the address - current = state->MemoryAllocations.Flink; - - while (current != &state->MemoryAllocations) { - entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); - - if (entry->Address == Address) { - // Found the entry - save allocation info for freed memory tracking - allocFileName = entry->FileName; - allocLineNumber = entry->LineNumber; - allocSize = entry->Size; - allocationId = entry->AllocationId; - - // Remove it from the allocation list - RemoveEntryList(&entry->ListEntry); - - // Update statistics - state->CurrentBytesAllocated -= entry->Size; - - // Free the tracking entry itself (using system heap) - HeapFree(GetProcessHeap(), 0, entry); - - found = TRUE; - break; - } - - current = current->Flink; - } - - // If we found and removed the allocation, track it as freed memory - if (found && state->TrackFreedMemory) { - TrackFreedMemoryLocked(Address, allocSize, allocFileName, allocLineNumber, FreeFileName, FreeLineNumber, allocationId); - } - - return found; -} - -__forceinline PVOID ExAllocatePoolWithTracking(POOL_TYPE PoolType, SIZE_T NumberOfBytes, const char* FileName, int LineNumber) { - GLOBAL_STATE* state; - PVOID ptr; - - UNREFERENCED_PARAMETER(PoolType); // Mark as unused parameter to fix C4100 warning - - state = GetGlobalState(); - if (!state) { - return NULL; - } - - // Fast path - no heap validation overhead for performance - ptr = HeapAlloc(state->HeapHandle, 0, NumberOfBytes); - - // Return NULL if memory allocation failed - if (ptr == NULL) { - if (!state->SuppressErrors) { - HEAP_ERROR("Memory allocation failed for %zu bytes", NumberOfBytes); - } - return NULL; - } - - // Track the allocation (non-thread-safe but fast) - TrackAllocation(ptr, NumberOfBytes, FileName, LineNumber); - - return ptr; -} +/* Thin inline wrappers - these match the real WDK pattern of simple delegation */ __forceinline PVOID ExAllocatePool(POOL_TYPE PoolType, SIZE_T NumberOfBytes) { - // PoolType parameter is passed to ExAllocatePoolWithTracking for compatibility, but not used there return ExAllocatePoolWithTracking(PoolType, NumberOfBytes, "Unknown", 0); } __forceinline PVOID ExAllocatePoolWithTag(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag) { - // The Tag parameter is ignored in this implementation since we're in user mode - // It's only used for kernel-mode debugging and memory tracking - UNREFERENCED_PARAMETER(Tag); // Mark as unused parameter to fix C4100 warning - + UNREFERENCED_PARAMETER(Tag); return ExAllocatePool(PoolType, NumberOfBytes); } __forceinline PVOID ExAllocatePoolZeroWithTag(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag) { - // Allocate memory and initialize it to zeros - // PoolType and Tag are passed along but aren't actually used in our implementation - UNREFERENCED_PARAMETER(Tag); // Mark as unused parameter to fix C4100 warning - + UNREFERENCED_PARAMETER(Tag); PVOID memory = ExAllocatePoolWithTag(PoolType, NumberOfBytes, Tag); if (memory) { ZeroMemory(memory, NumberOfBytes); @@ -471,10 +133,7 @@ __forceinline PVOID ExAllocatePoolZeroWithTag(POOL_TYPE PoolType, SIZE_T NumberO } __forceinline PVOID _ExAllocatePoolZeroWithTagTracked(POOL_TYPE PoolType, SIZE_T NumberOfBytes, ULONG Tag, const char* FileName, int LineNumber) { - // Allocate memory with tracking and initialize it to zeros - // PoolType and Tag parameters are not used in this implementation - UNREFERENCED_PARAMETER(Tag); // Mark as unused parameter to fix C4100 warning - + UNREFERENCED_PARAMETER(Tag); PVOID memory = ExAllocatePoolWithTracking(PoolType, NumberOfBytes, FileName, LineNumber); if (memory) { ZeroMemory(memory, NumberOfBytes); @@ -482,292 +141,15 @@ __forceinline PVOID _ExAllocatePoolZeroWithTagTracked(POOL_TYPE PoolType, SIZE_T return memory; } -__forceinline BOOL IsValidHeapPointer(PVOID pointer) { - GLOBAL_STATE* state = GetGlobalState(); - if (!state || !pointer) return FALSE; - - // Fast path - no critical section overhead for performance - // This function is now non-thread-safe but much faster - - // Walk the linked list looking for the address - PLIST_ENTRY current = state->MemoryAllocations.Flink; - while (current != &state->MemoryAllocations) { - PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); - - if (entry->Address == pointer) { - return TRUE; - } - - current = current->Flink; - } - - return FALSE; -} - -inline void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber) { - GLOBAL_STATE* state; - BOOL found; - BOOL isDoubleFree = FALSE; - ULONGLONG currentAllocationId = 0; - - HEAP_TRACE("_ExFreePoolWithTracking: Freeing %p from %s:%d", pointer, - FileName ? FileName : "Unknown", LineNumber); - - if (!pointer) { - HEAP_TRACE("_ExFreePoolWithTracking: NULL pointer provided, returning"); - return; - } - - state = GetGlobalState(); - if (!state) { - HEAP_ERROR("_ExFreePoolWithTracking: Failed to get global state"); - return; - } - - HEAP_VERBOSE("_ExFreePoolWithTracking: State info - heap handle: %p, TrackFreedMemory: %d", - state->HeapHandle, state->TrackFreedMemory); - - // First, get the allocation ID for this pointer from the active allocations list - if (state->TrackFreedMemory) { - PLIST_ENTRY current = state->MemoryAllocations.Flink; - while (current != &state->MemoryAllocations) { - PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); - if (entry->Address == pointer) { - currentAllocationId = entry->AllocationId; - break; - } - current = current->Flink; - } - - // Only check for double-free if we found a current allocation ID - if (currentAllocationId != 0) { - HEAP_TRACE("_ExFreePoolWithTracking: Checking for double-free of %p (ID: %llu)", pointer, currentAllocationId); - isDoubleFree = CheckForDoubleFree(pointer, FileName, LineNumber, currentAllocationId); - if (isDoubleFree) { - HEAP_WARN("_ExFreePoolWithTracking: Double-free detected for %p, returning without freeing", pointer); - // Don't actually free the memory again - just return immediately - return; - } - HEAP_TRACE("_ExFreePoolWithTracking: No double-free detected for %p", pointer); - } - } - - // Try to untrack the allocation - this will return TRUE if it was found and removed - HEAP_TRACE("_ExFreePoolWithTracking: Calling UntrackAllocation for %p", pointer); - found = UntrackAllocation(pointer, FileName, LineNumber); - - HEAP_VERBOSE("_ExFreePoolWithTracking: UntrackAllocation returned %d for %p", found, pointer); - - if (!found && !state->SuppressErrors) { - HEAP_WARN("_ExFreePoolWithTracking: Address %p not found in tracking, validating heap pointer", pointer); - // Check if it's a valid heap pointer that we just don't track - BOOL isValidHeapPointer = FALSE; - __try { - // Safer check: try to validate with Windows heap functions - isValidHeapPointer = HeapValidate(state->HeapHandle, 0, pointer); - HEAP_VERBOSE("_ExFreePoolWithTracking: HeapValidate returned %d for %p", isValidHeapPointer, pointer); - } - __except (EXCEPTION_EXECUTE_HANDLER) { - HEAP_ERROR("_ExFreePoolWithTracking: Exception during HeapValidate for %p: 0x%08X", - pointer, GetExceptionCode()); - isValidHeapPointer = FALSE; - } - - if (isValidHeapPointer) { - HEAP_WARN("Attempting to free untracked but valid heap memory at %p from %s:%d", - pointer, FileName, LineNumber); - } else { - HEAP_WARN("Attempting to free invalid memory pointer at %p from %s:%d", - pointer, FileName, LineNumber); - } - } - - // Free the memory only if: - // 1. It's not a double-free AND - // 2. Either we found it in our tracking (meaning it's valid) OR it's an untracked but valid heap pointer - if (!isDoubleFree) { - HEAP_TRACE("_ExFreePoolWithTracking: Calling HeapFree for %p", pointer); - __try { - HeapFree(state->HeapHandle, 0, pointer); - HEAP_VERBOSE("_ExFreePoolWithTracking: Successfully freed %p", pointer); - } - __except (EXCEPTION_EXECUTE_HANDLER) { - HEAP_ERROR("_ExFreePoolWithTracking: Exception occurred while freeing memory at %p from %s:%d (Exception: 0x%08X)", - pointer, FileName, LineNumber, GetExceptionCode()); - if (!state->SuppressErrors) { - HEAP_ERROR("Exception occurred while freeing memory at %p from %s:%d (Exception: 0x%08X)", - pointer, FileName, LineNumber, GetExceptionCode()); - } - } - } -} - __forceinline void ExFreePool(PVOID pointer) { _ExFreePoolWithTracking(pointer, "Unknown", 0); } __forceinline void ExFreePoolWithTag(PVOID pointer, ULONG Tag) { - // The Tag parameter is ignored in this implementation since we're in user mode - // It's only used for kernel-mode debugging and memory tracking - UNREFERENCED_PARAMETER(Tag); // Mark as unused parameter to fix C4100 warning + UNREFERENCED_PARAMETER(Tag); ExFreePool(pointer); } -__forceinline void PrintMemoryLeaks(void) { - GLOBAL_STATE* state; - BOOL foundLeaks; - SIZE_T leakCount; - SIZE_T leakBytes; - PLIST_ENTRY current; - PMEMORY_TRACKING_ENTRY entry; - - state = GetGlobalState(); - if (!state) return; - - foundLeaks = FALSE; - leakCount = 0; - leakBytes = 0; - - // CriticalSection usage removed for performance - function is now non-thread-safe but faster - - HEAP_INFO("=== MEMORY LEAK REPORT ==="); - - // Walk the linked list of allocations - current = state->MemoryAllocations.Flink; - while (current != &state->MemoryAllocations) { - entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); - - if (!foundLeaks) { - HEAP_INFO("Address | Size | Allocation Location"); - HEAP_INFO("------------- | -------- | ------------------"); - foundLeaks = TRUE; - } - - HEAP_INFO("%p | %8d | %s:%d", - entry->Address, - (int)entry->Size, - entry->FileName, - entry->LineNumber); - - leakCount++; - leakBytes += entry->Size; - - current = current->Flink; - } - - if (foundLeaks) { - HEAP_INFO("Total: %d leaks, %d bytes", (int)leakCount, (int)leakBytes); - } else { - HEAP_INFO("No memory leaks detected!"); - } - - HEAP_INFO("Memory usage statistics:"); - HEAP_INFO(" Total allocations: %d", (int)state->AllocationCount); - HEAP_INFO(" Total bytes allocated: %d", (int)state->TotalBytesAllocated); - HEAP_INFO(" Peak bytes allocated: %d", (int)state->PeakBytesAllocated); - HEAP_INFO(" Double-free attempts: %d", (int)state->DoubleFreeCount); - HEAP_INFO(" Freed entries tracked: %d", (int)state->FreedEntryCount); - HEAP_INFO("==========================="); -} - -__forceinline void PrintDoubleFreeReport(void) { - GLOBAL_STATE* state; - PLIST_ENTRY current; - PFREED_MEMORY_ENTRY entry; - SIZE_T entryCount = 0; - - state = GetGlobalState(); - if (!state) return; - - // CriticalSection usage removed for performance - function is now non-thread-safe but faster - - HEAP_INFO("=== FREED MEMORY REPORT ==="); - HEAP_INFO("Total double-free attempts detected: %d", (int)state->DoubleFreeCount); - HEAP_INFO("Currently tracking %d freed allocations", (int)state->FreedEntryCount); - HEAP_INFO("Maximum freed entries to track: %d", (int)state->MaxFreedEntries); - - if (state->FreedEntryCount > 0) { - HEAP_INFO("Recent freed allocations:"); - HEAP_INFO("Address | Size | Alloc ID | Alloc Location | Free Location | Thread"); - HEAP_INFO("------------- | -------- | -------- | --------------- | --------------- | ------"); - - // Walk the freed memory list (showing most recent first) - current = state->FreedMemoryList.Flink; - while (current != &state->FreedMemoryList && entryCount < 20) // Limit output to 20 entries - { - entry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); - - HEAP_INFO("%p | %8d | %8llu | %15s:%-4d | %15s:%-4d | %6lu", - entry->Address, - (int)entry->Size, - entry->AllocationId, - entry->AllocFileName, entry->AllocLineNumber, - entry->FreeFileName, entry->FreeLineNumber, - entry->ThreadId); - - entryCount++; - current = current->Flink; - } - - if (state->FreedEntryCount > 20) { - HEAP_INFO("... and %d more entries", (int)(state->FreedEntryCount - 20)); - } - } - - HEAP_INFO("=============================="); -} - -__forceinline void SetErrorSuppression(BOOL suppress) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - state->SuppressErrors = suppress; - } -} - -__forceinline BOOL GetErrorSuppression(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - return state->SuppressErrors; - } - return FALSE; -} - -__forceinline void SetFreedMemoryTracking(BOOL enable) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - state->TrackFreedMemory = enable; - } -} - -__forceinline BOOL GetFreedMemoryTracking(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - return state->TrackFreedMemory; - } - return FALSE; -} - -__forceinline void SetMaxFreedEntries(SIZE_T maxEntries) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - // CriticalSection usage removed for performance - function is now non-thread-safe but faster - state->MaxFreedEntries = maxEntries; - - // Clean up excess entries if needed - if (state->FreedEntryCount > maxEntries) { - CleanupOldFreedEntries(); - } - } -} - -__forceinline SIZE_T GetMaxFreedEntries(void) { - GLOBAL_STATE* state = GetGlobalState(); - if (state) { - return state->MaxFreedEntries; - } - return 0; -} - /* Macro definitions for automatic file and line capture */ #define ExAllocatePoolTracked(PoolType, NumberOfBytes) \ ExAllocatePoolWithTracking(PoolType, NumberOfBytes, __FILE__, __LINE__) @@ -815,4 +197,3 @@ __forceinline SIZE_T GetMaxFreedEntries(void) { #endif #endif /* WINKERNEL_KERNELHEAP_H_ */ - diff --git a/include/Registry.h b/include/Registry.h index 88a6a70..13e3681 100644 --- a/include/Registry.h +++ b/include/Registry.h @@ -131,7 +131,7 @@ typedef struct _KEY_VALUE_PARTIAL_INFORMATION { /** * @brief User mode implementation of ZwEnumerateKey * Enumerates registry key information - * + * * @param KeyHandle Handle to the registry key * @param Index Index of the subkey to enumerate * @param KeyInformationClass Type of information to retrieve @@ -140,7 +140,7 @@ typedef struct _KEY_VALUE_PARTIAL_INFORMATION { * @param ResultLength Pointer to receive the size of data written * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwEnumerateKey( +NTSTATUS ZwEnumerateKey( IN HANDLE KeyHandle, IN ULONG Index, IN KEY_INFORMATION_CLASS KeyInformationClass, @@ -152,18 +152,32 @@ inline NTSTATUS ZwEnumerateKey( /** * @brief User mode implementation of ZwClose * Closes an object handle - * + * * @param Handle Handle to close * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwClose( +NTSTATUS ZwClose( IN HANDLE Handle ); +/** + * @brief Helper function to parse a full registry path into root key and subkey path + * + * @param fullPath The full registry path (e.g., L"\\Registry\\Machine\\Software\\Test") + * @param rootKey Pointer to receive the root key handle + * @param subKeyPath Pointer to receive the subkey path + * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code + */ +NTSTATUS ParseRegistryPath( + IN PCWSTR fullPath, + OUT HKEY* rootKey, + OUT PCWSTR* subKeyPath +); + /** * @brief User mode implementation of ZwCreateKey * Creates or opens a registry key - * + * * @param KeyHandle Pointer to receive the handle to the key * @param DesiredAccess The access mask for the key * @param ObjectAttributes Attributes for the key @@ -173,7 +187,7 @@ inline NTSTATUS ZwClose( * @param Disposition Optional pointer to receive creation disposition * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwCreateKey( +NTSTATUS ZwCreateKey( OUT PHANDLE KeyHandle, IN ACCESS_MASK DesiredAccess, IN POBJECT_ATTRIBUTES ObjectAttributes, @@ -186,7 +200,7 @@ inline NTSTATUS ZwCreateKey( /** * @brief User mode implementation of ZwSetValueKey * Sets a registry value - * + * * @param KeyHandle Handle to the registry key * @param ValueName Name of the value to set * @param TitleIndex Reserved, must be zero @@ -195,7 +209,7 @@ inline NTSTATUS ZwCreateKey( * @param DataSize Size of the value data in bytes * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwSetValueKey( +NTSTATUS ZwSetValueKey( IN HANDLE KeyHandle, IN PUNICODE_STRING ValueName, IN ULONG TitleIndex, @@ -207,13 +221,13 @@ inline NTSTATUS ZwSetValueKey( /** * @brief User mode implementation of ZwOpenKey * Opens an existing registry key - * + * * @param KeyHandle Pointer to receive the handle to the key * @param DesiredAccess The access mask for the key * @param ObjectAttributes Attributes for the key * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwOpenKey( +NTSTATUS ZwOpenKey( OUT PHANDLE KeyHandle, IN ACCESS_MASK DesiredAccess, IN POBJECT_ATTRIBUTES ObjectAttributes @@ -222,7 +236,7 @@ inline NTSTATUS ZwOpenKey( /** * @brief User mode implementation of ZwQueryValueKey * Queries registry key value information - * + * * @param KeyHandle Handle to the registry key * @param ValueName Name of the value to query * @param KeyValueInformationClass Type of information to retrieve @@ -231,7 +245,7 @@ inline NTSTATUS ZwOpenKey( * @param ResultLength Pointer to receive the size of data written or needed * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwQueryValueKey( +NTSTATUS ZwQueryValueKey( IN HANDLE KeyHandle, IN PUNICODE_STRING ValueName, IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, @@ -243,7 +257,7 @@ inline NTSTATUS ZwQueryValueKey( /** * @brief User mode implementation of ZwEnumerateValueKey * Enumerates value entries for a registry key - * + * * @param KeyHandle Handle to the registry key * @param Index Index of the value entry to enumerate * @param KeyValueInformationClass Type of information to retrieve @@ -252,7 +266,7 @@ inline NTSTATUS ZwQueryValueKey( * @param ResultLength Pointer to receive the size of data written * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code */ -inline NTSTATUS ZwEnumerateValueKey( +NTSTATUS ZwEnumerateValueKey( IN HANDLE KeyHandle, IN ULONG Index, IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, @@ -291,1055 +305,111 @@ inline NTSTATUS ZwEnumerateValueKey( #define REG_OPENED_EXISTING_KEY 0x00000002L #endif -#ifdef __cplusplus -} -#endif - -// Implementation of user mode registry functions - -inline NTSTATUS ZwEnumerateKey( - IN HANDLE KeyHandle, - IN ULONG Index, - IN KEY_INFORMATION_CLASS KeyInformationClass, - OUT PVOID KeyInformation, - IN ULONG Length, - OUT PULONG ResultLength -) -{ - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!KeyInformation) { - return STATUS_INVALID_PARAMETER; - } - if (!ResultLength) { - return STATUS_INVALID_PARAMETER; - } - - WCHAR tempKeyName[1024]; - DWORD keyNameSize = sizeof(tempKeyName) / sizeof(WCHAR); - FILETIME lastWriteTime; - - LSTATUS status = RegEnumKeyExW( - (HKEY)KeyHandle, - Index, - tempKeyName, - &keyNameSize, - NULL, // Reserved - NULL, // Class name (not used) - NULL, // Class name length - &lastWriteTime // Last write time - ); - - // Convert the result to the requested information format - if (status == ERROR_SUCCESS) { - ULONG nameLength = (ULONG)(keyNameSize * sizeof(WCHAR)); - ULONG requiredSize = 0; - - switch (KeyInformationClass) { - case KeyBasicInformation: { - requiredSize = sizeof(KEY_BASIC_INFORMATION) + nameLength; - if (Length < requiredSize) { - if (ResultLength) *ResultLength = requiredSize; - return STATUS_BUFFER_TOO_SMALL; - } - - PKEY_BASIC_INFORMATION basicInfo = (PKEY_BASIC_INFORMATION)KeyInformation; - basicInfo->LastWriteTime.QuadPart = ((LARGE_INTEGER*)&lastWriteTime)->QuadPart; - basicInfo->TitleIndex = 0; - basicInfo->NameLength = nameLength; - memcpy_s(basicInfo->Name, nameLength, tempKeyName, nameLength); - - if (ResultLength) *ResultLength = requiredSize; - break; - } - - case KeyNodeInformation: { - requiredSize = sizeof(KEY_NODE_INFORMATION) + nameLength; - if (Length < requiredSize) { - if (ResultLength) *ResultLength = requiredSize; - return STATUS_BUFFER_TOO_SMALL; - } - - PKEY_NODE_INFORMATION nodeInfo = (PKEY_NODE_INFORMATION)KeyInformation; - nodeInfo->LastWriteTime.QuadPart = ((LARGE_INTEGER*)&lastWriteTime)->QuadPart; - nodeInfo->TitleIndex = 0; - nodeInfo->ClassLength = 0; - nodeInfo->ClassOffset = 0; - nodeInfo->NameLength = nameLength; - memcpy_s(nodeInfo->Name, nameLength, tempKeyName, nameLength); - - if (ResultLength) *ResultLength = requiredSize; - break; - } - - case KeyFullInformation: - default: - // Unsupported information class - return STATUS_INVALID_PARAMETER; - } - - return STATUS_SUCCESS; - } else if (status == ERROR_NO_MORE_ITEMS) { - return STATUS_NO_MORE_ENTRIES; - } else if (status == ERROR_MORE_DATA) { - return STATUS_BUFFER_TOO_SMALL; - } else { - return STATUS_INVALID_PARAMETER; - } -} - -inline NTSTATUS ZwClose( - IN HANDLE Handle -) -{ - // Check for obviously invalid handles upfront - if (Handle == NULL || Handle == INVALID_HANDLE_VALUE) { - return STATUS_INVALID_PARAMETER; - } - - // Check if the handle is valid using GetHandleInformation - DWORD flags = 0; - if (!GetHandleInformation(Handle, &flags)) { - return STATUS_INVALID_HANDLE; - } - - // Handle is valid, so close it - if (CloseHandle(Handle)) { - return STATUS_SUCCESS; - } else { - // This should be unreachable if GetHandleInformation succeeded - return STATUS_INVALID_HANDLE; - } -} - /** - * @brief Helper function to parse a full registry path into root key and subkey path - * - * @param fullPath The full registry path (e.g., L"\\Registry\\Machine\\Software\\Test") - * @param rootKey Pointer to receive the root key handle - * @param subKeyPath Pointer to receive the subkey path - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code + * @brief Deletes a registry key (empty keys only). */ -inline NTSTATUS ParseRegistryPath( - IN PCWSTR fullPath, - OUT HKEY* rootKey, - OUT PCWSTR* subKeyPath -) -{ - if (!fullPath || !rootKey || !subKeyPath) { - return STATUS_INVALID_PARAMETER; - } - - // Check for known registry path prefixes - if (wcsncmp(fullPath, L"\\Registry\\Machine\\", 18) == 0) { - *rootKey = HKEY_LOCAL_MACHINE; - *subKeyPath = fullPath + 18; // Skip "\\Registry\\Machine\\" - return STATUS_SUCCESS; - } else if (wcsncmp(fullPath, L"\\Registry\\User\\", 15) == 0) { - *rootKey = HKEY_USERS; - *subKeyPath = fullPath + 15; // Skip "\\Registry\\User\\" - return STATUS_SUCCESS; - } else if (wcsncmp(fullPath, L"HKEY_LOCAL_MACHINE\\", 19) == 0) { - *rootKey = HKEY_LOCAL_MACHINE; - *subKeyPath = fullPath + 19; // Skip "HKEY_LOCAL_MACHINE\\" - return STATUS_SUCCESS; - } else if (wcsncmp(fullPath, L"HKEY_CURRENT_USER\\", 18) == 0) { - *rootKey = HKEY_CURRENT_USER; - *subKeyPath = fullPath + 18; // Skip "HKEY_CURRENT_USER\\" - return STATUS_SUCCESS; - } else if (wcsncmp(fullPath, L"HKEY_USERS\\", 11) == 0) { - *rootKey = HKEY_USERS; - *subKeyPath = fullPath + 11; // Skip "HKEY_USERS\\" - return STATUS_SUCCESS; - } else if (wcsncmp(fullPath, L"HKEY_CLASSES_ROOT\\", 18) == 0) { - *rootKey = HKEY_CLASSES_ROOT; - *subKeyPath = fullPath + 18; // Skip "HKEY_CLASSES_ROOT\\" - return STATUS_SUCCESS; - } - - // Unknown registry path format - return STATUS_OBJECT_PATH_SYNTAX_BAD; -} - -inline NTSTATUS ZwCreateKey( - OUT PHANDLE KeyHandle, - IN ACCESS_MASK DesiredAccess, - IN POBJECT_ATTRIBUTES ObjectAttributes, - IN ULONG TitleIndex, - IN PUNICODE_STRING Class OPTIONAL, - IN ULONG CreateOptions, - OUT PULONG Disposition OPTIONAL -) -{ - // Validate parameters - UNREFERENCED_PARAMETER(TitleIndex); // Mark as unused parameter to fix C4100 warning - - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!ObjectAttributes) { - return STATUS_INVALID_PARAMETER; - } - if (!ObjectAttributes->ObjectName) { - return STATUS_INVALID_PARAMETER; - } - - HKEY rootKey = NULL; - PCWSTR keyPath = NULL; - DWORD dispositionValue = 0; - LSTATUS status; - NTSTATUS ntStatus; - - // Get root key and subkey path from ObjectAttributes - if (ObjectAttributes->RootDirectory) { - // Root directory provided - use it with relative path - rootKey = (HKEY)ObjectAttributes->RootDirectory; - keyPath = ObjectAttributes->ObjectName->Buffer; - } else { - // No root directory - parse full path from ObjectName - ntStatus = ParseRegistryPath( - ObjectAttributes->ObjectName->Buffer, - &rootKey, - &keyPath - ); - if (!NT_SUCCESS(ntStatus)) { - return ntStatus; - } - } - - status = RegCreateKeyExW( - rootKey, - keyPath, - 0, // Reserved - Class ? Class->Buffer : NULL, // Class - CreateOptions, // Options - DesiredAccess, // Access - NULL, // Security attributes not supported - (PHKEY)KeyHandle, - &dispositionValue - ); - - // Pass back the disposition if requested - if (Disposition) { - *Disposition = dispositionValue; - } - - // Convert Windows error to NTSTATUS - if (status == ERROR_SUCCESS) { - return STATUS_SUCCESS; - } else if (status == ERROR_ACCESS_DENIED) { - return STATUS_ACCESS_DENIED; - } else if (status == ERROR_INVALID_PARAMETER) { - return STATUS_INVALID_PARAMETER; - } else { - return STATUS_UNSUCCESSFUL; - } -} - -inline NTSTATUS ZwSetValueKey( - IN HANDLE KeyHandle, - IN PUNICODE_STRING ValueName, - IN ULONG TitleIndex, - IN ULONG Type, - IN PVOID Data, - IN ULONG DataSize -) -{ - // Validate parameters - UNREFERENCED_PARAMETER(TitleIndex); // Mark as unused parameter to fix C4100 warning - - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!ValueName) { - return STATUS_INVALID_PARAMETER; - } - if (!ValueName->Buffer) { - return STATUS_INVALID_PARAMETER; - } - if (!Data && DataSize > 0) { - return STATUS_INVALID_PARAMETER; - } - - // avoid potential problems with the strings not being null-terminated as UNICODE_STRING is not required to have null terminated buffer - // and we don't know how exactly this works in the Win kernel - ULONG nameChars = ValueName->Length / sizeof(WCHAR); - WCHAR* nullTerminatedName = (WCHAR*)HeapAlloc(GetProcessHeap(), 0, (nameChars + 1) * sizeof(WCHAR)); - if (!nullTerminatedName) { - return STATUS_NO_MEMORY; - } - - // Copy the name and add null terminator - memcpy_s(nullTerminatedName, ValueName->Length, ValueName->Buffer, ValueName->Length); - nullTerminatedName[nameChars] = L'\0'; - - LSTATUS status = RegSetValueExW( - (HKEY)KeyHandle, - nullTerminatedName, - 0, // Reserved - Type, - (CONST BYTE*)Data, - DataSize - ); - - HeapFree(GetProcessHeap(), 0, nullTerminatedName); - - if (status == ERROR_SUCCESS) { - return STATUS_SUCCESS; - } else if (status == ERROR_ACCESS_DENIED) { - return STATUS_ACCESS_DENIED; - } else if (status == ERROR_INVALID_PARAMETER) { - return STATUS_INVALID_PARAMETER; - } else { - return STATUS_UNSUCCESSFUL; - } -} - -inline NTSTATUS ZwOpenKey( - OUT PHANDLE KeyHandle, - IN ACCESS_MASK DesiredAccess, - IN POBJECT_ATTRIBUTES ObjectAttributes -) -{ - // Validate parameters - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!ObjectAttributes) { - return STATUS_INVALID_PARAMETER; - } - if (!ObjectAttributes->ObjectName) { - return STATUS_INVALID_PARAMETER; - } - if (!ObjectAttributes->ObjectName->Buffer) { - return STATUS_INVALID_PARAMETER; - } - - HKEY rootKey = NULL; - PCWSTR keyPath = NULL; - LSTATUS status; - NTSTATUS ntStatus; - - // Get root key and subkey path from ObjectAttributes - if (ObjectAttributes->RootDirectory) { - // Root directory provided - use it with relative path - rootKey = (HKEY)ObjectAttributes->RootDirectory; - keyPath = ObjectAttributes->ObjectName->Buffer; - } else { - // No root directory - parse full path from ObjectName or assume HKEY_CURRENT_USER - ntStatus = ParseRegistryPath( - ObjectAttributes->ObjectName->Buffer, - &rootKey, - &keyPath - ); - if (!NT_SUCCESS(ntStatus)) { - // If path parsing fails, assume it's a relative path under HKEY_CURRENT_USER - rootKey = HKEY_CURRENT_USER; - keyPath = ObjectAttributes->ObjectName->Buffer; - } - } - - status = RegOpenKeyExW( - rootKey, - keyPath, - 0, // Reserved - DesiredAccess, // Access - (PHKEY)KeyHandle - ); - - // Convert Windows error to NTSTATUS - if (status == ERROR_SUCCESS) { - return STATUS_SUCCESS; - } else if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else if (status == ERROR_PATH_NOT_FOUND) { - return STATUS_OBJECT_PATH_NOT_FOUND; - } else if (status == ERROR_ACCESS_DENIED) { - return STATUS_ACCESS_DENIED; - } else { - return STATUS_UNSUCCESSFUL; - } -} - -inline NTSTATUS ZwQueryValueKey( - IN HANDLE KeyHandle, - IN PUNICODE_STRING ValueName, - IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, - OUT PVOID KeyValueInformation, - IN ULONG Length, - OUT PULONG ResultLength -) -{ - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!ValueName) { - return STATUS_INVALID_PARAMETER; - } - if (!ValueName->Buffer) { - return STATUS_INVALID_PARAMETER; - } - if (!KeyValueInformation) { - return STATUS_INVALID_PARAMETER; - } - if (!ResultLength) { - return STATUS_INVALID_PARAMETER; - } - - DWORD dwType; - DWORD dwDataSize = 0; - LSTATUS status; - - switch (KeyValueInformationClass) { - case KeyValuePartialInformation: - { - PKEY_VALUE_PARTIAL_INFORMATION partialInfo = (PKEY_VALUE_PARTIAL_INFORMATION)KeyValueInformation; - - // First call to get required buffer size - status = RegQueryValueExW( - (HKEY)KeyHandle, - ValueName->Buffer, - NULL, - &dwType, - NULL, - &dwDataSize - ); - - if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { - if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else { - return STATUS_UNSUCCESSFUL; - } - } +NTSTATUS ZwDeleteKey(HANDLE KeyHandle); - // Calculate required size for the structure - ULONG requiredSize = sizeof(KEY_VALUE_PARTIAL_INFORMATION) + dwDataSize - sizeof(BYTE); - *ResultLength = requiredSize; - - if (Length < requiredSize) { - return STATUS_BUFFER_TOO_SMALL; - } - - // Setup the structure fields - partialInfo->TitleIndex = 0; - partialInfo->Type = dwType; - partialInfo->DataLength = dwDataSize; - - status = RegQueryValueExW( - (HKEY)KeyHandle, - ValueName->Buffer, - NULL, - &dwType, - (LPBYTE)partialInfo->Data, - &dwDataSize - ); - - if (status == ERROR_SUCCESS) { - return STATUS_SUCCESS; - } else if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else if (status == ERROR_MORE_DATA) { - return STATUS_BUFFER_TOO_SMALL; - } else { - return STATUS_UNSUCCESSFUL; - } - } - - case KeyValueFullInformation: - { - PKEY_VALUE_FULL_INFORMATION fullInfo = (PKEY_VALUE_FULL_INFORMATION)KeyValueInformation; - // First call to get required buffer size - status = RegQueryValueExW( - (HKEY)KeyHandle, - ValueName->Buffer, - NULL, - &dwType, - NULL, - &dwDataSize - ); - - if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { - if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else if (status == ERROR_MORE_DATA) { - return STATUS_BUFFER_TOO_SMALL; - } else { - return STATUS_UNSUCCESSFUL; - } - } - - if (Length < sizeof(KEY_VALUE_FULL_INFORMATION)) { - *ResultLength = sizeof(KEY_VALUE_FULL_INFORMATION) + dwDataSize; - return STATUS_BUFFER_TOO_SMALL; - } - - // Setup the structure fields - fullInfo->TitleIndex = 0; - fullInfo->Type = dwType; - fullInfo->NameLength = ValueName->Length; - fullInfo->DataLength = dwDataSize; - fullInfo->DataOffset = FIELD_OFFSET(KEY_VALUE_FULL_INFORMATION, Name) + - ValueName->Length + sizeof(WCHAR); - - // Make sure we have enough space for name and data - if (Length < fullInfo->DataOffset + dwDataSize) { - *ResultLength = fullInfo->DataOffset + dwDataSize; - return STATUS_BUFFER_TOO_SMALL; - } - - // Copy the name - memcpy_s(fullInfo->Name, ValueName->Length + sizeof(WCHAR), ValueName->Buffer, ValueName->Length); - ((PWSTR)((PBYTE)fullInfo->Name + ValueName->Length))[0] = L'\0'; - - - status = RegQueryValueExW( - (HKEY)KeyHandle, - ValueName->Buffer, - NULL, - &dwType, - (LPBYTE)((PBYTE)KeyValueInformation + fullInfo->DataOffset), - &dwDataSize - ); - - *ResultLength = fullInfo->DataOffset + dwDataSize; - - if (status == ERROR_SUCCESS) { - return STATUS_SUCCESS; - } else if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else if (status == ERROR_MORE_DATA) { - return STATUS_BUFFER_TOO_SMALL; - } else { - return STATUS_UNSUCCESSFUL; - } - } - - case KeyValueBasicInformation: - { - PKEY_VALUE_BASIC_INFORMATION basicInfo = (PKEY_VALUE_BASIC_INFORMATION)KeyValueInformation; - - // First call to get required buffer size and type - status = RegQueryValueExW( - (HKEY)KeyHandle, - ValueName->Buffer, - NULL, - &dwType, - NULL, - &dwDataSize - ); - - if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { - if (status == ERROR_FILE_NOT_FOUND) { - return STATUS_OBJECT_NAME_NOT_FOUND; - } else { - return STATUS_UNSUCCESSFUL; - } - } - - // Calculate required size for the structure - ULONG requiredSize = sizeof(KEY_VALUE_BASIC_INFORMATION) + ValueName->Length; - *ResultLength = requiredSize; - - if (Length < requiredSize) { - return STATUS_BUFFER_TOO_SMALL; - } - - // Setup the structure fields - basicInfo->TitleIndex = 0; - basicInfo->Type = dwType; - basicInfo->NameLength = ValueName->Length; - - // Copy the name - memcpy_s(basicInfo->Name, ValueName->Length, ValueName->Buffer, ValueName->Length); - - return STATUS_SUCCESS; - } - - default: - return STATUS_INVALID_PARAMETER; - } -} - -inline NTSTATUS ZwEnumerateValueKey( - IN HANDLE KeyHandle, - IN ULONG Index, - IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, - OUT PVOID KeyValueInformation, - IN ULONG Length, - OUT PULONG ResultLength -) -{ - if (!KeyHandle) { - return STATUS_INVALID_PARAMETER; - } - if (!KeyValueInformation) { - return STATUS_INVALID_PARAMETER; - } - if (!ResultLength) { - return STATUS_INVALID_PARAMETER; - } - - // Status variable is not used in this function but kept for consistency with original API - // NTSTATUS status = STATUS_SUCCESS; - HKEY hKey = (HKEY)KeyHandle; - DWORD maxValueNameLen = 0, maxValueDataLen = 0; - WCHAR valueName[MAX_PATH] = { 0 }; - DWORD valueNameLen = MAX_PATH; - DWORD type = 0; - BYTE* data = NULL; - DWORD dataSize = 0; - - // Get value name and data size - LSTATUS winStatus = RegQueryInfoKeyW( - hKey, - NULL, - NULL, - NULL, - NULL, - NULL, - NULL, - NULL, - &maxValueNameLen, - &maxValueDataLen, - NULL, - NULL - ); - if (winStatus != ERROR_SUCCESS) { - return winStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; - } - - // Get value name, type and data - valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW - winStatus = RegEnumValueW( - hKey, - Index, - valueName, - &valueNameLen, - NULL, - &type, - NULL, - &dataSize - ); - if (winStatus != ERROR_SUCCESS) { - return winStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : - winStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : - winStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; - } - - // Fill in the information based on the requested class - switch (KeyValueInformationClass) { - case KeyValueBasicInformation: { - PKEY_VALUE_BASIC_INFORMATION pBasic = (PKEY_VALUE_BASIC_INFORMATION)KeyValueInformation; - size_t requiredSize = sizeof(KEY_VALUE_BASIC_INFORMATION) + (valueNameLen * sizeof(WCHAR)); - *ResultLength = (ULONG)requiredSize; - - if (Length < requiredSize) { - return STATUS_BUFFER_TOO_SMALL; - } - - pBasic->TitleIndex = 0; - pBasic->Type = type; - pBasic->NameLength = valueNameLen * sizeof(WCHAR); - memcpy_s(pBasic->Name, pBasic->NameLength, valueName, pBasic->NameLength); - break; - } - case KeyValueFullInformation: { - PKEY_VALUE_FULL_INFORMATION pFull = (PKEY_VALUE_FULL_INFORMATION)KeyValueInformation; - size_t requiredSize = sizeof(KEY_VALUE_FULL_INFORMATION) + (valueNameLen * sizeof(WCHAR)) + dataSize; - *ResultLength = (ULONG)requiredSize; - - if (Length < requiredSize) { - return STATUS_BUFFER_TOO_SMALL; - } - - data = (BYTE*)HeapAlloc(GetProcessHeap(), 0, dataSize); - if (!data) { - return STATUS_NO_MEMORY; - } - - DWORD actualDataSize = dataSize; - valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW - LSTATUS valueStatus = RegEnumValueW( - hKey, - Index, - valueName, - &valueNameLen, - NULL, - &type, - data, - &actualDataSize - ); - if (valueStatus != ERROR_SUCCESS) { - HeapFree(GetProcessHeap(), 0, data); - return valueStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : - valueStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : - valueStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; - } - - pFull->TitleIndex = 0; - pFull->Type = type; - pFull->DataOffset = sizeof(KEY_VALUE_FULL_INFORMATION) + (valueNameLen * sizeof(WCHAR)); - pFull->DataLength = dataSize; - pFull->NameLength = valueNameLen * sizeof(WCHAR); - memcpy_s(pFull->Name, pFull->NameLength, valueName, pFull->NameLength); - memcpy_s((BYTE*)pFull + pFull->DataOffset, dataSize, data, dataSize); - HeapFree(GetProcessHeap(), 0, data); - break; - } - case KeyValuePartialInformation: { - PKEY_VALUE_PARTIAL_INFORMATION pPartial = (PKEY_VALUE_PARTIAL_INFORMATION)KeyValueInformation; - size_t requiredSize = sizeof(KEY_VALUE_PARTIAL_INFORMATION) + dataSize; - *ResultLength = (ULONG)requiredSize; - - if (Length < requiredSize) { - return STATUS_BUFFER_TOO_SMALL; - } - - data = (BYTE*)HeapAlloc(GetProcessHeap(), 0, dataSize); - if (!data) { - return STATUS_NO_MEMORY; - } - - DWORD actualDataSize = dataSize; - valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW - LSTATUS valueStatus = RegEnumValueW( - hKey, - Index, - valueName, - &valueNameLen, - NULL, - &type, - data, - &actualDataSize - ); - if (valueStatus != ERROR_SUCCESS) { - HeapFree(GetProcessHeap(), 0, data); - return valueStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : - valueStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : - valueStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; - } - - pPartial->TitleIndex = 0; - pPartial->Type = type; - pPartial->DataLength = dataSize; - memcpy_s(pPartial->Data, dataSize, data, dataSize); - HeapFree(GetProcessHeap(), 0, data); - break; - } - - default: - return STATUS_INVALID_PARAMETER; - } +/** + * @brief Extended function to delete a subkey of a parent key. + */ +NTSTATUS ZwDeleteKeyEx(HKEY KeyHandle, LPCWSTR SubKey); - return STATUS_SUCCESS; +#ifdef __cplusplus } +#endif -// Simplified API for registry operations +/* Simplified API wrappers - kept inline because names conflict with Win32 macros */ -/** - * @brief Create or open a registry key (simplified wrapper) - * - * @param KeyHandle Pointer to receive the handle to the key - * @param DesiredAccess The access mask for the key - * @param RootKey Root registry key (e.g., HKEY_CURRENT_USER) - * @param SubKey Path to the subkey - * @param CreateOptions Options for creating the key - * @param Disposition Optional pointer to receive creation disposition - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegCreateKey( - PHANDLE KeyHandle, - ACCESS_MASK DesiredAccess, - HANDLE RootKey, - LPCWSTR SubKey, - ULONG CreateOptions, - PULONG Disposition OPTIONAL -) + PHANDLE KeyHandle, ACCESS_MASK DesiredAccess, HANDLE RootKey, + LPCWSTR SubKey, ULONG CreateOptions, PULONG Disposition OPTIONAL) { - // Create Unicode string for subkey UNICODE_STRING unicodeSubKey; - RtlInitUnicodeString(&unicodeSubKey, SubKey); - - // Initialize object attributes OBJECT_ATTRIBUTES objectAttributes; + RtlInitUnicodeString(&unicodeSubKey, SubKey); InitializeObjectAttributes(&objectAttributes, &unicodeSubKey, OBJ_CASE_INSENSITIVE, RootKey, NULL); - - return ZwCreateKey( - KeyHandle, - DesiredAccess, - &objectAttributes, - 0, // TitleIndex, must be zero - NULL, // Class - CreateOptions, - Disposition - ); + return ZwCreateKey(KeyHandle, DesiredAccess, &objectAttributes, 0, NULL, CreateOptions, Disposition); } -/** - * @brief Open an existing registry key (simplified wrapper) - * - * @param KeyHandle Pointer to receive the handle to the key - * @param DesiredAccess The access mask for the key - * @param RootKey Root registry key (e.g., HKEY_CURRENT_USER) - * @param SubKey Path to the subkey - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegOpenKey( - PHANDLE KeyHandle, - ACCESS_MASK DesiredAccess, - HANDLE RootKey, - LPCWSTR SubKey -) + PHANDLE KeyHandle, ACCESS_MASK DesiredAccess, HANDLE RootKey, LPCWSTR SubKey) { - // Create Unicode string for subkey UNICODE_STRING unicodeSubKey; - RtlInitUnicodeString(&unicodeSubKey, SubKey); - - // Initialize object attributes OBJECT_ATTRIBUTES objectAttributes; + RtlInitUnicodeString(&unicodeSubKey, SubKey); InitializeObjectAttributes(&objectAttributes, &unicodeSubKey, OBJ_CASE_INSENSITIVE, RootKey, NULL); - - // Call standard ZwOpenKey - return ZwOpenKey( - KeyHandle, - DesiredAccess, - &objectAttributes - ); + return ZwOpenKey(KeyHandle, DesiredAccess, &objectAttributes); } -/** - * @brief Set a registry value (simplified wrapper) - * - * @param KeyHandle Handle to the registry key - * @param ValueName Name of the value to set - * @param Type Type of the value data - * @param Data The value data to set - * @param DataSize Size of the value data in bytes - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegSetValueKey( - HANDLE KeyHandle, - LPCWSTR ValueName, - ULONG Type, - PVOID Data, - ULONG DataSize -) + HANDLE KeyHandle, LPCWSTR ValueName, ULONG Type, PVOID Data, ULONG DataSize) { - // Create Unicode string for value name UNICODE_STRING unicodeValueName; RtlInitUnicodeString(&unicodeValueName, ValueName); - - // Call standard ZwSetValueKey - return ZwSetValueKey( - KeyHandle, - &unicodeValueName, - 0, // TitleIndex, must be zero - Type, - Data, - DataSize - ); + return ZwSetValueKey(KeyHandle, &unicodeValueName, 0, Type, Data, DataSize); } -/** - * @brief Enumerate registry keys (simplified wrapper) - * - * @param KeyHandle Handle to the registry key - * @param Index Index of the subkey to enumerate - * @param Name Buffer to receive the name of the subkey - * @param NameSize Size of the name buffer in characters - * @return NTSTATUS STATUS_SUCCESS on success, STATUS_NO_MORE_ENTRIES when done - */ inline NTSTATUS RegEnumerateKey( - HANDLE KeyHandle, - ULONG Index, - PWSTR Name, - ULONG NameSize -) + HANDLE KeyHandle, ULONG Index, PWSTR Name, ULONG NameSize) { - BYTE buffer[2048]; // Buffer for key information + BYTE buffer[2048]; ULONG resultLength; - - // Call the standard ZwEnumerateKey with KeyBasicInformation - NTSTATUS status = ZwEnumerateKey( - KeyHandle, - Index, - KeyBasicInformation, - buffer, - sizeof(buffer), - &resultLength - ); - + NTSTATUS status = ZwEnumerateKey(KeyHandle, Index, KeyBasicInformation, buffer, sizeof(buffer), &resultLength); if (NT_SUCCESS(status)) { - // Copy the name to the output buffer PKEY_BASIC_INFORMATION keyInfo = (PKEY_BASIC_INFORMATION)buffer; ULONG copyLen = min(NameSize - 1, keyInfo->NameLength / sizeof(WCHAR)); - wcsncpy_s(Name, NameSize / sizeof(WCHAR), keyInfo->Name, copyLen); Name[copyLen] = L'\0'; } - return status; } -/** - * @brief Query a registry value (simplified wrapper) - * - * @param KeyHandle Handle to the registry key - * @param ValueName Name of the value to query - * @param Type Pointer to receive the value type - * @param Data Buffer to receive the value data - * @param DataSize Pointer to the size of the data buffer, receives actual data size - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegQueryValueKey( - HANDLE KeyHandle, - LPCWSTR ValueName, - PULONG Type, - PVOID Data, - PULONG DataSize -) + HANDLE KeyHandle, LPCWSTR ValueName, PULONG Type, PVOID Data, PULONG DataSize) { - // Create Unicode string for value name UNICODE_STRING unicodeValueName; + ULONG bufferSize, resultLength; + PKEY_VALUE_FULL_INFORMATION fullInfo; + NTSTATUS status; + RtlInitUnicodeString(&unicodeValueName, ValueName); - - // Allocate buffer for full information - ULONG bufferSize = sizeof(KEY_VALUE_FULL_INFORMATION) + *DataSize; - PKEY_VALUE_FULL_INFORMATION fullInfo = (PKEY_VALUE_FULL_INFORMATION)HeapAlloc(GetProcessHeap(), 0, bufferSize); if (!fullInfo) { - return STATUS_NO_MEMORY; - } - - // Query the value - ULONG resultLength; - NTSTATUS status = ZwQueryValueKey( - KeyHandle, - &unicodeValueName, - KeyValueFullInformation, - fullInfo, - bufferSize, - &resultLength - ); - - if (NT_SUCCESS(status)) { - if (Type) { - *Type = fullInfo->Type; - } + bufferSize = sizeof(KEY_VALUE_FULL_INFORMATION) + *DataSize; + fullInfo = (PKEY_VALUE_FULL_INFORMATION)HeapAlloc(GetProcessHeap(), 0, bufferSize); + if (!fullInfo) return STATUS_NO_MEMORY; - // Make sure we have enough space for the data + status = ZwQueryValueKey(KeyHandle, &unicodeValueName, KeyValueFullInformation, fullInfo, bufferSize, &resultLength); + if (NT_SUCCESS(status)) { + if (Type) *Type = fullInfo->Type; if (*DataSize < fullInfo->DataLength) { *DataSize = fullInfo->DataLength; status = STATUS_BUFFER_TOO_SMALL; } else { - // Copy the data memcpy_s(Data, *DataSize, (PBYTE)fullInfo + fullInfo->DataOffset, fullInfo->DataLength); *DataSize = fullInfo->DataLength; } } else if (status == STATUS_BUFFER_TOO_SMALL) { *DataSize = resultLength - FIELD_OFFSET(KEY_VALUE_FULL_INFORMATION, Name); } - HeapFree(GetProcessHeap(), 0, fullInfo); return status; } -/** - * @brief Opens a registry key for the specified access - * - * @param KeyHandle Pointer to receive the handle to the key - * @param DesiredAccess The access mask for the key - * @param RootKey Root registry key (e.g., HKEY_CURRENT_USER) - * @param SubKey Path to the subkey - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegOpenKeyEx( - PHANDLE KeyHandle, - ACCESS_MASK DesiredAccess, - HANDLE RootKey, - LPCWSTR SubKey -) + PHANDLE KeyHandle, ACCESS_MASK DesiredAccess, HANDLE RootKey, LPCWSTR SubKey) { - // Create Unicode string for subkey UNICODE_STRING unicodeSubKey; - RtlInitUnicodeString(&unicodeSubKey, SubKey); - - // Initialize object attributes OBJECT_ATTRIBUTES objectAttributes; + RtlInitUnicodeString(&unicodeSubKey, SubKey); InitializeObjectAttributes(&objectAttributes, &unicodeSubKey, OBJ_CASE_INSENSITIVE, RootKey, NULL); - - // Call standard ZwOpenKey - return ZwOpenKey( - KeyHandle, - DesiredAccess, - &objectAttributes - ); + return ZwOpenKey(KeyHandle, DesiredAccess, &objectAttributes); } -/** - * @brief Enumerates value entries for a registry key - * Simplified wrapper for ZwEnumerateValueKey that returns full information - * - * @param KeyHandle Handle to the registry key - * @param Index Index of the value entry to enumerate - * @param ValueInformation Buffer to receive value information - * @param Length Size of the ValueInformation buffer in bytes - * @param ResultLength Pointer to receive the size of data written - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ inline NTSTATUS RegEnumerateValueKey( - IN HANDLE KeyHandle, - IN ULONG Index, - OUT PKEY_VALUE_FULL_INFORMATION ValueInformation, - IN ULONG Length, - OUT PULONG ResultLength -) { - return ZwEnumerateValueKey( - KeyHandle, - Index, - KeyValueFullInformation, - ValueInformation, - Length, - ResultLength - ); -} - -// ZwDeleteKey: Deletes a registry key using RegDeleteKeyW. Only works for empty keys (not recursive). -/** - * @brief Deletes a registry key (empty keys only). - * This function matches the Windows kernel's function signature. - * - * @param KeyHandle Handle to the key to delete - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ -inline NTSTATUS ZwDeleteKey(HANDLE KeyHandle) { - // This deletes the key that KeyHandle refers to - // We need to close the handle after the key is deleted - LSTATUS status = RegDeleteKeyW((HKEY)KeyHandle, L""); - - if (status != ERROR_SUCCESS) { - // Try to close the handle to avoid leaks even if delete failed - RegCloseKey((HKEY)KeyHandle); - return status; - } - - // Key was deleted, now close the handle (safe because RegDeleteKeyW doesn't invalidate it) - RegCloseKey((HKEY)KeyHandle); - return STATUS_SUCCESS; -} - -/** - * @brief Extended function to delete a subkey of a parent key. - * This is a WinKernelLite extension, not part of standard Windows kernel API. - * - * @param KeyHandle Handle to the parent key - * @param SubKey Name of the subkey to delete - * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code - */ -inline NTSTATUS ZwDeleteKeyEx(HKEY KeyHandle, LPCWSTR SubKey) { - // RegDeleteKeyW returns ERROR_SUCCESS (0) on success, map to STATUS_SUCCESS (0) - LSTATUS status = RegDeleteKeyW(KeyHandle, SubKey); - return (status == ERROR_SUCCESS) ? STATUS_SUCCESS : status; + IN HANDLE KeyHandle, IN ULONG Index, + OUT PKEY_VALUE_FULL_INFORMATION ValueInformation, IN ULONG Length, OUT PULONG ResultLength) +{ + return ZwEnumerateValueKey(KeyHandle, Index, KeyValueFullInformation, ValueInformation, Length, ResultLength); } #endif // WINKERNEL_REGISTRY_H diff --git a/include/Resource.h b/include/Resource.h index 1f30b99..e3b89a8 100644 --- a/include/Resource.h +++ b/include/Resource.h @@ -64,7 +64,7 @@ typedef struct _ERESOURCE { PVOID Address; ULONG_PTR CreatorBackTraceIndex; }; KSPIN_LOCK SpinLock; - + // User-mode implementation additions CRITICAL_SECTION CriticalSection; } ERESOURCE, * PERESOURCE; @@ -89,293 +89,22 @@ extern LIST_ENTRY g_WinKernelLite_SystemResourcesList; extern CRITICAL_SECTION g_WinKernelLite_SystemResourcesLock; extern BOOLEAN g_WinKernelLite_SystemResourcesInitialized; -#ifdef __cplusplus -} -#endif - #define IsOwnedExclusive(R) ((R)->Flag & ResourceOwnedExclusive) -inline void EnsureSystemResourcesListInitialized() -{ - if (!g_WinKernelLite_SystemResourcesInitialized) { - InitializeCriticalSection(&g_WinKernelLite_SystemResourcesLock); - InitializeListHead(&g_WinKernelLite_SystemResourcesList); - g_WinKernelLite_SystemResourcesInitialized = TRUE; - } -} - -inline void CleanupGlobalResources() -{ - /* Cleanup APC disable lock if initialized */ - if (g_WinKernelLite_KernelApcDisableLockInitialized) { - DeleteCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - g_WinKernelLite_KernelApcDisableLockInitialized = FALSE; - } - - /* Cleanup system resources lock if initialized */ - if (g_WinKernelLite_SystemResourcesInitialized) { - DeleteCriticalSection(&g_WinKernelLite_SystemResourcesLock); - g_WinKernelLite_SystemResourcesInitialized = FALSE; - } -} - -// Initialize a resource -inline NTSTATUS -ExInitializeResourceLite( - IN PERESOURCE Resource -) -{ - if (!Resource) - return STATUS_INVALID_PARAMETER; - - ZeroMemory(Resource, sizeof(ERESOURCE)); - - // Make sure global list is initialized - EnsureSystemResourcesListInitialized(); - - // Initialize the critical section for our user-mode implementation - InitializeCriticalSection(&Resource->CriticalSection); - - // Add to the global system resources list - EnterCriticalSection(&g_WinKernelLite_SystemResourcesLock); - InsertTailList(&g_WinKernelLite_SystemResourcesList, &Resource->SystemResourcesList); - LeaveCriticalSection(&g_WinKernelLite_SystemResourcesLock); - - return STATUS_SUCCESS; -} - -// Delete a resource -inline NTSTATUS -ExDeleteResourceLite( - IN PERESOURCE Resource -) -{ - if (!Resource) - return STATUS_INVALID_PARAMETER; - - // Remove from global system resources list - EnterCriticalSection(&g_WinKernelLite_SystemResourcesLock); - RemoveEntryList(&Resource->SystemResourcesList); - LeaveCriticalSection(&g_WinKernelLite_SystemResourcesLock); - - // Delete the critical section - DeleteCriticalSection(&Resource->CriticalSection); - - return STATUS_SUCCESS; -} - +/* Function declarations - implementations in Resource.c */ +void EnsureSystemResourcesListInitialized(void); +void CleanupGlobalResources(void); +NTSTATUS ExInitializeResourceLite(IN PERESOURCE Resource); +NTSTATUS ExDeleteResourceLite(IN PERESOURCE Resource); +BOOLEAN ExAcquireResourceExclusiveLite(IN PERESOURCE Resource, IN BOOLEAN Wait); +BOOLEAN ExAcquireResourceSharedLite(IN PERESOURCE Resource, IN BOOLEAN Wait); +VOID ExReleaseResourceLite(IN PERESOURCE Resource); +VOID KeEnterCriticalRegion(VOID); +VOID KeLeaveCriticalRegion(VOID); +LONG GetKernelApcDisableCount(void); -inline BOOLEAN -ExAcquireResourceExclusiveLite( - IN PERESOURCE Resource, - IN BOOLEAN Wait -) -{ - ERESOURCE_THREAD CurrentThread; - BOOLEAN Result = FALSE; - - if (!Resource) - return FALSE; - - // Get the current thread ID as the resource thread - CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); - - // Enter the critical section if Wait is TRUE, otherwise try to enter - if (Wait) { - EnterCriticalSection(&Resource->CriticalSection); - } - else if (!TryEnterCriticalSection(&Resource->CriticalSection)) { - return FALSE; - } - - // Check if the resource is already owned - if (Resource->ActiveCount != 0) { - // If owned exclusively by the current thread, allow recursive exclusive acquisition - if (IsOwnedExclusive(Resource) && - (Resource->OwnerThreads[0].OwnerThread == CurrentThread)) { - Resource->OwnerThreads[0].OwnerCount += 1; - Result = TRUE; - } - else { - // Resource is owned by another thread or shared - cannot acquire exclusive - if (Wait == FALSE) { - Result = FALSE; - } - else { - // For simplicity in this user-mode implementation: - // If we need to wait and we're here, we know we're holding the critical section, - // but the resource is owned by someone else. We'll release and retry. - LeaveCriticalSection(&Resource->CriticalSection); - Sleep(1); // Yield to other threads - return ExAcquireResourceExclusiveLite(Resource, Wait); - } - } - } - else { - // Resource is not owned, so we can take it - Resource->Flag |= ResourceOwnedExclusive; - Resource->OwnerThreads[0].OwnerThread = CurrentThread; - Resource->OwnerThreads[0].OwnerCount = 1; - Resource->ActiveCount = 1; - Result = TRUE; - } - - // Always release the critical section - LeaveCriticalSection(&Resource->CriticalSection); return Result; -} - -// Acquire a resource for shared access -inline BOOLEAN -ExAcquireResourceSharedLite( - IN PERESOURCE Resource, - IN BOOLEAN Wait -) -{ - ERESOURCE_THREAD CurrentThread; - BOOLEAN Result = FALSE; - - if (!Resource) - return FALSE; - - // Get the current thread ID as the resource thread - CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); - - // Enter the critical section if Wait is TRUE, otherwise try to enter - if (Wait) { - EnterCriticalSection(&Resource->CriticalSection); - } - else if (!TryEnterCriticalSection(&Resource->CriticalSection)) { - return FALSE; - } - - // Check if the resource is already owned - if (Resource->ActiveCount != 0) { - // If owned exclusively by the current thread, we CANNOT acquire it shared - // This is a key difference from exclusive acquisition - no shared acquisition - // when holding exclusive, even for the same thread - if (IsOwnedExclusive(Resource)) { - if (Wait == FALSE) { - Result = FALSE; - } - else { - // For simplicity in this user-mode implementation: - // If we need to wait and we're here, we release and retry - LeaveCriticalSection(&Resource->CriticalSection); - Sleep(1); // Yield to other threads - return ExAcquireResourceSharedLite(Resource, Wait); - } - } - // It's owned shared, so we can add ourselves as another shared owner - else { - // For simplicity, in this implementation we'll just increment ActiveCount - Resource->ActiveCount++; - Result = TRUE; - } - } - else { - // Resource is not owned, so we can take it shared - Resource->ActiveCount = 1; - Result = TRUE; - } - - // Always leave the critical section when acquiring shared - LeaveCriticalSection(&Resource->CriticalSection); - - return Result; -} - -// Release a resource -inline VOID -ExReleaseResourceLite( - IN PERESOURCE Resource -) -{ - ERESOURCE_THREAD CurrentThread; - - if (!Resource) - return; - - // Get the current thread ID as the resource thread - CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); - - // Enter the critical section to safely modify the resource - EnterCriticalSection(&Resource->CriticalSection); - - if (IsOwnedExclusive(Resource)) { - // If owned exclusively, verify it's our thread - if (Resource->OwnerThreads[0].OwnerThread == CurrentThread) { - // Decrement the count, and if it reaches 0, release ownership - Resource->OwnerThreads[0].OwnerCount -= 1; - if (Resource->OwnerThreads[0].OwnerCount == 0) { - Resource->Flag &= ~ResourceOwnedExclusive; - Resource->OwnerThreads[0].OwnerThread = 0; - Resource->ActiveCount = 0; - } - } - } - else { - // For a shared resource, just decrement the active count - if (Resource->ActiveCount > 0) { - Resource->ActiveCount -= 1; - } - } - - // Release the critical section - LeaveCriticalSection(&Resource->CriticalSection); -} - -// Enter a critical region (disable APCs) -inline VOID -KeEnterCriticalRegion( - VOID -) -{ - /* Initialize the critical section if needed */ - if (!g_WinKernelLite_KernelApcDisableLockInitialized) { - InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; - } - - /* In kernel mode, this disables normal kernel APCs */ - /* In our user-mode implementation, we'll use a critical section for thread safety */ - EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - InterlockedIncrement(&g_WinKernelLite_KernelApcDisableCount); - LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); -} - -// Leave a critical region (enable APCs) -inline VOID -KeLeaveCriticalRegion( - VOID -) -{ - /* Initialize the critical section if needed (safety check) */ - if (!g_WinKernelLite_KernelApcDisableLockInitialized) { - InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; - } /* Re-enables normal kernel APCs */ - /* In our user-mode implementation, we'll use a critical section for thread safety */ - EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - InterlockedDecrement(&g_WinKernelLite_KernelApcDisableCount); - LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); -} - -// Get the current APC disable count (thread-safe) -inline LONG GetKernelApcDisableCount() -{ - LONG currentValue; - - /* Initialize the critical section if needed */ - if (!g_WinKernelLite_KernelApcDisableLockInitialized) { - InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; - } - - EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - currentValue = g_WinKernelLite_KernelApcDisableCount; - LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); - - return currentValue; +#ifdef __cplusplus } +#endif #endif diff --git a/openspec/changes/refactor-headers-to-source/.openspec.yaml b/openspec/changes/refactor-headers-to-source/.openspec.yaml new file mode 100644 index 0000000..caac517 --- /dev/null +++ b/openspec/changes/refactor-headers-to-source/.openspec.yaml @@ -0,0 +1,2 @@ +schema: spec-driven +created: 2026-03-22 diff --git a/openspec/changes/refactor-headers-to-source/design.md b/openspec/changes/refactor-headers-to-source/design.md new file mode 100644 index 0000000..e68f300 --- /dev/null +++ b/openspec/changes/refactor-headers-to-source/design.md @@ -0,0 +1,59 @@ +## Context + +WinKernelLite headers contain both declarations and complex implementations inline. This architecture emerged organically but now blocks thread safety work identified in the code review. The project uses CMake with automatic source file discovery via glob patterns, so adding new `.c` files requires no build system changes. + +Current state by file: +- `KernelHeap.h`: ~600 lines, contains `GetGlobalState`, `TrackAllocation`, `UntrackAllocation`, `CheckForDoubleFree`, `_ExFreePoolWithTracking`, `PrintMemoryLeaks` - all touching `g_WinKernelLite_GlobalState` +- `Resource.h`: ~400 lines, contains full locking implementations with `CRITICAL_SECTION` usage and `Sleep(1)` spin loops +- `Registry.h`: ~850 lines, contains 6 Zw* functions with Win32 API mapping and buffer management +- `File.h`: ~230 lines, contains `ZwCreateFile` with disposition mapping and error translation +- Existing `.c` files: `KernelHeap.c` (1 line - global var), `Resource.c` (~30 lines - global vars), `Debug.c` (1 line - global var) + +## Goals / Non-Goals + +**Goals:** +- Move complex implementations (>5 lines) from headers to `.c` files +- Enable future thread safety fixes by having implementations in compilable source +- Match real WDK/kernel architecture: inline only what the WDK inlines +- All 36 existing tests pass without modification + +**Non-Goals:** +- Fixing any bugs (thread safety, API fidelity) - those are separate changes +- Changing any function signatures or behavior +- Refactoring UnicodeString.h - `RtlInitUnicodeString` and accessors are inline in the real WDK +- Refactoring LinkedList.h - all functions are inline macros in the real WDK +- Optimizing compilation or adding precompiled headers + +## Decisions + +### 1. Move implementations per-file, one header at a time + +**Decision:** Refactor each header independently in sequence: KernelHeap, Resource, File, Registry. + +**Rationale:** Each header is self-contained enough to move independently. Sequential approach lets us build and test after each move, catching issues early. KernelHeap first because it has the most global state and is the highest-priority target for thread safety. + +**Alternative considered:** Move all at once - rejected because a single broken move would be hard to bisect. + +### 2. Keep thin inline wrappers in headers for ExAllocatePoolWithTag / ExFreePoolWithTag + +**Decision:** These two functions stay inline in `KernelHeap.h` but call non-inline helpers in `KernelHeap.c`. + +**Rationale:** These are the primary API surface that consumers call directly. Keeping them inline preserves zero-overhead for the common case (the allocation itself is just `HeapAlloc`). The tracking/bookkeeping logic that touches global state moves to `.c`. This mirrors how the kernel separates the fast path (inline in wdm.h) from the pool manager internals (in ntoskrnl.exe). + +### 3. Function prototypes use same signatures, no new types + +**Decision:** The moved functions keep identical signatures. Headers gain `extern` function declarations where `__forceinline` definitions were. + +**Rationale:** Zero consumer impact. Tests and downstream projects like WinKernelCommLib compile without changes. + +### 4. Internal helpers become non-static in .c files + +**Decision:** Functions like `GetGlobalState`, `TrackAllocation`, `CheckForDoubleFree` become regular (non-static, non-inline) functions declared in the header and defined in the `.c` file. They retain their existing linkage (C linkage via `extern "C"` blocks). + +**Rationale:** These functions are called from inline wrappers in the header, so they cannot be `static`. They need external linkage. + +## Risks / Trade-offs + +- **[Risk] Subtle behavior change from losing inlining** -> Mitigation: These are wrapper functions calling Win32 APIs; the Win32 call dominates execution time. Inline vs non-inline overhead is negligible. All tests verify behavior is unchanged. +- **[Risk] Include order dependencies** -> Mitigation: Headers already include their dependencies (``, other project headers). Moving implementations doesn't change include requirements. +- **[Risk] Linker errors from duplicate definitions** -> Mitigation: Remove `__forceinline`/`inline` from moved functions. Verify single-definition rule with build test. diff --git a/openspec/changes/refactor-headers-to-source/proposal.md b/openspec/changes/refactor-headers-to-source/proposal.md new file mode 100644 index 0000000..2a5d0ab --- /dev/null +++ b/openspec/changes/refactor-headers-to-source/proposal.md @@ -0,0 +1,53 @@ +## Why + +WinKernelLite implements complex kernel API wrappers (50-100+ lines each) as `__forceinline` functions in header files. This prevents adding thread synchronization to global state, blocks debugger breakpoints on key functions, duplicates compiled code across every translation unit, and contradicts the real kernel's architecture where these functions live in `ntoskrnl.exe`/`ntdll.dll`, not in WDK headers. A code review identified thread safety as the top-priority fix, and that fix is blocked until implementations move out of headers. + +## What Changes + +- Move complex function implementations from 4 header files into corresponding `.c` source files +- Headers retain: type definitions, struct declarations, function prototypes, and simple inline wrappers (1-5 lines) that match real WDK inline patterns +- Source files receive: implementations >5 lines, functions touching global state, functions needing synchronization +- Create new source files: `src/Registry.c`, `src/File.c` +- Extend existing source files: `src/Resource.c`, `src/KernelHeap.c` +- No API signature changes - all existing function prototypes remain identical +- No behavioral changes - this is a pure structural refactor + +### Functions to move (by file): + +**Registry.h -> src/Registry.c (new):** +- `ZwOpenKey`, `ZwQueryValueKey`, `ZwSetValueKey`, `ZwDeleteKey`, `ZwEnumerateKey`, `ZwEnumerateValueKey` + +**File.h -> src/File.c (new):** +- `ZwCreateFile` + +**Resource.h -> src/Resource.c (existing):** +- `ExInitializeResourceLite`, `ExAcquireResourceExclusiveLite`, `ExAcquireResourceSharedLite`, `ExReleaseResourceLite`, `ExDeleteResourceLite`, `KeEnterCriticalRegion`, `KeLeaveCriticalRegion` + +**KernelHeap.h -> src/KernelHeap.c (existing):** +- `GetGlobalState`, `TrackAllocation`, `UntrackAllocation`, `CheckForDoubleFree`, `_ExFreePoolWithTracking`, `PrintMemoryLeaks` +- Keep `ExAllocatePoolWithTag`/`ExFreePoolWithTag` as thin inline wrappers calling the moved implementations + +### Functions to keep inline (match real WDK): +- `LinkedList.h` - all functions (WDK inlines these) +- `UnicodeString.h` - `RtlInitUnicodeString` and simple accessors (WDK inlines these) +- `KernelPerf.h` - `KeQueryPerformanceCounter` (simple delegation) + +## Capabilities + +### New Capabilities +- `header-source-separation`: Rules governing which functions belong in headers vs source files, based on real WDK/kernel architecture + +### Modified Capabilities + + +## Impact + +- **Headers**: Registry.h, File.h, Resource.h, KernelHeap.h - reduced to declarations + simple inlines +- **Source files**: 2 new (Registry.c, File.c), 2 extended (Resource.c, KernelHeap.c) +- **Build system**: CMakeLists.txt auto-discovers .c files via glob, so new source files are picked up automatically +- **Tests**: No changes required - all function signatures and behavior remain identical +- **Consumers**: No changes required - `#include` still provides the same API surface + +## Version Impact + +**PATCH** - Pure internal refactor. No public API changes, no behavioral changes, no new functions, no removed functions. All existing consumer code compiles and works identically. diff --git a/openspec/changes/refactor-headers-to-source/specs/header-source-separation/spec.md b/openspec/changes/refactor-headers-to-source/specs/header-source-separation/spec.md new file mode 100644 index 0000000..02867c2 --- /dev/null +++ b/openspec/changes/refactor-headers-to-source/specs/header-source-separation/spec.md @@ -0,0 +1,42 @@ +## ADDED Requirements + +### Requirement: Complex implementations SHALL reside in source files +Functions with implementations exceeding 5 lines, or functions that access global state, MUST be defined in `.c` source files. Headers SHALL contain only the function prototype (declaration). + +#### Scenario: Function accessing global state +- **WHEN** a function reads or writes any `g_WinKernelLite_*` global variable +- **THEN** its implementation MUST be in a `.c` file, not inline in a header + +#### Scenario: Function exceeding 5 lines +- **WHEN** a function body exceeds 5 lines of logic (excluding braces and blank lines) +- **THEN** its implementation MUST be in a `.c` file unless the real WDK declares the equivalent function inline + +### Requirement: Simple wrappers matching WDK inlines SHALL remain in headers +Functions that are declared inline or as macros in the real Windows Driver Kit headers (wdm.h, ntddk.h) SHALL remain as `__forceinline` in WinKernelLite headers. + +#### Scenario: Linked list operations +- **WHEN** implementing `InitializeListHead`, `InsertTailList`, `InsertHeadList`, `RemoveEntryList`, `IsListEmpty` +- **THEN** these SHALL remain `__forceinline` in `LinkedList.h` because the real WDK declares them inline + +#### Scenario: Simple kernel API delegation +- **WHEN** a function is a 1-5 line delegation to a Win32 API (e.g., `KeQueryPerformanceCounter` calling `QueryPerformanceCounter`) +- **THEN** it SHALL remain `__forceinline` in its header + +### Requirement: Refactored functions SHALL preserve identical signatures +All functions moved from headers to source files MUST retain their exact original function signature. No parameter types, return types, or calling conventions SHALL change. + +#### Scenario: Consumer code compilation +- **WHEN** a downstream project (e.g., WinKernelCommLib) includes WinKernelLite headers after the refactor +- **THEN** the project SHALL compile and link without any source code changes + +#### Scenario: Test suite passes unchanged +- **WHEN** the existing test suite is executed after the refactor +- **THEN** all tests SHALL pass without modification to any test file + +### Requirement: Header files SHALL declare moved functions with extern linkage +Functions moved to `.c` files SHALL have their prototypes declared in the corresponding header file within an `extern "C"` block for C++ compatibility. + +#### Scenario: C++ compilation +- **WHEN** a `.cpp` file includes a refactored header +- **THEN** the function prototypes SHALL be enclosed in `extern "C" { }` guards +- **THEN** linking SHALL succeed without name mangling issues diff --git a/openspec/changes/refactor-headers-to-source/tasks.md b/openspec/changes/refactor-headers-to-source/tasks.md new file mode 100644 index 0000000..c57654b --- /dev/null +++ b/openspec/changes/refactor-headers-to-source/tasks.md @@ -0,0 +1,19 @@ +## 1. KernelHeap refactor + +- [x] 1.1 Move `GetGlobalState`, `TrackAllocation`, `UntrackAllocation`, `CheckForDoubleFree`, `_ExFreePoolWithTracking`, `PrintMemoryLeaks` implementations from `KernelHeap.h` to `src/KernelHeap.c`. Replace inline definitions in header with function prototypes inside the existing `extern "C"` block. Keep `ExAllocatePoolWithTag` and `ExFreePoolWithTag` as thin inline wrappers calling the moved functions. +- [x] 1.2 Build and run all tests to verify KernelHeap refactor + +## 2. Resource refactor + +- [x] 2.1 Move `EnsureSystemResourcesListInitialized`, `ExInitializeResourceLite`, `ExAcquireResourceExclusiveLite`, `ExAcquireResourceSharedLite`, `ExReleaseResourceLite`, `ExDeleteResourceLite`, `GetKernelApcDisableCount`, `KeEnterCriticalRegion`, `KeLeaveCriticalRegion` implementations from `Resource.h` to `src/Resource.c`. Replace inline definitions in header with function prototypes. Keep simple macro `IsOwnedExclusive` inline. +- [x] 2.2 Build and run all tests to verify Resource refactor + +## 3. File refactor + +- [x] 3.1 Move `ZwCreateFile` implementation from `File.h` to new `src/File.c`. Replace inline definition in header with function prototype. +- [x] 3.2 Build and run all tests to verify File refactor + +## 4. Registry refactor + +- [x] 4.1 Move `ZwOpenKey`, `ZwQueryValueKey`, `ZwSetValueKey`, `ZwDeleteKey`, `ZwEnumerateKey`, `ZwEnumerateValueKey` implementations from `Registry.h` to new `src/Registry.c`. Replace inline definitions in header with function prototypes. +- [x] 4.2 Build and run all tests to verify Registry refactor diff --git a/src/File.c b/src/File.c new file mode 100644 index 0000000..6f80793 --- /dev/null +++ b/src/File.c @@ -0,0 +1,93 @@ +/* + * Copyright 2025 WinKernelLite Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/File.h" + +NTSTATUS ZwCreateFile( + OUT PHANDLE FileHandle, + IN ACCESS_MASK DesiredAccess, + IN POBJECT_ATTRIBUTES ObjectAttributes, + OUT PIO_STATUS_BLOCK IoStatusBlock, + IN PLARGE_INTEGER AllocationSize OPTIONAL, + IN ULONG FileAttributes, + IN ULONG ShareAccess, + IN ULONG CreateDisposition, + IN ULONG CreateOptions, + IN PVOID EaBuffer OPTIONAL, + IN ULONG EaLength +) +{ + DWORD dwCreationDisposition; + + /* Validate parameters */ + if (!FileHandle || !ObjectAttributes || !ObjectAttributes->ObjectName || !IoStatusBlock) { + return STATUS_INVALID_PARAMETER; + } + + UNREFERENCED_PARAMETER(AllocationSize); + UNREFERENCED_PARAMETER(EaBuffer); + UNREFERENCED_PARAMETER(EaLength); + + /* Convert NT CreateDisposition to Win32 CreateDisposition */ + switch (CreateDisposition) { + case FILE_SUPERSEDE: + case FILE_OVERWRITE_IF: + dwCreationDisposition = CREATE_ALWAYS; + break; + case FILE_CREATE: + dwCreationDisposition = CREATE_NEW; + break; + case FILE_OPEN: + dwCreationDisposition = OPEN_EXISTING; + break; + case FILE_OPEN_IF: + dwCreationDisposition = OPEN_ALWAYS; + break; + case FILE_OVERWRITE: + dwCreationDisposition = TRUNCATE_EXISTING; + break; + default: + return STATUS_INVALID_PARAMETER; + } + + /* Create file with CreateFileW */ + *FileHandle = CreateFileW( + ObjectAttributes->ObjectName->Buffer, + DesiredAccess, + ShareAccess, + NULL, /* Security attributes not supported */ + dwCreationDisposition, + FileAttributes | (CreateOptions & 0x00FFFFFF), /* Convert relevant options */ + NULL /* Template file not supported */ + ); + + if (*FileHandle == INVALID_HANDLE_VALUE) { + DWORD error = GetLastError(); + IoStatusBlock->Status = STATUS_UNSUCCESSFUL; + + if (error == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else if (error == ERROR_ACCESS_DENIED) { + return STATUS_ACCESS_DENIED; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + IoStatusBlock->Status = STATUS_SUCCESS; + IoStatusBlock->Information = FILE_OPENED; /* Simplified */ + return STATUS_SUCCESS; +} diff --git a/src/KernelHeap.c b/src/KernelHeap.c index 1cef72b..7e70635 100644 --- a/src/KernelHeap.c +++ b/src/KernelHeap.c @@ -19,3 +19,536 @@ /* Global state variable definition - single instance across all translation units */ GLOBAL_STATE* g_WinKernelLite_GlobalState = NULL; + +GLOBAL_STATE* GetGlobalState(void) { + if (g_WinKernelLite_GlobalState == NULL) { + HEAP_TRACE("GetGlobalState: Creating new global state"); + GLOBAL_STATE* temp = (GLOBAL_STATE*)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(GLOBAL_STATE)); + if (temp != NULL) { + HEAP_VERBOSE("GetGlobalState: Allocated global state at %p (size: %zu)", temp, sizeof(GLOBAL_STATE)); + /* CriticalSection initialization removed for performance */ + temp->HeapHandle = GetProcessHeap(); + HEAP_VERBOSE("GetGlobalState: Using heap handle %p", temp->HeapHandle); + /* Initialize the linked lists immediately when creating global state */ + InitializeListHead(&temp->MemoryAllocations); + InitializeListHead(&temp->FreedMemoryList); + /* Initialize other fields to safe defaults */ + temp->AllocationCount = 0; + temp->TotalBytesAllocated = 0; + temp->CurrentBytesAllocated = 0; + temp->PeakBytesAllocated = 0; + temp->DoubleFreeCount = 0; + temp->FreedEntryCount = 0; + temp->MaxFreedEntries = 1000; /* Default: keep track of last 1000 freed allocations */ + temp->NextAllocationId = 1; /* Start allocation IDs at 1 */ + temp->SuppressErrors = FALSE; + temp->TrackFreedMemory = TRUE; /* Enable double-free tracking by default */ + g_WinKernelLite_GlobalState = temp; + HEAP_INFO("GetGlobalState: Global state initialized successfully"); + } else { + HEAP_ERROR("GetGlobalState: Failed to allocate global state"); + } + } + return g_WinKernelLite_GlobalState; +} + +BOOL InitHeap(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (!state || !state->HeapHandle) { + return FALSE; + } + + /* Only reset if we have no active allocations */ + if (IsListEmpty(&state->MemoryAllocations)) { + state->AllocationCount = 0; + state->TotalBytesAllocated = 0; + state->CurrentBytesAllocated = 0; + state->PeakBytesAllocated = 0; + state->DoubleFreeCount = 0; + state->SuppressErrors = FALSE; + state->TrackFreedMemory = TRUE; + if (state->MaxFreedEntries == 0) { + state->MaxFreedEntries = 1000; + } + } + + return TRUE; +} + +void CleanupHeap(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + PLIST_ENTRY current, next; + PMEMORY_TRACKING_ENTRY entry; + PFREED_MEMORY_ENTRY freedEntry; + + /* Clean up allocation tracking entries */ + current = state->MemoryAllocations.Flink; + while (current != &state->MemoryAllocations) { + next = current->Flink; + entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); + HeapFree(GetProcessHeap(), 0, entry); + current = next; + } + + /* Clean up freed memory tracking entries */ + current = state->FreedMemoryList.Flink; + while (current != &state->FreedMemoryList) { + next = current->Flink; + freedEntry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); + HeapFree(GetProcessHeap(), 0, freedEntry); + current = next; + } + + HeapFree(GetProcessHeap(), 0, state); + g_WinKernelLite_GlobalState = NULL; + } +} + +void TrackAllocation(PVOID Address, SIZE_T Size, const char* FileName, int LineNumber) { + GLOBAL_STATE* state; + PMEMORY_TRACKING_ENTRY entry; + + state = GetGlobalState(); + if (!state) { + return; + } + + entry = (PMEMORY_TRACKING_ENTRY)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(MEMORY_TRACKING_ENTRY)); + if (!entry) { + if (!state->SuppressErrors) { + HEAP_ERROR("Failed to allocate memory tracking entry"); + } + return; + } + + entry->Address = Address; + entry->Size = Size; + entry->FileName = FileName; + entry->LineNumber = LineNumber; + entry->AllocationId = state->NextAllocationId++; + + InsertHeadList(&state->MemoryAllocations, &entry->ListEntry); + + state->AllocationCount++; + state->TotalBytesAllocated += Size; + state->CurrentBytesAllocated += Size; + if (state->CurrentBytesAllocated > state->PeakBytesAllocated) + state->PeakBytesAllocated = state->CurrentBytesAllocated; +} + +void TrackFreedMemoryLocked(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId) { + GLOBAL_STATE* state = GetGlobalState(); + PFREED_MEMORY_ENTRY entry; + + if (!state || !state->TrackFreedMemory) { + return; + } + + entry = (PFREED_MEMORY_ENTRY)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(FREED_MEMORY_ENTRY)); + + if (!entry) { + if (!state->SuppressErrors) { + HEAP_ERROR("Failed to allocate freed memory tracking entry"); + } + return; + } + + entry->Address = Address; + entry->Size = Size; + entry->AllocFileName = AllocFileName; + entry->AllocLineNumber = AllocLineNumber; + entry->FreeFileName = FreeFileName; + entry->FreeLineNumber = FreeLineNumber; + entry->ThreadId = GetCurrentThreadId(); + entry->AllocationId = AllocationId; + GetSystemTimeAsFileTime(&entry->FreeTime); + + InsertHeadList(&state->FreedMemoryList, &entry->ListEntry); + state->FreedEntryCount++; + + if (state->FreedEntryCount > state->MaxFreedEntries) { + CleanupOldFreedEntries(); + } +} + +BOOL CheckForDoubleFree(PVOID Address, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId) { + GLOBAL_STATE* state; + PLIST_ENTRY current; + PFREED_MEMORY_ENTRY entry; + BOOL found = FALSE; + + if (!Address) return FALSE; + + state = GetGlobalState(); + if (!state || !state->TrackFreedMemory) return FALSE; + + current = state->FreedMemoryList.Flink; + while (current != &state->FreedMemoryList) { + entry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); + + if (entry->Address == Address && entry->AllocationId == AllocationId) { + found = TRUE; + state->DoubleFreeCount++; + + if (!state->SuppressErrors) { + HEAP_ERROR("=== DOUBLE-FREE DETECTED ==="); + HEAP_ERROR("Address: %p (Size: %zu bytes, Allocation ID: %llu)", Address, entry->Size, AllocationId); + HEAP_ERROR("Originally allocated at: %s:%d", entry->AllocFileName, entry->AllocLineNumber); + HEAP_ERROR("First freed at: %s:%d (Thread: %lu)", + entry->FreeFileName, entry->FreeLineNumber, entry->ThreadId); + HEAP_ERROR("Attempted second free at: %s:%d (Thread: %lu)", + FreeFileName, FreeLineNumber, GetCurrentThreadId()); + HEAP_ERROR("============================"); + } + break; + } + + current = current->Flink; + } + + return found; +} + +void CleanupOldFreedEntries(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (!state) return; + + while (state->FreedEntryCount > state->MaxFreedEntries && !IsListEmpty(&state->FreedMemoryList)) { + PLIST_ENTRY lastEntry = state->FreedMemoryList.Blink; + PFREED_MEMORY_ENTRY entry = CONTAINING_RECORD(lastEntry, FREED_MEMORY_ENTRY, ListEntry); + + RemoveEntryList(lastEntry); + HeapFree(GetProcessHeap(), 0, entry); + state->FreedEntryCount--; + } +} + +BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumber) { + GLOBAL_STATE* state; + PLIST_ENTRY current; + PMEMORY_TRACKING_ENTRY entry; + BOOL found = FALSE; + const char* allocFileName = "Unknown"; + int allocLineNumber = 0; + SIZE_T allocSize = 0; + ULONGLONG allocationId = 0; + + if (!Address) { + return FALSE; + } + + state = GetGlobalState(); + if (!state) { + return FALSE; + } + + current = state->MemoryAllocations.Flink; + + while (current != &state->MemoryAllocations) { + entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); + + if (entry->Address == Address) { + allocFileName = entry->FileName; + allocLineNumber = entry->LineNumber; + allocSize = entry->Size; + allocationId = entry->AllocationId; + + RemoveEntryList(&entry->ListEntry); + state->CurrentBytesAllocated -= entry->Size; + HeapFree(GetProcessHeap(), 0, entry); + + found = TRUE; + break; + } + + current = current->Flink; + } + + if (found && state->TrackFreedMemory) { + TrackFreedMemoryLocked(Address, allocSize, allocFileName, allocLineNumber, FreeFileName, FreeLineNumber, allocationId); + } + + return found; +} + +PVOID ExAllocatePoolWithTracking(POOL_TYPE PoolType, SIZE_T NumberOfBytes, const char* FileName, int LineNumber) { + GLOBAL_STATE* state; + PVOID ptr; + + UNREFERENCED_PARAMETER(PoolType); + + state = GetGlobalState(); + if (!state) { + return NULL; + } + + ptr = HeapAlloc(state->HeapHandle, 0, NumberOfBytes); + + if (ptr == NULL) { + if (!state->SuppressErrors) { + HEAP_ERROR("Memory allocation failed for %zu bytes", NumberOfBytes); + } + return NULL; + } + + TrackAllocation(ptr, NumberOfBytes, FileName, LineNumber); + + return ptr; +} + +void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber) { + GLOBAL_STATE* state; + BOOL found; + BOOL isDoubleFree = FALSE; + ULONGLONG currentAllocationId = 0; + + HEAP_TRACE("_ExFreePoolWithTracking: Freeing %p from %s:%d", pointer, + FileName ? FileName : "Unknown", LineNumber); + + if (!pointer) { + HEAP_TRACE("_ExFreePoolWithTracking: NULL pointer provided, returning"); + return; + } + + state = GetGlobalState(); + if (!state) { + HEAP_ERROR("_ExFreePoolWithTracking: Failed to get global state"); + return; + } + + HEAP_VERBOSE("_ExFreePoolWithTracking: State info - heap handle: %p, TrackFreedMemory: %d", + state->HeapHandle, state->TrackFreedMemory); + + if (state->TrackFreedMemory) { + PLIST_ENTRY current = state->MemoryAllocations.Flink; + while (current != &state->MemoryAllocations) { + PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); + if (entry->Address == pointer) { + currentAllocationId = entry->AllocationId; + break; + } + current = current->Flink; + } + + if (currentAllocationId != 0) { + HEAP_TRACE("_ExFreePoolWithTracking: Checking for double-free of %p (ID: %llu)", pointer, currentAllocationId); + isDoubleFree = CheckForDoubleFree(pointer, FileName, LineNumber, currentAllocationId); + if (isDoubleFree) { + HEAP_WARN("_ExFreePoolWithTracking: Double-free detected for %p, returning without freeing", pointer); + return; + } + HEAP_TRACE("_ExFreePoolWithTracking: No double-free detected for %p", pointer); + } + } + + HEAP_TRACE("_ExFreePoolWithTracking: Calling UntrackAllocation for %p", pointer); + found = UntrackAllocation(pointer, FileName, LineNumber); + + HEAP_VERBOSE("_ExFreePoolWithTracking: UntrackAllocation returned %d for %p", found, pointer); + + if (!found && !state->SuppressErrors) { + HEAP_WARN("_ExFreePoolWithTracking: Address %p not found in tracking, validating heap pointer", pointer); + BOOL isValidHeapPtr = FALSE; + __try { + isValidHeapPtr = HeapValidate(state->HeapHandle, 0, pointer); + HEAP_VERBOSE("_ExFreePoolWithTracking: HeapValidate returned %d for %p", isValidHeapPtr, pointer); + } + __except (EXCEPTION_EXECUTE_HANDLER) { + HEAP_ERROR("_ExFreePoolWithTracking: Exception during HeapValidate for %p: 0x%08X", + pointer, GetExceptionCode()); + isValidHeapPtr = FALSE; + } + + if (isValidHeapPtr) { + HEAP_WARN("Attempting to free untracked but valid heap memory at %p from %s:%d", + pointer, FileName, LineNumber); + } else { + HEAP_WARN("Attempting to free invalid memory pointer at %p from %s:%d", + pointer, FileName, LineNumber); + } + } + + if (!isDoubleFree) { + HEAP_TRACE("_ExFreePoolWithTracking: Calling HeapFree for %p", pointer); + __try { + HeapFree(state->HeapHandle, 0, pointer); + HEAP_VERBOSE("_ExFreePoolWithTracking: Successfully freed %p", pointer); + } + __except (EXCEPTION_EXECUTE_HANDLER) { + HEAP_ERROR("_ExFreePoolWithTracking: Exception occurred while freeing memory at %p from %s:%d (Exception: 0x%08X)", + pointer, FileName, LineNumber, GetExceptionCode()); + if (!state->SuppressErrors) { + HEAP_ERROR("Exception occurred while freeing memory at %p from %s:%d (Exception: 0x%08X)", + pointer, FileName, LineNumber, GetExceptionCode()); + } + } + } +} + +BOOL IsValidHeapPointer(PVOID pointer) { + GLOBAL_STATE* state = GetGlobalState(); + PLIST_ENTRY current; + + if (!state || !pointer) return FALSE; + + current = state->MemoryAllocations.Flink; + while (current != &state->MemoryAllocations) { + PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); + + if (entry->Address == pointer) { + return TRUE; + } + + current = current->Flink; + } + + return FALSE; +} + +void PrintMemoryLeaks(void) { + GLOBAL_STATE* state; + BOOL foundLeaks; + SIZE_T leakCount; + SIZE_T leakBytes; + PLIST_ENTRY current; + PMEMORY_TRACKING_ENTRY entry; + + state = GetGlobalState(); + if (!state) return; + + foundLeaks = FALSE; + leakCount = 0; + leakBytes = 0; + + HEAP_INFO("=== MEMORY LEAK REPORT ==="); + + current = state->MemoryAllocations.Flink; + while (current != &state->MemoryAllocations) { + entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); + + if (!foundLeaks) { + HEAP_INFO("Address | Size | Allocation Location"); + HEAP_INFO("------------- | -------- | ------------------"); + foundLeaks = TRUE; + } + + HEAP_INFO("%p | %8d | %s:%d", + entry->Address, + (int)entry->Size, + entry->FileName, + entry->LineNumber); + + leakCount++; + leakBytes += entry->Size; + + current = current->Flink; + } + + if (foundLeaks) { + HEAP_INFO("Total: %d leaks, %d bytes", (int)leakCount, (int)leakBytes); + } else { + HEAP_INFO("No memory leaks detected!"); + } + + HEAP_INFO("Memory usage statistics:"); + HEAP_INFO(" Total allocations: %d", (int)state->AllocationCount); + HEAP_INFO(" Total bytes allocated: %d", (int)state->TotalBytesAllocated); + HEAP_INFO(" Peak bytes allocated: %d", (int)state->PeakBytesAllocated); + HEAP_INFO(" Double-free attempts: %d", (int)state->DoubleFreeCount); + HEAP_INFO(" Freed entries tracked: %d", (int)state->FreedEntryCount); + HEAP_INFO("==========================="); +} + +void PrintDoubleFreeReport(void) { + GLOBAL_STATE* state; + PLIST_ENTRY current; + PFREED_MEMORY_ENTRY entry; + SIZE_T entryCount = 0; + + state = GetGlobalState(); + if (!state) return; + + HEAP_INFO("=== FREED MEMORY REPORT ==="); + HEAP_INFO("Total double-free attempts detected: %d", (int)state->DoubleFreeCount); + HEAP_INFO("Currently tracking %d freed allocations", (int)state->FreedEntryCount); + HEAP_INFO("Maximum freed entries to track: %d", (int)state->MaxFreedEntries); + + if (state->FreedEntryCount > 0) { + HEAP_INFO("Recent freed allocations:"); + HEAP_INFO("Address | Size | Alloc ID | Alloc Location | Free Location | Thread"); + HEAP_INFO("------------- | -------- | -------- | --------------- | --------------- | ------"); + + current = state->FreedMemoryList.Flink; + while (current != &state->FreedMemoryList && entryCount < 20) + { + entry = CONTAINING_RECORD(current, FREED_MEMORY_ENTRY, ListEntry); + + HEAP_INFO("%p | %8d | %8llu | %15s:%-4d | %15s:%-4d | %6lu", + entry->Address, + (int)entry->Size, + entry->AllocationId, + entry->AllocFileName, entry->AllocLineNumber, + entry->FreeFileName, entry->FreeLineNumber, + entry->ThreadId); + + entryCount++; + current = current->Flink; + } + + if (state->FreedEntryCount > 20) { + HEAP_INFO("... and %d more entries", (int)(state->FreedEntryCount - 20)); + } + } + + HEAP_INFO("=============================="); +} + +void SetErrorSuppression(BOOL suppress) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + state->SuppressErrors = suppress; + } +} + +BOOL GetErrorSuppression(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + return state->SuppressErrors; + } + return FALSE; +} + +void SetFreedMemoryTracking(BOOL enable) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + state->TrackFreedMemory = enable; + } +} + +BOOL GetFreedMemoryTracking(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + return state->TrackFreedMemory; + } + return FALSE; +} + +void SetMaxFreedEntries(SIZE_T maxEntries) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + state->MaxFreedEntries = maxEntries; + + if (state->FreedEntryCount > maxEntries) { + CleanupOldFreedEntries(); + } + } +} + +SIZE_T GetMaxFreedEntries(void) { + GLOBAL_STATE* state = GetGlobalState(); + if (state) { + return state->MaxFreedEntries; + } + return 0; +} diff --git a/src/Registry.c b/src/Registry.c new file mode 100644 index 0000000..2c806ed --- /dev/null +++ b/src/Registry.c @@ -0,0 +1,779 @@ +/* + * Copyright 2025 WinKernelLite Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/Registry.h" + +NTSTATUS ZwEnumerateKey( + IN HANDLE KeyHandle, + IN ULONG Index, + IN KEY_INFORMATION_CLASS KeyInformationClass, + OUT PVOID KeyInformation, + IN ULONG Length, + OUT PULONG ResultLength +) +{ + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!KeyInformation) { + return STATUS_INVALID_PARAMETER; + } + if (!ResultLength) { + return STATUS_INVALID_PARAMETER; + } + + WCHAR tempKeyName[1024]; + DWORD keyNameSize = sizeof(tempKeyName) / sizeof(WCHAR); + FILETIME lastWriteTime; + + LSTATUS status = RegEnumKeyExW( + (HKEY)KeyHandle, + Index, + tempKeyName, + &keyNameSize, + NULL, // Reserved + NULL, // Class name (not used) + NULL, // Class name length + &lastWriteTime // Last write time + ); + + // Convert the result to the requested information format + if (status == ERROR_SUCCESS) { + ULONG nameLength = (ULONG)(keyNameSize * sizeof(WCHAR)); + ULONG requiredSize = 0; + + switch (KeyInformationClass) { + case KeyBasicInformation: { + requiredSize = sizeof(KEY_BASIC_INFORMATION) + nameLength; + if (Length < requiredSize) { + if (ResultLength) *ResultLength = requiredSize; + return STATUS_BUFFER_TOO_SMALL; + } + + PKEY_BASIC_INFORMATION basicInfo = (PKEY_BASIC_INFORMATION)KeyInformation; + basicInfo->LastWriteTime.QuadPart = ((LARGE_INTEGER*)&lastWriteTime)->QuadPart; + basicInfo->TitleIndex = 0; + basicInfo->NameLength = nameLength; + memcpy_s(basicInfo->Name, nameLength, tempKeyName, nameLength); + + if (ResultLength) *ResultLength = requiredSize; + break; + } + + case KeyNodeInformation: { + requiredSize = sizeof(KEY_NODE_INFORMATION) + nameLength; + if (Length < requiredSize) { + if (ResultLength) *ResultLength = requiredSize; + return STATUS_BUFFER_TOO_SMALL; + } + + PKEY_NODE_INFORMATION nodeInfo = (PKEY_NODE_INFORMATION)KeyInformation; + nodeInfo->LastWriteTime.QuadPart = ((LARGE_INTEGER*)&lastWriteTime)->QuadPart; + nodeInfo->TitleIndex = 0; + nodeInfo->ClassLength = 0; + nodeInfo->ClassOffset = 0; + nodeInfo->NameLength = nameLength; + memcpy_s(nodeInfo->Name, nameLength, tempKeyName, nameLength); + + if (ResultLength) *ResultLength = requiredSize; + break; + } + + case KeyFullInformation: + default: + // Unsupported information class + return STATUS_INVALID_PARAMETER; + } + + return STATUS_SUCCESS; + } else if (status == ERROR_NO_MORE_ITEMS) { + return STATUS_NO_MORE_ENTRIES; + } else if (status == ERROR_MORE_DATA) { + return STATUS_BUFFER_TOO_SMALL; + } else { + return STATUS_INVALID_PARAMETER; + } +} + +NTSTATUS ZwClose( + IN HANDLE Handle +) +{ + // Check for obviously invalid handles upfront + if (Handle == NULL || Handle == INVALID_HANDLE_VALUE) { + return STATUS_INVALID_PARAMETER; + } + + // Check if the handle is valid using GetHandleInformation + DWORD flags = 0; + if (!GetHandleInformation(Handle, &flags)) { + return STATUS_INVALID_HANDLE; + } + + // Handle is valid, so close it + if (CloseHandle(Handle)) { + return STATUS_SUCCESS; + } else { + // This should be unreachable if GetHandleInformation succeeded + return STATUS_INVALID_HANDLE; + } +} + +/** + * @brief Helper function to parse a full registry path into root key and subkey path + * + * @param fullPath The full registry path (e.g., L"\\Registry\\Machine\\Software\\Test") + * @param rootKey Pointer to receive the root key handle + * @param subKeyPath Pointer to receive the subkey path + * @return NTSTATUS STATUS_SUCCESS on success or appropriate error code + */ +NTSTATUS ParseRegistryPath( + IN PCWSTR fullPath, + OUT HKEY* rootKey, + OUT PCWSTR* subKeyPath +) +{ + if (!fullPath || !rootKey || !subKeyPath) { + return STATUS_INVALID_PARAMETER; + } + + // Check for known registry path prefixes + if (wcsncmp(fullPath, L"\\Registry\\Machine\\", 18) == 0) { + *rootKey = HKEY_LOCAL_MACHINE; + *subKeyPath = fullPath + 18; // Skip "\\Registry\\Machine\\" + return STATUS_SUCCESS; + } else if (wcsncmp(fullPath, L"\\Registry\\User\\", 15) == 0) { + *rootKey = HKEY_USERS; + *subKeyPath = fullPath + 15; // Skip "\\Registry\\User\\" + return STATUS_SUCCESS; + } else if (wcsncmp(fullPath, L"HKEY_LOCAL_MACHINE\\", 19) == 0) { + *rootKey = HKEY_LOCAL_MACHINE; + *subKeyPath = fullPath + 19; // Skip "HKEY_LOCAL_MACHINE\\" + return STATUS_SUCCESS; + } else if (wcsncmp(fullPath, L"HKEY_CURRENT_USER\\", 18) == 0) { + *rootKey = HKEY_CURRENT_USER; + *subKeyPath = fullPath + 18; // Skip "HKEY_CURRENT_USER\\" + return STATUS_SUCCESS; + } else if (wcsncmp(fullPath, L"HKEY_USERS\\", 11) == 0) { + *rootKey = HKEY_USERS; + *subKeyPath = fullPath + 11; // Skip "HKEY_USERS\\" + return STATUS_SUCCESS; + } else if (wcsncmp(fullPath, L"HKEY_CLASSES_ROOT\\", 18) == 0) { + *rootKey = HKEY_CLASSES_ROOT; + *subKeyPath = fullPath + 18; // Skip "HKEY_CLASSES_ROOT\\" + return STATUS_SUCCESS; + } + + // Unknown registry path format + return STATUS_OBJECT_PATH_SYNTAX_BAD; +} + +NTSTATUS ZwCreateKey( + OUT PHANDLE KeyHandle, + IN ACCESS_MASK DesiredAccess, + IN POBJECT_ATTRIBUTES ObjectAttributes, + IN ULONG TitleIndex, + IN PUNICODE_STRING Class OPTIONAL, + IN ULONG CreateOptions, + OUT PULONG Disposition OPTIONAL +) +{ + // Validate parameters + UNREFERENCED_PARAMETER(TitleIndex); // Mark as unused parameter to fix C4100 warning + + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!ObjectAttributes) { + return STATUS_INVALID_PARAMETER; + } + if (!ObjectAttributes->ObjectName) { + return STATUS_INVALID_PARAMETER; + } + + HKEY rootKey = NULL; + PCWSTR keyPath = NULL; + DWORD dispositionValue = 0; + LSTATUS status; + NTSTATUS ntStatus; + + // Get root key and subkey path from ObjectAttributes + if (ObjectAttributes->RootDirectory) { + // Root directory provided - use it with relative path + rootKey = (HKEY)ObjectAttributes->RootDirectory; + keyPath = ObjectAttributes->ObjectName->Buffer; + } else { + // No root directory - parse full path from ObjectName + ntStatus = ParseRegistryPath( + ObjectAttributes->ObjectName->Buffer, + &rootKey, + &keyPath + ); + if (!NT_SUCCESS(ntStatus)) { + return ntStatus; + } + } + + status = RegCreateKeyExW( + rootKey, + keyPath, + 0, // Reserved + Class ? Class->Buffer : NULL, // Class + CreateOptions, // Options + DesiredAccess, // Access + NULL, // Security attributes not supported + (PHKEY)KeyHandle, + &dispositionValue + ); + + // Pass back the disposition if requested + if (Disposition) { + *Disposition = dispositionValue; + } + + // Convert Windows error to NTSTATUS + if (status == ERROR_SUCCESS) { + return STATUS_SUCCESS; + } else if (status == ERROR_ACCESS_DENIED) { + return STATUS_ACCESS_DENIED; + } else if (status == ERROR_INVALID_PARAMETER) { + return STATUS_INVALID_PARAMETER; + } else { + return STATUS_UNSUCCESSFUL; + } +} + +NTSTATUS ZwSetValueKey( + IN HANDLE KeyHandle, + IN PUNICODE_STRING ValueName, + IN ULONG TitleIndex, + IN ULONG Type, + IN PVOID Data, + IN ULONG DataSize +) +{ + // Validate parameters + UNREFERENCED_PARAMETER(TitleIndex); // Mark as unused parameter to fix C4100 warning + + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!ValueName) { + return STATUS_INVALID_PARAMETER; + } + if (!ValueName->Buffer) { + return STATUS_INVALID_PARAMETER; + } + if (!Data && DataSize > 0) { + return STATUS_INVALID_PARAMETER; + } + + // avoid potential problems with the strings not being null-terminated as UNICODE_STRING is not required to have null terminated buffer + // and we don't know how exactly this works in the Win kernel + ULONG nameChars = ValueName->Length / sizeof(WCHAR); + WCHAR* nullTerminatedName = (WCHAR*)HeapAlloc(GetProcessHeap(), 0, (nameChars + 1) * sizeof(WCHAR)); + if (!nullTerminatedName) { + return STATUS_NO_MEMORY; + } + + // Copy the name and add null terminator + memcpy_s(nullTerminatedName, ValueName->Length, ValueName->Buffer, ValueName->Length); + nullTerminatedName[nameChars] = L'\0'; + + LSTATUS status = RegSetValueExW( + (HKEY)KeyHandle, + nullTerminatedName, + 0, // Reserved + Type, + (CONST BYTE*)Data, + DataSize + ); + + HeapFree(GetProcessHeap(), 0, nullTerminatedName); + + if (status == ERROR_SUCCESS) { + return STATUS_SUCCESS; + } else if (status == ERROR_ACCESS_DENIED) { + return STATUS_ACCESS_DENIED; + } else if (status == ERROR_INVALID_PARAMETER) { + return STATUS_INVALID_PARAMETER; + } else { + return STATUS_UNSUCCESSFUL; + } +} + +NTSTATUS ZwOpenKey( + OUT PHANDLE KeyHandle, + IN ACCESS_MASK DesiredAccess, + IN POBJECT_ATTRIBUTES ObjectAttributes +) +{ + // Validate parameters + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!ObjectAttributes) { + return STATUS_INVALID_PARAMETER; + } + if (!ObjectAttributes->ObjectName) { + return STATUS_INVALID_PARAMETER; + } + if (!ObjectAttributes->ObjectName->Buffer) { + return STATUS_INVALID_PARAMETER; + } + + HKEY rootKey = NULL; + PCWSTR keyPath = NULL; + LSTATUS status; + NTSTATUS ntStatus; + + // Get root key and subkey path from ObjectAttributes + if (ObjectAttributes->RootDirectory) { + // Root directory provided - use it with relative path + rootKey = (HKEY)ObjectAttributes->RootDirectory; + keyPath = ObjectAttributes->ObjectName->Buffer; + } else { + // No root directory - parse full path from ObjectName or assume HKEY_CURRENT_USER + ntStatus = ParseRegistryPath( + ObjectAttributes->ObjectName->Buffer, + &rootKey, + &keyPath + ); + if (!NT_SUCCESS(ntStatus)) { + // If path parsing fails, assume it's a relative path under HKEY_CURRENT_USER + rootKey = HKEY_CURRENT_USER; + keyPath = ObjectAttributes->ObjectName->Buffer; + } + } + + status = RegOpenKeyExW( + rootKey, + keyPath, + 0, // Reserved + DesiredAccess, // Access + (PHKEY)KeyHandle + ); + + // Convert Windows error to NTSTATUS + if (status == ERROR_SUCCESS) { + return STATUS_SUCCESS; + } else if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else if (status == ERROR_PATH_NOT_FOUND) { + return STATUS_OBJECT_PATH_NOT_FOUND; + } else if (status == ERROR_ACCESS_DENIED) { + return STATUS_ACCESS_DENIED; + } else { + return STATUS_UNSUCCESSFUL; + } +} + +NTSTATUS ZwQueryValueKey( + IN HANDLE KeyHandle, + IN PUNICODE_STRING ValueName, + IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, + OUT PVOID KeyValueInformation, + IN ULONG Length, + OUT PULONG ResultLength +) +{ + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!ValueName) { + return STATUS_INVALID_PARAMETER; + } + if (!ValueName->Buffer) { + return STATUS_INVALID_PARAMETER; + } + if (!KeyValueInformation) { + return STATUS_INVALID_PARAMETER; + } + if (!ResultLength) { + return STATUS_INVALID_PARAMETER; + } + + DWORD dwType; + DWORD dwDataSize = 0; + LSTATUS status; + + switch (KeyValueInformationClass) { + case KeyValuePartialInformation: + { + PKEY_VALUE_PARTIAL_INFORMATION partialInfo = (PKEY_VALUE_PARTIAL_INFORMATION)KeyValueInformation; + + // First call to get required buffer size + status = RegQueryValueExW( + (HKEY)KeyHandle, + ValueName->Buffer, + NULL, + &dwType, + NULL, + &dwDataSize + ); + + if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { + if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + // Calculate required size for the structure + ULONG requiredSize = sizeof(KEY_VALUE_PARTIAL_INFORMATION) + dwDataSize - sizeof(BYTE); + *ResultLength = requiredSize; + + if (Length < requiredSize) { + return STATUS_BUFFER_TOO_SMALL; + } + + // Setup the structure fields + partialInfo->TitleIndex = 0; + partialInfo->Type = dwType; + partialInfo->DataLength = dwDataSize; + + status = RegQueryValueExW( + (HKEY)KeyHandle, + ValueName->Buffer, + NULL, + &dwType, + (LPBYTE)partialInfo->Data, + &dwDataSize + ); + + if (status == ERROR_SUCCESS) { + return STATUS_SUCCESS; + } else if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else if (status == ERROR_MORE_DATA) { + return STATUS_BUFFER_TOO_SMALL; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + case KeyValueFullInformation: + { + PKEY_VALUE_FULL_INFORMATION fullInfo = (PKEY_VALUE_FULL_INFORMATION)KeyValueInformation; + // First call to get required buffer size + status = RegQueryValueExW( + (HKEY)KeyHandle, + ValueName->Buffer, + NULL, + &dwType, + NULL, + &dwDataSize + ); + + if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { + if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else if (status == ERROR_MORE_DATA) { + return STATUS_BUFFER_TOO_SMALL; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + if (Length < sizeof(KEY_VALUE_FULL_INFORMATION)) { + *ResultLength = sizeof(KEY_VALUE_FULL_INFORMATION) + dwDataSize; + return STATUS_BUFFER_TOO_SMALL; + } + + // Setup the structure fields + fullInfo->TitleIndex = 0; + fullInfo->Type = dwType; + fullInfo->NameLength = ValueName->Length; + fullInfo->DataLength = dwDataSize; + fullInfo->DataOffset = FIELD_OFFSET(KEY_VALUE_FULL_INFORMATION, Name) + + ValueName->Length + sizeof(WCHAR); + + // Make sure we have enough space for name and data + if (Length < fullInfo->DataOffset + dwDataSize) { + *ResultLength = fullInfo->DataOffset + dwDataSize; + return STATUS_BUFFER_TOO_SMALL; + } + + // Copy the name + memcpy_s(fullInfo->Name, ValueName->Length + sizeof(WCHAR), ValueName->Buffer, ValueName->Length); + ((PWSTR)((PBYTE)fullInfo->Name + ValueName->Length))[0] = L'\0'; + + + status = RegQueryValueExW( + (HKEY)KeyHandle, + ValueName->Buffer, + NULL, + &dwType, + (LPBYTE)((PBYTE)KeyValueInformation + fullInfo->DataOffset), + &dwDataSize + ); + + *ResultLength = fullInfo->DataOffset + dwDataSize; + + if (status == ERROR_SUCCESS) { + return STATUS_SUCCESS; + } else if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else if (status == ERROR_MORE_DATA) { + return STATUS_BUFFER_TOO_SMALL; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + case KeyValueBasicInformation: + { + PKEY_VALUE_BASIC_INFORMATION basicInfo = (PKEY_VALUE_BASIC_INFORMATION)KeyValueInformation; + + // First call to get required buffer size and type + status = RegQueryValueExW( + (HKEY)KeyHandle, + ValueName->Buffer, + NULL, + &dwType, + NULL, + &dwDataSize + ); + + if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) { + if (status == ERROR_FILE_NOT_FOUND) { + return STATUS_OBJECT_NAME_NOT_FOUND; + } else { + return STATUS_UNSUCCESSFUL; + } + } + + // Calculate required size for the structure + ULONG requiredSize = sizeof(KEY_VALUE_BASIC_INFORMATION) + ValueName->Length; + *ResultLength = requiredSize; + + if (Length < requiredSize) { + return STATUS_BUFFER_TOO_SMALL; + } + + // Setup the structure fields + basicInfo->TitleIndex = 0; + basicInfo->Type = dwType; + basicInfo->NameLength = ValueName->Length; + + // Copy the name + memcpy_s(basicInfo->Name, ValueName->Length, ValueName->Buffer, ValueName->Length); + + return STATUS_SUCCESS; + } + + default: + return STATUS_INVALID_PARAMETER; + } +} + +NTSTATUS ZwEnumerateValueKey( + IN HANDLE KeyHandle, + IN ULONG Index, + IN KEY_VALUE_INFORMATION_CLASS KeyValueInformationClass, + OUT PVOID KeyValueInformation, + IN ULONG Length, + OUT PULONG ResultLength +) +{ + if (!KeyHandle) { + return STATUS_INVALID_PARAMETER; + } + if (!KeyValueInformation) { + return STATUS_INVALID_PARAMETER; + } + if (!ResultLength) { + return STATUS_INVALID_PARAMETER; + } + + // Status variable is not used in this function but kept for consistency with original API + // NTSTATUS status = STATUS_SUCCESS; + HKEY hKey = (HKEY)KeyHandle; + DWORD maxValueNameLen = 0, maxValueDataLen = 0; + WCHAR valueName[MAX_PATH] = { 0 }; + DWORD valueNameLen = MAX_PATH; + DWORD type = 0; + BYTE* data = NULL; + DWORD dataSize = 0; + + // Get value name and data size + LSTATUS winStatus = RegQueryInfoKeyW( + hKey, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + &maxValueNameLen, + &maxValueDataLen, + NULL, + NULL + ); + if (winStatus != ERROR_SUCCESS) { + return winStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; + } + + // Get value name, type and data + valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW + winStatus = RegEnumValueW( + hKey, + Index, + valueName, + &valueNameLen, + NULL, + &type, + NULL, + &dataSize + ); + if (winStatus != ERROR_SUCCESS) { + return winStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : + winStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : + winStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; + } + + // Fill in the information based on the requested class + switch (KeyValueInformationClass) { + case KeyValueBasicInformation: { + PKEY_VALUE_BASIC_INFORMATION pBasic = (PKEY_VALUE_BASIC_INFORMATION)KeyValueInformation; + size_t requiredSize = sizeof(KEY_VALUE_BASIC_INFORMATION) + (valueNameLen * sizeof(WCHAR)); + *ResultLength = (ULONG)requiredSize; + + if (Length < requiredSize) { + return STATUS_BUFFER_TOO_SMALL; + } + + pBasic->TitleIndex = 0; + pBasic->Type = type; + pBasic->NameLength = valueNameLen * sizeof(WCHAR); + memcpy_s(pBasic->Name, pBasic->NameLength, valueName, pBasic->NameLength); + break; + } + case KeyValueFullInformation: { + PKEY_VALUE_FULL_INFORMATION pFull = (PKEY_VALUE_FULL_INFORMATION)KeyValueInformation; + size_t requiredSize = sizeof(KEY_VALUE_FULL_INFORMATION) + (valueNameLen * sizeof(WCHAR)) + dataSize; + *ResultLength = (ULONG)requiredSize; + + if (Length < requiredSize) { + return STATUS_BUFFER_TOO_SMALL; + } + + data = (BYTE*)HeapAlloc(GetProcessHeap(), 0, dataSize); + if (!data) { + return STATUS_NO_MEMORY; + } + + DWORD actualDataSize = dataSize; + valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW + LSTATUS valueStatus = RegEnumValueW( + hKey, + Index, + valueName, + &valueNameLen, + NULL, + &type, + data, + &actualDataSize + ); + if (valueStatus != ERROR_SUCCESS) { + HeapFree(GetProcessHeap(), 0, data); + return valueStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : + valueStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : + valueStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; + } + + pFull->TitleIndex = 0; + pFull->Type = type; + pFull->DataOffset = sizeof(KEY_VALUE_FULL_INFORMATION) + (valueNameLen * sizeof(WCHAR)); + pFull->DataLength = dataSize; + pFull->NameLength = valueNameLen * sizeof(WCHAR); + memcpy_s(pFull->Name, pFull->NameLength, valueName, pFull->NameLength); + memcpy_s((BYTE*)pFull + pFull->DataOffset, dataSize, data, dataSize); + HeapFree(GetProcessHeap(), 0, data); + break; + } + case KeyValuePartialInformation: { + PKEY_VALUE_PARTIAL_INFORMATION pPartial = (PKEY_VALUE_PARTIAL_INFORMATION)KeyValueInformation; + size_t requiredSize = sizeof(KEY_VALUE_PARTIAL_INFORMATION) + dataSize; + *ResultLength = (ULONG)requiredSize; + + if (Length < requiredSize) { + return STATUS_BUFFER_TOO_SMALL; + } + + data = (BYTE*)HeapAlloc(GetProcessHeap(), 0, dataSize); + if (!data) { + return STATUS_NO_MEMORY; + } + + DWORD actualDataSize = dataSize; + valueNameLen = MAX_PATH; // Reset length before calling RegEnumValueW + LSTATUS valueStatus = RegEnumValueW( + hKey, + Index, + valueName, + &valueNameLen, + NULL, + &type, + data, + &actualDataSize + ); + if (valueStatus != ERROR_SUCCESS) { + HeapFree(GetProcessHeap(), 0, data); + return valueStatus == ERROR_NO_MORE_ITEMS ? STATUS_NO_MORE_ENTRIES : + valueStatus == ERROR_MORE_DATA ? STATUS_BUFFER_TOO_SMALL : + valueStatus == ERROR_ACCESS_DENIED ? STATUS_ACCESS_DENIED : STATUS_UNSUCCESSFUL; + } + + pPartial->TitleIndex = 0; + pPartial->Type = type; + pPartial->DataLength = dataSize; + memcpy_s(pPartial->Data, dataSize, data, dataSize); + HeapFree(GetProcessHeap(), 0, data); + break; + } + + default: + return STATUS_INVALID_PARAMETER; + } + + return STATUS_SUCCESS; +} + +// ZwDeleteKey: Deletes a registry key using RegDeleteKeyW. Only works for empty keys (not recursive). +NTSTATUS ZwDeleteKey(HANDLE KeyHandle) { + // This deletes the key that KeyHandle refers to + // We need to close the handle after the key is deleted + LSTATUS status = RegDeleteKeyW((HKEY)KeyHandle, L""); + + if (status != ERROR_SUCCESS) { + // Try to close the handle to avoid leaks even if delete failed + RegCloseKey((HKEY)KeyHandle); + return status; + } + + // Key was deleted, now close the handle (safe because RegDeleteKeyW doesn't invalidate it) + RegCloseKey((HKEY)KeyHandle); + return STATUS_SUCCESS; +} + +NTSTATUS ZwDeleteKeyEx(HKEY KeyHandle, LPCWSTR SubKey) { + // RegDeleteKeyW returns ERROR_SUCCESS (0) on success, map to STATUS_SUCCESS (0) + LSTATUS status = RegDeleteKeyW(KeyHandle, SubKey); + return (status == ERROR_SUCCESS) ? STATUS_SUCCESS : status; +} diff --git a/src/Resource.c b/src/Resource.c index b954085..e4ce9ee 100644 --- a/src/Resource.c +++ b/src/Resource.c @@ -27,3 +27,276 @@ BOOLEAN g_WinKernelLite_KernelApcDisableLockInitialized = FALSE; LIST_ENTRY g_WinKernelLite_SystemResourcesList; CRITICAL_SECTION g_WinKernelLite_SystemResourcesLock; BOOLEAN g_WinKernelLite_SystemResourcesInitialized = FALSE; + +void EnsureSystemResourcesListInitialized(void) +{ + if (!g_WinKernelLite_SystemResourcesInitialized) { + InitializeCriticalSection(&g_WinKernelLite_SystemResourcesLock); + InitializeListHead(&g_WinKernelLite_SystemResourcesList); + g_WinKernelLite_SystemResourcesInitialized = TRUE; + } +} + +void CleanupGlobalResources(void) +{ + /* Cleanup APC disable lock if initialized */ + if (g_WinKernelLite_KernelApcDisableLockInitialized) { + DeleteCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + g_WinKernelLite_KernelApcDisableLockInitialized = FALSE; + } + + /* Cleanup system resources lock if initialized */ + if (g_WinKernelLite_SystemResourcesInitialized) { + DeleteCriticalSection(&g_WinKernelLite_SystemResourcesLock); + g_WinKernelLite_SystemResourcesInitialized = FALSE; + } +} + +NTSTATUS +ExInitializeResourceLite( + IN PERESOURCE Resource +) +{ + if (!Resource) + return STATUS_INVALID_PARAMETER; + + ZeroMemory(Resource, sizeof(ERESOURCE)); + + // Make sure global list is initialized + EnsureSystemResourcesListInitialized(); + + // Initialize the critical section for our user-mode implementation + InitializeCriticalSection(&Resource->CriticalSection); + + // Add to the global system resources list + EnterCriticalSection(&g_WinKernelLite_SystemResourcesLock); + InsertTailList(&g_WinKernelLite_SystemResourcesList, &Resource->SystemResourcesList); + LeaveCriticalSection(&g_WinKernelLite_SystemResourcesLock); + + return STATUS_SUCCESS; +} + +NTSTATUS +ExDeleteResourceLite( + IN PERESOURCE Resource +) +{ + if (!Resource) + return STATUS_INVALID_PARAMETER; + + // Remove from global system resources list + EnterCriticalSection(&g_WinKernelLite_SystemResourcesLock); + RemoveEntryList(&Resource->SystemResourcesList); + LeaveCriticalSection(&g_WinKernelLite_SystemResourcesLock); + + // Delete the critical section + DeleteCriticalSection(&Resource->CriticalSection); + + return STATUS_SUCCESS; +} + +BOOLEAN +ExAcquireResourceExclusiveLite( + IN PERESOURCE Resource, + IN BOOLEAN Wait +) +{ + ERESOURCE_THREAD CurrentThread; + BOOLEAN Result = FALSE; + + if (!Resource) + return FALSE; + + // Get the current thread ID as the resource thread + CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); + + // Enter the critical section if Wait is TRUE, otherwise try to enter + if (Wait) { + EnterCriticalSection(&Resource->CriticalSection); + } + else if (!TryEnterCriticalSection(&Resource->CriticalSection)) { + return FALSE; + } + + // Check if the resource is already owned + if (Resource->ActiveCount != 0) { + // If owned exclusively by the current thread, allow recursive exclusive acquisition + if (IsOwnedExclusive(Resource) && + (Resource->OwnerThreads[0].OwnerThread == CurrentThread)) { + Resource->OwnerThreads[0].OwnerCount += 1; + Result = TRUE; + } + else { + // Resource is owned by another thread or shared - cannot acquire exclusive + if (Wait == FALSE) { + Result = FALSE; + } + else { + // For simplicity in this user-mode implementation: + // If we need to wait and we're here, we know we're holding the critical section, + // but the resource is owned by someone else. We'll release and retry. + LeaveCriticalSection(&Resource->CriticalSection); + Sleep(1); // Yield to other threads + return ExAcquireResourceExclusiveLite(Resource, Wait); + } + } + } + else { + // Resource is not owned, so we can take it + Resource->Flag |= ResourceOwnedExclusive; + Resource->OwnerThreads[0].OwnerThread = CurrentThread; + Resource->OwnerThreads[0].OwnerCount = 1; + Resource->ActiveCount = 1; + Result = TRUE; + } + + // Always release the critical section + LeaveCriticalSection(&Resource->CriticalSection); + return Result; +} + +BOOLEAN +ExAcquireResourceSharedLite( + IN PERESOURCE Resource, + IN BOOLEAN Wait +) +{ + ERESOURCE_THREAD CurrentThread; + BOOLEAN Result = FALSE; + + if (!Resource) + return FALSE; + + // Get the current thread ID as the resource thread + CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); + + // Enter the critical section if Wait is TRUE, otherwise try to enter + if (Wait) { + EnterCriticalSection(&Resource->CriticalSection); + } + else if (!TryEnterCriticalSection(&Resource->CriticalSection)) { + return FALSE; + } + + // Check if the resource is already owned + if (Resource->ActiveCount != 0) { + if (IsOwnedExclusive(Resource)) { + if (Wait == FALSE) { + Result = FALSE; + } + else { + // For simplicity in this user-mode implementation: + // If we need to wait and we're here, we release and retry + LeaveCriticalSection(&Resource->CriticalSection); + Sleep(1); // Yield to other threads + return ExAcquireResourceSharedLite(Resource, Wait); + } + } + // It's owned shared, so we can add ourselves as another shared owner + else { + Resource->ActiveCount++; + Result = TRUE; + } + } + else { + // Resource is not owned, so we can take it shared + Resource->ActiveCount = 1; + Result = TRUE; + } + + // Always leave the critical section when acquiring shared + LeaveCriticalSection(&Resource->CriticalSection); + + return Result; +} + +VOID +ExReleaseResourceLite( + IN PERESOURCE Resource +) +{ + ERESOURCE_THREAD CurrentThread; + + if (!Resource) + return; + + // Get the current thread ID as the resource thread + CurrentThread = (ERESOURCE_THREAD)GetCurrentThreadId(); + + // Enter the critical section to safely modify the resource + EnterCriticalSection(&Resource->CriticalSection); + + if (IsOwnedExclusive(Resource)) { + // If owned exclusively, verify it's our thread + if (Resource->OwnerThreads[0].OwnerThread == CurrentThread) { + // Decrement the count, and if it reaches 0, release ownership + Resource->OwnerThreads[0].OwnerCount -= 1; + if (Resource->OwnerThreads[0].OwnerCount == 0) { + Resource->Flag &= ~ResourceOwnedExclusive; + Resource->OwnerThreads[0].OwnerThread = 0; + Resource->ActiveCount = 0; + } + } + } + else { + // For a shared resource, just decrement the active count + if (Resource->ActiveCount > 0) { + Resource->ActiveCount -= 1; + } + } + + // Release the critical section + LeaveCriticalSection(&Resource->CriticalSection); +} + +VOID +KeEnterCriticalRegion( + VOID +) +{ + /* Initialize the critical section if needed */ + if (!g_WinKernelLite_KernelApcDisableLockInitialized) { + InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; + } + + /* In kernel mode, this disables normal kernel APCs */ + /* In our user-mode implementation, we'll use a critical section for thread safety */ + EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + InterlockedIncrement(&g_WinKernelLite_KernelApcDisableCount); + LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); +} + +VOID +KeLeaveCriticalRegion( + VOID +) +{ + /* Initialize the critical section if needed (safety check) */ + if (!g_WinKernelLite_KernelApcDisableLockInitialized) { + InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; + } + /* Re-enables normal kernel APCs */ + /* In our user-mode implementation, we'll use a critical section for thread safety */ + EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + InterlockedDecrement(&g_WinKernelLite_KernelApcDisableCount); + LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); +} + +LONG GetKernelApcDisableCount(void) +{ + LONG currentValue; + + /* Initialize the critical section if needed */ + if (!g_WinKernelLite_KernelApcDisableLockInitialized) { + InitializeCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + g_WinKernelLite_KernelApcDisableLockInitialized = TRUE; + } + + EnterCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + currentValue = g_WinKernelLite_KernelApcDisableCount; + LeaveCriticalSection(&g_WinKernelLite_KernelApcDisableLock); + + return currentValue; +}