fix whisper model loading language (#139)

* refactor: Add boolean flag for whisper model loaded status

* refactor: Improve handling of whisper model paths in transcription filter

* refactor: Update whisper model path and add flag for model loaded status
This commit is contained in:
Roy Shilkrot 2024-07-17 18:54:34 -04:00 committed by GitHub
parent 44f072b5ff
commit 4e3fdcd6ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 38 additions and 25 deletions

View File

@ -93,6 +93,7 @@ struct transcription_filter_data {
// Output file path to write the subtitles
std::string output_file_path;
std::string whisper_model_file_currently_loaded;
bool whisper_model_loaded_new;
// Use std for thread and mutex
std::thread whisper_thread;

View File

@ -69,14 +69,20 @@ bool file_output_select_changed(obs_properties_t *props, obs_property_t *propert
return true;
}
bool external_model_file_selection(obs_properties_t *props, obs_property_t *property,
bool external_model_file_selection(void *data_, obs_properties_t *props, obs_property_t *property,
obs_data_t *settings)
{
UNUSED_PARAMETER(property);
struct transcription_filter_data *gf_ =
static_cast<struct transcription_filter_data *>(data_);
// If the selected model is the external model, show the external model file selection
// input
const char *new_model_path = obs_data_get_string(settings, "whisper_model_path");
const bool is_external = strcmp(new_model_path, "!!!external!!!") == 0;
const char *new_model_path_cstr =
obs_data_get_string(settings, "whisper_model_path") != nullptr
? obs_data_get_string(settings, "whisper_model_path")
: "";
const std::string new_model_path = new_model_path_cstr;
const bool is_external = (new_model_path.find("!!!external!!!") != std::string::npos);
if (is_external) {
obs_property_set_visible(obs_properties_get(props, "whisper_model_path_external"),
true);
@ -85,26 +91,29 @@ bool external_model_file_selection(obs_properties_t *props, obs_property_t *prop
false);
}
const std::string model_name = new_model_path;
// if the model is english-only -> hide all the languages but english
const bool is_english_only_internal = (model_name.find("English") != std::string::npos) &&
!is_external;
// clear the language selection list ("whisper_language_select")
obs_property_t *prop_lang = obs_properties_get(props, "whisper_language_select");
obs_property_list_clear(prop_lang);
if (is_english_only_internal) {
// add only the english language
obs_property_list_add_string(prop_lang, "English", "en");
// set the language to english
obs_data_set_string(settings, "whisper_language_select", "en");
} else {
// add all the languages
for (const auto &lang : whisper_available_lang) {
obs_property_list_add_string(prop_lang, lang.second.c_str(),
lang.first.c_str());
// check if this is a new model selection
if (gf_->whisper_model_loaded_new) {
// if the model is english-only -> hide all the languages but english
const bool is_english_only_internal =
(new_model_path.find("English") != std::string::npos) && !is_external;
// clear the language selection list ("whisper_language_select")
obs_property_t *prop_lang = obs_properties_get(props, "whisper_language_select");
obs_property_list_clear(prop_lang);
if (is_english_only_internal) {
// add only the english language
obs_property_list_add_string(prop_lang, "English", "en");
// set the language to english
obs_data_set_string(settings, "whisper_language_select", "en");
} else {
// add all the languages
for (const auto &lang : whisper_available_lang) {
obs_property_list_add_string(prop_lang, lang.second.c_str(),
lang.first.c_str());
}
// set the language to auto (default)
obs_data_set_string(settings, "whisper_language_select", "auto");
}
// set the language to auto (default)
obs_data_set_string(settings, "whisper_language_select", "auto");
gf_->whisper_model_loaded_new = false;
}
return true;
}
@ -131,7 +140,8 @@ bool translation_external_model_selection(obs_properties_t *props, obs_property_
return true;
}
void add_transcription_group_properties(obs_properties_t *ppts)
void add_transcription_group_properties(obs_properties_t *ppts,
struct transcription_filter_data *gf)
{
// add "Transcription" group
obs_properties_t *transcription_group = obs_properties_create();
@ -159,7 +169,7 @@ void add_transcription_group_properties(obs_properties_t *ppts)
obs_property_set_visible(obs_properties_get(ppts, "whisper_model_path_external"), false);
// Add a callback to the model list to handle the external model file selection
obs_property_set_modified_callback(whisper_models_list, external_model_file_selection);
obs_property_set_modified_callback2(whisper_models_list, external_model_file_selection, gf);
}
void add_translation_group_properties(obs_properties_t *ppts)
@ -474,7 +484,7 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_set_modified_callback(advanced_settings, advanced_settings_callback);
add_general_group_properties(ppts);
add_transcription_group_properties(ppts);
add_transcription_group_properties(ppts, gf);
add_translation_group_properties(ppts);
add_file_output_group_properties(ppts);
add_buffered_output_group_properties(ppts);

View File

@ -62,6 +62,8 @@ void update_whisper_model(struct transcription_filter_data *gf)
// model path changed
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
gf->whisper_model_loaded_new = true;
}
// check if the new model is external file