diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index ed55a8d39..50ad471b8 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -7,8 +7,12 @@ use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig}; use crate::create_reqwest_client; pub async fn create(model: &HttpModelConfig) -> Arc { + let api_endpoint = model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"); let config = OpenAIConfig::default() - .with_api_base(model.api_endpoint.clone()) + .with_api_base(api_endpoint) .with_api_key(model.api_key.clone().unwrap_or_default()); let mut builder = ExtendedOpenAIConfig::builder(); @@ -28,6 +32,6 @@ pub async fn create(model: &HttpModelConfig) -> Arc { Arc::new( async_openai::Client::with_config(config) - .with_http_client(create_reqwest_client(&model.api_endpoint)), + .with_http_client(create_reqwest_client(api_endpoint)), ) } diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index d1a59ebd1..c49b1db75 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -13,14 +13,23 @@ use tabby_inference::CompletionStream; pub async fn create(model: &HttpModelConfig) -> Arc { match model.kind.as_str() { "llama.cpp/completion" => { - let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone()); + let engine = LlamaCppEngine::create( + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), + model.api_key.clone(), + ); Arc::new(engine) } "ollama/completion" => ollama_api_bindings::create_completion(model).await, "mistral/completion" => { let engine = MistralFIMEngine::create( - &model.api_endpoint, + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), model.api_key.clone(), model.model_name.clone(), ); @@ -29,7 +38,10 @@ pub async fn create(model: &HttpModelConfig) -> Arc { "openai/completion" => { let engine = OpenAICompletionEngine::create( model.model_name.clone(), - &model.api_endpoint, + model + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), model.api_key.clone(), ); Arc::new(engine) diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 464b29c95..fc42d123e 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -14,12 +14,21 @@ use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine}; pub async fn create(config: &HttpModelConfig) -> Arc { match config.kind.as_str() { "llama.cpp/embedding" => { - let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone()); + let engine = LlamaCppEngine::create( + config + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), + config.api_key.clone(), + ); Arc::new(engine) } "openai/embedding" => { let engine = OpenAIEmbeddingEngine::create( - &config.api_endpoint, + config + .api_endpoint + .as_deref() + .expect("api_endpoint is required"), config.model_name.as_deref().unwrap_or_default(), config.api_key.as_deref(), ); @@ -28,7 +37,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc { "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, "voyage/embedding" => { let engine = VoyageEmbeddingEngine::create( - &config.api_endpoint, + config.api_endpoint.as_deref(), config .model_name .as_deref() diff --git a/crates/http-api-bindings/src/embedding/voyage.rs b/crates/http-api-bindings/src/embedding/voyage.rs index 784215d09..e403d7900 100644 --- a/crates/http-api-bindings/src/embedding/voyage.rs +++ b/crates/http-api-bindings/src/embedding/voyage.rs @@ -14,16 +14,12 @@ pub struct VoyageEmbeddingEngine { } impl VoyageEmbeddingEngine { - pub fn create(api_endpoint: &str, model_name: &str, api_key: String) -> Self { - let endpoint = if api_endpoint.is_empty() { - DEFAULT_VOYAGE_API_ENDPOINT - } else { - api_endpoint - }; + pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> Self { + let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT); let client = Client::new(); Self { client, - api_endpoint: format!("{}/v1/embeddings", endpoint), + api_endpoint: format!("{}/v1/embeddings", api_endpoint), api_key, model_name: model_name.to_owned(), } @@ -62,9 +58,10 @@ impl Embedding for VoyageEmbeddingEngine { .bearer_auth(&self.api_key); let response = request_builder.send().await?; - if response.status().is_server_error() { + if !response.status().is_success() { + let status = response.status(); let error = response.text().await?; - return Err(anyhow::anyhow!("Error from server: {}", error)); + return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error)); } let response_body = response @@ -85,13 +82,12 @@ impl Embedding for VoyageEmbeddingEngine { mod tests { use super::*; - /// Make sure you have set the VOYAGE_API_KEY environment variable before running the test + /// VOYAGE_API_KEY=xxx cargo test test_voyage_embedding -- --ignored #[tokio::test] #[ignore] async fn test_voyage_embedding() { let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set"); - let engine = - VoyageEmbeddingEngine::create(DEFAULT_VOYAGE_API_ENDPOINT, "voyage-code-2", api_key); + let engine = VoyageEmbeddingEngine::create(None, "voyage-code-2", api_key); let embedding = engine.embed("Hello, world!").await.unwrap(); assert_eq!(embedding.len(), 1536); } diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index b96d9ec70..20d63532f 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -45,7 +45,7 @@ impl EmbeddingServer { server.start().await; let config = HttpModelConfigBuilder::default() - .api_endpoint(api_endpoint(server.port())) + .api_endpoint(Some(api_endpoint(server.port()))) .kind("llama.cpp/embedding".to_string()) .build() .expect("Failed to create HttpModelConfig"); @@ -90,7 +90,7 @@ impl CompletionServer { ); server.start().await; let config = HttpModelConfigBuilder::default() - .api_endpoint(api_endpoint(server.port())) + .api_endpoint(Some(api_endpoint(server.port()))) .kind("llama.cpp/completion".to_string()) .build() .expect("Failed to create HttpModelConfig"); @@ -133,7 +133,7 @@ impl ChatCompletionServer { ); server.start().await; let config = HttpModelConfigBuilder::default() - .api_endpoint(api_endpoint(server.port())) + .api_endpoint(Some(api_endpoint(server.port()))) .kind("openai/chat".to_string()) .model_name(Some("local".into())) .build() diff --git a/crates/ollama-api-bindings/src/completion.rs b/crates/ollama-api-bindings/src/completion.rs index 4daf7c973..fe405e5c6 100644 --- a/crates/ollama-api-bindings/src/completion.rs +++ b/crates/ollama-api-bindings/src/completion.rs @@ -59,7 +59,7 @@ impl CompletionStream for OllamaCompletion { } pub async fn create(config: &HttpModelConfig) -> Arc { - let connection = Ollama::try_new(config.api_endpoint.to_owned()) + let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned()) .expect("Failed to create connection to Ollama, URL invalid"); let model = connection.select_model_or_default(config).await.unwrap(); diff --git a/crates/ollama-api-bindings/src/embedding.rs b/crates/ollama-api-bindings/src/embedding.rs index 153a460cb..56050919d 100644 --- a/crates/ollama-api-bindings/src/embedding.rs +++ b/crates/ollama-api-bindings/src/embedding.rs @@ -27,7 +27,7 @@ impl Embedding for OllamaCompletion { } pub async fn create(config: &HttpModelConfig) -> Arc { - let connection = Ollama::try_new(config.api_endpoint.to_owned()) + let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned()) .expect("Failed to create connection to Ollama, URL invalid"); let model = connection.select_model_or_default(config).await.unwrap(); diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 28c07d3fc..537ff3d1e 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -232,7 +232,7 @@ pub struct HttpModelConfig { /// - llama.cpp/embedding: llama.cpp `/embedding` API. pub kind: String, - pub api_endpoint: String, + pub api_endpoint: Option, #[builder(default)] pub api_key: Option, @@ -248,9 +248,6 @@ pub struct HttpModelConfig { /// Used by Completion API to construct a chat model. #[builder(default)] pub chat_template: Option, - - #[builder(default)] - pub max_input_length: usize, } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/crates/tabby-index/src/code/intelligence.rs b/crates/tabby-index/src/code/intelligence.rs index a53e224ee..bad17f583 100644 --- a/crates/tabby-index/src/code/intelligence.rs +++ b/crates/tabby-index/src/code/intelligence.rs @@ -77,7 +77,7 @@ impl CodeIntelligence { return None; } let relative_path = path - .strip_prefix(&config.dir()) + .strip_prefix(config.dir()) .expect("Paths always begin with the prefix"); let Some(ext) = relative_path.extension() else {