refactor: support http bindings in config, remove ExperimentalHttp (#2153)

* refactor: support http bindings in config, remove ExperimentalHttp

* update naming
This commit is contained in:
Meng Zhang 2024-05-16 17:19:39 -07:00 committed by GitHub
parent 3280618d7b
commit 19f3b9eb74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 335 additions and 351 deletions

2
Cargo.lock generated
View File

@ -2682,7 +2682,7 @@ dependencies = [
"http-api-bindings",
"omnicopy_to_output",
"reqwest 0.12.4",
"serde_json",
"tabby-common",
"tabby-inference",
"tokio",
"tracing",

View File

@ -3,19 +3,16 @@ mod openai_chat;
use std::sync::Arc;
use openai_chat::OpenAIChatEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::ChatCompletionStream;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> Arc<dyn ChatCompletionStream> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai-chat" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
pub fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
if model.kind == "openai-chat" {
let engine = OpenAIChatEngine::create(
&model.api_endpoint,
model.model_name.as_deref().unwrap_or_default(),
model.api_key.clone(),
);
Arc::new(engine)
} else {
panic!("Only openai-chat are supported for http chat");

View File

@ -1,33 +1,16 @@
mod llama;
mod openai;
use std::sync::Arc;
use llama::LlamaCppEngine;
use openai::OpenAIEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
(Arc::new(engine), prompt_template, chat_template)
pub fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
if model.kind == "llama.cpp/completion" {
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
Arc::new(engine)
} else {
panic!("Only openai are supported for http completion");
panic!("Unsupported model kind: {}", model.kind);
}
}

View File

@ -1,72 +0,0 @@
use async_openai::{config::OpenAIConfig, error::OpenAIError, types::CreateCompletionRequestArgs};
use async_stream::stream;
use async_trait::async_trait;
use futures::stream::BoxStream;
use tabby_inference::{CompletionOptions, CompletionStream};
use tracing::warn;
pub struct OpenAIEngine {
client: async_openai::Client<OpenAIConfig>,
model_name: String,
}
impl OpenAIEngine {
pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<String>) -> Self {
let config = OpenAIConfig::default()
.with_api_base(api_endpoint)
.with_api_key(api_key.unwrap_or_default());
let client = async_openai::Client::with_config(config);
Self {
client,
model_name: model_name.to_owned(),
}
}
}
#[async_trait]
impl CompletionStream for OpenAIEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let request = CreateCompletionRequestArgs::default()
.model(&self.model_name)
.temperature(options.sampling_temperature)
.max_tokens(options.max_decoding_tokens as u16)
.stream(true)
.prompt(prompt)
.build();
let s = stream! {
let request = match request {
Ok(x) => x,
Err(e) => {
warn!("Failed to build completion request {:?}", e);
return;
}
};
let s = match self.client.completions().create_stream(request).await {
Ok(x) => x,
Err(e) => {
warn!("Failed to create completion request {:?}", e);
return;
}
};
for await x in s {
match x {
Ok(x) => {
yield x.choices[0].text.clone();
},
Err(OpenAIError::StreamError(_)) => break,
Err(e) => {
warn!("Failed to stream response: {}", e);
break;
}
};
}
};
Box::pin(s)
}
}

View File

@ -3,17 +3,12 @@ mod llama;
use std::sync::Arc;
use llama::LlamaCppEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::Embedding;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> Arc<dyn Embedding> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
pub fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
if config.kind == "llama.cpp/embedding" {
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
Arc::new(engine)
} else {
panic!("Only llama are supported for http embedding");

View File

@ -5,19 +5,3 @@ mod embedding;
pub use chat::create as create_chat;
pub use completion::create;
pub use embedding::create as create_embedding;
use serde_json::Value;
pub(crate) fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
.as_str()
.expect("Type unmatched")
.to_owned()
}
pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option<String> {
params
.get(key)
.map(|x| x.as_str().expect("Type unmatched").to_owned())
}

View File

@ -15,8 +15,8 @@ vulkan = ["binary"]
futures.workspace = true
http-api-bindings = { path = "../http-api-bindings" }
reqwest.workspace = true
serde_json.workspace = true
tabby-inference = { path = "../tabby-inference" }
tabby-common = { path = "../tabby-common" }
tracing.workspace = true
async-trait.workspace = true
tokio = { workspace = true, features = ["process"] }
@ -25,4 +25,4 @@ which = "6"
[build-dependencies]
cmake = "0.1"
omnicopy_to_output = "0.1.1"
omnicopy_to_output = "0.1.1"

View File

@ -5,8 +5,8 @@ use std::sync::Arc;
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde_json::json;
use supervisor::LlamaCppSupervisor;
use tabby_common::config::HttpModelConfigBuilder;
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
fn api_endpoint(port: u16) -> String {
@ -20,18 +20,19 @@ struct EmbeddingServer {
}
impl EmbeddingServer {
async fn new(use_gpu: bool, model_path: &str, parallelism: u8) -> EmbeddingServer {
let server = LlamaCppSupervisor::new(use_gpu, true, model_path, parallelism);
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> EmbeddingServer {
let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism);
server.start().await;
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(server.port()),
}))
.expect("Failed to serialize model spec");
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.kind("llama.cpp/embedding".to_string())
.build()
.expect("Failed to create HttpModelConfig");
Self {
server,
embedding: http_api_bindings::create_embedding(&model_spec),
embedding: http_api_bindings::create_embedding(&config),
}
}
}
@ -50,14 +51,14 @@ struct CompletionServer {
}
impl CompletionServer {
async fn new(use_gpu: bool, model_path: &str, parallelism: u8) -> Self {
let server = LlamaCppSupervisor::new(use_gpu, false, model_path, parallelism);
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(server.port()),
}))
.expect("Failed to serialize model spec");
let (completion, _, _) = http_api_bindings::create(&model_spec);
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> Self {
let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism);
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.kind("llama.cpp/completion".to_string())
.build()
.expect("Failed to create HttpModelConfig");
let completion = http_api_bindings::create(&config);
Self { server, completion }
}
}
@ -70,17 +71,17 @@ impl CompletionStream for CompletionServer {
}
pub async fn create_embedding(
use_gpu: bool,
num_gpu_layers: u16,
model_path: &str,
parallelism: u8,
) -> Arc<dyn Embedding> {
Arc::new(EmbeddingServer::new(use_gpu, model_path, parallelism).await)
Arc::new(EmbeddingServer::new(num_gpu_layers, model_path, parallelism).await)
}
pub async fn create_completion(
use_gpu: bool,
num_gpu_layers: u16,
model_path: &str,
parallelism: u8,
) -> Arc<dyn CompletionStream> {
Arc::new(CompletionServer::new(use_gpu, model_path, parallelism).await)
Arc::new(CompletionServer::new(num_gpu_layers, model_path, parallelism).await)
}

View File

@ -13,7 +13,7 @@ pub struct LlamaCppSupervisor {
impl LlamaCppSupervisor {
pub fn new(
use_gpu: bool,
num_gpu_layers: u16,
embedding: bool,
model_path: &str,
parallelism: u8,
@ -56,10 +56,8 @@ impl LlamaCppSupervisor {
command.arg("-t").arg(n_threads);
}
if use_gpu {
let num_gpu_layers =
std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
command.arg("-ngl").arg(&num_gpu_layers);
if num_gpu_layers > 0 {
command.arg("-ngl").arg(num_gpu_layers.to_string());
}
if embedding {

View File

@ -2,6 +2,7 @@ use std::{collections::HashSet, path::PathBuf};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use crate::{
@ -9,13 +10,16 @@ use crate::{
terminal::{HeaderFormat, InfoMessage},
};
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct Config {
#[serde(default)]
pub repositories: Vec<RepositoryConfig>,
#[serde(default)]
pub server: ServerConfig,
#[serde(default)]
pub model: ModelConfigGroup,
}
impl Config {
@ -115,7 +119,7 @@ fn sanitize_name(s: &str) -> String {
sanitized.into_iter().collect()
}
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
pub struct ServerConfig {
/// The timeout in seconds for the /v1/completion api.
pub completion_timeout: u64,
@ -129,6 +133,56 @@ impl Default for ServerConfig {
}
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct ModelConfigGroup {
pub completion: Option<ModelConfig>,
pub chat: Option<ModelConfig>,
pub embedding: Option<ModelConfig>,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "snake_case")]
pub enum ModelConfig {
Http(HttpModelConfig),
Local(LocalModelConfig),
}
#[derive(Serialize, Deserialize, Builder, Clone)]
pub struct HttpModelConfig {
pub api_endpoint: String,
pub kind: String,
#[builder(default)]
pub api_key: Option<String>,
/// Used by chat http endpoint to select model.
#[builder(default)]
pub model_name: Option<String>,
/// Used by completion http endpoint to construct FIM prompt.
#[builder(default)]
pub prompt_template: Option<String>,
/// Used by completion http endpoint to construct Chat prompt.
#[builder(default)]
pub chat_template: Option<String>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct LocalModelConfig {
pub model_id: String,
#[serde(default = "default_parallelism")]
pub parallelism: u8,
#[serde(default)]
pub num_gpu_layers: u16,
}
fn default_parallelism() -> u8 {
1
}
#[async_trait]
pub trait RepositoryAccess: Send + Sync {
async fn list_repositories(&self) -> Result<Vec<RepositoryConfig>>;

View File

@ -15,7 +15,7 @@ async fn main() {
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let mut doc_index = DocIndex::new(create_embedding(false, &model_path, 1).await);
let mut doc_index = DocIndex::new(create_embedding(0, &model_path, 1).await);
let mut cnt = 0;
stream! {
for await doc in crawl_pipeline("https://tabby.tabbyml.com/").await {

View File

@ -11,7 +11,7 @@ mod worker;
use std::os::unix::fs::PermissionsExt;
use clap::{Parser, Subcommand};
use tabby_common::config::{Config, ConfigRepositoryAccess};
use tabby_common::config::{Config, ConfigRepositoryAccess, LocalModelConfig, ModelConfig};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
@ -72,10 +72,6 @@ pub enum Device {
#[strum(serialize = "vulkan")]
Vulkan,
#[strum(serialize = "experimental_http")]
#[clap(hide = true)]
ExperimentalHttp,
}
#[tokio::main]
@ -103,11 +99,16 @@ async fn main() {
}
#[cfg(feature = "ee")]
Commands::WorkerCompletion(ref args) => {
worker::main(tabby_webserver::public::WorkerKind::Completion, args).await
worker::main(
&config,
tabby_webserver::public::WorkerKind::Completion,
args,
)
.await
}
#[cfg(feature = "ee")]
Commands::WorkerChat(ref args) => {
worker::main(tabby_webserver::public::WorkerKind::Chat, args).await
worker::main(&config, tabby_webserver::public::WorkerKind::Chat, args).await
}
}
}
@ -158,3 +159,20 @@ fn init_logging() {
.with(env_filter)
.init();
}
fn to_local_config(model: &str, parallelism: u8, device: &Device) -> ModelConfig {
let num_gpu_layers = if *device != Device::Cpu {
std::env::var("LLAMA_CPP_N_GPU_LAYERS")
.map(|s| s.parse::<u16>().ok())
.ok()
.flatten()
.unwrap_or(9999)
} else {
9999
};
ModelConfig::Local(LocalModelConfig {
model_id: model.to_owned(),
parallelism,
num_gpu_layers,
})
}

View File

@ -4,9 +4,8 @@ use axum::{routing, Router};
use clap::Args;
use hyper::StatusCode;
use tabby_common::{
api,
api::{code::CodeSearch, event::EventLogger},
config::{Config, ConfigRepositoryAccess, RepositoryAccess},
api::{self, code::CodeSearch, event::EventLogger},
config::{Config, ConfigRepositoryAccess, ModelConfig, RepositoryAccess},
usage,
};
use tokio::time::sleep;
@ -30,7 +29,7 @@ use crate::{
health,
model::download_model_if_needed,
},
Device,
to_local_config, Device,
};
#[derive(OpenApi)]
@ -92,9 +91,6 @@ pub struct ServeArgs {
#[clap(long)]
chat_model: Option<String>,
#[clap(long, hide = true)]
embedding_model: Option<String>,
#[clap(long, default_value = "0.0.0.0")]
host: IpAddr,
@ -125,7 +121,9 @@ pub struct ServeArgs {
}
pub async fn main(config: &Config, args: &ServeArgs) {
load_model(args).await;
let config = merge_args(config, args);
load_model(&config).await;
debug!("Starting server, this might take a few minutes...");
@ -160,7 +158,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
}
let code = Arc::new(create_code_search(repository_access));
let mut api = api_router(args, config, logger.clone(), code.clone(), webserver).await;
let mut api = api_router(args, &config, logger.clone(), code.clone(), webserver).await;
let mut ui = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
.fallback(|| async { axum::response::Redirect::temporary("/swagger-ui") });
@ -172,26 +170,21 @@ pub async fn main(config: &Config, args: &ServeArgs) {
ui = new_ui;
};
start_heartbeat(args, webserver);
start_heartbeat(args, &config, webserver);
run_app(api, Some(ui), args.host, args.port).await
}
async fn load_model(args: &ServeArgs) {
if args.device != Device::ExperimentalHttp {
if let Some(model) = &args.model {
download_model_if_needed(model).await;
}
async fn load_model(config: &Config) {
if let Some(ModelConfig::Local(ref model)) = config.model.completion {
download_model_if_needed(&model.model_id).await;
}
let chat_device = args.chat_device.as_ref().unwrap_or(&args.device);
if chat_device != &Device::ExperimentalHttp {
if let Some(chat_model) = &args.chat_model {
download_model_if_needed(chat_model).await
}
if let Some(ModelConfig::Local(ref model)) = config.model.chat {
download_model_if_needed(&model.model_id).await;
}
if let Some(embedding_model) = &args.embedding_model {
download_model_if_needed(embedding_model).await
if let Some(ModelConfig::Local(ref model)) = config.model.embedding {
download_model_if_needed(&model.model_id).await;
}
}
@ -202,37 +195,23 @@ async fn api_router(
code: Arc<dyn CodeSearch>,
webserver: Option<bool>,
) -> Router {
let completion_state = if let Some(model) = &args.model {
let model = &config.model;
let completion_state = if let Some(completion) = &model.completion {
Some(Arc::new(
create_completion_service(
code.clone(),
logger.clone(),
model,
&args.device,
args.parallelism,
)
.await,
create_completion_service(code.clone(), logger.clone(), completion).await,
))
} else {
None
};
let chat_state = if let Some(chat_model) = &args.chat_model {
Some(Arc::new(
create_chat_service(
logger.clone(),
chat_model,
args.chat_device.as_ref().unwrap_or(&args.device),
args.parallelism,
)
.await,
))
let chat_state = if let Some(chat) = &model.chat {
Some(Arc::new(create_chat_service(logger.clone(), chat).await))
} else {
None
};
let docsearch_state = if let Some(embedding_model) = &args.embedding_model {
let embedding = embedding::create(embedding_model, &args.device).await;
let docsearch_state = if let Some(embedding) = &model.embedding {
let embedding = embedding::create(embedding).await;
Some(Arc::new(services::doc::create(embedding)))
} else {
None
@ -249,9 +228,8 @@ async fn api_router(
let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(
args.model.as_deref(),
model,
&args.device,
args.chat_model.as_deref(),
args.chat_model
.as_deref()
.map(|_| args.chat_device.as_ref().unwrap_or(&args.device)),
@ -379,16 +357,15 @@ async fn api_router(
root
}
fn start_heartbeat(args: &ServeArgs, webserver: Option<bool>) {
let state = health::HealthState::new(
args.model.as_deref(),
fn start_heartbeat(args: &ServeArgs, config: &Config, webserver: Option<bool>) {
let state = Arc::new(health::HealthState::new(
&config.model,
&args.device,
args.chat_model.as_deref(),
args.chat_model
.as_deref()
.map(|_| args.chat_device.as_ref().unwrap_or(&args.device)),
webserver,
);
));
tokio::spawn(async move {
loop {
usage::capture("ServeHealth", &state).await;
@ -414,3 +391,20 @@ impl Modify for SecurityAddon {
}
}
}
fn merge_args(config: &Config, args: &ServeArgs) -> Config {
let mut config = (*config).clone();
if let Some(model) = &args.model {
config.model.completion = Some(to_local_config(model, args.parallelism, &args.device));
};
if let Some(chat_model) = &args.chat_model {
config.model.chat = Some(to_local_config(
chat_model,
args.parallelism,
args.chat_device.as_ref().unwrap_or(&args.device),
));
}
config
}

View File

@ -4,9 +4,12 @@ use async_stream::stream;
use derive_builder::Builder;
use futures::stream::BoxStream;
use serde::{Deserialize, Serialize};
use tabby_common::api::{
chat::Message,
event::{Event, EventLogger},
use tabby_common::{
api::{
chat::Message,
event::{Event, EventLogger},
},
config::ModelConfig,
};
use tabby_inference::{ChatCompletionOptionsBuilder, ChatCompletionStream};
use tracing::warn;
@ -14,7 +17,6 @@ use utoipa::ToSchema;
use uuid::Uuid;
use super::model;
use crate::Device;
#[derive(Serialize, Deserialize, ToSchema, Clone, Builder, Debug)]
#[schema(example=json!({
@ -156,13 +158,8 @@ fn convert_messages(input: &[Message]) -> Vec<tabby_common::api::event::Message>
.collect()
}
pub async fn create_chat_service(
logger: Arc<dyn EventLogger>,
model: &str,
device: &Device,
parallelism: u8,
) -> ChatService {
let engine = model::load_chat_completion(model, device, parallelism).await;
pub async fn create_chat_service(logger: Arc<dyn EventLogger>, chat: &ModelConfig) -> ChatService {
let engine = model::load_chat_completion(chat).await;
ChatService::new(engine, logger)
}

View File

@ -4,11 +4,12 @@ use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tabby_common::{
api,
api::{
self,
code::CodeSearch,
event::{Event, EventLogger},
},
config::ModelConfig,
languages::get_language,
};
use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
@ -16,7 +17,6 @@ use thiserror::Error;
use utoipa::ToSchema;
use super::model;
use crate::Device;
#[derive(Error, Debug)]
pub enum CompletionError {
@ -342,16 +342,14 @@ impl CompletionService {
pub async fn create_completion_service(
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
model: &str,
device: &Device,
parallelism: u8,
model: &ModelConfig,
) -> CompletionService {
let (
engine,
model::PromptInfo {
prompt_template, ..
},
) = model::load_code_generation(model, device, parallelism).await;
) = model::load_code_generation(model).await;
CompletionService::new(engine.clone(), code, logger, prompt_template)
}

View File

@ -1,10 +1,10 @@
use std::sync::Arc;
use tabby_common::config::ModelConfig;
use tabby_inference::Embedding;
use super::model;
use crate::Device;
pub async fn create(model: &str, device: &Device) -> Arc<dyn Embedding> {
model::load_embedding(model, device).await
pub async fn create(config: &ModelConfig) -> Arc<dyn Embedding> {
model::load_embedding(config).await
}

View File

@ -4,6 +4,7 @@ use anyhow::Result;
use nvml_wrapper::Nvml;
use serde::{Deserialize, Serialize};
use sysinfo::{CpuExt, System, SystemExt};
use tabby_common::config::{ModelConfig, ModelConfigGroup};
use utoipa::ToSchema;
use crate::Device;
@ -27,9 +28,8 @@ pub struct HealthState {
impl HealthState {
pub fn new(
model: Option<&str>,
model_config: &ModelConfigGroup,
device: &Device,
chat_model: Option<&str>,
chat_device: Option<&Device>,
webserver: Option<bool>,
) -> Self {
@ -40,24 +40,9 @@ impl HealthState {
Err(_) => vec![],
};
let http_model_name = Some("Remote");
let is_model_http = device == &Device::ExperimentalHttp;
let model = if is_model_http {
http_model_name
} else {
model
};
let is_chat_model_http = chat_device == Some(&Device::ExperimentalHttp);
let chat_model = if is_chat_model_http {
http_model_name
} else {
chat_model
};
Self {
model: model.map(|x| x.to_string()),
chat_model: chat_model.map(|x| x.to_owned()),
model: to_model_name(&model_config.completion),
chat_model: to_model_name(&model_config.chat),
chat_device: chat_device.map(|x| x.to_string()),
device: device.to_string(),
arch: ARCH.to_string(),
@ -70,6 +55,17 @@ impl HealthState {
}
}
fn to_model_name(model: &Option<ModelConfig>) -> Option<String> {
if let Some(model) = model {
match model {
ModelConfig::Http(_http) => Some("Remote".to_owned()),
ModelConfig::Local(llama) => Some(llama.model_id.clone()),
}
} else {
None
}
}
pub fn read_cpu_info() -> (String, usize) {
let mut system = System::new_all();
system.refresh_cpu();

View File

@ -3,98 +3,101 @@ mod chat;
use std::{fs, path::PathBuf, sync::Arc};
use serde::Deserialize;
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
use tabby_common::{
config::ModelConfig,
registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH},
};
use tabby_download::download_model;
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
use tracing::info;
use crate::{fatal, Device};
use crate::fatal;
pub async fn load_chat_completion(
model_id: &str,
device: &Device,
parallelism: u8,
) -> Arc<dyn ChatCompletionStream> {
if device == &Device::ExperimentalHttp {
return http_api_bindings::create_chat(model_id);
}
pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionStream> {
match chat {
ModelConfig::Http(http) => http_api_bindings::create_chat(http),
let (engine, PromptInfo { chat_template, .. }) =
load_completion(model_id, device, parallelism).await;
ModelConfig::Local(_) => {
let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await;
let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};
let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};
Arc::new(chat::make_chat_completion(engine, chat_template))
}
pub async fn load_embedding(model_id: &str, device: &Device) -> Arc<dyn Embedding> {
if device == &Device::ExperimentalHttp {
return http_api_bindings::create_embedding(model_id);
}
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
create_ggml_embedding_engine(model_path.display().to_string().as_str()).await
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
create_ggml_embedding_engine(&model_path).await
Arc::new(chat::make_chat_completion(engine, chat_template))
}
}
}
pub async fn load_code_generation(
model_id: &str,
device: &Device,
parallelism: u8,
) -> (Arc<CodeGeneration>, PromptInfo) {
let (engine, prompt_info) = load_completion(model_id, device, parallelism).await;
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
match config {
ModelConfig::Http(http) => http_api_bindings::create_embedding(http),
ModelConfig::Local(llama) => {
if fs::metadata(&llama.model_id).is_ok() {
let path = PathBuf::from(&llama.model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
create_ggml_embedding_engine(
model_path.display().to_string().as_str(),
llama.parallelism,
llama.num_gpu_layers,
)
.await
} else {
let (registry, name) = parse_model_id(&llama.model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
create_ggml_embedding_engine(&model_path, llama.parallelism, llama.num_gpu_layers)
.await
}
}
}
}
pub async fn load_code_generation(model: &ModelConfig) -> (Arc<CodeGeneration>, PromptInfo) {
let (engine, prompt_info) = load_completion(model).await;
(Arc::new(CodeGeneration::new(engine)), prompt_info)
}
async fn load_completion(
model_id: &str,
device: &Device,
parallelism: u8,
) -> (Arc<dyn CompletionStream>, PromptInfo) {
if device == &Device::ExperimentalHttp {
let (engine, prompt_template, chat_template) = http_api_bindings::create(model_id);
return (
engine,
PromptInfo {
prompt_template,
chat_template,
},
);
}
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
device,
model_path.display().to_string().as_str(),
parallelism,
)
.await;
let engine_info = PromptInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism).await;
(
engine,
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
async fn load_completion(model: &ModelConfig) -> (Arc<dyn CompletionStream>, PromptInfo) {
match model {
ModelConfig::Http(http) => {
let engine = http_api_bindings::create(http);
(
engine,
PromptInfo {
prompt_template: http.prompt_template.clone(),
chat_template: http.chat_template.clone(),
},
)
}
ModelConfig::Local(llama) => {
if fs::metadata(&llama.model_id).is_ok() {
let path = PathBuf::from(&llama.model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
llama.num_gpu_layers,
model_path.display().to_string().as_str(),
llama.parallelism,
)
.await;
let engine_info = PromptInfo::read(path.join("tabby.json"));
(engine, engine_info)
} else {
let (registry, name) = parse_model_id(&llama.model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine =
create_ggml_engine(llama.num_gpu_layers, &model_path, llama.parallelism).await;
(
engine,
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
}
}
}
}
@ -112,16 +115,20 @@ impl PromptInfo {
}
async fn create_ggml_engine(
device: &Device,
num_gpu_layers: u16,
model_path: &str,
parallelism: u8,
) -> Arc<dyn CompletionStream> {
llama_cpp_server::create_completion(device != &Device::Cpu, model_path, parallelism).await
llama_cpp_server::create_completion(num_gpu_layers, model_path, parallelism).await
}
async fn create_ggml_embedding_engine(model_path: &str) -> Arc<dyn Embedding> {
async fn create_ggml_embedding_engine(
model_path: &str,
parallelism: u8,
num_gpu_layers: u16,
) -> Arc<dyn Embedding> {
// By default, embedding always use CPU device with 1 parallelism.
llama_cpp_server::create_embedding(false, model_path, 1).await
llama_cpp_server::create_embedding(num_gpu_layers, model_path, parallelism).await
}
pub async fn download_model_if_needed(model: &str) {

View File

@ -2,7 +2,10 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc};
use axum::{routing, Router};
use clap::Args;
use tabby_common::api::{code::CodeSearch, event::EventLogger};
use tabby_common::{
api::{code::CodeSearch, event::EventLogger},
config::Config,
};
use tabby_webserver::public::{RegisterWorkerRequest, WorkerClient, WorkerKind};
use tracing::info;
@ -14,7 +17,7 @@ use crate::{
health::{read_cpu_info, read_cuda_devices},
model::download_model_if_needed,
},
Device,
to_local_config, Device,
};
#[derive(Args)]
@ -47,9 +50,18 @@ pub struct WorkerArgs {
parallelism: u8,
}
async fn make_chat_route(logger: Arc<dyn EventLogger>, args: &WorkerArgs) -> Router {
let chat_state =
Arc::new(create_chat_service(logger, &args.model, &args.device, args.parallelism).await);
async fn make_chat_route(logger: Arc<dyn EventLogger>, config: &Config) -> Router {
let chat_state = Arc::new(
create_chat_service(
logger,
config
.model
.chat
.as_ref()
.expect("Chat model config is missing"),
)
.await,
);
Router::new()
.route(
@ -65,10 +77,19 @@ async fn make_chat_route(logger: Arc<dyn EventLogger>, args: &WorkerArgs) -> Rou
async fn make_completion_route(
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
args: &WorkerArgs,
config: &Config,
) -> Router {
let completion_state = Arc::new(
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
create_completion_service(
code,
logger,
config
.model
.completion
.as_ref()
.expect("Completion model config is missing"),
)
.await,
);
Router::new().route(
@ -77,8 +98,9 @@ async fn make_completion_route(
)
}
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
pub async fn main(config: &Config, kind: WorkerKind, args: &WorkerArgs) {
download_model_if_needed(&args.model).await;
let config = merge_args(config, args, &kind);
info!("Starting worker, this might take a few minutes...");
@ -87,8 +109,8 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
let logger = code.clone();
let app = match kind {
WorkerKind::Completion => make_completion_route(code, logger, args).await,
WorkerKind::Chat => make_chat_route(logger.clone(), args).await,
WorkerKind::Completion => make_completion_route(code, logger, &config).await,
WorkerKind::Chat => make_chat_route(logger.clone(), &config).await,
};
run_app(app, None, args.host, args.port).await
@ -122,3 +144,15 @@ impl WorkerContext {
}
}
}
fn merge_args(config: &Config, args: &WorkerArgs, kind: &WorkerKind) -> Config {
let mut config = (*config).clone();
let override_config = Some(to_local_config(&args.model, args.parallelism, &args.device));
match kind {
WorkerKind::Chat => config.model.chat = override_config,
WorkerKind::Completion => config.model.completion = override_config,
}
config
}