Driver: Refactor fort_callout_force_reauth().

This commit is contained in:
Nodir Temirkhodjaev 2021-03-19 12:16:55 +03:00
parent 3f6f4e4302
commit e0ba5c668f
3 changed files with 72 additions and 70 deletions

View File

@ -3,6 +3,19 @@
#include "fortdef.h" #include "fortdef.h"
#include "fortprov.h" #include "fortprov.h"
FORT_API DWORD fort_prov_trans_close(HANDLE transEngine, DWORD status)
{
if (NT_SUCCESS(status)) {
status = fort_prov_trans_commit(transEngine);
} else {
fort_prov_trans_abort(transEngine);
}
fort_prov_close(transEngine);
return status;
}
FORT_API void fort_prov_unregister(HANDLE transEngine) FORT_API void fort_prov_unregister(HANDLE transEngine)
{ {
HANDLE engine = transEngine; HANDLE engine = transEngine;
@ -167,11 +180,7 @@ FORT_API DWORD fort_prov_register(HANDLE transEngine, BOOL is_boot)
} }
if (!transEngine) { if (!transEngine) {
if (NT_SUCCESS(status)) { status = fort_prov_trans_close(transEngine, status);
status = fort_prov_trans_commit(engine);
}
fort_prov_close(engine);
} }
end: end:
@ -243,11 +252,7 @@ FORT_API DWORD fort_prov_flow_register(HANDLE transEngine, BOOL filter_transport
} }
if (!transEngine) { if (!transEngine) {
if (NT_SUCCESS(status)) { status = fort_prov_trans_close(transEngine, status);
status = fort_prov_trans_commit(engine);
}
fort_prov_close(engine);
} }
end: end:
@ -313,13 +318,7 @@ FORT_API DWORD fort_prov_reauth(HANDLE transEngine)
} }
if (!transEngine) { if (!transEngine) {
if (NT_SUCCESS(status)) { status = fort_prov_trans_close(transEngine, status);
status = fort_prov_trans_commit(engine);
} else {
fort_prov_trans_abort(engine);
}
fort_prov_close(engine);
} }
end: end:

View File

@ -13,6 +13,8 @@
extern "C" { extern "C" {
#endif #endif
FORT_API DWORD fort_prov_trans_close(HANDLE transEngine, DWORD status);
FORT_API void fort_prov_unregister(HANDLE transEngine); FORT_API void fort_prov_unregister(HANDLE transEngine);
FORT_API void fort_prov_flow_unregister(HANDLE transEngine); FORT_API void fort_prov_flow_unregister(HANDLE transEngine);

View File

@ -605,10 +605,54 @@ FORT_API void fort_callout_remove(void)
} }
} }
static NTSTATUS fort_callout_force_reauth_prov(
const FORT_CONF_FLAGS old_conf_flags, const FORT_CONF_FLAGS conf_flags, HANDLE engine)
{
NTSTATUS status;
/* Check provider filters */
BOOL prov_recreated = FALSE;
if (old_conf_flags.prov_boot != conf_flags.prov_boot) {
fort_prov_unregister(engine);
if ((status = fort_prov_register(engine, conf_flags.prov_boot)))
return status;
prov_recreated = TRUE;
}
/* Check flow filter */
{
PFORT_STAT stat = &fort_device()->stat;
const PFORT_CONF_GROUP conf_group = &stat->conf_group;
const UINT16 filter_bits = (conf_group->fragment_bits | conf_group->limit_bits);
const BOOL old_filter_transport =
fort_device_flag(&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT) != 0;
const BOOL filter_transport = (conf_flags.group_bits & filter_bits) != 0;
if (prov_recreated || old_conf_flags.log_stat != conf_flags.log_stat
|| old_filter_transport != filter_transport) {
fort_device_flag_set(
&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT, filter_transport);
fort_prov_flow_unregister(engine);
if (conf_flags.log_stat) {
if ((status = fort_prov_flow_register(engine, filter_transport)))
return status;
}
}
}
/* Force reauth filter */
return fort_prov_reauth(engine);
}
FORT_API NTSTATUS fort_callout_force_reauth( FORT_API NTSTATUS fort_callout_force_reauth(
const FORT_CONF_FLAGS old_conf_flags, UINT32 defer_flush_bits) const FORT_CONF_FLAGS old_conf_flags, UINT32 defer_flush_bits)
{ {
PFORT_STAT stat = &fort_device()->stat;
NTSTATUS status; NTSTATUS status;
fort_timer_update(&fort_device()->log_timer, FALSE); fort_timer_update(&fort_device()->log_timer, FALSE);
@ -626,6 +670,8 @@ FORT_API NTSTATUS fort_callout_force_reauth(
/* Handle log_stat */ /* Handle log_stat */
if (old_conf_flags.log_stat != conf_flags.log_stat) { if (old_conf_flags.log_stat != conf_flags.log_stat) {
PFORT_STAT stat = &fort_device()->stat;
fort_stat_update(stat, conf_flags.log_stat); fort_stat_update(stat, conf_flags.log_stat);
if (!conf_flags.log_stat) { if (!conf_flags.log_stat) {
@ -639,64 +685,19 @@ FORT_API NTSTATUS fort_callout_force_reauth(
/* Open provider */ /* Open provider */
HANDLE engine; HANDLE engine;
if ((status = fort_prov_open(&engine))) if (!(status = fort_prov_open(&engine))) {
goto end;
fort_prov_trans_begin(engine); fort_prov_trans_begin(engine);
/* Check provider filters */ status = fort_callout_force_reauth_prov(old_conf_flags, conf_flags, engine);
BOOL prov_recreated = FALSE;
if (old_conf_flags.prov_boot != conf_flags.prov_boot) {
fort_prov_unregister(engine);
if ((status = fort_prov_register(engine, conf_flags.prov_boot))) status = fort_prov_trans_close(engine, status);
goto cleanup;
prov_recreated = TRUE;
} }
/* Check flow filter */ if (NT_SUCCESS(status)) {
{
const PFORT_CONF_GROUP conf_group = &stat->conf_group;
const UINT16 filter_bits = (conf_group->fragment_bits | conf_group->limit_bits);
const BOOL old_filter_transport =
fort_device_flag(&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT) != 0;
const BOOL filter_transport = (conf_flags.group_bits & filter_bits) != 0;
if (prov_recreated || old_conf_flags.log_stat != conf_flags.log_stat
|| old_filter_transport != filter_transport) {
fort_device_flag_set(
&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT, filter_transport);
fort_prov_flow_unregister(engine);
if (conf_flags.log_stat) {
if ((status = fort_prov_flow_register(engine, filter_transport)))
goto cleanup;
}
}
}
/* Force reauth filter */
if ((status = fort_prov_reauth(engine)))
goto cleanup;
fort_timer_update(&fort_device()->log_timer, fort_timer_update(&fort_device()->log_timer,
(conf_flags.allow_all_new || conf_flags.log_blocked || conf_flags.log_stat (conf_flags.allow_all_new || conf_flags.log_blocked || conf_flags.log_stat
|| conf_flags.log_blocked_ip)); || conf_flags.log_blocked_ip));
cleanup:
if (NT_SUCCESS(status)) {
status = fort_prov_trans_commit(engine);
} else { } else {
fort_prov_trans_abort(engine);
}
fort_prov_close(engine);
end:
if (!NT_SUCCESS(status)) {
DbgPrintEx(DPFLTR_IHVNETWORK_ID, DPFLTR_ERROR_LEVEL, "FORT: Callout Reauth: Error: %x\n", DbgPrintEx(DPFLTR_IHVNETWORK_ID, DPFLTR_ERROR_LEVEL, "FORT: Callout Reauth: Error: %x\n",
status); status);
} }