Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ async fn main() {
let api_router = OpenApiRouter::new()
.route("/protected", get(async || "Protected"))
.route_layer(from_fn(auth_required::<User>))
.nest("/auth", AuthRoutes::openapi_router::<User, ()>());
.nest("/auth", AuthRoutes::new(shield).openapi_router());

// Initialize router
let (router, openapi) = OpenApiRouter::new()
Expand Down
7 changes: 2 additions & 5 deletions examples/leptos-axum/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@ async fn main() {
use std::sync::Arc;

use axum::{Router, middleware::from_fn, routing::get};
use leptos::{
config::{LeptosOptions, get_configuration},
context::provide_context,
};
use leptos::{config::get_configuration, context::provide_context};
use leptos_axum::{LeptosRoutes, generate_route_list};
use shield::{Shield, ShieldOptions};
use shield_bootstrap::BootstrapLeptosStyle;
Expand Down Expand Up @@ -62,7 +59,7 @@ async fn main() {
let router = Router::new()
.route("/api/protected", get(async || "Protected"))
.route_layer(from_fn(auth_required::<User>))
.nest("/api/auth", AuthRoutes::router::<User, LeptosOptions>())
.nest("/api/auth", AuthRoutes::new(shield).router())
.leptos_routes_with_context(
&leptos_options,
routes,
Expand Down
1 change: 1 addition & 0 deletions packages/core/shield/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ utoipa = ["dep:utoipa"]
async-trait.workspace = true
bon.workspace = true
chrono = { workspace = true, features = ["serde"] }
convert_case = "0.8.0"
futures.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
Expand Down
59 changes: 56 additions & 3 deletions packages/core/shield/src/action.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use std::any::Any;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};

use crate::{
error::ShieldError,
form::Form,
Expand All @@ -11,6 +8,38 @@ use crate::{
response::Response,
session::{BaseSession, MethodSession},
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[cfg(feature = "utoipa")]
use utoipa::openapi::HttpMethod;

#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum ActionMethod {
Get,
Post,
Put,
Delete,
Options,
Head,
Patch,
Trace,
}

#[cfg(feature = "utoipa")]
impl From<ActionMethod> for HttpMethod {
fn from(value: ActionMethod) -> Self {
match value {
ActionMethod::Get => Self::Get,
ActionMethod::Post => Self::Post,
ActionMethod::Put => Self::Put,
ActionMethod::Delete => Self::Delete,
ActionMethod::Options => Self::Options,
ActionMethod::Head => Self::Head,
ActionMethod::Patch => Self::Patch,
ActionMethod::Trace => Self::Trace,
}
}
}

// TODO: Think of a better name.
#[derive(Clone, Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -46,6 +75,12 @@ pub trait Action<P: Provider, S>: ErasedAction + Send + Sync {

fn name(&self) -> String;

fn openapi_summary(&self) -> &'static str;

fn openapi_description(&self) -> &'static str;

fn method(&self) -> ActionMethod;

fn condition(&self, _provider: &P, _session: &MethodSession<S>) -> Result<bool, ShieldError> {
Ok(true)
}
Expand All @@ -66,6 +101,12 @@ pub trait ErasedAction: Send + Sync {

fn erased_name(&self) -> String;

fn erased_openapi_summary(&self) -> &'static str;

fn erased_openapi_description(&self) -> &'static str;

fn erased_method(&self) -> ActionMethod;

fn erased_condition(
&self,
provider: &(dyn Any + Send + Sync),
Expand Down Expand Up @@ -100,6 +141,18 @@ macro_rules! erased_action {
self.name()
}

fn erased_openapi_summary(&self) -> &'static str {
self.openapi_summary()
}

fn erased_openapi_description(&self) -> &'static str {
self.openapi_description()
}

fn erased_method(&self) -> $crate::ActionMethod {
self.method()
}

fn erased_condition(
&self,
provider: &(dyn std::any::Any + Send + Sync),
Expand Down
3 changes: 3 additions & 0 deletions packages/core/shield/src/actions/sign_out.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use crate::{Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, Sh
const ACTION_ID: &str = "sign-out";
const ACTION_NAME: &str = "Sign out";

// TODO: Sign out should be a global action that is independent of the method.
// TODO: Add hooks, so the method can still perform custom sign out.

pub struct SignOutAction;

impl SignOutAction {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/shield/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod error;
mod form;
mod method;
mod options;
mod path;
mod provider;
mod request;
mod response;
Expand All @@ -19,6 +20,7 @@ pub use error::*;
pub use form::*;
pub use method::*;
pub use options::*;
pub use path::*;
pub use provider::*;
pub use request::*;
pub use response::*;
Expand Down
60 changes: 55 additions & 5 deletions packages/core/shield/src/shield.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
use std::{any::Any, collections::HashMap, sync::Arc};

#[cfg(feature = "utoipa")]
use convert_case::{Case, Casing};
use futures::future::try_join_all;
use tracing::{debug, warn};
use tracing::warn;
#[cfg(feature = "utoipa")]
use utoipa::{
IntoParams,
openapi::{
OpenApi, PathItem, Paths,
path::{Operation, ParameterIn},
},
};

use crate::{
action::{ActionForms, ActionMethodForm, ActionProviderForm},
error::{ActionError, MethodError, ProviderError, SessionError, ShieldError},
method::ErasedMethod,
options::ShieldOptions,
path::ActionPathParams,
request::Request,
response::ResponseType,
session::Session,
Expand Down Expand Up @@ -184,16 +195,12 @@ impl<U: User> Shield<U> {
.erased_call(provider, &base_session, &*method_session, request)
.await?;

debug!("response {:#?}", response);

for session_action in &response.session_actions {
session_action
.call(method_id, provider_id, &session)
.await?;
}

debug!("session actions processed");

Ok(response.r#type)
}

Expand Down Expand Up @@ -232,6 +239,49 @@ impl<U: User> Shield<U> {
None => Ok(None),
}
}

#[cfg(feature = "utoipa")]
pub fn openapi(&self) -> OpenApi {
let mut paths = Paths::builder();

for method in self.methods.values() {
for action in method.erased_actions() {
use utoipa::openapi::Response;

let method_id = method.erased_id();
let action_id = action.erased_id();

// TODO: Query, request body, responses.

paths = paths.path(
format!("/{}/{}/{{providerId}}", method_id, action_id),
PathItem::builder()
.operation(
action.erased_method().into(),
Operation::builder()
.operation_id(Some(format!(
"{}{}",
action_id.to_case(Case::Camel),
method_id.to_case(Case::UpperCamel)
)))
.summary(Some(action.erased_openapi_summary()))
.description(Some(action.erased_openapi_description()))
.tag("auth")
.parameters(Some(ActionPathParams::into_params(|| {
Some(ParameterIn::Path)
})))
.response(
"500",
Response::builder().description("Internal server error."),
),
)
.build(),
);
}
}

OpenApi::builder().paths(paths.build()).build()
}
}

#[cfg(test)]
Expand Down
1 change: 0 additions & 1 deletion packages/integrations/shield-axum/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
mod error;
mod extract;
mod middleware;
mod path;
mod router;
mod routes;

Expand Down
32 changes: 20 additions & 12 deletions packages/integrations/shield-axum/src/router.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,40 @@
use axum::{
Router,
routing::{get, post},
routing::{any, get, post},
};
use shield::User;
use shield::{Shield, User};
#[cfg(feature = "utoipa")]
use utoipa::OpenApi;
#[cfg(feature = "utoipa")]
use utoipa_axum::router::OpenApiRouter;

use crate::routes::*;

#[cfg(feature = "utoipa")]
#[cfg_attr(feature = "utoipa", derive(utoipa::OpenApi))]
#[cfg_attr(feature = "utoipa", openapi(paths(action, forms, user)))]
pub struct AuthRoutes;
struct BaseOpenApi;

impl AuthRoutes {
pub fn router<U: User + Clone + 'static, S: Clone + Send + Sync + 'static>() -> Router<S> {
pub struct AuthRoutes<U: User> {
shield: Shield<U>,
}

impl<U: Clone + User + 'static> AuthRoutes<U> {
pub fn new(shield: Shield<U>) -> Self {
Self { shield }
}

pub fn router<S: Clone + Send + Sync + 'static>(&self) -> Router<S> {
Router::new()
.route("/user", get(user::<U>))
.route("/forms/{actionId}", get(forms::<U>))
.route("/{methodId}/{actionId}", get(action::<U>))
.route("/{methodId}/{actionId}", post(action::<U>))
.route("/{methodId}/{actionId}/{providerId}", get(action::<U>))
.route("/{methodId}/{actionId}/{providerId}", post(action::<U>))
.route("/{methodId}/{actionId}", any(action::<U>))
.route("/{methodId}/{actionId}/{providerId}", any(action::<U>))
}

#[cfg(feature = "utoipa")]
pub fn openapi_router<U: User + Clone + 'static, S: Clone + Send + Sync + 'static>()
-> OpenApiRouter<S> {
OpenApiRouter::with_openapi(AuthRoutes::openapi())
pub fn openapi_router<S: Clone + Send + Sync + 'static>(&self) -> OpenApiRouter<S> {
OpenApiRouter::with_openapi(BaseOpenApi::openapi().merge_from(self.shield.openapi()))
.route("/user", get(user::<U>))
.route("/forms/{actionId}", get(forms::<U>))
.route("/{methodId}/{actionId}", get(action::<U>))
Expand Down
4 changes: 2 additions & 2 deletions packages/integrations/shield-axum/src/routes/action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use axum::{
response::{IntoResponse, Redirect, Response},
};
use serde_json::Value;
use shield::{Request, ResponseType, User};
use shield::{ActionPathParams, Request, ResponseType, User};

use crate::{ExtractSession, ExtractShield, RouteError, error::ErrorBody, path::ActionPathParams};
use crate::{ExtractSession, ExtractShield, RouteError, error::ErrorBody};

#[cfg_attr(
feature = "utoipa",
Expand Down
6 changes: 2 additions & 4 deletions packages/integrations/shield-axum/src/routes/forms.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use axum::{Json, extract::Path};
use shield::{ActionForms, User};
use shield::{ActionForms, ActionFormsPathParams, User};

use crate::{
ExtractSession, ExtractShield, RouteError, error::ErrorBody, path::ActionFormsPathParams,
};
use crate::{ExtractSession, ExtractShield, RouteError, error::ErrorBody};

#[cfg_attr(
feature = "utoipa",
Expand Down
16 changes: 14 additions & 2 deletions packages/methods/shield-credentials/src/actions/sign_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use shield::{
Action, Form, MethodSession, Request, Response, ResponseType, SessionAction, ShieldError,
SignInAction, User, erased_action,
Action, ActionMethod, Form, MethodSession, Request, Response, ResponseType, SessionAction,
ShieldError, SignInAction, User, erased_action,
};

use crate::{credentials::Credentials, provider::CredentialsProvider};
Expand All @@ -31,6 +31,18 @@ impl<U: User + 'static, D: DeserializeOwned + 'static> Action<CredentialsProvide
SignInAction::name()
}

fn openapi_summary(&self) -> &'static str {
"Sign in with credentials"
}

fn openapi_description(&self) -> &'static str {
"Sign in with credentials."
}

fn method(&self) -> ActionMethod {
ActionMethod::Post
}

async fn forms(&self, _provider: CredentialsProvider) -> Result<Vec<Form>, ShieldError> {
Ok(vec![self.credentials.form()])
}
Expand Down
Loading