From 8c02e0c3fc2534f9f480b5d8dbbd9cf947134742 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 20 Nov 2023 09:18:06 -0500 Subject: [PATCH] Fix CUDA build, shuffle whisper files around (#58) * fix CUDA build, shuffle whisper files around * lint --- CMakeLists.txt | 11 +- cmake/BuildWhispercpp.cmake | 5 +- src/transcription-filter.cpp | 108 +---------------- src/{ => whisper-utils}/whisper-language.h | 0 .../whisper-processing.cpp | 0 src/{ => whisper-utils}/whisper-processing.h | 0 src/whisper-utils/whisper-utils.cpp | 110 ++++++++++++++++++ src/whisper-utils/whisper-utils.h | 14 +++ 8 files changed, 139 insertions(+), 109 deletions(-) rename src/{ => whisper-utils}/whisper-language.h (100%) rename src/{ => whisper-utils}/whisper-processing.cpp (100%) rename src/{ => whisper-utils}/whisper-processing.h (100%) create mode 100644 src/whisper-utils/whisper-utils.cpp create mode 100644 src/whisper-utils/whisper-utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6737289..21cdb9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,7 +41,7 @@ set(USE_SYSTEM_CURL CACHE STRING "Use system cURL") if(LOCALVOCAL_WITH_CUDA) - add_compile_definitions(-DLOCALVOCAL_WITH_CUDA) + add_compile_definitions("LOCALVOCAL_WITH_CUDA") endif() if(USE_SYSTEM_CURL) @@ -58,7 +58,12 @@ target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE Whispercpp) target_sources( ${CMAKE_PROJECT_NAME} - PRIVATE src/plugin-main.c src/transcription-filter.cpp src/transcription-filter.c src/whisper-processing.cpp - src/model-utils/model-downloader.cpp src/model-utils/model-downloader-ui.cpp) + PRIVATE src/plugin-main.c + src/transcription-filter.cpp + src/transcription-filter.c + src/whisper-utils/whisper-processing.cpp + src/model-utils/model-downloader.cpp + src/model-utils/model-downloader-ui.cpp + src/whisper-utils/whisper-utils.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) diff --git a/cmake/BuildWhispercpp.cmake b/cmake/BuildWhispercpp.cmake index 76dab06..4098de4 100644 --- a/cmake/BuildWhispercpp.cmake +++ b/cmake/BuildWhispercpp.cmake @@ -26,11 +26,12 @@ if(WIN32) # Build with CUDA Check that CUDA_TOOLKIT_ROOT_DIR is set if(NOT DEFINED CUDA_TOOLKIT_ROOT_DIR) message(FATAL_ERROR "CUDA_TOOLKIT_ROOT_DIR is not set. Please set it to the root directory of your CUDA " - "installation.") + "installation, e.g. `C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.4`") endif(NOT DEFINED CUDA_TOOLKIT_ROOT_DIR) set(WHISPER_ADDITIONAL_ENV "CUDAToolkit_ROOT=${CUDA_TOOLKIT_ROOT_DIR}") - set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_CUBLAS=ON -DCMAKE_GENERATOR_TOOLSET=cuda=${CUDA_TOOLKIT_ROOT_DIR}) + set(WHISPER_ADDITIONAL_CMAKE_ARGS -DWHISPER_CUBLAS=ON -DWHISPER_OPENBLAS=OFF + -DCMAKE_GENERATOR_TOOLSET=cuda=${CUDA_TOOLKIT_ROOT_DIR}) else() # Build with OpenBLAS set(OpenBLAS_URL "https://github.com/xianyi/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip") diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 4eaefd8..860401b 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -4,9 +4,10 @@ #include "plugin-support.h" #include "transcription-filter.h" #include "transcription-filter-data.h" -#include "whisper-processing.h" -#include "whisper-language.h" #include "model-utils/model-downloader.h" +#include "whisper-utils/whisper-processing.h" +#include "whisper-utils/whisper-language.h" +#include "whisper-utils/whisper-utils.h" #include #include @@ -364,47 +365,6 @@ void set_text_callback(struct transcription_filter_data *gf, } }; -void shutdown_whisper_thread(struct transcription_filter_data *gf) -{ - obs_log(gf->log_level, "shutdown_whisper_thread"); - if (gf->whisper_context != nullptr) { - // acquire the mutex before freeing the context - if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) { - obs_log(LOG_ERROR, "whisper_ctx_mutex is null"); - return; - } - std::lock_guard lock(*gf->whisper_ctx_mutex); - whisper_free(gf->whisper_context); - gf->whisper_context = nullptr; - gf->wshiper_thread_cv->notify_all(); - } - if (gf->whisper_thread.joinable()) { - gf->whisper_thread.join(); - } - if (gf->whisper_model_path != nullptr) { - bfree(gf->whisper_model_path); - gf->whisper_model_path = nullptr; - } -} - -void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path) -{ - obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str()); - if (!gf->whisper_ctx_mutex) { - obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null"); - return; - } - std::lock_guard lock(*gf->whisper_ctx_mutex); - if (gf->whisper_context != nullptr) { - obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null"); - return; - } - gf->whisper_context = init_whisper_context(path); - gf->whisper_model_file_currently_loaded = path; - std::thread new_whisper_thread(whisper_loop, gf); - gf->whisper_thread.swap(new_whisper_thread); -} - void transcription_filter_update(void *data, obs_data_t *s) { struct transcription_filter_data *gf = @@ -489,67 +449,7 @@ void transcription_filter_update(void *data, obs_data_t *s) } obs_log(gf->log_level, "transcription_filter: update whisper model"); - // update the whisper model path - std::string new_model_path = obs_data_get_string(s, "whisper_model_path"); - const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos; - - if (gf->whisper_model_path == nullptr || - strcmp(new_model_path.c_str(), gf->whisper_model_path) != 0 || is_external_model) { - // model path changed, reload the model - obs_log(gf->log_level, "model path changed from %s to %s", gf->whisper_model_path, - new_model_path.c_str()); - - // check if the new model is external file - if (!is_external_model) { - // new model is not external file - shutdown_whisper_thread(gf); - - gf->whisper_model_path = bstrdup(new_model_path.c_str()); - - // check if the model exists, if not, download it - std::string model_file_found = find_model_file(gf->whisper_model_path); - if (model_file_found == "") { - obs_log(LOG_WARNING, "Whisper model does not exist"); - download_model_with_ui_dialog( - gf->whisper_model_path, - [gf](int download_status, const std::string &path) { - if (download_status == 0) { - obs_log(LOG_INFO, - "Model download complete"); - start_whisper_thread_with_path(gf, path); - } else { - obs_log(LOG_ERROR, "Model download failed"); - } - }); - } else { - // Model exists, just load it - start_whisper_thread_with_path(gf, model_file_found); - } - } else { - // new model is external file, get file location from file property - std::string external_model_file_path = - obs_data_get_string(s, "whisper_model_path_external"); - if (external_model_file_path.empty()) { - obs_log(LOG_WARNING, "External model file path is empty"); - } else { - // check if the external model file is not currently loaded - if (gf->whisper_model_file_currently_loaded == - external_model_file_path) { - obs_log(LOG_INFO, "External model file is already loaded"); - return; - } else { - shutdown_whisper_thread(gf); - gf->whisper_model_path = bstrdup(new_model_path.c_str()); - start_whisper_thread_with_path(gf, - external_model_file_path); - } - } - } - } else { - // model path did not change - obs_log(LOG_DEBUG, "model path did not change: %s == %s", gf->whisper_model_path, - new_model_path.c_str()); - } + update_whsiper_model_path(gf, s); if (!gf->whisper_ctx_mutex) { obs_log(LOG_ERROR, "whisper_ctx_mutex is null"); diff --git a/src/whisper-language.h b/src/whisper-utils/whisper-language.h similarity index 100% rename from src/whisper-language.h rename to src/whisper-utils/whisper-language.h diff --git a/src/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp similarity index 100% rename from src/whisper-processing.cpp rename to src/whisper-utils/whisper-processing.cpp diff --git a/src/whisper-processing.h b/src/whisper-utils/whisper-processing.h similarity index 100% rename from src/whisper-processing.h rename to src/whisper-utils/whisper-processing.h diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp new file mode 100644 index 0000000..865d513 --- /dev/null +++ b/src/whisper-utils/whisper-utils.cpp @@ -0,0 +1,110 @@ +#include "whisper-utils.h" +#include "plugin-support.h" +#include "model-utils/model-downloader.h" +#include "whisper-processing.h" + +void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s) +{ + // update the whisper model path + std::string new_model_path = obs_data_get_string(s, "whisper_model_path"); + const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos; + + if (gf->whisper_model_path == nullptr || + strcmp(new_model_path.c_str(), gf->whisper_model_path) != 0 || is_external_model) { + // model path changed, reload the model + obs_log(gf->log_level, "model path changed from %s to %s", gf->whisper_model_path, + new_model_path.c_str()); + + // check if the new model is external file + if (!is_external_model) { + // new model is not external file + shutdown_whisper_thread(gf); + + gf->whisper_model_path = bstrdup(new_model_path.c_str()); + + // check if the model exists, if not, download it + std::string model_file_found = find_model_file(gf->whisper_model_path); + if (model_file_found == "") { + obs_log(LOG_WARNING, "Whisper model does not exist"); + download_model_with_ui_dialog( + gf->whisper_model_path, + [gf](int download_status, const std::string &path) { + if (download_status == 0) { + obs_log(LOG_INFO, + "Model download complete"); + start_whisper_thread_with_path(gf, path); + } else { + obs_log(LOG_ERROR, "Model download failed"); + } + }); + } else { + // Model exists, just load it + start_whisper_thread_with_path(gf, model_file_found); + } + } else { + // new model is external file, get file location from file property + std::string external_model_file_path = + obs_data_get_string(s, "whisper_model_path_external"); + if (external_model_file_path.empty()) { + obs_log(LOG_WARNING, "External model file path is empty"); + } else { + // check if the external model file is not currently loaded + if (gf->whisper_model_file_currently_loaded == + external_model_file_path) { + obs_log(LOG_INFO, "External model file is already loaded"); + return; + } else { + shutdown_whisper_thread(gf); + gf->whisper_model_path = bstrdup(new_model_path.c_str()); + start_whisper_thread_with_path(gf, + external_model_file_path); + } + } + } + } else { + // model path did not change + obs_log(LOG_DEBUG, "model path did not change: %s == %s", gf->whisper_model_path, + new_model_path.c_str()); + } +} + +void shutdown_whisper_thread(struct transcription_filter_data *gf) +{ + obs_log(gf->log_level, "shutdown_whisper_thread"); + if (gf->whisper_context != nullptr) { + // acquire the mutex before freeing the context + if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) { + obs_log(LOG_ERROR, "whisper_ctx_mutex is null"); + return; + } + std::lock_guard lock(*gf->whisper_ctx_mutex); + whisper_free(gf->whisper_context); + gf->whisper_context = nullptr; + gf->wshiper_thread_cv->notify_all(); + } + if (gf->whisper_thread.joinable()) { + gf->whisper_thread.join(); + } + if (gf->whisper_model_path != nullptr) { + bfree(gf->whisper_model_path); + gf->whisper_model_path = nullptr; + } +} + +void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path) +{ + obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str()); + if (!gf->whisper_ctx_mutex) { + obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null"); + return; + } + std::lock_guard lock(*gf->whisper_ctx_mutex); + if (gf->whisper_context != nullptr) { + obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null"); + return; + } + gf->whisper_context = init_whisper_context(path); + gf->whisper_model_file_currently_loaded = path; + std::thread new_whisper_thread(whisper_loop, gf); + gf->whisper_thread.swap(new_whisper_thread); +} diff --git a/src/whisper-utils/whisper-utils.h b/src/whisper-utils/whisper-utils.h new file mode 100644 index 0000000..6e80b2f --- /dev/null +++ b/src/whisper-utils/whisper-utils.h @@ -0,0 +1,14 @@ +#ifndef WHISPER_UTILS_H +#define WHISPER_UTILS_H + +#include "transcription-filter-data.h" + +#include + +#include + +void update_whsiper_model_path(struct transcription_filter_data *gf, obs_data_t *s); +void shutdown_whisper_thread(struct transcription_filter_data *gf); +void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path); + +#endif /* WHISPER_UTILS_H */