diff --git a/driver/callbacks.c b/driver/callbacks.c index 6dec081..162113f 100644 --- a/driver/callbacks.c +++ b/driver/callbacks.c @@ -12,6 +12,7 @@ #include "crypt.h" #include "map.h" #include "util.h" +#include "tree.h" #define PROCESS_HASHMAP_BUCKET_COUNT 101 @@ -56,23 +57,18 @@ UnregisterImageLoadNotifyRoutine() VOID UnregisterThreadCreateNotifyRoutine() { - PTHREAD_LIST_HEAD list = GetThreadList(); - InterlockedExchange(&list->active, FALSE); + PRB_TREE tree = GetThreadTree(); + InterlockedExchange(&tree->active, FALSE); ImpPsRemoveCreateThreadNotifyRoutine(ThreadCreateNotifyRoutine); } VOID CleanupThreadListOnDriverUnload() { - PTHREAD_LIST_HEAD list = GetThreadList(); + PRB_TREE tree = GetThreadTree(); DEBUG_VERBOSE("Freeing thread list!"); - for (;;) { - if (!LookasideListFreeFirstEntry( - &list->start, &list->lock, CleanupThreadListFreeCallback)) { - ExDeleteLookasideListEx(&list->lookaside_list); - return; - } - } + RtlRbTreeEnumerate(tree, CleanupThreadListFreeCallback, NULL); + RtlRbTreeDeleteTree(tree); } VOID @@ -85,27 +81,6 @@ CleanupDriverListOnDriverUnload() } } -VOID -EnumerateThreadListWithCallbackRoutine( - _In_ THREADLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context) -{ - PTHREAD_LIST_HEAD list = GetThreadList(); - ImpKeAcquireGuardedMutex(&list->lock); - - if (!CallbackRoutine) - goto unlock; - - PTHREAD_LIST_ENTRY entry = list->start.Next; - - while (entry) { - CallbackRoutine(entry, Context); - entry = (PTHREAD_LIST_ENTRY)entry->list.Next; - } - -unlock: - ImpKeReleaseGuardedMutex(&list->lock); -} - VOID EnumerateDriverListWithCallbackRoutine( _In_ DRIVERLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context) @@ -299,7 +274,10 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo, module->base = ImageInfo->ImageBase; module->size = ImageInfo->ImageSize; - /* We dont care if this errors. */ + /* + * 1. We dont care if this errors + * 2. There is a bug with the conversion need 2 look into... + */ if (FullImageName) UnicodeToCharBufString( FullImageName, module->path, sizeof(module->path)); @@ -555,52 +533,45 @@ InitialiseProcessHashmap() return status; } +STATIC +UINT32 +ThreadListTreeCompare(_In_ PVOID Key, _In_ PVOID Object) +{ + HANDLE tid_1 = *((PHANDLE)Object); + HANDLE tid_2 = *((PHANDLE)Key); + + if (tid_2 < tid_1) + return RB_TREE_LESS_THAN; + else if (tid_2 > tid_1) + return RB_TREE_GREATER_THAN; + else + return RB_TREE_EQUAL; +} + NTSTATUS InitialiseThreadList() { - NTSTATUS status = STATUS_UNSUCCESSFUL; - PTHREAD_LIST_HEAD list = GetThreadList(); + NTSTATUS status = STATUS_UNSUCCESSFUL; + PRB_TREE tree = GetThreadTree(); - status = ExInitializeLookasideListEx(&list->lookaside_list, - NULL, - NULL, - POOL_NX_ALLOCATION, - 0, - sizeof(THREAD_LIST_ENTRY), - POOL_TAG_PROCESS_LIST, - 0); + status = + RtlRbTreeCreate(ThreadListTreeCompare, sizeof(THREAD_LIST_ENTRY), tree); - if (!NT_SUCCESS(status)) { - DEBUG_ERROR("ExInitializeLookasideListEx failed with status %x", - status); - return status; - } + if (!NT_SUCCESS(status)) + DEBUG_ERROR("RtlRbTreeCreate: %x", status); - InterlockedExchange(&list->active, TRUE); - ListInit(&list->start, &list->lock); + tree->active = TRUE; return status; } VOID -FindThreadListEntryByThreadAddress(_In_ PKTHREAD Thread, +FindThreadListEntryByThreadAddress(_In_ HANDLE ThreadId, _Out_ PTHREAD_LIST_ENTRY* Entry) { - PTHREAD_LIST_HEAD list = GetThreadList(); - ImpKeAcquireGuardedMutex(&list->lock); - *Entry = NULL; - - PTHREAD_LIST_ENTRY entry = (PTHREAD_LIST_ENTRY)list->start.Next; - - while (entry) { - if (entry->thread == Thread) { - *Entry = entry; - goto unlock; - } - - entry = entry->list.Next; - } -unlock: - ImpKeReleaseGuardedMutex(&list->lock); + PRB_TREE tree = GetThreadTree(); + RtlRbTreeAcquireLock(tree); + *Entry = RtlRbTreeFindNode(tree, &ThreadId); + RtlRbTreeReleaselock(tree); } FORCEINLINE @@ -718,10 +689,10 @@ ThreadCreateNotifyRoutine(_In_ HANDLE ProcessId, PTHREAD_LIST_ENTRY entry = NULL; PKTHREAD thread = NULL; PKPROCESS process = NULL; - PTHREAD_LIST_HEAD list = GetThreadList(); + PRB_TREE tree = GetThreadTree(); /* ensure we don't insert new entries if we are unloading */ - if (!list->active) + if (!tree->active) return; ImpPsLookupThreadByThreadId(ThreadId, &thread); @@ -730,33 +701,37 @@ ThreadCreateNotifyRoutine(_In_ HANDLE ProcessId, if (!thread || !process) return; + RtlRbTreeAcquireLock(tree); + if (Create) { - entry = ExAllocateFromLookasideListEx(&list->lookaside_list); + entry = RtlRbTreeInsertNode(tree, &ThreadId); if (!entry) - return; + goto end; ImpObfReferenceObject(thread); ImpObfReferenceObject(process); + entry->thread_id = ThreadId; entry->thread = thread; entry->owning_process = process; entry->apc = NULL; entry->apc_queued = FALSE; - - ListInsert(&list->start, &entry->list, &list->lock); } else { - FindThreadListEntryByThreadAddress(thread, &entry); + entry = RtlRbTreeFindNode(tree, &ThreadId); if (!entry) - return; + goto end; ImpObDereferenceObject(entry->thread); ImpObDereferenceObject(entry->owning_process); - LookasideListRemoveEntry(&list->start, entry, &list->lock); + RtlRbTreeDeleteNode(tree, &ThreadId); } + +end: + RtlRbTreeReleaselock(tree); } VOID diff --git a/driver/callbacks.h b/driver/callbacks.h index 0ca9150..dc53bad 100644 --- a/driver/callbacks.h +++ b/driver/callbacks.h @@ -69,13 +69,9 @@ VOID CleanupThreadListOnDriverUnload(); VOID -FindThreadListEntryByThreadAddress(_In_ PKTHREAD Thread, +FindThreadListEntryByThreadAddress(_In_ HANDLE ThreadId, _Out_ PTHREAD_LIST_ENTRY* Entry); -VOID -EnumerateThreadListWithCallbackRoutine( - _In_ THREADLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context); - VOID FindDriverEntryByBaseAddress(_In_ PVOID ImageBase, _Out_ PDRIVER_LIST_ENTRY* Entry); diff --git a/driver/common.h b/driver/common.h index 6705943..31a75c2 100644 --- a/driver/common.h +++ b/driver/common.h @@ -45,11 +45,8 @@ "donna-ac : [VERBOSE] : " fmt "\n", \ ##__VA_ARGS__) -#define HEX_DUMP(fmt, ...) \ - DbgPrintEx(DPFLTR_DEFAULT_ID, \ - LOG_VERBOSE_LEVEL, \ - fmt, \ - ##__VA_ARGS__) +#define HEX_DUMP(fmt, ...) \ + DbgPrintEx(DPFLTR_DEFAULT_ID, LOG_VERBOSE_LEVEL, fmt, ##__VA_ARGS__) #define STATIC static #define INLINE inline @@ -87,7 +84,7 @@ typedef struct _DRIVER_LIST_HEAD { } DRIVER_LIST_HEAD, *PDRIVER_LIST_HEAD; typedef struct _THREAD_LIST_ENTRY { - SINGLE_LIST_ENTRY list; + HANDLE thread_id; PKTHREAD thread; PKPROCESS owning_process; BOOLEAN apc_queued; @@ -337,6 +334,7 @@ typedef struct _ACTIVE_SESSION { #define POOL_TAG_IRP_QUEUE 'irpp' #define POOL_TAG_TIMER 'time' #define POOL_TAG_MODULE_LIST 'elom' +#define POOL_TAG_RB_TREE 'eert' #define POOL_TAG_HASHMAP 'hsah' #define IA32_APERF_MSR 0x000000E8 diff --git a/driver/driver.c b/driver/driver.c index 56739f0..447cf36 100644 --- a/driver/driver.c +++ b/driver/driver.c @@ -96,7 +96,7 @@ typedef struct _DRIVER_CONFIG { TIMER_OBJECT timer; ACTIVE_SESSION session_information; - THREAD_LIST_HEAD thread_list; + RB_TREE thread_tree; DRIVER_LIST_HEAD driver_list; RTL_HASHMAP process_hashmap; SHARED_MAPPING mapping; @@ -270,11 +270,11 @@ GetDriverConfigSystemInformation() return &g_DriverConfig->system_information; } -PTHREAD_LIST_HEAD -GetThreadList() +PRB_TREE +GetThreadTree() { PAGED_CODE(); - return &g_DriverConfig->thread_list; + return &g_DriverConfig->thread_tree; } PDRIVER_LIST_HEAD diff --git a/driver/driver.h b/driver/driver.h index 284878a..350945c 100644 --- a/driver/driver.h +++ b/driver/driver.h @@ -10,6 +10,7 @@ #include "integrity.h" #include "callbacks.h" #include "map.h" +#include "tree.h" BCRYPT_ALG_HANDLE* GetCryptHandle_AES(); @@ -50,8 +51,8 @@ GetDriverSymbolicLink(); PSYSTEM_INFORMATION GetDriverConfigSystemInformation(); -PTHREAD_LIST_HEAD -GetThreadList(); +PRB_TREE +GetThreadTree(); PDRIVER_LIST_HEAD GetDriverList(); diff --git a/driver/driver.vcxproj b/driver/driver.vcxproj index 86892ff..ff611e3 100644 --- a/driver/driver.vcxproj +++ b/driver/driver.vcxproj @@ -261,6 +261,7 @@ + @@ -283,6 +284,7 @@ + diff --git a/driver/driver.vcxproj.filters b/driver/driver.vcxproj.filters index 600e9d8..3f9c5b8 100644 --- a/driver/driver.vcxproj.filters +++ b/driver/driver.vcxproj.filters @@ -78,6 +78,9 @@ Source Files + + Source Files + @@ -152,6 +155,9 @@ Header Files + + Header Files + diff --git a/driver/list.c b/driver/list.c index aa289a7..8d9481b 100644 --- a/driver/list.c +++ b/driver/list.c @@ -120,7 +120,7 @@ LookasideListRemoveEntry(_Inout_ PSINGLE_LIST_ENTRY Head, { ImpKeAcquireGuardedMutex(Lock); - PTHREAD_LIST_HEAD head = GetThreadList(); + PTHREAD_LIST_HEAD head = GetThreadTree(); PSINGLE_LIST_ENTRY entry = Head->Next; if (!entry) @@ -153,7 +153,7 @@ LookasideListFreeFirstEntry(_Inout_ PSINGLE_LIST_ENTRY Head, { ImpKeAcquireGuardedMutex(Lock); - PTHREAD_LIST_HEAD head = GetThreadList(); + PTHREAD_LIST_HEAD head = GetThreadTree(); BOOLEAN result = FALSE; if (Head->Next) { diff --git a/driver/modules.c b/driver/modules.c index 61c5e7f..36ed4db 100644 --- a/driver/modules.c +++ b/driver/modules.c @@ -9,6 +9,7 @@ #include "thread.h" #include "pe.h" #include "crypt.h" +#include "tree.h" #define WHITELISTED_MODULE_TAG 'whte' @@ -1191,8 +1192,8 @@ ValidateThreadsViaKernelApc() InsertApcContext(context); SetApcAllocationInProgress(context); - EnumerateThreadListWithCallbackRoutine(ValidateThreadViaKernelApcCallback, - context); + RtlRbTreeEnumerate( + GetThreadTree(), ValidateThreadViaKernelApcCallback, context); UnsetApcAllocationInProgress(context); return status; } diff --git a/driver/thread.c b/driver/thread.c index 1cdca4a..42d19e9 100644 --- a/driver/thread.c +++ b/driver/thread.c @@ -8,6 +8,7 @@ #include "queue.h" #include "session.h" #include "imports.h" +#include "tree.h" #include "crypt.h" #ifdef ALLOC_PRAGMA @@ -137,8 +138,6 @@ DetectThreadsAttachedToProtectedProcess() { PAGED_CODE(); DEBUG_VERBOSE("Detecting threads attached to our process..."); - EnumerateThreadListWithCallbackRoutine(DetectAttachedThreadsProcessCallback, - NULL); + RtlRbTreeEnumerate( + GetThreadTree(), DetectAttachedThreadsProcessCallback, NULL); } - - diff --git a/driver/tree.c b/driver/tree.c new file mode 100644 index 0000000..8a11970 --- /dev/null +++ b/driver/tree.c @@ -0,0 +1,611 @@ +#include "tree.h" + +/* Caller allocated RB_TREE */ +NTSTATUS +RtlRbTreeCreate(_In_ RB_COMPARE Compare, + _In_ UINT32 ObjectSize, + _Out_ PRB_TREE Tree) +{ + NTSTATUS status = STATUS_UNSUCCESSFUL; + + if (!ARGUMENT_PRESENT(Compare)) + return STATUS_INVALID_PARAMETER; + + status = ExInitializeLookasideListEx(&Tree->pool, + NULL, + NULL, + NonPagedPoolNx, + 0, + ObjectSize + sizeof(RB_TREE_NODE), + POOL_TAG_RB_TREE, + 0); + + if (!NT_SUCCESS(status)) + return status; + + Tree->compare = Compare; + KeInitializeGuardedMutex(&Tree->lock); + + return STATUS_SUCCESS; +} + +/* This function is used to maintain the balance of a red-black tree by + * performing a left rotation around a given node. A left rotation moves the + * given node down to the left and its right child up to take its place. + * + * The structure of the tree before and after the rotation is as follows: + * + * Before Rotation: After Rotation: + * (Node) (Right_Child) + * / \ / \ + * (A) (Right_Child) -> (Node) (C) + * / \ / \ + * (B) (C) (A) (B) + */ +STATIC +VOID +RtlpRbTreeRotateLeft(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node) +{ + PRB_TREE_NODE right_child = Node->right; + Node->right = right_child->left; + + if (right_child->left) + right_child->left->parent = Node; + + right_child->parent = Node->parent; + + if (!Node->parent) + Tree->root = right_child; + else if (Node == Node->parent->left) + Node->parent->left = right_child; + else + Node->parent->right = right_child; + + right_child->left = Node; + Node->parent = right_child; +} + +/* + * This function is used to maintain the balance of a red-black tree by + * performing a right rotation around a given node. A right rotation moves the + * given node down to the right and its left child up to take its place. + * + * The structure of the tree before and after the rotation is as follows: + * + * Before Rotation: After Rotation: + * (Node) (Left_Child) + * / \ / \ + * (Left_Child) (C) -> (A) (Node) + * / \ / \ + * (A) (B) (B) (C) + * + */ +STATIC +VOID +RtlpRbTreeRotateRight(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node) +{ + PRB_TREE_NODE left_child = Node->left; + Node->left = left_child->right; + + if (left_child->right) + left_child->right->parent = Node; + + left_child->parent = Node->parent; + + if (!Node->parent) + Tree->root = left_child; + else if (Node == Node->parent->right) + Node->parent->right = left_child; + else + Node->parent->left = left_child; + + left_child->right = Node; + Node->parent = left_child; +} + +/* + * This function ensures the red-black tree properties are maintained after a + * new node is inserted. It adjusts the colors and performs rotations as + * necessary. + * + * Example scenario: + * + * Inserted Node causing a fixup: + * (Grandparent) (Parent) + * / \ / \ + * (Parent) (Uncle) -> (Node) (Grandparent) + * / / \ + * (Node) (Left) (Uncle) + */ +STATIC +VOID +RtlpRbTreeFixupInsert(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node) +{ + PRB_TREE_NODE uncle = NULL; + PRB_TREE_NODE parent = NULL; + PRB_TREE_NODE grandparent = NULL; + + while ((parent = Node->parent) && parent->colour == red) { + grandparent = parent->parent; + + if (parent == grandparent->left) { + uncle = grandparent->right; + + if (uncle && uncle->colour == red) { + parent->colour = black; + uncle->colour = black; + grandparent->colour = red; + Node = grandparent; + } + else { + if (Node == parent->right) { + RtlpRbTreeRotateLeft(Tree, parent); + Node = parent; + parent = Node->parent; + } + + parent->colour = black; + grandparent->colour = red; + RtlpRbTreeRotateRight(Tree, grandparent); + } + } + else { + uncle = grandparent->left; + + if (uncle && uncle->colour == red) { + parent->colour = black; + uncle->colour = black; + grandparent->colour = red; + Node = grandparent; + } + else { + if (Node == parent->left) { + RtlpRbTreeRotateRight(Tree, parent); + Node = parent; + parent = Node->parent; + } + + parent->colour = black; + grandparent->colour = red; + RtlpRbTreeRotateLeft(Tree, grandparent); + } + } + } + + Tree->root->colour = black; +} + +/* + * ASSUMES LOCK IS HELD! + * + * This function inserts a new node into the red-black tree, and then calls a + * fix-up routine to ensure the tree properties are maintained. + * + * Example insertion process: + * + * Before insertion: + * (Root) + * / \ + * (Left) (Right) + * + * After insertion: + * (Root) + * / \ + * (Left) (Right) + * / + * (Node) + * + * After fix-up: + * (Root) + * / \ + * (Left) (Node) + * \ + * (Right) + */ +PVOID +RtlRbTreeInsertNode(_In_ PRB_TREE Tree, _In_ PVOID Key) +{ + UINT32 result = 0; + PRB_TREE_NODE node = NULL; + PRB_TREE_NODE parent = NULL; + PRB_TREE_NODE current = NULL; + + node = ExAllocateFromLookasideListEx(&Tree->pool); + + if (!node) + return NULL; + + node->parent = NULL; + node->left = NULL; + node->right = NULL; + node->colour = red; + + current = Tree->root; + + while (current) { + parent = current; + result = Tree->compare(Key, current->object); + + if (result == RB_TREE_LESS_THAN) { + current = current->left; + } + else if (result == RB_TREE_GREATER_THAN) { + current = current->right; + } + else { + ExFreeToLookasideListEx(&Tree->pool, node); + return current->object; + } + } + + node->parent = parent; + + if (!parent) + Tree->root = node; + else if (result == RB_TREE_LESS_THAN) + parent->left = node; + else + parent->right = node; + + RtlpRbTreeFixupInsert(Tree, node); + + return node->object; +} + +/* + * ASSUMES LOCK IS HELD! + * + * This function traverses the left children of the given node to find and + * return the node with the minimum key in the subtree. + * + * Example traversal to find minimum: + * + * (Root) + * / \ + * (Left) (Right) + * / + * (Node) + * + * After finding minimum: + * (Root) + * / \ + * (Node) (Right) + * + * Returns the left-most node. + */ +STATIC +PRB_TREE_NODE +RtlpRbTreeMinimum(_In_ PRB_TREE_NODE Node) +{ + while (Node->left != NULL) + Node = Node->left; + + return Node; +} + +/* + * ASSUMES LOCK IS HELD! + * + * This function is called after a node is deleted from the Red-Black Tree. + * It ensures that the tree remains balanced and the Red-Black properties are + * maintained. It performs the necessary rotations and recoloring. + * + * Example fixup scenarios: + * + * Before Fixup: After Fixup: + * (Parent) (Parent) + * / \ / \ + * (Node) (Sibling) (Node) (Sibling) + * / \ / \ + * (Left) (Right) (Left) (Right) + * + * The fixup process ensures that the tree remains balanced. + */ +STATIC +VOID +RtlpRbTreeFixupDelete(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node) +{ + PRB_TREE_NODE sibling = NULL; + + while (Node != Tree->root && Node->colour == black) { + if (Node == Node->parent->left) { + sibling = Node->parent->right; + + if (sibling && sibling->colour == red) { + sibling->colour = black; + Node->parent->colour = red; + RtlpRbTreeRotateLeft(Tree, Node->parent); + sibling = Node->parent->right; + } + + if (sibling && (!sibling->left || sibling->left->colour == black) && + (!sibling->right || sibling->right->colour == black)) { + sibling->colour = red; + Node = Node->parent; + } + else { + if (sibling && + (!sibling->right || sibling->right->colour == black)) { + if (sibling->left) { + sibling->left->colour = black; + } + sibling->colour = red; + RtlpRbTreeRotateRight(Tree, sibling); + sibling = Node->parent->right; + } + + if (sibling) { + sibling->colour = Node->parent->colour; + Node->parent->colour = black; + if (sibling->right) { + sibling->right->colour = black; + } + RtlpRbTreeRotateLeft(Tree, Node->parent); + } + Node = Tree->root; + } + } + else { + sibling = Node->parent->left; + + if (sibling && sibling->colour == red) { + sibling->colour = black; + Node->parent->colour = red; + RtlpRbTreeRotateRight(Tree, Node->parent); + sibling = Node->parent->left; + } + + if (sibling && + (!sibling->right || sibling->right->colour == black) && + (!sibling->left || sibling->left->colour == black)) { + sibling->colour = red; + Node = Node->parent; + } + else { + if (sibling && + (!sibling->left || sibling->left->colour == black)) { + if (sibling->right) { + sibling->right->colour = black; + } + sibling->colour = red; + RtlpRbTreeRotateLeft(Tree, sibling); + sibling = Node->parent->left; + } + + if (sibling) { + sibling->colour = Node->parent->colour; + Node->parent->colour = black; + if (sibling->left) { + sibling->left->colour = black; + } + RtlpRbTreeRotateRight(Tree, Node->parent); + } + Node = Tree->root; + } + } + } + + Node->colour = black; +} + +/* + * ASSUMES LOCK IS HELD! + * + * This function replaces the subtree rooted at the node `toBeReplacedNode` with + * the subtree rooted at the node `replacementNode`. It adjusts the parent + * pointers accordingly. + * + * Example scenario: + * + * Before Transplant: After Transplant: + * (ParentNode) (ParentNode) + * / \ / \ + * (toBeReplaced) Sibling (Replacement) Sibling + * / \ / \ + * Left Right Left Right + * + * The transplant process ensures that the subtree rooted at `replacementNode` + * takes the place of the subtree rooted at `toBeReplacedNode`. + */ +STATIC +VOID +RtlpRbTreeTransplant(_In_ PRB_TREE Tree, + _In_ PRB_TREE_NODE Target, + _In_ PRB_TREE_NODE Replacement) +{ + if (!Target->parent) + Tree->root = Replacement; + else if (Target == Target->parent->left) + Target->parent->left = Replacement; + else + Target->parent->right = Replacement; + + if (Replacement) + Replacement->parent = Target->parent; +} + +/* + * ASSUMES LOCK IS HELD! + * + * This function removes a node with the specified key from the Red-Black Tree + * and ensures the tree remains balanced by performing necessary rotations and + * recoloring. + * + * Example scenario: + * + * Before Deletion: After Deletion: + * (ParentNode) (ParentNode) + * / \ / \ + * (TargetNode) Sibling (Replacement) Sibling + * / \ / \ + * LeftChild RightChild LeftChild RightChild + * + * The deletion process involves finding the target node, replacing it with a + * suitable successor or child, and ensuring the Red-Black Tree properties are + * maintained. + */ +VOID +RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key) +{ + UINT32 result = 0; + COLOUR colour = 0; + PRB_TREE_NODE node = Tree->root; + PRB_TREE_NODE target = NULL; + PRB_TREE_NODE child = NULL; + PRB_TREE_NODE parent = NULL; + PRB_TREE_NODE successor = NULL; + + while (node) { + result = Tree->compare(Key, node->object); + if (result == RB_TREE_EQUAL) { + target = node; + break; + } + else if (result == RB_TREE_LESS_THAN) { + node = node->left; + } + else { + node = node->right; + } + } + + if (!target) + return; + + colour = target->colour; + + if (!target->left) { + child = target->right; + RtlpRbTreeTransplant(Tree, target, target->right); + } + else if (!target->right) { + child = target->left; + RtlpRbTreeTransplant(Tree, target, target->left); + } + else { + successor = RtlpRbTreeMinimum(target->right); + colour = successor->colour; + child = successor->right; + + if (successor->parent == target) { + if (child) + child->parent = successor; + } + else { + RtlpRbTreeTransplant(Tree, successor, successor->right); + successor->right = target->right; + successor->right->parent = successor; + } + + RtlpRbTreeTransplant(Tree, target, successor); + successor->left = target->left; + successor->left->parent = successor; + successor->colour = target->colour; + } + + if (colour == black && child) + RtlpRbTreeFixupDelete(Tree, child); + + ExFreeToLookasideListEx(&Tree->pool, target); +} + +PVOID +RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key) +{ + INT32 result = 0; + PRB_TREE_NODE current = Tree->root; + + while (current) { + result = Tree->compare(Key, current->object); + + if (result == RB_TREE_EQUAL) + return current->object; + else if (result == RB_TREE_LESS_THAN) + current = current->left; + else + current = current->right; + } + + return NULL; +} + +STATIC +VOID +RtlpRbTreeEnumerate(_In_ PRB_TREE_NODE Node, + _In_ RB_ENUM_CALLBACK Callback, + _In_opt_ PVOID Context) +{ + if (Node == NULL) + return; + + RtlpRbTreeEnumerate(Node->left, Callback, Context); + Callback(Node->object, Context); + RtlpRbTreeEnumerate(Node->right, Callback, Context); +} + +VOID +RtlRbTreeEnumerate(_In_ PRB_TREE Tree, + _In_ RB_ENUM_CALLBACK Callback, + _In_opt_ PVOID Context) +{ + if (Tree->root == NULL) + return; + + RtlRbTreeAcquireLock(Tree); + RtlpRbTreeEnumerate(Tree->root, Callback, Context); + RtlRbTreeReleaselock(Tree); +} + +STATIC +VOID +RtlpPrintInOrder(PRB_TREE_NODE Node) +{ + if (Node == NULL) + return; + + RtlpPrintInOrder(Node->left); + + const char* color = (Node->colour == red) ? "Red" : "Black"; + DbgPrintEx(DPFLTR_DEFAULT_ID, + DPFLTR_INFO_LEVEL, + "Node: Key=%p, Color=%s\n", + *((PHANDLE)Node->object), + color); + + RtlpPrintInOrder(Node->right); +} + +/* assumes lock is held */ +VOID +RtlRbTreeInOrderPrint(_In_ PRB_TREE Tree) +{ + DEBUG_ERROR("*************************************************"); + DEBUG_ERROR("<><><><>STARTING IN ORDER PRINT <><><><><><"); + RtlpPrintInOrder(Tree->root); + DEBUG_ERROR("<><><><>ENDING IN ORDER PRINT <><><><><><"); + DEBUG_ERROR("*************************************************"); +} + +STATIC +VOID +RtlpRbTreeDeleteSubtree(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node) +{ + if (Node == NULL) + return; + + RtlpRbTreeDeleteSubtree(Tree, Node->left); + RtlpRbTreeDeleteSubtree(Tree, Node->right); + + ExFreeToLookasideListEx(&Tree->pool, Node); +} + +VOID +RtlRbTreeDeleteTree(_In_ PRB_TREE Tree) +{ + Tree->active = FALSE; + + RtlRbTreeAcquireLock(Tree); + RtlpRbTreeDeleteSubtree(Tree, Tree->root); + ExDeleteLookasideListEx(&Tree->pool); + RtlRbTreeReleaselock(Tree); +} \ No newline at end of file diff --git a/driver/tree.h b/driver/tree.h new file mode 100644 index 0000000..4ebda95 --- /dev/null +++ b/driver/tree.h @@ -0,0 +1,75 @@ +#ifndef TREE_H +#define TREE_H + +#include "common.h" + +#define RB_TREE_EQUAL 0 +#define RB_TREE_LESS_THAN 1 +#define RB_TREE_GREATER_THAN 2 + +typedef enum _COLOUR { red, black } COLOUR; + +typedef struct _RB_TREE_NODE { + struct _RB_TREE_NODE* parent; + struct _RB_TREE_NODE* left; + struct _RB_TREE_NODE* right; + COLOUR colour; + CHAR object[]; +} RB_TREE_NODE, *PRB_TREE_NODE; + +typedef UINT32 (*RB_COMPARE)(_In_ PVOID Key, _In_ PVOID Object); + +typedef struct _RB_TREE { + PRB_TREE_NODE root; + KGUARDED_MUTEX lock; + RB_COMPARE compare; + LOOKASIDE_LIST_EX pool; + UINT32 object_size; + UINT32 active; +} RB_TREE, *PRB_TREE; + +typedef VOID (*RB_CALLBACK)(PRB_TREE_NODE Node); +typedef VOID (*RB_ENUM_CALLBACK)(_In_ PVOID Object, _In_opt_ PVOID Context); + +PVOID +RtlRbTreeInsertNode(_In_ PRB_TREE Tree, _In_ PVOID Key); + +NTSTATUS +RtlRbTreeCreate(_In_ RB_COMPARE Compare, + _In_ UINT32 ObjectSize, + _Out_ PRB_TREE Tree); + +VOID +RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key); + +PVOID +RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key); + +VOID +RtlRbTreeEnumerate(_In_ PRB_TREE Tree, + _In_ RB_ENUM_CALLBACK Callback, + _In_opt_ PVOID Context); + +VOID +RtlRbTreeDeleteTree(_In_ PRB_TREE Tree); + +VOID +RtlRbTreeInOrderPrint(_In_ PRB_TREE Tree); + +FORCEINLINE +STATIC +VOID +RtlRbTreeAcquireLock(_Inout_ PRB_TREE Tree) +{ + KeAcquireGuardedMutex(&Tree->lock); +} + +FORCEINLINE +STATIC +VOID +RtlRbTreeReleaselock(_Inout_ PRB_TREE Tree) +{ + KeReleaseGuardedMutex(&Tree->lock); +} + +#endif \ No newline at end of file