From 90b1b683dd7e963a8fe49f307f3c9eb40f85d835 Mon Sep 17 00:00:00 2001 From: Nodir Temirkhodjaev Date: Fri, 6 Jan 2023 17:30:34 +0300 Subject: [PATCH] Driver: Simplify flows handling --- src/driver/common/fortconf.c | 4 +- src/driver/fortcout.c | 139 +++++++++++----------- src/driver/fortdev.c | 15 ++- src/driver/fortdev.h | 5 - src/driver/fortps.c | 4 +- src/driver/fortstat.c | 217 +++++++++++++++++++++++------------ src/driver/fortstat.h | 33 ++---- src/driver/fortwrk.c | 4 +- 8 files changed, 239 insertions(+), 182 deletions(-) diff --git a/src/driver/common/fortconf.c b/src/driver/common/fortconf.c index 07738f8b..65e41ec6 100644 --- a/src/driver/common/fortconf.c +++ b/src/driver/common/fortconf.c @@ -188,10 +188,10 @@ FORT_API BOOL fort_conf_app_exe_equal( static BOOL fort_conf_app_wild_equal( const PFORT_APP_ENTRY app_entry, const PVOID path, UINT32 path_len) { - const WCHAR *app_path = (const WCHAR *) (app_entry + 1); - UNUSED(path_len); + const WCHAR *app_path = (const WCHAR *) (app_entry + 1); + return wildmatch(app_path, (const WCHAR *) path) == WM_MATCH; } diff --git a/src/driver/fortcout.c b/src/driver/fortcout.c index 6663f0c9..256861b6 100644 --- a/src/driver/fortcout.c +++ b/src/driver/fortcout.c @@ -40,12 +40,13 @@ static void fort_callout_classify_continue(FWPS_CLASSIFY_OUT0 *classifyOut) static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter, - FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField, - int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound, - UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, + UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, + int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, + BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, UINT32 process_id, PCUNICODE_STRING real_path, PFORT_CONF_REF conf_ref, INT8 *block_reason, BOOL blocked, FORT_APP_FLAGS app_flags, PIRP *irp, ULONG_PTR *info) { + PFORT_FLOW flow = (PFORT_FLOW) flowContext; const UINT64 flow_id = inMetaValues->flowHandle; const IPPROTO ip_proto = (IPPROTO) inFixedValues->incomingValue[ipProtoField].value.uint8; @@ -56,7 +57,7 @@ static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 * BOOL is_new_proc = FALSE; - const NTSTATUS status = fort_flow_associate(&fort_device()->stat, flow_id, process_id, + const NTSTATUS status = fort_flow_associate(&fort_device()->stat, flow, flow_id, process_id, group_index, isIPv6, is_tcp, inbound, is_reauth, &is_new_proc); if (!NT_SUCCESS(status)) { @@ -77,9 +78,9 @@ static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 * static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter, - FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField, - int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound, - UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, + UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, + int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, + BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, UINT32 process_id, PCUNICODE_STRING path, PCUNICODE_STRING real_path, PFORT_CONF_REF conf_ref, INT8 *block_reason, BOOL blocked, PIRP *irp, ULONG_PTR *info) { @@ -91,10 +92,10 @@ static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFix || !fort_conf_app_blocked(&conf_ref->conf, app_flags, block_reason)) { if (conf_flags.log_stat && fort_callout_classify_blocked_log_stat(inFixedValues, inMetaValues, filter, - classifyOut, flagsField, localIpField, remoteIpField, localPortField, - remotePortField, ipProtoField, isIPv6, inbound, classify_flags, remote_ip, - conf_flags, process_id, real_path, conf_ref, block_reason, blocked, - app_flags, irp, info)) + flowContext, classifyOut, flagsField, localIpField, remoteIpField, + localPortField, remotePortField, ipProtoField, isIPv6, inbound, + classify_flags, remote_ip, conf_flags, process_id, real_path, conf_ref, + block_reason, blocked, app_flags, irp, info)) return TRUE; /* blocked */ blocked = FALSE; /* allow */ @@ -118,9 +119,9 @@ static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFix static BOOL fort_callout_classify_blocked(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter, - FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField, - int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound, - UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, + UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, + int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, + BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags, UINT32 process_id, PUNICODE_STRING path, PUNICODE_STRING real_path, PFORT_CONF_REF conf_ref, INT8 *block_reason, PIRP *irp, ULONG_PTR *info) { @@ -151,18 +152,18 @@ static BOOL fort_callout_classify_blocked(const FWPS_INCOMING_VALUES0 *inFixedVa blocked = FALSE; } - return fort_callout_classify_blocked_log(inFixedValues, inMetaValues, filter, classifyOut, - flagsField, localIpField, remoteIpField, localPortField, remotePortField, ipProtoField, - isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, path, real_path, - conf_ref, block_reason, blocked, irp, info); + return fort_callout_classify_blocked_log(inFixedValues, inMetaValues, filter, flowContext, + classifyOut, flagsField, localIpField, remoteIpField, localPortField, remotePortField, + ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, path, + real_path, conf_ref, block_reason, blocked, irp, info); } static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter, - FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField, - int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound, - UINT32 classify_flags, const UINT32 *remote_ip, PFORT_CONF_REF conf_ref, PIRP *irp, - ULONG_PTR *info) + UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, + int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, + BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, PFORT_CONF_REF conf_ref, + PIRP *irp, ULONG_PTR *info) { const FORT_CONF_FLAGS conf_flags = conf_ref->conf.flags; @@ -184,9 +185,9 @@ static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValu INT8 block_reason = FORT_BLOCK_REASON_UNKNOWN; const BOOL blocked = fort_callout_classify_blocked(inFixedValues, inMetaValues, filter, - classifyOut, flagsField, localIpField, remoteIpField, localPortField, remotePortField, - ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, &path, - &real_path, conf_ref, &block_reason, irp, info); + flowContext, classifyOut, flagsField, localIpField, remoteIpField, localPortField, + remotePortField, ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags, + process_id, &path, &real_path, conf_ref, &block_reason, irp, info); if (blocked) { /* Log the blocked connection */ @@ -220,8 +221,9 @@ static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValu static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter, - FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField, - int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound) + UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, + int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, + BOOL inbound) { PIRP irp = NULL; ULONG_PTR info; @@ -252,9 +254,9 @@ static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues, return; } - fort_callout_classify_check(inFixedValues, inMetaValues, filter, classifyOut, flagsField, - localIpField, remoteIpField, localPortField, remotePortField, ipProtoField, isIPv6, - inbound, classify_flags, remote_ip, conf_ref, &irp, &info); + fort_callout_classify_check(inFixedValues, inMetaValues, filter, flowContext, classifyOut, + flagsField, localIpField, remoteIpField, localPortField, remotePortField, ipProtoField, + isIPv6, inbound, classify_flags, remote_ip, conf_ref, &irp, &info); fort_conf_ref_put(device_conf, conf_ref); @@ -265,12 +267,11 @@ static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues, static void NTAPI fort_callout_connect_v4(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData, - const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) + const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) { UNUSED(layerData); - UNUSED(flowContext); - fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut, + fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut, FWPS_FIELD_ALE_AUTH_CONNECT_V4_FLAGS, FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_ADDRESS, FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_ADDRESS, FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_PORT, @@ -280,12 +281,11 @@ static void NTAPI fort_callout_connect_v4(const FWPS_INCOMING_VALUES0 *inFixedVa static void NTAPI fort_callout_connect_v6(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData, - const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) + const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) { UNUSED(layerData); - UNUSED(flowContext); - fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut, + fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut, FWPS_FIELD_ALE_AUTH_CONNECT_V6_FLAGS, FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_LOCAL_ADDRESS, FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_REMOTE_ADDRESS, FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_LOCAL_PORT, @@ -295,12 +295,11 @@ static void NTAPI fort_callout_connect_v6(const FWPS_INCOMING_VALUES0 *inFixedVa static void NTAPI fort_callout_accept_v4(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData, - const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) + const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) { UNUSED(layerData); - UNUSED(flowContext); - fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut, + fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_FLAGS, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_IP_LOCAL_ADDRESS, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_IP_REMOTE_ADDRESS, @@ -311,12 +310,11 @@ static void NTAPI fort_callout_accept_v4(const FWPS_INCOMING_VALUES0 *inFixedVal static void NTAPI fort_callout_accept_v6(const FWPS_INCOMING_VALUES0 *inFixedValues, const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData, - const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) + const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) { UNUSED(layerData); - UNUSED(flowContext); - fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut, + fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_FLAGS, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_IP_LOCAL_ADDRESS, FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_IP_REMOTE_ADDRESS, @@ -338,10 +336,10 @@ static NTSTATUS NTAPI fort_callout_notify( static void fort_callout_flow_classify(const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, UINT32 dataSize, BOOL inbound) { - const UINT32 headerSize = inbound ? inMetaValues->transportHeaderSize : 0; - UNUSED(classifyOut); + const UINT32 headerSize = inbound ? inMetaValues->transportHeaderSize : 0; + fort_flow_classify(&fort_device()->stat, flowContext, headerSize + dataSize, inbound); } @@ -349,18 +347,17 @@ static void NTAPI fort_callout_stream_classify(const FWPS_INCOMING_VALUES0 *inFi const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, FWPS_STREAM_CALLOUT_IO_PACKET0 *packet, const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut) { + UNUSED(inFixedValues); + const FWPS_STREAM_DATA0 *streamData = packet->streamData; const UINT32 streamFlags = streamData->flags; const UINT32 dataSize = (UINT32) streamData->dataLength; const BOOL inbound = (streamFlags & FWPS_STREAM_FLAG_RECEIVE) != 0; - UNUSED(inFixedValues); - fort_callout_flow_classify(inMetaValues, flowContext, classifyOut, dataSize, inbound); fort_callout_classify_permit(filter, classifyOut); - return; } static void fort_callout_datagram_classify(const FWPS_INCOMING_VALUES0 *inFixedValues, @@ -449,7 +446,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device) c.calloutKey = FORT_GUID_CALLOUT_CONNECT_V4; c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_connect_v4; - status = FwpsCalloutRegister0(device, &c, &fort_device()->connect4_id); + status = FwpsCalloutRegister0(device, &c, &stat->connect4_id); if (!NT_SUCCESS(status)) { LOG("Register Connect V4: Error: %x\n", status); TRACE(FORT_CALLOUT_REGISTER_CONNECT_V4_ERROR, status, 0, 0); @@ -460,7 +457,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device) c.calloutKey = FORT_GUID_CALLOUT_CONNECT_V6; c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_connect_v6; - status = FwpsCalloutRegister0(device, &c, &fort_device()->connect6_id); + status = FwpsCalloutRegister0(device, &c, &stat->connect6_id); if (!NT_SUCCESS(status)) { LOG("Register Connect V6: Error: %x\n", status); TRACE(FORT_CALLOUT_REGISTER_CONNECT_V6_ERROR, status, 0, 0); @@ -471,7 +468,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device) c.calloutKey = FORT_GUID_CALLOUT_ACCEPT_V4; c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_accept_v4; - status = FwpsCalloutRegister0(device, &c, &fort_device()->accept4_id); + status = FwpsCalloutRegister0(device, &c, &stat->accept4_id); if (!NT_SUCCESS(status)) { LOG("Register Accept V4: Error: %x\n", status); TRACE(FORT_CALLOUT_REGISTER_ACCEPT_V4_ERROR, status, 0, 0); @@ -482,7 +479,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device) c.calloutKey = FORT_GUID_CALLOUT_ACCEPT_V6; c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_accept_v6; - status = FwpsCalloutRegister0(device, &c, &fort_device()->accept6_id); + status = FwpsCalloutRegister0(device, &c, &stat->accept6_id); if (!NT_SUCCESS(status)) { LOG("Register Accept V6: Error: %x\n", status); TRACE(FORT_CALLOUT_REGISTER_ACCEPT_V6_ERROR, status, 0, 0); @@ -602,62 +599,62 @@ FORT_API void fort_callout_remove(void) { PFORT_STAT stat = &fort_device()->stat; - if (fort_device()->connect4_id) { - FwpsCalloutUnregisterById0(fort_device()->connect4_id); - fort_device()->connect4_id = 0; + if (stat->connect4_id != 0) { + FwpsCalloutUnregisterById0(stat->connect4_id); + stat->connect4_id = 0; } - if (fort_device()->connect6_id) { - FwpsCalloutUnregisterById0(fort_device()->connect6_id); - fort_device()->connect6_id = 0; + if (stat->connect6_id != 0) { + FwpsCalloutUnregisterById0(stat->connect6_id); + stat->connect6_id = 0; } - if (fort_device()->accept4_id) { - FwpsCalloutUnregisterById0(fort_device()->accept4_id); - fort_device()->accept4_id = 0; + if (stat->accept4_id != 0) { + FwpsCalloutUnregisterById0(stat->accept4_id); + stat->accept4_id = 0; } - if (fort_device()->accept6_id) { - FwpsCalloutUnregisterById0(fort_device()->accept6_id); - fort_device()->accept6_id = 0; + if (stat->accept6_id != 0) { + FwpsCalloutUnregisterById0(stat->accept6_id); + stat->accept6_id = 0; } - if (stat->stream4_id) { + if (stat->stream4_id != 0) { FwpsCalloutUnregisterById0(stat->stream4_id); stat->stream4_id = 0; } - if (stat->stream6_id) { + if (stat->stream6_id != 0) { FwpsCalloutUnregisterById0(stat->stream6_id); stat->stream6_id = 0; } - if (stat->datagram4_id) { + if (stat->datagram4_id != 0) { FwpsCalloutUnregisterById0(stat->datagram4_id); stat->datagram4_id = 0; } - if (stat->datagram6_id) { + if (stat->datagram6_id != 0) { FwpsCalloutUnregisterById0(stat->datagram6_id); stat->datagram6_id = 0; } - if (stat->in_transport4_id) { + if (stat->in_transport4_id != 0) { FwpsCalloutUnregisterById0(stat->in_transport4_id); stat->in_transport4_id = 0; } - if (stat->in_transport6_id) { + if (stat->in_transport6_id != 0) { FwpsCalloutUnregisterById0(stat->in_transport6_id); stat->in_transport6_id = 0; } - if (stat->out_transport4_id) { + if (stat->out_transport4_id != 0) { FwpsCalloutUnregisterById0(stat->out_transport4_id); stat->out_transport4_id = 0; } - if (stat->out_transport6_id) { + if (stat->out_transport6_id != 0) { FwpsCalloutUnregisterById0(stat->out_transport6_id); stat->out_transport6_id = 0; } diff --git a/src/driver/fortdev.c b/src/driver/fortdev.c index 041aa969..fea200a5 100644 --- a/src/driver/fortdev.c +++ b/src/driver/fortdev.c @@ -52,10 +52,10 @@ FORT_API void NTAPI fort_app_period_timer(void) FORT_API NTSTATUS fort_device_create(PDEVICE_OBJECT device, PIRP irp) { - NTSTATUS status = STATUS_SUCCESS; - UNUSED(device); + NTSTATUS status = STATUS_SUCCESS; + /* Device opened */ if (fort_device_flag_set(&g_device->conf, FORT_DEVICE_IS_OPENED, TRUE) & FORT_DEVICE_IS_OPENED) { @@ -111,12 +111,11 @@ FORT_API NTSTATUS fort_device_cleanup(PDEVICE_OBJECT device, PIRP irp) static void fort_device_cancel_pending(PDEVICE_OBJECT device, PIRP irp) { - ULONG_PTR info; - NTSTATUS status; - UNUSED(device); - status = fort_buffer_cancel_pending(&g_device->buffer, irp, &info); + ULONG_PTR info; + + const NTSTATUS status = fort_buffer_cancel_pending(&g_device->buffer, irp, &info); IoReleaseCancelSpinLock(irp->CancelIrql); /* before IoCompleteRequest()! */ @@ -290,10 +289,10 @@ static NTSTATUS fort_device_control_process( FORT_API NTSTATUS fort_device_control(PDEVICE_OBJECT device, PIRP irp) { - ULONG_PTR info = 0; - UNUSED(device); + ULONG_PTR info = 0; + const PIO_STACK_LOCATION irp_stack = IoGetCurrentIrpStackLocation(irp); const NTSTATUS status = fort_device_control_process(irp_stack, irp, &info); diff --git a/src/driver/fortdev.h b/src/driver/fortdev.h index ca35fefc..495c7203 100644 --- a/src/driver/fortdev.h +++ b/src/driver/fortdev.h @@ -13,11 +13,6 @@ typedef struct fort_device { - UINT32 connect4_id; - UINT32 connect6_id; - UINT32 accept4_id; - UINT32 accept6_id; - PDEVICE_OBJECT object; PCALLBACK_OBJECT power_cb_obj; diff --git a/src/driver/fortps.c b/src/driver/fortps.c index 55576f41..64f5622b 100644 --- a/src/driver/fortps.c +++ b/src/driver/fortps.c @@ -585,10 +585,10 @@ static void fort_pstree_handle_new_proc(PFORT_PSTREE ps_tree, PCUNICODE_STRING p static void NTAPI fort_pstree_notify( PEPROCESS process, HANDLE processHandle, PPS_CREATE_NOTIFY_INFO createInfo) { - PFORT_PSTREE ps_tree = &fort_device()->ps_tree; - UNUSED(process); + PFORT_PSTREE ps_tree = &fort_device()->ps_tree; + const DWORD processId = (DWORD) (ptrdiff_t) processHandle; const tommy_key_t pid_hash = (tommy_key_t) tommy_hash_u32(0, &processId, sizeof(DWORD)); diff --git a/src/driver/fortstat.c b/src/driver/fortstat.c index 5806d632..e8d6937b 100644 --- a/src/driver/fortstat.c +++ b/src/driver/fortstat.c @@ -124,6 +124,28 @@ FORT_API UCHAR fort_flow_flags(PFORT_FLOW flow) return fort_flow_flags_set(flow, 0, TRUE); } +static void fort_flow_context_flow_init( + PFORT_STAT stat, BOOL isIPv6, BOOL inbound, UINT16 *layerId, UINT32 *calloutId) +{ + if (inbound) { + if (isIPv6) { + *layerId = FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V6; + *calloutId = stat->accept6_id; + } else { + *layerId = FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V4; + *calloutId = stat->accept4_id; + } + } else { + if (isIPv6) { + *layerId = FWPS_LAYER_ALE_AUTH_CONNECT_V6; + *calloutId = stat->connect6_id; + } else { + *layerId = FWPS_LAYER_ALE_AUTH_CONNECT_V4; + *calloutId = stat->datagram4_id; + } + } +} + static void fort_flow_context_stream_init( PFORT_STAT stat, BOOL isIPv6, BOOL is_tcp, UINT16 *layerId, UINT32 *calloutId) { @@ -168,6 +190,17 @@ static void fort_flow_context_transport_init( } } +inline static NTSTATUS fort_flow_context_flow_set( + PFORT_STAT stat, UINT64 flow_id, UINT64 flowContext, BOOL isIPv6, BOOL inbound) +{ + UINT16 layerId; + UINT32 calloutId; + + fort_flow_context_flow_init(stat, isIPv6, inbound, &layerId, &calloutId); + + return FwpsFlowAssociateContext0(flow_id, layerId, calloutId, flowContext); +} + inline static NTSTATUS fort_flow_context_stream_set( PFORT_STAT stat, UINT64 flow_id, UINT64 flowContext, BOOL isIPv6, BOOL is_tcp) { @@ -196,10 +229,22 @@ static void fort_flow_context_set( const UINT64 flow_id = flow->flow_id; const UINT64 flowContext = (UINT64) flow; + fort_flow_context_flow_set(stat, flow_id, flowContext, isIPv6, inbound); fort_flow_context_stream_set(stat, flow_id, flowContext, isIPv6, is_tcp); fort_flow_context_transport_set(stat, flow_id, flowContext, isIPv6, inbound); } +inline static NTSTATUS fort_flow_context_flow_remove( + PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL inbound) +{ + UINT16 layerId; + UINT32 calloutId; + + fort_flow_context_flow_init(stat, isIPv6, inbound, &layerId, &calloutId); + + return FwpsFlowRemoveContext0(flow_id, layerId, calloutId); +} + inline static NTSTATUS fort_flow_context_stream_remove( PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp) { @@ -222,7 +267,27 @@ inline static NTSTATUS fort_flow_context_transport_remove( return FwpsFlowRemoveContext0(flow_id, layerId, calloutId); } -static void fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow) +static BOOL fort_flow_context_remove_id( + PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp, BOOL inbound) +{ + NTSTATUS status; + + status = fort_flow_context_flow_remove(stat, flow_id, isIPv6, inbound); + if (status == STATUS_PENDING) + return FALSE; + + status = fort_flow_context_stream_remove(stat, flow_id, isIPv6, is_tcp); + if (status == STATUS_PENDING) + return FALSE; + + status = fort_flow_context_transport_remove(stat, flow_id, isIPv6, inbound); + if (status == STATUS_PENDING) + return FALSE; + + return TRUE; +} + +static BOOL fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow) { const UINT64 flow_id = flow->flow_id; @@ -231,13 +296,12 @@ static void fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow) const BOOL is_tcp = (flow_flags & FORT_FLOW_TCP); const BOOL isIPv6 = (flow_flags & FORT_FLOW_IP6); - const NTSTATUS stream_status = fort_flow_context_stream_remove(stat, flow_id, isIPv6, is_tcp); - const NTSTATUS transport_status = - fort_flow_context_transport_remove(stat, flow_id, isIPv6, inbound); + return fort_flow_context_remove_id(stat, flow_id, isIPv6, is_tcp, inbound); +} - if (stream_status == STATUS_PENDING || transport_status == STATUS_PENDING) { - fort_stat_flags_set(stat, FORT_STAT_FLOW_PENDING, TRUE); - } +static BOOL fort_flow_is_closed(PFORT_FLOW flow) +{ + return (flow->opt.proc_index == FORT_PROC_BAD_INDEX); } static void fort_flow_close(PFORT_FLOW flow) @@ -245,39 +309,26 @@ static void fort_flow_close(PFORT_FLOW flow) flow->opt.proc_index = FORT_PROC_BAD_INDEX; } -static PFORT_FLOW fort_flow_get(PFORT_STAT stat, UINT64 flow_id, tommy_key_t flow_hash) -{ - PFORT_FLOW flow = (PFORT_FLOW) tommy_hashdyn_bucket(&stat->flows_map, flow_hash); - - while (flow != NULL) { - if (flow->flow_hash == flow_hash && flow->flow_id == flow_id) - return flow; - - flow = flow->next; - } - return NULL; -} - static void fort_flow_free(PFORT_STAT stat, PFORT_FLOW flow) { - const UINT16 proc_index = flow->opt.proc_index; - - if (proc_index != FORT_PROC_BAD_INDEX) { - fort_stat_proc_dec(stat, proc_index); + if (!fort_flow_is_closed(flow)) { + fort_stat_proc_dec(stat, flow->opt.proc_index); } - tommy_hashdyn_remove_existing(&stat->flows_map, (tommy_hashdyn_node *) flow); - /* Add to free chain */ flow->next = stat->flow_free; stat->flow_free = flow; + + stat->flow_count--; } -static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key_t flow_hash, - BOOL isIPv6, BOOL is_tcp, BOOL inbound) +static PFORT_FLOW fort_flow_new( + PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp, BOOL inbound) { PFORT_FLOW flow; + stat->flow_count++; + if (stat->flow_free != NULL) { flow = stat->flow_free; stat->flow_free = flow->next; @@ -291,8 +342,6 @@ static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key flow = tommy_arrayof_ref(&stat->flows, size); } - tommy_hashdyn_insert(&stat->flows_map, (tommy_hashdyn_node *) flow, 0, flow_hash); - flow->flow_id = flow_id; fort_flow_context_set(stat, flow, isIPv6, is_tcp, inbound); @@ -300,26 +349,26 @@ static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key return flow; } -static NTSTATUS fort_flow_add(PFORT_STAT stat, UINT64 flow_id, UCHAR group_index, UINT16 proc_index, - UCHAR speed_limit, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth) +static NTSTATUS fort_flow_add(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id, UCHAR group_index, + UINT16 proc_index, UCHAR speed_limit, BOOL isIPv6, BOOL is_tcp, BOOL inbound, + BOOL is_reauth) { - const tommy_key_t flow_hash = fort_flow_hash(flow_id); - PFORT_FLOW flow = fort_flow_get(stat, flow_id, flow_hash); BOOL is_new_flow = FALSE; if (flow == NULL) { if (is_reauth) { - /* Block existing flow after reauth. to be able to associate a flow-context */ - return FORT_STATUS_FLOW_BLOCK; + /* Remove existing flow context after reauth. to be able to associate a flow-context */ + if (!fort_flow_context_remove_id(stat, flow_id, isIPv6, is_tcp, inbound)) + return FORT_STATUS_FLOW_BLOCK; } - flow = fort_flow_new(stat, flow_id, flow_hash, isIPv6, is_tcp, inbound); + flow = fort_flow_new(stat, flow_id, isIPv6, is_tcp, inbound); if (flow == NULL) return STATUS_INSUFFICIENT_RESOURCES; is_new_flow = TRUE; } else { - is_new_flow = (flow->opt.proc_index == FORT_PROC_BAD_INDEX); + is_new_flow = fort_flow_is_closed(flow); } if (is_new_flow) { @@ -340,35 +389,54 @@ FORT_API void fort_stat_open(PFORT_STAT stat) tommy_hashdyn_init(&stat->procs_map); tommy_arrayof_init(&stat->flows, sizeof(FORT_FLOW)); - tommy_hashdyn_init(&stat->flows_map); KeInitializeSpinLock(&stat->lock); } -static BOOL fort_stat_close_flows(PFORT_STAT stat) +static BOOL fort_stat_close_flows(PFORT_STAT stat, UINT32 *flow_index) { + BOOL is_pending = FALSE; + KLOCK_QUEUE_HANDLE lock_queue; KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); + { + assert((fort_stat_flags(stat) & FORT_STAT_CLOSED) != 0); - fort_stat_flags_set(stat, FORT_STAT_FLOW_PENDING, FALSE); + tommy_arrayof *flows = &stat->flows; - tommy_hashdyn_foreach_node_arg(&stat->flows_map, &fort_flow_context_remove, stat); + UINT32 i = *flow_index; - const BOOL not_pending = (fort_stat_flags(stat) & FORT_STAT_FLOW_PENDING) == 0; + for (; stat->flow_count != 0; ++i) { + PFORT_FLOW flow = tommy_arrayof_ref(flows, i); + if (fort_flow_is_closed(flow)) + continue; + if (!fort_flow_context_remove(stat, flow)) { + is_pending = TRUE; + break; + } + + fort_flow_close(flow); + + stat->flow_count--; + } + + *flow_index = i; + } KeReleaseInStackQueuedSpinLock(&lock_queue); - return not_pending; + return !is_pending; } FORT_API void fort_stat_close(PFORT_STAT stat) { fort_stat_flags_set(stat, FORT_STAT_CLOSED, TRUE); - while (!fort_stat_close_flows(stat)) { + UINT32 flow_index = 0; + while (!fort_stat_close_flows(stat, &flow_index)) { /* Wait for asynchronously deleting flows */ LARGE_INTEGER delay; - delay.QuadPart = -5000000; /* sleep 500000us (500ms) */ + delay.QuadPart = -3000000; /* sleep 300000us (300ms) */ KeDelayExecutionThread(KernelMode, FALSE, &delay); } @@ -380,11 +448,26 @@ FORT_API void fort_stat_close(PFORT_STAT stat) tommy_hashdyn_done(&stat->procs_map); tommy_arrayof_done(&stat->flows); - tommy_hashdyn_done(&stat->flows_map); KeReleaseInStackQueuedSpinLock(&lock_queue); } +static void fort_stat_clear_flows(PFORT_STAT stat) +{ + tommy_arrayof *flows = &stat->flows; + UINT32 flow_count = stat->flow_count; + + for (UINT32 i = 0; flow_count != 0; ++i) { + PFORT_FLOW flow = tommy_arrayof_ref(flows, i); + if (fort_flow_is_closed(flow)) + continue; + + fort_flow_close(flow); + + --flow_count; + } +} + static void fort_stat_clear(PFORT_STAT stat) { KLOCK_QUEUE_HANDLE lock_queue; @@ -393,7 +476,8 @@ static void fort_stat_clear(PFORT_STAT stat) fort_stat_proc_active_clear(stat); tommy_hashdyn_foreach_node_arg(&stat->procs_map, &fort_stat_proc_free, stat); - tommy_hashdyn_foreach_node(&stat->flows_map, &fort_flow_close); + + fort_stat_clear_flows(stat); KeReleaseInStackQueuedSpinLock(&lock_queue); } @@ -427,8 +511,8 @@ FORT_API void fort_stat_conf_flags_update(PFORT_STAT stat, const PFORT_CONF_FLAG KeReleaseInStackQueuedSpinLock(&lock_queue); } -static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOOL is_reauth, - BOOL *is_new_proc, PFORT_STAT_PROC *proc) +static NTSTATUS fort_flow_associate_proc( + PFORT_STAT stat, UINT32 process_id, BOOL *is_new_proc, PFORT_STAT_PROC *proc) { if ((fort_stat_flags(stat) & FORT_STAT_LOG) == 0) return STATUS_DEVICE_DATA_ERROR; @@ -438,11 +522,6 @@ static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOO *proc = fort_stat_proc_get(stat, process_id, proc_hash); if (*proc == NULL) { - if (is_reauth) { - /* Block existing flow after reauth. to be able to associate a flow-context */ - return FORT_STATUS_FLOW_BLOCK; - } - *proc = fort_stat_proc_add(stat, process_id); if (*proc == NULL) @@ -454,9 +533,9 @@ static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOO return STATUS_SUCCESS; } -FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 process_id, - UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth, - BOOL *is_new_proc) +FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id, + UINT32 process_id, UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, + BOOL is_reauth, BOOL *is_new_proc) { NTSTATUS status; @@ -464,14 +543,14 @@ FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 pr KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); PFORT_STAT_PROC proc = NULL; - status = fort_flow_associate_proc(stat, process_id, is_reauth, is_new_proc, &proc); + status = fort_flow_associate_proc(stat, process_id, is_new_proc, &proc); /* Add flow */ if (NT_SUCCESS(status)) { const UCHAR speed_limit = fort_stat_group_speed_limit(stat, group_index); - status = fort_flow_add(stat, flow_id, group_index, proc->proc_index, speed_limit, isIPv6, - is_tcp, inbound, is_reauth); + status = fort_flow_add(stat, flow, flow_id, group_index, proc->proc_index, speed_limit, + isIPv6, is_tcp, inbound, is_reauth); if (!NT_SUCCESS(status) && *is_new_proc) { fort_stat_proc_free(stat, proc); @@ -505,18 +584,14 @@ FORT_API void fort_flow_classify(PFORT_STAT stat, UINT64 flowContext, UINT32 dat KLOCK_QUEUE_HANDLE lock_queue; KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); - if ((fort_stat_flags(stat) & FORT_STAT_LOG) != 0) { - const FORT_FLOW_OPT opt = flow->opt; + if ((fort_stat_flags(stat) & FORT_STAT_LOG) != 0 && !fort_flow_is_closed(flow)) { + PFORT_STAT_PROC proc = tommy_arrayof_ref(&stat->procs, flow->opt.proc_index); + UINT32 *proc_bytes = inbound ? &proc->traf.in_bytes : &proc->traf.out_bytes; - if (opt.proc_index != FORT_PROC_BAD_INDEX) { - PFORT_STAT_PROC proc = tommy_arrayof_ref(&stat->procs, opt.proc_index); - UINT32 *proc_bytes = inbound ? &proc->traf.in_bytes : &proc->traf.out_bytes; + /* Add traffic to process */ + *proc_bytes += data_len; - /* Add traffic to process */ - *proc_bytes += data_len; - - fort_stat_proc_active_add(stat, proc); - } + fort_stat_proc_active_add(stat, proc); } KeReleaseInStackQueuedSpinLock(&lock_queue); diff --git a/src/driver/fortstat.h b/src/driver/fortstat.h index 9f4d0843..a087f408 100644 --- a/src/driver/fortstat.h +++ b/src/driver/fortstat.h @@ -61,45 +61,37 @@ typedef struct fort_flow_opt }; } FORT_FLOW_OPT, *PFORT_FLOW_OPT; -/* Synchronize with tommy_hashdyn_node! */ typedef struct fort_flow { - struct fort_flow *next; - struct fort_flow *prev; + FORT_FLOW_OPT opt; union { -#if defined(_WIN64) UINT64 flow_id; -#else - FORT_FLOW_OPT opt; -#endif - void *data; /* tommy_hashdyn_node::data */ + struct fort_flow *next; }; - - tommy_key_t flow_hash; /* tommy_hashdyn_node::index */ - -#if defined(_WIN64) - FORT_FLOW_OPT opt; -#else - UINT64 flow_id; -#endif } FORT_FLOW, *PFORT_FLOW; #define FORT_STAT_CLOSED 0x01 #define FORT_STAT_LOG 0x02 #define FORT_STAT_TIME_CHANGED 0x04 -#define FORT_STAT_FLOW_PENDING 0x08 /* used only on driver unloading */ typedef struct fort_stat { UCHAR volatile flags; UINT16 proc_active_count; + UINT32 flow_count; + + UINT32 connect4_id; + UINT32 connect6_id; + UINT32 accept4_id; + UINT32 accept6_id; UINT32 stream4_id; UINT32 stream6_id; UINT32 datagram4_id; UINT32 datagram6_id; + UINT32 in_transport4_id; UINT32 in_transport6_id; UINT32 out_transport4_id; @@ -114,7 +106,6 @@ typedef struct fort_stat tommy_hashdyn procs_map; tommy_arrayof flows; - tommy_hashdyn flows_map; FORT_CONF_GROUP conf_group; @@ -145,9 +136,9 @@ FORT_API void fort_stat_conf_update(PFORT_STAT stat, const PFORT_CONF_IO conf_io FORT_API void fort_stat_conf_flags_update(PFORT_STAT stat, const PFORT_CONF_FLAGS conf_flags); -FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 process_id, - UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth, - BOOL *is_new_proc); +FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id, + UINT32 process_id, UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, + BOOL is_reauth, BOOL *is_new_proc); FORT_API void fort_flow_delete(PFORT_STAT stat, UINT64 flowContext); diff --git a/src/driver/fortwrk.c b/src/driver/fortwrk.c index 0a5eb952..ac0103c9 100644 --- a/src/driver/fortwrk.c +++ b/src/driver/fortwrk.c @@ -8,11 +8,11 @@ static void NTAPI fort_worker_callback(PDEVICE_OBJECT device, PVOID context) { + UNUSED(device); + PFORT_WORKER worker = (PFORT_WORKER) context; const UCHAR id_bits = InterlockedAnd8(&worker->id_bits, 0); - UNUSED(device); - if (id_bits & (1 << FORT_WORKER_REAUTH)) { worker->funcs[FORT_WORKER_REAUTH](); }