From 102165029da557b6be44de4bac0c3e22e8f6506a Mon Sep 17 00:00:00 2001 From: donnaskiez Date: Tue, 11 Jun 2024 21:41:55 +1000 Subject: [PATCH] small hashmap improvements --- driver/callbacks.c | 48 +++++++------- driver/integrity.c | 2 +- driver/io.c | 4 +- driver/map.c | 152 +++++++++++++++++++++++---------------------- driver/map.h | 50 ++++++++++----- driver/modules.c | 2 +- driver/pool.c | 4 +- 7 files changed, 142 insertions(+), 120 deletions(-) diff --git a/driver/callbacks.c b/driver/callbacks.c index f0a61e5..c951d2e 100644 --- a/driver/callbacks.c +++ b/driver/callbacks.c @@ -41,8 +41,7 @@ CleanupThreadListFreeCallback(_In_ PTHREAD_LIST_ENTRY ThreadListEntry) VOID UnregisterProcessCreateNotifyRoutine() { - PRTL_HASHMAP map = GetProcessHashmap(); - InterlockedExchange(&map->active, FALSE); + RtlHashmapSetInactive(GetProcessHashmap()); ImpPsSetCreateProcessNotifyRoutine(ProcessCreateNotifyRoutine, TRUE); } @@ -276,11 +275,11 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, if (!NT_SUCCESS(status)) return; - KeAcquireGuardedMutex(&map->lock); + RtlHashmapAcquireLock(map); /* the PEPROCESS is the first element and is the only thing compared, hence * we can simply pass it in the context parameter.*/ - entry = RtlLookupEntryHashmap(GetProcessHashmap(), ProcessId, &ProcessId); + entry = RtlHashmapEntryLookup(GetProcessHashmap(), ProcessId, &ProcessId); /* critical error has occured */ if (!entry) { @@ -307,7 +306,7 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, entry->list_count++; end: - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } VOID @@ -411,9 +410,9 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId, if (!map->active) return; - KeAcquireGuardedMutex(&map->lock); + RtlHashmapAcquireLock(map); - entry = RtlLookupEntryHashmap(map, ProcessId, &ProcessId); + entry = RtlHashmapEntryLookup(map, ProcessId, &ProcessId); if (!entry) goto end; @@ -427,7 +426,7 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId, } end: - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } VOID @@ -443,9 +442,9 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, if (!map->active) return; - KeAcquireGuardedMutex(&map->lock); + RtlHashmapAcquireLock(map); - entry = RtlLookupEntryHashmap(map, session->km_handle, &session->km_handle); + entry = RtlHashmapEntryLookup(map, session->km_handle, &session->km_handle); if (!entry) return; @@ -462,7 +461,7 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, } end: - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } VOID @@ -474,12 +473,11 @@ CleanupProcessHashmap() PLIST_ENTRY list = NULL; PPROCESS_MODULE_MAP_CONTEXT context = NULL; - map->active = FALSE; - - KeAcquireGuardedMutex(&map->lock); + RtlHashmapSetInactive(map); + RtlHashmapAcquireLock(map); /* First, free all module lists */ - RtlEnumerateHashmap(map, FreeProcessEntryModuleList, NULL); + RtlHashmapEnumerate(map, FreeProcessEntryModuleList, NULL); for (UINT32 index = 0; index < map->bucket_count; index++) { entry = &map->buckets[index]; @@ -495,9 +493,9 @@ CleanupProcessHashmap() ExDeleteLookasideListEx(&context->pool); ExFreePoolWithTag(map->context, POOL_TAG_HASHMAP); - RtlDeleteHashmap(map); + RtlHashmapDelete(map); - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } NTSTATUS @@ -529,7 +527,7 @@ InitialiseProcessHashmap() return status; } - status = RtlCreateHashmap(PROCESS_HASHMAP_BUCKET_COUNT, + status = RtlHashmapCreate(PROCESS_HASHMAP_BUCKET_COUNT, sizeof(PROCESS_LIST_ENTRY), ProcessHashmapHashFunction, ProcessHashmapCompareFunction, @@ -614,7 +612,7 @@ EnumerateAndPrintProcessHashmap() PLIST_ENTRY list_entry = NULL; PLIST_ENTRY mod_list_entry = NULL; - KeAcquireGuardedMutex(&map->lock); + RtlHashmapAcquireLock(map); for (UINT32 index = 0; index < map->bucket_count; index++) { list_head = &map->buckets[index]; @@ -646,7 +644,7 @@ EnumerateAndPrintProcessHashmap() } } - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } VOID @@ -673,10 +671,10 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, process_name = ImpPsGetProcessImageFileName(process); - KeAcquireGuardedMutex(&map->lock); + RtlHashmapAcquireLock(map); if (Create) { - entry = RtlInsertEntryHashmap(map, ProcessId); + entry = RtlHashmapEntryInsert(map, ProcessId); if (!entry) goto end; @@ -701,7 +699,7 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, } } else { - entry = RtlLookupEntryHashmap(map, ProcessId, &ProcessId); + entry = RtlHashmapEntryLookup(map, ProcessId, &ProcessId); if (!entry) { DEBUG_ERROR("UNABLE TO FIND PROCESS NODE!!!"); @@ -712,11 +710,11 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, ImpObDereferenceObject(entry->process); FreeProcessEntryModuleList(entry, NULL); - RtlDeleteEntryHashmap(map, ProcessId, &ProcessId); + RtlHashmapEntryDelete(map, ProcessId, &ProcessId); } end: - KeReleaseGuardedMutex(&map->lock); + RtlHashmapReleaseLock(map); } VOID diff --git a/driver/integrity.c b/driver/integrity.c index a5835e6..4b0a6f6 100644 --- a/driver/integrity.c +++ b/driver/integrity.c @@ -1493,7 +1493,7 @@ StoreModuleExecutableRegionsx86(_In_ PRTL_MODULE_EXTENDED_INFO Module, PEPROCESS process = NULL; KAPC_STATE apc_state = {0}; - RtlEnumerateHashmap(GetProcessHashmap(), FindWinLogonProcess, &process); + RtlHashmapEnumerate(GetProcessHashmap(), FindWinLogonProcess, &process); if (!process) return STATUS_NOT_FOUND; diff --git a/driver/io.c b/driver/io.c index b3b7df1..5e2859c 100644 --- a/driver/io.c +++ b/driver/io.c @@ -447,7 +447,7 @@ SharedMappingWorkRoutine(_In_ PDEVICE_OBJECT DeviceObject, /* can maybe implement this better so we can extract a status * value */ - RtlEnumerateHashmap(GetProcessHashmap(), EnumerateProcessHandles, NULL); + RtlHashmapEnumerate(GetProcessHashmap(), EnumerateProcessHandles, NULL); break; @@ -898,7 +898,7 @@ DeviceControl(_In_ PDEVICE_OBJECT DeviceObject, _Inout_ PIRP Irp) /* can maybe implement this better so we can extract a status * value */ - RtlEnumerateHashmap(GetProcessHashmap(), EnumerateProcessHandles, NULL); + RtlHashmapEnumerate(GetProcessHashmap(), EnumerateProcessHandles, NULL); break; diff --git a/driver/map.c b/driver/map.c index 88f5aba..5a52f19 100644 --- a/driver/map.c +++ b/driver/map.c @@ -1,17 +1,27 @@ #include "map.h" +VOID +RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap) +{ + ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); + ExDeleteLookasideListEx(&Hashmap->pool); +} + NTSTATUS -RtlCreateHashmap(_In_ UINT32 BucketCount, +RtlHashmapCreate(_In_ UINT32 BucketCount, _In_ UINT32 EntryObjectSize, _In_ HASH_FUNCTION HashFunction, _In_ COMPARE_FUNCTION CompareFunction, - _In_ PVOID Context, + _In_opt_ PVOID Context, _Out_ PRTL_HASHMAP Hashmap) { NTSTATUS status = STATUS_UNSUCCESSFUL; UINT32 entry_size = sizeof(RTL_HASHMAP_ENTRY) + EntryObjectSize; PRTL_HASHMAP_ENTRY entry = NULL; + if (!CompareFunction || !HashFunction) + return STATUS_INVALID_PARAMETER; + Hashmap->buckets = ExAllocatePool2( POOL_FLAG_NON_PAGED, BucketCount * entry_size, POOL_TAG_HASHMAP); @@ -26,6 +36,21 @@ RtlCreateHashmap(_In_ UINT32 BucketCount, KeInitializeGuardedMutex(&Hashmap->lock); + status = ExInitializeLookasideListEx(&Hashmap->pool, + NULL, + NULL, + NonPagedPoolNx, + 0, + entry_size, + POOL_TAG_HASHMAP, + 0); + + if (!NT_SUCCESS(status)) { + DEBUG_ERROR("ExInitializeLookasideListEx: %x", status); + ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); + return status; + } + Hashmap->bucket_count = BucketCount; Hashmap->hash_function = HashFunction; Hashmap->compare_function = CompareFunction; @@ -39,15 +64,20 @@ RtlCreateHashmap(_In_ UINT32 BucketCount, FORCEINLINE STATIC PRTL_HASHMAP_ENTRY -RtlFindUnusedHashmapEntry(_In_ PRTL_HASHMAP_ENTRY Head) +RtlHashmapFindUnusedEntry(_In_ PLIST_ENTRY Head) { - PRTL_HASHMAP_ENTRY entry = Head; + PRTL_HASHMAP_ENTRY entry = NULL; + PLIST_ENTRY list_entry = Head->Flink; - while (entry) { - if (entry->in_use == FALSE) + while (list_entry != Head) { + entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry); + + if (entry->in_use == FALSE) { + entry->in_use = TRUE; return entry; + } - entry = CONTAINING_RECORD(entry->entry.Flink, RTL_HASHMAP_ENTRY, entry); + list_entry = list_entry->Flink; } return NULL; @@ -56,12 +86,9 @@ RtlFindUnusedHashmapEntry(_In_ PRTL_HASHMAP_ENTRY Head) FORCEINLINE STATIC PRTL_HASHMAP_ENTRY -RtlAllocateBucketListEntry(_In_ PRTL_HASHMAP Hashmap) +RtlHashmapAllocateBucketEntry(_In_ PRTL_HASHMAP Hashmap) { - PRTL_HASHMAP_ENTRY entry = - ExAllocatePool2(POOL_FLAG_NON_PAGED, - Hashmap->object_size + sizeof(RTL_HASHMAP_ENTRY), - POOL_TAG_HASHMAP); + PRTL_HASHMAP_ENTRY entry = ExAllocateFromLookasideListEx(&Hashmap->pool); if (!entry) return NULL; @@ -73,43 +100,34 @@ RtlAllocateBucketListEntry(_In_ PRTL_HASHMAP Hashmap) FORCEINLINE STATIC BOOLEAN -RtlIsIndexInHashmapRange(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index) +RtlHashmapIsIndexInRange(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index) { return Index < Hashmap->bucket_count ? TRUE : FALSE; } /* assumes map lock is held */ PVOID -RtlInsertEntryHashmap(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) +RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) { - UINT32 index = 0; - PLIST_ENTRY list_head = NULL; - PLIST_ENTRY list_entry = NULL; - PRTL_HASHMAP_ENTRY entry = NULL; - PRTL_HASHMAP_ENTRY new_entry = NULL; + UINT32 index = 0; + PLIST_ENTRY list_head = NULL; + PRTL_HASHMAP_ENTRY entry = NULL; + PRTL_HASHMAP_ENTRY new_entry = NULL; index = Hashmap->hash_function(Key); - if (!RtlIsIndexInHashmapRange(Hashmap, index)) { + if (!RtlHashmapIsIndexInRange(Hashmap, index)) { DEBUG_ERROR("Key is not in range of buckets"); return NULL; } - list_head = &(&Hashmap->buckets[index])->entry; - list_entry = list_head->Flink; + list_head = &(&Hashmap->buckets[index])->entry; + entry = RtlHashmapFindUnusedEntry(list_head); - while (list_entry != list_head) { - entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry); + if (entry) + return entry; - if (entry->in_use == FALSE) { - entry->in_use = TRUE; - return entry->object; - } - - list_entry = list_entry->Flink; - } - - new_entry = RtlAllocateBucketListEntry(Hashmap); + new_entry = RtlHashmapAllocateBucketEntry(Hashmap); if (!new_entry) { DEBUG_ERROR("Failed to allocate new entry"); @@ -126,7 +144,7 @@ RtlInsertEntryHashmap(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) * Also assumes lock is held. */ PVOID -RtlLookupEntryHashmap(_In_ PRTL_HASHMAP Hashmap, +RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key, _In_ PVOID Compare) { @@ -135,7 +153,7 @@ RtlLookupEntryHashmap(_In_ PRTL_HASHMAP Hashmap, index = Hashmap->hash_function(Key); - if (!RtlIsIndexInHashmapRange(Hashmap, index)) { + if (!RtlHashmapIsIndexInRange(Hashmap, index)) { DEBUG_ERROR("Key is not in range of buckets"); return NULL; } @@ -159,83 +177,67 @@ RtlLookupEntryHashmap(_In_ PRTL_HASHMAP Hashmap, /* Assumes lock is held */ BOOLEAN -RtlDeleteEntryHashmap(_In_ PRTL_HASHMAP Hashmap, - _In_ UINT64 Key, - _In_ PVOID Compare) +RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, + _In_ UINT64 Key, + _In_ PVOID Compare) { - UINT32 index = 0; - PRTL_HASHMAP_ENTRY entry = NULL; - PRTL_HASHMAP_ENTRY next = NULL; + UINT32 index = 0; + PLIST_ENTRY list_head = NULL; + PLIST_ENTRY list_entry = NULL; + PRTL_HASHMAP_ENTRY entry = NULL; index = Hashmap->hash_function(Key); - if (!RtlIsIndexInHashmapRange(Hashmap, index)) { + if (!RtlHashmapIsIndexInRange(Hashmap, index)) { DEBUG_ERROR("Key is not in range of buckets"); return FALSE; } - entry = &Hashmap->buckets[index]; + list_head = &(&Hashmap->buckets[index])->entry; + list_entry = list_head->Flink; - while (entry) { - if (entry->in_use == FALSE) { - next = - CONTAINING_RECORD(entry->entry.Flink, RTL_HASHMAP_ENTRY, entry); + while (list_entry != list_head) { + entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry); - if (next == &Hashmap->buckets[index]) - break; - - entry = next; - continue; - } - - if (Hashmap->compare_function(entry->object, Compare)) { - if (entry == &Hashmap->buckets[index]) { + if (entry->in_use && + Hashmap->compare_function(entry->object, Compare)) { + if (entry == list_head) { entry->in_use = FALSE; } else { RemoveEntryList(&entry->entry); - ExFreePoolWithTag(entry, POOL_TAG_HASHMAP); + ExFreeToLookasideListEx(&Hashmap->pool, entry); } return TRUE; } - next = CONTAINING_RECORD(entry->entry.Flink, RTL_HASHMAP_ENTRY, entry); - - if (next == &Hashmap->buckets[index]) - break; - - entry = next; + list_entry = list_entry->Flink; } return FALSE; } VOID -RtlEnumerateHashmap(_In_ PRTL_HASHMAP Hashmap, +RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, _In_ ENUMERATE_HASHMAP EnumerationCallback, _In_opt_ PVOID Context) { - PRTL_HASHMAP_ENTRY entry = NULL; + PLIST_ENTRY list_head = NULL; + PLIST_ENTRY list_entry = NULL; + PRTL_HASHMAP_ENTRY entry = NULL; for (UINT32 index = 0; index < Hashmap->bucket_count; index++) { - PLIST_ENTRY list_head = &Hashmap->buckets[index]; - PLIST_ENTRY list_entry = list_head->Flink; + list_head = &Hashmap->buckets[index]; + list_entry = list_head->Flink; while (list_entry != list_head) { entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry); - if (entry->in_use == TRUE) { + if (entry->in_use == TRUE) EnumerationCallback(entry->object, Context); - } list_entry = list_entry->Flink; } } -} - -VOID -RtlDeleteHashmap(_In_ PRTL_HASHMAP Hashmap) -{ - ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); } \ No newline at end of file diff --git a/driver/map.h b/driver/map.h index b207be5..00e139e 100644 --- a/driver/map.h +++ b/driver/map.h @@ -3,11 +3,6 @@ #include "common.h" -typedef UINT32 (*HASH_FUNCTION)(_In_ UINT64 Key); - -/* Struct1 being the node being compared to the value in Struct 2*/ -typedef BOOLEAN (*COMPARE_FUNCTION)(_In_ PVOID Struct1, _In_ PVOID Struct2); - /* To improve efficiency, each entry contains a common header * RTL_HASHMAP_ENTRY*, reducing the need to store a seperate pointer to the * entrys data. */ @@ -17,8 +12,10 @@ typedef struct _RTL_HASHMAP_ENTRY { CHAR object[]; } RTL_HASHMAP_ENTRY, *PRTL_HASHMAP_ENTRY; -typedef VOID (*ENUMERATE_HASHMAP)(_In_ PRTL_HASHMAP_ENTRY Entry, - _In_opt_ PVOID Context); +typedef UINT32 (*HASH_FUNCTION)(_In_ UINT64 Key); + +/* Struct1 being the node being compared to the value in Struct 2*/ +typedef BOOLEAN (*COMPARE_FUNCTION)(_In_ PVOID Struct1, _In_ PVOID Struct2); typedef struct _RTL_HASHMAP { /* Array of RTL_HASHMAP_ENTRIES with length = bucket_count */ @@ -45,34 +42,59 @@ typedef struct _RTL_HASHMAP { } RTL_HASHMAP, *PRTL_HASHMAP; +typedef VOID (*ENUMERATE_HASHMAP)(_In_ PRTL_HASHMAP_ENTRY Entry, + _In_opt_ PVOID Context); + /* Hashmap is caller allocated */ NTSTATUS -RtlCreateHashmap(_In_ UINT32 BucketCount, +RtlHashmapCreate(_In_ UINT32 BucketCount, _In_ UINT32 EntryObjectSize, _In_ HASH_FUNCTION HashFunction, _In_ COMPARE_FUNCTION CompareFunction, - _In_ PVOID Context, + _In_opt_ PVOID Context, _Out_ PRTL_HASHMAP Hashmap); PVOID -RtlInsertEntryHashmap(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key); +RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key); PVOID -RtlLookupEntryHashmap(_In_ PRTL_HASHMAP Hashmap, +RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key, _In_ PVOID Compare); BOOLEAN -RtlDeleteEntryHashmap(_In_ PRTL_HASHMAP Hashmap, +RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key, _In_ PVOID Compare); VOID -RtlEnumerateHashmap(_In_ PRTL_HASHMAP Hashmap, +RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, _In_ ENUMERATE_HASHMAP EnumerationCallback, _In_opt_ PVOID Context); VOID -RtlDeleteHashmap(_In_ PRTL_HASHMAP Hashmap); +RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap); + +FORCEINLINE +VOID +RtlHashmapAcquireLock(_Inout_ PRTL_HASHMAP Hashmap) +{ + KeAcquireGuardedMutex(&Hashmap->lock); +} + +FORCEINLINE +VOID +RtlHashmapReleaseLock(_Inout_ PRTL_HASHMAP Hashmap) +{ + KeReleaseGuardedMutex(&Hashmap->lock); +} + +FORCEINLINE +VOID +RtlHashmapSetInactive(_Inout_ PRTL_HASHMAP Hashmap) +{ + Hashmap->active = FALSE; +} + #endif \ No newline at end of file diff --git a/driver/modules.c b/driver/modules.c index 98c48ec..61c5e7f 100644 --- a/driver/modules.c +++ b/driver/modules.c @@ -1957,7 +1957,7 @@ ValidateWin32kBase_gDxgInterface() goto end; } - RtlEnumerateHashmap(GetProcessHashmap(), FindWinLogonProcess, &winlogon); + RtlHashmapEnumerate(GetProcessHashmap(), FindWinLogonProcess, &winlogon); if (!winlogon) { status = STATUS_UNSUCCESSFUL; diff --git a/driver/pool.c b/driver/pool.c index 695321e..90a1a92 100644 --- a/driver/pool.c +++ b/driver/pool.c @@ -684,7 +684,7 @@ FindUnlinkedProcesses() UINT32 packet_size = CryptRequestRequiredBufferLength( sizeof(INVALID_PROCESS_ALLOCATION_REPORT)); - RtlEnumerateHashmap(GetProcessHashmap(), IncrementProcessCounter, &context); + RtlHashmapEnumerate(GetProcessHashmap(), IncrementProcessCounter, &context); if (context.process_count == 0) { DEBUG_ERROR("IncrementProcessCounter failed with no status."); @@ -701,7 +701,7 @@ FindUnlinkedProcesses() WalkKernelPageTables(&context); - RtlEnumerateHashmap( + RtlHashmapEnumerate( GetProcessHashmap(), CheckIfProcessAllocationIsInProcessList, &context); allocation_address = (PUINT64)context.process_buffer;