Driver: Store all flow handles to remove on close.

Workaround to properly unload the driver.
Unregistering callout doesn't auto-remove flow contexts.
This commit is contained in:
Nodir Temirkhodjaev 2017-12-05 17:12:39 +05:00
parent c9859945c1
commit fb65a3af36
2 changed files with 197 additions and 50 deletions

View File

@ -317,8 +317,8 @@ fort_callout_flow_classify_v4 (const FWPS_INCOMING_VALUES0 *inFixedValues,
{ {
FWPS_STREAM_DATA0 *streamData = packet->streamData; FWPS_STREAM_DATA0 *streamData = packet->streamData;
fort_stat_flow_classify(&g_device->stat, inMetaValues->flowHandle, fort_stat_flow_classify(&g_device->stat, flowContext,
g_device->flow4_id, flowContext, (UINT32) streamData->dataLength, (UINT32) streamData->dataLength,
(streamData->flags & FWPS_STREAM_FLAG_RECEIVE) != 0); (streamData->flags & FWPS_STREAM_FLAG_RECEIVE) != 0);
classifyOut->actionType = FWP_ACTION_CONTINUE; classifyOut->actionType = FWP_ACTION_CONTINUE;
@ -439,7 +439,7 @@ fort_callout_force_reauth (PDEVICE_OBJECT device,
end: end:
fort_timer_update(&g_device->timer, conf_flags); fort_timer_update(&g_device->timer, conf_flags);
fort_stat_update(&g_device->stat, conf_flags); fort_stat_update(&g_device->stat, conf_flags, g_device->flow4_id);
return status; return status;
} }
@ -648,15 +648,16 @@ fort_driver_unload (PDRIVER_OBJECT driver)
UNICODE_STRING device_link; UNICODE_STRING device_link;
if (g_device != NULL) { if (g_device != NULL) {
fort_timer_close(&g_device->timer);
fort_stat_close(&g_device->stat, g_device->flow4_id);
fort_buffer_close(&g_device->buffer);
if (!g_device->prov_boot) { if (!g_device->prov_boot) {
fort_prov_unregister(); fort_prov_unregister();
} }
fort_callout_remove(); fort_callout_remove();
fort_timer_close(&g_device->timer);
fort_stat_close(&g_device->stat);
fort_buffer_close(&g_device->buffer);
} }
RtlInitUnicodeString(&device_link, DOS_DEVICE_NAME); RtlInitUnicodeString(&device_link, DOS_DEVICE_NAME);

View File

@ -3,10 +3,12 @@
#define FORT_STAT_POOL_TAG 'SwfF' #define FORT_STAT_POOL_TAG 'SwfF'
#define FORT_PROC_BAD_INDEX ((UINT16) -1) #define FORT_PROC_BAD_INDEX ((UINT16) -1)
#define FORT_FLOW_BAD_INDEX ((UINT32) -1)
#define FORT_FLOW_BAD_ID ((UINT64) -1)
typedef struct fort_stat_traf { typedef struct fort_stat_traf {
UINT32 volatile in_bytes; UINT32 in_bytes;
UINT32 volatile out_bytes; UINT32 out_bytes;
} FORT_STAT_TRAF, *PFORT_STAT_TRAF; } FORT_STAT_TRAF, *PFORT_STAT_TRAF;
typedef struct fort_stat_proc { typedef struct fort_stat_proc {
@ -15,25 +17,52 @@ typedef struct fort_stat_proc {
UINT32 process_id; UINT32 process_id;
union { union {
LARGE_INTEGER volatile traf_all; LARGE_INTEGER traf_all;
FORT_STAT_TRAF traf; FORT_STAT_TRAF traf;
}; };
} FORT_STAT_PROC, *PFORT_STAT_PROC; } FORT_STAT_PROC, *PFORT_STAT_PROC;
typedef struct fort_stat_flow {
union {
LARGE_INTEGER l;
UINT64 flow_id;
};
} FORT_STAT_FLOW, *PFORT_STAT_FLOW;
typedef struct fort_stat { typedef struct fort_stat {
UINT16 id; UINT16 volatile closing;
UINT16 version;
UINT16 proc_count; UINT16 proc_count;
UINT16 proc_top; UINT16 proc_top;
UINT16 proc_end; UINT16 proc_end;
UINT16 proc_free_index; UINT16 proc_free_index;
UINT32 flow_count;
UINT32 flow_top;
UINT32 flow_end;
UINT32 flow_free_index;
PFORT_STAT_PROC procs; PFORT_STAT_PROC procs;
PFORT_STAT_FLOW flows;
KSPIN_LOCK lock; KSPIN_LOCK lock;
} FORT_STAT, *PFORT_STAT; } FORT_STAT, *PFORT_STAT;
static void
fort_stat_array_del (PVOID p)
{
ExFreePoolWithTag(p, FORT_STAT_POOL_TAG);
}
static PVOID
fort_stat_array_new (SIZE_T size)
{
return ExAllocatePoolWithTag(NonPagedPool, size, FORT_STAT_POOL_TAG);
}
static UINT16 static UINT16
fort_stat_proc_index (PFORT_STAT stat, UINT32 process_id) fort_stat_proc_index (PFORT_STAT stat, UINT32 process_id)
{ {
@ -82,19 +111,13 @@ fort_stat_proc_dec (PFORT_STAT stat, UINT16 proc_index)
proc->refcount--; proc->refcount--;
} }
static void
fort_stat_proc_del (PFORT_STAT_PROC procs)
{
ExFreePoolWithTag(procs, FORT_STAT_POOL_TAG);
}
static BOOL static BOOL
fort_stat_proc_realloc (PFORT_STAT stat) fort_stat_proc_realloc (PFORT_STAT stat)
{ {
const UINT16 proc_end = stat->proc_end; const UINT16 proc_end = stat->proc_end;
const UINT16 new_end = (proc_end ? proc_end : 8) * 3 / 2; const UINT16 new_end = (proc_end ? proc_end : 8) * 3 / 2;
PFORT_STAT_PROC new_procs = ExAllocatePoolWithTag(NonPagedPool, PFORT_STAT_PROC new_procs = fort_stat_array_new(
new_end * sizeof(FORT_STAT_PROC), FORT_STAT_POOL_TAG); new_end * sizeof(FORT_STAT_PROC));
if (new_procs == NULL) if (new_procs == NULL)
return FALSE; return FALSE;
@ -103,7 +126,7 @@ fort_stat_proc_realloc (PFORT_STAT stat)
PFORT_STAT_PROC procs = stat->procs; PFORT_STAT_PROC procs = stat->procs;
RtlCopyMemory(new_procs, procs, stat->proc_top * sizeof(FORT_STAT_PROC)); RtlCopyMemory(new_procs, procs, stat->proc_top * sizeof(FORT_STAT_PROC));
fort_stat_proc_del(procs); fort_stat_array_del(procs);
} }
stat->proc_end = new_end; stat->proc_end = new_end;
@ -141,23 +164,109 @@ fort_stat_proc_add (PFORT_STAT stat, UINT32 process_id)
return proc_index; return proc_index;
} }
static void
fort_stat_flow_free (PFORT_STAT stat, UINT32 flow_index)
{
if (flow_index == stat->flow_top - 1) {
/* Chop from buffer */
stat->flow_top--;
} else {
/* Add to free chain */
PFORT_STAT_FLOW flow = &stat->flows[flow_index];
flow->l.HighPart = FORT_FLOW_BAD_INDEX;
flow->l.LowPart = stat->flow_free_index;
stat->flow_free_index = flow_index;
}
stat->flow_count--;
}
static BOOL
fort_stat_flow_realloc (PFORT_STAT stat)
{
const UINT32 flow_end = stat->flow_end;
const UINT32 new_end = (flow_end ? flow_end : 8) * 2;
PFORT_STAT_FLOW new_flows = fort_stat_array_new(
new_end * sizeof(FORT_STAT_FLOW));
if (new_flows == NULL)
return FALSE;
if (flow_end) {
PFORT_STAT_FLOW flows = stat->flows;
RtlCopyMemory(new_flows, flows, stat->flow_top * sizeof(FORT_STAT_FLOW));
fort_stat_array_del(flows);
}
stat->flow_end = new_end;
stat->flows = new_flows;
return TRUE;
}
static UINT32
fort_stat_flow_add (PFORT_STAT stat, UINT64 flow_id)
{
PFORT_STAT_FLOW flow;
UINT32 flow_index = stat->flow_free_index;
if (flow_index != FORT_FLOW_BAD_INDEX) {
flow = &stat->flows[flow_index];
stat->flow_free_index = flow->l.LowPart;
} else {
if (stat->flow_top >= stat->flow_end
&& !fort_stat_flow_realloc(stat)) {
return FORT_FLOW_BAD_INDEX;
}
flow_index = stat->flow_top++;
flow = &stat->flows[flow_index];
}
flow->flow_id = flow_id;
stat->flow_count++;
return flow_index;
}
static NTSTATUS
fort_stat_flow_set_context (UINT64 flow_id, UINT32 callout_id,
UINT64 flow_context)
{
return FwpsFlowAssociateContext0(flow_id, FWPS_LAYER_STREAM_V4,
callout_id, flow_context);
}
static void
fort_stat_flow_remove_context (UINT64 flow_id, UINT32 callout_id)
{
FwpsFlowRemoveContext0(flow_id, FWPS_LAYER_STREAM_V4, callout_id);
}
static void static void
fort_stat_init (PFORT_STAT stat) fort_stat_init (PFORT_STAT stat)
{ {
stat->proc_free_index = FORT_PROC_BAD_INDEX; stat->proc_free_index = FORT_PROC_BAD_INDEX;
stat->flow_free_index = FORT_FLOW_BAD_INDEX;
KeInitializeSpinLock(&stat->lock); KeInitializeSpinLock(&stat->lock);
} }
static void static void
fort_stat_close (PFORT_STAT stat) fort_stat_close (PFORT_STAT stat, UINT32 callout_id)
{ {
KLOCK_QUEUE_HANDLE lock_queue; KLOCK_QUEUE_HANDLE lock_queue;
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
if (stat->procs != NULL) { stat->closing = TRUE;
fort_stat_proc_del(stat->procs); stat->version++;
if (stat->procs != NULL) {
fort_stat_array_del(stat->procs);
stat->procs = NULL; stat->procs = NULL;
stat->proc_count = 0; stat->proc_count = 0;
@ -165,9 +274,35 @@ fort_stat_close (PFORT_STAT stat)
stat->proc_end = 0; stat->proc_end = 0;
stat->proc_free_index = FORT_PROC_BAD_INDEX; stat->proc_free_index = FORT_PROC_BAD_INDEX;
stat->id++;
} }
if (stat->flows != NULL) {
// Remove flow contexts
{
PFORT_STAT_FLOW flow = stat->flows;
UINT32 count = stat->flow_count;
for (; count != 0; ++flow) {
if (flow->l.HighPart == FORT_FLOW_BAD_INDEX)
continue;
fort_stat_flow_remove_context(flow->flow_id, callout_id);
--count;
}
}
fort_stat_array_del(stat->flows);
stat->flows = NULL;
stat->flow_count = 0;
stat->flow_top = 0;
stat->flow_end = 0;
stat->flow_free_index = FORT_FLOW_BAD_INDEX;
}
stat->closing = FALSE;
KeReleaseInStackQueuedSpinLock(&lock_queue); KeReleaseInStackQueuedSpinLock(&lock_queue);
} }
@ -177,12 +312,11 @@ fort_stat_flow_associate (PFORT_STAT stat, UINT64 flow_id,
{ {
KLOCK_QUEUE_HANDLE lock_queue; KLOCK_QUEUE_HANDLE lock_queue;
UINT64 flow_context; UINT64 flow_context;
UINT32 flow_index;
UINT16 proc_index; UINT16 proc_index;
BOOL is_new = FALSE; BOOL is_new = FALSE;
NTSTATUS status; NTSTATUS status;
FwpsFlowRemoveContext0(flow_id, FWPS_LAYER_STREAM_V4, callout_id);
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
proc_index = fort_stat_proc_index(stat, process_id); proc_index = fort_stat_proc_index(stat, process_id);
@ -198,21 +332,33 @@ fort_stat_flow_associate (PFORT_STAT stat, UINT64 flow_id,
is_new = TRUE; is_new = TRUE;
} }
flow_context = (UINT64) process_id flow_index = fort_stat_flow_add(stat, flow_id);
| ((UINT64) proc_index << 32)
| ((UINT64) stat->id << 48);
status = FwpsFlowAssociateContext0(flow_id, if (flow_index == FORT_FLOW_BAD_INDEX) {
FWPS_LAYER_STREAM_V4, callout_id, flow_context); status = STATUS_INSUFFICIENT_RESOURCES;
goto cleanup;
}
flow_context = (UINT64) flow_index
| ((UINT64) proc_index << 32)
| ((UINT64) stat->version << 48);
status = fort_stat_flow_set_context(flow_id, callout_id, flow_context);
if (NT_SUCCESS(status)) { if (NT_SUCCESS(status)) {
fort_stat_proc_inc(stat, proc_index); fort_stat_proc_inc(stat, proc_index);
} }
else if (is_new) {
cleanup:
if (!NT_SUCCESS(status)) {
fort_stat_flow_free(stat, flow_index);
if (is_new) {
PFORT_STAT_PROC proc = &stat->procs[proc_index]; PFORT_STAT_PROC proc = &stat->procs[proc_index];
fort_stat_proc_free(stat, proc, proc_index); fort_stat_proc_free(stat, proc, proc_index);
} }
}
end: end:
KeReleaseInStackQueuedSpinLock(&lock_queue); KeReleaseInStackQueuedSpinLock(&lock_queue);
@ -224,30 +370,35 @@ static void
fort_stat_flow_delete (PFORT_STAT stat, UINT64 flow_context) fort_stat_flow_delete (PFORT_STAT stat, UINT64 flow_context)
{ {
KLOCK_QUEUE_HANDLE lock_queue; KLOCK_QUEUE_HANDLE lock_queue;
const UINT32 flow_index = (UINT32) (flow_context >> 32);
const UINT16 proc_index = (UINT16) (flow_context >> 32); const UINT16 proc_index = (UINT16) (flow_context >> 32);
const UINT16 stat_id = (UINT16) (flow_context >> 48); const UINT16 stat_version = (UINT16) (flow_context >> 48);
if (stat->closing)
return;
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
if (stat_id == stat->id) { if (stat_version == stat->version) {
fort_stat_flow_free(stat, flow_index);
fort_stat_proc_dec(stat, proc_index); fort_stat_proc_dec(stat, proc_index);
} }
KeReleaseInStackQueuedSpinLock(&lock_queue); KeReleaseInStackQueuedSpinLock(&lock_queue);
} }
static void static void
fort_stat_flow_classify (PFORT_STAT stat, UINT64 flow_id, fort_stat_flow_classify (PFORT_STAT stat, UINT64 flow_context,
UINT32 callout_id, UINT64 flow_context,
UINT32 data_len, BOOL inbound) UINT32 data_len, BOOL inbound)
{ {
PFORT_STAT_PROC proc; PFORT_STAT_PROC proc;
KLOCK_QUEUE_HANDLE lock_queue; KLOCK_QUEUE_HANDLE lock_queue;
const UINT32 process_id = (UINT32) flow_context;
const UINT16 proc_index = (UINT16) (flow_context >> 32); const UINT16 proc_index = (UINT16) (flow_context >> 32);
const UINT16 stat_id = (UINT16) (flow_context >> 48); const UINT16 stat_version = (UINT16) (flow_context >> 48);
BOOL is_old_flow = FALSE;
if (stat->closing)
return;
KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue); KeAcquireInStackQueuedSpinLock(&stat->lock, &lock_queue);
if (stat_id == stat->id) { if (stat_version == stat->version) {
proc = &stat->procs[proc_index]; proc = &stat->procs[proc_index];
if (inbound) { if (inbound) {
@ -255,21 +406,16 @@ fort_stat_flow_classify (PFORT_STAT stat, UINT64 flow_id,
} else { } else {
proc->traf.out_bytes += data_len; proc->traf.out_bytes += data_len;
} }
} else {
is_old_flow = TRUE;
} }
KeReleaseInStackQueuedSpinLock(&lock_queue); KeReleaseInStackQueuedSpinLock(&lock_queue);
if (is_old_flow) {
FwpsFlowRemoveContext0(flow_id, FWPS_LAYER_STREAM_V4, callout_id);
}
} }
static void static void
fort_stat_update (PFORT_STAT stat, const FORT_CONF_FLAGS conf_flags) fort_stat_update (PFORT_STAT stat, const FORT_CONF_FLAGS conf_flags,
UINT32 callout_id)
{ {
if (!conf_flags.log_stat) { if (!conf_flags.log_stat) {
fort_stat_close(stat); fort_stat_close(stat, callout_id);
} }
} }