mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
feat(llama-cpp-server): reuse local server for completion and chat (#2812)
* ✨ support reuse the same server if model id is equaled Signed-off-by: Wei Zhang <kweizh@gmail.com> * 🎨 fix review Signed-off-by: Wei Zhang <kweizh@gmail.com> * 🔨 should compare model in llama cpp server Signed-off-by: Wei Zhang <kweizh@gmail.com> * match case for creating completion and chat Signed-off-by: Wei Zhang <kweizh@gmail.com> * if let case for creating completion and chat Signed-off-by: Wei Zhang <kweizh@gmail.com> * add changelog * rebase && make fix --------- Signed-off-by: Wei Zhang <kweizh@gmail.com> Co-authored-by: Meng Zhang <meng@tabbyml.com>
This commit is contained in:
parent
946037ffde
commit
1632dd054e
@ -0,0 +1,4 @@
|
|||||||
|
kind: Fixed and Improvements
|
||||||
|
body: When reusing a model for Chat / Completion together (e.g Codestral-22B) with
|
||||||
|
local model backend, resuse the llama-server sub process.
|
||||||
|
time: 2024-08-10T22:10:45.454278-07:00
|
@ -66,7 +66,7 @@ impl Embedding for EmbeddingServer {
|
|||||||
|
|
||||||
struct CompletionServer {
|
struct CompletionServer {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
server: LlamaCppSupervisor,
|
server: Arc<LlamaCppSupervisor>,
|
||||||
completion: Arc<dyn CompletionStream>,
|
completion: Arc<dyn CompletionStream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,6 +89,10 @@ impl CompletionServer {
|
|||||||
context_size,
|
context_size,
|
||||||
);
|
);
|
||||||
server.start().await;
|
server.start().await;
|
||||||
|
Self::new_with_supervisor(Arc::new(server)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
|
||||||
let config = HttpModelConfigBuilder::default()
|
let config = HttpModelConfigBuilder::default()
|
||||||
.api_endpoint(Some(api_endpoint(server.port())))
|
.api_endpoint(Some(api_endpoint(server.port())))
|
||||||
.kind("llama.cpp/completion".to_string())
|
.kind("llama.cpp/completion".to_string())
|
||||||
@ -108,7 +112,7 @@ impl CompletionStream for CompletionServer {
|
|||||||
|
|
||||||
struct ChatCompletionServer {
|
struct ChatCompletionServer {
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
server: LlamaCppSupervisor,
|
server: Arc<LlamaCppSupervisor>,
|
||||||
chat_completion: Arc<dyn ChatCompletionStream>,
|
chat_completion: Arc<dyn ChatCompletionStream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -132,6 +136,10 @@ impl ChatCompletionServer {
|
|||||||
context_size,
|
context_size,
|
||||||
);
|
);
|
||||||
server.start().await;
|
server.start().await;
|
||||||
|
Self::new_with_supervisor(Arc::new(server)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn new_with_supervisor(server: Arc<LlamaCppSupervisor>) -> Self {
|
||||||
let config = HttpModelConfigBuilder::default()
|
let config = HttpModelConfigBuilder::default()
|
||||||
.api_endpoint(Some(api_endpoint(server.port())))
|
.api_endpoint(Some(api_endpoint(server.port())))
|
||||||
.kind("openai/chat".to_string())
|
.kind("openai/chat".to_string())
|
||||||
@ -202,6 +210,53 @@ pub async fn create_completion(
|
|||||||
(stream, prompt_info)
|
(stream, prompt_info)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn create_completion_and_chat(
|
||||||
|
completion_model: &LocalModelConfig,
|
||||||
|
chat_model: &LocalModelConfig,
|
||||||
|
) -> (
|
||||||
|
Arc<dyn CompletionStream>,
|
||||||
|
PromptInfo,
|
||||||
|
Arc<dyn ChatCompletionStream>,
|
||||||
|
) {
|
||||||
|
let chat_model_path = resolve_model_path(&chat_model.model_id).await;
|
||||||
|
let chat_template = resolve_prompt_info(&chat_model.model_id)
|
||||||
|
.await
|
||||||
|
.chat_template
|
||||||
|
.unwrap_or_else(|| panic!("Chat model requires specifying prompt template"));
|
||||||
|
|
||||||
|
let model_path = resolve_model_path(&completion_model.model_id).await;
|
||||||
|
let prompt_info = resolve_prompt_info(&completion_model.model_id).await;
|
||||||
|
|
||||||
|
let server = Arc::new(LlamaCppSupervisor::new(
|
||||||
|
"chat",
|
||||||
|
chat_model.num_gpu_layers,
|
||||||
|
false,
|
||||||
|
&chat_model_path,
|
||||||
|
chat_model.parallelism,
|
||||||
|
Some(chat_template),
|
||||||
|
chat_model.enable_fast_attention.unwrap_or_default(),
|
||||||
|
chat_model.context_size,
|
||||||
|
));
|
||||||
|
server.start().await;
|
||||||
|
|
||||||
|
let chat = ChatCompletionServer::new_with_supervisor(server.clone()).await;
|
||||||
|
|
||||||
|
let completion = if completion_model == chat_model {
|
||||||
|
CompletionServer::new_with_supervisor(server).await
|
||||||
|
} else {
|
||||||
|
CompletionServer::new(
|
||||||
|
completion_model.num_gpu_layers,
|
||||||
|
&model_path,
|
||||||
|
completion_model.parallelism,
|
||||||
|
completion_model.enable_fast_attention.unwrap_or_default(),
|
||||||
|
completion_model.context_size,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
};
|
||||||
|
|
||||||
|
(Arc::new(completion), prompt_info, Arc::new(chat))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||||
match config {
|
match config {
|
||||||
ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await,
|
ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await,
|
||||||
|
@ -250,7 +250,7 @@ pub struct HttpModelConfig {
|
|||||||
pub chat_template: Option<String>,
|
pub chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||||
pub struct LocalModelConfig {
|
pub struct LocalModelConfig {
|
||||||
pub model_id: String,
|
pub model_id: String,
|
||||||
|
|
||||||
|
@ -24,11 +24,11 @@ use crate::{
|
|||||||
services::{
|
services::{
|
||||||
self,
|
self,
|
||||||
code::create_code_search,
|
code::create_code_search,
|
||||||
completion::{self, create_completion_service},
|
completion::{self, create_completion_service_and_chat, CompletionService},
|
||||||
embedding,
|
embedding,
|
||||||
event::create_event_logger,
|
event::create_event_logger,
|
||||||
health,
|
health,
|
||||||
model::{self, download_model_if_needed},
|
model::download_model_if_needed,
|
||||||
tantivy::IndexReaderProvider,
|
tantivy::IndexReaderProvider,
|
||||||
},
|
},
|
||||||
to_local_config, Device,
|
to_local_config, Device,
|
||||||
@ -171,17 +171,22 @@ pub async fn main(config: &Config, args: &ServeArgs) {
|
|||||||
index_reader_provider.clone(),
|
index_reader_provider.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
let chat = if let Some(chat) = &config.model.chat {
|
let model = &config.model;
|
||||||
Some(model::load_chat_completion(chat).await)
|
let (completion, chat) = create_completion_service_and_chat(
|
||||||
} else {
|
&config.completion,
|
||||||
None
|
code.clone(),
|
||||||
};
|
logger.clone(),
|
||||||
|
model.completion.clone(),
|
||||||
|
model.chat.clone(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut api = api_router(
|
let mut api = api_router(
|
||||||
args,
|
args,
|
||||||
&config,
|
&config,
|
||||||
logger.clone(),
|
logger.clone(),
|
||||||
code.clone(),
|
code.clone(),
|
||||||
|
completion,
|
||||||
chat.clone(),
|
chat.clone(),
|
||||||
webserver,
|
webserver,
|
||||||
)
|
)
|
||||||
@ -227,24 +232,15 @@ async fn api_router(
|
|||||||
args: &ServeArgs,
|
args: &ServeArgs,
|
||||||
config: &Config,
|
config: &Config,
|
||||||
logger: Arc<dyn EventLogger>,
|
logger: Arc<dyn EventLogger>,
|
||||||
code: Arc<dyn CodeSearch>,
|
_code: Arc<dyn CodeSearch>,
|
||||||
|
completion_state: Option<CompletionService>,
|
||||||
chat_state: Option<Arc<dyn ChatCompletionStream>>,
|
chat_state: Option<Arc<dyn ChatCompletionStream>>,
|
||||||
webserver: Option<bool>,
|
webserver: Option<bool>,
|
||||||
) -> Router {
|
) -> Router {
|
||||||
let model = &config.model;
|
|
||||||
let completion_state = if let Some(completion) = &model.completion {
|
|
||||||
Some(Arc::new(
|
|
||||||
create_completion_service(&config.completion, code.clone(), logger.clone(), completion)
|
|
||||||
.await,
|
|
||||||
))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut routers = vec![];
|
let mut routers = vec![];
|
||||||
|
|
||||||
let health_state = Arc::new(health::HealthState::new(
|
let health_state = Arc::new(health::HealthState::new(
|
||||||
model,
|
&config.model,
|
||||||
&args.device,
|
&args.device,
|
||||||
args.chat_model
|
args.chat_model
|
||||||
.as_deref()
|
.as_deref()
|
||||||
@ -273,7 +269,7 @@ async fn api_router(
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route(
|
.route(
|
||||||
"/v1/completions",
|
"/v1/completions",
|
||||||
routing::post(routes::completions).with_state(completion_state),
|
routing::post(routes::completions).with_state(Arc::new(completion_state)),
|
||||||
)
|
)
|
||||||
.layer(TimeoutLayer::new(Duration::from_secs(
|
.layer(TimeoutLayer::new(Duration::from_secs(
|
||||||
config.server.completion_timeout,
|
config.server.completion_timeout,
|
||||||
|
@ -12,7 +12,9 @@ use tabby_common::{
|
|||||||
config::{CompletionConfig, ModelConfig},
|
config::{CompletionConfig, ModelConfig},
|
||||||
languages::get_language,
|
languages::get_language,
|
||||||
};
|
};
|
||||||
use tabby_inference::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
|
use tabby_inference::{
|
||||||
|
ChatCompletionStream, CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder,
|
||||||
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use utoipa::ToSchema;
|
use utoipa::ToSchema;
|
||||||
|
|
||||||
@ -352,26 +354,32 @@ impl CompletionService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_completion_service(
|
pub async fn create_completion_service_and_chat(
|
||||||
config: &CompletionConfig,
|
config: &CompletionConfig,
|
||||||
code: Arc<dyn CodeSearch>,
|
code: Arc<dyn CodeSearch>,
|
||||||
logger: Arc<dyn EventLogger>,
|
logger: Arc<dyn EventLogger>,
|
||||||
model: &ModelConfig,
|
completion: Option<ModelConfig>,
|
||||||
) -> CompletionService {
|
chat: Option<ModelConfig>,
|
||||||
let (
|
) -> (
|
||||||
engine,
|
Option<CompletionService>,
|
||||||
model::PromptInfo {
|
Option<Arc<dyn ChatCompletionStream>>,
|
||||||
prompt_template, ..
|
) {
|
||||||
},
|
let (code_generation, prompt, chat) =
|
||||||
) = model::load_code_generation(model).await;
|
model::load_code_generation_and_chat(completion, chat).await;
|
||||||
|
|
||||||
CompletionService::new(
|
let completion = code_generation.map(|code_generation| {
|
||||||
config.to_owned(),
|
CompletionService::new(
|
||||||
engine.clone(),
|
config.to_owned(),
|
||||||
code,
|
code_generation.clone(),
|
||||||
logger,
|
code,
|
||||||
prompt_template,
|
logger,
|
||||||
)
|
prompt
|
||||||
|
.unwrap_or_else(|| panic!("Prompt template is required for code completion"))
|
||||||
|
.prompt_template,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
(completion, chat)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -6,37 +6,74 @@ use tabby_download::download_model;
|
|||||||
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
|
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
|
||||||
pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionStream> {
|
|
||||||
match chat {
|
|
||||||
ModelConfig::Http(http) => http_api_bindings::create_chat(http).await,
|
|
||||||
ModelConfig::Local(llama) => llama_cpp_server::create_chat_completion(llama).await,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
pub async fn load_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
|
||||||
llama_cpp_server::create_embedding(config).await
|
llama_cpp_server::create_embedding(config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn load_code_generation(model: &ModelConfig) -> (Arc<CodeGeneration>, PromptInfo) {
|
pub async fn load_code_generation_and_chat(
|
||||||
let (engine, prompt_info) = load_completion(model).await;
|
completion_model: Option<ModelConfig>,
|
||||||
(Arc::new(CodeGeneration::new(engine)), prompt_info)
|
chat_model: Option<ModelConfig>,
|
||||||
|
) -> (
|
||||||
|
Option<Arc<CodeGeneration>>,
|
||||||
|
Option<PromptInfo>,
|
||||||
|
Option<Arc<dyn ChatCompletionStream>>,
|
||||||
|
) {
|
||||||
|
let (engine, prompt_info, chat) = load_completion_and_chat(completion_model, chat_model).await;
|
||||||
|
let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine)));
|
||||||
|
(code, prompt_info, chat)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load_completion(model: &ModelConfig) -> (Arc<dyn CompletionStream>, PromptInfo) {
|
async fn load_completion_and_chat(
|
||||||
match model {
|
completion_model: Option<ModelConfig>,
|
||||||
ModelConfig::Http(http) => {
|
chat_model: Option<ModelConfig>,
|
||||||
let engine = http_api_bindings::create(http).await;
|
) -> (
|
||||||
let (prompt_template, chat_template) = http_api_bindings::build_completion_prompt(http);
|
Option<Arc<dyn CompletionStream>>,
|
||||||
(
|
Option<PromptInfo>,
|
||||||
engine,
|
Option<Arc<dyn ChatCompletionStream>>,
|
||||||
PromptInfo {
|
) {
|
||||||
prompt_template,
|
if let (Some(ModelConfig::Local(completion)), Some(ModelConfig::Local(chat))) =
|
||||||
chat_template,
|
(&completion_model, &chat_model)
|
||||||
},
|
{
|
||||||
)
|
let (completion, prompt, chat) =
|
||||||
}
|
llama_cpp_server::create_completion_and_chat(completion, chat).await;
|
||||||
ModelConfig::Local(llama) => llama_cpp_server::create_completion(llama).await,
|
return (Some(completion), Some(prompt), Some(chat));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (completion, prompt) = if let Some(completion_model) = completion_model {
|
||||||
|
match completion_model {
|
||||||
|
ModelConfig::Http(http) => {
|
||||||
|
let engine = http_api_bindings::create(&http).await;
|
||||||
|
let (prompt_template, chat_template) =
|
||||||
|
http_api_bindings::build_completion_prompt(&http);
|
||||||
|
(
|
||||||
|
Some(engine),
|
||||||
|
Some(PromptInfo {
|
||||||
|
prompt_template,
|
||||||
|
chat_template,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ModelConfig::Local(llama) => {
|
||||||
|
let (stream, prompt) = llama_cpp_server::create_completion(&llama).await;
|
||||||
|
(Some(stream), Some(prompt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(None, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
let chat = if let Some(chat_model) = chat_model {
|
||||||
|
match chat_model {
|
||||||
|
ModelConfig::Http(http) => Some(http_api_bindings::create_chat(&http).await),
|
||||||
|
ModelConfig::Local(llama) => {
|
||||||
|
Some(llama_cpp_server::create_chat_completion(&llama).await)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
(completion, prompt, chat)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn download_model_if_needed(model: &str) {
|
pub async fn download_model_if_needed(model: &str) {
|
||||||
|
Loading…
Reference in New Issue
Block a user