From 2477ea004ee67b069e0cc37a092d23f1aa72d9f8 Mon Sep 17 00:00:00 2001 From: donnaskiez Date: Sun, 5 May 2024 23:58:36 +1000 Subject: [PATCH] more refactoring --- driver/modules.c | 223 ++++++++++++++++------------------------------- driver/modules.h | 10 +-- driver/thread.c | 2 +- driver/thread.h | 2 +- 4 files changed, 77 insertions(+), 160 deletions(-) diff --git a/driver/modules.c b/driver/modules.c index 568c59e..cd0edb3 100644 --- a/driver/modules.c +++ b/driver/modules.c @@ -128,7 +128,6 @@ ValidateThreadViaKernelApcCallback(_In_ PTHREAD_LIST_ENTRY ThreadListEntry, # pragma alloc_text(PAGE, ApcRundownRoutine) # pragma alloc_text(PAGE, ApcKernelRoutine) # pragma alloc_text(PAGE, ApcNormalRoutine) -# pragma alloc_text(PAGE, FlipKThreadMiscFlagsFlag) # pragma alloc_text(PAGE, ValidateThreadsViaKernelApc) # pragma alloc_text(PAGE, ValidateThreadViaKernelApcCallback) #endif @@ -524,16 +523,12 @@ end: * TODO: this probably doesnt need to return an NTSTATUS, we can just return a * boolean and remove the out variable. */ -NTSTATUS +BOOLEAN IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP, - _In_ PSYSTEM_MODULES SystemModules, - _Out_ PBOOLEAN Result) + _In_ PSYSTEM_MODULES SystemModules) { PAGED_CODE(); - if (!RIP || !SystemModules || !Result) - return STATUS_INVALID_PARAMETER; - PRTL_MODULE_EXTENDED_INFO modules = (PRTL_MODULE_EXTENDED_INFO)SystemModules->address; @@ -543,17 +538,15 @@ IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP, UINT64 end = base + modules[index].ImageSize; if (RIP >= base && RIP <= end) { - *Result = TRUE; - return STATUS_SUCCESS; + return FALSE; } } - *Result = FALSE; - return STATUS_SUCCESS; + return TRUE; } BOOLEAN -IsInstructionPointerInsideModule(_In_ UINT64 Rip, +IsInstructionPointerInsideSpecifiedModule(_In_ UINT64 Rip, _In_ PRTL_MODULE_EXTENDED_INFO Module) { UINT64 base = (UINT64)Module->ImageBase; @@ -685,21 +678,15 @@ AnalyseNmiData(_In_ PNMI_CONTEXT NmiContext, _In_ PSYSTEM_MODULES SystemModules) * PsGetNextProcess ? */ - if (!ValidateThreadsPspCidTableEntry(NmiContext[core].kthread)) { + if (!DoesThreadHaveValidCidEntry(NmiContext[core].kthread)) { ReportMissingCidTableEntry(&NmiContext[core]); } if (NmiContext[core].user_thread) continue; - status = IsInstructionPointerInInvalidRegion( - NmiContext[core].interrupted_rip, SystemModules, &flag); - - if (!NT_SUCCESS(status)) { - continue; - } - - if (!flag) + if (IsInstructionPointerInInvalidRegion( + NmiContext[core].interrupted_rip, SystemModules)) ReportInvalidRipFoundDuringNmi(&NmiContext[core]); } @@ -971,14 +958,8 @@ ApcKernelRoutine(_In_ PRKAPC Apc, * structure that we passed into KeInitializeApc as the last * argument. */ - status = IsInstructionPointerInInvalidRegion( - frames[index], context->modules, &flag); - - if (!NT_SUCCESS(status)) { - goto free; - } - - if (!flag) + if (IsInstructionPointerInInvalidRegion(frames[index], + context->modules)) ReportApcStackwalkViolation(frames[index]); } @@ -1248,12 +1229,8 @@ ValidateDpcStackFrame(_In_ PDPC_CONTEXT Context, _In_ PSYSTEM_MODULES Modules) for (UINT32 frame = 0; frame < Context->frames_captured; frame++) { UINT64 rip = Context->stack_frame[frame]; - status = IsInstructionPointerInInvalidRegion(rip, Modules, &flag); - if (!NT_SUCCESS(status)) - return; - - if (!flag) + if (IsInstructionPointerInInvalidRegion(rip, Modules)) ReportDpcStackwalkViolation(Context, rip); } } @@ -1335,17 +1312,7 @@ ValidateTableDispatchRoutines(_In_ PVOID* Base, if (!Base[index]) continue; - status = - IsInstructionPointerInInvalidRegion(Base[index], Modules, &flag); - - if (!NT_SUCCESS(status)) { - DEBUG_ERROR( - "IsInstructionPointerInInvalidRegion failed with status %x", - status); - continue; - } - - if (!flag) + if (IsInstructionPointerInInvalidRegion(Base[index], Modules)) *Routine = Base[index]; } @@ -1431,147 +1398,103 @@ ValidateHalDispatchTable(_Out_ PVOID* Routine, _In_ PSYSTEM_MODULES Modules) * * What if there are 2 invalid routines? hmm.. tink. */ - status = IsInstructionPointerInInvalidRegion( - HalQuerySystemInformation, Modules, &flag); - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalQuerySystemInformation, + Modules)) { *Routine = HalQuerySystemInformation; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalSetSystemInformation, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalSetSystemInformation, Modules)) { *Routine = HalSetSystemInformation; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalQueryBusSlots, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalQueryBusSlots, Modules)) { *Routine = HalQueryBusSlots; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalReferenceHandlerForBus, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalReferenceHandlerForBus, + Modules)) { *Routine = HalReferenceHandlerForBus; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalReferenceBusHandler, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalReferenceBusHandler, Modules)) { *Routine = HalReferenceBusHandler; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalDereferenceBusHandler, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalDereferenceBusHandler, + Modules)) { *Routine = HalDereferenceBusHandler; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalInitPnpDriver, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalInitPnpDriver, Modules)) { *Routine = HalInitPnpDriver; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalInitPowerManagement, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalInitPowerManagement, Modules)) { *Routine = HalInitPowerManagement; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalGetDmaAdapter, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalGetDmaAdapter, Modules)) { *Routine = HalGetDmaAdapter; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalGetInterruptTranslator, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalGetInterruptTranslator, + Modules)) { *Routine = HalGetInterruptTranslator; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalStartMirroring, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalStartMirroring, Modules)) { *Routine = HalStartMirroring; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalEndMirroring, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalEndMirroring, Modules)) { *Routine = HalEndMirroring; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalMirrorPhysicalMemory, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalMirrorPhysicalMemory, Modules)) { *Routine = HalMirrorPhysicalMemory; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion(HalEndOfBoot, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalEndOfBoot, Modules)) { *Routine = HalEndOfBoot; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalMirrorVerify, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalMirrorVerify, Modules)) { *Routine = HalMirrorVerify; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalGetCachedAcpiTable, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalGetCachedAcpiTable, Modules)) { *Routine = HalGetCachedAcpiTable; - else - return status; + goto end; + } - status = IsInstructionPointerInInvalidRegion( - HalSetPciErrorHandlerCallback, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalSetPciErrorHandlerCallback, + Modules)) { *Routine = HalSetPciErrorHandlerCallback; - else - return status; + goto end; + } - status = - IsInstructionPointerInInvalidRegion(HalGetPrmCache, Modules, &flag); - - if (!flag && NT_SUCCESS(status)) + if (IsInstructionPointerInInvalidRegion(HalGetPrmCache, Modules)) { *Routine = HalGetPrmCache; + goto end; + } +end: return status; } @@ -1996,7 +1919,7 @@ ValidateWin32kBase_gDxgInterface() DEBUG_INFO("regular entry: %p", dxg_interface[index]); #endif - if (!IsInstructionPointerInsideModule(entry, dxgkrnl)) { + if (!IsInstructionPointerInsideSpecifiedModule(entry, dxgkrnl)) { DEBUG_ERROR("invalid entry!!!"); ReportWin32kBase_DxgInterfaceViolation(index, entry); } diff --git a/driver/modules.h b/driver/modules.h index 357831f..074e7cd 100644 --- a/driver/modules.h +++ b/driver/modules.h @@ -57,15 +57,9 @@ ValidateThreadsViaKernelApc(); VOID FreeApcStackwalkApcContextInformation(_Inout_ PAPC_STACKWALK_CONTEXT Context); -NTSTATUS +BOOLEAN IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP, - _In_ PSYSTEM_MODULES SystemModules, - _Out_ PBOOLEAN Result); - -VOID -FlipKThreadMiscFlagsFlag(_In_ PKTHREAD Thread, - _In_ ULONG FlagIndex, - _In_ BOOLEAN NewValue); + _In_ PSYSTEM_MODULES SystemModules); PVOID FindDriverBaseNoApi(_In_ PDRIVER_OBJECT DriverObject, _In_ PWCH Name); diff --git a/driver/thread.c b/driver/thread.c index 9ea7212..c6e775c 100644 --- a/driver/thread.c +++ b/driver/thread.c @@ -15,7 +15,7 @@ #endif BOOLEAN -ValidateThreadsPspCidTableEntry(_In_ PETHREAD Thread) +DoesThreadHaveValidCidEntry(_In_ PETHREAD Thread) { PAGED_CODE(); diff --git a/driver/thread.h b/driver/thread.h index 09237f5..8109dac 100644 --- a/driver/thread.h +++ b/driver/thread.h @@ -7,7 +7,7 @@ #include "callbacks.h" BOOLEAN -ValidateThreadsPspCidTableEntry(_In_ PETHREAD Thread); +DoesThreadHaveValidCidEntry(_In_ PETHREAD Thread); VOID DetectThreadsAttachedToProtectedProcess();