feat(tabby-inference, http-api-bindings): support llama.cpp server embedding interface. (#2094)

* refactor: move chat / completion into corresponding dir

* add embedding

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Meng Zhang 2024-05-10 22:22:17 -07:00 committed by GitHub
parent bef5629a38
commit 174bbae43e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 161 additions and 48 deletions

1
Cargo.lock generated
View File

@ -2134,6 +2134,7 @@ dependencies = [
"serde_json",
"tabby-common",
"tabby-inference",
"tokio",
"tracing",
]

View File

@ -18,3 +18,6 @@ serde_json = { workspace = true }
tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
tracing.workspace = true
[dev-dependencies]
tokio ={ workspace = true, features = ["rt", "macros"]}

View File

@ -0,0 +1,23 @@
mod openai_chat;
use std::sync::Arc;
use openai_chat::OpenAIChatEngine;
use tabby_inference::ChatCompletionStream;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> Arc<dyn ChatCompletionStream> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai-chat" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
Arc::new(engine)
} else {
panic!("Only openai-chat are supported for http chat");
}
}

View File

@ -0,0 +1,33 @@
mod llama;
mod openai;
use std::sync::Arc;
use llama::LlamaCppEngine;
use openai::OpenAIEngine;
use tabby_inference::CompletionStream;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else {
panic!("Only openai are supported for http completion");
}
}

View File

@ -0,0 +1,64 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tabby_inference::Embedding;
pub struct LlamaCppEngine {
client: reqwest::Client,
api_endpoint: String,
api_key: Option<String>,
}
impl LlamaCppEngine {
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Self {
let client = reqwest::Client::new();
Self {
client,
api_endpoint: format!("{}/embeddings", api_endpoint),
api_key,
}
}
}
#[derive(Serialize)]
struct EmbeddingRequest {
content: String,
}
#[derive(Deserialize)]
struct EmbeddingResponse {
embedding: Vec<f32>,
}
#[async_trait]
impl Embedding for LlamaCppEngine {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
let request = EmbeddingRequest {
content: prompt.to_owned(),
};
let mut request = self.client.post(&self.api_endpoint).json(&request);
if let Some(api_key) = &self.api_key {
request = request.bearer_auth(api_key);
}
let response = request.send().await?.json::<EmbeddingResponse>().await?;
Ok(response.embedding)
}
}
#[cfg(test)]
mod tests {
use super::*;
/// This unit test should only run manually when the server is running
/// curl -L https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q8_0.gguf -o ./models/nomic.gguf
/// ./server -m ./models/nomic.gguf --port 8000 --embedding
#[tokio::test]
#[ignore]
async fn test_embedding() {
let engine = LlamaCppEngine::create("http://localhost:8000", None);
let embedding = engine.embed("hello").await.unwrap();
assert_eq!(embedding.len(), 768);
}
}

View File

@ -0,0 +1,21 @@
mod llama;
use std::sync::Arc;
use llama::LlamaCppEngine;
use tabby_inference::Embedding;
use crate::{get_optional_param, get_param};
pub fn create(model: &str) -> Arc<dyn Embedding> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
Arc::new(engine)
} else {
panic!("Only llama are supported for http embedding");
}
}

View File

@ -1,53 +1,13 @@
mod llama;
mod openai;
mod openai_chat;
mod chat;
mod completion;
mod embedding;
use std::sync::Arc;
use openai::OpenAIEngine;
use openai_chat::OpenAIChatEngine;
pub use chat::create as create_chat;
pub use completion::create;
pub use embedding::create as create_embedding;
use serde_json::Value;
use tabby_inference::{ChatCompletionStream, CompletionStream};
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else if kind == "llama" {
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let prompt_template = get_optional_param(&params, "prompt_template");
let chat_template = get_optional_param(&params, "chat_template");
let engine = llama::LlamaCppEngine::create(&api_endpoint, api_key);
(Arc::new(engine), prompt_template, chat_template)
} else {
panic!("Only openai are supported for http completion");
}
}
pub fn create_chat(model: &str) -> Arc<dyn ChatCompletionStream> {
let params = serde_json::from_str(model).expect("Failed to parse model string");
let kind = get_param(&params, "kind");
if kind == "openai-chat" {
let model_name = get_optional_param(&params, "model_name").unwrap_or_default();
let api_endpoint = get_param(&params, "api_endpoint");
let api_key = get_optional_param(&params, "api_key");
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
Arc::new(engine)
} else {
panic!("Only openai-chat are supported for http chat");
}
}
fn get_param(params: &Value, key: &str) -> String {
pub(crate) fn get_param(params: &Value, key: &str) -> String {
params
.get(key)
.unwrap_or_else(|| panic!("Missing {} field", key))
@ -56,7 +16,7 @@ fn get_param(params: &Value, key: &str) -> String {
.to_owned()
}
fn get_optional_param(params: &Value, key: &str) -> Option<String> {
pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option<String> {
params
.get(key)
.map(|x| x.as_str().expect("Type unmatched").to_owned())

View File

@ -0,0 +1,6 @@
use async_trait::async_trait;
#[async_trait]
pub trait Embedding: Sync + Send {
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>>;
}

View File

@ -3,10 +3,12 @@ mod chat;
mod code;
mod completion;
mod decoding;
mod embedding;
pub use chat::{ChatCompletionOptions, ChatCompletionOptionsBuilder, ChatCompletionStream};
pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream};
pub use embedding::Embedding;
fn default_seed() -> u64 {
std::time::SystemTime::now()