add model downloader

This commit is contained in:
Roy Shilkrot 2023-08-13 17:55:04 +03:00
parent 86e719150d
commit 357b429b80
10 changed files with 332 additions and 34 deletions

View File

@ -50,7 +50,9 @@ endif()
include(cmake/BuildWhispercpp.cmake)
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp)
target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c
src/whisper-processing.cpp)
target_sources(
${CMAKE_PROJECT_NAME}
PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp
src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp)
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})

View File

@ -23,8 +23,8 @@
"CMAKE_OSX_DEPLOYMENT_TARGET": "11.0",
"CODESIGN_IDENTITY": "$penv{CODESIGN_IDENT}",
"CODESIGN_TEAM": "$penv{CODESIGN_TEAM}",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@ -53,8 +53,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_SYSTEM_VERSION": "10.0.18363.657",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@ -81,8 +81,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{
@ -110,8 +110,8 @@
"cacheVariables": {
"QT_VERSION": "6",
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
"ENABLE_FRONTEND_API": false,
"ENABLE_QT": false
"ENABLE_FRONTEND_API": true,
"ENABLE_QT": true
}
},
{

View File

@ -73,6 +73,14 @@ function(_setup_obs_studio)
set(_cmake_version "3.0.0")
endif()
message(STATUS "Patch libobs")
execute_process(
COMMAND patch --forward "libobs/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/patch_libobs.diff"
RESULT_VARIABLE _process_result COMMAND_ERROR_IS_FATAL ANY
WORKING_DIRECTORY "${dependencies_dir}/${_obs_destination}"
)
message(STATUS "Patch - done")
message(STATUS "Configure ${label} (${arch})")
execute_process(
COMMAND

View File

@ -0,0 +1,180 @@
#include "model-downloader-ui.h"
#include "plugin-support.h"
#include <obs-module.h>
const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp";
const std::string MODEL_PREFIX = "resolve/main/";
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
{
size_t written = fwrite(ptr, size, nmemb, stream);
return written;
}
ModelDownloader::ModelDownloader(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback_, QWidget *parent)
: QDialog(parent), download_finished_callback(download_finished_callback_)
{
this->setWindowTitle("Downloading model...");
this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint);
this->setFixedSize(300, 100);
this->layout = new QVBoxLayout(this);
// Add a label for the model name
QLabel *model_name_label = new QLabel(this);
model_name_label->setText(QString::fromStdString(model_name));
model_name_label->setAlignment(Qt::AlignCenter);
this->layout->addWidget(model_name_label);
this->progress_bar = new QProgressBar(this);
this->progress_bar->setRange(0, 100);
this->progress_bar->setValue(0);
this->progress_bar->setAlignment(Qt::AlignCenter);
// Show progress as a percentage
this->progress_bar->setFormat("%p%");
this->layout->addWidget(this->progress_bar);
this->download_thread = new QThread();
this->download_worker = new ModelDownloadWorker(model_name);
this->download_worker->moveToThread(this->download_thread);
connect(this->download_thread, &QThread::started, this->download_worker,
&ModelDownloadWorker::download_model);
connect(this->download_worker, &ModelDownloadWorker::download_progress, this,
&ModelDownloader::update_progress);
connect(this->download_worker, &ModelDownloadWorker::download_finished, this,
&ModelDownloader::download_finished);
connect(this->download_worker, &ModelDownloadWorker::download_finished,
this->download_thread, &QThread::quit);
connect(this->download_worker, &ModelDownloadWorker::download_finished,
this->download_worker, &ModelDownloadWorker::deleteLater);
connect(this->download_worker, &ModelDownloadWorker::download_error, this,
&ModelDownloader::show_error);
connect(this->download_thread, &QThread::finished, this->download_thread,
&QThread::deleteLater);
this->download_thread->start();
}
void ModelDownloader::update_progress(int progress)
{
this->progress_bar->setValue(progress);
}
void ModelDownloader::download_finished()
{
this->setWindowTitle("Download finished!");
this->progress_bar->setValue(100);
this->progress_bar->setFormat("Download finished!");
this->progress_bar->setAlignment(Qt::AlignCenter);
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }");
// Add a button to close the dialog
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
// Call the callback
this->download_finished_callback(0);
}
void ModelDownloader::show_error(const std::string &reason)
{
this->setWindowTitle("Download failed!");
this->progress_bar->setFormat("Download failed!");
this->progress_bar->setAlignment(Qt::AlignCenter);
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }");
// Add a label to show the error
QLabel *error_label = new QLabel(this);
error_label->setText(QString::fromStdString(reason));
error_label->setAlignment(Qt::AlignCenter);
// Color red
error_label->setStyleSheet("QLabel { color : red; }");
this->layout->addWidget(error_label);
// Add a button to close the dialog
QPushButton *close_button = new QPushButton("Close", this);
this->layout->addWidget(close_button);
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
this->download_finished_callback(1);
}
ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)
{
this->model_name = model_name_;
}
void ModelDownloadWorker::download_model()
{
std::string module_data_dir = obs_get_module_data_path(obs_current_module());
// join the directory and the filename using the platform-specific separator
std::string model_save_path = module_data_dir + "/" + this->model_name;
obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str());
// extract filename from path in this->modle_name
const std::string model_filename =
this->model_name.substr(this->model_name.find_last_of("/\\") + 1);
std::string model_url = MODEL_BASE_PATH + "/" + MODEL_PREFIX + model_filename;
obs_log(LOG_INFO, "Model URL: %s", model_url.c_str());
CURL *curl = curl_easy_init();
if (curl) {
FILE *fp = fopen(model_save_path.c_str(), "wb");
if (fp == nullptr) {
obs_log(LOG_ERROR, "Failed to open file %s.", model_save_path.c_str());
emit download_error("Failed to open file.");
return;
}
curl_easy_setopt(curl, CURLOPT_URL, model_url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
ModelDownloadWorker::progress_callback);
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this);
// Follow redirects
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
obs_log(LOG_ERROR, "Failed to download model %s.",
this->model_name.c_str());
emit download_error("Failed to download model.");
}
curl_easy_cleanup(curl);
fclose(fp);
} else {
obs_log(LOG_ERROR, "Failed to initialize curl.");
emit download_error("Failed to initialize curl.");
}
emit download_finished();
}
int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
curl_off_t, curl_off_t)
{
if (dltotal == 0) {
return 0; // Unknown progress
}
ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp;
if (worker == nullptr) {
obs_log(LOG_ERROR, "Worker is null.");
return 1;
}
int progress = (int)(dlnow * 100l / dltotal);
emit worker->download_progress(progress);
return 0;
}
ModelDownloader::~ModelDownloader()
{
this->download_thread->quit();
this->download_thread->wait();
delete this->download_thread;
delete this->download_worker;
}
ModelDownloadWorker::~ModelDownloadWorker()
{
// Do nothing
}

View File

@ -0,0 +1,54 @@
#ifndef MODEL_DOWNLOADER_UI_H
#define MODEL_DOWNLOADER_UI_H
#include <QtWidgets>
#include <QThread>
#include <string>
#include <functional>
#include <curl/curl.h>
class ModelDownloadWorker : public QObject {
Q_OBJECT
public:
ModelDownloadWorker(const std::string &model_name);
~ModelDownloadWorker();
public slots:
void download_model();
signals:
void download_progress(int progress);
void download_finished();
void download_error(const std::string &reason);
private:
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
curl_off_t ultotal, curl_off_t ulnow);
std::string model_name;
};
class ModelDownloader : public QDialog {
Q_OBJECT
public:
ModelDownloader(const std::string &model_name,
std::function<void(int download_status)> download_finished_callback,
QWidget *parent = nullptr);
~ModelDownloader();
public slots:
void update_progress(int progress);
void download_finished();
void show_error(const std::string &reason);
private:
QVBoxLayout *layout;
QProgressBar *progress_bar;
QThread *download_thread;
ModelDownloadWorker *download_worker;
// Callback for when the download is finished
std::function<void(int download_status)> download_finished_callback;
};
#endif // MODEL_DOWNLOADER_UI_H

View File

@ -0,0 +1,42 @@
#include "model-downloader.h"
#include "plugin-support.h"
#include "model-downloader-ui.h"
#include <obs-module.h>
#include <obs-frontend-api.h>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <curl/curl.h>
bool check_if_model_exists(const std::string &model_name)
{
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
char *model_file_path = obs_module_file(model_name.c_str());
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
if (model_file_path == nullptr) {
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
return false;
}
if (!std::filesystem::exists(model_file_path)) {
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
bfree(model_file_path);
return false;
}
bfree(model_file_path);
return true;
}
void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback)
{
// Start the model downloader UI
ModelDownloader *model_downloader = new ModelDownloader(
model_name, download_finished_callback, (QWidget *)obs_frontend_get_main_window());
model_downloader->show();
}

View File

@ -0,0 +1,14 @@
#ifndef MODEL_DOWNLOADER_H
#define MODEL_DOWNLOADER_H
#include <string>
#include <functional>
bool check_if_model_exists(const std::string &model_name);
// Start the model downloader UI dialog with a callback for when the download is finished
void download_model_with_ui_dialog(
const std::string &model_name,
std::function<void(int download_status)> download_finished_callback);
#endif // MODEL_DOWNLOADER_H

View File

@ -33,7 +33,6 @@ struct transcription_filter_data {
/* PCM buffers */
float *copy_buffers[MAX_PREPROC_CHANNELS];
DARRAY(float) copy_output_buffers[MAX_PREPROC_CHANNELS];
struct circlebuf info_buffer;
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];

View File

@ -5,6 +5,7 @@
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include "whisper-language.h"
#include "model-utils/model-downloader.h"
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
{
@ -220,24 +221,24 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_model_path = bstrdup(new_model_path);
// check if the model exists, if not, download it
// if (!check_if_model_exists(gf->whisper_model_path)) {
// obs_log(LOG_ERROR, "Whisper model does not exist");
// download_model_with_ui_dialog(
// gf->whisper_model_path, [gf](int download_status) {
// if (download_status == 0) {
// obs_log(LOG_INFO, "Model download complete");
// gf->whisper_context = init_whisper_context(
// gf->whisper_model_path);
// gf->whisper_thread = std::thread(whisper_loop, gf);
// } else {
// obs_log(LOG_ERROR, "Model download failed");
// }
// });
// } else {
// Model exists, just load it
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf);
// }
if (!check_if_model_exists(gf->whisper_model_path)) {
obs_log(LOG_ERROR, "Whisper model does not exist");
download_model_with_ui_dialog(
gf->whisper_model_path, [gf](int download_status) {
if (download_status == 0) {
obs_log(LOG_INFO, "Model download complete");
gf->whisper_context = init_whisper_context(
gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
gf->whisper_thread = std::thread(whisper_loop, gf);
}
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
@ -250,6 +251,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
gf->whisper_params.translate = obs_data_get_bool(s, "translate");
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
@ -390,6 +392,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "initial_prompt", "");
obs_data_set_default_int(s, "n_threads", 4);
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
obs_data_set_default_bool(s, "translate", false);
obs_data_set_default_bool(s, "no_context", true);
obs_data_set_default_bool(s, "single_segment", true);
obs_data_set_default_bool(s, "print_special", false);
@ -471,6 +474,7 @@ obs_properties_t *transcription_filter_properties(void *data)
// int offset_ms; // start offset in ms
// int duration_ms; // audio duration to process in ms
// bool translate;
obs_properties_add_bool(whisper_params_group, "translate", "translate");
// bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
obs_properties_add_bool(whisper_params_group, "no_context", "no_context");
// bool single_segment; // force single segment output (useful for streaming)

View File

@ -247,11 +247,6 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
gf->log_level != LOG_DEBUG);
}
// copy output buffer before potentially modifying it
for (size_t c = 0; c < gf->channels; c++) {
da_copy_array(gf->copy_output_buffers[c], gf->copy_buffers[c], gf->last_num_frames);
}
if (!skipped_inference) {
// run inference
const struct DetectionResultWithText inference_result =