add ctranslate2-bindings / tabby rust packages (#146)

* add ctranslate2-bindings

* add fixme for linux build

* turn off shared lib

* add tabby-cli
This commit is contained in:
Meng Zhang 2023-05-25 14:05:28 -07:00 committed by GitHub
parent c08f5acf26
commit a2476af373
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 3326 additions and 0 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "crates/ctranslate2-bindings/CTranslate2"]
path = crates/ctranslate2-bindings/CTranslate2
url = https://github.com/OpenNMT/CTranslate2.git

View File

@ -0,0 +1,2 @@
/target
/Cargo.lock

View File

@ -0,0 +1,14 @@
[package]
name = "ctranslate2-bindings"
version = "0.1.0"
edition = "2021"
[dependencies]
cxx = "1.0"
derive_builder = "0.12.0"
tokenizers = "0.13.3"
[build-dependencies]
bindgen = "0.53.1"
cxx-build = "1.0"
cmake = "0.1"

View File

@ -0,0 +1,32 @@
use cmake::Config;
fn main() {
let dst = Config::new("CTranslate2")
// Default flags.
.define("CMAKE_BUILD_TYPE", "Release")
.define("BUILD_CLI", "OFF")
.define("CMAKE_INSTALL_RPATH_USE_LINK_PATH", "ON")
// FIXME(meng): support linux build.
// OSX flags.
.define("CMAKE_OSX_ARCHITECTURES", "arm64")
.define("WITH_ACCELERATE", "ON")
.define("WITH_MKL", "OFF")
.define("OPENMP_RUNTIME", "NONE")
.define("WITH_RUY", "ON")
.build();
println!("cargo:rustc-link-search=native={}", dst.join("lib").display());
println!("cargo:rustc-link-lib=ctranslate2");
// Tell cargo to invalidate the built crate whenever the wrapper changes
println!("cargo:rerun-if-changed=include/ctranslate2.h");
println!("cargo:rerun-if-changed=src/ctranslate2.cc");
println!("cargo:rerun-if-changed=src/lib.rs");
cxx_build::bridge("src/lib.rs")
.file("src/ctranslate2.cc")
.flag_if_supported("-std=c++17")
.flag_if_supported(&format!("-I{}", dst.join("include").display()))
.compile("cxxbridge");
}

@ -0,0 +1 @@
Subproject commit 692fb607ab67573fa5cf6e410aec24e8655844f8

View File

@ -0,0 +1,19 @@
#pragma once
#include "rust/cxx.h"
namespace tabby {
class TextInferenceEngine {
public:
virtual ~TextInferenceEngine();
virtual rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path);
} // namespace

View File

@ -0,0 +1,47 @@
#include "ctranslate2-bindings/include/ctranslate2.h"
#include "ctranslate2/translator.h"
namespace tabby {
TextInferenceEngine::~TextInferenceEngine() {}
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(const std::string& model_path) {
ctranslate2::models::ModelLoader loader(model_path);
translator_ = std::make_unique<ctranslate2::Translator>(loader);
}
~TextInferenceEngineImpl() {}
rust::Vec<rust::String> inference(
rust::Slice<const rust::String> tokens,
size_t max_decoding_length,
float sampling_temperature,
size_t beam_size
) const {
// Create options.
ctranslate2::TranslationOptions options;
options.max_decoding_length = max_decoding_length;
options.sampling_temperature = sampling_temperature;
options.beam_size = beam_size;
// Inference.
std::vector<std::string> input_tokens(tokens.begin(), tokens.end());
ctranslate2::TranslationResult result = translator_->translate_batch({ input_tokens }, options)[0];
const auto& output_tokens = result.output();
// Convert to rust vec.
rust::Vec<rust::String> output;
output.reserve(output_tokens.size());
std::copy(output_tokens.begin(), output_tokens.end(), std::back_inserter(output));
return output;
}
private:
std::unique_ptr<ctranslate2::Translator> translator_;
};
std::unique_ptr<TextInferenceEngine> create_engine(rust::Str model_path) {
return std::make_unique<TextInferenceEngineImpl>(std::string(model_path));
}
} // namespace tabby

View File

@ -0,0 +1,69 @@
use std::sync::Mutex;
use tokenizers::tokenizer::{Model, Tokenizer};
#[macro_use]
extern crate derive_builder;
#[cxx::bridge(namespace = "tabby")]
mod ffi {
unsafe extern "C++" {
include!("ctranslate2-bindings/include/ctranslate2.h");
type TextInferenceEngine;
fn create_engine(model_path: &str) -> UniquePtr<TextInferenceEngine>;
fn inference(
&self,
tokens: &[String],
max_decoding_length: usize,
sampling_temperature: f32,
beam_size: usize,
) -> Vec<String>;
}
}
#[derive(Builder, Debug)]
pub struct TextInferenceOptions {
#[builder(default = "256")]
max_decoding_length: usize,
#[builder(default = "1.0")]
sampling_temperature: f32,
#[builder(default = "2")]
beam_size: usize,
}
pub struct TextInferenceEngine {
engine: Mutex<cxx::UniquePtr<ffi::TextInferenceEngine>>,
tokenizer: Tokenizer,
}
unsafe impl Send for TextInferenceEngine {}
unsafe impl Sync for TextInferenceEngine {}
impl TextInferenceEngine {
pub fn create(model_path: &str, tokenizer_path: &str) -> Self where {
return TextInferenceEngine {
engine: Mutex::new(ffi::create_engine(model_path)),
tokenizer: Tokenizer::from_file(tokenizer_path).unwrap(),
};
}
pub fn inference(&self, prompt: &str, options: TextInferenceOptions) -> String {
let encoding = self.tokenizer.encode(prompt, true).unwrap();
let output_tokens = self.engine.lock().unwrap().inference(
encoding.get_tokens(),
options.max_decoding_length,
options.sampling_temperature,
options.beam_size,
);
let model = self.tokenizer.get_model();
let output_ids: Vec<u32> = output_tokens
.iter()
.map(|x| model.token_to_id(x).unwrap())
.collect();
self.tokenizer.decode(output_ids, true).unwrap()
}
}

1
crates/tabby/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

2901
crates/tabby/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

29
crates/tabby/Cargo.toml Normal file
View File

@ -0,0 +1,29 @@
[package]
name = "tabby"
version = "0.1.0"
edition = "2021"
[dependencies]
axum = "0.6"
hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1.17", features = ["full"] }
tower = "0.4"
utoipa = { version = "3.3", features = ["axum_extras", "preserve_order"] }
utoipa-swagger-ui = { version = "3.1", features = ["axum"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
env_logger = "0.10.0"
log = "0.4"
ctranslate2-bindings = { path = "../ctranslate2-bindings" }
tower-http = { version = "0.4.0", features = ["cors"] }
clap = { version = "4.3.0", features = ["derive"] }
regex = "1.8.3"
lazy_static = "1.4.0"
[dependencies.uuid]
version = "1.3.3"
features = [
"v4", # Lets you generate random UUIDs
"fast-rng", # Use a faster (but still sufficiently random) RNG
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
]

36
crates/tabby/src/main.rs Normal file
View File

@ -0,0 +1,36 @@
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(propagate_version = true)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
pub enum Commands {
/// Serve the model
Serve {
/// path to model for serving
#[clap(long)]
model: String,
},
}
mod serve;
#[tokio::main]
async fn main() {
let cli = Cli::parse();
// You can check for the existence of subcommands, and if found use their
// matches just as you would the top level cmd
match &cli.command {
Commands::Serve { model } => {
serve::main(model)
.await
.expect("Error happens during the serve");
}
}
}

View File

@ -0,0 +1,81 @@
use axum::{extract::State, Json};
use ctranslate2_bindings::{TextInferenceEngine, TextInferenceOptionsBuilder};
use serde::{Deserialize, Serialize};
use std::{path::Path, sync::Arc};
use utoipa::ToSchema;
mod languages;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct CompletionRequest {
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
language: String,
#[schema(example = "def fib(n):")]
prompt: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct Choice {
index: u32,
text: String,
}
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct CompletionResponse {
id: String,
created: u64,
choices: Vec<Choice>,
}
#[utoipa::path(
post,
path = "/v1/completions",
request_body = CompletionRequest ,
)]
pub async fn completion(
State(state): State<Arc<CompletionState>>,
Json(request): Json<CompletionRequest>,
) -> Json<CompletionResponse> {
let options = TextInferenceOptionsBuilder::default()
.max_decoding_length(64)
.sampling_temperature(0.2)
.build()
.unwrap();
let text = state.engine.inference(&request.prompt, options);
let filtered_text = languages::remove_stop_words(&request.language, &text);
Json(CompletionResponse {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
created: timestamp(),
choices: [Choice {
index: 0,
text: filtered_text.to_string(),
}]
.to_vec(),
})
}
pub struct CompletionState {
engine: TextInferenceEngine,
}
impl CompletionState {
pub fn new(model: &str) -> Self {
let engine = TextInferenceEngine::create(
Path::new(model).join("cpu").to_str().unwrap(),
Path::new(model).join("tokenizer.json").to_str().unwrap(),
);
return Self { engine: engine };
}
}
fn timestamp() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
let start = SystemTime::now();
start
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs()
}

View File

@ -0,0 +1,25 @@
use lazy_static::lazy_static;
use regex::Regex;
use std::collections::HashMap;
lazy_static! {
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
static ref LANGUAGES: HashMap<&'static str, Regex> = {
let mut map = HashMap::new();
map.insert(
"python",
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
);
map
};
}
pub fn remove_stop_words<'a>(language: &'a str, text: &'a str) -> &'a str {
let re = LANGUAGES.get(language).unwrap_or(&DEFAULT);
let position = re.find_iter(&text).next();
if let Some(m) = position {
&text[..m.start()]
} else {
&text
}
}

View File

@ -0,0 +1,22 @@
use axum::Json;
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct LogEventRequest {
#[serde(rename = "type")]
event_type: String,
completion_id: String,
choice_index: u32,
}
#[utoipa::path(
post,
path = "/v1/events",
request_body = LogEventRequest,
)]
pub async fn log_event(Json(request): Json<LogEventRequest>) -> StatusCode {
println!("log_event: {:?}", request);
StatusCode::OK
}

View File

@ -0,0 +1,44 @@
use std::{
net::{Ipv4Addr, SocketAddr},
sync::Arc,
};
use axum::{response::Redirect, routing, Router, Server};
use hyper::Error;
use tower_http::cors::CorsLayer;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
mod completions;
mod events;
#[derive(OpenApi)]
#[openapi(
paths(events::log_event, completions::completion,),
components(schemas(
events::LogEventRequest,
completions::CompletionRequest,
completions::CompletionResponse,
completions::Choice
))
)]
struct ApiDoc;
pub async fn main(model: &str) -> Result<(), Error> {
let completions_state = Arc::new(completions::CompletionState::new(model));
let app = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.route("/v1/events", routing::post(events::log_event))
.route("/v1/completions", routing::post(completions::completion))
.with_state(completions_state)
.route(
"/",
routing::get(|| async { Redirect::temporary("/swagger-ui") }),
)
.layer(CorsLayer::permissive());
let address = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 8080));
println!("Listening at {}", address);
Server::bind(&address).serve(app.into_make_service()).await
}