mirror of
https://github.com/TabbyML/tabby
synced 2024-11-21 16:03:07 +00:00
add ctranslate2-bindings / tabby rust packages (#146)
* add ctranslate2-bindings * add fixme for linux build * turn off shared lib * add tabby-cli
This commit is contained in:
parent
c08f5acf26
commit
a2476af373
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "crates/ctranslate2-bindings/CTranslate2"]
|
||||
path = crates/ctranslate2-bindings/CTranslate2
|
||||
url = https://github.com/OpenNMT/CTranslate2.git
|
2
crates/ctranslate2-bindings/.gitignore
vendored
Normal file
2
crates/ctranslate2-bindings/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
/Cargo.lock
|
14
crates/ctranslate2-bindings/Cargo.toml
Normal file
14
crates/ctranslate2-bindings/Cargo.toml
Normal file
@ -0,0 +1,14 @@
|
||||
[package]
|
||||
name = "ctranslate2-bindings"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
cxx = "1.0"
|
||||
derive_builder = "0.12.0"
|
||||
tokenizers = "0.13.3"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.53.1"
|
||||
cxx-build = "1.0"
|
||||
cmake = "0.1"
|
32
crates/ctranslate2-bindings/build.rs
Normal file
32
crates/ctranslate2-bindings/build.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use cmake::Config;
|
||||
|
||||
fn main() {
|
||||
let dst = Config::new("CTranslate2")
|
||||
// Default flags.
|
||||
.define("CMAKE_BUILD_TYPE", "Release")
|
||||
.define("BUILD_CLI", "OFF")
|
||||
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")
|
||||
|
||||
// FIXME(meng): support linux build.
|
||||
// OSX flags.
|
||||
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
|
||||
.define("WITH_ACCELERATE", "ON")
|
||||
.define("WITH_MKL", "OFF")
|
||||
.define("OPENMP_RUNTIME", "NONE")
|
||||
.define("WITH_RUY", "ON")
|
||||
.build();
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
|
||||
println!("cargo:rustc-link-lib=ctranslate2");
|
||||
|
||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||
println!("cargo:rerun-if-changed=include/ctranslate2.h");
|
||||
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
|
||||
println!("cargo:rerun-if-changed=src/lib.rs");
|
||||
|
||||
cxx_build::bridge("src/lib.rs")
|
||||
.file("src/ctranslate2.cc")
|
||||
.flag_if_supported("-std=c++17")
|
||||
.flag_if_supported(&format!("-I{}", dst.join("include").display()))
|
||||
.compile("cxxbridge");
|
||||
}
|
1
crates/ctranslate2-bindings/ctranslate2
Submodule
1
crates/ctranslate2-bindings/ctranslate2
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 692fb607ab67573fa5cf6e410aec24e8655844f8
|
19
crates/ctranslate2-bindings/include/ctranslate2.h
Normal file
19
crates/ctranslate2-bindings/include/ctranslate2.h
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "rust/cxx.h"
|
||||
|
||||
namespace tabby {
|
||||
|
||||
class TextInferenceEngine {
|
||||
public:
|
||||
virtual ~TextInferenceEngine();
|
||||
virtual rust::Vec<rust::String> inference(
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
) const = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
|
||||
} // namespace
|
47
crates/ctranslate2-bindings/src/ctranslate2.cc
Normal file
47
crates/ctranslate2-bindings/src/ctranslate2.cc
Normal file
@ -0,0 +1,47 @@
|
||||
#include "ctranslate2-bindings/include/ctranslate2.h"
|
||||
|
||||
#include "ctranslate2/translator.h"
|
||||
|
||||
namespace tabby {
|
||||
TextInferenceEngine::~TextInferenceEngine() {}
|
||||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(const std::string& model_path) {
|
||||
ctranslate2::models::ModelLoader loader(model_path);
|
||||
translator_ = std::make_unique<ctranslate2::Translator>(loader);
|
||||
}
|
||||
|
||||
~TextInferenceEngineImpl() {}
|
||||
|
||||
rust::Vec<rust::String> inference(
|
||||
rust::Slice<const rust::String> tokens,
|
||||
size_t max_decoding_length,
|
||||
float sampling_temperature,
|
||||
size_t beam_size
|
||||
) const {
|
||||
// Create options.
|
||||
ctranslate2::TranslationOptions options;
|
||||
options.max_decoding_length = max_decoding_length;
|
||||
options.sampling_temperature = sampling_temperature;
|
||||
options.beam_size = beam_size;
|
||||
|
||||
// Inference.
|
||||
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
|
||||
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
|
||||
const auto& output_tokens = result.output();
|
||||
|
||||
// Convert to rust vec.
|
||||
rust::Vec<rust::String> output;
|
||||
output.reserve(output_tokens.size());
|
||||
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
|
||||
return output;
|
||||
}
|
||||
private:
|
||||
std::unique_ptr<ctranslate2::Translator> translator_;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
|
||||
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
|
||||
}
|
||||
} // namespace tabby
|
69
crates/ctranslate2-bindings/src/lib.rs
Normal file
69
crates/ctranslate2-bindings/src/lib.rs
Normal file
@ -0,0 +1,69 @@
|
||||
use std::sync::Mutex;
|
||||
use tokenizers::tokenizer::{Model, Tokenizer};
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_builder;
|
||||
|
||||
#[cxx::bridge(namespace = "tabby")]
|
||||
mod ffi {
|
||||
unsafe extern "C++" {
|
||||
include!("ctranslate2-bindings/include/ctranslate2.h");
|
||||
|
||||
type TextInferenceEngine;
|
||||
|
||||
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
|
||||
fn inference(
|
||||
&self,
|
||||
tokens: &[String],
|
||||
max_decoding_length: usize,
|
||||
sampling_temperature: f32,
|
||||
beam_size: usize,
|
||||
) -> Vec<String>;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Builder, Debug)]
|
||||
pub struct TextInferenceOptions {
|
||||
#[builder(default = "256")]
|
||||
max_decoding_length: usize,
|
||||
|
||||
#[builder(default = "1.0")]
|
||||
sampling_temperature: f32,
|
||||
|
||||
#[builder(default = "2")]
|
||||
beam_size: usize,
|
||||
}
|
||||
|
||||
pub struct TextInferenceEngine {
|
||||
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
|
||||
tokenizer: Tokenizer,
|
||||
}
|
||||
|
||||
unsafe impl Send for TextInferenceEngine {}
|
||||
unsafe impl Sync for TextInferenceEngine {}
|
||||
|
||||
impl TextInferenceEngine {
|
||||
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
|
||||
return TextInferenceEngine {
|
||||
engine: Mutex::new(ffi::create_engine(model_path)),
|
||||
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
|
||||
let encoding = self.tokenizer.encode(prompt, true).unwrap();
|
||||
let output_tokens = self.engine.lock().unwrap().inference(
|
||||
encoding.get_tokens(),
|
||||
options.max_decoding_length,
|
||||
options.sampling_temperature,
|
||||
options.beam_size,
|
||||
);
|
||||
|
||||
let model = self.tokenizer.get_model();
|
||||
let output_ids: Vec<u32> = output_tokens
|
||||
.iter()
|
||||
.map(|x| model.token_to_id(x).unwrap())
|
||||
.collect();
|
||||
self.tokenizer.decode(output_ids, true).unwrap()
|
||||
}
|
||||
}
|
1
crates/tabby/.gitignore
vendored
Normal file
1
crates/tabby/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/target
|
2901
crates/tabby/Cargo.lock
generated
Normal file
2901
crates/tabby/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
29
crates/tabby/Cargo.toml
Normal file
29
crates/tabby/Cargo.toml
Normal file
@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "tabby"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
axum = "0.6"
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
tokio = { version = "1.17", features = ["full"] }
|
||||
tower = "0.4"
|
||||
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
|
||||
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
env_logger = "0.10.0"
|
||||
log = "0.4"
|
||||
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
|
||||
tower-http = { version = "0.4.0", features = ["cors"] }
|
||||
clap = { version = "4.3.0", features = ["derive"] }
|
||||
regex = "1.8.3"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1.3.3"
|
||||
features = [
|
||||
"v4", # Lets you generate random UUIDs
|
||||
"fast-rng", # Use a faster (but still sufficiently random) RNG
|
||||
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
|
||||
]
|
36
crates/tabby/src/main.rs
Normal file
36
crates/tabby/src/main.rs
Normal file
@ -0,0 +1,36 @@
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
#[command(propagate_version = true)]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
pub enum Commands {
|
||||
/// Serve the model
|
||||
Serve {
|
||||
/// path to model for serving
|
||||
#[clap(long)]
|
||||
model: String,
|
||||
},
|
||||
}
|
||||
|
||||
mod serve;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// You can check for the existence of subcommands, and if found use their
|
||||
// matches just as you would the top level cmd
|
||||
match &cli.command {
|
||||
Commands::Serve { model } => {
|
||||
serve::main(model)
|
||||
.await
|
||||
.expect("Error happens during the serve");
|
||||
}
|
||||
}
|
||||
}
|
81
crates/tabby/src/serve/completions.rs
Normal file
81
crates/tabby/src/serve/completions.rs
Normal file
@ -0,0 +1,81 @@
|
||||
use axum::{extract::State, Json};
|
||||
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::Path, sync::Arc};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
mod languages;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct CompletionRequest {
|
||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||
#[schema(example = "python")]
|
||||
language: String,
|
||||
|
||||
#[schema(example = "def fib(n):")]
|
||||
prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct Choice {
|
||||
index: u32,
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct CompletionResponse {
|
||||
id: String,
|
||||
created: u64,
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1/completions",
|
||||
request_body = CompletionRequest ,
|
||||
)]
|
||||
pub async fn completion(
|
||||
State(state): State<Arc<CompletionState>>,
|
||||
Json(request): Json<CompletionRequest>,
|
||||
) -> Json<CompletionResponse> {
|
||||
let options = TextInferenceOptionsBuilder::default()
|
||||
.max_decoding_length(64)
|
||||
.sampling_temperature(0.2)
|
||||
.build()
|
||||
.unwrap();
|
||||
let text = state.engine.inference(&request.prompt, options);
|
||||
let filtered_text = languages::remove_stop_words(&request.language, &text);
|
||||
|
||||
Json(CompletionResponse {
|
||||
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||
created: timestamp(),
|
||||
choices: [Choice {
|
||||
index: 0,
|
||||
text: filtered_text.to_string(),
|
||||
}]
|
||||
.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct CompletionState {
|
||||
engine: TextInferenceEngine,
|
||||
}
|
||||
|
||||
impl CompletionState {
|
||||
pub fn new(model: &str) -> Self {
|
||||
let engine = TextInferenceEngine::create(
|
||||
Path::new(model).join("cpu").to_str().unwrap(),
|
||||
Path::new(model).join("tokenizer.json").to_str().unwrap(),
|
||||
);
|
||||
return Self { engine: engine };
|
||||
}
|
||||
}
|
||||
|
||||
fn timestamp() -> u64 {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let start = SystemTime::now();
|
||||
start
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("Time went backwards")
|
||||
.as_secs()
|
||||
}
|
25
crates/tabby/src/serve/completions/languages.rs
Normal file
25
crates/tabby/src/serve/completions/languages.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use lazy_static::lazy_static;
|
||||
use regex::Regex;
|
||||
use std::collections::HashMap;
|
||||
|
||||
lazy_static! {
|
||||
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
|
||||
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"python",
|
||||
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
|
||||
);
|
||||
map
|
||||
};
|
||||
}
|
||||
|
||||
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str {
|
||||
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT);
|
||||
let position = re.find_iter(&text).next();
|
||||
if let Some(m) = position {
|
||||
&text[..m.start()]
|
||||
} else {
|
||||
&text
|
||||
}
|
||||
}
|
22
crates/tabby/src/serve/events.rs
Normal file
22
crates/tabby/src/serve/events.rs
Normal file
@ -0,0 +1,22 @@
|
||||
use axum::Json;
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
|
||||
pub struct LogEventRequest {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
completion_id: String,
|
||||
choice_index: u32,
|
||||
}
|
||||
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/v1/events",
|
||||
request_body = LogEventRequest,
|
||||
)]
|
||||
pub async fn log_event(Json(request): Json<LogEventRequest>) -> StatusCode {
|
||||
println!("log_event: {:?}", request);
|
||||
StatusCode::OK
|
||||
}
|
44
crates/tabby/src/serve/mod.rs
Normal file
44
crates/tabby/src/serve/mod.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::{response::Redirect, routing, Router, Server};
|
||||
use hyper::Error;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
mod completions;
|
||||
mod events;
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
paths(events::log_event, completions::completion,),
|
||||
components(schemas(
|
||||
events::LogEventRequest,
|
||||
completions::CompletionRequest,
|
||||
completions::CompletionResponse,
|
||||
completions::Choice
|
||||
))
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
pub async fn main(model: &str) -> Result<(), Error> {
|
||||
let completions_state = Arc::new(completions::CompletionState::new(model));
|
||||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||
.route("/v1/events", routing::post(events::log_event))
|
||||
.route("/v1/completions", routing::post(completions::completion))
|
||||
.with_state(completions_state)
|
||||
.route(
|
||||
"/",
|
||||
routing::get(|| async { Redirect::temporary("/swagger-ui") }),
|
||||
)
|
||||
.layer(CorsLayer::permissive());
|
||||
|
||||
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080));
|
||||
println!("Listening at {}", address);
|
||||
Server::bind(&address).serve(app.into_make_service()).await
|
||||
}
|
Loading…
Reference in New Issue
Block a user