chore: mark thread safety [TAB-52] (#186)

* mark thread safety

* use shared_ptr to ensure thread safety

* fmt
This commit is contained in:
Meng Zhang 2023-06-03 23:23:31 -07:00 committed by GitHub
parent 775576b53e
commit 6de61f45bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 26 additions and 18 deletions

5
Cargo.lock generated
View File

@ -490,6 +490,7 @@ dependencies = [
"derive_builder",
"rust-cxx-cmake-bridge",
"tokenizers",
"tokio",
]
[[package]]
@ -2294,9 +2295,9 @@ dependencies = [
[[package]]
name = "tokio"
version = "1.28.1"
version = "1.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105"
checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
dependencies = [
"autocfg",
"bytes",

View File

@ -16,3 +16,4 @@ homepage = "https://github.com/TabbyML/tabby"
lazy_static = "1.4.0"
serde = { version = "1.0", features = ["derive"] }
serdeconv = "0.4.1"
tokio = "1.28"

View File

@ -7,6 +7,7 @@ edition = "2021"
cxx = "1.0"
derive_builder = "0.12.0"
tokenizers = "0.13.3"
tokio = { workspace = true, features = ["rt"] }
[build-dependencies]
cxx-build = "1.0"

View File

@ -16,7 +16,7 @@ class TextInferenceEngine {
) const = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(
std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,

View File

@ -77,7 +77,7 @@ class DecoderImpl: public TextInferenceEngine {
std::unique_ptr<ctranslate2::Generator> generator_;
};
std::unique_ptr<TextInferenceEngine> create_engine(
std::shared_ptr<TextInferenceEngine> create_engine(
rust::Str model_path,
rust::Str model_type,
rust::Str device,

View File

@ -16,7 +16,7 @@ mod ffi {
device: &str,
device_indices: &[i32],
num_replicas_per_device: usize,
) -> UniquePtr<TextInferenceEngine>;
) -> SharedPtr<TextInferenceEngine>;
fn inference(
&self,
@ -28,6 +28,9 @@ mod ffi {
}
}
unsafe impl Send for ffi::TextInferenceEngine {}
unsafe impl Sync for ffi::TextInferenceEngine {}
#[derive(Builder, Debug)]
pub struct TextInferenceEngineCreateOptions {
model_path: String,
@ -56,13 +59,10 @@ pub struct TextInferenceOptions {
}
pub struct TextInferenceEngine {
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
tokenizer: Tokenizer,
}
unsafe impl Send for TextInferenceEngine {}
unsafe impl Sync for TextInferenceEngine {}
impl TextInferenceEngine {
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
let engine = ffi::create_engine(
@ -78,14 +78,19 @@ impl TextInferenceEngine {
};
}
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let output_tokens = self.engine.inference(
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,
options.beam_size,
);
let engine = self.engine.clone();
let output_tokens = tokio::task::spawn_blocking(move || {
engine.inference(
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,
options.beam_size,
)
})
.await
.expect("Inference failed");
let output_ids: Vec<u32> = output_tokens
.iter()
.filter_map(|x| match self.tokenizer.token_to_id(x) {

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
axum = "0.6"
hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1.17", features = ["full"] }
tokio = { workspace = true, features = ["full"] }
tower = "0.4"
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }

View File

@ -80,7 +80,7 @@ pub async fn completion(
request.prompt.expect("No prompt is set")
};
let text = state.engine.inference(&prompt, options);
let text = state.engine.inference(&prompt, options).await;
let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text);