hashmap multithreading improvement

This commit is contained in:
donnaskiez 2024-06-13 22:53:37 +10:00
parent 2f679357ed
commit b7493ebcd7
4 changed files with 109 additions and 103 deletions

View file

@ -260,6 +260,7 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo,
_In_opt_ PUNICODE_STRING _In_opt_ PUNICODE_STRING
FullImageName) FullImageName)
{ {
UINT32 index = 0;
NTSTATUS status = STATUS_UNSUCCESSFUL; NTSTATUS status = STATUS_UNSUCCESSFUL;
PEPROCESS process = NULL; PEPROCESS process = NULL;
PRTL_HASHMAP map = GetProcessHashmap(); PRTL_HASHMAP map = GetProcessHashmap();
@ -275,11 +276,12 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo,
if (!NT_SUCCESS(status)) if (!NT_SUCCESS(status))
return; return;
RtlHashmapAcquireLock(map); index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId);
/* the PEPROCESS is the first element and is the only thing compared, hence if (index == STATUS_INVALID_HASHMAP_INDEX)
* we can simply pass it in the context parameter.*/ return;
entry = RtlHashmapEntryLookup(GetProcessHashmap(), ProcessId, &ProcessId);
entry = RtlHashmapEntryLookup(GetProcessHashmap(), index, &ProcessId);
/* critical error has occured */ /* critical error has occured */
if (!entry) { if (!entry) {
@ -306,7 +308,7 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo,
entry->list_count++; entry->list_count++;
end: end:
RtlHashmapReleaseLock(map); RtlHashmapReleaseBucket(map, index);
} }
VOID VOID
@ -401,6 +403,7 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId,
_In_ PROCESS_MODULE_CALLBACK Callback, _In_ PROCESS_MODULE_CALLBACK Callback,
_In_opt_ PVOID Context) _In_opt_ PVOID Context)
{ {
UINT32 index = 0;
PRTL_HASHMAP map = GetProcessHashmap(); PRTL_HASHMAP map = GetProcessHashmap();
BOOLEAN ret = FALSE; BOOLEAN ret = FALSE;
PPROCESS_LIST_ENTRY entry = NULL; PPROCESS_LIST_ENTRY entry = NULL;
@ -410,9 +413,12 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId,
if (!map->active) if (!map->active)
return; return;
RtlHashmapAcquireLock(map); index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId);
entry = RtlHashmapEntryLookup(map, ProcessId, &ProcessId); if (index == STATUS_INVALID_HASHMAP_INDEX)
return;
entry = RtlHashmapEntryLookup(map, index, &ProcessId);
if (!entry) if (!entry)
goto end; goto end;
@ -426,13 +432,14 @@ EnumerateProcessModuleList(_In_ HANDLE ProcessId,
} }
end: end:
RtlHashmapReleaseLock(map); RtlHashmapReleaseBucket(map, index);
} }
VOID VOID
FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback, FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback,
_In_opt_ PVOID Context) _In_opt_ PVOID Context)
{ {
UINT32 index = 0;
PRTL_HASHMAP map = GetProcessHashmap(); PRTL_HASHMAP map = GetProcessHashmap();
PPROCESS_LIST_ENTRY entry = NULL; PPROCESS_LIST_ENTRY entry = NULL;
PACTIVE_SESSION session = GetActiveSession(); PACTIVE_SESSION session = GetActiveSession();
@ -442,9 +449,12 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback,
if (!map->active) if (!map->active)
return; return;
RtlHashmapAcquireLock(map); index = RtlHashmapHashKeyAndAcquireBucket(map, session->km_handle);
entry = RtlHashmapEntryLookup(map, session->km_handle, &session->km_handle); if (index == STATUS_INVALID_HASHMAP_INDEX)
return;
entry = RtlHashmapEntryLookup(map, index, &session->km_handle);
if (!entry) if (!entry)
return; return;
@ -461,7 +471,7 @@ FindOurUserModeModuleEntry(_In_ PROCESS_MODULE_CALLBACK Callback,
} }
end: end:
RtlHashmapReleaseLock(map); RtlHashmapReleaseBucket(map, index);
} }
VOID VOID
@ -474,7 +484,6 @@ CleanupProcessHashmap()
PPROCESS_MODULE_MAP_CONTEXT context = NULL; PPROCESS_MODULE_MAP_CONTEXT context = NULL;
RtlHashmapSetInactive(map); RtlHashmapSetInactive(map);
RtlHashmapAcquireLock(map);
/* First, free all module lists */ /* First, free all module lists */
RtlHashmapEnumerate(map, FreeProcessEntryModuleList, NULL); RtlHashmapEnumerate(map, FreeProcessEntryModuleList, NULL);
@ -482,11 +491,15 @@ CleanupProcessHashmap()
for (UINT32 index = 0; index < map->bucket_count; index++) { for (UINT32 index = 0; index < map->bucket_count; index++) {
entry = &map->buckets[index]; entry = &map->buckets[index];
KeAcquireGuardedMutex(&map->locks[index]);
while (!IsListEmpty(&entry->entry)) { while (!IsListEmpty(&entry->entry)) {
list = RemoveHeadList(&entry->entry); list = RemoveHeadList(&entry->entry);
temp = CONTAINING_RECORD(list, RTL_HASHMAP_ENTRY, entry); temp = CONTAINING_RECORD(list, RTL_HASHMAP_ENTRY, entry);
ExFreePoolWithTag(temp, POOL_TAG_HASHMAP); ExFreePoolWithTag(temp, POOL_TAG_HASHMAP);
} }
KeReleaseGuardedMutex(&map->locks[index]);
} }
context = map->context; context = map->context;
@ -494,8 +507,6 @@ CleanupProcessHashmap()
ExDeleteLookasideListEx(&context->pool); ExDeleteLookasideListEx(&context->pool);
ExFreePoolWithTag(map->context, POOL_TAG_HASHMAP); ExFreePoolWithTag(map->context, POOL_TAG_HASHMAP);
RtlHashmapDelete(map); RtlHashmapDelete(map);
RtlHashmapReleaseLock(map);
} }
NTSTATUS NTSTATUS
@ -601,50 +612,30 @@ CanInitiateDeferredHashing(_In_ LPCSTR ProcessName, _In_ PDRIVER_LIST_HEAD Head)
: FALSE; : FALSE;
} }
STATIC
VOID
PrintHashmapCallback(_In_ PPROCESS_LIST_ENTRY Entry, _In_opt_ PVOID Context)
{
PPROCESS_MAP_MODULE_ENTRY module = NULL;
PLIST_ENTRY list = NULL;
UNREFERENCED_PARAMETER(Context);
DEBUG_VERBOSE("Process ID: %p", Entry->process_id);
for (list = Entry->module_list.Flink; list != &Entry->module_list;
list = list->Flink) {
module = CONTAINING_RECORD(list, PROCESS_MAP_MODULE_ENTRY, entry);
DEBUG_VERBOSE(" -> Module Base: %p, size: %lx, path: %s",
(PVOID)module->base,
module->size,
module->path);
}
}
VOID VOID
EnumerateAndPrintProcessHashmap() EnumerateAndPrintProcessHashmap()
{ {
PRTL_HASHMAP map = GetProcessHashmap(); RtlHashmapEnumerate(GetProcessHashmap(), PrintHashmapCallback, NULL);
PRTL_HASHMAP_ENTRY entry = NULL;
PPROCESS_LIST_ENTRY proc_entry = NULL;
PPROCESS_MAP_MODULE_ENTRY mod_entry = NULL;
PLIST_ENTRY list_head = NULL;
PLIST_ENTRY list_entry = NULL;
PLIST_ENTRY mod_list_entry = NULL;
RtlHashmapAcquireLock(map);
for (UINT32 index = 0; index < map->bucket_count; index++) {
list_head = &map->buckets[index];
list_entry = list_head->Flink;
DEBUG_VERBOSE("Bucket %u:\n", index);
while (list_entry != list_head) {
entry = CONTAINING_RECORD(list_entry, RTL_HASHMAP_ENTRY, entry);
if (entry->in_use == TRUE) {
proc_entry = (PPROCESS_LIST_ENTRY)entry->object;
DEBUG_VERBOSE(" -> process id: %lx", proc_entry->process_id);
DEBUG_VERBOSE(" -> process: %llx", proc_entry->process);
DEBUG_VERBOSE(" -> parent: %llx", proc_entry->parent);
mod_list_entry = proc_entry->module_list.Flink;
while (mod_list_entry != &proc_entry->module_list) {
mod_entry = CONTAINING_RECORD(
mod_list_entry, PROCESS_MAP_MODULE_ENTRY, entry);
DEBUG_VERBOSE(" -> module base: %llx", mod_entry->base);
DEBUG_VERBOSE(" -> module size: %lx", mod_entry->size);
mod_list_entry = mod_list_entry->Flink;
}
}
list_entry = list_entry->Flink;
}
}
RtlHashmapReleaseLock(map);
} }
VOID VOID
@ -652,7 +643,7 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId,
_In_ HANDLE ProcessId, _In_ HANDLE ProcessId,
_In_ BOOLEAN Create) _In_ BOOLEAN Create)
{ {
BOOLEAN new = FALSE; UINT32 index = 0;
PKPROCESS parent = NULL; PKPROCESS parent = NULL;
PKPROCESS process = NULL; PKPROCESS process = NULL;
PDRIVER_LIST_HEAD driver_list = GetDriverList(); PDRIVER_LIST_HEAD driver_list = GetDriverList();
@ -670,8 +661,10 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId,
return; return;
process_name = ImpPsGetProcessImageFileName(process); process_name = ImpPsGetProcessImageFileName(process);
index = RtlHashmapHashKeyAndAcquireBucket(map, ProcessId);
RtlHashmapAcquireLock(map); if (index == STATUS_INVALID_HASHMAP_INDEX)
return;
if (Create) { if (Create) {
entry = RtlHashmapEntryInsert(map, ProcessId); entry = RtlHashmapEntryInsert(map, ProcessId);
@ -714,7 +707,7 @@ ProcessCreateNotifyRoutine(_In_ HANDLE ParentId,
} }
end: end:
RtlHashmapReleaseLock(map); RtlHashmapReleaseBucket(map, index);
} }
VOID VOID

View file

@ -1191,10 +1191,6 @@ DeviceCreate(_In_ PDEVICE_OBJECT DeviceObject, _Inout_ PIRP Irp)
PAGED_CODE(); PAGED_CODE();
UNREFERENCED_PARAMETER(DeviceObject); UNREFERENCED_PARAMETER(DeviceObject);
DEBUG_INFO("Handle to driver opened."); DEBUG_INFO("Handle to driver opened.");
// NTSTATUS status = ValidatePciDevices();
// if (!NT_SUCCESS(status))
// DEBUG_ERROR("ValidatePciDevices failed with status %x", status);
IoCompleteRequest(Irp, IO_NO_INCREMENT); IoCompleteRequest(Irp, IO_NO_INCREMENT);
return Irp->IoStatus.Status; return Irp->IoStatus.Status;

View file

@ -4,6 +4,7 @@ VOID
RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap) RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap)
{ {
ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP);
ExFreePoolWithTag(Hashmap->locks, POOL_TAG_HASHMAP);
ExDeleteLookasideListEx(&Hashmap->pool); ExDeleteLookasideListEx(&Hashmap->pool);
} }
@ -28,14 +29,22 @@ RtlHashmapCreate(_In_ UINT32 BucketCount,
if (!Hashmap->buckets) if (!Hashmap->buckets)
return STATUS_INSUFFICIENT_RESOURCES; return STATUS_INSUFFICIENT_RESOURCES;
Hashmap->locks = ExAllocatePool2(POOL_FLAG_NON_PAGED,
sizeof(KGUARDED_MUTEX) * BucketCount,
POOL_TAG_HASHMAP);
if (!Hashmap->locks) {
ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP);
return STATUS_INSUFFICIENT_RESOURCES;
}
for (UINT32 index = 0; index < BucketCount; index++) { for (UINT32 index = 0; index < BucketCount; index++) {
entry = &Hashmap->buckets[index]; entry = &Hashmap->buckets[index];
entry->in_use = FALSE; entry->in_use = FALSE;
InitializeListHead(&entry->entry); InitializeListHead(&entry->entry);
KeInitializeGuardedMutex(&Hashmap->locks[index]);
} }
KeInitializeGuardedMutex(&Hashmap->lock);
status = ExInitializeLookasideListEx(&Hashmap->pool, status = ExInitializeLookasideListEx(&Hashmap->pool,
NULL, NULL,
NULL, NULL,
@ -48,6 +57,7 @@ RtlHashmapCreate(_In_ UINT32 BucketCount,
if (!NT_SUCCESS(status)) { if (!NT_SUCCESS(status)) {
DEBUG_ERROR("ExInitializeLookasideListEx: %x", status); DEBUG_ERROR("ExInitializeLookasideListEx: %x", status);
ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP); ExFreePoolWithTag(Hashmap->buckets, POOL_TAG_HASHMAP);
ExFreePoolWithTag(Hashmap->locks, POOL_TAG_HASHMAP);
return status; return status;
} }
@ -105,21 +115,37 @@ RtlpHashmapIsIndexInRange(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index)
return Index < Hashmap->bucket_count ? TRUE : FALSE; return Index < Hashmap->bucket_count ? TRUE : FALSE;
} }
INT32
RtlHashmapHashKeyAndAcquireBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key)
{
UINT32 index = Hashmap->hash_function(Key);
if (!RtlpHashmapIsIndexInRange(Hashmap, index))
return -1;
KeAcquireGuardedMutex(&Hashmap->locks[index]);
return index;
}
VOID
RtlHashmapReleaseBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index)
{
/* No index check here, assuming we exit the caller early if we fail on
* acquisition */
KeReleaseGuardedMutex(&Hashmap->locks[Index]);
}
/* assumes map lock is held */ /* assumes map lock is held */
PVOID PVOID
RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key) RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index)
{ {
UINT32 index = 0; UINT32 index = 0;
PLIST_ENTRY list_head = NULL; PLIST_ENTRY list_head = NULL;
PRTL_HASHMAP_ENTRY entry = NULL; PRTL_HASHMAP_ENTRY entry = NULL;
PRTL_HASHMAP_ENTRY new_entry = NULL; PRTL_HASHMAP_ENTRY new_entry = NULL;
index = Hashmap->hash_function(Key); if (!Hashmap->active)
if (!RtlpHashmapIsIndexInRange(Hashmap, index)) {
DEBUG_ERROR("Key is not in range of buckets");
return NULL; return NULL;
}
list_head = &(&Hashmap->buckets[index])->entry; list_head = &(&Hashmap->buckets[index])->entry;
entry = RtlpHashmapFindUnusedEntry(list_head); entry = RtlpHashmapFindUnusedEntry(list_head);
@ -145,18 +171,14 @@ RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key)
*/ */
PVOID PVOID
RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap,
_In_ UINT64 Key, _In_ UINT32 Index,
_In_ PVOID Compare) _In_ PVOID Compare)
{ {
UINT32 index = 0; UINT32 index = 0;
PRTL_HASHMAP_ENTRY entry = NULL; PRTL_HASHMAP_ENTRY entry = NULL;
index = Hashmap->hash_function(Key); if (!Hashmap->active)
if (!RtlpHashmapIsIndexInRange(Hashmap, index)) {
DEBUG_ERROR("Key is not in range of buckets");
return NULL; return NULL;
}
entry = &Hashmap->buckets[index]; entry = &Hashmap->buckets[index];
@ -178,7 +200,7 @@ RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap,
/* Assumes lock is held */ /* Assumes lock is held */
BOOLEAN BOOLEAN
RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap,
_In_ UINT64 Key, _In_ UINT32 Index,
_In_ PVOID Compare) _In_ PVOID Compare)
{ {
UINT32 index = 0; UINT32 index = 0;
@ -186,12 +208,8 @@ RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap,
PLIST_ENTRY list_entry = NULL; PLIST_ENTRY list_entry = NULL;
PRTL_HASHMAP_ENTRY entry = NULL; PRTL_HASHMAP_ENTRY entry = NULL;
index = Hashmap->hash_function(Key); if (!Hashmap->active)
if (!RtlpHashmapIsIndexInRange(Hashmap, index)) {
DEBUG_ERROR("Key is not in range of buckets");
return FALSE; return FALSE;
}
list_head = &(&Hashmap->buckets[index])->entry; list_head = &(&Hashmap->buckets[index])->entry;
list_entry = list_head->Flink; list_entry = list_head->Flink;
@ -229,6 +247,8 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap,
PRTL_HASHMAP_ENTRY entry = NULL; PRTL_HASHMAP_ENTRY entry = NULL;
for (UINT32 index = 0; index < Hashmap->bucket_count; index++) { for (UINT32 index = 0; index < Hashmap->bucket_count; index++) {
KeAcquireGuardedMutex(&Hashmap->locks[index]);
list_head = &Hashmap->buckets[index]; list_head = &Hashmap->buckets[index];
list_entry = list_head->Flink; list_entry = list_head->Flink;
@ -240,5 +260,7 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap,
list_entry = list_entry->Flink; list_entry = list_entry->Flink;
} }
KeReleaseGuardedMutex(&Hashmap->locks[index]);
} }
} }

View file

@ -21,6 +21,9 @@ typedef struct _RTL_HASHMAP {
/* Array of RTL_HASHMAP_ENTRIES with length = bucket_count */ /* Array of RTL_HASHMAP_ENTRIES with length = bucket_count */
PRTL_HASHMAP_ENTRY buckets; PRTL_HASHMAP_ENTRY buckets;
/* per bucket locks */
PKGUARDED_MUTEX locks;
/* Number of buckets, ideally a prime number */ /* Number of buckets, ideally a prime number */
UINT32 bucket_count; UINT32 bucket_count;
@ -31,41 +34,41 @@ typedef struct _RTL_HASHMAP {
HASH_FUNCTION hash_function; HASH_FUNCTION hash_function;
COMPARE_FUNCTION compare_function; COMPARE_FUNCTION compare_function;
KGUARDED_MUTEX lock;
/* in the future bucket entries will use this */ /* in the future bucket entries will use this */
LOOKASIDE_LIST_EX pool; LOOKASIDE_LIST_EX pool;
/* user allocated context */ /* user allocated context */
PVOID context; PVOID context;
volatile UINT32 active; volatile UINT32 active;
} RTL_HASHMAP, *PRTL_HASHMAP; } RTL_HASHMAP, *PRTL_HASHMAP;
typedef VOID (*ENUMERATE_HASHMAP)(_In_ PRTL_HASHMAP_ENTRY Entry, typedef VOID (*ENUMERATE_HASHMAP)(_In_ PRTL_HASHMAP_ENTRY Entry,
_In_opt_ PVOID Context); _In_opt_ PVOID Context);
#define STATUS_INVALID_HASHMAP_INDEX -1
/* Hashmap is caller allocated */ /* Hashmap is caller allocated */
NTSTATUS NTSTATUS
RtlHashmapCreate(_In_ UINT32 BucketCount, RtlHashmapCreate(_In_ UINT32 BucketCount,
_In_ UINT32 EntryObjectSize, _In_ UINT32 EntryObjectSize,
_In_ HASH_FUNCTION HashFunction, _In_ HASH_FUNCTION HashFunction,
_In_ COMPARE_FUNCTION CompareFunction, _In_ COMPARE_FUNCTION CompareFunction,
_In_opt_ PVOID Context, _In_opt_ PVOID Context,
_Out_ PRTL_HASHMAP Hashmap); _Out_ PRTL_HASHMAP Hashmap);
PVOID PVOID
RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT64 Key); RtlHashmapEntryInsert(_In_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index);
PVOID PVOID
RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap, RtlHashmapEntryLookup(_In_ PRTL_HASHMAP Hashmap,
_In_ UINT64 Key, _In_ UINT32 Index,
_In_ PVOID Compare); _In_ PVOID Compare);
BOOLEAN BOOLEAN
RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap, RtlHashmapEntryDelete(_Inout_ PRTL_HASHMAP Hashmap,
_In_ UINT64 Key, _In_ UINT32 Index,
_In_ PVOID Compare); _In_ PVOID Compare);
VOID VOID
RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap, RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap,
@ -75,19 +78,12 @@ RtlHashmapEnumerate(_In_ PRTL_HASHMAP Hashmap,
VOID VOID
RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap); RtlHashmapDelete(_In_ PRTL_HASHMAP Hashmap);
FORCEINLINE INT32
VOID RtlHashmapHashKeyAndAcquireBucket(_Inout_ PRTL_HASHMAP Hashmap,
RtlHashmapAcquireLock(_Inout_ PRTL_HASHMAP Hashmap) _In_ UINT64 Key);
{
KeAcquireGuardedMutex(&Hashmap->lock);
}
FORCEINLINE
VOID VOID
RtlHashmapReleaseLock(_Inout_ PRTL_HASHMAP Hashmap) RtlHashmapReleaseBucket(_Inout_ PRTL_HASHMAP Hashmap, _In_ UINT32 Index);
{
KeReleaseGuardedMutex(&Hashmap->lock);
}
FORCEINLINE FORCEINLINE
VOID VOID
@ -96,5 +92,4 @@ RtlHashmapSetInactive(_Inout_ PRTL_HASHMAP Hashmap)
Hashmap->active = FALSE; Hashmap->active = FALSE;
} }
#endif #endif