mirror of
https://github.com/occ-ai/obs-localvocal
synced 2024-11-07 18:57:14 +00:00
Start and stop based on filter enable status (#111)
* refactor: Add initial_creation flag to transcription filter data * refactor: Improve caption duration calculation in set_text_callback
This commit is contained in:
parent
91c2842009
commit
2aa151eb22
@ -17,8 +17,8 @@
|
||||
#include "transcription-utils.h"
|
||||
#include "translation/translation.h"
|
||||
#include "translation/translation-includes.h"
|
||||
|
||||
#define SEND_TIMED_METADATA_URL "http://localhost:8080/timed-metadata"
|
||||
#include "whisper-utils/whisper-utils.h"
|
||||
#include "whisper-utils/whisper-model-utils.h"
|
||||
|
||||
void send_caption_to_source(const std::string &target_source_name, const std::string &caption,
|
||||
struct transcription_filter_data *gf)
|
||||
@ -130,7 +130,13 @@ void set_text_callback(struct transcription_filter_data *gf,
|
||||
if (gf->caption_to_stream) {
|
||||
obs_output_t *streaming_output = obs_frontend_get_streaming_output();
|
||||
if (streaming_output) {
|
||||
obs_output_output_caption_text1(streaming_output, str_copy.c_str());
|
||||
// calculate the duration in seconds
|
||||
const uint64_t duration =
|
||||
result.end_timestamp_ms - result.start_timestamp_ms;
|
||||
obs_log(gf->log_level, "Sending caption to streaming output: %s",
|
||||
str_copy.c_str());
|
||||
obs_output_output_caption_text2(streaming_output, str_copy.c_str(),
|
||||
(double)duration / 1000.0);
|
||||
obs_output_release(streaming_output);
|
||||
}
|
||||
}
|
||||
@ -285,3 +291,23 @@ void media_stopped_callback(void *data_, calldata_t *cd)
|
||||
gf_->active = false;
|
||||
reset_caption_state(gf_);
|
||||
}
|
||||
|
||||
void enable_callback(void *data_, calldata_t *cd)
|
||||
{
|
||||
transcription_filter_data *gf_ = static_cast<struct transcription_filter_data *>(data_);
|
||||
bool enable = calldata_bool(cd, "enabled");
|
||||
if (enable) {
|
||||
obs_log(gf_->log_level, "enable_callback: enable");
|
||||
gf_->active = true;
|
||||
reset_caption_state(gf_);
|
||||
// get filter settings from gf_->context
|
||||
obs_data_t *settings = obs_source_get_settings(gf_->context);
|
||||
update_whisper_model(gf_, settings);
|
||||
obs_data_release(settings);
|
||||
} else {
|
||||
obs_log(gf_->log_level, "enable_callback: disable");
|
||||
gf_->active = false;
|
||||
reset_caption_state(gf_);
|
||||
shutdown_whisper_thread(gf_);
|
||||
}
|
||||
}
|
||||
|
@ -22,5 +22,6 @@ void media_started_callback(void *data_, calldata_t *cd);
|
||||
void media_pause_callback(void *data_, calldata_t *cd);
|
||||
void media_restart_callback(void *data_, calldata_t *cd);
|
||||
void media_stopped_callback(void *data_, calldata_t *cd);
|
||||
void enable_callback(void *data_, calldata_t *cd);
|
||||
|
||||
#endif /* TRANSCRIPTION_FILTER_CALLBACKS_H */
|
||||
|
@ -79,6 +79,7 @@ struct transcription_filter_data {
|
||||
bool fix_utf8 = true;
|
||||
bool enable_audio_chunks_callback = false;
|
||||
bool source_signals_set = false;
|
||||
bool initial_creation = true;
|
||||
|
||||
// Last transcription result
|
||||
std::string last_text;
|
||||
|
@ -14,4 +14,6 @@ struct obs_source_info transcription_filter_info = {
|
||||
.deactivate = transcription_filter_deactivate,
|
||||
.filter_audio = transcription_filter_filter_audio,
|
||||
.filter_remove = transcription_filter_remove,
|
||||
.show = transcription_filter_show,
|
||||
.hide = transcription_filter_hide,
|
||||
};
|
||||
|
@ -57,6 +57,7 @@ struct obs_audio_data *transcription_filter_filter_audio(void *data, struct obs_
|
||||
if (!audio) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (data == nullptr) {
|
||||
return audio;
|
||||
}
|
||||
@ -137,6 +138,9 @@ void transcription_filter_destroy(void *data)
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
|
||||
signal_handler_disconnect(sh_filter, "enable", enable_callback, gf);
|
||||
|
||||
obs_log(gf->log_level, "filter destroy");
|
||||
shutdown_whisper_thread(gf);
|
||||
|
||||
@ -167,7 +171,7 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
|
||||
gf->log_level = (int)obs_data_get_int(s, "log_level");
|
||||
gf->log_level = LOG_INFO; //(int)obs_data_get_int(s, "log_level");
|
||||
gf->vad_enabled = obs_data_get_bool(s, "vad_enabled");
|
||||
gf->log_words = obs_data_get_bool(s, "log_words");
|
||||
gf->caption_to_stream = obs_data_get_bool(s, "caption_to_stream");
|
||||
@ -293,51 +297,61 @@ void transcription_filter_update(void *data, obs_data_t *s)
|
||||
gf->text_source_name = new_text_source_name;
|
||||
}
|
||||
|
||||
obs_log(gf->log_level, "update whisper model");
|
||||
update_whisper_model(gf, s);
|
||||
|
||||
obs_log(gf->log_level, "update whisper params");
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
|
||||
gf->sentence_psum_accept_thresh =
|
||||
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");
|
||||
gf->sentence_psum_accept_thresh =
|
||||
(float)obs_data_get_double(s, "sentence_psum_accept_thresh");
|
||||
|
||||
gf->whisper_params = whisper_full_default_params(
|
||||
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
|
||||
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
|
||||
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
|
||||
gf->whisper_params.language = obs_data_get_string(s, "whisper_language_select");
|
||||
} else {
|
||||
// take the language from gf->target_lang
|
||||
gf->whisper_params.language = language_codes_2_reverse[gf->target_lang].c_str();
|
||||
gf->whisper_params = whisper_full_default_params(
|
||||
(whisper_sampling_strategy)obs_data_get_int(s, "whisper_sampling_method"));
|
||||
gf->whisper_params.duration_ms = (int)obs_data_get_int(s, "buffer_size_msec");
|
||||
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
|
||||
gf->whisper_params.language =
|
||||
obs_data_get_string(s, "whisper_language_select");
|
||||
} else {
|
||||
// take the language from gf->target_lang
|
||||
gf->whisper_params.language =
|
||||
language_codes_2_reverse[gf->target_lang].c_str();
|
||||
}
|
||||
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
|
||||
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
|
||||
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
|
||||
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
|
||||
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
|
||||
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
|
||||
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
|
||||
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
|
||||
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
|
||||
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
|
||||
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
|
||||
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
|
||||
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
|
||||
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
|
||||
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
|
||||
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
|
||||
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
|
||||
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
|
||||
gf->whisper_params.suppress_non_speech_tokens =
|
||||
obs_data_get_bool(s, "suppress_non_speech_tokens");
|
||||
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
|
||||
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
|
||||
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");
|
||||
|
||||
if (gf->vad_enabled && gf->vad) {
|
||||
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
|
||||
gf->vad->set_threshold(vad_threshold);
|
||||
}
|
||||
}
|
||||
gf->whisper_params.initial_prompt = obs_data_get_string(s, "initial_prompt");
|
||||
gf->whisper_params.n_threads = (int)obs_data_get_int(s, "n_threads");
|
||||
gf->whisper_params.n_max_text_ctx = (int)obs_data_get_int(s, "n_max_text_ctx");
|
||||
gf->whisper_params.translate = obs_data_get_bool(s, "whisper_translate");
|
||||
gf->whisper_params.no_context = obs_data_get_bool(s, "no_context");
|
||||
gf->whisper_params.single_segment = obs_data_get_bool(s, "single_segment");
|
||||
gf->whisper_params.print_special = obs_data_get_bool(s, "print_special");
|
||||
gf->whisper_params.print_progress = obs_data_get_bool(s, "print_progress");
|
||||
gf->whisper_params.print_realtime = obs_data_get_bool(s, "print_realtime");
|
||||
gf->whisper_params.print_timestamps = obs_data_get_bool(s, "print_timestamps");
|
||||
gf->whisper_params.token_timestamps = obs_data_get_bool(s, "token_timestamps");
|
||||
gf->whisper_params.thold_pt = (float)obs_data_get_double(s, "thold_pt");
|
||||
gf->whisper_params.thold_ptsum = (float)obs_data_get_double(s, "thold_ptsum");
|
||||
gf->whisper_params.max_len = (int)obs_data_get_int(s, "max_len");
|
||||
gf->whisper_params.split_on_word = obs_data_get_bool(s, "split_on_word");
|
||||
gf->whisper_params.max_tokens = (int)obs_data_get_int(s, "max_tokens");
|
||||
gf->whisper_params.speed_up = obs_data_get_bool(s, "speed_up");
|
||||
gf->whisper_params.suppress_blank = obs_data_get_bool(s, "suppress_blank");
|
||||
gf->whisper_params.suppress_non_speech_tokens =
|
||||
obs_data_get_bool(s, "suppress_non_speech_tokens");
|
||||
gf->whisper_params.temperature = (float)obs_data_get_double(s, "temperature");
|
||||
gf->whisper_params.max_initial_ts = (float)obs_data_get_double(s, "max_initial_ts");
|
||||
gf->whisper_params.length_penalty = (float)obs_data_get_double(s, "length_penalty");
|
||||
|
||||
if (gf->vad_enabled && gf->vad) {
|
||||
const float vad_threshold = (float)obs_data_get_double(s, "vad_threshold");
|
||||
gf->vad->set_threshold(vad_threshold);
|
||||
if (gf->initial_creation && obs_source_enabled(gf->context)) {
|
||||
// source was enabled on creation
|
||||
obs_data_t *settings = obs_source_get_settings(gf->context);
|
||||
update_whisper_model(gf, settings);
|
||||
obs_data_release(settings);
|
||||
gf->active = true;
|
||||
gf->initial_creation = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -421,12 +435,13 @@ void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter)
|
||||
gf->whisper_model_path = std::string(""); // The update function will set the model path
|
||||
gf->whisper_context = nullptr;
|
||||
|
||||
signal_handler_t *sh_filter = obs_source_get_signal_handler(gf->context);
|
||||
signal_handler_connect(sh_filter, "enable", enable_callback, gf);
|
||||
|
||||
obs_log(gf->log_level, "run update");
|
||||
// get the settings updated on the filter data struct
|
||||
transcription_filter_update(gf, settings);
|
||||
|
||||
gf->active = true;
|
||||
|
||||
// handle the event OBS_FRONTEND_EVENT_RECORDING_STARTING to reset the srt sentence number
|
||||
// to match the subtitles with the recording
|
||||
obs_frontend_add_event_callback(recording_state_callback, gf);
|
||||
@ -466,6 +481,20 @@ void transcription_filter_deactivate(void *data)
|
||||
gf->active = false;
|
||||
}
|
||||
|
||||
void transcription_filter_show(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
obs_log(gf->log_level, "filter show");
|
||||
}
|
||||
|
||||
void transcription_filter_hide(void *data)
|
||||
{
|
||||
struct transcription_filter_data *gf =
|
||||
static_cast<struct transcription_filter_data *>(data);
|
||||
obs_log(gf->log_level, "filter hide");
|
||||
}
|
||||
|
||||
void transcription_filter_defaults(obs_data_t *s)
|
||||
{
|
||||
obs_log(LOG_INFO, "filter defaults");
|
||||
@ -586,11 +615,11 @@ obs_properties_t *transcription_filter_properties(void *data)
|
||||
whisper_model_path_external,
|
||||
[](void *data_, obs_properties_t *props, obs_property_t *property,
|
||||
obs_data_t *settings) {
|
||||
obs_log(LOG_INFO, "whisper_model_path_external modified");
|
||||
UNUSED_PARAMETER(property);
|
||||
UNUSED_PARAMETER(props);
|
||||
struct transcription_filter_data *gf_ =
|
||||
static_cast<struct transcription_filter_data *>(data_);
|
||||
obs_log(gf_->log_level, "whisper_model_path_external modified");
|
||||
transcription_filter_update(gf_, settings);
|
||||
return true;
|
||||
},
|
||||
|
@ -16,6 +16,8 @@ void transcription_filter_deactivate(void *data);
|
||||
void transcription_filter_defaults(obs_data_t *s);
|
||||
obs_properties_t *transcription_filter_properties(void *data);
|
||||
void transcription_filter_remove(void *data, obs_source_t *source);
|
||||
void transcription_filter_show(void *data);
|
||||
void transcription_filter_hide(void *data);
|
||||
|
||||
const char *const PLUGIN_INFO_TEMPLATE =
|
||||
"<a href=\"https://github.com/occ-ai/obs-localvocal/\">LocalVocal</a> (%1) by "
|
||||
|
@ -20,6 +20,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
obs_log(LOG_ERROR, "Cannot find Silero VAD model file");
|
||||
return;
|
||||
}
|
||||
obs_log(gf->log_level, "Silero VAD model file: %s", silero_vad_model_file);
|
||||
std::string silero_vad_model_file_str = std::string(silero_vad_model_file);
|
||||
bfree(silero_vad_model_file);
|
||||
|
||||
if (gf->whisper_model_path.empty() || gf->whisper_model_path != new_model_path ||
|
||||
is_external_model) {
|
||||
@ -49,14 +52,15 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
obs_log(LOG_WARNING, "Whisper model does not exist");
|
||||
download_model_with_ui_dialog(
|
||||
model_info,
|
||||
[gf, new_model_path, silero_vad_model_file](
|
||||
[gf, new_model_path, silero_vad_model_file_str](
|
||||
int download_status, const std::string &path) {
|
||||
if (download_status == 0) {
|
||||
obs_log(LOG_INFO,
|
||||
"Model download complete");
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(
|
||||
gf, path, silero_vad_model_file);
|
||||
gf, path,
|
||||
silero_vad_model_file_str.c_str());
|
||||
} else {
|
||||
obs_log(LOG_ERROR, "Model download failed");
|
||||
}
|
||||
@ -65,7 +69,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
// Model exists, just load it
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, model_file_found,
|
||||
silero_vad_model_file);
|
||||
silero_vad_model_file_str.c_str());
|
||||
}
|
||||
} else {
|
||||
// new model is external file, get file location from file property
|
||||
@ -82,8 +86,9 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
} else {
|
||||
shutdown_whisper_thread(gf);
|
||||
gf->whisper_model_path = new_model_path;
|
||||
start_whisper_thread_with_path(gf, external_model_file_path,
|
||||
silero_vad_model_file);
|
||||
start_whisper_thread_with_path(
|
||||
gf, external_model_file_path,
|
||||
silero_vad_model_file_str.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -101,6 +106,7 @@ void update_whisper_model(struct transcription_filter_data *gf, obs_data_t *s)
|
||||
gf->enable_token_ts_dtw, new_dtw_timestamps);
|
||||
gf->enable_token_ts_dtw = obs_data_get_bool(s, "dtw_token_timestamps");
|
||||
shutdown_whisper_thread(gf);
|
||||
start_whisper_thread_with_path(gf, gf->whisper_model_path, silero_vad_model_file);
|
||||
start_whisper_thread_with_path(gf, gf->whisper_model_path,
|
||||
silero_vad_model_file_str.c_str());
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,10 @@
|
||||
|
||||
#include <obs-module.h>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
|
||||
void shutdown_whisper_thread(struct transcription_filter_data *gf)
|
||||
{
|
||||
obs_log(gf->log_level, "shutdown_whisper_thread");
|
||||
@ -27,7 +31,8 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,
|
||||
const std::string &whisper_model_path,
|
||||
const char *silero_vad_model_file)
|
||||
{
|
||||
obs_log(gf->log_level, "start_whisper_thread_with_path: %s", whisper_model_path.c_str());
|
||||
obs_log(gf->log_level, "start_whisper_thread_with_path: %s, silero model path: %s",
|
||||
whisper_model_path.c_str(), silero_vad_model_file);
|
||||
std::lock_guard<std::mutex> lock(gf->whisper_ctx_mutex);
|
||||
if (gf->whisper_context != nullptr) {
|
||||
obs_log(LOG_ERROR, "cannot init whisper: whisper_context is not null");
|
||||
@ -36,16 +41,22 @@ void start_whisper_thread_with_path(struct transcription_filter_data *gf,
|
||||
|
||||
// initialize Silero VAD
|
||||
#ifdef _WIN32
|
||||
std::wstring silero_vad_model_path;
|
||||
silero_vad_model_path.assign(silero_vad_model_file,
|
||||
silero_vad_model_file + strlen(silero_vad_model_file));
|
||||
// convert mbstring to wstring
|
||||
int count = MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file,
|
||||
strlen(silero_vad_model_file), NULL, 0);
|
||||
std::wstring silero_vad_model_path(count, 0);
|
||||
MultiByteToWideChar(CP_UTF8, 0, silero_vad_model_file, strlen(silero_vad_model_file),
|
||||
&silero_vad_model_path[0], count);
|
||||
obs_log(gf->log_level, "Create silero VAD: %S", silero_vad_model_path.c_str());
|
||||
#else
|
||||
std::string silero_vad_model_path = silero_vad_model_file;
|
||||
obs_log(gf->log_level, "Create silero VAD: %s", silero_vad_model_path.c_str());
|
||||
#endif
|
||||
// roughly following https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
||||
// for silero vad parameters
|
||||
gf->vad.reset(new VadIterator(silero_vad_model_path, WHISPER_SAMPLE_RATE));
|
||||
|
||||
obs_log(gf->log_level, "Create whisper context");
|
||||
gf->whisper_context = init_whisper_context(whisper_model_path, gf);
|
||||
if (gf->whisper_context == nullptr) {
|
||||
obs_log(LOG_ERROR, "Failed to initialize whisper context");
|
||||
|
Loading…
Reference in New Issue
Block a user