mirror of
https://github.com/occ-ai/obs-localvocal
synced 2024-11-07 18:57:14 +00:00
Offline transcription accuracy tests (#96)
* Update translation-utils.h, transcription-filter.h, whisper-model-utils.h, model-find-utils.h, and model-downloader.h * Update create_context function to include ct2ModelFolder parameter * fix: add fix_utf8 flag to transcription_filter_data struct * Update create_context function to include ct2ModelFolder parameter * Update read_text_from_file function to include join_sentences parameter * fix: Update VadIterator::reset_states to include reset_hc parameter * Update create_context function to include whisper_sampling_method parameter * Update tests README with additional configuration options * feat: Add function to find file in folder by regex expression * refactor: Improve text conditioning logic in transcription-filter.cpp * refactor: Improve text conditioning logic in transcription-filter.cpp * chore: Update ctranslate2 dependency to version 1.2.0 * refactor: Improve text conditioning logic in transcription-filter.cpp * chore: Update cmake BuildCTranslate2.cmake to disable -Wno-comma warning * refactor: Update translation context in whisper-processing.cpp and translation-utils.cpp
This commit is contained in:
parent
2e83300fbb
commit
31c41a9574
@ -6,6 +6,7 @@ project(${_name} VERSION ${_version})
|
||||
|
||||
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" ON)
|
||||
option(ENABLE_QT "Use Qt functionality" ON)
|
||||
option(ENABLE_TESTS "Enable tests" OFF)
|
||||
|
||||
include(compilerconfig)
|
||||
include(defaults)
|
||||
@ -90,11 +91,41 @@ target_sources(
|
||||
src/model-utils/model-downloader.cpp
|
||||
src/model-utils/model-downloader-ui.cpp
|
||||
src/model-utils/model-infos.cpp
|
||||
src/model-utils/model-find-utils.cpp
|
||||
src/whisper-utils/whisper-processing.cpp
|
||||
src/whisper-utils/whisper-utils.cpp
|
||||
src/whisper-utils/whisper-model-utils.cpp
|
||||
src/whisper-utils/silero-vad-onnx.cpp
|
||||
src/whisper-utils/token-buffer-thread.cpp
|
||||
src/translation/translation.cpp
|
||||
src/translation/translation-utils.cpp
|
||||
src/utils.cpp)
|
||||
|
||||
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
|
||||
|
||||
if(ENABLE_TESTS)
|
||||
add_executable(${CMAKE_PROJECT_NAME}-tests)
|
||||
|
||||
include(cmake/FindLibAvObs.cmake)
|
||||
|
||||
target_sources(
|
||||
${CMAKE_PROJECT_NAME}-tests
|
||||
PRIVATE src/tests/localvocal-offline-test.cpp
|
||||
src/transcription-utils.cpp
|
||||
src/model-utils/model-infos.cpp
|
||||
src/model-utils/model-find-utils.cpp
|
||||
src/whisper-utils/whisper-processing.cpp
|
||||
src/whisper-utils/whisper-utils.cpp
|
||||
src/whisper-utils/silero-vad-onnx.cpp
|
||||
src/whisper-utils/token-buffer-thread.cpp
|
||||
src/translation/translation.cpp
|
||||
src/utils.cpp)
|
||||
|
||||
find_libav(${CMAKE_PROJECT_NAME}-tests)
|
||||
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs)
|
||||
target_include_directories(${CMAKE_PROJECT_NAME}-tests PRIVATE src)
|
||||
|
||||
# install the tests to the release/test directory
|
||||
install(TARGETS ${CMAKE_PROJECT_NAME}-tests DESTINATION test)
|
||||
endif()
|
||||
|
@ -7,15 +7,15 @@ if(APPLE)
|
||||
|
||||
FetchContent_Declare(
|
||||
ctranslate2_fetch
|
||||
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-macos-Release-1.1.1.tar.gz
|
||||
URL_HASH SHA256=da04d88ecc1ea105f8ee672e4eab33af96e50c999c5cc8170e105e110392182b)
|
||||
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-macos-Release-1.2.0.tar.gz
|
||||
URL_HASH SHA256=9029F19B0F50E5EDC14473479EDF0A983F7D6FA00BE61DC1B01BF8AA7F1CDB1B)
|
||||
FetchContent_MakeAvailable(ctranslate2_fetch)
|
||||
|
||||
add_library(ct2 INTERFACE)
|
||||
target_link_libraries(ct2 INTERFACE "-framework Accelerate" ${ctranslate2_fetch_SOURCE_DIR}/lib/libctranslate2.a
|
||||
${ctranslate2_fetch_SOURCE_DIR}/lib/libcpu_features.a)
|
||||
set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include)
|
||||
target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32)
|
||||
target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32 -Wno-comma)
|
||||
|
||||
elseif(WIN32)
|
||||
|
||||
|
48
cmake/FindLibAvObs.cmake
Normal file
48
cmake/FindLibAvObs.cmake
Normal file
@ -0,0 +1,48 @@
|
||||
# Find LibAV from the OBS dependencies
|
||||
|
||||
function(find_libav TARGET)
|
||||
if(UNIX AND NOT APPLE)
|
||||
find_package(PkgConfig REQUIRED)
|
||||
pkg_check_modules(
|
||||
FFMPEG
|
||||
REQUIRED
|
||||
IMPORTED_TARGET
|
||||
libavformat
|
||||
libavcodec
|
||||
libavutil
|
||||
libswresample)
|
||||
if(FFMPEG_FOUND)
|
||||
target_link_libraries(${TARGET} PRIVATE PkgConfig::FFMPEG)
|
||||
else()
|
||||
message(FATAL_ERROR "FFMPEG not found!")
|
||||
endif()
|
||||
return()
|
||||
endif()
|
||||
|
||||
if(NOT buildspec)
|
||||
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/buildspec.json" buildspec)
|
||||
endif()
|
||||
string(
|
||||
JSON
|
||||
version
|
||||
GET
|
||||
${buildspec}
|
||||
dependencies
|
||||
prebuilt
|
||||
version)
|
||||
|
||||
if(MSVC)
|
||||
set(arch ${CMAKE_GENERATOR_PLATFORM})
|
||||
elseif(APPLE)
|
||||
set(arch universal)
|
||||
endif()
|
||||
set(deps_root "${CMAKE_CURRENT_SOURCE_DIR}/.deps/obs-deps-${version}-${arch}")
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE "${deps_root}/include")
|
||||
target_link_libraries(
|
||||
${TARGET}
|
||||
PRIVATE "${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avcodec${CMAKE_STATIC_LIBRARY_SUFFIX}"
|
||||
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avformat${CMAKE_STATIC_LIBRARY_SUFFIX}"
|
||||
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avutil${CMAKE_STATIC_LIBRARY_SUFFIX}"
|
||||
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}swresample${CMAKE_STATIC_LIBRARY_SUFFIX}")
|
||||
endfunction(find_libav)
|
@ -55,3 +55,5 @@ suppress_sentences="Suppress sentences (each line)"
|
||||
translate_output="Translation output"
|
||||
dtw_token_timestamps="DTW token timestamps"
|
||||
buffered_output="Buffered output (Experimental)"
|
||||
translate_model="Translation Model"
|
||||
Whisper-Based-Translation="Whisper-Based Translation"
|
||||
|
@ -1,44 +1,11 @@
|
||||
#include "model-downloader.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-downloader-ui.h"
|
||||
#include "model-find-utils.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
#include <obs-frontend-api.h>
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include <curl/curl.h>
|
||||
|
||||
std::string find_file_in_folder_by_name(const std::string &folder_path,
|
||||
const std::string &file_name)
|
||||
{
|
||||
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
|
||||
if (entry.path().filename() == file_name) {
|
||||
return entry.path().string();
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
|
||||
{
|
||||
// find .bin file in folder
|
||||
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
|
||||
if (entry.path().extension() == ".bin") {
|
||||
const std::string bin_file_path = entry.path().string();
|
||||
obs_log(LOG_INFO, "Model bin file found in folder: %s",
|
||||
bin_file_path.c_str());
|
||||
return bin_file_path;
|
||||
}
|
||||
}
|
||||
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
|
||||
model_local_folder_path.c_str());
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string find_model_folder(const ModelInfo &model_info)
|
||||
{
|
||||
char *data_folder_models = obs_module_file("models");
|
||||
|
@ -2,13 +2,9 @@
|
||||
#define MODEL_DOWNLOADER_H
|
||||
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
#include "model-downloader-types.h"
|
||||
|
||||
std::string find_file_in_folder_by_name(const std::string &folder_path,
|
||||
const std::string &file_name);
|
||||
std::string find_bin_file_in_folder(const std::string &path);
|
||||
std::string find_model_folder(const ModelInfo &model_info);
|
||||
std::string find_model_bin_file(const ModelInfo &model_info);
|
||||
|
||||
|
50
src/model-utils/model-find-utils.cpp
Normal file
50
src/model-utils/model-find-utils.cpp
Normal file
@ -0,0 +1,50 @@
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <regex>
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
#include "model-find-utils.h"
|
||||
#include "plugin-support.h"
|
||||
|
||||
std::string find_file_in_folder_by_name(const std::string &folder_path,
|
||||
const std::string &file_name)
|
||||
{
|
||||
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
|
||||
if (entry.path().filename() == file_name) {
|
||||
return entry.path().string();
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// Find a file in a folder by expression
|
||||
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
|
||||
const std::string &file_name_regex)
|
||||
{
|
||||
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
|
||||
if (std::regex_match(entry.path().filename().string(),
|
||||
std::regex(file_name_regex))) {
|
||||
return entry.path().string();
|
||||
}
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
|
||||
{
|
||||
// find .bin file in folder
|
||||
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
|
||||
if (entry.path().extension() == ".bin") {
|
||||
const std::string bin_file_path = entry.path().string();
|
||||
obs_log(LOG_INFO, "Model bin file found in folder: %s",
|
||||
bin_file_path.c_str());
|
||||
return bin_file_path;
|
||||
}
|
||||
}
|
||||
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
|
||||
model_local_folder_path.c_str());
|
||||
return "";
|
||||
}
|
14
src/model-utils/model-find-utils.h
Normal file
14
src/model-utils/model-find-utils.h
Normal file
@ -0,0 +1,14 @@
|
||||
#ifndef MODEL_FIND_UTILS_H
|
||||
#define MODEL_FIND_UTILS_H
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "model-downloader-types.h"
|
||||
|
||||
std::string find_file_in_folder_by_name(const std::string &folder_path,
|
||||
const std::string &file_name);
|
||||
std::string find_bin_file_in_folder(const std::string &path);
|
||||
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
|
||||
const std::string &file_name_regex);
|
||||
|
||||
#endif // MODEL_FIND_UTILS_H
|
@ -23,6 +23,28 @@ std::map<std::string, ModelInfo> models_info = {{
|
||||
"B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"},
|
||||
{"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true",
|
||||
"D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}},
|
||||
{"M2M-100 1.2B (1.25Gb)",
|
||||
{"M2M-100 1.2BM",
|
||||
"m2m-100-1_2B",
|
||||
MODEL_TYPE_TRANSLATION,
|
||||
{{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/model.bin?download=true",
|
||||
"C97DF052A558895317312470E1FF7CB8EAE5416F7AE16214A2983C6853DD3CE5"},
|
||||
{
|
||||
"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/config.json?download=true",
|
||||
"4244772990E30069563E3DDFB4AD6DC95BDFD2AC3DE667EA8858C9B0A8433FA8",
|
||||
},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/generation_config.json?download=true",
|
||||
"AED76366507333DDBB8BD49960F23C82FE6446B3319A46A54BEFDB45324CCF61"},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/shared_vocabulary.json?download=true",
|
||||
"7EB5D0FF184C6095C7C10F9911C0AEA492250ABD12854F9C3D787C64B1C6397E"},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/special_tokens_map.json?download=true",
|
||||
"C1A4F86C3874D279AE1B2A05162858DB5DD6C61665D84223ED886CBCFF08FDA6"},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/tokenizer_config.json?download=true",
|
||||
"1566A6CFA4F541A55594C9D5E090F530812D5DE7C94882EA3AF156962D9933AE"},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/vocab.json?download=true",
|
||||
"B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"},
|
||||
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true",
|
||||
"D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}},
|
||||
{"Whisper Base q5 (57Mb)",
|
||||
{"Whisper Base q5",
|
||||
"whisper-base-q5",
|
||||
|
182
src/tests/README.md
Normal file
182
src/tests/README.md
Normal file
@ -0,0 +1,182 @@
|
||||
# Building and Using the Offline Testing tool
|
||||
|
||||
The offline testing tool provides a way to run the internal core transcription+translation algorithm of the OBS plugin without running OBS, in effect simulating how it would run within OBS. However, all the audio is pre-cached in memory so it runs faster than real-time (e.g. it doesn't simulate the audio input timing).
|
||||
The tool is useful for automating tests to measure performance.
|
||||
|
||||
## Building
|
||||
|
||||
The tool was tested on Windows with CUDA, so this guide focuses on this setup.
|
||||
However there's nothing preventing the tool to successfully build and run on Mac as well.
|
||||
Linux unfortunately is not supported at the moment.
|
||||
|
||||
Start by cloning the repo.
|
||||
|
||||
Proceed to build the plugin regularly, e.g.
|
||||
```powershell
|
||||
obs-localvocal> $env:CPU_OR_CUDA="12.2.0"
|
||||
obs-localvocal> .\.github\scripts\Build-Windows.ps1 -Configuration Release
|
||||
```
|
||||
|
||||
Then run CMake to enable the test tool in the build
|
||||
```powershell
|
||||
obs-localvocal> cmake -S . -B .\build_x64\ -DENABLE_TESTS=ON
|
||||
```
|
||||
|
||||
Once that is done you can build the test target `.exe` only (instead of building everything) e.g.
|
||||
```powershell
|
||||
obs-localvocal> cmake --build .\build_x64\ --target obs-localvocal-tests --config Release
|
||||
obs-localvocal> copy-item -Force ".\build_x64\Release\obs-localvocal-tests.exe" -Destination ".\release\Release\test"
|
||||
```
|
||||
|
||||
Also in the above we're copying the result `.exe` to the `./release/Release/test` folder.
|
||||
|
||||
Next, a few `.dll` files need to be collected and placed alongside the `obs-localvocal-test.exe` file in the `./release/Release/test` folder. Fortunately all `dll`s are available in the plugin's build folders.
|
||||
|
||||
For an automatic step to copy `.dll`s run the script: (run from any location, it will orient itself)
|
||||
```powershell
|
||||
obs-localvocal> &"src\tests\copy_dlls.ps1"
|
||||
```
|
||||
|
||||
For manual copying follow below:
|
||||
|
||||
From `.\release\Release\obs-plugins\64bit` copy:
|
||||
|
||||
- ctranslate2.dll
|
||||
- cublas64_12.dll
|
||||
- cublasLt64_12.dll
|
||||
- cudart64_12.dll
|
||||
- libopenblas.dll
|
||||
- obs-localvocal.dll
|
||||
- onnxruntime_providers_shared.dll
|
||||
- onnxruntime.dll
|
||||
- whisper.dll
|
||||
|
||||
From `.deps\obs-deps-2023-11-03-x64\bin` copy:
|
||||
|
||||
- avcodec-60.dll
|
||||
- avdevice-60.dll
|
||||
- avfilter-9.dll
|
||||
- avformat-60.dll
|
||||
- avutil-58.dll
|
||||
- libx264-164.dll
|
||||
- swresample-4.dll
|
||||
- swscale-7.dll
|
||||
- zlib.dll
|
||||
|
||||
Finally, from `.deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit` copy:
|
||||
|
||||
- obs-frontend-api.dll
|
||||
- obs.dll
|
||||
- w32-pthreads.dll
|
||||
|
||||
With all the `.dll`s in place in the `.\release\Release\test` folder the test tool should run.
|
||||
|
||||
## Using the test tool
|
||||
|
||||
The tool expects the following arguments:
|
||||
|
||||
- audio/video file
|
||||
- configuration file in JSON
|
||||
|
||||
For example, this is a valid run command:
|
||||
|
||||
```powershell
|
||||
obs-localvocal> .\release\Release\test\obs-localvocal-tests.exe "C:\Users\roysh\Downloads\audio.mp3" ".\config.json"
|
||||
```
|
||||
### Configuration
|
||||
|
||||
The tool must receive configuration to test different parameters of the algorithm.
|
||||
|
||||
- whisper language
|
||||
- translation source language (or `none`)
|
||||
- translation target language (or `none`)
|
||||
- whisper model `.bin` file
|
||||
- silero VAD model file e.g. `silero_vad.onnx`
|
||||
- CT2 model *folder* (whitin which the model and json files can be found)
|
||||
- fix UTF8 characters
|
||||
- suppress sentences
|
||||
- overlap in milliseconds
|
||||
- log level (debug, info, warning, error)
|
||||
- whisper sampling strategy (0 = greedy, 1 = beam)
|
||||
|
||||
The Whisper languages are listed in [whisper-language.h](../whisper-utils/whisper-language.h) and the CT2 language codes are listed in [language_codes.h](../translation/language_codes.h). They roughly match except CT2 has underscores e.g. `ko` -> `__ko__`, `ja` -> `__ja__`.
|
||||
|
||||
|
||||
The JSON config file can look e.g. like
|
||||
```
|
||||
{
|
||||
"whisper_language": "ko",
|
||||
"source_language": "none",
|
||||
"target_language": "none",
|
||||
"whisper_model_path": ".../obs-localvocal/models/ggml-model-whisper-small/ggml-model-whisper-small.bin",
|
||||
"silero_vad_model_file": ".../obs-localvocal/data/models/silero-vad/silero_vad.onnx",
|
||||
"ct2_model_folder": ".../obs-localvocal/models/m2m-100-418M",
|
||||
"fix_utf8": true,
|
||||
"suppress_sentences": "끝까지 시청해주셔서 감사합니다/n구독과 좋아요 부탁드립니다!/nMBC 뉴스 안영백입니다./nMBC 뉴스 이덕영입니다/n구독과 좋아요 눌러주세요!/n구독과 좋아요 부탁드",
|
||||
"overlap_ms": 150,
|
||||
"log_level": "debug",
|
||||
"whisper_sampling_method": 0
|
||||
}
|
||||
```
|
||||
|
||||
If you've used the OBS plugin to download a Whisper model and a CT2 model then you would find those in the OBS plugin config folders as visible above. It is recommended to do so.
|
||||
|
||||
Give the path to this file to the tool.
|
||||
|
||||
### Output
|
||||
|
||||
The tool would write a `output.txt` file in the running directory.
|
||||
|
||||
It would also output verbose running log to the console, e.g.
|
||||
```
|
||||
[02:07:25.148] [UNKNOWN] found 59539456 bytes, 14884864 frames in input buffer, need >= 576000
|
||||
[02:07:25.150] [UNKNOWN] processing audio from buffer, 0 existing frames, 144000 frames needed to full segment (144000 frames)
|
||||
[02:07:25.150] [UNKNOWN] with 144000 remaining to full segment, popped 143360 frames from info buffer, pushed at 0 (overlap)
|
||||
[02:07:25.151] [UNKNOWN] first segment, no overlap exists, 143360 frames to process
|
||||
[02:07:25.151] [UNKNOWN] processing 143360 frames (2986 ms), start timestamp 85
|
||||
[02:07:25.154] [UNKNOWN] 2 channels, 47770 frames, 2985.625000 ms
|
||||
[02:07:25.168] [UNKNOWN] VAD detected speech from 29696 to 47770 (18074 frames, 1129 ms)
|
||||
[02:07:25.169] [UNKNOWN] run_whisper_inference: processing 18074 samples, 1.130 sec, 4 threads
|
||||
[02:07:26.700] [UNKNOWN] Token 0: 50364, [_BEG_], p: 1.000, dtw: -1 [keep: 0]
|
||||
```
|
||||
|
||||
### Translation
|
||||
|
||||
To translate with Whisper, set the whisper output language to your desired output and the CT2 languages to `none`.
|
||||
For example this would be a Japanese translation with Whisper:
|
||||
|
||||
```json
|
||||
{
|
||||
"whisper_language": "ja",
|
||||
"source_language": "none",
|
||||
"target_language": "none",
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
To translate with CT2, make sure the Whisper output is in the spoken language and that it matches the source language.
|
||||
For example this would be a Korean-to-Japanese translation with CT2 M2M100:
|
||||
|
||||
```json
|
||||
{
|
||||
"whisper_language": "ko",
|
||||
"source_language": "__ko__",
|
||||
"target_language": "__ja__",
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
## Evaluation of the results
|
||||
|
||||
The provided [python script](evaluate_output.py) can run WER/CER evaluation on the results.
|
||||
|
||||
Exmple of running the evaluation script:
|
||||
|
||||
```powershell
|
||||
obs-localvocal> python .\src\tests\evaluate_output.py ".\ground_truth.txt" ".\output.txt"
|
||||
```
|
||||
|
||||
It requires to install a couple packages:
|
||||
```powershell
|
||||
pip install Levenshtein diff_match_patch
|
||||
```
|
40
src/tests/copy_dlls.ps1
Normal file
40
src/tests/copy_dlls.ps1
Normal file
@ -0,0 +1,40 @@
|
||||
|
||||
# change into the root directory of the repository from the location of this script
|
||||
$scriptPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
||||
Set-Location "$scriptPath\..\.."
|
||||
|
||||
$testToolPath = ".\release\Release\test"
|
||||
|
||||
# make sure the test tool directory exists
|
||||
if (-not (Test-Path $testToolPath)) {
|
||||
New-Item -ItemType Directory -Path $testToolPath | Out-Null
|
||||
}
|
||||
|
||||
# copy the required DLLs to the test tool directory
|
||||
$obsDlls = @(
|
||||
".\release\Release\obs-plugins\64bit\ctranslate2.dll",
|
||||
".\release\Release\obs-plugins\64bit\cublas64_12.dll",
|
||||
".\release\Release\obs-plugins\64bit\cublasLt64_12.dll",
|
||||
".\release\Release\obs-plugins\64bit\cudart64_12.dll",
|
||||
".\release\Release\obs-plugins\64bit\libopenblas.dll",
|
||||
".\release\Release\obs-plugins\64bit\onnxruntime_providers_shared.dll",
|
||||
".\release\Release\obs-plugins\64bit\onnxruntime.dll",
|
||||
".\release\Release\obs-plugins\64bit\whisper.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\avcodec-60.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\avdevice-60.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\avfilter-9.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\avformat-60.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\avutil-58.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\libx264-164.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\swresample-4.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\swscale-7.dll",
|
||||
".deps\obs-deps-2023-11-03-x64\bin\zlib.dll"
|
||||
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\obs-frontend-api.dll",
|
||||
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\obs.dll",
|
||||
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\w32-pthreads.dll"
|
||||
)
|
||||
|
||||
$obsDlls | ForEach-Object {
|
||||
Copy-Item -Force -Path $_ -Destination $testToolPath
|
||||
}
|
||||
|
68
src/tests/evaluate_output.py
Normal file
68
src/tests/evaluate_output.py
Normal file
@ -0,0 +1,68 @@
|
||||
import Levenshtein
|
||||
import argparse
|
||||
from diff_match_patch import diff_match_patch
|
||||
|
||||
def visualize_differences(ref_text, hyp_text):
|
||||
dmp = diff_match_patch()
|
||||
diffs = dmp.diff_main(hyp_text, ref_text, checklines=True)
|
||||
html = dmp.diff_prettyHtml(diffs)
|
||||
return html
|
||||
|
||||
def calculate_wer(ref_text, hyp_text):
|
||||
ref_words = ref_text.split()
|
||||
hyp_words = hyp_text.split()
|
||||
|
||||
distance = Levenshtein.distance(ref_words, hyp_words)
|
||||
wer = distance / len(ref_words)
|
||||
return wer
|
||||
|
||||
def calculate_cer(ref_text, hyp_text):
|
||||
distance = Levenshtein.distance(ref_text, hyp_text)
|
||||
cer = distance / len(ref_text)
|
||||
return cer
|
||||
|
||||
def compare_tokens(ref_tokens, hyp_tokens):
|
||||
comparisons = []
|
||||
for ref_token, hyp_token in zip(ref_tokens, hyp_tokens):
|
||||
distance = Levenshtein.distance(ref_token, hyp_token)
|
||||
comparison = {'ref_token': ref_token, 'hyp_token': hyp_token, 'error_rate': distance / len(ref_token)}
|
||||
comparisons.append(comparison)
|
||||
return comparisons
|
||||
|
||||
def read_text_from_file(file_path, join_sentences=True):
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
|
||||
sentences = file.readlines()
|
||||
sentences = [sentence.strip() for sentence in sentences]
|
||||
# merge into a single string
|
||||
if join_sentences:
|
||||
return ' '.join(sentences)
|
||||
return sentences
|
||||
|
||||
parser = argparse.ArgumentParser(description='Evaluate output')
|
||||
parser.add_argument('ref_file_path', type=str, help='Path to the reference file')
|
||||
parser.add_argument('hyp_file_path', type=str, help='Path to the hypothesis file')
|
||||
args = parser.parse_args()
|
||||
|
||||
ref_text = read_text_from_file(args.ref_file_path)
|
||||
hyp_text = read_text_from_file(args.hyp_file_path)
|
||||
wer = calculate_wer(ref_text, hyp_text)
|
||||
cer = calculate_cer(ref_text, hyp_text)
|
||||
print("Word Error Rate (WER):", wer)
|
||||
print("Character Error Rate (CER):", cer)
|
||||
|
||||
ref_text = '\n'.join(read_text_from_file(args.ref_file_path, join_sentences=False))
|
||||
hyp_text = '\n'.join(read_text_from_file(args.hyp_file_path, join_sentences=False))
|
||||
html_diff = visualize_differences(ref_text, hyp_text)
|
||||
with open("diff_visualization.html", "w", encoding="utf-8") as file:
|
||||
file.write(html_diff)
|
||||
|
||||
from Bio.Align import PairwiseAligner
|
||||
|
||||
aligner = PairwiseAligner()
|
||||
|
||||
alignments = aligner.align(ref_text, hyp_text)
|
||||
|
||||
# write the first alignment to a file
|
||||
with open("alignment.txt", "w", encoding="utf-8") as file:
|
||||
file.write(alignments[0].format())
|
||||
|
586
src/tests/localvocal-offline-test.cpp
Normal file
586
src/tests/localvocal-offline-test.cpp
Normal file
@ -0,0 +1,586 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <codecvt>
|
||||
#include <vector>
|
||||
|
||||
#include "transcription-filter-data.h"
|
||||
#include "transcription-filter.h"
|
||||
#include "transcription-utils.h"
|
||||
#include "whisper-utils/whisper-utils.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdarg.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
void obs_log(int log_level, const char *format, ...)
|
||||
{
|
||||
if (log_level == LOG_DEBUG) {
|
||||
return;
|
||||
}
|
||||
// print timestamp in format [HH:MM:SS.mmm], use std::chrono::system_clock
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto now_ms = std::chrono::time_point_cast<std::chrono::milliseconds>(now);
|
||||
auto epoch = now_ms.time_since_epoch();
|
||||
|
||||
// convert to std::time_t in order to convert to std::tm
|
||||
std::time_t now_time_t = std::chrono::system_clock::to_time_t(now);
|
||||
std::tm now_tm = *std::localtime(&now_time_t);
|
||||
|
||||
// print timestamp
|
||||
printf("[%02d:%02d:%02d.%03d] ", now_tm.tm_hour, now_tm.tm_min, now_tm.tm_sec,
|
||||
(int)(epoch.count() % 1000));
|
||||
|
||||
// print log level
|
||||
switch (log_level) {
|
||||
case LOG_DEBUG:
|
||||
printf("[DEBUG] ");
|
||||
break;
|
||||
case LOG_INFO:
|
||||
printf("[INFO] ");
|
||||
break;
|
||||
case LOG_WARNING:
|
||||
printf("[WARNING] ");
|
||||
break;
|
||||
case LOG_ERROR:
|
||||
printf("[ERROR] ");
|
||||
break;
|
||||
default:
|
||||
printf("[UNKNOWN] ");
|
||||
break;
|
||||
}
|
||||
// convert format to wstring
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::wstring wformat = converter.from_bytes(format);
|
||||
|
||||
// print format with arguments with utf-8 support
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
vprintf(format, args);
|
||||
va_end(args);
|
||||
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
#if defined(_WIN32) || defined(__APPLE__)
|
||||
|
||||
extern "C" {
|
||||
#include <libavformat/avformat.h>
|
||||
#include <libavcodec/avcodec.h>
|
||||
#include <libavutil/frame.h>
|
||||
#include <libavutil/mem.h>
|
||||
#include <libavutil/opt.h>
|
||||
#include <libswresample/swresample.h>
|
||||
}
|
||||
|
||||
std::vector<std::vector<uint8_t>>
|
||||
read_audio_file(const char *filename, std::function<void(int, int)> initialization_callback)
|
||||
{
|
||||
obs_log(LOG_INFO, "Reading audio file %s", filename);
|
||||
|
||||
AVFormatContext *formatContext = nullptr;
|
||||
int ret = avformat_open_input(&formatContext, filename, nullptr, nullptr);
|
||||
if (ret != 0) {
|
||||
char errbuf[AV_ERROR_MAX_STRING_SIZE];
|
||||
av_make_error_string(errbuf, AV_ERROR_MAX_STRING_SIZE, ret);
|
||||
obs_log(LOG_ERROR, "Error opening file: %s", errbuf);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (avformat_find_stream_info(formatContext, nullptr) < 0) {
|
||||
obs_log(LOG_ERROR, "Error finding stream information");
|
||||
return {};
|
||||
}
|
||||
|
||||
int audioStreamIndex = -1;
|
||||
for (unsigned int i = 0; i < formatContext->nb_streams; i++) {
|
||||
if (formatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
|
||||
audioStreamIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (audioStreamIndex == -1) {
|
||||
obs_log(LOG_ERROR, "No audio stream found");
|
||||
return {};
|
||||
}
|
||||
|
||||
// print information about the file
|
||||
av_dump_format(formatContext, 0, filename, 0);
|
||||
|
||||
// if the sample format is not float, return
|
||||
if (formatContext->streams[audioStreamIndex]->codecpar->format != AV_SAMPLE_FMT_FLTP) {
|
||||
obs_log(LOG_ERROR,
|
||||
"Sample format is not float (it is %s). Encode the audio file with float sample format."
|
||||
" For example, use the command 'ffmpeg -i input.mp3 -c:a pcm_f32le output.wav' to "
|
||||
"convert the audio file to float format.",
|
||||
av_get_sample_fmt_name(
|
||||
(AVSampleFormat)formatContext->streams[audioStreamIndex]
|
||||
->codecpar->format));
|
||||
return {};
|
||||
}
|
||||
|
||||
initialization_callback(formatContext->streams[audioStreamIndex]->codecpar->sample_rate,
|
||||
formatContext->streams[audioStreamIndex]->codecpar->channels);
|
||||
|
||||
AVCodecParameters *codecParams = formatContext->streams[audioStreamIndex]->codecpar;
|
||||
const AVCodec *codec = avcodec_find_decoder(codecParams->codec_id);
|
||||
if (!codec) {
|
||||
obs_log(LOG_ERROR, "Decoder not found");
|
||||
return {};
|
||||
}
|
||||
|
||||
AVCodecContext *codecContext = avcodec_alloc_context3(codec);
|
||||
if (!codecContext) {
|
||||
obs_log(LOG_ERROR, "Failed to allocate codec context");
|
||||
return {};
|
||||
}
|
||||
|
||||
if (avcodec_parameters_to_context(codecContext, codecParams) < 0) {
|
||||
obs_log(LOG_ERROR, "Failed to copy codec parameters to codec context");
|
||||
return {};
|
||||
}
|
||||
|
||||
if (avcodec_open2(codecContext, codec, nullptr) < 0) {
|
||||
obs_log(LOG_ERROR, "Failed to open codec");
|
||||
return {};
|
||||
}
|
||||
|
||||
AVFrame *frame = av_frame_alloc();
|
||||
AVPacket packet;
|
||||
|
||||
std::vector<std::vector<uint8_t>> buffer(
|
||||
formatContext->streams[audioStreamIndex]->codecpar->channels);
|
||||
|
||||
while (av_read_frame(formatContext, &packet) >= 0) {
|
||||
if (packet.stream_index == audioStreamIndex) {
|
||||
if (avcodec_send_packet(codecContext, &packet) == 0) {
|
||||
while (avcodec_receive_frame(codecContext, frame) == 0) {
|
||||
// push data to the buffer
|
||||
for (int j = 0; j < codecContext->channels; j++) {
|
||||
buffer[j].insert(buffer[j].end(), frame->data[j],
|
||||
frame->data[j] +
|
||||
frame->linesize[0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
av_packet_unref(&packet);
|
||||
}
|
||||
|
||||
av_frame_free(&frame);
|
||||
avcodec_free_context(&codecContext);
|
||||
avformat_close_input(&formatContext);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
transcription_filter_data *
|
||||
create_context(int sample_rate, int channels, const std::string &whisper_model_path,
|
||||
const std::string &silero_vad_model_file, const std::string &ct2ModelFolder,
|
||||
const whisper_sampling_strategy whisper_sampling_method = WHISPER_SAMPLING_GREEDY)
|
||||
{
|
||||
struct transcription_filter_data *gf = new transcription_filter_data();
|
||||
|
||||
gf->log_level = LOG_DEBUG;
|
||||
gf->channels = channels;
|
||||
gf->sample_rate = sample_rate;
|
||||
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / 3000.0));
|
||||
gf->last_num_frames = 0;
|
||||
gf->step_size_msec = 3000;
|
||||
gf->min_sub_duration = 3000;
|
||||
gf->last_sub_render_time = 0;
|
||||
gf->save_srt = false;
|
||||
gf->truncate_output_file = false;
|
||||
gf->save_only_while_recording = false;
|
||||
gf->rename_file_to_match_recording = false;
|
||||
gf->process_while_muted = false;
|
||||
gf->buffered_output = false;
|
||||
gf->fix_utf8 = true;
|
||||
|
||||
for (size_t i = 0; i < gf->channels; i++) {
|
||||
circlebuf_init(&gf->input_buffers[i]);
|
||||
}
|
||||
circlebuf_init(&gf->info_buffer);
|
||||
|
||||
// allocate copy buffers
|
||||
gf->copy_buffers[0] =
|
||||
static_cast<float *>(malloc(gf->channels * gf->frames * sizeof(float)));
|
||||
for (size_t c = 1; c < gf->channels; c++) { // set the channel pointers
|
||||
gf->copy_buffers[c] = gf->copy_buffers[0] + c * gf->frames;
|
||||
}
|
||||
|
||||
gf->overlap_ms = 150;
|
||||
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
|
||||
obs_log(gf->log_level, "channels %d, frames %d, sample_rate %d", (int)gf->channels,
|
||||
(int)gf->frames, gf->sample_rate);
|
||||
|
||||
obs_log(gf->log_level, "setup audio resampler");
|
||||
struct resample_info src, dst;
|
||||
src.samples_per_sec = gf->sample_rate;
|
||||
src.format = AUDIO_FORMAT_FLOAT_PLANAR;
|
||||
src.speakers = convert_speaker_layout((uint8_t)gf->channels);
|
||||
|
||||
dst.samples_per_sec = WHISPER_SAMPLE_RATE;
|
||||
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
|
||||
dst.speakers = convert_speaker_layout((uint8_t)1);
|
||||
|
||||
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
|
||||
|
||||
gf->whisper_model_file_currently_loaded = "";
|
||||
gf->output_file_path = std::string("output.txt");
|
||||
gf->whisper_model_path = std::string(""); // The update function will set the model path
|
||||
gf->whisper_context = nullptr;
|
||||
|
||||
// gf->captions_monitor.initialize(
|
||||
// gf,
|
||||
// [gf](const std::string &text) {
|
||||
// obs_log(LOG_INFO, "Captions: %s", text.c_str());
|
||||
// },
|
||||
// 30, std::chrono::seconds(10));
|
||||
|
||||
gf->vad_enabled = true;
|
||||
gf->log_words = true;
|
||||
gf->caption_to_stream = false;
|
||||
gf->start_timestamp_ms = now_ms();
|
||||
gf->sentence_number = 1;
|
||||
gf->last_sub_render_time = 0;
|
||||
gf->buffered_output = false;
|
||||
|
||||
gf->source_lang = "";
|
||||
gf->target_lang = "";
|
||||
gf->translation_ctx.add_context = true;
|
||||
gf->translation_output = "";
|
||||
gf->suppress_sentences = "";
|
||||
gf->translate = false;
|
||||
|
||||
gf->whisper_params = whisper_full_default_params(whisper_sampling_method);
|
||||
gf->whisper_params.duration_ms = 3000;
|
||||
gf->whisper_params.language = "en";
|
||||
gf->whisper_params.initial_prompt = "";
|
||||
gf->whisper_params.n_threads = 4;
|
||||
gf->whisper_params.n_max_text_ctx = 16384;
|
||||
gf->whisper_params.translate = false;
|
||||
gf->whisper_params.no_context = false;
|
||||
gf->whisper_params.single_segment = true;
|
||||
gf->whisper_params.print_special = false;
|
||||
gf->whisper_params.print_progress = false;
|
||||
gf->whisper_params.print_realtime = false;
|
||||
gf->whisper_params.print_timestamps = false;
|
||||
gf->whisper_params.token_timestamps = false;
|
||||
gf->whisper_params.thold_pt = 0.01;
|
||||
gf->whisper_params.thold_ptsum = 0.01;
|
||||
gf->whisper_params.max_len = 0;
|
||||
gf->whisper_params.split_on_word = false;
|
||||
gf->whisper_params.max_tokens = 32;
|
||||
gf->whisper_params.speed_up = false;
|
||||
gf->whisper_params.suppress_blank = true;
|
||||
gf->whisper_params.suppress_non_speech_tokens = true;
|
||||
gf->whisper_params.temperature = 0.5;
|
||||
gf->whisper_params.max_initial_ts = 1.0;
|
||||
gf->whisper_params.length_penalty = -1;
|
||||
gf->active = true;
|
||||
|
||||
start_whisper_thread_with_path(gf, whisper_model_path, silero_vad_model_file.c_str());
|
||||
|
||||
obs_log(gf->log_level, "context created");
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
void set_text_callback(struct transcription_filter_data *gf,
|
||||
const DetectionResultWithText &resultIn)
|
||||
{
|
||||
DetectionResultWithText result = resultIn;
|
||||
|
||||
if (!result.text.empty() && result.result == DETECTION_RESULT_SPEECH) {
|
||||
std::string str_copy = result.text;
|
||||
if (gf->fix_utf8) {
|
||||
str_copy = fix_utf8(str_copy);
|
||||
}
|
||||
str_copy = remove_leading_trailing_nonalpha(str_copy);
|
||||
|
||||
// if suppression is enabled, check if the text is in the suppression list
|
||||
if (!gf->suppress_sentences.empty()) {
|
||||
// split the suppression list by newline into individual sentences
|
||||
std::vector<std::string> suppress_sentences_list =
|
||||
split(gf->suppress_sentences, '\n');
|
||||
// check if the text is in the suppression list
|
||||
for (const std::string &suppress_sentence : suppress_sentences_list) {
|
||||
// check if str_copy starts with the suppress sentence
|
||||
if (str_copy.find(suppress_sentence) == 0) {
|
||||
obs_log(LOG_INFO, "Suppressed sentence: '%s'",
|
||||
str_copy.c_str());
|
||||
gf->last_text = str_copy;
|
||||
return; // do not process the sentence
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (gf->translate) {
|
||||
obs_log(gf->log_level, "Translating text. %s -> %s",
|
||||
gf->source_lang.c_str(), gf->target_lang.c_str());
|
||||
std::string translated_text;
|
||||
if (translate(gf->translation_ctx, str_copy, gf->source_lang,
|
||||
gf->target_lang,
|
||||
translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) {
|
||||
if (gf->log_words) {
|
||||
obs_log(LOG_INFO, "Translation: '%s' -> '%s'",
|
||||
str_copy.c_str(), translated_text.c_str());
|
||||
}
|
||||
// overwrite the original text with the translated text
|
||||
str_copy = str_copy + " -> " + translated_text;
|
||||
} else {
|
||||
obs_log(gf->log_level, "Failed to translate text");
|
||||
}
|
||||
}
|
||||
|
||||
std::ofstream output_file(gf->output_file_path, std::ios::app);
|
||||
output_file << str_copy << std::endl;
|
||||
output_file.close();
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
|
||||
if (gf->buffered_output) {
|
||||
gf->captions_monitor.addWords(result.tokens);
|
||||
}
|
||||
|
||||
if (gf->output_file_path != "" && gf->text_source_name.empty()) {
|
||||
// Check if we should save the sentence
|
||||
// should the file be truncated?
|
||||
std::ios_base::openmode openmode = std::ios::out;
|
||||
if (gf->truncate_output_file) {
|
||||
openmode |= std::ios::trunc;
|
||||
} else {
|
||||
openmode |= std::ios::app;
|
||||
}
|
||||
if (!gf->save_srt) {
|
||||
// Write raw sentence to file
|
||||
std::ofstream output_file(gf->output_file_path, openmode);
|
||||
output_file << str_copy << std::endl;
|
||||
output_file.close();
|
||||
} else {
|
||||
obs_log(gf->log_level, "Saving sentence to file %s, sentence #%d",
|
||||
gf->output_file_path.c_str(), gf->sentence_number);
|
||||
// Append sentence to file in .srt format
|
||||
std::ofstream output_file(gf->output_file_path, openmode);
|
||||
output_file << gf->sentence_number << std::endl;
|
||||
// use the start and end timestamps to calculate the start and end time in srt format
|
||||
auto format_ts_for_srt = [&output_file](uint64_t ts) {
|
||||
uint64_t time_s = ts / 1000;
|
||||
uint64_t time_m = time_s / 60;
|
||||
uint64_t time_h = time_m / 60;
|
||||
uint64_t time_ms_rem = ts % 1000;
|
||||
uint64_t time_s_rem = time_s % 60;
|
||||
uint64_t time_m_rem = time_m % 60;
|
||||
uint64_t time_h_rem = time_h % 60;
|
||||
output_file << std::setfill('0') << std::setw(2) << time_h_rem
|
||||
<< ":" << std::setfill('0') << std::setw(2)
|
||||
<< time_m_rem << ":" << std::setfill('0')
|
||||
<< std::setw(2) << time_s_rem << ","
|
||||
<< std::setfill('0') << std::setw(3) << time_ms_rem;
|
||||
};
|
||||
format_ts_for_srt(result.start_timestamp_ms);
|
||||
output_file << " --> ";
|
||||
format_ts_for_srt(result.end_timestamp_ms);
|
||||
output_file << std::endl;
|
||||
|
||||
output_file << str_copy << std::endl;
|
||||
output_file << std::endl;
|
||||
output_file.close();
|
||||
gf->sentence_number++;
|
||||
}
|
||||
}
|
||||
*/
|
||||
};
|
||||
|
||||
void release_context(transcription_filter_data *gf)
|
||||
{
|
||||
obs_log(LOG_INFO, "destroy");
|
||||
shutdown_whisper_thread(gf);
|
||||
|
||||
if (gf->resampler_to_whisper) {
|
||||
audio_resampler_destroy(gf->resampler_to_whisper);
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lockbuf(gf->whisper_buf_mutex);
|
||||
free(gf->copy_buffers[0]);
|
||||
gf->copy_buffers[0] = nullptr;
|
||||
for (size_t i = 0; i < gf->channels; i++) {
|
||||
circlebuf_free(&gf->input_buffers[i]);
|
||||
}
|
||||
}
|
||||
circlebuf_free(&gf->info_buffer);
|
||||
|
||||
delete gf;
|
||||
}
|
||||
|
||||
int wmain(int argc, wchar_t *argv[])
|
||||
{
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: localvocal-offline-test <audio-file> <config_json_file>"
|
||||
<< std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
// Set console output to UTF-8
|
||||
SetConsoleOutputCP(CP_UTF8);
|
||||
#endif
|
||||
|
||||
std::wstring file = argv[1];
|
||||
std::wstring configJsonFile = argv[2];
|
||||
|
||||
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
|
||||
std::string filenameStr = converter.to_bytes(file);
|
||||
|
||||
// read the configuration json file
|
||||
std::ifstream config_stream(configJsonFile);
|
||||
if (!config_stream.is_open()) {
|
||||
std::cout << "Failed to open config file" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
nlohmann::json config;
|
||||
config_stream >> config;
|
||||
config_stream.close();
|
||||
|
||||
// get the configuration values
|
||||
std::string whisperModelPathStr = config["whisper_model_path"];
|
||||
std::string sileroVadModelFileStr = config["silero_vad_model_file"];
|
||||
std::string sourceLanguageStr = config["source_language"];
|
||||
std::string targetLanguageStr = config["target_language"];
|
||||
std::string whisperLanguageStr = config["whisper_language"];
|
||||
std::string ct2ModelFolderStr = config["ct2_model_folder"];
|
||||
std::string logLevelStr = config["log_level"];
|
||||
whisper_sampling_strategy whisper_sampling_method = config["whisper_sampling_method"];
|
||||
|
||||
std::cout << "LocalVocal Offline Test" << std::endl;
|
||||
transcription_filter_data *gf = nullptr;
|
||||
|
||||
std::vector<std::vector<uint8_t>> audio =
|
||||
read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) {
|
||||
gf = create_context(sample_rate, channels, whisperModelPathStr,
|
||||
sileroVadModelFileStr, ct2ModelFolderStr,
|
||||
whisper_sampling_method);
|
||||
if (sourceLanguageStr.empty() || targetLanguageStr.empty() ||
|
||||
sourceLanguageStr == "none" || targetLanguageStr == "none") {
|
||||
obs_log(LOG_INFO,
|
||||
"Source or target translation language are empty or disabled");
|
||||
} else {
|
||||
obs_log(LOG_INFO, "Setting translation languages");
|
||||
gf->source_lang = sourceLanguageStr;
|
||||
gf->target_lang = targetLanguageStr;
|
||||
build_and_enable_translation(gf, ct2ModelFolderStr.c_str());
|
||||
}
|
||||
gf->whisper_params.language = whisperLanguageStr.c_str();
|
||||
if (config.contains("fix_utf8")) {
|
||||
obs_log(LOG_INFO, "Setting fix_utf8 to %s",
|
||||
config["fix_utf8"] ? "true" : "false");
|
||||
gf->fix_utf8 = config["fix_utf8"];
|
||||
}
|
||||
if (config.contains("suppress_sentences")) {
|
||||
obs_log(LOG_INFO, "Setting suppress_sentences to %ls",
|
||||
config["suppress_sentences"].get<std::string>().c_str());
|
||||
gf->suppress_sentences =
|
||||
config["suppress_sentences"].get<std::string>();
|
||||
}
|
||||
if (config.contains("overlap_ms")) {
|
||||
obs_log(LOG_INFO, "Setting overlap_ms to %d",
|
||||
config["overlap_ms"].get<int>());
|
||||
gf->overlap_ms = config["overlap_ms"];
|
||||
gf->overlap_frames = (size_t)((float)gf->sample_rate /
|
||||
(1000.0f / (float)gf->overlap_ms));
|
||||
}
|
||||
// set log level
|
||||
if (logLevelStr == "debug") {
|
||||
gf->log_level = LOG_DEBUG;
|
||||
} else if (logLevelStr == "info") {
|
||||
gf->log_level = LOG_INFO;
|
||||
} else if (logLevelStr == "warning") {
|
||||
gf->log_level = LOG_WARNING;
|
||||
} else if (logLevelStr == "error") {
|
||||
gf->log_level = LOG_ERROR;
|
||||
}
|
||||
});
|
||||
|
||||
if (gf == nullptr) {
|
||||
std::cout << "Failed to create context" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
if (audio.empty()) {
|
||||
std::cout << "Failed to read audio file" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
// truncate the output file
|
||||
obs_log(LOG_INFO, "Truncating output file");
|
||||
std::ofstream output_file(gf->output_file_path, std::ios::trunc);
|
||||
output_file.close();
|
||||
|
||||
// fill up the whisper buffer
|
||||
{
|
||||
obs_log(LOG_INFO, "Filling up whisper buffer");
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
|
||||
int frames = 4096;
|
||||
const int frame_size_bytes = sizeof(float);
|
||||
int frames_size_bytes = frames * frame_size_bytes;
|
||||
int frames_count = 0;
|
||||
while (true) {
|
||||
// check if there are enough frames left in the audio buffer
|
||||
if ((frames_count + frames) > (audio[0].size() / frame_size_bytes)) {
|
||||
// only take the remaining frames
|
||||
frames = audio[0].size() / frame_size_bytes - frames_count;
|
||||
frames_size_bytes = frames * frame_size_bytes;
|
||||
}
|
||||
// push back current audio data to input circlebuf
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
circlebuf_push_back(&gf->input_buffers[c],
|
||||
audio[c].data() +
|
||||
frames_count * frame_size_bytes,
|
||||
frames_size_bytes);
|
||||
}
|
||||
// push audio packet info (timestamp/frame count) to info circlebuf
|
||||
struct transcription_filter_audio_info info = {0};
|
||||
info.frames = frames; // number of frames in this packet
|
||||
// make a timestamp from the current frame count
|
||||
info.timestamp = frames_count * 1000 / gf->sample_rate;
|
||||
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
|
||||
frames_count += frames;
|
||||
if (frames_count >= audio[0].size() / frame_size_bytes) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wait for processing to finish
|
||||
obs_log(LOG_INFO, "Waiting for processing to finish");
|
||||
while (true) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
// check the input circlebuf has more data
|
||||
size_t input_buf_size = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
|
||||
input_buf_size = gf->input_buffers[0].size;
|
||||
}
|
||||
const size_t step_size_frames = gf->step_size_msec * gf->sample_rate / 1000;
|
||||
const size_t segment_size = step_size_frames * sizeof(float);
|
||||
|
||||
if (input_buf_size < segment_size) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
release_context(gf);
|
||||
|
||||
obs_log(LOG_INFO, "LocalVocal Offline Test Done");
|
||||
return 0;
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
#ifndef TRANSCRIPTION_FILTER_DATA_H
|
||||
#define TRANSCRIPTION_FILTER_DATA_H
|
||||
|
||||
#include <obs.h>
|
||||
#include <util/circlebuf.h>
|
||||
#include <util/darray.h>
|
||||
#include <media-io/audio-resampler.h>
|
||||
@ -22,8 +21,6 @@
|
||||
|
||||
#define MAX_PREPROC_CHANNELS 10
|
||||
|
||||
#define MT_ obs_module_text
|
||||
|
||||
struct transcription_filter_data {
|
||||
obs_source_t *context; // obs filter source (this filter)
|
||||
size_t channels; // number of channels
|
||||
@ -82,6 +79,7 @@ struct transcription_filter_data {
|
||||
bool buffered_output = false;
|
||||
bool enable_token_ts_dtw = false;
|
||||
std::string suppress_sentences;
|
||||
bool fix_utf8 = true;
|
||||
|
||||
// Last transcription result
|
||||
std::string last_text;
|
||||
@ -97,17 +95,18 @@ struct transcription_filter_data {
|
||||
// Use std for thread and mutex
|
||||
std::thread whisper_thread;
|
||||
|
||||
std::mutex *whisper_buf_mutex;
|
||||
std::mutex *whisper_ctx_mutex;
|
||||
std::condition_variable *wshiper_thread_cv;
|
||||
std::mutex whisper_buf_mutex;
|
||||
std::mutex whisper_ctx_mutex;
|
||||
std::condition_variable wshiper_thread_cv;
|
||||
|
||||
// translation context
|
||||
struct translation_context translation_ctx;
|
||||
std::string translation_model_index;
|
||||
|
||||
TokenBufferThread captions_monitor;
|
||||
|
||||
// ctor
|
||||
transcription_filter_data()
|
||||
transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv()
|
||||
{
|
||||
// initialize all pointers to nullptr
|
||||
for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) {
|
||||
@ -117,9 +116,6 @@ struct transcription_filter_data {
|
||||
resampler_to_whisper = nullptr;
|
||||
whisper_model_path = "";
|
||||
whisper_context = nullptr;
|
||||
whisper_buf_mutex = nullptr;
|
||||
whisper_ctx_mutex = nullptr;
|
||||
wshiper_thread_cv = nullptr;
|
||||
output_file_path = "";
|
||||
whisper_model_file_currently_loaded = "";
|
||||
}
|
||||
|
@ -8,8 +8,10 @@
|
||||
#include "model-utils/model-downloader.h"
|
||||
#include "whisper-utils/whisper-processing.h"
|
||||
#include "whisper-utils/whisper-language.h"
|
||||
#include "whisper-utils/whisper-model-utils.h"
|
||||
#include "whisper-utils/whisper-utils.h"
|
||||
#include "translation/language_codes.h"
|
||||
#include "translation/translation-utils.h"
|
||||
#include "translation/translation.h"
|
||||
#include "utils.h"
|
||||
|
||||
@ -24,37 +26,6 @@
|
||||
|
||||
#include <QString>
|
||||
|
||||
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
|
||||
{
|
||||
switch (channels) {
|
||||
case 0:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
case 1:
|
||||
return SPEAKERS_MONO;
|
||||
case 2:
|
||||
return SPEAKERS_STEREO;
|
||||
case 3:
|
||||
return SPEAKERS_2POINT1;
|
||||
case 4:
|
||||
return SPEAKERS_4POINT0;
|
||||
case 5:
|
||||
return SPEAKERS_4POINT1;
|
||||
case 6:
|
||||
return SPEAKERS_5POINT1;
|
||||
case 8:
|
||||
return SPEAKERS_7POINT1;
|
||||
default:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
inline uint64_t now_ms()
|
||||
{
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
|
||||
bool add_sources_to_list(void *list_property, obs_source_t *source)
|
||||
{
|
||||
auto source_id = obs_source_get_id(source);
|
||||
@ -97,13 +68,8 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
|
||||
return audio;
|
||||
}
|
||||
|
||||
if (!gf->whisper_buf_mutex || !gf->whisper_ctx_mutex) {
|
||||
obs_log(LOG_ERROR, "whisper mutexes are null");
|
||||
return audio;
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex); // scoped lock
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
|
||||
// push back current audio data to input circlebuf
|
||||
for (size_t c = 0; c < gf->channels; c++) {
|
||||
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
|
||||
@ -138,7 +104,7 @@ void transcription_filter_destroy(void *data)
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lockbuf(*gf->whisper_buf_mutex);
|
||||
std::lock_guard<std::mutex> lockbuf(gf->whisper_buf_mutex);
|
||||
bfree(gf->copy_buffers[0]);
|
||||
gf->copy_buffers[0] = nullptr;
|
||||
for (size_t i = 0; i < gf->channels; i++) {
|
||||
@ -147,11 +113,7 @@ void transcription_filter_destroy(void *data)
|
||||
}
|
||||
circlebuf_free(&gf->info_buffer);
|
||||
|
||||
delete gf->whisper_buf_mutex;
|
||||
delete gf->whisper_ctx_mutex;
|
||||
delete gf->wshiper_thread_cv;
|
||||
|
||||
delete gf;
|
||||
bfree(gf);
|
||||
}
|
||||
|
||||
void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy,
|
||||
@ -185,9 +147,15 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
}
|
||||
gf->last_sub_render_time = now;
|
||||
|
||||
// recondition the text
|
||||
std::string str_copy = fix_utf8(result.text);
|
||||
str_copy = remove_leading_trailing_nonalpha(str_copy);
|
||||
std::string str_copy = result.text;
|
||||
|
||||
// recondition the text - only if the output is not English
|
||||
if (gf->whisper_params.language != nullptr &&
|
||||
strcmp(gf->whisper_params.language, "en") != 0) {
|
||||
str_copy = fix_utf8(str_copy);
|
||||
} else {
|
||||
str_copy = remove_leading_trailing_nonalpha(str_copy);
|
||||
}
|
||||
|
||||
// if suppression is enabled, check if the text is in the suppression list
|
||||
if (!gf->suppress_sentences.empty()) {
|
||||
@ -340,10 +308,17 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
|
||||
gf->translation_output = obs_data_get_string(s, "translate_output");
|
||||
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
|
||||
gf->translation_model_index = obs_data_get_string(s, "translate_model");
|
||||
|
||||
if (new_translate != gf->translate) {
|
||||
if (new_translate) {
|
||||
start_translation(gf);
|
||||
if (gf->translation_model_index != "whisper-based-translation") {
|
||||
start_translation(gf);
|
||||
} else {
|
||||
// whisper-based translation
|
||||
obs_log(gf->log_level, "Starting whisper-based translation...");
|
||||
gf->translate = false;
|
||||
}
|
||||
} else {
|
||||
gf->translate = false;
|
||||
}
|
||||
@ -381,21 +356,21 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
obs_weak_source_release(old_weak_text_source);
|
||||
}
|
||||
|
||||
if (gf->whisper_ctx_mutex == nullptr) {
|
||||
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
|
||||
return;
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "update whisper model");
|
||||
update_whisper_model(gf, s);
|
||||
|
||||
obs_log(gf->log_level, "update whisper params");
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
|
||||
gf->whisper_params = whisper_full_default_params(
|
||||
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
|
||||
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
|
||||
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
|
||||
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
|
||||
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
|
||||
} else {
|
||||
// take the language from gf->target_lang
|
||||
gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str();
|
||||
}
|
||||
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
|
||||
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
|
||||
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
|
||||
@ -425,7 +400,8 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
{
|
||||
obs_log(LOG_INFO, "LocalVocal filter create");
|
||||
|
||||
struct transcription_filter_data *gf = new transcription_filter_data();
|
||||
void *data = bmalloc(sizeof(struct transcription_filter_data));
|
||||
struct transcription_filter_data *gf = new (data) transcription_filter_data();
|
||||
|
||||
// Get the number of channels for the input source
|
||||
gf->channels = audio_output_get_channels(obs_get_audio());
|
||||
@ -479,10 +455,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
|
||||
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
|
||||
|
||||
obs_log(gf->log_level, "setup mutexes and condition variables");
|
||||
gf->whisper_buf_mutex = new std::mutex();
|
||||
gf->whisper_ctx_mutex = new std::mutex();
|
||||
gf->wshiper_thread_cv = new std::condition_variable();
|
||||
obs_log(gf->log_level, "clear text source data");
|
||||
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
|
||||
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
|
||||
@ -675,6 +647,7 @@ void transcription_filter_defaults(obs_data_t *s)
|
||||
obs_data_set_default_string(s, "translate_target_language", "__es__");
|
||||
obs_data_set_default_string(s, "translate_source_language", "__en__");
|
||||
obs_data_set_default_bool(s, "translate_add_context", true);
|
||||
obs_data_set_default_string(s, "translate_model", "whisper-based-translation");
|
||||
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
|
||||
|
||||
// Whisper parameters
|
||||
@ -788,6 +761,21 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
obs_properties_t *translation_group = obs_properties_create();
|
||||
obs_property_t *translation_group_prop = obs_properties_add_group(
|
||||
ppts, "translate", MT_("translate"), OBS_GROUP_CHECKABLE, translation_group);
|
||||
|
||||
// add translatio model selection
|
||||
obs_property_t *prop_translate_model = obs_properties_add_list(
|
||||
translation_group, "translate_model", MT_("translate_model"), OBS_COMBO_TYPE_LIST,
|
||||
OBS_COMBO_FORMAT_STRING);
|
||||
// Populate the dropdown with the translation models
|
||||
// add "Whisper-Based Translation" option
|
||||
obs_property_list_add_string(prop_translate_model, MT_("Whisper-Based-Translation"),
|
||||
"whisper-based-translation");
|
||||
for (const auto &model_info : models_info) {
|
||||
if (model_info.second.type == MODEL_TYPE_TRANSLATION) {
|
||||
obs_property_list_add_string(prop_translate_model, model_info.first.c_str(),
|
||||
model_info.first.c_str());
|
||||
}
|
||||
}
|
||||
// add target language selection
|
||||
obs_property_t *prop_tgt = obs_properties_add_list(
|
||||
translation_group, "translate_target_language", MT_("target_language"),
|
||||
@ -822,8 +810,9 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
UNUSED_PARAMETER(property);
|
||||
// Show/Hide the translation group
|
||||
const bool translate_enabled = obs_data_get_bool(settings, "translate");
|
||||
for (const auto &prop : {"translate_target_language", "translate_source_language",
|
||||
"translate_add_context", "translate_output"}) {
|
||||
for (const auto &prop :
|
||||
{"translate_target_language", "translate_source_language",
|
||||
"translate_add_context", "translate_output", "translate_model"}) {
|
||||
obs_property_set_visible(obs_properties_get(props, prop),
|
||||
translate_enabled);
|
||||
}
|
||||
|
@ -4,6 +4,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define MT_ obs_module_text
|
||||
|
||||
void transcription_filter_activate(void *data);
|
||||
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter);
|
||||
void transcription_filter_update(void *data, obs_data_t *s);
|
||||
@ -19,7 +21,10 @@ const char *const PLUGIN_INFO_TEMPLATE =
|
||||
"<a href=\"https://github.com/occ-ai\">OCC AI</a> ❤️ "
|
||||
"<a href=\"https://www.patreon.com/RoyShilkrot\">Support & Follow</a>";
|
||||
|
||||
const char *const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you";
|
||||
const char *const SUPPRESS_SENTENCES_DEFAULT =
|
||||
"Thank you for watching\nPlease like and subscribe\n"
|
||||
"Check out my other videos\nFollow me on social media\n"
|
||||
"Please consider supporting me";
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
@ -90,17 +90,18 @@ std::string remove_leading_trailing_nonalpha(const std::string &str)
|
||||
{
|
||||
std::string str_copy = str;
|
||||
// remove trailing spaces, newlines, tabs or punctuation
|
||||
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
|
||||
[](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
})
|
||||
.base(),
|
||||
str_copy.end());
|
||||
auto last_non_space =
|
||||
std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
}).base();
|
||||
str_copy.erase(last_non_space, str_copy.end());
|
||||
// remove leading spaces, newlines, tabs or punctuation
|
||||
str_copy.erase(str_copy.begin(),
|
||||
std::find_if(str_copy.begin(), str_copy.end(), [](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
}));
|
||||
auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(),
|
||||
[](unsigned char ch) {
|
||||
return !std::isspace(ch) || !std::ispunct(ch);
|
||||
}) +
|
||||
1;
|
||||
str_copy.erase(str_copy.begin(), first_non_space);
|
||||
return str_copy;
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,42 @@
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <media-io/audio-io.h>
|
||||
|
||||
std::string fix_utf8(const std::string &str);
|
||||
std::string remove_leading_trailing_nonalpha(const std::string &str);
|
||||
std::vector<std::string> split(const std::string &string, char delimiter);
|
||||
|
||||
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
|
||||
{
|
||||
switch (channels) {
|
||||
case 0:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
case 1:
|
||||
return SPEAKERS_MONO;
|
||||
case 2:
|
||||
return SPEAKERS_STEREO;
|
||||
case 3:
|
||||
return SPEAKERS_2POINT1;
|
||||
case 4:
|
||||
return SPEAKERS_4POINT0;
|
||||
case 5:
|
||||
return SPEAKERS_4POINT1;
|
||||
case 6:
|
||||
return SPEAKERS_5POINT1;
|
||||
case 8:
|
||||
return SPEAKERS_7POINT1;
|
||||
default:
|
||||
return SPEAKERS_UNKNOWN;
|
||||
}
|
||||
}
|
||||
|
||||
inline uint64_t now_ms()
|
||||
{
|
||||
return std::chrono::duration_cast<std::chrono::milliseconds>(
|
||||
std::chrono::system_clock::now().time_since_epoch())
|
||||
.count();
|
||||
}
|
||||
|
||||
#endif // TRANSCRIPTION_UTILS_H
|
||||
|
@ -203,3 +203,57 @@ std::map<std::string, std::string> language_codes_reverse = {{"Afrikaans", "__af
|
||||
{"Yoruba", "__yo__"},
|
||||
{"Chinese", "__zh__"},
|
||||
{"Zulu", "__zu__"}};
|
||||
|
||||
std::map<std::string, std::string> language_codes_2 = {
|
||||
{"af", "__af__"}, {"am", "__am__"}, {"ar", "__ar__"}, {"ast", "__ast__"},
|
||||
{"az", "__az__"}, {"ba", "__ba__"}, {"be", "__be__"}, {"bg", "__bg__"},
|
||||
{"bn", "__bn__"}, {"br", "__br__"}, {"bs", "__bs__"}, {"ca", "__ca__"},
|
||||
{"ceb", "__ceb__"}, {"cs", "__cs__"}, {"cy", "__cy__"}, {"da", "__da__"},
|
||||
{"de", "__de__"}, {"el", "__el__"}, {"en", "__en__"}, {"es", "__es__"},
|
||||
{"et", "__et__"}, {"fa", "__fa__"}, {"ff", "__ff__"}, {"fi", "__fi__"},
|
||||
{"fr", "__fr__"}, {"fy", "__fy__"}, {"ga", "__ga__"}, {"gd", "__gd__"},
|
||||
{"gl", "__gl__"}, {"gu", "__gu__"}, {"ha", "__ha__"}, {"he", "__he__"},
|
||||
{"hi", "__hi__"}, {"hr", "__hr__"}, {"ht", "__ht__"}, {"hu", "__hu__"},
|
||||
{"hy", "__hy__"}, {"id", "__id__"}, {"ig", "__ig__"}, {"ilo", "__ilo__"},
|
||||
{"is", "__is__"}, {"it", "__it__"}, {"ja", "__ja__"}, {"jv", "__jv__"},
|
||||
{"ka", "__ka__"}, {"kk", "__kk__"}, {"km", "__km__"}, {"kn", "__kn__"},
|
||||
{"ko", "__ko__"}, {"lb", "__lb__"}, {"lg", "__lg__"}, {"ln", "__ln__"},
|
||||
{"lo", "__lo__"}, {"lt", "__lt__"}, {"lv", "__lv__"}, {"mg", "__mg__"},
|
||||
{"mk", "__mk__"}, {"ml", "__ml__"}, {"mn", "__mn__"}, {"mr", "__mr__"},
|
||||
{"ms", "__ms__"}, {"my", "__my__"}, {"ne", "__ne__"}, {"nl", "__nl__"},
|
||||
{"no", "__no__"}, {"ns", "__ns__"}, {"oc", "__oc__"}, {"or", "__or__"},
|
||||
{"pa", "__pa__"}, {"pl", "__pl__"}, {"ps", "__ps__"}, {"pt", "__pt__"},
|
||||
{"ro", "__ro__"}, {"ru", "__ru__"}, {"sd", "__sd__"}, {"si", "__si__"},
|
||||
{"sk", "__sk__"}, {"sl", "__sl__"}, {"so", "__so__"}, {"sq", "__sq__"},
|
||||
{"sr", "__sr__"}, {"ss", "__ss__"}, {"su", "__su__"}, {"sv", "__sv__"},
|
||||
{"sw", "__sw__"}, {"ta", "__ta__"}, {"th", "__th__"}, {"tl", "__tl__"},
|
||||
{"tn", "__tn__"}, {"tr", "__tr__"}, {"uk", "__uk__"}, {"ur", "__ur__"},
|
||||
{"uz", "__uz__"}, {"vi", "__vi__"}, {"wo", "__wo__"}, {"xh", "__xh__"},
|
||||
{"yi", "__yi__"}, {"yo", "__yo__"}, {"zh", "__zh__"}, {"zu", "__zu__"}};
|
||||
|
||||
std::map<std::string, std::string> language_codes_2_reverse = {
|
||||
{"__af__", "af"}, {"__am__", "am"}, {"__ar__", "ar"}, {"__ast__", "_st"},
|
||||
{"__az__", "az"}, {"__ba__", "ba"}, {"__be__", "be"}, {"__bg__", "bg"},
|
||||
{"__bn__", "bn"}, {"__br__", "br"}, {"__bs__", "bs"}, {"__ca__", "ca"},
|
||||
{"__ceb__", "_eb"}, {"__cs__", "cs"}, {"__cy__", "cy"}, {"__da__", "da"},
|
||||
{"__de__", "de"}, {"__el__", "el"}, {"__en__", "en"}, {"__es__", "es"},
|
||||
{"__et__", "et"}, {"__fa__", "fa"}, {"__ff__", "ff"}, {"__fi__", "fi"},
|
||||
{"__fr__", "fr"}, {"__fy__", "fy"}, {"__ga__", "ga"}, {"__gd__", "gd"},
|
||||
{"__gl__", "gl"}, {"__gu__", "gu"}, {"__ha__", "ha"}, {"__he__", "he"},
|
||||
{"__hi__", "hi"}, {"__hr__", "hr"}, {"__ht__", "ht"}, {"__hu__", "hu"},
|
||||
{"__hy__", "hy"}, {"__id__", "id"}, {"__ig__", "ig"}, {"__ilo__", "_lo"},
|
||||
{"__is__", "is"}, {"__it__", "it"}, {"__ja__", "ja"}, {"__jv__", "jv"},
|
||||
{"__ka__", "ka"}, {"__kk__", "kk"}, {"__km__", "km"}, {"__kn__", "kn"},
|
||||
{"__ko__", "ko"}, {"__lb__", "lb"}, {"__lg__", "lg"}, {"__ln__", "ln"},
|
||||
{"__lo__", "lo"}, {"__lt__", "lt"}, {"__lv__", "lv"}, {"__mg__", "mg"},
|
||||
{"__mk__", "mk"}, {"__ml__", "ml"}, {"__mn__", "mn"}, {"__mr__", "mr"},
|
||||
{"__ms__", "ms"}, {"__my__", "my"}, {"__ne__", "ne"}, {"__nl__", "nl"},
|
||||
{"__no__", "no"}, {"__ns__", "ns"}, {"__oc__", "oc"}, {"__or__", "or"},
|
||||
{"__pa__", "pa"}, {"__pl__", "pl"}, {"__ps__", "ps"}, {"__pt__", "pt"},
|
||||
{"__ro__", "ro"}, {"__ru__", "ru"}, {"__sd__", "sd"}, {"__si__", "si"},
|
||||
{"__sk__", "sk"}, {"__sl__", "sl"}, {"__so__", "so"}, {"__sq__", "sq"},
|
||||
{"__sr__", "sr"}, {"__ss__", "ss"}, {"__su__", "su"}, {"__sv__", "sv"},
|
||||
{"__sw__", "sw"}, {"__ta__", "ta"}, {"__th__", "th"}, {"__tl__", "tl"},
|
||||
{"__tn__", "tn"}, {"__tr__", "tr"}, {"__uk__", "uk"}, {"__ur__", "ur"},
|
||||
{"__uz__", "uz"}, {"__vi__", "vi"}, {"__wo__", "wo"}, {"__xh__", "xh"},
|
||||
{"__yi__", "yi"}, {"__yo__", "yo"}, {"__zh__", "zh"}, {"__zu__", "zu"}};
|
||||
|
33
src/translation/translation-utils.cpp
Normal file
33
src/translation/translation-utils.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
|
||||
#include "translation-utils.h"
|
||||
|
||||
#include "translation.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
void start_translation(struct transcription_filter_data *gf)
|
||||
{
|
||||
obs_log(LOG_INFO, "Starting translation...");
|
||||
|
||||
const ModelInfo &translation_model_info = models_info[gf->translation_model_index];
|
||||
std::string model_file_found = find_model_folder(translation_model_info);
|
||||
if (model_file_found == "") {
|
||||
obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading...");
|
||||
download_model_with_ui_dialog(
|
||||
translation_model_info,
|
||||
[gf, model_file_found](int download_status, const std::string &path) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO, "CT2 model download complete");
|
||||
build_and_enable_translation(gf, path);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
gf->translate = false;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
build_and_enable_translation(gf, model_file_found);
|
||||
}
|
||||
}
|
4
src/translation/translation-utils.h
Normal file
4
src/translation/translation-utils.h
Normal file
@ -0,0 +1,4 @@
|
||||
|
||||
#include "transcription-filter-data.h"
|
||||
|
||||
void start_translation(struct transcription_filter_data *gf);
|
@ -1,6 +1,6 @@
|
||||
#include "translation.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
#include "model-utils/model-find-utils.h"
|
||||
#include "transcription-filter-data.h"
|
||||
|
||||
#include <ctranslate2/translator.h>
|
||||
@ -11,11 +11,7 @@
|
||||
void build_and_enable_translation(struct transcription_filter_data *gf,
|
||||
const std::string &model_file_path)
|
||||
{
|
||||
if (gf->whisper_ctx_mutex == nullptr) {
|
||||
obs_log(LOG_ERROR, "Whisper context mutex is null");
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
|
||||
gf->translation_ctx.local_model_folder_path = model_file_path;
|
||||
if (build_translation_context(gf->translation_ctx) ==
|
||||
@ -28,38 +24,13 @@ void build_and_enable_translation(struct transcription_filter_data *gf,
|
||||
}
|
||||
}
|
||||
|
||||
void start_translation(struct transcription_filter_data *gf)
|
||||
{
|
||||
obs_log(LOG_INFO, "Starting translation...");
|
||||
|
||||
const ModelInfo &translation_model_info = models_info["M2M-100 418M (495Mb)"];
|
||||
std::string model_file_found = find_model_folder(translation_model_info);
|
||||
if (model_file_found == "") {
|
||||
obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading...");
|
||||
download_model_with_ui_dialog(
|
||||
translation_model_info,
|
||||
[gf, model_file_found](int download_status, const std::string &path) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO, "CT2 model download complete");
|
||||
build_and_enable_translation(gf, path);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
gf->translate = false;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
build_and_enable_translation(gf, model_file_found);
|
||||
}
|
||||
}
|
||||
|
||||
int build_translation_context(struct translation_context &translation_ctx)
|
||||
{
|
||||
std::string local_model_path = translation_ctx.local_model_folder_path;
|
||||
obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str());
|
||||
// find the SPM file in the model folder
|
||||
std::string local_spm_path =
|
||||
find_file_in_folder_by_name(local_model_path, "sentencepiece.bpe.model");
|
||||
std::string local_spm_path = find_file_in_folder_by_regex_expression(
|
||||
local_model_path, "(sentencepiece|spm).*?\\.model");
|
||||
|
||||
try {
|
||||
obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str());
|
||||
|
@ -19,8 +19,9 @@ struct translation_context {
|
||||
bool add_context;
|
||||
};
|
||||
|
||||
void start_translation(struct transcription_filter_data *gf);
|
||||
int build_translation_context(struct translation_context &translation_ctx);
|
||||
void build_and_enable_translation(struct transcription_filter_data *gf,
|
||||
const std::string &model_file_path);
|
||||
|
||||
int translate(struct translation_context &translation_ctx, const std::string &text,
|
||||
const std::string &source_lang, const std::string &target_lang, std::string &result);
|
||||
|
@ -90,12 +90,13 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
|
||||
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
|
||||
};
|
||||
|
||||
void VadIterator::reset_states()
|
||||
void VadIterator::reset_states(bool reset_hc)
|
||||
{
|
||||
// Call reset before each audio start
|
||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
||||
triggered = false;
|
||||
if (reset_hc) {
|
||||
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
|
||||
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
|
||||
triggered = false;
|
||||
}
|
||||
temp_end = 0;
|
||||
current_sample = 0;
|
||||
|
||||
@ -105,7 +106,7 @@ void VadIterator::reset_states()
|
||||
current_speech = timestamp_t();
|
||||
};
|
||||
|
||||
void VadIterator::predict(const std::vector<float> &data)
|
||||
float VadIterator::predict_one(const std::vector<float> &data)
|
||||
{
|
||||
// Infer
|
||||
// Create ort tensors
|
||||
@ -138,6 +139,13 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
float *cn = ort_outputs[2].GetTensorMutableData<float>();
|
||||
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
|
||||
|
||||
return speech_prob;
|
||||
}
|
||||
|
||||
void VadIterator::predict(const std::vector<float> &data)
|
||||
{
|
||||
const float speech_prob = predict_one(data);
|
||||
|
||||
// Push forward sample index
|
||||
current_sample += (unsigned int)window_size_samples;
|
||||
|
||||
@ -254,9 +262,9 @@ void VadIterator::predict(const std::vector<float> &data)
|
||||
}
|
||||
};
|
||||
|
||||
void VadIterator::process(const std::vector<float> &input_wav)
|
||||
void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
|
||||
{
|
||||
reset_states();
|
||||
reset_states(reset_hc);
|
||||
|
||||
audio_length_samples = (int)input_wav.size();
|
||||
|
||||
@ -280,7 +288,7 @@ void VadIterator::process(const std::vector<float> &input_wav)
|
||||
|
||||
void VadIterator::process(const std::vector<float> &input_wav, std::vector<float> &output_wav)
|
||||
{
|
||||
process(input_wav);
|
||||
process(input_wav, true);
|
||||
collect_chunks(input_wav, output_wav);
|
||||
}
|
||||
|
||||
|
@ -43,11 +43,12 @@ private:
|
||||
private:
|
||||
void init_engine_threads(int inter_threads, int intra_threads);
|
||||
void init_onnx_model(const SileroString &model_path);
|
||||
void reset_states();
|
||||
void reset_states(bool reset_hc);
|
||||
float predict_one(const std::vector<float> &data);
|
||||
void predict(const std::vector<float> &data);
|
||||
|
||||
public:
|
||||
void process(const std::vector<float> &input_wav);
|
||||
void process(const std::vector<float> &input_wav, bool reset_hc = true);
|
||||
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
|
||||
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
|
||||
const std::vector<timestamp_t> get_speech_timestamps() const;
|
||||
|
106
src/whisper-utils/whisper-model-utils.cpp
Normal file
106
src/whisper-utils/whisper-model-utils.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
#include "whisper-utils.h"
|
||||
#include "plugin-support.h"
|
||||
#include "model-utils/model-downloader.h"
|
||||
#include "whisper-processing.h"
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
{
|
||||
// update the whisper model path
|
||||
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
|
||||
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
|
||||
|
||||
char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
|
||||
if (silero_vad_model_file == nullptr) {
|
||||
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
|
||||
return;
|
||||
}
|
||||
|
||||
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
|
||||
is_external_model) {
|
||||
|
||||
if (gf->whisper_model_path != new_model_path) {
|
||||
// model path changed
|
||||
obs_log(gf->log_level, "model path changed from %s to %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
// check if the new model is external file
|
||||
if (!is_external_model) {
|
||||
// new model is not external file
|
||||
shutdown_whisper_thread(gf);
|
||||
|
||||
if (models_info.count(new_model_path) == 0) {
|
||||
obs_log(LOG_WARNING, "Model '%s' does not exist",
|
||||
new_model_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const ModelInfo &model_info = models_info[new_model_path];
|
||||
|
||||
// check if the model exists, if not, download it
|
||||
std::string model_file_found = find_model_bin_file(model_info);
|
||||
if (model_file_found == "") {
|
||||
obs_log(LOG_WARNING, "Whisper model does not exist");
|
||||
download_model_with_ui_dialog(
|
||||
model_info,
|
||||
[gf, new_model_path, silero_vad_model_file](
|
||||
int download_status, const std::string &path) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO,
|
||||
"Model download complete");
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(
|
||||
gf, path, silero_vad_model_file);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, model_file_found,
|
||||
silero_vad_model_file);
|
||||
}
|
||||
} else {
|
||||
// new model is external file, get file location from file property
|
||||
std::string external_model_file_path =
|
||||
obs_data_get_string(s, "whisper_model_path_external");
|
||||
if (external_model_file_path.empty()) {
|
||||
obs_log(LOG_WARNING, "External model file path is empty");
|
||||
} else {
|
||||
// check if the external model file is not currently loaded
|
||||
if (gf->whisper_model_file_currently_loaded ==
|
||||
external_model_file_path) {
|
||||
obs_log(LOG_INFO, "External model file is already loaded");
|
||||
return;
|
||||
} else {
|
||||
shutdown_whisper_thread(gf);
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, external_model_file_path,
|
||||
silero_vad_model_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// model path did not change
|
||||
obs_log(gf->log_level, "Model path did not change: %s == %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
|
||||
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
|
||||
// dtw_token_timestamps changed
|
||||
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
shutdown_whisper_thread(gf);
|
||||
start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file);
|
||||
} else {
|
||||
// dtw_token_timestamps did not change
|
||||
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
}
|
||||
}
|
10
src/whisper-utils/whisper-model-utils.h
Normal file
10
src/whisper-utils/whisper-model-utils.h
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef WHISPER_MODEL_UTILS_H
|
||||
#define WHISPER_MODEL_UTILS_H
|
||||
|
||||
#include <obs.h>
|
||||
|
||||
#include "transcription-filter-data.h"
|
||||
|
||||
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
|
||||
|
||||
#endif // WHISPER_MODEL_UTILS_H
|
@ -11,12 +11,14 @@
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <cfloat>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <fstream>
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
#include "model-utils/model-downloader.h"
|
||||
#include "model-utils/model-find-utils.h"
|
||||
|
||||
#define VAD_THOLD 0.0001f
|
||||
#define FREQ_THOLD 100.0f
|
||||
@ -229,7 +231,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
|
||||
int(pcm32f_size), float(pcm32f_size) / WHISPER_SAMPLE_RATE,
|
||||
gf->whisper_params.n_threads);
|
||||
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "whisper context is null");
|
||||
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
|
||||
@ -339,7 +341,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
|
||||
{
|
||||
// scoped lock the buffer mutex
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
|
||||
|
||||
// We need (gf->frames - gf->last_num_frames) new frames for a full segment,
|
||||
const size_t remaining_frames_to_full_segment = gf->frames - gf->last_num_frames;
|
||||
@ -419,7 +421,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
|
||||
if (gf->vad_enabled) {
|
||||
std::vector<float> vad_input(resampled_16khz[0],
|
||||
resampled_16khz[0] + resampled_16khz_frames);
|
||||
gf->vad->process(vad_input);
|
||||
gf->vad->process(vad_input, false);
|
||||
|
||||
std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
|
||||
if (stamps.size() == 0) {
|
||||
@ -526,25 +528,12 @@ void whisper_loop(void *data)
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING,
|
||||
"Whisper context is null. Whisper thread cannot start");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
obs_log(LOG_INFO, "starting whisper thread");
|
||||
|
||||
// Thread main loop
|
||||
while (true) {
|
||||
{
|
||||
if (gf->whisper_ctx_mutex == nullptr) {
|
||||
obs_log(LOG_WARNING, "whisper_ctx_mutex is null, exiting thread");
|
||||
break;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
|
||||
break;
|
||||
@ -555,7 +544,7 @@ void whisper_loop(void *data)
|
||||
while (true) {
|
||||
size_t input_buf_size = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
|
||||
input_buf_size = gf->input_buffers[0].size;
|
||||
}
|
||||
const size_t step_size_frames = gf->step_size_msec * gf->sample_rate / 1000;
|
||||
@ -563,10 +552,9 @@ void whisper_loop(void *data)
|
||||
|
||||
if (input_buf_size >= segment_size) {
|
||||
obs_log(gf->log_level,
|
||||
"found %lu bytes, %lu frames in input buffer, need >= %lu, processing",
|
||||
"found %lu bytes, %lu frames in input buffer, need >= %lu",
|
||||
input_buf_size, (size_t)(input_buf_size / sizeof(float)),
|
||||
segment_size);
|
||||
|
||||
// Process the audio. This will also remove the processed data from the input buffer.
|
||||
// Mutex is locked inside process_audio_from_buffer.
|
||||
process_audio_from_buffer(gf);
|
||||
@ -577,8 +565,8 @@ void whisper_loop(void *data)
|
||||
// Sleep for 10 ms using the condition variable wshiper_thread_cv
|
||||
// This will wake up the thread if there is new data in the input buffer
|
||||
// or if the whisper context is null
|
||||
std::unique_lock<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
gf->wshiper_thread_cv->wait_for(lock, std::chrono::milliseconds(10));
|
||||
std::unique_lock<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(10));
|
||||
}
|
||||
|
||||
obs_log(LOG_INFO, "exiting whisper thread");
|
||||
|
@ -6,9 +6,9 @@
|
||||
// buffer size in msec
|
||||
#define DEFAULT_BUFFER_SIZE_MSEC 3000
|
||||
// overlap in msec
|
||||
#define DEFAULT_OVERLAP_SIZE_MSEC 150
|
||||
#define DEFAULT_OVERLAP_SIZE_MSEC 125
|
||||
#define MAX_OVERLAP_SIZE_MSEC 1000
|
||||
#define MIN_OVERLAP_SIZE_MSEC 150
|
||||
#define MIN_OVERLAP_SIZE_MSEC 125
|
||||
|
||||
enum DetectionResult {
|
||||
DETECTION_RESULT_UNKNOWN = 0,
|
||||
|
@ -5,110 +5,15 @@
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
{
|
||||
// update the whisper model path
|
||||
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
|
||||
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
|
||||
|
||||
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
|
||||
is_external_model) {
|
||||
|
||||
if (gf->whisper_model_path != new_model_path) {
|
||||
// model path changed
|
||||
obs_log(gf->log_level, "model path changed from %s to %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
// check if the new model is external file
|
||||
if (!is_external_model) {
|
||||
// new model is not external file
|
||||
shutdown_whisper_thread(gf);
|
||||
|
||||
if (models_info.count(new_model_path) == 0) {
|
||||
obs_log(LOG_WARNING, "Model '%s' does not exist",
|
||||
new_model_path.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
const ModelInfo &model_info = models_info[new_model_path];
|
||||
|
||||
// check if the model exists, if not, download it
|
||||
std::string model_file_found = find_model_bin_file(model_info);
|
||||
if (model_file_found == "") {
|
||||
obs_log(LOG_WARNING, "Whisper model does not exist");
|
||||
download_model_with_ui_dialog(
|
||||
model_info, [gf, new_model_path](int download_status,
|
||||
const std::string &path) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO,
|
||||
"Model download complete");
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, path);
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Model exists, just load it
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, model_file_found);
|
||||
}
|
||||
} else {
|
||||
// new model is external file, get file location from file property
|
||||
std::string external_model_file_path =
|
||||
obs_data_get_string(s, "whisper_model_path_external");
|
||||
if (external_model_file_path.empty()) {
|
||||
obs_log(LOG_WARNING, "External model file path is empty");
|
||||
} else {
|
||||
// check if the external model file is not currently loaded
|
||||
if (gf->whisper_model_file_currently_loaded ==
|
||||
external_model_file_path) {
|
||||
obs_log(LOG_INFO, "External model file is already loaded");
|
||||
return;
|
||||
} else {
|
||||
shutdown_whisper_thread(gf);
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf,
|
||||
external_model_file_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// model path did not change
|
||||
obs_log(gf->log_level, "Model path did not change: %s == %s",
|
||||
gf->whisper_model_path.c_str(), new_model_path.c_str());
|
||||
}
|
||||
|
||||
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
|
||||
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
|
||||
// dtw_token_timestamps changed
|
||||
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
shutdown_whisper_thread(gf);
|
||||
start_whisper_thread_with_path(gf, gf->whisper_model_path);
|
||||
} else {
|
||||
// dtw_token_timestamps did not change
|
||||
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
}
|
||||
}
|
||||
|
||||
void shutdown_whisper_thread(struct transcription_filter_data *gf)
|
||||
{
|
||||
obs_log(gf->log_level, "shutdown_whisper_thread");
|
||||
if (gf->whisper_context != nullptr) {
|
||||
// acquire the mutex before freeing the context
|
||||
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
|
||||
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
whisper_free(gf->whisper_context);
|
||||
gf->whisper_context = nullptr;
|
||||
gf->wshiper_thread_cv->notify_all();
|
||||
gf->wshiper_thread_cv.notify_all();
|
||||
}
|
||||
if (gf->whisper_thread.joinable()) {
|
||||
gf->whisper_thread.join();
|
||||
@ -118,21 +23,18 @@ void shutdown_whisper_thread(struct transcription_filter_data *gf)
|
||||
}
|
||||
}
|
||||
|
||||
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path)
|
||||
void start_whisper_thread_with_path(struct transcription_filter_data *gf,
|
||||
const std::string &whisper_model_path,
|
||||
const char *silero_vad_model_file)
|
||||
{
|
||||
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str());
|
||||
if (gf->whisper_ctx_mutex == nullptr) {
|
||||
obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null");
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
|
||||
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str());
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context != nullptr) {
|
||||
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
|
||||
return;
|
||||
}
|
||||
|
||||
// initialize Silero VAD
|
||||
char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
|
||||
#ifdef _WIN32
|
||||
std::wstring silero_vad_model_path;
|
||||
silero_vad_model_path.assign(silero_vad_model_file,
|
||||
@ -140,18 +42,17 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
|
||||
#else
|
||||
std::string silero_vad_model_path = silero_vad_model_file;
|
||||
#endif
|
||||
bfree(silero_vad_model_file);
|
||||
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
||||
// for silero vad parameters
|
||||
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000,
|
||||
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 500,
|
||||
200, 250));
|
||||
|
||||
gf->whisper_context = init_whisper_context(path, gf);
|
||||
gf->whisper_context = init_whisper_context(whisper_model_path, gf);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to initialize whisper context");
|
||||
return;
|
||||
}
|
||||
gf->whisper_model_file_currently_loaded = path;
|
||||
gf->whisper_model_file_currently_loaded = whisper_model_path;
|
||||
std::thread new_whisper_thread(whisper_loop, gf);
|
||||
gf->whisper_thread.swap(new_whisper_thread);
|
||||
}
|
||||
|
@ -3,13 +3,11 @@
|
||||
|
||||
#include "transcription-filter-data.h"
|
||||
|
||||
#include <obs.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
|
||||
void shutdown_whisper_thread(struct transcription_filter_data *gf);
|
||||
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path);
|
||||
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path,
|
||||
const char *silero_vad_model_file);
|
||||
|
||||
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
|
||||
const std::vector<whisper_token_data> &seq2);
|
||||
|
Loading…
Reference in New Issue
Block a user