refactor(webserver): extract BackgroundJob service (#2022)

* refactor(webserver): extract BackgroundJob service

* update
This commit is contained in:
Meng Zhang 2024-05-01 16:57:00 -07:00 committed by GitHub
parent 5536adcf93
commit 03c0a6d9ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 1041 additions and 1189 deletions

285
Cargo.lock generated
View File

@ -145,6 +145,91 @@ version = "1.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8"
[[package]]
name = "apalis"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3661d27ed090fb120a887a8416f648343a8e6e864791b36f6175a72b2ab3df39"
dependencies = [
"apalis-core",
"apalis-cron",
"apalis-redis",
"apalis-sql",
"futures",
"pin-project-lite",
"serde",
"thiserror",
"tokio",
"tower",
"tracing",
"tracing-futures",
]
[[package]]
name = "apalis-core"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d82227972a1bb6f5f5c4444b8228aaed79e28d6ad411e5f88ad46dc04cf066bb"
dependencies = [
"async-oneshot",
"async-timer",
"futures",
"pin-project-lite",
"serde",
"serde_json",
"thiserror",
"tower",
"ulid",
]
[[package]]
name = "apalis-cron"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d11c4150f1088c1237cfde2d5cd3b045c17b3ed605c52bb3346641e18f2e1f77"
dependencies = [
"apalis-core",
"async-stream",
"chrono",
"cron",
"futures",
"tower",
]
[[package]]
name = "apalis-redis"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd6f0968397ad66d4628a3c8022e201d3edc58eb44a522b5c76b5efd334b9fdd"
dependencies = [
"apalis-core",
"async-stream",
"async-trait",
"chrono",
"futures",
"log",
"redis",
"serde",
"tokio",
]
[[package]]
name = "apalis-sql"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99eaea6cf256a5d0fce59c68608ba16e3ea9f01cb4a45e5c7fa5709ea44dacd1"
dependencies = [
"apalis-core",
"async-stream",
"futures",
"futures-lite",
"log",
"serde",
"serde_json",
"sqlx",
"tokio",
]
[[package]]
name = "arc-swap"
version = "1.6.0"
@ -194,6 +279,15 @@ dependencies = [
"async-trait",
]
[[package]]
name = "async-oneshot"
version = "0.5.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae47de2a02d543205f3f5457a90b6ecbc9494db70557bd29590ec8f1ddff5463"
dependencies = [
"futures-micro",
]
[[package]]
name = "async-openai"
version = "0.18.3"
@ -242,10 +336,21 @@ dependencies = [
]
[[package]]
name = "async-trait"
version = "0.1.74"
name = "async-timer"
version = "1.0.0-beta.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9"
checksum = "54a18932baa05100f01c9980d03e330f95a8f2dee1a7576969fa507bdce3b568"
dependencies = [
"error-code",
"libc",
"wasm-bindgen",
]
[[package]]
name = "async-trait"
version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
@ -261,16 +366,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-write-file"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edcdbedc2236483ab103a53415653d6b4442ea6141baf1ffa85df29635e88436"
dependencies = [
"nix",
"rand 0.8.5",
]
[[package]]
name = "auto_enums"
version = "0.8.5"
@ -704,6 +799,20 @@ dependencies = [
"unreachable",
]
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "console"
version = "0.15.7"
@ -774,9 +883,9 @@ dependencies = [
[[package]]
name = "cron"
version = "0.12.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ff76b51e4c068c52bfd2866e1567bee7c567ae8f24ada09fd4307019e25eab7"
checksum = "6f8c3e73077b4b4a6ab1ea5047c37c57aee77657bc8ecd6f29b0af082d0b0c07"
dependencies = [
"chrono",
"nom",
@ -1203,6 +1312,12 @@ dependencies = [
"libc",
]
[[package]]
name = "error-code"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0474425d51df81997e2f90a21591180b38eccf27292d755f3e30750225c175b"
[[package]]
name = "etcetera"
version = "0.8.0"
@ -1430,6 +1545,19 @@ version = "0.3.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
[[package]]
name = "futures-lite"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5"
dependencies = [
"fastrand 2.0.1",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
[[package]]
name = "futures-macro"
version = "0.3.29"
@ -1441,6 +1569,15 @@ dependencies = [
"syn 2.0.52",
]
[[package]]
name = "futures-micro"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b460264b3593d68b16a7bc35f7bc226ddfebdf9a1c8db1ed95d5cc6b7168c826"
dependencies = [
"pin-project-lite",
]
[[package]]
name = "futures-sink"
version = "0.3.29"
@ -1527,8 +1664,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@ -1622,7 +1761,7 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1abd4ce5247dfc04a03ccde70f87a048458c9356c7e41d21ad8c407b3dde6f2"
dependencies = [
"combine",
"combine 3.8.1",
"thiserror",
]
@ -1632,7 +1771,7 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2ebc8013b4426d5b81a4364c419a95ed0b404af2b82e2457de52d9348f0e474"
dependencies = [
"combine",
"combine 3.8.1",
"thiserror",
]
@ -2831,17 +2970,6 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9a91b326434fca226707ed8ec1fd22d4e1c96801abdf10c412afdc7d97116e0"
[[package]]
name = "nix"
version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.4.0",
"cfg-if",
"libc",
]
[[package]]
name = "nom"
version = "7.1.3"
@ -3251,6 +3379,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "parking"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]]
name = "parking_lot"
version = "0.11.2"
@ -3728,6 +3862,29 @@ dependencies = [
"num_cpus",
]
[[package]]
name = "redis"
version = "0.25.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6472825949c09872e8f2c50bde59fcefc17748b6be5c90fd67cd8b4daca73bfd"
dependencies = [
"arc-swap",
"async-trait",
"bytes",
"combine 4.6.7",
"futures",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"tokio",
"tokio-retry",
"tokio-util",
"url",
]
[[package]]
name = "redox_syscall"
version = "0.2.16"
@ -4426,6 +4583,12 @@ dependencies = [
"digest",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]]
name = "sha2"
version = "0.10.8"
@ -4641,9 +4804,9 @@ dependencies = [
[[package]]
name = "sqlx"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
dependencies = [
"sqlx-core",
"sqlx-macros",
@ -4654,9 +4817,9 @@ dependencies = [
[[package]]
name = "sqlx-core"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6"
dependencies = [
"ahash",
"atoi",
@ -4665,7 +4828,6 @@ dependencies = [
"chrono",
"crc",
"crossbeam-queue",
"dotenvy",
"either",
"event-listener",
"futures-channel",
@ -4681,6 +4843,8 @@ dependencies = [
"once_cell",
"paste",
"percent-encoding",
"rustls 0.21.10",
"rustls-pemfile 1.0.4",
"serde",
"serde_json",
"sha2",
@ -4691,13 +4855,14 @@ dependencies = [
"tokio-stream",
"tracing",
"url",
"webpki-roots",
]
[[package]]
name = "sqlx-macros"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
dependencies = [
"proc-macro2",
"quote",
@ -4708,11 +4873,10 @@ dependencies = [
[[package]]
name = "sqlx-macros-core"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8"
dependencies = [
"atomic-write-file",
"dotenvy",
"either",
"heck",
@ -4735,9 +4899,9 @@ dependencies = [
[[package]]
name = "sqlx-mysql"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418"
dependencies = [
"atoi",
"base64 0.21.5",
@ -4778,9 +4942,9 @@ dependencies = [
[[package]]
name = "sqlx-postgres"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e"
dependencies = [
"atoi",
"base64 0.21.5",
@ -4806,7 +4970,6 @@ dependencies = [
"rand 0.8.5",
"serde",
"serde_json",
"sha1",
"sha2",
"smallvec",
"sqlx-core",
@ -4818,9 +4981,9 @@ dependencies = [
[[package]]
name = "sqlx-sqlite"
version = "0.7.3"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [
"atoi",
"chrono",
@ -5173,6 +5336,7 @@ name = "tabby-webserver"
version = "0.11.0-dev.0"
dependencies = [
"anyhow",
"apalis",
"argon2",
"assert_matches",
"async-trait",
@ -5195,7 +5359,6 @@ dependencies = [
"octocrab",
"pin-project",
"querystring",
"rand 0.8.5",
"regex",
"reqwest",
"rust-embed 8.0.0",
@ -5210,7 +5373,6 @@ dependencies = [
"temp_testdir",
"thiserror",
"tokio",
"tokio-cron-scheduler",
"tokio-tungstenite",
"tower",
"tower-http 0.4.0",
@ -5458,18 +5620,18 @@ checksum = "d321c8576c2b47e43953e9cce236550d4cd6af0a6ce518fe084340082ca6037b"
[[package]]
name = "thiserror"
version = "1.0.49"
version = "1.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4"
checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.49"
version = "1.0.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc"
checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66"
dependencies = [
"proc-macro2",
"quote",
@ -6204,6 +6366,17 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9"
[[package]]
name = "ulid"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34778c17965aa2a08913b57e1f34db9b4a63f5de31768b55bf20d2795f921259"
dependencies = [
"getrandom 0.2.11",
"rand 0.8.5",
"web-time",
]
[[package]]
name = "unicase"
version = "2.7.0"
@ -6579,6 +6752,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.25.4"

View File

@ -143,7 +143,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
#[cfg(feature = "ee")]
let ws = if !args.no_webserver {
Some(tabby_webserver::public::WebserverHandle::new(create_event_logger()).await)
Some(tabby_webserver::public::WebserverHandle::new(create_event_logger(), args.port).await)
} else {
None
};
@ -166,7 +166,7 @@ pub async fn main(config: &Config, args: &ServeArgs) {
#[cfg(feature = "ee")]
if let Some(ws) = &ws {
let (new_api, new_ui) = ws
.attach_webserver(api, ui, code, args.chat_model.is_some(), args.port)
.attach_webserver(api, ui, code, args.chat_model.is_some())
.await;
api = new_api;
ui = new_ui;

View File

@ -38,9 +38,8 @@ tabby-db = { path = "../../ee/tabby-db" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["fs", "process"] }
tokio-cron-scheduler = { workspace = true }
tokio-tungstenite = "0.20.1"
tower = { version = "0.4", features = ["util"] }
tower = { version = "0.4", features = ["util", "limit"] }
tower-http = { version = "0.4.0", features = ["fs", "trace"] }
tracing.workspace = true
unicase = "2.7.0"
@ -53,7 +52,7 @@ tabby-search = { path = "../tabby-search" }
octocrab = "0.38.0"
fs_extra = "1.3.0"
gitlab = "0.1610.0"
rand = "0.8.5"
apalis = { version = "0.5.1", features = ["sqlite", "cron" ] }
[dev-dependencies]
assert_matches = "1.5.0"

View File

@ -1,256 +0,0 @@
use std::{collections::HashMap, pin::Pin, sync::Arc};
use chrono::Utc;
use futures::Future;
use tabby_db::DbConn;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{debug, info, warn};
pub struct JobController {
scheduler: JobScheduler,
db: DbConn,
job_registry: HashMap<
&'static str,
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync + 'static>,
>,
event_sender: UnboundedSender<String>,
}
impl JobController {
pub async fn new(db: DbConn, event_sender: UnboundedSender<String>) -> Self {
db.finalize_stale_job_runs()
.await
.expect("failed to cleanup stale jobs");
let scheduler = JobScheduler::new()
.await
.expect("failed to create job scheduler");
Self {
scheduler,
db,
job_registry: HashMap::default(),
event_sender,
}
}
fn run_job(&self, name: &str) -> tokio::task::JoinHandle<()> {
let func = self
.job_registry
.get(name)
.expect("failed to get job")
.clone();
// Spawn a new thread for panic isolation
tokio::task::spawn(async move {
func().await;
})
}
/// Start the worker that listens for job events and runs the jobs.
///
/// 1. Only one instance of the job will be run at a time.
/// 2. Jobs are deduplicated within a time window (120 seconds).
pub fn start_worker(self: &Arc<Self>, mut event_receiver: UnboundedReceiver<String>) {
const JOB_DEDUPE_WINDOW_SECS: i64 = 120;
let controller = self.clone();
tokio::spawn(async move {
// Sleep for 5 seconds to allow the webserver to start.
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
let mut last_timestamps = HashMap::new();
loop {
while let Some(name) = event_receiver.recv().await {
if let Some(last_timestamp) = last_timestamps.get(&name) {
if Utc::now()
.signed_duration_since(*last_timestamp)
.num_seconds()
< JOB_DEDUPE_WINDOW_SECS
{
info!("Job `{name}` last ran less than {JOB_DEDUPE_WINDOW_SECS} seconds ago (@{last_timestamp}), skipped");
continue;
}
}
last_timestamps.insert(name.clone(), Utc::now());
let _ = controller.run_job(&name).await;
}
}
});
}
pub async fn start_cron(&self) {
if std::env::var("TABBY_WEBSERVER_CONTROLLER_ONESHOT").is_ok() {
warn!(
"Running controller job as oneshot, this should only be used for debugging purpose..."
);
for name in self.job_registry.keys() {
let _ = self.event_sender.send(name.to_string());
}
} else {
self.scheduler
.start()
.await
.expect("failed to start job scheduler")
}
}
/// Register a new job with the scheduler, the job will be displayed in Jobs dashboard.
pub async fn register_public<T>(&mut self, name: &'static str, schedule: &str, func: T)
where
T: FnMut(&JobContext) -> Pin<Box<dyn Future<Output = anyhow::Result<i32>> + Send>>
+ Send
+ Sync
+ Clone
+ 'static,
{
self.register_impl(true, name, schedule, func).await;
}
/// Register a new job with the scheduler, the job will NOT be displayed in Jobs dashboard.
pub async fn register<T>(&mut self, name: &'static str, schedule: &str, func: T)
where
T: FnMut() -> Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send>>
+ Send
+ Sync
+ Clone
+ 'static,
{
self.register_impl(false, name, schedule, move |_| {
let mut func = func.clone();
Box::pin(async move {
func().await?;
Ok(0)
})
})
.await;
}
async fn register_impl<F>(
&mut self,
is_public: bool,
name: &'static str,
schedule: &str,
func: F,
) where
F: FnMut(&JobContext) -> Pin<Box<dyn Future<Output = anyhow::Result<i32>> + Send>>
+ Send
+ Sync
+ Clone
+ 'static,
{
let db = self.db.clone();
self.job_registry.insert(
name,
Arc::new(move || {
let db = db.clone();
let mut func = func.clone();
Box::pin(async move {
debug!("Running job `{}`", name);
let context = JobContext::new(is_public, name, db.clone()).await;
match func(&context).await {
Ok(exit_code) => {
debug!("Job `{}` completed with exit code {}", name, exit_code);
context.complete(exit_code).await;
}
Err(e) => {
warn!("Job `{}` failed: {}", name, e);
context.complete(-1).await;
}
};
})
}),
);
self.add_to_schedule(name, schedule).await
}
async fn add_to_schedule(&mut self, name: &'static str, schedule: &str) {
let event_sender = self.event_sender.clone();
let job = Job::new_async(schedule, move |uuid, mut scheduler| {
let event_sender = event_sender.clone();
Box::pin(async move {
if let Err(err) = event_sender.send(name.to_owned()) {
warn!("Failed to schedule job `{}`: {}", &name, err);
} else {
debug!("Scheduling job `{}`", &name);
}
if let Ok(Some(next_tick)) = scheduler.next_tick_for_job(uuid).await {
debug!(
"Next time for job `{}` is {:?}",
&name,
next_tick.with_timezone(&chrono::Local)
);
}
})
})
.expect("failed to create job");
self.scheduler.add(job).await.expect("failed to add job");
}
}
#[derive(Clone)]
pub struct JobContext {
id: i64,
db: DbConn,
}
impl JobContext {
async fn new(public: bool, name: &'static str, db: DbConn) -> Self {
let id = if public {
db.create_job_run(name.to_owned())
.await
.expect("failed to create job")
} else {
-1
};
Self { id, db }
}
fn is_private(&self) -> bool {
self.id < 0
}
pub async fn stdout_writeline(&self, stdout: String) {
if self.is_private() {
return;
}
let stdout = stdout + "\n";
match self.db.update_job_stdout(self.id, stdout).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to write stdout to job `{}`", self.id);
}
}
}
pub async fn stderr_writeline(&self, stderr: String) {
if self.is_private() {
return;
}
let stderr = stderr + "\n";
match self.db.update_job_stderr(self.id, stderr).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to write stderr to job `{}`", self.id);
}
}
}
async fn complete(&self, exit_code: i32) {
if self.is_private() {
return;
}
match self.db.update_job_status(self.id, exit_code).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to complete job `{}`", self.id);
}
}
}
}

View File

@ -1,132 +0,0 @@
use std::sync::Arc;
use anyhow::Result;
use chrono::Utc;
use juniper::ID;
use octocrab::{models::Repository, GitHubError, Octocrab};
use crate::{
cron::controller::JobContext,
schema::repository::{GithubRepositoryProvider, GithubRepositoryService},
warn_stderr,
};
pub async fn refresh_all_repositories(
context: JobContext,
service: Arc<dyn GithubRepositoryService>,
) -> Result<i32> {
for provider in service
.list_providers(vec![], None, None, None, None)
.await?
{
let start = Utc::now();
context
.stdout_writeline(format!(
"Refreshing repositories for provider: {}\n",
provider.display_name
))
.await;
refresh_repositories_for_provider(context.clone(), service.clone(), provider.id.clone())
.await?;
service
.delete_outdated_repositories(provider.id, start)
.await?;
}
Ok(0)
}
async fn refresh_repositories_for_provider(
context: JobContext,
service: Arc<dyn GithubRepositoryService>,
provider_id: ID,
) -> Result<()> {
let provider = service.get_provider(provider_id).await?;
let repos = match fetch_all_repos(&provider).await {
Ok(repos) => repos,
Err(octocrab::Error::GitHub {
source: source @ GitHubError { .. },
..
}) if source.status_code.is_client_error() => {
service
.update_provider_status(provider.id.clone(), false)
.await?;
warn_stderr!(
context,
"GitHub credentials for provider {} are expired or invalid",
provider.display_name
);
return Err(source.into());
}
Err(e) => {
warn_stderr!(context, "Failed to fetch repositories from github: {e}");
return Err(e.into());
}
};
for repo in repos {
context
.stdout_writeline(format!(
"Importing: {}",
repo.full_name.as_deref().unwrap_or(&repo.name)
))
.await;
let id = repo.id.to_string();
let Some(url) = repo.git_url else {
continue;
};
let url = url.to_string();
let url = url
.strip_prefix("git://")
.map(|url| format!("https://{url}"))
.unwrap_or(url);
let url = url.strip_suffix(".git").unwrap_or(&url);
service
.upsert_repository(
provider.id.clone(),
id,
repo.full_name.unwrap_or(repo.name),
url.to_string(),
)
.await?;
}
service
.update_provider_status(provider.id.clone(), true)
.await?;
Ok(())
}
// FIXME(wsxiaoys): Convert to async stream
async fn fetch_all_repos(
provider: &GithubRepositoryProvider,
) -> Result<Vec<Repository>, octocrab::Error> {
let Some(token) = &provider.access_token else {
return Ok(vec![]);
};
let octocrab = Octocrab::builder()
.user_access_token(token.to_string())
.build()?;
let mut page = 1;
let mut repos = vec![];
loop {
let response = octocrab
.current()
.list_repos_for_authenticated_user()
.visibility("all")
.page(page)
.send()
.await?;
let pages = response.number_of_pages().unwrap_or_default() as u8;
repos.extend(response.items);
page += 1;
if page > pages {
break;
}
}
Ok(repos)
}

View File

@ -1,141 +0,0 @@
use std::sync::Arc;
use anyhow::Result;
use chrono::Utc;
use gitlab::{
api::{projects::Projects, ApiError, AsyncQuery, Pagination},
GitlabBuilder,
};
use juniper::ID;
use serde::Deserialize;
use crate::{
cron::controller::JobContext,
schema::repository::{GitlabRepositoryProvider, GitlabRepositoryService},
warn_stderr,
};
pub async fn refresh_all_repositories(
context: JobContext,
service: Arc<dyn GitlabRepositoryService>,
) -> Result<i32> {
for provider in service
.list_providers(vec![], None, None, None, None)
.await?
{
let start = Utc::now();
context
.stdout_writeline(format!(
"Refreshing repositories for provider: {}\n",
provider.display_name
))
.await;
refresh_repositories_for_provider(context.clone(), service.clone(), provider.id.clone())
.await?;
service
.delete_outdated_repositories(provider.id, start)
.await?;
}
Ok(0)
}
async fn refresh_repositories_for_provider(
context: JobContext,
service: Arc<dyn GitlabRepositoryService>,
provider_id: ID,
) -> Result<()> {
let provider = service.get_provider(provider_id).await?;
let repos = match fetch_all_repos(&provider).await {
Ok(repos) => repos,
Err(e) if e.is_client_error() => {
service
.update_provider_status(provider.id.clone(), false)
.await?;
warn_stderr!(
context,
"GitLab credentials for provider {} are expired or invalid",
provider.display_name
);
return Err(e.into());
}
Err(e) => {
warn_stderr!(context, "Failed to fetch repositories from gitlab: {e}");
return Err(e.into());
}
};
for repo in repos {
context
.stdout_writeline(format!("Importing: {}", &repo.path_with_namespace))
.await;
let id = repo.id.to_string();
let url = repo.http_url_to_repo;
let url = url.strip_suffix(".git").unwrap_or(&url);
service
.upsert_repository(
provider.id.clone(),
id,
repo.path_with_namespace,
url.to_string(),
)
.await?;
}
service
.update_provider_status(provider.id.clone(), true)
.await?;
Ok(())
}
#[derive(Deserialize)]
struct Repository {
id: u128,
path_with_namespace: String,
http_url_to_repo: String,
}
#[derive(thiserror::Error, Debug)]
enum GitlabError {
#[error(transparent)]
Rest(#[from] gitlab::api::ApiError<gitlab::RestError>),
#[error(transparent)]
Gitlab(#[from] gitlab::GitlabError),
#[error(transparent)]
Projects(#[from] gitlab::api::projects::ProjectsBuilderError),
}
impl GitlabError {
fn is_client_error(&self) -> bool {
match self {
GitlabError::Rest(source)
| GitlabError::Gitlab(gitlab::GitlabError::Api { source }) => {
matches!(
source,
ApiError::Auth { .. }
| ApiError::Client {
source: gitlab::RestError::AuthError { .. }
}
| ApiError::Gitlab { .. }
)
}
_ => false,
}
}
}
async fn fetch_all_repos(
provider: &GitlabRepositoryProvider,
) -> Result<Vec<Repository>, GitlabError> {
let Some(token) = &provider.access_token else {
return Ok(vec![]);
};
let gitlab = GitlabBuilder::new("gitlab.com", token)
.build_async()
.await?;
Ok(gitlab::api::paged(
Projects::builder().membership(true).build()?,
Pagination::All,
)
.query_async(&gitlab)
.await?)
}

View File

@ -1,67 +0,0 @@
//! db maintenance jobs
mod github;
mod gitlab;
use std::sync::Arc;
use super::{controller::JobController, every_two_hours};
use crate::schema::{
auth::AuthenticationService,
repository::{GithubRepositoryService, GitlabRepositoryService},
};
pub async fn register_jobs(
controller: &mut JobController,
auth: Arc<dyn AuthenticationService>,
github: Arc<dyn GithubRepositoryService>,
gitlab: Arc<dyn GitlabRepositoryService>,
) {
let cloned_auth = auth.clone();
controller
.register(
"remove_staled_refresh_token",
&every_two_hours(),
move || {
let auth = cloned_auth.clone();
Box::pin(async move { Ok(auth.delete_expired_token().await?) })
},
)
.await;
let cloned_auth = auth.clone();
controller
.register(
"remove_staled_password_reset",
&every_two_hours(),
move || {
let auth = cloned_auth.clone();
Box::pin(async move { Ok(auth.delete_expired_password_resets().await?) })
},
)
.await;
controller
.register_public(
"import_github_repositories",
&every_two_hours(),
move |context| {
let context = context.clone();
let github = github.clone();
Box::pin(async move { github::refresh_all_repositories(context, github).await })
},
)
.await;
controller
.register_public(
"import_gitlab_repositories",
&every_two_hours(),
move |context| {
let gitlab = gitlab.clone();
let context = context.clone();
Box::pin(async move { gitlab::refresh_all_repositories(context, gitlab).await })
},
)
.await;
}

View File

@ -1,59 +0,0 @@
mod controller;
mod db;
mod scheduler;
use std::sync::Arc;
use rand::Rng;
use tabby_db::DbConn;
use crate::schema::{
auth::AuthenticationService, repository::RepositoryService, worker::WorkerService,
};
#[macro_export]
macro_rules! warn_stderr {
($ctx:expr, $($params:tt)+) => {
tracing::warn!($($params)+);
$ctx.stderr_writeline(format!($($params)+)).await;
}
}
pub async fn run_cron(
schedule_event_sender: tokio::sync::mpsc::UnboundedSender<String>,
schedule_event_receiver: tokio::sync::mpsc::UnboundedReceiver<String>,
db: DbConn,
auth: Arc<dyn AuthenticationService>,
worker: Arc<dyn WorkerService>,
repository: Arc<dyn RepositoryService>,
local_port: u16,
) {
let mut controller = controller::JobController::new(db, schedule_event_sender).await;
db::register_jobs(
&mut controller,
auth,
repository.github(),
repository.gitlab(),
)
.await;
scheduler::register(&mut controller, worker, local_port).await;
let controller = Arc::new(controller);
controller.start_worker(schedule_event_receiver);
controller.start_cron().await
}
fn every_two_hours() -> String {
let mut rng = rand::thread_rng();
format!(
"{} {} */2 * * *",
rng.gen_range(0..59),
rng.gen_range(0..59)
)
}
fn every_ten_minutes() -> String {
let mut rng = rand::thread_rng();
format!("{} */10 * * * *", rng.gen_range(0..59))
}

View File

@ -1,74 +0,0 @@
use std::{process::Stdio, sync::Arc};
use anyhow::{Context, Result};
use tokio::io::AsyncBufReadExt;
use super::{
controller::{JobContext, JobController},
every_ten_minutes,
};
use crate::schema::worker::WorkerService;
pub async fn register(
controller: &mut JobController,
worker: Arc<dyn WorkerService>,
local_port: u16,
) {
controller
.register_public("scheduler", &every_ten_minutes(), move |context| {
let context = context.clone();
let worker = worker.clone();
Box::pin(async move { run_scheduler_now(context, worker, local_port).await })
})
.await;
}
async fn run_scheduler_now(
context: JobContext,
worker: Arc<dyn WorkerService>,
local_port: u16,
) -> Result<i32> {
let exe = std::env::current_exe()?;
let mut child = tokio::process::Command::new(exe)
.arg("scheduler")
.arg("--now")
.arg("--url")
.arg(format!("localhost:{local_port}"))
.arg("--token")
.arg(worker.read_registration_token().await?)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
{
// Pipe stdout
let stdout = child.stdout.take().context("Failed to acquire stdout")?;
let ctx = context.clone();
tokio::spawn(async move {
let stdout = tokio::io::BufReader::new(stdout);
let mut stdout = stdout.lines();
while let Ok(Some(line)) = stdout.next_line().await {
let _ = ctx.stdout_writeline(line).await;
}
});
}
{
// Pipe stderr
let stderr = child.stderr.take().context("Failed to acquire stderr")?;
let ctx = context.clone();
tokio::spawn(async move {
let stderr = tokio::io::BufReader::new(stderr);
let mut stdout = stderr.lines();
while let Ok(Some(line)) = stdout.next_line().await {
let _ = ctx.stderr_writeline(line).await;
}
});
}
if let Some(exit_code) = child.wait().await.ok().and_then(|s| s.code()) {
Ok(exit_code)
} else {
Ok(-1)
}
}

View File

@ -23,7 +23,6 @@ use tracing::{error, warn};
use crate::{
axum::{extract::AuthBearer, graphql},
cron,
hub::{self, HubState},
oauth,
path::db_file,
@ -32,7 +31,9 @@ use crate::{
auth::AuthenticationService, create_schema, repository::RepositoryService, Schema,
ServiceLocator,
},
service::{create_service_locator, event_logger::create_event_logger, repository},
service::{
background_job, create_service_locator, event_logger::create_event_logger, repository,
},
ui,
};
@ -43,12 +44,13 @@ pub struct WebserverHandle {
}
impl WebserverHandle {
pub async fn new(logger1: impl EventLogger + 'static) -> Self {
pub async fn new(logger1: impl EventLogger + 'static, local_port: u16) -> Self {
let db = DbConn::new(db_file().as_path())
.await
.expect("Must be able to initialize db");
let repository = repository::create(db.clone());
let background_job = background_job::create(db.clone(), local_port).await;
let repository = repository::create(db.clone(), background_job);
let logger2 = create_event_logger(db.clone());
let logger = Arc::new(ComposedLogger::new(logger1, logger2));
@ -73,28 +75,13 @@ impl WebserverHandle {
ui: Router,
code: Arc<dyn CodeSearch>,
is_chat_enabled: bool,
local_port: u16,
) -> (Router, Router) {
let (schedule_event_sender, schedule_event_receiver) =
tokio::sync::mpsc::unbounded_channel();
let ctx = create_service_locator(
self.logger(),
code,
self.repository.clone(),
self.db.clone(),
is_chat_enabled,
schedule_event_sender.clone(),
)
.await;
cron::run_cron(
schedule_event_sender,
schedule_event_receiver,
self.db.clone(),
ctx.auth(),
ctx.worker(),
ctx.repository(),
local_port,
)
.await;

View File

@ -1,7 +1,6 @@
//! Defines behavior for the tabby webserver which allows users to interact with enterprise features.
//! Using the web interface (e.g chat playground) requires using this module with the `--webserver` flag on the command line.
mod axum;
mod cron;
mod env;
mod handler;
mod hub;
@ -42,3 +41,11 @@ macro_rules! bail {
return std::result::Result::Err(anyhow::anyhow!($fmt, $($arg)*).into())
};
}
#[macro_export]
macro_rules! warn_stderr {
($ctx:expr, $($params:tt)+) => {
tracing::warn!($($params)+);
$ctx.stderr_writeline(format!($($params)+)).await;
}
}

View File

@ -16,3 +16,7 @@ pub fn db_file() -> PathBuf {
tabby_ee_root().join("dev-db.sqlite")
}
}
pub fn job_queue() -> PathBuf {
tabby_ee_root().join("job_queue.sqlite")
}

View File

@ -401,8 +401,6 @@ pub trait AuthenticationService: Send + Sync {
async fn token_auth(&self, email: String, password: String) -> Result<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> Result<RefreshTokenResponse>;
async fn delete_expired_token(&self) -> Result<()>;
async fn delete_expired_password_resets(&self) -> Result<()>;
async fn verify_access_token(&self, access_token: &str) -> Result<JWTPayload>;
async fn is_admin_initialized(&self) -> Result<bool>;
async fn get_user_by_email(&self, email: &str) -> Result<User>;

View File

@ -47,9 +47,6 @@ impl relay::NodeType for JobRun {
#[async_trait]
pub trait JobService: Send + Sync {
// Schedule one job immediately.
fn schedule(&self, name: &str);
async fn list(
&self,
ids: Option<Vec<ID>>,

View File

@ -789,7 +789,6 @@ impl Mutation {
.github()
.create_provider(input.display_name, input.access_token)
.await?;
ctx.locator.job().schedule("import_github_repositories");
Ok(id)
}
@ -814,7 +813,6 @@ impl Mutation {
.github()
.update_provider(input.id, input.display_name, input.access_token)
.await?;
ctx.locator.job().schedule("import_github_repositories");
Ok(true)
}
@ -843,7 +841,6 @@ impl Mutation {
.gitlab()
.create_provider(input.display_name, input.access_token)
.await?;
ctx.locator.job().schedule("import_gitlab_repositories");
Ok(id)
}
@ -868,7 +865,6 @@ impl Mutation {
.gitlab()
.update_provider(input.id, input.display_name, input.access_token)
.await?;
ctx.locator.job().schedule("import_gitlab_repositories");
Ok(true)
}

View File

@ -1,5 +1,4 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use juniper::{GraphQLObject, ID};
use super::{RepositoryProvider, RepositoryProviderStatus};
@ -66,7 +65,6 @@ impl NodeType for GithubProvidedRepository {
#[async_trait]
pub trait GithubRepositoryService: Send + Sync + RepositoryProvider {
async fn create_provider(&self, display_name: String, access_token: String) -> Result<ID>;
async fn get_provider(&self, id: ID) -> Result<GithubRepositoryProvider>;
async fn delete_provider(&self, id: ID) -> Result<()>;
async fn update_provider(
&self,
@ -74,7 +72,6 @@ pub trait GithubRepositoryService: Send + Sync + RepositoryProvider {
display_name: String,
access_token: String,
) -> Result<()>;
async fn update_provider_status(&self, id: ID, success: bool) -> Result<()>;
async fn list_providers(
&self,
@ -95,18 +92,6 @@ pub trait GithubRepositoryService: Send + Sync + RepositoryProvider {
last: Option<usize>,
) -> Result<Vec<GithubProvidedRepository>>;
async fn upsert_repository(
&self,
provider_id: ID,
vendor_id: String,
display_name: String,
git_url: String,
) -> Result<()>;
async fn update_repository_active(&self, id: ID, active: bool) -> Result<()>;
async fn delete_outdated_repositories(
&self,
provider_id: ID,
cutoff_timestamp: DateTime<Utc>,
) -> Result<()>;
async fn list_active_git_urls(&self) -> Result<Vec<String>>;
}

View File

@ -1,5 +1,4 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use juniper::{GraphQLObject, ID};
use super::{RepositoryProvider, RepositoryProviderStatus};
@ -66,7 +65,6 @@ impl NodeType for GitlabProvidedRepository {
#[async_trait]
pub trait GitlabRepositoryService: Send + Sync + RepositoryProvider {
async fn create_provider(&self, display_name: String, access_token: String) -> Result<ID>;
async fn get_provider(&self, id: ID) -> Result<GitlabRepositoryProvider>;
async fn delete_provider(&self, id: ID) -> Result<()>;
async fn update_provider(
&self,
@ -74,7 +72,6 @@ pub trait GitlabRepositoryService: Send + Sync + RepositoryProvider {
display_name: String,
access_token: String,
) -> Result<()>;
async fn update_provider_status(&self, id: ID, success: bool) -> Result<()>;
async fn list_providers(
&self,
@ -95,18 +92,6 @@ pub trait GitlabRepositoryService: Send + Sync + RepositoryProvider {
last: Option<usize>,
) -> Result<Vec<GitlabProvidedRepository>>;
async fn upsert_repository(
&self,
provider_id: ID,
vendor_id: String,
display_name: String,
git_url: String,
) -> Result<()>;
async fn update_repository_active(&self, id: ID, active: bool) -> Result<()>;
async fn list_active_git_urls(&self) -> Result<Vec<String>>;
async fn delete_outdated_repositories(
&self,
provider_id: ID,
cutoff_timestamp: DateTime<Utc>,
) -> Result<()>;
}

View File

@ -265,16 +265,6 @@ impl AuthenticationService for AuthenticationServiceImpl {
Ok(resp)
}
async fn delete_expired_token(&self) -> Result<()> {
self.db.delete_expired_token().await?;
Ok(())
}
async fn delete_expired_password_resets(&self) -> Result<()> {
self.db.delete_expired_password_resets().await?;
Ok(())
}
async fn verify_access_token(&self, access_token: &str) -> Result<JWTPayload> {
let claims = validate_jwt(access_token).map_err(anyhow::Error::new)?;
Ok(claims)
@ -1114,19 +1104,6 @@ mod tests {
.password_reset(&reset.code, "newpass")
.await
.is_err());
service
.db
.mark_password_reset_expired(&reset.code)
.await
.unwrap();
service.delete_expired_password_resets().await.unwrap();
assert!(service
.db
.get_password_reset_by_code(&reset.code)
.await
.unwrap()
.is_none());
}
#[tokio::test]

View File

@ -0,0 +1,38 @@
use std::str::FromStr;
use apalis::{
cron::{CronStream, Schedule},
prelude::{Data, Job, Monitor, WorkerBuilder, WorkerFactoryFn},
utils::TokioExecutor,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tabby_db::DbConn;
use tracing::debug;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DbMaintainanceJob;
impl Job for DbMaintainanceJob {
const NAME: &'static str = "db_maintainance";
}
impl DbMaintainanceJob {
async fn cron(_now: DateTime<Utc>, db: Data<DbConn>) -> crate::schema::Result<()> {
debug!("Running db maintainance job");
db.delete_expired_token().await?;
db.delete_expired_password_resets().await?;
Ok(())
}
pub fn register(monitor: Monitor<TokioExecutor>, db: DbConn) -> Monitor<TokioExecutor> {
let schedule = Schedule::from_str("@hourly").expect("unable to parse cron schedule");
monitor.register(
WorkerBuilder::new(DbMaintainanceJob::NAME)
.stream(CronStream::new(schedule).into_stream())
.data(db)
.build_fn(DbMaintainanceJob::cron),
)
}
}

View File

@ -0,0 +1,182 @@
use std::str::FromStr;
use anyhow::Result;
use apalis::{
cron::{CronStream, Schedule},
prelude::{Data, Job, Monitor, Storage, WorkerBuilder, WorkerFactoryFn},
sqlite::{SqlitePool, SqliteStorage},
utils::TokioExecutor,
};
use chrono::{DateTime, Utc};
use octocrab::{models::Repository, GitHubError, Octocrab};
use serde::{Deserialize, Serialize};
use tabby_db::{DbConn, GithubRepositoryProviderDAO};
use tower::limit::ConcurrencyLimitLayer;
use tracing::debug;
use super::logger::{JobLogLayer, JobLogger};
use crate::warn_stderr;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SyncGithubJob {
provider_id: i64,
}
impl SyncGithubJob {
pub fn new(provider_id: i64) -> Self {
Self { provider_id }
}
}
impl Job for SyncGithubJob {
const NAME: &'static str = "import_github_repositories";
}
impl SyncGithubJob {
async fn run(self, logger: Data<JobLogger>, db: Data<DbConn>) -> crate::schema::Result<()> {
refresh_repositories_for_provider((*logger).clone(), (*db).clone(), self.provider_id)
.await?;
Ok(())
}
async fn cron(
_now: DateTime<Utc>,
storage: Data<SqliteStorage<SyncGithubJob>>,
db: Data<DbConn>,
) -> crate::schema::Result<()> {
debug!("Syncing all github providers");
let mut storage = (*storage).clone();
for provider in db
.list_github_repository_providers(vec![], None, None, false)
.await?
{
storage
.push(SyncGithubJob::new(provider.id))
.await
.expect("unable to push job");
}
Ok(())
}
pub fn register(
monitor: Monitor<TokioExecutor>,
pool: SqlitePool,
db: DbConn,
) -> (SqliteStorage<SyncGithubJob>, Monitor<TokioExecutor>) {
let storage = SqliteStorage::new(pool);
let schedule = Schedule::from_str("@hourly").expect("unable to parse cron schedule");
let monitor = monitor
.register(
WorkerBuilder::new(Self::NAME)
.with_storage(storage.clone())
.layer(ConcurrencyLimitLayer::new(1))
.layer(JobLogLayer::new(db.clone(), Self::NAME))
.data(db.clone())
.build_fn(Self::run),
)
.register(
WorkerBuilder::new(format!("{}-cron", Self::NAME))
.stream(CronStream::new(schedule).into_stream())
.data(storage.clone())
.data(db.clone())
.build_fn(Self::cron),
);
(storage, monitor)
}
}
async fn refresh_repositories_for_provider(
context: JobLogger,
db: DbConn,
provider_id: i64,
) -> Result<()> {
let provider = db.get_github_provider(provider_id).await?;
let repos = match fetch_all_repos(&provider).await {
Ok(repos) => repos,
Err(octocrab::Error::GitHub {
source: source @ GitHubError { .. },
..
}) if source.status_code.is_client_error() => {
db.update_github_provider_sync_status(provider_id, false)
.await?;
warn_stderr!(
context,
"GitHub credentials for provider {} are expired or invalid",
provider.display_name
);
return Err(source.into());
}
Err(e) => {
warn_stderr!(context, "Failed to fetch repositories from github: {e}");
return Err(e.into());
}
};
for repo in repos {
context
.stdout_writeline(format!(
"importing: {}",
repo.full_name.as_deref().unwrap_or(&repo.name)
))
.await;
let id = repo.id.to_string();
let Some(url) = repo.git_url else {
continue;
};
let url = url.to_string();
let url = url
.strip_prefix("git://")
.map(|url| format!("https://{url}"))
.unwrap_or(url);
let url = url.strip_suffix(".git").unwrap_or(&url);
db.upsert_github_provided_repository(
provider_id,
id,
repo.full_name.unwrap_or(repo.name),
url.to_string(),
)
.await?;
}
db.update_github_provider_sync_status(provider_id, true)
.await?;
Ok(())
}
// FIXME(wsxiaoys): Convert to async stream
async fn fetch_all_repos(
provider: &GithubRepositoryProviderDAO,
) -> Result<Vec<Repository>, octocrab::Error> {
let Some(token) = &provider.access_token else {
return Ok(vec![]);
};
let octocrab = Octocrab::builder()
.user_access_token(token.to_string())
.build()?;
let mut page = 1;
let mut repos = vec![];
loop {
let response = octocrab
.current()
.list_repos_for_authenticated_user()
.visibility("all")
.page(page)
.send()
.await?;
let pages = response.number_of_pages().unwrap_or_default() as u8;
repos.extend(response.items);
page += 1;
if page > pages {
break;
}
}
Ok(repos)
}

View File

@ -0,0 +1,194 @@
use std::str::FromStr;
use anyhow::Result;
use apalis::{
cron::{CronStream, Schedule},
prelude::{Data, Job, Monitor, Storage, WorkerBuilder, WorkerFactoryFn},
sqlite::{SqlitePool, SqliteStorage},
utils::TokioExecutor,
};
use chrono::{DateTime, Utc};
use gitlab::{
api::{projects::Projects, ApiError, AsyncQuery, Pagination},
GitlabBuilder,
};
use serde::{Deserialize, Serialize};
use tabby_db::{DbConn, GitlabRepositoryProviderDAO};
use tower::limit::ConcurrencyLimitLayer;
use tracing::debug;
use super::logger::{JobLogLayer, JobLogger};
use crate::warn_stderr;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SyncGitlabJob {
provider_id: i64,
}
impl SyncGitlabJob {
pub fn new(provider_id: i64) -> Self {
Self { provider_id }
}
}
impl Job for SyncGitlabJob {
const NAME: &'static str = "import_gitlab_repositories";
}
impl SyncGitlabJob {
async fn run(self, logger: Data<JobLogger>, db: Data<DbConn>) -> crate::schema::Result<()> {
refresh_repositories_for_provider((*logger).clone(), (*db).clone(), self.provider_id)
.await?;
Ok(())
}
async fn cron(
_now: DateTime<Utc>,
storage: Data<SqliteStorage<SyncGitlabJob>>,
db: Data<DbConn>,
) -> crate::schema::Result<()> {
debug!("Syncing all gitlab providers");
let mut storage = (*storage).clone();
for provider in db
.list_gitlab_repository_providers(vec![], None, None, false)
.await?
{
storage
.push(SyncGitlabJob::new(provider.id))
.await
.expect("unable to push job");
}
Ok(())
}
pub fn register(
monitor: Monitor<TokioExecutor>,
pool: SqlitePool,
db: DbConn,
) -> (SqliteStorage<SyncGitlabJob>, Monitor<TokioExecutor>) {
let storage = SqliteStorage::new(pool);
let schedule = Schedule::from_str("@hourly").expect("unable to parse cron schedule");
let monitor = monitor
.register(
WorkerBuilder::new(Self::NAME)
.with_storage(storage.clone())
.layer(ConcurrencyLimitLayer::new(1))
.layer(JobLogLayer::new(db.clone(), Self::NAME))
.data(db.clone())
.build_fn(Self::run),
)
.register(
WorkerBuilder::new(format!("{}-cron", Self::NAME))
.stream(CronStream::new(schedule).into_stream())
.data(storage.clone())
.data(db.clone())
.build_fn(Self::cron),
);
(storage, monitor)
}
}
async fn refresh_repositories_for_provider(logger: JobLogger, db: DbConn, id: i64) -> Result<()> {
let provider = db.get_gitlab_provider(id).await?;
logger
.stdout_writeline(format!(
"Refreshing repositories for provider: {}\n",
provider.display_name
))
.await;
let start = Utc::now();
let repos = match fetch_all_repos(&provider).await {
Ok(repos) => repos,
Err(e) if e.is_client_error() => {
db.update_gitlab_provider_sync_status(id, false).await?;
warn_stderr!(
logger,
"GitLab credentials for provider {} are expired or invalid",
provider.display_name
);
return Err(e.into());
}
Err(e) => {
warn_stderr!(logger, "Failed to fetch repositories from gitlab: {e}");
return Err(e.into());
}
};
for repo in repos {
logger
.stdout_writeline(format!("importing: {}", &repo.path_with_namespace))
.await;
let id = repo.id.to_string();
let url = repo.http_url_to_repo;
let url = url.strip_suffix(".git").unwrap_or(&url);
db.upsert_gitlab_provided_repository(
provider.id,
id,
repo.path_with_namespace,
url.to_string(),
)
.await?;
}
db.update_gitlab_provided_repository_active(id, true)
.await?;
db.delete_outdated_gitlab_repositories(id, start.into())
.await?;
Ok(())
}
#[derive(Deserialize)]
struct Repository {
id: u128,
path_with_namespace: String,
http_url_to_repo: String,
}
#[derive(thiserror::Error, Debug)]
enum GitlabError {
#[error(transparent)]
Rest(#[from] gitlab::api::ApiError<gitlab::RestError>),
#[error(transparent)]
Gitlab(#[from] gitlab::GitlabError),
#[error(transparent)]
Projects(#[from] gitlab::api::projects::ProjectsBuilderError),
}
impl GitlabError {
fn is_client_error(&self) -> bool {
match self {
GitlabError::Rest(source)
| GitlabError::Gitlab(gitlab::GitlabError::Api { source }) => {
matches!(
source,
ApiError::Auth { .. }
| ApiError::Client {
source: gitlab::RestError::AuthError { .. }
}
| ApiError::Gitlab { .. }
)
}
_ => false,
}
}
}
async fn fetch_all_repos(
provider: &GitlabRepositoryProviderDAO,
) -> Result<Vec<Repository>, GitlabError> {
let Some(token) = &provider.access_token else {
return Ok(vec![]);
};
let gitlab = GitlabBuilder::new("gitlab.com", token)
.build_async()
.await?;
Ok(gitlab::api::paged(
Projects::builder().membership(true).build()?,
Pagination::All,
)
.query_async(&gitlab)
.await?)
}

View File

@ -0,0 +1,143 @@
use std::{
fmt::Debug,
task::{Context, Poll},
};
use apalis::prelude::Request;
use futures::{future::BoxFuture, FutureExt};
use tabby_db::DbConn;
use tower::{Layer, Service};
use tracing::{debug, warn};
#[derive(Clone)]
pub struct JobLogger {
id: i64,
db: DbConn,
}
impl JobLogger {
async fn new(name: &'static str, db: DbConn) -> Self {
let id = db
.create_job_run(name.to_owned())
.await
.expect("failed to create job");
Self { id, db }
}
pub async fn stdout_writeline(&self, stdout: String) {
let stdout = stdout + "\n";
match self.db.update_job_stdout(self.id, stdout).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to write stdout to job `{}`", self.id);
}
}
}
pub async fn stderr_writeline(&self, stderr: String) {
let stderr = stderr + "\n";
match self.db.update_job_stderr(self.id, stderr).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to write stderr to job `{}`", self.id);
}
}
}
async fn complete(&mut self, exit_code: i32) {
match self.db.update_job_status(self.id, exit_code).await {
Ok(_) => (),
Err(_) => {
warn!("Failed to complete job `{}`", self.id);
}
}
}
}
pub struct JobLogLayer {
db: DbConn,
name: &'static str,
}
impl JobLogLayer {
pub fn new(db: DbConn, name: &'static str) -> Self {
Self { db, name }
}
}
impl<S> Layer<S> for JobLogLayer {
type Service = JobLogService<S>;
fn layer(&self, service: S) -> Self::Service {
JobLogService {
db: self.db.clone(),
name: self.name,
service,
}
}
}
#[derive(Clone)]
pub struct JobLogService<S> {
db: DbConn,
name: &'static str,
service: S,
}
pub trait ExitCode {
fn into_exit_code(self) -> i32;
}
impl ExitCode for i32 {
fn into_exit_code(self) -> i32 {
self
}
}
impl ExitCode for () {
fn into_exit_code(self) -> i32 {
0
}
}
impl<S, Req> Service<Request<Req>> for JobLogService<S>
where
S: Service<Request<Req>> + Clone,
Request<Req>: Send + 'static,
S: Send + 'static,
S::Future: Send + 'static,
S::Response: Send + ExitCode + 'static,
S::Error: Send + Debug + 'static,
{
type Response = ();
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, mut request: Request<Req>) -> Self::Future {
debug!("Starting job `{}`", self.name);
let name = self.name;
let db = self.db.clone();
let mut service = self.service.clone();
let fut_with_log = async move {
let mut logger = JobLogger::new(name, db).await;
request.insert(logger.clone());
match service.call(request).await {
Ok(res) => {
debug!("Job `{}` completed", name);
logger.complete(res.into_exit_code()).await;
Ok(())
}
Err(e) => {
warn!("Job `{}` failed: {:?}", name, e);
logger.complete(-1).await;
Err(e)
}
}
};
fut_with_log.boxed()
}
}

View File

@ -0,0 +1,84 @@
mod db;
mod github;
mod gitlab;
mod logger;
mod scheduler;
use std::sync::Arc;
use apalis::{
prelude::{Monitor, Storage},
sqlite::{SqlitePool, SqliteStorage},
};
use async_trait::async_trait;
use tabby_db::DbConn;
use self::{
db::DbMaintainanceJob, github::SyncGithubJob, gitlab::SyncGitlabJob, scheduler::SchedulerJob,
};
use crate::path::job_queue;
#[async_trait]
pub trait BackgroundJob: Send + Sync {
async fn trigger_sync_github(&self, provider_id: i64);
async fn trigger_sync_gitlab(&self, provider_id: i64);
}
struct BackgroundJobImpl {
gitlab: SqliteStorage<SyncGitlabJob>,
github: SqliteStorage<SyncGithubJob>,
}
pub async fn create(db: DbConn, local_port: u16) -> Arc<dyn BackgroundJob> {
let path = format!("sqlite://{}?mode=rwc", job_queue().display());
let pool = SqlitePool::connect(&path)
.await
.expect("unable to create sqlite pool");
SqliteStorage::setup(&pool)
.await
.expect("unable to run migrations for sqlite");
let monitor = Monitor::new();
let monitor = DbMaintainanceJob::register(monitor, db.clone());
let monitor = SchedulerJob::register(monitor, db.clone(), local_port);
let (gitlab, monitor) = SyncGitlabJob::register(monitor, pool.clone(), db.clone());
let (github, monitor) = SyncGithubJob::register(monitor, pool.clone(), db.clone());
tokio::spawn(async move {
monitor.run().await.expect("failed to start worker");
});
Arc::new(BackgroundJobImpl { gitlab, github })
}
struct FakeBackgroundJob;
#[async_trait]
impl BackgroundJob for FakeBackgroundJob {
async fn trigger_sync_github(&self, _provider_id: i64) {}
async fn trigger_sync_gitlab(&self, _provider_id: i64) {}
}
#[cfg(test)]
pub fn create_fake() -> Arc<dyn BackgroundJob> {
Arc::new(FakeBackgroundJob)
}
#[async_trait]
impl BackgroundJob for BackgroundJobImpl {
async fn trigger_sync_github(&self, provider_id: i64) {
self.github
.clone()
.push(SyncGithubJob::new(provider_id))
.await
.expect("unable to push job");
}
async fn trigger_sync_gitlab(&self, provider_id: i64) {
self.gitlab
.clone()
.push(SyncGitlabJob::new(provider_id))
.await
.expect("unable to push job");
}
}

View File

@ -0,0 +1,101 @@
use std::{process::Stdio, str::FromStr};
use anyhow::Context;
use apalis::{
cron::{CronStream, Schedule},
prelude::{Data, Job, Monitor, WorkerBuilder, WorkerFactoryFn},
utils::TokioExecutor,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use tabby_db::DbConn;
use tokio::io::AsyncBufReadExt;
use super::logger::{JobLogLayer, JobLogger};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct SchedulerJob {}
impl Job for SchedulerJob {
const NAME: &'static str = "scheduler";
}
impl SchedulerJob {
async fn run_impl(
self,
job_logger: Data<JobLogger>,
db: Data<DbConn>,
local_port: Data<u16>,
) -> anyhow::Result<i32> {
let local_port = *local_port;
let exe = std::env::current_exe()?;
let mut child = tokio::process::Command::new(exe)
.arg("scheduler")
.arg("--now")
.arg("--url")
.arg(format!("localhost:{local_port}"))
.arg("--token")
.arg(db.read_registration_token().await?)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()?;
{
// Pipe stdout
let stdout = child.stdout.take().context("Failed to acquire stdout")?;
let logger = job_logger.clone();
tokio::spawn(async move {
let stdout = tokio::io::BufReader::new(stdout);
let mut stdout = stdout.lines();
while let Ok(Some(line)) = stdout.next_line().await {
let _ = logger.stdout_writeline(line).await;
}
});
}
{
// Pipe stderr
let stderr = child.stderr.take().context("Failed to acquire stderr")?;
let logger = job_logger.clone();
tokio::spawn(async move {
let stderr = tokio::io::BufReader::new(stderr);
let mut stdout = stderr.lines();
while let Ok(Some(line)) = stdout.next_line().await {
let _ = logger.stderr_writeline(line).await;
}
});
}
if let Some(exit_code) = child.wait().await.ok().and_then(|s| s.code()) {
Ok(exit_code)
} else {
Ok(-1)
}
}
async fn cron(
_now: DateTime<Utc>,
logger: Data<JobLogger>,
db: Data<DbConn>,
local_port: Data<u16>,
) -> crate::schema::Result<i32> {
let job = SchedulerJob {};
Ok(job.run_impl(logger, db, local_port).await?)
}
pub fn register(
monitor: Monitor<TokioExecutor>,
db: DbConn,
local_port: u16,
) -> Monitor<TokioExecutor> {
let schedule = Schedule::from_str("0 */10 * * * *").expect("unable to parse cron schedule");
monitor.register(
WorkerBuilder::new(SchedulerJob::NAME)
.stream(CronStream::new(schedule).into_stream())
.layer(JobLogLayer::new(db.clone(), SchedulerJob::NAME))
.data(db)
.data(local_port)
.build_fn(SchedulerJob::cron),
)
}
}

View File

@ -1,7 +1,6 @@
use async_trait::async_trait;
use juniper::ID;
use tabby_db::DbConn;
use tracing::{debug, error};
use super::{graphql_pagination_to_filter, AsRowid};
use crate::schema::{
@ -11,22 +10,14 @@ use crate::schema::{
struct JobControllerImpl {
db: DbConn,
sender: tokio::sync::mpsc::UnboundedSender<String>,
}
pub fn create(db: DbConn, sender: tokio::sync::mpsc::UnboundedSender<String>) -> impl JobService {
JobControllerImpl { db, sender }
pub async fn create(db: DbConn) -> impl JobService {
JobControllerImpl { db }
}
#[async_trait]
impl JobService for JobControllerImpl {
fn schedule(&self, name: &str) {
debug!("scheduling job: {}", name);
if let Err(e) = self.sender.send(name.to_owned()) {
error!("failed to send job to scheduler: {}", e);
}
}
async fn list(
&self,
ids: Option<Vec<ID>>,

View File

@ -1,5 +1,6 @@
mod analytic;
mod auth;
pub mod background_job;
mod dao;
mod email;
pub mod event_logger;
@ -74,7 +75,6 @@ impl ServerContext {
repository: Arc<dyn RepositoryService>,
db_conn: DbConn,
is_chat_enabled_locally: bool,
schedule_event_sender: tokio::sync::mpsc::UnboundedSender<String>,
) -> Self {
let mail = Arc::new(
new_email_service(db_conn.clone())
@ -87,7 +87,7 @@ impl ServerContext {
.expect("failed to initialize license service"),
);
let user_event = Arc::new(user_event::create(db_conn.clone()));
let job = Arc::new(job::create(db_conn.clone(), schedule_event_sender));
let job = Arc::new(job::create(db_conn.clone()).await);
Self {
client: Client::default(),
completion: worker::WorkerGroup::default(),
@ -326,18 +326,9 @@ pub async fn create_service_locator(
repository: Arc<dyn RepositoryService>,
db: DbConn,
is_chat_enabled: bool,
schedule_event_sender: tokio::sync::mpsc::UnboundedSender<String>,
) -> Arc<dyn ServiceLocator> {
Arc::new(Arc::new(
ServerContext::new(
logger,
code,
repository,
db,
is_chat_enabled,
schedule_event_sender,
)
.await,
ServerContext::new(logger, code, repository, db, is_chat_enabled).await,
))
}

View File

@ -1,7 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use juniper::ID;
use tabby_db::DbConn;
use url::Url;
@ -14,15 +16,16 @@ use crate::{
},
Result,
},
service::{graphql_pagination_to_filter, AsID, AsRowid},
service::{background_job::BackgroundJob, graphql_pagination_to_filter, AsID, AsRowid},
};
struct GithubRepositoryProviderServiceImpl {
db: DbConn,
background: Arc<dyn BackgroundJob>,
}
pub fn create(db: DbConn) -> impl GithubRepositoryService {
GithubRepositoryProviderServiceImpl { db }
pub fn create(db: DbConn, background: Arc<dyn BackgroundJob>) -> impl GithubRepositoryService {
GithubRepositoryProviderServiceImpl { db, background }
}
#[async_trait]
@ -32,14 +35,10 @@ impl GithubRepositoryService for GithubRepositoryProviderServiceImpl {
.db
.create_github_provider(display_name, access_token)
.await?;
self.background.trigger_sync_github(id).await;
Ok(id.as_id())
}
async fn get_provider(&self, id: ID) -> Result<GithubRepositoryProvider> {
let provider = self.db.get_github_provider(id.as_rowid()?).await?;
Ok(provider.into())
}
async fn delete_provider(&self, id: ID) -> Result<()> {
self.db.delete_github_provider(id.as_rowid()?).await?;
Ok(())
@ -95,24 +94,6 @@ impl GithubRepositoryService for GithubRepositoryProviderServiceImpl {
.collect())
}
async fn upsert_repository(
&self,
provider_id: ID,
vendor_id: String,
display_name: String,
git_url: String,
) -> Result<()> {
self.db
.upsert_github_provided_repository(
provider_id.as_rowid()?,
vendor_id,
display_name,
git_url,
)
.await?;
Ok(())
}
async fn update_repository_active(&self, id: ID, active: bool) -> Result<()> {
self.db
.update_github_provided_repository_active(id.as_rowid()?, active)
@ -126,9 +107,11 @@ impl GithubRepositoryService for GithubRepositoryProviderServiceImpl {
display_name: String,
access_token: String,
) -> Result<()> {
let id = id.as_rowid()?;
self.db
.update_github_provider(id.as_rowid()?, display_name, access_token)
.update_github_provider(id, display_name, access_token)
.await?;
self.background.trigger_sync_github(id).await;
Ok(())
}
@ -158,24 +141,6 @@ impl GithubRepositoryService for GithubRepositoryProviderServiceImpl {
Ok(urls)
}
async fn delete_outdated_repositories(
&self,
provider_id: ID,
cutoff_timestamp: DateTime<Utc>,
) -> Result<()> {
self.db
.delete_outdated_github_repositories(provider_id.as_rowid()?, cutoff_timestamp.into())
.await?;
Ok(())
}
async fn update_provider_status(&self, id: ID, success: bool) -> Result<()> {
self.db
.update_github_provider_sync_status(id.as_rowid()?, success)
.await?;
Ok(())
}
}
#[async_trait]
@ -206,15 +171,14 @@ fn deduplicate_github_repositories(repositories: &mut Vec<GithubProvidedReposito
#[cfg(test)]
mod tests {
use chrono::Duration;
use super::*;
use crate::{schema::repository::RepositoryProviderStatus, service::AsID};
use crate::{background_job::create_fake, service::AsID};
#[tokio::test]
async fn test_github_provided_repositories() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let service = create(db.clone(), create_fake());
let provider_id1 = db
.create_github_provider("test_id1".into(), "test_secret".into())
@ -301,25 +265,13 @@ mod tests {
#[tokio::test]
async fn test_github_repository_provider_crud() {
let db = DbConn::new_in_memory().await.unwrap();
let service = super::create(db.clone());
let service = create(db.clone(), create_fake());
let id = service
.create_provider("id".into(), "secret".into())
.await
.unwrap();
// Test retrieving github provider by ID
let provider1 = service.get_provider(id.clone()).await.unwrap();
assert_eq!(
provider1,
GithubRepositoryProvider {
id: id.clone(),
display_name: "id".into(),
access_token: Some("secret".into()),
status: RepositoryProviderStatus::Pending,
}
);
// Test listing github providers
let providers = service
.list_providers(vec![], None, None, None, None)
@ -344,7 +296,7 @@ mod tests {
#[tokio::test]
async fn test_provided_git_urls() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let service = create(db.clone(), create_fake());
let provider_id = db
.create_github_provider("provider1".into(), "token".into())
@ -371,87 +323,4 @@ mod tests {
["https://token@github.com/TabbyML/tabby".to_string()]
);
}
#[tokio::test]
async fn test_sync_status() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let provider_id = db
.create_github_provider("provider1".into(), "token".into())
.await
.unwrap();
service
.update_provider_status(provider_id.as_id(), true)
.await
.unwrap();
let provider = db.get_github_provider(provider_id).await.unwrap();
assert!(provider.access_token.is_some());
assert!(provider.synced_at.is_some());
service
.update_provider_status(provider_id.as_id(), false)
.await
.unwrap();
let provider = db.get_github_provider(provider_id).await.unwrap();
assert!(provider.access_token.is_none());
assert!(provider.synced_at.is_none());
}
#[tokio::test]
async fn test_delete_outdated_repos() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let time = Utc::now();
let provider_id = db
.create_github_provider("provider1".into(), "secret1".into())
.await
.unwrap();
let _repo_id = db
.upsert_github_provided_repository(
provider_id,
"vendor_id1".into(),
"test_repo".into(),
"https://github.com/TabbyML/tabby".into(),
)
.await
.unwrap();
service
.delete_outdated_repositories(provider_id.as_id(), time)
.await
.unwrap();
assert_eq!(
1,
service
.list_repositories(vec![], None, None, None, None, None)
.await
.unwrap()
.len()
);
let time = time + Duration::minutes(1);
service
.delete_outdated_repositories(provider_id.as_id(), time)
.await
.unwrap();
assert_eq!(
0,
service
.list_repositories(vec![], None, None, None, None, None)
.await
.unwrap()
.len()
);
}
}

View File

@ -1,7 +1,9 @@
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use juniper::ID;
use tabby_db::DbConn;
use url::Url;
@ -14,15 +16,16 @@ use crate::{
},
Result,
},
service::{graphql_pagination_to_filter, AsID, AsRowid},
service::{background_job::BackgroundJob, graphql_pagination_to_filter, AsID, AsRowid},
};
struct GitlabRepositoryProviderServiceImpl {
db: DbConn,
background_job: Arc<dyn BackgroundJob>,
}
pub fn create(db: DbConn) -> impl GitlabRepositoryService {
GitlabRepositoryProviderServiceImpl { db }
pub fn create(db: DbConn, background_job: Arc<dyn BackgroundJob>) -> impl GitlabRepositoryService {
GitlabRepositoryProviderServiceImpl { db, background_job }
}
#[async_trait]
@ -32,14 +35,10 @@ impl GitlabRepositoryService for GitlabRepositoryProviderServiceImpl {
.db
.create_gitlab_provider(display_name, access_token)
.await?;
self.background_job.trigger_sync_gitlab(id).await;
Ok(id.as_id())
}
async fn get_provider(&self, id: ID) -> Result<GitlabRepositoryProvider> {
let provider = self.db.get_gitlab_provider(id.as_rowid()?).await?;
Ok(provider.into())
}
async fn delete_provider(&self, id: ID) -> Result<()> {
self.db.delete_gitlab_provider(id.as_rowid()?).await?;
Ok(())
@ -95,24 +94,6 @@ impl GitlabRepositoryService for GitlabRepositoryProviderServiceImpl {
.collect())
}
async fn upsert_repository(
&self,
provider_id: ID,
vendor_id: String,
display_name: String,
git_url: String,
) -> Result<()> {
self.db
.upsert_gitlab_provided_repository(
provider_id.as_rowid()?,
vendor_id,
display_name,
git_url,
)
.await?;
Ok(())
}
async fn update_repository_active(&self, id: ID, active: bool) -> Result<()> {
self.db
.update_gitlab_provided_repository_active(id.as_rowid()?, active)
@ -126,9 +107,11 @@ impl GitlabRepositoryService for GitlabRepositoryProviderServiceImpl {
display_name: String,
access_token: String,
) -> Result<()> {
let id = id.as_rowid()?;
self.db
.update_gitlab_provider(id.as_rowid()?, display_name, access_token)
.update_gitlab_provider(id, display_name, access_token)
.await?;
self.background_job.trigger_sync_gitlab(id).await;
Ok(())
}
@ -161,24 +144,6 @@ impl GitlabRepositoryService for GitlabRepositoryProviderServiceImpl {
Ok(urls)
}
async fn delete_outdated_repositories(
&self,
provider_id: ID,
cutoff_timestamp: DateTime<Utc>,
) -> Result<()> {
self.db
.delete_outdated_gitlab_repositories(provider_id.as_rowid()?, cutoff_timestamp.into())
.await?;
Ok(())
}
async fn update_provider_status(&self, id: ID, success: bool) -> Result<()> {
self.db
.update_gitlab_provider_sync_status(id.as_rowid()?, success)
.await?;
Ok(())
}
}
fn deduplicate_gitlab_repositories(repositories: &mut Vec<GitlabProvidedRepository>) {
@ -209,15 +174,14 @@ impl RepositoryProvider for GitlabRepositoryProviderServiceImpl {
#[cfg(test)]
mod tests {
use chrono::Duration;
use super::*;
use crate::{schema::repository::RepositoryProviderStatus, service::AsID};
use crate::{background_job::create_fake, service::AsID};
#[tokio::test]
async fn test_gitlab_provided_repositories() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let service = create(db.clone(), create_fake());
let provider_id1 = db
.create_gitlab_provider("test_id1".into(), "test_secret".into())
@ -304,25 +268,13 @@ mod tests {
#[tokio::test]
async fn test_gitlab_repository_provider_crud() {
let db = DbConn::new_in_memory().await.unwrap();
let service = super::create(db.clone());
let service = create(db.clone(), create_fake());
let id = service
.create_provider("id".into(), "secret".into())
.await
.unwrap();
// Test retrieving gitlab provider by ID
let provider1 = service.get_provider(id.clone()).await.unwrap();
assert_eq!(
provider1,
GitlabRepositoryProvider {
id: id.clone(),
display_name: "id".into(),
access_token: Some("secret".into()),
status: RepositoryProviderStatus::Pending
}
);
// Test listing gitlab providers
let providers = service
.list_providers(vec![], None, None, None, None)
@ -344,41 +296,10 @@ mod tests {
);
}
#[tokio::test]
async fn test_sync_status() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let provider_id = db
.create_gitlab_provider("provider1".into(), "token".into())
.await
.unwrap();
service
.update_provider_status(provider_id.as_id(), true)
.await
.unwrap();
let provider = db.get_gitlab_provider(provider_id).await.unwrap();
assert!(provider.access_token.is_some());
assert!(provider.synced_at.is_some());
service
.update_provider_status(provider_id.as_id(), false)
.await
.unwrap();
let provider = db.get_gitlab_provider(provider_id).await.unwrap();
assert!(provider.access_token.is_none());
assert!(provider.synced_at.is_none());
}
#[tokio::test]
async fn test_provided_git_urls() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let service = super::create(db.clone(), create_fake());
let provider_id = db
.create_gitlab_provider("provider1".into(), "token".into())
@ -405,56 +326,4 @@ mod tests {
["https://oauth2:token@gitlab.com/TabbyML/tabby".to_string()]
);
}
#[tokio::test]
async fn test_delete_outdated_repos() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let time = Utc::now();
let provider_id = db
.create_gitlab_provider("provider1".into(), "secret1".into())
.await
.unwrap();
let _repo_id = db
.upsert_gitlab_provided_repository(
provider_id,
"vendor_id1".into(),
"test_repo".into(),
"https://gitlab.com/TabbyML/tabby".into(),
)
.await
.unwrap();
service
.delete_outdated_repositories(provider_id.as_id(), time)
.await
.unwrap();
assert_eq!(
1,
service
.list_repositories(vec![], None, None, None, None, None)
.await
.unwrap()
.len()
);
let time = time + Duration::minutes(1);
service
.delete_outdated_repositories(provider_id.as_id(), time)
.await
.unwrap();
assert_eq!(
0,
service
.list_repositories(vec![], None, None, None, None, None)
.await
.unwrap()
.len()
);
}
}

View File

@ -9,6 +9,7 @@ use juniper::ID;
use tabby_common::config::{RepositoryAccess, RepositoryConfig};
use tabby_db::DbConn;
use super::background_job::BackgroundJob;
use crate::schema::{
repository::{
FileEntrySearchResult, GitRepositoryService, GithubRepositoryService,
@ -23,11 +24,11 @@ struct RepositoryServiceImpl {
gitlab: Arc<dyn GitlabRepositoryService>,
}
pub fn create(db: DbConn) -> Arc<dyn RepositoryService> {
pub fn create(db: DbConn, background: Arc<dyn BackgroundJob>) -> Arc<dyn RepositoryService> {
Arc::new(RepositoryServiceImpl {
git: Arc::new(db.clone()),
github: Arc::new(github::create(db.clone())),
gitlab: Arc::new(gitlab::create(db.clone())),
github: Arc::new(github::create(db.clone(), background.clone())),
gitlab: Arc::new(gitlab::create(db, background)),
})
}
@ -129,11 +130,12 @@ mod tests {
use tabby_db::DbConn;
use super::*;
use crate::background_job::create_fake;
#[tokio::test]
async fn test_list_repositories() {
let db = DbConn::new_in_memory().await.unwrap();
let service = create(db.clone());
let service = create(db.clone(), create_fake());
service
.git()
.create("test_git_repo".into(), "http://test_git_repo".into())

View File

@ -6,7 +6,6 @@ files:
- ./ee/tabby-webserver/src/**
ignores:
- ./ee/tabby-webserver/src/service/**
- ./ee/tabby-webserver/src/cron/**
- ./ee/tabby-webserver/src/handler.rs
rule:
pattern: tabby_db