Offline transcription accuracy tests (#96)

* Update translation-utils.h, transcription-filter.h, whisper-model-utils.h, model-find-utils.h, and model-downloader.h

* Update create_context function to include ct2ModelFolder parameter

* fix: add fix_utf8 flag to transcription_filter_data struct

* Update create_context function to include ct2ModelFolder parameter

* Update read_text_from_file function to include join_sentences parameter

* fix: Update VadIterator::reset_states to include reset_hc parameter

* Update create_context function to include whisper_sampling_method parameter

* Update tests README with additional configuration options

* feat: Add function to find file in folder by regex expression

* refactor: Improve text conditioning logic in transcription-filter.cpp

* refactor: Improve text conditioning logic in transcription-filter.cpp

* chore: Update ctranslate2 dependency to version 1.2.0

* refactor: Improve text conditioning logic in transcription-filter.cpp

* chore: Update cmake BuildCTranslate2.cmake to disable -Wno-comma warning

* refactor: Update translation context in whisper-processing.cpp and translation-utils.cpp
This commit is contained in:
Roy Shilkrot 2024-05-10 17:37:09 -04:00 committed by GitHub
parent 2e83300fbb
commit 31c41a9574
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 1411 additions and 306 deletions

View File

@ -6,6 +6,7 @@ project(${_name} VERSION ${_version})
option(ENABLE_FRONTEND_API "Use obs-frontend-api for UI functionality" ON)
option(ENABLE_QT "Use Qt functionality" ON)
option(ENABLE_TESTS "Enable tests" OFF)
include(compilerconfig)
include(defaults)
@ -90,11 +91,41 @@ target_sources(
src/model-utils/model-downloader.cpp
src/model-utils/model-downloader-ui.cpp
src/model-utils/model-infos.cpp
src/model-utils/model-find-utils.cpp
src/whisper-utils/whisper-processing.cpp
src/whisper-utils/whisper-utils.cpp
src/whisper-utils/whisper-model-utils.cpp
src/whisper-utils/silero-vad-onnx.cpp
src/whisper-utils/token-buffer-thread.cpp
src/translation/translation.cpp
src/translation/translation-utils.cpp
src/utils.cpp)
set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name})
if(ENABLE_TESTS)
add_executable(${CMAKE_PROJECT_NAME}-tests)
include(cmake/FindLibAvObs.cmake)
target_sources(
${CMAKE_PROJECT_NAME}-tests
PRIVATE src/tests/localvocal-offline-test.cpp
src/transcription-utils.cpp
src/model-utils/model-infos.cpp
src/model-utils/model-find-utils.cpp
src/whisper-utils/whisper-processing.cpp
src/whisper-utils/whisper-utils.cpp
src/whisper-utils/silero-vad-onnx.cpp
src/whisper-utils/token-buffer-thread.cpp
src/translation/translation.cpp
src/utils.cpp)
find_libav(${CMAKE_PROJECT_NAME}-tests)
target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs)
target_include_directories(${CMAKE_PROJECT_NAME}-tests PRIVATE src)
# install the tests to the release/test directory
install(TARGETS ${CMAKE_PROJECT_NAME}-tests DESTINATION test)
endif()

View File

@ -7,15 +7,15 @@ if(APPLE)
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-macos-Release-1.1.1.tar.gz
URL_HASH SHA256=da04d88ecc1ea105f8ee672e4eab33af96e50c999c5cc8170e105e110392182b)
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-macos-Release-1.2.0.tar.gz
URL_HASH SHA256=9029F19B0F50E5EDC14473479EDF0A983F7D6FA00BE61DC1B01BF8AA7F1CDB1B)
FetchContent_MakeAvailable(ctranslate2_fetch)
add_library(ct2 INTERFACE)
target_link_libraries(ct2 INTERFACE "-framework Accelerate" ${ctranslate2_fetch_SOURCE_DIR}/lib/libctranslate2.a
${ctranslate2_fetch_SOURCE_DIR}/lib/libcpu_features.a)
set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include)
target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32)
target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32 -Wno-comma)
elseif(WIN32)

48
cmake/FindLibAvObs.cmake Normal file
View File

@ -0,0 +1,48 @@
# Find LibAV from the OBS dependencies
function(find_libav TARGET)
if(UNIX AND NOT APPLE)
find_package(PkgConfig REQUIRED)
pkg_check_modules(
FFMPEG
REQUIRED
IMPORTED_TARGET
libavformat
libavcodec
libavutil
libswresample)
if(FFMPEG_FOUND)
target_link_libraries(${TARGET} PRIVATE PkgConfig::FFMPEG)
else()
message(FATAL_ERROR "FFMPEG not found!")
endif()
return()
endif()
if(NOT buildspec)
file(READ "${CMAKE_CURRENT_SOURCE_DIR}/buildspec.json" buildspec)
endif()
string(
JSON
version
GET
${buildspec}
dependencies
prebuilt
version)
if(MSVC)
set(arch ${CMAKE_GENERATOR_PLATFORM})
elseif(APPLE)
set(arch universal)
endif()
set(deps_root "${CMAKE_CURRENT_SOURCE_DIR}/.deps/obs-deps-${version}-${arch}")
target_include_directories(${TARGET} PRIVATE "${deps_root}/include")
target_link_libraries(
${TARGET}
PRIVATE "${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avcodec${CMAKE_STATIC_LIBRARY_SUFFIX}"
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avformat${CMAKE_STATIC_LIBRARY_SUFFIX}"
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}avutil${CMAKE_STATIC_LIBRARY_SUFFIX}"
"${deps_root}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}swresample${CMAKE_STATIC_LIBRARY_SUFFIX}")
endfunction(find_libav)

View File

@ -55,3 +55,5 @@ suppress_sentences="Suppress sentences (each line)"
translate_output="Translation output"
dtw_token_timestamps="DTW token timestamps"
buffered_output="Buffered output (Experimental)"
translate_model="Translation Model"
Whisper-Based-Translation="Whisper-Based Translation"

View File

@ -1,44 +1,11 @@
#include "model-downloader.h"
#include "plugin-support.h"
#include "model-downloader-ui.h"
#include "model-find-utils.h"
#include <obs-module.h>
#include <obs-frontend-api.h>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <curl/curl.h>
std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (entry.path().filename() == file_name) {
return entry.path().string();
}
}
return "";
}
std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
{
// find .bin file in folder
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
if (entry.path().extension() == ".bin") {
const std::string bin_file_path = entry.path().string();
obs_log(LOG_INFO, "Model bin file found in folder: %s",
bin_file_path.c_str());
return bin_file_path;
}
}
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
model_local_folder_path.c_str());
return "";
}
std::string find_model_folder(const ModelInfo &model_info)
{
char *data_folder_models = obs_module_file("models");

View File

@ -2,13 +2,9 @@
#define MODEL_DOWNLOADER_H
#include <string>
#include <functional>
#include "model-downloader-types.h"
std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name);
std::string find_bin_file_in_folder(const std::string &path);
std::string find_model_folder(const ModelInfo &model_info);
std::string find_model_bin_file(const ModelInfo &model_info);

View File

@ -0,0 +1,50 @@
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
#include <regex>
#include <obs-module.h>
#include "model-find-utils.h"
#include "plugin-support.h"
std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (entry.path().filename() == file_name) {
return entry.path().string();
}
}
return "";
}
// Find a file in a folder by expression
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
const std::string &file_name_regex)
{
for (const auto &entry : std::filesystem::directory_iterator(folder_path)) {
if (std::regex_match(entry.path().filename().string(),
std::regex(file_name_regex))) {
return entry.path().string();
}
}
return "";
}
std::string find_bin_file_in_folder(const std::string &model_local_folder_path)
{
// find .bin file in folder
for (const auto &entry : std::filesystem::directory_iterator(model_local_folder_path)) {
if (entry.path().extension() == ".bin") {
const std::string bin_file_path = entry.path().string();
obs_log(LOG_INFO, "Model bin file found in folder: %s",
bin_file_path.c_str());
return bin_file_path;
}
}
obs_log(LOG_ERROR, "Model bin file not found in folder: %s",
model_local_folder_path.c_str());
return "";
}

View File

@ -0,0 +1,14 @@
#ifndef MODEL_FIND_UTILS_H
#define MODEL_FIND_UTILS_H
#include <string>
#include "model-downloader-types.h"
std::string find_file_in_folder_by_name(const std::string &folder_path,
const std::string &file_name);
std::string find_bin_file_in_folder(const std::string &path);
std::string find_file_in_folder_by_regex_expression(const std::string &folder_path,
const std::string &file_name_regex);
#endif // MODEL_FIND_UTILS_H

View File

@ -23,6 +23,28 @@ std::map<std::string, ModelInfo> models_info = {{
"B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"},
{"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true",
"D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}},
{"M2M-100 1.2B (1.25Gb)",
{"M2M-100 1.2BM",
"m2m-100-1_2B",
MODEL_TYPE_TRANSLATION,
{{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/model.bin?download=true",
"C97DF052A558895317312470E1FF7CB8EAE5416F7AE16214A2983C6853DD3CE5"},
{
"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/config.json?download=true",
"4244772990E30069563E3DDFB4AD6DC95BDFD2AC3DE667EA8858C9B0A8433FA8",
},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/generation_config.json?download=true",
"AED76366507333DDBB8BD49960F23C82FE6446B3319A46A54BEFDB45324CCF61"},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/shared_vocabulary.json?download=true",
"7EB5D0FF184C6095C7C10F9911C0AEA492250ABD12854F9C3D787C64B1C6397E"},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/special_tokens_map.json?download=true",
"C1A4F86C3874D279AE1B2A05162858DB5DD6C61665D84223ED886CBCFF08FDA6"},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/tokenizer_config.json?download=true",
"1566A6CFA4F541A55594C9D5E090F530812D5DE7C94882EA3AF156962D9933AE"},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/vocab.json?download=true",
"B6E77E474AEEA8F441363ACA7614317C06381F3EACFE10FB9856D5081D1074CC"},
{"https://huggingface.co/jncraton/m2m100_1.2B-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true",
"D8F7C76ED2A5E0822BE39F0A4F95A55EB19C78F4593CE609E2EDBC2AEA4D380A"}}}},
{"Whisper Base q5 (57Mb)",
{"Whisper Base q5",
"whisper-base-q5",

182
src/tests/README.md Normal file
View File

@ -0,0 +1,182 @@
# Building and Using the Offline Testing tool
The offline testing tool provides a way to run the internal core transcription+translation algorithm of the OBS plugin without running OBS, in effect simulating how it would run within OBS. However, all the audio is pre-cached in memory so it runs faster than real-time (e.g. it doesn't simulate the audio input timing).
The tool is useful for automating tests to measure performance.
## Building
The tool was tested on Windows with CUDA, so this guide focuses on this setup.
However there's nothing preventing the tool to successfully build and run on Mac as well.
Linux unfortunately is not supported at the moment.
Start by cloning the repo.
Proceed to build the plugin regularly, e.g.
```powershell
obs-localvocal> $env:CPU_OR_CUDA="12.2.0"
obs-localvocal> .\.github\scripts\Build-Windows.ps1 -Configuration Release
```
Then run CMake to enable the test tool in the build
```powershell
obs-localvocal> cmake -S . -B .\build_x64\ -DENABLE_TESTS=ON
```
Once that is done you can build the test target `.exe` only (instead of building everything) e.g.
```powershell
obs-localvocal> cmake --build .\build_x64\ --target obs-localvocal-tests --config Release
obs-localvocal> copy-item -Force ".\build_x64\Release\obs-localvocal-tests.exe" -Destination ".\release\Release\test"
```
Also in the above we're copying the result `.exe` to the `./release/Release/test` folder.
Next, a few `.dll` files need to be collected and placed alongside the `obs-localvocal-test.exe` file in the `./release/Release/test` folder. Fortunately all `dll`s are available in the plugin's build folders.
For an automatic step to copy `.dll`s run the script: (run from any location, it will orient itself)
```powershell
obs-localvocal> &"src\tests\copy_dlls.ps1"
```
For manual copying follow below:
From `.\release\Release\obs-plugins\64bit` copy:
- ctranslate2.dll
- cublas64_12.dll
- cublasLt64_12.dll
- cudart64_12.dll
- libopenblas.dll
- obs-localvocal.dll
- onnxruntime_providers_shared.dll
- onnxruntime.dll
- whisper.dll
From `.deps\obs-deps-2023-11-03-x64\bin` copy:
- avcodec-60.dll
- avdevice-60.dll
- avfilter-9.dll
- avformat-60.dll
- avutil-58.dll
- libx264-164.dll
- swresample-4.dll
- swscale-7.dll
- zlib.dll
Finally, from `.deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit` copy:
- obs-frontend-api.dll
- obs.dll
- w32-pthreads.dll
With all the `.dll`s in place in the `.\release\Release\test` folder the test tool should run.
## Using the test tool
The tool expects the following arguments:
- audio/video file
- configuration file in JSON
For example, this is a valid run command:
```powershell
obs-localvocal> .\release\Release\test\obs-localvocal-tests.exe "C:\Users\roysh\Downloads\audio.mp3" ".\config.json"
```
### Configuration
The tool must receive configuration to test different parameters of the algorithm.
- whisper language
- translation source language (or `none`)
- translation target language (or `none`)
- whisper model `.bin` file
- silero VAD model file e.g. `silero_vad.onnx`
- CT2 model *folder* (whitin which the model and json files can be found)
- fix UTF8 characters
- suppress sentences
- overlap in milliseconds
- log level (debug, info, warning, error)
- whisper sampling strategy (0 = greedy, 1 = beam)
The Whisper languages are listed in [whisper-language.h](../whisper-utils/whisper-language.h) and the CT2 language codes are listed in [language_codes.h](../translation/language_codes.h). They roughly match except CT2 has underscores e.g. `ko` -> `__ko__`, `ja` -> `__ja__`.
The JSON config file can look e.g. like
```
{
"whisper_language": "ko",
"source_language": "none",
"target_language": "none",
"whisper_model_path": ".../obs-localvocal/models/ggml-model-whisper-small/ggml-model-whisper-small.bin",
"silero_vad_model_file": ".../obs-localvocal/data/models/silero-vad/silero_vad.onnx",
"ct2_model_folder": ".../obs-localvocal/models/m2m-100-418M",
"fix_utf8": true,
"suppress_sentences": "끝까지 시청해주셔서 감사합니다/n구독과 좋아요 부탁드립니다!/nMBC 뉴스 안영백입니다./nMBC 뉴스 이덕영입니다/n구독과 좋아요 눌러주세요!/n구독과 좋아요 부탁드",
"overlap_ms": 150,
"log_level": "debug",
"whisper_sampling_method": 0
}
```
If you've used the OBS plugin to download a Whisper model and a CT2 model then you would find those in the OBS plugin config folders as visible above. It is recommended to do so.
Give the path to this file to the tool.
### Output
The tool would write a `output.txt` file in the running directory.
It would also output verbose running log to the console, e.g.
```
[02:07:25.148] [UNKNOWN] found 59539456 bytes, 14884864 frames in input buffer, need >= 576000
[02:07:25.150] [UNKNOWN] processing audio from buffer, 0 existing frames, 144000 frames needed to full segment (144000 frames)
[02:07:25.150] [UNKNOWN] with 144000 remaining to full segment, popped 143360 frames from info buffer, pushed at 0 (overlap)
[02:07:25.151] [UNKNOWN] first segment, no overlap exists, 143360 frames to process
[02:07:25.151] [UNKNOWN] processing 143360 frames (2986 ms), start timestamp 85
[02:07:25.154] [UNKNOWN] 2 channels, 47770 frames, 2985.625000 ms
[02:07:25.168] [UNKNOWN] VAD detected speech from 29696 to 47770 (18074 frames, 1129 ms)
[02:07:25.169] [UNKNOWN] run_whisper_inference: processing 18074 samples, 1.130 sec, 4 threads
[02:07:26.700] [UNKNOWN] Token 0: 50364, [_BEG_], p: 1.000, dtw: -1 [keep: 0]
```
### Translation
To translate with Whisper, set the whisper output language to your desired output and the CT2 languages to `none`.
For example this would be a Japanese translation with Whisper:
```json
{
"whisper_language": "ja",
"source_language": "none",
"target_language": "none",
// ...
}
```
To translate with CT2, make sure the Whisper output is in the spoken language and that it matches the source language.
For example this would be a Korean-to-Japanese translation with CT2 M2M100:
```json
{
"whisper_language": "ko",
"source_language": "__ko__",
"target_language": "__ja__",
// ...
}
```
## Evaluation of the results
The provided [python script](evaluate_output.py) can run WER/CER evaluation on the results.
Exmple of running the evaluation script:
```powershell
obs-localvocal> python .\src\tests\evaluate_output.py ".\ground_truth.txt" ".\output.txt"
```
It requires to install a couple packages:
```powershell
pip install Levenshtein diff_match_patch
```

40
src/tests/copy_dlls.ps1 Normal file
View File

@ -0,0 +1,40 @@
# change into the root directory of the repository from the location of this script
$scriptPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
Set-Location "$scriptPath\..\.."
$testToolPath = ".\release\Release\test"
# make sure the test tool directory exists
if (-not (Test-Path $testToolPath)) {
New-Item -ItemType Directory -Path $testToolPath | Out-Null
}
# copy the required DLLs to the test tool directory
$obsDlls = @(
".\release\Release\obs-plugins\64bit\ctranslate2.dll",
".\release\Release\obs-plugins\64bit\cublas64_12.dll",
".\release\Release\obs-plugins\64bit\cublasLt64_12.dll",
".\release\Release\obs-plugins\64bit\cudart64_12.dll",
".\release\Release\obs-plugins\64bit\libopenblas.dll",
".\release\Release\obs-plugins\64bit\onnxruntime_providers_shared.dll",
".\release\Release\obs-plugins\64bit\onnxruntime.dll",
".\release\Release\obs-plugins\64bit\whisper.dll",
".deps\obs-deps-2023-11-03-x64\bin\avcodec-60.dll",
".deps\obs-deps-2023-11-03-x64\bin\avdevice-60.dll",
".deps\obs-deps-2023-11-03-x64\bin\avfilter-9.dll",
".deps\obs-deps-2023-11-03-x64\bin\avformat-60.dll",
".deps\obs-deps-2023-11-03-x64\bin\avutil-58.dll",
".deps\obs-deps-2023-11-03-x64\bin\libx264-164.dll",
".deps\obs-deps-2023-11-03-x64\bin\swresample-4.dll",
".deps\obs-deps-2023-11-03-x64\bin\swscale-7.dll",
".deps\obs-deps-2023-11-03-x64\bin\zlib.dll"
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\obs-frontend-api.dll",
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\obs.dll",
".deps\obs-studio-30.0.2\build_x64\rundir\Debug\bin\64bit\w32-pthreads.dll"
)
$obsDlls | ForEach-Object {
Copy-Item -Force -Path $_ -Destination $testToolPath
}

View File

@ -0,0 +1,68 @@
import Levenshtein
import argparse
from diff_match_patch import diff_match_patch
def visualize_differences(ref_text, hyp_text):
dmp = diff_match_patch()
diffs = dmp.diff_main(hyp_text, ref_text, checklines=True)
html = dmp.diff_prettyHtml(diffs)
return html
def calculate_wer(ref_text, hyp_text):
ref_words = ref_text.split()
hyp_words = hyp_text.split()
distance = Levenshtein.distance(ref_words, hyp_words)
wer = distance / len(ref_words)
return wer
def calculate_cer(ref_text, hyp_text):
distance = Levenshtein.distance(ref_text, hyp_text)
cer = distance / len(ref_text)
return cer
def compare_tokens(ref_tokens, hyp_tokens):
comparisons = []
for ref_token, hyp_token in zip(ref_tokens, hyp_tokens):
distance = Levenshtein.distance(ref_token, hyp_token)
comparison = {'ref_token': ref_token, 'hyp_token': hyp_token, 'error_rate': distance / len(ref_token)}
comparisons.append(comparison)
return comparisons
def read_text_from_file(file_path, join_sentences=True):
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
sentences = file.readlines()
sentences = [sentence.strip() for sentence in sentences]
# merge into a single string
if join_sentences:
return ' '.join(sentences)
return sentences
parser = argparse.ArgumentParser(description='Evaluate output')
parser.add_argument('ref_file_path', type=str, help='Path to the reference file')
parser.add_argument('hyp_file_path', type=str, help='Path to the hypothesis file')
args = parser.parse_args()
ref_text = read_text_from_file(args.ref_file_path)
hyp_text = read_text_from_file(args.hyp_file_path)
wer = calculate_wer(ref_text, hyp_text)
cer = calculate_cer(ref_text, hyp_text)
print("Word Error Rate (WER):", wer)
print("Character Error Rate (CER):", cer)
ref_text = '\n'.join(read_text_from_file(args.ref_file_path, join_sentences=False))
hyp_text = '\n'.join(read_text_from_file(args.hyp_file_path, join_sentences=False))
html_diff = visualize_differences(ref_text, hyp_text)
with open("diff_visualization.html", "w", encoding="utf-8") as file:
file.write(html_diff)
from Bio.Align import PairwiseAligner
aligner = PairwiseAligner()
alignments = aligner.align(ref_text, hyp_text)
# write the first alignment to a file
with open("alignment.txt", "w", encoding="utf-8") as file:
file.write(alignments[0].format())

View File

@ -0,0 +1,586 @@
#include <iostream>
#include <string>
#include <codecvt>
#include <vector>
#include "transcription-filter-data.h"
#include "transcription-filter.h"
#include "transcription-utils.h"
#include "whisper-utils/whisper-utils.h"
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#ifdef _WIN32
#include <Windows.h>
#endif
void obs_log(int log_level, const char *format, ...)
{
if (log_level == LOG_DEBUG) {
return;
}
// print timestamp in format [HH:MM:SS.mmm], use std::chrono::system_clock
auto now = std::chrono::system_clock::now();
auto now_ms = std::chrono::time_point_cast<std::chrono::milliseconds>(now);
auto epoch = now_ms.time_since_epoch();
// convert to std::time_t in order to convert to std::tm
std::time_t now_time_t = std::chrono::system_clock::to_time_t(now);
std::tm now_tm = *std::localtime(&now_time_t);
// print timestamp
printf("[%02d:%02d:%02d.%03d] ", now_tm.tm_hour, now_tm.tm_min, now_tm.tm_sec,
(int)(epoch.count() % 1000));
// print log level
switch (log_level) {
case LOG_DEBUG:
printf("[DEBUG] ");
break;
case LOG_INFO:
printf("[INFO] ");
break;
case LOG_WARNING:
printf("[WARNING] ");
break;
case LOG_ERROR:
printf("[ERROR] ");
break;
default:
printf("[UNKNOWN] ");
break;
}
// convert format to wstring
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wformat = converter.from_bytes(format);
// print format with arguments with utf-8 support
va_list args;
va_start(args, format);
vprintf(format, args);
va_end(args);
printf("\n");
}
#if defined(_WIN32) || defined(__APPLE__)
extern "C" {
#include <libavformat/avformat.h>
#include <libavcodec/avcodec.h>
#include <libavutil/frame.h>
#include <libavutil/mem.h>
#include <libavutil/opt.h>
#include <libswresample/swresample.h>
}
std::vector<std::vector<uint8_t>>
read_audio_file(const char *filename, std::function<void(int, int)> initialization_callback)
{
obs_log(LOG_INFO, "Reading audio file %s", filename);
AVFormatContext *formatContext = nullptr;
int ret = avformat_open_input(&formatContext, filename, nullptr, nullptr);
if (ret != 0) {
char errbuf[AV_ERROR_MAX_STRING_SIZE];
av_make_error_string(errbuf, AV_ERROR_MAX_STRING_SIZE, ret);
obs_log(LOG_ERROR, "Error opening file: %s", errbuf);
return {};
}
if (avformat_find_stream_info(formatContext, nullptr) < 0) {
obs_log(LOG_ERROR, "Error finding stream information");
return {};
}
int audioStreamIndex = -1;
for (unsigned int i = 0; i < formatContext->nb_streams; i++) {
if (formatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
audioStreamIndex = i;
break;
}
}
if (audioStreamIndex == -1) {
obs_log(LOG_ERROR, "No audio stream found");
return {};
}
// print information about the file
av_dump_format(formatContext, 0, filename, 0);
// if the sample format is not float, return
if (formatContext->streams[audioStreamIndex]->codecpar->format != AV_SAMPLE_FMT_FLTP) {
obs_log(LOG_ERROR,
"Sample format is not float (it is %s). Encode the audio file with float sample format."
" For example, use the command 'ffmpeg -i input.mp3 -c:a pcm_f32le output.wav' to "
"convert the audio file to float format.",
av_get_sample_fmt_name(
(AVSampleFormat)formatContext->streams[audioStreamIndex]
->codecpar->format));
return {};
}
initialization_callback(formatContext->streams[audioStreamIndex]->codecpar->sample_rate,
formatContext->streams[audioStreamIndex]->codecpar->channels);
AVCodecParameters *codecParams = formatContext->streams[audioStreamIndex]->codecpar;
const AVCodec *codec = avcodec_find_decoder(codecParams->codec_id);
if (!codec) {
obs_log(LOG_ERROR, "Decoder not found");
return {};
}
AVCodecContext *codecContext = avcodec_alloc_context3(codec);
if (!codecContext) {
obs_log(LOG_ERROR, "Failed to allocate codec context");
return {};
}
if (avcodec_parameters_to_context(codecContext, codecParams) < 0) {
obs_log(LOG_ERROR, "Failed to copy codec parameters to codec context");
return {};
}
if (avcodec_open2(codecContext, codec, nullptr) < 0) {
obs_log(LOG_ERROR, "Failed to open codec");
return {};
}
AVFrame *frame = av_frame_alloc();
AVPacket packet;
std::vector<std::vector<uint8_t>> buffer(
formatContext->streams[audioStreamIndex]->codecpar->channels);
while (av_read_frame(formatContext, &packet) >= 0) {
if (packet.stream_index == audioStreamIndex) {
if (avcodec_send_packet(codecContext, &packet) == 0) {
while (avcodec_receive_frame(codecContext, frame) == 0) {
// push data to the buffer
for (int j = 0; j < codecContext->channels; j++) {
buffer[j].insert(buffer[j].end(), frame->data[j],
frame->data[j] +
frame->linesize[0]);
}
}
}
}
av_packet_unref(&packet);
}
av_frame_free(&frame);
avcodec_free_context(&codecContext);
avformat_close_input(&formatContext);
return buffer;
}
#endif
transcription_filter_data *
create_context(int sample_rate, int channels, const std::string &whisper_model_path,
const std::string &silero_vad_model_file, const std::string &ct2ModelFolder,
const whisper_sampling_strategy whisper_sampling_method = WHISPER_SAMPLING_GREEDY)
{
struct transcription_filter_data *gf = new transcription_filter_data();
gf->log_level = LOG_DEBUG;
gf->channels = channels;
gf->sample_rate = sample_rate;
gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / 3000.0));
gf->last_num_frames = 0;
gf->step_size_msec = 3000;
gf->min_sub_duration = 3000;
gf->last_sub_render_time = 0;
gf->save_srt = false;
gf->truncate_output_file = false;
gf->save_only_while_recording = false;
gf->rename_file_to_match_recording = false;
gf->process_while_muted = false;
gf->buffered_output = false;
gf->fix_utf8 = true;
for (size_t i = 0; i < gf->channels; i++) {
circlebuf_init(&gf->input_buffers[i]);
}
circlebuf_init(&gf->info_buffer);
// allocate copy buffers
gf->copy_buffers[0] =
static_cast<float *>(malloc(gf->channels * gf->frames * sizeof(float)));
for (size_t c = 1; c < gf->channels; c++) { // set the channel pointers
gf->copy_buffers[c] = gf->copy_buffers[0] + c * gf->frames;
}
gf->overlap_ms = 150;
gf->overlap_frames = (size_t)((float)gf->sample_rate / (1000.0f / (float)gf->overlap_ms));
obs_log(gf->log_level, "channels %d, frames %d, sample_rate %d", (int)gf->channels,
(int)gf->frames, gf->sample_rate);
obs_log(gf->log_level, "setup audio resampler");
struct resample_info src, dst;
src.samples_per_sec = gf->sample_rate;
src.format = AUDIO_FORMAT_FLOAT_PLANAR;
src.speakers = convert_speaker_layout((uint8_t)gf->channels);
dst.samples_per_sec = WHISPER_SAMPLE_RATE;
dst.format = AUDIO_FORMAT_FLOAT_PLANAR;
dst.speakers = convert_speaker_layout((uint8_t)1);
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
gf->whisper_model_file_currently_loaded = "";
gf->output_file_path = std::string("output.txt");
gf->whisper_model_path = std::string(""); // The update function will set the model path
gf->whisper_context = nullptr;
// gf->captions_monitor.initialize(
// gf,
// [gf](const std::string &text) {
// obs_log(LOG_INFO, "Captions: %s", text.c_str());
// },
// 30, std::chrono::seconds(10));
gf->vad_enabled = true;
gf->log_words = true;
gf->caption_to_stream = false;
gf->start_timestamp_ms = now_ms();
gf->sentence_number = 1;
gf->last_sub_render_time = 0;
gf->buffered_output = false;
gf->source_lang = "";
gf->target_lang = "";
gf->translation_ctx.add_context = true;
gf->translation_output = "";
gf->suppress_sentences = "";
gf->translate = false;
gf->whisper_params = whisper_full_default_params(whisper_sampling_method);
gf->whisper_params.duration_ms = 3000;
gf->whisper_params.language = "en";
gf->whisper_params.initial_prompt = "";
gf->whisper_params.n_threads = 4;
gf->whisper_params.n_max_text_ctx = 16384;
gf->whisper_params.translate = false;
gf->whisper_params.no_context = false;
gf->whisper_params.single_segment = true;
gf->whisper_params.print_special = false;
gf->whisper_params.print_progress = false;
gf->whisper_params.print_realtime = false;
gf->whisper_params.print_timestamps = false;
gf->whisper_params.token_timestamps = false;
gf->whisper_params.thold_pt = 0.01;
gf->whisper_params.thold_ptsum = 0.01;
gf->whisper_params.max_len = 0;
gf->whisper_params.split_on_word = false;
gf->whisper_params.max_tokens = 32;
gf->whisper_params.speed_up = false;
gf->whisper_params.suppress_blank = true;
gf->whisper_params.suppress_non_speech_tokens = true;
gf->whisper_params.temperature = 0.5;
gf->whisper_params.max_initial_ts = 1.0;
gf->whisper_params.length_penalty = -1;
gf->active = true;
start_whisper_thread_with_path(gf, whisper_model_path, silero_vad_model_file.c_str());
obs_log(gf->log_level, "context created");
return gf;
}
void set_text_callback(struct transcription_filter_data *gf,
const DetectionResultWithText &resultIn)
{
DetectionResultWithText result = resultIn;
if (!result.text.empty() && result.result == DETECTION_RESULT_SPEECH) {
std::string str_copy = result.text;
if (gf->fix_utf8) {
str_copy = fix_utf8(str_copy);
}
str_copy = remove_leading_trailing_nonalpha(str_copy);
// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
// split the suppression list by newline into individual sentences
std::vector<std::string> suppress_sentences_list =
split(gf->suppress_sentences, '\n');
// check if the text is in the suppression list
for (const std::string &suppress_sentence : suppress_sentences_list) {
// check if str_copy starts with the suppress sentence
if (str_copy.find(suppress_sentence) == 0) {
obs_log(LOG_INFO, "Suppressed sentence: '%s'",
str_copy.c_str());
gf->last_text = str_copy;
return; // do not process the sentence
}
}
}
if (gf->translate) {
obs_log(gf->log_level, "Translating text. %s -> %s",
gf->source_lang.c_str(), gf->target_lang.c_str());
std::string translated_text;
if (translate(gf->translation_ctx, str_copy, gf->source_lang,
gf->target_lang,
translated_text) == OBS_POLYGLOT_TRANSLATION_SUCCESS) {
if (gf->log_words) {
obs_log(LOG_INFO, "Translation: '%s' -> '%s'",
str_copy.c_str(), translated_text.c_str());
}
// overwrite the original text with the translated text
str_copy = str_copy + " -> " + translated_text;
} else {
obs_log(gf->log_level, "Failed to translate text");
}
}
std::ofstream output_file(gf->output_file_path, std::ios::app);
output_file << str_copy << std::endl;
output_file.close();
}
/*
if (gf->buffered_output) {
gf->captions_monitor.addWords(result.tokens);
}
if (gf->output_file_path != "" && gf->text_source_name.empty()) {
// Check if we should save the sentence
// should the file be truncated?
std::ios_base::openmode openmode = std::ios::out;
if (gf->truncate_output_file) {
openmode |= std::ios::trunc;
} else {
openmode |= std::ios::app;
}
if (!gf->save_srt) {
// Write raw sentence to file
std::ofstream output_file(gf->output_file_path, openmode);
output_file << str_copy << std::endl;
output_file.close();
} else {
obs_log(gf->log_level, "Saving sentence to file %s, sentence #%d",
gf->output_file_path.c_str(), gf->sentence_number);
// Append sentence to file in .srt format
std::ofstream output_file(gf->output_file_path, openmode);
output_file << gf->sentence_number << std::endl;
// use the start and end timestamps to calculate the start and end time in srt format
auto format_ts_for_srt = [&output_file](uint64_t ts) {
uint64_t time_s = ts / 1000;
uint64_t time_m = time_s / 60;
uint64_t time_h = time_m / 60;
uint64_t time_ms_rem = ts % 1000;
uint64_t time_s_rem = time_s % 60;
uint64_t time_m_rem = time_m % 60;
uint64_t time_h_rem = time_h % 60;
output_file << std::setfill('0') << std::setw(2) << time_h_rem
<< ":" << std::setfill('0') << std::setw(2)
<< time_m_rem << ":" << std::setfill('0')
<< std::setw(2) << time_s_rem << ","
<< std::setfill('0') << std::setw(3) << time_ms_rem;
};
format_ts_for_srt(result.start_timestamp_ms);
output_file << " --> ";
format_ts_for_srt(result.end_timestamp_ms);
output_file << std::endl;
output_file << str_copy << std::endl;
output_file << std::endl;
output_file.close();
gf->sentence_number++;
}
}
*/
};
void release_context(transcription_filter_data *gf)
{
obs_log(LOG_INFO, "destroy");
shutdown_whisper_thread(gf);
if (gf->resampler_to_whisper) {
audio_resampler_destroy(gf->resampler_to_whisper);
}
{
std::lock_guard<std::mutex> lockbuf(gf->whisper_buf_mutex);
free(gf->copy_buffers[0]);
gf->copy_buffers[0] = nullptr;
for (size_t i = 0; i < gf->channels; i++) {
circlebuf_free(&gf->input_buffers[i]);
}
}
circlebuf_free(&gf->info_buffer);
delete gf;
}
int wmain(int argc, wchar_t *argv[])
{
if (argc < 3) {
std::cout << "Usage: localvocal-offline-test <audio-file> <config_json_file>"
<< std::endl;
return 1;
}
#ifdef _WIN32
// Set console output to UTF-8
SetConsoleOutputCP(CP_UTF8);
#endif
std::wstring file = argv[1];
std::wstring configJsonFile = argv[2];
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::string filenameStr = converter.to_bytes(file);
// read the configuration json file
std::ifstream config_stream(configJsonFile);
if (!config_stream.is_open()) {
std::cout << "Failed to open config file" << std::endl;
return 1;
}
nlohmann::json config;
config_stream >> config;
config_stream.close();
// get the configuration values
std::string whisperModelPathStr = config["whisper_model_path"];
std::string sileroVadModelFileStr = config["silero_vad_model_file"];
std::string sourceLanguageStr = config["source_language"];
std::string targetLanguageStr = config["target_language"];
std::string whisperLanguageStr = config["whisper_language"];
std::string ct2ModelFolderStr = config["ct2_model_folder"];
std::string logLevelStr = config["log_level"];
whisper_sampling_strategy whisper_sampling_method = config["whisper_sampling_method"];
std::cout << "LocalVocal Offline Test" << std::endl;
transcription_filter_data *gf = nullptr;
std::vector<std::vector<uint8_t>> audio =
read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) {
gf = create_context(sample_rate, channels, whisperModelPathStr,
sileroVadModelFileStr, ct2ModelFolderStr,
whisper_sampling_method);
if (sourceLanguageStr.empty() || targetLanguageStr.empty() ||
sourceLanguageStr == "none" || targetLanguageStr == "none") {
obs_log(LOG_INFO,
"Source or target translation language are empty or disabled");
} else {
obs_log(LOG_INFO, "Setting translation languages");
gf->source_lang = sourceLanguageStr;
gf->target_lang = targetLanguageStr;
build_and_enable_translation(gf, ct2ModelFolderStr.c_str());
}
gf->whisper_params.language = whisperLanguageStr.c_str();
if (config.contains("fix_utf8")) {
obs_log(LOG_INFO, "Setting fix_utf8 to %s",
config["fix_utf8"] ? "true" : "false");
gf->fix_utf8 = config["fix_utf8"];
}
if (config.contains("suppress_sentences")) {
obs_log(LOG_INFO, "Setting suppress_sentences to %ls",
config["suppress_sentences"].get<std::string>().c_str());
gf->suppress_sentences =
config["suppress_sentences"].get<std::string>();
}
if (config.contains("overlap_ms")) {
obs_log(LOG_INFO, "Setting overlap_ms to %d",
config["overlap_ms"].get<int>());
gf->overlap_ms = config["overlap_ms"];
gf->overlap_frames = (size_t)((float)gf->sample_rate /
(1000.0f / (float)gf->overlap_ms));
}
// set log level
if (logLevelStr == "debug") {
gf->log_level = LOG_DEBUG;
} else if (logLevelStr == "info") {
gf->log_level = LOG_INFO;
} else if (logLevelStr == "warning") {
gf->log_level = LOG_WARNING;
} else if (logLevelStr == "error") {
gf->log_level = LOG_ERROR;
}
});
if (gf == nullptr) {
std::cout << "Failed to create context" << std::endl;
return 1;
}
if (audio.empty()) {
std::cout << "Failed to read audio file" << std::endl;
return 1;
}
// truncate the output file
obs_log(LOG_INFO, "Truncating output file");
std::ofstream output_file(gf->output_file_path, std::ios::trunc);
output_file.close();
// fill up the whisper buffer
{
obs_log(LOG_INFO, "Filling up whisper buffer");
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
int frames = 4096;
const int frame_size_bytes = sizeof(float);
int frames_size_bytes = frames * frame_size_bytes;
int frames_count = 0;
while (true) {
// check if there are enough frames left in the audio buffer
if ((frames_count + frames) > (audio[0].size() / frame_size_bytes)) {
// only take the remaining frames
frames = audio[0].size() / frame_size_bytes - frames_count;
frames_size_bytes = frames * frame_size_bytes;
}
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c],
audio[c].data() +
frames_count * frame_size_bytes,
frames_size_bytes);
}
// push audio packet info (timestamp/frame count) to info circlebuf
struct transcription_filter_audio_info info = {0};
info.frames = frames; // number of frames in this packet
// make a timestamp from the current frame count
info.timestamp = frames_count * 1000 / gf->sample_rate;
circlebuf_push_back(&gf->info_buffer, &info, sizeof(info));
frames_count += frames;
if (frames_count >= audio[0].size() / frame_size_bytes) {
break;
}
}
}
// wait for processing to finish
obs_log(LOG_INFO, "Waiting for processing to finish");
while (true) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// check the input circlebuf has more data
size_t input_buf_size = 0;
{
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
input_buf_size = gf->input_buffers[0].size;
}
const size_t step_size_frames = gf->step_size_msec * gf->sample_rate / 1000;
const size_t segment_size = step_size_frames * sizeof(float);
if (input_buf_size < segment_size) {
break;
}
}
release_context(gf);
obs_log(LOG_INFO, "LocalVocal Offline Test Done");
return 0;
}

View File

@ -1,7 +1,6 @@
#ifndef TRANSCRIPTION_FILTER_DATA_H
#define TRANSCRIPTION_FILTER_DATA_H
#include <obs.h>
#include <util/circlebuf.h>
#include <util/darray.h>
#include <media-io/audio-resampler.h>
@ -22,8 +21,6 @@
#define MAX_PREPROC_CHANNELS 10
#define MT_ obs_module_text
struct transcription_filter_data {
obs_source_t *context; // obs filter source (this filter)
size_t channels; // number of channels
@ -82,6 +79,7 @@ struct transcription_filter_data {
bool buffered_output = false;
bool enable_token_ts_dtw = false;
std::string suppress_sentences;
bool fix_utf8 = true;
// Last transcription result
std::string last_text;
@ -97,17 +95,18 @@ struct transcription_filter_data {
// Use std for thread and mutex
std::thread whisper_thread;
std::mutex *whisper_buf_mutex;
std::mutex *whisper_ctx_mutex;
std::condition_variable *wshiper_thread_cv;
std::mutex whisper_buf_mutex;
std::mutex whisper_ctx_mutex;
std::condition_variable wshiper_thread_cv;
// translation context
struct translation_context translation_ctx;
std::string translation_model_index;
TokenBufferThread captions_monitor;
// ctor
transcription_filter_data()
transcription_filter_data() : whisper_buf_mutex(), whisper_ctx_mutex(), wshiper_thread_cv()
{
// initialize all pointers to nullptr
for (size_t i = 0; i < MAX_PREPROC_CHANNELS; i++) {
@ -117,9 +116,6 @@ struct transcription_filter_data {
resampler_to_whisper = nullptr;
whisper_model_path = "";
whisper_context = nullptr;
whisper_buf_mutex = nullptr;
whisper_ctx_mutex = nullptr;
wshiper_thread_cv = nullptr;
output_file_path = "";
whisper_model_file_currently_loaded = "";
}

View File

@ -8,8 +8,10 @@
#include "model-utils/model-downloader.h"
#include "whisper-utils/whisper-processing.h"
#include "whisper-utils/whisper-language.h"
#include "whisper-utils/whisper-model-utils.h"
#include "whisper-utils/whisper-utils.h"
#include "translation/language_codes.h"
#include "translation/translation-utils.h"
#include "translation/translation.h"
#include "utils.h"
@ -24,37 +26,6 @@
#include <QString>
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
{
switch (channels) {
case 0:
return SPEAKERS_UNKNOWN;
case 1:
return SPEAKERS_MONO;
case 2:
return SPEAKERS_STEREO;
case 3:
return SPEAKERS_2POINT1;
case 4:
return SPEAKERS_4POINT0;
case 5:
return SPEAKERS_4POINT1;
case 6:
return SPEAKERS_5POINT1;
case 8:
return SPEAKERS_7POINT1;
default:
return SPEAKERS_UNKNOWN;
}
}
inline uint64_t now_ms()
{
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
bool add_sources_to_list(void *list_property, obs_source_t *source)
{
auto source_id = obs_source_get_id(source);
@ -97,13 +68,8 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
return audio;
}
if (!gf->whisper_buf_mutex || !gf->whisper_ctx_mutex) {
obs_log(LOG_ERROR, "whisper mutexes are null");
return audio;
}
{
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex); // scoped lock
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex); // scoped lock
// push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) {
circlebuf_push_back(&gf->input_buffers[c], audio->data[c],
@ -138,7 +104,7 @@ void transcription_filter_destroy(void *data)
}
{
std::lock_guard<std::mutex> lockbuf(*gf->whisper_buf_mutex);
std::lock_guard<std::mutex> lockbuf(gf->whisper_buf_mutex);
bfree(gf->copy_buffers[0]);
gf->copy_buffers[0] = nullptr;
for (size_t i = 0; i < gf->channels; i++) {
@ -147,11 +113,7 @@ void transcription_filter_destroy(void *data)
}
circlebuf_free(&gf->info_buffer);
delete gf->whisper_buf_mutex;
delete gf->whisper_ctx_mutex;
delete gf->wshiper_thread_cv;
delete gf;
bfree(gf);
}
void send_caption_to_source(const std::string &target_source_name, const std::string &str_copy,
@ -185,9 +147,15 @@ void set_text_callback(struct transcription_filter_data *gf,
}
gf->last_sub_render_time = now;
// recondition the text
std::string str_copy = fix_utf8(result.text);
str_copy = remove_leading_trailing_nonalpha(str_copy);
std::string str_copy = result.text;
// recondition the text - only if the output is not English
if (gf->whisper_params.language != nullptr &&
strcmp(gf->whisper_params.language, "en") != 0) {
str_copy = fix_utf8(str_copy);
} else {
str_copy = remove_leading_trailing_nonalpha(str_copy);
}
// if suppression is enabled, check if the text is in the suppression list
if (!gf->suppress_sentences.empty()) {
@ -340,10 +308,17 @@ void transcription_filter_update(void *data, obs_data_t *s)
gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context");
gf->translation_output = obs_data_get_string(s, "translate_output");
gf->suppress_sentences = obs_data_get_string(s, "suppress_sentences");
gf->translation_model_index = obs_data_get_string(s, "translate_model");
if (new_translate != gf->translate) {
if (new_translate) {
start_translation(gf);
if (gf->translation_model_index != "whisper-based-translation") {
start_translation(gf);
} else {
// whisper-based translation
obs_log(gf->log_level, "Starting whisper-based translation...");
gf->translate = false;
}
} else {
gf->translate = false;
}
@ -381,21 +356,21 @@ void transcription_filter_update(void *data, obs_data_t *s)
obs_weak_source_release(old_weak_text_source);
}
if (gf->whisper_ctx_mutex == nullptr) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
obs_log(gf->log_level, "update whisper model");
update_whisper_model(gf, s);
obs_log(gf->log_level, "update whisper params");
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
gf->whisper_params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
} else {
// take the language from gf->target_lang
gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str();
}
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
@ -425,7 +400,8 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
{
obs_log(LOG_INFO, "LocalVocal filter create");
struct transcription_filter_data *gf = new transcription_filter_data();
void *data = bmalloc(sizeof(struct transcription_filter_data));
struct transcription_filter_data *gf = new (data) transcription_filter_data();
// Get the number of channels for the input source
gf->channels = audio_output_get_channels(obs_get_audio());
@ -479,10 +455,6 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
gf->resampler_to_whisper = audio_resampler_create(&dst, &src);
obs_log(gf->log_level, "setup mutexes and condition variables");
gf->whisper_buf_mutex = new std::mutex();
gf->whisper_ctx_mutex = new std::mutex();
gf->wshiper_thread_cv = new std::condition_variable();
obs_log(gf->log_level, "clear text source data");
const char *subtitle_sources = obs_data_get_string(settings, "subtitle_sources");
if (subtitle_sources == nullptr || strcmp(subtitle_sources, "none") == 0 ||
@ -675,6 +647,7 @@ void transcription_filter_defaults(obs_data_t *s)
obs_data_set_default_string(s, "translate_target_language", "__es__");
obs_data_set_default_string(s, "translate_source_language", "__en__");
obs_data_set_default_bool(s, "translate_add_context", true);
obs_data_set_default_string(s, "translate_model", "whisper-based-translation");
obs_data_set_default_string(s, "suppress_sentences", SUPPRESS_SENTENCES_DEFAULT);
// Whisper parameters
@ -788,6 +761,21 @@ obs_properties_t *transcription_filter_properties(void *data)
obs_properties_t *translation_group = obs_properties_create();
obs_property_t *translation_group_prop = obs_properties_add_group(
ppts, "translate", MT_("translate"), OBS_GROUP_CHECKABLE, translation_group);
// add translatio model selection
obs_property_t *prop_translate_model = obs_properties_add_list(
translation_group, "translate_model", MT_("translate_model"), OBS_COMBO_TYPE_LIST,
OBS_COMBO_FORMAT_STRING);
// Populate the dropdown with the translation models
// add "Whisper-Based Translation" option
obs_property_list_add_string(prop_translate_model, MT_("Whisper-Based-Translation"),
"whisper-based-translation");
for (const auto &model_info : models_info) {
if (model_info.second.type == MODEL_TYPE_TRANSLATION) {
obs_property_list_add_string(prop_translate_model, model_info.first.c_str(),
model_info.first.c_str());
}
}
// add target language selection
obs_property_t *prop_tgt = obs_properties_add_list(
translation_group, "translate_target_language", MT_("target_language"),
@ -822,8 +810,9 @@ obs_properties_t *transcription_filter_properties(void *data)
UNUSED_PARAMETER(property);
// Show/Hide the translation group
const bool translate_enabled = obs_data_get_bool(settings, "translate");
for (const auto &prop : {"translate_target_language", "translate_source_language",
"translate_add_context", "translate_output"}) {
for (const auto &prop :
{"translate_target_language", "translate_source_language",
"translate_add_context", "translate_output", "translate_model"}) {
obs_property_set_visible(obs_properties_get(props, prop),
translate_enabled);
}

View File

@ -4,6 +4,8 @@
extern "C" {
#endif
#define MT_ obs_module_text
void transcription_filter_activate(void *data);
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter);
void transcription_filter_update(void *data, obs_data_t *s);
@ -19,7 +21,10 @@ const char *const PLUGIN_INFO_TEMPLATE =
"<a href=\"https://github.com/occ-ai\">OCC AI</a> ❤️ "
"<a href=\"https://www.patreon.com/RoyShilkrot\">Support & Follow</a>";
const char *const SUPPRESS_SENTENCES_DEFAULT = "Thank you for watching\nThank you";
const char *const SUPPRESS_SENTENCES_DEFAULT =
"Thank you for watching\nPlease like and subscribe\n"
"Check out my other videos\nFollow me on social media\n"
"Please consider supporting me";
#ifdef __cplusplus
}

View File

@ -90,17 +90,18 @@ std::string remove_leading_trailing_nonalpha(const std::string &str)
{
std::string str_copy = str;
// remove trailing spaces, newlines, tabs or punctuation
str_copy.erase(std::find_if(str_copy.rbegin(), str_copy.rend(),
[](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
})
.base(),
str_copy.end());
auto last_non_space =
std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}).base();
str_copy.erase(last_non_space, str_copy.end());
// remove leading spaces, newlines, tabs or punctuation
str_copy.erase(str_copy.begin(),
std::find_if(str_copy.begin(), str_copy.end(), [](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}));
auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(),
[](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}) +
1;
str_copy.erase(str_copy.begin(), first_non_space);
return str_copy;
}

View File

@ -3,9 +3,42 @@
#include <string>
#include <vector>
#include <chrono>
#include <media-io/audio-io.h>
std::string fix_utf8(const std::string &str);
std::string remove_leading_trailing_nonalpha(const std::string &str);
std::vector<std::string> split(const std::string &string, char delimiter);
inline enum speaker_layout convert_speaker_layout(uint8_t channels)
{
switch (channels) {
case 0:
return SPEAKERS_UNKNOWN;
case 1:
return SPEAKERS_MONO;
case 2:
return SPEAKERS_STEREO;
case 3:
return SPEAKERS_2POINT1;
case 4:
return SPEAKERS_4POINT0;
case 5:
return SPEAKERS_4POINT1;
case 6:
return SPEAKERS_5POINT1;
case 8:
return SPEAKERS_7POINT1;
default:
return SPEAKERS_UNKNOWN;
}
}
inline uint64_t now_ms()
{
return std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
}
#endif // TRANSCRIPTION_UTILS_H

View File

@ -203,3 +203,57 @@ std::map<std::string, std::string> language_codes_reverse = {{"Afrikaans", "__af
{"Yoruba", "__yo__"},
{"Chinese", "__zh__"},
{"Zulu", "__zu__"}};
std::map<std::string, std::string> language_codes_2 = {
{"af", "__af__"}, {"am", "__am__"}, {"ar", "__ar__"}, {"ast", "__ast__"},
{"az", "__az__"}, {"ba", "__ba__"}, {"be", "__be__"}, {"bg", "__bg__"},
{"bn", "__bn__"}, {"br", "__br__"}, {"bs", "__bs__"}, {"ca", "__ca__"},
{"ceb", "__ceb__"}, {"cs", "__cs__"}, {"cy", "__cy__"}, {"da", "__da__"},
{"de", "__de__"}, {"el", "__el__"}, {"en", "__en__"}, {"es", "__es__"},
{"et", "__et__"}, {"fa", "__fa__"}, {"ff", "__ff__"}, {"fi", "__fi__"},
{"fr", "__fr__"}, {"fy", "__fy__"}, {"ga", "__ga__"}, {"gd", "__gd__"},
{"gl", "__gl__"}, {"gu", "__gu__"}, {"ha", "__ha__"}, {"he", "__he__"},
{"hi", "__hi__"}, {"hr", "__hr__"}, {"ht", "__ht__"}, {"hu", "__hu__"},
{"hy", "__hy__"}, {"id", "__id__"}, {"ig", "__ig__"}, {"ilo", "__ilo__"},
{"is", "__is__"}, {"it", "__it__"}, {"ja", "__ja__"}, {"jv", "__jv__"},
{"ka", "__ka__"}, {"kk", "__kk__"}, {"km", "__km__"}, {"kn", "__kn__"},
{"ko", "__ko__"}, {"lb", "__lb__"}, {"lg", "__lg__"}, {"ln", "__ln__"},
{"lo", "__lo__"}, {"lt", "__lt__"}, {"lv", "__lv__"}, {"mg", "__mg__"},
{"mk", "__mk__"}, {"ml", "__ml__"}, {"mn", "__mn__"}, {"mr", "__mr__"},
{"ms", "__ms__"}, {"my", "__my__"}, {"ne", "__ne__"}, {"nl", "__nl__"},
{"no", "__no__"}, {"ns", "__ns__"}, {"oc", "__oc__"}, {"or", "__or__"},
{"pa", "__pa__"}, {"pl", "__pl__"}, {"ps", "__ps__"}, {"pt", "__pt__"},
{"ro", "__ro__"}, {"ru", "__ru__"}, {"sd", "__sd__"}, {"si", "__si__"},
{"sk", "__sk__"}, {"sl", "__sl__"}, {"so", "__so__"}, {"sq", "__sq__"},
{"sr", "__sr__"}, {"ss", "__ss__"}, {"su", "__su__"}, {"sv", "__sv__"},
{"sw", "__sw__"}, {"ta", "__ta__"}, {"th", "__th__"}, {"tl", "__tl__"},
{"tn", "__tn__"}, {"tr", "__tr__"}, {"uk", "__uk__"}, {"ur", "__ur__"},
{"uz", "__uz__"}, {"vi", "__vi__"}, {"wo", "__wo__"}, {"xh", "__xh__"},
{"yi", "__yi__"}, {"yo", "__yo__"}, {"zh", "__zh__"}, {"zu", "__zu__"}};
std::map<std::string, std::string> language_codes_2_reverse = {
{"__af__", "af"}, {"__am__", "am"}, {"__ar__", "ar"}, {"__ast__", "_st"},
{"__az__", "az"}, {"__ba__", "ba"}, {"__be__", "be"}, {"__bg__", "bg"},
{"__bn__", "bn"}, {"__br__", "br"}, {"__bs__", "bs"}, {"__ca__", "ca"},
{"__ceb__", "_eb"}, {"__cs__", "cs"}, {"__cy__", "cy"}, {"__da__", "da"},
{"__de__", "de"}, {"__el__", "el"}, {"__en__", "en"}, {"__es__", "es"},
{"__et__", "et"}, {"__fa__", "fa"}, {"__ff__", "ff"}, {"__fi__", "fi"},
{"__fr__", "fr"}, {"__fy__", "fy"}, {"__ga__", "ga"}, {"__gd__", "gd"},
{"__gl__", "gl"}, {"__gu__", "gu"}, {"__ha__", "ha"}, {"__he__", "he"},
{"__hi__", "hi"}, {"__hr__", "hr"}, {"__ht__", "ht"}, {"__hu__", "hu"},
{"__hy__", "hy"}, {"__id__", "id"}, {"__ig__", "ig"}, {"__ilo__", "_lo"},
{"__is__", "is"}, {"__it__", "it"}, {"__ja__", "ja"}, {"__jv__", "jv"},
{"__ka__", "ka"}, {"__kk__", "kk"}, {"__km__", "km"}, {"__kn__", "kn"},
{"__ko__", "ko"}, {"__lb__", "lb"}, {"__lg__", "lg"}, {"__ln__", "ln"},
{"__lo__", "lo"}, {"__lt__", "lt"}, {"__lv__", "lv"}, {"__mg__", "mg"},
{"__mk__", "mk"}, {"__ml__", "ml"}, {"__mn__", "mn"}, {"__mr__", "mr"},
{"__ms__", "ms"}, {"__my__", "my"}, {"__ne__", "ne"}, {"__nl__", "nl"},
{"__no__", "no"}, {"__ns__", "ns"}, {"__oc__", "oc"}, {"__or__", "or"},
{"__pa__", "pa"}, {"__pl__", "pl"}, {"__ps__", "ps"}, {"__pt__", "pt"},
{"__ro__", "ro"}, {"__ru__", "ru"}, {"__sd__", "sd"}, {"__si__", "si"},
{"__sk__", "sk"}, {"__sl__", "sl"}, {"__so__", "so"}, {"__sq__", "sq"},
{"__sr__", "sr"}, {"__ss__", "ss"}, {"__su__", "su"}, {"__sv__", "sv"},
{"__sw__", "sw"}, {"__ta__", "ta"}, {"__th__", "th"}, {"__tl__", "tl"},
{"__tn__", "tn"}, {"__tr__", "tr"}, {"__uk__", "uk"}, {"__ur__", "ur"},
{"__uz__", "uz"}, {"__vi__", "vi"}, {"__wo__", "wo"}, {"__xh__", "xh"},
{"__yi__", "yi"}, {"__yo__", "yo"}, {"__zh__", "zh"}, {"__zu__", "zu"}};

View File

@ -0,0 +1,33 @@
#include "translation-utils.h"
#include "translation.h"
#include "plugin-support.h"
#include "model-utils/model-downloader.h"
#include <obs-module.h>
void start_translation(struct transcription_filter_data *gf)
{
obs_log(LOG_INFO, "Starting translation...");
const ModelInfo &translation_model_info = models_info[gf->translation_model_index];
std::string model_file_found = find_model_folder(translation_model_info);
if (model_file_found == "") {
obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading...");
download_model_with_ui_dialog(
translation_model_info,
[gf, model_file_found](int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO, "CT2 model download complete");
build_and_enable_translation(gf, path);
} else {
obs_log(LOG_ERROR, "Model download failed");
gf->translate = false;
}
});
} else {
// Model exists, just load it
build_and_enable_translation(gf, model_file_found);
}
}

View File

@ -0,0 +1,4 @@
#include "transcription-filter-data.h"
void start_translation(struct transcription_filter_data *gf);

View File

@ -1,6 +1,6 @@
#include "translation.h"
#include "plugin-support.h"
#include "model-utils/model-downloader.h"
#include "model-utils/model-find-utils.h"
#include "transcription-filter-data.h"
#include <ctranslate2/translator.h>
@ -11,11 +11,7 @@
void build_and_enable_translation(struct transcription_filter_data *gf,
const std::string &model_file_path)
{
if (gf->whisper_ctx_mutex == nullptr) {
obs_log(LOG_ERROR, "Whisper context mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
gf->translation_ctx.local_model_folder_path = model_file_path;
if (build_translation_context(gf->translation_ctx) ==
@ -28,38 +24,13 @@ void build_and_enable_translation(struct transcription_filter_data *gf,
}
}
void start_translation(struct transcription_filter_data *gf)
{
obs_log(LOG_INFO, "Starting translation...");
const ModelInfo &translation_model_info = models_info["M2M-100 418M (495Mb)"];
std::string model_file_found = find_model_folder(translation_model_info);
if (model_file_found == "") {
obs_log(LOG_INFO, "Translation CT2 model does not exist. Downloading...");
download_model_with_ui_dialog(
translation_model_info,
[gf, model_file_found](int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO, "CT2 model download complete");
build_and_enable_translation(gf, path);
} else {
obs_log(LOG_ERROR, "Model download failed");
gf->translate = false;
}
});
} else {
// Model exists, just load it
build_and_enable_translation(gf, model_file_found);
}
}
int build_translation_context(struct translation_context &translation_ctx)
{
std::string local_model_path = translation_ctx.local_model_folder_path;
obs_log(LOG_INFO, "Building translation context from '%s'...", local_model_path.c_str());
// find the SPM file in the model folder
std::string local_spm_path =
find_file_in_folder_by_name(local_model_path, "sentencepiece.bpe.model");
std::string local_spm_path = find_file_in_folder_by_regex_expression(
local_model_path, "(sentencepiece|spm).*?\\.model");
try {
obs_log(LOG_INFO, "Loading SPM from %s", local_spm_path.c_str());

View File

@ -19,8 +19,9 @@ struct translation_context {
bool add_context;
};
void start_translation(struct transcription_filter_data *gf);
int build_translation_context(struct translation_context &translation_ctx);
void build_and_enable_translation(struct transcription_filter_data *gf,
const std::string &model_file_path);
int translate(struct translation_context &translation_ctx, const std::string &text,
const std::string &source_lang, const std::string &target_lang, std::string &result);

View File

@ -90,12 +90,13 @@ void VadIterator::init_onnx_model(const SileroString &model_path)
session = std::make_shared<Ort::Session>(env, model_path.c_str(), session_options);
};
void VadIterator::reset_states()
void VadIterator::reset_states(bool reset_hc)
{
// Call reset before each audio start
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
triggered = false;
if (reset_hc) {
std::memset(_h.data(), 0.0f, _h.size() * sizeof(float));
std::memset(_c.data(), 0.0f, _c.size() * sizeof(float));
triggered = false;
}
temp_end = 0;
current_sample = 0;
@ -105,7 +106,7 @@ void VadIterator::reset_states()
current_speech = timestamp_t();
};
void VadIterator::predict(const std::vector<float> &data)
float VadIterator::predict_one(const std::vector<float> &data)
{
// Infer
// Create ort tensors
@ -138,6 +139,13 @@ void VadIterator::predict(const std::vector<float> &data)
float *cn = ort_outputs[2].GetTensorMutableData<float>();
std::memcpy(_c.data(), cn, size_hc * sizeof(float));
return speech_prob;
}
void VadIterator::predict(const std::vector<float> &data)
{
const float speech_prob = predict_one(data);
// Push forward sample index
current_sample += (unsigned int)window_size_samples;
@ -254,9 +262,9 @@ void VadIterator::predict(const std::vector<float> &data)
}
};
void VadIterator::process(const std::vector<float> &input_wav)
void VadIterator::process(const std::vector<float> &input_wav, bool reset_hc)
{
reset_states();
reset_states(reset_hc);
audio_length_samples = (int)input_wav.size();
@ -280,7 +288,7 @@ void VadIterator::process(const std::vector<float> &input_wav)
void VadIterator::process(const std::vector<float> &input_wav, std::vector<float> &output_wav)
{
process(input_wav);
process(input_wav, true);
collect_chunks(input_wav, output_wav);
}

View File

@ -43,11 +43,12 @@ private:
private:
void init_engine_threads(int inter_threads, int intra_threads);
void init_onnx_model(const SileroString &model_path);
void reset_states();
void reset_states(bool reset_hc);
float predict_one(const std::vector<float> &data);
void predict(const std::vector<float> &data);
public:
void process(const std::vector<float> &input_wav);
void process(const std::vector<float> &input_wav, bool reset_hc = true);
void process(const std::vector<float> &input_wav, std::vector<float> &output_wav);
void collect_chunks(const std::vector<float> &input_wav, std::vector<float> &output_wav);
const std::vector<timestamp_t> get_speech_timestamps() const;

View File

@ -0,0 +1,106 @@
#include "whisper-utils.h"
#include "plugin-support.h"
#include "model-utils/model-downloader.h"
#include "whisper-processing.h"
#include <obs-module.h>
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
{
// update the whisper model path
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
if (silero_vad_model_file == nullptr) {
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
return;
}
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
is_external_model) {
if (gf->whisper_model_path != new_model_path) {
// model path changed
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
// check if the new model is external file
if (!is_external_model) {
// new model is not external file
shutdown_whisper_thread(gf);
if (models_info.count(new_model_path) == 0) {
obs_log(LOG_WARNING, "Model '%s' does not exist",
new_model_path.c_str());
return;
}
const ModelInfo &model_info = models_info[new_model_path];
// check if the model exists, if not, download it
std::string model_file_found = find_model_bin_file(model_info);
if (model_file_found == "") {
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
model_info,
[gf, new_model_path, silero_vad_model_file](
int download_status, const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(
gf, path, silero_vad_model_file);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, model_file_found,
silero_vad_model_file);
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
// check if the external model file is not currently loaded
if (gf->whisper_model_file_currently_loaded ==
external_model_file_path) {
obs_log(LOG_INFO, "External model file is already loaded");
return;
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, external_model_file_path,
silero_vad_model_file);
}
}
}
} else {
// model path did not change
obs_log(gf->log_level, "Model path did not change: %s == %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
// dtw_token_timestamps changed
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file);
} else {
// dtw_token_timestamps did not change
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
}
}

View File

@ -0,0 +1,10 @@
#ifndef WHISPER_MODEL_UTILS_H
#define WHISPER_MODEL_UTILS_H
#include <obs.h>
#include "transcription-filter-data.h"
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
#endif // WHISPER_MODEL_UTILS_H

View File

@ -11,12 +11,14 @@
#include <algorithm>
#include <cctype>
#include <cfloat>
#include <chrono>
#include <cmath>
#ifdef _WIN32
#include <fstream>
#include <Windows.h>
#endif
#include "model-utils/model-downloader.h"
#include "model-utils/model-find-utils.h"
#define VAD_THOLD 0.0001f
#define FREQ_THOLD 100.0f
@ -229,7 +231,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter
int(pcm32f_size), float(pcm32f_size) / WHISPER_SAMPLE_RATE,
gf->whisper_params.n_threads);
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "whisper context is null");
return {DETECTION_RESULT_UNKNOWN, "", 0, 0, {}};
@ -339,7 +341,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
{
// scoped lock the buffer mutex
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
// We need (gf->frames - gf->last_num_frames) new frames for a full segment,
const size_t remaining_frames_to_full_segment = gf->frames - gf->last_num_frames;
@ -419,7 +421,7 @@ void process_audio_from_buffer(struct transcription_filter_data *gf)
if (gf->vad_enabled) {
std::vector<float> vad_input(resampled_16khz[0],
resampled_16khz[0] + resampled_16khz_frames);
gf->vad->process(vad_input);
gf->vad->process(vad_input, false);
std::vector<timestamp_t> stamps = gf->vad->get_speech_timestamps();
if (stamps.size() == 0) {
@ -526,25 +528,12 @@ void whisper_loop(void *data)
struct transcription_filter_data *gf =
static_cast<struct transcription_filter_data *>(data);
{
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING,
"Whisper context is null. Whisper thread cannot start");
return;
}
}
obs_log(LOG_INFO, "starting whisper thread");
// Thread main loop
while (true) {
{
if (gf->whisper_ctx_mutex == nullptr) {
obs_log(LOG_WARNING, "whisper_ctx_mutex is null, exiting thread");
break;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context == nullptr) {
obs_log(LOG_WARNING, "Whisper context is null, exiting thread");
break;
@ -555,7 +544,7 @@ void whisper_loop(void *data)
while (true) {
size_t input_buf_size = 0;
{
std::lock_guard<std::mutex> lock(*gf->whisper_buf_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_buf_mutex);
input_buf_size = gf->input_buffers[0].size;
}
const size_t step_size_frames = gf->step_size_msec * gf->sample_rate / 1000;
@ -563,10 +552,9 @@ void whisper_loop(void *data)
if (input_buf_size >= segment_size) {
obs_log(gf->log_level,
"found %lu bytes, %lu frames in input buffer, need >= %lu, processing",
"found %lu bytes, %lu frames in input buffer, need >= %lu",
input_buf_size, (size_t)(input_buf_size / sizeof(float)),
segment_size);
// Process the audio. This will also remove the processed data from the input buffer.
// Mutex is locked inside process_audio_from_buffer.
process_audio_from_buffer(gf);
@ -577,8 +565,8 @@ void whisper_loop(void *data)
// Sleep for 10 ms using the condition variable wshiper_thread_cv
// This will wake up the thread if there is new data in the input buffer
// or if the whisper context is null
std::unique_lock<std::mutex> lock(*gf->whisper_ctx_mutex);
gf->wshiper_thread_cv->wait_for(lock, std::chrono::milliseconds(10));
std::unique_lock<std::mutex> lock(gf->whisper_ctx_mutex);
gf->wshiper_thread_cv.wait_for(lock, std::chrono::milliseconds(10));
}
obs_log(LOG_INFO, "exiting whisper thread");

View File

@ -6,9 +6,9 @@
// buffer size in msec
#define DEFAULT_BUFFER_SIZE_MSEC 3000
// overlap in msec
#define DEFAULT_OVERLAP_SIZE_MSEC 150
#define DEFAULT_OVERLAP_SIZE_MSEC 125
#define MAX_OVERLAP_SIZE_MSEC 1000
#define MIN_OVERLAP_SIZE_MSEC 150
#define MIN_OVERLAP_SIZE_MSEC 125
enum DetectionResult {
DETECTION_RESULT_UNKNOWN = 0,

View File

@ -5,110 +5,15 @@
#include <obs-module.h>
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
{
// update the whisper model path
std::string new_model_path = obs_data_get_string(s, "whisper_model_path");
const bool is_external_model = new_model_path.find("!!!external!!!") != std::string::npos;
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
is_external_model) {
if (gf->whisper_model_path != new_model_path) {
// model path changed
obs_log(gf->log_level, "model path changed from %s to %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
// check if the new model is external file
if (!is_external_model) {
// new model is not external file
shutdown_whisper_thread(gf);
if (models_info.count(new_model_path) == 0) {
obs_log(LOG_WARNING, "Model '%s' does not exist",
new_model_path.c_str());
return;
}
const ModelInfo &model_info = models_info[new_model_path];
// check if the model exists, if not, download it
std::string model_file_found = find_model_bin_file(model_info);
if (model_file_found == "") {
obs_log(LOG_WARNING, "Whisper model does not exist");
download_model_with_ui_dialog(
model_info, [gf, new_model_path](int download_status,
const std::string &path) {
if (download_status == 0) {
obs_log(LOG_INFO,
"Model download complete");
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, path);
} else {
obs_log(LOG_ERROR, "Model download failed");
}
});
} else {
// Model exists, just load it
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf, model_file_found);
}
} else {
// new model is external file, get file location from file property
std::string external_model_file_path =
obs_data_get_string(s, "whisper_model_path_external");
if (external_model_file_path.empty()) {
obs_log(LOG_WARNING, "External model file path is empty");
} else {
// check if the external model file is not currently loaded
if (gf->whisper_model_file_currently_loaded ==
external_model_file_path) {
obs_log(LOG_INFO, "External model file is already loaded");
return;
} else {
shutdown_whisper_thread(gf);
gf->whisper_model_path = new_model_path;
start_whisper_thread_with_path(gf,
external_model_file_path);
}
}
}
} else {
// model path did not change
obs_log(gf->log_level, "Model path did not change: %s == %s",
gf->whisper_model_path.c_str(), new_model_path.c_str());
}
const bool new_dtw_timestamps = obs_data_get_bool(s, "dtw_token_timestamps");
if (new_dtw_timestamps != gf->enable_token_ts_dtw) {
// dtw_token_timestamps changed
obs_log(gf->log_level, "dtw_token_timestamps changed from %d to %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
shutdown_whisper_thread(gf);
start_whisper_thread_with_path(gf, gf->whisper_model_path);
} else {
// dtw_token_timestamps did not change
obs_log(gf->log_level, "dtw_token_timestamps did not change: %d == %d",
gf->enable_token_ts_dtw, new_dtw_timestamps);
}
}
void shutdown_whisper_thread(struct transcription_filter_data *gf)
{
obs_log(gf->log_level, "shutdown_whisper_thread");
if (gf->whisper_context != nullptr) {
// acquire the mutex before freeing the context
if (!gf->whisper_ctx_mutex || !gf->wshiper_thread_cv) {
obs_log(LOG_ERROR, "whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
whisper_free(gf->whisper_context);
gf->whisper_context = nullptr;
gf->wshiper_thread_cv->notify_all();
gf->wshiper_thread_cv.notify_all();
}
if (gf->whisper_thread.joinable()) {
gf->whisper_thread.join();
@ -118,21 +23,18 @@ void shutdown_whisper_thread(struct transcription_filter_data *gf)
}
}
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path)
void start_whisper_thread_with_path(struct transcription_filter_data *gf,
const std::string &whisper_model_path,
const char *silero_vad_model_file)
{
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", path.c_str());
if (gf->whisper_ctx_mutex == nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_ctx_mutex is null");
return;
}
std::lock_guard<std::mutex> lock(*gf->whisper_ctx_mutex);
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str());
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
if (gf->whisper_context != nullptr) {
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
return;
}
// initialize Silero VAD
char *silero_vad_model_file = obs_module_file("models/silero-vad/silero_vad.onnx");
#ifdef _WIN32
std::wstring silero_vad_model_path;
silero_vad_model_path.assign(silero_vad_model_file,
@ -140,18 +42,17 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, const
#else
std::string silero_vad_model_path = silero_vad_model_file;
#endif
bfree(silero_vad_model_file);
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
// for silero vad parameters
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 1000,
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 64, 0.5f, 500,
200, 250));
gf->whisper_context = init_whisper_context(path, gf);
gf->whisper_context = init_whisper_context(whisper_model_path, gf);
if (gf->whisper_context == nullptr) {
obs_log(LOG_ERROR, "Failed to initialize whisper context");
return;
}
gf->whisper_model_file_currently_loaded = path;
gf->whisper_model_file_currently_loaded = whisper_model_path;
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}

View File

@ -3,13 +3,11 @@
#include "transcription-filter-data.h"
#include <obs.h>
#include <string>
void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s);
void shutdown_whisper_thread(struct transcription_filter_data *gf);
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path);
void start_whisper_thread_with_path(struct transcription_filter_data *gf, const std::string &path,
const char *silero_vad_model_file);
std::pair<int, int> findStartOfOverlap(const std::vector<whisper_token_data> &seq1,
const std::vector<whisper_token_data> &seq2);