feat(webserver): add access_policy service (#3117)

This commit is contained in:
Meng Zhang 2024-09-10 21:10:17 -07:00 committed by GitHub
parent 1a21fbcbe4
commit 553fc8518b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 259 additions and 37 deletions

View File

@ -1,6 +1,7 @@
use chrono::{DateTime, Utc};
use sqlx::query; use sqlx::query;
use crate::DbConn; use crate::{DbConn, UserGroupDAO};
impl DbConn { impl DbConn {
pub async fn allow_read_source(&self, user_id: i64, source_id: &str) -> anyhow::Result<bool> { pub async fn allow_read_source(&self, user_id: i64, source_id: &str) -> anyhow::Result<bool> {
@ -76,4 +77,60 @@ DELETE FROM source_id_read_access_policies WHERE source_id = ? AND user_group_id
)) ))
} }
} }
pub async fn delete_unused_source_id_read_access_policy(
&self,
active_source_ids: &[String],
) -> anyhow::Result<usize> {
let in_clause = active_source_ids
.iter()
.map(|s| format!("'{}'", s))
.collect::<Vec<_>>()
.join(",");
let rows_deleted = sqlx::query(&format!(
"DELETE FROM source_id_read_access_policies WHERE source_id NOT IN ({in_clause})"
))
.execute(&self.pool)
.await?
.rows_affected();
Ok(rows_deleted as usize)
}
pub async fn list_source_id_read_access_user_groups(
&self,
source_id: &str,
) -> anyhow::Result<Vec<UserGroupDAO>> {
let user_groups = sqlx::query_as!(
UserGroupDAO,
r#"SELECT
user_groups.id as "id",
name,
user_groups.created_at as "created_at: DateTime<Utc>",
user_groups.updated_at as "updated_at: DateTime<Utc>"
FROM source_id_read_access_policies INNER JOIN user_groups ON (source_id_read_access_policies.user_group_id = user_groups.id)
WHERE source_id = ?
"#,
source_id
)
.fetch_all(&self.pool)
.await?;
Ok(user_groups)
}
}
#[cfg(test)]
mod tests {
use crate::DbConn;
#[tokio::test]
async fn test_delete_unused_source_id_read_access_policy() {
let db = DbConn::new_in_memory().await.unwrap();
let rows_deleted = db
.delete_unused_source_id_read_access_policy(&["test1".into()])
.await
.unwrap();
assert_eq!(rows_deleted, 0);
}
} }

View File

@ -24,6 +24,7 @@ pub use user_groups::{UserGroupDAO, UserGroupMembershipDAO};
pub use users::UserDAO; pub use users::UserDAO;
pub use web_documents::WebDocumentDAO; pub use web_documents::WebDocumentDAO;
mod access_policy;
pub mod cache; pub mod cache;
mod email_setting; mod email_setting;
mod integrations; mod integrations;
@ -33,7 +34,6 @@ mod job_runs;
mod migration_tests; mod migration_tests;
mod oauth_credential; mod oauth_credential;
mod password_reset; mod password_reset;
mod policy;
mod provided_repositories; mod provided_repositories;
mod refresh_tokens; mod refresh_tokens;
mod repositories; mod repositories;

View File

@ -560,6 +560,8 @@ type Mutation {
deleteUserGroup(id: ID!): Boolean! deleteUserGroup(id: ID!): Boolean!
upsertUserGroupMembership(input: UpsertUserGroupMembershipInput!): Boolean! upsertUserGroupMembership(input: UpsertUserGroupMembershipInput!): Boolean!
deleteUserGroupMembership(userGroupId: ID!, userId: ID!): Boolean! deleteUserGroupMembership(userGroupId: ID!, userId: ID!): Boolean!
grantSourceIdReadAccess(sourceId: String!, userGroupId: ID!): Boolean!
revokeSourceIdReadAccess(sourceId: String!, userGroupId: ID!): Boolean!
} }
type NetworkSetting { type NetworkSetting {
@ -682,6 +684,7 @@ type Query {
When the requesting user is an admin, all user groups will be returned. Otherwise, they can only see groups they are a member of. When the requesting user is an admin, all user groups will be returned. Otherwise, they can only see groups they are a member of.
""" """
userGroups: [UserGroup!]! userGroups: [UserGroup!]!
sourceIdAccessPolicies(sourceId: String!): SourceIdAccessPolicy!
} }
type RefreshTokenResponse { type RefreshTokenResponse {
@ -732,6 +735,11 @@ type ServerInfo {
isDemoMode: Boolean! isDemoMode: Boolean!
} }
type SourceIdAccessPolicy {
sourceId: String!
read: [UserGroup!]!
}
type Subscription { type Subscription {
createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem! createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem!
createThreadRun(input: CreateThreadRunInput!): ThreadRunItem! createThreadRun(input: CreateThreadRunInput!): ThreadRunItem!

View File

@ -0,0 +1,19 @@
use async_trait::async_trait;
use juniper::{GraphQLObject, ID};
use super::{user_group::UserGroup, Context, Result};
#[derive(GraphQLObject)]
#[graphql(context = Context)]
pub struct SourceIdAccessPolicy {
pub source_id: String,
pub read: Vec<UserGroup>,
}
#[async_trait]
pub trait AccessPolicyService: Sync + Send {
async fn list_source_id_read_access(&self, source_id: &str) -> Result<Vec<UserGroup>>;
async fn grant_source_id_read_access(&self, source_id: &str, user_group_id: &ID) -> Result<()>;
async fn revoke_source_id_read_access(&self, source_id: &str, user_group_id: &ID)
-> Result<()>;
}

View File

@ -1,3 +1,4 @@
pub mod access_policy;
pub mod analytic; pub mod analytic;
pub mod auth; pub mod auth;
pub mod constants; pub mod constants;
@ -17,6 +18,7 @@ pub mod worker;
use std::sync::Arc; use std::sync::Arc;
use access_policy::{AccessPolicyService, SourceIdAccessPolicy};
use auth::{ use auth::{
AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
UserSecured, UserSecured,
@ -83,6 +85,7 @@ pub trait ServiceLocator: Send + Sync {
fn thread(&self) -> Arc<dyn ThreadService>; fn thread(&self) -> Arc<dyn ThreadService>;
fn context(&self) -> Arc<dyn ContextService>; fn context(&self) -> Arc<dyn ContextService>;
fn user_group(&self) -> Arc<dyn UserGroupService>; fn user_group(&self) -> Arc<dyn UserGroupService>;
fn access_policy(&self) -> Arc<dyn AccessPolicyService>;
} }
pub struct Context { pub struct Context {
@ -658,6 +661,20 @@ impl Query {
let user = check_user(ctx).await?; let user = check_user(ctx).await?;
ctx.locator.user_group().list(&user.policy).await ctx.locator.user_group().list(&user.policy).await
} }
async fn source_id_access_policies(
ctx: &Context,
source_id: String,
) -> Result<SourceIdAccessPolicy> {
check_admin(ctx).await?;
let read = ctx
.locator
.access_policy()
.list_source_id_read_access(&source_id)
.await?;
Ok(SourceIdAccessPolicy { source_id, read })
}
} }
#[derive(GraphQLObject)] #[derive(GraphQLObject)]
@ -1117,6 +1134,32 @@ impl Mutation {
.await?; .await?;
Ok(true) Ok(true)
} }
async fn grant_source_id_read_access(
ctx: &Context,
source_id: String,
user_group_id: ID,
) -> Result<bool> {
check_admin(ctx).await?;
ctx.locator
.access_policy()
.grant_source_id_read_access(&source_id, &user_group_id)
.await?;
Ok(true)
}
async fn revoke_source_id_read_access(
ctx: &Context,
source_id: String,
user_group_id: ID,
) -> Result<bool> {
check_admin(ctx).await?;
ctx.locator
.access_policy()
.revoke_source_id_read_access(&source_id, &user_group_id)
.await?;
Ok(true)
}
} }
fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> { fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> {

View File

@ -0,0 +1,59 @@
use std::sync::Arc;
use juniper::ID;
use tabby_db::DbConn;
use tabby_schema::{
access_policy::AccessPolicyService, bail, context::ContextService, user_group::UserGroup,
AsRowid, Result,
};
use super::UserGroupExt;
struct AccessPolicyServiceImpl {
db: DbConn,
context: Arc<dyn ContextService>,
}
#[async_trait::async_trait]
impl AccessPolicyService for AccessPolicyServiceImpl {
async fn list_source_id_read_access(&self, source_id: &str) -> Result<Vec<UserGroup>> {
let mut user_groups = Vec::new();
for x in self
.db
.list_source_id_read_access_user_groups(source_id)
.await?
{
user_groups.push(UserGroup::new(self.db.clone(), x).await?)
}
Ok(user_groups)
}
async fn grant_source_id_read_access(&self, source_id: &str, user_group_id: &ID) -> Result<()> {
let context_info = self.context.read(None).await?;
let helper = context_info.helper();
if !helper.can_access_source_id(source_id) {
bail!("source_id {} not found", source_id)
}
self.db
.upsert_source_id_read_access_policy(source_id, user_group_id.as_rowid()?)
.await?;
Ok(())
}
async fn revoke_source_id_read_access(
&self,
source_id: &str,
user_group_id: &ID,
) -> Result<()> {
self.db
.delete_source_id_read_access_policy(source_id, user_group_id.as_rowid()?)
.await?;
Ok(())
}
}
pub fn create(db: DbConn, context: Arc<dyn ContextService>) -> impl AccessPolicyService {
AccessPolicyServiceImpl { db, context }
}

View File

@ -1,6 +1,9 @@
use std::sync::Arc;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tabby_db::DbConn; use tabby_db::DbConn;
use tabby_schema::context::ContextService;
use super::helper::Job; use super::helper::Job;
@ -12,12 +15,26 @@ impl Job for DbMaintainanceJob {
} }
impl DbMaintainanceJob { impl DbMaintainanceJob {
pub async fn cron(_now: DateTime<Utc>, db: DbConn) -> tabby_schema::Result<()> { pub async fn cron(
_now: DateTime<Utc>,
context: Arc<dyn ContextService>,
db: DbConn,
) -> tabby_schema::Result<()> {
db.delete_expired_token().await?; db.delete_expired_token().await?;
db.delete_expired_password_resets().await?; db.delete_expired_password_resets().await?;
db.delete_expired_ephemeral_threads().await?; db.delete_expired_ephemeral_threads().await?;
// FIXME(meng): add maintainance job for source_id_read_access_policies // Read all active sources
let active_source_ids = context
.read(None)
.await?
.sources
.into_iter()
.map(|source| source.source_id)
.collect::<Vec<_>>();
db.delete_unused_source_id_read_access_policy(&active_source_ids)
.await?;
Ok(()) Ok(())
} }
} }

View File

@ -120,7 +120,7 @@ pub async fn start(
debug!("Background job {} completed", job.id); debug!("Background job {} completed", job.id);
}, },
Some(now) = hourly.next() => { Some(now) = hourly.next() => {
if let Err(err) = DbMaintainanceJob::cron(now, db.clone()).await { if let Err(err) = DbMaintainanceJob::cron(now, context_service.clone(), db.clone()).await {
warn!("Database maintainance failed: {:?}", err); warn!("Database maintainance failed: {:?}", err);
} }

View File

@ -1,3 +1,4 @@
mod access_policy;
mod analytic; mod analytic;
pub mod answer; pub mod answer;
mod auth; mod auth;
@ -19,6 +20,7 @@ pub mod web_documents;
use std::sync::Arc; use std::sync::Arc;
use answer::AnswerService; use answer::AnswerService;
use anyhow::Context;
use async_trait::async_trait; use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
@ -32,14 +34,16 @@ use tabby_common::{
api::{code::CodeSearch, event::EventLogger}, api::{code::CodeSearch, event::EventLogger},
constants::USER_HEADER_FIELD_NAME, constants::USER_HEADER_FIELD_NAME,
}; };
use tabby_db::{DbConn, UserDAO}; use tabby_db::{DbConn, UserDAO, UserGroupDAO};
use tabby_inference::Embedding; use tabby_inference::Embedding;
use tabby_schema::{ use tabby_schema::{
access_policy::AccessPolicyService,
analytic::AnalyticService, analytic::AnalyticService,
auth::AuthenticationService, auth::{AuthenticationService, UserSecured},
context::ContextService, context::ContextService,
email::EmailService, email::EmailService,
integration::IntegrationService, integration::IntegrationService,
interface::UserValue,
is_demo_mode, is_demo_mode,
job::JobService, job::JobService,
license::{IsLicenseValid, LicenseService}, license::{IsLicenseValid, LicenseService},
@ -48,7 +52,7 @@ use tabby_schema::{
setting::SettingService, setting::SettingService,
thread::ThreadService, thread::ThreadService,
user_event::UserEventService, user_event::UserEventService,
user_group::UserGroupService, user_group::{UserGroup, UserGroupMembership, UserGroupService},
web_documents::WebDocumentService, web_documents::WebDocumentService,
worker::WorkerService, worker::WorkerService,
AsID, AsRowid, CoreError, Result, ServiceLocator, AsID, AsRowid, CoreError, Result, ServiceLocator,
@ -70,6 +74,7 @@ struct ServerContext {
thread: Arc<dyn ThreadService>, thread: Arc<dyn ThreadService>,
context: Arc<dyn ContextService>, context: Arc<dyn ContextService>,
user_group: Arc<dyn UserGroupService>, user_group: Arc<dyn UserGroupService>,
access_policy: Arc<dyn AccessPolicyService>,
logger: Arc<dyn EventLogger>, logger: Arc<dyn EventLogger>,
code: Arc<dyn CodeSearch>, code: Arc<dyn CodeSearch>,
@ -107,6 +112,7 @@ impl ServerContext {
let setting = Arc::new(setting::create(db_conn.clone())); let setting = Arc::new(setting::create(db_conn.clone()));
let thread = Arc::new(thread::create(db_conn.clone(), answer.clone())); let thread = Arc::new(thread::create(db_conn.clone(), answer.clone()));
let user_group = Arc::new(user_group::create(db_conn.clone())); let user_group = Arc::new(user_group::create(db_conn.clone()));
let access_policy = Arc::new(access_policy::create(db_conn.clone(), context.clone()));
background_job::start( background_job::start(
db_conn.clone(), db_conn.clone(),
@ -140,6 +146,7 @@ impl ServerContext {
code, code,
setting, setting,
user_group, user_group,
access_policy,
db_conn, db_conn,
is_chat_enabled_locally, is_chat_enabled_locally,
} }
@ -297,6 +304,10 @@ impl ServiceLocator for ArcServerContext {
fn user_group(&self) -> Arc<dyn UserGroupService> { fn user_group(&self) -> Arc<dyn UserGroupService> {
self.0.user_group.clone() self.0.user_group.clone()
} }
fn access_policy(&self) -> Arc<dyn AccessPolicyService> {
self.0.access_policy.clone()
}
} }
pub async fn create_service_locator( pub async fn create_service_locator(
@ -391,3 +402,36 @@ impl UserSecuredExt for tabby_schema::auth::UserSecured {
} }
} }
} }
#[async_trait::async_trait]
trait UserGroupExt {
async fn new(db: DbConn, val: UserGroupDAO) -> Result<UserGroup>;
}
#[async_trait::async_trait]
impl UserGroupExt for UserGroup {
async fn new(db: DbConn, val: UserGroupDAO) -> Result<UserGroup> {
let mut members = Vec::new();
for x in db.list_user_group_memberships(val.id, None).await? {
members.push(UserGroupMembership {
is_group_admin: x.is_group_admin,
created_at: x.created_at,
updated_at: x.updated_at,
user: UserValue::UserSecured(UserSecured::new(
db.clone(),
db.get_user(x.user_id)
.await?
.context("User doesn't exists")?,
)),
});
}
Ok(UserGroup {
id: val.id.as_id(),
name: val.name,
created_at: val.created_at,
updated_at: val.updated_at,
members,
})
}
}

View File

@ -1,18 +1,14 @@
use anyhow::Context;
use juniper::ID; use juniper::ID;
use tabby_db::DbConn; use tabby_db::DbConn;
use tabby_schema::{ use tabby_schema::{
auth::UserSecured,
interface::UserValue,
policy::AccessPolicy, policy::AccessPolicy,
user_group::{ user_group::{
CreateUserGroupInput, UpsertUserGroupMembershipInput, UserGroup, UserGroupMembership, CreateUserGroupInput, UpsertUserGroupMembershipInput, UserGroup, UserGroupService,
UserGroupService,
}, },
AsID, AsRowid, Result, AsID, AsRowid, Result,
}; };
use super::UserSecuredExt; use super::{UserGroupExt, UserSecuredExt};
struct UserGroupServiceImpl { struct UserGroupServiceImpl {
db: DbConn, db: DbConn,
@ -28,29 +24,7 @@ impl UserGroupService for UserGroupServiceImpl {
let mut user_groups = Vec::new(); let mut user_groups = Vec::new();
for x in self.db.list_user_groups(user_id).await? { for x in self.db.list_user_groups(user_id).await? {
let mut members = Vec::new(); user_groups.push(UserGroup::new(self.db.clone(), x).await?);
for x in self.db.list_user_group_memberships(x.id, None).await? {
members.push(UserGroupMembership {
is_group_admin: x.is_group_admin,
created_at: x.created_at,
updated_at: x.updated_at,
user: UserValue::UserSecured(UserSecured::new(
self.db.clone(),
self.db
.get_user(x.user_id)
.await?
.context("User doesn't exists")?,
)),
});
}
user_groups.push(UserGroup {
id: x.id.as_id(),
name: x.name,
created_at: x.created_at,
updated_at: x.updated_at,
members,
});
} }
Ok(user_groups) Ok(user_groups)
} }
@ -91,6 +65,7 @@ pub fn create(db: DbConn) -> impl UserGroupService {
mod tests { mod tests {
use assert_matches::assert_matches; use assert_matches::assert_matches;
use tabby_db::testutils; use tabby_db::testutils;
use tabby_schema::interface::UserValue;
use super::*; use super::*;