feat(webserver): support persisted thread in answer engine (#2793)

* init commit

* [autofix.ci] apply automated fixes

* draft create_thread_run_api

* cleanup

* update

* update

* update

* update

* update

* update

* connect streaming

* update

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* add back answer api for ease of migration

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Meng Zhang 2024-08-09 15:29:19 -07:00 committed by GitHub
parent 13a5f8bcd9
commit 5c9b62faa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 650 additions and 46 deletions

View File

@ -0,0 +1,3 @@
kind: Features
body: Support persisted thread discussion
time: 2024-08-05T16:29:01.227967-07:00

1
Cargo.lock generated
View File

@ -5206,7 +5206,6 @@ name = "tabby-schema"
version = "0.16.0-dev.0"
dependencies = [
"anyhow",
"async-stream",
"async-trait",
"axum",
"base64 0.22.1",

View File

@ -28,7 +28,6 @@ validator = { version = "0.18.1", features = ["derive"] }
regex.workspace = true
hash-ids.workspace = true
url.workspace = true
async-stream.workspace = true
[[example]]
name = "update-schema"

View File

@ -78,6 +78,18 @@ enum RepositoryKind {
GITLAB_SELF_HOSTED
}
enum Role {
USER
ASSISTANT
}
input CodeQueryInput {
gitUrl: String!
filepath: String
language: String
content: String!
}
input CreateIntegrationInput {
displayName: String!
accessToken: String!
@ -85,10 +97,35 @@ input CreateIntegrationInput {
apiBase: String
}
input CreateMessageInput {
role: Role!
content: String!
attachments: MessageAttachmentInput
}
input CreateThreadAndRunInput {
thread: CreateThreadInput!
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
}
input CreateThreadInput {
messages: [CreateMessageInput!]!
}
input CreateThreadRunInput {
threadId: ID!
additionalMessages: [CreateMessageInput!]!
options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false}
}
input CreateWebCrawlerUrlInput {
url: String!
}
input DocQueryInput {
content: String!
}
input EmailSettingInput {
smtpUsername: String!
fromAddress: String!
@ -99,6 +136,15 @@ input EmailSettingInput {
smtpPassword: String
}
input MessageAttachmentCodeInput {
filepath: String
content: String!
}
input MessageAttachmentInput {
code: [MessageAttachmentCodeInput!]!
}
input NetworkSettingInput {
externalUrl: String!
}
@ -128,6 +174,12 @@ input SecuritySettingInput {
disableClientSideTelemetry: Boolean!
}
input ThreadRunOptionsInput {
docQuery: DocQueryInput = null
codeQuery: CodeQueryInput = null
generateRelevantQuestions: Boolean! = false
}
input UpdateIntegrationInput {
id: ID!
displayName: String!
@ -317,6 +369,17 @@ type LicenseInfo {
expiresAt: DateTime
}
type MessageAttachmentCode {
filepath: String
content: String!
}
type MessageAttachmentDoc {
title: String!
link: String!
content: String!
}
type Mutation {
resetRegistrationToken: String!
requestInvitationEmail(input: RequestInvitationInput!): Invitation!
@ -490,7 +553,24 @@ type ServerInfo {
}
type Subscription {
count: Int!
createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem!
createThreadRun(input: CreateThreadRunInput!): ThreadRunItem!
}
"""
Schema of thread run stream.
The event's order is kept as same as the order defined in the struct fields.
Apart from `thread_message_content_delta`, all other items will only appear once in the stream.
"""
type ThreadRunItem {
threadCreated: ID
threadMessageCreated: ID
threadMessageAttachmentsCode: [MessageAttachmentCode!]
threadMessageAttachmentsDoc: [MessageAttachmentDoc!]
threadMessageRelevantQuestions: [String!]
threadMessageContentDelta: String
threadMessageCompleted: ID
}
type TokenAuthResponse {

View File

@ -7,20 +7,19 @@ pub mod job;
pub mod license;
pub mod repository;
pub mod setting;
pub mod thread;
pub mod user_event;
pub mod web_crawler;
pub mod worker;
use std::sync::Arc;
use async_stream::stream;
use auth::{
AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
User,
};
use base64::Engine;
use chrono::{DateTime, Utc};
use futures::stream::BoxStream;
use job::{JobRun, JobService};
use juniper::{
graphql_object, graphql_subscription, graphql_value, FieldError, GraphQLObject, IntoFieldError,
@ -28,6 +27,7 @@ use juniper::{
};
use repository::RepositoryGrepOutput;
use tabby_common::api::{code::CodeSearch, event::EventLogger};
use thread::{CreateThreadAndRunInput, CreateThreadRunInput, ThreadRunStream, ThreadService};
use tracing::error;
use validator::{Validate, ValidationErrors};
use worker::WorkerService;
@ -71,6 +71,7 @@ pub trait ServiceLocator: Send + Sync {
fn analytic(&self) -> Arc<dyn AnalyticService>;
fn user_event(&self) -> Arc<dyn UserEventService>;
fn web_crawler(&self) -> Arc<dyn WebCrawlerService>;
fn thread(&self) -> Arc<dyn ThreadService>;
}
pub struct Context {
@ -914,22 +915,35 @@ fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError
#[derive(Clone, Copy, Debug)]
pub struct Subscription;
type NumberStream = BoxStream<'static, Result<i32, FieldError>>;
#[graphql_subscription]
impl Subscription {
// FIXME(meng): This is a temporary subscription to test the subscription feature, we should remove it later.
async fn count(ctx: &Context) -> Result<NumberStream> {
async fn create_thread_and_run(
ctx: &Context,
input: CreateThreadAndRunInput,
) -> Result<ThreadRunStream> {
check_user(ctx).await?;
let mut value = 0;
let s = stream! {
loop {
value += 1;
yield Ok(value);
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
input.validate()?;
let thread = ctx.locator.thread();
let thread_id = thread.create(&input.thread).await?;
thread.create_run(&thread_id, &input.options).await
}
};
Ok(Box::pin(s))
async fn create_thread_run(
ctx: &Context,
input: CreateThreadRunInput,
) -> Result<ThreadRunStream> {
// check_user(ctx).await?;
input.validate()?;
let thread = ctx.locator.thread();
thread
.append_messages(&input.thread_id, &input.additional_messages)
.await?;
thread.create_run(&input.thread_id, &input.options).await
}
}

View File

@ -0,0 +1,45 @@
use async_trait::async_trait;
use futures::stream::BoxStream;
use juniper::ID;
use crate::schema::Result;
mod types;
pub use types::*;
mod inputs;
pub use inputs::*;
pub type ThreadRunStream = BoxStream<'static, Result<ThreadRunItem>>;
#[async_trait]
pub trait ThreadService: Send + Sync {
/// Create a new thread
async fn create(&self, input: &CreateThreadInput) -> Result<ID>;
/// Create a new thread run
async fn create_run(&self, id: &ID, options: &ThreadRunOptionsInput)
-> Result<ThreadRunStream>;
/// Append messages to an existing thread
async fn append_messages(&self, id: &ID, messages: &[CreateMessageInput]) -> Result<()>;
// /// Delete a thread by ID
// async fn delete(&self, id: ID) -> Result<()>;
// /// Create a new message in a thread
// async fn create_message(&self, input: CreateMessageInput) -> Result<ID>;
// /// Delete a message by ID
// async fn delete_message(&self, id: ID) -> Result<()>;
// /// Query messages in a thread
// async fn list_messages(
// &self,
// thread_id: ID,
// after: Option<String>,
// before: Option<String>,
// first: Option<usize>,
// last: Option<usize>,
// ) -> Result<Vec<Message>>;
}

View File

@ -0,0 +1,112 @@
use juniper::{GraphQLInputObject, ID};
use validator::{Validate, ValidationError};
use super::Role;
#[derive(GraphQLInputObject, Validate)]
#[validate(schema(function = "validate_message_input", skip_on_field_errors = false))]
pub struct CreateMessageInput {
role: Role,
content: String,
#[validate(nested)]
attachments: Option<MessageAttachmentInput>,
}
#[derive(GraphQLInputObject, Validate)]
#[validate(schema(function = "validate_thread_input", skip_on_field_errors = false))]
pub struct CreateThreadInput {
#[validate(nested)]
messages: Vec<CreateMessageInput>,
}
#[derive(GraphQLInputObject, Validate)]
pub struct CreateThreadAndRunInput {
#[validate(nested)]
pub thread: CreateThreadInput,
#[validate(nested)]
#[graphql(default)]
pub options: ThreadRunOptionsInput,
}
#[derive(GraphQLInputObject, Validate, Clone)]
pub struct DocQueryInput {
pub content: String,
}
#[derive(GraphQLInputObject, Validate, Clone)]
pub struct CodeQueryInput {
pub git_url: String,
pub filepath: Option<String>,
pub language: Option<String>,
pub content: String,
}
#[derive(GraphQLInputObject, Validate, Default, Clone)]
pub struct ThreadRunOptionsInput {
#[validate(nested)]
#[graphql(default)]
pub doc_query: Option<DocQueryInput>,
#[validate(nested)]
#[graphql(default)]
pub code_query: Option<CodeQueryInput>,
#[graphql(default)]
pub generate_relevant_questions: bool,
}
#[derive(GraphQLInputObject, Validate)]
pub struct CreateThreadRunInput {
pub thread_id: ID,
#[validate(nested)]
pub additional_messages: Vec<CreateMessageInput>,
#[validate(nested)]
#[graphql(default)]
pub options: ThreadRunOptionsInput,
}
#[derive(GraphQLInputObject, Validate)]
pub struct MessageAttachmentInput {
#[validate(nested)]
code: Vec<MessageAttachmentCodeInput>,
}
#[derive(GraphQLInputObject, Validate)]
pub struct MessageAttachmentCodeInput {
pub filepath: Option<String>,
pub content: String,
}
fn validate_message_input(input: &CreateMessageInput) -> Result<(), ValidationError> {
if let Role::Assistant = input.role {
if input.attachments.is_some() {
return Err(ValidationError::new(
"Attachments are not allowed for assistants",
));
}
}
Ok(())
}
fn validate_thread_input(input: &CreateThreadInput) -> Result<(), ValidationError> {
let messages = &input.messages;
let length = messages.len();
for (i, message) in messages.iter().enumerate() {
let is_last = i == length - 1;
if !is_last && message.attachments.is_some() {
return Err(ValidationError::new(
"Attachments are only allowed on the last message",
));
}
}
Ok(())
}

View File

@ -0,0 +1,102 @@
use juniper::{GraphQLEnum, GraphQLObject, ID};
use serde::Serialize;
#[derive(GraphQLEnum, Serialize, Clone)]
pub enum Role {
User,
Assistant,
}
#[derive(GraphQLObject, Clone)]
pub struct Message {
pub id: ID,
pub thread_id: ID,
pub role: Role,
pub content: String,
pub attachments: Option<MessageAttachment>,
}
#[derive(GraphQLObject, Clone)]
pub struct MessageAttachment {
pub code: Vec<MessageAttachmentCode>,
pub doc: Vec<MessageAttachmentDoc>,
}
#[derive(GraphQLObject, Clone)]
pub struct MessageAttachmentCode {
pub filepath: Option<String>,
pub content: String,
}
#[derive(GraphQLObject, Clone)]
pub struct MessageAttachmentDoc {
pub title: String,
pub link: String,
pub content: String,
}
/// Schema of thread run stream.
///
/// The event's order is kept as same as the order defined in the struct fields.
/// Apart from `thread_message_content_delta`, all other items will only appear once in the stream.
#[derive(GraphQLObject)]
pub struct ThreadRunItem {
thread_created: Option<ID>,
thread_message_created: Option<ID>,
thread_message_attachments_code: Option<Vec<MessageAttachmentCode>>,
thread_message_attachments_doc: Option<Vec<MessageAttachmentDoc>>,
thread_message_relevant_questions: Option<Vec<String>>,
thread_message_content_delta: Option<String>,
thread_message_completed: Option<ID>,
}
impl ThreadRunItem {
pub fn thread_message_attachments_code(code: Vec<MessageAttachmentCode>) -> Self {
Self {
thread_created: None,
thread_message_created: None,
thread_message_attachments_code: Some(code),
thread_message_attachments_doc: None,
thread_message_relevant_questions: None,
thread_message_content_delta: None,
thread_message_completed: None,
}
}
pub fn thread_message_relevant_questions(questions: Vec<String>) -> Self {
Self {
thread_created: None,
thread_message_created: None,
thread_message_attachments_code: None,
thread_message_attachments_doc: None,
thread_message_relevant_questions: Some(questions),
thread_message_content_delta: None,
thread_message_completed: None,
}
}
pub fn thread_message_attachments_doc(doc: Vec<MessageAttachmentDoc>) -> Self {
Self {
thread_created: None,
thread_message_created: None,
thread_message_attachments_code: None,
thread_message_attachments_doc: Some(doc),
thread_message_relevant_questions: None,
thread_message_content_delta: None,
thread_message_completed: None,
}
}
pub fn thread_message_content_delta(delta: String) -> Self {
Self {
thread_created: None,
thread_message_created: None,
thread_message_attachments_code: None,
thread_message_attachments_doc: None,
thread_message_relevant_questions: None,
thread_message_content_delta: Some(delta),
thread_message_completed: None,
}
}
}

View File

@ -1,20 +1,30 @@
use core::panic;
use std::sync::Arc;
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs,
CreateChatCompletionRequestArgs,
use anyhow::anyhow;
use async_openai::{
error::OpenAIError,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Role,
},
};
use async_stream::stream;
use futures::stream::BoxStream;
use tabby_common::api::{
answer::{AnswerCodeSnippet, AnswerRequest, AnswerResponseChunk},
answer::{AnswerRequest, AnswerResponseChunk},
code::{CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchQuery},
doc::{DocSearch, DocSearchError, DocSearchHit},
};
use tabby_inference::ChatCompletionStream;
use tabby_schema::{repository::RepositoryService, web_crawler::WebCrawlerService};
use tracing::{debug, warn};
use tabby_schema::{
repository::RepositoryService,
thread::{
self, CodeQueryInput, DocQueryInput, MessageAttachmentCode, ThreadRunItem,
ThreadRunOptionsInput,
},
web_crawler::WebCrawlerService,
};
use tracing::{debug, error, warn};
pub struct AnswerService {
chat: Arc<dyn ChatCompletionStream>,
@ -54,6 +64,7 @@ impl AnswerService {
}
}
#[deprecated(note = "This shall be removed after the migration to v2 is done.")]
pub async fn answer<'a>(
self: Arc<Self>,
mut req: AnswerRequest,
@ -82,7 +93,14 @@ impl AnswerService {
// Code snippet is extended to the query.
self.override_query_with_code_query(query, &code_query).await;
}
self.collect_relevant_code(code_query).await
let code_query = CodeQueryInput {
git_url: code_query.git_url,
filepath: code_query.filepath,
language: code_query.language,
content: code_query.content,
};
self.collect_relevant_code(&code_query).await
} else {
vec![]
};
@ -94,7 +112,10 @@ impl AnswerService {
// 2. Collect relevant docs if needed.
let relevant_docs = if req.doc_query {
self.collect_relevant_docs(git_url.as_deref(), get_content(query)).await
let query = DocQueryInput {
content: get_content(query).to_owned(),
};
self.collect_relevant_docs(git_url.as_deref(), &query).await
} else {
vec![]
};
@ -110,6 +131,10 @@ impl AnswerService {
yield AnswerResponseChunk::RelevantQuestions(relevant_questions);
}
let code_snippets: Vec<MessageAttachmentCode> = code_snippets.iter().map(|x| MessageAttachmentCode {
filepath: x.filepath.clone(),
content: x.content.clone(),
}).collect();
// 4. Generate override prompt from the query
set_content(query, self.generate_prompt(&code_snippets, &relevant_code, &relevant_docs, get_content(query)).await);
@ -153,7 +178,158 @@ impl AnswerService {
Box::pin(s)
}
async fn collect_relevant_code(&self, query: CodeSearchQuery) -> Vec<CodeSearchHit> {
pub async fn answer_v2<'a>(
self: Arc<Self>,
messages: &[tabby_schema::thread::Message],
options: &ThreadRunOptionsInput,
) -> tabby_schema::Result<BoxStream<'a, tabby_schema::Result<ThreadRunItem>>> {
let messages = messages.to_vec();
let options = options.clone();
let s = stream! {
let query = match messages.last() {
Some(query) => query,
None => {
yield Err(anyhow!("No query found in the request").into());
return;
}
};
let git_url = options.code_query.as_ref().map(|x| x.git_url.clone());
// 1. Collect relevant code if needed.
let relevant_code = if let Some(code_query) = options.code_query.as_ref() {
self.collect_relevant_code(code_query).await
} else {
vec![]
};
relevant_code.is_empty();
// 2. Collect relevant docs if needed.
let relevant_docs = if let Some(doc_query) = options.doc_query.as_ref() {
self.collect_relevant_docs(git_url.as_deref(), doc_query)
.await
} else {
vec![]
};
if !relevant_docs.is_empty() {
yield Ok(ThreadRunItem::thread_message_attachments_code(
relevant_code
.iter()
.map(|x| MessageAttachmentCode {
filepath: Some(x.doc.filepath.clone()),
content: x.doc.body.clone(),
})
.collect(),
));
}
// 3. Generate relevant questions.
if options.generate_relevant_questions {
let questions = self
.generate_relevant_questions(&relevant_code, &relevant_docs, &query.content)
.await;
yield Ok(ThreadRunItem::thread_message_relevant_questions(questions));
}
// 4. Prepare requesting LLM
let request = {
let empty = Vec::default();
let code_snippets: &[MessageAttachmentCode] = query
.attachments
.as_ref()
.map(|x| &x.code)
.unwrap_or(&empty);
let override_user_prompt = if !code_snippets.is_empty()
|| !relevant_code.is_empty()
|| !relevant_docs.is_empty()
{
self.generate_prompt(
code_snippets,
&relevant_code,
&relevant_docs,
&query.content,
)
.await
} else {
query.content.clone()
};
// Convert `messages` to CreateChatCompletionRequest
let chat_messages: Vec<_> = messages
.iter()
.enumerate()
.map(|(i, x)| {
let role = match x.role {
thread::Role::Assistant => Role::Assistant,
thread::Role::User => Role::User,
};
let is_last = i == messages.len() - 1;
let content = if is_last {
override_user_prompt.clone()
} else {
x.content.clone()
};
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content,
role,
name: None,
})
})
.collect();
CreateChatCompletionRequestArgs::default()
.messages(chat_messages)
.presence_penalty(PRESENCE_PENALTY)
.build()
.expect("Failed to build chat completion request")
};
let s = match self.chat.chat_stream(request).await {
Ok(s) => s,
Err(err) => {
warn!("Failed to create chat completion stream: {:?}", err);
return;
}
};
for await chunk in s {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(err) => {
if let OpenAIError::StreamError(content) = err {
if content == "Stream ended" {
break;
}
} else {
error!("Failed to get chat completion chunk: {:?}", err);
}
break;
}
};
if let Some(content) = chunk.choices[0].delta.content.as_deref() {
yield Ok(ThreadRunItem::thread_message_content_delta(content.to_owned()));
}
}
};
Ok(Box::pin(s))
}
async fn collect_relevant_code(&self, query: &CodeQueryInput) -> Vec<CodeSearchHit> {
let query = CodeSearchQuery {
git_url: query.git_url.clone(),
filepath: query.filepath.clone(),
language: query.language.clone(),
content: query.content.clone(),
};
match self.code.search_in_language(query, 20).await {
Ok(docs) => docs.hits,
Err(err) => {
@ -170,7 +346,7 @@ impl AnswerService {
async fn collect_relevant_docs(
&self,
code_query_git_url: Option<&str>,
content: &str,
doc_query: &DocQueryInput,
) -> Vec<DocSearchHit> {
let source_ids = {
// 1. By default only web sources are considered.
@ -203,7 +379,7 @@ impl AnswerService {
// 1. Collect relevant docs from the tantivy doc search.
let mut hits = vec![];
let doc_hits = match self.doc.search(&source_ids, content, 5).await {
let doc_hits = match self.doc.search(&source_ids, &doc_query.content, 5).await {
Ok(docs) => docs.hits,
Err(err) => {
if let DocSearchError::NotReady = err {
@ -218,7 +394,7 @@ impl AnswerService {
// 2. If serper is available, we also collect from serper
if let Some(serper) = self.serper.as_ref() {
let serper_hits = match serper.search(&[], content, 5).await {
let serper_hits = match serper.search(&[], &doc_query.content, 5).await {
Ok(docs) => docs.hits,
Err(err) => {
warn!("Failed to search serper: {:?}", err);
@ -308,7 +484,7 @@ Remember, based on the original question and related contexts, suggest three suc
async fn generate_prompt(
&self,
code_snippets: &[AnswerCodeSnippet],
code_snippets: &[MessageAttachmentCode],
relevant_code: &[CodeSearchHit],
relevant_docs: &[DocSearchHit],
question: &str,

View File

@ -9,11 +9,13 @@ pub mod job;
mod license;
pub mod repository;
mod setting;
mod thread;
mod user_event;
pub mod web_crawler;
use std::sync::Arc;
use answer::AnswerService;
use async_trait::async_trait;
use axum::{
body::Body,
@ -38,6 +40,7 @@ use tabby_schema::{
license::{IsLicenseValid, LicenseService},
repository::RepositoryService,
setting::SettingService,
thread::ThreadService,
user_event::UserEventService,
web_crawler::WebCrawlerService,
worker::WorkerService,
@ -57,6 +60,7 @@ struct ServerContext {
user_event: Arc<dyn UserEventService>,
job: Arc<dyn JobService>,
web_crawler: Arc<dyn WebCrawlerService>,
thread: Arc<dyn ThreadService>,
logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>,
@ -74,6 +78,7 @@ impl ServerContext {
integration: Arc<dyn IntegrationService>,
web_crawler: Arc<dyn WebCrawlerService>,
job: Arc<dyn JobService>,
answer: Option<Arc<AnswerService>>,
db_conn: DbConn,
is_chat_enabled_locally: bool,
) -> Self {
@ -89,6 +94,7 @@ impl ServerContext {
);
let user_event = Arc::new(user_event::create(db_conn.clone()));
let setting = Arc::new(setting::create(db_conn.clone()));
let thread = Arc::new(thread::create(answer));
Self {
mail: mail.clone(),
@ -99,6 +105,7 @@ impl ServerContext {
setting.clone(),
)),
web_crawler,
thread,
license,
repository,
integration,
@ -252,6 +259,10 @@ impl ServiceLocator for ArcServerContext {
fn web_crawler(&self) -> Arc<dyn WebCrawlerService> {
self.0.web_crawler.clone()
}
fn thread(&self) -> Arc<dyn ThreadService> {
self.0.thread.clone()
}
}
pub async fn create_service_locator(
@ -261,6 +272,7 @@ pub async fn create_service_locator(
integration: Arc<dyn IntegrationService>,
web_crawler: Arc<dyn WebCrawlerService>,
job: Arc<dyn JobService>,
answer: Option<Arc<AnswerService>>,
db: DbConn,
is_chat_enabled: bool,
) -> Arc<dyn ServiceLocator> {
@ -272,6 +284,7 @@ pub async fn create_service_locator(
integration,
web_crawler,
job,
answer,
db,
is_chat_enabled,
)

View File

@ -0,0 +1,53 @@
use std::sync::Arc;
use async_trait::async_trait;
use juniper::ID;
use tabby_schema::{
bail,
thread::{
self, CreateMessageInput, CreateThreadInput, ThreadRunOptionsInput, ThreadRunStream,
ThreadService,
},
Result,
};
use super::answer::AnswerService;
struct ThreadServiceImpl {
answer: Option<Arc<AnswerService>>,
}
#[async_trait]
impl ThreadService for ThreadServiceImpl {
async fn create(&self, _input: &CreateThreadInput) -> Result<ID> {
Ok(ID::new("message:1"))
}
async fn create_run(
&self,
_id: &ID,
options: &ThreadRunOptionsInput,
) -> Result<ThreadRunStream> {
let Some(answer) = self.answer.clone() else {
bail!("Answer service is not available");
};
// FIXME(meng): actual lookup messages from database.
let messages = vec![thread::Message {
id: ID::new("message:1"),
thread_id: ID::new("thread:1"),
role: thread::Role::User,
content: "Hello, world!".to_string(),
attachments: None,
}];
answer.answer_v2(&messages, options).await
}
async fn append_messages(&self, _id: &ID, _messages: &[CreateMessageInput]) -> Result<()> {
Ok(())
}
}
pub fn create(answer: Option<Arc<AnswerService>>) -> impl ThreadService {
ThreadServiceImpl { answer }
}

View File

@ -101,6 +101,17 @@ impl Webserver {
docsearch: Arc<dyn DocSearch>,
serper_factory_fn: impl Fn(&str) -> Box<dyn DocSearch>,
) -> (Router, Router) {
let answer = chat.as_ref().map(|chat| {
Arc::new(crate::service::answer::create(
chat.clone(),
code.clone(),
docsearch.clone(),
self.web_crawler.clone(),
self.repository.clone(),
serper_factory_fn,
))
});
let is_chat_enabled = chat.is_some();
let ctx = create_service_locator(
self.logger(),
@ -109,22 +120,12 @@ impl Webserver {
self.integration.clone(),
self.web_crawler.clone(),
self.job.clone(),
answer.clone(),
self.db.clone(),
is_chat_enabled,
)
.await;
let answer = chat.as_ref().map(|chat| {
Arc::new(crate::service::answer::create(
chat.clone(),
code.clone(),
docsearch.clone(),
ctx.web_crawler().clone(),
ctx.repository().clone(),
serper_factory_fn,
))
});
routes::create(ctx, api, ui, answer)
}
}

View File

@ -17,6 +17,13 @@ rule:
stopBy: end
pattern: message
- not:
has:
any:
- has:
stopBy: end
pattern: custom
- has:
stopBy: end
pattern: nested
- has:
stopBy: end
pattern: schema