mirror of
https://github.com/TabbyML/tabby
synced 2024-11-23 10:05:08 +00:00
feat(webserver): Add graphql api for oauth credential management (#1177)
* feat(webserver): graphql api for oauth management * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * resolve comment * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
356d1b0751
commit
ef7674c29d
@ -31,31 +31,52 @@ impl DbConn {
|
||||
pub async fn update_github_oauth_credential(
|
||||
&self,
|
||||
client_id: &str,
|
||||
client_secret: &str,
|
||||
client_secret: Option<&str>,
|
||||
active: bool,
|
||||
) -> Result<()> {
|
||||
let client_id = client_id.to_string();
|
||||
let client_secret = client_secret.to_string();
|
||||
|
||||
self.conn
|
||||
.call(move |c| {
|
||||
let mut stmt = c.prepare(
|
||||
r#"INSERT INTO github_oauth_credential (id, client_id, client_secret)
|
||||
VALUES (:id, :cid, :secret) ON CONFLICT(id) DO UPDATE
|
||||
SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now')
|
||||
WHERE id = :id"#,
|
||||
)?;
|
||||
stmt.insert(named_params! {
|
||||
if let Some(client_secret) = client_secret {
|
||||
let client_secret = client_secret.to_string();
|
||||
let sql = r#"INSERT INTO github_oauth_credential (id, client_id, client_secret, active)
|
||||
VALUES (:id, :cid, :secret, :active) ON CONFLICT(id) DO UPDATE
|
||||
SET client_id = :cid, client_secret = :secret, active = :active, updated_at = datetime('now')
|
||||
WHERE id = :id"#;
|
||||
self.conn
|
||||
.call(move |c| {
|
||||
let mut stmt = c.prepare(sql)?;
|
||||
stmt.insert(named_params! {
|
||||
":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID,
|
||||
":cid": client_id,
|
||||
":secret": client_secret,
|
||||
":active": active,
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(())
|
||||
} else {
|
||||
let sql = r#"
|
||||
UPDATE github_oauth_credential SET client_id = :cid, active = :active, updated_at = datetime('now')
|
||||
WHERE id = :id"#;
|
||||
let rows = self
|
||||
.conn
|
||||
.call(move |c| {
|
||||
let mut stmt = c.prepare(sql)?;
|
||||
let rows = stmt.execute(named_params! {
|
||||
":id": GITHUB_OAUTH_CREDENTIAL_ROW_ID,
|
||||
":cid": client_id,
|
||||
":active": active,
|
||||
})?;
|
||||
Ok(rows)
|
||||
})
|
||||
.await?;
|
||||
if rows != 1 {
|
||||
return Err(anyhow::anyhow!(
|
||||
"failed to update: github credential not found"
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_github_oauth_credential(&self) -> Result<Option<GithubOAuthCredentialDAO>> {
|
||||
@ -82,9 +103,16 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_update_github_oauth_credential() {
|
||||
// test insert
|
||||
let conn = DbConn::new_in_memory().await.unwrap();
|
||||
conn.update_github_oauth_credential("client_id", "client_secret", false)
|
||||
|
||||
// test update failure when no record exists
|
||||
let res = conn
|
||||
.update_github_oauth_credential("client_id", None, false)
|
||||
.await;
|
||||
assert!(res.is_err());
|
||||
|
||||
// test insert
|
||||
conn.update_github_oauth_credential("client_id", Some("client_secret"), true)
|
||||
.await
|
||||
.unwrap();
|
||||
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
|
||||
@ -93,12 +121,21 @@ mod tests {
|
||||
assert!(res.active);
|
||||
|
||||
// test update
|
||||
conn.update_github_oauth_credential("client_id", "client_secret_2", false)
|
||||
conn.update_github_oauth_credential("client_id", Some("client_secret_2"), false)
|
||||
.await
|
||||
.unwrap();
|
||||
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
|
||||
assert_eq!(res.client_id, "client_id");
|
||||
assert_eq!(res.client_secret, "client_secret_2");
|
||||
assert!(!res.active);
|
||||
|
||||
// test update without client_secret
|
||||
conn.update_github_oauth_credential("client_id_2", None, true)
|
||||
.await
|
||||
.unwrap();
|
||||
let res = conn.read_github_oauth_credential().await.unwrap().unwrap();
|
||||
assert_eq!(res.client_id, "client_id_2");
|
||||
assert_eq!(res.client_secret, "client_secret_2");
|
||||
assert!(res.active);
|
||||
}
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ type Mutation {
|
||||
createInvitation(email: String!): ID!
|
||||
deleteInvitation(id: Int!): Int! @deprecated
|
||||
deleteInvitationNext(id: ID!): ID!
|
||||
updateOauthCredential(provider: OAuthProvider!, clientId: String!, clientSecret: String, active: Boolean!): Boolean!
|
||||
}
|
||||
|
||||
"DateTime"
|
||||
@ -49,6 +50,7 @@ type Query {
|
||||
usersNext(after: String, before: String, first: Int, last: Int): UserConnection!
|
||||
invitationsNext(after: String, before: String, first: Int, last: Int): InvitationConnection!
|
||||
jobRuns(after: String, before: String, first: Int, last: Int): JobRunConnection!
|
||||
oauthCredential(provider: OAuthProvider!): OAuthCredential
|
||||
}
|
||||
|
||||
type UserEdge {
|
||||
@ -84,6 +86,14 @@ type UserConnection {
|
||||
pageInfo: PageInfo!
|
||||
}
|
||||
|
||||
type OAuthCredential {
|
||||
provider: OAuthProvider!
|
||||
clientId: String!
|
||||
active: Boolean!
|
||||
createdAt: DateTimeUtc!
|
||||
updatedAt: DateTimeUtc!
|
||||
}
|
||||
|
||||
type VerifyTokenResponse {
|
||||
claims: JWTPayload!
|
||||
}
|
||||
@ -103,6 +113,11 @@ type User {
|
||||
createdAt: DateTimeUtc!
|
||||
}
|
||||
|
||||
type TokenAuthResponse {
|
||||
accessToken: String!
|
||||
refreshToken: String!
|
||||
}
|
||||
|
||||
type Worker {
|
||||
kind: WorkerKind!
|
||||
name: String!
|
||||
@ -119,9 +134,8 @@ type InvitationEdge {
|
||||
cursor: String!
|
||||
}
|
||||
|
||||
type TokenAuthResponse {
|
||||
accessToken: String!
|
||||
refreshToken: String!
|
||||
enum OAuthProvider {
|
||||
GITHUB
|
||||
}
|
||||
|
||||
type PageInfo {
|
||||
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use jsonwebtoken as jwt;
|
||||
use juniper::{FieldError, GraphQLObject, IntoFieldError, ScalarValue, ID};
|
||||
use juniper::{FieldError, GraphQLEnum, GraphQLObject, IntoFieldError, ScalarValue, ID};
|
||||
use juniper_axum::relay;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@ -334,6 +334,21 @@ impl relay::NodeType for InvitationNext {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(GraphQLEnum, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum OAuthProvider {
|
||||
Github,
|
||||
}
|
||||
|
||||
#[derive(GraphQLObject)]
|
||||
pub struct OAuthCredential {
|
||||
pub provider: OAuthProvider,
|
||||
pub client_id: String,
|
||||
pub active: bool,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthenticationService: Send + Sync {
|
||||
async fn register(
|
||||
@ -384,6 +399,19 @@ pub trait AuthenticationService: Send + Sync {
|
||||
code: String,
|
||||
client: Arc<GithubClient>,
|
||||
) -> std::result::Result<GithubAuthResponse, GithubAuthError>;
|
||||
|
||||
async fn read_oauth_credential(
|
||||
&self,
|
||||
provider: OAuthProvider,
|
||||
) -> Result<Option<OAuthCredential>>;
|
||||
|
||||
async fn update_oauth_credential(
|
||||
&self,
|
||||
provider: OAuthProvider,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
active: bool,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1,6 +1,10 @@
|
||||
use tabby_db::{InvitationDAO, JobRunDAO, UserDAO};
|
||||
use tabby_db::{GithubOAuthCredentialDAO, InvitationDAO, JobRunDAO, UserDAO};
|
||||
|
||||
use crate::schema::{auth, job};
|
||||
use crate::schema::{
|
||||
auth,
|
||||
auth::{OAuthCredential, OAuthProvider},
|
||||
job,
|
||||
};
|
||||
|
||||
impl From<InvitationDAO> for auth::InvitationNext {
|
||||
fn from(val: InvitationDAO) -> Self {
|
||||
@ -38,3 +42,15 @@ impl From<UserDAO> for auth::User {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<GithubOAuthCredentialDAO> for OAuthCredential {
|
||||
fn from(val: GithubOAuthCredentialDAO) -> Self {
|
||||
OAuthCredential {
|
||||
provider: OAuthProvider::Github,
|
||||
client_id: val.client_id,
|
||||
active: val.active,
|
||||
created_at: val.created_at,
|
||||
updated_at: val.updated_at,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,6 +21,8 @@ use tracing::error;
|
||||
use validator::ValidationErrors;
|
||||
use worker::{Worker, WorkerService};
|
||||
|
||||
use crate::schema::auth::{OAuthCredential, OAuthProvider};
|
||||
|
||||
pub trait ServiceLocator: Send + Sync {
|
||||
fn auth(&self) -> Arc<dyn AuthenticationService>;
|
||||
fn worker(&self) -> Arc<dyn WorkerService>;
|
||||
@ -240,6 +242,20 @@ impl Query {
|
||||
"Only admin is able to query job runs",
|
||||
)))
|
||||
}
|
||||
|
||||
async fn oauth_credential(
|
||||
ctx: &Context,
|
||||
provider: OAuthProvider,
|
||||
) -> Result<Option<OAuthCredential>> {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.is_admin {
|
||||
return Ok(ctx.locator.auth().read_oauth_credential(provider).await?);
|
||||
}
|
||||
}
|
||||
Err(CoreError::Unauthorized(
|
||||
"Only admin is able to query oauth credential",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@ -341,6 +357,27 @@ impl Mutation {
|
||||
"Only admin is able to delete invitation",
|
||||
))
|
||||
}
|
||||
|
||||
async fn update_oauth_credential(
|
||||
ctx: &Context,
|
||||
provider: OAuthProvider,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
active: bool,
|
||||
) -> Result<bool> {
|
||||
if let Some(claims) = &ctx.claims {
|
||||
if claims.is_admin {
|
||||
ctx.locator
|
||||
.auth()
|
||||
.update_oauth_credential(provider, client_id, client_secret, active)
|
||||
.await?;
|
||||
return Ok(true);
|
||||
}
|
||||
}
|
||||
Err(CoreError::Unauthorized(
|
||||
"Only admin is able to update oauth credential",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn from_validation_errors<S: ScalarValue>(error: ValidationErrors) -> FieldError<S> {
|
||||
|
@ -15,9 +15,9 @@ use crate::{
|
||||
oauth::github::GithubClient,
|
||||
schema::auth::{
|
||||
generate_jwt, generate_refresh_token, validate_jwt, AuthenticationService, GithubAuthError,
|
||||
GithubAuthResponse, InvitationNext, JWTPayload, RefreshTokenError, RefreshTokenResponse,
|
||||
RegisterError, RegisterResponse, TokenAuthError, TokenAuthResponse, User,
|
||||
VerifyTokenResponse,
|
||||
GithubAuthResponse, InvitationNext, JWTPayload, OAuthCredential, OAuthProvider,
|
||||
RefreshTokenError, RefreshTokenResponse, RegisterError, RegisterResponse, TokenAuthError,
|
||||
TokenAuthResponse, User, VerifyTokenResponse,
|
||||
},
|
||||
};
|
||||
|
||||
@ -387,6 +387,31 @@ impl AuthenticationService for DbConn {
|
||||
};
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
async fn read_oauth_credential(
|
||||
&self,
|
||||
provider: OAuthProvider,
|
||||
) -> Result<Option<OAuthCredential>> {
|
||||
match provider {
|
||||
OAuthProvider::Github => {
|
||||
Ok(self.read_github_oauth_credential().await?.map(|x| x.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn update_oauth_credential(
|
||||
&self,
|
||||
provider: OAuthProvider,
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
active: bool,
|
||||
) -> Result<()> {
|
||||
match provider {
|
||||
OAuthProvider::Github => Ok(self
|
||||
.update_github_oauth_credential(&client_id, client_secret.as_deref(), active)
|
||||
.await?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn password_hash(raw: &str) -> password_hash::Result<String> {
|
||||
|
Loading…
Reference in New Issue
Block a user