From b7493ebcd7d50575077b3c9af83b478cb04be2a5 Mon Sep 17 00:00:00 2001 From: donnaskiez Date: Thu, 13 Jun 2024 22:53:37 +1000 Subject: [PATCH] hashmap multithreading improvement --- driver/callbacks.c | 109 +++++++++++++++++++++------------------------ driver/io.c | 4 -- driver/map.c | 62 +++++++++++++++++--------- driver/map.h | 37 +++++++-------- 4 files changed, 109 insertions(+), 103 deletions(-) diff --git a/driver/callbacks.c b/driver/callbacks.c index c951d2e..6dec081 100644 --- a/driver/callbacks.c +++ b/driver/callbacks.c @@ -260,6 +260,7 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, _In_opt_ PUNICODE_STRING FullImageName) { + UINT32 index = 0; NTSTATUS status = STATUS_UNSUCCESSFUL; PEPROCESS process = NULL; PRTL_HASHMAP map = GetProcessHashmap(); @@ -275,11 +276,12 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, if (!NT_SUCCESS(status)) return; - RtlHashmapAcquireLock(map); + index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId); - /* the PEPROCESS is the first element and is the only thing compared, hence - * we can simply pass it in the context parameter.*/ - entry = RtlHashmapEntryLookup(GetProcessHashmap(), ProcessId, &ProcessId); + if (index == STATUS_INVALID_HASHMAP_INDEX) + return; + + entry = RtlHashmapEntryLookup(GetProcessHashmap(), index, &ProcessId); /* critical error has occured */ if (!entry) { @@ -306,7 +308,7 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, entry->list_count++; end: - RtlHashmapReleaseLock(map); + RtlHashmapReleaseBucket(map, index); } VOID @@ -401,6 +403,7 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId, _In_ PROCESS_MODULE_CALLBACK Callback, _In_opt_ PVOID Context) { + UINT32 index = 0; PRTL_HASHMAP map = GetProcessHashmap(); BOOLEAN ret = FALSE; PPROCESS_LIST_ENTRY entry = NULL; @@ -410,9 +413,12 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId, if (!map->active) return; - RtlHashmapAcquireLock(map); + index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId); - entry = RtlHashmapEntryLookup(map, ProcessId, &ProcessId); + if (index == STATUS_INVALID_HASHMAP_INDEX) + return; + + entry = RtlHashmapEntryLookup(map, index, &ProcessId); if (!entry) goto end; @@ -426,13 +432,14 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId, } end: - RtlHashmapReleaseLock(map); + RtlHashmapReleaseBucket(map, index); } VOID FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, _In_opt_ PVOID Context) { + UINT32 index = 0; PRTL_HASHMAP map = GetProcessHashmap(); PPROCESS_LIST_ENTRY entry = NULL; PACTIVE_SESSION session = GetActiveSession(); @@ -442,9 +449,12 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, if (!map->active) return; - RtlHashmapAcquireLock(map); + index = RtlHashmapHashKeyAndAcquireBucket(map, session->km_handle); - entry = RtlHashmapEntryLookup(map, session->km_handle, &session->km_handle); + if (index == STATUS_INVALID_HASHMAP_INDEX) + return; + + entry = RtlHashmapEntryLookup(map, index, &session->km_handle); if (!entry) return; @@ -461,7 +471,7 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, } end: - RtlHashmapReleaseLock(map); + RtlHashmapReleaseBucket(map, index); } VOID @@ -474,7 +484,6 @@ CleanupProcessHashmap() PPROCESS_MODULE_MAP_CONTEXT context = NULL; RtlHashmapSetInactive(map); - RtlHashmapAcquireLock(map); /* First, free all module lists */ RtlHashmapEnumerate(map, FreeProcessEntryModuleList, NULL); @@ -482,11 +491,15 @@ CleanupProcessHashmap() for (UINT32 index = 0; index < map->bucket_count; index++) { entry = &map->buckets[index]; + KeAcquireGuardedMutex(&map->locks[index]); + while (!IsListEmpty(&entry->entry)) { list = RemoveHeadList(&entry->entry); temp = CONTAINING_RECORD(list, RTL_HASHMAP_ENTRY, entry); ExFreePoolWithTag(temp, POOL_TAG_HASHMAP); } + + KeReleaseGuardedMutex(&map->locks[index]); } context = map->context; @@ -494,8 +507,6 @@ CleanupProcessHashmap() ExDeleteLookasideListEx(&context->pool); ExFreePoolWithTag(map->context, POOL_TAG_HASHMAP); RtlHashmapDelete(map); - - RtlHashmapReleaseLock(map); } NTSTATUS @@ -601,50 +612,30 @@ CanInitiateDeferredHashing(_In_ LPCSTR ProcessName, _In_ PDRIVER_LIST_HEAD Head) : FALSE; } +STATIC +VOID +PrintHashmapCallback(_In_ PPROCESS_LIST_ENTRY Entry, _In_opt_ PVOID Context) +{ + PPROCESS_MAP_MODULE_ENTRY module = NULL; + PLIST_ENTRY list = NULL; + + UNREFERENCED_PARAMETER(Context); + DEBUG_VERBOSE("Process ID: %p", Entry->process_id); + + for (list = Entry->module_list.Flink; list != &Entry->module_list; + list = list->Flink) { + module = CONTAINING_RECORD(list, PROCESS_MAP_MODULE_ENTRY, entry); + DEBUG_VERBOSE(" -> Module Base: %p, size: %lx, path: %s", + (PVOID)module->base, + module->size, + module->path); + } +} + VOID EnumerateAndPrintProcessHashmap() { - PRTL_HASHMAP map = GetProcessHashmap(); - PRTL_HASHMAP_ENTRY entry = NULL; - PPROCESS_LIST_ENTRY proc_entry = NULL; - PPROCESS_MAP_MODULE_ENTRY mod_entry = NULL; - PLIST_ENTRY list_head = NULL; - PLIST_ENTRY list_entry = NULL; - PLIST_ENTRY mod_list_entry = NULL; - - RtlHashmapAcquireLock(map); - - for (UINT32 index = 0; index < map->bucket_count; index++) { - list_head = &map->buckets[index]; - list_entry = list_head->Flink; - - DEBUG_VERBOSE("Bucket %u:\n", index); - - while (list_entry != list_head) { - entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry); - - if (entry->in_use == TRUE) { - proc_entry = (PPROCESS_LIST_ENTRY)entry->object; - DEBUG_VERBOSE(" -> process id: %lx", proc_entry->process_id); - DEBUG_VERBOSE(" -> process: %llx", proc_entry->process); - DEBUG_VERBOSE(" -> parent: %llx", proc_entry->parent); - - mod_list_entry = proc_entry->module_list.Flink; - while (mod_list_entry != &proc_entry->module_list) { - mod_entry = CONTAINING_RECORD( - mod_list_entry, PROCESS_MAP_MODULE_ENTRY, entry); - DEBUG_VERBOSE(" -> module base: %llx", mod_entry->base); - DEBUG_VERBOSE(" -> module size: %lx", mod_entry->size); - - mod_list_entry = mod_list_entry->Flink; - } - } - - list_entry = list_entry->Flink; - } - } - - RtlHashmapReleaseLock(map); + RtlHashmapEnumerate(GetProcessHashmap(), PrintHashmapCallback, NULL); } VOID @@ -652,7 +643,7 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, _In_ HANDLE ProcessId, _In_ BOOLEAN Create) { - BOOLEAN new = FALSE; + UINT32 index = 0; PKPROCESS parent = NULL; PKPROCESS process = NULL; PDRIVER_LIST_HEAD driver_list = GetDriverList(); @@ -670,8 +661,10 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, return; process_name = ImpPsGetProcessImageFileName(process); + index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId); - RtlHashmapAcquireLock(map); + if (index == STATUS_INVALID_HASHMAP_INDEX) + return; if (Create) { entry = RtlHashmapEntryInsert(map, ProcessId); @@ -714,7 +707,7 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId, } end: - RtlHashmapReleaseLock(map); + RtlHashmapReleaseBucket(map, index); } VOID diff --git a/driver/io.c b/driver/io.c index 5e2859c..ff7f266 100644 --- a/driver/io.c +++ b/driver/io.c @@ -1191,10 +1191,6 @@ DeviceCreate(_In_ PDEVICE_OBJECT DeviceObject, _Inout_ PIRP Irp) PAGED_CODE(); UNREFERENCED_PARAMETER(DeviceObject); DEBUG_INFO("Handle to driver opened."); - // NTSTATUS status = ValidatePciDevices(); - - // if (!NT_SUCCESS(status)) - // DEBUG_ERROR("ValidatePciDevices failed with status %x", status); IoCompleteRequest(Irp, IO_NO_INCREMENT); return Irp->IoStatus.Status; diff --git a/driver/map.c b/driver/map.c index 48f9cb9..3ef3186 100644 --- a/driver/map.c +++ b/driver/map.c @@ -4,6 +4,7 @@ VOID RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap) { ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); + ExFreePoolWithTag(Hashmap->locks, POOL_TAG_HASHMAP); ExDeleteLookasideListEx(&Hashmap->pool); } @@ -28,14 +29,22 @@ RtlHashmapCreate(_In_ UINT32 BucketCount, if (!Hashmap->buckets) return STATUS_INSUFFICIENT_RESOURCES; + Hashmap->locks = ExAllocatePool2(POOL_FLAG_NON_PAGED, + sizeof(KGUARDED_MUTEX) * BucketCount, + POOL_TAG_HASHMAP); + + if (!Hashmap->locks) { + ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); + return STATUS_INSUFFICIENT_RESOURCES; + } + for (UINT32 index = 0; index < BucketCount; index++) { entry = &Hashmap->buckets[index]; entry->in_use = FALSE; InitializeListHead(&entry->entry); + KeInitializeGuardedMutex(&Hashmap->locks[index]); } - KeInitializeGuardedMutex(&Hashmap->lock); - status = ExInitializeLookasideListEx(&Hashmap->pool, NULL, NULL, @@ -48,6 +57,7 @@ RtlHashmapCreate(_In_ UINT32 BucketCount, if (!NT_SUCCESS(status)) { DEBUG_ERROR("ExInitializeLookasideListEx: %x", status); ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); + ExFreePoolWithTag(Hashmap->locks, POOL_TAG_HASHMAP); return status; } @@ -105,21 +115,37 @@ RtlpHashmapIsIndexInRange(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index) return Index < Hashmap->bucket_count ? TRUE : FALSE; } +INT32 +RtlHashmapHashKeyAndAcquireBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) +{ + UINT32 index = Hashmap->hash_function(Key); + + if (!RtlpHashmapIsIndexInRange(Hashmap, index)) + return -1; + + KeAcquireGuardedMutex(&Hashmap->locks[index]); + return index; +} + +VOID +RtlHashmapReleaseBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index) +{ + /* No index check here, assuming we exit the caller early if we fail on + * acquisition */ + KeReleaseGuardedMutex(&Hashmap->locks[Index]); +} + /* assumes map lock is held */ PVOID -RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) +RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index) { UINT32 index = 0; PLIST_ENTRY list_head = NULL; PRTL_HASHMAP_ENTRY entry = NULL; PRTL_HASHMAP_ENTRY new_entry = NULL; - index = Hashmap->hash_function(Key); - - if (!RtlpHashmapIsIndexInRange(Hashmap, index)) { - DEBUG_ERROR("Key is not in range of buckets"); + if (!Hashmap->active) return NULL; - } list_head = &(&Hashmap->buckets[index])->entry; entry = RtlpHashmapFindUnusedEntry(list_head); @@ -145,18 +171,14 @@ RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) */ PVOID RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, - _In_ UINT64 Key, + _In_ UINT32 Index, _In_ PVOID Compare) { UINT32 index = 0; PRTL_HASHMAP_ENTRY entry = NULL; - index = Hashmap->hash_function(Key); - - if (!RtlpHashmapIsIndexInRange(Hashmap, index)) { - DEBUG_ERROR("Key is not in range of buckets"); + if (!Hashmap->active) return NULL; - } entry = &Hashmap->buckets[index]; @@ -178,7 +200,7 @@ RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, /* Assumes lock is held */ BOOLEAN RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, - _In_ UINT64 Key, + _In_ UINT32 Index, _In_ PVOID Compare) { UINT32 index = 0; @@ -186,12 +208,8 @@ RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, PLIST_ENTRY list_entry = NULL; PRTL_HASHMAP_ENTRY entry = NULL; - index = Hashmap->hash_function(Key); - - if (!RtlpHashmapIsIndexInRange(Hashmap, index)) { - DEBUG_ERROR("Key is not in range of buckets"); + if (!Hashmap->active) return FALSE; - } list_head = &(&Hashmap->buckets[index])->entry; list_entry = list_head->Flink; @@ -229,6 +247,8 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, PRTL_HASHMAP_ENTRY entry = NULL; for (UINT32 index = 0; index < Hashmap->bucket_count; index++) { + KeAcquireGuardedMutex(&Hashmap->locks[index]); + list_head = &Hashmap->buckets[index]; list_entry = list_head->Flink; @@ -240,5 +260,7 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, list_entry = list_entry->Flink; } + + KeReleaseGuardedMutex(&Hashmap->locks[index]); } } \ No newline at end of file diff --git a/driver/map.h b/driver/map.h index 00e139e..9b6fc97 100644 --- a/driver/map.h +++ b/driver/map.h @@ -21,6 +21,9 @@ typedef struct _RTL_HASHMAP { /* Array of RTL_HASHMAP_ENTRIES with length = bucket_count */ PRTL_HASHMAP_ENTRY buckets; + /* per bucket locks */ + PKGUARDED_MUTEX locks; + /* Number of buckets, ideally a prime number */ UINT32 bucket_count; @@ -31,41 +34,41 @@ typedef struct _RTL_HASHMAP { HASH_FUNCTION hash_function; COMPARE_FUNCTION compare_function; - KGUARDED_MUTEX lock; - /* in the future bucket entries will use this */ LOOKASIDE_LIST_EX pool; /* user allocated context */ - PVOID context; - volatile UINT32 active; + PVOID context; + volatile UINT32 active; } RTL_HASHMAP, *PRTL_HASHMAP; typedef VOID (*ENUMERATE_HASHMAP)(_In_ PRTL_HASHMAP_ENTRY Entry, _In_opt_ PVOID Context); +#define STATUS_INVALID_HASHMAP_INDEX -1 + /* Hashmap is caller allocated */ NTSTATUS RtlHashmapCreate(_In_ UINT32 BucketCount, _In_ UINT32 EntryObjectSize, _In_ HASH_FUNCTION HashFunction, _In_ COMPARE_FUNCTION CompareFunction, - _In_opt_ PVOID Context, + _In_opt_ PVOID Context, _Out_ PRTL_HASHMAP Hashmap); PVOID -RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key); +RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index); PVOID RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, - _In_ UINT64 Key, + _In_ UINT32 Index, _In_ PVOID Compare); BOOLEAN RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, - _In_ UINT64 Key, - _In_ PVOID Compare); + _In_ UINT32 Index, + _In_ PVOID Compare); VOID RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, @@ -75,19 +78,12 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, VOID RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap); -FORCEINLINE -VOID -RtlHashmapAcquireLock(_Inout_ PRTL_HASHMAP Hashmap) -{ - KeAcquireGuardedMutex(&Hashmap->lock); -} +INT32 +RtlHashmapHashKeyAndAcquireBucket(_Inout_ PRTL_HASHMAP Hashmap, + _In_ UINT64 Key); -FORCEINLINE VOID -RtlHashmapReleaseLock(_Inout_ PRTL_HASHMAP Hashmap) -{ - KeReleaseGuardedMutex(&Hashmap->lock); -} +RtlHashmapReleaseBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index); FORCEINLINE VOID @@ -96,5 +92,4 @@ RtlHashmapSetInactive(_Inout_ PRTL_HASHMAP Hashmap) Hashmap->active = FALSE; } - #endif \ No newline at end of file