more refactoring

This commit is contained in:
donnaskiez 2024-05-05 23:58:36 +10:00
parent 8685725415
commit 2477ea004e
4 changed files with 77 additions and 160 deletions

View file

@ -128,7 +128,6 @@ ValidateThreadViaKernelApcCallback(_In_ PTHREAD_LIST_ENTRY ThreadListEntry,
# pragma alloc_text(PAGE, ApcRundownRoutine) # pragma alloc_text(PAGE, ApcRundownRoutine)
# pragma alloc_text(PAGE, ApcKernelRoutine) # pragma alloc_text(PAGE, ApcKernelRoutine)
# pragma alloc_text(PAGE, ApcNormalRoutine) # pragma alloc_text(PAGE, ApcNormalRoutine)
# pragma alloc_text(PAGE, FlipKThreadMiscFlagsFlag)
# pragma alloc_text(PAGE, ValidateThreadsViaKernelApc) # pragma alloc_text(PAGE, ValidateThreadsViaKernelApc)
# pragma alloc_text(PAGE, ValidateThreadViaKernelApcCallback) # pragma alloc_text(PAGE, ValidateThreadViaKernelApcCallback)
#endif #endif
@ -524,16 +523,12 @@ end:
* TODO: this probably doesnt need to return an NTSTATUS, we can just return a * TODO: this probably doesnt need to return an NTSTATUS, we can just return a
* boolean and remove the out variable. * boolean and remove the out variable.
*/ */
NTSTATUS BOOLEAN
IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP, IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP,
_In_ PSYSTEM_MODULES SystemModules, _In_ PSYSTEM_MODULES SystemModules)
_Out_ PBOOLEAN Result)
{ {
PAGED_CODE(); PAGED_CODE();
if (!RIP || !SystemModules || !Result)
return STATUS_INVALID_PARAMETER;
PRTL_MODULE_EXTENDED_INFO modules = PRTL_MODULE_EXTENDED_INFO modules =
(PRTL_MODULE_EXTENDED_INFO)SystemModules->address; (PRTL_MODULE_EXTENDED_INFO)SystemModules->address;
@ -543,17 +538,15 @@ IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP,
UINT64 end = base + modules[index].ImageSize; UINT64 end = base + modules[index].ImageSize;
if (RIP >= base && RIP <= end) { if (RIP >= base && RIP <= end) {
*Result = TRUE; return FALSE;
return STATUS_SUCCESS;
} }
} }
*Result = FALSE; return TRUE;
return STATUS_SUCCESS;
} }
BOOLEAN BOOLEAN
IsInstructionPointerInsideModule(_In_ UINT64 Rip, IsInstructionPointerInsideSpecifiedModule(_In_ UINT64 Rip,
_In_ PRTL_MODULE_EXTENDED_INFO Module) _In_ PRTL_MODULE_EXTENDED_INFO Module)
{ {
UINT64 base = (UINT64)Module->ImageBase; UINT64 base = (UINT64)Module->ImageBase;
@ -685,21 +678,15 @@ AnalyseNmiData(_In_ PNMI_CONTEXT NmiContext, _In_ PSYSTEM_MODULES SystemModules)
* PsGetNextProcess ? * PsGetNextProcess ?
*/ */
if (!ValidateThreadsPspCidTableEntry(NmiContext[core].kthread)) { if (!DoesThreadHaveValidCidEntry(NmiContext[core].kthread)) {
ReportMissingCidTableEntry(&NmiContext[core]); ReportMissingCidTableEntry(&NmiContext[core]);
} }
if (NmiContext[core].user_thread) if (NmiContext[core].user_thread)
continue; continue;
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(
NmiContext[core].interrupted_rip, SystemModules, &flag); NmiContext[core].interrupted_rip, SystemModules))
if (!NT_SUCCESS(status)) {
continue;
}
if (!flag)
ReportInvalidRipFoundDuringNmi(&NmiContext[core]); ReportInvalidRipFoundDuringNmi(&NmiContext[core]);
} }
@ -971,14 +958,8 @@ ApcKernelRoutine(_In_ PRKAPC Apc,
* structure that we passed into KeInitializeApc as the last * structure that we passed into KeInitializeApc as the last
* argument. * argument.
*/ */
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(frames[index],
frames[index], context->modules, &flag); context->modules))
if (!NT_SUCCESS(status)) {
goto free;
}
if (!flag)
ReportApcStackwalkViolation(frames[index]); ReportApcStackwalkViolation(frames[index]);
} }
@ -1248,12 +1229,8 @@ ValidateDpcStackFrame(_In_ PDPC_CONTEXT Context, _In_ PSYSTEM_MODULES Modules)
for (UINT32 frame = 0; frame < Context->frames_captured; frame++) { for (UINT32 frame = 0; frame < Context->frames_captured; frame++) {
UINT64 rip = Context->stack_frame[frame]; UINT64 rip = Context->stack_frame[frame];
status = IsInstructionPointerInInvalidRegion(rip, Modules, &flag);
if (!NT_SUCCESS(status)) if (IsInstructionPointerInInvalidRegion(rip, Modules))
return;
if (!flag)
ReportDpcStackwalkViolation(Context, rip); ReportDpcStackwalkViolation(Context, rip);
} }
} }
@ -1335,17 +1312,7 @@ ValidateTableDispatchRoutines(_In_ PVOID* Base,
if (!Base[index]) if (!Base[index])
continue; continue;
status = if (IsInstructionPointerInInvalidRegion(Base[index], Modules))
IsInstructionPointerInInvalidRegion(Base[index], Modules, &flag);
if (!NT_SUCCESS(status)) {
DEBUG_ERROR(
"IsInstructionPointerInInvalidRegion failed with status %x",
status);
continue;
}
if (!flag)
*Routine = Base[index]; *Routine = Base[index];
} }
@ -1431,147 +1398,103 @@ ValidateHalDispatchTable(_Out_ PVOID* Routine, _In_ PSYSTEM_MODULES Modules)
* *
* What if there are 2 invalid routines? hmm.. tink. * What if there are 2 invalid routines? hmm.. tink.
*/ */
status = IsInstructionPointerInInvalidRegion(
HalQuerySystemInformation, Modules, &flag);
if (!flag && NT_SUCCESS(status)) if (IsInstructionPointerInInvalidRegion(HalQuerySystemInformation,
Modules)) {
*Routine = HalQuerySystemInformation; *Routine = HalQuerySystemInformation;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalSetSystemInformation, Modules)) {
HalSetSystemInformation, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalSetSystemInformation; *Routine = HalSetSystemInformation;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalQueryBusSlots, Modules)) {
IsInstructionPointerInInvalidRegion(HalQueryBusSlots, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalQueryBusSlots; *Routine = HalQueryBusSlots;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalReferenceHandlerForBus,
HalReferenceHandlerForBus, Modules, &flag); Modules)) {
if (!flag && NT_SUCCESS(status))
*Routine = HalReferenceHandlerForBus; *Routine = HalReferenceHandlerForBus;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalReferenceBusHandler, Modules)) {
HalReferenceBusHandler, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalReferenceBusHandler; *Routine = HalReferenceBusHandler;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalDereferenceBusHandler,
HalDereferenceBusHandler, Modules, &flag); Modules)) {
if (!flag && NT_SUCCESS(status))
*Routine = HalDereferenceBusHandler; *Routine = HalDereferenceBusHandler;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalInitPnpDriver, Modules)) {
IsInstructionPointerInInvalidRegion(HalInitPnpDriver, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalInitPnpDriver; *Routine = HalInitPnpDriver;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalInitPowerManagement, Modules)) {
HalInitPowerManagement, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalInitPowerManagement; *Routine = HalInitPowerManagement;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalGetDmaAdapter, Modules)) {
IsInstructionPointerInInvalidRegion(HalGetDmaAdapter, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalGetDmaAdapter; *Routine = HalGetDmaAdapter;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalGetInterruptTranslator,
HalGetInterruptTranslator, Modules, &flag); Modules)) {
if (!flag && NT_SUCCESS(status))
*Routine = HalGetInterruptTranslator; *Routine = HalGetInterruptTranslator;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalStartMirroring, Modules)) {
IsInstructionPointerInInvalidRegion(HalStartMirroring, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalStartMirroring; *Routine = HalStartMirroring;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalEndMirroring, Modules)) {
IsInstructionPointerInInvalidRegion(HalEndMirroring, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalEndMirroring; *Routine = HalEndMirroring;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalMirrorPhysicalMemory, Modules)) {
HalMirrorPhysicalMemory, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalMirrorPhysicalMemory; *Routine = HalMirrorPhysicalMemory;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion(HalEndOfBoot, Modules, &flag); if (IsInstructionPointerInInvalidRegion(HalEndOfBoot, Modules)) {
if (!flag && NT_SUCCESS(status))
*Routine = HalEndOfBoot; *Routine = HalEndOfBoot;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalMirrorVerify, Modules)) {
IsInstructionPointerInInvalidRegion(HalMirrorVerify, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalMirrorVerify; *Routine = HalMirrorVerify;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalGetCachedAcpiTable, Modules)) {
HalGetCachedAcpiTable, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalGetCachedAcpiTable; *Routine = HalGetCachedAcpiTable;
else goto end;
return status; }
status = IsInstructionPointerInInvalidRegion( if (IsInstructionPointerInInvalidRegion(HalSetPciErrorHandlerCallback,
HalSetPciErrorHandlerCallback, Modules, &flag); Modules)) {
if (!flag && NT_SUCCESS(status))
*Routine = HalSetPciErrorHandlerCallback; *Routine = HalSetPciErrorHandlerCallback;
else goto end;
return status; }
status = if (IsInstructionPointerInInvalidRegion(HalGetPrmCache, Modules)) {
IsInstructionPointerInInvalidRegion(HalGetPrmCache, Modules, &flag);
if (!flag && NT_SUCCESS(status))
*Routine = HalGetPrmCache; *Routine = HalGetPrmCache;
goto end;
}
end:
return status; return status;
} }
@ -1996,7 +1919,7 @@ ValidateWin32kBase_gDxgInterface()
DEBUG_INFO("regular entry: %p", dxg_interface[index]); DEBUG_INFO("regular entry: %p", dxg_interface[index]);
#endif #endif
if (!IsInstructionPointerInsideModule(entry, dxgkrnl)) { if (!IsInstructionPointerInsideSpecifiedModule(entry, dxgkrnl)) {
DEBUG_ERROR("invalid entry!!!"); DEBUG_ERROR("invalid entry!!!");
ReportWin32kBase_DxgInterfaceViolation(index, entry); ReportWin32kBase_DxgInterfaceViolation(index, entry);
} }

View file

@ -57,15 +57,9 @@ ValidateThreadsViaKernelApc();
VOID VOID
FreeApcStackwalkApcContextInformation(_Inout_ PAPC_STACKWALK_CONTEXT Context); FreeApcStackwalkApcContextInformation(_Inout_ PAPC_STACKWALK_CONTEXT Context);
NTSTATUS BOOLEAN
IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP, IsInstructionPointerInInvalidRegion(_In_ UINT64 RIP,
_In_ PSYSTEM_MODULES SystemModules, _In_ PSYSTEM_MODULES SystemModules);
_Out_ PBOOLEAN Result);
VOID
FlipKThreadMiscFlagsFlag(_In_ PKTHREAD Thread,
_In_ ULONG FlagIndex,
_In_ BOOLEAN NewValue);
PVOID PVOID
FindDriverBaseNoApi(_In_ PDRIVER_OBJECT DriverObject, _In_ PWCH Name); FindDriverBaseNoApi(_In_ PDRIVER_OBJECT DriverObject, _In_ PWCH Name);

View file

@ -15,7 +15,7 @@
#endif #endif
BOOLEAN BOOLEAN
ValidateThreadsPspCidTableEntry(_In_ PETHREAD Thread) DoesThreadHaveValidCidEntry(_In_ PETHREAD Thread)
{ {
PAGED_CODE(); PAGED_CODE();

View file

@ -7,7 +7,7 @@
#include "callbacks.h" #include "callbacks.h"
BOOLEAN BOOLEAN
ValidateThreadsPspCidTableEntry(_In_ PETHREAD Thread); DoesThreadHaveValidCidEntry(_In_ PETHREAD Thread);
VOID VOID
DetectThreadsAttachedToProtectedProcess(); DetectThreadsAttachedToProtectedProcess();