refactor(config): make HttpModelConfig.api_endpoint optional, so we can use default value for certain model kind (#2760)

This commit is contained in:
Meng Zhang 2024-07-31 11:22:29 -07:00 committed by GitHub
parent bd83870668
commit 48b4ec476f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 48 additions and 30 deletions

View File

@ -7,8 +7,12 @@ use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};
use crate::create_reqwest_client;
pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let api_endpoint = model
.api_endpoint
.as_deref()
.expect("api_endpoint is required");
let config = OpenAIConfig::default()
.with_api_base(model.api_endpoint.clone())
.with_api_base(api_endpoint)
.with_api_key(model.api_key.clone().unwrap_or_default());
let mut builder = ExtendedOpenAIConfig::builder();
@ -28,6 +32,6 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
Arc::new(
async_openai::Client::with_config(config)
.with_http_client(create_reqwest_client(&model.api_endpoint)),
.with_http_client(create_reqwest_client(api_endpoint)),
)
}

View File

@ -13,14 +13,23 @@ use tabby_inference::CompletionStream;
pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
match model.kind.as_str() {
"llama.cpp/completion" => {
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
let engine = LlamaCppEngine::create(
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
);
Arc::new(engine)
}
"ollama/completion" => ollama_api_bindings::create_completion(model).await,
"mistral/completion" => {
let engine = MistralFIMEngine::create(
&model.api_endpoint,
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
model.model_name.clone(),
);
@ -29,7 +38,10 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
"openai/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
&model.api_endpoint,
model
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
model.api_key.clone(),
);
Arc::new(engine)

View File

@ -14,12 +14,21 @@ use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};
pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
match config.kind.as_str() {
"llama.cpp/embedding" => {
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
let engine = LlamaCppEngine::create(
config
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
config.api_key.clone(),
);
Arc::new(engine)
}
"openai/embedding" => {
let engine = OpenAIEmbeddingEngine::create(
&config.api_endpoint,
config
.api_endpoint
.as_deref()
.expect("api_endpoint is required"),
config.model_name.as_deref().unwrap_or_default(),
config.api_key.as_deref(),
);
@ -28,7 +37,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,
"voyage/embedding" => {
let engine = VoyageEmbeddingEngine::create(
&config.api_endpoint,
config.api_endpoint.as_deref(),
config
.model_name
.as_deref()

View File

@ -14,16 +14,12 @@ pub struct VoyageEmbeddingEngine {
}
impl VoyageEmbeddingEngine {
pub fn create(api_endpoint: &str, model_name: &str, api_key: String) -> Self {
let endpoint = if api_endpoint.is_empty() {
DEFAULT_VOYAGE_API_ENDPOINT
} else {
api_endpoint
};
pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> Self {
let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT);
let client = Client::new();
Self {
client,
api_endpoint: format!("{}/v1/embeddings", endpoint),
api_endpoint: format!("{}/v1/embeddings", api_endpoint),
api_key,
model_name: model_name.to_owned(),
}
@ -62,9 +58,10 @@ impl Embedding for VoyageEmbeddingEngine {
.bearer_auth(&self.api_key);
let response = request_builder.send().await?;
if response.status().is_server_error() {
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
return Err(anyhow::anyhow!("Error from server: {}", error));
return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error));
}
let response_body = response
@ -85,13 +82,12 @@ impl Embedding for VoyageEmbeddingEngine {
mod tests {
use super::*;
/// Make sure you have set the VOYAGE_API_KEY environment variable before running the test
/// VOYAGE_API_KEY=xxx cargo test test_voyage_embedding -- --ignored
#[tokio::test]
#[ignore]
async fn test_voyage_embedding() {
let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set");
let engine =
VoyageEmbeddingEngine::create(DEFAULT_VOYAGE_API_ENDPOINT, "voyage-code-2", api_key);
let engine = VoyageEmbeddingEngine::create(None, "voyage-code-2", api_key);
let embedding = engine.embed("Hello, world!").await.unwrap();
assert_eq!(embedding.len(), 1536);
}

View File

@ -45,7 +45,7 @@ impl EmbeddingServer {
server.start().await;
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.api_endpoint(Some(api_endpoint(server.port())))
.kind("llama.cpp/embedding".to_string())
.build()
.expect("Failed to create HttpModelConfig");
@ -90,7 +90,7 @@ impl CompletionServer {
);
server.start().await;
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.api_endpoint(Some(api_endpoint(server.port())))
.kind("llama.cpp/completion".to_string())
.build()
.expect("Failed to create HttpModelConfig");
@ -133,7 +133,7 @@ impl ChatCompletionServer {
);
server.start().await;
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.api_endpoint(Some(api_endpoint(server.port())))
.kind("openai/chat".to_string())
.model_name(Some("local".into()))
.build()

View File

@ -59,7 +59,7 @@ impl CompletionStream for OllamaCompletion {
}
pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> {
let connection = Ollama::try_new(config.api_endpoint.to_owned())
let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned())
.expect("Failed to create connection to Ollama, URL invalid");
let model = connection.select_model_or_default(config).await.unwrap();

View File

@ -27,7 +27,7 @@ impl Embedding for OllamaCompletion {
}
pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
let connection = Ollama::try_new(config.api_endpoint.to_owned())
let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned())
.expect("Failed to create connection to Ollama, URL invalid");
let model = connection.select_model_or_default(config).await.unwrap();

View File

@ -232,7 +232,7 @@ pub struct HttpModelConfig {
/// - llama.cpp/embedding: llama.cpp `/embedding` API.
pub kind: String,
pub api_endpoint: String,
pub api_endpoint: Option<String>,
#[builder(default)]
pub api_key: Option<String>,
@ -248,9 +248,6 @@ pub struct HttpModelConfig {
/// Used by Completion API to construct a chat model.
#[builder(default)]
pub chat_template: Option<String>,
#[builder(default)]
pub max_input_length: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]

View File

@ -77,7 +77,7 @@ impl CodeIntelligence {
return None;
}
let relative_path = path
.strip_prefix(&config.dir())
.strip_prefix(config.dir())
.expect("Paths always begin with the prefix");
let Some(ext) = relative_path.extension() else {