mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
refactor: support http bindings in config, remove ExperimentalHttp (#2153)
* refactor: support http bindings in config, remove ExperimentalHttp * update naming
This commit is contained in:
parent
3280618d7b
commit
19f3b9eb74
2
Cargo.lock
generated
2
Cargo.lock
generated
@ -2682,7 +2682,7 @@ dependencies = [
|
||||
"http-api-bindings",
|
||||
"omnicopy_to_output",
|
||||
"reqwest 0.12.4",
|
||||
"serde_json",
|
||||
"tabby-common",
|
||||
"tabby-inference",
|
||||
"tokio",
|
||||
"tracing",
|
||||
|
@ -3,19 +3,16 @@ mod openai_chat;
|
||||
use std::sync::Arc;
|
||||
|
||||
use openai_chat::OpenAIChatEngine;
|
||||
use tabby_common::config::HttpModelConfig;
|
||||
use tabby_inference::ChatCompletionStream;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> Arc<dyn ChatCompletionStream> {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai-chat" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
|
||||
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
|
||||
pub fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
|
||||
if model.kind == "openai-chat" {
|
||||
let engine = OpenAIChatEngine::create(
|
||||
&model.api_endpoint,
|
||||
model.model_name.as_deref().unwrap_or_default(),
|
||||
model.api_key.clone(),
|
||||
);
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only openai-chat are supported for http chat");
|
||||
|
@ -1,33 +1,16 @@
|
||||
mod llama;
|
||||
mod openai;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use llama::LlamaCppEngine;
|
||||
use openai::OpenAIEngine;
|
||||
use tabby_common::config::HttpModelConfig;
|
||||
use tabby_inference::CompletionStream;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
} else if kind == "llama" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
pub fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
|
||||
if model.kind == "llama.cpp/completion" {
|
||||
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only openai are supported for http completion");
|
||||
panic!("Unsupported model kind: {}", model.kind);
|
||||
}
|
||||
}
|
||||
|
@ -1,72 +0,0 @@
|
||||
use async_openai::{config::OpenAIConfig, error::OpenAIError, types::CreateCompletionRequestArgs};
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use tabby_inference::{CompletionOptions, CompletionStream};
|
||||
use tracing::warn;
|
||||
|
||||
pub struct OpenAIEngine {
|
||||
client: async_openai::Client<OpenAIConfig>,
|
||||
model_name: String,
|
||||
}
|
||||
|
||||
impl OpenAIEngine {
|
||||
pub fn create(api_endpoint: &str, model_name: &str, api_key: Option<String>) -> Self {
|
||||
let config = OpenAIConfig::default()
|
||||
.with_api_base(api_endpoint)
|
||||
.with_api_key(api_key.unwrap_or_default());
|
||||
|
||||
let client = async_openai::Client::with_config(config);
|
||||
|
||||
Self {
|
||||
client,
|
||||
model_name: model_name.to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CompletionStream for OpenAIEngine {
|
||||
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
|
||||
let request = CreateCompletionRequestArgs::default()
|
||||
.model(&self.model_name)
|
||||
.temperature(options.sampling_temperature)
|
||||
.max_tokens(options.max_decoding_tokens as u16)
|
||||
.stream(true)
|
||||
.prompt(prompt)
|
||||
.build();
|
||||
|
||||
let s = stream! {
|
||||
let request = match request {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
warn!("Failed to build completion request {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let s = match self.client.completions().create_stream(request).await {
|
||||
Ok(x) => x,
|
||||
Err(e) => {
|
||||
warn!("Failed to create completion request {:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for await x in s {
|
||||
match x {
|
||||
Ok(x) => {
|
||||
yield x.choices[0].text.clone();
|
||||
},
|
||||
Err(OpenAIError::StreamError(_)) => break,
|
||||
Err(e) => {
|
||||
warn!("Failed to stream response: {}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(s)
|
||||
}
|
||||
}
|
@ -3,17 +3,12 @@ mod llama;
|
||||
use std::sync::Arc;
|
||||
|
||||
use llama::LlamaCppEngine;
|
||||
use tabby_common::config::HttpModelConfig;
|
||||
use tabby_inference::Embedding;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> Arc<dyn Embedding> {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "llama" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
|
||||
pub fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
|
||||
if config.kind == "llama.cpp/embedding" {
|
||||
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only llama are supported for http embedding");
|
||||
|
@ -5,19 +5,3 @@ mod embedding;
|
||||
pub use chat::create as create_chat;
|
||||
pub use completion::create;
|
||||
pub use embedding::create as create_embedding;
|
||||
use serde_json::Value;
|
||||
|
||||
pub(crate) fn get_param(params: &Value, key: &str) -> String {
|
||||
params
|
||||
.get(key)
|
||||
.unwrap_or_else(|| panic!("Missing {} field", key))
|
||||
.as_str()
|
||||
.expect("Type unmatched")
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option<String> {
|
||||
params
|
||||
.get(key)
|
||||
.map(|x| x.as_str().expect("Type unmatched").to_owned())
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ vulkan = ["binary"]
|
||||
futures.workspace = true
|
||||
http-api-bindings = { path = "../http-api-bindings" }
|
||||
reqwest.workspace = true
|
||||
serde_json.workspace = true
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
tracing.workspace = true
|
||||
async-trait.workspace = true
|
||||
tokio = { workspace = true, features = ["process"] }
|
||||
@ -25,4 +25,4 @@ which = "6"
|
||||
|
||||
[build-dependencies]
|
||||
cmake = "0.1"
|
||||
omnicopy_to_output = "0.1.1"
|
||||
omnicopy_to_output = "0.1.1"
|
||||
|
@ -5,8 +5,8 @@ use std::sync::Arc;
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use serde_json::json;
|
||||
use supervisor::LlamaCppSupervisor;
|
||||
use tabby_common::config::HttpModelConfigBuilder;
|
||||
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
|
||||
|
||||
fn api_endpoint(port: u16) -> String {
|
||||
@ -20,18 +20,19 @@ struct EmbeddingServer {
|
||||
}
|
||||
|
||||
impl EmbeddingServer {
|
||||
async fn new(use_gpu: bool, model_path: &str, parallelism: u8) -> EmbeddingServer {
|
||||
let server = LlamaCppSupervisor::new(use_gpu, true, model_path, parallelism);
|
||||
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> EmbeddingServer {
|
||||
let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism);
|
||||
server.start().await;
|
||||
|
||||
let model_spec: String = serde_json::to_string(&json!({
|
||||
"kind": "llama",
|
||||
"api_endpoint": api_endpoint(server.port()),
|
||||
}))
|
||||
.expect("Failed to serialize model spec");
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(api_endpoint(server.port()))
|
||||
.kind("llama.cpp/embedding".to_string())
|
||||
.build()
|
||||
.expect("Failed to create HttpModelConfig");
|
||||
|
||||
Self {
|
||||
server,
|
||||
embedding: http_api_bindings::create_embedding(&model_spec),
|
||||
embedding: http_api_bindings::create_embedding(&config),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -50,14 +51,14 @@ struct CompletionServer {
|
||||
}
|
||||
|
||||
impl CompletionServer {
|
||||
async fn new(use_gpu: bool, model_path: &str, parallelism: u8) -> Self {
|
||||
let server = LlamaCppSupervisor::new(use_gpu, false, model_path, parallelism);
|
||||
let model_spec: String = serde_json::to_string(&json!({
|
||||
"kind": "llama",
|
||||
"api_endpoint": api_endpoint(server.port()),
|
||||
}))
|
||||
.expect("Failed to serialize model spec");
|
||||
let (completion, _, _) = http_api_bindings::create(&model_spec);
|
||||
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> Self {
|
||||
let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism);
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(api_endpoint(server.port()))
|
||||
.kind("llama.cpp/completion".to_string())
|
||||
.build()
|
||||
.expect("Failed to create HttpModelConfig");
|
||||
let completion = http_api_bindings::create(&config);
|
||||
Self { server, completion }
|
||||
}
|
||||
}
|
||||
@ -70,17 +71,17 @@ impl CompletionStream for CompletionServer {
|
||||
}
|
||||
|
||||
pub async fn create_embedding(
|
||||
use_gpu: bool,
|
||||
num_gpu_layers: u16,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> Arc<dyn Embedding> {
|
||||
Arc::new(EmbeddingServer::new(use_gpu, model_path, parallelism).await)
|
||||
Arc::new(EmbeddingServer::new(num_gpu_layers, model_path, parallelism).await)
|
||||
}
|
||||
|
||||
pub async fn create_completion(
|
||||
use_gpu: bool,
|
||||
num_gpu_layers: u16,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> Arc<dyn CompletionStream> {
|
||||
Arc::new(CompletionServer::new(use_gpu, model_path, parallelism).await)
|
||||
Arc::new(CompletionServer::new(num_gpu_layers, model_path, parallelism).await)
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ pub struct LlamaCppSupervisor {
|
||||
|
||||
impl LlamaCppSupervisor {
|
||||
pub fn new(
|
||||
use_gpu: bool,
|
||||
num_gpu_layers: u16,
|
||||
embedding: bool,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
@ -56,10 +56,8 @@ impl LlamaCppSupervisor {
|
||||
command.arg("-t").arg(n_threads);
|
||||
}
|
||||
|
||||
if use_gpu {
|
||||
let num_gpu_layers =
|
||||
std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
|
||||
command.arg("-ngl").arg(&num_gpu_layers);
|
||||
if num_gpu_layers > 0 {
|
||||
command.arg("-ngl").arg(num_gpu_layers.to_string());
|
||||
}
|
||||
|
||||
if embedding {
|
||||
|
@ -2,6 +2,7 @@ use std::{collections::HashSet, path::PathBuf};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use derive_builder::Builder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
@ -9,13 +10,16 @@ use crate::{
|
||||
terminal::{HeaderFormat, InfoMessage},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct Config {
|
||||
#[serde(default)]
|
||||
pub repositories: Vec<RepositoryConfig>,
|
||||
|
||||
#[serde(default)]
|
||||
pub server: ServerConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub model: ModelConfigGroup,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -115,7 +119,7 @@ fn sanitize_name(s: &str) -> String {
|
||||
sanitized.into_iter().collect()
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
/// The timeout in seconds for the /v1/completion api.
|
||||
pub completion_timeout: u64,
|
||||
@ -129,6 +133,56 @@ impl Default for ServerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
pub struct ModelConfigGroup {
|
||||
pub completion: Option<ModelConfig>,
|
||||
pub chat: Option<ModelConfig>,
|
||||
pub embedding: Option<ModelConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ModelConfig {
|
||||
Http(HttpModelConfig),
|
||||
Local(LocalModelConfig),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Builder, Clone)]
|
||||
pub struct HttpModelConfig {
|
||||
pub api_endpoint: String,
|
||||
pub kind: String,
|
||||
|
||||
#[builder(default)]
|
||||
pub api_key: Option<String>,
|
||||
|
||||
/// Used by chat http endpoint to select model.
|
||||
#[builder(default)]
|
||||
pub model_name: Option<String>,
|
||||
|
||||
/// Used by completion http endpoint to construct FIM prompt.
|
||||
#[builder(default)]
|
||||
pub prompt_template: Option<String>,
|
||||
|
||||
/// Used by completion http endpoint to construct Chat prompt.
|
||||
#[builder(default)]
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct LocalModelConfig {
|
||||
pub model_id: String,
|
||||
|
||||
#[serde(default = "default_parallelism")]
|
||||
pub parallelism: u8,
|
||||
|
||||
#[serde(default)]
|
||||
pub num_gpu_layers: u16,
|
||||
}
|
||||
|
||||
fn default_parallelism() -> u8 {
|
||||
1
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait RepositoryAccess: Send + Sync {
|
||||
async fn list_repositories(&self) -> Result<Vec<RepositoryConfig>>;
|
||||
|
@ -15,7 +15,7 @@ async fn main() {
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
|
||||
let mut doc_index = DocIndex::new(create_embedding(false, &model_path, 1).await);
|
||||
let mut doc_index = DocIndex::new(create_embedding(0, &model_path, 1).await);
|
||||
let mut cnt = 0;
|
||||
stream! {
|
||||
for await doc in crawl_pipeline("https://tabby.tabbyml.com/").await {
|
||||
|
@ -11,7 +11,7 @@ mod worker;
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use tabby_common::config::{Config, ConfigRepositoryAccess};
|
||||
use tabby_common::config::{Config, ConfigRepositoryAccess, LocalModelConfig, ModelConfig};
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer};
|
||||
|
||||
@ -72,10 +72,6 @@ pub enum Device {
|
||||
|
||||
#[strum(serialize = "vulkan")]
|
||||
Vulkan,
|
||||
|
||||
#[strum(serialize = "experimental_http")]
|
||||
#[clap(hide = true)]
|
||||
ExperimentalHttp,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -103,11 +99,16 @@ async fn main() {
|
||||
}
|
||||
#[cfg(feature = "ee")]
|
||||
Commands::WorkerCompletion(ref args) => {
|
||||
worker::main(tabby_webserver::public::WorkerKind::Completion, args).await
|
||||
worker::main(
|
||||
&config,
|
||||
tabby_webserver::public::WorkerKind::Completion,
|
||||
args,
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(feature = "ee")]
|
||||
Commands::WorkerChat(ref args) => {
|
||||
worker::main(tabby_webserver::public::WorkerKind::Chat, args).await
|
||||
worker::main(&config, tabby_webserver::public::WorkerKind::Chat, args).await
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -158,3 +159,20 @@ fn init_logging() {
|
||||
.with(env_filter)
|
||||
.init();
|
||||
}
|
||||
|
||||
fn to_local_config(model: &str, parallelism: u8, device: &Device) -> ModelConfig {
|
||||
let num_gpu_layers = if *device != Device::Cpu {
|
||||
std::env::var("LLAMA_CPP_N_GPU_LAYERS")
|
||||
.map(|s| s.parse::<u16>().ok())
|
||||
.ok()
|
||||
.flatten()
|
||||
.unwrap_or(9999)
|
||||
} else {
|
||||
9999
|
||||
};
|
||||
ModelConfig::Local(LocalModelConfig {
|
||||
model_id: model.to_owned(),
|
||||
parallelism,
|
||||
num_gpu_layers,
|
||||
})
|
||||
}
|
||||
|
@ -4,9 +4,8 @@ use axum::{routing, Router};
|
||||
use clap::Args;
|
||||
use hyper::StatusCode;
|
||||
use tabby_common::{
|
||||
api,
|
||||
api::{code::CodeSearch, event::EventLogger},
|
||||
config::{Config, ConfigRepositoryAccess, RepositoryAccess},
|
||||
api::{self, code::CodeSearch, event::EventLogger},
|
||||
config::{Config, ConfigRepositoryAccess, ModelConfig, RepositoryAccess},
|
||||
usage,
|
||||
};
|
||||
use tokio::time::sleep;
|
||||
@ -30,7 +29,7 @@ use crate::{
|
||||
health,
|
||||
model::download_model_if_needed,
|
||||
},
|
||||
Device,
|
||||
to_local_config, Device,
|
||||
};
|
||||
|
||||
#[derive(OpenApi)]
|
||||
@ -92,9 +91,6 @@ pub struct ServeArgs {
|
||||
#[clap(long)]
|
||||
chat_model: Option<String>,
|
||||
|
||||
#[clap(long, hide = true)]
|
||||
embedding_model: Option<String>,
|
||||
|
||||
#[clap(long, default_value = "0.0.0.0")]
|
||||
host: IpAddr,
|
||||
|
||||
@ -125,7 +121,9 @@ pub struct ServeArgs {
|
||||
}
|
||||
|
||||
pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
load_model(args).await;
|
||||
let config = merge_args(config, args);
|
||||
|
||||
load_model(&config).await;
|
||||
|
||||
debug!("Starting server, this might take a few minutes...");
|
||||
|
||||
@ -160,7 +158,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
}
|
||||
|
||||
let code = Arc::new(create_code_search(repository_access));
|
||||
let mut api = api_router(args, config, logger.clone(), code.clone(), webserver).await;
|
||||
let mut api = api_router(args, &config, logger.clone(), code.clone(), webserver).await;
|
||||
let mut ui = Router::new()
|
||||
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", ApiDoc::openapi()))
|
||||
.fallback(|| async { axum::response::Redirect::temporary("/swagger-ui") });
|
||||
@ -172,26 +170,21 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
ui = new_ui;
|
||||
};
|
||||
|
||||
start_heartbeat(args, webserver);
|
||||
start_heartbeat(args, &config, webserver);
|
||||
run_app(api, Some(ui), args.host, args.port).await
|
||||
}
|
||||
|
||||
async fn load_model(args: &ServeArgs) {
|
||||
if args.device != Device::ExperimentalHttp {
|
||||
if let Some(model) = &args.model {
|
||||
download_model_if_needed(model).await;
|
||||
}
|
||||
async fn load_model(config: &Config) {
|
||||
if let Some(ModelConfig::Local(ref model)) = config.model.completion {
|
||||
download_model_if_needed(&model.model_id).await;
|
||||
}
|
||||
|
||||
let chat_device = args.chat_device.as_ref().unwrap_or(&args.device);
|
||||
if chat_device != &Device::ExperimentalHttp {
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
download_model_if_needed(chat_model).await
|
||||
}
|
||||
if let Some(ModelConfig::Local(ref model)) = config.model.chat {
|
||||
download_model_if_needed(&model.model_id).await;
|
||||
}
|
||||
|
||||
if let Some(embedding_model) = &args.embedding_model {
|
||||
download_model_if_needed(embedding_model).await
|
||||
if let Some(ModelConfig::Local(ref model)) = config.model.embedding {
|
||||
download_model_if_needed(&model.model_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
@ -202,37 +195,23 @@ async fn api_router(
|
||||
code: Arc<dyn CodeSearch>,
|
||||
webserver: Option<bool>,
|
||||
) -> Router {
|
||||
let completion_state = if let Some(model) = &args.model {
|
||||
let model = &config.model;
|
||||
let completion_state = if let Some(completion) = &model.completion {
|
||||
Some(Arc::new(
|
||||
create_completion_service(
|
||||
code.clone(),
|
||||
logger.clone(),
|
||||
model,
|
||||
&args.device,
|
||||
args.parallelism,
|
||||
)
|
||||
.await,
|
||||
create_completion_service(code.clone(), logger.clone(), completion).await,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let chat_state = if let Some(chat_model) = &args.chat_model {
|
||||
Some(Arc::new(
|
||||
create_chat_service(
|
||||
logger.clone(),
|
||||
chat_model,
|
||||
args.chat_device.as_ref().unwrap_or(&args.device),
|
||||
args.parallelism,
|
||||
)
|
||||
.await,
|
||||
))
|
||||
let chat_state = if let Some(chat) = &model.chat {
|
||||
Some(Arc::new(create_chat_service(logger.clone(), chat).await))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let docsearch_state = if let Some(embedding_model) = &args.embedding_model {
|
||||
let embedding = embedding::create(embedding_model, &args.device).await;
|
||||
let docsearch_state = if let Some(embedding) = &model.embedding {
|
||||
let embedding = embedding::create(embedding).await;
|
||||
Some(Arc::new(services::doc::create(embedding)))
|
||||
} else {
|
||||
None
|
||||
@ -249,9 +228,8 @@ async fn api_router(
|
||||
let mut routers = vec![];
|
||||
|
||||
let health_state = Arc::new(health::HealthState::new(
|
||||
args.model.as_deref(),
|
||||
model,
|
||||
&args.device,
|
||||
args.chat_model.as_deref(),
|
||||
args.chat_model
|
||||
.as_deref()
|
||||
.map(|_| args.chat_device.as_ref().unwrap_or(&args.device)),
|
||||
@ -379,16 +357,15 @@ async fn api_router(
|
||||
root
|
||||
}
|
||||
|
||||
fn start_heartbeat(args: &ServeArgs, webserver: Option<bool>) {
|
||||
let state = health::HealthState::new(
|
||||
args.model.as_deref(),
|
||||
fn start_heartbeat(args: &ServeArgs, config: &Config, webserver: Option<bool>) {
|
||||
let state = Arc::new(health::HealthState::new(
|
||||
&config.model,
|
||||
&args.device,
|
||||
args.chat_model.as_deref(),
|
||||
args.chat_model
|
||||
.as_deref()
|
||||
.map(|_| args.chat_device.as_ref().unwrap_or(&args.device)),
|
||||
webserver,
|
||||
);
|
||||
));
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
usage::capture("ServeHealth", &state).await;
|
||||
@ -414,3 +391,20 @@ impl Modify for SecurityAddon {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_args(config: &Config, args: &ServeArgs) -> Config {
|
||||
let mut config = (*config).clone();
|
||||
if let Some(model) = &args.model {
|
||||
config.model.completion = Some(to_local_config(model, args.parallelism, &args.device));
|
||||
};
|
||||
|
||||
if let Some(chat_model) = &args.chat_model {
|
||||
config.model.chat = Some(to_local_config(
|
||||
chat_model,
|
||||
args.parallelism,
|
||||
args.chat_device.as_ref().unwrap_or(&args.device),
|
||||
));
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
@ -4,9 +4,12 @@ use async_stream::stream;
|
||||
use derive_builder::Builder;
|
||||
use futures::stream::BoxStream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::api::{
|
||||
chat::Message,
|
||||
event::{Event, EventLogger},
|
||||
use tabby_common::{
|
||||
api::{
|
||||
chat::Message,
|
||||
event::{Event, EventLogger},
|
||||
},
|
||||
config::ModelConfig,
|
||||
};
|
||||
use tabby_inference::{ChatCompletionOptionsBuilder, ChatCompletionStream};
|
||||
use tracing::warn;
|
||||
@ -14,7 +17,6 @@ use utoipa::ToSchema;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::model;
|
||||
use crate::Device;
|
||||
|
||||
#[derive(Serialize, Deserialize, ToSchema, Clone, Builder, Debug)]
|
||||
#[schema(example=json!({
|
||||
@ -156,13 +158,8 @@ fn convert_messages(input: &[Message]) -> Vec<tabby_common::api::event::Message>
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn create_chat_service(
|
||||
logger: Arc<dyn EventLogger>,
|
||||
model: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> ChatService {
|
||||
let engine = model::load_chat_completion(model, device, parallelism).await;
|
||||
pub async fn create_chat_service(logger: Arc<dyn EventLogger>, chat: &ModelConfig) -> ChatService {
|
||||
let engine = model::load_chat_completion(chat).await;
|
||||
|
||||
ChatService::new(engine, logger)
|
||||
}
|
||||
|
@ -4,11 +4,12 @@ use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::{
|
||||
api,
|
||||
api::{
|
||||
self,
|
||||
code::CodeSearch,
|
||||
event::{Event, EventLogger},
|
||||
},
|
||||
config::ModelConfig,
|
||||
languages::get_language,
|
||||
};
|
||||
use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
|
||||
@ -16,7 +17,6 @@ use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use super::model;
|
||||
use crate::Device;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CompletionError {
|
||||
@ -342,16 +342,14 @@ impl CompletionService {
|
||||
pub async fn create_completion_service(
|
||||
code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
model: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
model: &ModelConfig,
|
||||
) -> CompletionService {
|
||||
let (
|
||||
engine,
|
||||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = model::load_code_generation(model, device, parallelism).await;
|
||||
) = model::load_code_generation(model).await;
|
||||
|
||||
CompletionService::new(engine.clone(), code, logger, prompt_template)
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tabby_common::config::ModelConfig;
|
||||
use tabby_inference::Embedding;
|
||||
|
||||
use super::model;
|
||||
use crate::Device;
|
||||
|
||||
pub async fn create(model: &str, device: &Device) -> Arc<dyn Embedding> {
|
||||
model::load_embedding(model, device).await
|
||||
pub async fn create(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||
model::load_embedding(config).await
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ use anyhow::Result;
|
||||
use nvml_wrapper::Nvml;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sysinfo::{CpuExt, System, SystemExt};
|
||||
use tabby_common::config::{ModelConfig, ModelConfigGroup};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::Device;
|
||||
@ -27,9 +28,8 @@ pub struct HealthState {
|
||||
|
||||
impl HealthState {
|
||||
pub fn new(
|
||||
model: Option<&str>,
|
||||
model_config: &ModelConfigGroup,
|
||||
device: &Device,
|
||||
chat_model: Option<&str>,
|
||||
chat_device: Option<&Device>,
|
||||
webserver: Option<bool>,
|
||||
) -> Self {
|
||||
@ -40,24 +40,9 @@ impl HealthState {
|
||||
Err(_) => vec![],
|
||||
};
|
||||
|
||||
let http_model_name = Some("Remote");
|
||||
let is_model_http = device == &Device::ExperimentalHttp;
|
||||
let model = if is_model_http {
|
||||
http_model_name
|
||||
} else {
|
||||
model
|
||||
};
|
||||
|
||||
let is_chat_model_http = chat_device == Some(&Device::ExperimentalHttp);
|
||||
let chat_model = if is_chat_model_http {
|
||||
http_model_name
|
||||
} else {
|
||||
chat_model
|
||||
};
|
||||
|
||||
Self {
|
||||
model: model.map(|x| x.to_string()),
|
||||
chat_model: chat_model.map(|x| x.to_owned()),
|
||||
model: to_model_name(&model_config.completion),
|
||||
chat_model: to_model_name(&model_config.chat),
|
||||
chat_device: chat_device.map(|x| x.to_string()),
|
||||
device: device.to_string(),
|
||||
arch: ARCH.to_string(),
|
||||
@ -70,6 +55,17 @@ impl HealthState {
|
||||
}
|
||||
}
|
||||
|
||||
fn to_model_name(model: &Option<ModelConfig>) -> Option<String> {
|
||||
if let Some(model) = model {
|
||||
match model {
|
||||
ModelConfig::Http(_http) => Some("Remote".to_owned()),
|
||||
ModelConfig::Local(llama) => Some(llama.model_id.clone()),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_cpu_info() -> (String, usize) {
|
||||
let mut system = System::new_all();
|
||||
system.refresh_cpu();
|
||||
|
@ -3,98 +3,101 @@ mod chat;
|
||||
use std::{fs, path::PathBuf, sync::Arc};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};
|
||||
use tabby_common::{
|
||||
config::ModelConfig,
|
||||
registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH},
|
||||
};
|
||||
use tabby_download::download_model;
|
||||
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{fatal, Device};
|
||||
use crate::fatal;
|
||||
|
||||
pub async fn load_chat_completion(
|
||||
model_id: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> Arc<dyn ChatCompletionStream> {
|
||||
if device == &Device::ExperimentalHttp {
|
||||
return http_api_bindings::create_chat(model_id);
|
||||
}
|
||||
pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionStream> {
|
||||
match chat {
|
||||
ModelConfig::Http(http) => http_api_bindings::create_chat(http),
|
||||
|
||||
let (engine, PromptInfo { chat_template, .. }) =
|
||||
load_completion(model_id, device, parallelism).await;
|
||||
ModelConfig::Local(_) => {
|
||||
let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await;
|
||||
|
||||
let Some(chat_template) = chat_template else {
|
||||
fatal!("Chat model requires specifying prompt template");
|
||||
};
|
||||
let Some(chat_template) = chat_template else {
|
||||
fatal!("Chat model requires specifying prompt template");
|
||||
};
|
||||
|
||||
Arc::new(chat::make_chat_completion(engine, chat_template))
|
||||
}
|
||||
|
||||
pub async fn load_embedding(model_id: &str, device: &Device) -> Arc<dyn Embedding> {
|
||||
if device == &Device::ExperimentalHttp {
|
||||
return http_api_bindings::create_embedding(model_id);
|
||||
}
|
||||
|
||||
if fs::metadata(model_id).is_ok() {
|
||||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
create_ggml_embedding_engine(model_path.display().to_string().as_str()).await
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
create_ggml_embedding_engine(&model_path).await
|
||||
Arc::new(chat::make_chat_completion(engine, chat_template))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_code_generation(
|
||||
model_id: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> (Arc<CodeGeneration>, PromptInfo) {
|
||||
let (engine, prompt_info) = load_completion(model_id, device, parallelism).await;
|
||||
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||
match config {
|
||||
ModelConfig::Http(http) => http_api_bindings::create_embedding(http),
|
||||
ModelConfig::Local(llama) => {
|
||||
if fs::metadata(&llama.model_id).is_ok() {
|
||||
let path = PathBuf::from(&llama.model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
create_ggml_embedding_engine(
|
||||
model_path.display().to_string().as_str(),
|
||||
llama.parallelism,
|
||||
llama.num_gpu_layers,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(&llama.model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
create_ggml_embedding_engine(&model_path, llama.parallelism, llama.num_gpu_layers)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_code_generation(model: &ModelConfig) -> (Arc<CodeGeneration>, PromptInfo) {
|
||||
let (engine, prompt_info) = load_completion(model).await;
|
||||
(Arc::new(CodeGeneration::new(engine)), prompt_info)
|
||||
}
|
||||
|
||||
async fn load_completion(
|
||||
model_id: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
) -> (Arc<dyn CompletionStream>, PromptInfo) {
|
||||
if device == &Device::ExperimentalHttp {
|
||||
let (engine, prompt_template, chat_template) = http_api_bindings::create(model_id);
|
||||
return (
|
||||
engine,
|
||||
PromptInfo {
|
||||
prompt_template,
|
||||
chat_template,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if fs::metadata(model_id).is_ok() {
|
||||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine = create_ggml_engine(
|
||||
device,
|
||||
model_path.display().to_string().as_str(),
|
||||
parallelism,
|
||||
)
|
||||
.await;
|
||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||
(engine, engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(device, &model_path, parallelism).await;
|
||||
(
|
||||
engine,
|
||||
PromptInfo {
|
||||
prompt_template: model_info.prompt_template.clone(),
|
||||
chat_template: model_info.chat_template.clone(),
|
||||
},
|
||||
)
|
||||
async fn load_completion(model: &ModelConfig) -> (Arc<dyn CompletionStream>, PromptInfo) {
|
||||
match model {
|
||||
ModelConfig::Http(http) => {
|
||||
let engine = http_api_bindings::create(http);
|
||||
(
|
||||
engine,
|
||||
PromptInfo {
|
||||
prompt_template: http.prompt_template.clone(),
|
||||
chat_template: http.chat_template.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
ModelConfig::Local(llama) => {
|
||||
if fs::metadata(&llama.model_id).is_ok() {
|
||||
let path = PathBuf::from(&llama.model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine = create_ggml_engine(
|
||||
llama.num_gpu_layers,
|
||||
model_path.display().to_string().as_str(),
|
||||
llama.parallelism,
|
||||
)
|
||||
.await;
|
||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||
(engine, engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(&llama.model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine =
|
||||
create_ggml_engine(llama.num_gpu_layers, &model_path, llama.parallelism).await;
|
||||
(
|
||||
engine,
|
||||
PromptInfo {
|
||||
prompt_template: model_info.prompt_template.clone(),
|
||||
chat_template: model_info.chat_template.clone(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,16 +115,20 @@ impl PromptInfo {
|
||||
}
|
||||
|
||||
async fn create_ggml_engine(
|
||||
device: &Device,
|
||||
num_gpu_layers: u16,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
) -> Arc<dyn CompletionStream> {
|
||||
llama_cpp_server::create_completion(device != &Device::Cpu, model_path, parallelism).await
|
||||
llama_cpp_server::create_completion(num_gpu_layers, model_path, parallelism).await
|
||||
}
|
||||
|
||||
async fn create_ggml_embedding_engine(model_path: &str) -> Arc<dyn Embedding> {
|
||||
async fn create_ggml_embedding_engine(
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
num_gpu_layers: u16,
|
||||
) -> Arc<dyn Embedding> {
|
||||
// By default, embedding always use CPU device with 1 parallelism.
|
||||
llama_cpp_server::create_embedding(false, model_path, 1).await
|
||||
llama_cpp_server::create_embedding(num_gpu_layers, model_path, parallelism).await
|
||||
}
|
||||
|
||||
pub async fn download_model_if_needed(model: &str) {
|
||||
|
@ -2,7 +2,10 @@ use std::{env::consts::ARCH, net::IpAddr, sync::Arc};
|
||||
|
||||
use axum::{routing, Router};
|
||||
use clap::Args;
|
||||
use tabby_common::api::{code::CodeSearch, event::EventLogger};
|
||||
use tabby_common::{
|
||||
api::{code::CodeSearch, event::EventLogger},
|
||||
config::Config,
|
||||
};
|
||||
use tabby_webserver::public::{RegisterWorkerRequest, WorkerClient, WorkerKind};
|
||||
use tracing::info;
|
||||
|
||||
@ -14,7 +17,7 @@ use crate::{
|
||||
health::{read_cpu_info, read_cuda_devices},
|
||||
model::download_model_if_needed,
|
||||
},
|
||||
Device,
|
||||
to_local_config, Device,
|
||||
};
|
||||
|
||||
#[derive(Args)]
|
||||
@ -47,9 +50,18 @@ pub struct WorkerArgs {
|
||||
parallelism: u8,
|
||||
}
|
||||
|
||||
async fn make_chat_route(logger: Arc<dyn EventLogger>, args: &WorkerArgs) -> Router {
|
||||
let chat_state =
|
||||
Arc::new(create_chat_service(logger, &args.model, &args.device, args.parallelism).await);
|
||||
async fn make_chat_route(logger: Arc<dyn EventLogger>, config: &Config) -> Router {
|
||||
let chat_state = Arc::new(
|
||||
create_chat_service(
|
||||
logger,
|
||||
config
|
||||
.model
|
||||
.chat
|
||||
.as_ref()
|
||||
.expect("Chat model config is missing"),
|
||||
)
|
||||
.await,
|
||||
);
|
||||
|
||||
Router::new()
|
||||
.route(
|
||||
@ -65,10 +77,19 @@ async fn make_chat_route(logger: Arc<dyn EventLogger>, args: &WorkerArgs) -> Rou
|
||||
async fn make_completion_route(
|
||||
code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
args: &WorkerArgs,
|
||||
config: &Config,
|
||||
) -> Router {
|
||||
let completion_state = Arc::new(
|
||||
create_completion_service(code, logger, &args.model, &args.device, args.parallelism).await,
|
||||
create_completion_service(
|
||||
code,
|
||||
logger,
|
||||
config
|
||||
.model
|
||||
.completion
|
||||
.as_ref()
|
||||
.expect("Completion model config is missing"),
|
||||
)
|
||||
.await,
|
||||
);
|
||||
|
||||
Router::new().route(
|
||||
@ -77,8 +98,9 @@ async fn make_completion_route(
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
|
||||
pub async fn main(config: &Config, kind: WorkerKind, args: &WorkerArgs) {
|
||||
download_model_if_needed(&args.model).await;
|
||||
let config = merge_args(config, args, &kind);
|
||||
|
||||
info!("Starting worker, this might take a few minutes...");
|
||||
|
||||
@ -87,8 +109,8 @@ pub async fn main(kind: WorkerKind, args: &WorkerArgs) {
|
||||
let logger = code.clone();
|
||||
|
||||
let app = match kind {
|
||||
WorkerKind::Completion => make_completion_route(code, logger, args).await,
|
||||
WorkerKind::Chat => make_chat_route(logger.clone(), args).await,
|
||||
WorkerKind::Completion => make_completion_route(code, logger, &config).await,
|
||||
WorkerKind::Chat => make_chat_route(logger.clone(), &config).await,
|
||||
};
|
||||
|
||||
run_app(app, None, args.host, args.port).await
|
||||
@ -122,3 +144,15 @@ impl WorkerContext {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_args(config: &Config, args: &WorkerArgs, kind: &WorkerKind) -> Config {
|
||||
let mut config = (*config).clone();
|
||||
let override_config = Some(to_local_config(&args.model, args.parallelism, &args.device));
|
||||
|
||||
match kind {
|
||||
WorkerKind::Chat => config.model.chat = override_config,
|
||||
WorkerKind::Completion => config.model.completion = override_config,
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user