mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 08:21:59 +00:00
chore: mark thread safety [TAB-52] (#186)
* mark thread safety * use shared_ptr to ensure thread safety * fmt
This commit is contained in:
parent
775576b53e
commit
6de61f45bb
5
Cargo.lock
generated
5
Cargo.lock
generated
@ -490,6 +490,7 @@ dependencies = [
|
||||
"derive_builder",
|
||||
"rust-cxx-cmake-bridge",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2294,9 +2295,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.28.1"
|
||||
version = "1.28.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105"
|
||||
checksum = "94d7b1cfd2aa4011f2de74c2c4c63665e27a71006b0a192dcd2710272e73dfa2"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"bytes",
|
||||
|
@ -16,3 +16,4 @@ homepage = "https://github.com/TabbyML/tabby"
|
||||
lazy_static = "1.4.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serdeconv = "0.4.1"
|
||||
tokio = "1.28"
|
||||
|
@ -7,6 +7,7 @@ edition = "2021"
|
||||
cxx = "1.0"
|
||||
derive_builder = "0.12.0"
|
||||
tokenizers = "0.13.3"
|
||||
tokio = { workspace = true, features = ["rt"] }
|
||||
|
||||
[build-dependencies]
|
||||
cxx-build = "1.0"
|
||||
|
@ -16,7 +16,7 @@ class TextInferenceEngine {
|
||||
) const = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
|
@ -77,7 +77,7 @@ class DecoderImpl: public TextInferenceEngine {
|
||||
std::unique_ptr<ctranslate2::Generator> generator_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
std::shared_ptr<TextInferenceEngine> create_engine(
|
||||
rust::Str model_path,
|
||||
rust::Str model_type,
|
||||
rust::Str device,
|
||||
|
@ -16,7 +16,7 @@ mod ffi {
|
||||
device: &str,
|
||||
device_indices: &[i32],
|
||||
num_replicas_per_device: usize,
|
||||
) -> UniquePtr<TextInferenceEngine>;
|
||||
) -> SharedPtr<TextInferenceEngine>;
|
||||
|
||||
fn inference(
|
||||
&self,
|
||||
@ -28,6 +28,9 @@ mod ffi {
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl Send for ffi::TextInferenceEngine {}
|
||||
unsafe impl Sync for ffi::TextInferenceEngine {}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextInferenceEngineCreateOptions {
|
||||
model_path: String,
|
||||
@ -56,13 +59,10 @@ pub struct TextInferenceOptions {
|
||||
}
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: cxx::UniquePtr<ffi::TextInferenceEngine>,
|
||||
engine: cxx::SharedPtr<ffi::TextInferenceEngine>,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
unsafe impl Send for TextInferenceEngine {}
|
||||
unsafe impl Sync for TextInferenceEngine {}
|
||||
|
||||
impl TextInferenceEngine {
|
||||
pub fn create(options: TextInferenceEngineCreateOptions) -> Self where {
|
||||
let engine = ffi::create_engine(
|
||||
@ -78,14 +78,19 @@ impl TextInferenceEngine {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||
pub async fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let output_tokens = self.engine.inference(
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
options.beam_size,
|
||||
);
|
||||
let engine = self.engine.clone();
|
||||
let output_tokens = tokio::task::spawn_blocking(move || {
|
||||
engine.inference(
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
options.beam_size,
|
||||
)
|
||||
})
|
||||
.await
|
||||
.expect("Inference failed");
|
||||
let output_ids: Vec<u32> = output_tokens
|
||||
.iter()
|
||||
.filter_map(|x| match self.tokenizer.token_to_id(x) {
|
||||
|
@ -6,7 +6,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
axum = "0.6"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tokio = { version = "1.17", features = ["full"] }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
tower = "0.4"
|
||||
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
||||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||
|
@ -80,7 +80,7 @@ pub async fn completion(
|
||||
request.prompt.expect("No prompt is set")
|
||||
};
|
||||
|
||||
let text = state.engine.inference(&prompt, options);
|
||||
let text = state.engine.inference(&prompt, options).await;
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user