Fix CUDA build, shuffle whisper files around (#58)

* fix CUDA build, shuffle whisper files around

* lint
This commit is contained in:
Roy Shilkrot 2023-11-20 09:18:06 -05:00 committed by GitHub
parent 33b9756624
commit 8c02e0c3fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 139 additions and 109 deletions

View File

@ -41,7 +41,7 @@ set(USE_SYSTEM_CURL
CACHE STRING "Use system cURL")
if(LOCALVOCAL_WITH_CUDA)
add_compile_definitions(-DLOCALVOCAL_WITH_CUDA)
add_compile_definitions("LOCALVOCAL_WITH_CUDA")
endif()
if(USE_SYSTEM_CURL)
@ -58,7 +58,12 @@ target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp)
target_sources(
${CMAKE_PROJECT_NAME}
PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp
src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp)
PRIVATE src/plugin-main.c
src/transcription-filter.cpp
src/transcription-filter.c
src/whisper-utils/whisper-processing.cpp
src/model-utils/model-downloader.cpp
src/model-utils/model-downloader-ui.cpp
src/whisper-utils/whisper-utils.cpp)
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})

View File

@ -26,11 +26,12 @@ if(WIN32)
# Build with CUDA Check that CUDA_TOOLKIT_ROOT_DIR is set
if(NOT DEFINED CUDA_TOOLKIT_ROOT_DIR)
message(FATAL_ERROR "CUDA_TOOLKIT_ROOT_DIR is not set. Please set it to the root directory of your CUDA "
"installation.")
"installation, e.g. `C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.4`")
endif(NOT DEFINED CUDA_TOOLKIT_ROOT_DIR)
set(WHISPER_ADDITIONAL_ENV "CUDAToolkit_ROOT=${CUDA_TOOLKIT_ROOT_DIR}")
set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_CUBLAS=ON -DCMAKE_GENERATOR_TOOLSET=cuda=${CUDA_TOOLKIT_ROOT_DIR})
set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_CUBLAS=ON -DWHISPER_OPENBLAS=OFF
-DCMAKE_GENERATOR_TOOLSET=cuda=${CUDA_TOOLKIT_ROOT_DIR})
else()
# Build with OpenBLAS
set(OpenBLAS_URL "https://github.com/xianyi/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip")

View File

@ -4,9 +4,10 @@
#include "plugin-support.h"
#include "transcription-filter.h"
#include "transcription-filter-data.h"
#include "whisper-processing.h"
#include "whisper-language.h"
#include "model-utils/model-downloader.h"
#include "whisper-utils/whisper-processing.h"
#include "whisper-utils/whisper-language.h"
#include "whisper-utils/whisper-utils.h"
#include <algorithm>
#include <fstream>
@ -364,47 +365,6 @@ void set_text_callback(struct transcription_filter_data *gf,
}
};
void shutdown_whisper_thread(struct transcription_filter_data *gf)
{
obs_log(gf->log_level, "shutdown_whisper_thread");
if (gf->whisper_context != nullptr) {
// acquire the mutex before freeing the context
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
if (gf->whisper_model_path != nullptr) {
bfree(gf->whisper_model_path);
gf->whisper_model_path = nullptr;
}
}
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path)
{
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str());
if (!gf->whisper_ctx_mutex) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
return;
}
gf->whisper_context = init_whisper_context(path);
gf->whisper_model_file_currently_loaded = path;
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}
void transcription_filter_update(void *data, obs_data_t *s)
{
struct transcription_filter_data *gf =
@ -489,67 +449,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
}
obs_log(gf->log_level, "transcription_filter: update whisper model");
// update the whisper model path
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
if (gf->whisper_model_path == nullptr ||
strcmp(new_model_path.c_str(), gf->whisper_model_path) != 0 || is_external_model) {
// model path changed, reload the model
obs_log(gf->log_level, "model path changed from %s to %s", gf->whisper_model_path,
new_model_path.c_str());
// check if the new model is external file
if (!is_external_model) {
// new model is not external file
shutdown_whisper_thread(gf);
gf->whisper_model_path = bstrdup(new_model_path.c_str());
// check if the model exists, if not, download it
std::string model_file_found = find_model_file(gf->whisper_model_path);
if (model_file_found == "") {
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
gf->whisper_model_path,
[gf](int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
start_whisper_thread_with_path(gf, path);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
start_whisper_thread_with_path(gf, model_file_found);
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
// check if the external model file is not currently loaded
if (gf->whisper_model_file_currently_loaded ==
external_model_file_path) {
obs_log(LOG_INFO, "External model file is already loaded");
return;
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = bstrdup(new_model_path.c_str());
start_whisper_thread_with_path(gf,
external_model_file_path);
}
}
}
} else {
// model path did not change
obs_log(LOG_DEBUG, "model path did not change: %s == %s", gf->whisper_model_path,
new_model_path.c_str());
}
update_whsiper_model_path(gf, s);
if (!gf->whisper_ctx_mutex) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");

View File

@ -0,0 +1,110 @@
#include "whisper-utils.h"
#include "plugin-support.h"
#include "model-utils/model-downloader.h"
#include "whisper-processing.h"
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s)
{
// update the whisper model path
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
if (gf->whisper_model_path == nullptr ||
strcmp(new_model_path.c_str(), gf->whisper_model_path) != 0 || is_external_model) {
// model path changed, reload the model
obs_log(gf->log_level, "model path changed from %s to %s", gf->whisper_model_path,
new_model_path.c_str());
// check if the new model is external file
if (!is_external_model) {
// new model is not external file
shutdown_whisper_thread(gf);
gf->whisper_model_path = bstrdup(new_model_path.c_str());
// check if the model exists, if not, download it
std::string model_file_found = find_model_file(gf->whisper_model_path);
if (model_file_found == "") {
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
gf->whisper_model_path,
[gf](int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
start_whisper_thread_with_path(gf, path);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
start_whisper_thread_with_path(gf, model_file_found);
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
// check if the external model file is not currently loaded
if (gf->whisper_model_file_currently_loaded ==
external_model_file_path) {
obs_log(LOG_INFO, "External model file is already loaded");
return;
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = bstrdup(new_model_path.c_str());
start_whisper_thread_with_path(gf,
external_model_file_path);
}
}
}
} else {
// model path did not change
obs_log(LOG_DEBUG, "model path did not change: %s == %s", gf->whisper_model_path,
new_model_path.c_str());
}
}
void shutdown_whisper_thread(struct transcription_filter_data *gf)
{
obs_log(gf->log_level, "shutdown_whisper_thread");
if (gf->whisper_context != nullptr) {
// acquire the mutex before freeing the context
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
}
if (gf->whisper_model_path != nullptr) {
bfree(gf->whisper_model_path);
gf->whisper_model_path = nullptr;
}
}
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path)
{
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str());
if (!gf->whisper_ctx_mutex) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
return;
}
gf->whisper_context = init_whisper_context(path);
gf->whisper_model_file_currently_loaded = path;
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}

View File

@ -0,0 +1,14 @@
#ifndef WHISPER_UTILS_H
#define WHISPER_UTILS_H
#include "transcription-filter-data.h"
#include <obs.h>
#include <string>
void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s);
void shutdown_whisper_thread(struct transcription_filter_data *gf);
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path);
#endif /* WHISPER_UTILS_H */