diff --git a/driver/callbacks.c b/driver/callbacks.c
index 6dec081..162113f 100644
--- a/driver/callbacks.c
+++ b/driver/callbacks.c
@@ -12,6 +12,7 @@
#include "crypt.h"
#include "map.h"
#include "util.h"
+#include "tree.h"
#define PROCESS_HASHMAP_BUCKET_COUNT 101
@@ -56,23 +57,18 @@ UnregisterImageLoadNotifyRoutine()
VOID
UnregisterThreadCreateNotifyRoutine()
{
- PTHREAD_LIST_HEAD list = GetThreadList();
- InterlockedExchange(&list->active, FALSE);
+ PRB_TREE tree = GetThreadTree();
+ InterlockedExchange(&tree->active, FALSE);
ImpPsRemoveCreateThreadNotifyRoutine(ThreadCreateNotifyRoutine);
}
VOID
CleanupThreadListOnDriverUnload()
{
- PTHREAD_LIST_HEAD list = GetThreadList();
+ PRB_TREE tree = GetThreadTree();
DEBUG_VERBOSE("Freeing thread list!");
- for (;;) {
- if (!LookasideListFreeFirstEntry(
- &list->start, &list->lock, CleanupThreadListFreeCallback)) {
- ExDeleteLookasideListEx(&list->lookaside_list);
- return;
- }
- }
+ RtlRbTreeEnumerate(tree, CleanupThreadListFreeCallback, NULL);
+ RtlRbTreeDeleteTree(tree);
}
VOID
@@ -85,27 +81,6 @@ CleanupDriverListOnDriverUnload()
}
}
-VOID
-EnumerateThreadListWithCallbackRoutine(
- _In_ THREADLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context)
-{
- PTHREAD_LIST_HEAD list = GetThreadList();
- ImpKeAcquireGuardedMutex(&list->lock);
-
- if (!CallbackRoutine)
- goto unlock;
-
- PTHREAD_LIST_ENTRY entry = list->start.Next;
-
- while (entry) {
- CallbackRoutine(entry, Context);
- entry = (PTHREAD_LIST_ENTRY)entry->list.Next;
- }
-
-unlock:
- ImpKeReleaseGuardedMutex(&list->lock);
-}
-
VOID
EnumerateDriverListWithCallbackRoutine(
_In_ DRIVERLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context)
@@ -299,7 +274,10 @@ ImageLoadInsertNonSystemImageIntoProcessHashmap(_In_ PIMAGE_INFO ImageInfo,
module->base = ImageInfo->ImageBase;
module->size = ImageInfo->ImageSize;
- /* We dont care if this errors. */
+ /*
+ * 1. We dont care if this errors
+ * 2. There is a bug with the conversion need 2 look into...
+ */
if (FullImageName)
UnicodeToCharBufString(
FullImageName, module->path, sizeof(module->path));
@@ -555,52 +533,45 @@ InitialiseProcessHashmap()
return status;
}
+STATIC
+UINT32
+ThreadListTreeCompare(_In_ PVOID Key, _In_ PVOID Object)
+{
+ HANDLE tid_1 = *((PHANDLE)Object);
+ HANDLE tid_2 = *((PHANDLE)Key);
+
+ if (tid_2 < tid_1)
+ return RB_TREE_LESS_THAN;
+ else if (tid_2 > tid_1)
+ return RB_TREE_GREATER_THAN;
+ else
+ return RB_TREE_EQUAL;
+}
+
NTSTATUS
InitialiseThreadList()
{
- NTSTATUS status = STATUS_UNSUCCESSFUL;
- PTHREAD_LIST_HEAD list = GetThreadList();
+ NTSTATUS status = STATUS_UNSUCCESSFUL;
+ PRB_TREE tree = GetThreadTree();
- status = ExInitializeLookasideListEx(&list->lookaside_list,
- NULL,
- NULL,
- POOL_NX_ALLOCATION,
- 0,
- sizeof(THREAD_LIST_ENTRY),
- POOL_TAG_PROCESS_LIST,
- 0);
+ status =
+ RtlRbTreeCreate(ThreadListTreeCompare, sizeof(THREAD_LIST_ENTRY), tree);
- if (!NT_SUCCESS(status)) {
- DEBUG_ERROR("ExInitializeLookasideListEx failed with status %x",
- status);
- return status;
- }
+ if (!NT_SUCCESS(status))
+ DEBUG_ERROR("RtlRbTreeCreate: %x", status);
- InterlockedExchange(&list->active, TRUE);
- ListInit(&list->start, &list->lock);
+ tree->active = TRUE;
return status;
}
VOID
-FindThreadListEntryByThreadAddress(_In_ PKTHREAD Thread,
+FindThreadListEntryByThreadAddress(_In_ HANDLE ThreadId,
_Out_ PTHREAD_LIST_ENTRY* Entry)
{
- PTHREAD_LIST_HEAD list = GetThreadList();
- ImpKeAcquireGuardedMutex(&list->lock);
- *Entry = NULL;
-
- PTHREAD_LIST_ENTRY entry = (PTHREAD_LIST_ENTRY)list->start.Next;
-
- while (entry) {
- if (entry->thread == Thread) {
- *Entry = entry;
- goto unlock;
- }
-
- entry = entry->list.Next;
- }
-unlock:
- ImpKeReleaseGuardedMutex(&list->lock);
+ PRB_TREE tree = GetThreadTree();
+ RtlRbTreeAcquireLock(tree);
+ *Entry = RtlRbTreeFindNode(tree, &ThreadId);
+ RtlRbTreeReleaselock(tree);
}
FORCEINLINE
@@ -718,10 +689,10 @@ ThreadCreateNotifyRoutine(_In_ HANDLE ProcessId,
PTHREAD_LIST_ENTRY entry = NULL;
PKTHREAD thread = NULL;
PKPROCESS process = NULL;
- PTHREAD_LIST_HEAD list = GetThreadList();
+ PRB_TREE tree = GetThreadTree();
/* ensure we don't insert new entries if we are unloading */
- if (!list->active)
+ if (!tree->active)
return;
ImpPsLookupThreadByThreadId(ThreadId, &thread);
@@ -730,33 +701,37 @@ ThreadCreateNotifyRoutine(_In_ HANDLE ProcessId,
if (!thread || !process)
return;
+ RtlRbTreeAcquireLock(tree);
+
if (Create) {
- entry = ExAllocateFromLookasideListEx(&list->lookaside_list);
+ entry = RtlRbTreeInsertNode(tree, &ThreadId);
if (!entry)
- return;
+ goto end;
ImpObfReferenceObject(thread);
ImpObfReferenceObject(process);
+ entry->thread_id = ThreadId;
entry->thread = thread;
entry->owning_process = process;
entry->apc = NULL;
entry->apc_queued = FALSE;
-
- ListInsert(&list->start, &entry->list, &list->lock);
}
else {
- FindThreadListEntryByThreadAddress(thread, &entry);
+ entry = RtlRbTreeFindNode(tree, &ThreadId);
if (!entry)
- return;
+ goto end;
ImpObDereferenceObject(entry->thread);
ImpObDereferenceObject(entry->owning_process);
- LookasideListRemoveEntry(&list->start, entry, &list->lock);
+ RtlRbTreeDeleteNode(tree, &ThreadId);
}
+
+end:
+ RtlRbTreeReleaselock(tree);
}
VOID
diff --git a/driver/callbacks.h b/driver/callbacks.h
index 0ca9150..dc53bad 100644
--- a/driver/callbacks.h
+++ b/driver/callbacks.h
@@ -69,13 +69,9 @@ VOID
CleanupThreadListOnDriverUnload();
VOID
-FindThreadListEntryByThreadAddress(_In_ PKTHREAD Thread,
+FindThreadListEntryByThreadAddress(_In_ HANDLE ThreadId,
_Out_ PTHREAD_LIST_ENTRY* Entry);
-VOID
-EnumerateThreadListWithCallbackRoutine(
- _In_ THREADLIST_CALLBACK_ROUTINE CallbackRoutine, _In_opt_ PVOID Context);
-
VOID
FindDriverEntryByBaseAddress(_In_ PVOID ImageBase,
_Out_ PDRIVER_LIST_ENTRY* Entry);
diff --git a/driver/common.h b/driver/common.h
index 6705943..31a75c2 100644
--- a/driver/common.h
+++ b/driver/common.h
@@ -45,11 +45,8 @@
"donna-ac : [VERBOSE] : " fmt "\n", \
##__VA_ARGS__)
-#define HEX_DUMP(fmt, ...) \
- DbgPrintEx(DPFLTR_DEFAULT_ID, \
- LOG_VERBOSE_LEVEL, \
- fmt, \
- ##__VA_ARGS__)
+#define HEX_DUMP(fmt, ...) \
+ DbgPrintEx(DPFLTR_DEFAULT_ID, LOG_VERBOSE_LEVEL, fmt, ##__VA_ARGS__)
#define STATIC static
#define INLINE inline
@@ -87,7 +84,7 @@ typedef struct _DRIVER_LIST_HEAD {
} DRIVER_LIST_HEAD, *PDRIVER_LIST_HEAD;
typedef struct _THREAD_LIST_ENTRY {
- SINGLE_LIST_ENTRY list;
+ HANDLE thread_id;
PKTHREAD thread;
PKPROCESS owning_process;
BOOLEAN apc_queued;
@@ -337,6 +334,7 @@ typedef struct _ACTIVE_SESSION {
#define POOL_TAG_IRP_QUEUE 'irpp'
#define POOL_TAG_TIMER 'time'
#define POOL_TAG_MODULE_LIST 'elom'
+#define POOL_TAG_RB_TREE 'eert'
#define POOL_TAG_HASHMAP 'hsah'
#define IA32_APERF_MSR 0x000000E8
diff --git a/driver/driver.c b/driver/driver.c
index 56739f0..447cf36 100644
--- a/driver/driver.c
+++ b/driver/driver.c
@@ -96,7 +96,7 @@ typedef struct _DRIVER_CONFIG {
TIMER_OBJECT timer;
ACTIVE_SESSION session_information;
- THREAD_LIST_HEAD thread_list;
+ RB_TREE thread_tree;
DRIVER_LIST_HEAD driver_list;
RTL_HASHMAP process_hashmap;
SHARED_MAPPING mapping;
@@ -270,11 +270,11 @@ GetDriverConfigSystemInformation()
return &g_DriverConfig->system_information;
}
-PTHREAD_LIST_HEAD
-GetThreadList()
+PRB_TREE
+GetThreadTree()
{
PAGED_CODE();
- return &g_DriverConfig->thread_list;
+ return &g_DriverConfig->thread_tree;
}
PDRIVER_LIST_HEAD
diff --git a/driver/driver.h b/driver/driver.h
index 284878a..350945c 100644
--- a/driver/driver.h
+++ b/driver/driver.h
@@ -10,6 +10,7 @@
#include "integrity.h"
#include "callbacks.h"
#include "map.h"
+#include "tree.h"
BCRYPT_ALG_HANDLE*
GetCryptHandle_AES();
@@ -50,8 +51,8 @@ GetDriverSymbolicLink();
PSYSTEM_INFORMATION
GetDriverConfigSystemInformation();
-PTHREAD_LIST_HEAD
-GetThreadList();
+PRB_TREE
+GetThreadTree();
PDRIVER_LIST_HEAD
GetDriverList();
diff --git a/driver/driver.vcxproj b/driver/driver.vcxproj
index 86892ff..ff611e3 100644
--- a/driver/driver.vcxproj
+++ b/driver/driver.vcxproj
@@ -261,6 +261,7 @@
+
@@ -283,6 +284,7 @@
+
diff --git a/driver/driver.vcxproj.filters b/driver/driver.vcxproj.filters
index 600e9d8..3f9c5b8 100644
--- a/driver/driver.vcxproj.filters
+++ b/driver/driver.vcxproj.filters
@@ -78,6 +78,9 @@
Source Files
+
+ Source Files
+
@@ -152,6 +155,9 @@
Header Files
+
+ Header Files
+
diff --git a/driver/list.c b/driver/list.c
index aa289a7..8d9481b 100644
--- a/driver/list.c
+++ b/driver/list.c
@@ -120,7 +120,7 @@ LookasideListRemoveEntry(_Inout_ PSINGLE_LIST_ENTRY Head,
{
ImpKeAcquireGuardedMutex(Lock);
- PTHREAD_LIST_HEAD head = GetThreadList();
+ PTHREAD_LIST_HEAD head = GetThreadTree();
PSINGLE_LIST_ENTRY entry = Head->Next;
if (!entry)
@@ -153,7 +153,7 @@ LookasideListFreeFirstEntry(_Inout_ PSINGLE_LIST_ENTRY Head,
{
ImpKeAcquireGuardedMutex(Lock);
- PTHREAD_LIST_HEAD head = GetThreadList();
+ PTHREAD_LIST_HEAD head = GetThreadTree();
BOOLEAN result = FALSE;
if (Head->Next) {
diff --git a/driver/modules.c b/driver/modules.c
index 61c5e7f..36ed4db 100644
--- a/driver/modules.c
+++ b/driver/modules.c
@@ -9,6 +9,7 @@
#include "thread.h"
#include "pe.h"
#include "crypt.h"
+#include "tree.h"
#define WHITELISTED_MODULE_TAG 'whte'
@@ -1191,8 +1192,8 @@ ValidateThreadsViaKernelApc()
InsertApcContext(context);
SetApcAllocationInProgress(context);
- EnumerateThreadListWithCallbackRoutine(ValidateThreadViaKernelApcCallback,
- context);
+ RtlRbTreeEnumerate(
+ GetThreadTree(), ValidateThreadViaKernelApcCallback, context);
UnsetApcAllocationInProgress(context);
return status;
}
diff --git a/driver/thread.c b/driver/thread.c
index 1cdca4a..42d19e9 100644
--- a/driver/thread.c
+++ b/driver/thread.c
@@ -8,6 +8,7 @@
#include "queue.h"
#include "session.h"
#include "imports.h"
+#include "tree.h"
#include "crypt.h"
#ifdef ALLOC_PRAGMA
@@ -137,8 +138,6 @@ DetectThreadsAttachedToProtectedProcess()
{
PAGED_CODE();
DEBUG_VERBOSE("Detecting threads attached to our process...");
- EnumerateThreadListWithCallbackRoutine(DetectAttachedThreadsProcessCallback,
- NULL);
+ RtlRbTreeEnumerate(
+ GetThreadTree(), DetectAttachedThreadsProcessCallback, NULL);
}
-
-
diff --git a/driver/tree.c b/driver/tree.c
new file mode 100644
index 0000000..8a11970
--- /dev/null
+++ b/driver/tree.c
@@ -0,0 +1,611 @@
+#include "tree.h"
+
+/* Caller allocated RB_TREE */
+NTSTATUS
+RtlRbTreeCreate(_In_ RB_COMPARE Compare,
+ _In_ UINT32 ObjectSize,
+ _Out_ PRB_TREE Tree)
+{
+ NTSTATUS status = STATUS_UNSUCCESSFUL;
+
+ if (!ARGUMENT_PRESENT(Compare))
+ return STATUS_INVALID_PARAMETER;
+
+ status = ExInitializeLookasideListEx(&Tree->pool,
+ NULL,
+ NULL,
+ NonPagedPoolNx,
+ 0,
+ ObjectSize + sizeof(RB_TREE_NODE),
+ POOL_TAG_RB_TREE,
+ 0);
+
+ if (!NT_SUCCESS(status))
+ return status;
+
+ Tree->compare = Compare;
+ KeInitializeGuardedMutex(&Tree->lock);
+
+ return STATUS_SUCCESS;
+}
+
+/* This function is used to maintain the balance of a red-black tree by
+ * performing a left rotation around a given node. A left rotation moves the
+ * given node down to the left and its right child up to take its place.
+ *
+ * The structure of the tree before and after the rotation is as follows:
+ *
+ * Before Rotation: After Rotation:
+ * (Node) (Right_Child)
+ * / \ / \
+ * (A) (Right_Child) -> (Node) (C)
+ * / \ / \
+ * (B) (C) (A) (B)
+ */
+STATIC
+VOID
+RtlpRbTreeRotateLeft(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node)
+{
+ PRB_TREE_NODE right_child = Node->right;
+ Node->right = right_child->left;
+
+ if (right_child->left)
+ right_child->left->parent = Node;
+
+ right_child->parent = Node->parent;
+
+ if (!Node->parent)
+ Tree->root = right_child;
+ else if (Node == Node->parent->left)
+ Node->parent->left = right_child;
+ else
+ Node->parent->right = right_child;
+
+ right_child->left = Node;
+ Node->parent = right_child;
+}
+
+/*
+ * This function is used to maintain the balance of a red-black tree by
+ * performing a right rotation around a given node. A right rotation moves the
+ * given node down to the right and its left child up to take its place.
+ *
+ * The structure of the tree before and after the rotation is as follows:
+ *
+ * Before Rotation: After Rotation:
+ * (Node) (Left_Child)
+ * / \ / \
+ * (Left_Child) (C) -> (A) (Node)
+ * / \ / \
+ * (A) (B) (B) (C)
+ *
+ */
+STATIC
+VOID
+RtlpRbTreeRotateRight(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node)
+{
+ PRB_TREE_NODE left_child = Node->left;
+ Node->left = left_child->right;
+
+ if (left_child->right)
+ left_child->right->parent = Node;
+
+ left_child->parent = Node->parent;
+
+ if (!Node->parent)
+ Tree->root = left_child;
+ else if (Node == Node->parent->right)
+ Node->parent->right = left_child;
+ else
+ Node->parent->left = left_child;
+
+ left_child->right = Node;
+ Node->parent = left_child;
+}
+
+/*
+ * This function ensures the red-black tree properties are maintained after a
+ * new node is inserted. It adjusts the colors and performs rotations as
+ * necessary.
+ *
+ * Example scenario:
+ *
+ * Inserted Node causing a fixup:
+ * (Grandparent) (Parent)
+ * / \ / \
+ * (Parent) (Uncle) -> (Node) (Grandparent)
+ * / / \
+ * (Node) (Left) (Uncle)
+ */
+STATIC
+VOID
+RtlpRbTreeFixupInsert(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node)
+{
+ PRB_TREE_NODE uncle = NULL;
+ PRB_TREE_NODE parent = NULL;
+ PRB_TREE_NODE grandparent = NULL;
+
+ while ((parent = Node->parent) && parent->colour == red) {
+ grandparent = parent->parent;
+
+ if (parent == grandparent->left) {
+ uncle = grandparent->right;
+
+ if (uncle && uncle->colour == red) {
+ parent->colour = black;
+ uncle->colour = black;
+ grandparent->colour = red;
+ Node = grandparent;
+ }
+ else {
+ if (Node == parent->right) {
+ RtlpRbTreeRotateLeft(Tree, parent);
+ Node = parent;
+ parent = Node->parent;
+ }
+
+ parent->colour = black;
+ grandparent->colour = red;
+ RtlpRbTreeRotateRight(Tree, grandparent);
+ }
+ }
+ else {
+ uncle = grandparent->left;
+
+ if (uncle && uncle->colour == red) {
+ parent->colour = black;
+ uncle->colour = black;
+ grandparent->colour = red;
+ Node = grandparent;
+ }
+ else {
+ if (Node == parent->left) {
+ RtlpRbTreeRotateRight(Tree, parent);
+ Node = parent;
+ parent = Node->parent;
+ }
+
+ parent->colour = black;
+ grandparent->colour = red;
+ RtlpRbTreeRotateLeft(Tree, grandparent);
+ }
+ }
+ }
+
+ Tree->root->colour = black;
+}
+
+/*
+ * ASSUMES LOCK IS HELD!
+ *
+ * This function inserts a new node into the red-black tree, and then calls a
+ * fix-up routine to ensure the tree properties are maintained.
+ *
+ * Example insertion process:
+ *
+ * Before insertion:
+ * (Root)
+ * / \
+ * (Left) (Right)
+ *
+ * After insertion:
+ * (Root)
+ * / \
+ * (Left) (Right)
+ * /
+ * (Node)
+ *
+ * After fix-up:
+ * (Root)
+ * / \
+ * (Left) (Node)
+ * \
+ * (Right)
+ */
+PVOID
+RtlRbTreeInsertNode(_In_ PRB_TREE Tree, _In_ PVOID Key)
+{
+ UINT32 result = 0;
+ PRB_TREE_NODE node = NULL;
+ PRB_TREE_NODE parent = NULL;
+ PRB_TREE_NODE current = NULL;
+
+ node = ExAllocateFromLookasideListEx(&Tree->pool);
+
+ if (!node)
+ return NULL;
+
+ node->parent = NULL;
+ node->left = NULL;
+ node->right = NULL;
+ node->colour = red;
+
+ current = Tree->root;
+
+ while (current) {
+ parent = current;
+ result = Tree->compare(Key, current->object);
+
+ if (result == RB_TREE_LESS_THAN) {
+ current = current->left;
+ }
+ else if (result == RB_TREE_GREATER_THAN) {
+ current = current->right;
+ }
+ else {
+ ExFreeToLookasideListEx(&Tree->pool, node);
+ return current->object;
+ }
+ }
+
+ node->parent = parent;
+
+ if (!parent)
+ Tree->root = node;
+ else if (result == RB_TREE_LESS_THAN)
+ parent->left = node;
+ else
+ parent->right = node;
+
+ RtlpRbTreeFixupInsert(Tree, node);
+
+ return node->object;
+}
+
+/*
+ * ASSUMES LOCK IS HELD!
+ *
+ * This function traverses the left children of the given node to find and
+ * return the node with the minimum key in the subtree.
+ *
+ * Example traversal to find minimum:
+ *
+ * (Root)
+ * / \
+ * (Left) (Right)
+ * /
+ * (Node)
+ *
+ * After finding minimum:
+ * (Root)
+ * / \
+ * (Node) (Right)
+ *
+ * Returns the left-most node.
+ */
+STATIC
+PRB_TREE_NODE
+RtlpRbTreeMinimum(_In_ PRB_TREE_NODE Node)
+{
+ while (Node->left != NULL)
+ Node = Node->left;
+
+ return Node;
+}
+
+/*
+ * ASSUMES LOCK IS HELD!
+ *
+ * This function is called after a node is deleted from the Red-Black Tree.
+ * It ensures that the tree remains balanced and the Red-Black properties are
+ * maintained. It performs the necessary rotations and recoloring.
+ *
+ * Example fixup scenarios:
+ *
+ * Before Fixup: After Fixup:
+ * (Parent) (Parent)
+ * / \ / \
+ * (Node) (Sibling) (Node) (Sibling)
+ * / \ / \
+ * (Left) (Right) (Left) (Right)
+ *
+ * The fixup process ensures that the tree remains balanced.
+ */
+STATIC
+VOID
+RtlpRbTreeFixupDelete(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node)
+{
+ PRB_TREE_NODE sibling = NULL;
+
+ while (Node != Tree->root && Node->colour == black) {
+ if (Node == Node->parent->left) {
+ sibling = Node->parent->right;
+
+ if (sibling && sibling->colour == red) {
+ sibling->colour = black;
+ Node->parent->colour = red;
+ RtlpRbTreeRotateLeft(Tree, Node->parent);
+ sibling = Node->parent->right;
+ }
+
+ if (sibling && (!sibling->left || sibling->left->colour == black) &&
+ (!sibling->right || sibling->right->colour == black)) {
+ sibling->colour = red;
+ Node = Node->parent;
+ }
+ else {
+ if (sibling &&
+ (!sibling->right || sibling->right->colour == black)) {
+ if (sibling->left) {
+ sibling->left->colour = black;
+ }
+ sibling->colour = red;
+ RtlpRbTreeRotateRight(Tree, sibling);
+ sibling = Node->parent->right;
+ }
+
+ if (sibling) {
+ sibling->colour = Node->parent->colour;
+ Node->parent->colour = black;
+ if (sibling->right) {
+ sibling->right->colour = black;
+ }
+ RtlpRbTreeRotateLeft(Tree, Node->parent);
+ }
+ Node = Tree->root;
+ }
+ }
+ else {
+ sibling = Node->parent->left;
+
+ if (sibling && sibling->colour == red) {
+ sibling->colour = black;
+ Node->parent->colour = red;
+ RtlpRbTreeRotateRight(Tree, Node->parent);
+ sibling = Node->parent->left;
+ }
+
+ if (sibling &&
+ (!sibling->right || sibling->right->colour == black) &&
+ (!sibling->left || sibling->left->colour == black)) {
+ sibling->colour = red;
+ Node = Node->parent;
+ }
+ else {
+ if (sibling &&
+ (!sibling->left || sibling->left->colour == black)) {
+ if (sibling->right) {
+ sibling->right->colour = black;
+ }
+ sibling->colour = red;
+ RtlpRbTreeRotateLeft(Tree, sibling);
+ sibling = Node->parent->left;
+ }
+
+ if (sibling) {
+ sibling->colour = Node->parent->colour;
+ Node->parent->colour = black;
+ if (sibling->left) {
+ sibling->left->colour = black;
+ }
+ RtlpRbTreeRotateRight(Tree, Node->parent);
+ }
+ Node = Tree->root;
+ }
+ }
+ }
+
+ Node->colour = black;
+}
+
+/*
+ * ASSUMES LOCK IS HELD!
+ *
+ * This function replaces the subtree rooted at the node `toBeReplacedNode` with
+ * the subtree rooted at the node `replacementNode`. It adjusts the parent
+ * pointers accordingly.
+ *
+ * Example scenario:
+ *
+ * Before Transplant: After Transplant:
+ * (ParentNode) (ParentNode)
+ * / \ / \
+ * (toBeReplaced) Sibling (Replacement) Sibling
+ * / \ / \
+ * Left Right Left Right
+ *
+ * The transplant process ensures that the subtree rooted at `replacementNode`
+ * takes the place of the subtree rooted at `toBeReplacedNode`.
+ */
+STATIC
+VOID
+RtlpRbTreeTransplant(_In_ PRB_TREE Tree,
+ _In_ PRB_TREE_NODE Target,
+ _In_ PRB_TREE_NODE Replacement)
+{
+ if (!Target->parent)
+ Tree->root = Replacement;
+ else if (Target == Target->parent->left)
+ Target->parent->left = Replacement;
+ else
+ Target->parent->right = Replacement;
+
+ if (Replacement)
+ Replacement->parent = Target->parent;
+}
+
+/*
+ * ASSUMES LOCK IS HELD!
+ *
+ * This function removes a node with the specified key from the Red-Black Tree
+ * and ensures the tree remains balanced by performing necessary rotations and
+ * recoloring.
+ *
+ * Example scenario:
+ *
+ * Before Deletion: After Deletion:
+ * (ParentNode) (ParentNode)
+ * / \ / \
+ * (TargetNode) Sibling (Replacement) Sibling
+ * / \ / \
+ * LeftChild RightChild LeftChild RightChild
+ *
+ * The deletion process involves finding the target node, replacing it with a
+ * suitable successor or child, and ensuring the Red-Black Tree properties are
+ * maintained.
+ */
+VOID
+RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key)
+{
+ UINT32 result = 0;
+ COLOUR colour = 0;
+ PRB_TREE_NODE node = Tree->root;
+ PRB_TREE_NODE target = NULL;
+ PRB_TREE_NODE child = NULL;
+ PRB_TREE_NODE parent = NULL;
+ PRB_TREE_NODE successor = NULL;
+
+ while (node) {
+ result = Tree->compare(Key, node->object);
+ if (result == RB_TREE_EQUAL) {
+ target = node;
+ break;
+ }
+ else if (result == RB_TREE_LESS_THAN) {
+ node = node->left;
+ }
+ else {
+ node = node->right;
+ }
+ }
+
+ if (!target)
+ return;
+
+ colour = target->colour;
+
+ if (!target->left) {
+ child = target->right;
+ RtlpRbTreeTransplant(Tree, target, target->right);
+ }
+ else if (!target->right) {
+ child = target->left;
+ RtlpRbTreeTransplant(Tree, target, target->left);
+ }
+ else {
+ successor = RtlpRbTreeMinimum(target->right);
+ colour = successor->colour;
+ child = successor->right;
+
+ if (successor->parent == target) {
+ if (child)
+ child->parent = successor;
+ }
+ else {
+ RtlpRbTreeTransplant(Tree, successor, successor->right);
+ successor->right = target->right;
+ successor->right->parent = successor;
+ }
+
+ RtlpRbTreeTransplant(Tree, target, successor);
+ successor->left = target->left;
+ successor->left->parent = successor;
+ successor->colour = target->colour;
+ }
+
+ if (colour == black && child)
+ RtlpRbTreeFixupDelete(Tree, child);
+
+ ExFreeToLookasideListEx(&Tree->pool, target);
+}
+
+PVOID
+RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key)
+{
+ INT32 result = 0;
+ PRB_TREE_NODE current = Tree->root;
+
+ while (current) {
+ result = Tree->compare(Key, current->object);
+
+ if (result == RB_TREE_EQUAL)
+ return current->object;
+ else if (result == RB_TREE_LESS_THAN)
+ current = current->left;
+ else
+ current = current->right;
+ }
+
+ return NULL;
+}
+
+STATIC
+VOID
+RtlpRbTreeEnumerate(_In_ PRB_TREE_NODE Node,
+ _In_ RB_ENUM_CALLBACK Callback,
+ _In_opt_ PVOID Context)
+{
+ if (Node == NULL)
+ return;
+
+ RtlpRbTreeEnumerate(Node->left, Callback, Context);
+ Callback(Node->object, Context);
+ RtlpRbTreeEnumerate(Node->right, Callback, Context);
+}
+
+VOID
+RtlRbTreeEnumerate(_In_ PRB_TREE Tree,
+ _In_ RB_ENUM_CALLBACK Callback,
+ _In_opt_ PVOID Context)
+{
+ if (Tree->root == NULL)
+ return;
+
+ RtlRbTreeAcquireLock(Tree);
+ RtlpRbTreeEnumerate(Tree->root, Callback, Context);
+ RtlRbTreeReleaselock(Tree);
+}
+
+STATIC
+VOID
+RtlpPrintInOrder(PRB_TREE_NODE Node)
+{
+ if (Node == NULL)
+ return;
+
+ RtlpPrintInOrder(Node->left);
+
+ const char* color = (Node->colour == red) ? "Red" : "Black";
+ DbgPrintEx(DPFLTR_DEFAULT_ID,
+ DPFLTR_INFO_LEVEL,
+ "Node: Key=%p, Color=%s\n",
+ *((PHANDLE)Node->object),
+ color);
+
+ RtlpPrintInOrder(Node->right);
+}
+
+/* assumes lock is held */
+VOID
+RtlRbTreeInOrderPrint(_In_ PRB_TREE Tree)
+{
+ DEBUG_ERROR("*************************************************");
+ DEBUG_ERROR("<><><><>STARTING IN ORDER PRINT <><><><><><");
+ RtlpPrintInOrder(Tree->root);
+ DEBUG_ERROR("<><><><>ENDING IN ORDER PRINT <><><><><><");
+ DEBUG_ERROR("*************************************************");
+}
+
+STATIC
+VOID
+RtlpRbTreeDeleteSubtree(_In_ PRB_TREE Tree, _In_ PRB_TREE_NODE Node)
+{
+ if (Node == NULL)
+ return;
+
+ RtlpRbTreeDeleteSubtree(Tree, Node->left);
+ RtlpRbTreeDeleteSubtree(Tree, Node->right);
+
+ ExFreeToLookasideListEx(&Tree->pool, Node);
+}
+
+VOID
+RtlRbTreeDeleteTree(_In_ PRB_TREE Tree)
+{
+ Tree->active = FALSE;
+
+ RtlRbTreeAcquireLock(Tree);
+ RtlpRbTreeDeleteSubtree(Tree, Tree->root);
+ ExDeleteLookasideListEx(&Tree->pool);
+ RtlRbTreeReleaselock(Tree);
+}
\ No newline at end of file
diff --git a/driver/tree.h b/driver/tree.h
new file mode 100644
index 0000000..4ebda95
--- /dev/null
+++ b/driver/tree.h
@@ -0,0 +1,75 @@
+#ifndef TREE_H
+#define TREE_H
+
+#include "common.h"
+
+#define RB_TREE_EQUAL 0
+#define RB_TREE_LESS_THAN 1
+#define RB_TREE_GREATER_THAN 2
+
+typedef enum _COLOUR { red, black } COLOUR;
+
+typedef struct _RB_TREE_NODE {
+ struct _RB_TREE_NODE* parent;
+ struct _RB_TREE_NODE* left;
+ struct _RB_TREE_NODE* right;
+ COLOUR colour;
+ CHAR object[];
+} RB_TREE_NODE, *PRB_TREE_NODE;
+
+typedef UINT32 (*RB_COMPARE)(_In_ PVOID Key, _In_ PVOID Object);
+
+typedef struct _RB_TREE {
+ PRB_TREE_NODE root;
+ KGUARDED_MUTEX lock;
+ RB_COMPARE compare;
+ LOOKASIDE_LIST_EX pool;
+ UINT32 object_size;
+ UINT32 active;
+} RB_TREE, *PRB_TREE;
+
+typedef VOID (*RB_CALLBACK)(PRB_TREE_NODE Node);
+typedef VOID (*RB_ENUM_CALLBACK)(_In_ PVOID Object, _In_opt_ PVOID Context);
+
+PVOID
+RtlRbTreeInsertNode(_In_ PRB_TREE Tree, _In_ PVOID Key);
+
+NTSTATUS
+RtlRbTreeCreate(_In_ RB_COMPARE Compare,
+ _In_ UINT32 ObjectSize,
+ _Out_ PRB_TREE Tree);
+
+VOID
+RtlRbTreeDeleteNode(_In_ PRB_TREE Tree, _In_ PVOID Key);
+
+PVOID
+RtlRbTreeFindNode(_In_ PRB_TREE Tree, _In_ PVOID Key);
+
+VOID
+RtlRbTreeEnumerate(_In_ PRB_TREE Tree,
+ _In_ RB_ENUM_CALLBACK Callback,
+ _In_opt_ PVOID Context);
+
+VOID
+RtlRbTreeDeleteTree(_In_ PRB_TREE Tree);
+
+VOID
+RtlRbTreeInOrderPrint(_In_ PRB_TREE Tree);
+
+FORCEINLINE
+STATIC
+VOID
+RtlRbTreeAcquireLock(_Inout_ PRB_TREE Tree)
+{
+ KeAcquireGuardedMutex(&Tree->lock);
+}
+
+FORCEINLINE
+STATIC
+VOID
+RtlRbTreeReleaselock(_Inout_ PRB_TREE Tree)
+{
+ KeReleaseGuardedMutex(&Tree->lock);
+}
+
+#endif
\ No newline at end of file