diff --git a/.changes/unreleased/Features-20240805-162901.yaml b/.changes/unreleased/Features-20240805-162901.yaml new file mode 100644 index 000000000..b572bba09 --- /dev/null +++ b/.changes/unreleased/Features-20240805-162901.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Support persisted thread discussion +time: 2024-08-05T16:29:01.227967-07:00 diff --git a/Cargo.lock b/Cargo.lock index 0d87558c9..612074c6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5206,7 +5206,6 @@ name = "tabby-schema" version = "0.16.0-dev.0" dependencies = [ "anyhow", - "async-stream", "async-trait", "axum", "base64 0.22.1", diff --git a/ee/tabby-schema/Cargo.toml b/ee/tabby-schema/Cargo.toml index 8df6cceee..a375a7d58 100644 --- a/ee/tabby-schema/Cargo.toml +++ b/ee/tabby-schema/Cargo.toml @@ -28,7 +28,6 @@ validator = { version = "0.18.1", features = ["derive"] } regex.workspace = true hash-ids.workspace = true url.workspace = true -async-stream.workspace = true [[example]] name = "update-schema" diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index ea82c3ad8..ffd974c7e 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -78,6 +78,18 @@ enum RepositoryKind { GITLAB_SELF_HOSTED } +enum Role { + USER + ASSISTANT +} + +input CodeQueryInput { + gitUrl: String! + filepath: String + language: String + content: String! +} + input CreateIntegrationInput { displayName: String! accessToken: String! @@ -85,10 +97,35 @@ input CreateIntegrationInput { apiBase: String } +input CreateMessageInput { + role: Role! + content: String! + attachments: MessageAttachmentInput +} + +input CreateThreadAndRunInput { + thread: CreateThreadInput! + options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false} +} + +input CreateThreadInput { + messages: [CreateMessageInput!]! +} + +input CreateThreadRunInput { + threadId: ID! + additionalMessages: [CreateMessageInput!]! + options: ThreadRunOptionsInput! = {codeQuery: null, docQuery: null, generateRelevantQuestions: false} +} + input CreateWebCrawlerUrlInput { url: String! } +input DocQueryInput { + content: String! +} + input EmailSettingInput { smtpUsername: String! fromAddress: String! @@ -99,6 +136,15 @@ input EmailSettingInput { smtpPassword: String } +input MessageAttachmentCodeInput { + filepath: String + content: String! +} + +input MessageAttachmentInput { + code: [MessageAttachmentCodeInput!]! +} + input NetworkSettingInput { externalUrl: String! } @@ -128,6 +174,12 @@ input SecuritySettingInput { disableClientSideTelemetry: Boolean! } +input ThreadRunOptionsInput { + docQuery: DocQueryInput = null + codeQuery: CodeQueryInput = null + generateRelevantQuestions: Boolean! = false +} + input UpdateIntegrationInput { id: ID! displayName: String! @@ -317,6 +369,17 @@ type LicenseInfo { expiresAt: DateTime } +type MessageAttachmentCode { + filepath: String + content: String! +} + +type MessageAttachmentDoc { + title: String! + link: String! + content: String! +} + type Mutation { resetRegistrationToken: String! requestInvitationEmail(input: RequestInvitationInput!): Invitation! @@ -490,7 +553,24 @@ type ServerInfo { } type Subscription { - count: Int! + createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem! + createThreadRun(input: CreateThreadRunInput!): ThreadRunItem! +} + +""" + Schema of thread run stream. + + The event's order is kept as same as the order defined in the struct fields. + Apart from `thread_message_content_delta`, all other items will only appear once in the stream. +""" +type ThreadRunItem { + threadCreated: ID + threadMessageCreated: ID + threadMessageAttachmentsCode: [MessageAttachmentCode!] + threadMessageAttachmentsDoc: [MessageAttachmentDoc!] + threadMessageRelevantQuestions: [String!] + threadMessageContentDelta: String + threadMessageCompleted: ID } type TokenAuthResponse { diff --git a/ee/tabby-schema/src/schema/mod.rs b/ee/tabby-schema/src/schema/mod.rs index fae06a9ef..32918a1bd 100644 --- a/ee/tabby-schema/src/schema/mod.rs +++ b/ee/tabby-schema/src/schema/mod.rs @@ -7,20 +7,19 @@ pub mod job; pub mod license; pub mod repository; pub mod setting; +pub mod thread; pub mod user_event; pub mod web_crawler; pub mod worker; use std::sync::Arc; -use async_stream::stream; use auth::{ AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, User, }; use base64::Engine; use chrono::{DateTime, Utc}; -use futures::stream::BoxStream; use job::{JobRun, JobService}; use juniper::{ graphql_object, graphql_subscription, graphql_value, FieldError, GraphQLObject, IntoFieldError, @@ -28,6 +27,7 @@ use juniper::{ }; use repository::RepositoryGrepOutput; use tabby_common::api::{code::CodeSearch, event::EventLogger}; +use thread::{CreateThreadAndRunInput, CreateThreadRunInput, ThreadRunStream, ThreadService}; use tracing::error; use validator::{Validate, ValidationErrors}; use worker::WorkerService; @@ -71,6 +71,7 @@ pub trait ServiceLocator: Send + Sync { fn analytic(&self) -> Arc; fn user_event(&self) -> Arc; fn web_crawler(&self) -> Arc; + fn thread(&self) -> Arc; } pub struct Context { @@ -914,22 +915,35 @@ fn from_validation_errors(error: ValidationErrors) -> FieldError #[derive(Clone, Copy, Debug)] pub struct Subscription; -type NumberStream = BoxStream<'static, Result>; - #[graphql_subscription] impl Subscription { - // FIXME(meng): This is a temporary subscription to test the subscription feature, we should remove it later. - async fn count(ctx: &Context) -> Result { + async fn create_thread_and_run( + ctx: &Context, + input: CreateThreadAndRunInput, + ) -> Result { check_user(ctx).await?; - let mut value = 0; - let s = stream! { - loop { - value += 1; - yield Ok(value); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - } - }; - Ok(Box::pin(s)) + input.validate()?; + + let thread = ctx.locator.thread(); + + let thread_id = thread.create(&input.thread).await?; + + thread.create_run(&thread_id, &input.options).await + } + + async fn create_thread_run( + ctx: &Context, + input: CreateThreadRunInput, + ) -> Result { + // check_user(ctx).await?; + input.validate()?; + + let thread = ctx.locator.thread(); + thread + .append_messages(&input.thread_id, &input.additional_messages) + .await?; + + thread.create_run(&input.thread_id, &input.options).await } } diff --git a/ee/tabby-schema/src/schema/thread.rs b/ee/tabby-schema/src/schema/thread.rs new file mode 100644 index 000000000..cbcb356c8 --- /dev/null +++ b/ee/tabby-schema/src/schema/thread.rs @@ -0,0 +1,45 @@ +use async_trait::async_trait; +use futures::stream::BoxStream; +use juniper::ID; + +use crate::schema::Result; + +mod types; +pub use types::*; + +mod inputs; +pub use inputs::*; + +pub type ThreadRunStream = BoxStream<'static, Result>; + +#[async_trait] +pub trait ThreadService: Send + Sync { + /// Create a new thread + async fn create(&self, input: &CreateThreadInput) -> Result; + + /// Create a new thread run + async fn create_run(&self, id: &ID, options: &ThreadRunOptionsInput) + -> Result; + + /// Append messages to an existing thread + async fn append_messages(&self, id: &ID, messages: &[CreateMessageInput]) -> Result<()>; + + // /// Delete a thread by ID + // async fn delete(&self, id: ID) -> Result<()>; + + // /// Create a new message in a thread + // async fn create_message(&self, input: CreateMessageInput) -> Result; + + // /// Delete a message by ID + // async fn delete_message(&self, id: ID) -> Result<()>; + + // /// Query messages in a thread + // async fn list_messages( + // &self, + // thread_id: ID, + // after: Option, + // before: Option, + // first: Option, + // last: Option, + // ) -> Result>; +} diff --git a/ee/tabby-schema/src/schema/thread/inputs.rs b/ee/tabby-schema/src/schema/thread/inputs.rs new file mode 100644 index 000000000..e54973fb5 --- /dev/null +++ b/ee/tabby-schema/src/schema/thread/inputs.rs @@ -0,0 +1,112 @@ +use juniper::{GraphQLInputObject, ID}; +use validator::{Validate, ValidationError}; + +use super::Role; + +#[derive(GraphQLInputObject, Validate)] +#[validate(schema(function = "validate_message_input", skip_on_field_errors = false))] +pub struct CreateMessageInput { + role: Role, + + content: String, + + #[validate(nested)] + attachments: Option, +} + +#[derive(GraphQLInputObject, Validate)] +#[validate(schema(function = "validate_thread_input", skip_on_field_errors = false))] +pub struct CreateThreadInput { + #[validate(nested)] + messages: Vec, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct CreateThreadAndRunInput { + #[validate(nested)] + pub thread: CreateThreadInput, + + #[validate(nested)] + #[graphql(default)] + pub options: ThreadRunOptionsInput, +} + +#[derive(GraphQLInputObject, Validate, Clone)] +pub struct DocQueryInput { + pub content: String, +} + +#[derive(GraphQLInputObject, Validate, Clone)] +pub struct CodeQueryInput { + pub git_url: String, + pub filepath: Option, + pub language: Option, + pub content: String, +} + +#[derive(GraphQLInputObject, Validate, Default, Clone)] +pub struct ThreadRunOptionsInput { + #[validate(nested)] + #[graphql(default)] + pub doc_query: Option, + + #[validate(nested)] + #[graphql(default)] + pub code_query: Option, + + #[graphql(default)] + pub generate_relevant_questions: bool, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct CreateThreadRunInput { + pub thread_id: ID, + + #[validate(nested)] + pub additional_messages: Vec, + + #[validate(nested)] + #[graphql(default)] + pub options: ThreadRunOptionsInput, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct MessageAttachmentInput { + #[validate(nested)] + code: Vec, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct MessageAttachmentCodeInput { + pub filepath: Option, + + pub content: String, +} + +fn validate_message_input(input: &CreateMessageInput) -> Result<(), ValidationError> { + if let Role::Assistant = input.role { + if input.attachments.is_some() { + return Err(ValidationError::new( + "Attachments are not allowed for assistants", + )); + } + } + + Ok(()) +} + +fn validate_thread_input(input: &CreateThreadInput) -> Result<(), ValidationError> { + let messages = &input.messages; + let length = messages.len(); + + for (i, message) in messages.iter().enumerate() { + let is_last = i == length - 1; + if !is_last && message.attachments.is_some() { + return Err(ValidationError::new( + "Attachments are only allowed on the last message", + )); + } + } + + Ok(()) +} diff --git a/ee/tabby-schema/src/schema/thread/types.rs b/ee/tabby-schema/src/schema/thread/types.rs new file mode 100644 index 000000000..01c4e40f5 --- /dev/null +++ b/ee/tabby-schema/src/schema/thread/types.rs @@ -0,0 +1,102 @@ +use juniper::{GraphQLEnum, GraphQLObject, ID}; +use serde::Serialize; + +#[derive(GraphQLEnum, Serialize, Clone)] +pub enum Role { + User, + Assistant, +} + +#[derive(GraphQLObject, Clone)] +pub struct Message { + pub id: ID, + pub thread_id: ID, + pub role: Role, + pub content: String, + + pub attachments: Option, +} + +#[derive(GraphQLObject, Clone)] +pub struct MessageAttachment { + pub code: Vec, + pub doc: Vec, +} + +#[derive(GraphQLObject, Clone)] +pub struct MessageAttachmentCode { + pub filepath: Option, + pub content: String, +} + +#[derive(GraphQLObject, Clone)] +pub struct MessageAttachmentDoc { + pub title: String, + pub link: String, + pub content: String, +} + +/// Schema of thread run stream. +/// +/// The event's order is kept as same as the order defined in the struct fields. +/// Apart from `thread_message_content_delta`, all other items will only appear once in the stream. +#[derive(GraphQLObject)] +pub struct ThreadRunItem { + thread_created: Option, + thread_message_created: Option, + thread_message_attachments_code: Option>, + thread_message_attachments_doc: Option>, + thread_message_relevant_questions: Option>, + thread_message_content_delta: Option, + thread_message_completed: Option, +} + +impl ThreadRunItem { + pub fn thread_message_attachments_code(code: Vec) -> Self { + Self { + thread_created: None, + thread_message_created: None, + thread_message_attachments_code: Some(code), + thread_message_attachments_doc: None, + thread_message_relevant_questions: None, + thread_message_content_delta: None, + thread_message_completed: None, + } + } + + pub fn thread_message_relevant_questions(questions: Vec) -> Self { + Self { + thread_created: None, + thread_message_created: None, + thread_message_attachments_code: None, + thread_message_attachments_doc: None, + thread_message_relevant_questions: Some(questions), + thread_message_content_delta: None, + thread_message_completed: None, + } + } + + pub fn thread_message_attachments_doc(doc: Vec) -> Self { + Self { + thread_created: None, + thread_message_created: None, + thread_message_attachments_code: None, + thread_message_attachments_doc: Some(doc), + thread_message_relevant_questions: None, + thread_message_content_delta: None, + thread_message_completed: None, + } + } + + pub fn thread_message_content_delta(delta: String) -> Self { + Self { + thread_created: None, + thread_message_created: None, + thread_message_attachments_code: None, + thread_message_attachments_doc: None, + thread_message_relevant_questions: None, + thread_message_content_delta: Some(delta), + thread_message_completed: None, + } + } +} diff --git a/ee/tabby-webserver/src/service/answer.rs b/ee/tabby-webserver/src/service/answer.rs index 725c75181..13d398c70 100644 --- a/ee/tabby-webserver/src/service/answer.rs +++ b/ee/tabby-webserver/src/service/answer.rs @@ -1,20 +1,30 @@ -use core::panic; use std::sync::Arc; -use async_openai::types::{ - ChatCompletionRequestMessage, ChatCompletionRequestUserMessageArgs, - CreateChatCompletionRequestArgs, +use anyhow::anyhow; +use async_openai::{ + error::OpenAIError, + types::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, Role, + }, }; use async_stream::stream; use futures::stream::BoxStream; use tabby_common::api::{ - answer::{AnswerCodeSnippet, AnswerRequest, AnswerResponseChunk}, + answer::{AnswerRequest, AnswerResponseChunk}, code::{CodeSearch, CodeSearchError, CodeSearchHit, CodeSearchQuery}, doc::{DocSearch, DocSearchError, DocSearchHit}, }; use tabby_inference::ChatCompletionStream; -use tabby_schema::{repository::RepositoryService, web_crawler::WebCrawlerService}; -use tracing::{debug, warn}; +use tabby_schema::{ + repository::RepositoryService, + thread::{ + self, CodeQueryInput, DocQueryInput, MessageAttachmentCode, ThreadRunItem, + ThreadRunOptionsInput, + }, + web_crawler::WebCrawlerService, +}; +use tracing::{debug, error, warn}; pub struct AnswerService { chat: Arc, @@ -54,6 +64,7 @@ impl AnswerService { } } + #[deprecated(note = "This shall be removed after the migration to v2 is done.")] pub async fn answer<'a>( self: Arc, mut req: AnswerRequest, @@ -82,7 +93,14 @@ impl AnswerService { // Code snippet is extended to the query. self.override_query_with_code_query(query, &code_query).await; } - self.collect_relevant_code(code_query).await + + let code_query = CodeQueryInput { + git_url: code_query.git_url, + filepath: code_query.filepath, + language: code_query.language, + content: code_query.content, + }; + self.collect_relevant_code(&code_query).await } else { vec![] }; @@ -94,7 +112,10 @@ impl AnswerService { // 2. Collect relevant docs if needed. let relevant_docs = if req.doc_query { - self.collect_relevant_docs(git_url.as_deref(), get_content(query)).await + let query = DocQueryInput { + content: get_content(query).to_owned(), + }; + self.collect_relevant_docs(git_url.as_deref(), &query).await } else { vec![] }; @@ -110,6 +131,10 @@ impl AnswerService { yield AnswerResponseChunk::RelevantQuestions(relevant_questions); } + let code_snippets: Vec = code_snippets.iter().map(|x| MessageAttachmentCode { + filepath: x.filepath.clone(), + content: x.content.clone(), + }).collect(); // 4. Generate override prompt from the query set_content(query, self.generate_prompt(&code_snippets, &relevant_code, &relevant_docs, get_content(query)).await); @@ -153,7 +178,158 @@ impl AnswerService { Box::pin(s) } - async fn collect_relevant_code(&self, query: CodeSearchQuery) -> Vec { + pub async fn answer_v2<'a>( + self: Arc, + messages: &[tabby_schema::thread::Message], + options: &ThreadRunOptionsInput, + ) -> tabby_schema::Result>> { + let messages = messages.to_vec(); + let options = options.clone(); + + let s = stream! { + let query = match messages.last() { + Some(query) => query, + None => { + yield Err(anyhow!("No query found in the request").into()); + return; + } + }; + + let git_url = options.code_query.as_ref().map(|x| x.git_url.clone()); + + // 1. Collect relevant code if needed. + let relevant_code = if let Some(code_query) = options.code_query.as_ref() { + self.collect_relevant_code(code_query).await + } else { + vec![] + }; + + relevant_code.is_empty(); + + // 2. Collect relevant docs if needed. + let relevant_docs = if let Some(doc_query) = options.doc_query.as_ref() { + self.collect_relevant_docs(git_url.as_deref(), doc_query) + .await + } else { + vec![] + }; + + if !relevant_docs.is_empty() { + yield Ok(ThreadRunItem::thread_message_attachments_code( + relevant_code + .iter() + .map(|x| MessageAttachmentCode { + filepath: Some(x.doc.filepath.clone()), + content: x.doc.body.clone(), + }) + .collect(), + )); + } + + // 3. Generate relevant questions. + if options.generate_relevant_questions { + let questions = self + .generate_relevant_questions(&relevant_code, &relevant_docs, &query.content) + .await; + yield Ok(ThreadRunItem::thread_message_relevant_questions(questions)); + } + + // 4. Prepare requesting LLM + let request = { + let empty = Vec::default(); + let code_snippets: &[MessageAttachmentCode] = query + .attachments + .as_ref() + .map(|x| &x.code) + .unwrap_or(&empty); + + let override_user_prompt = if !code_snippets.is_empty() + || !relevant_code.is_empty() + || !relevant_docs.is_empty() + { + self.generate_prompt( + code_snippets, + &relevant_code, + &relevant_docs, + &query.content, + ) + .await + } else { + query.content.clone() + }; + + // Convert `messages` to CreateChatCompletionRequest + let chat_messages: Vec<_> = messages + .iter() + .enumerate() + .map(|(i, x)| { + let role = match x.role { + thread::Role::Assistant => Role::Assistant, + thread::Role::User => Role::User, + }; + + let is_last = i == messages.len() - 1; + let content = if is_last { + override_user_prompt.clone() + } else { + x.content.clone() + }; + + ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content, + role, + name: None, + }) + }) + .collect(); + + CreateChatCompletionRequestArgs::default() + .messages(chat_messages) + .presence_penalty(PRESENCE_PENALTY) + .build() + .expect("Failed to build chat completion request") + }; + + + let s = match self.chat.chat_stream(request).await { + Ok(s) => s, + Err(err) => { + warn!("Failed to create chat completion stream: {:?}", err); + return; + } + }; + + for await chunk in s { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => { + if let OpenAIError::StreamError(content) = err { + if content == "Stream ended" { + break; + } + } else { + error!("Failed to get chat completion chunk: {:?}", err); + } + break; + } + }; + + if let Some(content) = chunk.choices[0].delta.content.as_deref() { + yield Ok(ThreadRunItem::thread_message_content_delta(content.to_owned())); + } + } + }; + + Ok(Box::pin(s)) + } + + async fn collect_relevant_code(&self, query: &CodeQueryInput) -> Vec { + let query = CodeSearchQuery { + git_url: query.git_url.clone(), + filepath: query.filepath.clone(), + language: query.language.clone(), + content: query.content.clone(), + }; match self.code.search_in_language(query, 20).await { Ok(docs) => docs.hits, Err(err) => { @@ -170,7 +346,7 @@ impl AnswerService { async fn collect_relevant_docs( &self, code_query_git_url: Option<&str>, - content: &str, + doc_query: &DocQueryInput, ) -> Vec { let source_ids = { // 1. By default only web sources are considered. @@ -203,7 +379,7 @@ impl AnswerService { // 1. Collect relevant docs from the tantivy doc search. let mut hits = vec![]; - let doc_hits = match self.doc.search(&source_ids, content, 5).await { + let doc_hits = match self.doc.search(&source_ids, &doc_query.content, 5).await { Ok(docs) => docs.hits, Err(err) => { if let DocSearchError::NotReady = err { @@ -218,7 +394,7 @@ impl AnswerService { // 2. If serper is available, we also collect from serper if let Some(serper) = self.serper.as_ref() { - let serper_hits = match serper.search(&[], content, 5).await { + let serper_hits = match serper.search(&[], &doc_query.content, 5).await { Ok(docs) => docs.hits, Err(err) => { warn!("Failed to search serper: {:?}", err); @@ -308,7 +484,7 @@ Remember, based on the original question and related contexts, suggest three suc async fn generate_prompt( &self, - code_snippets: &[AnswerCodeSnippet], + code_snippets: &[MessageAttachmentCode], relevant_code: &[CodeSearchHit], relevant_docs: &[DocSearchHit], question: &str, diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index ed09460e4..b4a71dde6 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -9,11 +9,13 @@ pub mod job; mod license; pub mod repository; mod setting; +mod thread; mod user_event; pub mod web_crawler; use std::sync::Arc; +use answer::AnswerService; use async_trait::async_trait; use axum::{ body::Body, @@ -38,6 +40,7 @@ use tabby_schema::{ license::{IsLicenseValid, LicenseService}, repository::RepositoryService, setting::SettingService, + thread::ThreadService, user_event::UserEventService, web_crawler::WebCrawlerService, worker::WorkerService, @@ -57,6 +60,7 @@ struct ServerContext { user_event: Arc, job: Arc, web_crawler: Arc, + thread: Arc, logger: Arc, code: Arc, @@ -74,6 +78,7 @@ impl ServerContext { integration: Arc, web_crawler: Arc, job: Arc, + answer: Option>, db_conn: DbConn, is_chat_enabled_locally: bool, ) -> Self { @@ -89,6 +94,7 @@ impl ServerContext { ); let user_event = Arc::new(user_event::create(db_conn.clone())); let setting = Arc::new(setting::create(db_conn.clone())); + let thread = Arc::new(thread::create(answer)); Self { mail: mail.clone(), @@ -99,6 +105,7 @@ impl ServerContext { setting.clone(), )), web_crawler, + thread, license, repository, integration, @@ -252,6 +259,10 @@ impl ServiceLocator for ArcServerContext { fn web_crawler(&self) -> Arc { self.0.web_crawler.clone() } + + fn thread(&self) -> Arc { + self.0.thread.clone() + } } pub async fn create_service_locator( @@ -261,6 +272,7 @@ pub async fn create_service_locator( integration: Arc, web_crawler: Arc, job: Arc, + answer: Option>, db: DbConn, is_chat_enabled: bool, ) -> Arc { @@ -272,6 +284,7 @@ pub async fn create_service_locator( integration, web_crawler, job, + answer, db, is_chat_enabled, ) diff --git a/ee/tabby-webserver/src/service/thread.rs b/ee/tabby-webserver/src/service/thread.rs new file mode 100644 index 000000000..7b1922268 --- /dev/null +++ b/ee/tabby-webserver/src/service/thread.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use juniper::ID; +use tabby_schema::{ + bail, + thread::{ + self, CreateMessageInput, CreateThreadInput, ThreadRunOptionsInput, ThreadRunStream, + ThreadService, + }, + Result, +}; + +use super::answer::AnswerService; + +struct ThreadServiceImpl { + answer: Option>, +} + +#[async_trait] +impl ThreadService for ThreadServiceImpl { + async fn create(&self, _input: &CreateThreadInput) -> Result { + Ok(ID::new("message:1")) + } + + async fn create_run( + &self, + _id: &ID, + options: &ThreadRunOptionsInput, + ) -> Result { + let Some(answer) = self.answer.clone() else { + bail!("Answer service is not available"); + }; + + // FIXME(meng): actual lookup messages from database. + let messages = vec![thread::Message { + id: ID::new("message:1"), + thread_id: ID::new("thread:1"), + role: thread::Role::User, + content: "Hello, world!".to_string(), + attachments: None, + }]; + answer.answer_v2(&messages, options).await + } + + async fn append_messages(&self, _id: &ID, _messages: &[CreateMessageInput]) -> Result<()> { + Ok(()) + } +} + +pub fn create(answer: Option>) -> impl ThreadService { + ThreadServiceImpl { answer } +} diff --git a/ee/tabby-webserver/src/webserver.rs b/ee/tabby-webserver/src/webserver.rs index c0d290583..16e794711 100644 --- a/ee/tabby-webserver/src/webserver.rs +++ b/ee/tabby-webserver/src/webserver.rs @@ -101,6 +101,17 @@ impl Webserver { docsearch: Arc, serper_factory_fn: impl Fn(&str) -> Box, ) -> (Router, Router) { + let answer = chat.as_ref().map(|chat| { + Arc::new(crate::service::answer::create( + chat.clone(), + code.clone(), + docsearch.clone(), + self.web_crawler.clone(), + self.repository.clone(), + serper_factory_fn, + )) + }); + let is_chat_enabled = chat.is_some(); let ctx = create_service_locator( self.logger(), @@ -109,22 +120,12 @@ impl Webserver { self.integration.clone(), self.web_crawler.clone(), self.job.clone(), + answer.clone(), self.db.clone(), is_chat_enabled, ) .await; - let answer = chat.as_ref().map(|chat| { - Arc::new(crate::service::answer::create( - chat.clone(), - code.clone(), - docsearch.clone(), - ctx.web_crawler().clone(), - ctx.repository().clone(), - serper_factory_fn, - )) - }); - routes::create(ctx, api, ui, answer) } } diff --git a/rules/validate-requires-code.yml b/rules/validate-requires-code.yml index aab8c6b10..ca74dfffd 100644 --- a/rules/validate-requires-code.yml +++ b/rules/validate-requires-code.yml @@ -17,6 +17,13 @@ rule: stopBy: end pattern: message - not: - has: - stopBy: end - pattern: custom + any: + - has: + stopBy: end + pattern: custom + - has: + stopBy: end + pattern: nested + - has: + stopBy: end + pattern: schema