From 661c975ffcfee1dbf9085a0ac5061df8fa7f4179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABlle=20Huisman?= Date: Fri, 14 Nov 2025 21:28:21 +0100 Subject: [PATCH] feat: add redirect origin --- Cargo.lock | 27 +++--- Cargo.toml | 1 + examples/axum/Cargo.toml | 1 + examples/axum/src/main.rs | 34 ++++--- .../shield/src/actions/sign_in_callback.rs | 2 +- packages/core/shield/src/actions/sign_out.rs | 11 ++- packages/core/shield/src/form.rs | 10 +- .../src/client/@tanstack/react-query.gen.ts | 90 +++++++++++++++++- .../shield-react/src/client/client/index.ts | 1 - .../src/client/client/types.gen.ts | 20 +--- .../src/client/client/utils.gen.ts | 14 +-- .../src/client/core/bodySerializer.gen.ts | 16 +++- .../src/client/core/params.gen.ts | 43 +++++++-- .../shield-react/src/client/sdk.gen.ts | 47 ++++++++++ .../shield-react/src/client/types.gen.ts | 92 ++++++++++++++++++- packages/methods/shield-oauth/Cargo.toml | 1 + .../shield-oauth/src/actions/sign_in.rs | 41 ++++++++- .../src/actions/sign_in_callback.rs | 1 + packages/methods/shield-oauth/src/method.rs | 2 +- packages/methods/shield-oauth/src/options.rs | 4 + packages/methods/shield-oauth/src/session.rs | 2 + packages/methods/shield-oidc/Cargo.toml | 1 + .../shield-oidc/src/actions/sign_in.rs | 62 +++++++++++-- .../src/actions/sign_in_callback.rs | 14 ++- packages/methods/shield-oidc/src/method.rs | 2 +- packages/methods/shield-oidc/src/options.rs | 4 + packages/methods/shield-oidc/src/session.rs | 2 + .../shield-workos/src/actions/index.rs | 26 ++++-- .../shield-workos/src/actions/sign_in.rs | 10 +- .../shield-workos/src/actions/sign_up.rs | 10 +- .../shield-bootstrap/src/dioxus/input.rs | 7 +- .../shield-bootstrap/src/leptos/input.rs | 7 +- .../src/components/style/input.tsx | 17 +++- 33 files changed, 513 insertions(+), 109 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a719580..184932f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2233,7 +2233,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2350,9 +2350,9 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" dependencies = [ "percent-encoding", ] @@ -3049,9 +3049,9 @@ checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" [[package]] name = "idna" -version = "1.0.3" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", "smallvec", @@ -3977,7 +3977,7 @@ version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "chrono", "getrandom 0.2.16", "http 1.3.1", @@ -4537,7 +4537,7 @@ dependencies = [ "once_cell", "socket2 0.5.10", "tracing", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4959,7 +4959,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4972,7 +4972,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.9.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5650,6 +5650,7 @@ dependencies = [ "tower-sessions", "tracing", "tracing-subscriber", + "url", "utoipa-axum", "utoipa-scalar", ] @@ -5807,6 +5808,7 @@ dependencies = [ "serde", "serde_json", "shield", + "url", ] [[package]] @@ -5823,6 +5825,7 @@ dependencies = [ "serde_json", "shield", "tracing", + "url", ] [[package]] @@ -6425,7 +6428,7 @@ dependencies = [ "getrandom 0.3.3", "once_cell", "rustix 1.0.7", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -7025,9 +7028,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.4" +version = "2.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" dependencies = [ "form_urlencoded", "idna", diff --git a/Cargo.toml b/Cargo.toml index bea0ba7..b6581ae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,7 @@ tower-service = "0.3.3" tower-sessions = "0.14.0" tracing = "0.1.41" tracing-subscriber = "0.3.19" +url = "2.5.7" utoipa = { version = "5.3.1", features = ["chrono", "uuid"] } utoipa-axum = "0.2.0" uuid = "1.11.0" diff --git a/examples/axum/Cargo.toml b/examples/axum/Cargo.toml index 5bac1ac..0e45fc0 100644 --- a/examples/axum/Cargo.toml +++ b/examples/axum/Cargo.toml @@ -20,5 +20,6 @@ tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tower-sessions.workspace = true tracing.workspace = true tracing-subscriber.workspace = true +url.workspace = true utoipa-axum.workspace = true utoipa-scalar = { version = "0.3.0", features = ["axum"] } diff --git a/examples/axum/src/main.rs b/examples/axum/src/main.rs index 6bf5ff7..56e0b48 100644 --- a/examples/axum/src/main.rs +++ b/examples/axum/src/main.rs @@ -5,11 +5,12 @@ use axum::{Json, middleware::from_fn, routing::get}; use shield::{Shield, ShieldOptions}; use shield_axum::{AuthRoutes, ShieldLayer, auth_required}; use shield_memory::{MemoryStorage, User}; -use shield_oidc::{Keycloak, OidcMethod}; +use shield_oidc::{Keycloak, OidcMethod, OidcOptions}; use time::Duration; use tokio::net::TcpListener; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; use tracing::{info, level_filters::LevelFilter}; +use url::Url; use utoipa_axum::router::OpenApiRouter; use utoipa_scalar::{Scalar, Servable}; @@ -34,17 +35,26 @@ async fn main() { let shield = Shield::new( storage.clone(), vec![Arc::new( - OidcMethod::new(storage).with_providers([Keycloak::builder( - "keycloak", - "http://localhost:18080/realms/Shield", - "client1", - ) - .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") - .redirect_url(format!( - "http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak", - addr.port() - )) - .build()]), + OidcMethod::new(storage) + .with_providers([Keycloak::builder( + "keycloak", + "http://localhost:18080/realms/Shield", + "client1", + ) + .client_secret("xcpQsaGbRILTljPtX4npjmYMBjKrariJ") + .redirect_url(format!( + "http://localhost:{}/api/auth/oidc/sign-in-callback/keycloak", + addr.port() + )) + .build()]) + .with_options( + OidcOptions::builder() + .redirect_origins([ + Url::parse(&format!("http://localhost:{}", addr.port())).unwrap(), + Url::parse("http://localhost:5173").unwrap(), + ]) + .build(), + ), )], ShieldOptions::default(), ); diff --git a/packages/core/shield/src/actions/sign_in_callback.rs b/packages/core/shield/src/actions/sign_in_callback.rs index 76d743c..a7a0e9b 100644 --- a/packages/core/shield/src/actions/sign_in_callback.rs +++ b/packages/core/shield/src/actions/sign_in_callback.rs @@ -1,4 +1,4 @@ -use crate::{MethodSession, Provider, ShieldError}; +use crate::{error::ShieldError, provider::Provider, session::MethodSession}; const ACTION_ID: &str = "sign-in-callback"; const ACTION_NAME: &str = "Sign in callback"; diff --git a/packages/core/shield/src/actions/sign_out.rs b/packages/core/shield/src/actions/sign_out.rs index a2e46a9..ee5209d 100644 --- a/packages/core/shield/src/actions/sign_out.rs +++ b/packages/core/shield/src/actions/sign_out.rs @@ -1,4 +1,9 @@ -use crate::{Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, ShieldError}; +use crate::{ + error::ShieldError, + form::{Form, Input, InputType, InputTypeSubmit, InputValue}, + provider::Provider, + session::MethodSession, +}; const ACTION_ID: &str = "sign-out"; const ACTION_NAME: &str = "Sign out"; @@ -37,7 +42,9 @@ impl SignOutAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit {}), - value: Some(Self::name()), + value: Some(InputValue::String { + value: Self::name(), + }), }], }]) } diff --git a/packages/core/shield/src/form.rs b/packages/core/shield/src/form.rs index fa4dbd4..b60982d 100644 --- a/packages/core/shield/src/form.rs +++ b/packages/core/shield/src/form.rs @@ -14,7 +14,15 @@ pub struct Input { pub name: String, pub label: Option, pub r#type: InputType, - pub value: Option, + pub value: Option, +} + +#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum InputValue { + Origin, + String { value: String }, } #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)] diff --git a/packages/integrations/shield-react/src/client/@tanstack/react-query.gen.ts b/packages/integrations/shield-react/src/client/@tanstack/react-query.gen.ts index 7577dd3..653888b 100644 --- a/packages/integrations/shield-react/src/client/@tanstack/react-query.gen.ts +++ b/packages/integrations/shield-react/src/client/@tanstack/react-query.gen.ts @@ -1,9 +1,25 @@ // This file is auto-generated by @hey-api/openapi-ts -import { type UseMutationOptions, queryOptions } from '@tanstack/react-query'; +import { type DefaultError, type UseMutationOptions, queryOptions } from '@tanstack/react-query'; import { client } from '../client.gen.js'; -import { type Options, callAction, getActionForms, getCurrentUser } from '../sdk.gen.js'; -import type { CallActionData, CallActionError, GetActionFormsData, GetCurrentUserData } from '../types.gen.js'; +import { + type Options, + callAction, + getActionForms, + getCurrentUser, + signInCallbackOidc, + signInOidc, + signOutOidc, +} from '../sdk.gen.js'; +import type { + CallActionData, + CallActionError, + GetActionFormsData, + GetCurrentUserData, + SignInCallbackOidcData, + SignInOidcData, + SignOutOidcData, +} from '../types.gen.js'; export type QueryKey = [ Pick & { @@ -49,6 +65,7 @@ export const getActionFormsQueryKey = (options: Options) => /** * Get action forms + * * Get action forms. */ export const getActionFormsOptions = (options: Options) => { @@ -66,11 +83,77 @@ export const getActionFormsOptions = (options: Options) => { }); }; +export const signInCallbackOidcQueryKey = (options: Options) => + createQueryKey('signInCallbackOidc', options); + +/** + * Sign in callback for OpenID Connect + * + * Sign in callback for OpenID Connect. + */ +export const signInCallbackOidcOptions = (options: Options) => { + return queryOptions({ + queryFn: async ({ queryKey, signal }) => { + const { data } = await signInCallbackOidc({ + ...options, + ...queryKey[0], + signal, + throwOnError: true, + }); + return data; + }, + queryKey: signInCallbackOidcQueryKey(options), + }); +}; + +/** + * Sign in with OpenID Connect + * + * Sign in with OpenID Connect. + */ +export const signInOidcMutation = ( + options?: Partial>, +): UseMutationOptions> => { + const mutationOptions: UseMutationOptions> = { + mutationFn: async (fnOptions) => { + const { data } = await signInOidc({ + ...options, + ...fnOptions, + throwOnError: true, + }); + return data; + }, + }; + return mutationOptions; +}; + +/** + * Sign out with OpenID Connect + * + * Sign out with OpenID Connect. + */ +export const signOutOidcMutation = ( + options?: Partial>, +): UseMutationOptions> => { + const mutationOptions: UseMutationOptions> = { + mutationFn: async (fnOptions) => { + const { data } = await signOutOidc({ + ...options, + ...fnOptions, + throwOnError: true, + }); + return data; + }, + }; + return mutationOptions; +}; + export const getCurrentUserQueryKey = (options?: Options) => createQueryKey('getCurrentUser', options); /** * Get current user + * * Get the current user account. */ export const getCurrentUserOptions = (options?: Options) => { @@ -90,6 +173,7 @@ export const getCurrentUserOptions = (options?: Options) => /** * Call action + * * Call an action. */ export const callActionMutation = ( diff --git a/packages/integrations/shield-react/src/client/client/index.ts b/packages/integrations/shield-react/src/client/client/index.ts index 9fea7b5..35eb1f9 100644 --- a/packages/integrations/shield-react/src/client/client/index.ts +++ b/packages/integrations/shield-react/src/client/client/index.ts @@ -16,7 +16,6 @@ export type { Config, CreateClientConfig, Options, - OptionsLegacyParser, RequestOptions, RequestResult, ResolvedRequestOptions, diff --git a/packages/integrations/shield-react/src/client/client/types.gen.ts b/packages/integrations/shield-react/src/client/client/types.gen.ts index d58b116..756dbe3 100644 --- a/packages/integrations/shield-react/src/client/client/types.gen.ts +++ b/packages/integrations/shield-react/src/client/client/types.gen.ts @@ -163,7 +163,7 @@ type BuildUrlFn = < url: string; }, >( - options: Pick & Options, + options: TData & Options, ) => string; export type Client = CoreClient & { @@ -198,20 +198,4 @@ export type Options< TResponse = unknown, TResponseStyle extends ResponseStyle = 'fields', > = OmitKeys, 'body' | 'path' | 'query' | 'url'> & - Omit; - -export type OptionsLegacyParser< - TData = unknown, - ThrowOnError extends boolean = boolean, - TResponseStyle extends ResponseStyle = 'fields', -> = TData extends { body?: any } - ? TData extends { headers?: any } - ? OmitKeys, 'body' | 'headers' | 'url'> & TData - : OmitKeys, 'body' | 'url'> & - TData & - Pick, 'headers'> - : TData extends { headers?: any } - ? OmitKeys, 'headers' | 'url'> & - TData & - Pick, 'body'> - : OmitKeys, 'url'> & TData; + ([TData] extends [never] ? unknown : Omit); diff --git a/packages/integrations/shield-react/src/client/client/utils.gen.ts b/packages/integrations/shield-react/src/client/client/utils.gen.ts index 21f7a82..e722a6f 100644 --- a/packages/integrations/shield-react/src/client/client/utils.gen.ts +++ b/packages/integrations/shield-react/src/client/client/utils.gen.ts @@ -6,7 +6,7 @@ import { serializeArrayParam, serializeObjectParam, serializePrimitiveParam } fr import { getUrl } from '../core/utils.gen.js'; import type { Client, ClientOptions, Config, RequestOptions } from './types.gen.js'; -export const createQuerySerializer = ({ allowReserved, array, object }: QuerySerializerOptions = {}) => { +export const createQuerySerializer = ({ parameters = {}, ...args }: QuerySerializerOptions = {}) => { const querySerializer = (queryParams: T) => { const search: string[] = []; if (queryParams && typeof queryParams === 'object') { @@ -17,29 +17,31 @@ export const createQuerySerializer = ({ allowReserved, array, objec continue; } + const options = parameters[name] || args; + if (Array.isArray(value)) { const serializedArray = serializeArrayParam({ - allowReserved, + allowReserved: options.allowReserved, explode: true, name, style: 'form', value, - ...array, + ...options.array, }); if (serializedArray) search.push(serializedArray); } else if (typeof value === 'object') { const serializedObject = serializeObjectParam({ - allowReserved, + allowReserved: options.allowReserved, explode: true, name, style: 'deepObject', value: value as Record, - ...object, + ...options.object, }); if (serializedObject) search.push(serializedObject); } else { const serializedPrimitive = serializePrimitiveParam({ - allowReserved, + allowReserved: options.allowReserved, name, value: value as string, }); diff --git a/packages/integrations/shield-react/src/client/core/bodySerializer.gen.ts b/packages/integrations/shield-react/src/client/core/bodySerializer.gen.ts index fe51a1a..f81441e 100644 --- a/packages/integrations/shield-react/src/client/core/bodySerializer.gen.ts +++ b/packages/integrations/shield-react/src/client/core/bodySerializer.gen.ts @@ -5,11 +5,19 @@ export type QuerySerializer = (query: Record) => string; export type BodySerializer = (body: any) => any; -export interface QuerySerializerOptions { +type QuerySerializerOptionsObject = { allowReserved?: boolean; - array?: SerializerOptions; - object?: SerializerOptions; -} + array?: Partial>; + object?: Partial>; +}; + +export type QuerySerializerOptions = QuerySerializerOptionsObject & { + /** + * Per-parameter serialization overrides. When provided, these settings + * override the global array/object settings for specific parameter names. + */ + parameters?: Record; +}; const serializeFormDataPair = (data: FormData, key: string, value: unknown): void => { if (typeof value === 'string' || value instanceof Blob) { diff --git a/packages/integrations/shield-react/src/client/core/params.gen.ts b/packages/integrations/shield-react/src/client/core/params.gen.ts index 95b2705..78cafbb 100644 --- a/packages/integrations/shield-react/src/client/core/params.gen.ts +++ b/packages/integrations/shield-react/src/client/core/params.gen.ts @@ -22,6 +22,17 @@ export type Field = */ key?: string; map?: string; + } + | { + /** + * Field name. This is the name we want the user to see and use. + */ + key: string; + /** + * Field mapped name. This is the name we want to use in the request. + * If `in` is omitted, `map` aliases `key` to the transport layer. + */ + map: Slot; }; export interface Fields { @@ -41,10 +52,14 @@ const extraPrefixes = Object.entries(extraPrefixesMap); type KeyMap = Map< string, - { - in: Slot; - map?: string; - } + | { + in: Slot; + map?: string; + } + | { + in?: never; + map: Slot; + } >; const buildKeyMap = (fields: FieldsConfig, map?: KeyMap): KeyMap => { @@ -60,6 +75,10 @@ const buildKeyMap = (fields: FieldsConfig, map?: KeyMap): KeyMap => { map: config.map, }); } + } else if ('key' in config) { + map.set(config.key, { + map: config.map, + }); } else if (config.args) { buildKeyMap(config.args, map); } @@ -108,7 +127,9 @@ export const buildClientParams = (args: ReadonlyArray, fields: FieldsCo if (config.key) { const field = map.get(config.key)!; const name = field.map || config.key; - (params[field.in] as Record)[name] = arg; + if (field.in) { + (params[field.in] as Record)[name] = arg; + } } else { params.body = arg; } @@ -117,16 +138,20 @@ export const buildClientParams = (args: ReadonlyArray, fields: FieldsCo const field = map.get(key); if (field) { - const name = field.map || key; - (params[field.in] as Record)[name] = value; + if (field.in) { + const name = field.map || key; + (params[field.in] as Record)[name] = value; + } else { + params[field.map] = value; + } } else { const extra = extraPrefixes.find(([prefix]) => key.startsWith(prefix)); if (extra) { const [prefix, slot] = extra; (params[slot] as Record)[key.slice(prefix.length)] = value; - } else { - for (const [slot, allowed] of Object.entries(config.allowExtra ?? {})) { + } else if ('allowExtra' in config && config.allowExtra) { + for (const [slot, allowed] of Object.entries(config.allowExtra)) { if (allowed) { (params[slot as Slot] as Record)[key] = value; break; diff --git a/packages/integrations/shield-react/src/client/sdk.gen.ts b/packages/integrations/shield-react/src/client/sdk.gen.ts index 6b75304..445446b 100644 --- a/packages/integrations/shield-react/src/client/sdk.gen.ts +++ b/packages/integrations/shield-react/src/client/sdk.gen.ts @@ -15,6 +15,12 @@ import type { GetCurrentUserData, GetCurrentUserErrors, GetCurrentUserResponses, + SignInCallbackOidcData, + SignInCallbackOidcErrors, + SignInOidcData, + SignInOidcErrors, + SignOutOidcData, + SignOutOidcErrors, } from './types.gen.js'; export type Options = Options2< @@ -36,6 +42,7 @@ export type Options( @@ -47,8 +54,47 @@ export const getActionForms = ( }); }; +/** + * Sign in callback for OpenID Connect + * + * Sign in callback for OpenID Connect. + */ +export const signInCallbackOidc = ( + options: Options, +) => { + return (options.client ?? client).get({ + url: '/api/auth/oidc/sign-in-callback/{providerId}', + ...options, + }); +}; + +/** + * Sign in with OpenID Connect + * + * Sign in with OpenID Connect. + */ +export const signInOidc = (options: Options) => { + return (options.client ?? client).post({ + url: '/api/auth/oidc/sign-in/{providerId}', + ...options, + }); +}; + +/** + * Sign out with OpenID Connect + * + * Sign out with OpenID Connect. + */ +export const signOutOidc = (options: Options) => { + return (options.client ?? client).post({ + url: '/api/auth/oidc/sign-out/{providerId}', + ...options, + }); +}; + /** * Get current user + * * Get the current user account. */ export const getCurrentUser = ( @@ -62,6 +108,7 @@ export const getCurrentUser = ( /** * Call action + * * Call an action. */ export const callAction = (options: Options) => { diff --git a/packages/integrations/shield-react/src/client/types.gen.ts b/packages/integrations/shield-react/src/client/types.gen.ts index 0c49c3f..9091642 100644 --- a/packages/integrations/shield-react/src/client/types.gen.ts +++ b/packages/integrations/shield-react/src/client/types.gen.ts @@ -41,7 +41,7 @@ export type Input = { label?: string | null; name: string; type: InputType; - value?: string | null; + value?: null | InputValue; }; export type InputType = @@ -300,6 +300,15 @@ export type InputTypeWeek = { step?: string | null; }; +export type InputValue = + | { + type: 'origin'; + } + | { + type: 'string'; + value: string; + }; + export type User = { additional: unknown; emailAddresses: Array; @@ -337,6 +346,87 @@ export type GetActionFormsResponses = { export type GetActionFormsResponse = GetActionFormsResponses[keyof GetActionFormsResponses]; +export type SignInCallbackOidcData = { + body?: never; + path: { + /** + * ID of the method. + */ + methodId: string; + /** + * ID of the action. + */ + actionId: string; + /** + * ID of provider (optional). + */ + providerId: string | null; + }; + query?: never; + url: '/api/auth/oidc/sign-in-callback/{providerId}'; +}; + +export type SignInCallbackOidcErrors = { + /** + * Internal server error. + */ + 500: unknown; +}; + +export type SignInOidcData = { + body?: never; + path: { + /** + * ID of the method. + */ + methodId: string; + /** + * ID of the action. + */ + actionId: string; + /** + * ID of provider (optional). + */ + providerId: string | null; + }; + query?: never; + url: '/api/auth/oidc/sign-in/{providerId}'; +}; + +export type SignInOidcErrors = { + /** + * Internal server error. + */ + 500: unknown; +}; + +export type SignOutOidcData = { + body?: never; + path: { + /** + * ID of the method. + */ + methodId: string; + /** + * ID of the action. + */ + actionId: string; + /** + * ID of provider (optional). + */ + providerId: string | null; + }; + query?: never; + url: '/api/auth/oidc/sign-out/{providerId}'; +}; + +export type SignOutOidcErrors = { + /** + * Internal server error. + */ + 500: unknown; +}; + export type GetCurrentUserData = { body?: never; path?: never; diff --git a/packages/methods/shield-oauth/Cargo.toml b/packages/methods/shield-oauth/Cargo.toml index 3970739..d97906f 100644 --- a/packages/methods/shield-oauth/Cargo.toml +++ b/packages/methods/shield-oauth/Cargo.toml @@ -28,3 +28,4 @@ secrecy.workspace = true serde.workspace = true serde_json.workspace = true shield.workspace = true +url.workspace = true diff --git a/packages/methods/shield-oauth/src/actions/sign_in.rs b/packages/methods/shield-oauth/src/actions/sign_in.rs index 2e3719a..eb87254 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in.rs @@ -1,17 +1,34 @@ use async_trait::async_trait; use oauth2::{CsrfToken, PkceCodeChallenge, Scope, url::form_urlencoded::parse}; +use serde::Deserialize; use shield::{ - Action, ActionMethod, ConfigurationError, Form, Input, InputType, InputTypeSubmit, + Action, ActionMethod, ConfigurationError, Form, Input, InputType, InputTypeSubmit, InputValue, MethodSession, Provider, Request, Response, ResponseType, SessionAction, ShieldError, SignInAction, erased_action, }; +use url::Url; use crate::{ + options::OauthOptions, provider::{OauthProvider, OauthProviderPkceCodeChallenge}, session::OauthSession, }; -pub struct OauthSignInAction; +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignInData { + pub redirect_origin: Option, +} + +pub struct OauthSignInAction { + options: OauthOptions, +} + +impl OauthSignInAction { + pub fn new(options: OauthOptions) -> Self { + Self { options } + } +} #[async_trait] impl Action for OauthSignInAction { @@ -41,7 +58,9 @@ impl Action for OauthSignInAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some(format!("Sign in with {}", provider.name())), + value: Some(InputValue::String { + value: format!("Sign in with {}", provider.name()), + }), }], }]) } @@ -50,8 +69,21 @@ impl Action for OauthSignInAction { &self, provider: OauthProvider, _session: &MethodSession, - _request: Request, + request: Request, ) -> Result { + let data = serde_json::from_value::(request.form_data) + .map_err(|err| ShieldError::Validation(err.to_string()))?; + + let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins + && let Some(redirect_origin) = data.redirect_origin + // TODO: Consider returning an error when redirect origin is not allowed. + && redirect_origins.contains(&redirect_origin) + { + Some(redirect_origin) + } else { + None + }; + let client = provider.oauth_client().await?; let mut authorization_request = client @@ -88,6 +120,7 @@ impl Action for OauthSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::Unauthenticate) .session_action(SessionAction::data(OauthSession { + redirect_origin, csrf: Some(csrf_token.secret().clone()), pkce_verifier: pkce_code_challenge .map(|(_, pkce_code_verifier)| pkce_code_verifier.secret().clone()), diff --git a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs index 20a0617..0623d08 100644 --- a/packages/methods/shield-oauth/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oauth/src/actions/sign_in_callback.rs @@ -261,6 +261,7 @@ impl Action for OauthSignInCallb )) .session_action(SessionAction::authenticate(user)) .session_action(SessionAction::data(OauthSession { + redirect_origin: None, csrf: None, pkce_verifier: None, oauth_connection_id: Some(connection.id), diff --git a/packages/methods/shield-oauth/src/method.rs b/packages/methods/shield-oauth/src/method.rs index db1f044..329fda1 100644 --- a/packages/methods/shield-oauth/src/method.rs +++ b/packages/methods/shield-oauth/src/method.rs @@ -73,7 +73,7 @@ impl Method for OauthMethod { fn actions(&self) -> Vec>> { vec![ - Box::new(OauthSignInAction), + Box::new(OauthSignInAction::new(self.options.clone())), Box::new(OauthSignInCallbackAction::new( self.options.clone(), self.storage.clone(), diff --git a/packages/methods/shield-oauth/src/options.rs b/packages/methods/shield-oauth/src/options.rs index 5045afc..bf6b4b9 100644 --- a/packages/methods/shield-oauth/src/options.rs +++ b/packages/methods/shield-oauth/src/options.rs @@ -1,10 +1,14 @@ use bon::Builder; +use url::Url; #[derive(Builder, Clone, Debug)] #[builder(on(String, into), state_mod(vis = "pub(crate)"))] pub struct OauthOptions { #[builder(default = "/")] pub(crate) sign_in_redirect: String, + + #[builder(with = FromIterator::from_iter)] + pub(crate) redirect_origins: Option>, } impl Default for OauthOptions { diff --git a/packages/methods/shield-oauth/src/session.rs b/packages/methods/shield-oauth/src/session.rs index 94f0401..4a70be4 100644 --- a/packages/methods/shield-oauth/src/session.rs +++ b/packages/methods/shield-oauth/src/session.rs @@ -1,7 +1,9 @@ use serde::{Deserialize, Serialize}; +use url::Url; #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct OauthSession { + pub redirect_origin: Option, pub csrf: Option, pub pkce_verifier: Option, pub oauth_connection_id: Option, diff --git a/packages/methods/shield-oidc/Cargo.toml b/packages/methods/shield-oidc/Cargo.toml index fd1d729..58b1fb7 100644 --- a/packages/methods/shield-oidc/Cargo.toml +++ b/packages/methods/shield-oidc/Cargo.toml @@ -31,3 +31,4 @@ serde.workspace = true serde_json.workspace = true shield.workspace = true tracing.workspace = true +url.workspace = true diff --git a/packages/methods/shield-oidc/src/actions/sign_in.rs b/packages/methods/shield-oidc/src/actions/sign_in.rs index 9e3089d..25d9ccd 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in.rs @@ -3,17 +3,35 @@ use openidconnect::{ CsrfToken, Nonce, PkceCodeChallenge, Scope, core::CoreAuthenticationFlow, url::form_urlencoded::parse, }; +use serde::Deserialize; use shield::{ - Action, ActionMethod, Form, Input, InputType, InputTypeSubmit, MethodSession, Provider, - Request, Response, ResponseType, SessionAction, ShieldError, SignInAction, erased_action, + Action, ActionMethod, Form, Input, InputType, InputTypeHidden, InputTypeSubmit, InputValue, + MethodSession, Provider, Request, Response, ResponseType, SessionAction, ShieldError, + SignInAction, erased_action, }; +use url::Url; use crate::{ + options::OidcOptions, provider::{OidcProvider, OidcProviderPkceCodeChallenge}, session::OidcSession, }; -pub struct OidcSignInAction; +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SignInData { + pub redirect_origin: Option, +} + +pub struct OidcSignInAction { + options: OidcOptions, +} + +impl OidcSignInAction { + pub fn new(options: OidcOptions) -> Self { + Self { options } + } +} #[async_trait] impl Action for OidcSignInAction { @@ -39,12 +57,22 @@ impl Action for OidcSignInAction { async fn forms(&self, provider: OidcProvider) -> Result, ShieldError> { Ok(vec![Form { - inputs: vec![Input { - name: "submit".to_owned(), - label: None, - r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some(format!("Sign in with {}", provider.name())), - }], + inputs: vec![ + Input { + name: "redirectOrigin".to_owned(), + label: None, + r#type: InputType::Hidden(InputTypeHidden::default()), + value: Some(InputValue::Origin), + }, + Input { + name: "submit".to_owned(), + label: None, + r#type: InputType::Submit(InputTypeSubmit::default()), + value: Some(InputValue::String { + value: format!("Sign in with {}", provider.name()), + }), + }, + ], }]) } @@ -52,8 +80,21 @@ impl Action for OidcSignInAction { &self, provider: OidcProvider, _session: &MethodSession, - _request: Request, + request: Request, ) -> Result { + let data = serde_json::from_value::(request.form_data) + .map_err(|err| ShieldError::Validation(err.to_string()))?; + + let redirect_origin = if let Some(redirect_origins) = &self.options.redirect_origins + && let Some(redirect_origin) = data.redirect_origin + // TODO: Consider returning an error when redirect origin is not allowed. + && redirect_origins.contains(&redirect_origin) + { + Some(redirect_origin) + } else { + None + }; + let client = provider.oidc_client().await?; let mut authorization_request = client.authorize_url( @@ -92,6 +133,7 @@ impl Action for OidcSignInAction { Ok(Response::new(ResponseType::Redirect(auth_url.to_string())) .session_action(SessionAction::unauthenticate()) .session_action(SessionAction::data(OidcSession { + redirect_origin, csrf: Some(csrf_token.secret().clone()), nonce: Some(nonce.secret().clone()), pkce_verifier: pkce_code_challenge diff --git a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs index ae70c28..2fc98d9 100644 --- a/packages/methods/shield-oidc/src/actions/sign_in_callback.rs +++ b/packages/methods/shield-oidc/src/actions/sign_in_callback.rs @@ -291,10 +291,22 @@ impl Action for OidcSignInCallback }; Ok(Response::new(ResponseType::Redirect( - self.options.sign_in_redirect.clone(), + session + .method + .redirect_origin + .as_ref() + .and_then(|redirect_origin| { + redirect_origin + .join(&self.options.sign_in_redirect) + .as_ref() + .map(ToString::to_string) + .ok() + }) + .unwrap_or_else(|| self.options.sign_in_redirect.clone()), )) .session_action(SessionAction::authenticate(user)) .session_action(SessionAction::data(OidcSession { + redirect_origin: None, csrf: None, nonce: None, pkce_verifier: None, diff --git a/packages/methods/shield-oidc/src/method.rs b/packages/methods/shield-oidc/src/method.rs index 5e1fc80..61bbd7b 100644 --- a/packages/methods/shield-oidc/src/method.rs +++ b/packages/methods/shield-oidc/src/method.rs @@ -73,7 +73,7 @@ impl Method for OidcMethod { fn actions(&self) -> Vec>> { vec![ - Box::new(OidcSignInAction), + Box::new(OidcSignInAction::new(self.options.clone())), Box::new(OidcSignInCallbackAction::new( self.options.clone(), self.storage.clone(), diff --git a/packages/methods/shield-oidc/src/options.rs b/packages/methods/shield-oidc/src/options.rs index a97ae45..0e2ce27 100644 --- a/packages/methods/shield-oidc/src/options.rs +++ b/packages/methods/shield-oidc/src/options.rs @@ -1,10 +1,14 @@ use bon::Builder; +use url::Url; #[derive(Builder, Clone, Debug)] #[builder(on(String, into), state_mod(vis = "pub(crate)"))] pub struct OidcOptions { #[builder(default = "/")] pub(crate) sign_in_redirect: String, + + #[builder(with = FromIterator::from_iter)] + pub(crate) redirect_origins: Option>, } impl Default for OidcOptions { diff --git a/packages/methods/shield-oidc/src/session.rs b/packages/methods/shield-oidc/src/session.rs index 743de46..5f57509 100644 --- a/packages/methods/shield-oidc/src/session.rs +++ b/packages/methods/shield-oidc/src/session.rs @@ -1,7 +1,9 @@ use serde::{Deserialize, Serialize}; +use url::Url; #[derive(Clone, Debug, Default, Deserialize, Serialize)] pub struct OidcSession { + pub redirect_origin: Option, pub csrf: Option, pub nonce: Option, pub pkce_verifier: Option, diff --git a/packages/methods/shield-workos/src/actions/index.rs b/packages/methods/shield-workos/src/actions/index.rs index fed5440..0f944fb 100644 --- a/packages/methods/shield-workos/src/actions/index.rs +++ b/packages/methods/shield-workos/src/actions/index.rs @@ -4,8 +4,8 @@ use async_trait::async_trait; use serde::Deserialize; use shield::{ Action, ActionMethod, Form, Input, InputType, InputTypeEmail, InputTypeHidden, InputTypeSubmit, - MethodSession, Request, Response, ResponseType, ShieldError, SignInAction, SignUpAction, - erased_action, + InputValue, MethodSession, Request, Response, ResponseType, ShieldError, SignInAction, + SignUpAction, erased_action, }; use workos::{ PaginationParams, @@ -93,7 +93,9 @@ impl Action for WorkosIndexAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some("Continue".to_owned()), + value: Some(InputValue::String { + value: "Continue".to_owned(), + }), }, ], }] @@ -111,14 +113,16 @@ impl Action for WorkosIndexAction { required: Some(true), ..Default::default() }), - value: Some(oauth_provider.to_string()), + value: Some(InputValue::String { + value: oauth_provider.to_string(), + }), }, Input { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some( - format!( + value: Some(InputValue::String { + value: format!( "Continue with {}", match oauth_provider { OauthProvider::AppleOAuth => "Apple", @@ -128,7 +132,7 @@ impl Action for WorkosIndexAction { } ) .to_owned(), - ), + }), }, ], }), @@ -142,13 +146,17 @@ impl Action for WorkosIndexAction { required: Some(true), ..Default::default() }), - value: Some(connection.id.to_string()), + value: Some(InputValue::String { + value: connection.id.to_string(), + }), }, Input { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some(format!("Continue with {}", connection.name).to_owned()), + value: Some(InputValue::String { + value: format!("Continue with {}", connection.name).to_owned(), + }), }, ], })) diff --git a/packages/methods/shield-workos/src/actions/sign_in.rs b/packages/methods/shield-workos/src/actions/sign_in.rs index 6f4a21c..6bbede8 100644 --- a/packages/methods/shield-workos/src/actions/sign_in.rs +++ b/packages/methods/shield-workos/src/actions/sign_in.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use shield::{ Action, ActionMethod, Form, Input, InputType, InputTypeEmail, InputTypeHidden, - InputTypePassword, InputTypeSubmit, MethodSession, Request, Response, ResponseType, + InputTypePassword, InputTypeSubmit, InputValue, MethodSession, Request, Response, ResponseType, ShieldError, SignInAction, erased_action, }; @@ -76,7 +76,9 @@ impl Action for WorkosSignInAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some("Sign in".to_owned()), + value: Some(InputValue::String { + value: "Sign in".to_owned(), + }), }, ], }, @@ -95,7 +97,9 @@ impl Action for WorkosSignInAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some("Email sign-in code".to_owned()), + value: Some(InputValue::String { + value: "Email sign-in code".to_owned(), + }), }, ], }, diff --git a/packages/methods/shield-workos/src/actions/sign_up.rs b/packages/methods/shield-workos/src/actions/sign_up.rs index b91f9b7..86db882 100644 --- a/packages/methods/shield-workos/src/actions/sign_up.rs +++ b/packages/methods/shield-workos/src/actions/sign_up.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_trait::async_trait; use shield::{ Action, ActionMethod, Form, Input, InputType, InputTypeEmail, InputTypeHidden, - InputTypePassword, InputTypeSubmit, MethodSession, Request, Response, ResponseType, + InputTypePassword, InputTypeSubmit, InputValue, MethodSession, Request, Response, ResponseType, ShieldError, SignUpAction, erased_action, }; @@ -76,7 +76,9 @@ impl Action for WorkosSignUpAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some("Sign up".to_owned()), + value: Some(InputValue::String { + value: "Sign up".to_owned(), + }), }, ], }, @@ -95,7 +97,9 @@ impl Action for WorkosSignUpAction { name: "submit".to_owned(), label: None, r#type: InputType::Submit(InputTypeSubmit::default()), - value: Some("Email sign-up code".to_owned()), + value: Some(InputValue::String { + value: "Email sign-up code".to_owned(), + }), }, ], }, diff --git a/packages/styles/shield-bootstrap/src/dioxus/input.rs b/packages/styles/shield-bootstrap/src/dioxus/input.rs index 2579f4e..1be9dca 100644 --- a/packages/styles/shield-bootstrap/src/dioxus/input.rs +++ b/packages/styles/shield-bootstrap/src/dioxus/input.rs @@ -1,5 +1,5 @@ use dioxus::prelude::*; -use shield::Input; +use shield::{Input, InputValue}; #[derive(Clone, PartialEq, Props)] pub struct FormInputProps { @@ -26,7 +26,10 @@ pub fn FormInput(props: FormInputProps) -> Element { class: "form-control", name: props.input.name, type: props.input.r#type.as_str(), - value: props.input.value.clone(), + value: props.input.value.map(|value| match value { + InputValue::Origin => todo!(), + InputValue::String { value } => value.clone(), + }), placeholder: props.input.label, } } diff --git a/packages/styles/shield-bootstrap/src/leptos/input.rs b/packages/styles/shield-bootstrap/src/leptos/input.rs index 380b320..a6f016d 100644 --- a/packages/styles/shield-bootstrap/src/leptos/input.rs +++ b/packages/styles/shield-bootstrap/src/leptos/input.rs @@ -1,5 +1,5 @@ use leptos::prelude::*; -use shield::Input; +use shield::{Input, InputValue}; #[component] pub fn FormInput(input: Input) -> impl IntoView { @@ -28,7 +28,10 @@ fn Control(input: Input) -> impl IntoView { // TODO: Support nested data (`data[user[name]]` should instead be `data[user][name]`). name=format!("data[{}]", input.name) r#type=input.r#type.as_str() - value=input.value.clone() + value=input.value.map(|value| match value { + InputValue::Origin => todo!(), + InputValue::String { value } => value.clone(), + }) placeholder=input.label /> } diff --git a/packages/styles/shield-react-shadcn-ui/src/components/style/input.tsx b/packages/styles/shield-react-shadcn-ui/src/components/style/input.tsx index f59f996..40497c7 100644 --- a/packages/styles/shield-react-shadcn-ui/src/components/style/input.tsx +++ b/packages/styles/shield-react-shadcn-ui/src/components/style/input.tsx @@ -1,4 +1,4 @@ -import type { Input as ApiInput } from '@rustforweb/shield-react'; +import type { Input as ApiInput, InputValue as ApiInputValue } from '@rustforweb/shield-react'; import { useId, useMemo } from 'react'; // import { Controller } from 'react-hook-form'; @@ -7,6 +7,17 @@ import { Button } from '../ui/button.js'; import { Field, FieldLabel } from '../ui/field.js'; import { Input } from '../ui/input.js'; +const inputValue = (value: ApiInputValue) => { + switch (value.type) { + case 'origin': { + return window.location.origin; + } + case 'string': { + return value.value; + } + } +}; + export type StyleInputProps = { // control: Control; input: ApiInput; @@ -19,7 +30,7 @@ export const StyleInput = ({ input }: StyleInputProps) => { if (input.type.type === 'button' || input.type.type === 'reset' || input.type.type === 'submit') { return ( ); } @@ -52,7 +63,7 @@ export const StyleInput = ({ input }: StyleInputProps) => { name={input.name} type={input.type.type} placeholder={input.label ?? undefined} - value={input.value ?? undefined} + value={input.value ? inputValue(input.value) : undefined} /> );