From 97937da6a88862401a68d55a70e3a7610ceb7f3d Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Sat, 16 Sep 2023 12:01:14 +0300 Subject: [PATCH] models to config folder --- src/model-utils/model-downloader-types.h | 2 + src/model-utils/model-downloader-ui.cpp | 74 +++++++++++++++++------- src/model-utils/model-downloader-ui.h | 15 +++-- src/model-utils/model-downloader.cpp | 41 ++++++++----- src/model-utils/model-downloader.h | 6 +- src/transcription-filter-data.h | 2 +- src/transcription-filter.cpp | 28 +++------ src/whisper-processing.cpp | 2 +- 8 files changed, 108 insertions(+), 62 deletions(-) create mode 100644 src/model-utils/model-downloader-types.h diff --git a/src/model-utils/model-downloader-types.h b/src/model-utils/model-downloader-types.h new file mode 100644 index 0000000..4c7e25a --- /dev/null +++ b/src/model-utils/model-downloader-types.h @@ -0,0 +1,2 @@ + +typedef std::function download_finished_callback_t; diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp index 9386273..a05249a 100644 --- a/src/model-utils/model-downloader-ui.cpp +++ b/src/model-utils/model-downloader-ui.cpp @@ -3,6 +3,8 @@ #include +#include + const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp"; const std::string MODEL_PREFIX = "resolve/main/"; @@ -14,12 +16,15 @@ size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) ModelDownloader::ModelDownloader( const std::string &model_name, - std::function download_finished_callback_, QWidget *parent) + download_finished_callback_t download_finished_callback_, QWidget *parent) : QDialog(parent), download_finished_callback(download_finished_callback_) { - this->setWindowTitle("Downloading model..."); + this->setWindowTitle("LocalVocal: Downloading model..."); this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); this->setFixedSize(300, 100); + // Bring the dialog to the front + this->activateWindow(); + this->raise(); this->layout = new QVBoxLayout(this); @@ -59,24 +64,32 @@ ModelDownloader::ModelDownloader( this->download_thread->start(); } +void ModelDownloader::closeEvent(QCloseEvent *e) +{ + if (!this->mPrepareToClose) + e->ignore(); + else + QDialog::closeEvent(e); +} + +void ModelDownloader::close() +{ + this->mPrepareToClose = true; + + QDialog::close(); +} + void ModelDownloader::update_progress(int progress) { this->progress_bar->setValue(progress); } -void ModelDownloader::download_finished() +void ModelDownloader::download_finished(const std::string& path) { - this->setWindowTitle("Download finished!"); - this->progress_bar->setValue(100); - this->progress_bar->setFormat("Download finished!"); - this->progress_bar->setAlignment(Qt::AlignCenter); - this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }"); - // Add a button to close the dialog - QPushButton *close_button = new QPushButton("Close", this); - this->layout->addWidget(close_button); - connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); - // Call the callback - this->download_finished_callback(0); + // Call the callback with the path to the downloaded model + this->download_finished_callback(0, path); + // Close the dialog + this->close(); } void ModelDownloader::show_error(const std::string &reason) @@ -96,7 +109,7 @@ void ModelDownloader::show_error(const std::string &reason) QPushButton *close_button = new QPushButton("Close", this); this->layout->addWidget(close_button); connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); - this->download_finished_callback(1); + this->download_finished_callback(1, ""); } ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_) @@ -106,9 +119,22 @@ ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_) void ModelDownloadWorker::download_model() { - std::string module_data_dir = obs_get_module_data_path(obs_current_module()); - // join the directory and the filename using the platform-specific separator - std::string model_save_path = module_data_dir + "/" + this->model_name; + char* module_config_path = obs_module_get_config_path(obs_current_module(), "models"); + // Check if the config folder exists + if (!std::filesystem::exists(module_config_path)) { + obs_log(LOG_WARNING, "Config folder does not exist: %s", module_config_path); + // Create the config folder + if (!std::filesystem::create_directories(module_config_path)) { + obs_log(LOG_ERROR, "Failed to create config folder: %s", module_config_path); + emit download_error("Failed to create config folder."); + return; + } + } + + char *model_save_path_str = + obs_module_get_config_path(obs_current_module(), this->model_name.c_str()); + std::string model_save_path(model_save_path_str); + bfree(model_save_path_str); obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str()); // extract filename from path in this->modle_name @@ -143,11 +169,11 @@ void ModelDownloadWorker::download_model() } curl_easy_cleanup(curl); fclose(fp); + emit download_finished(model_save_path); } else { obs_log(LOG_ERROR, "Failed to initialize curl."); emit download_error("Failed to initialize curl."); } - emit download_finished(); } int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, @@ -168,9 +194,13 @@ int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, cu ModelDownloader::~ModelDownloader() { - this->download_thread->quit(); - this->download_thread->wait(); - delete this->download_thread; + if (this->download_thread != nullptr) { + if (this->download_thread->isRunning()) { + this->download_thread->quit(); + this->download_thread->wait(); + } + delete this->download_thread; + } delete this->download_worker; } diff --git a/src/model-utils/model-downloader-ui.h b/src/model-utils/model-downloader-ui.h index b567a9a..c8fea10 100644 --- a/src/model-utils/model-downloader-ui.h +++ b/src/model-utils/model-downloader-ui.h @@ -9,6 +9,8 @@ #include +#include "model-downloader-types.h" + class ModelDownloadWorker : public QObject { Q_OBJECT public: @@ -20,7 +22,7 @@ public slots: signals: void download_progress(int progress); - void download_finished(); + void download_finished(const std::string& path); void download_error(const std::string &reason); private: @@ -33,22 +35,27 @@ class ModelDownloader : public QDialog { Q_OBJECT public: ModelDownloader(const std::string &model_name, - std::function download_finished_callback, + download_finished_callback_t download_finished_callback, QWidget *parent = nullptr); ~ModelDownloader(); public slots: void update_progress(int progress); - void download_finished(); + void download_finished(const std::string& path); void show_error(const std::string &reason); +protected: + void closeEvent(QCloseEvent *e) override; + private: QVBoxLayout *layout; QProgressBar *progress_bar; QThread *download_thread; ModelDownloadWorker *download_worker; // Callback for when the download is finished - std::function download_finished_callback; + download_finished_callback_t download_finished_callback; + bool mPrepareToClose; + void close(); }; #endif // MODEL_DOWNLOADER_UI_H diff --git a/src/model-utils/model-downloader.cpp b/src/model-utils/model-downloader.cpp index af1f7a8..dbd043c 100644 --- a/src/model-utils/model-downloader.cpp +++ b/src/model-utils/model-downloader.cpp @@ -12,28 +12,43 @@ #include -bool check_if_model_exists(const std::string &model_name) +std::string find_model_file(const std::string &model_name) { - obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str()); - char *model_file_path = obs_module_file(model_name.c_str()); - obs_log(LOG_INFO, "Model file path: %s", model_file_path); + const char *model_name_cstr = model_name.c_str(); + obs_log(LOG_INFO, "Checking if model %s exists in data...", model_name_cstr); + + char *model_file_path = obs_module_file(model_name_cstr); if (model_file_path == nullptr) { - obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str()); - return false; + obs_log(LOG_INFO, "Model %s not found in data.", model_name_cstr); + } else { + std::string model_file_path_str(model_file_path); + bfree(model_file_path); + if (!std::filesystem::exists(model_file_path_str)) { + obs_log(LOG_INFO, "Model not found in data: %s", model_file_path_str.c_str()); + } else { + obs_log(LOG_INFO, "Model found in data: %s", model_file_path_str.c_str()); + return model_file_path_str; + } } - if (!std::filesystem::exists(model_file_path)) { - obs_log(LOG_INFO, "Model %s does not exist.", model_file_path); - bfree(model_file_path); - return false; + // Check if model exists in the config folder + char *model_config_path_str = + obs_module_get_config_path(obs_current_module(), model_name_cstr); + std::string model_config_path(model_config_path_str); + bfree(model_config_path_str); + obs_log(LOG_INFO, "Model path in config: %s", model_config_path.c_str()); + if (std::filesystem::exists(model_config_path)) { + obs_log(LOG_INFO, "Model exists in config folder: %s", model_config_path.c_str()); + return model_config_path; } - bfree(model_file_path); - return true; + + obs_log(LOG_INFO, "Model %s not found.", model_name_cstr); + return ""; } void download_model_with_ui_dialog( const std::string &model_name, - std::function download_finished_callback) + download_finished_callback_t download_finished_callback) { // Start the model downloader UI ModelDownloader *model_downloader = new ModelDownloader( diff --git a/src/model-utils/model-downloader.h b/src/model-utils/model-downloader.h index 7d67ade..db5c317 100644 --- a/src/model-utils/model-downloader.h +++ b/src/model-utils/model-downloader.h @@ -4,11 +4,13 @@ #include #include -bool check_if_model_exists(const std::string &model_name); +#include "model-downloader-types.h" + +std::string find_model_file(const std::string &model_name); // Start the model downloader UI dialog with a callback for when the download is finished void download_model_with_ui_dialog( const std::string &model_name, - std::function download_finished_callback); + download_finished_callback_t download_finished_callback); #endif // MODEL_DOWNLOADER_H diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index d53af79..a43be9a 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -40,7 +40,7 @@ struct transcription_filter_data { audio_resampler_t *resampler = nullptr; /* whisper */ - std::string whisper_model_path = "models/ggml-tiny.en.bin"; + std::string whisper_model_path; struct whisper_context *whisper_context = nullptr; whisper_full_params whisper_params; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 8d3a4ef..18e49b4 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -313,14 +313,14 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_model_path = new_model_path; // check if the model exists, if not, download it - if (!check_if_model_exists(gf->whisper_model_path)) { - obs_log(LOG_ERROR, "Whisper model does not exist"); + std::string model_file_found = find_model_file(gf->whisper_model_path); + if (model_file_found == "") { + obs_log(LOG_WARNING, "Whisper model does not exist"); download_model_with_ui_dialog( - gf->whisper_model_path, [gf](int download_status) { + gf->whisper_model_path, [gf](int download_status, const std::string &path) { if (download_status == 0) { obs_log(LOG_INFO, "Model download complete"); - gf->whisper_context = init_whisper_context( - gf->whisper_model_path); + gf->whisper_context = init_whisper_context(path); std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } else { @@ -329,7 +329,7 @@ void transcription_filter_update(void *data, obs_data_t *s) }); } else { // Model exists, just load it - gf->whisper_context = init_whisper_context(gf->whisper_model_path); + gf->whisper_context = init_whisper_context(model_file_found); std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } @@ -374,8 +374,8 @@ void transcription_filter_update(void *data, obs_data_t *s) void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) { - struct transcription_filter_data *gf = static_cast( - bzalloc(sizeof(struct transcription_filter_data))); + void *p = bzalloc(sizeof(struct transcription_filter_data)); + struct transcription_filter_data *gf = new (p) transcription_filter_data; // Get the number of channels for the input source gf->channels = audio_output_get_channels(obs_get_audio()); @@ -396,12 +396,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) } gf->context = filter; - gf->whisper_model_path = std::string(obs_data_get_string(settings, "whisper_model_path")); - gf->whisper_context = init_whisper_context(gf->whisper_model_path); - if (gf->whisper_context == nullptr) { - obs_log(LOG_ERROR, "Failed to load whisper model"); - return nullptr; - } + gf->whisper_model_path = ""; // The update function will set the model path gf->overlap_ms = OVERLAP_SIZE_MSEC; gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms)); @@ -433,11 +428,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) // get the settings updated on the filter data struct transcription_filter_update(gf, settings); - obs_log(gf->log_level, "transcription_filter: start whisper thread"); - // start the thread - std::thread new_whisper_thread(whisper_loop, gf); - gf->whisper_thread.swap(new_whisper_thread); - gf->active = true; obs_log(gf->log_level, "transcription_filter: filter created."); diff --git a/src/whisper-processing.cpp b/src/whisper-processing.cpp index cf7ba7b..0eba784 100644 --- a/src/whisper-processing.cpp +++ b/src/whisper-processing.cpp @@ -73,7 +73,7 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v struct whisper_context *init_whisper_context(const std::string &model_path) { obs_log(LOG_INFO, "Loading whisper model from %s", model_path.c_str()); - struct whisper_context *ctx = whisper_init_from_file(obs_module_file(model_path.c_str())); + struct whisper_context *ctx = whisper_init_from_file(model_path.c_str()); if (ctx == nullptr) { obs_log(LOG_ERROR, "Failed to load whisper model"); return nullptr;