external model file

This commit is contained in:
Roy Shilkrot 2023-09-22 18:06:22 -04:00
parent 306c2883a5
commit ed355d647f

View File

@ -216,6 +216,23 @@ void set_text_callback(struct transcription_filter_data *gf, const std::string &
} }
}; };
void shutdown_whisper_thread(struct transcription_filter_data *gf) {
if (gf->whisper_context != nullptr) {
// acquire the mutex before freeing the context
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
}
void transcription_filter_update(void *data, obs_data_t *s) void transcription_filter_update(void *data, obs_data_t *s)
{ {
struct transcription_filter_data *gf = struct transcription_filter_data *gf =
@ -296,22 +313,13 @@ void transcription_filter_update(void *data, obs_data_t *s)
if (new_model_path != gf->whisper_model_path) { if (new_model_path != gf->whisper_model_path) {
// model path changed, reload the model // model path changed, reload the model
obs_log(LOG_INFO, "model path changed, reloading model"); obs_log(LOG_INFO, "model path changed, reloading model");
if (gf->whisper_context != nullptr) { shutdown_whisper_thread(gf);
// acquire the mutex before freeing the context
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
gf->whisper_model_path = new_model_path; gf->whisper_model_path = new_model_path;
// check if the new model is external file
if (new_model_path.find("!!!external!!!") == std::string::npos) {
// new model is not external file
// check if the model exists, if not, download it // check if the model exists, if not, download it
std::string model_file_found = find_model_file(gf->whisper_model_path); std::string model_file_found = find_model_file(gf->whisper_model_path);
if (model_file_found == "") { if (model_file_found == "") {
@ -334,6 +342,13 @@ void transcription_filter_update(void *data, obs_data_t *s)
std::thread new_whisper_thread(whisper_loop, gf); std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread); gf->whisper_thread.swap(new_whisper_thread);
} }
} else {
// new model is local file, get file location from file property
std::string external_model_file_path = obs_data_get_string(s, "whisper_model_path_external");
gf->whisper_context = init_whisper_context(external_model_file_path);
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}
} }
if (!gf->whisper_ctx_mutex) { if (!gf->whisper_ctx_mutex) {
@ -495,6 +510,9 @@ void transcription_filter_defaults(obs_data_t *s)
obs_properties_t *transcription_filter_properties(void *data) obs_properties_t *transcription_filter_properties(void *data)
{ {
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
obs_properties_t *ppts = obs_properties_create(); obs_properties_t *ppts = obs_properties_create();
obs_properties_add_bool(ppts, "vad_enabled", "VAD Enabled"); obs_properties_add_bool(ppts, "vad_enabled", "VAD Enabled");
@ -564,6 +582,46 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_property_list_add_string(whisper_models_list, "Small (Eng) 466Mb", obs_property_list_add_string(whisper_models_list, "Small (Eng) 466Mb",
"models/ggml-small.en.bin"); "models/ggml-small.en.bin");
obs_property_list_add_string(whisper_models_list, "Small 466Mb", "models/ggml-small.bin"); obs_property_list_add_string(whisper_models_list, "Small 466Mb", "models/ggml-small.bin");
obs_property_list_add_string(whisper_models_list, "Load external model file", "!!!external!!!");
// Add a file selection input to select an external model file
obs_property_t* whisper_model_path_external = obs_properties_add_path(ppts, "whisper_model_path_external", "External model file",
OBS_PATH_FILE, "Model (*.bin)", NULL);
// Hide the external model file selection input
obs_property_set_visible(obs_properties_get(ppts, "whisper_model_path_external"), false);
obs_property_set_modified_callback2(whisper_model_path_external, [](void* data, obs_properties_t *props,
obs_property_t *property,
obs_data_t *settings) {
UNUSED_PARAMETER(property);
UNUSED_PARAMETER(props);
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
shutdown_whisper_thread(gf);
std::string external_model_file_path = obs_data_get_string(settings, "whisper_model_path_external");
gf->whisper_context = init_whisper_context(external_model_file_path);
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
return true;
}, gf);
// Add a callback to the model list to handle the external model file selection
obs_property_set_modified_callback(whisper_models_list, [](obs_properties_t *props,
obs_property_t *property,
obs_data_t *settings) {
UNUSED_PARAMETER(property);
// If the selected model is the external model, show the external model file selection
// input
const char *new_model_path = obs_data_get_string(settings, "whisper_model_path");
if (strcmp(new_model_path, "!!!external!!!") == 0) {
obs_property_set_visible(
obs_properties_get(props, "whisper_model_path_external"), true);
} else {
obs_property_set_visible(
obs_properties_get(props, "whisper_model_path_external"), false);
}
return true;
});
obs_properties_t *whisper_params_group = obs_properties_create(); obs_properties_t *whisper_params_group = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", "Whisper Parameters", obs_properties_add_group(ppts, "whisper_params_group", "Whisper Parameters",