diff --git a/src/driver/common/fortprov.c b/src/driver/common/fortprov.c index 731b63ef..756ba973 100644 --- a/src/driver/common/fortprov.c +++ b/src/driver/common/fortprov.c @@ -3,6 +3,19 @@ #include "fortdef.h" #include "fortprov.h" +FORT_API DWORD fort_prov_trans_close(HANDLE transEngine, DWORD status) +{ + if (NT_SUCCESS(status)) { + status = fort_prov_trans_commit(transEngine); + } else { + fort_prov_trans_abort(transEngine); + } + + fort_prov_close(transEngine); + + return status; +} + FORT_API void fort_prov_unregister(HANDLE transEngine) { HANDLE engine = transEngine; @@ -167,11 +180,7 @@ FORT_API DWORD fort_prov_register(HANDLE transEngine, BOOL is_boot) } if (!transEngine) { - if (NT_SUCCESS(status)) { - status = fort_prov_trans_commit(engine); - } - - fort_prov_close(engine); + status = fort_prov_trans_close(transEngine, status); } end: @@ -243,11 +252,7 @@ FORT_API DWORD fort_prov_flow_register(HANDLE transEngine, BOOL filter_transport } if (!transEngine) { - if (NT_SUCCESS(status)) { - status = fort_prov_trans_commit(engine); - } - - fort_prov_close(engine); + status = fort_prov_trans_close(transEngine, status); } end: @@ -313,13 +318,7 @@ FORT_API DWORD fort_prov_reauth(HANDLE transEngine) } if (!transEngine) { - if (NT_SUCCESS(status)) { - status = fort_prov_trans_commit(engine); - } else { - fort_prov_trans_abort(engine); - } - - fort_prov_close(engine); + status = fort_prov_trans_close(transEngine, status); } end: diff --git a/src/driver/common/fortprov.h b/src/driver/common/fortprov.h index 49420d4d..e9fe9ff8 100644 --- a/src/driver/common/fortprov.h +++ b/src/driver/common/fortprov.h @@ -13,6 +13,8 @@ extern "C" { #endif +FORT_API DWORD fort_prov_trans_close(HANDLE transEngine, DWORD status); + FORT_API void fort_prov_unregister(HANDLE transEngine); FORT_API void fort_prov_flow_unregister(HANDLE transEngine); diff --git a/src/driver/fortcout.c b/src/driver/fortcout.c index 84f281ea..a99f3f45 100644 --- a/src/driver/fortcout.c +++ b/src/driver/fortcout.c @@ -605,10 +605,54 @@ FORT_API void fort_callout_remove(void) } } +static NTSTATUS fort_callout_force_reauth_prov( + const FORT_CONF_FLAGS old_conf_flags, const FORT_CONF_FLAGS conf_flags, HANDLE engine) +{ + NTSTATUS status; + + /* Check provider filters */ + BOOL prov_recreated = FALSE; + if (old_conf_flags.prov_boot != conf_flags.prov_boot) { + fort_prov_unregister(engine); + + if ((status = fort_prov_register(engine, conf_flags.prov_boot))) + return status; + + prov_recreated = TRUE; + } + + /* Check flow filter */ + { + PFORT_STAT stat = &fort_device()->stat; + + const PFORT_CONF_GROUP conf_group = &stat->conf_group; + const UINT16 filter_bits = (conf_group->fragment_bits | conf_group->limit_bits); + + const BOOL old_filter_transport = + fort_device_flag(&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT) != 0; + const BOOL filter_transport = (conf_flags.group_bits & filter_bits) != 0; + + if (prov_recreated || old_conf_flags.log_stat != conf_flags.log_stat + || old_filter_transport != filter_transport) { + fort_device_flag_set( + &fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT, filter_transport); + + fort_prov_flow_unregister(engine); + + if (conf_flags.log_stat) { + if ((status = fort_prov_flow_register(engine, filter_transport))) + return status; + } + } + } + + /* Force reauth filter */ + return fort_prov_reauth(engine); +} + FORT_API NTSTATUS fort_callout_force_reauth( const FORT_CONF_FLAGS old_conf_flags, UINT32 defer_flush_bits) { - PFORT_STAT stat = &fort_device()->stat; NTSTATUS status; fort_timer_update(&fort_device()->log_timer, FALSE); @@ -626,6 +670,8 @@ FORT_API NTSTATUS fort_callout_force_reauth( /* Handle log_stat */ if (old_conf_flags.log_stat != conf_flags.log_stat) { + PFORT_STAT stat = &fort_device()->stat; + fort_stat_update(stat, conf_flags.log_stat); if (!conf_flags.log_stat) { @@ -639,64 +685,19 @@ FORT_API NTSTATUS fort_callout_force_reauth( /* Open provider */ HANDLE engine; - if ((status = fort_prov_open(&engine))) - goto end; + if (!(status = fort_prov_open(&engine))) { + fort_prov_trans_begin(engine); - fort_prov_trans_begin(engine); + status = fort_callout_force_reauth_prov(old_conf_flags, conf_flags, engine); - /* Check provider filters */ - BOOL prov_recreated = FALSE; - if (old_conf_flags.prov_boot != conf_flags.prov_boot) { - fort_prov_unregister(engine); - - if ((status = fort_prov_register(engine, conf_flags.prov_boot))) - goto cleanup; - - prov_recreated = TRUE; + status = fort_prov_trans_close(engine, status); } - /* Check flow filter */ - { - const PFORT_CONF_GROUP conf_group = &stat->conf_group; - const UINT16 filter_bits = (conf_group->fragment_bits | conf_group->limit_bits); - - const BOOL old_filter_transport = - fort_device_flag(&fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT) != 0; - const BOOL filter_transport = (conf_flags.group_bits & filter_bits) != 0; - - if (prov_recreated || old_conf_flags.log_stat != conf_flags.log_stat - || old_filter_transport != filter_transport) { - fort_device_flag_set( - &fort_device()->conf, FORT_DEVICE_FILTER_TRANSPORT, filter_transport); - - fort_prov_flow_unregister(engine); - - if (conf_flags.log_stat) { - if ((status = fort_prov_flow_register(engine, filter_transport))) - goto cleanup; - } - } - } - - /* Force reauth filter */ - if ((status = fort_prov_reauth(engine))) - goto cleanup; - - fort_timer_update(&fort_device()->log_timer, - (conf_flags.allow_all_new || conf_flags.log_blocked || conf_flags.log_stat - || conf_flags.log_blocked_ip)); - -cleanup: if (NT_SUCCESS(status)) { - status = fort_prov_trans_commit(engine); + fort_timer_update(&fort_device()->log_timer, + (conf_flags.allow_all_new || conf_flags.log_blocked || conf_flags.log_stat + || conf_flags.log_blocked_ip)); } else { - fort_prov_trans_abort(engine); - } - - fort_prov_close(engine); - -end: - if (!NT_SUCCESS(status)) { DbgPrintEx(DPFLTR_IHVNETWORK_ID, DPFLTR_ERROR_LEVEL, "FORT: Callout Reauth: Error: %x\n", status); }