diff --git a/driver/callbacks.c b/driver/callbacks.c index 8a7ebfd..180847b 100644 --- a/driver/callbacks.c +++ b/driver/callbacks.c @@ -150,6 +150,8 @@ InitialiseDriverList() return status; } + KeAcquireGuardedMutex(&head->lock); + /* skip hal.dll and ntoskrnl.exe */ for (UINT32 index = 2; index < modules.module_count; index++) { entry = ImpExAllocatePool2(POOL_FLAG_NON_PAGED, @@ -183,11 +185,11 @@ InitialiseDriverList() entry->hashed = FALSE; } - KeAcquireGuardedMutex(&head->lock); InsertHeadList(&head->list_entry, &entry->list_entry); - KeReleaseGuardedMutex(&head->lock); } + KeReleaseGuardedMutex(&head->lock); + head->active = TRUE; if (modules.address) @@ -591,7 +593,7 @@ FindThreadListEntryByThreadAddress(_In_ HANDLE ThreadId, { PRB_TREE tree = GetThreadTree(); RtlRbTreeAcquireLock(tree); - *Entry = RtlRbTreeFindNode(tree, &ThreadId); + *Entry = RtlRbTreeFindNodeObject(tree, &ThreadId); RtlRbTreeReleaselock(tree); } @@ -740,7 +742,7 @@ ThreadCreateNotifyRoutine(_In_ HANDLE ProcessId, entry->apc_queued = FALSE; } else { - entry = RtlRbTreeFindNode(tree, &ThreadId); + entry = RtlRbTreeFindNodeObject(tree, &ThreadId); if (!entry) goto end; diff --git a/driver/containers/tree.c b/driver/containers/tree.c index 8a11970..0f7c828 100644 --- a/driver/containers/tree.c +++ b/driver/containers/tree.c @@ -424,6 +424,27 @@ RtlpRbTreeTransplant(_In_ PRB_TREE Tree, Replacement->parent = Target->parent; } +STATIC +PVOID +RtlpRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key) +{ + INT32 result = 0; + PRB_TREE_NODE current = Tree->root; + + while (current) { + result = Tree->compare(Key, current->object); + + if (result == RB_TREE_EQUAL) + return current; + else if (result == RB_TREE_LESS_THAN) + current = current->left; + else + current = current->right; + } + + return NULL; +} + /* * ASSUMES LOCK IS HELD! * @@ -447,27 +468,13 @@ RtlpRbTreeTransplant(_In_ PRB_TREE Tree, VOID RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key) { - UINT32 result = 0; - COLOUR colour = 0; - PRB_TREE_NODE node = Tree->root; PRB_TREE_NODE target = NULL; PRB_TREE_NODE child = NULL; - PRB_TREE_NODE parent = NULL; PRB_TREE_NODE successor = NULL; + COLOUR colour = {0}; - while (node) { - result = Tree->compare(Key, node->object); - if (result == RB_TREE_EQUAL) { - target = node; - break; - } - else if (result == RB_TREE_LESS_THAN) { - node = node->left; - } - else { - node = node->right; - } - } + /* We want the node not the object */ + target = RtlpRbTreeFindNode(Tree, Key); if (!target) return; @@ -509,8 +516,11 @@ RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key) ExFreeToLookasideListEx(&Tree->pool, target); } +/* Public API that is used to find the node object for an associated key. Should + * be used externally when wanting to find an object with a key value. If you + * are wanting to get the node itself, use the RtlpRbTreeFindNode routine. */ PVOID -RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key) +RtlRbTreeFindNodeObject(_In_ PRB_TREE Tree, _In_ PVOID Key) { INT32 result = 0; PRB_TREE_NODE current = Tree->root; diff --git a/driver/containers/tree.h b/driver/containers/tree.h index 3469656..49bf3c5 100644 --- a/driver/containers/tree.h +++ b/driver/containers/tree.h @@ -43,7 +43,7 @@ VOID RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key); PVOID -RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key); +RtlRbTreeFindNodeObject(_In_ PRB_TREE Tree, _In_ PVOID Key); VOID RtlRbTreeEnumerate(_In_ PRB_TREE Tree,