feat: add support for OpenAI completion API (#2604)

* feat: add support for OpenAI completion endpoint

* add openai/completion example in model configuration documentation
This commit is contained in:
Mehdi CHTAYTI 2024-07-10 11:40:51 +02:00 committed by GitHub
parent dcc91d1d6c
commit 510a63c095
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 114 additions and 0 deletions

View File

@ -1,10 +1,12 @@
mod llama;
mod mistral;
mod openai;
use std::sync::Arc;
use llama::LlamaCppEngine;
use mistral::MistralFIMEngine;
use openai::OpenAICompletionEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;
@ -24,6 +26,14 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
);
Arc::new(engine)
}
"openai/completion" => {
let engine = OpenAICompletionEngine::create(
model.model_name.clone(),
&model.api_endpoint,
model.api_key.clone(),
);
Arc::new(engine)
}
unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",

View File

@ -0,0 +1,92 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use reqwest_eventsource::{Event, EventSource};
use serde::{Deserialize, Serialize};
use tabby_inference::{CompletionOptions, CompletionStream};
pub struct OpenAICompletionEngine {
client: reqwest::Client,
model_name: String,
api_endpoint: String,
api_key: Option<String>,
}
impl OpenAICompletionEngine {
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
let model_name = model_name.unwrap();
let client = reqwest::Client::new();
Self {
client,
model_name,
api_endpoint: format!("{}/completions", api_endpoint),
api_key,
}
}
}
#[derive(Serialize)]
struct CompletionRequest {
model: String,
prompt: String,
max_tokens: i32,
temperature: f32,
stream: bool,
presence_penalty: f32,
}
#[derive(Deserialize)]
struct CompletionResponseChunk {
choices: Vec<CompletionResponseChoice>,
}
#[derive(Deserialize)]
struct CompletionResponseChoice {
text: String,
finish_reason: Option<String>,
}
#[async_trait]
impl CompletionStream for OpenAICompletionEngine {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let request = CompletionRequest {
model: self.model_name.clone(),
prompt: prompt.to_owned(),
max_tokens: options.max_decoding_tokens,
temperature: options.sampling_temperature,
stream: true,
presence_penalty: options.presence_penalty,
};
let mut request = self.client.post(&self.api_endpoint).json(&request);
if let Some(api_key) = &self.api_key {
request = request.bearer_auth(api_key);
}
let s = stream! {
let mut es = EventSource::new(request).expect("Failed to create event source");
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(message)) => {
let x: CompletionResponseChunk = serde_json::from_str(&message.data).expect("Failed to parse response");
if let Some(choice) = x.choices.first() {
yield choice.text.clone();
if choice.finish_reason.is_some() {
break;
}
}
}
Err(_) => {
// StreamEnd
break;
}
}
}
};
Box::pin(s)
}
}

View File

@ -49,6 +49,18 @@ api_endpoint = "https://api.mistral.ai"
api_key = "secret-api-key"
```
#### [openai completion](https://platform.openai.com/docs/api-reference/completions)
Configure Tabby with an OpenAI-compatible completion model (`/v1/completions`) using an online service or a self-hosted backend (vLLM, Nvidia NIM, LocalAI, ...) as follows:
```toml
[model.completion.http]
kind = "openai/completion"
model_name = "your_model"
api_endpoint = "https://url_to_your_backend_or_service"
api_key = "secret-api-key"
```
### Chat Model
Chat models adhere to the standard interface specified by OpenAI's `/chat/completions` API.