From e3c69518a762394ec9f29d0bc3cedbaddb4569c0 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Fri, 6 Sep 2024 10:27:05 -0400 Subject: [PATCH] Fix hangups and VAD segmentation (#157) * Fix hangups and VAD segmentation * feat: Add max_sub_duration field to transcription filter data * chore: Update VAD parameters for better segmentation accuracy * feat: Add segment_duration field to transcription filter data * feat: Optimize VAD processing for better performance * feat: Refactor token buffer thread and whisper processing The code changes involve refactoring the token buffer thread and whisper processing. The token buffer thread now uses the variable name `word_token` instead of `word` for better clarity. In the whisper processing, the log message format has been updated to include the segment number and token number. These changes aim to improve the performance and accuracy of VAD processing, as well as add new fields to the transcription filter data. * Refactor token buffer thread and whisper processing * refactor: Update translation context in transcription filter The code changes in this commit update the translation context in the transcription filter. The `translate_add_context` property has been changed from a boolean to an integer slider, allowing the user to specify the number of context lines to add to the translation. This change aims to provide more flexibility in controlling the context for translation and improve the accuracy of the translation output. * refactor: Update last_text variable name in transcription filter callbacks * feat: Add translation language utilities This commit adds a new file, `translation-language-utils.h`, which contains utility functions for handling translation languages. The `remove_start_punctuation` function removes any leading punctuation from a given string. This utility will be used in the translation process to improve the quality of the translated output. * feat: Update ICU library configuration and dependencies This commit updates the configuration and dependencies of the ICU library. The `BuildICU.cmake` file has been modified to use the `INSTALL_DIR` variable instead of the `ICU_INSTALL_DIR` variable for setting the ICU library paths. Additionally, the `ICU_IN_LIBRARY` variable has been renamed to `ICU_IN_LIBRARY` for better clarity. These changes aim to improve the build process and ensure proper linking of the ICU library. * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies * refactor: Update ICU library configuration and dependencies This commit updates the `BuildICU.cmake` file to set the `CFLAGS`, `CXXFLAGS`, and `LDFLAGS` environment variables to `-fPIC` for Linux platforms. This change aims to ensure that the ICU library is built with position-independent code, improving compatibility and security. Additionally, the `icuin` library has been renamed to `icui18n` to align with the naming convention. These updates enhance the build process and maintain consistency in the ICU library configuration. --- CMakeLists.txt | 13 +- cmake/BuildICU.cmake | 101 +++++ data/locale/en-US.ini | 12 +- src/tests/localvocal-offline-test.cpp | 5 +- src/transcription-filter-callbacks.cpp | 52 ++- src/transcription-filter-data.h | 13 +- src/transcription-filter-properties.cpp | 127 +++++- src/transcription-filter.cpp | 82 +--- .../translation-language-utils.cpp | 33 ++ src/translation/translation-language-utils.h | 8 + src/translation/translation.cpp | 64 ++- src/translation/translation.h | 9 +- src/whisper-utils/token-buffer-thread.cpp | 92 ++++- src/whisper-utils/token-buffer-thread.h | 28 +- src/whisper-utils/vad-processing.cpp | 377 ++++++++++++++++++ src/whisper-utils/vad-processing.h | 18 + src/whisper-utils/whisper-processing.cpp | 279 ++----------- src/whisper-utils/whisper-processing.h | 4 +- src/whisper-utils/whisper-utils.cpp | 21 +- 19 files changed, 927 insertions(+), 411 deletions(-) create mode 100644 cmake/BuildICU.cmake create mode 100644 src/translation/translation-language-utils.cpp create mode 100644 src/translation/translation-language-utils.h create mode 100644 src/whisper-utils/vad-processing.cpp create mode 100644 src/whisper-utils/vad-processing.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 7eff873..e64f45c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -96,6 +96,11 @@ else() include(cmake/FetchOnnxruntime.cmake) endif() +include(cmake/BuildICU.cmake) +# Add ICU to the target +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ICU) +target_include_directories(${CMAKE_PROJECT_NAME} SYSTEM PUBLIC ${ICU_INCLUDE_DIR}) + target_sources( ${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c @@ -114,9 +119,11 @@ target_sources( src/whisper-utils/whisper-model-utils.cpp src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp + src/whisper-utils/vad-processing.cpp src/translation/language_codes.cpp src/translation/translation.cpp src/translation/translation-utils.cpp + src/translation/translation-language-utils.cpp src/ui/filter-replace-dialog.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) @@ -137,12 +144,14 @@ if(ENABLE_TESTS) src/whisper-utils/whisper-utils.cpp src/whisper-utils/silero-vad-onnx.cpp src/whisper-utils/token-buffer-thread.cpp + src/whisper-utils/vad-processing.cpp src/translation/language_codes.cpp - src/translation/translation.cpp) + src/translation/translation.cpp + src/translation/translation-language-utils.cpp) find_libav(${CMAKE_PROJECT_NAME}-tests) - target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs) + target_link_libraries(${CMAKE_PROJECT_NAME}-tests PRIVATE ct2 sentencepiece Whispercpp Ort OBS::libobs ICU) target_include_directories(${CMAKE_PROJECT_NAME}-tests PRIVATE src) # install the tests to the release/test directory diff --git a/cmake/BuildICU.cmake b/cmake/BuildICU.cmake new file mode 100644 index 0000000..a3c575d --- /dev/null +++ b/cmake/BuildICU.cmake @@ -0,0 +1,101 @@ +include(FetchContent) +include(ExternalProject) + +set(ICU_VERSION "75.1") +set(ICU_VERSION_UNDERSCORE "75_1") +set(ICU_VERSION_DASH "75-1") +set(ICU_VERSION_NO_MINOR "75") + +if(WIN32) + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-Win64-MSVC2022.zip" + ) + set(ICU_HASH "SHA256=7ac9c0dc6ccc1ec809c7d5689b8d831c5b8f6b11ecf70fdccc55f7ae8731ac8f") + + FetchContent_Declare( + ICU_build + URL ${ICU_URL} + URL_HASH ${ICU_HASH}) + + FetchContent_MakeAvailable(ICU_build) + + # Assuming the ZIP structure, adjust paths as necessary + set(ICU_INCLUDE_DIR "${icu_build_SOURCE_DIR}/include") + set(ICU_LIBRARY_DIR "${icu_build_SOURCE_DIR}/lib64") + set(ICU_BINARY_DIR "${icu_build_SOURCE_DIR}/bin64") + + # Define the library names + set(ICU_LIBRARIES icudt icuuc icuin) + + foreach(lib ${ICU_LIBRARIES}) + # Add ICU library + find_library( + ICU_LIB_${lib} + NAMES ${lib} + PATHS ${ICU_LIBRARY_DIR} + NO_DEFAULT_PATH REQUIRED) + # find the dll + find_file( + ICU_DLL_${lib} + NAMES ${lib}${ICU_VERSION_NO_MINOR}.dll + PATHS ${ICU_BINARY_DIR} + NO_DEFAULT_PATH) + # Copy the DLLs to the output directory + install(FILES ${ICU_DLL_${lib}} DESTINATION "obs-plugins/64bit") + # add the library + add_library(ICU::${lib} SHARED IMPORTED GLOBAL) + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIB_${lib}}" IMPORTED_IMPLIB + "${ICU_LIB_${lib}}") + endforeach() +else() + set(ICU_URL + "https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION_DASH}/icu4c-${ICU_VERSION_UNDERSCORE}-src.tgz" + ) + set(ICU_HASH "SHA256=cb968df3e4d2e87e8b11c49a5d01c787bd13b9545280fc6642f826527618caef") + if(APPLE) + set(ICU_PLATFORM "MacOSX") + set(TARGET_ARCH -arch\ $ENV{MACOS_ARCH}) + set(ICU_BUILD_ENV_VARS CFLAGS=${TARGET_ARCH} CXXFLAGS=${TARGET_ARCH} LDFLAGS=${TARGET_ARCH}) + else() + set(ICU_PLATFORM "Linux") + set(ICU_BUILD_ENV_VARS CFLAGS=-fPIC CXXFLAGS=-fPIC LDFLAGS=-fPIC) + endif() + + ExternalProject_Add( + ICU_build + DOWNLOAD_EXTRACT_TIMESTAMP true + GIT_REPOSITORY "https://github.com/unicode-org/icu.git" + GIT_TAG "release-${ICU_VERSION_DASH}" + CONFIGURE_COMMAND ${CMAKE_COMMAND} -E env ${ICU_BUILD_ENV_VARS} /icu4c/source/runConfigureICU + ${ICU_PLATFORM} --prefix= --enable-static --disable-shared + BUILD_COMMAND make -j4 + BUILD_BYPRODUCTS + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icudata${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icuuc${CMAKE_STATIC_LIBRARY_SUFFIX} + /lib/${CMAKE_STATIC_LIBRARY_PREFIX}icui18n${CMAKE_STATIC_LIBRARY_SUFFIX} + INSTALL_COMMAND make install + BUILD_IN_SOURCE 1) + + ExternalProject_Get_Property(ICU_build INSTALL_DIR) + + set(ICU_INCLUDE_DIR "${INSTALL_DIR}/include") + set(ICU_LIBRARY_DIR "${INSTALL_DIR}/lib") + + set(ICU_LIBRARIES icudata icuuc icui18n) + + foreach(lib ${ICU_LIBRARIES}) + add_library(ICU::${lib} STATIC IMPORTED GLOBAL) + add_dependencies(ICU::${lib} ICU_build) + set(ICU_LIBRARY "${ICU_LIBRARY_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}${lib}${CMAKE_STATIC_LIBRARY_SUFFIX}") + set_target_properties(ICU::${lib} PROPERTIES IMPORTED_LOCATION "${ICU_LIBRARY}" INTERFACE_INCLUDE_DIRECTORIES + "${ICU_INCLUDE_DIR}") + endforeach(lib ${ICU_LIBRARIES}) +endif() + +# Create an interface target for ICU +add_library(ICU INTERFACE) +add_dependencies(ICU ICU_build) +foreach(lib ${ICU_LIBRARIES}) + target_link_libraries(ICU INTERFACE ICU::${lib}) +endforeach() +target_include_directories(ICU SYSTEM INTERFACE $) diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index 0f7b661..9ef4d18 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -1,12 +1,9 @@ LocalVocalPlugin="LocalVocal Plugin" transcription_filterAudioFilter="LocalVocal Transcription" -vad_enabled="VAD Enabled" vad_threshold="VAD Threshold" log_level="Internal Log Level" log_words="Log Output to Console" caption_to_stream="Stream Captions" -step_by_step_processing="Step-by-step processing (⚠️ increased processing)" -step_size_msec="Step size (ms)" subtitle_sources="Output Destination" none_no_output="None / No output" file_output_enable="Save to File" @@ -51,7 +48,6 @@ translate="Translation" translate_add_context="Translate with context" whisper_translate="Translate to English (Whisper)" buffer_size_msec="Buffer size (ms)" -overlap_size_msec="Overlap size (ms)" suppress_sentences="Suppress sentences (each line)" translate_output="Output Destination" dtw_token_timestamps="DTW token timestamps" @@ -85,4 +81,10 @@ buffered_output_parameters="Buffered Output Configuration" file_output_info="Note: Translation output will be saved to a file in the same directory with the target language added to the name, e.g. 'output_es.srt'." partial_transcription="Enable Partial Transcription" partial_transcription_info="Partial transcription will increase processing load on your machine to transcribe content in real-time, which may impact performance." -partial_latency="Latency (ms)" \ No newline at end of file +partial_latency="Latency (ms)" +vad_mode="VAD Mode" +Active_VAD="Active VAD" +Hybrid_VAD="Hybrid VAD" +translate_only_full_sentences="Translate only full sentences" +duration_filter_threshold="Duration filter" +segment_duration="Segment duration" \ No newline at end of file diff --git a/src/tests/localvocal-offline-test.cpp b/src/tests/localvocal-offline-test.cpp index 8fec08b..ee936af 100644 --- a/src/tests/localvocal-offline-test.cpp +++ b/src/tests/localvocal-offline-test.cpp @@ -17,6 +17,7 @@ #include "transcription-filter.h" #include "transcription-utils.h" #include "whisper-utils/whisper-utils.h" +#include "whisper-utils/vad-processing.h" #include "audio-file-utils.h" #include "translation/language_codes.h" @@ -148,7 +149,7 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p // }, // 30, std::chrono::seconds(10)); - gf->vad_enabled = true; + gf->vad_mode = VAD_MODE_ACTIVE; gf->log_words = true; gf->caption_to_stream = false; gf->start_timestamp_ms = now_ms(); @@ -157,7 +158,7 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p gf->buffered_output = false; gf->target_lang = ""; - gf->translation_ctx.add_context = true; + gf->translation_ctx.add_context = 1; gf->translation_output = ""; gf->translate = false; gf->sentence_psum_accept_thresh = 0.4; diff --git a/src/transcription-filter-callbacks.cpp b/src/transcription-filter-callbacks.cpp index f5c2209..7b8208f 100644 --- a/src/transcription-filter-callbacks.cpp +++ b/src/transcription-filter-callbacks.cpp @@ -53,8 +53,8 @@ std::string send_sentence_to_translation(const std::string &sentence, struct transcription_filter_data *gf, const std::string &source_language) { - const std::string last_text = gf->last_text; - gf->last_text = sentence; + const std::string last_text = gf->last_text_for_translation; + gf->last_text_for_translation = sentence; if (gf->translate && !sentence.empty()) { obs_log(gf->log_level, "Translating text. %s -> %s", source_language.c_str(), gf->target_lang.c_str()); @@ -199,11 +199,6 @@ void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &resultIn) { DetectionResultWithText result = resultIn; - if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || - result.result == DETECTION_RESULT_PARTIAL)) { - gf->last_sub_render_time = now_ms(); - gf->cleared_last_sub = false; - } std::string str_copy = result.text; @@ -233,9 +228,12 @@ void set_text_callback(struct transcription_filter_data *gf, } } + bool should_translate = + gf->translate_only_full_sentences ? result.result == DETECTION_RESULT_SPEECH : true; + // send the sentence to translation (if enabled) std::string translated_sentence = - send_sentence_to_translation(str_copy, gf, result.language); + should_translate ? send_sentence_to_translation(str_copy, gf, result.language) : ""; if (gf->translate) { if (gf->translation_output == "none") { @@ -243,10 +241,12 @@ void set_text_callback(struct transcription_filter_data *gf, str_copy = translated_sentence; } else { if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - // buffered output - add the sentence to the monitor - gf->translation_monitor.addSentence(translated_sentence); - } + // buffered output - add the sentence to the monitor + gf->translation_monitor.addSentenceFromStdString( + translated_sentence, + get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->translation_output, translated_sentence, @@ -256,9 +256,10 @@ void set_text_callback(struct transcription_filter_data *gf, } if (gf->buffered_output) { - if (result.result == DETECTION_RESULT_SPEECH) { - gf->captions_monitor.addSentence(str_copy); - } + gf->captions_monitor.addSentenceFromStdString( + str_copy, get_time_point_from_ms(result.start_timestamp_ms), + get_time_point_from_ms(result.end_timestamp_ms), + result.result == DETECTION_RESULT_PARTIAL); } else { // non-buffered output - send the sentence to the selected source send_caption_to_source(gf->text_source_name, str_copy, gf); @@ -273,6 +274,21 @@ void set_text_callback(struct transcription_filter_data *gf, result.result == DETECTION_RESULT_SPEECH) { send_sentence_to_file(gf, result, str_copy, translated_sentence); } + + if (!result.text.empty() && (result.result == DETECTION_RESULT_SPEECH || + result.result == DETECTION_RESULT_PARTIAL)) { + gf->last_sub_render_time = now_ms(); + gf->cleared_last_sub = false; + if (result.result == DETECTION_RESULT_SPEECH) { + // save the last subtitle if it was a full sentence + gf->last_transcription_sentence.push_back(result.text); + // remove the oldest sentence if the buffer is too long + while (gf->last_transcription_sentence.size() > + (size_t)gf->n_context_sentences) { + gf->last_transcription_sentence.pop_front(); + } + } + } }; void recording_state_callback(enum obs_frontend_event event, void *data) @@ -314,6 +330,12 @@ void reset_caption_state(transcription_filter_data *gf_) } send_caption_to_source(gf_->text_source_name, "", gf_); send_caption_to_source(gf_->translation_output, "", gf_); + // reset translation context + gf_->last_text_for_translation = ""; + gf_->last_text_translation = ""; + gf_->translation_ctx.last_input_tokens.clear(); + gf_->translation_ctx.last_translation_tokens.clear(); + gf_->last_transcription_sentence.clear(); // flush the buffer { std::lock_guard lock(gf_->whisper_buf_mutex); diff --git a/src/transcription-filter-data.h b/src/transcription-filter-data.h index 4b16d13..e1af694 100644 --- a/src/transcription-filter-data.h +++ b/src/transcription-filter-data.h @@ -36,6 +36,8 @@ struct transcription_filter_data { size_t sentence_number; // Minimal subtitle duration in ms size_t min_sub_duration; + // Maximal subtitle duration in ms + size_t max_sub_duration; // Last time a subtitle was rendered uint64_t last_sub_render_time; bool cleared_last_sub; @@ -62,7 +64,7 @@ struct transcription_filter_data { float sentence_psum_accept_thresh; bool do_silence; - bool vad_enabled; + int vad_mode; int log_level = LOG_DEBUG; bool log_words; bool caption_to_stream; @@ -84,11 +86,17 @@ struct transcription_filter_data { bool initial_creation = true; bool partial_transcription = false; int partial_latency = 1000; + float duration_filter_threshold = 2.25f; + int segment_duration = 7000; // Last transcription result - std::string last_text; + std::string last_text_for_translation; std::string last_text_translation; + // Transcription context sentences + int n_context_sentences; + std::deque last_transcription_sentence; + // Text source to output the subtitles std::string text_source_name; // Callback to set the text in the output text source (subtitles) @@ -110,6 +118,7 @@ struct transcription_filter_data { struct translation_context translation_ctx; std::string translation_model_index; std::string translation_model_path_external; + bool translate_only_full_sentences; bool buffered_output = false; TokenBufferThread captions_monitor; diff --git a/src/transcription-filter-properties.cpp b/src/transcription-filter-properties.cpp index 523bbf8..4a3693f 100644 --- a/src/transcription-filter-properties.cpp +++ b/src/transcription-filter-properties.cpp @@ -7,6 +7,7 @@ #include "transcription-filter.h" #include "transcription-filter-utils.h" #include "whisper-utils/whisper-language.h" +#include "whisper-utils/vad-processing.h" #include "model-utils/model-downloader-types.h" #include "translation/language_codes.h" #include "ui/filter-replace-dialog.h" @@ -212,8 +213,12 @@ void add_translation_group_properties(obs_properties_t *ppts) obs_property_t *prop_tgt = obs_properties_add_list( translation_group, "translate_target_language", MT_("target_language"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING); - obs_properties_add_bool(translation_group, "translate_add_context", - MT_("translate_add_context")); + + // add slider for number of context lines to add to the translation + obs_properties_add_int_slider(translation_group, "translate_add_context", + MT_("translate_add_context"), 0, 5, 1); + obs_properties_add_bool(translation_group, "translate_only_full_sentences", + MT_("translate_only_full_sentences")); // Populate the dropdown with the language codes for (const auto &language : language_codes) { @@ -290,6 +295,31 @@ void add_buffered_output_group_properties(obs_properties_t *ppts) OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); obs_property_list_add_int(buffer_type_list, "Character", SEGMENTATION_TOKEN); obs_property_list_add_int(buffer_type_list, "Word", SEGMENTATION_WORD); + obs_property_list_add_int(buffer_type_list, "Sentence", SEGMENTATION_SENTENCE); + // add callback to the segmentation selection to set default values + obs_property_set_modified_callback(buffer_type_list, [](obs_properties_t *props, + obs_property_t *property, + obs_data_t *settings) { + UNUSED_PARAMETER(property); + UNUSED_PARAMETER(props); + const int segmentation_type = (int)obs_data_get_int(settings, "buffer_output_type"); + // set default values for the number of lines and characters per line + switch (segmentation_type) { + case SEGMENTATION_TOKEN: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 30); + break; + case SEGMENTATION_WORD: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 10); + break; + case SEGMENTATION_SENTENCE: + obs_data_set_int(settings, "buffer_num_lines", 2); + obs_data_set_int(settings, "buffer_num_chars_per_line", 2); + break; + } + return true; + }); // add buffer lines parameter obs_properties_add_int_slider(buffered_output_group, "buffer_num_lines", MT_("buffer_num_lines"), 1, 5, 1); @@ -310,16 +340,29 @@ void add_advanced_group_properties(obs_properties_t *ppts, struct transcription_ obs_properties_add_int_slider(advanced_config_group, "min_sub_duration", MT_("min_sub_duration"), 1000, 5000, 50); + obs_properties_add_int_slider(advanced_config_group, "max_sub_duration", + MT_("max_sub_duration"), 1000, 5000, 50); obs_properties_add_float_slider(advanced_config_group, "sentence_psum_accept_thresh", MT_("sentence_psum_accept_thresh"), 0.0, 1.0, 0.05); obs_properties_add_bool(advanced_config_group, "process_while_muted", MT_("process_while_muted")); - obs_properties_add_bool(advanced_config_group, "vad_enabled", MT_("vad_enabled")); + // add selection for Active VAD vs Hybrid VAD + obs_property_t *vad_mode_list = + obs_properties_add_list(advanced_config_group, "vad_mode", MT_("vad_mode"), + OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); + obs_property_list_add_int(vad_mode_list, MT_("Active_VAD"), VAD_MODE_ACTIVE); + obs_property_list_add_int(vad_mode_list, MT_("Hybrid_VAD"), VAD_MODE_HYBRID); // add vad threshold slider obs_properties_add_float_slider(advanced_config_group, "vad_threshold", MT_("vad_threshold"), 0.0, 1.0, 0.05); + // add duration filter threshold slider + obs_properties_add_float_slider(advanced_config_group, "duration_filter_threshold", + MT_("duration_filter_threshold"), 0.1, 3.0, 0.05); + // add segment duration slider + obs_properties_add_int_slider(advanced_config_group, "segment_duration", + MT_("segment_duration"), 3000, 15000, 100); // add button to open filter and replace UI dialog obs_properties_add_button2( @@ -371,6 +414,10 @@ void add_whisper_params_group_properties(obs_properties_t *ppts) WHISPER_SAMPLING_BEAM_SEARCH); obs_property_list_add_int(whisper_sampling_method_list, "Greedy", WHISPER_SAMPLING_GREEDY); + // add int slider for context sentences + obs_properties_add_int_slider(whisper_params_group, "n_context_sentences", + MT_("n_context_sentences"), 0, 5, 1); + // int n_threads; obs_properties_add_int_slider(whisper_params_group, "n_threads", MT_("n_threads"), 1, 8, 1); // int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder @@ -507,3 +554,77 @@ obs_properties_t *transcription_filter_properties(void *data) UNUSED_PARAMETER(data); return ppts; } + +void transcription_filter_defaults(obs_data_t *s) +{ + obs_log(LOG_DEBUG, "filter defaults"); + + obs_data_set_default_bool(s, "buffered_output", false); + obs_data_set_default_int(s, "buffer_num_lines", 2); + obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); + obs_data_set_default_int(s, "buffer_output_type", + (int)TokenBufferSegmentation::SEGMENTATION_TOKEN); + + obs_data_set_default_bool(s, "vad_mode", VAD_MODE_ACTIVE); + obs_data_set_default_double(s, "vad_threshold", 0.65); + obs_data_set_default_double(s, "duration_filter_threshold", 2.25); + obs_data_set_default_int(s, "segment_duration", 7000); + obs_data_set_default_int(s, "log_level", LOG_DEBUG); + obs_data_set_default_bool(s, "log_words", false); + obs_data_set_default_bool(s, "caption_to_stream", false); + obs_data_set_default_string(s, "whisper_model_path", "Whisper Tiny English (74Mb)"); + obs_data_set_default_string(s, "whisper_language_select", "en"); + obs_data_set_default_string(s, "subtitle_sources", "none"); + obs_data_set_default_bool(s, "process_while_muted", false); + obs_data_set_default_bool(s, "subtitle_save_srt", false); + obs_data_set_default_bool(s, "truncate_output_file", false); + obs_data_set_default_bool(s, "only_while_recording", false); + obs_data_set_default_bool(s, "rename_file_to_match_recording", true); + obs_data_set_default_int(s, "min_sub_duration", 1000); + obs_data_set_default_int(s, "max_sub_duration", 3000); + obs_data_set_default_bool(s, "advanced_settings", false); + obs_data_set_default_bool(s, "translate", false); + obs_data_set_default_string(s, "translate_target_language", "__es__"); + obs_data_set_default_int(s, "translate_add_context", 1); + obs_data_set_default_bool(s, "translate_only_full_sentences", true); + obs_data_set_default_string(s, "translate_model", "whisper-based-translation"); + obs_data_set_default_string(s, "translation_model_path_external", ""); + obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100); + obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); + obs_data_set_default_bool(s, "partial_group", false); + obs_data_set_default_int(s, "partial_latency", 1100); + + // translation options + obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); + obs_data_set_default_double(s, "translation_repetition_penalty", 2.0); + obs_data_set_default_int(s, "translation_beam_size", 1); + obs_data_set_default_int(s, "translation_max_decoding_length", 65); + obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); + obs_data_set_default_int(s, "translation_max_input_length", 65); + + // Whisper parameters + obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); + obs_data_set_default_int(s, "n_context_sentences", 0); + obs_data_set_default_string(s, "initial_prompt", ""); + obs_data_set_default_int(s, "n_threads", 4); + obs_data_set_default_int(s, "n_max_text_ctx", 16384); + obs_data_set_default_bool(s, "whisper_translate", false); + obs_data_set_default_bool(s, "no_context", true); + obs_data_set_default_bool(s, "single_segment", true); + obs_data_set_default_bool(s, "print_special", false); + obs_data_set_default_bool(s, "print_progress", false); + obs_data_set_default_bool(s, "print_realtime", false); + obs_data_set_default_bool(s, "print_timestamps", false); + obs_data_set_default_bool(s, "token_timestamps", false); + obs_data_set_default_bool(s, "dtw_token_timestamps", false); + obs_data_set_default_double(s, "thold_pt", 0.01); + obs_data_set_default_double(s, "thold_ptsum", 0.01); + obs_data_set_default_int(s, "max_len", 0); + obs_data_set_default_bool(s, "split_on_word", true); + obs_data_set_default_int(s, "max_tokens", 0); + obs_data_set_default_bool(s, "suppress_blank", false); + obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); + obs_data_set_default_double(s, "temperature", 0.1); + obs_data_set_default_double(s, "max_initial_ts", 1.0); + obs_data_set_default_double(s, "length_penalty", -1.0); +} diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index 3683c18..657fea6 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -174,7 +174,7 @@ void transcription_filter_update(void *data, obs_data_t *s) obs_log(gf->log_level, "LocalVocal filter update"); gf->log_level = (int)obs_data_get_int(s, "log_level"); - gf->vad_enabled = obs_data_get_bool(s, "vad_enabled"); + gf->vad_mode = (int)obs_data_get_int(s, "vad_mode"); gf->log_words = obs_data_get_bool(s, "log_words"); gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream"); gf->save_to_file = obs_data_get_bool(s, "file_output_enable"); @@ -187,7 +187,10 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->sentence_number = 1; gf->process_while_muted = obs_data_get_bool(s, "process_while_muted"); gf->min_sub_duration = (int)obs_data_get_int(s, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(s, "max_sub_duration"); gf->last_sub_render_time = now_ms(); + gf->duration_filter_threshold = (float)obs_data_get_double(s, "duration_filter_threshold"); + gf->segment_duration = (int)obs_data_get_int(s, "segment_duration"); gf->partial_transcription = obs_data_get_bool(s, "partial_group"); gf->partial_latency = (int)obs_data_get_int(s, "partial_latency"); bool new_buffered_output = obs_data_get_bool(s, "buffered_output"); @@ -281,9 +284,10 @@ void transcription_filter_update(void *data, obs_data_t *s) bool new_translate = obs_data_get_bool(s, "translate"); gf->target_lang = obs_data_get_string(s, "translate_target_language"); - gf->translation_ctx.add_context = obs_data_get_bool(s, "translate_add_context"); + gf->translation_ctx.add_context = (int)obs_data_get_int(s, "translate_add_context"); gf->translation_ctx.input_tokenization_style = (InputTokenizationStyle)obs_data_get_int(s, "translate_input_tokenization_style"); + gf->translate_only_full_sentences = obs_data_get_bool(s, "translate_only_full_sentences"); gf->translation_output = obs_data_get_string(s, "translate_output"); std::string new_translate_model_index = obs_data_get_string(s, "translate_model"); std::string new_translation_model_path_external = @@ -342,6 +346,8 @@ void transcription_filter_update(void *data, obs_data_t *s) { std::lock_guard lock(gf->whisper_ctx_mutex); + gf->n_context_sentences = (int)obs_data_get_int(s, "n_context_sentences"); + gf->sentence_psum_accept_thresh = (float)obs_data_get_double(s, "sentence_psum_accept_thresh"); @@ -390,7 +396,7 @@ void transcription_filter_update(void *data, obs_data_t *s) gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts"); gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty"); - if (gf->vad_enabled && gf->vad) { + if (gf->vad) { const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold"); gf->vad->set_threshold(vad_threshold); } @@ -431,6 +437,7 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter) gf->frames = (size_t)((float)gf->sample_rate / (1000.0f / MAX_MS_WORK_BUFFER)); gf->last_num_frames = 0; gf->min_sub_duration = (int)obs_data_get_int(settings, "min_sub_duration"); + gf->max_sub_duration = (int)obs_data_get_int(settings, "max_sub_duration"); gf->last_sub_render_time = now_ms(); gf->log_level = (int)obs_data_get_int(settings, "log_level"); gf->save_srt = obs_data_get_bool(settings, "subtitle_save_srt"); @@ -551,72 +558,3 @@ void transcription_filter_hide(void *data) static_cast(data); obs_log(gf->log_level, "filter hide"); } - -void transcription_filter_defaults(obs_data_t *s) -{ - obs_log(LOG_DEBUG, "filter defaults"); - - obs_data_set_default_bool(s, "buffered_output", false); - obs_data_set_default_int(s, "buffer_num_lines", 2); - obs_data_set_default_int(s, "buffer_num_chars_per_line", 30); - obs_data_set_default_int(s, "buffer_output_type", - (int)TokenBufferSegmentation::SEGMENTATION_TOKEN); - - obs_data_set_default_bool(s, "vad_enabled", true); - obs_data_set_default_double(s, "vad_threshold", 0.65); - obs_data_set_default_int(s, "log_level", LOG_DEBUG); - obs_data_set_default_bool(s, "log_words", false); - obs_data_set_default_bool(s, "caption_to_stream", false); - obs_data_set_default_string(s, "whisper_model_path", "Whisper Tiny English (74Mb)"); - obs_data_set_default_string(s, "whisper_language_select", "en"); - obs_data_set_default_string(s, "subtitle_sources", "none"); - obs_data_set_default_bool(s, "process_while_muted", false); - obs_data_set_default_bool(s, "subtitle_save_srt", false); - obs_data_set_default_bool(s, "truncate_output_file", false); - obs_data_set_default_bool(s, "only_while_recording", false); - obs_data_set_default_bool(s, "rename_file_to_match_recording", true); - obs_data_set_default_int(s, "min_sub_duration", 3000); - obs_data_set_default_bool(s, "advanced_settings", false); - obs_data_set_default_bool(s, "translate", false); - obs_data_set_default_string(s, "translate_target_language", "__es__"); - obs_data_set_default_bool(s, "translate_add_context", true); - obs_data_set_default_string(s, "translate_model", "whisper-based-translation"); - obs_data_set_default_string(s, "translation_model_path_external", ""); - obs_data_set_default_int(s, "translate_input_tokenization_style", INPUT_TOKENIZAION_M2M100); - obs_data_set_default_double(s, "sentence_psum_accept_thresh", 0.4); - obs_data_set_default_bool(s, "partial_group", false); - obs_data_set_default_int(s, "partial_latency", 1100); - - // translation options - obs_data_set_default_double(s, "translation_sampling_temperature", 0.1); - obs_data_set_default_double(s, "translation_repetition_penalty", 2.0); - obs_data_set_default_int(s, "translation_beam_size", 1); - obs_data_set_default_int(s, "translation_max_decoding_length", 65); - obs_data_set_default_int(s, "translation_no_repeat_ngram_size", 1); - obs_data_set_default_int(s, "translation_max_input_length", 65); - - // Whisper parameters - obs_data_set_default_int(s, "whisper_sampling_method", WHISPER_SAMPLING_BEAM_SEARCH); - obs_data_set_default_string(s, "initial_prompt", ""); - obs_data_set_default_int(s, "n_threads", 4); - obs_data_set_default_int(s, "n_max_text_ctx", 16384); - obs_data_set_default_bool(s, "whisper_translate", false); - obs_data_set_default_bool(s, "no_context", true); - obs_data_set_default_bool(s, "single_segment", true); - obs_data_set_default_bool(s, "print_special", false); - obs_data_set_default_bool(s, "print_progress", false); - obs_data_set_default_bool(s, "print_realtime", false); - obs_data_set_default_bool(s, "print_timestamps", false); - obs_data_set_default_bool(s, "token_timestamps", false); - obs_data_set_default_bool(s, "dtw_token_timestamps", false); - obs_data_set_default_double(s, "thold_pt", 0.01); - obs_data_set_default_double(s, "thold_ptsum", 0.01); - obs_data_set_default_int(s, "max_len", 0); - obs_data_set_default_bool(s, "split_on_word", true); - obs_data_set_default_int(s, "max_tokens", 0); - obs_data_set_default_bool(s, "suppress_blank", false); - obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); - obs_data_set_default_double(s, "temperature", 0.1); - obs_data_set_default_double(s, "max_initial_ts", 1.0); - obs_data_set_default_double(s, "length_penalty", -1.0); -} diff --git a/src/translation/translation-language-utils.cpp b/src/translation/translation-language-utils.cpp new file mode 100644 index 0000000..685ca1a --- /dev/null +++ b/src/translation/translation-language-utils.cpp @@ -0,0 +1,33 @@ +#include "translation-language-utils.h" + +#include +#include + +std::string remove_start_punctuation(const std::string &text) +{ + if (text.empty()) { + return text; + } + + // Convert the input string to ICU's UnicodeString + icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(text); + + // Find the index of the first non-punctuation character + int32_t start = 0; + while (start < ustr.length()) { + UChar32 ch = ustr.char32At(start); + if (!u_ispunct(ch)) { + break; + } + start += U16_LENGTH(ch); + } + + // Create a new UnicodeString with punctuation removed from the start + icu::UnicodeString result = ustr.tempSubString(start); + + // Convert the result back to UTF-8 + std::string output; + result.toUTF8String(output); + + return output; +} diff --git a/src/translation/translation-language-utils.h b/src/translation/translation-language-utils.h new file mode 100644 index 0000000..44b450a --- /dev/null +++ b/src/translation/translation-language-utils.h @@ -0,0 +1,8 @@ +#ifndef TRANSLATION_LANGUAGE_UTILS_H +#define TRANSLATION_LANGUAGE_UTILS_H + +#include + +std::string remove_start_punctuation(const std::string &text); + +#endif // TRANSLATION_LANGUAGE_UTILS_H \ No newline at end of file diff --git a/src/translation/translation.cpp b/src/translation/translation.cpp index e11f072..0701d95 100644 --- a/src/translation/translation.cpp +++ b/src/translation/translation.cpp @@ -3,6 +3,7 @@ #include "model-utils/model-find-utils.h" #include "transcription-filter-data.h" #include "language_codes.h" +#include "translation-language-utils.h" #include #include @@ -114,31 +115,53 @@ int translate(struct translation_context &translation_ctx, const std::string &te if (translation_ctx.input_tokenization_style == INPUT_TOKENIZAION_M2M100) { // set input tokens std::vector input_tokens = {source_lang, ""}; - if (translation_ctx.add_context && + if (translation_ctx.add_context > 0 && translation_ctx.last_input_tokens.size() > 0) { - input_tokens.insert(input_tokens.end(), - translation_ctx.last_input_tokens.begin(), - translation_ctx.last_input_tokens.end()); + // add the last input tokens sentences to the input tokens + for (const auto &tokens : translation_ctx.last_input_tokens) { + input_tokens.insert(input_tokens.end(), tokens.begin(), + tokens.end()); + } } std::vector new_input_tokens = translation_ctx.tokenizer(text); input_tokens.insert(input_tokens.end(), new_input_tokens.begin(), new_input_tokens.end()); input_tokens.push_back(""); - translation_ctx.last_input_tokens = new_input_tokens; + // log the input tokens + std::string input_tokens_str; + for (const auto &token : input_tokens) { + input_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Input tokens: %s", input_tokens_str.c_str()); + + translation_ctx.last_input_tokens.push_back(new_input_tokens); + // remove the oldest input tokens + while (translation_ctx.last_input_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_input_tokens.pop_front(); + } const std::vector> batch = {input_tokens}; // get target prefix target_prefix = {target_lang}; - if (translation_ctx.add_context && + // add the last translation tokens to the target prefix + if (translation_ctx.add_context > 0 && translation_ctx.last_translation_tokens.size() > 0) { - target_prefix.insert( - target_prefix.end(), - translation_ctx.last_translation_tokens.begin(), - translation_ctx.last_translation_tokens.end()); + for (const auto &tokens : translation_ctx.last_translation_tokens) { + target_prefix.insert(target_prefix.end(), tokens.begin(), + tokens.end()); + } } + // log the target prefix + std::string target_prefix_str; + for (const auto &token : target_prefix) { + target_prefix_str += token + ","; + } + obs_log(LOG_INFO, "Target prefix: %s", target_prefix_str.c_str()); + const std::vector> target_prefix_batch = { target_prefix}; results = translation_ctx.translator->translate_batch( @@ -161,9 +184,26 @@ int translate(struct translation_context &translation_ctx, const std::string &te std::vector translation_tokens( tokens_result.begin() + target_prefix.size(), tokens_result.end()); - translation_ctx.last_translation_tokens = translation_tokens; + // log the translation tokens + std::string translation_tokens_str; + for (const auto &token : translation_tokens) { + translation_tokens_str += token + ", "; + } + obs_log(LOG_INFO, "Translation tokens: %s", translation_tokens_str.c_str()); + + // save the translation tokens + translation_ctx.last_translation_tokens.push_back(translation_tokens); + // remove the oldest translation tokens + while (translation_ctx.last_translation_tokens.size() > + (size_t)translation_ctx.add_context) { + translation_ctx.last_translation_tokens.pop_front(); + } + obs_log(LOG_INFO, "Last translation tokens deque size: %d", + (int)translation_ctx.last_translation_tokens.size()); + // detokenize - result = translation_ctx.detokenizer(translation_tokens); + const std::string result_ = translation_ctx.detokenizer(translation_tokens); + result = remove_start_punctuation(result_); } catch (std::exception &e) { obs_log(LOG_ERROR, "Error: %s", e.what()); return OBS_POLYGLOT_TRANSLATION_FAIL; diff --git a/src/translation/translation.h b/src/translation/translation.h index 0d45080..c740726 100644 --- a/src/translation/translation.h +++ b/src/translation/translation.h @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -25,10 +26,10 @@ struct translation_context { std::unique_ptr options; std::function(const std::string &)> tokenizer; std::function &)> detokenizer; - std::vector last_input_tokens; - std::vector last_translation_tokens; - // Use the last translation as context for the next translation - bool add_context; + std::deque> last_input_tokens; + std::deque> last_translation_tokens; + // How many sentences to use as context for the next translation + int add_context; InputTokenizationStyle input_tokenization_style; }; diff --git a/src/whisper-utils/token-buffer-thread.cpp b/src/whisper-utils/token-buffer-thread.cpp index ac34534..3e3b002 100644 --- a/src/whisper-utils/token-buffer-thread.cpp +++ b/src/whisper-utils/token-buffer-thread.cpp @@ -6,6 +6,9 @@ #include "whisper-utils.h" #include "transcription-utils.h" +#include +#include + #include #ifdef _WIN32 @@ -75,37 +78,74 @@ void TokenBufferThread::log_token_vector(const std::vector &tokens) obs_log(LOG_INFO, "TokenBufferThread::log_token_vector: '%s'", output.c_str()); } -void TokenBufferThread::addSentence(const std::string &sentence) +void TokenBufferThread::addSentenceFromStdString(const std::string &sentence, + TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial) { + if (sentence.empty()) { + return; + } #ifdef _WIN32 // on windows convert from multibyte to wide char int count = MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), NULL, 0); - std::wstring sentence_ws(count, 0); + TokenBufferString sentence_ws(count, 0); MultiByteToWideChar(CP_UTF8, 0, sentence.c_str(), (int)sentence.length(), &sentence_ws[0], count); #else - std::string sentence_ws = sentence; + TokenBufferString sentence_ws = sentence; #endif - // split to characters - std::vector characters; - for (const auto &c : sentence_ws) { - characters.push_back(TokenBufferString(1, c)); + + TokenBufferSentence sentence_for_add; + sentence_for_add.start_time = start_time; + sentence_for_add.end_time = end_time; + + if (this->segmentation == SEGMENTATION_WORD) { + // split the sentence to words + std::vector words; + std::basic_istringstream iss(sentence_ws); + TokenBufferString word_token; + while (iss >> word_token) { + words.push_back(word_token); + } + // add the words to a sentence + for (const auto &word : words) { + sentence_for_add.tokens.push_back({word, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); + } + } else if (this->segmentation == SEGMENTATION_TOKEN) { + // split to characters + std::vector characters; + for (const auto &c : sentence_ws) { + characters.push_back(TokenBufferString(1, c)); + } + // add the characters to a sentece + for (const auto &character : characters) { + sentence_for_add.tokens.push_back({character, is_partial}); + } + } else { + // add the whole sentence as a single token + sentence_for_add.tokens.push_back({sentence_ws, is_partial}); + sentence_for_add.tokens.push_back({SPACE, is_partial}); } + addSentence(sentence_for_add); +} - std::lock_guard lock(inputQueueMutex); +void TokenBufferThread::addSentence(const TokenBufferSentence &sentence) +{ + std::lock_guard lock(this->inputQueueMutex); - // add the characters to the inputQueue - for (const auto &character : characters) { + // add the tokens to the inputQueue + for (const auto &character : sentence.tokens) { inputQueue.push_back(character); } - inputQueue.push_back(SPACE); + inputQueue.push_back({SPACE, sentence.tokens.back().is_partial}); // add to the contribution queue as well - for (const auto &character : characters) { + for (const auto &character : sentence.tokens) { contributionQueue.push_back(character); } - contributionQueue.push_back(SPACE); + contributionQueue.push_back({SPACE, sentence.tokens.back().is_partial}); this->lastContributionTime = std::chrono::steady_clock::now(); } @@ -148,7 +188,7 @@ void TokenBufferThread::monitor() if (this->segmentation == SEGMENTATION_TOKEN) { // pop tokens until a space is found while (!presentationQueue.empty() && - presentationQueue.front() != SPACE) { + presentationQueue.front().token != SPACE) { presentationQueue.pop_front(); } } @@ -158,6 +198,13 @@ void TokenBufferThread::monitor() std::lock_guard lock(inputQueueMutex); if (!inputQueue.empty()) { + // if the input on the inputQueue is partial - first remove all partials + // from the end of the presentation queue + while (!presentationQueue.empty() && + presentationQueue.back().is_partial) { + presentationQueue.pop_back(); + } + // if there are token on the input queue // then add to the presentation queue based on the segmentation if (this->segmentation == SEGMENTATION_SENTENCE) { @@ -171,16 +218,17 @@ void TokenBufferThread::monitor() presentationQueue.push_back(inputQueue.front()); inputQueue.pop_front(); } else { + // SEGMENTATION_WORD // skip spaces in the beginning of the input queue while (!inputQueue.empty() && - inputQueue.front() == SPACE) { + inputQueue.front().token == SPACE) { inputQueue.pop_front(); } // add one word to the presentation queue - TokenBufferString word; + TokenBufferToken word; while (!inputQueue.empty() && - inputQueue.front() != SPACE) { - word += inputQueue.front(); + inputQueue.front().token != SPACE) { + word = inputQueue.front(); inputQueue.pop_front(); } presentationQueue.push_back(word); @@ -200,7 +248,7 @@ void TokenBufferThread::monitor() size_t wordsInSentence = 0; for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &word = presentationQueue[i]; - sentences.back() += word + SPACE; + sentences.back() += word.token + SPACE; wordsInSentence++; if (wordsInSentence == this->numPerSentence) { sentences.push_back(TokenBufferString()); @@ -211,12 +259,12 @@ void TokenBufferThread::monitor() for (size_t i = 0; i < presentationQueue.size(); i++) { const auto &token = presentationQueue[i]; // skip spaces in the beginning of a sentence (tokensInSentence == 0) - if (token == SPACE && + if (token.token == SPACE && sentences.back().length() == 0) { continue; } - sentences.back() += token; + sentences.back() += token.token; if (sentences.back().length() == this->numPerSentence) { // if the next character is not a space - this is a broken word @@ -280,7 +328,7 @@ void TokenBufferThread::monitor() // take the contribution queue and send it to the output TokenBufferString contribution; for (const auto &token : contributionQueue) { - contribution += token; + contribution += token.token; } contributionQueue.clear(); #ifdef _WIN32 diff --git a/src/whisper-utils/token-buffer-thread.h b/src/whisper-utils/token-buffer-thread.h index 13be208..7666669 100644 --- a/src/whisper-utils/token-buffer-thread.h +++ b/src/whisper-utils/token-buffer-thread.h @@ -16,8 +16,10 @@ #ifdef _WIN32 typedef std::wstring TokenBufferString; +typedef wchar_t TokenBufferChar; #else typedef std::string TokenBufferString; +typedef char TokenBufferChar; #endif struct transcription_filter_data; @@ -27,6 +29,22 @@ enum TokenBufferSpeed { SPEED_SLOW = 0, SPEED_NORMAL, SPEED_FAST }; typedef std::chrono::time_point TokenBufferTimePoint; +inline std::chrono::time_point get_time_point_from_ms(uint64_t ms) +{ + return std::chrono::time_point(std::chrono::milliseconds(ms)); +} + +struct TokenBufferToken { + TokenBufferString token; + bool is_partial; +}; + +struct TokenBufferSentence { + std::vector tokens; + TokenBufferTimePoint start_time; + TokenBufferTimePoint end_time; +}; + class TokenBufferThread { public: // default constructor @@ -40,7 +58,9 @@ public: std::chrono::seconds maxTime_, TokenBufferSegmentation segmentation_ = SEGMENTATION_TOKEN); - void addSentence(const std::string &sentence); + void addSentenceFromStdString(const std::string &sentence, TokenBufferTimePoint start_time, + TokenBufferTimePoint end_time, bool is_partial = false); + void addSentence(const TokenBufferSentence &sentence); void clear(); void stopThread(); @@ -59,9 +79,9 @@ private: void log_token_vector(const std::vector &tokens); int getWaitTime(TokenBufferSpeed speed) const; struct transcription_filter_data *gf; - std::deque inputQueue; - std::deque presentationQueue; - std::deque contributionQueue; + std::deque inputQueue; + std::deque presentationQueue; + std::deque contributionQueue; std::thread workerThread; std::mutex inputQueueMutex; std::mutex presentationQueueMutex; diff --git a/src/whisper-utils/vad-processing.cpp b/src/whisper-utils/vad-processing.cpp new file mode 100644 index 0000000..0e9c744 --- /dev/null +++ b/src/whisper-utils/vad-processing.cpp @@ -0,0 +1,377 @@ + +#include + +#include "transcription-filter-data.h" + +#include "vad-processing.h" + +#ifdef _WIN32 +#define NOMINMAX +#include +#endif + +int get_data_from_buf_and_resample(transcription_filter_data *gf, + uint64_t &start_timestamp_offset_ns, + uint64_t &end_timestamp_offset_ns) +{ + uint32_t num_frames_from_infos = 0; + + { + // scoped lock the buffer mutex + std::lock_guard lock(gf->whisper_buf_mutex); + + if (gf->input_buffers[0].size == 0) { + return 1; + } + + obs_log(gf->log_level, + "segmentation: currently %lu bytes in the audio input buffer", + gf->input_buffers[0].size); + + // max number of frames is 10 seconds worth of audio + const size_t max_num_frames = gf->sample_rate * 10; + + // pop all infos from the info buffer and mark the beginning timestamp from the first + // info as the beginning timestamp of the segment + struct transcription_filter_audio_info info_from_buf = {0}; + const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); + while (gf->info_buffer.size >= size_of_audio_info) { + circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); + num_frames_from_infos += info_from_buf.frames; + if (start_timestamp_offset_ns == 0) { + start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; + } + // Check if we're within the needed segment length + if (num_frames_from_infos > max_num_frames) { + // too big, push the last info into the buffer's front where it was + num_frames_from_infos -= info_from_buf.frames; + circlebuf_push_front(&gf->info_buffer, &info_from_buf, + size_of_audio_info); + break; + } + } + // calculate the end timestamp from the info plus the number of frames in the packet + end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns + + info_from_buf.frames * 1000000000 / gf->sample_rate; + + if (start_timestamp_offset_ns > end_timestamp_offset_ns) { + // this may happen when the incoming media has a timestamp reset + // in this case, we should figure out the start timestamp from the end timestamp + // and the number of frames + start_timestamp_offset_ns = + end_timestamp_offset_ns - + num_frames_from_infos * 1000000000 / gf->sample_rate; + } + + for (size_t c = 0; c < gf->channels; c++) { + // zero the rest of copy_buffers + memset(gf->copy_buffers[c], 0, gf->frames * sizeof(float)); + } + + /* Pop from input circlebuf */ + for (size_t c = 0; c < gf->channels; c++) { + // Push the new data to copy_buffers[c] + circlebuf_pop_front(&gf->input_buffers[c], gf->copy_buffers[c], + num_frames_from_infos * sizeof(float)); + } + } + + obs_log(gf->log_level, "found %d frames from info buffer.", num_frames_from_infos); + gf->last_num_frames = num_frames_from_infos; + + { + // resample to 16kHz + float *resampled_16khz[MAX_PREPROC_CHANNELS]; + uint32_t resampled_16khz_frames; + uint64_t ts_offset; + { + ProfileScope("resample"); + audio_resampler_resample(gf->resampler_to_whisper, + (uint8_t **)resampled_16khz, + &resampled_16khz_frames, &ts_offset, + (const uint8_t **)gf->copy_buffers, + (uint32_t)num_frames_from_infos); + } + + circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], + resampled_16khz_frames * sizeof(float)); + obs_log(gf->log_level, + "resampled: %d channels, %d frames, %f ms, current size: %lu bytes", + (int)gf->channels, (int)resampled_16khz_frames, + (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f, + gf->resampled_buffer.size); + } + + return 0; +} + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + const int ret = get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns); + if (ret != 0) { + return last_vad_state; + } + + const size_t vad_window_size_samples = gf->vad->get_window_size_samples() * sizeof(float); + const size_t min_vad_buffer_size = vad_window_size_samples * 8; + if (gf->resampled_buffer.size < min_vad_buffer_size) + return last_vad_state; + + size_t vad_num_windows = gf->resampled_buffer.size / vad_window_size_samples; + + std::vector vad_input; + vad_input.resize(vad_num_windows * gf->vad->get_window_size_samples()); + circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %d windows, reset state? %s", + vad_input.size(), vad_num_windows, (!last_vad_state.vad_on) ? "yes" : "no"); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, !last_vad_state.vad_on); + } + + const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; + const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + vad_state current_vad_state = {false, start_ts_offset_ms, end_ts_offset_ms, + last_vad_state.last_partial_segment_end_ts}; + + std::vector stamps = gf->vad->get_speech_timestamps(); + if (stamps.size() == 0) { + obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); + if (last_vad_state.vad_on) { + obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_WAS_ON); + current_vad_state.last_partial_segment_end_ts = 0; + } + + if (gf->enable_audio_chunks_callback) { + audio_chunk_callback(gf, vad_input.data(), vad_input.size(), + VAD_STATE_IS_OFF, + {DETECTION_RESULT_SILENCE, + "[silence]", + current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + {}}); + } + + return current_vad_state; + } + + // process vad segments + for (size_t i = 0; i < stamps.size(); i++) { + int start_frame = stamps[i].start; + if (i > 0) { + // if this is not the first segment, start from the end of the previous segment + start_frame = stamps[i - 1].end; + } else { + // take at least 100ms of audio before the first speech segment, if available + start_frame = std::max(0, start_frame - WHISPER_SAMPLE_RATE / 10); + } + + int end_frame = stamps[i].end; + // if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { + // // take at least 100ms of audio after the last speech segment, if available + // end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, + // (int)vad_input.size()); + // } + + const int number_of_frames = end_frame - start_frame; + + // push the data into gf-whisper_buffer + circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, + number_of_frames * sizeof(float)); + + obs_log(gf->log_level, + "VAD segment %d/%d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", + i, (stamps.size() - 1), start_frame, end_frame, number_of_frames, + number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, + gf->whisper_buffer.size / sizeof(float), + gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); + + // segment "end" is in the middle of the buffer, send it to inference + if (stamps[i].end < (int)vad_input.size()) { + // new "ending" segment (not up to the end of the buffer) + obs_log(gf->log_level, "VAD segment end -> send to inference"); + // find the end timestamp of the segment + const uint64_t segment_end_ts = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + run_inference_and_callbacks( + gf, last_vad_state.start_ts_offest_ms, segment_end_ts, + last_vad_state.vad_on ? VAD_STATE_WAS_ON : VAD_STATE_WAS_OFF); + current_vad_state.vad_on = false; + current_vad_state.start_ts_offest_ms = current_vad_state.end_ts_offset_ms; + current_vad_state.end_ts_offset_ms = 0; + current_vad_state.last_partial_segment_end_ts = 0; + last_vad_state = current_vad_state; + continue; + } + + // end not reached - speech is ongoing + current_vad_state.vad_on = true; + if (last_vad_state.vad_on) { + obs_log(gf->log_level, + "last vad state was: ON, start ts: %llu, end ts: %llu", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms); + current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; + } else { + obs_log(gf->log_level, + "last vad state was: OFF, start ts: %llu, end ts: %llu. start_ts_offset_ms: %llu, start_frame: %d", + last_vad_state.start_ts_offest_ms, last_vad_state.end_ts_offset_ms, + start_ts_offset_ms, start_frame); + current_vad_state.start_ts_offest_ms = + start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; + } + current_vad_state.end_ts_offset_ms = + start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; + obs_log(gf->log_level, + "end not reached. vad state: ON, start ts: %llu, end ts: %llu", + current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); + + last_vad_state = current_vad_state; + + // if partial transcription is enabled, check if we should send a partial segment + if (!gf->partial_transcription) { + continue; + } + + // current length of audio in buffer + const uint64_t current_length_ms = + (current_vad_state.end_ts_offset_ms > 0 + ? current_vad_state.end_ts_offset_ms + : current_vad_state.start_ts_offest_ms) - + (current_vad_state.last_partial_segment_end_ts > 0 + ? current_vad_state.last_partial_segment_end_ts + : current_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + current_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + current_vad_state.last_partial_segment_end_ts = + current_vad_state.end_ts_offset_ms; + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, + current_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } + } + + return current_vad_state; +} + +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state) +{ + // get data from buffer and resample + uint64_t start_timestamp_offset_ns = 0; + uint64_t end_timestamp_offset_ns = 0; + + if (get_data_from_buf_and_resample(gf, start_timestamp_offset_ns, + end_timestamp_offset_ns) != 0) { + return last_vad_state; + } + + last_vad_state.end_ts_offset_ms = end_timestamp_offset_ns / 1000000; + + // extract the data from the resampled buffer with circlebuf_pop_front into a temp buffer + // and then push it into the whisper buffer + const size_t resampled_buffer_size = gf->resampled_buffer.size; + std::vector temp_buffer; + temp_buffer.resize(resampled_buffer_size); + circlebuf_pop_front(&gf->resampled_buffer, temp_buffer.data(), resampled_buffer_size); + circlebuf_push_back(&gf->whisper_buffer, temp_buffer.data(), resampled_buffer_size); + + obs_log(gf->log_level, "whisper buffer size: %lu bytes", gf->whisper_buffer.size); + + // use last_vad_state timestamps to calculate the duration of the current segment + if (last_vad_state.end_ts_offset_ms - last_vad_state.start_ts_offest_ms >= + (uint64_t)gf->segment_duration) { + obs_log(gf->log_level, "%d seconds worth of audio -> send to inference", + gf->segment_duration); + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, VAD_STATE_WAS_ON); + last_vad_state.start_ts_offest_ms = end_timestamp_offset_ns / 1000000; + last_vad_state.last_partial_segment_end_ts = 0; + return last_vad_state; + } + + // if partial transcription is enabled, check if we should send a partial segment + if (gf->partial_transcription) { + // current length of audio in buffer + const uint64_t current_length_ms = + (last_vad_state.end_ts_offset_ms > 0 ? last_vad_state.end_ts_offset_ms + : last_vad_state.start_ts_offest_ms) - + (last_vad_state.last_partial_segment_end_ts > 0 + ? last_vad_state.last_partial_segment_end_ts + : last_vad_state.start_ts_offest_ms); + obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", + last_vad_state.last_partial_segment_end_ts, current_length_ms); + + if (current_length_ms > (uint64_t)gf->partial_latency) { + // send partial segment to inference + obs_log(gf->log_level, "Partial segment -> send to inference"); + last_vad_state.last_partial_segment_end_ts = + last_vad_state.end_ts_offset_ms; + + // run vad on the current buffer + std::vector vad_input; + vad_input.resize(gf->whisper_buffer.size / sizeof(float)); + circlebuf_peek_front(&gf->whisper_buffer, vad_input.data(), + vad_input.size() * sizeof(float)); + + obs_log(gf->log_level, "sending %d frames to vad, %.1f ms", + vad_input.size(), + (float)vad_input.size() * 1000.0f / (float)WHISPER_SAMPLE_RATE); + { + ProfileScope("vad->process"); + gf->vad->process(vad_input, true); + } + + if (gf->vad->get_speech_timestamps().size() > 0) { + // VAD detected speech in the partial segment + run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, + last_vad_state.end_ts_offset_ms, + VAD_STATE_PARTIAL); + } else { + // VAD detected silence in the partial segment + obs_log(gf->log_level, "VAD detected silence in partial segment"); + // pop the partial segment from the whisper buffer, save some audio for the next segment + const size_t num_bytes_to_keep = + (WHISPER_SAMPLE_RATE / 4) * sizeof(float); + circlebuf_pop_front(&gf->whisper_buffer, nullptr, + gf->whisper_buffer.size - num_bytes_to_keep); + } + } + } + + return last_vad_state; +} + +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file) +{ + // initialize Silero VAD +#ifdef _WIN32 + // convert mbstring to wstring + int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, + strlen(silero_vad_model_file), NULL, 0); + std::wstring silero_vad_model_path(count, 0); + MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), + &silero_vad_model_path[0], count); + obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); +#else + std::string silero_vad_model_path = silero_vad_model_file; + obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); +#endif + // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py + // for silero vad parameters + gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE, 32, 0.5f, 100, + 100, 100)); +} diff --git a/src/whisper-utils/vad-processing.h b/src/whisper-utils/vad-processing.h new file mode 100644 index 0000000..996002b --- /dev/null +++ b/src/whisper-utils/vad-processing.h @@ -0,0 +1,18 @@ +#ifndef VAD_PROCESSING_H +#define VAD_PROCESSING_H + +enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF, VAD_STATE_PARTIAL }; +enum VadMode { VAD_MODE_ACTIVE = 0, VAD_MODE_HYBRID, VAD_MODE_DISABLED }; + +struct vad_state { + bool vad_on; + uint64_t start_ts_offest_ms; + uint64_t end_ts_offset_ms; + uint64_t last_partial_segment_end_ts; +}; + +vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +vad_state hybrid_vad_segmentation(transcription_filter_data *gf, vad_state last_vad_state); +void initialize_vad(transcription_filter_data *gf, const char *silero_vad_model_file); + +#endif // VAD_PROCESSING_H diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 6d2d76e..6da91d9 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -17,18 +17,12 @@ #endif #include "model-utils/model-find-utils.h" +#include "vad-processing.h" #include #include #include -struct vad_state { - bool vad_on; - uint64_t start_ts_offest_ms; - uint64_t end_ts_offset_ms; - uint64_t last_partial_segment_end_ts; -}; - struct whisper_context *init_whisper_context(const std::string &model_path_in, struct transcription_filter_data *gf) { @@ -161,6 +155,10 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter float *pcm32f_data = (float *)pcm32f_data_; size_t pcm32f_size = pcm32f_num_samples; + // incoming duration in ms + const uint64_t incoming_duration_ms = + (uint64_t)(pcm32f_num_samples * 1000 / WHISPER_SAMPLE_RATE); + if (pcm32f_num_samples < WHISPER_SAMPLE_RATE) { obs_log(gf->log_level, "Speech segment is less than 1 second, padding with zeros to 1 second"); @@ -175,7 +173,7 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter } // duration in ms - const uint64_t duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); + const uint64_t whisper_duration_ms = (uint64_t)(pcm32f_size * 1000 / WHISPER_SAMPLE_RATE); std::lock_guard lock(gf->whisper_ctx_mutex); if (gf->whisper_context == nullptr) { @@ -183,9 +181,19 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter return {DETECTION_RESULT_UNKNOWN, "", t0, t1, {}, ""}; } + if (gf->n_context_sentences > 0 && !gf->last_transcription_sentence.empty()) { + // set the initial prompt to the last transcription sentences (concatenated) + std::string initial_prompt = gf->last_transcription_sentence[0]; + for (size_t i = 1; i < gf->last_transcription_sentence.size(); ++i) { + initial_prompt += " " + gf->last_transcription_sentence[i]; + } + gf->whisper_params.initial_prompt = initial_prompt.c_str(); + obs_log(gf->log_level, "Initial prompt: %s", gf->whisper_params.initial_prompt); + } + // run the inference int whisper_full_result = -1; - gf->whisper_params.duration_ms = (int)(duration_ms); + gf->whisper_params.duration_ms = (int)(whisper_duration_ms); try { whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, pcm32f_data, (int)pcm32f_size); @@ -243,13 +251,13 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter // token ids https://huggingface.co/openai/whisper-large-v3/raw/main/tokenizer.json if (token.id > 50365 && token.id <= 51865) { const float time = ((float)token.id - 50365.0f) * 0.02f; - const float duration_s = (float)duration_ms / 1000.0f; - const float ratio = - std::max(time, duration_s) / std::min(time, duration_s); + const float duration_s = (float)incoming_duration_ms / 1000.0f; + const float ratio = time / duration_s; obs_log(gf->log_level, - "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f.", - token.id, time, duration_s, ratio); - if (ratio > 3.0f) { + "Time token found %d -> %.3f. Duration: %.3f. Ratio: %.3f. Threshold %.2f", + token.id, time, duration_s, ratio, + gf->duration_filter_threshold); + if (ratio > gf->duration_filter_threshold) { // ratio is too high, skip this detection obs_log(gf->log_level, "Time token ratio too high, skipping"); @@ -263,8 +271,8 @@ struct DetectionResultWithText run_whisper_inference(struct transcription_filter text += token_str; tokens.push_back(token); } - obs_log(gf->log_level, "S %d, Token %d: %d\t%s\tp: %.3f [keep: %d]", - n_segment, j, token.id, token_str, token.p, keep); + obs_log(gf->log_level, "S %d, T %d: %d\t%s\tp: %.3f [keep: %d]", n_segment, + j, token.id, token_str, token.p, keep); } } sentence_p /= (float)tokens.size(); @@ -327,233 +335,6 @@ void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_o bfree(pcm32f_data); } -vad_state vad_based_segmentation(transcription_filter_data *gf, vad_state last_vad_state) -{ - uint32_t num_frames_from_infos = 0; - uint64_t start_timestamp_offset_ns = 0; - uint64_t end_timestamp_offset_ns = 0; - size_t overlap_size = 0; - - for (size_t c = 0; c < gf->channels; c++) { - // zero the rest of copy_buffers - memset(gf->copy_buffers[c] + overlap_size, 0, - (gf->frames - overlap_size) * sizeof(float)); - } - - { - // scoped lock the buffer mutex - std::lock_guard lock(gf->whisper_buf_mutex); - - obs_log(gf->log_level, - "vad based segmentation. currently %lu bytes in the audio input buffer", - gf->input_buffers[0].size); - - // max number of frames is 10 seconds worth of audio - const size_t max_num_frames = gf->sample_rate * 10; - - // pop all infos from the info buffer and mark the beginning timestamp from the first - // info as the beginning timestamp of the segment - struct transcription_filter_audio_info info_from_buf = {0}; - const size_t size_of_audio_info = sizeof(transcription_filter_audio_info); - while (gf->info_buffer.size >= size_of_audio_info) { - circlebuf_pop_front(&gf->info_buffer, &info_from_buf, size_of_audio_info); - num_frames_from_infos += info_from_buf.frames; - if (start_timestamp_offset_ns == 0) { - start_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; - } - // Check if we're within the needed segment length - if (num_frames_from_infos > max_num_frames) { - // too big, push the last info into the buffer's front where it was - num_frames_from_infos -= info_from_buf.frames; - circlebuf_push_front(&gf->info_buffer, &info_from_buf, - size_of_audio_info); - break; - } - } - end_timestamp_offset_ns = info_from_buf.timestamp_offset_ns; - - if (start_timestamp_offset_ns > end_timestamp_offset_ns) { - // this may happen when the incoming media has a timestamp reset - // in this case, we should figure out the start timestamp from the end timestamp - // and the number of frames - start_timestamp_offset_ns = - end_timestamp_offset_ns - - num_frames_from_infos * 1000000000 / gf->sample_rate; - } - - /* Pop from input circlebuf */ - for (size_t c = 0; c < gf->channels; c++) { - // Push the new data to copy_buffers[c] - circlebuf_pop_front(&gf->input_buffers[c], - gf->copy_buffers[c] + overlap_size, - num_frames_from_infos * sizeof(float)); - } - } - - obs_log(gf->log_level, "found %d frames from info buffer. %lu in overlap", - num_frames_from_infos, overlap_size); - gf->last_num_frames = num_frames_from_infos + overlap_size; - - { - // resample to 16kHz - float *resampled_16khz[MAX_PREPROC_CHANNELS]; - uint32_t resampled_16khz_frames; - uint64_t ts_offset; - { - ProfileScope("resample"); - audio_resampler_resample(gf->resampler_to_whisper, - (uint8_t **)resampled_16khz, - &resampled_16khz_frames, &ts_offset, - (const uint8_t **)gf->copy_buffers, - (uint32_t)num_frames_from_infos); - } - - obs_log(gf->log_level, "resampled: %d channels, %d frames, %f ms", - (int)gf->channels, (int)resampled_16khz_frames, - (float)resampled_16khz_frames / WHISPER_SAMPLE_RATE * 1000.0f); - circlebuf_push_back(&gf->resampled_buffer, resampled_16khz[0], - resampled_16khz_frames * sizeof(float)); - } - - if (gf->resampled_buffer.size < (gf->vad->get_window_size_samples() * sizeof(float))) - return last_vad_state; - - size_t len = - gf->resampled_buffer.size / (gf->vad->get_window_size_samples() * sizeof(float)); - - std::vector vad_input; - vad_input.resize(len * gf->vad->get_window_size_samples()); - circlebuf_pop_front(&gf->resampled_buffer, vad_input.data(), - vad_input.size() * sizeof(float)); - - obs_log(gf->log_level, "sending %d frames to vad", vad_input.size()); - { - ProfileScope("vad->process"); - gf->vad->process(vad_input, !last_vad_state.vad_on); - } - - const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000; - const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000; - - vad_state current_vad_state = {false, start_ts_offset_ms, end_ts_offset_ms, - last_vad_state.last_partial_segment_end_ts}; - - std::vector stamps = gf->vad->get_speech_timestamps(); - if (stamps.size() == 0) { - obs_log(gf->log_level, "VAD detected no speech in %u frames", vad_input.size()); - if (last_vad_state.vad_on) { - obs_log(gf->log_level, "Last VAD was ON: segment end -> send to inference"); - run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms, - last_vad_state.end_ts_offset_ms, - VAD_STATE_WAS_ON); - current_vad_state.last_partial_segment_end_ts = 0; - } - - if (gf->enable_audio_chunks_callback) { - audio_chunk_callback(gf, vad_input.data(), vad_input.size(), - VAD_STATE_IS_OFF, - {DETECTION_RESULT_SILENCE, - "[silence]", - current_vad_state.start_ts_offest_ms, - current_vad_state.end_ts_offset_ms, - {}}); - } - - return current_vad_state; - } - - // process vad segments - for (size_t i = 0; i < stamps.size(); i++) { - int start_frame = stamps[i].start; - if (i > 0) { - // if this is not the first segment, start from the end of the previous segment - start_frame = stamps[i - 1].end; - } else { - // take at least 100ms of audio before the first speech segment, if available - start_frame = std::max(0, start_frame - WHISPER_SAMPLE_RATE / 10); - } - - int end_frame = stamps[i].end; - if (i == stamps.size() - 1 && stamps[i].end < (int)vad_input.size()) { - // take at least 100ms of audio after the last speech segment, if available - end_frame = std::min(end_frame + WHISPER_SAMPLE_RATE / 10, - (int)vad_input.size()); - } - - const int number_of_frames = end_frame - start_frame; - - // push the data into gf-whisper_buffer - circlebuf_push_back(&gf->whisper_buffer, vad_input.data() + start_frame, - number_of_frames * sizeof(float)); - - obs_log(gf->log_level, - "VAD segment %d. pushed %d to %d (%d frames / %lu ms). current size: %lu bytes / %lu frames / %lu ms", - i, start_frame, end_frame, number_of_frames, - number_of_frames * 1000 / WHISPER_SAMPLE_RATE, gf->whisper_buffer.size, - gf->whisper_buffer.size / sizeof(float), - gf->whisper_buffer.size / sizeof(float) * 1000 / WHISPER_SAMPLE_RATE); - - // segment "end" is in the middle of the buffer, send it to inference - if (stamps[i].end < (int)vad_input.size()) { - // new "ending" segment (not up to the end of the buffer) - obs_log(gf->log_level, "VAD segment end -> send to inference"); - // find the end timestamp of the segment - const uint64_t segment_end_ts = - start_ts_offset_ms + end_frame * 1000 / WHISPER_SAMPLE_RATE; - run_inference_and_callbacks( - gf, last_vad_state.start_ts_offest_ms, segment_end_ts, - last_vad_state.vad_on ? VAD_STATE_WAS_ON : VAD_STATE_WAS_OFF); - current_vad_state.vad_on = false; - current_vad_state.start_ts_offest_ms = current_vad_state.end_ts_offset_ms; - current_vad_state.end_ts_offset_ms = 0; - current_vad_state.last_partial_segment_end_ts = 0; - last_vad_state = current_vad_state; - continue; - } - - // end not reached - speech is ongoing - current_vad_state.vad_on = true; - if (last_vad_state.vad_on) { - current_vad_state.start_ts_offest_ms = last_vad_state.start_ts_offest_ms; - } else { - current_vad_state.start_ts_offest_ms = - start_ts_offset_ms + start_frame * 1000 / WHISPER_SAMPLE_RATE; - } - obs_log(gf->log_level, "end not reached. vad state: start ts: %llu, end ts: %llu", - current_vad_state.start_ts_offest_ms, current_vad_state.end_ts_offset_ms); - - last_vad_state = current_vad_state; - - // if partial transcription is enabled, check if we should send a partial segment - if (!gf->partial_transcription) { - continue; - } - - // current length of audio in buffer - const uint64_t current_length_ms = - (current_vad_state.end_ts_offset_ms > 0 - ? current_vad_state.end_ts_offset_ms - : current_vad_state.start_ts_offest_ms) - - (current_vad_state.last_partial_segment_end_ts > 0 - ? current_vad_state.last_partial_segment_end_ts - : current_vad_state.start_ts_offest_ms); - obs_log(gf->log_level, "current buffer length after last partial (%lu): %lu ms", - current_vad_state.last_partial_segment_end_ts, current_length_ms); - - if (current_length_ms > (uint64_t)gf->partial_latency) { - current_vad_state.last_partial_segment_end_ts = - current_vad_state.end_ts_offset_ms; - // send partial segment to inference - obs_log(gf->log_level, "Partial segment -> send to inference"); - run_inference_and_callbacks(gf, current_vad_state.start_ts_offest_ms, - current_vad_state.end_ts_offset_ms, - VAD_STATE_PARTIAL); - } - } - - return current_vad_state; -} - void whisper_loop(void *data) { if (data == nullptr) { @@ -566,7 +347,7 @@ void whisper_loop(void *data) obs_log(gf->log_level, "Starting whisper thread"); - vad_state current_vad_state = {false, 0, 0, 0}; + vad_state current_vad_state = {false, now_ms(), 0, 0}; const char *whisper_loop_name = "Whisper loop"; profile_register_root(whisper_loop_name, 50 * 1000 * 1000); @@ -584,12 +365,16 @@ void whisper_loop(void *data) } } - current_vad_state = vad_based_segmentation(gf, current_vad_state); + if (gf->vad_mode == VAD_MODE_HYBRID) { + current_vad_state = hybrid_vad_segmentation(gf, current_vad_state); + } else if (gf->vad_mode == VAD_MODE_ACTIVE) { + current_vad_state = vad_based_segmentation(gf, current_vad_state); + } if (!gf->cleared_last_sub) { // check if we should clear the current sub depending on the minimum subtitle duration uint64_t now = now_ms(); - if ((now - gf->last_sub_render_time) > gf->min_sub_duration) { + if ((now - gf->last_sub_render_time) > gf->max_sub_duration) { // clear the current sub, call the callback with an empty string obs_log(gf->log_level, "Clearing current subtitle. now: %lu ms, last: %lu ms", now, diff --git a/src/whisper-utils/whisper-processing.h b/src/whisper-utils/whisper-processing.h index 5bc162b..a00f7cb 100644 --- a/src/whisper-utils/whisper-processing.h +++ b/src/whisper-utils/whisper-processing.h @@ -29,10 +29,10 @@ struct DetectionResultWithText { std::string language; }; -enum VadState { VAD_STATE_WAS_ON = 0, VAD_STATE_WAS_OFF, VAD_STATE_IS_OFF, VAD_STATE_PARTIAL }; - void whisper_loop(void *data); struct whisper_context *init_whisper_context(const std::string &model_path, struct transcription_filter_data *gf); +void run_inference_and_callbacks(transcription_filter_data *gf, uint64_t start_offset_ms, + uint64_t end_offset_ms, int vad_state); #endif // WHISPER_PROCESSING_H diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index c2e4929..84f3b0a 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -2,13 +2,10 @@ #include "plugin-support.h" #include "model-utils/model-downloader.h" #include "whisper-processing.h" +#include "vad-processing.h" #include -#ifdef _WIN32 -#include -#endif - void shutdown_whisper_thread(struct transcription_filter_data *gf) { obs_log(gf->log_level, "shutdown_whisper_thread"); @@ -40,21 +37,7 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf, } // initialize Silero VAD -#ifdef _WIN32 - // convert mbstring to wstring - int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, - strlen(silero_vad_model_file), NULL, 0); - std::wstring silero_vad_model_path(count, 0); - MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file), - &silero_vad_model_path[0], count); - obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str()); -#else - std::string silero_vad_model_path = silero_vad_model_file; - obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str()); -#endif - // roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py - // for silero vad parameters - gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE)); + initialize_vad(gf, silero_vad_model_file); obs_log(gf->log_level, "Create whisper context"); gf->whisper_context = init_whisper_context(whisper_model_path, gf);