feat: add server.completion_timeout to control timeout of /v1/completion (#637)

* feat: add server.completion_timeout to control timeout of /v1/completion

* Update config.rs
This commit is contained in:
Meng Zhang 2023-10-25 15:05:23 -07:00 committed by GitHub
parent d6296bb121
commit 21ec60eddf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 16 deletions

View File

@ -11,7 +11,10 @@ use crate::path::{config_file, repositories_dir};
#[derive(Serialize, Deserialize, Default)]
pub struct Config {
#[serde(default)]
pub repositories: Vec<Repository>,
pub repositories: Vec<RepositoryConfig>,
#[serde(default)]
pub server: ServerConfig,
}
impl Config {
@ -37,11 +40,11 @@ impl Config {
}
#[derive(Serialize, Deserialize)]
pub struct Repository {
pub struct RepositoryConfig {
pub git_url: String,
}
impl Repository {
impl RepositoryConfig {
pub fn dir(&self) -> PathBuf {
if self.is_local_dir() {
let path = self.git_url.strip_prefix("file://").unwrap();
@ -56,9 +59,23 @@ impl Repository {
}
}
#[derive(Serialize, Deserialize)]
pub struct ServerConfig {
/// The timeout in seconds for the /v1/completion api.
pub completion_timeout: u64,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
completion_timeout: 30,
}
}
}
#[cfg(test)]
mod tests {
use super::{Config, Repository};
use super::{Config, RepositoryConfig};
#[test]
fn it_parses_empty_config() {
@ -68,13 +85,13 @@ mod tests {
#[test]
fn it_parses_local_dir() {
let repo = Repository {
let repo = RepositoryConfig {
git_url: "file:///home/user".to_owned(),
};
assert!(repo.is_local_dir());
assert_eq!(repo.dir().display().to_string(), "/home/user");
let repo = Repository {
let repo = RepositoryConfig {
git_url: "https://github.com/TabbyML/tabby".to_owned(),
};
assert!(!repo.is_local_dir());

View File

@ -11,7 +11,7 @@ use ignore::{DirEntry, Walk};
use lazy_static::lazy_static;
use serde_jsonlines::WriteExt;
use tabby_common::{
config::{Config, Repository},
config::{Config, RepositoryConfig},
path::dataset_dir,
SourceFile,
};
@ -22,7 +22,7 @@ trait RepositoryExt {
fn create_dataset(&self, writer: &mut impl Write) -> Result<()>;
}
impl RepositoryExt for Repository {
impl RepositoryExt for RepositoryConfig {
fn create_dataset(&self, writer: &mut impl Write) -> Result<()> {
let dir = self.dir();

View File

@ -1,7 +1,7 @@
use std::process::Command;
use anyhow::{anyhow, Result};
use tabby_common::config::{Config, Repository};
use tabby_common::config::{Config, RepositoryConfig};
trait ConfigExt {
fn sync_repositories(&self) -> Result<()>;
@ -27,7 +27,7 @@ trait RepositoryExt {
fn sync(&self) -> Result<()>;
}
impl RepositoryExt for Repository {
impl RepositoryExt for RepositoryConfig {
fn sync(&self) -> Result<()> {
let dir = self.dir();
let dir_string = dir.display().to_string();

View File

@ -3,7 +3,7 @@ mod tests {
use std::fs::create_dir_all;
use tabby_common::{
config::{Config, Repository},
config::{Config, RepositoryConfig, ServerConfig},
path::set_tabby_root,
};
use temp_testdir::*;
@ -17,9 +17,10 @@ mod tests {
set_tabby_root(root.to_path_buf());
let config = Config {
repositories: vec![Repository {
repositories: vec![RepositoryConfig {
git_url: "https://github.com/TabbyML/interview-questions".to_owned(),
}],
server: ServerConfig::default(),
};
config.save();

View File

@ -125,7 +125,7 @@ fn should_download_ggml_files(device: &Device) -> bool {
*device == Device::Metal
}
pub async fn main(_config: &Config, args: &ServeArgs) {
pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);
if args.device != Device::ExperimentalHttp {
@ -146,7 +146,7 @@ pub async fn main(_config: &Config, args: &ServeArgs) {
.route("/", routing::get(playground::handler))
.route("/index.txt", routing::get(playground::handler))
.route("/_next/*path", routing::get(playground::handler))
.merge(api_router(args))
.merge(api_router(args, config))
.merge(SwaggerUi::new("/swagger-ui").url("/api-docs/openapi.json", doc));
let app = if args.chat_model.is_some() {
@ -166,7 +166,7 @@ pub async fn main(_config: &Config, args: &ServeArgs) {
.unwrap_or_else(|err| fatal!("Error happens during serving: {}", err))
}
fn api_router(args: &ServeArgs) -> Router {
fn api_router(args: &ServeArgs, config: &Config) -> Router {
let index_server = Arc::new(IndexServer::new());
let completion_state = {
let (
@ -218,7 +218,9 @@ fn api_router(args: &ServeArgs) -> Router {
"/v1/completions",
routing::post(completions::completions).with_state(completion_state),
)
.layer(TimeoutLayer::new(Duration::from_secs(3)))
.layer(TimeoutLayer::new(Duration::from_secs(
config.server.completion_timeout,
)))
});
if let Some(chat_state) = chat_state {