diff --git a/CMakeLists.txt b/CMakeLists.txt index 88f2875..2b819a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,7 +50,9 @@ endif() include(cmake/BuildWhispercpp.cmake) target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp) -target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c - src/whisper-processing.cpp) +target_sources( + ${CMAKE_PROJECT_NAME} + PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp + src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) diff --git a/CMakePresets.json b/CMakePresets.json index fc86c17..bc80925 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -23,8 +23,8 @@ "CMAKE_OSX_DEPLOYMENT_TARGET": "11.0", "CODESIGN_IDENTITY": "$penv{CODESIGN_IDENT}", "CODESIGN_TEAM": "$penv{CODESIGN_TEAM}", - "ENABLE_FRONTEND_API": false, - "ENABLE_QT": false + "ENABLE_FRONTEND_API": true, + "ENABLE_QT": true } }, { @@ -53,8 +53,8 @@ "cacheVariables": { "QT_VERSION": "6", "CMAKE_SYSTEM_VERSION": "10.0.18363.657", - "ENABLE_FRONTEND_API": false, - "ENABLE_QT": false + "ENABLE_FRONTEND_API": true, + "ENABLE_QT": true } }, { @@ -81,8 +81,8 @@ "cacheVariables": { "QT_VERSION": "6", "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "ENABLE_FRONTEND_API": false, - "ENABLE_QT": false + "ENABLE_FRONTEND_API": true, + "ENABLE_QT": true } }, { @@ -110,8 +110,8 @@ "cacheVariables": { "QT_VERSION": "6", "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "ENABLE_FRONTEND_API": false, - "ENABLE_QT": false + "ENABLE_FRONTEND_API": true, + "ENABLE_QT": true } }, { diff --git a/cmake/common/buildspec_common.cmake b/cmake/common/buildspec_common.cmake index b2c2414..2e9c575 100644 --- a/cmake/common/buildspec_common.cmake +++ b/cmake/common/buildspec_common.cmake @@ -73,6 +73,14 @@ function(_setup_obs_studio) set(_cmake_version "3.0.0") endif() + message(STATUS "Patch libobs") + execute_process( + COMMAND patch --forward "libobs/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/patch_libobs.diff" + RESULT_VARIABLE _process_result COMMAND_ERROR_IS_FATAL ANY + WORKING_DIRECTORY "${dependencies_dir}/${_obs_destination}" + ) + message(STATUS "Patch - done") + message(STATUS "Configure ${label} (${arch})") execute_process( COMMAND diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp new file mode 100644 index 0000000..9386273 --- /dev/null +++ b/src/model-utils/model-downloader-ui.cpp @@ -0,0 +1,180 @@ +#include "model-downloader-ui.h" +#include "plugin-support.h" + +#include + +const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp"; +const std::string MODEL_PREFIX = "resolve/main/"; + +size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) +{ + size_t written = fwrite(ptr, size, nmemb, stream); + return written; +} + +ModelDownloader::ModelDownloader( + const std::string &model_name, + std::function download_finished_callback_, QWidget *parent) + : QDialog(parent), download_finished_callback(download_finished_callback_) +{ + this->setWindowTitle("Downloading model..."); + this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); + this->setFixedSize(300, 100); + + this->layout = new QVBoxLayout(this); + + // Add a label for the model name + QLabel *model_name_label = new QLabel(this); + model_name_label->setText(QString::fromStdString(model_name)); + model_name_label->setAlignment(Qt::AlignCenter); + this->layout->addWidget(model_name_label); + + this->progress_bar = new QProgressBar(this); + this->progress_bar->setRange(0, 100); + this->progress_bar->setValue(0); + this->progress_bar->setAlignment(Qt::AlignCenter); + // Show progress as a percentage + this->progress_bar->setFormat("%p%"); + this->layout->addWidget(this->progress_bar); + + this->download_thread = new QThread(); + this->download_worker = new ModelDownloadWorker(model_name); + this->download_worker->moveToThread(this->download_thread); + + connect(this->download_thread, &QThread::started, this->download_worker, + &ModelDownloadWorker::download_model); + connect(this->download_worker, &ModelDownloadWorker::download_progress, this, + &ModelDownloader::update_progress); + connect(this->download_worker, &ModelDownloadWorker::download_finished, this, + &ModelDownloader::download_finished); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_thread, &QThread::quit); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_worker, &ModelDownloadWorker::deleteLater); + connect(this->download_worker, &ModelDownloadWorker::download_error, this, + &ModelDownloader::show_error); + connect(this->download_thread, &QThread::finished, this->download_thread, + &QThread::deleteLater); + + this->download_thread->start(); +} + +void ModelDownloader::update_progress(int progress) +{ + this->progress_bar->setValue(progress); +} + +void ModelDownloader::download_finished() +{ + this->setWindowTitle("Download finished!"); + this->progress_bar->setValue(100); + this->progress_bar->setFormat("Download finished!"); + this->progress_bar->setAlignment(Qt::AlignCenter); + this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }"); + // Add a button to close the dialog + QPushButton *close_button = new QPushButton("Close", this); + this->layout->addWidget(close_button); + connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); + // Call the callback + this->download_finished_callback(0); +} + +void ModelDownloader::show_error(const std::string &reason) +{ + this->setWindowTitle("Download failed!"); + this->progress_bar->setFormat("Download failed!"); + this->progress_bar->setAlignment(Qt::AlignCenter); + this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }"); + // Add a label to show the error + QLabel *error_label = new QLabel(this); + error_label->setText(QString::fromStdString(reason)); + error_label->setAlignment(Qt::AlignCenter); + // Color red + error_label->setStyleSheet("QLabel { color : red; }"); + this->layout->addWidget(error_label); + // Add a button to close the dialog + QPushButton *close_button = new QPushButton("Close", this); + this->layout->addWidget(close_button); + connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); + this->download_finished_callback(1); +} + +ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_) +{ + this->model_name = model_name_; +} + +void ModelDownloadWorker::download_model() +{ + std::string module_data_dir = obs_get_module_data_path(obs_current_module()); + // join the directory and the filename using the platform-specific separator + std::string model_save_path = module_data_dir + "/" + this->model_name; + obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str()); + + // extract filename from path in this->modle_name + const std::string model_filename = + this->model_name.substr(this->model_name.find_last_of("/\\") + 1); + + std::string model_url = MODEL_BASE_PATH + "/" + MODEL_PREFIX + model_filename; + obs_log(LOG_INFO, "Model URL: %s", model_url.c_str()); + + CURL *curl = curl_easy_init(); + if (curl) { + FILE *fp = fopen(model_save_path.c_str(), "wb"); + if (fp == nullptr) { + obs_log(LOG_ERROR, "Failed to open file %s.", model_save_path.c_str()); + emit download_error("Failed to open file."); + return; + } + curl_easy_setopt(curl, CURLOPT_URL, model_url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, + ModelDownloadWorker::progress_callback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this); + // Follow redirects + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + obs_log(LOG_ERROR, "Failed to download model %s.", + this->model_name.c_str()); + emit download_error("Failed to download model."); + } + curl_easy_cleanup(curl); + fclose(fp); + } else { + obs_log(LOG_ERROR, "Failed to initialize curl."); + emit download_error("Failed to initialize curl."); + } + emit download_finished(); +} + +int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t, curl_off_t) +{ + if (dltotal == 0) { + return 0; // Unknown progress + } + ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp; + if (worker == nullptr) { + obs_log(LOG_ERROR, "Worker is null."); + return 1; + } + int progress = (int)(dlnow * 100l / dltotal); + emit worker->download_progress(progress); + return 0; +} + +ModelDownloader::~ModelDownloader() +{ + this->download_thread->quit(); + this->download_thread->wait(); + delete this->download_thread; + delete this->download_worker; +} + +ModelDownloadWorker::~ModelDownloadWorker() +{ + // Do nothing +} diff --git a/src/model-utils/model-downloader-ui.h b/src/model-utils/model-downloader-ui.h new file mode 100644 index 0000000..b567a9a --- /dev/null +++ b/src/model-utils/model-downloader-ui.h @@ -0,0 +1,54 @@ +#ifndef MODEL_DOWNLOADER_UI_H +#define MODEL_DOWNLOADER_UI_H + +#include +#include + +#include +#include + +#include + +class ModelDownloadWorker : public QObject { + Q_OBJECT +public: + ModelDownloadWorker(const std::string &model_name); + ~ModelDownloadWorker(); + +public slots: + void download_model(); + +signals: + void download_progress(int progress); + void download_finished(); + void download_error(const std::string &reason); + +private: + static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t ultotal, curl_off_t ulnow); + std::string model_name; +}; + +class ModelDownloader : public QDialog { + Q_OBJECT +public: + ModelDownloader(const std::string &model_name, + std::function download_finished_callback, + QWidget *parent = nullptr); + ~ModelDownloader(); + +public slots: + void update_progress(int progress); + void download_finished(); + void show_error(const std::string &reason); + +private: + QVBoxLayout *layout; + QProgressBar *progress_bar; + QThread *download_thread; + ModelDownloadWorker *download_worker; + // Callback for when the download is finished + std::function download_finished_callback; +}; + +#endif // MODEL_DOWNLOADER_UI_H diff --git a/src/model-utils/model-downloader.cpp b/src/model-utils/model-downloader.cpp new file mode 100644 index 0000000..af1f7a8 --- /dev/null +++ b/src/model-utils/model-downloader.cpp @@ -0,0 +1,42 @@ +#include "model-downloader.h" +#include "plugin-support.h" +#include "model-downloader-ui.h" + +#include +#include + +#include +#include +#include +#include + +#include + +bool check_if_model_exists(const std::string &model_name) +{ + obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str()); + char *model_file_path = obs_module_file(model_name.c_str()); + obs_log(LOG_INFO, "Model file path: %s", model_file_path); + if (model_file_path == nullptr) { + obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str()); + return false; + } + + if (!std::filesystem::exists(model_file_path)) { + obs_log(LOG_INFO, "Model %s does not exist.", model_file_path); + bfree(model_file_path); + return false; + } + bfree(model_file_path); + return true; +} + +void download_model_with_ui_dialog( + const std::string &model_name, + std::function download_finished_callback) +{ + // Start the model downloader UI + ModelDownloader *model_downloader = new ModelDownloader( + model_name, download_finished_callback, (QWidget *)obs_frontend_get_main_window()); + model_downloader->show(); +} diff --git a/src/model-utils/model-downloader.h b/src/model-utils/model-downloader.h new file mode 100644 index 0000000..7d67ade --- /dev/null +++ b/src/model-utils/model-downloader.h @@ -0,0 +1,14 @@ +#ifndef MODEL_DOWNLOADER_H +#define MODEL_DOWNLOADER_H + +#include +#include + +bool check_if_model_exists(const std::string &model_name); + +// Start the model downloader UI dialog with a callback for when the download is finished +void download_model_with_ui_dialog( + const std::string &model_name, + std::function download_finished_callback); + +#endif // MODEL_DOWNLOADER_H diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index d55dcce..4457368 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -33,7 +33,6 @@ struct transcription_filter_data { /* PCM buffers */ float *copy_buffers[MAX_PREPROC_CHANNELS]; - DARRAY(float) copy_output_buffers[MAX_PREPROC_CHANNELS]; struct circlebuf info_buffer; struct circlebuf input_buffers[MAX_PREPROC_CHANNELS]; diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index c7eac16..01a58d6 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -5,6 +5,7 @@ #include "transcription-filter-data.h" #include "whisper-processing.h" #include "whisper-language.h" +#include "model-utils/model-downloader.h" inline enum speaker_layout convert_speaker_layout(uint8_t channels) { @@ -220,24 +221,24 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_model_path = bstrdup(new_model_path); // check if the model exists, if not, download it - // if (!check_if_model_exists(gf->whisper_model_path)) { - // obs_log(LOG_ERROR, "Whisper model does not exist"); - // download_model_with_ui_dialog( - // gf->whisper_model_path, [gf](int download_status) { - // if (download_status == 0) { - // obs_log(LOG_INFO, "Model download complete"); - // gf->whisper_context = init_whisper_context( - // gf->whisper_model_path); - // gf->whisper_thread = std::thread(whisper_loop, gf); - // } else { - // obs_log(LOG_ERROR, "Model download failed"); - // } - // }); - // } else { - // Model exists, just load it - gf->whisper_context = init_whisper_context(gf->whisper_model_path); - gf->whisper_thread = std::thread(whisper_loop, gf); - // } + if (!check_if_model_exists(gf->whisper_model_path)) { + obs_log(LOG_ERROR, "Whisper model does not exist"); + download_model_with_ui_dialog( + gf->whisper_model_path, [gf](int download_status) { + if (download_status == 0) { + obs_log(LOG_INFO, "Model download complete"); + gf->whisper_context = init_whisper_context( + gf->whisper_model_path); + gf->whisper_thread = std::thread(whisper_loop, gf); + } else { + obs_log(LOG_ERROR, "Model download failed"); + } + }); + } else { + // Model exists, just load it + gf->whisper_context = init_whisper_context(gf->whisper_model_path); + gf->whisper_thread = std::thread(whisper_loop, gf); + } } std::lock_guard lock(*gf->whisper_ctx_mutex); @@ -250,6 +251,7 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt"); gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads"); gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx"); + gf->whisper_params.translate = obs_data_get_bool(s, "translate"); gf->whisper_params.no_context = obs_data_get_bool(s, "no_context"); gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment"); gf->whisper_params.print_special = obs_data_get_bool(s, "print_special"); @@ -390,6 +392,7 @@ void transcription_filter_defaults(obs_data_t *s) obs_data_set_default_string(s, "initial_prompt", ""); obs_data_set_default_int(s, "n_threads", 4); obs_data_set_default_int(s, "n_max_text_ctx", 16384); + obs_data_set_default_bool(s, "translate", false); obs_data_set_default_bool(s, "no_context", true); obs_data_set_default_bool(s, "single_segment", true); obs_data_set_default_bool(s, "print_special", false); @@ -471,6 +474,7 @@ obs_properties_t *transcription_filter_properties(void *data) // int offset_ms; // start offset in ms // int duration_ms; // audio duration to process in ms // bool translate; + obs_properties_add_bool(whisper_params_group, "translate", "translate"); // bool no_context; // do not use past transcription (if any) as initial prompt for the decoder obs_properties_add_bool(whisper_params_group, "no_context", "no_context"); // bool single_segment; // force single segment output (useful for streaming) diff --git a/src/whisper-processing.cpp b/src/whisper-processing.cpp index c692f8c..35e8d3e 100644 --- a/src/whisper-processing.cpp +++ b/src/whisper-processing.cpp @@ -247,11 +247,6 @@ void process_audio_from_buffer(struct transcription_filter_data *gf) gf->log_level != LOG_DEBUG); } - // copy output buffer before potentially modifying it - for (size_t c = 0; c < gf->channels; c++) { - da_copy_array(gf->copy_output_buffers[c], gf->copy_buffers[c], gf->last_num_frames); - } - if (!skipped_inference) { // run inference const struct DetectionResultWithText inference_result =