feat(eval): add debug_options.raw_prompt to tabby api for evaluation purpose (#605)

* update eval

* feat: re-purpose  as raw input to LLM

* move prompt to Debug options

* Update crates/tabby/src/serve/completions.rs
This commit is contained in:
Meng Zhang 2023-10-21 13:47:44 -07:00 committed by GitHub
parent 049ebdf9a9
commit 8fca850037
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 70 additions and 48 deletions

View File

@ -56,15 +56,15 @@ pub enum Event<'a> {
completion_id: &'a str,
language: &'a str,
prompt: &'a str,
segments: &'a Segments<'a>,
segments: &'a Option<Segments>,
choices: Vec<Choice<'a>>,
user: Option<&'a str>,
},
}
#[derive(Serialize)]
pub struct Segments<'a> {
pub prefix: &'a str,
pub suffix: Option<&'a str>,
pub struct Segments {
pub prefix: String,
pub suffix: Option<String>,
}
#[derive(Serialize)]

View File

@ -21,10 +21,6 @@ use super::search::IndexServer;
}
}))]
pub struct CompletionRequest {
#[deprecated]
#[schema(example = "def fib(n):")]
prompt: Option<String>,
/// Language identifier, full list is maintained at
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
@ -42,9 +38,18 @@ pub struct CompletionRequest {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugOptions {
/// When true, returns debug_data in completion response.
/// When `raw_prompt` is specified, it will be passed directly to the inference engine for completion. `segments` field in `CompletionRequest` will be ignored.
///
/// This is useful for certain requests that aim to test the tabby's e2e quality.
raw_prompt: Option<String>,
/// When true, returns `snippets` in `debug_data`.
#[serde(default = "default_false")]
enabled: bool,
return_snippets: bool,
/// When true, returns `prompt` in `debug_data`.
#[serde(default = "default_false")]
return_prompt: bool,
/// When true, disable retrieval augmented code completion.
#[serde(default = "default_false")]
@ -92,10 +97,11 @@ pub struct CompletionResponse {
#[derive(Serialize, Deserialize, ToSchema, Clone, Debug)]
pub struct DebugData {
#[serde(skip_serializing_if = "Vec::is_empty")]
snippets: Vec<Snippet>,
#[serde(skip_serializing_if = "Option::is_none")]
snippets: Option<Vec<Snippet>>,
prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
prompt: Option<String>,
}
#[utoipa::path(
@ -123,42 +129,34 @@ pub async fn completions(
.build()
.unwrap();
let segments = if let Some(segments) = request.segments {
segments
} else if let Some(prompt) = request.prompt {
Segments {
prefix: prompt,
suffix: None,
}
let (prompt, segments, snippets) = if let Some(prompt) = request
.debug_options
.as_ref()
.and_then(|x| x.raw_prompt.clone())
{
(prompt, None, vec![])
} else if let Some(segments) = request.segments {
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let (prompt, snippets) = build_prompt(&state, &request.debug_options, &language, &segments);
(prompt, Some(segments), snippets)
} else {
return Err(StatusCode::BAD_REQUEST);
};
debug!("PREFIX: {}, SUFFIX: {:?}", segments.prefix, segments.suffix);
let snippets = if !request
.debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
{
state.prompt_builder.collect(&language, &segments)
} else {
vec![]
};
let prompt = state
.prompt_builder
.build(&language, segments.clone(), &snippets);
debug!("PROMPT: {}", prompt);
let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4());
let text = state.engine.generate(&prompt, options).await;
let segments = segments.map(|x| tabby_common::events::Segments {
prefix: x.prefix,
suffix: x.suffix,
});
events::Event::Completion {
completion_id: &completion_id,
language: &language,
prompt: &prompt,
segments: &tabby_common::events::Segments {
prefix: &segments.prefix,
suffix: segments.suffix.as_deref(),
},
segments: &segments,
choices: vec![events::Choice {
index: 0,
text: &text,
@ -167,19 +165,43 @@ pub async fn completions(
}
.log();
let debug_data = DebugData { snippets, prompt };
let debug_data = request
.debug_options
.as_ref()
.map(|debug_options| DebugData {
snippets: debug_options.return_snippets.then_some(snippets),
prompt: debug_options.return_prompt.then_some(prompt),
});
Ok(Json(CompletionResponse {
id: completion_id,
choices: vec![Choice { index: 0, text }],
debug_data: if request.debug_options.is_some_and(|x| x.enabled) {
Some(debug_data)
} else {
None
},
debug_data,
}))
}
fn build_prompt(
state: &Arc<CompletionState>,
debug_options: &Option<DebugOptions>,
language: &str,
segments: &Segments,
) -> (String, Vec<Snippet>) {
let snippets = if !debug_options
.as_ref()
.is_some_and(|x| x.disable_retrieval_augmented_code_completion)
{
state.prompt_builder.collect(language, segments)
} else {
vec![]
};
(
state
.prompt_builder
.build(language, segments.clone(), &snippets),
snippets,
)
}
pub struct CompletionState {
engine: Arc<Box<dyn TextGeneration>>,
prompt_builder: prompt::PromptBuilder,

View File

@ -82,10 +82,10 @@ class Model:
@method()
async def complete(self, language: str, prompt: str):
from tabby_client.api.v1 import completion
from tabby_client.models import CompletionRequest, CompletionResponse, Segments
from tabby_client.models import CompletionRequest, DebugOptions, CompletionResponse, Segments
request = CompletionRequest(
language=language, prompt=prompt
language=language, debug_options=DebugOptions(raw_prompt=prompt)
)
resp: CompletionResponse = await completion.asyncio(client=self.client, json_body=request)
return resp.choices[0].text

View File

@ -11,7 +11,7 @@ language = st.text_input("Language", "rust")
query = st.text_area("Query", "to_owned")
if query:
r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(enabled=True)))
r = requests.post("http://localhost:8080/v1/completions", json=dict(segments=dict(prefix=query), language=language, debug_options=dict(return_snippets=True, return_prompt=True)))
json = r.json()
debug = json["debug_data"]
snippets = debug.get("snippets", [])
@ -25,4 +25,4 @@ if query:
for x in snippets:
st.write(f"**{x['filepath']}**: {x['score']}")
st.write(f"Length: {len(x['body'])}")
st.code(x['body'])
st.code(x['body'])