minor fixes

This commit is contained in:
Roy Shilkrot 2023-08-20 07:13:37 +03:00
parent b14ba3e93f
commit f84e48fb6c
2 changed files with 18 additions and 19 deletions

View File

@ -37,11 +37,11 @@ struct transcription_filter_data {
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
/* Resampler */ /* Resampler */
audio_resampler_t *resampler; audio_resampler_t *resampler = nullptr;
/* whisper */ /* whisper */
std::string whisper_model_path = "models/ggml-tiny.en.bin"; std::string whisper_model_path = "models/ggml-tiny.en.bin";
struct whisper_context *whisper_context; struct whisper_context *whisper_context = nullptr;
whisper_full_params whisper_params; whisper_full_params whisper_params;
float filler_p_threshold; float filler_p_threshold;
@ -50,21 +50,21 @@ struct transcription_filter_data {
bool vad_enabled; bool vad_enabled;
int log_level; int log_level;
bool log_words; bool log_words;
bool active; bool active = false;
// Text source to output the subtitles // Text source to output the subtitles
obs_weak_source_t *text_source; obs_weak_source_t *text_source = nullptr;
char *text_source_name; char *text_source_name = nullptr;
std::unique_ptr<std::mutex> text_source_mutex; std::unique_ptr<std::mutex> text_source_mutex = nullptr;
// Callback to set the text in the output text source (subtitles) // Callback to set the text in the output text source (subtitles)
std::function<void(const std::string &str)> setTextCallback; std::function<void(const std::string &str)> setTextCallback;
// Use std for thread and mutex // Use std for thread and mutex
std::thread whisper_thread; std::thread whisper_thread;
std::unique_ptr<std::mutex> whisper_buf_mutex; std::unique_ptr<std::mutex> whisper_buf_mutex = nullptr;
std::unique_ptr<std::mutex> whisper_ctx_mutex; std::unique_ptr<std::mutex> whisper_ctx_mutex = nullptr;
std::unique_ptr<std::condition_variable> wshiper_thread_cv; std::unique_ptr<std::condition_variable> wshiper_thread_cv = nullptr;
}; };
// Audio packet info // Audio packet info

View File

@ -206,8 +206,9 @@ void transcription_filter_update(void *data, obs_data_t *s)
obs_weak_source_release(old_weak_text_source); obs_weak_source_release(old_weak_text_source);
} }
const char *new_model_path = obs_data_get_string(s, "whisper_model_path"); std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
if (strcmp(new_model_path, gf->whisper_model_path.c_str()) != 0) {
if (new_model_path != gf->whisper_model_path) {
// model path changed, reload the model // model path changed, reload the model
obs_log(LOG_INFO, "model path changed, reloading model"); obs_log(LOG_INFO, "model path changed, reloading model");
if (gf->whisper_context != nullptr) { if (gf->whisper_context != nullptr) {
@ -220,7 +221,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
if (gf->whisper_thread.joinable()) { if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join(); gf->whisper_thread.join();
} }
gf->whisper_model_path = bstrdup(new_model_path); gf->whisper_model_path = new_model_path;
// check if the model exists, if not, download it // check if the model exists, if not, download it
if (!check_if_model_exists(gf->whisper_model_path)) { if (!check_if_model_exists(gf->whisper_model_path)) {
@ -229,8 +230,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_model_path, [gf](int download_status) { gf->whisper_model_path, [gf](int download_status) {
if (download_status == 0) { if (download_status == 0) {
obs_log(LOG_INFO, "Model download complete"); obs_log(LOG_INFO, "Model download complete");
gf->whisper_context = init_whisper_context( gf->whisper_context = init_whisper_context(gf->whisper_model_path);
gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf); gf->whisper_thread = std::thread(whisper_loop, gf);
} else { } else {
obs_log(LOG_ERROR, "Model download failed"); obs_log(LOG_ERROR, "Model download failed");
@ -321,8 +321,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->resampler = audio_resampler_create(&dst, &src); gf->resampler = audio_resampler_create(&dst, &src);
gf->active = true;
gf->whisper_buf_mutex = std::unique_ptr<std::mutex>(new std::mutex()); gf->whisper_buf_mutex = std::unique_ptr<std::mutex>(new std::mutex());
gf->whisper_ctx_mutex = std::unique_ptr<std::mutex>(new std::mutex()); gf->whisper_ctx_mutex = std::unique_ptr<std::mutex>(new std::mutex());
gf->wshiper_thread_cv = gf->wshiper_thread_cv =
@ -340,12 +338,11 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
std::lock_guard<std::mutex> lock(*gf->text_source_mutex); std::lock_guard<std::mutex> lock(*gf->text_source_mutex);
obs_weak_source_t *text_source = gf->text_source; if (!gf->text_source) {
if (!text_source) {
obs_log(LOG_ERROR, "text_source is null"); obs_log(LOG_ERROR, "text_source is null");
return; return;
} }
auto target = obs_weak_source_get_source(text_source); auto target = obs_weak_source_get_source(gf->text_source);
if (!target) { if (!target) {
obs_log(LOG_ERROR, "text_source target is null"); obs_log(LOG_ERROR, "text_source target is null");
return; return;
@ -362,6 +359,8 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
// start the thread // start the thread
gf->whisper_thread = std::thread(whisper_loop, gf); gf->whisper_thread = std::thread(whisper_loop, gf);
gf->active = true;
return gf; return gf;
} }