From 2b5969b3f3b2d02cec4b48fad670bfe66241dc69 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 16 May 2024 17:51:28 -0700 Subject: [PATCH] feat: support retrieve documents with serper.dev api (#2154) --- .../unreleased/Features-20240516-173446.yaml | 3 + crates/tabby/Cargo.toml | 1 + crates/tabby/src/services/answer.rs | 43 ++++++++-- crates/tabby/src/services/doc/mod.rs | 15 ++++ crates/tabby/src/services/doc/serper.rs | 86 +++++++++++++++++++ .../src/services/{doc.rs => doc/tantivy.rs} | 8 +- 6 files changed, 142 insertions(+), 14 deletions(-) create mode 100644 .changes/unreleased/Features-20240516-173446.yaml create mode 100644 crates/tabby/src/services/doc/mod.rs create mode 100644 crates/tabby/src/services/doc/serper.rs rename crates/tabby/src/services/{doc.rs => doc/tantivy.rs} (96%) diff --git a/.changes/unreleased/Features-20240516-173446.yaml b/.changes/unreleased/Features-20240516-173446.yaml new file mode 100644 index 000000000..75276cf1f --- /dev/null +++ b/.changes/unreleased/Features-20240516-173446.yaml @@ -0,0 +1,3 @@ +kind: Features +body: support retrieve documents with serper.dev api +time: 2024-05-16T17:34:46.89429-07:00 diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index b08471d66..66baa9719 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -59,6 +59,7 @@ cached = { workspace = true, features = ["async"] } parse-git-url = "0.5.1" color-eyre = { version = "0.6.3" } derive_builder.workspace = true +reqwest.workspace = true [dependencies.openssl] optional = true diff --git a/crates/tabby/src/services/answer.rs b/crates/tabby/src/services/answer.rs index 1afaf1e9b..9ebfbde9f 100644 --- a/crates/tabby/src/services/answer.rs +++ b/crates/tabby/src/services/answer.rs @@ -7,7 +7,7 @@ use tabby_common::api::{ chat::Message, doc::{DocSearch, DocSearchDocument}, }; -use tracing::warn; +use tracing::{debug, warn}; use utoipa::ToSchema; use crate::services::chat::{ChatCompletionRequestBuilder, ChatService}; @@ -32,11 +32,26 @@ pub enum AnswerResponseChunk { pub struct AnswerService { chat: Arc, doc: Arc, + serper: Option>, } impl AnswerService { fn new(chat: Arc, doc: Arc) -> Self { - Self { chat, doc } + if let Ok(api_key) = std::env::var("SERPER_API_KEY") { + debug!("Serper API key found, enabling serper..."); + let serper = Box::new(super::doc::create_serper(api_key.as_str())); + Self { + chat, + doc, + serper: Some(serper), + } + } else { + Self { + chat, + doc, + serper: None, + } + } } pub async fn answer<'a>( @@ -55,18 +70,30 @@ impl AnswerService { // 2. Generate relevant docs from the query // For now we only collect from DocSearch. - let serp = match self.doc.search(&query.content, 20, 0).await { - Ok(docs) => docs, + let mut hits = match self.doc.search(&query.content, 5, 0).await { + Ok(docs) => docs.hits, Err(err) => { - warn!("Failed to search docs: {:?}", err); - return; + warn!("Failed to search tantivy docs: {:?}", err); + vec![] } }; - yield AnswerResponseChunk::RelevantDocuments(serp.hits.iter().map(|hit| hit.doc.clone()).collect()); + // If serper is available, we also collect from serper + if let Some(serper) = self.serper.as_ref() { + let serper_hits = match serper.search(&query.content, 5, 0).await { + Ok(docs) => docs.hits, + Err(err) => { + warn!("Failed to search serper: {:?}", err); + vec![] + } + }; + hits.extend(serper_hits); + } + + yield AnswerResponseChunk::RelevantDocuments(hits.iter().map(|hit| hit.doc.clone()).collect()); // 3. Generate relevant answers from the query - let snippets = serp.hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::>(); + let snippets = hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::>(); yield AnswerResponseChunk::RelevantQuestions(self.generate_relevant_questions(&snippets, &query.content).await); // 4. Generate override prompt from the query diff --git a/crates/tabby/src/services/doc/mod.rs b/crates/tabby/src/services/doc/mod.rs new file mode 100644 index 000000000..9c636647e --- /dev/null +++ b/crates/tabby/src/services/doc/mod.rs @@ -0,0 +1,15 @@ +mod serper; +mod tantivy; + +use std::sync::Arc; + +use tabby_common::api::doc::DocSearch; +use tabby_inference::Embedding; + +pub fn create(embedding: Arc) -> impl DocSearch { + tantivy::DocSearchService::new(embedding) +} + +pub fn create_serper(api_key: &str) -> impl DocSearch { + serper::SerperService::new(api_key) +} diff --git a/crates/tabby/src/services/doc/serper.rs b/crates/tabby/src/services/doc/serper.rs new file mode 100644 index 000000000..7de1507f9 --- /dev/null +++ b/crates/tabby/src/services/doc/serper.rs @@ -0,0 +1,86 @@ +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use tabby_common::api::doc::{ + DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse, +}; + +#[derive(Debug, Serialize)] +struct SerperRequest { + q: String, + num: usize, + page: usize, +} + +#[derive(Debug, Deserialize)] +struct SerperResponse { + organic: Vec, +} + +#[derive(Debug, Deserialize)] +struct SerperOrganicHit { + title: String, + snippet: String, + link: String, +} + +pub struct SerperService { + client: reqwest::Client, +} + +impl SerperService { + pub fn new(api_key: &str) -> Self { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + "X-API-KEY", + api_key.parse().expect("Failed to parse Serper API key"), + ); + Self { + client: reqwest::Client::builder() + .default_headers(headers) + .build() + .expect("Failed to create reqwest client"), + } + } +} + +#[async_trait] +impl DocSearch for SerperService { + async fn search( + &self, + q: &str, + limit: usize, + offset: usize, + ) -> Result { + let page = offset / limit; + let request = SerperRequest { + q: q.to_string(), + num: limit, + page, + }; + let response = self + .client + .post("https://google.serper.dev/search") + .json(&request) + .send() + .await + .map_err(|e| DocSearchError::Other(e.into()))? + .json::() + .await + .map_err(|e| DocSearchError::Other(e.into()))?; + + let hits = response + .organic + .into_iter() + .map(|hit| DocSearchHit { + score: 0.0, + doc: DocSearchDocument { + title: hit.title, + link: hit.link, + snippet: hit.snippet, + }, + }) + .collect(); + + Ok(DocSearchResponse { hits }) + } +} diff --git a/crates/tabby/src/services/doc.rs b/crates/tabby/src/services/doc/tantivy.rs similarity index 96% rename from crates/tabby/src/services/doc.rs rename to crates/tabby/src/services/doc/tantivy.rs index 0921f8d11..fd00bfbf9 100644 --- a/crates/tabby/src/services/doc.rs +++ b/crates/tabby/src/services/doc/tantivy.rs @@ -106,7 +106,7 @@ fn get_text(doc: &TantivyDocument, field: schema::Field) -> &str { doc.get_first(field).unwrap().as_str().unwrap() } -struct DocSearchService { +pub struct DocSearchService { search: Arc>>, loader: tokio::task::JoinHandle<()>, } @@ -120,7 +120,7 @@ impl Drop for DocSearchService { } impl DocSearchService { - fn new(embedding: Arc) -> Self { + pub fn new(embedding: Arc) -> Self { let search = Arc::new(RwLock::new(None)); let cloned_search = search.clone(); let loader = tokio::spawn(async move { @@ -150,7 +150,3 @@ impl DocSearch for DocSearchService { } } } - -pub fn create(embedding: Arc) -> impl DocSearch { - DocSearchService::new(embedding) -}