mirror of
https://github.com/TabbyML/tabby
synced 2024-11-22 00:08:06 +00:00
feat: support retrieve documents with serper.dev api (#2154)
This commit is contained in:
parent
19f3b9eb74
commit
2b5969b3f3
3
.changes/unreleased/Features-20240516-173446.yaml
Normal file
3
.changes/unreleased/Features-20240516-173446.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
kind: Features
|
||||
body: support retrieve documents with serper.dev api
|
||||
time: 2024-05-16T17:34:46.89429-07:00
|
@ -59,6 +59,7 @@ cached = { workspace = true, features = ["async"] }
|
||||
parse-git-url = "0.5.1"
|
||||
color-eyre = { version = "0.6.3" }
|
||||
derive_builder.workspace = true
|
||||
reqwest.workspace = true
|
||||
|
||||
[dependencies.openssl]
|
||||
optional = true
|
||||
|
@ -7,7 +7,7 @@ use tabby_common::api::{
|
||||
chat::Message,
|
||||
doc::{DocSearch, DocSearchDocument},
|
||||
};
|
||||
use tracing::warn;
|
||||
use tracing::{debug, warn};
|
||||
use utoipa::ToSchema;
|
||||
|
||||
use crate::services::chat::{ChatCompletionRequestBuilder, ChatService};
|
||||
@ -32,11 +32,26 @@ pub enum AnswerResponseChunk {
|
||||
pub struct AnswerService {
|
||||
chat: Arc<ChatService>,
|
||||
doc: Arc<dyn DocSearch>,
|
||||
serper: Option<Box<dyn DocSearch>>,
|
||||
}
|
||||
|
||||
impl AnswerService {
|
||||
fn new(chat: Arc<ChatService>, doc: Arc<dyn DocSearch>) -> Self {
|
||||
Self { chat, doc }
|
||||
if let Ok(api_key) = std::env::var("SERPER_API_KEY") {
|
||||
debug!("Serper API key found, enabling serper...");
|
||||
let serper = Box::new(super::doc::create_serper(api_key.as_str()));
|
||||
Self {
|
||||
chat,
|
||||
doc,
|
||||
serper: Some(serper),
|
||||
}
|
||||
} else {
|
||||
Self {
|
||||
chat,
|
||||
doc,
|
||||
serper: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn answer<'a>(
|
||||
@ -55,18 +70,30 @@ impl AnswerService {
|
||||
|
||||
// 2. Generate relevant docs from the query
|
||||
// For now we only collect from DocSearch.
|
||||
let serp = match self.doc.search(&query.content, 20, 0).await {
|
||||
Ok(docs) => docs,
|
||||
let mut hits = match self.doc.search(&query.content, 5, 0).await {
|
||||
Ok(docs) => docs.hits,
|
||||
Err(err) => {
|
||||
warn!("Failed to search docs: {:?}", err);
|
||||
return;
|
||||
warn!("Failed to search tantivy docs: {:?}", err);
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
|
||||
yield AnswerResponseChunk::RelevantDocuments(serp.hits.iter().map(|hit| hit.doc.clone()).collect());
|
||||
// If serper is available, we also collect from serper
|
||||
if let Some(serper) = self.serper.as_ref() {
|
||||
let serper_hits = match serper.search(&query.content, 5, 0).await {
|
||||
Ok(docs) => docs.hits,
|
||||
Err(err) => {
|
||||
warn!("Failed to search serper: {:?}", err);
|
||||
vec![]
|
||||
}
|
||||
};
|
||||
hits.extend(serper_hits);
|
||||
}
|
||||
|
||||
yield AnswerResponseChunk::RelevantDocuments(hits.iter().map(|hit| hit.doc.clone()).collect());
|
||||
|
||||
// 3. Generate relevant answers from the query
|
||||
let snippets = serp.hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::<Vec<_>>();
|
||||
let snippets = hits.iter().map(|hit| hit.doc.snippet.as_str()).collect::<Vec<_>>();
|
||||
yield AnswerResponseChunk::RelevantQuestions(self.generate_relevant_questions(&snippets, &query.content).await);
|
||||
|
||||
// 4. Generate override prompt from the query
|
||||
|
15
crates/tabby/src/services/doc/mod.rs
Normal file
15
crates/tabby/src/services/doc/mod.rs
Normal file
@ -0,0 +1,15 @@
|
||||
mod serper;
|
||||
mod tantivy;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use tabby_common::api::doc::DocSearch;
|
||||
use tabby_inference::Embedding;
|
||||
|
||||
pub fn create(embedding: Arc<dyn Embedding>) -> impl DocSearch {
|
||||
tantivy::DocSearchService::new(embedding)
|
||||
}
|
||||
|
||||
pub fn create_serper(api_key: &str) -> impl DocSearch {
|
||||
serper::SerperService::new(api_key)
|
||||
}
|
86
crates/tabby/src/services/doc/serper.rs
Normal file
86
crates/tabby/src/services/doc/serper.rs
Normal file
@ -0,0 +1,86 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_common::api::doc::{
|
||||
DocSearch, DocSearchDocument, DocSearchError, DocSearchHit, DocSearchResponse,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SerperRequest {
|
||||
q: String,
|
||||
num: usize,
|
||||
page: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerperResponse {
|
||||
organic: Vec<SerperOrganicHit>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SerperOrganicHit {
|
||||
title: String,
|
||||
snippet: String,
|
||||
link: String,
|
||||
}
|
||||
|
||||
pub struct SerperService {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl SerperService {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
"X-API-KEY",
|
||||
api_key.parse().expect("Failed to parse Serper API key"),
|
||||
);
|
||||
Self {
|
||||
client: reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.expect("Failed to create reqwest client"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl DocSearch for SerperService {
|
||||
async fn search(
|
||||
&self,
|
||||
q: &str,
|
||||
limit: usize,
|
||||
offset: usize,
|
||||
) -> Result<DocSearchResponse, DocSearchError> {
|
||||
let page = offset / limit;
|
||||
let request = SerperRequest {
|
||||
q: q.to_string(),
|
||||
num: limit,
|
||||
page,
|
||||
};
|
||||
let response = self
|
||||
.client
|
||||
.post("https://google.serper.dev/search")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| DocSearchError::Other(e.into()))?
|
||||
.json::<SerperResponse>()
|
||||
.await
|
||||
.map_err(|e| DocSearchError::Other(e.into()))?;
|
||||
|
||||
let hits = response
|
||||
.organic
|
||||
.into_iter()
|
||||
.map(|hit| DocSearchHit {
|
||||
score: 0.0,
|
||||
doc: DocSearchDocument {
|
||||
title: hit.title,
|
||||
link: hit.link,
|
||||
snippet: hit.snippet,
|
||||
},
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(DocSearchResponse { hits })
|
||||
}
|
||||
}
|
@ -106,7 +106,7 @@ fn get_text(doc: &TantivyDocument, field: schema::Field) -> &str {
|
||||
doc.get_first(field).unwrap().as_str().unwrap()
|
||||
}
|
||||
|
||||
struct DocSearchService {
|
||||
pub struct DocSearchService {
|
||||
search: Arc<RwLock<Option<DocSearchImpl>>>,
|
||||
loader: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
@ -120,7 +120,7 @@ impl Drop for DocSearchService {
|
||||
}
|
||||
|
||||
impl DocSearchService {
|
||||
fn new(embedding: Arc<dyn Embedding>) -> Self {
|
||||
pub fn new(embedding: Arc<dyn Embedding>) -> Self {
|
||||
let search = Arc::new(RwLock::new(None));
|
||||
let cloned_search = search.clone();
|
||||
let loader = tokio::spawn(async move {
|
||||
@ -150,7 +150,3 @@ impl DocSearch for DocSearchService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create(embedding: Arc<dyn Embedding>) -> impl DocSearch {
|
||||
DocSearchService::new(embedding)
|
||||
}
|
Loading…
Reference in New Issue
Block a user