mirror of
https://github.com/TabbyML/tabby
synced 2024-11-21 07:50:13 +00:00
feat(webserver): add access_policy service (#3117)
This commit is contained in:
parent
1a21fbcbe4
commit
553fc8518b
@ -1,6 +1,7 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::query;
|
||||
|
||||
use crate::DbConn;
|
||||
use crate::{DbConn, UserGroupDAO};
|
||||
|
||||
impl DbConn {
|
||||
pub async fn allow_read_source(&self, user_id: i64, source_id: &str) -> anyhow::Result<bool> {
|
||||
@ -76,4 +77,60 @@ DELETE FROM source_id_read_access_policies WHERE source_id = ? AND user_group_id
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_unused_source_id_read_access_policy(
|
||||
&self,
|
||||
active_source_ids: &[String],
|
||||
) -> anyhow::Result<usize> {
|
||||
let in_clause = active_source_ids
|
||||
.iter()
|
||||
.map(|s| format!("'{}'", s))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
let rows_deleted = sqlx::query(&format!(
|
||||
"DELETE FROM source_id_read_access_policies WHERE source_id NOT IN ({in_clause})"
|
||||
))
|
||||
.execute(&self.pool)
|
||||
.await?
|
||||
.rows_affected();
|
||||
|
||||
Ok(rows_deleted as usize)
|
||||
}
|
||||
|
||||
pub async fn list_source_id_read_access_user_groups(
|
||||
&self,
|
||||
source_id: &str,
|
||||
) -> anyhow::Result<Vec<UserGroupDAO>> {
|
||||
let user_groups = sqlx::query_as!(
|
||||
UserGroupDAO,
|
||||
r#"SELECT
|
||||
user_groups.id as "id",
|
||||
name,
|
||||
user_groups.created_at as "created_at: DateTime<Utc>",
|
||||
user_groups.updated_at as "updated_at: DateTime<Utc>"
|
||||
FROM source_id_read_access_policies INNER JOIN user_groups ON (source_id_read_access_policies.user_group_id = user_groups.id)
|
||||
WHERE source_id = ?
|
||||
"#,
|
||||
source_id
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
Ok(user_groups)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::DbConn;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_delete_unused_source_id_read_access_policy() {
|
||||
let db = DbConn::new_in_memory().await.unwrap();
|
||||
let rows_deleted = db
|
||||
.delete_unused_source_id_read_access_policy(&["test1".into()])
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(rows_deleted, 0);
|
||||
}
|
||||
}
|
@ -24,6 +24,7 @@ pub use user_groups::{UserGroupDAO, UserGroupMembershipDAO};
|
||||
pub use users::UserDAO;
|
||||
pub use web_documents::WebDocumentDAO;
|
||||
|
||||
mod access_policy;
|
||||
pub mod cache;
|
||||
mod email_setting;
|
||||
mod integrations;
|
||||
@ -33,7 +34,6 @@ mod job_runs;
|
||||
mod migration_tests;
|
||||
mod oauth_credential;
|
||||
mod password_reset;
|
||||
mod policy;
|
||||
mod provided_repositories;
|
||||
mod refresh_tokens;
|
||||
mod repositories;
|
||||
|
@ -560,6 +560,8 @@ type Mutation {
|
||||
deleteUserGroup(id: ID!): Boolean!
|
||||
upsertUserGroupMembership(input: UpsertUserGroupMembershipInput!): Boolean!
|
||||
deleteUserGroupMembership(userGroupId: ID!, userId: ID!): Boolean!
|
||||
grantSourceIdReadAccess(sourceId: String!, userGroupId: ID!): Boolean!
|
||||
revokeSourceIdReadAccess(sourceId: String!, userGroupId: ID!): Boolean!
|
||||
}
|
||||
|
||||
type NetworkSetting {
|
||||
@ -682,6 +684,7 @@ type Query {
|
||||
When the requesting user is an admin, all user groups will be returned. Otherwise, they can only see groups they are a member of.
|
||||
"""
|
||||
userGroups: [UserGroup!]!
|
||||
sourceIdAccessPolicies(sourceId: String!): SourceIdAccessPolicy!
|
||||
}
|
||||
|
||||
type RefreshTokenResponse {
|
||||
@ -732,6 +735,11 @@ type ServerInfo {
|
||||
isDemoMode: Boolean!
|
||||
}
|
||||
|
||||
type SourceIdAccessPolicy {
|
||||
sourceId: String!
|
||||
read: [UserGroup!]!
|
||||
}
|
||||
|
||||
type Subscription {
|
||||
createThreadAndRun(input: CreateThreadAndRunInput!): ThreadRunItem!
|
||||
createThreadRun(input: CreateThreadRunInput!): ThreadRunItem!
|
||||
|
19
ee/tabby-schema/src/schema/access_policy.rs
Normal file
19
ee/tabby-schema/src/schema/access_policy.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use async_trait::async_trait;
|
||||
use juniper::{GraphQLObject, ID};
|
||||
|
||||
use super::{user_group::UserGroup, Context, Result};
|
||||
|
||||
#[derive(GraphQLObject)]
|
||||
#[graphql(context = Context)]
|
||||
pub struct SourceIdAccessPolicy {
|
||||
pub source_id: String,
|
||||
pub read: Vec<UserGroup>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AccessPolicyService: Sync + Send {
|
||||
async fn list_source_id_read_access(&self, source_id: &str) -> Result<Vec<UserGroup>>;
|
||||
async fn grant_source_id_read_access(&self, source_id: &str, user_group_id: &ID) -> Result<()>;
|
||||
async fn revoke_source_id_read_access(&self, source_id: &str, user_group_id: &ID)
|
||||
-> Result<()>;
|
||||
}
|
@ -1,3 +1,4 @@
|
||||
pub mod access_policy;
|
||||
pub mod analytic;
|
||||
pub mod auth;
|
||||
pub mod constants;
|
||||
@ -17,6 +18,7 @@ pub mod worker;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use access_policy::{AccessPolicyService, SourceIdAccessPolicy};
|
||||
use auth::{
|
||||
AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse,
|
||||
UserSecured,
|
||||
@ -83,6 +85,7 @@ pub trait ServiceLocator: Send + Sync {
|
||||
fn thread(&self) -> Arc<dyn ThreadService>;
|
||||
fn context(&self) -> Arc<dyn ContextService>;
|
||||
fn user_group(&self) -> Arc<dyn UserGroupService>;
|
||||
fn access_policy(&self) -> Arc<dyn AccessPolicyService>;
|
||||
}
|
||||
|
||||
pub struct Context {
|
||||
@ -658,6 +661,20 @@ impl Query {
|
||||
let user = check_user(ctx).await?;
|
||||
ctx.locator.user_group().list(&user.policy).await
|
||||
}
|
||||
|
||||
async fn source_id_access_policies(
|
||||
ctx: &Context,
|
||||
source_id: String,
|
||||
) -> Result<SourceIdAccessPolicy> {
|
||||
check_admin(ctx).await?;
|
||||
let read = ctx
|
||||
.locator
|
||||
.access_policy()
|
||||
.list_source_id_read_access(&source_id)
|
||||
.await?;
|
||||
|
||||
Ok(SourceIdAccessPolicy { source_id, read })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject)]
|
||||
@ -1117,6 +1134,32 @@ impl Mutation {
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn grant_source_id_read_access(
|
||||
ctx: &Context,
|
||||
source_id: String,
|
||||
user_group_id: ID,
|
||||
) -> Result<bool> {
|
||||
check_admin(ctx).await?;
|
||||
ctx.locator
|
||||
.access_policy()
|
||||
.grant_source_id_read_access(&source_id, &user_group_id)
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
async fn revoke_source_id_read_access(
|
||||
ctx: &Context,
|
||||
source_id: String,
|
||||
user_group_id: ID,
|
||||
) -> Result<bool> {
|
||||
check_admin(ctx).await?;
|
||||
ctx.locator
|
||||
.access_policy()
|
||||
.revoke_source_id_read_access(&source_id, &user_group_id)
|
||||
.await?;
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> {
|
||||
|
59
ee/tabby-webserver/src/service/access_policy.rs
Normal file
59
ee/tabby-webserver/src/service/access_policy.rs
Normal file
@ -0,0 +1,59 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use juniper::ID;
|
||||
use tabby_db::DbConn;
|
||||
use tabby_schema::{
|
||||
access_policy::AccessPolicyService, bail, context::ContextService, user_group::UserGroup,
|
||||
AsRowid, Result,
|
||||
};
|
||||
|
||||
use super::UserGroupExt;
|
||||
|
||||
struct AccessPolicyServiceImpl {
|
||||
db: DbConn,
|
||||
context: Arc<dyn ContextService>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AccessPolicyService for AccessPolicyServiceImpl {
|
||||
async fn list_source_id_read_access(&self, source_id: &str) -> Result<Vec<UserGroup>> {
|
||||
let mut user_groups = Vec::new();
|
||||
for x in self
|
||||
.db
|
||||
.list_source_id_read_access_user_groups(source_id)
|
||||
.await?
|
||||
{
|
||||
user_groups.push(UserGroup::new(self.db.clone(), x).await?)
|
||||
}
|
||||
|
||||
Ok(user_groups)
|
||||
}
|
||||
|
||||
async fn grant_source_id_read_access(&self, source_id: &str, user_group_id: &ID) -> Result<()> {
|
||||
let context_info = self.context.read(None).await?;
|
||||
let helper = context_info.helper();
|
||||
if !helper.can_access_source_id(source_id) {
|
||||
bail!("source_id {} not found", source_id)
|
||||
}
|
||||
|
||||
self.db
|
||||
.upsert_source_id_read_access_policy(source_id, user_group_id.as_rowid()?)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn revoke_source_id_read_access(
|
||||
&self,
|
||||
source_id: &str,
|
||||
user_group_id: &ID,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.delete_source_id_read_access_policy(source_id, user_group_id.as_rowid()?)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create(db: DbConn, context: Arc<dyn ContextService>) -> impl AccessPolicyService {
|
||||
AccessPolicyServiceImpl { db, context }
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tabby_db::DbConn;
|
||||
use tabby_schema::context::ContextService;
|
||||
|
||||
use super::helper::Job;
|
||||
|
||||
@ -12,12 +15,26 @@ impl Job for DbMaintainanceJob {
|
||||
}
|
||||
|
||||
impl DbMaintainanceJob {
|
||||
pub async fn cron(_now: DateTime<Utc>, db: DbConn) -> tabby_schema::Result<()> {
|
||||
pub async fn cron(
|
||||
_now: DateTime<Utc>,
|
||||
context: Arc<dyn ContextService>,
|
||||
db: DbConn,
|
||||
) -> tabby_schema::Result<()> {
|
||||
db.delete_expired_token().await?;
|
||||
db.delete_expired_password_resets().await?;
|
||||
db.delete_expired_ephemeral_threads().await?;
|
||||
|
||||
// FIXME(meng): add maintainance job for source_id_read_access_policies
|
||||
// Read all active sources
|
||||
let active_source_ids = context
|
||||
.read(None)
|
||||
.await?
|
||||
.sources
|
||||
.into_iter()
|
||||
.map(|source| source.source_id)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
db.delete_unused_source_id_read_access_policy(&active_source_ids)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -120,7 +120,7 @@ pub async fn start(
|
||||
debug!("Background job {} completed", job.id);
|
||||
},
|
||||
Some(now) = hourly.next() => {
|
||||
if let Err(err) = DbMaintainanceJob::cron(now, db.clone()).await {
|
||||
if let Err(err) = DbMaintainanceJob::cron(now, context_service.clone(), db.clone()).await {
|
||||
warn!("Database maintainance failed: {:?}", err);
|
||||
}
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
mod access_policy;
|
||||
mod analytic;
|
||||
pub mod answer;
|
||||
mod auth;
|
||||
@ -19,6 +20,7 @@ pub mod web_documents;
|
||||
use std::sync::Arc;
|
||||
|
||||
use answer::AnswerService;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@ -32,14 +34,16 @@ use tabby_common::{
|
||||
api::{code::CodeSearch, event::EventLogger},
|
||||
constants::USER_HEADER_FIELD_NAME,
|
||||
};
|
||||
use tabby_db::{DbConn, UserDAO};
|
||||
use tabby_db::{DbConn, UserDAO, UserGroupDAO};
|
||||
use tabby_inference::Embedding;
|
||||
use tabby_schema::{
|
||||
access_policy::AccessPolicyService,
|
||||
analytic::AnalyticService,
|
||||
auth::AuthenticationService,
|
||||
auth::{AuthenticationService, UserSecured},
|
||||
context::ContextService,
|
||||
email::EmailService,
|
||||
integration::IntegrationService,
|
||||
interface::UserValue,
|
||||
is_demo_mode,
|
||||
job::JobService,
|
||||
license::{IsLicenseValid, LicenseService},
|
||||
@ -48,7 +52,7 @@ use tabby_schema::{
|
||||
setting::SettingService,
|
||||
thread::ThreadService,
|
||||
user_event::UserEventService,
|
||||
user_group::UserGroupService,
|
||||
user_group::{UserGroup, UserGroupMembership, UserGroupService},
|
||||
web_documents::WebDocumentService,
|
||||
worker::WorkerService,
|
||||
AsID, AsRowid, CoreError, Result, ServiceLocator,
|
||||
@ -70,6 +74,7 @@ struct ServerContext {
|
||||
thread: Arc<dyn ThreadService>,
|
||||
context: Arc<dyn ContextService>,
|
||||
user_group: Arc<dyn UserGroupService>,
|
||||
access_policy: Arc<dyn AccessPolicyService>,
|
||||
|
||||
logger: Arc<dyn EventLogger>,
|
||||
code: Arc<dyn CodeSearch>,
|
||||
@ -107,6 +112,7 @@ impl ServerContext {
|
||||
let setting = Arc::new(setting::create(db_conn.clone()));
|
||||
let thread = Arc::new(thread::create(db_conn.clone(), answer.clone()));
|
||||
let user_group = Arc::new(user_group::create(db_conn.clone()));
|
||||
let access_policy = Arc::new(access_policy::create(db_conn.clone(), context.clone()));
|
||||
|
||||
background_job::start(
|
||||
db_conn.clone(),
|
||||
@ -140,6 +146,7 @@ impl ServerContext {
|
||||
code,
|
||||
setting,
|
||||
user_group,
|
||||
access_policy,
|
||||
db_conn,
|
||||
is_chat_enabled_locally,
|
||||
}
|
||||
@ -297,6 +304,10 @@ impl ServiceLocator for ArcServerContext {
|
||||
fn user_group(&self) -> Arc<dyn UserGroupService> {
|
||||
self.0.user_group.clone()
|
||||
}
|
||||
|
||||
fn access_policy(&self) -> Arc<dyn AccessPolicyService> {
|
||||
self.0.access_policy.clone()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_service_locator(
|
||||
@ -391,3 +402,36 @@ impl UserSecuredExt for tabby_schema::auth::UserSecured {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
trait UserGroupExt {
|
||||
async fn new(db: DbConn, val: UserGroupDAO) -> Result<UserGroup>;
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UserGroupExt for UserGroup {
|
||||
async fn new(db: DbConn, val: UserGroupDAO) -> Result<UserGroup> {
|
||||
let mut members = Vec::new();
|
||||
for x in db.list_user_group_memberships(val.id, None).await? {
|
||||
members.push(UserGroupMembership {
|
||||
is_group_admin: x.is_group_admin,
|
||||
created_at: x.created_at,
|
||||
updated_at: x.updated_at,
|
||||
user: UserValue::UserSecured(UserSecured::new(
|
||||
db.clone(),
|
||||
db.get_user(x.user_id)
|
||||
.await?
|
||||
.context("User doesn't exists")?,
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(UserGroup {
|
||||
id: val.id.as_id(),
|
||||
name: val.name,
|
||||
created_at: val.created_at,
|
||||
updated_at: val.updated_at,
|
||||
members,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,14 @@
|
||||
use anyhow::Context;
|
||||
use juniper::ID;
|
||||
use tabby_db::DbConn;
|
||||
use tabby_schema::{
|
||||
auth::UserSecured,
|
||||
interface::UserValue,
|
||||
policy::AccessPolicy,
|
||||
user_group::{
|
||||
CreateUserGroupInput, UpsertUserGroupMembershipInput, UserGroup, UserGroupMembership,
|
||||
UserGroupService,
|
||||
CreateUserGroupInput, UpsertUserGroupMembershipInput, UserGroup, UserGroupService,
|
||||
},
|
||||
AsID, AsRowid, Result,
|
||||
};
|
||||
|
||||
use super::UserSecuredExt;
|
||||
use super::{UserGroupExt, UserSecuredExt};
|
||||
|
||||
struct UserGroupServiceImpl {
|
||||
db: DbConn,
|
||||
@ -28,29 +24,7 @@ impl UserGroupService for UserGroupServiceImpl {
|
||||
|
||||
let mut user_groups = Vec::new();
|
||||
for x in self.db.list_user_groups(user_id).await? {
|
||||
let mut members = Vec::new();
|
||||
for x in self.db.list_user_group_memberships(x.id, None).await? {
|
||||
members.push(UserGroupMembership {
|
||||
is_group_admin: x.is_group_admin,
|
||||
created_at: x.created_at,
|
||||
updated_at: x.updated_at,
|
||||
user: UserValue::UserSecured(UserSecured::new(
|
||||
self.db.clone(),
|
||||
self.db
|
||||
.get_user(x.user_id)
|
||||
.await?
|
||||
.context("User doesn't exists")?,
|
||||
)),
|
||||
});
|
||||
}
|
||||
|
||||
user_groups.push(UserGroup {
|
||||
id: x.id.as_id(),
|
||||
name: x.name,
|
||||
created_at: x.created_at,
|
||||
updated_at: x.updated_at,
|
||||
members,
|
||||
});
|
||||
user_groups.push(UserGroup::new(self.db.clone(), x).await?);
|
||||
}
|
||||
Ok(user_groups)
|
||||
}
|
||||
@ -91,6 +65,7 @@ pub fn create(db: DbConn) -> impl UserGroupService {
|
||||
mod tests {
|
||||
use assert_matches::assert_matches;
|
||||
use tabby_db::testutils;
|
||||
use tabby_schema::interface::UserValue;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user