mirror of
https://github.com/occ-ai/obs-localvocal
synced 2024-11-07 10:50:26 +00:00
add model downloader
This commit is contained in:
parent
86e719150d
commit
357b429b80
@ -50,7 +50,9 @@ endif()
|
||||
include(cmake/BuildWhispercpp.cmake)
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp)
|
||||
|
||||
target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c
|
||||
src/whisper-processing.cpp)
|
||||
target_sources(
|
||||
${CMAKE_PROJECT_NAME}
|
||||
PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp
|
||||
src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp)
|
||||
|
||||
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
|
||||
|
@ -23,8 +23,8 @@
|
||||
"CMAKE_OSX_DEPLOYMENT_TARGET": "11.0",
|
||||
"CODESIGN_IDENTITY": "$penv{CODESIGN_IDENT}",
|
||||
"CODESIGN_TEAM": "$penv{CODESIGN_TEAM}",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -53,8 +53,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_SYSTEM_VERSION": "10.0.18363.657",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -81,8 +81,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
@ -110,8 +110,8 @@
|
||||
"cacheVariables": {
|
||||
"QT_VERSION": "6",
|
||||
"CMAKE_BUILD_TYPE": "RelWithDebInfo",
|
||||
"ENABLE_FRONTEND_API": false,
|
||||
"ENABLE_QT": false
|
||||
"ENABLE_FRONTEND_API": true,
|
||||
"ENABLE_QT": true
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -73,6 +73,14 @@ function(_setup_obs_studio)
|
||||
set(_cmake_version "3.0.0")
|
||||
endif()
|
||||
|
||||
message(STATUS "Patch libobs")
|
||||
execute_process(
|
||||
COMMAND patch --forward "libobs/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/patch_libobs.diff"
|
||||
RESULT_VARIABLE _process_result COMMAND_ERROR_IS_FATAL ANY
|
||||
WORKING_DIRECTORY "${dependencies_dir}/${_obs_destination}"
|
||||
)
|
||||
message(STATUS "Patch - done")
|
||||
|
||||
message(STATUS "Configure ${label} (${arch})")
|
||||
execute_process(
|
||||
COMMAND
|
||||
|
180
src/model-utils/model-downloader-ui.cpp
Normal file
180
src/model-utils/model-downloader-ui.cpp
Normal file
@ -0,0 +1,180 @@
|
||||
#include "model-downloader-ui.h"
|
||||
#include "plugin-support.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
const std::string MODEL_BASE_PATH = "https://huggingface.co/ggerganov/whisper.cpp";
|
||||
const std::string MODEL_PREFIX = "resolve/main/";
|
||||
|
||||
size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream)
|
||||
{
|
||||
size_t written = fwrite(ptr, size, nmemb, stream);
|
||||
return written;
|
||||
}
|
||||
|
||||
ModelDownloader::ModelDownloader(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback_, QWidget *parent)
|
||||
: QDialog(parent), download_finished_callback(download_finished_callback_)
|
||||
{
|
||||
this->setWindowTitle("Downloading model...");
|
||||
this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint);
|
||||
this->setFixedSize(300, 100);
|
||||
|
||||
this->layout = new QVBoxLayout(this);
|
||||
|
||||
// Add a label for the model name
|
||||
QLabel *model_name_label = new QLabel(this);
|
||||
model_name_label->setText(QString::fromStdString(model_name));
|
||||
model_name_label->setAlignment(Qt::AlignCenter);
|
||||
this->layout->addWidget(model_name_label);
|
||||
|
||||
this->progress_bar = new QProgressBar(this);
|
||||
this->progress_bar->setRange(0, 100);
|
||||
this->progress_bar->setValue(0);
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
// Show progress as a percentage
|
||||
this->progress_bar->setFormat("%p%");
|
||||
this->layout->addWidget(this->progress_bar);
|
||||
|
||||
this->download_thread = new QThread();
|
||||
this->download_worker = new ModelDownloadWorker(model_name);
|
||||
this->download_worker->moveToThread(this->download_thread);
|
||||
|
||||
connect(this->download_thread, &QThread::started, this->download_worker,
|
||||
&ModelDownloadWorker::download_model);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_progress, this,
|
||||
&ModelDownloader::update_progress);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished, this,
|
||||
&ModelDownloader::download_finished);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished,
|
||||
this->download_thread, &QThread::quit);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_finished,
|
||||
this->download_worker, &ModelDownloadWorker::deleteLater);
|
||||
connect(this->download_worker, &ModelDownloadWorker::download_error, this,
|
||||
&ModelDownloader::show_error);
|
||||
connect(this->download_thread, &QThread::finished, this->download_thread,
|
||||
&QThread::deleteLater);
|
||||
|
||||
this->download_thread->start();
|
||||
}
|
||||
|
||||
void ModelDownloader::update_progress(int progress)
|
||||
{
|
||||
this->progress_bar->setValue(progress);
|
||||
}
|
||||
|
||||
void ModelDownloader::download_finished()
|
||||
{
|
||||
this->setWindowTitle("Download finished!");
|
||||
this->progress_bar->setValue(100);
|
||||
this->progress_bar->setFormat("Download finished!");
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #05B8CC; }");
|
||||
// Add a button to close the dialog
|
||||
QPushButton *close_button = new QPushButton("Close", this);
|
||||
this->layout->addWidget(close_button);
|
||||
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
|
||||
// Call the callback
|
||||
this->download_finished_callback(0);
|
||||
}
|
||||
|
||||
void ModelDownloader::show_error(const std::string &reason)
|
||||
{
|
||||
this->setWindowTitle("Download failed!");
|
||||
this->progress_bar->setFormat("Download failed!");
|
||||
this->progress_bar->setAlignment(Qt::AlignCenter);
|
||||
this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }");
|
||||
// Add a label to show the error
|
||||
QLabel *error_label = new QLabel(this);
|
||||
error_label->setText(QString::fromStdString(reason));
|
||||
error_label->setAlignment(Qt::AlignCenter);
|
||||
// Color red
|
||||
error_label->setStyleSheet("QLabel { color : red; }");
|
||||
this->layout->addWidget(error_label);
|
||||
// Add a button to close the dialog
|
||||
QPushButton *close_button = new QPushButton("Close", this);
|
||||
this->layout->addWidget(close_button);
|
||||
connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close);
|
||||
this->download_finished_callback(1);
|
||||
}
|
||||
|
||||
ModelDownloadWorker::ModelDownloadWorker(const std::string &model_name_)
|
||||
{
|
||||
this->model_name = model_name_;
|
||||
}
|
||||
|
||||
void ModelDownloadWorker::download_model()
|
||||
{
|
||||
std::string module_data_dir = obs_get_module_data_path(obs_current_module());
|
||||
// join the directory and the filename using the platform-specific separator
|
||||
std::string model_save_path = module_data_dir + "/" + this->model_name;
|
||||
obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str());
|
||||
|
||||
// extract filename from path in this->modle_name
|
||||
const std::string model_filename =
|
||||
this->model_name.substr(this->model_name.find_last_of("/\\") + 1);
|
||||
|
||||
std::string model_url = MODEL_BASE_PATH + "/" + MODEL_PREFIX + model_filename;
|
||||
obs_log(LOG_INFO, "Model URL: %s", model_url.c_str());
|
||||
|
||||
CURL *curl = curl_easy_init();
|
||||
if (curl) {
|
||||
FILE *fp = fopen(model_save_path.c_str(), "wb");
|
||||
if (fp == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to open file %s.", model_save_path.c_str());
|
||||
emit download_error("Failed to open file.");
|
||||
return;
|
||||
}
|
||||
curl_easy_setopt(curl, CURLOPT_URL, model_url.c_str());
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data);
|
||||
curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp);
|
||||
curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L);
|
||||
curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION,
|
||||
ModelDownloadWorker::progress_callback);
|
||||
curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this);
|
||||
// Follow redirects
|
||||
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
|
||||
CURLcode res = curl_easy_perform(curl);
|
||||
if (res != CURLE_OK) {
|
||||
obs_log(LOG_ERROR, "Failed to download model %s.",
|
||||
this->model_name.c_str());
|
||||
emit download_error("Failed to download model.");
|
||||
}
|
||||
curl_easy_cleanup(curl);
|
||||
fclose(fp);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Failed to initialize curl.");
|
||||
emit download_error("Failed to initialize curl.");
|
||||
}
|
||||
emit download_finished();
|
||||
}
|
||||
|
||||
int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
|
||||
curl_off_t, curl_off_t)
|
||||
{
|
||||
if (dltotal == 0) {
|
||||
return 0; // Unknown progress
|
||||
}
|
||||
ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp;
|
||||
if (worker == nullptr) {
|
||||
obs_log(LOG_ERROR, "Worker is null.");
|
||||
return 1;
|
||||
}
|
||||
int progress = (int)(dlnow * 100l / dltotal);
|
||||
emit worker->download_progress(progress);
|
||||
return 0;
|
||||
}
|
||||
|
||||
ModelDownloader::~ModelDownloader()
|
||||
{
|
||||
this->download_thread->quit();
|
||||
this->download_thread->wait();
|
||||
delete this->download_thread;
|
||||
delete this->download_worker;
|
||||
}
|
||||
|
||||
ModelDownloadWorker::~ModelDownloadWorker()
|
||||
{
|
||||
// Do nothing
|
||||
}
|
54
src/model-utils/model-downloader-ui.h
Normal file
54
src/model-utils/model-downloader-ui.h
Normal file
@ -0,0 +1,54 @@
|
||||
#ifndef MODEL_DOWNLOADER_UI_H
|
||||
#define MODEL_DOWNLOADER_UI_H
|
||||
|
||||
#include <QtWidgets>
|
||||
#include <QThread>
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
class ModelDownloadWorker : public QObject {
|
||||
Q_OBJECT
|
||||
public:
|
||||
ModelDownloadWorker(const std::string &model_name);
|
||||
~ModelDownloadWorker();
|
||||
|
||||
public slots:
|
||||
void download_model();
|
||||
|
||||
signals:
|
||||
void download_progress(int progress);
|
||||
void download_finished();
|
||||
void download_error(const std::string &reason);
|
||||
|
||||
private:
|
||||
static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
|
||||
curl_off_t ultotal, curl_off_t ulnow);
|
||||
std::string model_name;
|
||||
};
|
||||
|
||||
class ModelDownloader : public QDialog {
|
||||
Q_OBJECT
|
||||
public:
|
||||
ModelDownloader(const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback,
|
||||
QWidget *parent = nullptr);
|
||||
~ModelDownloader();
|
||||
|
||||
public slots:
|
||||
void update_progress(int progress);
|
||||
void download_finished();
|
||||
void show_error(const std::string &reason);
|
||||
|
||||
private:
|
||||
QVBoxLayout *layout;
|
||||
QProgressBar *progress_bar;
|
||||
QThread *download_thread;
|
||||
ModelDownloadWorker *download_worker;
|
||||
// Callback for when the download is finished
|
||||
std::function<void(int download_status)> download_finished_callback;
|
||||
};
|
||||
|
||||
#endif // MODEL_DOWNLOADER_UI_H
|
42
src/model-utils/model-downloader.cpp
Normal file
42
src/model-utils/model-downloader.cpp
Normal file
@ -0,0 +1,42 @@
|
||||
#include "model-downloader.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-downloader-ui.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
#include <obs-frontend-api.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
bool check_if_model_exists(const std::string &model_name)
|
||||
{
|
||||
obs_log(LOG_INFO, "Checking if model %s exists...", model_name.c_str());
|
||||
char *model_file_path = obs_module_file(model_name.c_str());
|
||||
obs_log(LOG_INFO, "Model file path: %s", model_file_path);
|
||||
if (model_file_path == nullptr) {
|
||||
obs_log(LOG_INFO, "Model %s does not exist.", model_name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!std::filesystem::exists(model_file_path)) {
|
||||
obs_log(LOG_INFO, "Model %s does not exist.", model_file_path);
|
||||
bfree(model_file_path);
|
||||
return false;
|
||||
}
|
||||
bfree(model_file_path);
|
||||
return true;
|
||||
}
|
||||
|
||||
void download_model_with_ui_dialog(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback)
|
||||
{
|
||||
// Start the model downloader UI
|
||||
ModelDownloader *model_downloader = new ModelDownloader(
|
||||
model_name, download_finished_callback, (QWidget *)obs_frontend_get_main_window());
|
||||
model_downloader->show();
|
||||
}
|
14
src/model-utils/model-downloader.h
Normal file
14
src/model-utils/model-downloader.h
Normal file
@ -0,0 +1,14 @@
|
||||
#ifndef MODEL_DOWNLOADER_H
|
||||
#define MODEL_DOWNLOADER_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
bool check_if_model_exists(const std::string &model_name);
|
||||
|
||||
// Start the model downloader UI dialog with a callback for when the download is finished
|
||||
void download_model_with_ui_dialog(
|
||||
const std::string &model_name,
|
||||
std::function<void(int download_status)> download_finished_callback);
|
||||
|
||||
#endif // MODEL_DOWNLOADER_H
|
@ -33,7 +33,6 @@ struct transcription_filter_data {
|
||||
|
||||
/* PCM buffers */
|
||||
float *copy_buffers[MAX_PREPROC_CHANNELS];
|
||||
DARRAY(float) copy_output_buffers[MAX_PREPROC_CHANNELS];
|
||||
struct circlebuf info_buffer;
|
||||
struct circlebuf input_buffers[MAX_PREPROC_CHANNELS];
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "transcription-filter-data.h"
|
||||
#include "whisper-processing.h"
|
||||
#include "whisper-language.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
|
||||
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
|
||||
{
|
||||
@ -220,24 +221,24 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
gf->whisper_model_path = bstrdup(new_model_path);
|
||||
|
||||
// check if the model exists, if not, download it
|
||||
// if (!check_if_model_exists(gf->whisper_model_path)) {
|
||||
// obs_log(LOG_ERROR, "Whisper model does not exist");
|
||||
// download_model_with_ui_dialog(
|
||||
// gf->whisper_model_path, [gf](int download_status) {
|
||||
// if (download_status == 0) {
|
||||
// obs_log(LOG_INFO, "Model download complete");
|
||||
// gf->whisper_context = init_whisper_context(
|
||||
// gf->whisper_model_path);
|
||||
// gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
// } else {
|
||||
// obs_log(LOG_ERROR, "Model download failed");
|
||||
// }
|
||||
// });
|
||||
// } else {
|
||||
// Model exists, just load it
|
||||
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
// }
|
||||
if (!check_if_model_exists(gf->whisper_model_path)) {
|
||||
obs_log(LOG_ERROR, "Whisper model does not exist");
|
||||
download_model_with_ui_dialog(
|
||||
gf->whisper_model_path, [gf](int download_status) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO, "Model download complete");
|
||||
gf->whisper_context = init_whisper_context(
|
||||
gf->whisper_model_path);
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
gf->whisper_context = init_whisper_context(gf->whisper_model_path);
|
||||
gf->whisper_thread = std::thread(whisper_loop, gf);
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
@ -250,6 +251,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
|
||||
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
|
||||
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
|
||||
gf->whisper_params.translate = obs_data_get_bool(s, "translate");
|
||||
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
|
||||
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
|
||||
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
|
||||
@ -390,6 +392,7 @@ void transcription_filter_defaults(obs_data_t *s)
|
||||
obs_data_set_default_string(s, "initial_prompt", "");
|
||||
obs_data_set_default_int(s, "n_threads", 4);
|
||||
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
|
||||
obs_data_set_default_bool(s, "translate", false);
|
||||
obs_data_set_default_bool(s, "no_context", true);
|
||||
obs_data_set_default_bool(s, "single_segment", true);
|
||||
obs_data_set_default_bool(s, "print_special", false);
|
||||
@ -471,6 +474,7 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
// int offset_ms; // start offset in ms
|
||||
// int duration_ms; // audio duration to process in ms
|
||||
// bool translate;
|
||||
obs_properties_add_bool(whisper_params_group, "translate", "translate");
|
||||
// bool no_context; // do not use past transcription (if any) as initial prompt for the decoder
|
||||
obs_properties_add_bool(whisper_params_group, "no_context", "no_context");
|
||||
// bool single_segment; // force single segment output (useful for streaming)
|
||||
|
@ -247,11 +247,6 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
gf->log_level != LOG_DEBUG);
|
||||
}
|
||||
|
||||
// copy output buffer before potentially modifying it
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
da_copy_array(gf->copy_output_buffers[c], gf->copy_buffers[c], gf->last_num_frames);
|
||||
}
|
||||
|
||||
if (!skipped_inference) {
|
||||
// run inference
|
||||
const struct DetectionResultWithText inference_result =
|
||||
|
Loading…
Reference in New Issue
Block a user