From 6de61f45bbb735d5812f8c9e00ac932e2fc41ec5 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 3 Jun 2023 23:23:31 -0700 Subject: [PATCH] chore: mark thread safety [TAB-52] (#186) * mark thread safety * use shared_ptr to ensure thread safety * fmt --- Cargo.lock | 5 ++-- Cargo.toml | 1 + crates/ctranslate2-bindings/Cargo.toml | 1 + .../include/ctranslate2.h | 2 +- .../ctranslate2-bindings/src/ctranslate2.cc | 2 +- crates/ctranslate2-bindings/src/lib.rs | 29 +++++++++++-------- crates/tabby/Cargo.toml | 2 +- crates/tabby/src/serve/completions.rs | 2 +- 8 files changed, 26 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 73dc84178..5ace27147 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -490,6 +490,7 @@ dependencies = [ "derive_builder", "rust-cxx-cmake-bridge", "tokenizers", + "tokio", ] [[package]] @@ -2294,9 +2295,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.28.1" +version = "1.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105" +checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2" dependencies = [ "autocfg", "bytes", diff --git a/Cargo.toml b/Cargo.toml index ac233434e..648891176 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,4 @@ homepage = "https://github.com/TabbyML/tabby" lazy_static = "1.4.0" serde = { version = "1.0", features = ["derive"] } serdeconv = "0.4.1" +tokio = "1.28" diff --git a/crates/ctranslate2-bindings/Cargo.toml b/crates/ctranslate2-bindings/Cargo.toml index 1b7e82075..511c9f575 100644 --- a/crates/ctranslate2-bindings/Cargo.toml +++ b/crates/ctranslate2-bindings/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" cxx = "1.0" derive_builder = "0.12.0" tokenizers = "0.13.3" +tokio = { workspace = true, features = ["rt"] } [build-dependencies] cxx-build = "1.0" diff --git a/crates/ctranslate2-bindings/include/ctranslate2.h b/crates/ctranslate2-bindings/include/ctranslate2.h index 514ba573b..9e127d8b1 100644 --- a/crates/ctranslate2-bindings/include/ctranslate2.h +++ b/crates/ctranslate2-bindings/include/ctranslate2.h @@ -16,7 +16,7 @@ class TextInferenceEngine { ) const = 0; }; -std::unique_ptr create_engine( +std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, diff --git a/crates/ctranslate2-bindings/src/ctranslate2.cc b/crates/ctranslate2-bindings/src/ctranslate2.cc index 915e3673b..1191b81d8 100644 --- a/crates/ctranslate2-bindings/src/ctranslate2.cc +++ b/crates/ctranslate2-bindings/src/ctranslate2.cc @@ -77,7 +77,7 @@ class DecoderImpl: public TextInferenceEngine { std::unique_ptr generator_; }; -std::unique_ptr create_engine( +std::shared_ptr create_engine( rust::Str model_path, rust::Str model_type, rust::Str device, diff --git a/crates/ctranslate2-bindings/src/lib.rs b/crates/ctranslate2-bindings/src/lib.rs index 14649d1ae..55e10e768 100644 --- a/crates/ctranslate2-bindings/src/lib.rs +++ b/crates/ctranslate2-bindings/src/lib.rs @@ -16,7 +16,7 @@ mod ffi { device: &str, device_indices: &[i32], num_replicas_per_device: usize, - ) -> UniquePtr; + ) -> SharedPtr; fn inference( &self, @@ -28,6 +28,9 @@ mod ffi { } } +unsafe impl Send for ffi::TextInferenceEngine {} +unsafe impl Sync for ffi::TextInferenceEngine {} + #[derive(Builder, Debug)] pub struct TextInferenceEngineCreateOptions { model_path: String, @@ -56,13 +59,10 @@ pub struct TextInferenceOptions { } pub struct TextInferenceEngine { - engine: cxx::UniquePtr, + engine: cxx::SharedPtr, tokenizer: Tokenizer, } -unsafe impl Send for TextInferenceEngine {} -unsafe impl Sync for TextInferenceEngine {} - impl TextInferenceEngine { pub fn create(options: TextInferenceEngineCreateOptions) -> Self where { let engine = ffi::create_engine( @@ -78,14 +78,19 @@ impl TextInferenceEngine { }; } - pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { + pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String { let encoding = self.tokenizer.encode(prompt, true).unwrap(); - let output_tokens = self.engine.inference( - encoding.get_tokens(), - options.max_decoding_length, - options.sampling_temperature, - options.beam_size, - ); + let engine = self.engine.clone(); + let output_tokens = tokio::task::spawn_blocking(move || { + engine.inference( + encoding.get_tokens(), + options.max_decoding_length, + options.sampling_temperature, + options.beam_size, + ) + }) + .await + .expect("Inference failed"); let output_ids: Vec = output_tokens .iter() .filter_map(|x| match self.tokenizer.token_to_id(x) { diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index ea0118f70..e0aa9ddd9 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] axum = "0.6" hyper = { version = "0.14", features = ["full"] } -tokio = { version = "1.17", features = ["full"] } +tokio = { workspace = true, features = ["full"] } tower = "0.4" utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] } utoipa-swagger-ui = { version = "3.1", features = ["axum"] } diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 4ed567ff3..bcff23d59 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -80,7 +80,7 @@ pub async fn completion( request.prompt.expect("No prompt is set") }; - let text = state.engine.inference(&prompt, options); + let text = state.engine.inference(&prompt, options).await; let language = request.language.unwrap_or("unknown".into()); let filtered_text = languages::remove_stop_words(&language, &text);