mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
refactor(config): make HttpModelConfig.api_endpoint optional, so we can use default value for certain model kind (#2760)
This commit is contained in:
parent
bd83870668
commit
48b4ec476f
@ -7,8 +7,12 @@ use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};
|
||||
use crate::create_reqwest_client;
|
||||
|
||||
pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
|
||||
let api_endpoint = model
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required");
|
||||
let config = OpenAIConfig::default()
|
||||
.with_api_base(model.api_endpoint.clone())
|
||||
.with_api_base(api_endpoint)
|
||||
.with_api_key(model.api_key.clone().unwrap_or_default());
|
||||
|
||||
let mut builder = ExtendedOpenAIConfig::builder();
|
||||
@ -28,6 +32,6 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
|
||||
|
||||
Arc::new(
|
||||
async_openai::Client::with_config(config)
|
||||
.with_http_client(create_reqwest_client(&model.api_endpoint)),
|
||||
.with_http_client(create_reqwest_client(api_endpoint)),
|
||||
)
|
||||
}
|
||||
|
@ -13,14 +13,23 @@ use tabby_inference::CompletionStream;
|
||||
pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
|
||||
match model.kind.as_str() {
|
||||
"llama.cpp/completion" => {
|
||||
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
|
||||
let engine = LlamaCppEngine::create(
|
||||
model
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required"),
|
||||
model.api_key.clone(),
|
||||
);
|
||||
Arc::new(engine)
|
||||
}
|
||||
"ollama/completion" => ollama_api_bindings::create_completion(model).await,
|
||||
|
||||
"mistral/completion" => {
|
||||
let engine = MistralFIMEngine::create(
|
||||
&model.api_endpoint,
|
||||
model
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required"),
|
||||
model.api_key.clone(),
|
||||
model.model_name.clone(),
|
||||
);
|
||||
@ -29,7 +38,10 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
|
||||
"openai/completion" => {
|
||||
let engine = OpenAICompletionEngine::create(
|
||||
model.model_name.clone(),
|
||||
&model.api_endpoint,
|
||||
model
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required"),
|
||||
model.api_key.clone(),
|
||||
);
|
||||
Arc::new(engine)
|
||||
|
@ -14,12 +14,21 @@ use self::{openai::OpenAIEmbeddingEngine, voyage::VoyageEmbeddingEngine};
|
||||
pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
|
||||
match config.kind.as_str() {
|
||||
"llama.cpp/embedding" => {
|
||||
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
|
||||
let engine = LlamaCppEngine::create(
|
||||
config
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required"),
|
||||
config.api_key.clone(),
|
||||
);
|
||||
Arc::new(engine)
|
||||
}
|
||||
"openai/embedding" => {
|
||||
let engine = OpenAIEmbeddingEngine::create(
|
||||
&config.api_endpoint,
|
||||
config
|
||||
.api_endpoint
|
||||
.as_deref()
|
||||
.expect("api_endpoint is required"),
|
||||
config.model_name.as_deref().unwrap_or_default(),
|
||||
config.api_key.as_deref(),
|
||||
);
|
||||
@ -28,7 +37,7 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
|
||||
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,
|
||||
"voyage/embedding" => {
|
||||
let engine = VoyageEmbeddingEngine::create(
|
||||
&config.api_endpoint,
|
||||
config.api_endpoint.as_deref(),
|
||||
config
|
||||
.model_name
|
||||
.as_deref()
|
||||
|
@ -14,16 +14,12 @@ pub struct VoyageEmbeddingEngine {
|
||||
}
|
||||
|
||||
impl VoyageEmbeddingEngine {
|
||||
pub fn create(api_endpoint: &str, model_name: &str, api_key: String) -> Self {
|
||||
let endpoint = if api_endpoint.is_empty() {
|
||||
DEFAULT_VOYAGE_API_ENDPOINT
|
||||
} else {
|
||||
api_endpoint
|
||||
};
|
||||
pub fn create(api_endpoint: Option<&str>, model_name: &str, api_key: String) -> Self {
|
||||
let api_endpoint = api_endpoint.unwrap_or(DEFAULT_VOYAGE_API_ENDPOINT);
|
||||
let client = Client::new();
|
||||
Self {
|
||||
client,
|
||||
api_endpoint: format!("{}/v1/embeddings", endpoint),
|
||||
api_endpoint: format!("{}/v1/embeddings", api_endpoint),
|
||||
api_key,
|
||||
model_name: model_name.to_owned(),
|
||||
}
|
||||
@ -62,9 +58,10 @@ impl Embedding for VoyageEmbeddingEngine {
|
||||
.bearer_auth(&self.api_key);
|
||||
|
||||
let response = request_builder.send().await?;
|
||||
if response.status().is_server_error() {
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
return Err(anyhow::anyhow!("Error from server: {}", error));
|
||||
return Err(anyhow::anyhow!("Error {}: {}", status.as_u16(), error));
|
||||
}
|
||||
|
||||
let response_body = response
|
||||
@ -85,13 +82,12 @@ impl Embedding for VoyageEmbeddingEngine {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Make sure you have set the VOYAGE_API_KEY environment variable before running the test
|
||||
/// VOYAGE_API_KEY=xxx cargo test test_voyage_embedding -- --ignored
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn test_voyage_embedding() {
|
||||
let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY must be set");
|
||||
let engine =
|
||||
VoyageEmbeddingEngine::create(DEFAULT_VOYAGE_API_ENDPOINT, "voyage-code-2", api_key);
|
||||
let engine = VoyageEmbeddingEngine::create(None, "voyage-code-2", api_key);
|
||||
let embedding = engine.embed("Hello, world!").await.unwrap();
|
||||
assert_eq!(embedding.len(), 1536);
|
||||
}
|
||||
|
@ -45,7 +45,7 @@ impl EmbeddingServer {
|
||||
server.start().await;
|
||||
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(api_endpoint(server.port()))
|
||||
.api_endpoint(Some(api_endpoint(server.port())))
|
||||
.kind("llama.cpp/embedding".to_string())
|
||||
.build()
|
||||
.expect("Failed to create HttpModelConfig");
|
||||
@ -90,7 +90,7 @@ impl CompletionServer {
|
||||
);
|
||||
server.start().await;
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(api_endpoint(server.port()))
|
||||
.api_endpoint(Some(api_endpoint(server.port())))
|
||||
.kind("llama.cpp/completion".to_string())
|
||||
.build()
|
||||
.expect("Failed to create HttpModelConfig");
|
||||
@ -133,7 +133,7 @@ impl ChatCompletionServer {
|
||||
);
|
||||
server.start().await;
|
||||
let config = HttpModelConfigBuilder::default()
|
||||
.api_endpoint(api_endpoint(server.port()))
|
||||
.api_endpoint(Some(api_endpoint(server.port())))
|
||||
.kind("openai/chat".to_string())
|
||||
.model_name(Some("local".into()))
|
||||
.build()
|
||||
|
@ -59,7 +59,7 @@ impl CompletionStream for OllamaCompletion {
|
||||
}
|
||||
|
||||
pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> {
|
||||
let connection = Ollama::try_new(config.api_endpoint.to_owned())
|
||||
let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned())
|
||||
.expect("Failed to create connection to Ollama, URL invalid");
|
||||
|
||||
let model = connection.select_model_or_default(config).await.unwrap();
|
||||
|
@ -27,7 +27,7 @@ impl Embedding for OllamaCompletion {
|
||||
}
|
||||
|
||||
pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
|
||||
let connection = Ollama::try_new(config.api_endpoint.to_owned())
|
||||
let connection = Ollama::try_new(config.api_endpoint.as_deref().unwrap().to_owned())
|
||||
.expect("Failed to create connection to Ollama, URL invalid");
|
||||
|
||||
let model = connection.select_model_or_default(config).await.unwrap();
|
||||
|
@ -232,7 +232,7 @@ pub struct HttpModelConfig {
|
||||
/// - llama.cpp/embedding: llama.cpp `/embedding` API.
|
||||
pub kind: String,
|
||||
|
||||
pub api_endpoint: String,
|
||||
pub api_endpoint: Option<String>,
|
||||
|
||||
#[builder(default)]
|
||||
pub api_key: Option<String>,
|
||||
@ -248,9 +248,6 @@ pub struct HttpModelConfig {
|
||||
/// Used by Completion API to construct a chat model.
|
||||
#[builder(default)]
|
||||
pub chat_template: Option<String>,
|
||||
|
||||
#[builder(default)]
|
||||
pub max_input_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
|
@ -77,7 +77,7 @@ impl CodeIntelligence {
|
||||
return None;
|
||||
}
|
||||
let relative_path = path
|
||||
.strip_prefix(&config.dir())
|
||||
.strip_prefix(config.dir())
|
||||
.expect("Paths always begin with the prefix");
|
||||
|
||||
let Some(ext) = relative_path.extension() else {
|
||||
|
Loading…
Reference in New Issue
Block a user