More offline test improvements (#153)

* Protect logging with a mutex

Main thread and worker thread output could get interleaved weirdly
without this

* Move segments.json saving to different thread

This was taking a considerable amount of time, especially for longer
input files, reducing overall utilization

* Check whether offline test can push more data before waiting

* Fix offline test with large files

In
```
circlebuf_push_back(
  &gf->input_buffers[c],
  audio[c].data() +
    frames_count * frame_size_bytes,
  frames_size_bytes);
```
`frames_count * frame_size_bytes` would overflow with `int` on
a 4 hour file; using `size_t` (on a 64 bit platform) fixes that
This commit is contained in:
Ruwen Hahn 2024-08-14 15:28:33 +02:00 committed by GitHub
parent 6cc88b1ead
commit bdab41cafc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -46,6 +46,8 @@ void obs_log(int log_level, const char *format, ...)
auto diff = now - start; auto diff = now - start;
static std::mutex log_mutex;
auto lock = std::lock_guard(log_mutex);
// print timestamp // print timestamp
printf("[%02d:%02d:%02d.%03d] [%02d:%02lld.%03lld] ", now_tm.tm_hour, now_tm.tm_min, printf("[%02d:%02d:%02d.%03d] [%02d:%02lld.%03lld] ", now_tm.tm_hour, now_tm.tm_min,
now_tm.tm_sec, (int)(epoch.count() % 1000), now_tm.tm_sec, (int)(epoch.count() % 1000),
@ -194,6 +196,11 @@ create_context(int sample_rate, int channels, const std::string &whisper_model_p
return gf; return gf;
} }
std::mutex json_segments_input_mutex;
std::condition_variable json_segments_input_cv;
std::vector<nlohmann::json> json_segments_input;
bool json_segments_input_finished = false;
void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data, void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm32f_data,
size_t frames, int vad_state, const DetectionResultWithText &result) size_t frames, int vad_state, const DetectionResultWithText &result)
{ {
@ -214,25 +221,47 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm
// obs_log(gf->log_level, "Saving %lu frames to %s", frames, filename.c_str()); // obs_log(gf->log_level, "Saving %lu frames to %s", frames, filename.c_str());
// write_audio_wav_file(filename.c_str(), pcm32f_data, frames); // write_audio_wav_file(filename.c_str(), pcm32f_data, frames);
// append a row to the array in the segments.json file
std::string segments_filename = "segments.json";
nlohmann::json segments_json;
// Read existing segments from file
std::ifstream segments_file(segments_filename);
if (segments_file.is_open()) {
segments_file >> segments_json;
segments_file.close();
}
// Create a new segment object // Create a new segment object
nlohmann::json segment; nlohmann::json segment;
segment["start_time"] = result.start_timestamp_ms / 1000.0; segment["start_time"] = result.start_timestamp_ms / 1000.0;
segment["end_time"] = result.end_timestamp_ms / 1000.0; segment["end_time"] = result.end_timestamp_ms / 1000.0;
segment["segment_label"] = result.text; segment["segment_label"] = result.text;
{
auto lock = std::lock_guard(json_segments_input_mutex);
// Add the new segment to the segments array // Add the new segment to the segments array
segments_json.push_back(segment); json_segments_input.push_back(segment);
}
json_segments_input_cv.notify_one();
}
void json_segments_saver_thread_function()
{
std::string segments_filename = "segments.json";
nlohmann::json segments_json;
decltype(json_segments_input) json_segments_input_local;
for (;;) {
{
auto lock = std::unique_lock(json_segments_input_mutex);
while (json_segments_input.empty()) {
if (json_segments_input_finished)
return;
json_segments_input_cv.wait(lock, [&] {
return json_segments_input_finished ||
!json_segments_input.empty();
});
}
std::swap(json_segments_input, json_segments_input_local);
json_segments_input.clear();
}
for (auto &elem : json_segments_input_local) {
segments_json.push_back(std::move(elem));
}
// Write the updated segments back to the file // Write the updated segments back to the file
std::ofstream segments_file_out(segments_filename); std::ofstream segments_file_out(segments_filename);
@ -240,7 +269,8 @@ void audio_chunk_callback(struct transcription_filter_data *gf, const float *pcm
segments_file_out << std::setw(4) << segments_json << std::endl; segments_file_out << std::setw(4) << segments_json << std::endl;
segments_file_out.close(); segments_file_out.close();
} else { } else {
obs_log(gf->log_level, "Failed to open %s", segments_filename.c_str()); obs_log(LOG_INFO, "Failed to open %s", segments_filename.c_str());
}
} }
} }
@ -361,6 +391,7 @@ int wmain(int argc, wchar_t *argv[])
std::cout << "LocalVocal Offline Test" << std::endl; std::cout << "LocalVocal Offline Test" << std::endl;
transcription_filter_data *gf = nullptr; transcription_filter_data *gf = nullptr;
std::optional<std::thread> audio_chunk_saver_thread;
std::vector<std::vector<uint8_t>> audio = std::vector<std::vector<uint8_t>> audio =
read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) { read_audio_file(filenameStr.c_str(), [&](int sample_rate, int channels) {
@ -419,6 +450,10 @@ int wmain(int argc, wchar_t *argv[])
return 1; return 1;
} }
if (gf->enable_audio_chunks_callback) {
audio_chunk_saver_thread.emplace(json_segments_saver_thread_function);
}
// truncate the output file // truncate the output file
obs_log(LOG_INFO, "Truncating output file"); obs_log(LOG_INFO, "Truncating output file");
std::ofstream output_file(gf->output_file_path, std::ios::trunc); std::ofstream output_file(gf->output_file_path, std::ios::trunc);
@ -437,10 +472,10 @@ int wmain(int argc, wchar_t *argv[])
obs_log(LOG_INFO, "Sending samples to whisper buffer"); obs_log(LOG_INFO, "Sending samples to whisper buffer");
// 25 ms worth of frames // 25 ms worth of frames
int frames = gf->sample_rate * window_size_in_ms.count() / 1000; size_t frames = gf->sample_rate * window_size_in_ms.count() / 1000;
const int frame_size_bytes = sizeof(float); const int frame_size_bytes = sizeof(float);
int frames_size_bytes = frames * frame_size_bytes; size_t frames_size_bytes = frames * frame_size_bytes;
int frames_count = 0; size_t frames_count = 0;
int64_t start_time = std::chrono::duration_cast<std::chrono::nanoseconds>( int64_t start_time = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::system_clock::now().time_since_epoch()) std::chrono::system_clock::now().time_since_epoch())
.count(); .count();
@ -464,12 +499,13 @@ int wmain(int argc, wchar_t *argv[])
if (false && now > max_wait) if (false && now > max_wait)
break; break;
gf->input_cv->wait_for(
lock, std::chrono::milliseconds(10), [&] {
return gf->input_buffers->size == 0;
});
if (gf->input_buffers->size == 0) if (gf->input_buffers->size == 0)
break; break;
gf->input_cv->wait_for(
lock, std::chrono::milliseconds(1), [&] {
return gf->input_buffers->size == 0;
});
} }
// push back current audio data to input circlebuf // push back current audio data to input circlebuf
for (size_t c = 0; c < gf->channels; c++) { for (size_t c = 0; c < gf->channels; c++) {
@ -533,6 +569,15 @@ int wmain(int argc, wchar_t *argv[])
} }
} }
if (audio_chunk_saver_thread.has_value()) {
{
auto lock = std::lock_guard(json_segments_input_mutex);
json_segments_input_finished = true;
}
json_segments_input_cv.notify_one();
audio_chunk_saver_thread->join();
}
release_context(gf); release_context(gf);
obs_log(LOG_INFO, "LocalVocal Offline Test Done"); obs_log(LOG_INFO, "LocalVocal Offline Test Done");