mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
feat(llama-cpp-server): reuse local server for completion and chat (#2812)
* ✨ support reuse the same server if model id is equaled Signed-off-by: Wei Zhang <kweizh@gmail.com> * 🎨 fix review Signed-off-by: Wei Zhang <kweizh@gmail.com> * 🔨 should compare model in llama cpp server Signed-off-by: Wei Zhang <kweizh@gmail.com> * match case for creating completion and chat Signed-off-by: Wei Zhang <kweizh@gmail.com> * if let case for creating completion and chat Signed-off-by: Wei Zhang <kweizh@gmail.com> * add changelog * rebase && make fix --------- Signed-off-by: Wei Zhang <kweizh@gmail.com> Co-authored-by: Meng Zhang <meng@tabbyml.com>
This commit is contained in:
parent
946037ffde
commit
1632dd054e
@ -0,0 +1,4 @@
|
||||
kind: Fixed and Improvements
|
||||
body: When reusing a model for Chat / Completion together (e.g Codestral-22B) with
|
||||
local model backend, resuse the llama-server sub process.
|
||||
time: 2024-08-10T22:10:45.454278-07:00
|
@ -66,7 +66,7 @@ impl Embedding for EmbeddingServer {
|
||||
|
||||
struct CompletionServer {
|
||||
#[allow(unused)]
|
||||
server: LlamaCppSupervisor,
|
||||
server: Arc<LlamaCppSupervisor>,
|
||||
completion: Arc<dyn CompletionStream>,
|
||||
}
|
||||
|
||||
@ -89,6 +89,10 @@ impl CompletionServer {
|
||||
context_size,
|
||||
);
|
||||
server.start().await;
|
||||
Self::new_with_supervisor(Arc::new(server)).await
|
||||
}
|
||||
|
||||
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(Some(api_endpoint(server.port())))
|
||||
.kind("llama.cpp/completion".to_string())
|
||||
@ -108,7 +112,7 @@ impl CompletionStream for CompletionServer {
|
||||
|
||||
struct ChatCompletionServer {
|
||||
#[allow(unused)]
|
||||
server: LlamaCppSupervisor,
|
||||
server: Arc<LlamaCppSupervisor>,
|
||||
chat_completion: Arc<dyn ChatCompletionStream>,
|
||||
}
|
||||
|
||||
@ -132,6 +136,10 @@ impl ChatCompletionServer {
|
||||
context_size,
|
||||
);
|
||||
server.start().await;
|
||||
Self::new_with_supervisor(Arc::new(server)).await
|
||||
}
|
||||
|
||||
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(Some(api_endpoint(server.port())))
|
||||
.kind("openai/chat".to_string())
|
||||
@ -202,6 +210,53 @@ pub async fn create_completion(
|
||||
(stream, prompt_info)
|
||||
}
|
||||
|
||||
pub async fn create_completion_and_chat(
|
||||
completion_model: &LocalModelConfig,
|
||||
chat_model: &LocalModelConfig,
|
||||
) -> (
|
||||
Arc<dyn CompletionStream>,
|
||||
PromptInfo,
|
||||
Arc<dyn ChatCompletionStream>,
|
||||
) {
|
||||
let chat_model_path = resolve_model_path(&chat_model.model_id).await;
|
||||
let chat_template = resolve_prompt_info(&chat_model.model_id)
|
||||
.await
|
||||
.chat_template
|
||||
.unwrap_or_else(|| panic!("Chat model requires specifying prompt template"));
|
||||
|
||||
let model_path = resolve_model_path(&completion_model.model_id).await;
|
||||
let prompt_info = resolve_prompt_info(&completion_model.model_id).await;
|
||||
|
||||
let server = Arc::new(LlamaCppSupervisor::new(
|
||||
"chat",
|
||||
chat_model.num_gpu_layers,
|
||||
false,
|
||||
&chat_model_path,
|
||||
chat_model.parallelism,
|
||||
Some(chat_template),
|
||||
chat_model.enable_fast_attention.unwrap_or_default(),
|
||||
chat_model.context_size,
|
||||
));
|
||||
server.start().await;
|
||||
|
||||
let chat = ChatCompletionServer::new_with_supervisor(server.clone()).await;
|
||||
|
||||
let completion = if completion_model == chat_model {
|
||||
CompletionServer::new_with_supervisor(server).await
|
||||
} else {
|
||||
CompletionServer::new(
|
||||
completion_model.num_gpu_layers,
|
||||
&model_path,
|
||||
completion_model.parallelism,
|
||||
completion_model.enable_fast_attention.unwrap_or_default(),
|
||||
completion_model.context_size,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
(Arc::new(completion), prompt_info, Arc::new(chat))
|
||||
}
|
||||
|
||||
pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||
match config {
|
||||
ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await,
|
||||
|
@ -250,7 +250,7 @@ pub struct HttpModelConfig {
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct LocalModelConfig {
|
||||
pub model_id: String,
|
||||
|
||||
|
@ -24,11 +24,11 @@ use crate::{
|
||||
services::{
|
||||
self,
|
||||
code::create_code_search,
|
||||
completion::{self, create_completion_service},
|
||||
completion::{self, create_completion_service_and_chat, CompletionService},
|
||||
embedding,
|
||||
event::create_event_logger,
|
||||
health,
|
||||
model::{self, download_model_if_needed},
|
||||
model::download_model_if_needed,
|
||||
tantivy::IndexReaderProvider,
|
||||
},
|
||||
to_local_config, Device,
|
||||
@ -171,17 +171,22 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
||||
index_reader_provider.clone(),
|
||||
));
|
||||
|
||||
let chat = if let Some(chat) = &config.model.chat {
|
||||
Some(model::load_chat_completion(chat).await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let model = &config.model;
|
||||
let (completion, chat) = create_completion_service_and_chat(
|
||||
&config.completion,
|
||||
code.clone(),
|
||||
logger.clone(),
|
||||
model.completion.clone(),
|
||||
model.chat.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut api = api_router(
|
||||
args,
|
||||
&config,
|
||||
logger.clone(),
|
||||
code.clone(),
|
||||
completion,
|
||||
chat.clone(),
|
||||
webserver,
|
||||
)
|
||||
@ -227,24 +232,15 @@ async fn api_router(
|
||||
args: &ServeArgs,
|
||||
config: &Config,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
_code: Arc<dyn CodeSearch>,
|
||||
completion_state: Option<CompletionService>,
|
||||
chat_state: Option<Arc<dyn ChatCompletionStream>>,
|
||||
webserver: Option<bool>,
|
||||
) -> Router {
|
||||
let model = &config.model;
|
||||
let completion_state = if let Some(completion) = &model.completion {
|
||||
Some(Arc::new(
|
||||
create_completion_service(&config.completion, code.clone(), logger.clone(), completion)
|
||||
.await,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut routers = vec![];
|
||||
|
||||
let health_state = Arc::new(health::HealthState::new(
|
||||
model,
|
||||
&config.model,
|
||||
&args.device,
|
||||
args.chat_model
|
||||
.as_deref()
|
||||
@ -273,7 +269,7 @@ async fn api_router(
|
||||
Router::new()
|
||||
.route(
|
||||
"/v1/completions",
|
||||
routing::post(routes::completions).with_state(completion_state),
|
||||
routing::post(routes::completions).with_state(Arc::new(completion_state)),
|
||||
)
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||
config.server.completion_timeout,
|
||||
|
@ -12,7 +12,9 @@ use tabby_common::{
|
||||
config::{CompletionConfig, ModelConfig},
|
||||
languages::get_language,
|
||||
};
|
||||
use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
|
||||
use tabby_inference::{
|
||||
ChatCompletionStream, CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use utoipa::ToSchema;
|
||||
|
||||
@ -352,26 +354,32 @@ impl CompletionService {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_completion_service(
|
||||
pub async fn create_completion_service_and_chat(
|
||||
config: &CompletionConfig,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
logger: Arc<dyn EventLogger>,
|
||||
model: &ModelConfig,
|
||||
) -> CompletionService {
|
||||
let (
|
||||
engine,
|
||||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = model::load_code_generation(model).await;
|
||||
completion: Option<ModelConfig>,
|
||||
chat: Option<ModelConfig>,
|
||||
) -> (
|
||||
Option<CompletionService>,
|
||||
Option<Arc<dyn ChatCompletionStream>>,
|
||||
) {
|
||||
let (code_generation, prompt, chat) =
|
||||
model::load_code_generation_and_chat(completion, chat).await;
|
||||
|
||||
let completion = code_generation.map(|code_generation| {
|
||||
CompletionService::new(
|
||||
config.to_owned(),
|
||||
engine.clone(),
|
||||
code_generation.clone(),
|
||||
code,
|
||||
logger,
|
||||
prompt_template,
|
||||
prompt
|
||||
.unwrap_or_else(|| panic!("Prompt template is required for code completion"))
|
||||
.prompt_template,
|
||||
)
|
||||
});
|
||||
|
||||
(completion, chat)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -6,37 +6,74 @@ use tabby_download::download_model;
|
||||
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
|
||||
use tracing::info;
|
||||
|
||||
pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionStream> {
|
||||
match chat {
|
||||
ModelConfig::Http(http) => http_api_bindings::create_chat(http).await,
|
||||
ModelConfig::Local(llama) => llama_cpp_server::create_chat_completion(llama).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||
llama_cpp_server::create_embedding(config).await
|
||||
}
|
||||
|
||||
pub async fn load_code_generation(model: &ModelConfig) -> (Arc<CodeGeneration>, PromptInfo) {
|
||||
let (engine, prompt_info) = load_completion(model).await;
|
||||
(Arc::new(CodeGeneration::new(engine)), prompt_info)
|
||||
pub async fn load_code_generation_and_chat(
|
||||
completion_model: Option<ModelConfig>,
|
||||
chat_model: Option<ModelConfig>,
|
||||
) -> (
|
||||
Option<Arc<CodeGeneration>>,
|
||||
Option<PromptInfo>,
|
||||
Option<Arc<dyn ChatCompletionStream>>,
|
||||
) {
|
||||
let (engine, prompt_info, chat) = load_completion_and_chat(completion_model, chat_model).await;
|
||||
let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine)));
|
||||
(code, prompt_info, chat)
|
||||
}
|
||||
|
||||
async fn load_completion(model: &ModelConfig) -> (Arc<dyn CompletionStream>, PromptInfo) {
|
||||
match model {
|
||||
async fn load_completion_and_chat(
|
||||
completion_model: Option<ModelConfig>,
|
||||
chat_model: Option<ModelConfig>,
|
||||
) -> (
|
||||
Option<Arc<dyn CompletionStream>>,
|
||||
Option<PromptInfo>,
|
||||
Option<Arc<dyn ChatCompletionStream>>,
|
||||
) {
|
||||
if let (Some(ModelConfig::Local(completion)), Some(ModelConfig::Local(chat))) =
|
||||
(&completion_model, &chat_model)
|
||||
{
|
||||
let (completion, prompt, chat) =
|
||||
llama_cpp_server::create_completion_and_chat(completion, chat).await;
|
||||
return (Some(completion), Some(prompt), Some(chat));
|
||||
}
|
||||
|
||||
let (completion, prompt) = if let Some(completion_model) = completion_model {
|
||||
match completion_model {
|
||||
ModelConfig::Http(http) => {
|
||||
let engine = http_api_bindings::create(http).await;
|
||||
let (prompt_template, chat_template) = http_api_bindings::build_completion_prompt(http);
|
||||
let engine = http_api_bindings::create(&http).await;
|
||||
let (prompt_template, chat_template) =
|
||||
http_api_bindings::build_completion_prompt(&http);
|
||||
(
|
||||
engine,
|
||||
PromptInfo {
|
||||
Some(engine),
|
||||
Some(PromptInfo {
|
||||
prompt_template,
|
||||
chat_template,
|
||||
},
|
||||
}),
|
||||
)
|
||||
}
|
||||
ModelConfig::Local(llama) => llama_cpp_server::create_completion(llama).await,
|
||||
ModelConfig::Local(llama) => {
|
||||
let (stream, prompt) = llama_cpp_server::create_completion(&llama).await;
|
||||
(Some(stream), Some(prompt))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
let chat = if let Some(chat_model) = chat_model {
|
||||
match chat_model {
|
||||
ModelConfig::Http(http) => Some(http_api_bindings::create_chat(&http).await),
|
||||
ModelConfig::Local(llama) => {
|
||||
Some(llama_cpp_server::create_chat_completion(&llama).await)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(completion, prompt, chat)
|
||||
}
|
||||
|
||||
pub async fn download_model_if_needed(model: &str) {
|
||||
|
Loading…
Reference in New Issue
Block a user