diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index baa76cd..e9256d1 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -216,6 +216,24 @@ void set_text_callback(struct transcription_filter_data *gf, const std::string & } }; +void shutdown_whisper_thread(struct transcription_filter_data *gf) +{ + if (gf->whisper_context != nullptr) { + // acquire the mutex before freeing the context + if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) { + obs_log(LOG_ERROR, "whisper_ctx_mutex is null"); + return; + } + std::lock_guard lock(*gf->whisper_ctx_mutex); + whisper_free(gf->whisper_context); + gf->whisper_context = nullptr; + gf->wshiper_thread_cv->notify_all(); + } + if (gf->whisper_thread.joinable()) { + gf->whisper_thread.join(); + } +} + void transcription_filter_update(void *data, obs_data_t *s) { struct transcription_filter_data *gf = @@ -296,41 +314,43 @@ void transcription_filter_update(void *data, obs_data_t *s) if (new_model_path != gf->whisper_model_path) { // model path changed, reload the model obs_log(LOG_INFO, "model path changed, reloading model"); - if (gf->whisper_context != nullptr) { - // acquire the mutex before freeing the context - if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) { - obs_log(LOG_ERROR, "whisper_ctx_mutex is null"); - return; - } - std::lock_guard lock(*gf->whisper_ctx_mutex); - whisper_free(gf->whisper_context); - gf->whisper_context = nullptr; - gf->wshiper_thread_cv->notify_all(); - } - if (gf->whisper_thread.joinable()) { - gf->whisper_thread.join(); - } + shutdown_whisper_thread(gf); + gf->whisper_model_path = new_model_path; - // check if the model exists, if not, download it - std::string model_file_found = find_model_file(gf->whisper_model_path); - if (model_file_found == "") { - obs_log(LOG_WARNING, "Whisper model does not exist"); - download_model_with_ui_dialog( - gf->whisper_model_path, - [gf](int download_status, const std::string &path) { - if (download_status == 0) { - obs_log(LOG_INFO, "Model download complete"); - gf->whisper_context = init_whisper_context(path); - std::thread new_whisper_thread(whisper_loop, gf); - gf->whisper_thread.swap(new_whisper_thread); - } else { - obs_log(LOG_ERROR, "Model download failed"); - } - }); + // check if the new model is external file + if (new_model_path.find("!!!external!!!") == std::string::npos) { + // new model is not external file + // check if the model exists, if not, download it + std::string model_file_found = find_model_file(gf->whisper_model_path); + if (model_file_found == "") { + obs_log(LOG_WARNING, "Whisper model does not exist"); + download_model_with_ui_dialog( + gf->whisper_model_path, + [gf](int download_status, const std::string &path) { + if (download_status == 0) { + obs_log(LOG_INFO, + "Model download complete"); + gf->whisper_context = + init_whisper_context(path); + std::thread new_whisper_thread(whisper_loop, + gf); + gf->whisper_thread.swap(new_whisper_thread); + } else { + obs_log(LOG_ERROR, "Model download failed"); + } + }); + } else { + // Model exists, just load it + gf->whisper_context = init_whisper_context(model_file_found); + std::thread new_whisper_thread(whisper_loop, gf); + gf->whisper_thread.swap(new_whisper_thread); + } } else { - // Model exists, just load it - gf->whisper_context = init_whisper_context(model_file_found); + // new model is local file, get file location from file property + std::string external_model_file_path = + obs_data_get_string(s, "whisper_model_path_external"); + gf->whisper_context = init_whisper_context(external_model_file_path); std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } @@ -495,6 +515,9 @@ void transcription_filter_defaults(obs_data_t *s) obs_properties_t *transcription_filter_properties(void *data) { + struct transcription_filter_data *gf = + static_cast(data); + obs_properties_t *ppts = obs_properties_create(); obs_properties_add_bool(ppts, "vad_enabled", "VAD Enabled"); @@ -564,6 +587,51 @@ obs_properties_t *transcription_filter_properties(void *data) obs_property_list_add_string(whisper_models_list, "Small (Eng) 466Mb", "models/ggml-small.en.bin"); obs_property_list_add_string(whisper_models_list, "Small 466Mb", "models/ggml-small.bin"); + obs_property_list_add_string(whisper_models_list, "Load external model file", + "!!!external!!!"); + + // Add a file selection input to select an external model file + obs_property_t *whisper_model_path_external = + obs_properties_add_path(ppts, "whisper_model_path_external", "External model file", + OBS_PATH_FILE, "Model (*.bin)", NULL); + // Hide the external model file selection input + obs_property_set_visible(obs_properties_get(ppts, "whisper_model_path_external"), false); + + obs_property_set_modified_callback2( + whisper_model_path_external, + [](void *data_, obs_properties_t *props, obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + UNUSED_PARAMETER(props); + struct transcription_filter_data *gf_ = + static_cast(data_); + shutdown_whisper_thread(gf_); + std::string external_model_file_path = + obs_data_get_string(settings, "whisper_model_path_external"); + gf_->whisper_context = init_whisper_context(external_model_file_path); + std::thread new_whisper_thread(whisper_loop, gf_); + gf_->whisper_thread.swap(new_whisper_thread); + return true; + }, + gf); + + // Add a callback to the model list to handle the external model file selection + obs_property_set_modified_callback(whisper_models_list, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + // If the selected model is the external model, show the external model file selection + // input + const char *new_model_path = obs_data_get_string(settings, "whisper_model_path"); + if (strcmp(new_model_path, "!!!external!!!") == 0) { + obs_property_set_visible( + obs_properties_get(props, "whisper_model_path_external"), true); + } else { + obs_property_set_visible( + obs_properties_get(props, "whisper_model_path_external"), false); + } + return true; + }); obs_properties_t *whisper_params_group = obs_properties_create(); obs_properties_add_group(ppts, "whisper_params_group", "Whisper Parameters",