mirror of
https://github.com/TabbyML/tabby
synced 2024-11-23 10:05:08 +00:00
feat(tabby-inference, http-api-bindings): support llama.cpp server embedding interface. (#2094)
* refactor: move chat / completion into corresponding dir * add embedding * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
bef5629a38
commit
174bbae43e
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2134,6 +2134,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"tabby-common",
|
||||
"tabby-inference",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
|
@ -18,3 +18,6 @@ serde_json = { workspace = true }
|
||||
tabby-common = { path = "../tabby-common" }
|
||||
tabby-inference = { path = "../tabby-inference" }
|
||||
tracing.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio ={ workspace = true, features = ["rt", "macros"]}
|
23
crates/http-api-bindings/src/chat/mod.rs
Normal file
23
crates/http-api-bindings/src/chat/mod.rs
Normal file
@ -0,0 +1,23 @@
|
||||
mod openai_chat;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use openai_chat::OpenAIChatEngine;
|
||||
use tabby_inference::ChatCompletionStream;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> Arc<dyn ChatCompletionStream> {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai-chat" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
|
||||
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only openai-chat are supported for http chat");
|
||||
}
|
||||
}
|
33
crates/http-api-bindings/src/completion/mod.rs
Normal file
33
crates/http-api-bindings/src/completion/mod.rs
Normal file
@ -0,0 +1,33 @@
|
||||
mod llama;
|
||||
mod openai;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use llama::LlamaCppEngine;
|
||||
use openai::OpenAIEngine;
|
||||
use tabby_inference::CompletionStream;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
} else if kind == "llama" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
} else {
|
||||
panic!("Only openai are supported for http completion");
|
||||
}
|
||||
}
|
64
crates/http-api-bindings/src/embedding/llama.rs
Normal file
64
crates/http-api-bindings/src/embedding/llama.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::Embedding;
|
||||
|
||||
pub struct LlamaCppEngine {
|
||||
client: reqwest::Client,
|
||||
api_endpoint: String,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl LlamaCppEngine {
|
||||
pub fn create(api_endpoint: &str, api_key: Option<String>) -> Self {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
Self {
|
||||
client,
|
||||
api_endpoint: format!("{}/embeddings", api_endpoint),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct EmbeddingRequest {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct EmbeddingResponse {
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Embedding for LlamaCppEngine {
|
||||
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>> {
|
||||
let request = EmbeddingRequest {
|
||||
content: prompt.to_owned(),
|
||||
};
|
||||
|
||||
let mut request = self.client.post(&self.api_endpoint).json(&request);
|
||||
if let Some(api_key) = &self.api_key {
|
||||
request = request.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
let response = request.send().await?.json::<EmbeddingResponse>().await?;
|
||||
Ok(response.embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// This unit test should only run manually when the server is running
|
||||
/// curl -L https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q8_0.gguf -o ./models/nomic.gguf
|
||||
/// ./server -m ./models/nomic.gguf --port 8000 --embedding
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_embedding() {
|
||||
let engine = LlamaCppEngine::create("http://localhost:8000", None);
|
||||
let embedding = engine.embed("hello").await.unwrap();
|
||||
assert_eq!(embedding.len(), 768);
|
||||
}
|
||||
}
|
21
crates/http-api-bindings/src/embedding/mod.rs
Normal file
21
crates/http-api-bindings/src/embedding/mod.rs
Normal file
@ -0,0 +1,21 @@
|
||||
mod llama;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use llama::LlamaCppEngine;
|
||||
use tabby_inference::Embedding;
|
||||
|
||||
use crate::{get_optional_param, get_param};
|
||||
|
||||
pub fn create(model: &str) -> Arc<dyn Embedding> {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "llama" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let engine = LlamaCppEngine::create(&api_endpoint, api_key);
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only llama are supported for http embedding");
|
||||
}
|
||||
}
|
@ -1,53 +1,13 @@
|
||||
mod llama;
|
||||
mod openai;
|
||||
mod openai_chat;
|
||||
mod chat;
|
||||
mod completion;
|
||||
mod embedding;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use openai::OpenAIEngine;
|
||||
use openai_chat::OpenAIChatEngine;
|
||||
pub use chat::create as create_chat;
|
||||
pub use completion::create;
|
||||
pub use embedding::create as create_embedding;
|
||||
use serde_json::Value;
|
||||
use tabby_inference::{ChatCompletionStream, CompletionStream};
|
||||
|
||||
pub fn create(model: &str) -> (Arc<dyn CompletionStream>, Option<String>, Option<String>) {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
} else if kind == "llama" {
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
let prompt_template = get_optional_param(¶ms, "prompt_template");
|
||||
let chat_template = get_optional_param(¶ms, "chat_template");
|
||||
let engine = llama::LlamaCppEngine::create(&api_endpoint, api_key);
|
||||
(Arc::new(engine), prompt_template, chat_template)
|
||||
} else {
|
||||
panic!("Only openai are supported for http completion");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_chat(model: &str) -> Arc<dyn ChatCompletionStream> {
|
||||
let params = serde_json::from_str(model).expect("Failed to parse model string");
|
||||
let kind = get_param(¶ms, "kind");
|
||||
if kind == "openai-chat" {
|
||||
let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default();
|
||||
let api_endpoint = get_param(¶ms, "api_endpoint");
|
||||
let api_key = get_optional_param(¶ms, "api_key");
|
||||
|
||||
let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key);
|
||||
Arc::new(engine)
|
||||
} else {
|
||||
panic!("Only openai-chat are supported for http chat");
|
||||
}
|
||||
}
|
||||
|
||||
fn get_param(params: &Value, key: &str) -> String {
|
||||
pub(crate) fn get_param(params: &Value, key: &str) -> String {
|
||||
params
|
||||
.get(key)
|
||||
.unwrap_or_else(|| panic!("Missing {} field", key))
|
||||
@ -56,7 +16,7 @@ fn get_param(params: &Value, key: &str) -> String {
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn get_optional_param(params: &Value, key: &str) -> Option<String> {
|
||||
pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option<String> {
|
||||
params
|
||||
.get(key)
|
||||
.map(|x| x.as_str().expect("Type unmatched").to_owned())
|
||||
|
6
crates/tabby-inference/src/embedding.rs
Normal file
6
crates/tabby-inference/src/embedding.rs
Normal file
@ -0,0 +1,6 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Embedding: Sync + Send {
|
||||
async fn embed(&self, prompt: &str) -> anyhow::Result<Vec<f32>>;
|
||||
}
|
@ -3,10 +3,12 @@ mod chat;
|
||||
mod code;
|
||||
mod completion;
|
||||
mod decoding;
|
||||
mod embedding;
|
||||
|
||||
pub use chat::{ChatCompletionOptions, ChatCompletionOptionsBuilder, ChatCompletionStream};
|
||||
pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder};
|
||||
pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream};
|
||||
pub use embedding::Embedding;
|
||||
|
||||
fn default_seed() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
|
Loading…
Reference in New Issue
Block a user