feat(llama-cpp-server): reuse local server for completion and chat (#2812)

*  support reuse the same server if model id is equaled

Signed-off-by: Wei Zhang <kweizh@gmail.com>

* 🎨 fix review

Signed-off-by: Wei Zhang <kweizh@gmail.com>

* 🔨 should compare model in llama cpp server

Signed-off-by: Wei Zhang <kweizh@gmail.com>

* match case for creating completion and chat

Signed-off-by: Wei Zhang <kweizh@gmail.com>

* if let case for creating completion and chat

Signed-off-by: Wei Zhang <kweizh@gmail.com>

* add changelog

* rebase && make fix

---------

Signed-off-by: Wei Zhang <kweizh@gmail.com>
Co-authored-by: Meng Zhang <meng@tabbyml.com>
This commit is contained in:
Wei Zhang 2024-08-11 13:51:13 +08:00 committed by GitHub
parent 946037ffde
commit 1632dd054e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 164 additions and 64 deletions

View File

@ -0,0 +1,4 @@
kind: Fixed and Improvements
body: When reusing a model for Chat / Completion together (e.g Codestral-22B) with
local model backend, resuse the llama-server sub process.
time: 2024-08-10T22:10:45.454278-07:00

View File

@ -66,7 +66,7 @@ impl Embedding for EmbeddingServer {
struct CompletionServer {
#[allow(unused)]
server: LlamaCppSupervisor,
server: Arc<LlamaCppSupervisor>,
completion: Arc<dyn CompletionStream>,
}
@ -89,6 +89,10 @@ impl CompletionServer {
context_size,
);
server.start().await;
Self::new_with_supervisor(Arc::new(server)).await
}
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
let config = HttpModelConfigBuilder::default()
.api_endpoint(Some(api_endpoint(server.port())))
.kind("llama.cpp/completion".to_string())
@ -108,7 +112,7 @@ impl CompletionStream for CompletionServer {
struct ChatCompletionServer {
#[allow(unused)]
server: LlamaCppSupervisor,
server: Arc<LlamaCppSupervisor>,
chat_completion: Arc<dyn ChatCompletionStream>,
}
@ -132,6 +136,10 @@ impl ChatCompletionServer {
context_size,
);
server.start().await;
Self::new_with_supervisor(Arc::new(server)).await
}
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
let config = HttpModelConfigBuilder::default()
.api_endpoint(Some(api_endpoint(server.port())))
.kind("openai/chat".to_string())
@ -202,6 +210,53 @@ pub async fn create_completion(
(stream, prompt_info)
}
pub async fn create_completion_and_chat(
completion_model: &LocalModelConfig,
chat_model: &LocalModelConfig,
) -> (
Arc<dyn CompletionStream>,
PromptInfo,
Arc<dyn ChatCompletionStream>,
) {
let chat_model_path = resolve_model_path(&chat_model.model_id).await;
let chat_template = resolve_prompt_info(&chat_model.model_id)
.await
.chat_template
.unwrap_or_else(|| panic!("Chat model requires specifying prompt template"));
let model_path = resolve_model_path(&completion_model.model_id).await;
let prompt_info = resolve_prompt_info(&completion_model.model_id).await;
let server = Arc::new(LlamaCppSupervisor::new(
"chat",
chat_model.num_gpu_layers,
false,
&chat_model_path,
chat_model.parallelism,
Some(chat_template),
chat_model.enable_fast_attention.unwrap_or_default(),
chat_model.context_size,
));
server.start().await;
let chat = ChatCompletionServer::new_with_supervisor(server.clone()).await;
let completion = if completion_model == chat_model {
CompletionServer::new_with_supervisor(server).await
} else {
CompletionServer::new(
completion_model.num_gpu_layers,
&model_path,
completion_model.parallelism,
completion_model.enable_fast_attention.unwrap_or_default(),
completion_model.context_size,
)
.await
};
(Arc::new(completion), prompt_info, Arc::new(chat))
}
pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
match config {
ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await,

View File

@ -250,7 +250,7 @@ pub struct HttpModelConfig {
pub chat_template: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct LocalModelConfig {
pub model_id: String,

View File

@ -24,11 +24,11 @@ use crate::{
services::{
self,
code::create_code_search,
completion::{self, create_completion_service},
completion::{self, create_completion_service_and_chat, CompletionService},
embedding,
event::create_event_logger,
health,
model::{self, download_model_if_needed},
model::download_model_if_needed,
tantivy::IndexReaderProvider,
},
to_local_config, Device,
@ -171,17 +171,22 @@ pub async fn main(config: &Config, args: &ServeArgs) {
index_reader_provider.clone(),
));
let chat = if let Some(chat) = &config.model.chat {
Some(model::load_chat_completion(chat).await)
} else {
None
};
let model = &config.model;
let (completion, chat) = create_completion_service_and_chat(
&config.completion,
code.clone(),
logger.clone(),
model.completion.clone(),
model.chat.clone(),
)
.await;
let mut api = api_router(
args,
&config,
logger.clone(),
code.clone(),
completion,
chat.clone(),
webserver,
)
@ -227,24 +232,15 @@ async fn api_router(
args: &ServeArgs,
config: &Config,
logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>,
_code: Arc<dyn CodeSearch>,
completion_state: Option<CompletionService>,
chat_state: Option<Arc<dyn ChatCompletionStream>>,
webserver: Option<bool>,
) -> Router {
let model = &config.model;
let completion_state = if let Some(completion) = &model.completion {
Some(Arc::new(
create_completion_service(&config.completion, code.clone(), logger.clone(), completion)
.await,
))
} else {
None
};
let mut routers = vec![];
let health_state = Arc::new(health::HealthState::new(
model,
&config.model,
&args.device,
args.chat_model
.as_deref()
@ -273,7 +269,7 @@ async fn api_router(
Router::new()
.route(
"/v1/completions",
routing::post(routes::completions).with_state(completion_state),
routing::post(routes::completions).with_state(Arc::new(completion_state)),
)
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,

View File

@ -12,7 +12,9 @@ use tabby_common::{
config::{CompletionConfig, ModelConfig},
languages::get_language,
};
use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
use tabby_inference::{
ChatCompletionStream, CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder,
};
use thiserror::Error;
use utoipa::ToSchema;
@ -352,26 +354,32 @@ impl CompletionService {
}
}
pub async fn create_completion_service(
pub async fn create_completion_service_and_chat(
config: &CompletionConfig,
code: Arc<dyn CodeSearch>,
logger: Arc<dyn EventLogger>,
model: &ModelConfig,
) -> CompletionService {
let (
engine,
model::PromptInfo {
prompt_template, ..
},
) = model::load_code_generation(model).await;
completion: Option<ModelConfig>,
chat: Option<ModelConfig>,
) -> (
Option<CompletionService>,
Option<Arc<dyn ChatCompletionStream>>,
) {
let (code_generation, prompt, chat) =
model::load_code_generation_and_chat(completion, chat).await;
let completion = code_generation.map(|code_generation| {
CompletionService::new(
config.to_owned(),
engine.clone(),
code_generation.clone(),
code,
logger,
prompt_template,
prompt
.unwrap_or_else(|| panic!("Prompt template is required for code completion"))
.prompt_template,
)
});
(completion, chat)
}
#[cfg(test)]

View File

@ -6,37 +6,74 @@ use tabby_download::download_model;
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
use tracing::info;
pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionStream> {
match chat {
ModelConfig::Http(http) => http_api_bindings::create_chat(http).await,
ModelConfig::Local(llama) => llama_cpp_server::create_chat_completion(llama).await,
}
}
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
llama_cpp_server::create_embedding(config).await
}
pub async fn load_code_generation(model: &ModelConfig) -> (Arc<CodeGeneration>, PromptInfo) {
let (engine, prompt_info) = load_completion(model).await;
(Arc::new(CodeGeneration::new(engine)), prompt_info)
pub async fn load_code_generation_and_chat(
completion_model: Option<ModelConfig>,
chat_model: Option<ModelConfig>,
) -> (
Option<Arc<CodeGeneration>>,
Option<PromptInfo>,
Option<Arc<dyn ChatCompletionStream>>,
) {
let (engine, prompt_info, chat) = load_completion_and_chat(completion_model, chat_model).await;
let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine)));
(code, prompt_info, chat)
}
async fn load_completion(model: &ModelConfig) -> (Arc<dyn CompletionStream>, PromptInfo) {
match model {
async fn load_completion_and_chat(
completion_model: Option<ModelConfig>,
chat_model: Option<ModelConfig>,
) -> (
Option<Arc<dyn CompletionStream>>,
Option<PromptInfo>,
Option<Arc<dyn ChatCompletionStream>>,
) {
if let (Some(ModelConfig::Local(completion)), Some(ModelConfig::Local(chat))) =
(&completion_model, &chat_model)
{
let (completion, prompt, chat) =
llama_cpp_server::create_completion_and_chat(completion, chat).await;
return (Some(completion), Some(prompt), Some(chat));
}
let (completion, prompt) = if let Some(completion_model) = completion_model {
match completion_model {
ModelConfig::Http(http) => {
let engine = http_api_bindings::create(http).await;
let (prompt_template, chat_template) = http_api_bindings::build_completion_prompt(http);
let engine = http_api_bindings::create(&http).await;
let (prompt_template, chat_template) =
http_api_bindings::build_completion_prompt(&http);
(
engine,
PromptInfo {
Some(engine),
Some(PromptInfo {
prompt_template,
chat_template,
},
}),
)
}
ModelConfig::Local(llama) => llama_cpp_server::create_completion(llama).await,
ModelConfig::Local(llama) => {
let (stream, prompt) = llama_cpp_server::create_completion(&llama).await;
(Some(stream), Some(prompt))
}
}
} else {
(None, None)
};
let chat = if let Some(chat_model) = chat_model {
match chat_model {
ModelConfig::Http(http) => Some(http_api_bindings::create_chat(&http).await),
ModelConfig::Local(llama) => {
Some(llama_cpp_server::create_chat_completion(&llama).await)
}
}
} else {
None
};
(completion, prompt, chat)
}
pub async fn download_model_if_needed(model: &str) {