mirror of
https://github.com/TabbyML/tabby
synced 2024-11-21 16:03:07 +00:00
feat(webserver): support persisted thread in answer engine (#2793)
* init commit * [autofix.ci] apply automated fixes * draft create_thread_run_api * cleanup * update * update * update * update * update * update * connect streaming * update * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * add back answer api for ease of migration * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
13a5f8bcd9
commit
5c9b62faa3
3
.changes/unreleased/Features-20240805-162901.yaml
Normal file
3
.changes/unreleased/Features-20240805-162901.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
kind: Features
|
||||
body: Support persisted thread discussion
|
||||
time: 2024-08-05T16:29:01.227967-07:00
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -5206,7 +5206,6 @@ name = "tabby-schema"
|
||||
version = "0.16.0-dev.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.22.1",
|
||||
|
@ -28,7 +28,6 @@ validator = { version = "0.18.1", features = ["derive"] }
|
||||
regex.workspace = true
|
||||
hash-ids.workspace = true
|
||||
url.workspace = true
|
||||
async-stream.workspace = true
|
||||
|
||||
[[example]]
|
||||
name = "update-schema"
|
||||
|
@ -78,6 +78,18 @@ enum RepositoryKind {
|
||||
GITLAB_SELF_HOSTED
|
||||
}
|
||||
|
||||
enum Role {
|
||||
USER
|
||||
ASSISTANT
|
||||
}
|
||||
|
||||
input CodeQueryInput {
|
||||
gitUrl: String!
|
||||
filepath: String
|
||||
language: String
|
||||
content: String!
|
||||
}
|
||||
|
||||
input CreateIntegrationInput {
|
||||
displayName: String!
|
||||
accessToken: String!
|
||||
@ -85,10 +97,35 @@ input CreateIntegrationInput {
|
||||
apiBase: String
|
||||
}
|
||||
|
||||
input CreateMessageInput {
|
||||
role: Role!
|
||||
content: String!
|
||||
attachments: MessageAttachmentInput
|
||||
}
|
||||
|
||||
input CreateThreadAndRunInput {
|
||||
thread: CreateThreadInput!
|
||||
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
|
||||
}
|
||||
|
||||
input CreateThreadInput {
|
||||
messages: [CreateMessageInput!]!
|
||||
}
|
||||
|
||||
input CreateThreadRunInput {
|
||||
threadId: ID!
|
||||
additionalMessages: [CreateMessageInput!]!
|
||||
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
|
||||
}
|
||||
|
||||
input CreateWebCrawlerUrlInput {
|
||||
url: String!
|
||||
}
|
||||
|
||||
input DocQueryInput {
|
||||
content: String!
|
||||
}
|
||||
|
||||
input EmailSettingInput {
|
||||
smtpUsername: String!
|
||||
fromAddress: String!
|
||||
@ -99,6 +136,15 @@ input EmailSettingInput {
|
||||
smtpPassword: String
|
||||
}
|
||||
|
||||
input MessageAttachmentCodeInput {
|
||||
filepath: String
|
||||
content: String!
|
||||
}
|
||||
|
||||
input MessageAttachmentInput {
|
||||
code: [MessageAttachmentCodeInput!]!
|
||||
}
|
||||
|
||||
input NetworkSettingInput {
|
||||
externalUrl: String!
|
||||
}
|
||||
@ -128,6 +174,12 @@ input SecuritySettingInput {
|
||||
disableClientSideTelemetry: Boolean!
|
||||
}
|
||||
|
||||
input ThreadRunOptionsInput {
|
||||
docQuery: DocQueryInput = null
|
||||
codeQuery: CodeQueryInput = null
|
||||
generateRelevantQuestions: Boolean! = false
|
||||
}
|
||||
|
||||
input UpdateIntegrationInput {
|
||||
id: ID!
|
||||
displayName: String!
|
||||
@ -317,6 +369,17 @@ type LicenseInfo {
|
||||
expiresAt: DateTime
|
||||
}
|
||||
|
||||
type MessageAttachmentCode {
|
||||
filepath: String
|
||||
content: String!
|
||||
}
|
||||
|
||||
type MessageAttachmentDoc {
|
||||
title: String!
|
||||
link: String!
|
||||
content: String!
|
||||
}
|
||||
|
||||
type Mutation {
|
||||
resetRegistrationToken: String!
|
||||
requestInvitationEmail(input: RequestInvitationInput!): Invitation!
|
||||
@ -490,7 +553,24 @@ type ServerInfo {
|
||||
}
|
||||
|
||||
type Subscription {
|
||||
count: Int!
|
||||
createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem!
|
||||
createThreadRun(input: CreateThreadRunInput!): ThreadRunItem!
|
||||
}
|
||||
|
||||
"""
|
||||
Schema of thread run stream.
|
||||
|
||||
The event's order is kept as same as the order defined in the struct fields.
|
||||
Apart from `thread_message_content_delta`, all other items will only appear once in the stream.
|
||||
"""
|
||||
type ThreadRunItem {
|
||||
threadCreated: ID
|
||||
threadMessageCreated: ID
|
||||
threadMessageAttachmentsCode: [MessageAttachmentCode!]
|
||||
threadMessageAttachmentsDoc: [MessageAttachmentDoc!]
|
||||
threadMessageRelevantQuestions: [String!]
|
||||
threadMessageContentDelta: String
|
||||
threadMessageCompleted: ID
|
||||
}
|
||||
|
||||
type TokenAuthResponse {
|
||||
|
@ -7,20 +7,19 @@ pub mod job;
|
||||
pub mod license;
|
||||
pub mod repository;
|
||||
pub mod setting;
|
||||
pub mod thread;
|
||||
pub mod user_event;
|
||||
pub mod web_crawler;
|
||||
pub mod worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_stream::stream;
|
||||
use auth::{
|
||||
AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
|
||||
User,
|
||||
};
|
||||
use base64::Engine;
|
||||
use chrono::{DateTime, Utc};
|
||||
use futures::stream::BoxStream;
|
||||
use job::{JobRun, JobService};
|
||||
use juniper::{
|
||||
graphql_object, graphql_subscription, graphql_value, FieldError, GraphQLObject, IntoFieldError,
|
||||
@ -28,6 +27,7 @@ use juniper::{
|
||||
};
|
||||
use repository::RepositoryGrepOutput;
|
||||
use tabby_common::api::{code::CodeSearch, event::EventLogger};
|
||||
use thread::{CreateThreadAndRunInput, CreateThreadRunInput, ThreadRunStream, ThreadService};
|
||||
use tracing::error;
|
||||
use validator::{Validate, ValidationErrors};
|
||||
use worker::WorkerService;
|
||||
@ -71,6 +71,7 @@ pub trait ServiceLocator: Send + Sync {
|
||||
fn analytic(&self) -> Arc<dyn AnalyticService>;
|
||||
fn user_event(&self) -> Arc<dyn UserEventService>;
|
||||
fn web_crawler(&self) -> Arc<dyn WebCrawlerService>;
|
||||
fn thread(&self) -> Arc<dyn ThreadService>;
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
@ -914,22 +915,35 @@ fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Subscription;
|
||||
|
||||
type NumberStream = BoxStream<'static, Result<i32, FieldError>>;
|
||||
|
||||
#[graphql_subscription]
|
||||
impl Subscription {
|
||||
// FIXME(meng): This is a temporary subscription to test the subscription feature, we should remove it later.
|
||||
async fn count(ctx: &Context) -> Result<NumberStream> {
|
||||
async fn create_thread_and_run(
|
||||
ctx: &Context,
|
||||
input: CreateThreadAndRunInput,
|
||||
) -> Result<ThreadRunStream> {
|
||||
check_user(ctx).await?;
|
||||
let mut value = 0;
|
||||
let s = stream! {
|
||||
loop {
|
||||
value += 1;
|
||||
yield Ok(value);
|
||||
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
input.validate()?;
|
||||
|
||||
let thread = ctx.locator.thread();
|
||||
|
||||
let thread_id = thread.create(&input.thread).await?;
|
||||
|
||||
thread.create_run(&thread_id, &input.options).await
|
||||
}
|
||||
};
|
||||
Ok(Box::pin(s))
|
||||
|
||||
async fn create_thread_run(
|
||||
ctx: &Context,
|
||||
input: CreateThreadRunInput,
|
||||
) -> Result<ThreadRunStream> {
|
||||
// check_user(ctx).await?;
|
||||
input.validate()?;
|
||||
|
||||
let thread = ctx.locator.thread();
|
||||
thread
|
||||
.append_messages(&input.thread_id, &input.additional_messages)
|
||||
.await?;
|
||||
|
||||
thread.create_run(&input.thread_id, &input.options).await
|
||||
}
|
||||
}
|
||||
|
||||
|
45
ee/tabby-schema/src/schema/thread.rs
Normal file
45
ee/tabby-schema/src/schema/thread.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::BoxStream;
|
||||
use juniper::ID;
|
||||
|
||||
use crate::schema::Result;
|
||||
|
||||
mod types;
|
||||
pub use types::*;
|
||||
|
||||
mod inputs;
|
||||
pub use inputs::*;
|
||||
|
||||
pub type ThreadRunStream = BoxStream<'static, Result<ThreadRunItem>>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ThreadService: Send + Sync {
|
||||
/// Create a new thread
|
||||
async fn create(&self, input: &CreateThreadInput) -> Result<ID>;
|
||||
|
||||
/// Create a new thread run
|
||||
async fn create_run(&self, id: &ID, options: &ThreadRunOptionsInput)
|
||||
-> Result<ThreadRunStream>;
|
||||
|
||||
/// Append messages to an existing thread
|
||||
async fn append_messages(&self, id: &ID, messages: &[CreateMessageInput]) -> Result<()>;
|
||||
|
||||
// /// Delete a thread by ID
|
||||
// async fn delete(&self, id: ID) -> Result<()>;
|
||||
|
||||
// /// Create a new message in a thread
|
||||
// async fn create_message(&self, input: CreateMessageInput) -> Result<ID>;
|
||||
|
||||
// /// Delete a message by ID
|
||||
// async fn delete_message(&self, id: ID) -> Result<()>;
|
||||
|
||||
// /// Query messages in a thread
|
||||
// async fn list_messages(
|
||||
// &self,
|
||||
// thread_id: ID,
|
||||
// after: Option<String>,
|
||||
// before: Option<String>,
|
||||
// first: Option<usize>,
|
||||
// last: Option<usize>,
|
||||
// ) -> Result<Vec<Message>>;
|
||||
}
|
112
ee/tabby-schema/src/schema/thread/inputs.rs
Normal file
112
ee/tabby-schema/src/schema/thread/inputs.rs
Normal file
@ -0,0 +1,112 @@
|
||||
use juniper::{GraphQLInputObject, ID};
|
||||
use validator::{Validate, ValidationError};
|
||||
|
||||
use super::Role;
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
#[validate(schema(function = "validate_message_input", skip_on_field_errors = false))]
|
||||
pub struct CreateMessageInput {
|
||||
role: Role,
|
||||
|
||||
content: String,
|
||||
|
||||
#[validate(nested)]
|
||||
attachments: Option<MessageAttachmentInput>,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
#[validate(schema(function = "validate_thread_input", skip_on_field_errors = false))]
|
||||
pub struct CreateThreadInput {
|
||||
#[validate(nested)]
|
||||
messages: Vec<CreateMessageInput>,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
pub struct CreateThreadAndRunInput {
|
||||
#[validate(nested)]
|
||||
pub thread: CreateThreadInput,
|
||||
|
||||
#[validate(nested)]
|
||||
#[graphql(default)]
|
||||
pub options: ThreadRunOptionsInput,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate, Clone)]
|
||||
pub struct DocQueryInput {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate, Clone)]
|
||||
pub struct CodeQueryInput {
|
||||
pub git_url: String,
|
||||
pub filepath: Option<String>,
|
||||
pub language: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate, Default, Clone)]
|
||||
pub struct ThreadRunOptionsInput {
|
||||
#[validate(nested)]
|
||||
#[graphql(default)]
|
||||
pub doc_query: Option<DocQueryInput>,
|
||||
|
||||
#[validate(nested)]
|
||||
#[graphql(default)]
|
||||
pub code_query: Option<CodeQueryInput>,
|
||||
|
||||
#[graphql(default)]
|
||||
pub generate_relevant_questions: bool,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
pub struct CreateThreadRunInput {
|
||||
pub thread_id: ID,
|
||||
|
||||
#[validate(nested)]
|
||||
pub additional_messages: Vec<CreateMessageInput>,
|
||||
|
||||
#[validate(nested)]
|
||||
#[graphql(default)]
|
||||
pub options: ThreadRunOptionsInput,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
pub struct MessageAttachmentInput {
|
||||
#[validate(nested)]
|
||||
code: Vec<MessageAttachmentCodeInput>,
|
||||
}
|
||||
|
||||
#[derive(GraphQLInputObject, Validate)]
|
||||
pub struct MessageAttachmentCodeInput {
|
||||
pub filepath: Option<String>,
|
||||
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
fn validate_message_input(input: &CreateMessageInput) -> Result<(), ValidationError> {
|
||||
if let Role::Assistant = input.role {
|
||||
if input.attachments.is_some() {
|
||||
return Err(ValidationError::new(
|
||||
"Attachments are not allowed for assistants",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_thread_input(input: &CreateThreadInput) -> Result<(), ValidationError> {
|
||||
let messages = &input.messages;
|
||||
let length = messages.len();
|
||||
|
||||
for (i, message) in messages.iter().enumerate() {
|
||||
let is_last = i == length - 1;
|
||||
if !is_last && message.attachments.is_some() {
|
||||
return Err(ValidationError::new(
|
||||
"Attachments are only allowed on the last message",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
102
ee/tabby-schema/src/schema/thread/types.rs
Normal file
102
ee/tabby-schema/src/schema/thread/types.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use juniper::{GraphQLEnum, GraphQLObject, ID};
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(GraphQLEnum, Serialize, Clone)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone)]
|
||||
pub struct Message {
|
||||
pub id: ID,
|
||||
pub thread_id: ID,
|
||||
pub role: Role,
|
||||
pub content: String,
|
||||
|
||||
pub attachments: Option<MessageAttachment>,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone)]
|
||||
pub struct MessageAttachment {
|
||||
pub code: Vec<MessageAttachmentCode>,
|
||||
pub doc: Vec<MessageAttachmentDoc>,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone)]
|
||||
pub struct MessageAttachmentCode {
|
||||
pub filepath: Option<String>,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject, Clone)]
|
||||
pub struct MessageAttachmentDoc {
|
||||
pub title: String,
|
||||
pub link: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Schema of thread run stream.
|
||||
///
|
||||
/// The event's order is kept as same as the order defined in the struct fields.
|
||||
/// Apart from `thread_message_content_delta`, all other items will only appear once in the stream.
|
||||
#[derive(GraphQLObject)]
|
||||
pub struct ThreadRunItem {
|
||||
thread_created: Option<ID>,
|
||||
thread_message_created: Option<ID>,
|
||||
thread_message_attachments_code: Option<Vec<MessageAttachmentCode>>,
|
||||
thread_message_attachments_doc: Option<Vec<MessageAttachmentDoc>>,
|
||||
thread_message_relevant_questions: Option<Vec<String>>,
|
||||
thread_message_content_delta: Option<String>,
|
||||
thread_message_completed: Option<ID>,
|
||||
}
|
||||
|
||||
impl ThreadRunItem {
|
||||
pub fn thread_message_attachments_code(code: Vec<MessageAttachmentCode>) -> Self {
|
||||
Self {
|
||||
thread_created: None,
|
||||
thread_message_created: None,
|
||||
thread_message_attachments_code: Some(code),
|
||||
thread_message_attachments_doc: None,
|
||||
thread_message_relevant_questions: None,
|
||||
thread_message_content_delta: None,
|
||||
thread_message_completed: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn thread_message_relevant_questions(questions: Vec<String>) -> Self {
|
||||
Self {
|
||||
thread_created: None,
|
||||
thread_message_created: None,
|
||||
thread_message_attachments_code: None,
|
||||
thread_message_attachments_doc: None,
|
||||
thread_message_relevant_questions: Some(questions),
|
||||
thread_message_content_delta: None,
|
||||
thread_message_completed: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn thread_message_attachments_doc(doc: Vec<MessageAttachmentDoc>) -> Self {
|
||||
Self {
|
||||
thread_created: None,
|
||||
thread_message_created: None,
|
||||
thread_message_attachments_code: None,
|
||||
thread_message_attachments_doc: Some(doc),
|
||||
thread_message_relevant_questions: None,
|
||||
thread_message_content_delta: None,
|
||||
thread_message_completed: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn thread_message_content_delta(delta: String) -> Self {
|
||||
Self {
|
||||
thread_created: None,
|
||||
thread_message_created: None,
|
||||
thread_message_attachments_code: None,
|
||||
thread_message_attachments_doc: None,
|
||||
thread_message_relevant_questions: None,
|
||||
thread_message_content_delta: Some(delta),
|
||||
thread_message_completed: None,
|
||||
}
|
||||
}
|
||||
}
|
@ -1,20 +1,30 @@
|
||||
use core::panic;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_openai::types::{
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
|
||||
CreateChatCompletionRequestArgs,
|
||||
use anyhow::anyhow;
|
||||
use async_openai::{
|
||||
error::OpenAIError,
|
||||
types::{
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Role,
|
||||
},
|
||||
};
|
||||
use async_stream::stream;
|
||||
use futures::stream::BoxStream;
|
||||
use tabby_common::api::{
|
||||
answer::{AnswerCodeSnippet, AnswerRequest, AnswerResponseChunk},
|
||||
answer::{AnswerRequest, AnswerResponseChunk},
|
||||
code::{CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchQuery},
|
||||
doc::{DocSearch, DocSearchError, DocSearchHit},
|
||||
};
|
||||
use tabby_inference::ChatCompletionStream;
|
||||
use tabby_schema::{repository::RepositoryService, web_crawler::WebCrawlerService};
|
||||
use tracing::{debug, warn};
|
||||
use tabby_schema::{
|
||||
repository::RepositoryService,
|
||||
thread::{
|
||||
self, CodeQueryInput, DocQueryInput, MessageAttachmentCode, ThreadRunItem,
|
||||
ThreadRunOptionsInput,
|
||||
},
|
||||
web_crawler::WebCrawlerService,
|
||||
};
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
pub struct AnswerService {
|
||||
chat: Arc<dyn ChatCompletionStream>,
|
||||
@ -54,6 +64,7 @@ impl AnswerService {
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(note = "This shall be removed after the migration to v2 is done.")]
|
||||
pub async fn answer<'a>(
|
||||
self: Arc<Self>,
|
||||
mut req: AnswerRequest,
|
||||
@ -82,7 +93,14 @@ impl AnswerService {
|
||||
// Code snippet is extended to the query.
|
||||
self.override_query_with_code_query(query, &code_query).await;
|
||||
}
|
||||
self.collect_relevant_code(code_query).await
|
||||
|
||||
let code_query = CodeQueryInput {
|
||||
git_url: code_query.git_url,
|
||||
filepath: code_query.filepath,
|
||||
language: code_query.language,
|
||||
content: code_query.content,
|
||||
};
|
||||
self.collect_relevant_code(&code_query).await
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
@ -94,7 +112,10 @@ impl AnswerService {
|
||||
|
||||
// 2. Collect relevant docs if needed.
|
||||
let relevant_docs = if req.doc_query {
|
||||
self.collect_relevant_docs(git_url.as_deref(), get_content(query)).await
|
||||
let query = DocQueryInput {
|
||||
content: get_content(query).to_owned(),
|
||||
};
|
||||
self.collect_relevant_docs(git_url.as_deref(), &query).await
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
@ -110,6 +131,10 @@ impl AnswerService {
|
||||
yield AnswerResponseChunk::RelevantQuestions(relevant_questions);
|
||||
}
|
||||
|
||||
let code_snippets: Vec<MessageAttachmentCode> = code_snippets.iter().map(|x| MessageAttachmentCode {
|
||||
filepath: x.filepath.clone(),
|
||||
content: x.content.clone(),
|
||||
}).collect();
|
||||
|
||||
// 4. Generate override prompt from the query
|
||||
set_content(query, self.generate_prompt(&code_snippets, &relevant_code, &relevant_docs, get_content(query)).await);
|
||||
@ -153,7 +178,158 @@ impl AnswerService {
|
||||
Box::pin(s)
|
||||
}
|
||||
|
||||
async fn collect_relevant_code(&self, query: CodeSearchQuery) -> Vec<CodeSearchHit> {
|
||||
pub async fn answer_v2<'a>(
|
||||
self: Arc<Self>,
|
||||
messages: &[tabby_schema::thread::Message],
|
||||
options: &ThreadRunOptionsInput,
|
||||
) -> tabby_schema::Result<BoxStream<'a, tabby_schema::Result<ThreadRunItem>>> {
|
||||
let messages = messages.to_vec();
|
||||
let options = options.clone();
|
||||
|
||||
let s = stream! {
|
||||
let query = match messages.last() {
|
||||
Some(query) => query,
|
||||
None => {
|
||||
yield Err(anyhow!("No query found in the request").into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let git_url = options.code_query.as_ref().map(|x| x.git_url.clone());
|
||||
|
||||
// 1. Collect relevant code if needed.
|
||||
let relevant_code = if let Some(code_query) = options.code_query.as_ref() {
|
||||
self.collect_relevant_code(code_query).await
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
relevant_code.is_empty();
|
||||
|
||||
// 2. Collect relevant docs if needed.
|
||||
let relevant_docs = if let Some(doc_query) = options.doc_query.as_ref() {
|
||||
self.collect_relevant_docs(git_url.as_deref(), doc_query)
|
||||
.await
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if !relevant_docs.is_empty() {
|
||||
yield Ok(ThreadRunItem::thread_message_attachments_code(
|
||||
relevant_code
|
||||
.iter()
|
||||
.map(|x| MessageAttachmentCode {
|
||||
filepath: Some(x.doc.filepath.clone()),
|
||||
content: x.doc.body.clone(),
|
||||
})
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
|
||||
// 3. Generate relevant questions.
|
||||
if options.generate_relevant_questions {
|
||||
let questions = self
|
||||
.generate_relevant_questions(&relevant_code, &relevant_docs, &query.content)
|
||||
.await;
|
||||
yield Ok(ThreadRunItem::thread_message_relevant_questions(questions));
|
||||
}
|
||||
|
||||
// 4. Prepare requesting LLM
|
||||
let request = {
|
||||
let empty = Vec::default();
|
||||
let code_snippets: &[MessageAttachmentCode] = query
|
||||
.attachments
|
||||
.as_ref()
|
||||
.map(|x| &x.code)
|
||||
.unwrap_or(&empty);
|
||||
|
||||
let override_user_prompt = if !code_snippets.is_empty()
|
||||
|| !relevant_code.is_empty()
|
||||
|| !relevant_docs.is_empty()
|
||||
{
|
||||
self.generate_prompt(
|
||||
code_snippets,
|
||||
&relevant_code,
|
||||
&relevant_docs,
|
||||
&query.content,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
query.content.clone()
|
||||
};
|
||||
|
||||
// Convert `messages` to CreateChatCompletionRequest
|
||||
let chat_messages: Vec<_> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, x)| {
|
||||
let role = match x.role {
|
||||
thread::Role::Assistant => Role::Assistant,
|
||||
thread::Role::User => Role::User,
|
||||
};
|
||||
|
||||
let is_last = i == messages.len() - 1;
|
||||
let content = if is_last {
|
||||
override_user_prompt.clone()
|
||||
} else {
|
||||
x.content.clone()
|
||||
};
|
||||
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||
content,
|
||||
role,
|
||||
name: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
CreateChatCompletionRequestArgs::default()
|
||||
.messages(chat_messages)
|
||||
.presence_penalty(PRESENCE_PENALTY)
|
||||
.build()
|
||||
.expect("Failed to build chat completion request")
|
||||
};
|
||||
|
||||
|
||||
let s = match self.chat.chat_stream(request).await {
|
||||
Ok(s) => s,
|
||||
Err(err) => {
|
||||
warn!("Failed to create chat completion stream: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for await chunk in s {
|
||||
let chunk = match chunk {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
if let OpenAIError::StreamError(content) = err {
|
||||
if content == "Stream ended" {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
error!("Failed to get chat completion chunk: {:?}", err);
|
||||
}
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(content) = chunk.choices[0].delta.content.as_deref() {
|
||||
yield Ok(ThreadRunItem::thread_message_content_delta(content.to_owned()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(s))
|
||||
}
|
||||
|
||||
async fn collect_relevant_code(&self, query: &CodeQueryInput) -> Vec<CodeSearchHit> {
|
||||
let query = CodeSearchQuery {
|
||||
git_url: query.git_url.clone(),
|
||||
filepath: query.filepath.clone(),
|
||||
language: query.language.clone(),
|
||||
content: query.content.clone(),
|
||||
};
|
||||
match self.code.search_in_language(query, 20).await {
|
||||
Ok(docs) => docs.hits,
|
||||
Err(err) => {
|
||||
@ -170,7 +346,7 @@ impl AnswerService {
|
||||
async fn collect_relevant_docs(
|
||||
&self,
|
||||
code_query_git_url: Option<&str>,
|
||||
content: &str,
|
||||
doc_query: &DocQueryInput,
|
||||
) -> Vec<DocSearchHit> {
|
||||
let source_ids = {
|
||||
// 1. By default only web sources are considered.
|
||||
@ -203,7 +379,7 @@ impl AnswerService {
|
||||
|
||||
// 1. Collect relevant docs from the tantivy doc search.
|
||||
let mut hits = vec![];
|
||||
let doc_hits = match self.doc.search(&source_ids, content, 5).await {
|
||||
let doc_hits = match self.doc.search(&source_ids, &doc_query.content, 5).await {
|
||||
Ok(docs) => docs.hits,
|
||||
Err(err) => {
|
||||
if let DocSearchError::NotReady = err {
|
||||
@ -218,7 +394,7 @@ impl AnswerService {
|
||||
|
||||
// 2. If serper is available, we also collect from serper
|
||||
if let Some(serper) = self.serper.as_ref() {
|
||||
let serper_hits = match serper.search(&[], content, 5).await {
|
||||
let serper_hits = match serper.search(&[], &doc_query.content, 5).await {
|
||||
Ok(docs) => docs.hits,
|
||||
Err(err) => {
|
||||
warn!("Failed to search serper: {:?}", err);
|
||||
@ -308,7 +484,7 @@ Remember, based on the original question and related contexts, suggest three suc
|
||||
|
||||
async fn generate_prompt(
|
||||
&self,
|
||||
code_snippets: &[AnswerCodeSnippet],
|
||||
code_snippets: &[MessageAttachmentCode],
|
||||
relevant_code: &[CodeSearchHit],
|
||||
relevant_docs: &[DocSearchHit],
|
||||
question: &str,
|
||||
|
@ -9,11 +9,13 @@ pub mod job;
|
||||
mod license;
|
||||
pub mod repository;
|
||||
mod setting;
|
||||
mod thread;
|
||||
mod user_event;
|
||||
pub mod web_crawler;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use answer::AnswerService;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@ -38,6 +40,7 @@ use tabby_schema::{
|
||||
license::{IsLicenseValid, LicenseService},
|
||||
repository::RepositoryService,
|
||||
setting::SettingService,
|
||||
thread::ThreadService,
|
||||
user_event::UserEventService,
|
||||
web_crawler::WebCrawlerService,
|
||||
worker::WorkerService,
|
||||
@ -57,6 +60,7 @@ struct ServerContext {
|
||||
user_event: Arc<dyn UserEventService>,
|
||||
job: Arc<dyn JobService>,
|
||||
web_crawler: Arc<dyn WebCrawlerService>,
|
||||
thread: Arc<dyn ThreadService>,
|
||||
|
||||
logger: Arc<dyn EventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
@ -74,6 +78,7 @@ impl ServerContext {
|
||||
integration: Arc<dyn IntegrationService>,
|
||||
web_crawler: Arc<dyn WebCrawlerService>,
|
||||
job: Arc<dyn JobService>,
|
||||
answer: Option<Arc<AnswerService>>,
|
||||
db_conn: DbConn,
|
||||
is_chat_enabled_locally: bool,
|
||||
) -> Self {
|
||||
@ -89,6 +94,7 @@ impl ServerContext {
|
||||
);
|
||||
let user_event = Arc::new(user_event::create(db_conn.clone()));
|
||||
let setting = Arc::new(setting::create(db_conn.clone()));
|
||||
let thread = Arc::new(thread::create(answer));
|
||||
|
||||
Self {
|
||||
mail: mail.clone(),
|
||||
@ -99,6 +105,7 @@ impl ServerContext {
|
||||
setting.clone(),
|
||||
)),
|
||||
web_crawler,
|
||||
thread,
|
||||
license,
|
||||
repository,
|
||||
integration,
|
||||
@ -252,6 +259,10 @@ impl ServiceLocator for ArcServerContext {
|
||||
fn web_crawler(&self) -> Arc<dyn WebCrawlerService> {
|
||||
self.0.web_crawler.clone()
|
||||
}
|
||||
|
||||
fn thread(&self) -> Arc<dyn ThreadService> {
|
||||
self.0.thread.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_service_locator(
|
||||
@ -261,6 +272,7 @@ pub async fn create_service_locator(
|
||||
integration: Arc<dyn IntegrationService>,
|
||||
web_crawler: Arc<dyn WebCrawlerService>,
|
||||
job: Arc<dyn JobService>,
|
||||
answer: Option<Arc<AnswerService>>,
|
||||
db: DbConn,
|
||||
is_chat_enabled: bool,
|
||||
) -> Arc<dyn ServiceLocator> {
|
||||
@ -272,6 +284,7 @@ pub async fn create_service_locator(
|
||||
integration,
|
||||
web_crawler,
|
||||
job,
|
||||
answer,
|
||||
db,
|
||||
is_chat_enabled,
|
||||
)
|
||||
|
53
ee/tabby-webserver/src/service/thread.rs
Normal file
53
ee/tabby-webserver/src/service/thread.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use juniper::ID;
|
||||
use tabby_schema::{
|
||||
bail,
|
||||
thread::{
|
||||
self, CreateMessageInput, CreateThreadInput, ThreadRunOptionsInput, ThreadRunStream,
|
||||
ThreadService,
|
||||
},
|
||||
Result,
|
||||
};
|
||||
|
||||
use super::answer::AnswerService;
|
||||
|
||||
struct ThreadServiceImpl {
|
||||
answer: Option<Arc<AnswerService>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ThreadService for ThreadServiceImpl {
|
||||
async fn create(&self, _input: &CreateThreadInput) -> Result<ID> {
|
||||
Ok(ID::new("message:1"))
|
||||
}
|
||||
|
||||
async fn create_run(
|
||||
&self,
|
||||
_id: &ID,
|
||||
options: &ThreadRunOptionsInput,
|
||||
) -> Result<ThreadRunStream> {
|
||||
let Some(answer) = self.answer.clone() else {
|
||||
bail!("Answer service is not available");
|
||||
};
|
||||
|
||||
// FIXME(meng): actual lookup messages from database.
|
||||
let messages = vec![thread::Message {
|
||||
id: ID::new("message:1"),
|
||||
thread_id: ID::new("thread:1"),
|
||||
role: thread::Role::User,
|
||||
content: "Hello, world!".to_string(),
|
||||
attachments: None,
|
||||
}];
|
||||
answer.answer_v2(&messages, options).await
|
||||
}
|
||||
|
||||
async fn append_messages(&self, _id: &ID, _messages: &[CreateMessageInput]) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create(answer: Option<Arc<AnswerService>>) -> impl ThreadService {
|
||||
ThreadServiceImpl { answer }
|
||||
}
|
@ -101,6 +101,17 @@ impl Webserver {
|
||||
docsearch: Arc<dyn DocSearch>,
|
||||
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
|
||||
) -> (Router, Router) {
|
||||
let answer = chat.as_ref().map(|chat| {
|
||||
Arc::new(crate::service::answer::create(
|
||||
chat.clone(),
|
||||
code.clone(),
|
||||
docsearch.clone(),
|
||||
self.web_crawler.clone(),
|
||||
self.repository.clone(),
|
||||
serper_factory_fn,
|
||||
))
|
||||
});
|
||||
|
||||
let is_chat_enabled = chat.is_some();
|
||||
let ctx = create_service_locator(
|
||||
self.logger(),
|
||||
@ -109,22 +120,12 @@ impl Webserver {
|
||||
self.integration.clone(),
|
||||
self.web_crawler.clone(),
|
||||
self.job.clone(),
|
||||
answer.clone(),
|
||||
self.db.clone(),
|
||||
is_chat_enabled,
|
||||
)
|
||||
.await;
|
||||
|
||||
let answer = chat.as_ref().map(|chat| {
|
||||
Arc::new(crate::service::answer::create(
|
||||
chat.clone(),
|
||||
code.clone(),
|
||||
docsearch.clone(),
|
||||
ctx.web_crawler().clone(),
|
||||
ctx.repository().clone(),
|
||||
serper_factory_fn,
|
||||
))
|
||||
});
|
||||
|
||||
routes::create(ctx, api, ui, answer)
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,13 @@ rule:
|
||||
stopBy: end
|
||||
pattern: message
|
||||
- not:
|
||||
has:
|
||||
any:
|
||||
- has:
|
||||
stopBy: end
|
||||
pattern: custom
|
||||
- has:
|
||||
stopBy: end
|
||||
pattern: nested
|
||||
- has:
|
||||
stopBy: end
|
||||
pattern: schema
|
||||
|
Loading…
Reference in New Issue
Block a user