feat: support retrieve documents with serper.dev api (#2154)

This commit is contained in:
Meng Zhang 2024-05-16 17:51:28 -07:00 committed by GitHub
parent 19f3b9eb74
commit 2b5969b3f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 142 additions and 14 deletions

View File

@ -0,0 +1,3 @@
kind: Features
body: support retrieve documents with serper.dev api
time: 2024-05-16T17:34:46.89429-07:00

View File

@ -59,6 +59,7 @@ cached = { workspace = true, features = ["async"] }
parse-git-url = "0.5.1"
color-eyre = { version = "0.6.3" }
derive_builder.workspace = true
reqwest.workspace = true
[dependencies.openssl]
optional = true

View File

@ -7,7 +7,7 @@ use tabby_common::api::{
chat::Message,
doc::{DocSearch, DocSearchDocument},
};
use tracing::warn;
use tracing::{debug, warn};
use utoipa::ToSchema;
use crate::services::chat::{ChatCompletionRequestBuilder, ChatService};
@ -32,11 +32,26 @@ pub enum AnswerResponseChunk {
pub struct AnswerService {
chat: Arc<ChatService>,
doc: Arc<dyn DocSearch>,
serper: Option<Box<dyn DocSearch>>,
}
impl AnswerService {
fn new(chat: Arc<ChatService>, doc: Arc<dyn DocSearch>) -> Self {
Self { chat, doc }
if let Ok(api_key) = std::env::var("SERPER_API_KEY") {
debug!("Serper API key found, enabling serper...");
let serper = Box::new(super::doc::create_serper(api_key.as_str()));
Self {
chat,
doc,
serper: Some(serper),
}
} else {
Self {
chat,
doc,
serper: None,
}
}
}
pub async fn answer<'a>(
@ -55,18 +70,30 @@ impl AnswerService {
// 2. Generate relevant docs from the query
// For now we only collect from DocSearch.
let serp = match self.doc.search(&query.content, 20, 0).await {
Ok(docs) => docs,
let mut hits = match self.doc.search(&query.content, 5, 0).await {
Ok(docs) => docs.hits,
Err(err) => {
warn!("Failed to search docs: {:?}", err);
return;
warn!("Failed to search tantivy docs: {:?}", err);
vec![]
}
};
yield AnswerResponseChunk::RelevantDocuments(serp.hits.iter().map(|hit| hit.doc.clone()).collect());
// If serper is available, we also collect from serper
if let Some(serper) = self.serper.as_ref() {
let serper_hits = match serper.search(&query.content, 5, 0).await {
Ok(docs) => docs.hits,
Err(err) => {
warn!("Failed to search serper: {:?}", err);
vec![]
}
};
hits.extend(serper_hits);
}
yield AnswerResponseChunk::RelevantDocuments(hits.iter().map(|hit| hit.doc.clone()).collect());
// 3. Generate relevant answers from the query
let snippets = serp.hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::<Vec<_>>();
let snippets = hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::<Vec<_>>();
yield AnswerResponseChunk::RelevantQuestions(self.generate_relevant_questions(&snippets, &query.content).await);
// 4. Generate override prompt from the query

View File

@ -0,0 +1,15 @@
mod serper;
mod tantivy;
use std::sync::Arc;
use tabby_common::api::doc::DocSearch;
use tabby_inference::Embedding;
pub fn create(embedding: Arc<dyn Embedding>) -> impl DocSearch {
tantivy::DocSearchService::new(embedding)
}
pub fn create_serper(api_key: &str) -> impl DocSearch {
serper::SerperService::new(api_key)
}

View File

@ -0,0 +1,86 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tabby_common::api::doc::{
DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse,
};
#[derive(Debug, Serialize)]
struct SerperRequest {
q: String,
num: usize,
page: usize,
}
#[derive(Debug, Deserialize)]
struct SerperResponse {
organic: Vec<SerperOrganicHit>,
}
#[derive(Debug, Deserialize)]
struct SerperOrganicHit {
title: String,
snippet: String,
link: String,
}
pub struct SerperService {
client: reqwest::Client,
}
impl SerperService {
pub fn new(api_key: &str) -> Self {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"X-API-KEY",
api_key.parse().expect("Failed to parse Serper API key"),
);
Self {
client: reqwest::Client::builder()
.default_headers(headers)
.build()
.expect("Failed to create reqwest client"),
}
}
}
#[async_trait]
impl DocSearch for SerperService {
async fn search(
&self,
q: &str,
limit: usize,
offset: usize,
) -> Result<DocSearchResponse, DocSearchError> {
let page = offset / limit;
let request = SerperRequest {
q: q.to_string(),
num: limit,
page,
};
let response = self
.client
.post("https://google.serper.dev/search")
.json(&request)
.send()
.await
.map_err(|e| DocSearchError::Other(e.into()))?
.json::<SerperResponse>()
.await
.map_err(|e| DocSearchError::Other(e.into()))?;
let hits = response
.organic
.into_iter()
.map(|hit| DocSearchHit {
score: 0.0,
doc: DocSearchDocument {
title: hit.title,
link: hit.link,
snippet: hit.snippet,
},
})
.collect();
Ok(DocSearchResponse { hits })
}
}

View File

@ -106,7 +106,7 @@ fn get_text(doc: &TantivyDocument, field: schema::Field) -> &str {
doc.get_first(field).unwrap().as_str().unwrap()
}
struct DocSearchService {
pub struct DocSearchService {
search: Arc<RwLock<Option<DocSearchImpl>>>,
loader: tokio::task::JoinHandle<()>,
}
@ -120,7 +120,7 @@ impl Drop for DocSearchService {
}
impl DocSearchService {
fn new(embedding: Arc<dyn Embedding>) -> Self {
pub fn new(embedding: Arc<dyn Embedding>) -> Self {
let search = Arc::new(RwLock::new(None));
let cloned_search = search.clone();
let loader = tokio::spawn(async move {
@ -150,7 +150,3 @@ impl DocSearch for DocSearchService {
}
}
}
pub fn create(embedding: Arc<dyn Embedding>) -> impl DocSearch {
DocSearchService::new(embedding)
}