diff --git a/crates/tabby-common/src/events.rs b/crates/tabby-common/src/events.rs index 88e683b93..7f8e2161d 100644 --- a/crates/tabby-common/src/events.rs +++ b/crates/tabby-common/src/events.rs @@ -56,15 +56,15 @@ pub enum Event<'a> { completion_id: &'a str, language: &'a str, prompt: &'a str, - segments: &'a Segments<'a>, + segments: &'a Option, choices: Vec>, user: Option<&'a str>, }, } #[derive(Serialize)] -pub struct Segments<'a> { - pub prefix: &'a str, - pub suffix: Option<&'a str>, +pub struct Segments { + pub prefix: String, + pub suffix: Option, } #[derive(Serialize)] diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index f9f5b0500..8dd6e3eac 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -21,10 +21,6 @@ use super::search::IndexServer; } }))] pub struct CompletionRequest { - #[deprecated] - #[schema(example = "def fib(n):")] - prompt: Option, - /// Language identifier, full list is maintained at /// https://code.visualstudio.com/docs/languages/identifiers #[schema(example = "python")] @@ -42,9 +38,18 @@ pub struct CompletionRequest { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct DebugOptions { - /// When true, returns debug_data in completion response. + /// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored. + /// + /// This is useful for certain requests that aim to test the tabby's e2e quality. + raw_prompt: Option, + + /// When true, returns `snippets` in `debug_data`. #[serde(default = "default_false")] - enabled: bool, + return_snippets: bool, + + /// When true, returns `prompt` in `debug_data`. + #[serde(default = "default_false")] + return_prompt: bool, /// When true, disable retrieval augmented code completion. #[serde(default = "default_false")] @@ -92,10 +97,11 @@ pub struct CompletionResponse { #[derive(Serialize, Deserialize, ToSchema, Clone, Debug)] pub struct DebugData { - #[serde(skip_serializing_if = "Vec::is_empty")] - snippets: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + snippets: Option>, - prompt: String, + #[serde(skip_serializing_if = "Option::is_none")] + prompt: Option, } #[utoipa::path( @@ -123,42 +129,34 @@ pub async fn completions( .build() .unwrap(); - let segments = if let Some(segments) = request.segments { - segments - } else if let Some(prompt) = request.prompt { - Segments { - prefix: prompt, - suffix: None, - } + let (prompt, segments, snippets) = if let Some(prompt) = request + .debug_options + .as_ref() + .and_then(|x| x.raw_prompt.clone()) + { + (prompt, None, vec![]) + } else if let Some(segments) = request.segments { + debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); + let (prompt, snippets) = build_prompt(&state, &request.debug_options, &language, &segments); + (prompt, Some(segments), snippets) } else { return Err(StatusCode::BAD_REQUEST); }; - - debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix); - let snippets = if !request - .debug_options - .as_ref() - .is_some_and(|x| x.disable_retrieval_augmented_code_completion) - { - state.prompt_builder.collect(&language, &segments) - } else { - vec![] - }; - let prompt = state - .prompt_builder - .build(&language, segments.clone(), &snippets); debug!("PROMPT: {}", prompt); + let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let text = state.engine.generate(&prompt, options).await; + let segments = segments.map(|x| tabby_common::events::Segments { + prefix: x.prefix, + suffix: x.suffix, + }); + events::Event::Completion { completion_id: &completion_id, language: &language, prompt: &prompt, - segments: &tabby_common::events::Segments { - prefix: &segments.prefix, - suffix: segments.suffix.as_deref(), - }, + segments: &segments, choices: vec![events::Choice { index: 0, text: &text, @@ -167,19 +165,43 @@ pub async fn completions( } .log(); - let debug_data = DebugData { snippets, prompt }; + let debug_data = request + .debug_options + .as_ref() + .map(|debug_options| DebugData { + snippets: debug_options.return_snippets.then_some(snippets), + prompt: debug_options.return_prompt.then_some(prompt), + }); Ok(Json(CompletionResponse { id: completion_id, choices: vec![Choice { index: 0, text }], - debug_data: if request.debug_options.is_some_and(|x| x.enabled) { - Some(debug_data) - } else { - None - }, + debug_data, })) } +fn build_prompt( + state: &Arc, + debug_options: &Option, + language: &str, + segments: &Segments, +) -> (String, Vec) { + let snippets = if !debug_options + .as_ref() + .is_some_and(|x| x.disable_retrieval_augmented_code_completion) + { + state.prompt_builder.collect(language, segments) + } else { + vec![] + }; + ( + state + .prompt_builder + .build(language, segments.clone(), &snippets), + snippets, + ) +} + pub struct CompletionState { engine: Arc>, prompt_builder: prompt::PromptBuilder, diff --git a/experimental/eval/tabby.py b/experimental/eval/tabby.py index 019ac05fc..99de85bc8 100644 --- a/experimental/eval/tabby.py +++ b/experimental/eval/tabby.py @@ -82,10 +82,10 @@ class Model: @method() async def complete(self, language: str, prompt: str): from tabby_client.api.v1 import completion - from tabby_client.models import CompletionRequest, CompletionResponse, Segments + from tabby_client.models import CompletionRequest, DebugOptions, CompletionResponse, Segments request = CompletionRequest( - language=language, prompt=prompt + language=language, debug_options=DebugOptions(raw_prompt=prompt) ) resp: CompletionResponse = await completion.asyncio(client=self.client, json_body=request) return resp.choices[0].text diff --git a/experimental/scheduler/completion.py b/experimental/scheduler/completion.py index 3142a4abd..3b49d1a3c 100644 --- a/experimental/scheduler/completion.py +++ b/experimental/scheduler/completion.py @@ -11,7 +11,7 @@ language = st.text_input("Language", "rust") query = st.text_area("Query", "to_owned") if query: - r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(enabled=True))) + r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(return_snippets=True, return_prompt=True))) json = r.json() debug = json["debug_data"] snippets = debug.get("snippets", []) @@ -25,4 +25,4 @@ if query: for x in snippets: st.write(f"**{x['filepath']}**: {x['score']}") st.write(f"Length: {len(x['body'])}") - st.code(x['body']) \ No newline at end of file + st.code(x['body'])