mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
feat: add support for OpenAI completion API (#2604)
* feat: add support for OpenAI completion endpoint * add openai/completion example in model configuration documentation
This commit is contained in:
parent
dcc91d1d6c
commit
510a63c095
@ -1,10 +1,12 @@
|
||||
mod llama;
|
||||
mod mistral;
|
||||
mod openai;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use llama::LlamaCppEngine;
|
||||
use mistral::MistralFIMEngine;
|
||||
use openai::OpenAICompletionEngine;
|
||||
use tabby_common::config::HttpModelConfig;
|
||||
use tabby_inference::CompletionStream;
|
||||
|
||||
@ -24,6 +26,14 @@ pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
|
||||
);
|
||||
Arc::new(engine)
|
||||
}
|
||||
"openai/completion" => {
|
||||
let engine = OpenAICompletionEngine::create(
|
||||
model.model_name.clone(),
|
||||
&model.api_endpoint,
|
||||
model.api_key.clone(),
|
||||
);
|
||||
Arc::new(engine)
|
||||
}
|
||||
|
||||
unsupported_kind => panic!(
|
||||
"Unsupported model kind for http completion: {}",
|
||||
|
92
crates/http-api-bindings/src/completion/openai.rs
Normal file
92
crates/http-api-bindings/src/completion/openai.rs
Normal file
@ -0,0 +1,92 @@
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use futures::{stream::BoxStream, StreamExt};
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_inference::{CompletionOptions, CompletionStream};
|
||||
|
||||
pub struct OpenAICompletionEngine {
|
||||
client: reqwest::Client,
|
||||
model_name: String,
|
||||
api_endpoint: String,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl OpenAICompletionEngine {
|
||||
pub fn create(model_name: Option<String>, api_endpoint: &str, api_key: Option<String>) -> Self {
|
||||
let model_name = model_name.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
Self {
|
||||
client,
|
||||
model_name,
|
||||
api_endpoint: format!("{}/completions", api_endpoint),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CompletionRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
max_tokens: i32,
|
||||
temperature: f32,
|
||||
stream: bool,
|
||||
presence_penalty: f32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CompletionResponseChunk {
|
||||
choices: Vec<CompletionResponseChoice>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CompletionResponseChoice {
|
||||
text: String,
|
||||
finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl CompletionStream for OpenAICompletionEngine {
|
||||
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
|
||||
let request = CompletionRequest {
|
||||
model: self.model_name.clone(),
|
||||
prompt: prompt.to_owned(),
|
||||
max_tokens: options.max_decoding_tokens,
|
||||
temperature: options.sampling_temperature,
|
||||
stream: true,
|
||||
presence_penalty: options.presence_penalty,
|
||||
};
|
||||
|
||||
let mut request = self.client.post(&self.api_endpoint).json(&request);
|
||||
if let Some(api_key) = &self.api_key {
|
||||
request = request.bearer_auth(api_key);
|
||||
}
|
||||
|
||||
let s = stream! {
|
||||
let mut es = EventSource::new(request).expect("Failed to create event source");
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Open) => {}
|
||||
Ok(Event::Message(message)) => {
|
||||
let x: CompletionResponseChunk = serde_json::from_str(&message.data).expect("Failed to parse response");
|
||||
if let Some(choice) = x.choices.first() {
|
||||
yield choice.text.clone();
|
||||
|
||||
if choice.finish_reason.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// StreamEnd
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Box::pin(s)
|
||||
}
|
||||
}
|
12
website/docs/administration/model.md
vendored
12
website/docs/administration/model.md
vendored
@ -49,6 +49,18 @@ api_endpoint = "https://api.mistral.ai"
|
||||
api_key = "secret-api-key"
|
||||
```
|
||||
|
||||
#### [openai completion](https://platform.openai.com/docs/api-reference/completions)
|
||||
|
||||
Configure Tabby with an OpenAI-compatible completion model (`/v1/completions`) using an online service or a self-hosted backend (vLLM, Nvidia NIM, LocalAI, ...) as follows:
|
||||
|
||||
```toml
|
||||
[model.completion.http]
|
||||
kind = "openai/completion"
|
||||
model_name = "your_model"
|
||||
api_endpoint = "https://url_to_your_backend_or_service"
|
||||
api_key = "secret-api-key"
|
||||
```
|
||||
|
||||
### Chat Model
|
||||
|
||||
Chat models adhere to the standard interface specified by OpenAI's `/chat/completions` API.
|
||||
|
Loading…
Reference in New Issue
Block a user