From 148aadf8b852231e2294778e416e7f0bedfb58bb Mon Sep 17 00:00:00 2001 From: Torin Date: Sun, 22 Mar 2026 19:28:02 +0200 Subject: [PATCH] fix(heap): add CRITICAL_SECTION to protect tracking linked lists Add TrackingLock to GLOBAL_STATE, initialized in GetGlobalState() and cleaned up in CleanupHeap(). All linked list operations (insert, remove, traverse) on MemoryAllocations and FreedMemoryList are now protected. HeapAlloc/HeapFree for user memory stay outside the lock since the Windows heap is already thread-safe. CRITICAL_SECTION supports recursive acquisition, so nested calls from _ExFreePoolWithTracking through UntrackAllocation and TrackFreedMemoryLocked are safe. Co-Authored-By: Claude Opus 4.6 (1M context) --- include/KernelHeap.h | 1 + src/KernelHeap.c | 52 +++++++++++++++++++++++++++++++++++++++----- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/include/KernelHeap.h b/include/KernelHeap.h index cee9510..d59a7eb 100644 --- a/include/KernelHeap.h +++ b/include/KernelHeap.h @@ -77,6 +77,7 @@ typedef struct _GLOBAL_STATE { HANDLE HeapHandle; BOOL SuppressErrors; /* Control error message output */ BOOL TrackFreedMemory; /* Control whether to track freed memory for double-free detection */ + CRITICAL_SECTION TrackingLock; /* Protects MemoryAllocations, FreedMemoryList, and counters */ } GLOBAL_STATE; /* Function declarations - implementations in KernelHeap.c */ diff --git a/src/KernelHeap.c b/src/KernelHeap.c index e89f2d7..71c614d 100644 --- a/src/KernelHeap.c +++ b/src/KernelHeap.c @@ -48,12 +48,14 @@ GLOBAL_STATE* GetGlobalState(void) { temp->NextAllocationId = 1; temp->SuppressErrors = FALSE; temp->TrackFreedMemory = TRUE; + InitializeCriticalSection(&temp->TrackingLock); /* Atomically publish: if another thread won the race, free our copy */ if (InterlockedCompareExchangePointer( (PVOID*)&g_WinKernelLite_GlobalState, temp, NULL) != NULL) { /* Another thread initialized first - free our allocation */ HEAP_TRACE("GetGlobalState: Lost race, freeing duplicate allocation at %p", temp); + DeleteCriticalSection(&temp->TrackingLock); HeapFree(GetProcessHeap(), 0, temp); } else { HEAP_INFO("GetGlobalState: Global state initialized successfully"); @@ -68,6 +70,8 @@ BOOL InitHeap(void) { return FALSE; } + EnterCriticalSection(&state->TrackingLock); + /* Only reset if we have no active allocations */ if (IsListEmpty(&state->MemoryAllocations)) { state->AllocationCount = 0; @@ -82,6 +86,8 @@ BOOL InitHeap(void) { } } + LeaveCriticalSection(&state->TrackingLock); + return TRUE; } @@ -92,6 +98,8 @@ void CleanupHeap(void) { PMEMORY_TRACKING_ENTRY entry; PFREED_MEMORY_ENTRY freedEntry; + EnterCriticalSection(&state->TrackingLock); + /* Clean up allocation tracking entries */ current = state->MemoryAllocations.Flink; while (current != &state->MemoryAllocations) { @@ -100,6 +108,7 @@ void CleanupHeap(void) { HeapFree(GetProcessHeap(), 0, entry); current = next; } + InitializeListHead(&state->MemoryAllocations); /* Clean up freed memory tracking entries */ current = state->FreedMemoryList.Flink; @@ -109,6 +118,10 @@ void CleanupHeap(void) { HeapFree(GetProcessHeap(), 0, freedEntry); current = next; } + InitializeListHead(&state->FreedMemoryList); + + LeaveCriticalSection(&state->TrackingLock); + DeleteCriticalSection(&state->TrackingLock); HeapFree(GetProcessHeap(), 0, state); g_WinKernelLite_GlobalState = NULL; @@ -136,17 +149,21 @@ void TrackAllocation(PVOID Address, SIZE_T Size, const char* FileName, int LineN entry->Size = Size; entry->FileName = FileName; entry->LineNumber = LineNumber; - entry->AllocationId = state->NextAllocationId++; - InsertHeadList(&state->MemoryAllocations, &entry->ListEntry); + EnterCriticalSection(&state->TrackingLock); + entry->AllocationId = state->NextAllocationId++; + InsertHeadList(&state->MemoryAllocations, &entry->ListEntry); state->AllocationCount++; state->TotalBytesAllocated += Size; state->CurrentBytesAllocated += Size; if (state->CurrentBytesAllocated > state->PeakBytesAllocated) state->PeakBytesAllocated = state->CurrentBytesAllocated; + + LeaveCriticalSection(&state->TrackingLock); } +/* NOTE: Caller MUST hold state->TrackingLock */ void TrackFreedMemoryLocked(PVOID Address, SIZE_T Size, const char* AllocFileName, int AllocLineNumber, const char* FreeFileName, int FreeLineNumber, ULONGLONG AllocationId) { GLOBAL_STATE* state = GetGlobalState(); PFREED_MEMORY_ENTRY entry; @@ -253,6 +270,8 @@ BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumb return FALSE; } + EnterCriticalSection(&state->TrackingLock); + current = state->MemoryAllocations.Flink; while (current != &state->MemoryAllocations) { @@ -275,10 +294,13 @@ BOOL UntrackAllocation(PVOID Address, const char* FreeFileName, int FreeLineNumb current = current->Flink; } + /* TrackFreedMemoryLocked expects lock held */ if (found && state->TrackFreedMemory) { TrackFreedMemoryLocked(Address, allocSize, allocFileName, allocLineNumber, FreeFileName, FreeLineNumber, allocationId); } + LeaveCriticalSection(&state->TrackingLock); + return found; } @@ -331,6 +353,8 @@ void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber state->HeapHandle, state->TrackFreedMemory); if (state->TrackFreedMemory) { + EnterCriticalSection(&state->TrackingLock); + PLIST_ENTRY current = state->MemoryAllocations.Flink; while (current != &state->MemoryAllocations) { PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); @@ -345,11 +369,14 @@ void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber HEAP_TRACE("_ExFreePoolWithTracking: Checking for double-free of %p (ID: %llu)", pointer, currentAllocationId); isDoubleFree = CheckForDoubleFree(pointer, FileName, LineNumber, currentAllocationId); if (isDoubleFree) { + LeaveCriticalSection(&state->TrackingLock); HEAP_WARN("_ExFreePoolWithTracking: Double-free detected for %p, returning without freeing", pointer); return; } HEAP_TRACE("_ExFreePoolWithTracking: No double-free detected for %p", pointer); } + + LeaveCriticalSection(&state->TrackingLock); } HEAP_TRACE("_ExFreePoolWithTracking: Calling UntrackAllocation for %p", pointer); @@ -399,21 +426,27 @@ void _ExFreePoolWithTracking(PVOID pointer, const char* FileName, int LineNumber BOOL IsValidHeapPointer(PVOID pointer) { GLOBAL_STATE* state = GetGlobalState(); PLIST_ENTRY current; + BOOL result = FALSE; if (!state || !pointer) return FALSE; + EnterCriticalSection(&state->TrackingLock); + current = state->MemoryAllocations.Flink; while (current != &state->MemoryAllocations) { PMEMORY_TRACKING_ENTRY entry = CONTAINING_RECORD(current, MEMORY_TRACKING_ENTRY, ListEntry); if (entry->Address == pointer) { - return TRUE; + result = TRUE; + break; } current = current->Flink; } - return FALSE; + LeaveCriticalSection(&state->TrackingLock); + + return result; } void PrintMemoryLeaks(void) { @@ -431,6 +464,8 @@ void PrintMemoryLeaks(void) { leakCount = 0; leakBytes = 0; + EnterCriticalSection(&state->TrackingLock); + HEAP_INFO("=== MEMORY LEAK REPORT ==="); current = state->MemoryAllocations.Flink; @@ -468,6 +503,8 @@ void PrintMemoryLeaks(void) { HEAP_INFO(" Double-free attempts: %d", (int)state->DoubleFreeCount); HEAP_INFO(" Freed entries tracked: %d", (int)state->FreedEntryCount); HEAP_INFO("==========================="); + + LeaveCriticalSection(&state->TrackingLock); } void PrintDoubleFreeReport(void) { @@ -479,6 +516,8 @@ void PrintDoubleFreeReport(void) { state = GetGlobalState(); if (!state) return; + EnterCriticalSection(&state->TrackingLock); + HEAP_INFO("=== FREED MEMORY REPORT ==="); HEAP_INFO("Total double-free attempts detected: %d", (int)state->DoubleFreeCount); HEAP_INFO("Currently tracking %d freed allocations", (int)state->FreedEntryCount); @@ -512,6 +551,8 @@ void PrintDoubleFreeReport(void) { } HEAP_INFO("=============================="); + + LeaveCriticalSection(&state->TrackingLock); } void SetErrorSuppression(BOOL suppress) { @@ -547,11 +588,12 @@ BOOL GetFreedMemoryTracking(void) { void SetMaxFreedEntries(SIZE_T maxEntries) { GLOBAL_STATE* state = GetGlobalState(); if (state) { + EnterCriticalSection(&state->TrackingLock); state->MaxFreedEntries = maxEntries; - if (state->FreedEntryCount > maxEntries) { CleanupOldFreedEntries(); } + LeaveCriticalSection(&state->TrackingLock); } }