diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index 34b5d97..256ea1b 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -17,8 +17,8 @@ #include "transcription-utils.h" #include "translation/translation.h" #include "translation/translation-includes.h" - -#define SEND_TIMED_METADATA_URL "http://localhost:8080/timed-metadata" +#include "whisper-utils/whisper-utils.h" +#include "whisper-utils/whisper-model-utils.h" void send_caption_to_source(const std::string &target_source_name, const std::string &caption, struct transcription_filter_data *gf) @@ -130,7 +130,13 @@ void set_text_callback(struct transcription_filter_data *gf, if (gf->caption_to_stream) { obs_output_t *streaming_output = obs_frontend_get_streaming_output(); if (streaming_output) { - obs_output_output_caption_text1(streaming_output, str_copy.c_str()); + // calculate the duration in seconds + const uint64_t duration = + result.end_timestamp_ms - result.start_timestamp_ms; + obs_log(gf->log_level, "Sending caption to streaming output: %s", + str_copy.c_str()); + obs_output_output_caption_text2(streaming_output, str_copy.c_str(), + (double)duration / 1000.0); obs_output_release(streaming_output); } } @@ -285,3 +291,23 @@ void media_stopped_callback(void *data_, calldata_t *cd) gf_->active = false; reset_caption_state(gf_); } + +void enable_callback(void *data_, calldata_t *cd) +{ + transcription_filter_data *gf_ = static_cast(data_); + bool enable = calldata_bool(cd, "enabled"); + if (enable) { + obs_log(gf_->log_level, "enable_callback: enable"); + gf_->active = true; + reset_caption_state(gf_); + // get filter settings from gf_->context + obs_data_t *settings = obs_source_get_settings(gf_->context); + update_whisper_model(gf_, settings); + obs_data_release(settings); + } else { + obs_log(gf_->log_level, "enable_callback: disable"); + gf_->active = false; + reset_caption_state(gf_); + shutdown_whisper_thread(gf_); + } +} diff --git a/src/transcription-filter-callbacks.h b/src/transcription-filter-callbacks.h index a49f099..138cf1e 100644 --- a/src/transcription-filter-callbacks.h +++ b/src/transcription-filter-callbacks.h @@ -22,5 +22,6 @@ void media_started_callback(void *data_, calldata_t *cd); void media_pause_callback(void *data_, calldata_t *cd); void media_restart_callback(void *data_, calldata_t *cd); void media_stopped_callback(void *data_, calldata_t *cd); +void enable_callback(void *data_, calldata_t *cd); #endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */ diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 313b35c..9cb26c0 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -79,6 +79,7 @@ struct transcription_filter_data { bool fix_utf8 = true; bool enable_audio_chunks_callback = false; bool source_signals_set = false; + bool initial_creation = true; // Last transcription result std::string last_text; diff --git a/src/transcription-filter.c b/src/transcription-filter.c index 6162fab..6d23e65 100644 --- a/src/transcription-filter.c +++ b/src/transcription-filter.c @@ -14,4 +14,6 @@ struct obs_source_info transcription_filter_info = { .deactivate = transcription_filter_deactivate, .filter_audio = transcription_filter_filter_audio, .filter_remove = transcription_filter_remove, + .show = transcription_filter_show, + .hide = transcription_filter_hide, }; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index ad75f4c..8119682 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -57,6 +57,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_ if (!audio) { return nullptr; } + if (data == nullptr) { return audio; } @@ -137,6 +138,9 @@ void transcription_filter_destroy(void *data) struct transcription_filter_data *gf = static_cast(data); + signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context); + signal_handler_disconnect(sh_filter, "enable", enable_callback, gf); + obs_log(gf->log_level, "filter destroy"); shutdown_whisper_thread(gf); @@ -167,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s) struct transcription_filter_data *gf = static_cast(data); - gf->log_level = (int)obs_data_get_int(s, "log_level"); + gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level"); gf->vad_enabled = obs_data_get_bool(s, "vad_enabled"); gf->log_words = obs_data_get_bool(s, "log_words"); gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream"); @@ -293,51 +297,61 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->text_source_name = new_text_source_name; } - obs_log(gf->log_level, "update whisper model"); - update_whisper_model(gf, s); - obs_log(gf->log_level, "update whisper params"); - std::lock_guard lock(gf->whisper_ctx_mutex); + { + std::lock_guard lock(gf->whisper_ctx_mutex); - gf->sentence_psum_accept_thresh = - (float)obs_data_get_double(s, "sentence_psum_accept_thresh"); + gf->sentence_psum_accept_thresh = + (float)obs_data_get_double(s, "sentence_psum_accept_thresh"); - gf->whisper_params = whisper_full_default_params( - (whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method")); - gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec"); - if (!new_translate || gf->translation_model_index != "whisper-based-translation") { - gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select"); - } else { - // take the language from gf->target_lang - gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str(); + gf->whisper_params = whisper_full_default_params( + (whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method")); + gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec"); + if (!new_translate || gf->translation_model_index != "whisper-based-translation") { + gf->whisper_params.language = + obs_data_get_string(s, "whisper_language_select"); + } else { + // take the language from gf->target_lang + gf->whisper_params.language = + language_codes_2_reverse[gf->target_lang].c_str(); + } + gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt"); + gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads"); + gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx"); + gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate"); + gf->whisper_params.no_context = obs_data_get_bool(s, "no_context"); + gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment"); + gf->whisper_params.print_special = obs_data_get_bool(s, "print_special"); + gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress"); + gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime"); + gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps"); + gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps"); + gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt"); + gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum"); + gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len"); + gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word"); + gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens"); + gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up"); + gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank"); + gf->whisper_params.suppress_non_speech_tokens = + obs_data_get_bool(s, "suppress_non_speech_tokens"); + gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature"); + gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts"); + gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty"); + + if (gf->vad_enabled && gf->vad) { + const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold"); + gf->vad->set_threshold(vad_threshold); + } } - gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt"); - gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads"); - gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx"); - gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate"); - gf->whisper_params.no_context = obs_data_get_bool(s, "no_context"); - gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment"); - gf->whisper_params.print_special = obs_data_get_bool(s, "print_special"); - gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress"); - gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime"); - gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps"); - gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps"); - gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt"); - gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum"); - gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len"); - gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word"); - gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens"); - gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up"); - gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank"); - gf->whisper_params.suppress_non_speech_tokens = - obs_data_get_bool(s, "suppress_non_speech_tokens"); - gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature"); - gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts"); - gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty"); - if (gf->vad_enabled && gf->vad) { - const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold"); - gf->vad->set_threshold(vad_threshold); + if (gf->initial_creation && obs_source_enabled(gf->context)) { + // source was enabled on creation + obs_data_t *settings = obs_source_get_settings(gf->context); + update_whisper_model(gf, settings); + obs_data_release(settings); + gf->active = true; + gf->initial_creation = false; } } @@ -421,12 +435,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->whisper_model_path = std::string(""); // The update function will set the model path gf->whisper_context = nullptr; + signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context); + signal_handler_connect(sh_filter, "enable", enable_callback, gf); + obs_log(gf->log_level, "run update"); // get the settings updated on the filter data struct transcription_filter_update(gf, settings); - gf->active = true; - // handle the event OBS_FRONTEND_EVENT_RECORDING_STARTING to reset the srt sentence number // to match the subtitles with the recording obs_frontend_add_event_callback(recording_state_callback, gf); @@ -466,6 +481,20 @@ void transcription_filter_deactivate(void *data) gf->active = false; } +void transcription_filter_show(void *data) +{ + struct transcription_filter_data *gf = + static_cast(data); + obs_log(gf->log_level, "filter show"); +} + +void transcription_filter_hide(void *data) +{ + struct transcription_filter_data *gf = + static_cast(data); + obs_log(gf->log_level, "filter hide"); +} + void transcription_filter_defaults(obs_data_t *s) { obs_log(LOG_INFO, "filter defaults"); @@ -586,11 +615,11 @@ obs_properties_t *transcription_filter_properties(void *data) whisper_model_path_external, [](void *data_, obs_properties_t *props, obs_property_t *property, obs_data_t *settings) { - obs_log(LOG_INFO, "whisper_model_path_external modified"); UNUSED_PARAMETER(property); UNUSED_PARAMETER(props); struct transcription_filter_data *gf_ = static_cast(data_); + obs_log(gf_->log_level, "whisper_model_path_external modified"); transcription_filter_update(gf_, settings); return true; }, diff --git a/src/transcription-filter.h b/src/transcription-filter.h index 922d44e..a357785 100644 --- a/src/transcription-filter.h +++ b/src/transcription-filter.h @@ -16,6 +16,8 @@ void transcription_filter_deactivate(void *data); void transcription_filter_defaults(obs_data_t *s); obs_properties_t *transcription_filter_properties(void *data); void transcription_filter_remove(void *data, obs_source_t *source); +void transcription_filter_show(void *data); +void transcription_filter_hide(void *data); const char *const PLUGIN_INFO_TEMPLATE = "LocalVocal (%1) by " diff --git a/src/whisper-utils/whisper-model-utils.cpp b/src/whisper-utils/whisper-model-utils.cpp index c9620c8..af0d9c0 100644 --- a/src/whisper-utils/whisper-model-utils.cpp +++ b/src/whisper-utils/whisper-model-utils.cpp @@ -20,6 +20,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) obs_log(LOG_ERROR, "Cannot find Silero VAD model file"); return; } + obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file); + std::string silero_vad_model_file_str = std::string(silero_vad_model_file); + bfree(silero_vad_model_file); if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path || is_external_model) { @@ -49,14 +52,15 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) obs_log(LOG_WARNING, "Whisper model does not exist"); download_model_with_ui_dialog( model_info, - [gf, new_model_path, silero_vad_model_file]( + [gf, new_model_path, silero_vad_model_file_str]( int download_status, const std::string &path) { if (download_status == 0) { obs_log(LOG_INFO, "Model download complete"); gf->whisper_model_path = new_model_path; start_whisper_thread_with_path( - gf, path, silero_vad_model_file); + gf, path, + silero_vad_model_file_str.c_str()); } else { obs_log(LOG_ERROR, "Model download failed"); } @@ -65,7 +69,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) // Model exists, just load it gf->whisper_model_path = new_model_path; start_whisper_thread_with_path(gf, model_file_found, - silero_vad_model_file); + silero_vad_model_file_str.c_str()); } } else { // new model is external file, get file location from file property @@ -82,8 +86,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) } else { shutdown_whisper_thread(gf); gf->whisper_model_path = new_model_path; - start_whisper_thread_with_path(gf, external_model_file_path, - silero_vad_model_file); + start_whisper_thread_with_path( + gf, external_model_file_path, + silero_vad_model_file_str.c_str()); } } } @@ -101,6 +106,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s) gf->enable_token_ts_dtw, new_dtw_timestamps); gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps"); shutdown_whisper_thread(gf); - start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file); + start_whisper_thread_with_path(gf, gf->whisper_model_path, + silero_vad_model_file_str.c_str()); } } diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index 7dc8f5c..c2e4929 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -5,6 +5,10 @@ #include +#ifdef _WIN32 +#include +#endif + void shutdown_whisper_thread(struct transcription_filter_data *gf) { obs_log(gf->log_level, "shutdown_whisper_thread"); @@ -27,7 +31,8 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &whisper_model_path, const char *silero_vad_model_file) { - obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str()); + obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s", + whisper_model_path.c_str(), silero_vad_model_file); std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context != nullptr) { obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null"); @@ -36,16 +41,22 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, // initialize Silero VAD #ifdef _WIN32 - std::wstring silero_vad_model_path; - silero_vad_model_path.assign(silero_vad_model_file, - silero_vad_model_file + strlen(silero_vad_model_file)); + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, + strlen(silero_vad_model_file), NULL, 0); + std::wstring silero_vad_model_path(count, 0); + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + &silero_vad_model_path[0], count); + obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); #else std::string silero_vad_model_path = silero_vad_model_file; + obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); #endif // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py // for silero vad parameters gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + obs_log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf); if (gf->whisper_context == nullptr) { obs_log(LOG_ERROR, "Failed to initialize whisper context");