diff --git a/Cargo.lock b/Cargo.lock index f3241eb..bc5ed00 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -413,15 +413,6 @@ version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" -[[package]] -name = "arbitrary" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" -dependencies = [ - "derive_arbitrary", -] - [[package]] name = "arrayvec" version = "0.7.6" @@ -536,6 +527,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ "axum-core", + "axum-macros", "base64 0.22.1", "bytes", "form_urlencoded", @@ -587,6 +579,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -1355,17 +1358,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "derive_arbitrary" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.104", -] - [[package]] name = "derive_more" version = "2.0.1" @@ -2190,7 +2182,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a3d7db9596fecd151c5f638c0ee5d5bd487b6e0ea232e5dc96d5250f6f94b1d" dependencies = [ "crc32fast", - "libz-rs-sys", "miniz_oxide", ] @@ -3392,7 +3383,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -3412,15 +3403,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "libz-rs-sys" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "172a788537a2221661b480fee8dc5f96c580eb34fa88764d3205dc356c7e4221" -dependencies = [ - "zlib-rs", -] - [[package]] name = "linear-map" version = "1.2.0" @@ -3837,7 +3819,7 @@ version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "chrono", "getrandom 0.2.16", "http 1.3.1", @@ -4762,40 +4744,6 @@ dependencies = [ "thiserror 2.0.16", ] -[[package]] -name = "rust-embed" -version = "8.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "025908b8682a26ba8d12f6f2d66b987584a4a87bc024abc5bbc12553a8cd178a" -dependencies = [ - "rust-embed-impl", - "rust-embed-utils", - "walkdir", -] - -[[package]] -name = "rust-embed-impl" -version = "8.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6065f1a4392b71819ec1ea1df1120673418bf386f50de1d6f54204d836d4349c" -dependencies = [ - "proc-macro2", - "quote", - "rust-embed-utils", - "syn 2.0.104", - "walkdir", -] - -[[package]] -name = "rust-embed-utils" -version = "8.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6cc0c81648b20b70c491ff8cce00c1c3b223bb8ed2b5d41f0e54c6c4c0a3594" -dependencies = [ - "sha2", - "walkdir", -] - [[package]] name = "rust_decimal" version = "1.37.2" @@ -5469,6 +5417,7 @@ dependencies = [ "shield", "shield-tower", "utoipa", + "utoipa-axum", ] [[package]] @@ -5539,8 +5488,8 @@ dependencies = [ "tower-sessions", "tracing", "tracing-subscriber", - "utoipa", - "utoipa-swagger-ui", + "utoipa-axum", + "utoipa-scalar", ] [[package]] @@ -5604,8 +5553,6 @@ dependencies = [ "tower-sessions", "tracing", "tracing-subscriber", - "utoipa", - "utoipa-swagger-ui", "wasm-bindgen", "wasm-tracing", ] @@ -5791,12 +5738,6 @@ dependencies = [ "rand_core 0.6.4", ] -[[package]] -name = "simd-adler32" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" - [[package]] name = "simdutf8" version = "0.1.5" @@ -6975,6 +6916,19 @@ dependencies = [ "utoipa-gen", ] +[[package]] +name = "utoipa-axum" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c25bae5bccc842449ec0c5ddc5cbb6a3a1eaeac4503895dc105a1138f8234a0" +dependencies = [ + "axum", + "paste", + "tower-layer", + "tower-service", + "utoipa", +] + [[package]] name = "utoipa-gen" version = "5.4.0" @@ -6989,30 +6943,17 @@ dependencies = [ ] [[package]] -name = "utoipa-swagger-ui" -version = "9.0.2" +name = "utoipa-scalar" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d047458f1b5b65237c2f6dc6db136945667f40a7668627b3490b9513a3d43a55" +checksum = "59559e1509172f6b26c1cdbc7247c4ddd1ac6560fe94b584f81ee489b141f719" dependencies = [ "axum", - "base64 0.22.1", - "mime_guess", - "regex", - "rust-embed", "serde", "serde_json", - "url", "utoipa", - "utoipa-swagger-ui-vendored", - "zip", ] -[[package]] -name = "utoipa-swagger-ui-vendored" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" - [[package]] name = "uuid" version = "1.18.1" @@ -7251,7 +7192,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -7721,38 +7662,6 @@ dependencies = [ "syn 2.0.104", ] -[[package]] -name = "zip" -version = "3.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12598812502ed0105f607f941c386f43d441e00148fce9dec3ca5ffb0bde9308" -dependencies = [ - "arbitrary", - "crc32fast", - "flate2", - "indexmap 2.9.0", - "memchr", - "zopfli", -] - -[[package]] -name = "zlib-rs" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "626bd9fa9734751fc50d6060752170984d7053f5a39061f524cda68023d4db8a" - -[[package]] -name = "zopfli" -version = "0.8.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edfc5ee405f504cd4984ecc6f14d02d55cfda60fa4b689434ef4102aae150cd7" -dependencies = [ - "bumpalo", - "crc32fast", - "log", - "simd-adler32", -] - [[package]] name = "zstd" version = "0.13.3" diff --git a/Cargo.toml b/Cargo.toml index b572d15..9617da8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,24 @@ [workspace] -members = ["examples/*", "packages/*/*"] +members = [ + "examples/axum", + "examples/dioxus-axum", + "examples/leptos-actix", + "examples/leptos-axum", + "examples/sea-orm", + "examples/workos", + "packages/core/*", + "packages/integrations/shield-actix", + "packages/integrations/shield-axum", + "packages/integrations/shield-dioxus", + "packages/integrations/shield-dioxus-axum", + "packages/integrations/shield-leptos", + "packages/integrations/shield-leptos-actix", + "packages/integrations/shield-leptos-axum", + "packages/integrations/shield-tower", + "packages/methods/*", + "packages/storage/*", + "packages/styles/shield-bootstrap", +] resolver = "2" [workspace.package] @@ -63,6 +82,7 @@ tower-sessions = "0.14.0" tracing = "0.1.41" tracing-subscriber = "0.3.19" utoipa = { version = "5.3.1", features = ["chrono", "uuid"] } +utoipa-axum = "0.2.0" uuid = "1.11.0" wasm-bindgen = "0.2.100" wasm-tracing = "2.0.0" diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml index ecced05..5bac1ac 100644 --- a/examples/axum/Cargo.toml +++ b/examples/axum/Cargo.toml @@ -20,5 +20,5 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tower-sessions.workspace = true tracing.workspace = true tracing-subscriber.workspace = true -utoipa.workspace = true -utoipa-swagger-ui = { version = "9.0.0", features = ["axum", "vendored"] } +utoipa-axum.workspace = true +utoipa-scalar = { version = "0.3.0", features = ["axum"] } diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index 656f7d8..9268751 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -1,22 +1,20 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; + +use axum::{Json, middleware::from_fn, routing::get}; +use shield::{Shield, ShieldOptions}; +use shield_axum::{AuthRoutes, ShieldLayer, auth_required}; +use shield_memory::{MemoryStorage, User}; +use shield_oidc::{Keycloak, OidcMethod}; +use time::Duration; +use tokio::net::TcpListener; +use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; +use tracing::{info, level_filters::LevelFilter}; +use utoipa_axum::router::OpenApiRouter; +use utoipa_scalar::{Scalar, Servable}; #[tokio::main] async fn main() { - use std::sync::Arc; - - use axum::{Router, middleware::from_fn, routing::get}; - - use shield::{Shield, ShieldOptions}; - use shield_axum::{AuthRoutes, ShieldLayer, auth_required}; - use shield_memory::{MemoryStorage, User}; - use shield_oidc::{Keycloak, OidcMethod}; - use time::Duration; - use tokio::net::TcpListener; - use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; - use tracing::{info, level_filters::LevelFilter}; - use utoipa::OpenApi; - use utoipa_swagger_ui::SwaggerUi; - // Initialize tracing tracing_subscriber::fmt() .with_max_level(LevelFilter::DEBUG) @@ -52,21 +50,23 @@ async fn main() { ); let shield_layer = ShieldLayer::new(shield.clone()); - // Initialize OpenAPI specification (optional) - #[derive(OpenApi)] - #[openapi(nest( - (path = "/api/auth", api = AuthRoutes, tags = ["auth"]), - ))] - struct Docs; + // Initialize API router + let api_router = OpenApiRouter::new() + .route("/protected", get(async || "Protected")) + .route_layer(from_fn(auth_required::)) + .nest("/auth", AuthRoutes::openapi_router::()); // Initialize router - let router = Router::new() - .route("/api/protected", get(async || "Protected")) - .route_layer(from_fn(auth_required::)) - .nest("/api/auth", AuthRoutes::router::()) - .merge(SwaggerUi::new("/api-docs").url("/api/openapi.json", Docs::openapi())) + let (router, openapi) = OpenApiRouter::new() + .nest("/api", api_router) .layer(shield_layer) - .layer(session_layer); + .layer(session_layer) + .split_for_parts(); + + // Add Scalar and OpenAPI specification + let router = router + .merge(Scalar::with_url("/api/reference", openapi.clone())) + .route("/api/openapi.json", get(|| async { Json(openapi) })); // Start app info!("listening on http://{}", &addr); diff --git a/examples/leptos-axum/Cargo.toml b/examples/leptos-axum/Cargo.toml index c6d05ed..831416d 100644 --- a/examples/leptos-axum/Cargo.toml +++ b/examples/leptos-axum/Cargo.toml @@ -48,9 +48,7 @@ leptos_router.workspace = true shield.workspace = true shield-bootstrap = { workspace = true, features = ["leptos"] } shield-leptos.workspace = true -shield-leptos-axum = { workspace = true, features = [ - "utoipa", -], optional = true } +shield-leptos-axum = { workspace = true, optional = true } shield-memory = { workspace = true, optional = true } shield-oidc = { workspace = true, features = ["native-tls"], optional = true } time = "0.3.37" @@ -61,7 +59,5 @@ tokio = { workspace = true, features = [ tower-sessions = { workspace = true, optional = true } tracing.workspace = true tracing-subscriber.workspace = true -utoipa.workspace = true -utoipa-swagger-ui = { version = "9.0.0", features = ["axum", "vendored"] } wasm-bindgen.workspace = true wasm-tracing.workspace = true diff --git a/examples/leptos-axum/src/main.rs b/examples/leptos-axum/src/main.rs index 0e8d973..811fcfe 100644 --- a/examples/leptos-axum/src/main.rs +++ b/examples/leptos-axum/src/main.rs @@ -19,8 +19,6 @@ async fn main() { use tokio::net::TcpListener; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; use tracing::{info, level_filters::LevelFilter}; - use utoipa::OpenApi; - use utoipa_swagger_ui::SwaggerUi; // Initialize tracing tracing_subscriber::fmt() @@ -60,19 +58,11 @@ async fn main() { ); let shield_layer = ShieldLayer::new(shield.clone()); - // Initialize OpenAPI specification (optional) - #[derive(OpenApi)] - #[openapi(nest( - (path = "/api/auth", api = AuthRoutes, tags = ["auth"]), - ))] - struct Docs; - // Initialize router let router = Router::new() .route("/api/protected", get(async || "Protected")) .route_layer(from_fn(auth_required::)) .nest("/api/auth", AuthRoutes::router::()) - .merge(SwaggerUi::new("/api-docs").url("/api/openapi.json", Docs::openapi())) .leptos_routes_with_context( &leptos_options, routes, diff --git a/packages/core/shield/src/action.rs b/packages/core/shield/src/action.rs index f21a69c..f9db85b 100644 --- a/packages/core/shield/src/action.rs +++ b/packages/core/shield/src/action.rs @@ -14,6 +14,8 @@ use crate::{ // TODO: Think of a better name. #[derive(Clone, Debug, Deserialize, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] pub struct ActionForms { pub id: String, pub name: String, @@ -22,6 +24,8 @@ pub struct ActionForms { // TODO: Think of a better name. #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] pub struct ActionMethodForm { pub id: String, pub provider_forms: Vec, @@ -29,6 +33,8 @@ pub struct ActionMethodForm { // TODO: Think of a better name. #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] pub struct ActionProviderForm { pub id: Option, pub form: Form, diff --git a/packages/core/shield/src/form.rs b/packages/core/shield/src/form.rs index d30802e..a0c557d 100644 --- a/packages/core/shield/src/form.rs +++ b/packages/core/shield/src/form.rs @@ -1,11 +1,15 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] pub struct Form { pub inputs: Vec, } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "camelCase")] pub struct Input { pub name: String, pub label: Option, @@ -14,6 +18,8 @@ pub struct Input { } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(rename_all = "kebab-case")] pub enum InputType { Button(InputTypeButton), Checkbox(InputTypeCheckbox), @@ -69,15 +75,18 @@ impl InputType { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeButton {} #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeCheckbox { pub checked: Option, pub required: Option, } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeColor { pub alpha: Option, pub autocomplete: Option, @@ -86,6 +95,7 @@ pub struct InputTypeColor { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeDate { pub autocomplete: Option, pub list: Option, @@ -97,6 +107,7 @@ pub struct InputTypeDate { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeDatetimeLocal { pub autocomplete: Option, pub list: Option, @@ -108,6 +119,7 @@ pub struct InputTypeDatetimeLocal { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeEmail { pub autocomplete: Option, pub list: Option, @@ -122,6 +134,7 @@ pub struct InputTypeEmail { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeFile { pub accept: Option, pub multiple: Option, @@ -129,12 +142,14 @@ pub struct InputTypeFile { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeHidden { pub autocomplete: Option, pub required: Option, } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeImage { pub alt: Option, pub height: Option, @@ -143,6 +158,7 @@ pub struct InputTypeImage { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeMonth { pub autocomplete: Option, pub list: Option, @@ -154,6 +170,7 @@ pub struct InputTypeMonth { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeNumber { pub autocomplete: Option, pub list: Option, @@ -166,6 +183,7 @@ pub struct InputTypeNumber { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypePassword { pub autocomplete: Option, pub maxlength: Option, @@ -178,12 +196,14 @@ pub struct InputTypePassword { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeRadio { pub checked: Option, pub required: Option, } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeRange { pub autocomplete: Option, pub list: Option, @@ -193,9 +213,11 @@ pub struct InputTypeRange { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeReset {} #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeSearch { pub autocomplete: Option, pub list: Option, @@ -209,9 +231,11 @@ pub struct InputTypeSearch { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeSubmit {} #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeTel { pub autocomplete: Option, pub list: Option, @@ -225,6 +249,7 @@ pub struct InputTypeTel { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeText { pub autocomplete: Option, pub list: Option, @@ -238,6 +263,7 @@ pub struct InputTypeText { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeTime { pub autocomplete: Option, pub list: Option, @@ -249,6 +275,7 @@ pub struct InputTypeTime { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeUrl { pub autocomplete: Option, pub list: Option, @@ -262,6 +289,7 @@ pub struct InputTypeUrl { } #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] pub struct InputTypeWeek { pub autocomplete: Option, pub list: Option, diff --git a/packages/core/shield/src/shield.rs b/packages/core/shield/src/shield.rs index 7227115..e731df3 100644 --- a/packages/core/shield/src/shield.rs +++ b/packages/core/shield/src/shield.rs @@ -196,6 +196,42 @@ impl Shield { Ok(response.r#type) } + + pub async fn user(&self, session: &Session) -> Result, ShieldError> { + let authentication = { + let session_data = session.data(); + let session_data = session_data + .lock() + .map_err(|err| SessionError::Lock(err.to_string()))?; + + session_data.base.authentication.clone() + }; + + match authentication { + Some(authentication) => { + if self + .provider_by_id( + &authentication.method_id, + authentication.provider_id.as_deref(), + ) + .await? + .is_none() + { + session.purge().await?; + return Ok(None); + } + + let user = self.storage().user_by_id(&authentication.user_id).await?; + + if user.is_none() { + session.purge().await?; + } + + Ok(user) + } + None => Ok(None), + } + } } #[cfg(test)] diff --git a/packages/integrations/shield-axum/Cargo.toml b/packages/integrations/shield-axum/Cargo.toml index 364bdd6..6a59942 100644 --- a/packages/integrations/shield-axum/Cargo.toml +++ b/packages/integrations/shield-axum/Cargo.toml @@ -13,9 +13,10 @@ default = [] utoipa = ["dep:utoipa", "shield/utoipa"] [dependencies] -axum.workspace = true +axum = { workspace = true, features = ["macros"] } serde.workspace = true serde_json.workspace = true shield.workspace = true shield-tower.workspace = true utoipa = { workspace = true, features = ["axum_extras"], optional = true } +utoipa-axum.workspace = true diff --git a/packages/integrations/shield-axum/src/path.rs b/packages/integrations/shield-axum/src/path.rs index f8a0c5c..2d14c0b 100644 --- a/packages/integrations/shield-axum/src/path.rs +++ b/packages/integrations/shield-axum/src/path.rs @@ -1,5 +1,13 @@ use serde::Deserialize; +#[derive(Deserialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))] +#[serde(rename_all = "camelCase")] +pub struct ActionFormsPathParams { + /// ID of the action. + pub action_id: String, +} + #[derive(Deserialize)] #[cfg_attr(feature = "utoipa", derive(utoipa::IntoParams))] #[serde(rename_all = "camelCase")] diff --git a/packages/integrations/shield-axum/src/router.rs b/packages/integrations/shield-axum/src/router.rs index 09e4c2e..eef7a6e 100644 --- a/packages/integrations/shield-axum/src/router.rs +++ b/packages/integrations/shield-axum/src/router.rs @@ -3,16 +3,32 @@ use axum::{ routing::{get, post}, }; use shield::User; +use utoipa::OpenApi; +use utoipa_axum::router::OpenApiRouter; use crate::routes::*; #[cfg_attr(feature = "utoipa", derive(utoipa::OpenApi))] -#[cfg_attr(feature = "utoipa", openapi(paths()))] +#[cfg_attr(feature = "utoipa", openapi(paths(action, forms, user)))] pub struct AuthRoutes; impl AuthRoutes { pub fn router() -> Router { Router::new() + .route("/user", get(user::)) + .route("/forms/{actionId}", get(forms::)) + .route("/{methodId}/{actionId}", get(action::)) + .route("/{methodId}/{actionId}", post(action::)) + .route("/{methodId}/{actionId}/{providerId}", get(action::)) + .route("/{methodId}/{actionId}/{providerId}", post(action::)) + } + + #[cfg(feature = "utoipa")] + pub fn openapi_router() + -> OpenApiRouter { + OpenApiRouter::with_openapi(AuthRoutes::openapi()) + .route("/user", get(user::)) + .route("/forms/{actionId}", get(forms::)) .route("/{methodId}/{actionId}", get(action::)) .route("/{methodId}/{actionId}", post(action::)) .route("/{methodId}/{actionId}/{providerId}", get(action::)) diff --git a/packages/integrations/shield-axum/src/routes.rs b/packages/integrations/shield-axum/src/routes.rs index 8bd911f..77712d6 100644 --- a/packages/integrations/shield-axum/src/routes.rs +++ b/packages/integrations/shield-axum/src/routes.rs @@ -1,3 +1,7 @@ mod action; +mod forms; +mod user; pub use action::*; +pub use forms::*; +pub use user::*; diff --git a/packages/integrations/shield-axum/src/routes/action.rs b/packages/integrations/shield-axum/src/routes/action.rs index e01c1ba..ecd0593 100644 --- a/packages/integrations/shield-axum/src/routes/action.rs +++ b/packages/integrations/shield-axum/src/routes/action.rs @@ -6,8 +6,27 @@ use axum::{ use serde_json::Value; use shield::{Request, ResponseType, User}; -use crate::{ExtractSession, ExtractShield, RouteError, path::ActionPathParams}; +use crate::{ExtractSession, ExtractShield, RouteError, error::ErrorBody, path::ActionPathParams}; +#[cfg_attr( + feature = "utoipa", + utoipa::path( + get, + post, + path = "/{methodId}/{actionId}/{providerId}", + operation_id = "callAction", + summary = "Call action", + description = "Call an action.", + tags = ["auth"], + params( + ActionPathParams + ), + responses( + (status = 302, description = "Redirect."), + (status = 500, description = "Internal server error.", body = ErrorBody), + ) + ) +)] pub async fn action( Path(ActionPathParams { method_id, diff --git a/packages/integrations/shield-axum/src/routes/forms.rs b/packages/integrations/shield-axum/src/routes/forms.rs new file mode 100644 index 0000000..a4d05de --- /dev/null +++ b/packages/integrations/shield-axum/src/routes/forms.rs @@ -0,0 +1,34 @@ +use axum::{Json, extract::Path}; +use shield::{ActionForms, User}; + +use crate::{ + ExtractSession, ExtractShield, RouteError, error::ErrorBody, path::ActionFormsPathParams, +}; + +#[cfg_attr( + feature = "utoipa", + utoipa::path( + get, + path = "/forms/{actionId}", + operation_id = "getActionForms", + summary = "Get action forms", + description = "Get action forms.", + tags = ["auth"], + params( + ActionFormsPathParams + ), + responses( + (status = 200, description = "The action forms.", body = ActionForms), + (status = 500, description = "Internal server error.", body = ErrorBody), + ) + ) +)] +pub async fn forms( + Path(ActionFormsPathParams { action_id, .. }): Path, + ExtractShield(shield): ExtractShield, + ExtractSession(session): ExtractSession, +) -> Result, RouteError> { + let forms = shield.action_forms(&action_id, session).await?; + + Ok(Json(forms)) +} diff --git a/packages/integrations/shield-axum/src/routes/user.rs b/packages/integrations/shield-axum/src/routes/user.rs new file mode 100644 index 0000000..e00d771 --- /dev/null +++ b/packages/integrations/shield-axum/src/routes/user.rs @@ -0,0 +1,56 @@ +use axum::Json; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use shield::{ConfigurationError, EmailAddress, ShieldError, User}; + +use crate::{RouteError, error::ErrorBody, extract::UserRequired}; + +#[derive(Deserialize, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[cfg_attr(feature = "utoipa", schema(as = User))] +#[serde(rename_all = "camelCase")] +pub struct UserBody { + id: String, + name: Option, + email_addresses: Vec, + additional: Value, +} + +impl UserBody { + async fn new(user: U) -> Result { + let email_addresses = user.email_addresses().await?; + + Ok(Self { + id: user.id(), + name: user.name(), + email_addresses, + additional: serde_json::to_value(user.additional()).map_err(|err| { + ConfigurationError::Invalid(format!( + "additional user data is not serializable: {err}" + )) + })?, + }) + } +} + +#[cfg_attr( + feature = "utoipa", + utoipa::path( + get, + path = "/user", + operation_id = "getCurrentUser", + summary = "Get current user", + description = "Get the current user account.", + tags = ["auth"], + responses( + (status = 200, description = "The current user account.", body = UserBody), + (status = 401, description = "No account signed in.", body = ErrorBody), + (status = 500, description = "Internal server error.", body = ErrorBody), + ) + ) +)] +pub async fn user( + UserRequired(user): UserRequired, +) -> Result, RouteError> { + Ok(Json(UserBody::new(user).await?)) +} diff --git a/packages/integrations/shield-tower/src/service.rs b/packages/integrations/shield-tower/src/service.rs index ee7345d..a0b40f4 100644 --- a/packages/integrations/shield-tower/src/service.rs +++ b/packages/integrations/shield-tower/src/service.rs @@ -74,14 +74,14 @@ where }; let shield_session = Session::new(session_storage); - // let user = match shield.user(&shield_session).await { - // Ok(user) => user, - // Err(_err) => return Ok(Self::internal_server_error()), - // }; + let user = match shield.user(&shield_session).await { + Ok(user) => user, + Err(_err) => return Ok(Self::internal_server_error()), + }; req.extensions_mut().insert(shield); req.extensions_mut().insert(shield_session); - // req.extensions_mut().insert(user); + req.extensions_mut().insert(user); inner.call(req).await })