diff --git a/Cargo.lock b/Cargo.lock index 0ee0fa14b..1eef5a96d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2134,6 +2134,7 @@ dependencies = [ "serde_json", "tabby-common", "tabby-inference", + "tokio", "tracing", ] diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 12642f29b..eae43f6ba 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -18,3 +18,6 @@ serde_json = { workspace = true } tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } tracing.workspace = true + +[dev-dependencies] +tokio ={ workspace = true, features = ["rt", "macros"]} \ No newline at end of file diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs new file mode 100644 index 000000000..f2577da23 --- /dev/null +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -0,0 +1,23 @@ +mod openai_chat; + +use std::sync::Arc; + +use openai_chat::OpenAIChatEngine; +use tabby_inference::ChatCompletionStream; + +use crate::{get_optional_param, get_param}; + +pub fn create(model: &str) -> Arc { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "openai-chat" { + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + + let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key); + Arc::new(engine) + } else { + panic!("Only openai-chat are supported for http chat"); + } +} diff --git a/crates/http-api-bindings/src/openai_chat.rs b/crates/http-api-bindings/src/chat/openai_chat.rs similarity index 100% rename from crates/http-api-bindings/src/openai_chat.rs rename to crates/http-api-bindings/src/chat/openai_chat.rs diff --git a/crates/http-api-bindings/src/llama.rs b/crates/http-api-bindings/src/completion/llama.rs similarity index 100% rename from crates/http-api-bindings/src/llama.rs rename to crates/http-api-bindings/src/completion/llama.rs diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs new file mode 100644 index 000000000..b35be4280 --- /dev/null +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -0,0 +1,33 @@ +mod llama; +mod openai; + +use std::sync::Arc; + +use llama::LlamaCppEngine; +use openai::OpenAIEngine; +use tabby_inference::CompletionStream; + +use crate::{get_optional_param, get_param}; + +pub fn create(model: &str) -> (Arc, Option, Option) { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "openai" { + let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + let prompt_template = get_optional_param(¶ms, "prompt_template"); + let chat_template = get_optional_param(¶ms, "chat_template"); + let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key); + (Arc::new(engine), prompt_template, chat_template) + } else if kind == "llama" { + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + let prompt_template = get_optional_param(¶ms, "prompt_template"); + let chat_template = get_optional_param(¶ms, "chat_template"); + let engine = LlamaCppEngine::create(&api_endpoint, api_key); + (Arc::new(engine), prompt_template, chat_template) + } else { + panic!("Only openai are supported for http completion"); + } +} diff --git a/crates/http-api-bindings/src/openai.rs b/crates/http-api-bindings/src/completion/openai.rs similarity index 100% rename from crates/http-api-bindings/src/openai.rs rename to crates/http-api-bindings/src/completion/openai.rs diff --git a/crates/http-api-bindings/src/embedding/llama.rs b/crates/http-api-bindings/src/embedding/llama.rs new file mode 100644 index 000000000..87cfb0418 --- /dev/null +++ b/crates/http-api-bindings/src/embedding/llama.rs @@ -0,0 +1,64 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tabby_inference::Embedding; + +pub struct LlamaCppEngine { + client: reqwest::Client, + api_endpoint: String, + api_key: Option, +} + +impl LlamaCppEngine { + pub fn create(api_endpoint: &str, api_key: Option) -> Self { + let client = reqwest::Client::new(); + + Self { + client, + api_endpoint: format!("{}/embeddings", api_endpoint), + api_key, + } + } +} + +#[derive(Serialize)] +struct EmbeddingRequest { + content: String, +} + +#[derive(Deserialize)] +struct EmbeddingResponse { + embedding: Vec, +} + +#[async_trait] +impl Embedding for LlamaCppEngine { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + let request = EmbeddingRequest { + content: prompt.to_owned(), + }; + + let mut request = self.client.post(&self.api_endpoint).json(&request); + if let Some(api_key) = &self.api_key { + request = request.bearer_auth(api_key); + } + + let response = request.send().await?.json::().await?; + Ok(response.embedding) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// This unit test should only run manually when the server is running + /// curl -L https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF/resolve/main/nomic-embed-text-v1.5.Q8_0.gguf -o ./models/nomic.gguf + /// ./server -m ./models/nomic.gguf --port 8000 --embedding + #[tokio::test] + #[ignore] + async fn test_embedding() { + let engine = LlamaCppEngine::create("http://localhost:8000", None); + let embedding = engine.embed("hello").await.unwrap(); + assert_eq!(embedding.len(), 768); + } +} diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs new file mode 100644 index 000000000..cf45f9cee --- /dev/null +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -0,0 +1,21 @@ +mod llama; + +use std::sync::Arc; + +use llama::LlamaCppEngine; +use tabby_inference::Embedding; + +use crate::{get_optional_param, get_param}; + +pub fn create(model: &str) -> Arc { + let params = serde_json::from_str(model).expect("Failed to parse model string"); + let kind = get_param(¶ms, "kind"); + if kind == "llama" { + let api_endpoint = get_param(¶ms, "api_endpoint"); + let api_key = get_optional_param(¶ms, "api_key"); + let engine = LlamaCppEngine::create(&api_endpoint, api_key); + Arc::new(engine) + } else { + panic!("Only llama are supported for http embedding"); + } +} diff --git a/crates/http-api-bindings/src/lib.rs b/crates/http-api-bindings/src/lib.rs index cc905227c..75994356b 100644 --- a/crates/http-api-bindings/src/lib.rs +++ b/crates/http-api-bindings/src/lib.rs @@ -1,53 +1,13 @@ -mod llama; -mod openai; -mod openai_chat; +mod chat; +mod completion; +mod embedding; -use std::sync::Arc; - -use openai::OpenAIEngine; -use openai_chat::OpenAIChatEngine; +pub use chat::create as create_chat; +pub use completion::create; +pub use embedding::create as create_embedding; use serde_json::Value; -use tabby_inference::{ChatCompletionStream, CompletionStream}; -pub fn create(model: &str) -> (Arc, Option, Option) { - let params = serde_json::from_str(model).expect("Failed to parse model string"); - let kind = get_param(¶ms, "kind"); - if kind == "openai" { - let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); - let api_endpoint = get_param(¶ms, "api_endpoint"); - let api_key = get_optional_param(¶ms, "api_key"); - let prompt_template = get_optional_param(¶ms, "prompt_template"); - let chat_template = get_optional_param(¶ms, "chat_template"); - let engine = OpenAIEngine::create(&api_endpoint, &model_name, api_key); - (Arc::new(engine), prompt_template, chat_template) - } else if kind == "llama" { - let api_endpoint = get_param(¶ms, "api_endpoint"); - let api_key = get_optional_param(¶ms, "api_key"); - let prompt_template = get_optional_param(¶ms, "prompt_template"); - let chat_template = get_optional_param(¶ms, "chat_template"); - let engine = llama::LlamaCppEngine::create(&api_endpoint, api_key); - (Arc::new(engine), prompt_template, chat_template) - } else { - panic!("Only openai are supported for http completion"); - } -} - -pub fn create_chat(model: &str) -> Arc { - let params = serde_json::from_str(model).expect("Failed to parse model string"); - let kind = get_param(¶ms, "kind"); - if kind == "openai-chat" { - let model_name = get_optional_param(¶ms, "model_name").unwrap_or_default(); - let api_endpoint = get_param(¶ms, "api_endpoint"); - let api_key = get_optional_param(¶ms, "api_key"); - - let engine = OpenAIChatEngine::create(&api_endpoint, &model_name, api_key); - Arc::new(engine) - } else { - panic!("Only openai-chat are supported for http chat"); - } -} - -fn get_param(params: &Value, key: &str) -> String { +pub(crate) fn get_param(params: &Value, key: &str) -> String { params .get(key) .unwrap_or_else(|| panic!("Missing {} field", key)) @@ -56,7 +16,7 @@ fn get_param(params: &Value, key: &str) -> String { .to_owned() } -fn get_optional_param(params: &Value, key: &str) -> Option { +pub(crate) fn get_optional_param(params: &Value, key: &str) -> Option { params .get(key) .map(|x| x.as_str().expect("Type unmatched").to_owned()) diff --git a/crates/tabby-inference/src/embedding.rs b/crates/tabby-inference/src/embedding.rs new file mode 100644 index 000000000..428a4195f --- /dev/null +++ b/crates/tabby-inference/src/embedding.rs @@ -0,0 +1,6 @@ +use async_trait::async_trait; + +#[async_trait] +pub trait Embedding: Sync + Send { + async fn embed(&self, prompt: &str) -> anyhow::Result>; +} diff --git a/crates/tabby-inference/src/lib.rs b/crates/tabby-inference/src/lib.rs index b498d2408..0402531b3 100644 --- a/crates/tabby-inference/src/lib.rs +++ b/crates/tabby-inference/src/lib.rs @@ -3,10 +3,12 @@ mod chat; mod code; mod completion; mod decoding; +mod embedding; pub use chat::{ChatCompletionOptions, ChatCompletionOptionsBuilder, ChatCompletionStream}; pub use code::{CodeGeneration, CodeGenerationOptions, CodeGenerationOptionsBuilder}; pub use completion::{CompletionOptions, CompletionOptionsBuilder, CompletionStream}; +pub use embedding::Embedding; fn default_seed() -> u64 { std::time::SystemTime::now()