Driver: Simplify flows handling

This commit is contained in:
Nodir Temirkhodjaev 2023-01-06 17:30:34 +03:00
parent 909dc5066e
commit 90b1b683dd
8 changed files with 239 additions and 182 deletions

View File

@ -188,10 +188,10 @@ FORT_API BOOL fort_conf_app_exe_equal(
static BOOL fort_conf_app_wild_equal(
const PFORT_APP_ENTRY app_entry, const PVOID path, UINT32 path_len)
{
const WCHAR *app_path = (const WCHAR *) (app_entry + 1);
UNUSED(path_len);
const WCHAR *app_path = (const WCHAR *) (app_entry + 1);
return wildmatch(app_path, (const WCHAR *) path) == WM_MATCH;
}

View File

@ -40,12 +40,13 @@ static void fort_callout_classify_continue(FWPS_CLASSIFY_OUT0 *classifyOut)
static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter,
FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField,
int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound,
UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField,
int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6,
BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT32 process_id, PCUNICODE_STRING real_path, PFORT_CONF_REF conf_ref, INT8 *block_reason,
BOOL blocked, FORT_APP_FLAGS app_flags, PIRP *irp, ULONG_PTR *info)
{
PFORT_FLOW flow = (PFORT_FLOW) flowContext;
const UINT64 flow_id = inMetaValues->flowHandle;
const IPPROTO ip_proto = (IPPROTO) inFixedValues->incomingValue[ipProtoField].value.uint8;
@ -56,7 +57,7 @@ static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 *
BOOL is_new_proc = FALSE;
const NTSTATUS status = fort_flow_associate(&fort_device()->stat, flow_id, process_id,
const NTSTATUS status = fort_flow_associate(&fort_device()->stat, flow, flow_id, process_id,
group_index, isIPv6, is_tcp, inbound, is_reauth, &is_new_proc);
if (!NT_SUCCESS(status)) {
@ -77,9 +78,9 @@ static BOOL fort_callout_classify_blocked_log_stat(const FWPS_INCOMING_VALUES0 *
static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter,
FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField,
int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound,
UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField,
int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6,
BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT32 process_id, PCUNICODE_STRING path, PCUNICODE_STRING real_path,
PFORT_CONF_REF conf_ref, INT8 *block_reason, BOOL blocked, PIRP *irp, ULONG_PTR *info)
{
@ -91,10 +92,10 @@ static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFix
|| !fort_conf_app_blocked(&conf_ref->conf, app_flags, block_reason)) {
if (conf_flags.log_stat
&& fort_callout_classify_blocked_log_stat(inFixedValues, inMetaValues, filter,
classifyOut, flagsField, localIpField, remoteIpField, localPortField,
remotePortField, ipProtoField, isIPv6, inbound, classify_flags, remote_ip,
conf_flags, process_id, real_path, conf_ref, block_reason, blocked,
app_flags, irp, info))
flowContext, classifyOut, flagsField, localIpField, remoteIpField,
localPortField, remotePortField, ipProtoField, isIPv6, inbound,
classify_flags, remote_ip, conf_flags, process_id, real_path, conf_ref,
block_reason, blocked, app_flags, irp, info))
return TRUE; /* blocked */
blocked = FALSE; /* allow */
@ -118,9 +119,9 @@ static BOOL fort_callout_classify_blocked_log(const FWPS_INCOMING_VALUES0 *inFix
static BOOL fort_callout_classify_blocked(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter,
FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField,
int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound,
UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField,
int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6,
BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, FORT_CONF_FLAGS conf_flags,
UINT32 process_id, PUNICODE_STRING path, PUNICODE_STRING real_path, PFORT_CONF_REF conf_ref,
INT8 *block_reason, PIRP *irp, ULONG_PTR *info)
{
@ -151,18 +152,18 @@ static BOOL fort_callout_classify_blocked(const FWPS_INCOMING_VALUES0 *inFixedVa
blocked = FALSE;
}
return fort_callout_classify_blocked_log(inFixedValues, inMetaValues, filter, classifyOut,
flagsField, localIpField, remoteIpField, localPortField, remotePortField, ipProtoField,
isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, path, real_path,
conf_ref, block_reason, blocked, irp, info);
return fort_callout_classify_blocked_log(inFixedValues, inMetaValues, filter, flowContext,
classifyOut, flagsField, localIpField, remoteIpField, localPortField, remotePortField,
ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, path,
real_path, conf_ref, block_reason, blocked, irp, info);
}
static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter,
FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField,
int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound,
UINT32 classify_flags, const UINT32 *remote_ip, PFORT_CONF_REF conf_ref, PIRP *irp,
ULONG_PTR *info)
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField,
int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6,
BOOL inbound, UINT32 classify_flags, const UINT32 *remote_ip, PFORT_CONF_REF conf_ref,
PIRP *irp, ULONG_PTR *info)
{
const FORT_CONF_FLAGS conf_flags = conf_ref->conf.flags;
@ -184,9 +185,9 @@ static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValu
INT8 block_reason = FORT_BLOCK_REASON_UNKNOWN;
const BOOL blocked = fort_callout_classify_blocked(inFixedValues, inMetaValues, filter,
classifyOut, flagsField, localIpField, remoteIpField, localPortField, remotePortField,
ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags, process_id, &path,
&real_path, conf_ref, &block_reason, irp, info);
flowContext, classifyOut, flagsField, localIpField, remoteIpField, localPortField,
remotePortField, ipProtoField, isIPv6, inbound, classify_flags, remote_ip, conf_flags,
process_id, &path, &real_path, conf_ref, &block_reason, irp, info);
if (blocked) {
/* Log the blocked connection */
@ -220,8 +221,9 @@ static void fort_callout_classify_check(const FWPS_INCOMING_VALUES0 *inFixedValu
static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, const FWPS_FILTER0 *filter,
FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField, int remoteIpField,
int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6, BOOL inbound)
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, int flagsField, int localIpField,
int remoteIpField, int localPortField, int remotePortField, int ipProtoField, BOOL isIPv6,
BOOL inbound)
{
PIRP irp = NULL;
ULONG_PTR info;
@ -252,9 +254,9 @@ static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues,
return;
}
fort_callout_classify_check(inFixedValues, inMetaValues, filter, classifyOut, flagsField,
localIpField, remoteIpField, localPortField, remotePortField, ipProtoField, isIPv6,
inbound, classify_flags, remote_ip, conf_ref, &irp, &info);
fort_callout_classify_check(inFixedValues, inMetaValues, filter, flowContext, classifyOut,
flagsField, localIpField, remoteIpField, localPortField, remotePortField, ipProtoField,
isIPv6, inbound, classify_flags, remote_ip, conf_ref, &irp, &info);
fort_conf_ref_put(device_conf, conf_ref);
@ -265,12 +267,11 @@ static void fort_callout_classify(const FWPS_INCOMING_VALUES0 *inFixedValues,
static void NTAPI fort_callout_connect_v4(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData,
const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
{
UNUSED(layerData);
UNUSED(flowContext);
fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut,
fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut,
FWPS_FIELD_ALE_AUTH_CONNECT_V4_FLAGS, FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_ADDRESS,
FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_ADDRESS,
FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_LOCAL_PORT,
@ -280,12 +281,11 @@ static void NTAPI fort_callout_connect_v4(const FWPS_INCOMING_VALUES0 *inFixedVa
static void NTAPI fort_callout_connect_v6(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData,
const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
{
UNUSED(layerData);
UNUSED(flowContext);
fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut,
fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut,
FWPS_FIELD_ALE_AUTH_CONNECT_V6_FLAGS, FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_LOCAL_ADDRESS,
FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_REMOTE_ADDRESS,
FWPS_FIELD_ALE_AUTH_CONNECT_V6_IP_LOCAL_PORT,
@ -295,12 +295,11 @@ static void NTAPI fort_callout_connect_v6(const FWPS_INCOMING_VALUES0 *inFixedVa
static void NTAPI fort_callout_accept_v4(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData,
const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
{
UNUSED(layerData);
UNUSED(flowContext);
fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut,
fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_FLAGS,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_IP_LOCAL_ADDRESS,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V4_IP_REMOTE_ADDRESS,
@ -311,12 +310,11 @@ static void NTAPI fort_callout_accept_v4(const FWPS_INCOMING_VALUES0 *inFixedVal
static void NTAPI fort_callout_accept_v6(const FWPS_INCOMING_VALUES0 *inFixedValues,
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, void *layerData,
const FWPS_FILTER0 *filter, const UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
{
UNUSED(layerData);
UNUSED(flowContext);
fort_callout_classify(inFixedValues, inMetaValues, filter, classifyOut,
fort_callout_classify(inFixedValues, inMetaValues, filter, flowContext, classifyOut,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_FLAGS,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_IP_LOCAL_ADDRESS,
FWPS_FIELD_ALE_AUTH_RECV_ACCEPT_V6_IP_REMOTE_ADDRESS,
@ -338,10 +336,10 @@ static NTSTATUS NTAPI fort_callout_notify(
static void fort_callout_flow_classify(const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues,
UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut, UINT32 dataSize, BOOL inbound)
{
const UINT32 headerSize = inbound ? inMetaValues->transportHeaderSize : 0;
UNUSED(classifyOut);
const UINT32 headerSize = inbound ? inMetaValues->transportHeaderSize : 0;
fort_flow_classify(&fort_device()->stat, flowContext, headerSize + dataSize, inbound);
}
@ -349,18 +347,17 @@ static void NTAPI fort_callout_stream_classify(const FWPS_INCOMING_VALUES0 *inFi
const FWPS_INCOMING_METADATA_VALUES0 *inMetaValues, FWPS_STREAM_CALLOUT_IO_PACKET0 *packet,
const FWPS_FILTER0 *filter, UINT64 flowContext, FWPS_CLASSIFY_OUT0 *classifyOut)
{
UNUSED(inFixedValues);
const FWPS_STREAM_DATA0 *streamData = packet->streamData;
const UINT32 streamFlags = streamData->flags;
const UINT32 dataSize = (UINT32) streamData->dataLength;
const BOOL inbound = (streamFlags & FWPS_STREAM_FLAG_RECEIVE) != 0;
UNUSED(inFixedValues);
fort_callout_flow_classify(inMetaValues, flowContext, classifyOut, dataSize, inbound);
fort_callout_classify_permit(filter, classifyOut);
return;
}
static void fort_callout_datagram_classify(const FWPS_INCOMING_VALUES0 *inFixedValues,
@ -449,7 +446,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device)
c.calloutKey = FORT_GUID_CALLOUT_CONNECT_V4;
c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_connect_v4;
status = FwpsCalloutRegister0(device, &c, &fort_device()->connect4_id);
status = FwpsCalloutRegister0(device, &c, &stat->connect4_id);
if (!NT_SUCCESS(status)) {
LOG("Register Connect V4: Error: %x\n", status);
TRACE(FORT_CALLOUT_REGISTER_CONNECT_V4_ERROR, status, 0, 0);
@ -460,7 +457,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device)
c.calloutKey = FORT_GUID_CALLOUT_CONNECT_V6;
c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_connect_v6;
status = FwpsCalloutRegister0(device, &c, &fort_device()->connect6_id);
status = FwpsCalloutRegister0(device, &c, &stat->connect6_id);
if (!NT_SUCCESS(status)) {
LOG("Register Connect V6: Error: %x\n", status);
TRACE(FORT_CALLOUT_REGISTER_CONNECT_V6_ERROR, status, 0, 0);
@ -471,7 +468,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device)
c.calloutKey = FORT_GUID_CALLOUT_ACCEPT_V4;
c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_accept_v4;
status = FwpsCalloutRegister0(device, &c, &fort_device()->accept4_id);
status = FwpsCalloutRegister0(device, &c, &stat->accept4_id);
if (!NT_SUCCESS(status)) {
LOG("Register Accept V4: Error: %x\n", status);
TRACE(FORT_CALLOUT_REGISTER_ACCEPT_V4_ERROR, status, 0, 0);
@ -482,7 +479,7 @@ FORT_API NTSTATUS fort_callout_install(PDEVICE_OBJECT device)
c.calloutKey = FORT_GUID_CALLOUT_ACCEPT_V6;
c.classifyFn = (FWPS_CALLOUT_CLASSIFY_FN0) fort_callout_accept_v6;
status = FwpsCalloutRegister0(device, &c, &fort_device()->accept6_id);
status = FwpsCalloutRegister0(device, &c, &stat->accept6_id);
if (!NT_SUCCESS(status)) {
LOG("Register Accept V6: Error: %x\n", status);
TRACE(FORT_CALLOUT_REGISTER_ACCEPT_V6_ERROR, status, 0, 0);
@ -602,62 +599,62 @@ FORT_API void fort_callout_remove(void)
{
PFORT_STAT stat = &fort_device()->stat;
if (fort_device()->connect4_id) {
FwpsCalloutUnregisterById0(fort_device()->connect4_id);
fort_device()->connect4_id = 0;
if (stat->connect4_id != 0) {
FwpsCalloutUnregisterById0(stat->connect4_id);
stat->connect4_id = 0;
}
if (fort_device()->connect6_id) {
FwpsCalloutUnregisterById0(fort_device()->connect6_id);
fort_device()->connect6_id = 0;
if (stat->connect6_id != 0) {
FwpsCalloutUnregisterById0(stat->connect6_id);
stat->connect6_id = 0;
}
if (fort_device()->accept4_id) {
FwpsCalloutUnregisterById0(fort_device()->accept4_id);
fort_device()->accept4_id = 0;
if (stat->accept4_id != 0) {
FwpsCalloutUnregisterById0(stat->accept4_id);
stat->accept4_id = 0;
}
if (fort_device()->accept6_id) {
FwpsCalloutUnregisterById0(fort_device()->accept6_id);
fort_device()->accept6_id = 0;
if (stat->accept6_id != 0) {
FwpsCalloutUnregisterById0(stat->accept6_id);
stat->accept6_id = 0;
}
if (stat->stream4_id) {
if (stat->stream4_id != 0) {
FwpsCalloutUnregisterById0(stat->stream4_id);
stat->stream4_id = 0;
}
if (stat->stream6_id) {
if (stat->stream6_id != 0) {
FwpsCalloutUnregisterById0(stat->stream6_id);
stat->stream6_id = 0;
}
if (stat->datagram4_id) {
if (stat->datagram4_id != 0) {
FwpsCalloutUnregisterById0(stat->datagram4_id);
stat->datagram4_id = 0;
}
if (stat->datagram6_id) {
if (stat->datagram6_id != 0) {
FwpsCalloutUnregisterById0(stat->datagram6_id);
stat->datagram6_id = 0;
}
if (stat->in_transport4_id) {
if (stat->in_transport4_id != 0) {
FwpsCalloutUnregisterById0(stat->in_transport4_id);
stat->in_transport4_id = 0;
}
if (stat->in_transport6_id) {
if (stat->in_transport6_id != 0) {
FwpsCalloutUnregisterById0(stat->in_transport6_id);
stat->in_transport6_id = 0;
}
if (stat->out_transport4_id) {
if (stat->out_transport4_id != 0) {
FwpsCalloutUnregisterById0(stat->out_transport4_id);
stat->out_transport4_id = 0;
}
if (stat->out_transport6_id) {
if (stat->out_transport6_id != 0) {
FwpsCalloutUnregisterById0(stat->out_transport6_id);
stat->out_transport6_id = 0;
}

View File

@ -52,10 +52,10 @@ FORT_API void NTAPI fort_app_period_timer(void)
FORT_API NTSTATUS fort_device_create(PDEVICE_OBJECT device, PIRP irp)
{
NTSTATUS status = STATUS_SUCCESS;
UNUSED(device);
NTSTATUS status = STATUS_SUCCESS;
/* Device opened */
if (fort_device_flag_set(&g_device->conf, FORT_DEVICE_IS_OPENED, TRUE)
& FORT_DEVICE_IS_OPENED) {
@ -111,12 +111,11 @@ FORT_API NTSTATUS fort_device_cleanup(PDEVICE_OBJECT device, PIRP irp)
static void fort_device_cancel_pending(PDEVICE_OBJECT device, PIRP irp)
{
ULONG_PTR info;
NTSTATUS status;
UNUSED(device);
status = fort_buffer_cancel_pending(&g_device->buffer, irp, &info);
ULONG_PTR info;
const NTSTATUS status = fort_buffer_cancel_pending(&g_device->buffer, irp, &info);
IoReleaseCancelSpinLock(irp->CancelIrql); /* before IoCompleteRequest()! */
@ -290,10 +289,10 @@ static NTSTATUS fort_device_control_process(
FORT_API NTSTATUS fort_device_control(PDEVICE_OBJECT device, PIRP irp)
{
ULONG_PTR info = 0;
UNUSED(device);
ULONG_PTR info = 0;
const PIO_STACK_LOCATION irp_stack = IoGetCurrentIrpStackLocation(irp);
const NTSTATUS status = fort_device_control_process(irp_stack, irp, &info);

View File

@ -13,11 +13,6 @@
typedef struct fort_device
{
UINT32 connect4_id;
UINT32 connect6_id;
UINT32 accept4_id;
UINT32 accept6_id;
PDEVICE_OBJECT object;
PCALLBACK_OBJECT power_cb_obj;

View File

@ -585,10 +585,10 @@ static void fort_pstree_handle_new_proc(PFORT_PSTREE ps_tree, PCUNICODE_STRING p
static void NTAPI fort_pstree_notify(
PEPROCESS process, HANDLE processHandle, PPS_CREATE_NOTIFY_INFO createInfo)
{
PFORT_PSTREE ps_tree = &fort_device()->ps_tree;
UNUSED(process);
PFORT_PSTREE ps_tree = &fort_device()->ps_tree;
const DWORD processId = (DWORD) (ptrdiff_t) processHandle;
const tommy_key_t pid_hash = (tommy_key_t) tommy_hash_u32(0, &processId, sizeof(DWORD));

View File

@ -124,6 +124,28 @@ FORT_API UCHAR fort_flow_flags(PFORT_FLOW flow)
return fort_flow_flags_set(flow, 0, TRUE);
}
static void fort_flow_context_flow_init(
PFORT_STAT stat, BOOL isIPv6, BOOL inbound, UINT16 *layerId, UINT32 *calloutId)
{
if (inbound) {
if (isIPv6) {
*layerId = FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V6;
*calloutId = stat->accept6_id;
} else {
*layerId = FWPS_LAYER_ALE_AUTH_RECV_ACCEPT_V4;
*calloutId = stat->accept4_id;
}
} else {
if (isIPv6) {
*layerId = FWPS_LAYER_ALE_AUTH_CONNECT_V6;
*calloutId = stat->connect6_id;
} else {
*layerId = FWPS_LAYER_ALE_AUTH_CONNECT_V4;
*calloutId = stat->datagram4_id;
}
}
}
static void fort_flow_context_stream_init(
PFORT_STAT stat, BOOL isIPv6, BOOL is_tcp, UINT16 *layerId, UINT32 *calloutId)
{
@ -168,6 +190,17 @@ static void fort_flow_context_transport_init(
}
}
inline static NTSTATUS fort_flow_context_flow_set(
PFORT_STAT stat, UINT64 flow_id, UINT64 flowContext, BOOL isIPv6, BOOL inbound)
{
UINT16 layerId;
UINT32 calloutId;
fort_flow_context_flow_init(stat, isIPv6, inbound, &layerId, &calloutId);
return FwpsFlowAssociateContext0(flow_id, layerId, calloutId, flowContext);
}
inline static NTSTATUS fort_flow_context_stream_set(
PFORT_STAT stat, UINT64 flow_id, UINT64 flowContext, BOOL isIPv6, BOOL is_tcp)
{
@ -196,10 +229,22 @@ static void fort_flow_context_set(
const UINT64 flow_id = flow->flow_id;
const UINT64 flowContext = (UINT64) flow;
fort_flow_context_flow_set(stat, flow_id, flowContext, isIPv6, inbound);
fort_flow_context_stream_set(stat, flow_id, flowContext, isIPv6, is_tcp);
fort_flow_context_transport_set(stat, flow_id, flowContext, isIPv6, inbound);
}
inline static NTSTATUS fort_flow_context_flow_remove(
PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL inbound)
{
UINT16 layerId;
UINT32 calloutId;
fort_flow_context_flow_init(stat, isIPv6, inbound, &layerId, &calloutId);
return FwpsFlowRemoveContext0(flow_id, layerId, calloutId);
}
inline static NTSTATUS fort_flow_context_stream_remove(
PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp)
{
@ -222,7 +267,27 @@ inline static NTSTATUS fort_flow_context_transport_remove(
return FwpsFlowRemoveContext0(flow_id, layerId, calloutId);
}
static void fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow)
static BOOL fort_flow_context_remove_id(
PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp, BOOL inbound)
{
NTSTATUS status;
status = fort_flow_context_flow_remove(stat, flow_id, isIPv6, inbound);
if (status == STATUS_PENDING)
return FALSE;
status = fort_flow_context_stream_remove(stat, flow_id, isIPv6, is_tcp);
if (status == STATUS_PENDING)
return FALSE;
status = fort_flow_context_transport_remove(stat, flow_id, isIPv6, inbound);
if (status == STATUS_PENDING)
return FALSE;
return TRUE;
}
static BOOL fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow)
{
const UINT64 flow_id = flow->flow_id;
@ -231,13 +296,12 @@ static void fort_flow_context_remove(PFORT_STAT stat, PFORT_FLOW flow)
const BOOL is_tcp = (flow_flags & FORT_FLOW_TCP);
const BOOL isIPv6 = (flow_flags & FORT_FLOW_IP6);
const NTSTATUS stream_status = fort_flow_context_stream_remove(stat, flow_id, isIPv6, is_tcp);
const NTSTATUS transport_status =
fort_flow_context_transport_remove(stat, flow_id, isIPv6, inbound);
return fort_flow_context_remove_id(stat, flow_id, isIPv6, is_tcp, inbound);
}
if (stream_status == STATUS_PENDING || transport_status == STATUS_PENDING) {
fort_stat_flags_set(stat, FORT_STAT_FLOW_PENDING, TRUE);
}
static BOOL fort_flow_is_closed(PFORT_FLOW flow)
{
return (flow->opt.proc_index == FORT_PROC_BAD_INDEX);
}
static void fort_flow_close(PFORT_FLOW flow)
@ -245,39 +309,26 @@ static void fort_flow_close(PFORT_FLOW flow)
flow->opt.proc_index = FORT_PROC_BAD_INDEX;
}
static PFORT_FLOW fort_flow_get(PFORT_STAT stat, UINT64 flow_id, tommy_key_t flow_hash)
{
PFORT_FLOW flow = (PFORT_FLOW) tommy_hashdyn_bucket(&stat->flows_map, flow_hash);
while (flow != NULL) {
if (flow->flow_hash == flow_hash && flow->flow_id == flow_id)
return flow;
flow = flow->next;
}
return NULL;
}
static void fort_flow_free(PFORT_STAT stat, PFORT_FLOW flow)
{
const UINT16 proc_index = flow->opt.proc_index;
if (proc_index != FORT_PROC_BAD_INDEX) {
fort_stat_proc_dec(stat, proc_index);
if (!fort_flow_is_closed(flow)) {
fort_stat_proc_dec(stat, flow->opt.proc_index);
}
tommy_hashdyn_remove_existing(&stat->flows_map, (tommy_hashdyn_node *) flow);
/* Add to free chain */
flow->next = stat->flow_free;
stat->flow_free = flow;
stat->flow_count--;
}
static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key_t flow_hash,
BOOL isIPv6, BOOL is_tcp, BOOL inbound)
static PFORT_FLOW fort_flow_new(
PFORT_STAT stat, UINT64 flow_id, BOOL isIPv6, BOOL is_tcp, BOOL inbound)
{
PFORT_FLOW flow;
stat->flow_count++;
if (stat->flow_free != NULL) {
flow = stat->flow_free;
stat->flow_free = flow->next;
@ -291,8 +342,6 @@ static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key
flow = tommy_arrayof_ref(&stat->flows, size);
}
tommy_hashdyn_insert(&stat->flows_map, (tommy_hashdyn_node *) flow, 0, flow_hash);
flow->flow_id = flow_id;
fort_flow_context_set(stat, flow, isIPv6, is_tcp, inbound);
@ -300,26 +349,26 @@ static PFORT_FLOW fort_flow_new(PFORT_STAT stat, UINT64 flow_id, const tommy_key
return flow;
}
static NTSTATUS fort_flow_add(PFORT_STAT stat, UINT64 flow_id, UCHAR group_index, UINT16 proc_index,
UCHAR speed_limit, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth)
static NTSTATUS fort_flow_add(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id, UCHAR group_index,
UINT16 proc_index, UCHAR speed_limit, BOOL isIPv6, BOOL is_tcp, BOOL inbound,
BOOL is_reauth)
{
const tommy_key_t flow_hash = fort_flow_hash(flow_id);
PFORT_FLOW flow = fort_flow_get(stat, flow_id, flow_hash);
BOOL is_new_flow = FALSE;
if (flow == NULL) {
if (is_reauth) {
/* Block existing flow after reauth. to be able to associate a flow-context */
return FORT_STATUS_FLOW_BLOCK;
/* Remove existing flow context after reauth. to be able to associate a flow-context */
if (!fort_flow_context_remove_id(stat, flow_id, isIPv6, is_tcp, inbound))
return FORT_STATUS_FLOW_BLOCK;
}
flow = fort_flow_new(stat, flow_id, flow_hash, isIPv6, is_tcp, inbound);
flow = fort_flow_new(stat, flow_id, isIPv6, is_tcp, inbound);
if (flow == NULL)
return STATUS_INSUFFICIENT_RESOURCES;
is_new_flow = TRUE;
} else {
is_new_flow = (flow->opt.proc_index == FORT_PROC_BAD_INDEX);
is_new_flow = fort_flow_is_closed(flow);
}
if (is_new_flow) {
@ -340,35 +389,54 @@ FORT_API void fort_stat_open(PFORT_STAT stat)
tommy_hashdyn_init(&stat->procs_map);
tommy_arrayof_init(&stat->flows, sizeof(FORT_FLOW));
tommy_hashdyn_init(&stat->flows_map);
KeInitializeSpinLock(&stat->lock);
}
static BOOL fort_stat_close_flows(PFORT_STAT stat)
static BOOL fort_stat_close_flows(PFORT_STAT stat, UINT32 *flow_index)
{
BOOL is_pending = FALSE;
KLOCK_QUEUE_HANDLE lock_queue;
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
{
assert((fort_stat_flags(stat) & FORT_STAT_CLOSED) != 0);
fort_stat_flags_set(stat, FORT_STAT_FLOW_PENDING, FALSE);
tommy_arrayof *flows = &stat->flows;
tommy_hashdyn_foreach_node_arg(&stat->flows_map, &fort_flow_context_remove, stat);
UINT32 i = *flow_index;
const BOOL not_pending = (fort_stat_flags(stat) & FORT_STAT_FLOW_PENDING) == 0;
for (; stat->flow_count != 0; ++i) {
PFORT_FLOW flow = tommy_arrayof_ref(flows, i);
if (fort_flow_is_closed(flow))
continue;
if (!fort_flow_context_remove(stat, flow)) {
is_pending = TRUE;
break;
}
fort_flow_close(flow);
stat->flow_count--;
}
*flow_index = i;
}
KeReleaseInStackQueuedSpinLock(&lock_queue);
return not_pending;
return !is_pending;
}
FORT_API void fort_stat_close(PFORT_STAT stat)
{
fort_stat_flags_set(stat, FORT_STAT_CLOSED, TRUE);
while (!fort_stat_close_flows(stat)) {
UINT32 flow_index = 0;
while (!fort_stat_close_flows(stat, &flow_index)) {
/* Wait for asynchronously deleting flows */
LARGE_INTEGER delay;
delay.QuadPart = -5000000; /* sleep 500000us (500ms) */
delay.QuadPart = -3000000; /* sleep 300000us (300ms) */
KeDelayExecutionThread(KernelMode, FALSE, &delay);
}
@ -380,11 +448,26 @@ FORT_API void fort_stat_close(PFORT_STAT stat)
tommy_hashdyn_done(&stat->procs_map);
tommy_arrayof_done(&stat->flows);
tommy_hashdyn_done(&stat->flows_map);
KeReleaseInStackQueuedSpinLock(&lock_queue);
}
static void fort_stat_clear_flows(PFORT_STAT stat)
{
tommy_arrayof *flows = &stat->flows;
UINT32 flow_count = stat->flow_count;
for (UINT32 i = 0; flow_count != 0; ++i) {
PFORT_FLOW flow = tommy_arrayof_ref(flows, i);
if (fort_flow_is_closed(flow))
continue;
fort_flow_close(flow);
--flow_count;
}
}
static void fort_stat_clear(PFORT_STAT stat)
{
KLOCK_QUEUE_HANDLE lock_queue;
@ -393,7 +476,8 @@ static void fort_stat_clear(PFORT_STAT stat)
fort_stat_proc_active_clear(stat);
tommy_hashdyn_foreach_node_arg(&stat->procs_map, &fort_stat_proc_free, stat);
tommy_hashdyn_foreach_node(&stat->flows_map, &fort_flow_close);
fort_stat_clear_flows(stat);
KeReleaseInStackQueuedSpinLock(&lock_queue);
}
@ -427,8 +511,8 @@ FORT_API void fort_stat_conf_flags_update(PFORT_STAT stat, const PFORT_CONF_FLAG
KeReleaseInStackQueuedSpinLock(&lock_queue);
}
static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOOL is_reauth,
BOOL *is_new_proc, PFORT_STAT_PROC *proc)
static NTSTATUS fort_flow_associate_proc(
PFORT_STAT stat, UINT32 process_id, BOOL *is_new_proc, PFORT_STAT_PROC *proc)
{
if ((fort_stat_flags(stat) & FORT_STAT_LOG) == 0)
return STATUS_DEVICE_DATA_ERROR;
@ -438,11 +522,6 @@ static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOO
*proc = fort_stat_proc_get(stat, process_id, proc_hash);
if (*proc == NULL) {
if (is_reauth) {
/* Block existing flow after reauth. to be able to associate a flow-context */
return FORT_STATUS_FLOW_BLOCK;
}
*proc = fort_stat_proc_add(stat, process_id);
if (*proc == NULL)
@ -454,9 +533,9 @@ static NTSTATUS fort_flow_associate_proc(PFORT_STAT stat, UINT32 process_id, BOO
return STATUS_SUCCESS;
}
FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 process_id,
UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth,
BOOL *is_new_proc)
FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id,
UINT32 process_id, UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound,
BOOL is_reauth, BOOL *is_new_proc)
{
NTSTATUS status;
@ -464,14 +543,14 @@ FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 pr
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
PFORT_STAT_PROC proc = NULL;
status = fort_flow_associate_proc(stat, process_id, is_reauth, is_new_proc, &proc);
status = fort_flow_associate_proc(stat, process_id, is_new_proc, &proc);
/* Add flow */
if (NT_SUCCESS(status)) {
const UCHAR speed_limit = fort_stat_group_speed_limit(stat, group_index);
status = fort_flow_add(stat, flow_id, group_index, proc->proc_index, speed_limit, isIPv6,
is_tcp, inbound, is_reauth);
status = fort_flow_add(stat, flow, flow_id, group_index, proc->proc_index, speed_limit,
isIPv6, is_tcp, inbound, is_reauth);
if (!NT_SUCCESS(status) && *is_new_proc) {
fort_stat_proc_free(stat, proc);
@ -505,18 +584,14 @@ FORT_API void fort_flow_classify(PFORT_STAT stat, UINT64 flowContext, UINT32 dat
KLOCK_QUEUE_HANDLE lock_queue;
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
if ((fort_stat_flags(stat) & FORT_STAT_LOG) != 0) {
const FORT_FLOW_OPT opt = flow->opt;
if ((fort_stat_flags(stat) & FORT_STAT_LOG) != 0 && !fort_flow_is_closed(flow)) {
PFORT_STAT_PROC proc = tommy_arrayof_ref(&stat->procs, flow->opt.proc_index);
UINT32 *proc_bytes = inbound ? &proc->traf.in_bytes : &proc->traf.out_bytes;
if (opt.proc_index != FORT_PROC_BAD_INDEX) {
PFORT_STAT_PROC proc = tommy_arrayof_ref(&stat->procs, opt.proc_index);
UINT32 *proc_bytes = inbound ? &proc->traf.in_bytes : &proc->traf.out_bytes;
/* Add traffic to process */
*proc_bytes += data_len;
/* Add traffic to process */
*proc_bytes += data_len;
fort_stat_proc_active_add(stat, proc);
}
fort_stat_proc_active_add(stat, proc);
}
KeReleaseInStackQueuedSpinLock(&lock_queue);

View File

@ -61,45 +61,37 @@ typedef struct fort_flow_opt
};
} FORT_FLOW_OPT, *PFORT_FLOW_OPT;
/* Synchronize with tommy_hashdyn_node! */
typedef struct fort_flow
{
struct fort_flow *next;
struct fort_flow *prev;
FORT_FLOW_OPT opt;
union {
#if defined(_WIN64)
UINT64 flow_id;
#else
FORT_FLOW_OPT opt;
#endif
void *data; /* tommy_hashdyn_node::data */
struct fort_flow *next;
};
tommy_key_t flow_hash; /* tommy_hashdyn_node::index */
#if defined(_WIN64)
FORT_FLOW_OPT opt;
#else
UINT64 flow_id;
#endif
} FORT_FLOW, *PFORT_FLOW;
#define FORT_STAT_CLOSED 0x01
#define FORT_STAT_LOG 0x02
#define FORT_STAT_TIME_CHANGED 0x04
#define FORT_STAT_FLOW_PENDING 0x08 /* used only on driver unloading */
typedef struct fort_stat
{
UCHAR volatile flags;
UINT16 proc_active_count;
UINT32 flow_count;
UINT32 connect4_id;
UINT32 connect6_id;
UINT32 accept4_id;
UINT32 accept6_id;
UINT32 stream4_id;
UINT32 stream6_id;
UINT32 datagram4_id;
UINT32 datagram6_id;
UINT32 in_transport4_id;
UINT32 in_transport6_id;
UINT32 out_transport4_id;
@ -114,7 +106,6 @@ typedef struct fort_stat
tommy_hashdyn procs_map;
tommy_arrayof flows;
tommy_hashdyn flows_map;
FORT_CONF_GROUP conf_group;
@ -145,9 +136,9 @@ FORT_API void fort_stat_conf_update(PFORT_STAT stat, const PFORT_CONF_IO conf_io
FORT_API void fort_stat_conf_flags_update(PFORT_STAT stat, const PFORT_CONF_FLAGS conf_flags);
FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, UINT64 flow_id, UINT32 process_id,
UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound, BOOL is_reauth,
BOOL *is_new_proc);
FORT_API NTSTATUS fort_flow_associate(PFORT_STAT stat, PFORT_FLOW flow, UINT64 flow_id,
UINT32 process_id, UCHAR group_index, BOOL isIPv6, BOOL is_tcp, BOOL inbound,
BOOL is_reauth, BOOL *is_new_proc);
FORT_API void fort_flow_delete(PFORT_STAT stat, UINT64 flowContext);

View File

@ -8,11 +8,11 @@
static void NTAPI fort_worker_callback(PDEVICE_OBJECT device, PVOID context)
{
UNUSED(device);
PFORT_WORKER worker = (PFORT_WORKER) context;
const UCHAR id_bits = InterlockedAnd8(&worker->id_bits, 0);
UNUSED(device);
if (id_bits & (1 << FORT_WORKER_REAUTH)) {
worker->funcs[FORT_WORKER_REAUTH]();
}