From 1632dd054e0aa9776916d4f2a17d389f394e7b1b Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Sun, 11 Aug 2024 13:51:13 +0800 Subject: [PATCH] feat(llama-cpp-server): reuse local server for completion and chat (#2812) * :sparkles: support reuse the same server if model id is equaled Signed-off-by: Wei Zhang * :art: fix review Signed-off-by: Wei Zhang * :hammer: should compare model in llama cpp server Signed-off-by: Wei Zhang * match case for creating completion and chat Signed-off-by: Wei Zhang * if let case for creating completion and chat Signed-off-by: Wei Zhang * add changelog * rebase && make fix --------- Signed-off-by: Wei Zhang Co-authored-by: Meng Zhang --- ...ixed and Improvements-20240810-221045.yaml | 4 + crates/llama-cpp-server/src/lib.rs | 59 ++++++++++++- crates/tabby-common/src/config.rs | 2 +- crates/tabby/src/serve.rs | 36 ++++---- crates/tabby/src/services/completion.rs | 42 +++++---- crates/tabby/src/services/model/mod.rs | 85 +++++++++++++------ 6 files changed, 164 insertions(+), 64 deletions(-) create mode 100644 .changes/unreleased/Fixed and Improvements-20240810-221045.yaml diff --git a/.changes/unreleased/Fixed and Improvements-20240810-221045.yaml b/.changes/unreleased/Fixed and Improvements-20240810-221045.yaml new file mode 100644 index 000000000..714f6c014 --- /dev/null +++ b/.changes/unreleased/Fixed and Improvements-20240810-221045.yaml @@ -0,0 +1,4 @@ +kind: Fixed and Improvements +body: When reusing a model for Chat / Completion together (e.g Codestral-22B) with + local model backend, resuse the llama-server sub process. +time: 2024-08-10T22:10:45.454278-07:00 diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 20d63532f..25e612a5a 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -66,7 +66,7 @@ impl Embedding for EmbeddingServer { struct CompletionServer { #[allow(unused)] - server: LlamaCppSupervisor, + server: Arc, completion: Arc, } @@ -89,6 +89,10 @@ impl CompletionServer { context_size, ); server.start().await; + Self::new_with_supervisor(Arc::new(server)).await + } + + async fn new_with_supervisor(server: Arc) -> Self { let config = HttpModelConfigBuilder::default() .api_endpoint(Some(api_endpoint(server.port()))) .kind("llama.cpp/completion".to_string()) @@ -108,7 +112,7 @@ impl CompletionStream for CompletionServer { struct ChatCompletionServer { #[allow(unused)] - server: LlamaCppSupervisor, + server: Arc, chat_completion: Arc, } @@ -132,6 +136,10 @@ impl ChatCompletionServer { context_size, ); server.start().await; + Self::new_with_supervisor(Arc::new(server)).await + } + + async fn new_with_supervisor(server: Arc) -> Self { let config = HttpModelConfigBuilder::default() .api_endpoint(Some(api_endpoint(server.port()))) .kind("openai/chat".to_string()) @@ -202,6 +210,53 @@ pub async fn create_completion( (stream, prompt_info) } +pub async fn create_completion_and_chat( + completion_model: &LocalModelConfig, + chat_model: &LocalModelConfig, +) -> ( + Arc, + PromptInfo, + Arc, +) { + let chat_model_path = resolve_model_path(&chat_model.model_id).await; + let chat_template = resolve_prompt_info(&chat_model.model_id) + .await + .chat_template + .unwrap_or_else(|| panic!("Chat model requires specifying prompt template")); + + let model_path = resolve_model_path(&completion_model.model_id).await; + let prompt_info = resolve_prompt_info(&completion_model.model_id).await; + + let server = Arc::new(LlamaCppSupervisor::new( + "chat", + chat_model.num_gpu_layers, + false, + &chat_model_path, + chat_model.parallelism, + Some(chat_template), + chat_model.enable_fast_attention.unwrap_or_default(), + chat_model.context_size, + )); + server.start().await; + + let chat = ChatCompletionServer::new_with_supervisor(server.clone()).await; + + let completion = if completion_model == chat_model { + CompletionServer::new_with_supervisor(server).await + } else { + CompletionServer::new( + completion_model.num_gpu_layers, + &model_path, + completion_model.parallelism, + completion_model.enable_fast_attention.unwrap_or_default(), + completion_model.context_size, + ) + .await + }; + + (Arc::new(completion), prompt_info, Arc::new(chat)) +} + pub async fn create_embedding(config: &ModelConfig) -> Arc { match config { ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await, diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 14cd5e999..53639099f 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -250,7 +250,7 @@ pub struct HttpModelConfig { pub chat_template: Option, } -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub struct LocalModelConfig { pub model_id: String, diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index f161c9fd5..ce7e68670 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -24,11 +24,11 @@ use crate::{ services::{ self, code::create_code_search, - completion::{self, create_completion_service}, + completion::{self, create_completion_service_and_chat, CompletionService}, embedding, event::create_event_logger, health, - model::{self, download_model_if_needed}, + model::download_model_if_needed, tantivy::IndexReaderProvider, }, to_local_config, Device, @@ -171,17 +171,22 @@ pub async fn main(config: &Config, args: &ServeArgs) { index_reader_provider.clone(), )); - let chat = if let Some(chat) = &config.model.chat { - Some(model::load_chat_completion(chat).await) - } else { - None - }; + let model = &config.model; + let (completion, chat) = create_completion_service_and_chat( + &config.completion, + code.clone(), + logger.clone(), + model.completion.clone(), + model.chat.clone(), + ) + .await; let mut api = api_router( args, &config, logger.clone(), code.clone(), + completion, chat.clone(), webserver, ) @@ -227,24 +232,15 @@ async fn api_router( args: &ServeArgs, config: &Config, logger: Arc, - code: Arc, + _code: Arc, + completion_state: Option, chat_state: Option>, webserver: Option, ) -> Router { - let model = &config.model; - let completion_state = if let Some(completion) = &model.completion { - Some(Arc::new( - create_completion_service(&config.completion, code.clone(), logger.clone(), completion) - .await, - )) - } else { - None - }; - let mut routers = vec![]; let health_state = Arc::new(health::HealthState::new( - model, + &config.model, &args.device, args.chat_model .as_deref() @@ -273,7 +269,7 @@ async fn api_router( Router::new() .route( "/v1/completions", - routing::post(routes::completions).with_state(completion_state), + routing::post(routes::completions).with_state(Arc::new(completion_state)), ) .layer(TimeoutLayer::new(Duration::from_secs( config.server.completion_timeout, diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index 80c10dbd8..f6131b91a 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -12,7 +12,9 @@ use tabby_common::{ config::{CompletionConfig, ModelConfig}, languages::get_language, }; -use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder}; +use tabby_inference::{ + ChatCompletionStream, CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder, +}; use thiserror::Error; use utoipa::ToSchema; @@ -352,26 +354,32 @@ impl CompletionService { } } -pub async fn create_completion_service( +pub async fn create_completion_service_and_chat( config: &CompletionConfig, code: Arc, logger: Arc, - model: &ModelConfig, -) -> CompletionService { - let ( - engine, - model::PromptInfo { - prompt_template, .. - }, - ) = model::load_code_generation(model).await; + completion: Option, + chat: Option, +) -> ( + Option, + Option>, +) { + let (code_generation, prompt, chat) = + model::load_code_generation_and_chat(completion, chat).await; - CompletionService::new( - config.to_owned(), - engine.clone(), - code, - logger, - prompt_template, - ) + let completion = code_generation.map(|code_generation| { + CompletionService::new( + config.to_owned(), + code_generation.clone(), + code, + logger, + prompt + .unwrap_or_else(|| panic!("Prompt template is required for code completion")) + .prompt_template, + ) + }); + + (completion, chat) } #[cfg(test)] diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 690597022..d72aba0d7 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -6,37 +6,74 @@ use tabby_download::download_model; use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding}; use tracing::info; -pub async fn load_chat_completion(chat: &ModelConfig) -> Arc { - match chat { - ModelConfig::Http(http) => http_api_bindings::create_chat(http).await, - ModelConfig::Local(llama) => llama_cpp_server::create_chat_completion(llama).await, - } -} - pub async fn load_embedding(config: &ModelConfig) -> Arc { llama_cpp_server::create_embedding(config).await } -pub async fn load_code_generation(model: &ModelConfig) -> (Arc, PromptInfo) { - let (engine, prompt_info) = load_completion(model).await; - (Arc::new(CodeGeneration::new(engine)), prompt_info) +pub async fn load_code_generation_and_chat( + completion_model: Option, + chat_model: Option, +) -> ( + Option>, + Option, + Option>, +) { + let (engine, prompt_info, chat) = load_completion_and_chat(completion_model, chat_model).await; + let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine))); + (code, prompt_info, chat) } -async fn load_completion(model: &ModelConfig) -> (Arc, PromptInfo) { - match model { - ModelConfig::Http(http) => { - let engine = http_api_bindings::create(http).await; - let (prompt_template, chat_template) = http_api_bindings::build_completion_prompt(http); - ( - engine, - PromptInfo { - prompt_template, - chat_template, - }, - ) - } - ModelConfig::Local(llama) => llama_cpp_server::create_completion(llama).await, +async fn load_completion_and_chat( + completion_model: Option, + chat_model: Option, +) -> ( + Option>, + Option, + Option>, +) { + if let (Some(ModelConfig::Local(completion)), Some(ModelConfig::Local(chat))) = + (&completion_model, &chat_model) + { + let (completion, prompt, chat) = + llama_cpp_server::create_completion_and_chat(completion, chat).await; + return (Some(completion), Some(prompt), Some(chat)); } + + let (completion, prompt) = if let Some(completion_model) = completion_model { + match completion_model { + ModelConfig::Http(http) => { + let engine = http_api_bindings::create(&http).await; + let (prompt_template, chat_template) = + http_api_bindings::build_completion_prompt(&http); + ( + Some(engine), + Some(PromptInfo { + prompt_template, + chat_template, + }), + ) + } + ModelConfig::Local(llama) => { + let (stream, prompt) = llama_cpp_server::create_completion(&llama).await; + (Some(stream), Some(prompt)) + } + } + } else { + (None, None) + }; + + let chat = if let Some(chat_model) = chat_model { + match chat_model { + ModelConfig::Http(http) => Some(http_api_bindings::create_chat(&http).await), + ModelConfig::Local(llama) => { + Some(llama_cpp_server::create_chat_completion(&llama).await) + } + } + } else { + None + }; + + (completion, prompt, chat) } pub async fn download_model_if_needed(model: &str) {