From f18c840509b21bb6b8793d879c2446020c1937ed Mon Sep 17 00:00:00 2001 From: maefall Date: Tue, 19 Nov 2024 20:31:15 +0100 Subject: [PATCH 01/14] Nix devshell: Fix the order of env vars (#9) --- flake.nix | 2 -- 1 file changed, 2 deletions(-) diff --git a/flake.nix b/flake.nix index c28b49e..7dc324d 100644 --- a/flake.nix +++ b/flake.nix @@ -54,8 +54,6 @@ export DATABASE_URL="postgres://$PGUSER:$PGPASSWORD@$PGHOST:$PGPORT/$db_name" export ROCKET_DATABASES="{roops={url=\"$DATABASE_URL\"}}" - - {db_name={url="postgres://user:password@localhost:5432/db_name"}} function start_dev_db { pg_ctl -D "$PGDATA" -l "$PGDATA/logfile" start From c825ba024973a25eeb1693330c284d81e361b906 Mon Sep 17 00:00:00 2001 From: maefall Date: Tue, 26 Nov 2024 12:24:20 +0100 Subject: [PATCH 02/14] (db): sessions & connections, user_id for users (#10) * Start work on the wrappers for database interactions * (db): sessions & connections, user_id for users * (flake): delete_dev_db() * refactor(field-name): Rename `database_connection` to `conn`, commonly used and straightforward * refactor(schema): Rename fields * refactor(db-models): Rename files from `db_interaction_models` to `models` for conciseness * refactor(db-wrappers): Rename methods * refactor(db-migrations): Remove indexes * refactor(dependencies): Uninstall `serde` and `serde_json` - Unnecessary * refactor(dependencies): Disable default features for `chrono` and `diesel` * refactor(http-errors): Catch all HTTP errors * chore(security): Set the `SameSite` attribute of the `session` cookie to `Strict` as it shouldn't involve any cross-site requests * refactor(roblox-oauth-callback): Update todo comment in accordance to recent implementation changes * chore(db): Update schema, migrations, and add model builders * refactor(sessions-db): Add token expiration timestamp * refactor(discord-connections-db): Make field names more concise * chore(discord-connections-db): Add nonce fields for encryption * refactor(fmt): Run formatter * chore(cipher): Implement encryption and decryption methods * chore(discord-connections-db): Implement token encryption method * refactor(fmt): Run formatter * refactor(tests): Correct placement of unit tests * refactor(env-example): Remove `DATABASE_URL` * refactor(env-example): Correct `openssl` example * fix(exports): Correct visibility of db exports * fix(discord-oauth-callback): Generate session using correct function * refactor(cipher): Cleanup * refactor(cipher): Use in-place encryption to avoid cloning the value and add tests for DiscordConnectionBuilder * refactor: Clean up file structure * fix(roblox-oauth): Import correct method * chore(discord-connections-db): Insert callback data into db * refactor(fmt): Run formatter * refactor(env): Add `DATABASE_URL` back and add db variables --------- Co-authored-by: nick <59822256+Archasion@users.noreply.github.com> --- .example.env | 21 +- Cargo.lock | 151 +++++++++++++ Cargo.toml | 6 +- flake.nix | 10 +- .../down.sql | 1 - .../up.sql | 12 - .../down.sql | 1 + .../up.sql | 11 + .../down.sql | 2 + .../up.sql | 14 ++ .../down.sql | 1 + .../2024-11-21-101635_create_sessions/up.sql | 8 + src/cipher.rs | 71 ++++++ src/database/mod.rs | 120 +--------- src/database/models.rs | 19 -- src/database/schema.rs | 7 - src/database/wrappers/account_links/mod.rs | 193 ++++++++++++++++ src/database/wrappers/account_links/models.rs | 72 ++++++ src/database/wrappers/account_links/schema.rs | 7 + .../wrappers/discord_connections/mod.rs | 66 ++++++ .../wrappers/discord_connections/models.rs | 206 ++++++++++++++++++ .../wrappers/discord_connections/schema.rs | 11 + src/database/wrappers/mod.rs | 3 + src/database/wrappers/sessions/mod.rs | 47 ++++ src/database/wrappers/sessions/models.rs | 71 ++++++ src/database/wrappers/sessions/schema.rs | 7 + src/lib.rs | 18 +- src/oauth/mod.rs | 23 +- src/oauth/rocket_routes/discord.rs | 66 ------ src/oauth/routes/discord.rs | 190 ++++++++++++++++ src/oauth/{rocket_routes => routes}/mod.rs | 0 src/oauth/{rocket_routes => routes}/roblox.rs | 13 +- .../{oauth_types.rs => types/discord.rs} | 200 +++-------------- src/oauth/types/mod.rs | 8 + src/oauth/types/roblox.rs | 145 ++++++++++++ src/oauth/{ => utils}/discord.rs | 12 +- src/oauth/utils/mod.rs | 18 ++ src/oauth/{ => utils}/pixy.rs | 10 +- src/oauth/{ => utils}/roblox.rs | 12 +- src/{util.rs => utils.rs} | 0 40 files changed, 1393 insertions(+), 460 deletions(-) delete mode 100644 migrations/2024-11-04-215323_create_users_table/down.sql delete mode 100644 migrations/2024-11-04-215323_create_users_table/up.sql create mode 100644 migrations/2024-11-21-101431_create_discord_connections/down.sql create mode 100644 migrations/2024-11-21-101431_create_discord_connections/up.sql create mode 100644 migrations/2024-11-21-101617_create_account_links/down.sql create mode 100644 migrations/2024-11-21-101617_create_account_links/up.sql create mode 100644 migrations/2024-11-21-101635_create_sessions/down.sql create mode 100644 migrations/2024-11-21-101635_create_sessions/up.sql create mode 100644 src/cipher.rs delete mode 100644 src/database/models.rs delete mode 100644 src/database/schema.rs create mode 100644 src/database/wrappers/account_links/mod.rs create mode 100644 src/database/wrappers/account_links/models.rs create mode 100644 src/database/wrappers/account_links/schema.rs create mode 100644 src/database/wrappers/discord_connections/mod.rs create mode 100644 src/database/wrappers/discord_connections/models.rs create mode 100644 src/database/wrappers/discord_connections/schema.rs create mode 100644 src/database/wrappers/mod.rs create mode 100644 src/database/wrappers/sessions/mod.rs create mode 100644 src/database/wrappers/sessions/models.rs create mode 100644 src/database/wrappers/sessions/schema.rs delete mode 100644 src/oauth/rocket_routes/discord.rs create mode 100644 src/oauth/routes/discord.rs rename src/oauth/{rocket_routes => routes}/mod.rs (100%) rename src/oauth/{rocket_routes => routes}/roblox.rs (81%) rename src/oauth/{oauth_types.rs => types/discord.rs} (52%) create mode 100644 src/oauth/types/mod.rs create mode 100644 src/oauth/types/roblox.rs rename src/oauth/{ => utils}/discord.rs (72%) create mode 100644 src/oauth/utils/mod.rs rename src/oauth/{ => utils}/pixy.rs (82%) rename src/oauth/{ => utils}/roblox.rs (71%) rename src/{util.rs => utils.rs} (100%) diff --git a/.example.env b/.example.env index b4eed5d..6fd477c 100644 --- a/.example.env +++ b/.example.env @@ -1,8 +1,19 @@ -# Make sure the db name matches the one set for DbConn in src/lib.rs -DATABASE_URL='postgres://user:password@localhost:5432/db_name' +# Discord OAuth2 client secret, you can get this from the Discord Developer Portal +DISCORD_CLIENT_SECRET="your_discord_client_secret" + +# A 32 byte encryption key, you can generate one with `openssl rand -base64 24` +ENCRYPTION_KEY="your_encryption_key" +# Database +DB_DRIVER='postgres' # Make sure the db name matches the one set for DbConn in src/lib.rs -ROCKET_DATABASES='{db_name={url="postgres://user:password@localhost:5432/db_name"}}' +DB_NAME='db_name' +DB_USER='user' +DB_PASSWORD='password' +DB_HOST='localhost' +DB_PORT='5432' + +# Use the variables above to construct the connection url +DATABASE_URL="${DB_DRIVER}://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME}" +ROCKET_DATABASES='{${DB_NAME}={url="${DATABASE_URL}"}}' -# Discord OAuth2 client secret, you can get this from the Discord Developer Portal -DISCORD_CLIENT_SECRET='your_discord_client_secret' diff --git a/Cargo.lock b/Cargo.lock index ab2a4d2..3997621 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -61,6 +61,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -166,6 +181,12 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + [[package]] name = "bytemuck" version = "1.19.0" @@ -199,6 +220,21 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-targets 0.52.6", +] + [[package]] name = "cipher" version = "0.4.4" @@ -226,6 +262,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "cpufeatures" version = "0.2.14" @@ -340,10 +382,12 @@ checksum = "158fe8e2e68695bd615d7e4f3227c0727b151330d3e253b525086c348d055d5e" dependencies = [ "bitflags", "byteorder", + "chrono", "diesel_derives", "itoa", "pq-sys", "r2d2", + "serde_json", ] [[package]] @@ -696,6 +740,29 @@ dependencies = [ "want", ] +[[package]] +name = "iana-time-zone" +version = "0.1.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "ident_case" version = "1.0.1" @@ -755,6 +822,15 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "js-sys" +version = "0.3.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -896,6 +972,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -1318,7 +1403,9 @@ dependencies = [ name = "roops-internal-api" version = "0.0.0" dependencies = [ + "aes-gcm", "base64-compat", + "chrono", "diesel", "dotenvy", "minreq", @@ -1917,6 +2004,61 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" + [[package]] name = "webpki-roots" version = "0.25.4" @@ -1954,6 +2096,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 430682f..6e46659 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "0.0.0" edition = "2021" [dependencies] -diesel = { version = "2.2.4", default-features = false, features = ["postgres"] } dotenvy = { version = "0.15.7", default-features = false } rocket = { version = "0.5.1", default-features = false, features = ["json", "secrets"] } rocket_cors = { version = "0.6.0", default-features = false } @@ -13,4 +12,7 @@ base64-compat = { version = "1.0.0", default-features = false } rand = { version = "0.8.0", default-features = false } sha2 = { version = "0.10.7", default-features = false } secrecy = { version = "0.10.3", default-features = false } -minreq = { version = "2.12.0", default-features = false, features = ["json-using-serde", "https"] } \ No newline at end of file +minreq = { version = "2.12.0", default-features = false, features = ["json-using-serde", "https"] } +chrono = { version = "0.4.38", default-feature = false, features = ["serde"] } +diesel = { version = "2.1.4", default-feature = false, features = ["postgres", "serde_json", "chrono"] } +aes-gcm = { version = "0.10.3", default-features = false } diff --git a/flake.nix b/flake.nix index 7dc324d..3eb903c 100644 --- a/flake.nix +++ b/flake.nix @@ -63,7 +63,15 @@ pg_ctl -D "$PGDATA" -l "$PGDATA/logfile" stop } + function delete_dev_db { + stop_dev_db + + rm -rf "$PGDATA" + } + function init_dev_db { + delete_dev_db + mkdir "$PGDATA" initdb -D "$PGDATA" @@ -76,7 +84,7 @@ stop_dev_db } - ''; + ''; buildInputs = runtimeDeps; nativeBuildInputs = buildDeps ++ devDeps ++ [ rustc ]; }; diff --git a/migrations/2024-11-04-215323_create_users_table/down.sql b/migrations/2024-11-04-215323_create_users_table/down.sql deleted file mode 100644 index c99ddcd..0000000 --- a/migrations/2024-11-04-215323_create_users_table/down.sql +++ /dev/null @@ -1 +0,0 @@ -DROP TABLE IF EXISTS users; diff --git a/migrations/2024-11-04-215323_create_users_table/up.sql b/migrations/2024-11-04-215323_create_users_table/up.sql deleted file mode 100644 index fe826b1..0000000 --- a/migrations/2024-11-04-215323_create_users_table/up.sql +++ /dev/null @@ -1,12 +0,0 @@ --- create -CREATE TABLE IF NOT EXISTS users ( - discord_id BIGINT, - roblox_id BIGINT, - is_primary BOOLEAN NOT NULL, - - PRIMARY KEY (discord_id, roblox_id) -); - --- index -CREATE UNIQUE INDEX primary_account_link ON users (discord_id) -WHERE is_primary = TRUE; diff --git a/migrations/2024-11-21-101431_create_discord_connections/down.sql b/migrations/2024-11-21-101431_create_discord_connections/down.sql new file mode 100644 index 0000000..ce61721 --- /dev/null +++ b/migrations/2024-11-21-101431_create_discord_connections/down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS discord_connections; diff --git a/migrations/2024-11-21-101431_create_discord_connections/up.sql b/migrations/2024-11-21-101431_create_discord_connections/up.sql new file mode 100644 index 0000000..81b65e7 --- /dev/null +++ b/migrations/2024-11-21-101431_create_discord_connections/up.sql @@ -0,0 +1,11 @@ +-- create +CREATE TABLE IF NOT EXISTS discord_connections +( + uid TEXT PRIMARY KEY, + access_token TEXT NOT NULL, + access_token_nonce TEXT NOT NULL, + refresh_token TEXT NOT NULL, + refresh_token_nonce TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + scope TEXT NOT NULL +); diff --git a/migrations/2024-11-21-101617_create_account_links/down.sql b/migrations/2024-11-21-101617_create_account_links/down.sql new file mode 100644 index 0000000..2f44ebc --- /dev/null +++ b/migrations/2024-11-21-101617_create_account_links/down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS primary_account_link; +DROP TABLE IF EXISTS account_links; diff --git a/migrations/2024-11-21-101617_create_account_links/up.sql b/migrations/2024-11-21-101617_create_account_links/up.sql new file mode 100644 index 0000000..88f60c3 --- /dev/null +++ b/migrations/2024-11-21-101617_create_account_links/up.sql @@ -0,0 +1,14 @@ +-- create +CREATE TABLE IF NOT EXISTS account_links +( + discord_uid TEXT REFERENCES discord_connections (uid) + ON DELETE CASCADE, + roblox_uid BIGINT, + is_primary BOOLEAN NOT NULL, + + PRIMARY KEY (discord_uid, roblox_uid) +); + +-- index +CREATE UNIQUE INDEX IF NOT EXISTS primary_account_link ON account_links (discord_uid) + WHERE is_primary = TRUE; diff --git a/migrations/2024-11-21-101635_create_sessions/down.sql b/migrations/2024-11-21-101635_create_sessions/down.sql new file mode 100644 index 0000000..63d205d --- /dev/null +++ b/migrations/2024-11-21-101635_create_sessions/down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS sessions; diff --git a/migrations/2024-11-21-101635_create_sessions/up.sql b/migrations/2024-11-21-101635_create_sessions/up.sql new file mode 100644 index 0000000..f4da9ba --- /dev/null +++ b/migrations/2024-11-21-101635_create_sessions/up.sql @@ -0,0 +1,8 @@ +-- create +CREATE TABLE IF NOT EXISTS sessions +( + session_token TEXT PRIMARY KEY, + discord_uid TEXT NOT NULL REFERENCES discord_connections (uid) + ON DELETE CASCADE, + expires_at TIMESTAMP NOT NULL +); diff --git a/src/cipher.rs b/src/cipher.rs new file mode 100644 index 0000000..8aa9b8a --- /dev/null +++ b/src/cipher.rs @@ -0,0 +1,71 @@ +use aes_gcm::aead::{Aead, Nonce, OsRng}; +use aes_gcm::{AeadCore, Aes256Gcm, Key, KeyInit}; + +pub(crate) struct EncryptedData { + /// Encrypted data represented as a hex string + pub(crate) data: String, + /// Nonce used to encrypt the data represented as a hex string + pub(crate) nonce: String, +} + +fn get_encryption_key() -> Result, String> { + // Get the encryption key from the environment + let key = std::env::var("ENCRYPTION_KEY").expect("ENCRYPTION_KEY must be set"); + assert_eq!(key.len(), 32, "Encryption key must be 32 bytes long"); + + Ok(*Key::::from_slice(key.as_bytes())) +} + +/// Encrypt a string using AES-256-GCM. +pub(crate) fn encrypt(data: &[u8]) -> Result { + let key = get_encryption_key()?; + let cipher = Aes256Gcm::new(&key); + let nonce = Aes256Gcm::generate_nonce(&mut OsRng); + + let encrypted_bytes = cipher + .encrypt(&nonce, data) + .map_err(|e| format!("Failed to encrypt data: {}", e))?; + + Ok(EncryptedData { + data: base64::encode(&encrypted_bytes), + nonce: base64::encode(nonce.as_slice()), + }) +} + +/// Decrypt an `EncryptedData` struct using AES-256-GCM. +pub(crate) fn decrypt(data: &EncryptedData) -> Result { + let key = get_encryption_key()?; + let cipher = Aes256Gcm::new(&key); + + // Convert the hex string to a vector of bytes + let encrypted_bytes = base64::decode(&data.data).expect("Failed to decode base64 data"); + let nonce = base64::decode(&data.nonce).expect("Failed to decode base64 nonce"); + let nonce = Nonce::::from_slice(nonce.as_slice()); + + let decrypted_bytes = cipher + .decrypt(nonce, encrypted_bytes.as_slice()) + .map_err(|e| format!("Failed to decrypt data: {}", e))?; + let decrypted_data = std::str::from_utf8(&decrypted_bytes) + .map_err(|e| format!("Failed to convert decrypted data to UTF-8: {}", e))?; + + Ok(decrypted_data.to_string()) +} + +#[cfg(test)] +mod tests { + use super::{decrypt, encrypt}; + + #[test] + fn test_encryption() { + // Set the encryption key in the environment + std::env::set_var("ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef"); + + const DATA: &str = "Hello, world!"; + + let encrypted_data = encrypt(DATA.as_bytes()).unwrap(); + assert_ne!(encrypted_data.data, DATA); + + let decrypted_data = decrypt(&encrypted_data).unwrap(); + assert_eq!(decrypted_data, DATA); + } +} diff --git a/src/database/mod.rs b/src/database/mod.rs index 37ec4e9..17a932b 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,119 +1 @@ -pub mod models; -pub mod schema; - -use self::models::{NewUser, User}; -use diesel::prelude::*; - -/// Creates a new user in the database. -/// -/// # Arguments -/// -/// * `conn` - The connection to the database. -/// * `roblox_id` - The Roblox ID of the user. -/// -/// # Returns -/// -/// The newly created user if successful, `None` otherwise. -pub fn create_user( - conn: &mut PgConnection, - roblox_id: i64, - discord_id: i64, - is_primary: bool, -) -> Option { - use self::schema::users; - - let new_user = NewUser { - roblox_id, - discord_id, - is_primary, - }; - - diesel::insert_into(users::table) - .values(&new_user) - .returning(User::as_returning()) - .get_result(conn) - .ok() -} - -/// Deletes a user from the database. -/// -/// # Arguments -/// -/// * `conn` - The connection to the database. -/// * `target_id` - The ID of the user to delete. -/// -/// # Returns -/// -/// `true` if the user was successfully deleted, `false` otherwise. -pub fn delete_users(conn: &mut PgConnection, target_id: UserId) -> bool { - use schema::users::dsl::*; - - match target_id { - UserId::DiscordMarker(id) => diesel::delete(users.filter(discord_id.eq(id))) - .execute(conn) - .is_ok(), - UserId::RobloxMarker(id) => diesel::delete(users.filter(roblox_id.eq(id))) - .execute(conn) - .is_ok(), - } -} - -/// Retrieves a user from the database. -/// -/// # Arguments -/// -/// * `conn` - The connection to the database. -/// * `target_id` - The ID of the user to retrieve. -/// -/// # Returns -/// -/// The user if they exist, `None` otherwise. -pub fn get_primary_user(conn: &mut PgConnection, target_id: UserId) -> Option { - use schema::users::dsl::*; - - match target_id { - UserId::DiscordMarker(id) => users - .filter(discord_id.eq(id)) - .filter(is_primary.eq(true)) - .first::(conn) - .ok(), - UserId::RobloxMarker(id) => users - .filter(roblox_id.eq(id)) - .filter(is_primary.eq(true)) - .first::(conn) - .ok(), - } -} - -/// Retrieves all users from the database. -/// -/// # Arguments -/// -/// * `conn` - The connection to the database. -/// * `target_id` - The ID of the user to retrieve. -/// -/// # Returns -/// -/// The corresponding Discord/Roblox accounts associated with the user ID. -pub fn get_users(conn: &mut PgConnection, target_id: UserId) -> Option> { - use schema::users::dsl::*; - - match target_id { - UserId::DiscordMarker(id) => users - .filter(discord_id.eq(id)) - .load::(conn) - .ok(), - UserId::RobloxMarker(id) => users - .filter(roblox_id.eq(id)) - .load::(conn) - .ok(), - } -} - -/// User ID marker. Used to differentiate between different types of IDs. -pub enum UserId { - /// Discord user ID marker. - DiscordMarker(i64), - /// Roblox user ID marker. - RobloxMarker(i64), -} +pub(crate) mod wrappers; diff --git a/src/database/models.rs b/src/database/models.rs deleted file mode 100644 index 43961ac..0000000 --- a/src/database/models.rs +++ /dev/null @@ -1,19 +0,0 @@ -use super::schema; -use diesel::prelude::*; - -#[derive(Queryable, Selectable, Debug)] -#[diesel(table_name = schema::users)] -#[diesel(check_for_backend(diesel::pg::Pg))] -pub struct User { - pub roblox_id: i64, - pub discord_id: i64, - pub is_primary: bool, -} - -#[derive(Insertable)] -#[diesel(table_name = schema::users)] -pub struct NewUser { - pub roblox_id: i64, - pub discord_id: i64, - pub is_primary: bool, -} diff --git a/src/database/schema.rs b/src/database/schema.rs deleted file mode 100644 index 447a302..0000000 --- a/src/database/schema.rs +++ /dev/null @@ -1,7 +0,0 @@ -diesel::table! { - users(discord_id, roblox_id){ - discord_id -> BigInt, - roblox_id -> BigInt, - is_primary -> Bool, - } -} diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs new file mode 100644 index 0000000..7dc3a63 --- /dev/null +++ b/src/database/wrappers/account_links/mod.rs @@ -0,0 +1,193 @@ +#![allow(unused_variables)] + +pub(crate) mod models; +mod schema; + +use self::models::{AccountLink, NewAccountLink}; +use diesel::prelude::*; + +pub(crate) struct AccountLinksDb; + +impl AccountLinksDb { + /// Creates a new account link in the database. + /// + /// # Arguments + /// + /// * `roblox_uid` - The Roblox ID of the user. + /// * `discord_uid` - The Discord ID of the user. + /// * `is_primary` - Whether the account is the primary account. + /// + /// # Returns + /// + /// The newly created user if successful, `None` otherwise. + pub(crate) fn insert_one( + new_account_link: &NewAccountLink, + conn: &mut PgConnection, + ) -> Result, diesel::result::Error> { + use self::schema::account_links::dsl::*; + + // Already exists + let pk = ( + new_account_link.discord_uid.as_str(), + new_account_link.roblox_uid, + ); + + if Self::find_one(pk, conn).is_some() { + rocket::warn!("Account link already exists: {:?}", new_account_link); + return Ok(None); + } + + // Unset the primary flag for all other account links + // if the new account link is primary + if new_account_link.is_primary { + conn.transaction(|conn| { + diesel::update(account_links) + .filter(discord_uid.eq(&new_account_link.discord_uid)) + .set(is_primary.eq(false)) + .execute(conn)?; + + diesel::insert_into(account_links) + .values(new_account_link) + .returning(AccountLink::as_returning()) + .get_result(conn) + }) + .map(Some) + } else { + diesel::insert_into(account_links) + .values(new_account_link) + .returning(AccountLink::as_returning()) + .get_result(conn) + .map(Some) + } + } + + /// Deletes account links associated with `target_id` from the database. + /// + /// # Arguments + /// + /// * `target_id` - ID of the account link(s) to delete. + /// + /// # Returns + /// + /// `true` if the account link(s) were successfully deleted, `false` otherwise. + pub(crate) fn delete_many(target_id: UserId, conn: &mut PgConnection) -> bool { + use schema::account_links::dsl::*; + + match target_id { + UserId::DiscordMarker(t_id) => diesel::delete(account_links) + .filter(discord_uid.eq(t_id)) + .execute(conn) + .is_ok(), + UserId::RobloxMarker(t_id) => diesel::delete(account_links) + .filter(roblox_uid.eq(t_id)) + .execute(conn) + .is_ok(), + } + } + + /// Retrieves an account link from the database. + /// + /// # Arguments + /// + /// * `discord_uid` - The Discord ID of the account link. + /// + /// # Returns + /// + /// The account link if it exists, `None` otherwise. + pub(crate) fn find_primary( + discord_uid: String, + conn: &mut PgConnection, + ) -> Option { + use schema::account_links::dsl::*; + + account_links + .filter(discord_uid.eq(discord_uid)) + .filter(is_primary.eq(true)) + .first::(conn) + .ok() + } + + /// Retrieve an account link by the primary key. + /// + /// # Arguments + /// + /// * `pk` - The primary key of the account link to retrieve (Discord ID, Roblox ID). + /// + /// # Returns + /// + /// The account link if it exists, `None` otherwise. + pub(crate) fn find_one(pk: (&str, i64), conn: &mut PgConnection) -> Option { + use schema::account_links::dsl::*; + let (pk_discord_uid, pk_roblox_uid) = pk; + + account_links + .filter(discord_uid.eq(pk_discord_uid)) + .filter(roblox_uid.eq(pk_roblox_uid)) + .first::(conn) + .ok() + } + + /// Retrieves all account links from the database. + /// + /// # Arguments + /// + /// * `target_id` - ID associated with the account links to retrieve. + /// + /// # Returns + /// + /// The corresponding account links associated with the user ID. + pub(crate) fn find_many( + target_id: UserId, + conn: &mut PgConnection, + ) -> Option> { + use schema::account_links::dsl::*; + + match target_id { + UserId::DiscordMarker(t_id) => account_links + .filter(discord_uid.eq(t_id)) + .load::(conn) + .ok(), + UserId::RobloxMarker(t_id) => account_links + .filter(roblox_uid.eq(t_id)) + .load::(conn) + .ok(), + } + } + + /// Updates the primary account link for a user. + /// + /// # Arguments + /// + /// * `pk` - The primary key of the account link to update (Discord ID, Roblox ID). + /// + /// # Returns + /// + /// `true` if the primary account link was successfully updated, `false` otherwise. + pub(crate) fn set_primary(pk: (&str, i64), conn: &mut PgConnection) -> bool { + use schema::account_links::dsl::*; + let (pk_discord_uid, pk_roblox_uid) = pk; + + conn.transaction(|conn| { + // Set all other account links to not primary + diesel::update(account_links) + .filter(discord_uid.eq(pk_discord_uid)) + .set(is_primary.eq(false)) + .execute(conn)?; + + // Set the new primary account link + diesel::update(account_links) + .filter(discord_uid.eq(pk_discord_uid)) + .filter(roblox_uid.eq(pk_roblox_uid)) + .set(is_primary.eq(true)) + .execute(conn)?; + + diesel::result::QueryResult::Ok(()) + }) + .is_ok() + } +} + +pub(crate) enum UserId { + DiscordMarker(String), + RobloxMarker(i64), +} diff --git a/src/database/wrappers/account_links/models.rs b/src/database/wrappers/account_links/models.rs new file mode 100644 index 0000000..153a8bb --- /dev/null +++ b/src/database/wrappers/account_links/models.rs @@ -0,0 +1,72 @@ +use super::schema; +use diesel::prelude::*; + +#[derive(Queryable, Selectable, Debug)] +#[diesel(table_name = schema::account_links)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub(crate) struct AccountLink { + pub(crate) roblox_uid: i64, + pub(crate) discord_uid: String, + pub(crate) is_primary: bool, +} + +#[derive(Insertable, Debug)] +#[diesel(table_name = schema::account_links)] +pub(crate) struct NewAccountLink { + pub(crate) roblox_uid: i64, + pub(crate) discord_uid: String, + pub(crate) is_primary: bool, +} + +#[derive(Debug)] +pub(crate) struct NewAccountLinkBuilder { + roblox_uid: Option, + discord_uid: Option, + is_primary: Option, +} + +impl Default for NewAccountLinkBuilder { + fn default() -> Self { + Self { + roblox_uid: None, + discord_uid: None, + is_primary: None, + } + } +} + +impl NewAccountLink { + pub(crate) fn build() -> NewAccountLinkBuilder { + NewAccountLinkBuilder::default() + } +} + +impl NewAccountLinkBuilder { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn roblox_uid(mut self, roblox_uid: i64) -> Self { + self.roblox_uid = Some(roblox_uid); + self + } + + pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { + self.discord_uid = Some(discord_uid); + self + } + + #[allow(clippy::wrong_self_convention)] + pub(crate) fn is_primary(mut self, is_primary: bool) -> Self { + self.is_primary = Some(is_primary); + self + } + + pub(crate) fn build(self) -> NewAccountLink { + NewAccountLink { + roblox_uid: self.roblox_uid.expect("roblox_uid is required"), + discord_uid: self.discord_uid.expect("discord_uid is required"), + is_primary: self.is_primary.expect("is_primary is required"), + } + } +} diff --git a/src/database/wrappers/account_links/schema.rs b/src/database/wrappers/account_links/schema.rs new file mode 100644 index 0000000..c73ea2a --- /dev/null +++ b/src/database/wrappers/account_links/schema.rs @@ -0,0 +1,7 @@ +diesel::table! { + account_links(discord_uid, roblox_uid){ + roblox_uid -> BigInt, + discord_uid -> Text, + is_primary -> Bool, + } +} diff --git a/src/database/wrappers/discord_connections/mod.rs b/src/database/wrappers/discord_connections/mod.rs new file mode 100644 index 0000000..af19f7f --- /dev/null +++ b/src/database/wrappers/discord_connections/mod.rs @@ -0,0 +1,66 @@ +#![allow(unused_variables)] + +pub(crate) mod models; +mod schema; + +use self::models::{DiscordConnection, NewDiscordConnection}; +use crate::database::wrappers::discord_connections::models::UpdateDiscordConnection; +use diesel::prelude::*; + +pub(crate) struct DiscordConnectionsDb; + +impl DiscordConnectionsDb { + pub(crate) fn insert_one( + new_conn: &NewDiscordConnection, + conn: &mut PgConnection, + ) -> Result, diesel::result::Error> { + use self::schema::discord_connections; + + // Already exists + if Self::find_one(&new_conn.uid, conn).is_some() { + rocket::warn!("Discord connection already exists: {:?}", new_conn.uid); + return Ok(None); + } + + diesel::insert_into(discord_connections::table) + .values(new_conn) + .returning(DiscordConnection::as_returning()) + .get_result(conn) + .map(Some) + } + + pub(crate) fn find_one(uid: &str, conn: &mut PgConnection) -> Option { + use schema::discord_connections::dsl::*; + + discord_connections + .filter(uid.eq(uid)) + .first::(conn) + .ok() + } + + pub(crate) fn delete_one(uid: &str, conn: &mut PgConnection) -> bool { + use schema::discord_connections::dsl::*; + + diesel::delete(discord_connections) + .filter(uid.eq(uid)) + .execute(conn) + .map_err(|e| rocket::error!("Failed to delete Discord connection: {:?}", e)) + .is_ok() + } + + pub(crate) fn update_one( + uid: &str, + new_conn: &UpdateDiscordConnection, + conn: &mut PgConnection, + ) -> Option { + use schema::discord_connections::dsl::*; + + diesel::update(discord_connections) + .filter(uid.eq(uid)) + .set(new_conn) + .returning(DiscordConnection::as_returning()) + .get_result(conn) + .map_err(|e| rocket::error!("Failed to update Discord connection: {:?}", e)) + .ok() + } +} diff --git a/src/database/wrappers/discord_connections/models.rs b/src/database/wrappers/discord_connections/models.rs new file mode 100644 index 0000000..44b178c --- /dev/null +++ b/src/database/wrappers/discord_connections/models.rs @@ -0,0 +1,206 @@ +use super::schema; +use diesel::prelude::*; + +#[derive(Queryable, Selectable)] +#[diesel(table_name = schema::discord_connections)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub(crate) struct DiscordConnection { + pub(crate) uid: String, + pub(crate) access_token: String, + pub(crate) access_token_nonce: String, + pub(crate) refresh_token: String, + pub(crate) refresh_token_nonce: String, + pub(crate) expires_at: chrono::NaiveDateTime, + pub(crate) scope: String, +} + +#[derive(Insertable, Debug)] +#[diesel(table_name = schema::discord_connections)] +pub(crate) struct NewDiscordConnection { + pub(crate) uid: String, + pub(crate) access_token: String, + access_token_nonce: String, + pub(crate) refresh_token: String, + refresh_token_nonce: String, + pub(crate) expires_at: chrono::NaiveDateTime, + pub(crate) scope: String, +} + +#[derive(AsChangeset, Debug)] +#[diesel(table_name = schema::discord_connections)] +pub(crate) struct UpdateDiscordConnection { + pub(crate) access_token: Option, + access_token_nonce: Option, + pub(crate) refresh_token: Option, + refresh_token_nonce: Option, + pub(crate) expires_at: Option, + pub(crate) scope: Option, +} + +pub(crate) struct DiscordConnectionBuilder { + uid: Option, + access_token: Option, + access_token_nonce: Option, + refresh_token: Option, + refresh_token_nonce: Option, + expires_at: Option, + scope: Option, +} + +impl Default for DiscordConnectionBuilder { + fn default() -> Self { + Self { + uid: None, + access_token: None, + access_token_nonce: None, + expires_at: None, + refresh_token: None, + refresh_token_nonce: None, + scope: None, + } + } +} + +impl NewDiscordConnection { + pub(crate) fn build() -> DiscordConnectionBuilder { + DiscordConnectionBuilder::default() + } +} + +impl DiscordConnectionBuilder { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn uid(mut self, uid: String) -> Self { + self.uid = Some(uid); + self + } + + pub(crate) fn access_token(mut self, access_token: String) -> Self { + self.access_token = Some(access_token); + self + } + + pub(crate) fn expires_at(mut self, expires_at: chrono::NaiveDateTime) -> Self { + self.expires_at = Some(expires_at); + self + } + + pub(crate) fn refresh_token(mut self, refresh_token: String) -> Self { + self.refresh_token = Some(refresh_token); + self + } + + pub(crate) fn scope(mut self, scope: String) -> Self { + self.scope = Some(scope); + self + } + + fn encrypt_tokens(&mut self) -> Result<(), String> { + // Encrypt the access token if it exists + if let Some(access_token) = &self.access_token { + let data = crate::cipher::encrypt(access_token.as_bytes())?; + + self.access_token_nonce = Some(data.nonce); + self.access_token = Some(data.data); + } + + // Encrypt the refresh token if it exists + if let Some(refresh_token) = &self.refresh_token { + let data = crate::cipher::encrypt(refresh_token.as_bytes())?; + + self.refresh_token_nonce = Some(data.nonce); + self.refresh_token = Some(data.data); + } + + Ok(()) + } + + pub(crate) fn build(mut self) -> Result { + self.encrypt_tokens()?; + + Ok(NewDiscordConnection { + uid: self.uid.expect("uid is required"), + access_token: self.access_token.expect("access_token is required"), + access_token_nonce: self + .access_token_nonce + .expect("access_token_nonce is required"), + expires_at: self.expires_at.expect("expires_at is required"), + refresh_token: self.refresh_token.expect("refresh_token is required"), + refresh_token_nonce: self + .refresh_token_nonce + .expect("refresh_token_nonce is required"), + scope: self.scope.expect("scope is required"), + }) + } + + pub(crate) fn build_update(mut self) -> Result { + self.encrypt_tokens()?; + + Ok(UpdateDiscordConnection { + access_token: self.access_token, + access_token_nonce: self.access_token_nonce, + expires_at: self.expires_at, + refresh_token: self.refresh_token, + refresh_token_nonce: self.refresh_token_nonce, + scope: self.scope, + }) + } +} + +#[cfg(test)] +mod tests { + use super::DiscordConnectionBuilder; + + const ACCESS_TOKEN: &str = "access_token"; + const REFRESH_TOKEN: &str = "refresh_token"; + const SCOPE: &str = "scope"; + const UID: &str = "1"; + + #[test] + fn test_build() { + let now = chrono::Utc::now().naive_utc(); + let conn = DiscordConnectionBuilder::new() + .uid(UID.to_string()) + .access_token(ACCESS_TOKEN.to_string()) + .expires_at(now) + .refresh_token(REFRESH_TOKEN.to_string()) + .scope(SCOPE.to_string()) + .build() + .expect("Failed to build NewDiscordConnection"); + + // The access token is encrypted, so we can't compare it directly + assert_ne!(conn.access_token, ACCESS_TOKEN); + + // The refresh token is encrypted, so we can't compare it directly + assert_ne!(conn.refresh_token, REFRESH_TOKEN); + + assert_eq!(conn.uid, UID); + assert_eq!(conn.expires_at, now); + assert_eq!(conn.scope, SCOPE); + } + + #[test] + fn test_build_update() { + let now = chrono::Utc::now().naive_utc(); + let conn = DiscordConnectionBuilder::new() + .access_token(ACCESS_TOKEN.to_string()) + .expires_at(now) + .refresh_token(REFRESH_TOKEN.to_string()) + .scope(SCOPE.to_string()) + .build_update() + .expect("Failed to build UpdateDiscordConnection"); + + // The access token is encrypted, so we can't compare it directly + assert_ne!(conn.access_token, Some(ACCESS_TOKEN.to_string())); + assert!(conn.access_token_nonce.is_some()); + + // The refresh token is encrypted, so we can't compare it directly + assert_ne!(conn.refresh_token, Some(REFRESH_TOKEN.to_string())); + assert!(conn.refresh_token_nonce.is_some()); + + assert_eq!(conn.expires_at, Some(now)); + assert_eq!(conn.scope, Some(SCOPE.to_string())); + } +} diff --git a/src/database/wrappers/discord_connections/schema.rs b/src/database/wrappers/discord_connections/schema.rs new file mode 100644 index 0000000..6c43b7e --- /dev/null +++ b/src/database/wrappers/discord_connections/schema.rs @@ -0,0 +1,11 @@ +diesel::table! { + discord_connections (uid) { + uid -> Text, + access_token -> Text, + access_token_nonce -> Text, + refresh_token -> Text, + refresh_token_nonce -> Text, + expires_at -> Timestamp, + scope -> Text, + } +} diff --git a/src/database/wrappers/mod.rs b/src/database/wrappers/mod.rs new file mode 100644 index 0000000..a55880f --- /dev/null +++ b/src/database/wrappers/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod account_links; +pub(crate) mod discord_connections; +pub(crate) mod sessions; diff --git a/src/database/wrappers/sessions/mod.rs b/src/database/wrappers/sessions/mod.rs new file mode 100644 index 0000000..0c1235c --- /dev/null +++ b/src/database/wrappers/sessions/mod.rs @@ -0,0 +1,47 @@ +#![allow(unused_variables)] + +pub(crate) mod models; +mod schema; + +use self::models::{NewSession, Session}; +use diesel::prelude::*; + +pub(crate) struct SessionsDb; + +impl SessionsDb { + pub(crate) fn insert_one( + new_session: &NewSession, + conn: &mut PgConnection, + ) -> Result, diesel::result::Error> { + use self::schema::sessions; + + // Already exists + if Self::find_one(&new_session.session_token, conn).is_some() { + rocket::warn!("Session already exists: {:?}", new_session.session_token); + return Ok(None); + } + + diesel::insert_into(sessions::table) + .values(new_session) + .returning(Session::as_returning()) + .get_result(conn) + .map(Some) + } + + pub(crate) fn find_one(session_token: &str, conn: &mut PgConnection) -> Option { + use schema::sessions::dsl::*; + + sessions + .filter(session_token.eq(session_token)) + .first::(conn) + .ok() + } + + pub(crate) fn delete_one(session_token: &str, conn: &mut PgConnection) -> bool { + use schema::sessions::dsl::*; + + diesel::delete(sessions.filter(session_token.eq(session_token))) + .execute(conn) + .is_ok() + } +} diff --git a/src/database/wrappers/sessions/models.rs b/src/database/wrappers/sessions/models.rs new file mode 100644 index 0000000..0c833ee --- /dev/null +++ b/src/database/wrappers/sessions/models.rs @@ -0,0 +1,71 @@ +use super::schema; +use diesel::prelude::*; + +#[derive(Queryable, Selectable, Debug)] +#[diesel(table_name = schema::sessions)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub(crate) struct Session { + pub(crate) session_token: String, + pub(crate) discord_uid: String, + pub(crate) expires_at: chrono::NaiveDateTime, +} + +#[derive(Insertable, Debug)] +#[diesel(table_name = schema::sessions)] +pub(crate) struct NewSession { + pub(crate) session_token: String, + pub(crate) discord_uid: String, + pub(crate) expires_at: chrono::NaiveDateTime, +} + +#[derive(Debug)] +pub(crate) struct NewSessionBuilder { + session_token: Option, + discord_uid: Option, + expires_at: Option, +} + +impl Default for NewSessionBuilder { + fn default() -> Self { + Self { + discord_uid: None, + session_token: None, + expires_at: None, + } + } +} + +impl NewSession { + pub(crate) fn build() -> NewSessionBuilder { + NewSessionBuilder::default() + } +} + +impl NewSessionBuilder { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { + self.discord_uid = Some(discord_uid); + self + } + + pub(crate) fn session_token(mut self, session_token: String) -> Self { + self.session_token = Some(session_token); + self + } + + pub(crate) fn expires_at(mut self, expires_at: chrono::NaiveDateTime) -> Self { + self.expires_at = Some(expires_at); + self + } + + pub(crate) fn build(self) -> NewSession { + NewSession { + discord_uid: self.discord_uid.expect("discord_uid not set"), + session_token: self.session_token.expect("session_token not set"), + expires_at: self.expires_at.expect("expires_at not set"), + } + } +} diff --git a/src/database/wrappers/sessions/schema.rs b/src/database/wrappers/sessions/schema.rs new file mode 100644 index 0000000..cc6d22f --- /dev/null +++ b/src/database/wrappers/sessions/schema.rs @@ -0,0 +1,7 @@ +diesel::table! { + sessions (session_token) { + session_token -> Text, + discord_uid -> Text, + expires_at -> Timestamp, + } +} diff --git a/src/lib.rs b/src/lib.rs index 5feb014..30da20b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,13 @@ +mod cipher; mod database; mod oauth; -mod util; +mod utils; #[macro_use] pub(crate) extern crate rocket; +use rocket::http::Status; use rocket::serde::json::{json, Value}; +use rocket::Request; #[macro_use] extern crate rocket_sync_db_pools; @@ -14,7 +17,6 @@ use rocket_cors::{Cors, CorsOptions}; use dotenvy::dotenv; - #[database("roops")] pub(crate) struct DbConn(diesel::PgConnection); @@ -24,11 +26,11 @@ fn cors_fairing() -> Cors { .expect("Cors fairing cannot be created") } -#[catch(404)] -fn not_found() -> Value { +#[catch(default)] +fn default(status: Status, _req: &Request<'_>) -> Value { json!({ - "status": 404, - "message": "Not Found" + "code": status.code, + "message": status.to_string() }) } @@ -38,6 +40,6 @@ pub fn rocket() -> _ { rocket::build() .attach(cors_fairing()) .attach(DbConn::fairing()) - .mount("/v1/oauth2", oauth::rocket_routes::routes()) - .register("/", catchers![not_found]) + .mount("/v1/oauth2", oauth::routes::routes()) + .register("/", catchers![default]) } diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 3853294..3476fc5 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -1,22 +1,5 @@ -pub(crate) mod rocket_routes; -mod oauth_types; -mod pixy; -mod discord; -mod roblox; - -use super::oauth::pixy::generate_random_base64_url_safe_no_pad_string; -use rand::{thread_rng, Rng}; +pub(crate) mod routes; +mod types; +mod utils; const OAUTH_STATE_COOKIE_NAME: &str = "oauth_state"; - -fn generate_state() -> String { - let number_of_bytes = thread_rng().gen_range(10..=20); - - generate_random_base64_url_safe_no_pad_string(number_of_bytes) -} - -fn generate_session() -> String { - let number_of_bytes = thread_rng().gen_range(32..=64); - - generate_random_base64_url_safe_no_pad_string(number_of_bytes) -} diff --git a/src/oauth/rocket_routes/discord.rs b/src/oauth/rocket_routes/discord.rs deleted file mode 100644 index 77d7c27..0000000 --- a/src/oauth/rocket_routes/discord.rs +++ /dev/null @@ -1,66 +0,0 @@ -use crate::oauth::discord::{ - construct_discord_oauth_url, exchange_code, DISCORD_ACCESS_TOKEN_COOKIE_NAME, -}; -use crate::oauth::oauth_types::{DiscordOAuthScopeSet, DiscordOAuthScopes, OAuthCallback}; -use crate::oauth::{generate_state, OAUTH_STATE_COOKIE_NAME}; -use rocket::http::{Cookie, CookieJar, SameSite, Status}; -use rocket::response::Redirect; -use rocket::time::Duration; - -#[get("/initiate/discord?")] -pub(super) fn discord_oauth_initiate( - scope_set: String, - jar: &CookieJar<'_>, -) -> Result { - let scope_set = DiscordOAuthScopeSet::try_from(scope_set)?; - let state = generate_state(); - let redirect_uri = construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state); - - // Set the cookie lifetime to 5 minutes - let auth_cookie = Cookie::build((OAUTH_STATE_COOKIE_NAME, state)) - .path("/v1/oauth2/callback/discord") - .same_site(SameSite::Lax) - .max_age(Duration::minutes(5)); - - // Save the state as a cookie to match against the one that'll be returned from callback - jar.add_private(auth_cookie); - - Ok(Redirect::to(redirect_uri)) -} - -#[get("/callback/discord?")] -pub(super) async fn discord_oauth_callback( - callback: OAuthCallback, - jar: &CookieJar<'_>, -) -> Result { - // Verify the state against the one that was saved in the cookie - let is_authorized = jar - .get_pending(OAUTH_STATE_COOKIE_NAME) - .map_or(false, |cookie| cookie.value() == callback.state); - - if !is_authorized { - return Err(Status::Unauthorized); - } - - // Use the code to obtain the token - let token_response = exchange_code(&callback.code).map_err(|e| { - rocket::error!("Failed to obtain Discord token: {:?}", e); - Status::InternalServerError - })?; - - // Save the current session as a cookie - // This will be used to access the token through the database - let session = generate_state(); - - // TODO - Save the session along with the token in the database - - let session_cookie = Cookie::build(("session", session)) - .path("/") - .same_site(SameSite::Lax) - .max_age(Duration::seconds(token_response.expires_in)); - - jar.add_private(session_cookie); - - // Redirect back to the main page - Ok(Redirect::to(uri!("/"))) -} diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs new file mode 100644 index 0000000..05fcde6 --- /dev/null +++ b/src/oauth/routes/discord.rs @@ -0,0 +1,190 @@ +use crate::database::wrappers::discord_connections::models::NewDiscordConnection; +use crate::database::wrappers::discord_connections::DiscordConnectionsDb; +use crate::database::wrappers::sessions::models::NewSession; +use crate::database::wrappers::sessions::SessionsDb; +use crate::oauth::types::discord::{DiscordOAuthScopeSet, DiscordOAuthScopes}; +use crate::oauth::types::OAuthCallback; +use crate::oauth::utils::discord::{ + construct_discord_oauth_url, exchange_code, DISCORD_AUTHORIZED_USER_ENDPOINT, +}; +use crate::oauth::utils::{generate_session, generate_state}; +use crate::oauth::OAUTH_STATE_COOKIE_NAME; +use crate::DbConn; +use diesel::Connection; +use rocket::http::{Cookie, CookieJar, SameSite, Status}; +use rocket::response::Redirect; +use rocket::serde::Deserialize; +use rocket::time::Duration; + +#[get("/initiate/discord?")] +pub(super) fn discord_oauth_initiate( + scope_set: String, + jar: &CookieJar<'_>, +) -> Result { + let scope_set = DiscordOAuthScopeSet::try_from(scope_set)?; + let state = generate_state(); + let redirect_uri = construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state); + + // Set the cookie lifetime to 5 minutes + let auth_cookie = Cookie::build((OAUTH_STATE_COOKIE_NAME, state)) + .path("/v1/oauth2/callback/discord") + .same_site(SameSite::Lax) + .max_age(Duration::minutes(5)); + + // Save the state as a cookie to match against the one that'll be returned from callback + jar.add_private(auth_cookie); + + Ok(Redirect::to(redirect_uri)) +} + +#[get("/callback/discord?")] +pub(super) async fn discord_oauth_callback( + callback: OAuthCallback, + jar: &CookieJar<'_>, + conn: DbConn, +) -> Result { + // Verify the state against the one that was saved in the cookie + let is_authorized = jar + .get_pending(OAUTH_STATE_COOKIE_NAME) + .map_or(false, |cookie| cookie.value() == callback.state); + + if !is_authorized { + return Err(Status::Unauthorized); + } + + // Use the code to obtain the token + let token_response = exchange_code(&callback.code).map_err(|e| { + rocket::error!("Failed to obtain Discord token: {:?}", e); + Status::InternalServerError + })?; + + // Fetch the authorized user + let authorized_user = minreq::get(DISCORD_AUTHORIZED_USER_ENDPOINT) + .with_header( + "Authorization", + format!("Bearer {}", token_response.access_token), + ) + .send() + .map_err(|e| { + rocket::error!("Failed to obtain Discord user info: {:?}", e); + Status::InternalServerError + })? + .json::() + .map(|r| r.user) + .map_err(|e| { + rocket::error!("Failed to parse Discord user info: {}", e); + Status::InternalServerError + })?; + + rocket::info!("Discord user info: {:?}", authorized_user); + + // Parse the user info + let discord_uid = authorized_user.expect("Missing 'identify' scope").id; + let session_token = generate_session(); + let token_expires_at = + chrono::Utc::now().naive_utc() + chrono::Duration::seconds(token_response.expires_in); + + let session = NewSession::build() + .discord_uid(discord_uid.clone()) + .session_token(session_token.clone()) + .expires_at(token_expires_at) + .build(); + + let discord_connection = NewDiscordConnection::build() + .uid(discord_uid) + .access_token(token_response.access_token) + .expires_at(token_expires_at) + .refresh_token(token_response.refresh_token) + .scope(token_response.scope) + .build() + .map_err(|e| { + rocket::error!("Failed to build Discord connection: {:?}", e); + Status::InternalServerError + })?; + + // Insert the session and Discord connection into the database + conn.run(move |conn| { + conn.transaction(|conn| { + rocket::info!("Inserting Discord connection: {:?}", discord_connection); + DiscordConnectionsDb::insert_one(&discord_connection, conn).map_err(|e| { + rocket::error!("Failed to insert Discord connection: {:?}", e); + e + })?; + + rocket::info!("Inserting session: {:?}", session); + SessionsDb::insert_one(&session, conn).map_err(|e| { + rocket::error!("Failed to insert session: {:?}", e); + e + })?; + + diesel::result::QueryResult::Ok(()) + }) + }) + .await + .map_err(|_| Status::InternalServerError)?; + + rocket::info!("Storing session cookie: {}", &session_token); + + // Save the current session as a cookie + // This will be used to access the token through the database + let session_cookie = Cookie::build(("session", session_token)) + .path("/") + .same_site(SameSite::Strict) + .max_age(Duration::seconds(token_response.expires_in)); + jar.add_private(session_cookie); + + Ok(Redirect::to(uri!("/"))) +} + +/// The response from the Discord OAuth2 @me endpoint +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +struct DiscordAuthorizedUserResponse { + /// The user who has authorized + /// + /// ⚠️ Requires the `identify` scope + user: Option, +} + +/// The user object represents a user profile on Discord. +/// [Reference](https://discord.com/developers/docs/resources/user#user-object) +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +struct DiscordUser { + /// The user's ID + id: String, + /// The user's username, not unique across the platform + username: String, + /// The user's 4-digit discord-tag + discriminator: String, + /// The user's display name, if it is set. For bots, this is the application name + global_name: Option, + /// The user's [avatar hash](https://discord.com/developers/docs/reference#image-formatting) + avatar: Option, + /// Whether the user belongs to an OAuth2 application + bot: Option, + /// Whether the user is an Official Discord System user (part of the urgent message system) + system: Option, + /// Whether the user has two factor enabled on their account + mfa_enabled: Option, + /// The user's [banner hash](https://discord.com/developers/docs/reference#image-formatting) + banner: Option, + /// The user's banner color encoded as an integer representation of hexadecimal color code + accent_color: Option, + /// The user's chosen [language option](https://discord.com/developers/docs/reference#locales) + locale: Option, + /// Whether the email on this account has been verified + /// + /// ⚠️ Requires the `email` scope + verified: Option, + /// The user's email + /// + /// ⚠️ Requires the `email` scope + email: Option, + /// The [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account + flags: Option, + /// The [type of Nitro subscription](https://discord.com/developers/docs/resources/user#user-object-premium-types) on a user's account + premium_type: Option, + /// The public [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account + public_flags: Option, +} diff --git a/src/oauth/rocket_routes/mod.rs b/src/oauth/routes/mod.rs similarity index 100% rename from src/oauth/rocket_routes/mod.rs rename to src/oauth/routes/mod.rs diff --git a/src/oauth/rocket_routes/roblox.rs b/src/oauth/routes/roblox.rs similarity index 81% rename from src/oauth/rocket_routes/roblox.rs rename to src/oauth/routes/roblox.rs index 5024049..02ad18a 100644 --- a/src/oauth/rocket_routes/roblox.rs +++ b/src/oauth/routes/roblox.rs @@ -1,7 +1,9 @@ -use crate::oauth::oauth_types::{OAuthCallback, RobloxOAuthScopeSet, RobloxOAuthScopes}; -use crate::oauth::pixy::Pixy; -use crate::oauth::roblox::construct_roblox_oauth_url; -use crate::oauth::{generate_state, OAUTH_STATE_COOKIE_NAME}; +use crate::oauth::types::roblox::{RobloxOAuthScopeSet, RobloxOAuthScopes}; +use crate::oauth::types::OAuthCallback; +use crate::oauth::utils::generate_state; +use crate::oauth::utils::pixy::Pixy; +use crate::oauth::utils::roblox::construct_roblox_oauth_url; +use crate::oauth::OAUTH_STATE_COOKIE_NAME; use rocket::http::{Cookie, CookieJar, Status}; use rocket::response::Redirect; use rocket::time::Duration; @@ -28,7 +30,8 @@ pub(super) fn roblox_oauth_callback( }; // TODO: Obtain token through https://apis.roblox.com/oauth/v1/token - // and store it as a cookie + // and store it in the database, associating it with the session cookie + // (which should already be in the cookie jar from discord auth) // Redirect back to the main page Ok(Redirect::to(uri!("/"))) diff --git a/src/oauth/oauth_types.rs b/src/oauth/types/discord.rs similarity index 52% rename from src/oauth/oauth_types.rs rename to src/oauth/types/discord.rs index 51e6cba..0eb6133 100644 --- a/src/oauth/oauth_types.rs +++ b/src/oauth/types/discord.rs @@ -1,43 +1,16 @@ +use crate::oauth::utils::discord::{DISCORD_OAUTH_APP_CLIENT_ID, DISCORD_OAUTH_REDIRECT_URI}; use crate::url; -use super::discord::{DISCORD_OAUTH_APP_CLIENT_ID, DISCORD_OAUTH_REDIRECT_URI}; use rocket::http::Status; use rocket::serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result as FmtResult}; -pub(super) enum RobloxOAuthScope { - OpenID, - Profile, - AssetRead, - AssetWrite, - GroupRead, - GroupWrite, - LegacyBadgeManage, - LegacyDeveloperProductManage, - LegacyGamePassManage, - LegacyGroupManage, - LegacyTeamCollaborationManage, - LegacyUniverseManage, - LegacyUniverseBadgeWrite, - LegacyUniverseFollowingRead, - LegacyUniverseFollowingWrite, - LegacyUserManage, - UniverseMessagingServicePublish, - UniverseWrite, - UniversePlaceWrite, - UniverseSubscriptionProductSubscriptionRead, - UniverseUserRestrictionRead, - UniverseUserRestrictionWrite, - UserAdvancedRead, - UserCommerceItemRead, - UserCommerceItemWrite, - UserCommerceMerchantConnectionRead, - UserCommerceMerchantConnectionWrite, - UserInventoryItemRead, - UserSocialRead, - UserUserNotificationWrite, +pub(crate) enum DiscordOAuthScopeSet { + Verification, } -pub(super) enum DiscordOAuthScope { +pub(crate) struct DiscordOAuthScopes(pub(crate) Vec); + +pub(crate) enum DiscordOAuthScope { Identify, Guilds, GuildsChannelsRead, @@ -82,65 +55,24 @@ pub(super) enum DiscordOAuthScope { RPCVideoWrite, } -impl From<&RobloxOAuthScope> for String { - fn from(value: &RobloxOAuthScope) -> Self { - match value { - RobloxOAuthScope::OpenID => String::from("openid"), - RobloxOAuthScope::Profile => String::from("profile"), - RobloxOAuthScope::AssetRead => String::from("creator-store-product:read"), - RobloxOAuthScope::AssetWrite => String::from("creator-store-product:write"), - RobloxOAuthScope::GroupRead => String::from("group:read"), - RobloxOAuthScope::GroupWrite => String::from("group:write"), - RobloxOAuthScope::LegacyBadgeManage => String::from("legacy-badge:manage"), - RobloxOAuthScope::LegacyDeveloperProductManage => { - String::from("legacy-developer-product:manage") - } - RobloxOAuthScope::LegacyGamePassManage => String::from("legacy-game-pass:manage"), - RobloxOAuthScope::LegacyGroupManage => String::from("legacy-group:manage"), - RobloxOAuthScope::LegacyTeamCollaborationManage => { - String::from("legacy-team-collaboration:manage") - } - RobloxOAuthScope::LegacyUniverseManage => String::from("legacy-universe:manage"), - RobloxOAuthScope::LegacyUniverseBadgeWrite => { - String::from("legacy-universe.badge:write") - } - RobloxOAuthScope::LegacyUniverseFollowingRead => { - String::from("legacy-universe.following:read") - } - RobloxOAuthScope::LegacyUniverseFollowingWrite => { - String::from("legacy-universe.following:write") - } - RobloxOAuthScope::LegacyUserManage => String::from("legacy-user:manage"), - RobloxOAuthScope::UniverseMessagingServicePublish => { - String::from("universe-messaging-service:publish") - } - RobloxOAuthScope::UniverseWrite => String::from("universe:write"), - RobloxOAuthScope::UniversePlaceWrite => String::from("universe.place:write"), - RobloxOAuthScope::UniverseSubscriptionProductSubscriptionRead => { - String::from("universe.subscription-product.subscription:read") - } - RobloxOAuthScope::UniverseUserRestrictionRead => { - String::from("universe.user-restriction:read") - } - RobloxOAuthScope::UniverseUserRestrictionWrite => { - String::from("universe.user-restriction:write") - } - RobloxOAuthScope::UserAdvancedRead => String::from("user.advanced:read"), - RobloxOAuthScope::UserCommerceItemRead => String::from("user.commerce-item:read"), - RobloxOAuthScope::UserCommerceItemWrite => String::from("user.commerce-item:write"), - RobloxOAuthScope::UserCommerceMerchantConnectionRead => { - String::from("user.commerce-merchant-connection:read") - } - RobloxOAuthScope::UserCommerceMerchantConnectionWrite => { - String::from("user.commerce-merchant-connection:write") - } - RobloxOAuthScope::UserInventoryItemRead => String::from("user.inventory-item:read"), - RobloxOAuthScope::UserSocialRead => String::from("user.social:read"), - RobloxOAuthScope::UserUserNotificationWrite => { - String::from("user.user-notification:write") - } - } - } +#[derive(Debug, Serialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct DiscordAuthorizationCodeRequestBody<'a> { + client_id: &'a str, + client_secret: String, + grant_type: &'a str, + code: &'a str, + redirect_uri: &'a str, +} + +#[derive(Debug, Deserialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct DiscordAuthorizationCodeResponse { + pub(crate) access_token: String, + pub(crate) token_type: String, + pub(crate) expires_in: i64, + pub(crate) refresh_token: String, + pub(crate) scope: String, } impl From<&DiscordOAuthScope> for String { @@ -204,21 +136,6 @@ impl From<&DiscordOAuthScope> for String { } } -pub(super) struct RobloxOAuthScopes(pub(super) Vec); -pub(super) struct DiscordOAuthScopes(pub(super) Vec); - -impl From<&RobloxOAuthScopes> for String { - fn from(value: &RobloxOAuthScopes) -> Self { - let scopes = &value.0; - - scopes - .iter() - .map(String::from) - .collect::>() - .join("%20") - } -} - impl From<&DiscordOAuthScopes> for String { fn from(value: &DiscordOAuthScopes) -> Self { let scopes = &value.0; @@ -231,40 +148,12 @@ impl From<&DiscordOAuthScopes> for String { } } -impl Display for RobloxOAuthScopes { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}", String::from(self)) - } -} - impl Display for DiscordOAuthScopes { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { write!(f, "{}", String::from(self)) } } -pub(super) enum RobloxOAuthScopeSet { - Verification, -} - -pub(super) enum DiscordOAuthScopeSet { - Verification, -} - -impl From<&RobloxOAuthScopeSet> for RobloxOAuthScopes { - fn from(value: &RobloxOAuthScopeSet) -> Self { - match value { - RobloxOAuthScopeSet::Verification => RobloxOAuthScopes(vec![ - RobloxOAuthScope::OpenID, - RobloxOAuthScope::Profile, - RobloxOAuthScope::GroupRead, - RobloxOAuthScope::UserInventoryItemRead, - RobloxOAuthScope::UserAdvancedRead, - ]), - } - } -} - impl From<&DiscordOAuthScopeSet> for DiscordOAuthScopes { fn from(value: &DiscordOAuthScopeSet) -> Self { match value { @@ -278,17 +167,6 @@ impl From<&DiscordOAuthScopeSet> for DiscordOAuthScopes { } } -impl TryFrom for RobloxOAuthScopeSet { - type Error = Status; - - fn try_from(value: String) -> Result { - match value.as_str() { - "verification" => Ok(RobloxOAuthScopeSet::Verification), - _ => Err(Status::BadRequest), - } - } -} - impl TryFrom for DiscordOAuthScopeSet { type Error = Status; @@ -300,24 +178,8 @@ impl TryFrom for DiscordOAuthScopeSet { } } -#[derive(Debug, FromForm)] -pub(super) struct OAuthCallback { - pub(super) code: String, - pub(super) state: String, -} - -#[derive(Debug, Serialize)] -#[serde(crate = "rocket::serde")] -pub(super) struct DiscordAuthorizationCodeRequestBody<'a> { - client_id: &'a str, - client_secret: String, - grant_type: &'a str, - code: &'a str, - redirect_uri: &'a str, -} - impl<'a> DiscordAuthorizationCodeRequestBody<'a> { - pub(super) fn new(code: &'a str) -> Self { + pub(crate) fn new(code: &'a str) -> Self { Self { client_id: DISCORD_OAUTH_APP_CLIENT_ID, redirect_uri: DISCORD_OAUTH_REDIRECT_URI, @@ -328,7 +190,7 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { } } - pub(super) fn as_query_params(&self) -> String { + pub(crate) fn as_query_params(&self) -> String { url!([ ("client_id", self.client_id), ("client_secret", self.client_secret), @@ -338,13 +200,3 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { ]) } } - -#[derive(Debug, Deserialize)] -#[serde(crate = "rocket::serde")] -pub(super) struct DiscordAuthorizationCodeResponse { - pub(super) access_token: String, - pub(super) token_type: String, - pub(super) expires_in: i64, - pub(super) refresh_token: String, - pub(super) scope: String, -} diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs new file mode 100644 index 0000000..52c0173 --- /dev/null +++ b/src/oauth/types/mod.rs @@ -0,0 +1,8 @@ +pub(super) mod discord; +pub(super) mod roblox; + +#[derive(Debug, FromForm)] +pub(super) struct OAuthCallback { + pub(crate) code: String, + pub(crate) state: String, +} diff --git a/src/oauth/types/roblox.rs b/src/oauth/types/roblox.rs new file mode 100644 index 0000000..2d22682 --- /dev/null +++ b/src/oauth/types/roblox.rs @@ -0,0 +1,145 @@ +use rocket::http::Status; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +pub(crate) enum RobloxOAuthScopeSet { + Verification, +} + +pub(crate) struct RobloxOAuthScopes(pub(crate) Vec); + +pub(crate) enum RobloxOAuthScope { + OpenID, + Profile, + AssetRead, + AssetWrite, + GroupRead, + GroupWrite, + LegacyBadgeManage, + LegacyDeveloperProductManage, + LegacyGamePassManage, + LegacyGroupManage, + LegacyTeamCollaborationManage, + LegacyUniverseManage, + LegacyUniverseBadgeWrite, + LegacyUniverseFollowingRead, + LegacyUniverseFollowingWrite, + LegacyUserManage, + UniverseMessagingServicePublish, + UniverseWrite, + UniversePlaceWrite, + UniverseSubscriptionProductSubscriptionRead, + UniverseUserRestrictionRead, + UniverseUserRestrictionWrite, + UserAdvancedRead, + UserCommerceItemRead, + UserCommerceItemWrite, + UserCommerceMerchantConnectionRead, + UserCommerceMerchantConnectionWrite, + UserInventoryItemRead, + UserSocialRead, + UserUserNotificationWrite, +} + +impl From<&RobloxOAuthScope> for String { + fn from(value: &RobloxOAuthScope) -> Self { + match value { + RobloxOAuthScope::OpenID => String::from("openid"), + RobloxOAuthScope::Profile => String::from("profile"), + RobloxOAuthScope::AssetRead => String::from("creator-store-product:read"), + RobloxOAuthScope::AssetWrite => String::from("creator-store-product:write"), + RobloxOAuthScope::GroupRead => String::from("group:read"), + RobloxOAuthScope::GroupWrite => String::from("group:write"), + RobloxOAuthScope::LegacyBadgeManage => String::from("legacy-badge:manage"), + RobloxOAuthScope::LegacyDeveloperProductManage => { + String::from("legacy-developer-product:manage") + } + RobloxOAuthScope::LegacyGamePassManage => String::from("legacy-game-pass:manage"), + RobloxOAuthScope::LegacyGroupManage => String::from("legacy-group:manage"), + RobloxOAuthScope::LegacyTeamCollaborationManage => { + String::from("legacy-team-collaboration:manage") + } + RobloxOAuthScope::LegacyUniverseManage => String::from("legacy-universe:manage"), + RobloxOAuthScope::LegacyUniverseBadgeWrite => { + String::from("legacy-universe.badge:write") + } + RobloxOAuthScope::LegacyUniverseFollowingRead => { + String::from("legacy-universe.following:read") + } + RobloxOAuthScope::LegacyUniverseFollowingWrite => { + String::from("legacy-universe.following:write") + } + RobloxOAuthScope::LegacyUserManage => String::from("legacy-user:manage"), + RobloxOAuthScope::UniverseMessagingServicePublish => { + String::from("universe-messaging-service:publish") + } + RobloxOAuthScope::UniverseWrite => String::from("universe:write"), + RobloxOAuthScope::UniversePlaceWrite => String::from("universe.place:write"), + RobloxOAuthScope::UniverseSubscriptionProductSubscriptionRead => { + String::from("universe.subscription-product.subscription:read") + } + RobloxOAuthScope::UniverseUserRestrictionRead => { + String::from("universe.user-restriction:read") + } + RobloxOAuthScope::UniverseUserRestrictionWrite => { + String::from("universe.user-restriction:write") + } + RobloxOAuthScope::UserAdvancedRead => String::from("user.advanced:read"), + RobloxOAuthScope::UserCommerceItemRead => String::from("user.commerce-item:read"), + RobloxOAuthScope::UserCommerceItemWrite => String::from("user.commerce-item:write"), + RobloxOAuthScope::UserCommerceMerchantConnectionRead => { + String::from("user.commerce-merchant-connection:read") + } + RobloxOAuthScope::UserCommerceMerchantConnectionWrite => { + String::from("user.commerce-merchant-connection:write") + } + RobloxOAuthScope::UserInventoryItemRead => String::from("user.inventory-item:read"), + RobloxOAuthScope::UserSocialRead => String::from("user.social:read"), + RobloxOAuthScope::UserUserNotificationWrite => { + String::from("user.user-notification:write") + } + } + } +} + +impl From<&RobloxOAuthScopes> for String { + fn from(value: &RobloxOAuthScopes) -> Self { + let scopes = &value.0; + + scopes + .iter() + .map(String::from) + .collect::>() + .join("%20") + } +} + +impl Display for RobloxOAuthScopes { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + write!(f, "{}", String::from(self)) + } +} + +impl From<&RobloxOAuthScopeSet> for RobloxOAuthScopes { + fn from(value: &RobloxOAuthScopeSet) -> Self { + match value { + RobloxOAuthScopeSet::Verification => RobloxOAuthScopes(vec![ + RobloxOAuthScope::OpenID, + RobloxOAuthScope::Profile, + RobloxOAuthScope::GroupRead, + RobloxOAuthScope::UserInventoryItemRead, + RobloxOAuthScope::UserAdvancedRead, + ]), + } + } +} + +impl TryFrom for RobloxOAuthScopeSet { + type Error = Status; + + fn try_from(value: String) -> Result { + match value.as_str() { + "verification" => Ok(RobloxOAuthScopeSet::Verification), + _ => Err(Status::BadRequest), + } + } +} diff --git a/src/oauth/discord.rs b/src/oauth/utils/discord.rs similarity index 72% rename from src/oauth/discord.rs rename to src/oauth/utils/discord.rs index e4f1ae2..9765350 100644 --- a/src/oauth/discord.rs +++ b/src/oauth/utils/discord.rs @@ -1,16 +1,16 @@ -use crate::oauth::oauth_types::{ +use crate::oauth::types::discord::{ DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, DiscordOAuthScopes, }; use crate::url; -pub(super) const DISCORD_OAUTH_REDIRECT_URI: &str = +pub(crate) const DISCORD_OAUTH_REDIRECT_URI: &str = "http://localhost:8000/v1/oauth2/callback/discord"; -pub(super) const DISCORD_OAUTH_APP_CLIENT_ID: &str = "1300538596348133499"; -pub(super) const DISCORD_ACCESS_TOKEN_COOKIE_NAME: &str = "discord_access_token"; +pub(crate) const DISCORD_OAUTH_APP_CLIENT_ID: &str = "1300538596348133499"; +pub(crate) const DISCORD_AUTHORIZED_USER_ENDPOINT: &str = "https://discord.com/api/v10/oauth2/@me"; const DISCORD_OAUTH_ENDPOINT_URL: &str = "https://discord.com/oauth2/authorize"; const DISCORD_TOKEN_ENDPOINT_URL: &str = "https://discord.com/api/v10/oauth2/token"; -pub(super) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &str) -> String { +pub(crate) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &str) -> String { url!( DISCORD_OAUTH_ENDPOINT_URL, [ @@ -23,7 +23,7 @@ pub(super) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &s ) } -pub(super) fn exchange_code(code: &str) -> Result { +pub(crate) fn exchange_code(code: &str) -> Result { let body = DiscordAuthorizationCodeRequestBody::new(code); let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs new file mode 100644 index 0000000..e126af0 --- /dev/null +++ b/src/oauth/utils/mod.rs @@ -0,0 +1,18 @@ +pub(super) mod discord; +pub(super) mod pixy; +pub(super) mod roblox; + +use self::pixy::generate_random_base64_url_safe_no_pad_string; +use rand::{thread_rng, Rng}; + +pub(super) fn generate_state() -> String { + let number_of_bytes = thread_rng().gen_range(10..=20); + + generate_random_base64_url_safe_no_pad_string(number_of_bytes) +} + +pub(super) fn generate_session() -> String { + let number_of_bytes = thread_rng().gen_range(32..=64); + + generate_random_base64_url_safe_no_pad_string(number_of_bytes) +} diff --git a/src/oauth/pixy.rs b/src/oauth/utils/pixy.rs similarity index 82% rename from src/oauth/pixy.rs rename to src/oauth/utils/pixy.rs index 2c73c1f..cbb9776 100644 --- a/src/oauth/pixy.rs +++ b/src/oauth/utils/pixy.rs @@ -2,7 +2,7 @@ use rand::{thread_rng, Rng}; use secrecy::{ExposeSecret, SecretString}; use sha2::{Digest, Sha256}; -pub(super) fn generate_random_base64_url_safe_no_pad_string(number_of_bytes: usize) -> String { +pub(crate) fn generate_random_base64_url_safe_no_pad_string(number_of_bytes: usize) -> String { let mut random_bytes = vec![0u8; number_of_bytes]; thread_rng().fill(&mut random_bytes[..]); @@ -27,13 +27,13 @@ fn calculate_challenge_from_verifier(verifier: &SecretString) -> SecretString { )) } -pub(super) struct Pixy { +pub(crate) struct Pixy { challenge: SecretString, verifier: SecretString, } impl Pixy { - pub(super) fn new() -> Self { + pub(crate) fn new() -> Self { let verifier = generate_verifier(); let challenge = calculate_challenge_from_verifier(&verifier); @@ -43,11 +43,11 @@ impl Pixy { } } - pub(super) fn expose_verifier(&self) -> &str { + pub(crate) fn expose_verifier(&self) -> &str { self.verifier.expose_secret() } - pub(super) fn get_challenge(&self) -> &SecretString { + pub(crate) fn get_challenge(&self) -> &SecretString { &self.challenge } } diff --git a/src/oauth/roblox.rs b/src/oauth/utils/roblox.rs similarity index 71% rename from src/oauth/roblox.rs rename to src/oauth/utils/roblox.rs index 9687df2..d34668a 100644 --- a/src/oauth/roblox.rs +++ b/src/oauth/utils/roblox.rs @@ -1,14 +1,12 @@ +use crate::oauth::types::roblox::RobloxOAuthScopes; use crate::url; -use super::oauth_types::RobloxOAuthScopes; -use super::pixy::generate_random_base64_url_safe_no_pad_string; -use rand::{thread_rng, Rng}; use secrecy::{ExposeSecret, SecretString}; const ROBLOX_OAUTH_ENDPOINT_URL: &str = "https://apis.roblox.com/oauth/v1/authorize"; const ROBLOX_OAUTH_REDIRECT_URI: &str = "http://localhost:8000/oauth-callback/roblox"; const ROBLOX_OAUTH_APP_CLIENT_ID: u64 = 5083487294232976060; -pub(super) fn construct_roblox_oauth_url( +pub(crate) fn construct_roblox_oauth_url( code_challenge_secret: &SecretString, scopes: &RobloxOAuthScopes, state: &str, @@ -28,9 +26,3 @@ pub(super) fn construct_roblox_oauth_url( ] ) } - -pub fn generate_state() -> String { - let number_of_bytes = thread_rng().gen_range(10..=20); - - generate_random_base64_url_safe_no_pad_string(number_of_bytes) -} diff --git a/src/util.rs b/src/utils.rs similarity index 100% rename from src/util.rs rename to src/utils.rs From 491bc26c8bd2b6f80e6df2844e21e65ac3402aed Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:46:51 +0000 Subject: [PATCH 03/14] fix(env): Correct concat syntax --- .example.env | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.example.env b/.example.env index 6fd477c..adcae2d 100644 --- a/.example.env +++ b/.example.env @@ -15,5 +15,5 @@ DB_PORT='5432' # Use the variables above to construct the connection url DATABASE_URL="${DB_DRIVER}://${DB_USER}:${DB_PASSWORD}@${DB_HOST}:${DB_PORT}/${DB_NAME}" -ROCKET_DATABASES='{${DB_NAME}={url="${DATABASE_URL}"}}' +ROCKET_DATABASES="{${DB_NAME}={url=\"${DATABASE_URL}\"}}" From 94c1596d30c4a3197b4b7a1fd2e839b3d4add9f6 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Tue, 26 Nov 2024 11:47:21 +0000 Subject: [PATCH 04/14] refactor(init-db): Allow custom host and port --- init_db.sh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/init_db.sh b/init_db.sh index b53f58a..7c80640 100644 --- a/init_db.sh +++ b/init_db.sh @@ -45,8 +45,18 @@ CREATE DATABASE roops WITH OWNER $USER; GRANT ALL PRIVILEGES ON DATABASE roops TO $USER; PSQL_SCRIPT +# Set HOST to localhost unless it's already set +if [ -z "$HOST" ]; then + HOST=localhost +fi + +# Set PORT to 5432 unless it's already set +if [ -z "$PORT" ]; then + PORT=5432 +fi + echo "PostgreSQL user and database created." echo "Username: $USER" echo "Database: roops" echo "Password: $PASSWORD" -echo "Connection URL: postgresql://$USER:$PASSWORD@localhost:5432/roops" +echo "Connection URL: postgresql://$USER:$PASSWORD@$HOST:$PORT/roops" From a1f38190b885f98c5937a2cd4936f0beead903a7 Mon Sep 17 00:00:00 2001 From: Nick <59822256+Archasion@users.noreply.github.com> Date: Sun, 1 Dec 2024 19:13:28 +0000 Subject: [PATCH 05/14] chore(discord-oauth2): Implement refresh token endpoint (#11) * chore(discord-oauth2): Implement token refresh endpoint * chore(discord-oauth2): Complete refresh token implementation * refactor(discord-refresh-token): Return correct error code * refactor: Remove index page * refactor(discord-refresh-token): Refresh token using session token * refactor(sessions): Add endpoint for refreshing session token * fix(db): Prevent parameter names from being overridden * refactor: Cleanup * refactor: Remove init db file * refactor(routes): Correct HTTP response statuses * chore(config): Add config file for CORS * fix(gitignore): Add Cargo.lock --- .gitignore | 2 + Cargo.lock | 2290 ----------------- Cargo.toml | 5 +- README.md | 17 + init_db.sh | 62 - .../up.sql | 12 +- .../2024-11-21-101635_create_sessions/up.sql | 6 +- src/config.rs | 73 + src/database/wrappers/account_links/mod.rs | 86 +- src/database/wrappers/account_links/models.rs | 20 +- .../wrappers/discord_connections/mod.rs | 46 +- .../wrappers/discord_connections/models.rs | 19 +- src/database/wrappers/sessions/mod.rs | 56 +- src/database/wrappers/sessions/models.rs | 50 +- src/database/wrappers/sessions/schema.rs | 4 +- src/lib.rs | 16 +- src/oauth/mod.rs | 1 + src/oauth/routes/discord.rs | 173 +- src/oauth/routes/mod.rs | 62 +- src/oauth/types/discord.rs | 91 +- src/oauth/types/mod.rs | 40 +- src/oauth/utils/discord.rs | 38 +- src/oauth/utils/mod.rs | 18 +- src/utils.rs | 2 +- 24 files changed, 614 insertions(+), 2575 deletions(-) delete mode 100644 Cargo.lock create mode 100644 README.md delete mode 100644 init_db.sh create mode 100644 src/config.rs diff --git a/.gitignore b/.gitignore index 2f32484..0e0b277 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target /.pg-dev-data +config.toml +Cargo.lock \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index 3997621..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,2290 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "addr2line" -version = "0.24.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" -dependencies = [ - "gimli", -] - -[[package]] -name = "adler2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" - -[[package]] -name = "aead" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" -dependencies = [ - "crypto-common", - "generic-array", -] - -[[package]] -name = "aes" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" -dependencies = [ - "cfg-if", - "cipher", - "cpufeatures", -] - -[[package]] -name = "aes-gcm" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" -dependencies = [ - "aead", - "aes", - "cipher", - "ctr", - "ghash", - "subtle", -] - -[[package]] -name = "aho-corasick" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" -dependencies = [ - "memchr", -] - -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - -[[package]] -name = "async-stream" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" -dependencies = [ - "async-stream-impl", - "futures-core", - "pin-project-lite", -] - -[[package]] -name = "async-stream-impl" -version = "0.3.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "async-trait" -version = "0.1.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "atomic" -version = "0.5.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba" - -[[package]] -name = "atomic" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d818003e740b63afc82337e3160717f4f63078720a810b7b903e70a5d1d2994" -dependencies = [ - "bytemuck", -] - -[[package]] -name = "autocfg" -version = "1.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" - -[[package]] -name = "backtrace" -version = "0.3.74" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" -dependencies = [ - "addr2line", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", - "windows-targets 0.52.6", -] - -[[package]] -name = "base64" -version = "0.22.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" - -[[package]] -name = "base64-compat" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a8d4d2746f89841e49230dd26917df1876050f95abafafbe34f47cb534b88d7" -dependencies = [ - "byteorder", -] - -[[package]] -name = "binascii" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "383d29d513d8764dcdc42ea295d979eb99c3c9f00607b3692cf68a431f7dca72" - -[[package]] -name = "bitflags" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "bumpalo" -version = "3.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" - -[[package]] -name = "bytemuck" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8334215b81e418a0a7bdb8ef0849474f40bb10c8b71f1c4ed315cff49f32494d" - -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - -[[package]] -name = "bytes" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" - -[[package]] -name = "cc" -version = "1.1.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3788d6ac30243803df38a3e9991cf37e41210232916d41a8222ae378f912624" -dependencies = [ - "shlex", -] - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "chrono" -version = "0.4.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" -dependencies = [ - "android-tzdata", - "iana-time-zone", - "js-sys", - "num-traits", - "serde", - "wasm-bindgen", - "windows-targets 0.52.6", -] - -[[package]] -name = "cipher" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" -dependencies = [ - "crypto-common", - "inout", -] - -[[package]] -name = "cookie" -version = "0.18.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" -dependencies = [ - "aes-gcm", - "base64", - "hkdf", - "percent-encoding", - "rand", - "sha2", - "subtle", - "time", - "version_check", -] - -[[package]] -name = "core-foundation-sys" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" - -[[package]] -name = "cpufeatures" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" -dependencies = [ - "libc", -] - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "rand_core", - "typenum", -] - -[[package]] -name = "ctr" -version = "0.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" -dependencies = [ - "cipher", -] - -[[package]] -name = "darling" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" -dependencies = [ - "darling_core", - "darling_macro", -] - -[[package]] -name = "darling_core" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn", -] - -[[package]] -name = "darling_macro" -version = "0.20.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" -dependencies = [ - "darling_core", - "quote", - "syn", -] - -[[package]] -name = "deranged" -version = "0.3.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" -dependencies = [ - "powerfmt", -] - -[[package]] -name = "devise" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1d90b0c4c777a2cad215e3c7be59ac7c15adf45cf76317009b7d096d46f651d" -dependencies = [ - "devise_codegen", - "devise_core", -] - -[[package]] -name = "devise_codegen" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71b28680d8be17a570a2334922518be6adc3f58ecc880cbb404eaeb8624fd867" -dependencies = [ - "devise_core", - "quote", -] - -[[package]] -name = "devise_core" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b035a542cf7abf01f2e3c4d5a7acbaebfefe120ae4efc7bde3df98186e4b8af7" -dependencies = [ - "bitflags", - "proc-macro2", - "proc-macro2-diagnostics", - "quote", - "syn", -] - -[[package]] -name = "diesel" -version = "2.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "158fe8e2e68695bd615d7e4f3227c0727b151330d3e253b525086c348d055d5e" -dependencies = [ - "bitflags", - "byteorder", - "chrono", - "diesel_derives", - "itoa", - "pq-sys", - "r2d2", - "serde_json", -] - -[[package]] -name = "diesel_derives" -version = "2.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f2c3de51e2ba6bf2a648285696137aaf0f5f487bcbea93972fe8a364e131a4" -dependencies = [ - "diesel_table_macro_syntax", - "dsl_auto_type", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "diesel_table_macro_syntax" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" -dependencies = [ - "syn", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "crypto-common", - "subtle", -] - -[[package]] -name = "dotenvy" -version = "0.15.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" - -[[package]] -name = "dsl_auto_type" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607" -dependencies = [ - "darling", - "either", - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "either" -version = "1.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" - -[[package]] -name = "encoding_rs" -version = "0.8.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "errno" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "fastrand" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" - -[[package]] -name = "figment" -version = "0.10.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cb01cd46b0cf372153850f4c6c272d9cbea2da513e07538405148f95bd789f3" -dependencies = [ - "atomic 0.6.0", - "pear", - "serde", - "toml", - "uncased", - "version_check", -] - -[[package]] -name = "fnv" -version = "1.0.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" - -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - -[[package]] -name = "futures-channel" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" -dependencies = [ - "futures-core", - "futures-sink", -] - -[[package]] -name = "futures-core" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" - -[[package]] -name = "futures-io" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" - -[[package]] -name = "futures-sink" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" - -[[package]] -name = "futures-task" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" - -[[package]] -name = "futures-util" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-sink", - "futures-task", - "memchr", - "pin-project-lite", - "pin-utils", - "slab", -] - -[[package]] -name = "generator" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc16584ff22b460a382b7feec54b23d2908d858152e5739a120b949293bd74e" -dependencies = [ - "cc", - "libc", - "log", - "rustversion", - "windows", -] - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - -[[package]] -name = "getrandom" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "ghash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" -dependencies = [ - "opaque-debug", - "polyval", -] - -[[package]] -name = "gimli" -version = "0.31.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" - -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - -[[package]] -name = "hashbrown" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - -[[package]] -name = "hermit-abi" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" - -[[package]] -name = "hkdf" -version = "0.12.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" -dependencies = [ - "hmac", -] - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - -[[package]] -name = "http" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - -[[package]] -name = "http-body" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" -dependencies = [ - "bytes", - "http 0.2.12", - "pin-project-lite", -] - -[[package]] -name = "httparse" -version = "1.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" - -[[package]] -name = "httpdate" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" - -[[package]] -name = "hyper" -version = "0.14.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "http 0.2.12", - "http-body", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", - "want", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.61" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" -dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", -] - -[[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" -dependencies = [ - "cc", -] - -[[package]] -name = "ident_case" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" - -[[package]] -name = "idna" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" -dependencies = [ - "unicode-bidi", - "unicode-normalization", -] - -[[package]] -name = "indexmap" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" -dependencies = [ - "equivalent", - "hashbrown", - "serde", -] - -[[package]] -name = "inlinable_string" -version = "0.1.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" - -[[package]] -name = "inout" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" -dependencies = [ - "generic-array", -] - -[[package]] -name = "is-terminal" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" -dependencies = [ - "hermit-abi 0.4.0", - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "itoa" -version = "1.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" - -[[package]] -name = "js-sys" -version = "0.3.72" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" -dependencies = [ - "wasm-bindgen", -] - -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - -[[package]] -name = "libc" -version = "0.2.161" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" - -[[package]] -name = "linux-raw-sys" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" - -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "log" -version = "0.4.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" - -[[package]] -name = "loom" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff50ecb28bb86013e935fb6683ab1f6d3a20016f123c76fd4c27470076ac30f5" -dependencies = [ - "cfg-if", - "generator", - "scoped-tls", - "serde", - "serde_json", - "tracing", - "tracing-subscriber", -] - -[[package]] -name = "matchers" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" -dependencies = [ - "regex-automata 0.1.10", -] - -[[package]] -name = "memchr" -version = "2.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" - -[[package]] -name = "mime" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" - -[[package]] -name = "miniz_oxide" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" -dependencies = [ - "adler2", -] - -[[package]] -name = "minreq" -version = "2.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763d142cdff44aaadd9268bebddb156ef6c65a0e13486bb81673cf2d8739f9b0" -dependencies = [ - "log", - "once_cell", - "rustls", - "rustls-webpki", - "serde", - "serde_json", - "webpki-roots", -] - -[[package]] -name = "mio" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" -dependencies = [ - "hermit-abi 0.3.9", - "libc", - "wasi", - "windows-sys 0.52.0", -] - -[[package]] -name = "multer" -version = "3.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b" -dependencies = [ - "bytes", - "encoding_rs", - "futures-util", - "http 1.1.0", - "httparse", - "memchr", - "mime", - "spin", - "tokio", - "tokio-util", - "version_check", -] - -[[package]] -name = "nu-ansi-term" -version = "0.46.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi", -] - -[[package]] -name = "num-conv" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi 0.3.9", - "libc", -] - -[[package]] -name = "object" -version = "0.36.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" -dependencies = [ - "memchr", -] - -[[package]] -name = "once_cell" -version = "1.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" - -[[package]] -name = "opaque-debug" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" - -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - -[[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets 0.52.6", -] - -[[package]] -name = "pear" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdeeaa00ce488657faba8ebf44ab9361f9365a97bd39ffb8a60663f57ff4b467" -dependencies = [ - "inlinable_string", - "pear_codegen", - "yansi", -] - -[[package]] -name = "pear_codegen" -version = "0.2.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bab5b985dc082b345f812b7df84e1bef27e7207b39e448439ba8bd69c93f147" -dependencies = [ - "proc-macro2", - "proc-macro2-diagnostics", - "quote", - "syn", -] - -[[package]] -name = "percent-encoding" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - -[[package]] -name = "pin-project-lite" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" - -[[package]] -name = "pin-utils" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" - -[[package]] -name = "polyval" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" -dependencies = [ - "cfg-if", - "cpufeatures", - "opaque-debug", - "universal-hash", -] - -[[package]] -name = "powerfmt" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" - -[[package]] -name = "ppv-lite86" -version = "0.2.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "pq-sys" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6cc05d7ea95200187117196eee9edd0644424911821aeb28a18ce60ea0b8793" -dependencies = [ - "vcpkg", -] - -[[package]] -name = "proc-macro2" -version = "1.0.89" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "proc-macro2-diagnostics" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "version_check", - "yansi", -] - -[[package]] -name = "quote" -version = "1.0.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "r2d2" -version = "0.8.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" -dependencies = [ - "log", - "parking_lot", - "scheduled-thread-pool", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "redox_syscall" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" -dependencies = [ - "bitflags", -] - -[[package]] -name = "ref-cast" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" -dependencies = [ - "ref-cast-impl", -] - -[[package]] -name = "ref-cast-impl" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "regex" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata 0.4.8", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", -] - -[[package]] -name = "regex-automata" -version = "0.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - -[[package]] -name = "regex-syntax" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" - -[[package]] -name = "ring" -version = "0.17.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" -dependencies = [ - "cc", - "cfg-if", - "getrandom", - "libc", - "spin", - "untrusted", - "windows-sys 0.52.0", -] - -[[package]] -name = "rocket" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a516907296a31df7dc04310e7043b61d71954d703b603cc6867a026d7e72d73f" -dependencies = [ - "async-stream", - "async-trait", - "atomic 0.5.3", - "binascii", - "bytes", - "either", - "figment", - "futures", - "indexmap", - "log", - "memchr", - "multer", - "num_cpus", - "parking_lot", - "pin-project-lite", - "rand", - "ref-cast", - "rocket_codegen", - "rocket_http", - "serde", - "serde_json", - "state", - "tempfile", - "time", - "tokio", - "tokio-stream", - "tokio-util", - "ubyte", - "version_check", - "yansi", -] - -[[package]] -name = "rocket_codegen" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "575d32d7ec1a9770108c879fc7c47815a80073f96ca07ff9525a94fcede1dd46" -dependencies = [ - "devise", - "glob", - "indexmap", - "proc-macro2", - "quote", - "rocket_http", - "syn", - "unicode-xid", - "version_check", -] - -[[package]] -name = "rocket_cors" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfac3a1df83f8d4fc96aa41dba3b86c786417b7fc0f52ec76295df2ba781aa69" -dependencies = [ - "http 0.2.12", - "log", - "regex", - "rocket", - "unicase", - "url", -] - -[[package]] -name = "rocket_http" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e274915a20ee3065f611c044bd63c40757396b6dbc057d6046aec27f14f882b9" -dependencies = [ - "cookie", - "either", - "futures", - "http 0.2.12", - "hyper", - "indexmap", - "log", - "memchr", - "pear", - "percent-encoding", - "pin-project-lite", - "ref-cast", - "serde", - "smallvec", - "stable-pattern", - "state", - "time", - "tokio", - "uncased", -] - -[[package]] -name = "rocket_sync_db_pools" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d83f32721ed79509adac4328e97f817a8f55a47c4b64799f6fd6cc3adb6e42ff" -dependencies = [ - "diesel", - "r2d2", - "rocket", - "rocket_sync_db_pools_codegen", - "serde", - "tokio", - "version_check", -] - -[[package]] -name = "rocket_sync_db_pools_codegen" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cc890925dc79370c28eb15c9957677093fdb7e8c44966d189f38cedb995ee68" -dependencies = [ - "devise", - "quote", -] - -[[package]] -name = "roops-internal-api" -version = "0.0.0" -dependencies = [ - "aes-gcm", - "base64-compat", - "chrono", - "diesel", - "dotenvy", - "minreq", - "rand", - "rocket", - "rocket_cors", - "rocket_sync_db_pools", - "secrecy", - "sha2", -] - -[[package]] -name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - -[[package]] -name = "rustix" -version = "0.38.38" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa260229e6538e52293eeb577aabd09945a09d6d9cc0fc550ed7529056c2e32a" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.52.0", -] - -[[package]] -name = "rustls" -version = "0.21.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" -dependencies = [ - "log", - "ring", - "rustls-webpki", - "sct", -] - -[[package]] -name = "rustls-webpki" -version = "0.101.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" -dependencies = [ - "ring", - "untrusted", -] - -[[package]] -name = "rustversion" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" - -[[package]] -name = "ryu" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" - -[[package]] -name = "scheduled-thread-pool" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" -dependencies = [ - "parking_lot", -] - -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring", - "untrusted", -] - -[[package]] -name = "secrecy" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e891af845473308773346dc847b2c23ee78fe442e0472ac50e22a18a93d3ae5a" -dependencies = [ - "zeroize", -] - -[[package]] -name = "serde" -version = "1.0.214" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.214" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.132" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" -dependencies = [ - "itoa", - "memchr", - "ryu", - "serde", -] - -[[package]] -name = "serde_spanned" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" -dependencies = [ - "serde", -] - -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - -[[package]] -name = "signal-hook-registry" -version = "1.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" -dependencies = [ - "libc", -] - -[[package]] -name = "slab" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] - -[[package]] -name = "smallvec" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" - -[[package]] -name = "socket2" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - -[[package]] -name = "stable-pattern" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4564168c00635f88eaed410d5efa8131afa8d8699a612c80c455a0ba05c21045" -dependencies = [ - "memchr", -] - -[[package]] -name = "state" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b8c4a4445d81357df8b1a650d0d0d6fbbbfe99d064aa5e02f3e4022061476d8" -dependencies = [ - "loom", -] - -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - -[[package]] -name = "syn" -version = "2.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89275301d38033efb81a6e60e3497e734dfcc62571f2854bf4b16690398824c" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "tempfile" -version = "3.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" -dependencies = [ - "cfg-if", - "fastrand", - "once_cell", - "rustix", - "windows-sys 0.59.0", -] - -[[package]] -name = "thread_local" -version = "1.1.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" -dependencies = [ - "cfg-if", - "once_cell", -] - -[[package]] -name = "time" -version = "0.3.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" -dependencies = [ - "deranged", - "itoa", - "num-conv", - "powerfmt", - "serde", - "time-core", - "time-macros", -] - -[[package]] -name = "time-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" - -[[package]] -name = "time-macros" -version = "0.2.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" -dependencies = [ - "num-conv", - "time-core", -] - -[[package]] -name = "tinyvec" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" -dependencies = [ - "tinyvec_macros", -] - -[[package]] -name = "tinyvec_macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" - -[[package]] -name = "tokio" -version = "1.41.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" -dependencies = [ - "backtrace", - "bytes", - "libc", - "mio", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.52.0", -] - -[[package]] -name = "tokio-macros" -version = "2.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tokio-stream" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f4e6ce100d0eb49a2734f8c0812bcd324cf357d21810932c5df6b96ef2b86f1" -dependencies = [ - "futures-core", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "tokio-util" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" -dependencies = [ - "bytes", - "futures-core", - "futures-sink", - "pin-project-lite", - "tokio", -] - -[[package]] -name = "toml" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - -[[package]] -name = "toml_datetime" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" -dependencies = [ - "serde", -] - -[[package]] -name = "toml_edit" -version = "0.22.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" -dependencies = [ - "indexmap", - "serde", - "serde_spanned", - "toml_datetime", - "winnow", -] - -[[package]] -name = "tower-service" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" - -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "try-lock" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "ubyte" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f720def6ce1ee2fc44d40ac9ed6d3a59c361c80a75a7aa8e75bb9baed31cf2ea" -dependencies = [ - "serde", -] - -[[package]] -name = "uncased" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b88fcfe09e89d3866a5c11019378088af2d24c3fbd4f0543f96b479ec90697" -dependencies = [ - "serde", - "version_check", -] - -[[package]] -name = "unicase" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" - -[[package]] -name = "unicode-bidi" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" - -[[package]] -name = "unicode-ident" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" - -[[package]] -name = "unicode-normalization" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" -dependencies = [ - "tinyvec", -] - -[[package]] -name = "unicode-xid" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" - -[[package]] -name = "universal-hash" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" -dependencies = [ - "crypto-common", - "subtle", -] - -[[package]] -name = "untrusted" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" - -[[package]] -name = "url" -version = "2.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - -[[package]] -name = "vcpkg" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" - -[[package]] -name = "version_check" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" - -[[package]] -name = "want" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" -dependencies = [ - "try-lock", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "wasm-bindgen" -version = "0.2.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" -dependencies = [ - "cfg-if", - "once_cell", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" - -[[package]] -name = "webpki-roots" -version = "0.25.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" - -[[package]] -name = "winapi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" -dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", -] - -[[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - -[[package]] -name = "windows" -version = "0.48.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" -dependencies = [ - "windows-targets 0.48.5", -] - -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-sys" -version = "0.59.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" -dependencies = [ - "windows-targets 0.52.6", -] - -[[package]] -name = "windows-targets" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" -dependencies = [ - "windows_aarch64_gnullvm 0.48.5", - "windows_aarch64_msvc 0.48.5", - "windows_i686_gnu 0.48.5", - "windows_i686_msvc 0.48.5", - "windows_x86_64_gnu 0.48.5", - "windows_x86_64_gnullvm 0.48.5", - "windows_x86_64_msvc 0.48.5", -] - -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm 0.52.6", - "windows_aarch64_msvc 0.52.6", - "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", - "windows_i686_msvc 0.52.6", - "windows_x86_64_gnu 0.52.6", - "windows_x86_64_gnullvm 0.52.6", - "windows_x86_64_msvc 0.52.6", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_i686_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.48.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - -[[package]] -name = "winnow" -version = "0.6.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" -dependencies = [ - "memchr", -] - -[[package]] -name = "yansi" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" -dependencies = [ - "is-terminal", -] - -[[package]] -name = "zerocopy" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" -dependencies = [ - "byteorder", - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.7.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zeroize" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/Cargo.toml b/Cargo.toml index 6e46659..35d7340 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ rand = { version = "0.8.0", default-features = false } sha2 = { version = "0.10.7", default-features = false } secrecy = { version = "0.10.3", default-features = false } minreq = { version = "2.12.0", default-features = false, features = ["json-using-serde", "https"] } -chrono = { version = "0.4.38", default-feature = false, features = ["serde"] } -diesel = { version = "2.1.4", default-feature = false, features = ["postgres", "serde_json", "chrono"] } +chrono = { version = "0.4.38", features = ["serde"] } +diesel = { version = "2.1.4", features = ["postgres", "serde_json", "chrono"] } aes-gcm = { version = "0.10.3", default-features = false } +toml = { version = "0.8.19", default-features = false } diff --git a/README.md b/README.md new file mode 100644 index 0000000..e6549ed --- /dev/null +++ b/README.md @@ -0,0 +1,17 @@ +## Config +A `config.toml` file can be created in the root directory of the project. The file is used to configure the server. + +### CORS +The `cors` section of the config file is used to configure the CORS settings for the server. If the request origin is not in the list of allowed origins, the server will respond with a 403 Forbidden status code. + +If the `allowed_methods` field is set, the server will respond with a 405 Method Not Allowed status code if the request method is not in the list of allowed methods. + +```toml +[cors] +# Allowed origins for CORS requests (exact match) +allowed_origins_exact = ["http://localhost:8000"] +# Allowed origins for CORS requests (regular expression) +allowed_origins_regex = ["^http://localhost:\\d{4}$"] +# Allowed methods for CORS requests +allowed_methods = ["GET", "POST", "PUT", "DELETE"] +``` \ No newline at end of file diff --git a/init_db.sh b/init_db.sh deleted file mode 100644 index 7c80640..0000000 --- a/init_db.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash - -# Check if postgres is installed -if ! command -v psql >/dev/null 2>&1; then - echo "PostgreSQL is not installed" - exit 1 -fi - -# Check if postgres is running -if ! pg_isready >/dev/null 2>&1; then - echo "PostgreSQL is not running" - exit 1 -fi - -# Check if database roops already exists -if psql -lqt | cut -d \| -f 1 | grep -qw roops; then - echo "Database roops already exists" - exit 1 -fi - -USER=$1 -PASSWORD=$2 - -# Check if user is provided -if [ -z "$USER" ]; then - echo "Invalid arguments, usage: ./init_db.sh (password)" - exit 1 -fi - -# Create a password if one wasn't passed -if [ -z "$PASSWORD" ]; then - PASSWORD=`openssl rand -base64 8` -fi - -# Check if user already exists -if psql postgres -t -c "SELECT 1 FROM pg_roles WHERE rolname='$USER'" | grep -qw 1; then - echo "User $USER already exists" - exit 1 -fi - -# Initialize the user and database -psql -q postgres < Self { + toml::from_str(include_str!("../config.toml")).unwrap_or_default() + } + + /// Get CORS fairing + pub(crate) fn cors_fairing(&self) -> RocketCors { + let Some(cors) = &self.cors else { + return CorsOptions::default() + .to_cors() + .expect("CORS fairing cannot be created"); + }; + + CorsOptions::default() + .allow_credentials(true) + .allowed_origins(cors.allowed_origins()) + .allowed_methods(cors.allowed_methods()) + .to_cors() + .expect("CORS fairing cannot be created") + } +} + +impl Cors { + /// Get allowed origins for CORS requests + fn allowed_origins(&self) -> AllowedOrigins { + match (&self.allowed_origins_exact, &self.allowed_origins_regex) { + (None, None) => AllowedOrigins::all(), + (Some(exact), None) => AllowedOrigins::some_exact(exact), + (None, Some(regex)) => AllowedOrigins::some_regex(regex), + (Some(exact), Some(regex)) => AllowedOrigins::some(exact, regex), + } + } + + /// Get allowed methods for CORS requests + fn allowed_methods(&self) -> AllowedMethods { + let Some(methods) = &self.allowed_methods else { + return AllowedMethods::default(); + }; + + methods + .iter() + .map(|s| FromStr::from_str(s).expect("Failed to parse method")) + .collect() + } +} + +/// The server configuration +#[derive(Deserialize, Debug, Default)] +#[serde(crate = "rocket::serde")] +pub(crate) struct Config { + /// CORS configuration + pub(crate) cors: Option, +} + +/// CORS configuration +#[derive(Deserialize, Debug, Default)] +#[serde(crate = "rocket::serde")] +pub(crate) struct Cors { + /// Allowed origins for CORS requests (exact match) + #[serde(default)] + allowed_origins_exact: Option>, + /// Allowed origins for CORS requests (regular expression) + #[serde(default)] + allowed_origins_regex: Option>, + /// Allowed methods for CORS requests + #[serde(default)] + allowed_methods: Option>, +} diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs index 7dc3a63..9996b3c 100644 --- a/src/database/wrappers/account_links/mod.rs +++ b/src/database/wrappers/account_links/mod.rs @@ -1,10 +1,9 @@ -#![allow(unused_variables)] - pub(crate) mod models; mod schema; use self::models::{AccountLink, NewAccountLink}; use diesel::prelude::*; +use std::borrow::Borrow; pub(crate) struct AccountLinksDb; @@ -20,18 +19,22 @@ impl AccountLinksDb { /// # Returns /// /// The newly created user if successful, `None` otherwise. - pub(crate) fn insert_one( - new_account_link: &NewAccountLink, + pub(crate) fn insert_one( + new_account_link: V, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> { + ) -> Result, diesel::result::Error> + where + V: Borrow, + { use self::schema::account_links::dsl::*; + let new_account_link = new_account_link.borrow(); - // Already exists - let pk = ( - new_account_link.discord_uid.as_str(), - new_account_link.roblox_uid, - ); + let pk = AccountLinkPk { + discord_uid: new_account_link.discord_uid.as_str(), + roblox_uid: new_account_link.roblox_uid, + }; + // Already exists if Self::find_one(pk, conn).is_some() { rocket::warn!("Account link already exists: {:?}", new_account_link); return Ok(None); @@ -70,10 +73,13 @@ impl AccountLinksDb { /// # Returns /// /// `true` if the account link(s) were successfully deleted, `false` otherwise. - pub(crate) fn delete_many(target_id: UserId, conn: &mut PgConnection) -> bool { + pub(crate) fn delete_many<'a, PartialPK>(target_id: PartialPK, conn: &mut PgConnection) -> bool + where + PartialPK: Borrow>, + { use schema::account_links::dsl::*; - match target_id { + match target_id.borrow() { UserId::DiscordMarker(t_id) => diesel::delete(account_links) .filter(discord_uid.eq(t_id)) .execute(conn) @@ -94,14 +100,18 @@ impl AccountLinksDb { /// # Returns /// /// The account link if it exists, `None` otherwise. - pub(crate) fn find_primary( - discord_uid: String, + pub(crate) fn find_primary( + pk_discord_uid: PartialPK, conn: &mut PgConnection, - ) -> Option { + ) -> Option + where + PartialPK: AsRef, + { use schema::account_links::dsl::*; + let pk_discord_uid = pk_discord_uid.as_ref(); account_links - .filter(discord_uid.eq(discord_uid)) + .filter(discord_uid.eq(pk_discord_uid)) .filter(is_primary.eq(true)) .first::(conn) .ok() @@ -116,13 +126,16 @@ impl AccountLinksDb { /// # Returns /// /// The account link if it exists, `None` otherwise. - pub(crate) fn find_one(pk: (&str, i64), conn: &mut PgConnection) -> Option { + pub(crate) fn find_one<'a, PK>(pk: PK, conn: &mut PgConnection) -> Option + where + PK: Borrow>, + { use schema::account_links::dsl::*; - let (pk_discord_uid, pk_roblox_uid) = pk; + let pk = pk.borrow(); account_links - .filter(discord_uid.eq(pk_discord_uid)) - .filter(roblox_uid.eq(pk_roblox_uid)) + .filter(discord_uid.eq(pk.discord_uid)) + .filter(roblox_uid.eq(pk.roblox_uid)) .first::(conn) .ok() } @@ -136,13 +149,16 @@ impl AccountLinksDb { /// # Returns /// /// The corresponding account links associated with the user ID. - pub(crate) fn find_many( - target_id: UserId, + pub(crate) fn find_many<'a, PartialPK>( + target_id: PartialPK, conn: &mut PgConnection, - ) -> Option> { + ) -> Option> + where + PartialPK: Borrow>, + { use schema::account_links::dsl::*; - match target_id { + match target_id.borrow() { UserId::DiscordMarker(t_id) => account_links .filter(discord_uid.eq(t_id)) .load::(conn) @@ -163,21 +179,24 @@ impl AccountLinksDb { /// # Returns /// /// `true` if the primary account link was successfully updated, `false` otherwise. - pub(crate) fn set_primary(pk: (&str, i64), conn: &mut PgConnection) -> bool { + pub(crate) fn set_primary<'a, PK>(pk: PK, conn: &mut PgConnection) -> bool + where + PK: Borrow>, + { use schema::account_links::dsl::*; - let (pk_discord_uid, pk_roblox_uid) = pk; + let pk = pk.borrow(); conn.transaction(|conn| { // Set all other account links to not primary diesel::update(account_links) - .filter(discord_uid.eq(pk_discord_uid)) + .filter(discord_uid.eq(pk.discord_uid)) .set(is_primary.eq(false)) .execute(conn)?; // Set the new primary account link diesel::update(account_links) - .filter(discord_uid.eq(pk_discord_uid)) - .filter(roblox_uid.eq(pk_roblox_uid)) + .filter(discord_uid.eq(pk.discord_uid)) + .filter(roblox_uid.eq(pk.roblox_uid)) .set(is_primary.eq(true)) .execute(conn)?; @@ -187,7 +206,12 @@ impl AccountLinksDb { } } -pub(crate) enum UserId { - DiscordMarker(String), +pub(crate) enum UserId<'a> { + DiscordMarker(&'a str), RobloxMarker(i64), } + +pub(crate) struct AccountLinkPk<'a> { + pub(crate) discord_uid: &'a str, + pub(crate) roblox_uid: i64, +} diff --git a/src/database/wrappers/account_links/models.rs b/src/database/wrappers/account_links/models.rs index 153a8bb..32b8ccd 100644 --- a/src/database/wrappers/account_links/models.rs +++ b/src/database/wrappers/account_links/models.rs @@ -18,30 +18,20 @@ pub(crate) struct NewAccountLink { pub(crate) is_primary: bool, } -#[derive(Debug)] -pub(crate) struct NewAccountLinkBuilder { +#[derive(Debug, Default)] +pub(crate) struct AccountLinkBuilder { roblox_uid: Option, discord_uid: Option, is_primary: Option, } -impl Default for NewAccountLinkBuilder { - fn default() -> Self { - Self { - roblox_uid: None, - discord_uid: None, - is_primary: None, - } - } -} - impl NewAccountLink { - pub(crate) fn build() -> NewAccountLinkBuilder { - NewAccountLinkBuilder::default() + pub(crate) fn build() -> AccountLinkBuilder { + AccountLinkBuilder::default() } } -impl NewAccountLinkBuilder { +impl AccountLinkBuilder { pub(crate) fn new() -> Self { Self::default() } diff --git a/src/database/wrappers/discord_connections/mod.rs b/src/database/wrappers/discord_connections/mod.rs index af19f7f..f4b4437 100644 --- a/src/database/wrappers/discord_connections/mod.rs +++ b/src/database/wrappers/discord_connections/mod.rs @@ -1,20 +1,23 @@ -#![allow(unused_variables)] - pub(crate) mod models; mod schema; use self::models::{DiscordConnection, NewDiscordConnection}; use crate::database::wrappers::discord_connections::models::UpdateDiscordConnection; use diesel::prelude::*; +use std::borrow::Borrow; pub(crate) struct DiscordConnectionsDb; impl DiscordConnectionsDb { - pub(crate) fn insert_one( - new_conn: &NewDiscordConnection, + pub(crate) fn insert_one( + new_conn: V, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> { + ) -> Result, diesel::result::Error> + where + V: Borrow, + { use self::schema::discord_connections; + let new_conn = new_conn.borrow(); // Already exists if Self::find_one(&new_conn.uid, conn).is_some() { @@ -29,34 +32,49 @@ impl DiscordConnectionsDb { .map(Some) } - pub(crate) fn find_one(uid: &str, conn: &mut PgConnection) -> Option { + pub(crate) fn find_one(pk_uid: PK, conn: &mut PgConnection) -> Option + where + PK: AsRef, + { use schema::discord_connections::dsl::*; + let pk_uid = pk_uid.as_ref(); discord_connections - .filter(uid.eq(uid)) + .filter(uid.eq(pk_uid)) .first::(conn) .ok() } - pub(crate) fn delete_one(uid: &str, conn: &mut PgConnection) -> bool { + pub(crate) fn delete_one(pk_uid: PK, conn: &mut PgConnection) -> bool + where + PK: AsRef, + { use schema::discord_connections::dsl::*; + let pk_uid = pk_uid.as_ref(); diesel::delete(discord_connections) - .filter(uid.eq(uid)) + .filter(uid.eq(pk_uid)) .execute(conn) .map_err(|e| rocket::error!("Failed to delete Discord connection: {:?}", e)) .is_ok() } - pub(crate) fn update_one( - uid: &str, - new_conn: &UpdateDiscordConnection, + pub(crate) fn update_one( + pk_uid: PK, + new_conn: V, conn: &mut PgConnection, - ) -> Option { + ) -> Option + where + PK: AsRef, + V: Borrow, + { use schema::discord_connections::dsl::*; + let pk_uid = pk_uid.as_ref(); + let new_conn = new_conn.borrow(); + diesel::update(discord_connections) - .filter(uid.eq(uid)) + .filter(uid.eq(pk_uid)) .set(new_conn) .returning(DiscordConnection::as_returning()) .get_result(conn) diff --git a/src/database/wrappers/discord_connections/models.rs b/src/database/wrappers/discord_connections/models.rs index 44b178c..c9b0c6d 100644 --- a/src/database/wrappers/discord_connections/models.rs +++ b/src/database/wrappers/discord_connections/models.rs @@ -1,7 +1,7 @@ use super::schema; use diesel::prelude::*; -#[derive(Queryable, Selectable)] +#[derive(Queryable, Selectable, Debug)] #[diesel(table_name = schema::discord_connections)] #[diesel(check_for_backend(diesel::pg::Pg))] pub(crate) struct DiscordConnection { @@ -37,6 +37,7 @@ pub(crate) struct UpdateDiscordConnection { pub(crate) scope: Option, } +#[derive(Default)] pub(crate) struct DiscordConnectionBuilder { uid: Option, access_token: Option, @@ -47,21 +48,13 @@ pub(crate) struct DiscordConnectionBuilder { scope: Option, } -impl Default for DiscordConnectionBuilder { - fn default() -> Self { - Self { - uid: None, - access_token: None, - access_token_nonce: None, - expires_at: None, - refresh_token: None, - refresh_token_nonce: None, - scope: None, - } +impl NewDiscordConnection { + pub(crate) fn build() -> DiscordConnectionBuilder { + DiscordConnectionBuilder::default() } } -impl NewDiscordConnection { +impl UpdateDiscordConnection { pub(crate) fn build() -> DiscordConnectionBuilder { DiscordConnectionBuilder::default() } diff --git a/src/database/wrappers/sessions/mod.rs b/src/database/wrappers/sessions/mod.rs index 0c1235c..a305a62 100644 --- a/src/database/wrappers/sessions/mod.rs +++ b/src/database/wrappers/sessions/mod.rs @@ -1,23 +1,27 @@ -#![allow(unused_variables)] - pub(crate) mod models; mod schema; use self::models::{NewSession, Session}; +use crate::database::wrappers::sessions::models::UpdateSession; use diesel::prelude::*; +use std::borrow::Borrow; pub(crate) struct SessionsDb; impl SessionsDb { - pub(crate) fn insert_one( - new_session: &NewSession, + pub(crate) fn insert_one( + new_session: V, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> { + ) -> Result, diesel::result::Error> + where + V: Borrow, + { use self::schema::sessions; + let new_session = new_session.borrow(); // Already exists - if Self::find_one(&new_session.session_token, conn).is_some() { - rocket::warn!("Session already exists: {:?}", new_session.session_token); + if Self::find_one(&new_session.session_id, conn).is_some() { + rocket::warn!("Session already exists: {:?}", new_session.session_id); return Ok(None); } @@ -28,20 +32,50 @@ impl SessionsDb { .map(Some) } - pub(crate) fn find_one(session_token: &str, conn: &mut PgConnection) -> Option { + pub(crate) fn find_one(pk_session_id: PK, conn: &mut PgConnection) -> Option + where + PK: AsRef, + { use schema::sessions::dsl::*; + let pk_session_id = pk_session_id.as_ref(); sessions - .filter(session_token.eq(session_token)) + .filter(session_id.eq(pk_session_id)) .first::(conn) .ok() } - pub(crate) fn delete_one(session_token: &str, conn: &mut PgConnection) -> bool { + pub(crate) fn delete_one(pk_session_id: PK, conn: &mut PgConnection) -> bool + where + PK: AsRef, + { use schema::sessions::dsl::*; + let pk_session_id = pk_session_id.as_ref(); - diesel::delete(sessions.filter(session_token.eq(session_token))) + diesel::delete(sessions.filter(session_id.eq(pk_session_id))) .execute(conn) .is_ok() } + + pub(crate) fn update_one( + pk_session_id: PK, + updated_session: V, + conn: &mut PgConnection, + ) -> Option + where + PK: AsRef, + V: Borrow, + { + use schema::sessions::dsl::*; + + let pk_session_id = pk_session_id.as_ref(); + let updated_session = updated_session.borrow(); + + diesel::update(sessions) + .filter(session_id.eq(pk_session_id)) + .set(updated_session) + .returning(Session::as_returning()) + .get_result(conn) + .ok() + } } diff --git a/src/database/wrappers/sessions/models.rs b/src/database/wrappers/sessions/models.rs index 0c833ee..6094933 100644 --- a/src/database/wrappers/sessions/models.rs +++ b/src/database/wrappers/sessions/models.rs @@ -5,7 +5,7 @@ use diesel::prelude::*; #[diesel(table_name = schema::sessions)] #[diesel(check_for_backend(diesel::pg::Pg))] pub(crate) struct Session { - pub(crate) session_token: String, + pub(crate) session_id: String, pub(crate) discord_uid: String, pub(crate) expires_at: chrono::NaiveDateTime, } @@ -13,35 +13,39 @@ pub(crate) struct Session { #[derive(Insertable, Debug)] #[diesel(table_name = schema::sessions)] pub(crate) struct NewSession { - pub(crate) session_token: String, + pub(crate) session_id: String, pub(crate) discord_uid: String, pub(crate) expires_at: chrono::NaiveDateTime, } -#[derive(Debug)] -pub(crate) struct NewSessionBuilder { - session_token: Option, +#[derive(AsChangeset, Debug)] +#[diesel(table_name = schema::sessions)] +pub(crate) struct UpdateSession { + pub(crate) session_id: Option, + pub(crate) discord_uid: Option, + pub(crate) expires_at: Option, +} + +#[derive(Debug, Default)] +pub(crate) struct SessionBuilder { + session_id: Option, discord_uid: Option, expires_at: Option, } -impl Default for NewSessionBuilder { - fn default() -> Self { - Self { - discord_uid: None, - session_token: None, - expires_at: None, - } +impl NewSession { + pub(crate) fn build() -> SessionBuilder { + SessionBuilder::default() } } -impl NewSession { - pub(crate) fn build() -> NewSessionBuilder { - NewSessionBuilder::default() +impl UpdateSession { + pub(crate) fn build() -> SessionBuilder { + SessionBuilder::default() } } -impl NewSessionBuilder { +impl SessionBuilder { pub(crate) fn new() -> Self { Self::default() } @@ -51,8 +55,8 @@ impl NewSessionBuilder { self } - pub(crate) fn session_token(mut self, session_token: String) -> Self { - self.session_token = Some(session_token); + pub(crate) fn session_id(mut self, session_id: String) -> Self { + self.session_id = Some(session_id); self } @@ -64,8 +68,16 @@ impl NewSessionBuilder { pub(crate) fn build(self) -> NewSession { NewSession { discord_uid: self.discord_uid.expect("discord_uid not set"), - session_token: self.session_token.expect("session_token not set"), + session_id: self.session_id.expect("session_id not set"), expires_at: self.expires_at.expect("expires_at not set"), } } + + pub(crate) fn build_update(self) -> UpdateSession { + UpdateSession { + discord_uid: self.discord_uid, + session_id: self.session_id, + expires_at: self.expires_at, + } + } } diff --git a/src/database/wrappers/sessions/schema.rs b/src/database/wrappers/sessions/schema.rs index cc6d22f..7a0dc8e 100644 --- a/src/database/wrappers/sessions/schema.rs +++ b/src/database/wrappers/sessions/schema.rs @@ -1,6 +1,6 @@ diesel::table! { - sessions (session_token) { - session_token -> Text, + sessions (session_id) { + session_id -> Text, discord_uid -> Text, expires_at -> Timestamp, } diff --git a/src/lib.rs b/src/lib.rs index 30da20b..f9a096f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ +#![allow(dead_code)] + mod cipher; +mod config; mod database; mod oauth; mod utils; @@ -12,20 +15,12 @@ use rocket::Request; #[macro_use] extern crate rocket_sync_db_pools; -extern crate rocket_cors; -use rocket_cors::{Cors, CorsOptions}; - +use crate::config::Config; use dotenvy::dotenv; #[database("roops")] pub(crate) struct DbConn(diesel::PgConnection); -fn cors_fairing() -> Cors { - CorsOptions::default() - .to_cors() - .expect("Cors fairing cannot be created") -} - #[catch(default)] fn default(status: Status, _req: &Request<'_>) -> Value { json!({ @@ -37,8 +32,9 @@ fn default(status: Status, _req: &Request<'_>) -> Value { #[launch] pub fn rocket() -> _ { dotenv().ok(); + let cfg = Config::load(); rocket::build() - .attach(cors_fairing()) + .attach(cfg.cors_fairing()) .attach(DbConn::fairing()) .mount("/v1/oauth2", oauth::routes::routes()) .register("/", catchers![default]) diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 3476fc5..4a8af42 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -3,3 +3,4 @@ mod types; mod utils; const OAUTH_STATE_COOKIE_NAME: &str = "oauth_state"; +const SESSION_COOKIE_NAME: &str = "sessionid"; diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index 05fcde6..d3d3841 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -1,11 +1,17 @@ -use crate::database::wrappers::discord_connections::models::NewDiscordConnection; +use crate::cipher::{decrypt, EncryptedData}; +use crate::database::wrappers::discord_connections::models::{ + NewDiscordConnection, UpdateDiscordConnection, +}; use crate::database::wrappers::discord_connections::DiscordConnectionsDb; use crate::database::wrappers::sessions::models::NewSession; use crate::database::wrappers::sessions::SessionsDb; -use crate::oauth::types::discord::{DiscordOAuthScopeSet, DiscordOAuthScopes}; +use crate::oauth::routes::SessionId; +use crate::oauth::types::discord::{ + DiscordAuthorizedUserResponse, DiscordOAuthScopeSet, DiscordOAuthScopes, +}; use crate::oauth::types::OAuthCallback; use crate::oauth::utils::discord::{ - construct_discord_oauth_url, exchange_code, DISCORD_AUTHORIZED_USER_ENDPOINT, + construct_discord_oauth_url, exchange_code, refresh_token, DISCORD_AUTHORIZED_USER_ENDPOINT, }; use crate::oauth::utils::{generate_session, generate_state}; use crate::oauth::OAUTH_STATE_COOKIE_NAME; @@ -13,7 +19,6 @@ use crate::DbConn; use diesel::Connection; use rocket::http::{Cookie, CookieJar, SameSite, Status}; use rocket::response::Redirect; -use rocket::serde::Deserialize; use rocket::time::Duration; #[get("/initiate/discord?")] @@ -25,7 +30,6 @@ pub(super) fn discord_oauth_initiate( let state = generate_state(); let redirect_uri = construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state); - // Set the cookie lifetime to 5 minutes let auth_cookie = Cookie::build((OAUTH_STATE_COOKIE_NAME, state)) .path("/v1/oauth2/callback/discord") .same_site(SameSite::Lax) @@ -44,30 +48,24 @@ pub(super) async fn discord_oauth_callback( conn: DbConn, ) -> Result { // Verify the state against the one that was saved in the cookie - let is_authorized = jar + let is_valid_state = jar .get_pending(OAUTH_STATE_COOKIE_NAME) .map_or(false, |cookie| cookie.value() == callback.state); - if !is_authorized { - return Err(Status::Unauthorized); + if !is_valid_state { + return Err(Status::BadRequest); } // Use the code to obtain the token - let token_response = exchange_code(&callback.code).map_err(|e| { - rocket::error!("Failed to obtain Discord token: {:?}", e); - Status::InternalServerError - })?; + let response = exchange_code(&callback.code)?; // Fetch the authorized user let authorized_user = minreq::get(DISCORD_AUTHORIZED_USER_ENDPOINT) - .with_header( - "Authorization", - format!("Bearer {}", token_response.access_token), - ) + .with_header("Authorization", format!("Bearer {}", response.access_token)) .send() .map_err(|e| { rocket::error!("Failed to obtain Discord user info: {:?}", e); - Status::InternalServerError + Status::BadGateway })? .json::() .map(|r| r.user) @@ -80,22 +78,22 @@ pub(super) async fn discord_oauth_callback( // Parse the user info let discord_uid = authorized_user.expect("Missing 'identify' scope").id; - let session_token = generate_session(); + let session = generate_session(); let token_expires_at = - chrono::Utc::now().naive_utc() + chrono::Duration::seconds(token_response.expires_in); + chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); - let session = NewSession::build() + let new_session = NewSession::build() .discord_uid(discord_uid.clone()) - .session_token(session_token.clone()) + .session_id(session.session_id.clone()) .expires_at(token_expires_at) .build(); let discord_connection = NewDiscordConnection::build() .uid(discord_uid) - .access_token(token_response.access_token) + .access_token(response.access_token) .expires_at(token_expires_at) - .refresh_token(token_response.refresh_token) - .scope(token_response.scope) + .refresh_token(response.refresh_token) + .scope(response.scope) .build() .map_err(|e| { rocket::error!("Failed to build Discord connection: {:?}", e); @@ -103,16 +101,16 @@ pub(super) async fn discord_oauth_callback( })?; // Insert the session and Discord connection into the database - conn.run(move |conn| { + conn.run(|conn| { conn.transaction(|conn| { rocket::info!("Inserting Discord connection: {:?}", discord_connection); - DiscordConnectionsDb::insert_one(&discord_connection, conn).map_err(|e| { + DiscordConnectionsDb::insert_one(discord_connection, conn).map_err(|e| { rocket::error!("Failed to insert Discord connection: {:?}", e); e })?; - rocket::info!("Inserting session: {:?}", session); - SessionsDb::insert_one(&session, conn).map_err(|e| { + rocket::info!("Inserting session: {:?}", new_session); + SessionsDb::insert_one(new_session, conn).map_err(|e| { rocket::error!("Failed to insert session: {:?}", e); e })?; @@ -123,68 +121,75 @@ pub(super) async fn discord_oauth_callback( .await .map_err(|_| Status::InternalServerError)?; - rocket::info!("Storing session cookie: {}", &session_token); + rocket::info!("Storing session cookie: {}", session.session_id); // Save the current session as a cookie // This will be used to access the token through the database - let session_cookie = Cookie::build(("session", session_token)) - .path("/") - .same_site(SameSite::Strict) - .max_age(Duration::seconds(token_response.expires_in)); - jar.add_private(session_cookie); + jar.add_private(session.cookie); Ok(Redirect::to(uri!("/"))) } -/// The response from the Discord OAuth2 @me endpoint -#[derive(Deserialize, Debug)] -#[serde(crate = "rocket::serde")] -struct DiscordAuthorizedUserResponse { - /// The user who has authorized - /// - /// ⚠️ Requires the `identify` scope - user: Option, -} +#[post("/refresh-token/discord")] +pub(super) async fn discord_refresh_token( + conn: DbConn, + session_id: SessionId, +) -> Result { + let session_id = session_id.into_inner(); + rocket::info!("Session ID: {:?}", session_id); + + // Fetch the Discord UID from the session token + let discord_uid = conn + .run(|conn| SessionsDb::find_one(session_id, conn)) + .await + .map(|session| session.discord_uid) + .ok_or(Status::Unauthorized)?; + rocket::info!("Discord UID: {}", discord_uid); + + // Fetch the Discord connection from the database + let discord_connection = conn + .run(|conn| DiscordConnectionsDb::find_one(discord_uid, conn)) + .await + .ok_or(Status::InternalServerError)?; + rocket::info!("Discord connection: {:?}", discord_connection); + + // Decrypt the refresh token + let discord_refresh_token = decrypt(&EncryptedData { + data: discord_connection.refresh_token, + nonce: discord_connection.refresh_token_nonce, + }) + .map_err(|e| { + rocket::error!("Failed to decrypt refresh token: {:?}", e); + Status::InternalServerError + })?; -/// The user object represents a user profile on Discord. -/// [Reference](https://discord.com/developers/docs/resources/user#user-object) -#[derive(Deserialize, Debug)] -#[serde(crate = "rocket::serde")] -struct DiscordUser { - /// The user's ID - id: String, - /// The user's username, not unique across the platform - username: String, - /// The user's 4-digit discord-tag - discriminator: String, - /// The user's display name, if it is set. For bots, this is the application name - global_name: Option, - /// The user's [avatar hash](https://discord.com/developers/docs/reference#image-formatting) - avatar: Option, - /// Whether the user belongs to an OAuth2 application - bot: Option, - /// Whether the user is an Official Discord System user (part of the urgent message system) - system: Option, - /// Whether the user has two factor enabled on their account - mfa_enabled: Option, - /// The user's [banner hash](https://discord.com/developers/docs/reference#image-formatting) - banner: Option, - /// The user's banner color encoded as an integer representation of hexadecimal color code - accent_color: Option, - /// The user's chosen [language option](https://discord.com/developers/docs/reference#locales) - locale: Option, - /// Whether the email on this account has been verified - /// - /// ⚠️ Requires the `email` scope - verified: Option, - /// The user's email - /// - /// ⚠️ Requires the `email` scope - email: Option, - /// The [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account - flags: Option, - /// The [type of Nitro subscription](https://discord.com/developers/docs/resources/user#user-object-premium-types) on a user's account - premium_type: Option, - /// The public [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account - public_flags: Option, + // Refresh the token + let response = refresh_token(&discord_refresh_token)?; + let token_expires_at = + chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); + + let new_discord_connection = UpdateDiscordConnection::build() + .access_token(response.access_token) + .refresh_token(response.refresh_token) + .expires_at(token_expires_at) + .build_update() + .map_err(|e| { + rocket::error!("Failed to build Discord connection update: {:?}", e); + Status::InternalServerError + })?; + + // Update the Discord connection in the database + let success = conn + .run(|conn| { + rocket::info!("Updating Discord connection: {:?}", new_discord_connection); + DiscordConnectionsDb::update_one(discord_connection.uid, new_discord_connection, conn) + .is_some() + }) + .await; + + if success { + Ok(Status::NoContent) + } else { + Err(Status::InternalServerError) + } } diff --git a/src/oauth/routes/mod.rs b/src/oauth/routes/mod.rs index 5b567fc..9039543 100644 --- a/src/oauth/routes/mod.rs +++ b/src/oauth/routes/mod.rs @@ -1,11 +1,69 @@ +use crate::database::wrappers::sessions::models::UpdateSession; +use crate::database::wrappers::sessions::SessionsDb; +use crate::oauth::types::SessionId; +use crate::oauth::utils::generate_session; +use crate::DbConn; +use rocket::http::{CookieJar, Status}; +use std::sync::Arc; + mod discord; mod roblox; +#[post("/refresh-session")] +async fn refresh_session( + jar: &CookieJar<'_>, + conn: DbConn, + session_id: SessionId, +) -> Result { + let session_id = Arc::new(session_id.into_inner()); + let session_id_find = Arc::clone(&session_id); + + rocket::info!("Session ID: {:?}", session_id); + + conn.run(move |conn| { + let Some(_) = SessionsDb::find_one(session_id_find.as_str(), conn) else { + return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); + }; + + diesel::result::QueryResult::Ok(()) + }) + .await + .map_err(|e| { + rocket::warn!("Failed to find session: {:?}", e); + Status::Unauthorized + })?; + + let session = generate_session(); + let updated_session = UpdateSession::build() + .session_id(session.session_id.clone()) + .expires_at(session.expires_at) + .build_update(); + + conn.run(move |conn| { + let Some(session) = SessionsDb::update_one(session_id.as_str(), updated_session, conn) + else { + return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); + }; + + diesel::result::QueryResult::Ok(session) + }) + .await + .map(|session| rocket::info!("Updated session: {:?}", session)) + .map_err(|e| { + rocket::warn!("Failed to update session: {:?}", e); + Status::InternalServerError + })?; + + jar.add_private(session.cookie); + + Ok(Status::NoContent) +} + pub(crate) fn routes() -> Vec { routes![ discord::discord_oauth_initiate, discord::discord_oauth_callback, - roblox::roblox_oauth_initiate, - roblox::roblox_oauth_callback + discord::discord_refresh_token, + refresh_session, ] } diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index 0eb6133..cee1de3 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -14,7 +14,7 @@ pub(crate) enum DiscordOAuthScope { Identify, Guilds, GuildsChannelsRead, - RPC, + Rpc, RPCVoiceWrite, RPCScreenshareRead, WebhookIncoming, @@ -55,7 +55,7 @@ pub(crate) enum DiscordOAuthScope { RPCVideoWrite, } -#[derive(Debug, Serialize)] +#[derive(Serialize)] #[serde(crate = "rocket::serde")] pub(crate) struct DiscordAuthorizationCodeRequestBody<'a> { client_id: &'a str, @@ -65,7 +65,16 @@ pub(crate) struct DiscordAuthorizationCodeRequestBody<'a> { redirect_uri: &'a str, } -#[derive(Debug, Deserialize)] +#[derive(Serialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct DiscordTokenRefreshBody<'a> { + grant_type: &'a str, + refresh_token: &'a str, + client_id: &'a str, + client_secret: String, +} + +#[derive(Deserialize)] #[serde(crate = "rocket::serde")] pub(crate) struct DiscordAuthorizationCodeResponse { pub(crate) access_token: String, @@ -75,13 +84,66 @@ pub(crate) struct DiscordAuthorizationCodeResponse { pub(crate) scope: String, } +/// The response from the Discord OAuth2 @me endpoint +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +pub(crate) struct DiscordAuthorizedUserResponse { + /// The user who has authorized + /// + /// ⚠️ Requires the `identify` scope + pub(crate) user: Option, +} + +/// The user object represents a user profile on Discord. +/// [Reference](https://discord.com/developers/docs/resources/user#user-object) +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +pub(crate) struct DiscordUser { + /// The user's ID + pub(crate) id: String, + /// The user's username, not unique across the platform + pub(crate) username: String, + /// The user's 4-digit discord-tag + pub(crate) discriminator: String, + /// The user's display name, if it is set. For bots, this is the application name + pub(crate) global_name: Option, + /// The user's [avatar hash](https://discord.com/developers/docs/reference#image-formatting) + pub(crate) avatar: Option, + /// Whether the user belongs to an OAuth2 application + pub(crate) bot: Option, + /// Whether the user is an Official Discord System user (part of the urgent message system) + pub(crate) system: Option, + /// Whether the user has two factor enabled on their account + pub(crate) mfa_enabled: Option, + /// The user's [banner hash](https://discord.com/developers/docs/reference#image-formatting) + pub(crate) banner: Option, + /// The user's banner color encoded as an integer representation of hexadecimal color code + pub(crate) accent_color: Option, + /// The user's chosen [language option](https://discord.com/developers/docs/reference#locales) + pub(crate) locale: Option, + /// Whether the email on this account has been verified + /// + /// ⚠️ Requires the `email` scope + pub(crate) verified: Option, + /// The user's email + /// + /// ⚠️ Requires the `email` scope + pub(crate) email: Option, + /// The [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account + pub(crate) flags: Option, + /// The [type of Nitro subscription](https://discord.com/developers/docs/resources/user#user-object-premium-types) on a user's account + pub(crate) premium_type: Option, + /// The public [flags](https://discord.com/developers/docs/resources/user#user-object-user-flags) on a user's account + pub(crate) public_flags: Option, +} + impl From<&DiscordOAuthScope> for String { fn from(value: &DiscordOAuthScope) -> Self { match value { DiscordOAuthScope::Identify => String::from("identify"), DiscordOAuthScope::Guilds => String::from("guilds"), DiscordOAuthScope::GuildsChannelsRead => String::from("guilds.channels.read"), - DiscordOAuthScope::RPC => String::from("rpc"), + DiscordOAuthScope::Rpc => String::from("rpc"), DiscordOAuthScope::RPCVoiceWrite => String::from("rpc.voice.write"), DiscordOAuthScope::RPCScreenshareRead => String::from("rpc.screenshare.read"), DiscordOAuthScope::WebhookIncoming => String::from("webhook.incoming"), @@ -200,3 +262,24 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { ]) } } + +impl<'a> DiscordTokenRefreshBody<'a> { + pub(crate) fn new(refresh_token: &'a str) -> Self { + Self { + grant_type: "refresh_token", + refresh_token, + client_id: DISCORD_OAUTH_APP_CLIENT_ID, + client_secret: std::env::var("DISCORD_CLIENT_SECRET") + .expect("DISCORD_CLIENT_SECRET must be set"), + } + } + + pub(crate) fn as_query_params(&self) -> String { + url!([ + ("grant_type", self.grant_type), + ("refresh_token", self.refresh_token), + ("client_id", self.client_id), + ("client_secret", self.client_secret) + ]) + } +} diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index 52c0173..efd85b3 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -1,8 +1,44 @@ +use crate::oauth::SESSION_COOKIE_NAME; +use rocket::http::{Cookie, Status}; +use rocket::request::{FromRequest, Outcome}; +use rocket::Request; + pub(super) mod discord; pub(super) mod roblox; #[derive(Debug, FromForm)] pub(super) struct OAuthCallback { - pub(crate) code: String, - pub(crate) state: String, + pub(super) code: String, + pub(super) state: String, +} + +pub(super) struct GeneratedSession<'a> { + pub(super) cookie: Cookie<'a>, + pub(super) session_id: String, + pub(super) expires_at: chrono::NaiveDateTime, +} + +pub(super) struct SessionId(String); + +impl SessionId { + pub(super) fn into_inner(self) -> String { + self.0 + } +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SessionId { + type Error = Status; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + let session_id = request + .cookies() + .get_pending(SESSION_COOKIE_NAME) + .map(|cookie| cookie.value().to_string()); + + match session_id { + None => Outcome::Error((Status::Unauthorized, Status::Unauthorized)), + Some(session_id) => Outcome::Success(SessionId(session_id)), + } + } } diff --git a/src/oauth/utils/discord.rs b/src/oauth/utils/discord.rs index 9765350..54a8dc1 100644 --- a/src/oauth/utils/discord.rs +++ b/src/oauth/utils/discord.rs @@ -1,7 +1,9 @@ use crate::oauth::types::discord::{ DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, DiscordOAuthScopes, + DiscordTokenRefreshBody, }; use crate::url; +use rocket::http::Status; pub(crate) const DISCORD_OAUTH_REDIRECT_URI: &str = "http://localhost:8000/v1/oauth2/callback/discord"; @@ -23,12 +25,42 @@ pub(crate) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &s ) } -pub(crate) fn exchange_code(code: &str) -> Result { +pub(crate) fn exchange_code(code: &str) -> Result { let body = DiscordAuthorizationCodeRequestBody::new(code); let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) - .send()?; + .send() + .map_err(|e| { + rocket::error!("Failed to exchange Discord code for token: {:?}", e); + Status::BadGateway + })?; - response.json::() + response + .json::() + .map_err(|e| { + rocket::error!("Failed to parse Discord token response: {:?}", e); + Status::InternalServerError + }) +} + +pub(crate) fn refresh_token( + refresh_token: &str, +) -> Result { + let body = DiscordTokenRefreshBody::new(refresh_token); + let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) + .with_header("Content-Type", "application/x-www-form-urlencoded") + .with_body(body.as_query_params()) + .send() + .map_err(|e| { + rocket::error!("Failed to refresh Discord token: {:?}", e); + Status::BadGateway + })?; + + response + .json::() + .map_err(|e| { + rocket::error!("Failed to parse Discord token refresh response: {:?}", e); + Status::InternalServerError + }) } diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs index e126af0..83759b7 100644 --- a/src/oauth/utils/mod.rs +++ b/src/oauth/utils/mod.rs @@ -3,7 +3,10 @@ pub(super) mod pixy; pub(super) mod roblox; use self::pixy::generate_random_base64_url_safe_no_pad_string; +use crate::oauth::types::GeneratedSession; +use crate::oauth::SESSION_COOKIE_NAME; use rand::{thread_rng, Rng}; +use rocket::http::Cookie; pub(super) fn generate_state() -> String { let number_of_bytes = thread_rng().gen_range(10..=20); @@ -11,8 +14,21 @@ pub(super) fn generate_state() -> String { generate_random_base64_url_safe_no_pad_string(number_of_bytes) } -pub(super) fn generate_session() -> String { +fn generate_session_id() -> String { let number_of_bytes = thread_rng().gen_range(32..=64); generate_random_base64_url_safe_no_pad_string(number_of_bytes) } + +pub(super) fn generate_session() -> GeneratedSession<'static> { + let session_id = generate_session_id(); + let expires_at = chrono::Utc::now() + chrono::Duration::days(30); + + GeneratedSession { + cookie: Cookie::build((SESSION_COOKIE_NAME, session_id.clone())) + .max_age(rocket::time::Duration::days(30)) + .build(), + session_id, + expires_at: expires_at.naive_utc(), + } +} diff --git a/src/utils.rs b/src/utils.rs index 7361fcc..41849b6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -10,7 +10,7 @@ /// /// ``` /// # use roops_internal_api::url; -/// +/// # /// let url = url!("https://example.com", [("key1", "value1"), ("key2", "value2")]); /// assert_eq!(url, "https://example.com?key1=value1&key2=value2"); /// From 78b2dc1b9de325935416603516a1cc334e5690d7 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Sun, 1 Dec 2024 19:19:07 +0000 Subject: [PATCH 06/14] refactor(cfg): Remove unnecessary config file for diesel --- diesel.toml | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 diesel.toml diff --git a/diesel.toml b/diesel.toml deleted file mode 100644 index 19cfa66..0000000 --- a/diesel.toml +++ /dev/null @@ -1,2 +0,0 @@ -[migrations_directory] -dir = "migrations" From a0b8a03940b8063b9a8b90b720114dd852741720 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Sun, 1 Dec 2024 20:51:43 +0000 Subject: [PATCH 07/14] chore: Implement custom errors --- src/database/wrappers/account_links/mod.rs | 1 - .../wrappers/discord_connections/mod.rs | 3 - src/database/wrappers/sessions/mod.rs | 1 - src/lib.rs | 1 + src/oauth/routes/discord.rs | 107 ++++++++---------- src/oauth/routes/mod.rs | 24 +--- src/oauth/routes/roblox.rs | 24 ++-- src/oauth/types/discord.rs | 7 +- src/oauth/types/mod.rs | 8 +- src/oauth/types/roblox.rs | 7 +- src/oauth/utils/discord.rs | 52 +++++---- src/response.rs | 91 +++++++++++++++ 12 files changed, 202 insertions(+), 124 deletions(-) create mode 100644 src/response.rs diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs index 9996b3c..64a189c 100644 --- a/src/database/wrappers/account_links/mod.rs +++ b/src/database/wrappers/account_links/mod.rs @@ -36,7 +36,6 @@ impl AccountLinksDb { // Already exists if Self::find_one(pk, conn).is_some() { - rocket::warn!("Account link already exists: {:?}", new_account_link); return Ok(None); } diff --git a/src/database/wrappers/discord_connections/mod.rs b/src/database/wrappers/discord_connections/mod.rs index f4b4437..ee2b066 100644 --- a/src/database/wrappers/discord_connections/mod.rs +++ b/src/database/wrappers/discord_connections/mod.rs @@ -21,7 +21,6 @@ impl DiscordConnectionsDb { // Already exists if Self::find_one(&new_conn.uid, conn).is_some() { - rocket::warn!("Discord connection already exists: {:?}", new_conn.uid); return Ok(None); } @@ -55,7 +54,6 @@ impl DiscordConnectionsDb { diesel::delete(discord_connections) .filter(uid.eq(pk_uid)) .execute(conn) - .map_err(|e| rocket::error!("Failed to delete Discord connection: {:?}", e)) .is_ok() } @@ -78,7 +76,6 @@ impl DiscordConnectionsDb { .set(new_conn) .returning(DiscordConnection::as_returning()) .get_result(conn) - .map_err(|e| rocket::error!("Failed to update Discord connection: {:?}", e)) .ok() } } diff --git a/src/database/wrappers/sessions/mod.rs b/src/database/wrappers/sessions/mod.rs index a305a62..6ed595a 100644 --- a/src/database/wrappers/sessions/mod.rs +++ b/src/database/wrappers/sessions/mod.rs @@ -21,7 +21,6 @@ impl SessionsDb { // Already exists if Self::find_one(&new_session.session_id, conn).is_some() { - rocket::warn!("Session already exists: {:?}", new_session.session_id); return Ok(None); } diff --git a/src/lib.rs b/src/lib.rs index f9a096f..fb3a92c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ mod cipher; mod config; mod database; mod oauth; +mod response; mod utils; #[macro_use] diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index d3d3841..3b489de 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -6,15 +6,14 @@ use crate::database::wrappers::discord_connections::DiscordConnectionsDb; use crate::database::wrappers::sessions::models::NewSession; use crate::database::wrappers::sessions::SessionsDb; use crate::oauth::routes::SessionId; -use crate::oauth::types::discord::{ - DiscordAuthorizedUserResponse, DiscordOAuthScopeSet, DiscordOAuthScopes, -}; +use crate::oauth::types::discord::{DiscordOAuthScopeSet, DiscordOAuthScopes}; use crate::oauth::types::OAuthCallback; use crate::oauth::utils::discord::{ - construct_discord_oauth_url, exchange_code, refresh_token, DISCORD_AUTHORIZED_USER_ENDPOINT, + construct_discord_oauth_url, exchange_code, get_authorized_user, refresh_token, }; use crate::oauth::utils::{generate_session, generate_state}; use crate::oauth::OAUTH_STATE_COOKIE_NAME; +use crate::response::{ApiError, ApiResponse, ApiResult}; use crate::DbConn; use diesel::Connection; use rocket::http::{Cookie, CookieJar, SameSite, Status}; @@ -25,8 +24,10 @@ use rocket::time::Duration; pub(super) fn discord_oauth_initiate( scope_set: String, jar: &CookieJar<'_>, -) -> Result { - let scope_set = DiscordOAuthScopeSet::try_from(scope_set)?; +) -> ApiResult { + let scope_set = DiscordOAuthScopeSet::try_from(scope_set) + .map_err(|e| ApiError::message(Status::BadRequest, e))?; + let state = generate_state(); let redirect_uri = construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state); @@ -46,38 +47,25 @@ pub(super) async fn discord_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, conn: DbConn, -) -> Result { +) -> ApiResult { // Verify the state against the one that was saved in the cookie let is_valid_state = jar .get_pending(OAUTH_STATE_COOKIE_NAME) .map_or(false, |cookie| cookie.value() == callback.state); if !is_valid_state { - return Err(Status::BadRequest); + return Err(ApiError::message(Status::BadRequest, "Invalid state")); } // Use the code to obtain the token let response = exchange_code(&callback.code)?; - // Fetch the authorized user - let authorized_user = minreq::get(DISCORD_AUTHORIZED_USER_ENDPOINT) - .with_header("Authorization", format!("Bearer {}", response.access_token)) - .send() - .map_err(|e| { - rocket::error!("Failed to obtain Discord user info: {:?}", e); - Status::BadGateway - })? - .json::() - .map(|r| r.user) - .map_err(|e| { - rocket::error!("Failed to parse Discord user info: {}", e); - Status::InternalServerError - })?; - - rocket::info!("Discord user info: {:?}", authorized_user); + let authorized_user = get_authorized_user(&response.access_token)?; // Parse the user info - let discord_uid = authorized_user.expect("Missing 'identify' scope").id; + let discord_uid = authorized_user + .expect("Failed to unwrap user, missing 'identify' scope") + .id; let session = generate_session(); let token_expires_at = chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); @@ -95,33 +83,29 @@ pub(super) async fn discord_oauth_callback( .refresh_token(response.refresh_token) .scope(response.scope) .build() - .map_err(|e| { - rocket::error!("Failed to build Discord connection: {:?}", e); - Status::InternalServerError + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to build Discord connection", + ) })?; // Insert the session and Discord connection into the database conn.run(|conn| { conn.transaction(|conn| { - rocket::info!("Inserting Discord connection: {:?}", discord_connection); - DiscordConnectionsDb::insert_one(discord_connection, conn).map_err(|e| { - rocket::error!("Failed to insert Discord connection: {:?}", e); - e - })?; - - rocket::info!("Inserting session: {:?}", new_session); - SessionsDb::insert_one(new_session, conn).map_err(|e| { - rocket::error!("Failed to insert session: {:?}", e); - e - })?; + DiscordConnectionsDb::insert_one(discord_connection, conn)?; + SessionsDb::insert_one(new_session, conn)?; diesel::result::QueryResult::Ok(()) }) }) .await - .map_err(|_| Status::InternalServerError)?; - - rocket::info!("Storing session cookie: {}", session.session_id); + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to insert session or Discord connection into database", + ) + })?; // Save the current session as a cookie // This will be used to access the token through the database @@ -131,36 +115,35 @@ pub(super) async fn discord_oauth_callback( } #[post("/refresh-token/discord")] -pub(super) async fn discord_refresh_token( - conn: DbConn, - session_id: SessionId, -) -> Result { +pub(super) async fn discord_refresh_token(conn: DbConn, session_id: SessionId) -> ApiResult { let session_id = session_id.into_inner(); - rocket::info!("Session ID: {:?}", session_id); // Fetch the Discord UID from the session token let discord_uid = conn .run(|conn| SessionsDb::find_one(session_id, conn)) .await .map(|session| session.discord_uid) - .ok_or(Status::Unauthorized)?; - rocket::info!("Discord UID: {}", discord_uid); + .ok_or(ApiError::message(Status::Unauthorized, "Session not found"))?; // Fetch the Discord connection from the database let discord_connection = conn .run(|conn| DiscordConnectionsDb::find_one(discord_uid, conn)) .await - .ok_or(Status::InternalServerError)?; - rocket::info!("Discord connection: {:?}", discord_connection); + .ok_or(ApiError::message( + Status::InternalServerError, + "Discord connection not found", + ))?; // Decrypt the refresh token let discord_refresh_token = decrypt(&EncryptedData { data: discord_connection.refresh_token, nonce: discord_connection.refresh_token_nonce, }) - .map_err(|e| { - rocket::error!("Failed to decrypt refresh token: {:?}", e); - Status::InternalServerError + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to decrypt refresh token", + ) })?; // Refresh the token @@ -173,23 +156,27 @@ pub(super) async fn discord_refresh_token( .refresh_token(response.refresh_token) .expires_at(token_expires_at) .build_update() - .map_err(|e| { - rocket::error!("Failed to build Discord connection update: {:?}", e); - Status::InternalServerError + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to build Discord connection update", + ) })?; // Update the Discord connection in the database let success = conn .run(|conn| { - rocket::info!("Updating Discord connection: {:?}", new_discord_connection); DiscordConnectionsDb::update_one(discord_connection.uid, new_discord_connection, conn) .is_some() }) .await; if success { - Ok(Status::NoContent) + Ok(ApiResponse::status(Status::NoContent)) } else { - Err(Status::InternalServerError) + Err(ApiError::message( + Status::InternalServerError, + "Failed to update Discord connection", + )) } } diff --git a/src/oauth/routes/mod.rs b/src/oauth/routes/mod.rs index 9039543..f29eea3 100644 --- a/src/oauth/routes/mod.rs +++ b/src/oauth/routes/mod.rs @@ -2,6 +2,7 @@ use crate::database::wrappers::sessions::models::UpdateSession; use crate::database::wrappers::sessions::SessionsDb; use crate::oauth::types::SessionId; use crate::oauth::utils::generate_session; +use crate::response::{ApiError, ApiResponse, ApiResult}; use crate::DbConn; use rocket::http::{CookieJar, Status}; use std::sync::Arc; @@ -10,28 +11,18 @@ mod discord; mod roblox; #[post("/refresh-session")] -async fn refresh_session( - jar: &CookieJar<'_>, - conn: DbConn, - session_id: SessionId, -) -> Result { +async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionId) -> ApiResult { let session_id = Arc::new(session_id.into_inner()); let session_id_find = Arc::clone(&session_id); - rocket::info!("Session ID: {:?}", session_id); - conn.run(move |conn| { let Some(_) = SessionsDb::find_one(session_id_find.as_str(), conn) else { return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); }; - diesel::result::QueryResult::Ok(()) }) .await - .map_err(|e| { - rocket::warn!("Failed to find session: {:?}", e); - Status::Unauthorized - })?; + .map_err(|_| ApiError::message(Status::Unauthorized, "Session not found"))?; let session = generate_session(); let updated_session = UpdateSession::build() @@ -44,19 +35,14 @@ async fn refresh_session( else { return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); }; - diesel::result::QueryResult::Ok(session) }) .await - .map(|session| rocket::info!("Updated session: {:?}", session)) - .map_err(|e| { - rocket::warn!("Failed to update session: {:?}", e); - Status::InternalServerError - })?; + .map_err(|_| ApiError::message(Status::InternalServerError, "Failed to update session"))?; jar.add_private(session.cookie); - Ok(Status::NoContent) + Ok(ApiResponse::status(Status::NoContent)) } pub(crate) fn routes() -> Vec { diff --git a/src/oauth/routes/roblox.rs b/src/oauth/routes/roblox.rs index 02ad18a..f73a230 100644 --- a/src/oauth/routes/roblox.rs +++ b/src/oauth/routes/roblox.rs @@ -4,6 +4,7 @@ use crate::oauth::utils::generate_state; use crate::oauth::utils::pixy::Pixy; use crate::oauth::utils::roblox::construct_roblox_oauth_url; use crate::oauth::OAUTH_STATE_COOKIE_NAME; +use crate::response::{ApiError, ApiResult}; use rocket::http::{Cookie, CookieJar, Status}; use rocket::response::Redirect; use rocket::time::Duration; @@ -14,19 +15,22 @@ const OAUTH_VERIFIER_COOKIE_NAME: &str = "oauth_code_verifier"; pub(super) fn roblox_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, -) -> Result { +) -> ApiResult { // Verify the state against the one that was saved in the cookie - let is_authorized = jar + let is_valid_state = jar .get_pending(OAUTH_STATE_COOKIE_NAME) .map_or(false, |cookie| cookie.value() == callback.state); - if !is_authorized { - return Err(Status::Unauthorized); + if !is_valid_state { + return Err(ApiError::message(Status::BadRequest, "Invalid state")); } // Use verifier to obtain token let Some(_verifier_cookie) = jar.get_pending(OAUTH_VERIFIER_COOKIE_NAME) else { - return Err(Status::InternalServerError); + return Err(ApiError::message( + Status::InternalServerError, + "Missing verifier cookie", + )); }; // TODO: Obtain token through https://apis.roblox.com/oauth/v1/token @@ -38,13 +42,11 @@ pub(super) fn roblox_oauth_callback( } #[get("/initiate/roblox?")] -pub(super) fn roblox_oauth_initiate( - scope_set: String, - jar: &CookieJar<'_>, -) -> Result { - let scope_set = RobloxOAuthScopeSet::try_from(scope_set)?; - let pixy = Pixy::new(); +pub(super) fn roblox_oauth_initiate(scope_set: String, jar: &CookieJar<'_>) -> ApiResult { + let scope_set = RobloxOAuthScopeSet::try_from(scope_set) + .map_err(|e| ApiError::message(Status::BadRequest, e))?; + let pixy = Pixy::new(); let challenge = pixy.get_challenge(); let verifier = pixy.expose_verifier().to_string(); let state = generate_state(); diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index cee1de3..788bf9f 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -1,6 +1,5 @@ use crate::oauth::utils::discord::{DISCORD_OAUTH_APP_CLIENT_ID, DISCORD_OAUTH_REDIRECT_URI}; use crate::url; -use rocket::http::Status; use rocket::serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result as FmtResult}; @@ -230,12 +229,12 @@ impl From<&DiscordOAuthScopeSet> for DiscordOAuthScopes { } impl TryFrom for DiscordOAuthScopeSet { - type Error = Status; + type Error = String; - fn try_from(value: String) -> Result { + fn try_from(value: String) -> Result { match value.as_str() { "verification" => Ok(DiscordOAuthScopeSet::Verification), - _ => Err(Status::BadRequest), + _ => Err(format!("Invalid scope set: {}", value)), } } } diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index efd85b3..12ebd15 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -1,4 +1,5 @@ use crate::oauth::SESSION_COOKIE_NAME; +use crate::response::ApiError; use rocket::http::{Cookie, Status}; use rocket::request::{FromRequest, Outcome}; use rocket::Request; @@ -28,7 +29,7 @@ impl SessionId { #[rocket::async_trait] impl<'r> FromRequest<'r> for SessionId { - type Error = Status; + type Error = ApiError; async fn from_request(request: &'r Request<'_>) -> Outcome { let session_id = request @@ -37,7 +38,10 @@ impl<'r> FromRequest<'r> for SessionId { .map(|cookie| cookie.value().to_string()); match session_id { - None => Outcome::Error((Status::Unauthorized, Status::Unauthorized)), + None => { + let error = ApiError::message(Status::Unauthorized, "Missing sessionid cookie"); + Outcome::Error((Status::Unauthorized, error)) + } Some(session_id) => Outcome::Success(SessionId(session_id)), } } diff --git a/src/oauth/types/roblox.rs b/src/oauth/types/roblox.rs index 2d22682..dda40e6 100644 --- a/src/oauth/types/roblox.rs +++ b/src/oauth/types/roblox.rs @@ -1,4 +1,3 @@ -use rocket::http::Status; use std::fmt::{Display, Formatter, Result as FmtResult}; pub(crate) enum RobloxOAuthScopeSet { @@ -134,12 +133,12 @@ impl From<&RobloxOAuthScopeSet> for RobloxOAuthScopes { } impl TryFrom for RobloxOAuthScopeSet { - type Error = Status; + type Error = String; - fn try_from(value: String) -> Result { + fn try_from(value: String) -> Result { match value.as_str() { "verification" => Ok(RobloxOAuthScopeSet::Verification), - _ => Err(Status::BadRequest), + _ => Err(format!("Invalid scope set: {}", value)), } } } diff --git a/src/oauth/utils/discord.rs b/src/oauth/utils/discord.rs index 54a8dc1..a15fda2 100644 --- a/src/oauth/utils/discord.rs +++ b/src/oauth/utils/discord.rs @@ -1,14 +1,15 @@ use crate::oauth::types::discord::{ - DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, DiscordOAuthScopes, - DiscordTokenRefreshBody, + DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, + DiscordAuthorizedUserResponse, DiscordOAuthScopes, DiscordTokenRefreshBody, DiscordUser, }; +use crate::response::ApiError; use crate::url; use rocket::http::Status; pub(crate) const DISCORD_OAUTH_REDIRECT_URI: &str = "http://localhost:8000/v1/oauth2/callback/discord"; pub(crate) const DISCORD_OAUTH_APP_CLIENT_ID: &str = "1300538596348133499"; -pub(crate) const DISCORD_AUTHORIZED_USER_ENDPOINT: &str = "https://discord.com/api/v10/oauth2/@me"; +const DISCORD_AUTHORIZED_USER_ENDPOINT: &str = "https://discord.com/api/v10/oauth2/@me"; const DISCORD_OAUTH_ENDPOINT_URL: &str = "https://discord.com/oauth2/authorize"; const DISCORD_TOKEN_ENDPOINT_URL: &str = "https://discord.com/api/v10/oauth2/token"; @@ -25,42 +26,55 @@ pub(crate) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &s ) } -pub(crate) fn exchange_code(code: &str) -> Result { +pub(crate) fn exchange_code(code: &str) -> Result { let body = DiscordAuthorizationCodeRequestBody::new(code); let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) .send() - .map_err(|e| { - rocket::error!("Failed to exchange Discord code for token: {:?}", e); - Status::BadGateway - })?; + .map_err(|_| ApiError::message(Status::BadGateway, "Failed to exchange code for token"))?; response .json::() - .map_err(|e| { - rocket::error!("Failed to parse Discord token response: {:?}", e); - Status::InternalServerError + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse token response", + ) }) } pub(crate) fn refresh_token( refresh_token: &str, -) -> Result { +) -> Result { let body = DiscordTokenRefreshBody::new(refresh_token); let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) .send() - .map_err(|e| { - rocket::error!("Failed to refresh Discord token: {:?}", e); - Status::BadGateway - })?; + .map_err(|_| ApiError::message(Status::BadGateway, "Failed to refresh token"))?; response .json::() - .map_err(|e| { - rocket::error!("Failed to parse Discord token refresh response: {:?}", e); - Status::InternalServerError + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse token refresh response", + ) + }) +} + +pub(crate) fn get_authorized_user(access_token: &str) -> Result, ApiError> { + minreq::get(DISCORD_AUTHORIZED_USER_ENDPOINT) + .with_header("Authorization", format!("Bearer {}", access_token)) + .send() + .map_err(|_| ApiError::message(Status::BadGateway, "Failed to obtain Discord user info"))? + .json::() + .map(|r| r.user) + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse Discord user info", + ) }) } diff --git a/src/response.rs b/src/response.rs new file mode 100644 index 0000000..c273f6e --- /dev/null +++ b/src/response.rs @@ -0,0 +1,91 @@ +use rocket::http::{ContentType, Status}; +use rocket::response; +use rocket::serde::json::serde_json::json; +use rocket::serde::json::Value; +use rocket::serde::Serialize; +use rocket::{Request, Response}; + +pub(crate) type ApiResult = Result; + +#[derive(Serialize, Debug)] +#[serde(crate = "rocket::serde")] +pub(crate) struct ApiError { + code: u16, + message: Option, +} + +impl ApiError { + pub(crate) fn status(status: Status) -> Self { + Self { + code: status.code, + message: None, + } + } + + pub(crate) fn message(status: Status, message: M) -> Self + where + M: ToString, + { + Self { + code: status.code, + message: Some(message.to_string()), + } + } +} + +impl<'r> response::Responder<'r, 'r> for ApiError { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r> { + // We can assume that the status code is valid since it's coming from the Status enum + let status = Status::from_code(self.code).unwrap(); + rocket::info!(" >> Responding with error: {:?}", self); + let error = json!({ "error": self }); + + Response::build_from(error.respond_to(req)?) + .status(status) + .header(ContentType::JSON) + .ok() + } +} + +#[derive(Debug, Serialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct ApiResponse { + code: u16, + data: Option, +} + +impl ApiResponse { + pub(crate) fn status(status: Status) -> Self { + Self { + code: status.code, + data: None, + } + } + + pub(crate) fn ok(data: D) -> Self + where + D: Serialize, + { + let data = + rocket::serde::json::to_value(&data).expect("Failed to serialize data for ApiResponse"); + Self { + code: Status::Ok.code, + data: Some(data), + } + } +} + +impl<'r> response::Responder<'r, 'r> for ApiResponse { + fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r> { + // We can assume that the status code is valid since it's coming from the Status enum + let status = Status::from_code(self.code).unwrap(); + rocket::info!(" >> Responding with data: {:?}", self); + // We can assume that the data is valid JSON since it's coming from serde_json + let data = rocket::serde::json::to_value(self).unwrap(); + + Response::build_from(data.respond_to(req)?) + .status(status) + .header(ContentType::JSON) + .ok() + } +} From 0874b7f1c76dd9d962b8d1d2fa6fd25b3f6438b5 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Sun, 1 Dec 2024 21:16:13 +0000 Subject: [PATCH 08/14] refactor(error): Use custom error for uncaught statuses and default to the status reason if no error message is given --- src/lib.rs | 9 +++------ src/response.rs | 16 ++++++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index fb3a92c..f905d94 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,24 +10,21 @@ mod utils; #[macro_use] pub(crate) extern crate rocket; use rocket::http::Status; -use rocket::serde::json::{json, Value}; use rocket::Request; #[macro_use] extern crate rocket_sync_db_pools; use crate::config::Config; +use crate::response::ApiError; use dotenvy::dotenv; #[database("roops")] pub(crate) struct DbConn(diesel::PgConnection); #[catch(default)] -fn default(status: Status, _req: &Request<'_>) -> Value { - json!({ - "code": status.code, - "message": status.to_string() - }) +fn default(status: Status, _req: &Request<'_>) -> ApiError { + ApiError::status(status) } #[launch] diff --git a/src/response.rs b/src/response.rs index c273f6e..f426bbf 100644 --- a/src/response.rs +++ b/src/response.rs @@ -11,14 +11,14 @@ pub(crate) type ApiResult = Result; #[serde(crate = "rocket::serde")] pub(crate) struct ApiError { code: u16, - message: Option, + message: String, } impl ApiError { pub(crate) fn status(status: Status) -> Self { Self { code: status.code, - message: None, + message: status.to_string(), } } @@ -28,7 +28,7 @@ impl ApiError { { Self { code: status.code, - message: Some(message.to_string()), + message: message.to_string(), } } } @@ -37,7 +37,7 @@ impl<'r> response::Responder<'r, 'r> for ApiError { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r> { // We can assume that the status code is valid since it's coming from the Status enum let status = Status::from_code(self.code).unwrap(); - rocket::info!(" >> Responding with error: {:?}", self); + rocket::info!(" >> Error message: {}", self.message); let error = json!({ "error": self }); Response::build_from(error.respond_to(req)?) @@ -79,10 +79,14 @@ impl<'r> response::Responder<'r, 'r> for ApiResponse { fn respond_to(self, req: &'r Request<'_>) -> response::Result<'r> { // We can assume that the status code is valid since it's coming from the Status enum let status = Status::from_code(self.code).unwrap(); - rocket::info!(" >> Responding with data: {:?}", self); + + if let Some(data) = &self.data { + rocket::info!(" >> Response: {}", data); + } + // We can assume that the data is valid JSON since it's coming from serde_json let data = rocket::serde::json::to_value(self).unwrap(); - + Response::build_from(data.respond_to(req)?) .status(status) .header(ContentType::JSON) From 81926128dd70ead61b33691d3f679c89f90ac345 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:54:28 +0000 Subject: [PATCH 09/14] chore: Document code, add tests, and add a config validation command --- Cargo.toml | 1 + README.md | 35 ++- src/cipher.rs | 160 ++++++++-- src/config.rs | 279 ++++++++++++++---- src/constants.rs | 47 +++ src/database/wrappers/account_links/models.rs | 7 +- .../wrappers/discord_connections/models.rs | 87 +++--- src/database/wrappers/sessions/models.rs | 14 +- src/lib.rs | 34 ++- src/main.rs | 16 +- src/oauth/mod.rs | 3 - src/oauth/routes/discord.rs | 23 +- src/oauth/routes/roblox.rs | 22 +- src/oauth/types/discord.rs | 23 +- src/oauth/types/mod.rs | 9 +- src/oauth/utils/discord.rs | 37 +-- src/oauth/utils/mod.rs | 4 +- src/oauth/utils/roblox.rs | 13 +- src/response.rs | 88 ++++++ src/utils.rs | 9 +- 20 files changed, 731 insertions(+), 180 deletions(-) create mode 100644 src/constants.rs diff --git a/Cargo.toml b/Cargo.toml index 35d7340..4222b7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,4 @@ chrono = { version = "0.4.38", features = ["serde"] } diesel = { version = "2.1.4", features = ["postgres", "serde_json", "chrono"] } aes-gcm = { version = "0.10.3", default-features = false } toml = { version = "0.8.19", default-features = false } +serial_test = { version = "3.2.0", default-features = false, features = ["file_locks"] } diff --git a/README.md b/README.md index e6549ed..7e0f0eb 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,16 @@ ## Config -A `config.toml` file can be created in the root directory of the project. The file is used to configure the server. + +A `config.toml` file can be created in the root directory of the project. The file is used to +configure the server. ### CORS -The `cors` section of the config file is used to configure the CORS settings for the server. If the request origin is not in the list of allowed origins, the server will respond with a 403 Forbidden status code. -If the `allowed_methods` field is set, the server will respond with a 405 Method Not Allowed status code if the request method is not in the list of allowed methods. +The `cors` section of the config file is used to configure the CORS settings for the server. If the +request origin is not in the list of allowed origins, the server will respond with a 403 Forbidden +status code. + +If the `allowed_methods` field is set, the server will respond with a 405 Method Not Allowed status +code if the request method is not in the list of allowed methods. ```toml [cors] @@ -14,4 +20,27 @@ allowed_origins_exact = ["http://localhost:8000"] allowed_origins_regex = ["^http://localhost:\\d{4}$"] # Allowed methods for CORS requests allowed_methods = ["GET", "POST", "PUT", "DELETE"] +# Whether to allow credentials in CORS requests +allow_credentials = true +``` + +### OAuth2 + +The `oauth2` section of the config file is used to configure the OAuth2 settings for the server. Any +sensitive data should be stored in the environment variables. + +```toml +# Discord OAuth2 settings +[oauth.discord] +# Client ID +client_id = "CLIENT_ID" +# Callback URL +redirect_uri = "http://localhost:8000/v1/oauth2/callback/discord" + +# Roblox OAuth2 settings +[oauth.roblox] +# Client ID +client_id = "CLIENT_ID" +# Callback URL +redirect_uri = "http://localhost:8000/v1/oauth2/callback/roblox" ``` \ No newline at end of file diff --git a/src/cipher.rs b/src/cipher.rs index 8aa9b8a..01ba4d4 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -1,24 +1,59 @@ +//! Provides functions for encrypting and decrypting data using AES-256-GCM. +//! The encryption key is fetched from the `ENCRYPTION_KEY` environment variable. +//! The encrypted data and nonce are returned as base64 encoded strings. + +use crate::constants::env::{ENCRYPTION_KEY, ENCRYPTION_KEY_LENGTH}; use aes_gcm::aead::{Aead, Nonce, OsRng}; use aes_gcm::{AeadCore, Aes256Gcm, Key, KeyInit}; +/// Represents encrypted data with the encrypted string and the nonce used for encryption. pub(crate) struct EncryptedData { - /// Encrypted data represented as a hex string + /// The encrypted data as a base64 encoded string. pub(crate) data: String, - /// Nonce used to encrypt the data represented as a hex string + /// The nonce used for encryption as a base64 encoded string. pub(crate) nonce: String, } -fn get_encryption_key() -> Result, String> { - // Get the encryption key from the environment - let key = std::env::var("ENCRYPTION_KEY").expect("ENCRYPTION_KEY must be set"); - assert_eq!(key.len(), 32, "Encryption key must be 32 bytes long"); +/// Retrieves the encryption key from the environment. +/// +/// This function fetches the encryption key from the [`ENCRYPTION_KEY`] environment variable +/// and ensures it is [`ENCRYPTION_KEY_LENGTH`] bytes long. +/// +/// # Panics +/// +/// This function will panic if the `ENCRYPTION_KEY` environment variable is not set or if the key is not 32 bytes long. +/// +/// # Returns +/// +/// A [`Key`] instance containing the encryption key. +fn get_encryption_key() -> Key { + let key = + std::env::var(ENCRYPTION_KEY).unwrap_or_else(|_| panic!("{} must be set", ENCRYPTION_KEY)); + + if key.len() != ENCRYPTION_KEY_LENGTH { + panic!( + "Encryption key must be {} bytes long", + ENCRYPTION_KEY_LENGTH + ); + } - Ok(*Key::::from_slice(key.as_bytes())) + *Key::::from_slice(key.as_bytes()) } -/// Encrypt a string using AES-256-GCM. +/// Encrypts the given data using AES-256-GCM. +/// +/// This function encrypts the provided data using the AES-256-GCM algorithm and a randomly generated nonce. +/// The encrypted data and nonce are returned as base64 encoded strings. +/// +/// # Arguments +/// +/// * `data` - A byte slice of the data to be encrypted. +/// +/// # Returns +/// +/// A [`Result`] containing an [`EncryptedData`] struct on success, or a [`String`] error message on failure. pub(crate) fn encrypt(data: &[u8]) -> Result { - let key = get_encryption_key()?; + let key = get_encryption_key(); let cipher = Aes256Gcm::new(&key); let nonce = Aes256Gcm::generate_nonce(&mut OsRng); @@ -32,14 +67,27 @@ pub(crate) fn encrypt(data: &[u8]) -> Result { }) } -/// Decrypt an `EncryptedData` struct using AES-256-GCM. +/// Decrypts the given encrypted data using AES-256-GCM. +/// +/// This function decrypts the provided [`EncryptedData`] using the AES-256-GCM algorithm and the nonce. +/// The decrypted data is returned as a UTF-8 string. +/// +/// # Arguments +/// +/// * `data` - A reference to the [`EncryptedData`] struct containing the encrypted data and nonce. +/// +/// # Returns +/// +/// A [`Result`] containing the decrypted data as a [`String`] on success, or a [`String`] error message on failure. pub(crate) fn decrypt(data: &EncryptedData) -> Result { - let key = get_encryption_key()?; + let key = get_encryption_key(); let cipher = Aes256Gcm::new(&key); // Convert the hex string to a vector of bytes - let encrypted_bytes = base64::decode(&data.data).expect("Failed to decode base64 data"); - let nonce = base64::decode(&data.nonce).expect("Failed to decode base64 nonce"); + let encrypted_bytes = + base64::decode(&data.data).map_err(|e| format!("Failed to decode base64 data: {}", e))?; + let nonce = + base64::decode(&data.nonce).map_err(|e| format!("Failed to decode base64 nonce: {}", e))?; let nonce = Nonce::::from_slice(nonce.as_slice()); let decrypted_bytes = cipher @@ -52,20 +100,92 @@ pub(crate) fn decrypt(data: &EncryptedData) -> Result { } #[cfg(test)] +#[serial_test::file_serial(env)] mod tests { - use super::{decrypt, encrypt}; + use super::{decrypt, encrypt, get_encryption_key, EncryptedData}; + use crate::constants; + + const DATA: &str = "Hello, world!"; + + #[test] + fn encryption_key_is_32_bytes() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + let key = get_encryption_key(); + assert_eq!(key.as_slice().len(), 32); + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } #[test] - fn test_encryption() { - // Set the encryption key in the environment - std::env::set_var("ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef"); + fn encryption_key_not_set() { + let panic_result = std::panic::catch_unwind(get_encryption_key); + assert!(panic_result.is_err()); - const DATA: &str = "Hello, world!"; + let panic_message = panic_result.unwrap_err(); + let panic_message = panic_message.downcast_ref::().unwrap(); + let expected_message = format!("{} must be set", constants::env::ENCRYPTION_KEY); + assert_eq!(panic_message, &expected_message); + } - let encrypted_data = encrypt(DATA.as_bytes()).unwrap(); - assert_ne!(encrypted_data.data, DATA); + #[test] + fn encryption_key_wrong_length() { + std::env::set_var(constants::env::ENCRYPTION_KEY, "short_key"); + let panic_result = std::panic::catch_unwind(get_encryption_key); + assert!(panic_result.is_err()); + + let panic_message = panic_result.unwrap_err(); + let panic_message = panic_message.downcast_ref::().unwrap(); + let expected_message = format!( + "Encryption key must be {} bytes long", + constants::env::ENCRYPTION_KEY_LENGTH + ); + assert_eq!(panic_message, &expected_message); + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } + #[test] + fn encrypt_and_decrypt_data() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + let encrypted_data = encrypt(DATA.as_bytes()).unwrap(); let decrypted_data = decrypt(&encrypted_data).unwrap(); + assert_eq!(decrypted_data, DATA); + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } + + #[test] + fn decrypt_with_invalid_base64_data() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + let encrypted_data = EncryptedData { + data: "valid+data".to_string(), + nonce: "valid+nonce".to_string(), + }; + let result = decrypt(&encrypted_data); + + assert!(result.is_err()); + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } + + #[test] + fn decrypt_with_invalid_nonce() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + + let mut encrypted_data = encrypt(DATA.as_bytes()).unwrap(); + encrypted_data.nonce = "@".to_string(); + let result = decrypt(&encrypted_data); + + assert!(result.is_err()); + std::env::remove_var(constants::env::ENCRYPTION_KEY); } } diff --git a/src/config.rs b/src/config.rs index ed5d9b4..096fdf4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,73 +1,252 @@ -use rocket::serde::Deserialize; +use rocket::serde::{Deserialize, Deserializer}; use rocket_cors::{AllowedMethods, AllowedOrigins, Cors as RocketCors, CorsOptions}; use std::str::FromStr; +/// Represents the configuration for the application. +#[derive(Debug)] +pub struct Config { + /// CORS fairing. + pub(crate) cors: RocketCors, + /// OAuth settings for the application. + pub(crate) oauth: OAuthTypes, +} + +/// Represents the OAuth settings for the application. +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +pub(crate) struct OAuthTypes { + /// OAuth settings for Discord. + pub(crate) discord: OAuth, + /// OAuth settings for Roblox. + pub(crate) roblox: OAuth, +} + +/// Represents the OAuth settings for a specific provider. +#[derive(Deserialize, Debug)] +#[serde(crate = "rocket::serde")] +pub(crate) struct OAuth { + /// The client ID for the OAuth application. + pub(crate) client_id: String, + /// The callback URL for the OAuth application. + pub(crate) redirect_uri: String, +} + +/// Represents the CORS settings for the application. +#[derive(Debug, Default)] +struct Cors { + allowed_origins: AllowedOrigins, + /// A list of HTTP methods that are allowed. + allowed_methods: AllowedMethods, + /// Whether credentials are allowed. + allow_credentials: bool, +} + impl Config { - /// Load config.toml from root directory and parse it into Config struct - pub(crate) fn load() -> Self { - toml::from_str(include_str!("../config.toml")).unwrap_or_default() + /// Loads the configuration from a TOML file. + /// + /// This function reads the configuration from the `config.toml` file and + /// deserializes it into a [`Config`] struct. If the file cannot be read or + /// deserialized, it returns an error message. + /// + /// # Returns + /// + /// A [`Result`] containing a [`Config`] struct if the configuration was loaded successfully, + /// or a [`String`] error message if the configuration could not be loaded. + pub fn load(path: Option) -> Result { + let path = path.unwrap_or("config.toml".to_string()); + let content = std::fs::read_to_string(path) + .map_err(|e| format!("Failed to read configuration file: {}", e))?; + + toml::from_str(&content).map_err(|e| format!("Failed to parse TOML configuration: {}", e)) } +} + +impl<'de> Deserialize<'de> for Config { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Deserialize the configuration into a helper struct. + #[derive(Deserialize)] + #[serde(crate = "rocket::serde")] + struct RawConfig { + cors: Option, + oauth: OAuthTypes, + } - /// Get CORS fairing - pub(crate) fn cors_fairing(&self) -> RocketCors { - let Some(cors) = &self.cors else { - return CorsOptions::default() + // Deserialize the raw configuration. + let raw = RawConfig::deserialize(deserializer)?; + + // Create the CORS fairing based on the configuration. + let cors = match raw.cors { + Some(cors) => CorsOptions::default() + .allow_credentials(cors.allow_credentials) + .allowed_origins(cors.allowed_origins) + .allowed_methods(cors.allowed_methods) + .to_cors() + .expect("CORS fairing cannot be created"), + None => CorsOptions::default() .to_cors() - .expect("CORS fairing cannot be created"); + .expect("CORS fairing cannot be created"), }; - CorsOptions::default() - .allow_credentials(true) - .allowed_origins(cors.allowed_origins()) - .allowed_methods(cors.allowed_methods()) - .to_cors() - .expect("CORS fairing cannot be created") + Ok(Config { + cors, + oauth: raw.oauth, + }) } } -impl Cors { - /// Get allowed origins for CORS requests - fn allowed_origins(&self) -> AllowedOrigins { - match (&self.allowed_origins_exact, &self.allowed_origins_regex) { - (None, None) => AllowedOrigins::all(), +impl<'de> Deserialize<'de> for Cors { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Define a helper struct to hold the raw fields temporarily. + #[derive(Deserialize)] + #[serde(crate = "rocket::serde")] + struct RawCorsConfig { + allowed_origins_exact: Option>, + allowed_origins_regex: Option>, + allowed_methods: Option>, + #[serde(default)] + allow_credentials: bool, + } + + // Deserialize into the helper struct. + let raw = RawCorsConfig::deserialize(deserializer)?; + + // Create AllowedOrigins using rocket_cors' API. + let allowed_origins = match (&raw.allowed_origins_exact, &raw.allowed_origins_regex) { + (None, None) => AllowedOrigins::default(), (Some(exact), None) => AllowedOrigins::some_exact(exact), (None, Some(regex)) => AllowedOrigins::some_regex(regex), (Some(exact), Some(regex)) => AllowedOrigins::some(exact, regex), - } - } + }; - /// Get allowed methods for CORS requests - fn allowed_methods(&self) -> AllowedMethods { - let Some(methods) = &self.allowed_methods else { - return AllowedMethods::default(); + // Parse the methods from the configuration. + let allowed_methods = match raw.allowed_methods { + Some(methods) => methods + .iter() + .map(|s| FromStr::from_str(s).expect("Failed to parse method")) + .collect(), + None => AllowedMethods::default(), }; - methods - .iter() - .map(|s| FromStr::from_str(s).expect("Failed to parse method")) - .collect() + Ok(Cors { + allowed_origins, + allowed_methods, + allow_credentials: raw.allow_credentials, + }) } } -/// The server configuration -#[derive(Deserialize, Debug, Default)] -#[serde(crate = "rocket::serde")] -pub(crate) struct Config { - /// CORS configuration - pub(crate) cors: Option, -} +#[cfg(test)] +mod tests { + use super::*; + use rocket::http::Method; + use rocket::serde::json::serde_json; + use rocket_cors::AllowedOrigins; -/// CORS configuration -#[derive(Deserialize, Debug, Default)] -#[serde(crate = "rocket::serde")] -pub(crate) struct Cors { - /// Allowed origins for CORS requests (exact match) - #[serde(default)] - allowed_origins_exact: Option>, - /// Allowed origins for CORS requests (regular expression) - #[serde(default)] - allowed_origins_regex: Option>, - /// Allowed methods for CORS requests - #[serde(default)] - allowed_methods: Option>, + #[test] + fn deserialize_cors_with_defaults() { + let cors: Cors = serde_json::from_str("{}").unwrap(); + assert_eq!(cors.allowed_origins, AllowedOrigins::default()); + assert_eq!(cors.allowed_methods, AllowedMethods::default()); + assert!(!cors.allow_credentials); + } + + #[test] + fn deserialize_cors_with_exact_origins() { + let json = r#" + { + "allowed_origins_exact": ["https://example.com"] + } + "#; + let cors: Cors = serde_json::from_str(json).unwrap(); + assert_eq!( + cors.allowed_origins, + AllowedOrigins::some_exact(&["https://example.com"]) + ); + } + + #[test] + fn deserialize_cors_with_regex_origins() { + let json = r#" + { + "allowed_origins_regex": ["^https://.*\\.example\\.com$"] + } + "#; + let cors: Cors = serde_json::from_str(json).unwrap(); + assert_eq!( + cors.allowed_origins, + AllowedOrigins::some_regex(&["^https://.*\\.example\\.com$"]) + ); + } + + #[test] + fn deserialize_cors_with_both_origins() { + let json = r#" + { + "allowed_origins_exact": ["https://example.com"], + "allowed_origins_regex": ["^https://.*\\.example\\.com$"] + } + "#; + let cors: Cors = serde_json::from_str(json).unwrap(); + assert_eq!( + cors.allowed_origins, + AllowedOrigins::some(&["https://example.com"], &["^https://.*\\.example\\.com$"]) + ); + } + + #[test] + fn deserialize_cors_with_methods() { + let json = r#" + { + "allowed_methods": ["GET"] + } + "#; + let cors: Cors = serde_json::from_str(json).unwrap(); + assert!(cors.allowed_methods.contains(&Method::Get.into())); + } + + #[test] + fn deserialize_cors_with_credentials() { + let json = r#" + { + "allow_credentials": true + } + "#; + let cors: Cors = serde_json::from_str(json).unwrap(); + assert!(cors.allow_credentials); + } + + #[test] + fn deserialize_config_with_defaults() { + let json = r#" + { + "oauth": { + "discord": { + "client_id": "discord-client-id", + "redirect_uri": "https://example.com/discord/callback" + }, + "roblox": { + "client_id": "roblox-client-id", + "redirect_uri": "https://example.com/roblox/callback" + } + } + } + "#; + let config: Config = serde_json::from_str(json).unwrap(); + assert_eq!(config.oauth.discord.client_id, "discord-client-id"); + assert_eq!( + config.oauth.discord.redirect_uri, + "https://example.com/discord/callback" + ); + assert_eq!(config.oauth.roblox.client_id, "roblox-client-id"); + assert_eq!( + config.oauth.roblox.redirect_uri, + "https://example.com/roblox/callback" + ); + } } diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..04b11f4 --- /dev/null +++ b/src/constants.rs @@ -0,0 +1,47 @@ +//! Constants used throughout the application. + +pub(crate) mod roblox_api { + /// The URL for authorizing with the Roblox API. + pub(crate) const AUTHORIZE_URL: &str = "https://apis.roblox.com/oauth/v1/authorize"; +} + +pub(crate) mod discord_api { + /// The URL for authorizing with the Discord API. + pub(crate) const AUTHORIZE_URL: &str = "https://discord.com/authorize"; + /// The URL for obtaining tokens from the Discord API. + pub(crate) const TOKEN_URL: &str = "https://discord.com/api/v10/oauth2/token"; + /// The URL for fetching the current user's information from the Discord API. + pub(crate) const USER_URL: &str = "https://discord.com/api/v10/users/@me"; +} + +pub(crate) mod cookie { + /// The name of the session ID cookie. + pub(crate) const SESSION_ID: &str = "sessionid"; + /// The name of the state cookie. + pub(crate) const STATE: &str = "state"; + /// The name of the Roblox OAuth verifier cookie. + pub(crate) const OAUTH_CODE_VERIFIER: &str = "oauth_code_verifier"; +} + +pub(crate) mod env { + /// The environment variable name for the encryption key. + pub(crate) const ENCRYPTION_KEY: &str = "ENCRYPTION_KEY"; + /// The required length of the encryption key. + pub(crate) const ENCRYPTION_KEY_LENGTH: usize = 32; + /// The environment variable name for the Discord client secret. + pub(crate) const DISCORD_CLIENT_SECRET: &str = "DISCORD_CLIENT_SECRET"; +} + +#[cfg(test)] +pub(crate) mod test { + /// A test encryption key. + pub(crate) const ENCRYPTION_KEY: &str = "0123456789abcdef0123456789abcdef"; + /// A test access token. + pub(crate) const ACCESS_TOKEN: &str = "access_token"; + /// A test refresh token. + pub(crate) const REFRESH_TOKEN: &str = "refresh_token"; + /// A test scope. + pub(crate) const SCOPE: &str = "scope"; + /// A test user ID. + pub(crate) const UID: &str = "1"; +} diff --git a/src/database/wrappers/account_links/models.rs b/src/database/wrappers/account_links/models.rs index 32b8ccd..1e3cd47 100644 --- a/src/database/wrappers/account_links/models.rs +++ b/src/database/wrappers/account_links/models.rs @@ -41,8 +41,11 @@ impl AccountLinkBuilder { self } - pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { - self.discord_uid = Some(discord_uid); + pub(crate) fn discord_uid(mut self, discord_uid: S) -> Self + where + S: Into, + { + self.discord_uid = Some(discord_uid.into()); self } diff --git a/src/database/wrappers/discord_connections/models.rs b/src/database/wrappers/discord_connections/models.rs index c9b0c6d..14fc634 100644 --- a/src/database/wrappers/discord_connections/models.rs +++ b/src/database/wrappers/discord_connections/models.rs @@ -65,13 +65,19 @@ impl DiscordConnectionBuilder { Self::default() } - pub(crate) fn uid(mut self, uid: String) -> Self { - self.uid = Some(uid); + pub(crate) fn uid(mut self, uid: S) -> Self + where + S: Into, + { + self.uid = Some(uid.into()); self } - pub(crate) fn access_token(mut self, access_token: String) -> Self { - self.access_token = Some(access_token); + pub(crate) fn access_token(mut self, access_token: S) -> Self + where + S: Into, + { + self.access_token = Some(access_token.into()); self } @@ -80,13 +86,19 @@ impl DiscordConnectionBuilder { self } - pub(crate) fn refresh_token(mut self, refresh_token: String) -> Self { - self.refresh_token = Some(refresh_token); + pub(crate) fn refresh_token(mut self, refresh_token: S) -> Self + where + S: Into, + { + self.refresh_token = Some(refresh_token.into()); self } - pub(crate) fn scope(mut self, scope: String) -> Self { - self.scope = Some(scope); + pub(crate) fn scope(mut self, scope: S) -> Self + where + S: Into, + { + self.scope = Some(scope.into()); self } @@ -145,55 +157,64 @@ impl DiscordConnectionBuilder { #[cfg(test)] mod tests { use super::DiscordConnectionBuilder; - - const ACCESS_TOKEN: &str = "access_token"; - const REFRESH_TOKEN: &str = "refresh_token"; - const SCOPE: &str = "scope"; - const UID: &str = "1"; + use crate::constants; + use serial_test::file_serial; #[test] + #[file_serial(env)] fn test_build() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + let now = chrono::Utc::now().naive_utc(); let conn = DiscordConnectionBuilder::new() - .uid(UID.to_string()) - .access_token(ACCESS_TOKEN.to_string()) + .uid(constants::test::UID) + .access_token(constants::test::ACCESS_TOKEN) .expires_at(now) - .refresh_token(REFRESH_TOKEN.to_string()) - .scope(SCOPE.to_string()) + .refresh_token(constants::test::REFRESH_TOKEN) + .scope(constants::test::SCOPE) .build() - .expect("Failed to build NewDiscordConnection"); + .unwrap(); // The access token is encrypted, so we can't compare it directly - assert_ne!(conn.access_token, ACCESS_TOKEN); - + assert_ne!(conn.access_token, constants::test::ACCESS_TOKEN); // The refresh token is encrypted, so we can't compare it directly - assert_ne!(conn.refresh_token, REFRESH_TOKEN); - - assert_eq!(conn.uid, UID); + assert_ne!(conn.refresh_token, constants::test::REFRESH_TOKEN); + assert_eq!(conn.uid, constants::test::UID); assert_eq!(conn.expires_at, now); - assert_eq!(conn.scope, SCOPE); + assert_eq!(conn.scope, constants::test::SCOPE); + + std::env::remove_var(constants::env::ENCRYPTION_KEY); } #[test] + #[file_serial(env)] fn test_build_update() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + let now = chrono::Utc::now().naive_utc(); let conn = DiscordConnectionBuilder::new() - .access_token(ACCESS_TOKEN.to_string()) + .access_token(constants::test::ACCESS_TOKEN) .expires_at(now) - .refresh_token(REFRESH_TOKEN.to_string()) - .scope(SCOPE.to_string()) + .refresh_token(constants::test::REFRESH_TOKEN) + .scope(constants::test::SCOPE) .build_update() - .expect("Failed to build UpdateDiscordConnection"); + .unwrap(); // The access token is encrypted, so we can't compare it directly - assert_ne!(conn.access_token, Some(ACCESS_TOKEN.to_string())); + assert_ne!(conn.access_token.unwrap(), constants::test::ACCESS_TOKEN); assert!(conn.access_token_nonce.is_some()); - // The refresh token is encrypted, so we can't compare it directly - assert_ne!(conn.refresh_token, Some(REFRESH_TOKEN.to_string())); + assert_ne!(conn.refresh_token.unwrap(), constants::test::REFRESH_TOKEN); assert!(conn.refresh_token_nonce.is_some()); + assert_eq!(conn.expires_at.unwrap(), now); + assert_eq!(conn.scope.unwrap(), constants::test::SCOPE); - assert_eq!(conn.expires_at, Some(now)); - assert_eq!(conn.scope, Some(SCOPE.to_string())); + std::env::remove_var(constants::env::ENCRYPTION_KEY); } } diff --git a/src/database/wrappers/sessions/models.rs b/src/database/wrappers/sessions/models.rs index 6094933..3cf3760 100644 --- a/src/database/wrappers/sessions/models.rs +++ b/src/database/wrappers/sessions/models.rs @@ -50,13 +50,19 @@ impl SessionBuilder { Self::default() } - pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { - self.discord_uid = Some(discord_uid); + pub(crate) fn discord_uid(mut self, discord_uid: S) -> Self + where + S: Into, + { + self.discord_uid = Some(discord_uid.into()); self } - pub(crate) fn session_id(mut self, session_id: String) -> Self { - self.session_id = Some(session_id); + pub(crate) fn session_id(mut self, session_id: S) -> Self + where + S: Into, + { + self.session_id = Some(session_id.into()); self } diff --git a/src/lib.rs b/src/lib.rs index f905d94..5c01fb0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ #![allow(dead_code)] mod cipher; -mod config; +pub mod config; +mod constants; mod database; mod oauth; mod response; @@ -19,21 +20,48 @@ use crate::config::Config; use crate::response::ApiError; use dotenvy::dotenv; +/// Represents a connection to the PostgreSQL database. +/// +/// This struct is used to manage a connection to the PostgreSQL database +/// using Diesel's `PgConnection`. #[database("roops")] pub(crate) struct DbConn(diesel::PgConnection); +/// Default error catcher for the Rocket application. +/// +/// This function handles all uncaught errors and converts them into an [`ApiError`] +/// with the appropriate status code. +/// +/// # Arguments +/// +/// * `status` - The HTTP status code of the error. +/// * `_req` - The request that caused the error. +/// +/// # Returns +/// +/// An [`ApiError`] instance with the appropriate status code. #[catch(default)] fn default(status: Status, _req: &Request<'_>) -> ApiError { ApiError::status(status) } +/// Launches the Rocket application. +/// +/// This function initializes the Rocket application, loads the configuration, +/// attaches the CORS fairing and database connection, mounts the OAuth2 routes, +/// and registers the default error catcher. +/// +/// # Returns +/// +/// A [`Rocket`](rocket::Rocket) instance in the build phase, configured with the application's settings. #[launch] pub fn rocket() -> _ { dotenv().ok(); - let cfg = Config::load(); + let cfg = Config::load(None).unwrap(); rocket::build() - .attach(cfg.cors_fairing()) + .attach(cfg.cors.clone()) .attach(DbConn::fairing()) .mount("/v1/oauth2", oauth::routes::routes()) .register("/", catchers![default]) + .manage(cfg) } diff --git a/src/main.rs b/src/main.rs index 0a1efe3..d612e60 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,19 @@ #[rocket::main] async fn main() -> Result<(), rocket::Error> { - roops_internal_api::rocket().launch().await?; + // Start the Rocket server if no arguments are provided + // Otherwise, validate the config if the validate-config argument is provided with an optional path + let Some(cmd) = std::env::args().nth(1) else { + roops_internal_api::rocket().launch().await?; + return Ok(()); + }; + + match cmd.as_str() { + "validate-config" => { + // Validate the config + let path = std::env::args().nth(2); + roops_internal_api::config::Config::load(path).unwrap(); + } + _ => panic!("Invalid command: {}", cmd), + } Ok(()) } diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 4a8af42..36ebc63 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -1,6 +1,3 @@ pub(crate) mod routes; mod types; mod utils; - -const OAUTH_STATE_COOKIE_NAME: &str = "oauth_state"; -const SESSION_COOKIE_NAME: &str = "sessionid"; diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index 3b489de..e948d49 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -1,4 +1,6 @@ use crate::cipher::{decrypt, EncryptedData}; +use crate::config::Config; +use crate::constants; use crate::database::wrappers::discord_connections::models::{ NewDiscordConnection, UpdateDiscordConnection, }; @@ -12,26 +14,28 @@ use crate::oauth::utils::discord::{ construct_discord_oauth_url, exchange_code, get_authorized_user, refresh_token, }; use crate::oauth::utils::{generate_session, generate_state}; -use crate::oauth::OAUTH_STATE_COOKIE_NAME; use crate::response::{ApiError, ApiResponse, ApiResult}; use crate::DbConn; use diesel::Connection; use rocket::http::{Cookie, CookieJar, SameSite, Status}; use rocket::response::Redirect; use rocket::time::Duration; +use rocket::State; #[get("/initiate/discord?")] pub(super) fn discord_oauth_initiate( scope_set: String, jar: &CookieJar<'_>, + cfg: &State, ) -> ApiResult { let scope_set = DiscordOAuthScopeSet::try_from(scope_set) .map_err(|e| ApiError::message(Status::BadRequest, e))?; let state = generate_state(); - let redirect_uri = construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state); + let redirect_uri = + construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state, cfg); - let auth_cookie = Cookie::build((OAUTH_STATE_COOKIE_NAME, state)) + let auth_cookie = Cookie::build((constants::cookie::STATE, state)) .path("/v1/oauth2/callback/discord") .same_site(SameSite::Lax) .max_age(Duration::minutes(5)); @@ -47,10 +51,11 @@ pub(super) async fn discord_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, conn: DbConn, + cfg: &State, ) -> ApiResult { // Verify the state against the one that was saved in the cookie let is_valid_state = jar - .get_pending(OAUTH_STATE_COOKIE_NAME) + .get_pending(constants::cookie::STATE) .map_or(false, |cookie| cookie.value() == callback.state); if !is_valid_state { @@ -58,7 +63,7 @@ pub(super) async fn discord_oauth_callback( } // Use the code to obtain the token - let response = exchange_code(&callback.code)?; + let response = exchange_code(&callback.code, cfg)?; // Fetch the authorized user let authorized_user = get_authorized_user(&response.access_token)?; @@ -115,7 +120,11 @@ pub(super) async fn discord_oauth_callback( } #[post("/refresh-token/discord")] -pub(super) async fn discord_refresh_token(conn: DbConn, session_id: SessionId) -> ApiResult { +pub(super) async fn discord_refresh_token( + conn: DbConn, + session_id: SessionId, + cfg: &State, +) -> ApiResult { let session_id = session_id.into_inner(); // Fetch the Discord UID from the session token @@ -147,7 +156,7 @@ pub(super) async fn discord_refresh_token(conn: DbConn, session_id: SessionId) - })?; // Refresh the token - let response = refresh_token(&discord_refresh_token)?; + let response = refresh_token(&discord_refresh_token, cfg)?; let token_expires_at = chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); diff --git a/src/oauth/routes/roblox.rs b/src/oauth/routes/roblox.rs index f73a230..72c47f9 100644 --- a/src/oauth/routes/roblox.rs +++ b/src/oauth/routes/roblox.rs @@ -1,15 +1,15 @@ +use crate::config::Config; +use crate::constants; use crate::oauth::types::roblox::{RobloxOAuthScopeSet, RobloxOAuthScopes}; use crate::oauth::types::OAuthCallback; use crate::oauth::utils::generate_state; use crate::oauth::utils::pixy::Pixy; use crate::oauth::utils::roblox::construct_roblox_oauth_url; -use crate::oauth::OAUTH_STATE_COOKIE_NAME; use crate::response::{ApiError, ApiResult}; use rocket::http::{Cookie, CookieJar, Status}; use rocket::response::Redirect; use rocket::time::Duration; - -const OAUTH_VERIFIER_COOKIE_NAME: &str = "oauth_code_verifier"; +use rocket::State; #[get("/callback/roblox?")] pub(super) fn roblox_oauth_callback( @@ -18,7 +18,7 @@ pub(super) fn roblox_oauth_callback( ) -> ApiResult { // Verify the state against the one that was saved in the cookie let is_valid_state = jar - .get_pending(OAUTH_STATE_COOKIE_NAME) + .get_pending(constants::cookie::STATE) .map_or(false, |cookie| cookie.value() == callback.state); if !is_valid_state { @@ -26,7 +26,7 @@ pub(super) fn roblox_oauth_callback( } // Use verifier to obtain token - let Some(_verifier_cookie) = jar.get_pending(OAUTH_VERIFIER_COOKIE_NAME) else { + let Some(_verifier_cookie) = jar.get_pending(constants::cookie::OAUTH_CODE_VERIFIER) else { return Err(ApiError::message( Status::InternalServerError, "Missing verifier cookie", @@ -42,7 +42,11 @@ pub(super) fn roblox_oauth_callback( } #[get("/initiate/roblox?")] -pub(super) fn roblox_oauth_initiate(scope_set: String, jar: &CookieJar<'_>) -> ApiResult { +pub(super) fn roblox_oauth_initiate( + scope_set: String, + jar: &CookieJar<'_>, + cfg: &State, +) -> ApiResult { let scope_set = RobloxOAuthScopeSet::try_from(scope_set) .map_err(|e| ApiError::message(Status::BadRequest, e))?; @@ -52,15 +56,15 @@ pub(super) fn roblox_oauth_initiate(scope_set: String, jar: &CookieJar<'_>) -> A let state = generate_state(); let redirect_uri = - construct_roblox_oauth_url(challenge, &RobloxOAuthScopes::from(&scope_set), &state); + construct_roblox_oauth_url(challenge, &RobloxOAuthScopes::from(&scope_set), &state, cfg); - let auth_cookie = Cookie::build((OAUTH_STATE_COOKIE_NAME, state)) + let auth_cookie = Cookie::build((constants::cookie::STATE, state)) .path("/v1/oauth2/callback/roblox") .same_site(rocket::http::SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(auth_cookie); - let verifier_cookie = Cookie::build((OAUTH_VERIFIER_COOKIE_NAME, verifier)) + let verifier_cookie = Cookie::build((constants::cookie::OAUTH_CODE_VERIFIER, verifier)) .path("/v1/oauth2/callback/roblox") .same_site(rocket::http::SameSite::Lax) .max_age(Duration::minutes(5)); diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index 788bf9f..595f095 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -1,4 +1,5 @@ -use crate::oauth::utils::discord::{DISCORD_OAUTH_APP_CLIENT_ID, DISCORD_OAUTH_REDIRECT_URI}; +use crate::config::Config; +use crate::constants; use crate::url; use rocket::serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result as FmtResult}; @@ -240,13 +241,14 @@ impl TryFrom for DiscordOAuthScopeSet { } impl<'a> DiscordAuthorizationCodeRequestBody<'a> { - pub(crate) fn new(code: &'a str) -> Self { + pub(crate) fn new(code: &'a str, cfg: &'a Config) -> Self { Self { - client_id: DISCORD_OAUTH_APP_CLIENT_ID, - redirect_uri: DISCORD_OAUTH_REDIRECT_URI, + client_id: &cfg.oauth.discord.client_id, + redirect_uri: &cfg.oauth.discord.redirect_uri, grant_type: "authorization_code", - client_secret: std::env::var("DISCORD_CLIENT_SECRET") - .expect("DISCORD_CLIENT_SECRET must be set"), + client_secret: std::env::var(constants::env::DISCORD_CLIENT_SECRET).unwrap_or_else( + |_| panic!("{} must be set", constants::env::DISCORD_CLIENT_SECRET), + ), code, } } @@ -263,13 +265,14 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { } impl<'a> DiscordTokenRefreshBody<'a> { - pub(crate) fn new(refresh_token: &'a str) -> Self { + pub(crate) fn new(refresh_token: &'a str, cfg: &'a Config) -> Self { Self { grant_type: "refresh_token", refresh_token, - client_id: DISCORD_OAUTH_APP_CLIENT_ID, - client_secret: std::env::var("DISCORD_CLIENT_SECRET") - .expect("DISCORD_CLIENT_SECRET must be set"), + client_id: &cfg.oauth.discord.client_id, + client_secret: std::env::var(constants::env::DISCORD_CLIENT_SECRET).unwrap_or_else( + |_| panic!("{} must be set", constants::env::DISCORD_CLIENT_SECRET), + ), } } diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index 12ebd15..373f61f 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -1,4 +1,4 @@ -use crate::oauth::SESSION_COOKIE_NAME; +use crate::constants; use crate::response::ApiError; use rocket::http::{Cookie, Status}; use rocket::request::{FromRequest, Outcome}; @@ -34,12 +34,15 @@ impl<'r> FromRequest<'r> for SessionId { async fn from_request(request: &'r Request<'_>) -> Outcome { let session_id = request .cookies() - .get_pending(SESSION_COOKIE_NAME) + .get_pending(constants::cookie::SESSION_ID) .map(|cookie| cookie.value().to_string()); match session_id { None => { - let error = ApiError::message(Status::Unauthorized, "Missing sessionid cookie"); + let error = ApiError::message( + Status::Unauthorized, + format!("Missing {} cookie", constants::cookie::SESSION_ID), + ); Outcome::Error((Status::Unauthorized, error)) } Some(session_id) => Outcome::Success(SessionId(session_id)), diff --git a/src/oauth/utils/discord.rs b/src/oauth/utils/discord.rs index a15fda2..923a83c 100644 --- a/src/oauth/utils/discord.rs +++ b/src/oauth/utils/discord.rs @@ -1,3 +1,5 @@ +use crate::config::Config; +use crate::constants; use crate::oauth::types::discord::{ DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, DiscordAuthorizedUserResponse, DiscordOAuthScopes, DiscordTokenRefreshBody, DiscordUser, @@ -6,19 +8,16 @@ use crate::response::ApiError; use crate::url; use rocket::http::Status; -pub(crate) const DISCORD_OAUTH_REDIRECT_URI: &str = - "http://localhost:8000/v1/oauth2/callback/discord"; -pub(crate) const DISCORD_OAUTH_APP_CLIENT_ID: &str = "1300538596348133499"; -const DISCORD_AUTHORIZED_USER_ENDPOINT: &str = "https://discord.com/api/v10/oauth2/@me"; -const DISCORD_OAUTH_ENDPOINT_URL: &str = "https://discord.com/oauth2/authorize"; -const DISCORD_TOKEN_ENDPOINT_URL: &str = "https://discord.com/api/v10/oauth2/token"; - -pub(crate) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &str) -> String { +pub(crate) fn construct_discord_oauth_url( + scopes: &DiscordOAuthScopes, + state: &str, + cfg: &Config, +) -> String { url!( - DISCORD_OAUTH_ENDPOINT_URL, + constants::discord_api::AUTHORIZE_URL, [ - ("redirect_uri", DISCORD_OAUTH_REDIRECT_URI), - ("client_id", DISCORD_OAUTH_APP_CLIENT_ID), + ("redirect_uri", cfg.oauth.discord.redirect_uri), + ("client_id", cfg.oauth.discord.client_id), ("response_type", "code"), ("scope", scopes), ("state", state) @@ -26,9 +25,12 @@ pub(crate) fn construct_discord_oauth_url(scopes: &DiscordOAuthScopes, state: &s ) } -pub(crate) fn exchange_code(code: &str) -> Result { - let body = DiscordAuthorizationCodeRequestBody::new(code); - let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) +pub(crate) fn exchange_code( + code: &str, + cfg: &Config, +) -> Result { + let body = DiscordAuthorizationCodeRequestBody::new(code, cfg); + let response = minreq::post(constants::discord_api::TOKEN_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) .send() @@ -46,9 +48,10 @@ pub(crate) fn exchange_code(code: &str) -> Result Result { - let body = DiscordTokenRefreshBody::new(refresh_token); - let response = minreq::post(DISCORD_TOKEN_ENDPOINT_URL) + let body = DiscordTokenRefreshBody::new(refresh_token, cfg); + let response = minreq::post(constants::discord_api::TOKEN_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) .send() @@ -65,7 +68,7 @@ pub(crate) fn refresh_token( } pub(crate) fn get_authorized_user(access_token: &str) -> Result, ApiError> { - minreq::get(DISCORD_AUTHORIZED_USER_ENDPOINT) + minreq::get(constants::discord_api::USER_URL) .with_header("Authorization", format!("Bearer {}", access_token)) .send() .map_err(|_| ApiError::message(Status::BadGateway, "Failed to obtain Discord user info"))? diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs index 83759b7..5137cca 100644 --- a/src/oauth/utils/mod.rs +++ b/src/oauth/utils/mod.rs @@ -3,8 +3,8 @@ pub(super) mod pixy; pub(super) mod roblox; use self::pixy::generate_random_base64_url_safe_no_pad_string; +use crate::constants; use crate::oauth::types::GeneratedSession; -use crate::oauth::SESSION_COOKIE_NAME; use rand::{thread_rng, Rng}; use rocket::http::Cookie; @@ -25,7 +25,7 @@ pub(super) fn generate_session() -> GeneratedSession<'static> { let expires_at = chrono::Utc::now() + chrono::Duration::days(30); GeneratedSession { - cookie: Cookie::build((SESSION_COOKIE_NAME, session_id.clone())) + cookie: Cookie::build((constants::cookie::SESSION_ID, session_id.clone())) .max_age(rocket::time::Duration::days(30)) .build(), session_id, diff --git a/src/oauth/utils/roblox.rs b/src/oauth/utils/roblox.rs index d34668a..a8da8c1 100644 --- a/src/oauth/utils/roblox.rs +++ b/src/oauth/utils/roblox.rs @@ -1,23 +1,22 @@ +use crate::config::Config; +use crate::constants; use crate::oauth::types::roblox::RobloxOAuthScopes; use crate::url; use secrecy::{ExposeSecret, SecretString}; -const ROBLOX_OAUTH_ENDPOINT_URL: &str = "https://apis.roblox.com/oauth/v1/authorize"; -const ROBLOX_OAUTH_REDIRECT_URI: &str = "http://localhost:8000/oauth-callback/roblox"; -const ROBLOX_OAUTH_APP_CLIENT_ID: u64 = 5083487294232976060; - pub(crate) fn construct_roblox_oauth_url( code_challenge_secret: &SecretString, scopes: &RobloxOAuthScopes, state: &str, + cfg: &Config, ) -> String { let code_challenge = code_challenge_secret.expose_secret(); url!( - ROBLOX_OAUTH_ENDPOINT_URL, + constants::roblox_api::AUTHORIZE_URL, [ - ("redirect_uri", ROBLOX_OAUTH_REDIRECT_URI), - ("client_id", ROBLOX_OAUTH_APP_CLIENT_ID), + ("redirect_uri", cfg.oauth.roblox.redirect_uri), + ("client_id", cfg.oauth.roblox.client_id), ("code_challenge_method", "S256"), ("code_challenge", code_challenge), ("response_type", "code"), diff --git a/src/response.rs b/src/response.rs index f426bbf..9a449b2 100644 --- a/src/response.rs +++ b/src/response.rs @@ -5,16 +5,33 @@ use rocket::serde::json::Value; use rocket::serde::Serialize; use rocket::{Request, Response}; +/// A type alias for API results, which can either be a successful response of type `R` +/// or an [`ApiError`]. Type `R` **must** implement the [`Serialize`] trait. +/// +/// The default type for `R` is [`ApiResponse`]. pub(crate) type ApiResult = Result; +/// Encapsulates error information that can be returned +/// by the API. It includes a status code and a message describing the error. #[derive(Serialize, Debug)] #[serde(crate = "rocket::serde")] pub(crate) struct ApiError { + /// The HTTP status code associated with the error. code: u16, + /// A message describing the error. message: String, } impl ApiError { + /// Creates a new [`ApiError`] with the given status code and its string representation as the message. + /// + /// # Arguments + /// + /// * `status` - The HTTP status to be used for the error. + /// + /// # Returns + /// + /// A new [`ApiError`] instance with the provided status code and message. pub(crate) fn status(status: Status) -> Self { Self { code: status.code, @@ -22,6 +39,16 @@ impl ApiError { } } + /// Creates a new [`ApiError`] with the given status code and a custom message. + /// + /// # Arguments + /// + /// * `status` - The HTTP status to be used for the error. + /// * `message` - A custom message describing the error. + /// + /// # Returns + /// + /// A new [`ApiError`] instance with the provided status code and custom message. pub(crate) fn message(status: Status, message: M) -> Self where M: ToString, @@ -47,14 +74,27 @@ impl<'r> response::Responder<'r, 'r> for ApiError { } } +/// Encapsulates a successful response that can be returned +/// by the API. It includes a status code and optional data. #[derive(Debug, Serialize)] #[serde(crate = "rocket::serde")] pub(crate) struct ApiResponse { + /// The HTTP status code associated with the response. code: u16, + /// Optional data included in the response. data: Option, } impl ApiResponse { + /// Creates a new [`ApiResponse`] with the given status code and no data. + /// + /// # Arguments + /// + /// * `status` - The HTTP status to be used for the response. + /// + /// # Returns + /// + /// A new [`ApiResponse`] instance with the provided status code and no data. pub(crate) fn status(status: Status) -> Self { Self { code: status.code, @@ -62,6 +102,15 @@ impl ApiResponse { } } + /// Creates a new [`ApiResponse`] with a status code of `200 OK` and the given data. + /// + /// # Arguments + /// + /// * `data` - The data to be included in the response. Must implement the [`Serialize`](rocket::serde::Serialize) trait. + /// + /// # Returns + /// + /// A new [`ApiResponse`] instance with a status code of `200 OK` and the provided data. pub(crate) fn ok(data: D) -> Self where D: Serialize, @@ -93,3 +142,42 @@ impl<'r> response::Responder<'r, 'r> for ApiResponse { .ok() } } + +#[cfg(test)] +mod tests { + use super::{ApiError, ApiResponse}; + use rocket::http::Status; + use rocket::serde::json::Value; + + #[test] + fn test_api_error_status() { + let status = Status::BadRequest; + let error = ApiError::status(status); + assert_eq!(error.code, status.code); + assert_eq!(error.message, status.to_string()); + } + + #[test] + fn test_api_error_message() { + let status = Status::NotFound; + let error = ApiError::message(status, "Resource not found"); + assert_eq!(error.code, status.code); + assert_eq!(error.message, "Resource not found"); + } + + #[test] + fn test_api_response_status() { + let status = Status::Created; + let response = ApiResponse::status(status); + assert_eq!(response.code, status.code); + assert!(response.data.is_none()); + } + + #[test] + fn test_api_response_ok() { + let data = "Hello, world!".to_string(); + let response = ApiResponse::ok(&data); + assert_eq!(response.code, Status::Ok.code); + assert_eq!(response.data, Some(Value::String(data))); + } +} diff --git a/src/utils.rs b/src/utils.rs index 41849b6..14734bc 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,10 +1,7 @@ -/// Construct a URL with query parameters. +/// Constructs a URL with query parameters. /// -/// # Arguments -/// -/// * `$domain` - The domain of the URL. -/// * `$key` - The key of the query parameter. -/// * `$value` - The value of the query parameter. +/// This macro can be used to create a URL with query parameters from a domain and a list of key-value pairs, +/// or just from a list of key-value pairs. /// /// # Examples /// From 3abb2557a464e266482843c53c9c0d111c826ef1 Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:00:24 +0000 Subject: [PATCH 10/14] refactor(readme): Specify requirement of fields --- README.md | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7e0f0eb..9e073db 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,19 @@ status code. If the `allowed_methods` field is set, the server will respond with a 405 Method Not Allowed status code if the request method is not in the list of allowed methods. +If neither the `allowed_origins_exact` nor the `allowed_origins_regex` fields are set, the server +will accept requests from any origin. + ```toml +# CORS settings (optional) [cors] -# Allowed origins for CORS requests (exact match) +# Allowed origins for CORS requests (exact match), default: [] allowed_origins_exact = ["http://localhost:8000"] -# Allowed origins for CORS requests (regular expression) +# Allowed origins for CORS requests (regular expression), default: [] allowed_origins_regex = ["^http://localhost:\\d{4}$"] -# Allowed methods for CORS requests +# Allowed methods for CORS requests, default: all allowed_methods = ["GET", "POST", "PUT", "DELETE"] -# Whether to allow credentials in CORS requests +# Whether to allow credentials in CORS requests, default: false allow_credentials = true ``` @@ -30,17 +34,17 @@ The `oauth2` section of the config file is used to configure the OAuth2 settings sensitive data should be stored in the environment variables. ```toml -# Discord OAuth2 settings +# Discord OAuth2 settings (required) [oauth.discord] -# Client ID +# Client ID (required) client_id = "CLIENT_ID" -# Callback URL +# Callback URL (required) redirect_uri = "http://localhost:8000/v1/oauth2/callback/discord" -# Roblox OAuth2 settings +# Roblox OAuth2 settings (required) [oauth.roblox] -# Client ID +# Client ID (required) client_id = "CLIENT_ID" -# Callback URL +# Callback URL (required) redirect_uri = "http://localhost:8000/v1/oauth2/callback/roblox" ``` \ No newline at end of file From b547af2afc07232e7ba8989595c606a6f2a2393b Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:14:14 +0000 Subject: [PATCH 11/14] refactor: Add more comments --- src/constants.rs | 2 + src/lib.rs | 3 +- src/oauth/routes/discord.rs | 78 ++++++++++++++---- src/oauth/routes/mod.rs | 17 ++++ src/oauth/routes/roblox.rs | 11 +-- src/oauth/types/discord.rs | 153 +++++++++++++++++++++--------------- src/oauth/types/mod.rs | 10 +++ src/oauth/types/roblox.rs | 94 ++++++++++------------ src/oauth/utils/discord.rs | 78 +++++++++++++----- src/oauth/utils/mod.rs | 16 +++- src/oauth/utils/pixy.rs | 5 ++ src/oauth/utils/roblox.rs | 11 +++ src/response.rs | 4 +- src/utils.rs | 12 +++ 14 files changed, 332 insertions(+), 162 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 04b11f4..09a7878 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,5 +1,7 @@ //! Constants used throughout the application. +pub(crate) const API_VERSION: &str = "v1"; + pub(crate) mod roblox_api { /// The URL for authorizing with the Roblox API. pub(crate) const AUTHORIZE_URL: &str = "https://apis.roblox.com/oauth/v1/authorize"; diff --git a/src/lib.rs b/src/lib.rs index 5c01fb0..732b455 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ extern crate rocket_sync_db_pools; use crate::config::Config; use crate::response::ApiError; +use crate::utils::construct_api_route; use dotenvy::dotenv; /// Represents a connection to the PostgreSQL database. @@ -61,7 +62,7 @@ pub fn rocket() -> _ { rocket::build() .attach(cfg.cors.clone()) .attach(DbConn::fairing()) - .mount("/v1/oauth2", oauth::routes::routes()) + .mount(construct_api_route("/oauth2"), oauth::routes::routes()) .register("/", catchers![default]) .manage(cfg) } diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index e948d49..19df39b 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -15,6 +15,7 @@ use crate::oauth::utils::discord::{ }; use crate::oauth::utils::{generate_session, generate_state}; use crate::response::{ApiError, ApiResponse, ApiResult}; +use crate::utils::construct_api_route; use crate::DbConn; use diesel::Connection; use rocket::http::{Cookie, CookieJar, SameSite, Status}; @@ -22,31 +23,55 @@ use rocket::response::Redirect; use rocket::time::Duration; use rocket::State; -#[get("/initiate/discord?")] +/// Initiates the Discord OAuth flow by saving a randomly-generated state in a cookie +/// and redirecting the user to the Discord OAuth page. +/// +/// # Possible Responses +/// +/// - `303 See Other` with a redirect to the Discord OAuth page. +/// - `400 Bad Request` if the scope set is invalid. +#[get("/discord/initiate?")] pub(super) fn discord_oauth_initiate( scope_set: String, jar: &CookieJar<'_>, cfg: &State, ) -> ApiResult { - let scope_set = DiscordOAuthScopeSet::try_from(scope_set) + let scopes = DiscordOAuthScopeSet::try_from(scope_set.as_str()) .map_err(|e| ApiError::message(Status::BadRequest, e))?; - + let scopes = DiscordOAuthScopes::from(&scopes); let state = generate_state(); - let redirect_uri = - construct_discord_oauth_url(&DiscordOAuthScopes::from(&scope_set), &state, cfg); + let redirect_uri = construct_discord_oauth_url(&scopes, &state, cfg); + // Build a state cookie that will be used to verify the callback let auth_cookie = Cookie::build((constants::cookie::STATE, state)) - .path("/v1/oauth2/callback/discord") + .path(construct_api_route("/oauth2/callback/discord")) .same_site(SameSite::Lax) .max_age(Duration::minutes(5)); - - // Save the state as a cookie to match against the one that'll be returned from callback jar.add_private(auth_cookie); + // Redirect the user to the Discord OAuth page Ok(Redirect::to(redirect_uri)) } -#[get("/callback/discord?")] +/// Handles the Discord OAuth callback by verifying the state and exchanging the code for a token. +/// +/// - Verifies the state against the one saved in the cookie. +/// - Exchanges the code for a token. +/// - Fetches the authorized user. +/// - Inserts the session and Discord connection into the database. +/// - Saves the current session as a cookie. +/// +/// # Possible Responses +/// +/// - `303 See Other` with a redirect to the main page. +/// - `400 Bad Request` if the state is invalid. +/// - `500 Internal Server Error` +/// - If the code exchange response cannot be parsed +/// - If the authenticated user is missing the 'identify' scope. +/// - If the access token or refresh token cannot be encrypted. +/// - If the session or Discord connection cannot be inserted into the database. +/// - `502 Bad Gateway` if the code exchange fails. +#[get("/discord/callback?")] pub(super) async fn discord_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, @@ -66,21 +91,26 @@ pub(super) async fn discord_oauth_callback( let response = exchange_code(&callback.code, cfg)?; // Fetch the authorized user let authorized_user = get_authorized_user(&response.access_token)?; - - // Parse the user info + // Parse the user info to get the Discord UID let discord_uid = authorized_user - .expect("Failed to unwrap user, missing 'identify' scope") + .ok_or(ApiError::message( + Status::InternalServerError, + "Failed to unwrap user, missing 'identify' scope", + ))? .id; + let session = generate_session(); let token_expires_at = chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); + // Session to insert into the database let new_session = NewSession::build() .discord_uid(discord_uid.clone()) .session_id(session.session_id.clone()) .expires_at(token_expires_at) .build(); + // Discord connection to insert into the database let discord_connection = NewDiscordConnection::build() .uid(discord_uid) .access_token(response.access_token) @@ -91,7 +121,7 @@ pub(super) async fn discord_oauth_callback( .map_err(|_| { ApiError::message( Status::InternalServerError, - "Failed to build Discord connection", + "Failed to encrypt access token and/or refresh token", ) })?; @@ -113,13 +143,29 @@ pub(super) async fn discord_oauth_callback( })?; // Save the current session as a cookie - // This will be used to access the token through the database + // This will be used to authenticate the user in future requests jar.add_private(session.cookie); + // Redirect back to the main page Ok(Redirect::to(uri!("/"))) } -#[post("/refresh-token/discord")] +/// Refreshes the Discord token by decrypting the refresh token, refreshing the token, and updating the database. +/// +/// # Possible Responses +/// +/// - `204 No Content` if the token was successfully refreshed. +/// - `401 Unauthorized` +/// - If the session cookie is missing. +/// - If the session is not found in the database. +/// - `500 Internal Server Error` +/// - If the Discord connection associated with the session is not found. +/// - If the refresh token cannot be decrypted. +/// - If the token refresh response cannot be parsed. +/// - If the access token or refresh token cannot be encrypted. +/// - If the Discord connection cannot be updated in the database. +/// - `502 Bad Gateway` if the token refresh fails. +#[post("/discord/refresh-token")] pub(super) async fn discord_refresh_token( conn: DbConn, session_id: SessionId, @@ -168,7 +214,7 @@ pub(super) async fn discord_refresh_token( .map_err(|_| { ApiError::message( Status::InternalServerError, - "Failed to build Discord connection update", + "Failed to encrypt access token and/or refresh token", ) })?; diff --git a/src/oauth/routes/mod.rs b/src/oauth/routes/mod.rs index f29eea3..2241092 100644 --- a/src/oauth/routes/mod.rs +++ b/src/oauth/routes/mod.rs @@ -10,11 +10,22 @@ use std::sync::Arc; mod discord; mod roblox; +/// Refreshes the session ID by updating the expiration date +/// +/// # Possible Responses +/// +/// - `204 No Content` if the session was successfully refreshed. +/// - `401 Unauthorized` +/// - If the session cookie is missing. +/// - If the session is not found in the database. +/// - `500 Internal Server Error` if the session cannot be updated in the database. #[post("/refresh-session")] async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionId) -> ApiResult { + // Wrap the ID in an Arc to avoid cloning the value let session_id = Arc::new(session_id.into_inner()); let session_id_find = Arc::clone(&session_id); + // Check if the session exists conn.run(move |conn| { let Some(_) = SessionsDb::find_one(session_id_find.as_str(), conn) else { return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); @@ -24,12 +35,14 @@ async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionI .await .map_err(|_| ApiError::message(Status::Unauthorized, "Session not found"))?; + // Generate a new session ID let session = generate_session(); let updated_session = UpdateSession::build() .session_id(session.session_id.clone()) .expires_at(session.expires_at) .build_update(); + // Update the session in the database conn.run(move |conn| { let Some(session) = SessionsDb::update_one(session_id.as_str(), updated_session, conn) else { @@ -40,11 +53,15 @@ async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionI .await .map_err(|_| ApiError::message(Status::InternalServerError, "Failed to update session"))?; + // Update the session cookie jar.add_private(session.cookie); Ok(ApiResponse::status(Status::NoContent)) } +/// Returns the routes for the OAuth module. +/// +/// These routes should be mounted on `/oauth2`. pub(crate) fn routes() -> Vec { routes![ discord::discord_oauth_initiate, diff --git a/src/oauth/routes/roblox.rs b/src/oauth/routes/roblox.rs index 72c47f9..d8dadc6 100644 --- a/src/oauth/routes/roblox.rs +++ b/src/oauth/routes/roblox.rs @@ -6,12 +6,13 @@ use crate::oauth::utils::generate_state; use crate::oauth::utils::pixy::Pixy; use crate::oauth::utils::roblox::construct_roblox_oauth_url; use crate::response::{ApiError, ApiResult}; +use crate::utils::construct_api_route; use rocket::http::{Cookie, CookieJar, Status}; use rocket::response::Redirect; use rocket::time::Duration; use rocket::State; -#[get("/callback/roblox?")] +#[get("/roblox/callback?")] pub(super) fn roblox_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, @@ -41,13 +42,13 @@ pub(super) fn roblox_oauth_callback( Ok(Redirect::to(uri!("/"))) } -#[get("/initiate/roblox?")] +#[get("/roblox/initiate?")] pub(super) fn roblox_oauth_initiate( scope_set: String, jar: &CookieJar<'_>, cfg: &State, ) -> ApiResult { - let scope_set = RobloxOAuthScopeSet::try_from(scope_set) + let scope_set = RobloxOAuthScopeSet::try_from(scope_set.as_str()) .map_err(|e| ApiError::message(Status::BadRequest, e))?; let pixy = Pixy::new(); @@ -59,13 +60,13 @@ pub(super) fn roblox_oauth_initiate( construct_roblox_oauth_url(challenge, &RobloxOAuthScopes::from(&scope_set), &state, cfg); let auth_cookie = Cookie::build((constants::cookie::STATE, state)) - .path("/v1/oauth2/callback/roblox") + .path(construct_api_route("/oauth2/callback/roblox")) .same_site(rocket::http::SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(auth_cookie); let verifier_cookie = Cookie::build((constants::cookie::OAUTH_CODE_VERIFIER, verifier)) - .path("/v1/oauth2/callback/roblox") + .path(construct_api_route("/oauth2/callback/roblox")) .same_site(rocket::http::SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(verifier_cookie); diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index 595f095..0b00b09 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -4,12 +4,21 @@ use crate::url; use rocket::serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result as FmtResult}; +/// A set of scopes that will be requested by the web app. +/// +/// The internal API should parse these scope sets into individual scopes +/// that will be requested from the OAuth2 provider pub(crate) enum DiscordOAuthScopeSet { + /// The scopes required for verifying a user's Discord account Verification, } +/// A set of scopes that will be requested from the OAuth2 provider. pub(crate) struct DiscordOAuthScopes(pub(crate) Vec); +/// A scope that can be requested from the OAuth2 provider. +/// +/// [Reference](https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-scopes) pub(crate) enum DiscordOAuthScope { Identify, Guilds, @@ -55,46 +64,74 @@ pub(crate) enum DiscordOAuthScope { RPCVideoWrite, } +/// The request body for the Discord OAuth2 token endpoint. +/// +/// [Reference](https://discord.com/developers/docs/topics/oauth2#authorization-code-grant) #[derive(Serialize)] #[serde(crate = "rocket::serde")] pub(crate) struct DiscordAuthorizationCodeRequestBody<'a> { + /// The client ID of the application client_id: &'a str, + /// The client secret of the application. + /// Must be an owned string as it is retrieved from the environment client_secret: String, + /// The grant type of the request. This field must contain the value `authorization_code` grant_type: &'a str, + /// The authorization code that was received from the authorization callback code: &'a str, + /// The callback URL that was used to request the authorization code redirect_uri: &'a str, } +/// The request body for the Discord OAuth2 token refresh endpoint. +/// +/// [Reference](https://discord.com/developers/docs/topics/oauth2#authorization-code-grant) #[derive(Serialize)] #[serde(crate = "rocket::serde")] -pub(crate) struct DiscordTokenRefreshBody<'a> { +pub(crate) struct DiscordRefreshTokenBody<'a> { + /// The grant type of the request. This field must contain the value `refresh_token` grant_type: &'a str, + /// The refresh token that was received from the token endpoint refresh_token: &'a str, + /// The client ID of the application client_id: &'a str, + /// The client secret of the application + /// Must be an owned string as it is retrieved from the environment client_secret: String, } +/// The response from the Discord OAuth2 token endpoint. +/// +/// [Reference](https://discord.com/developers/docs/topics/oauth2#authorization-code-grant-access-token-response) #[derive(Deserialize)] #[serde(crate = "rocket::serde")] -pub(crate) struct DiscordAuthorizationCodeResponse { +pub(crate) struct DiscordAccessTokenResponse { + /// The access token that can be used to authenticate requests pub(crate) access_token: String, + /// The type of token pub(crate) token_type: String, + /// The number of seconds until the token expires pub(crate) expires_in: i64, + /// The refresh token that can be used to refresh the access token pub(crate) refresh_token: String, + /// The scopes that the user has authorized pub(crate) scope: String, } -/// The response from the Discord OAuth2 @me endpoint +/// The response from the Discord OAuth2 @me endpoint. +/// +/// [Reference](https://discord.com/developers/docs/topics/oauth2#get-current-authorization-information-example-authorization-information) #[derive(Deserialize, Debug)] #[serde(crate = "rocket::serde")] pub(crate) struct DiscordAuthorizedUserResponse { - /// The user who has authorized + /// The user who has authorized the application /// /// ⚠️ Requires the `identify` scope pub(crate) user: Option, } /// The user object represents a user profile on Discord. +/// /// [Reference](https://discord.com/developers/docs/resources/user#user-object) #[derive(Deserialize, Debug)] #[serde(crate = "rocket::serde")] @@ -137,63 +174,53 @@ pub(crate) struct DiscordUser { pub(crate) public_flags: Option, } -impl From<&DiscordOAuthScope> for String { +impl From<&DiscordOAuthScope> for &str { fn from(value: &DiscordOAuthScope) -> Self { match value { - DiscordOAuthScope::Identify => String::from("identify"), - DiscordOAuthScope::Guilds => String::from("guilds"), - DiscordOAuthScope::GuildsChannelsRead => String::from("guilds.channels.read"), - DiscordOAuthScope::Rpc => String::from("rpc"), - DiscordOAuthScope::RPCVoiceWrite => String::from("rpc.voice.write"), - DiscordOAuthScope::RPCScreenshareRead => String::from("rpc.screenshare.read"), - DiscordOAuthScope::WebhookIncoming => String::from("webhook.incoming"), - DiscordOAuthScope::ApplicationsBuildsRead => String::from("applications.builds.read"), - DiscordOAuthScope::ApplicationsEntitlements => { - String::from("applications.entitlements") - } - DiscordOAuthScope::RelationshipsRead => String::from("relationships.read"), - DiscordOAuthScope::DMChannelsRead => String::from("dm.channels.read"), - DiscordOAuthScope::PresencesWrite => String::from("presences.write"), - DiscordOAuthScope::DMChannelsMessagesWrite => { - String::from("dm.channels.messages.write") - } - DiscordOAuthScope::PaymentSourcesCountryCode => { - String::from("payment.sources.country_code") - } + DiscordOAuthScope::Identify => "identify", + DiscordOAuthScope::Guilds => "guilds", + DiscordOAuthScope::GuildsChannelsRead => "guilds.channels.read", + DiscordOAuthScope::Rpc => "rpc", + DiscordOAuthScope::RPCVoiceWrite => "rpc.voice.write", + DiscordOAuthScope::RPCScreenshareRead => "rpc.screenshare.read", + DiscordOAuthScope::WebhookIncoming => "webhook.incoming", + DiscordOAuthScope::ApplicationsBuildsRead => "applications.builds.read", + DiscordOAuthScope::ApplicationsEntitlements => "applications.entitlements", + DiscordOAuthScope::RelationshipsRead => "relationships.read", + DiscordOAuthScope::DMChannelsRead => "dm.channels.read", + DiscordOAuthScope::PresencesWrite => "presences.write", + DiscordOAuthScope::DMChannelsMessagesWrite => "dm.channels.messages.write", + DiscordOAuthScope::PaymentSourcesCountryCode => "payment.sources.country_code", DiscordOAuthScope::ApplicationsCommandsPermissionsUpdate => { - String::from("applications.commands.permissions.update") + "applications.commands.permissions.update" } - DiscordOAuthScope::Email => String::from("email"), - DiscordOAuthScope::GuildsJoin => String::from("guilds.join"), - DiscordOAuthScope::GDMJoin => String::from("gdm.join"), - DiscordOAuthScope::RPCNotificationsRead => String::from("rpc.notifications.read"), - DiscordOAuthScope::RPCVideoRead => String::from("rpc.video.read"), - DiscordOAuthScope::RPCScreenshareWrite => String::from("rpc.screenshare.write"), - DiscordOAuthScope::MessagesRead => String::from("messages.read"), - DiscordOAuthScope::ApplicationsCommands => String::from("applications.commands"), - DiscordOAuthScope::ActivitiesRead => String::from("activities.read"), - DiscordOAuthScope::RelationshipsWrite => String::from("relationships.write"), - DiscordOAuthScope::RoleConnectionsWrite => String::from("role.connections.write"), - DiscordOAuthScope::OpenID => String::from("openid"), - DiscordOAuthScope::GatewayConnect => String::from("gateway.connect"), - DiscordOAuthScope::SDKSocialLayer => String::from("sdk.social_layer"), - DiscordOAuthScope::RPCActivitiesWrite => String::from("rpc.activities.write"), - DiscordOAuthScope::ApplicationsBuildsUpload => { - String::from("applications.builds.upload") - } - DiscordOAuthScope::ApplicationsStoreUpdate => String::from("applications.store.update"), - DiscordOAuthScope::ActivitiesWrite => String::from("activities.write"), - DiscordOAuthScope::Voice => String::from("voice"), - DiscordOAuthScope::PresencesRead => String::from("presences.read"), - DiscordOAuthScope::DMChannelsMessagesRead => String::from("dm.channels.messages.read"), - DiscordOAuthScope::AccountGlobalNameUpdate => { - String::from("account.global.name.update") - } - DiscordOAuthScope::Connections => String::from("connections"), - DiscordOAuthScope::GuildsMembersRead => String::from("guilds.members.read"), - DiscordOAuthScope::Bot => String::from("bot"), - DiscordOAuthScope::RPCVoiceRead => String::from("rpc.voice.read"), - DiscordOAuthScope::RPCVideoWrite => String::from("rpc.video.write"), + DiscordOAuthScope::Email => "email", + DiscordOAuthScope::GuildsJoin => "guilds.join", + DiscordOAuthScope::GDMJoin => "gdm.join", + DiscordOAuthScope::RPCNotificationsRead => "rpc.notifications.read", + DiscordOAuthScope::RPCVideoRead => "rpc.video.read", + DiscordOAuthScope::RPCScreenshareWrite => "rpc.screenshare.write", + DiscordOAuthScope::MessagesRead => "messages.read", + DiscordOAuthScope::ApplicationsCommands => "applications.commands", + DiscordOAuthScope::ActivitiesRead => "activities.read", + DiscordOAuthScope::RelationshipsWrite => "relationships.write", + DiscordOAuthScope::RoleConnectionsWrite => "role.connections.write", + DiscordOAuthScope::OpenID => "openid", + DiscordOAuthScope::GatewayConnect => "gateway.connect", + DiscordOAuthScope::SDKSocialLayer => "sdk.social_layer", + DiscordOAuthScope::RPCActivitiesWrite => "rpc.activities.write", + DiscordOAuthScope::ApplicationsBuildsUpload => "applications.builds.upload", + DiscordOAuthScope::ApplicationsStoreUpdate => "applications.store.update", + DiscordOAuthScope::ActivitiesWrite => "activities.write", + DiscordOAuthScope::Voice => "voice", + DiscordOAuthScope::PresencesRead => "presences.read", + DiscordOAuthScope::DMChannelsMessagesRead => "dm.channels.messages.read", + DiscordOAuthScope::AccountGlobalNameUpdate => "account.global.name.update", + DiscordOAuthScope::Connections => "connections", + DiscordOAuthScope::GuildsMembersRead => "guilds.members.read", + DiscordOAuthScope::Bot => "bot", + DiscordOAuthScope::RPCVoiceRead => "rpc.voice.read", + DiscordOAuthScope::RPCVideoWrite => "rpc.video.write", } } } @@ -204,8 +231,8 @@ impl From<&DiscordOAuthScopes> for String { scopes .iter() - .map(String::from) - .collect::>() + .map(|scope| scope.into()) + .collect::>() .join("+") } } @@ -229,11 +256,11 @@ impl From<&DiscordOAuthScopeSet> for DiscordOAuthScopes { } } -impl TryFrom for DiscordOAuthScopeSet { +impl TryFrom<&str> for DiscordOAuthScopeSet { type Error = String; - fn try_from(value: String) -> Result { - match value.as_str() { + fn try_from(value: &str) -> Result { + match value { "verification" => Ok(DiscordOAuthScopeSet::Verification), _ => Err(format!("Invalid scope set: {}", value)), } @@ -264,7 +291,7 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { } } -impl<'a> DiscordTokenRefreshBody<'a> { +impl<'a> DiscordRefreshTokenBody<'a> { pub(crate) fn new(refresh_token: &'a str, cfg: &'a Config) -> Self { Self { grant_type: "refresh_token", diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index 373f61f..6e0012a 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -7,21 +7,28 @@ use rocket::Request; pub(super) mod discord; pub(super) mod roblox; +/// The OAuth callback query parameters. #[derive(Debug, FromForm)] pub(super) struct OAuthCallback { pub(super) code: String, pub(super) state: String, } +/// Result of the [`generate_session`](crate::oauth::utils::generate_session) function. pub(super) struct GeneratedSession<'a> { + /// The session cookie. pub(super) cookie: Cookie<'a>, + /// The session ID. pub(super) session_id: String, + /// The expiration date of the session. pub(super) expires_at: chrono::NaiveDateTime, } +/// A rocket request guard that extracts the session ID from the session cookie. pub(super) struct SessionId(String); impl SessionId { + /// Consumes the [`SessionId`] and returns the inner session ID. pub(super) fn into_inner(self) -> String { self.0 } @@ -32,11 +39,14 @@ impl<'r> FromRequest<'r> for SessionId { type Error = ApiError; async fn from_request(request: &'r Request<'_>) -> Outcome { + // Extract the session ID from the session cookie. let session_id = request .cookies() .get_pending(constants::cookie::SESSION_ID) .map(|cookie| cookie.value().to_string()); + // Return an error if the session ID is missing. + // Otherwise, return the session ID. match session_id { None => { let error = ApiError::message( diff --git a/src/oauth/types/roblox.rs b/src/oauth/types/roblox.rs index dda40e6..a019894 100644 --- a/src/oauth/types/roblox.rs +++ b/src/oauth/types/roblox.rs @@ -1,11 +1,17 @@ use std::fmt::{Display, Formatter, Result as FmtResult}; +/// A set of scopes that will be requested by the web app. +/// +/// The internal API should parse these scope sets into individual scopes +/// that will be requested from the OAuth2 provider pub(crate) enum RobloxOAuthScopeSet { Verification, } +/// A set of scopes that will be requested from the OAuth2 provider. pub(crate) struct RobloxOAuthScopes(pub(crate) Vec); +/// A scope that will be requested from the OAuth2 provider. pub(crate) enum RobloxOAuthScope { OpenID, Profile, @@ -39,63 +45,47 @@ pub(crate) enum RobloxOAuthScope { UserUserNotificationWrite, } -impl From<&RobloxOAuthScope> for String { +impl From<&RobloxOAuthScope> for &str { fn from(value: &RobloxOAuthScope) -> Self { match value { - RobloxOAuthScope::OpenID => String::from("openid"), - RobloxOAuthScope::Profile => String::from("profile"), - RobloxOAuthScope::AssetRead => String::from("creator-store-product:read"), - RobloxOAuthScope::AssetWrite => String::from("creator-store-product:write"), - RobloxOAuthScope::GroupRead => String::from("group:read"), - RobloxOAuthScope::GroupWrite => String::from("group:write"), - RobloxOAuthScope::LegacyBadgeManage => String::from("legacy-badge:manage"), - RobloxOAuthScope::LegacyDeveloperProductManage => { - String::from("legacy-developer-product:manage") - } - RobloxOAuthScope::LegacyGamePassManage => String::from("legacy-game-pass:manage"), - RobloxOAuthScope::LegacyGroupManage => String::from("legacy-group:manage"), - RobloxOAuthScope::LegacyTeamCollaborationManage => { - String::from("legacy-team-collaboration:manage") - } - RobloxOAuthScope::LegacyUniverseManage => String::from("legacy-universe:manage"), - RobloxOAuthScope::LegacyUniverseBadgeWrite => { - String::from("legacy-universe.badge:write") - } - RobloxOAuthScope::LegacyUniverseFollowingRead => { - String::from("legacy-universe.following:read") - } - RobloxOAuthScope::LegacyUniverseFollowingWrite => { - String::from("legacy-universe.following:write") - } - RobloxOAuthScope::LegacyUserManage => String::from("legacy-user:manage"), + RobloxOAuthScope::OpenID => "openid", + RobloxOAuthScope::Profile => "profile", + RobloxOAuthScope::AssetRead => "creator-store-product:read", + RobloxOAuthScope::AssetWrite => "creator-store-product:write", + RobloxOAuthScope::GroupRead => "group:read", + RobloxOAuthScope::GroupWrite => "group:write", + RobloxOAuthScope::LegacyBadgeManage => "legacy-badge:manage", + RobloxOAuthScope::LegacyDeveloperProductManage => "legacy-developer-product:manage", + RobloxOAuthScope::LegacyGamePassManage => "legacy-game-pass:manage", + RobloxOAuthScope::LegacyGroupManage => "legacy-group:manage", + RobloxOAuthScope::LegacyTeamCollaborationManage => "legacy-team-collaboration:manage", + RobloxOAuthScope::LegacyUniverseManage => "legacy-universe:manage", + RobloxOAuthScope::LegacyUniverseBadgeWrite => "legacy-universe.badge:write", + RobloxOAuthScope::LegacyUniverseFollowingRead => "legacy-universe.following:read", + RobloxOAuthScope::LegacyUniverseFollowingWrite => "legacy-universe.following:write", + RobloxOAuthScope::LegacyUserManage => "legacy-user:manage", RobloxOAuthScope::UniverseMessagingServicePublish => { - String::from("universe-messaging-service:publish") + "universe-messaging-service:publish" } - RobloxOAuthScope::UniverseWrite => String::from("universe:write"), - RobloxOAuthScope::UniversePlaceWrite => String::from("universe.place:write"), + RobloxOAuthScope::UniverseWrite => "universe:write", + RobloxOAuthScope::UniversePlaceWrite => "universe.place:write", RobloxOAuthScope::UniverseSubscriptionProductSubscriptionRead => { - String::from("universe.subscription-product.subscription:read") + "universe.subscription-product.subscription:read" } - RobloxOAuthScope::UniverseUserRestrictionRead => { - String::from("universe.user-restriction:read") - } - RobloxOAuthScope::UniverseUserRestrictionWrite => { - String::from("universe.user-restriction:write") - } - RobloxOAuthScope::UserAdvancedRead => String::from("user.advanced:read"), - RobloxOAuthScope::UserCommerceItemRead => String::from("user.commerce-item:read"), - RobloxOAuthScope::UserCommerceItemWrite => String::from("user.commerce-item:write"), + RobloxOAuthScope::UniverseUserRestrictionRead => "universe.user-restriction:read", + RobloxOAuthScope::UniverseUserRestrictionWrite => "universe.user-restriction:write", + RobloxOAuthScope::UserAdvancedRead => "user.advanced:read", + RobloxOAuthScope::UserCommerceItemRead => "user.commerce-item:read", + RobloxOAuthScope::UserCommerceItemWrite => "user.commerce-item:write", RobloxOAuthScope::UserCommerceMerchantConnectionRead => { - String::from("user.commerce-merchant-connection:read") + "user.commerce-merchant-connection:read" } RobloxOAuthScope::UserCommerceMerchantConnectionWrite => { - String::from("user.commerce-merchant-connection:write") - } - RobloxOAuthScope::UserInventoryItemRead => String::from("user.inventory-item:read"), - RobloxOAuthScope::UserSocialRead => String::from("user.social:read"), - RobloxOAuthScope::UserUserNotificationWrite => { - String::from("user.user-notification:write") + "user.commerce-merchant-connection:write" } + RobloxOAuthScope::UserInventoryItemRead => "user.inventory-item:read", + RobloxOAuthScope::UserSocialRead => "user.social:read", + RobloxOAuthScope::UserUserNotificationWrite => "user.user-notification:write", } } } @@ -106,8 +96,8 @@ impl From<&RobloxOAuthScopes> for String { scopes .iter() - .map(String::from) - .collect::>() + .map(|scope| scope.into()) + .collect::>() .join("%20") } } @@ -132,11 +122,11 @@ impl From<&RobloxOAuthScopeSet> for RobloxOAuthScopes { } } -impl TryFrom for RobloxOAuthScopeSet { +impl TryFrom<&str> for RobloxOAuthScopeSet { type Error = String; - fn try_from(value: String) -> Result { - match value.as_str() { + fn try_from(value: &str) -> Result { + match value { "verification" => Ok(RobloxOAuthScopeSet::Verification), _ => Err(format!("Invalid scope set: {}", value)), } diff --git a/src/oauth/utils/discord.rs b/src/oauth/utils/discord.rs index 923a83c..25b5b8f 100644 --- a/src/oauth/utils/discord.rs +++ b/src/oauth/utils/discord.rs @@ -1,13 +1,24 @@ use crate::config::Config; use crate::constants; use crate::oauth::types::discord::{ - DiscordAuthorizationCodeRequestBody, DiscordAuthorizationCodeResponse, - DiscordAuthorizedUserResponse, DiscordOAuthScopes, DiscordTokenRefreshBody, DiscordUser, + DiscordAccessTokenResponse, DiscordAuthorizationCodeRequestBody, DiscordAuthorizedUserResponse, + DiscordOAuthScopes, DiscordRefreshTokenBody, DiscordUser, }; use crate::response::ApiError; use crate::url; use rocket::http::Status; +/// Constructs the Discord OAuth URL with the given scopes and state. +/// The scopes are joined by a plus sign (`+`). +/// +/// # Arguments +/// +/// * `scopes` - The scopes to request from the user. +/// * `state` - The state to send to the Discord OAuth server. +/// +/// # Returns +/// +/// The constructed Discord OAuth URL. pub(crate) fn construct_discord_oauth_url( scopes: &DiscordOAuthScopes, state: &str, @@ -25,10 +36,20 @@ pub(crate) fn construct_discord_oauth_url( ) } +/// Exchanges the given code for an access token. +/// +/// # Arguments +/// +/// * `code` - The code to exchange for an access token. +/// * `cfg` - The application configuration. +/// +/// # Returns +/// +/// The [`DiscordAccessTokenResponse`] struct if the exchange was successful, an [`ApiError`] otherwise. pub(crate) fn exchange_code( code: &str, cfg: &Config, -) -> Result { +) -> Result { let body = DiscordAuthorizationCodeRequestBody::new(code, cfg); let response = minreq::post(constants::discord_api::TOKEN_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") @@ -36,37 +57,52 @@ pub(crate) fn exchange_code( .send() .map_err(|_| ApiError::message(Status::BadGateway, "Failed to exchange code for token"))?; - response - .json::() - .map_err(|_| { - ApiError::message( - Status::InternalServerError, - "Failed to parse token response", - ) - }) + response.json::().map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse token response", + ) + }) } +/// Refreshes the given refresh token. +/// +/// # Arguments +/// +/// * `refresh_token` - The refresh token to use. +/// * `cfg` - The application configuration. +/// +/// # Returns +/// +/// The [`DiscordAccessTokenResponse`] struct if the refresh was successful, an [`ApiError`] otherwise. pub(crate) fn refresh_token( refresh_token: &str, cfg: &Config, -) -> Result { - let body = DiscordTokenRefreshBody::new(refresh_token, cfg); +) -> Result { + let body = DiscordRefreshTokenBody::new(refresh_token, cfg); let response = minreq::post(constants::discord_api::TOKEN_URL) .with_header("Content-Type", "application/x-www-form-urlencoded") .with_body(body.as_query_params()) .send() .map_err(|_| ApiError::message(Status::BadGateway, "Failed to refresh token"))?; - response - .json::() - .map_err(|_| { - ApiError::message( - Status::InternalServerError, - "Failed to parse token refresh response", - ) - }) + response.json::().map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse token refresh response", + ) + }) } +/// Gets the authorized user from the Discord API. +/// +/// # Arguments +/// +/// * `access_token` - The access token to use. +/// +/// # Returns +/// +/// The [`DiscordUser`] struct if the request was successful, an [`ApiError`] otherwise. pub(crate) fn get_authorized_user(access_token: &str) -> Result, ApiError> { minreq::get(constants::discord_api::USER_URL) .with_header("Authorization", format!("Bearer {}", access_token)) diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs index 5137cca..a87c98e 100644 --- a/src/oauth/utils/mod.rs +++ b/src/oauth/utils/mod.rs @@ -8,18 +8,30 @@ use crate::oauth::types::GeneratedSession; use rand::{thread_rng, Rng}; use rocket::http::Cookie; +/// Generates a random state. +/// +/// The state is a random base64 URL-safe string with no padding and a length between 10 and 20 bytes. pub(super) fn generate_state() -> String { let number_of_bytes = thread_rng().gen_range(10..=20); - generate_random_base64_url_safe_no_pad_string(number_of_bytes) } +/// Generates a random session ID. +/// +/// The session ID is a random base64 URL-safe string with no padding and a length between 43 and 86 bytes. fn generate_session_id() -> String { let number_of_bytes = thread_rng().gen_range(32..=64); - generate_random_base64_url_safe_no_pad_string(number_of_bytes) } +/// Generates a new session. +/// +/// - The session ID is a random base64 URL-safe string with no padding and a length between 43 and 86 bytes. +/// - The session cookie has the session ID as the value and a max age of 30 days. +/// +/// # Returns +/// +/// A [`GeneratedSession`] struct containing the session cookie, session ID, and the expiration date. pub(super) fn generate_session() -> GeneratedSession<'static> { let session_id = generate_session_id(); let expires_at = chrono::Utc::now() + chrono::Duration::days(30); diff --git a/src/oauth/utils/pixy.rs b/src/oauth/utils/pixy.rs index cbb9776..e512de1 100644 --- a/src/oauth/utils/pixy.rs +++ b/src/oauth/utils/pixy.rs @@ -2,6 +2,11 @@ use rand::{thread_rng, Rng}; use secrecy::{ExposeSecret, SecretString}; use sha2::{Digest, Sha256}; +/// Generates a random base64 URL-safe string with no padding. +/// +/// # Arguments +/// +/// - `number_of_bytes` - The number of random bytes to generate. pub(crate) fn generate_random_base64_url_safe_no_pad_string(number_of_bytes: usize) -> String { let mut random_bytes = vec![0u8; number_of_bytes]; diff --git a/src/oauth/utils/roblox.rs b/src/oauth/utils/roblox.rs index a8da8c1..4b9d1ee 100644 --- a/src/oauth/utils/roblox.rs +++ b/src/oauth/utils/roblox.rs @@ -4,6 +4,17 @@ use crate::oauth::types::roblox::RobloxOAuthScopes; use crate::url; use secrecy::{ExposeSecret, SecretString}; +/// Constructs the Roblox OAuth URL with the given scopes and state. +/// The scopes are joined by an encoded space (`%20`). +/// +/// # Arguments +/// +/// * `scopes` - The scopes to request from the user. +/// * `state` - The state to send to the Roblox OAuth server. +/// +/// # Returns +/// +/// The constructed Roblox OAuth URL. pub(crate) fn construct_roblox_oauth_url( code_challenge_secret: &SecretString, scopes: &RobloxOAuthScopes, diff --git a/src/response.rs b/src/response.rs index 9a449b2..6fdb60d 100644 --- a/src/response.rs +++ b/src/response.rs @@ -51,11 +51,11 @@ impl ApiError { /// A new [`ApiError`] instance with the provided status code and custom message. pub(crate) fn message(status: Status, message: M) -> Self where - M: ToString, + M: Into, { Self { code: status.code, - message: message.to_string(), + message: message.into(), } } } diff --git a/src/utils.rs b/src/utils.rs index 14734bc..2185835 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,5 @@ +use crate::constants; + /// Constructs a URL with query parameters. /// /// This macro can be used to create a URL with query parameters from a domain and a list of key-value pairs, @@ -23,3 +25,13 @@ macro_rules! url { vec![$(format!("{}={}", $key, $value)),*].join("&") }; } + +/// Constructs a route with the API version. +/// +/// Does not include a slash between the API version and the route. +pub(crate) fn construct_api_route(route: R) -> String +where + R: AsRef, +{ + format!("/{}{}", constants::API_VERSION, route.as_ref()) +} From c138fe4baca34ef5e2cc3b72bef6825e121dd5ce Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:59:59 +0000 Subject: [PATCH 12/14] refactor: Add remaining comments --- src/database/wrappers/account_links/mod.rs | 299 +++++++++++------- src/database/wrappers/account_links/models.rs | 25 +- src/database/wrappers/account_links/schema.rs | 10 +- .../wrappers/discord_connections/mod.rs | 86 +++-- .../wrappers/discord_connections/models.rs | 74 +++-- .../wrappers/discord_connections/schema.rs | 11 + src/database/wrappers/mod.rs | 3 + src/database/wrappers/sessions/mod.rs | 88 ++++-- src/database/wrappers/sessions/models.rs | 43 ++- src/database/wrappers/sessions/schema.rs | 8 + src/oauth/routes/discord.rs | 8 +- src/oauth/routes/mod.rs | 12 +- src/oauth/utils/mod.rs | 4 +- src/oauth/utils/pixy.rs | 19 ++ 14 files changed, 450 insertions(+), 240 deletions(-) diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs index 64a189c..da1d9ec 100644 --- a/src/database/wrappers/account_links/mod.rs +++ b/src/database/wrappers/account_links/mod.rs @@ -3,36 +3,33 @@ mod schema; use self::models::{AccountLink, NewAccountLink}; use diesel::prelude::*; -use std::borrow::Borrow; +use std::marker::PhantomData; +/// A collection of methods for interacting with the `account_links` table. pub(crate) struct AccountLinksDb; impl AccountLinksDb { - /// Creates a new account link in the database. + /// Inserts a new account link into the database. /// /// # Arguments /// - /// * `roblox_uid` - The Roblox ID of the user. - /// * `discord_uid` - The Discord ID of the user. - /// * `is_primary` - Whether the account is the primary account. + /// * `new_account_link` - The new account link to insert. + /// * `conn` - The database connection to use. /// /// # Returns /// - /// The newly created user if successful, `None` otherwise. - pub(crate) fn insert_one( - new_account_link: V, + /// The newly inserted [`AccountLink`], if successful, [`None`] if the account link already exists, + /// or an error if the operation failed. + pub(crate) fn insert_one( + new_account_link: NewAccountLink, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> - where - V: Borrow, - { - use self::schema::account_links::dsl::*; - let new_account_link = new_account_link.borrow(); - - let pk = AccountLinkPk { - discord_uid: new_account_link.discord_uid.as_str(), - roblox_uid: new_account_link.roblox_uid, - }; + ) -> Result, diesel::result::Error> { + use self::schema::account_links::dsl::{account_links, discord_uid, is_primary}; + + let pk = ( + AccountLinkUserId::new(new_account_link.discord_uid.as_str()), + AccountLinkUserId::new(new_account_link.roblox_uid), + ); // Already exists if Self::find_one(pk, conn).is_some() { @@ -63,154 +60,222 @@ impl AccountLinksDb { } } - /// Deletes account links associated with `target_id` from the database. + /// Finds an account link in the database by its primary key. /// /// # Arguments /// - /// * `target_id` - ID of the account link(s) to delete. + /// * `pk` - The primary key of the account link to find. + /// * `conn` - The database connection to use. /// /// # Returns /// - /// `true` if the account link(s) were successfully deleted, `false` otherwise. - pub(crate) fn delete_many<'a, PartialPK>(target_id: PartialPK, conn: &mut PgConnection) -> bool - where - PartialPK: Borrow>, - { - use schema::account_links::dsl::*; - - match target_id.borrow() { - UserId::DiscordMarker(t_id) => diesel::delete(account_links) - .filter(discord_uid.eq(t_id)) - .execute(conn) - .is_ok(), - UserId::RobloxMarker(t_id) => diesel::delete(account_links) - .filter(roblox_uid.eq(t_id)) - .execute(conn) - .is_ok(), - } + /// The [`AccountLink`], if found, or [`None`] if no account pink with the primary key exists. + pub(crate) fn find_one( + pk: ( + AccountLinkUserId, + AccountLinkUserId<'static, RobloxMarker>, + ), + conn: &mut PgConnection, + ) -> Option { + use self::schema::account_links::dsl::{account_links, discord_uid, roblox_uid}; + let (pk_discord_uid, pk_roblox_uid) = pk; + + account_links + .filter(discord_uid.eq(pk_discord_uid.id)) + .filter(roblox_uid.eq(pk_roblox_uid.id)) + .first::(conn) + .ok() } - /// Retrieves an account link from the database. + /// Updates the primary account link in the database. /// /// # Arguments /// - /// * `discord_uid` - The Discord ID of the account link. + /// * `pk` - The primary key of the account link to update. + /// * `conn` - The database connection to use. /// /// # Returns /// - /// The account link if it exists, `None` otherwise. - pub(crate) fn find_primary( - pk_discord_uid: PartialPK, + /// `true` if the account link was successfully updated, `false` if the account link did not exist. + pub(crate) fn set_primary( + pk: ( + AccountLinkUserId, + AccountLinkUserId<'static, RobloxMarker>, + ), conn: &mut PgConnection, - ) -> Option - where - PartialPK: AsRef, - { - use schema::account_links::dsl::*; - let pk_discord_uid = pk_discord_uid.as_ref(); + ) -> bool { + use self::schema::account_links::dsl::{ + account_links, discord_uid, is_primary, roblox_uid, + }; + let (pk_discord_uid, pk_roblox_uid) = pk; + + conn.transaction(|conn| { + // Set all other account links to not primary + diesel::update(account_links) + .filter(discord_uid.eq(&pk_discord_uid.id)) + .set(is_primary.eq(false)) + .execute(conn)?; + + // Set the new primary account link + diesel::update(account_links) + .filter(discord_uid.eq(pk_discord_uid.id)) + .filter(roblox_uid.eq(pk_roblox_uid.id)) + .set(is_primary.eq(true)) + .execute(conn)?; + + diesel::result::QueryResult::Ok(()) + }) + .is_ok() + } +} + +/// Marker for Discord users. +pub(crate) struct DiscordMarker; +/// Marker for Roblox users. +pub(crate) struct RobloxMarker; + +/// Marker trait for user types. +pub(crate) trait UserMarker<'a> { + /// The ID type for the user. + type Id; +} + +impl<'a> UserMarker<'a> for DiscordMarker { + type Id = &'a str; +} + +impl UserMarker<'static> for RobloxMarker { + type Id = i64; +} + +/// Wrapper for account link user IDs. +pub(crate) struct AccountLinkUserId<'a, Marker: UserMarker<'a>> { + /// Marker for the user type. + marker: PhantomData, + /// The user ID. + id: Marker::Id, +} + +impl<'a, Marker: UserMarker<'a>> AccountLinkUserId<'a, Marker> { + /// Create a new [`AccountLinkUserId`] with the given ID. + pub(crate) fn new(id: Marker::Id) -> Self { + Self { + marker: PhantomData, + id, + } + } +} + +impl AccountLinkUserId<'_, DiscordMarker> { + /// Find the primary account link by **Discord ID**. + /// + /// # Arguments + /// + /// * `conn` - The database connection. + /// + /// # Returns + /// + /// The primary [`AccountLink`] if it exists, [`None`] otherwise. + pub(crate) fn find_primary(&self, conn: &mut PgConnection) -> Option { + use self::schema::account_links::dsl::{account_links, discord_uid, is_primary}; account_links - .filter(discord_uid.eq(pk_discord_uid)) + .filter(discord_uid.eq(&self.id)) .filter(is_primary.eq(true)) .first::(conn) .ok() } - /// Retrieve an account link by the primary key. + /// Find all account links associated with the **Discord ID**. /// /// # Arguments /// - /// * `pk` - The primary key of the account link to retrieve (Discord ID, Roblox ID). + /// * `conn` - The database connection. /// /// # Returns /// - /// The account link if it exists, `None` otherwise. - pub(crate) fn find_one<'a, PK>(pk: PK, conn: &mut PgConnection) -> Option - where - PK: Borrow>, - { - use schema::account_links::dsl::*; - let pk = pk.borrow(); + /// A vector of [`AccountLink`]s associated with the **Discord ID**. + pub(crate) fn find_many(&self, conn: &mut PgConnection) -> Vec { + use self::schema::account_links::dsl::{account_links, discord_uid}; account_links - .filter(discord_uid.eq(pk.discord_uid)) - .filter(roblox_uid.eq(pk.roblox_uid)) - .first::(conn) - .ok() + .filter(discord_uid.eq(&self.id)) + .load::(conn) + .unwrap_or_default() } - /// Retrieves all account links from the database. + /// Delete all account links associated with the **Discord ID**. /// /// # Arguments /// - /// * `target_id` - ID associated with the account links to retrieve. + /// * `conn` - The database connection. /// /// # Returns /// - /// The corresponding account links associated with the user ID. - pub(crate) fn find_many<'a, PartialPK>( - target_id: PartialPK, - conn: &mut PgConnection, - ) -> Option> - where - PartialPK: Borrow>, - { - use schema::account_links::dsl::*; - - match target_id.borrow() { - UserId::DiscordMarker(t_id) => account_links - .filter(discord_uid.eq(t_id)) - .load::(conn) - .ok(), - UserId::RobloxMarker(t_id) => account_links - .filter(roblox_uid.eq(t_id)) - .load::(conn) - .ok(), - } + /// `true` if the account links were successfully deleted, `false` otherwise. + pub(crate) fn delete_many(&self, conn: &mut PgConnection) -> bool { + use self::schema::account_links::dsl::{account_links, discord_uid}; + + diesel::delete(account_links) + .filter(discord_uid.eq(&self.id)) + .execute(conn) + .is_ok() } +} - /// Updates the primary account link for a user. +impl AccountLinkUserId<'static, RobloxMarker> { + /// Find all account links associated with the **Roblox ID**. /// /// # Arguments /// - /// * `pk` - The primary key of the account link to update (Discord ID, Roblox ID). + /// * `conn` - The database connection. /// /// # Returns /// - /// `true` if the primary account link was successfully updated, `false` otherwise. - pub(crate) fn set_primary<'a, PK>(pk: PK, conn: &mut PgConnection) -> bool - where - PK: Borrow>, - { - use schema::account_links::dsl::*; - let pk = pk.borrow(); + /// A vector of [`AccountLink`]s associated with the **Roblox ID**. + pub(crate) fn find_many(&self, conn: &mut PgConnection) -> Vec { + use self::schema::account_links::dsl::{account_links, roblox_uid}; - conn.transaction(|conn| { - // Set all other account links to not primary - diesel::update(account_links) - .filter(discord_uid.eq(pk.discord_uid)) - .set(is_primary.eq(false)) - .execute(conn)?; + account_links + .filter(roblox_uid.eq(&self.id)) + .load::(conn) + .unwrap_or_default() + } - // Set the new primary account link - diesel::update(account_links) - .filter(discord_uid.eq(pk.discord_uid)) - .filter(roblox_uid.eq(pk.roblox_uid)) - .set(is_primary.eq(true)) - .execute(conn)?; + /// Find the primary account link by **Roblox ID**. + /// + /// # Arguments + /// + /// * `conn` - The database connection. + /// + /// # Returns + /// + /// The primary [`AccountLink`] if it exists, [`None`] otherwise. + pub(crate) fn find_primary(&self, conn: &mut PgConnection) -> Option { + use self::schema::account_links::dsl::{account_links, is_primary, roblox_uid}; - diesel::result::QueryResult::Ok(()) - }) - .is_ok() + account_links + .filter(roblox_uid.eq(&self.id)) + .filter(is_primary.eq(true)) + .first::(conn) + .ok() } -} -pub(crate) enum UserId<'a> { - DiscordMarker(&'a str), - RobloxMarker(i64), -} + /// Delete all account links associated with the **Roblox ID**. + /// + /// # Arguments + /// + /// * `conn` - The database connection. + /// + /// # Returns + /// + /// `true` if the account links were successfully deleted, `false` otherwise. + pub(crate) fn delete_many(&self, conn: &mut PgConnection) -> bool { + use self::schema::account_links::dsl::{account_links, roblox_uid}; -pub(crate) struct AccountLinkPk<'a> { - pub(crate) discord_uid: &'a str, - pub(crate) roblox_uid: i64, + diesel::delete(account_links) + .filter(roblox_uid.eq(&self.id)) + .execute(conn) + .is_ok() + } } diff --git a/src/database/wrappers/account_links/models.rs b/src/database/wrappers/account_links/models.rs index 1e3cd47..f6fb8a2 100644 --- a/src/database/wrappers/account_links/models.rs +++ b/src/database/wrappers/account_links/models.rs @@ -1,15 +1,25 @@ use super::schema; use diesel::prelude::*; +/// Represents a link between a Discord account and a Roblox account. #[derive(Queryable, Selectable, Debug)] #[diesel(table_name = schema::account_links)] #[diesel(check_for_backend(diesel::pg::Pg))] pub(crate) struct AccountLink { + /// The user's Roblox ID. + /// + /// Composite primary key with `discord_uid`. pub(crate) roblox_uid: i64, + /// The user's Discord ID. + /// + /// Composite primary key with `roblox_uid`. pub(crate) discord_uid: String, + /// Whether this link is the user's primary account. pub(crate) is_primary: bool, } +/// Represents a new account link to be inserted into the database. +/// See [`AccountLink`] for field definitions. #[derive(Insertable, Debug)] #[diesel(table_name = schema::account_links)] pub(crate) struct NewAccountLink { @@ -18,6 +28,8 @@ pub(crate) struct NewAccountLink { pub(crate) is_primary: bool, } +/// Represents an update to an account link in the database. +/// See [`AccountLink`] for field definitions. #[derive(Debug, Default)] pub(crate) struct AccountLinkBuilder { roblox_uid: Option, @@ -26,35 +38,38 @@ pub(crate) struct AccountLinkBuilder { } impl NewAccountLink { + /// Creates a new [`AccountLinkBuilder`] instance. pub(crate) fn build() -> AccountLinkBuilder { AccountLinkBuilder::default() } } impl AccountLinkBuilder { + /// Creates a new [`AccountLinkBuilder`] instance. pub(crate) fn new() -> Self { Self::default() } + /// Sets the Roblox UID for the account link. pub(crate) fn roblox_uid(mut self, roblox_uid: i64) -> Self { self.roblox_uid = Some(roblox_uid); self } - pub(crate) fn discord_uid(mut self, discord_uid: S) -> Self - where - S: Into, - { - self.discord_uid = Some(discord_uid.into()); + /// Sets the Discord UID for the account link. + pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { + self.discord_uid = Some(discord_uid); self } + /// Sets whether the account link is the user's primary account. #[allow(clippy::wrong_self_convention)] pub(crate) fn is_primary(mut self, is_primary: bool) -> Self { self.is_primary = Some(is_primary); self } + /// Builds the [`NewAccountLink`] instance. pub(crate) fn build(self) -> NewAccountLink { NewAccountLink { roblox_uid: self.roblox_uid.expect("roblox_uid is required"), diff --git a/src/database/wrappers/account_links/schema.rs b/src/database/wrappers/account_links/schema.rs index c73ea2a..09cd06f 100644 --- a/src/database/wrappers/account_links/schema.rs +++ b/src/database/wrappers/account_links/schema.rs @@ -1,7 +1,15 @@ diesel::table! { - account_links(discord_uid, roblox_uid){ + /// Represents a link between a Discord account and a Roblox account. + account_links (discord_uid, roblox_uid) { + /// The user's Roblox ID. + /// + /// Composite primary key with `discord_uid`. roblox_uid -> BigInt, + /// The user's Discord ID. + /// + /// Composite primary key with `roblox_uid`. discord_uid -> Text, + /// Whether this link is the user's primary account. is_primary -> Bool, } } diff --git a/src/database/wrappers/discord_connections/mod.rs b/src/database/wrappers/discord_connections/mod.rs index ee2b066..08c8e5e 100644 --- a/src/database/wrappers/discord_connections/mod.rs +++ b/src/database/wrappers/discord_connections/mod.rs @@ -4,20 +4,26 @@ mod schema; use self::models::{DiscordConnection, NewDiscordConnection}; use crate::database::wrappers::discord_connections::models::UpdateDiscordConnection; use diesel::prelude::*; -use std::borrow::Borrow; +/// A collection of methods for interacting with the `discord_connections` table. pub(crate) struct DiscordConnectionsDb; impl DiscordConnectionsDb { - pub(crate) fn insert_one( - new_conn: V, + /// Inserts a new discord connection into the database. + /// + /// # Arguments + /// + /// * `new_conn` - The new discord connection to insert. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The newly inserted [`DiscordConnection`], if successful, [`None`] if the connection already exists, + pub(crate) fn insert_one( + new_conn: NewDiscordConnection, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> - where - V: Borrow, - { + ) -> Result, diesel::result::Error> { use self::schema::discord_connections; - let new_conn = new_conn.borrow(); // Already exists if Self::find_one(&new_conn.uid, conn).is_some() { @@ -31,12 +37,18 @@ impl DiscordConnectionsDb { .map(Some) } - pub(crate) fn find_one(pk_uid: PK, conn: &mut PgConnection) -> Option - where - PK: AsRef, - { - use schema::discord_connections::dsl::*; - let pk_uid = pk_uid.as_ref(); + /// Finds a discord connection in the database by its primary key. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the discord connection to find. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The [`DiscordConnection`], if found, or [`None`] if no connection with the primary key exists. + pub(crate) fn find_one(pk_uid: &str, conn: &mut PgConnection) -> Option { + use self::schema::discord_connections::dsl::{discord_connections, uid}; discord_connections .filter(uid.eq(pk_uid)) @@ -44,12 +56,18 @@ impl DiscordConnectionsDb { .ok() } - pub(crate) fn delete_one(pk_uid: PK, conn: &mut PgConnection) -> bool - where - PK: AsRef, - { - use schema::discord_connections::dsl::*; - let pk_uid = pk_uid.as_ref(); + /// Deletes a discord connection from the database by its primary key. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the discord connection to delete. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// `true` if the operation was successful, `false` otherwise. + pub(crate) fn delete_one(pk_uid: &str, conn: &mut PgConnection) -> bool { + use self::schema::discord_connections::dsl::*; diesel::delete(discord_connections) .filter(uid.eq(pk_uid)) @@ -57,19 +75,23 @@ impl DiscordConnectionsDb { .is_ok() } - pub(crate) fn update_one( - pk_uid: PK, - new_conn: V, + /// Updates a discord connection in the database by its primary key. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the discord connection to update. + /// * `new_conn` - The new discord connection to update. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The updated [`DiscordConnection`], if successful, or [`None`] if the connection did not exist. + pub(crate) fn update_one( + pk_uid: &str, + new_conn: UpdateDiscordConnection, conn: &mut PgConnection, - ) -> Option - where - PK: AsRef, - V: Borrow, - { - use schema::discord_connections::dsl::*; - - let pk_uid = pk_uid.as_ref(); - let new_conn = new_conn.borrow(); + ) -> Option { + use self::schema::discord_connections::dsl::{discord_connections, uid}; diesel::update(discord_connections) .filter(uid.eq(pk_uid)) diff --git a/src/database/wrappers/discord_connections/models.rs b/src/database/wrappers/discord_connections/models.rs index 14fc634..d8b6ea9 100644 --- a/src/database/wrappers/discord_connections/models.rs +++ b/src/database/wrappers/discord_connections/models.rs @@ -1,19 +1,32 @@ use super::schema; use diesel::prelude::*; +/// Represents an authorized Discord connection in the database. #[derive(Queryable, Selectable, Debug)] #[diesel(table_name = schema::discord_connections)] #[diesel(check_for_backend(diesel::pg::Pg))] pub(crate) struct DiscordConnection { + /// The unique identifier for the connection. + /// This is the same as the user's Discord ID. + /// + /// Primary key. pub(crate) uid: String, + /// The user's access token. pub(crate) access_token: String, + /// The nonce that was used to encrypt the access token. pub(crate) access_token_nonce: String, + /// The user's refresh token. pub(crate) refresh_token: String, + /// The nonce that was used to encrypt the refresh token. pub(crate) refresh_token_nonce: String, + /// The time at which the access token expires. pub(crate) expires_at: chrono::NaiveDateTime, + /// The scopes granted by the user. pub(crate) scope: String, } +/// Represents a new Discord connection to be inserted into the database. +/// See [`DiscordConnection`] for field definitions. #[derive(Insertable, Debug)] #[diesel(table_name = schema::discord_connections)] pub(crate) struct NewDiscordConnection { @@ -26,6 +39,8 @@ pub(crate) struct NewDiscordConnection { pub(crate) scope: String, } +/// Represents an update to a Discord connection in the database. +/// See [`DiscordConnection`] for field definitions. #[derive(AsChangeset, Debug)] #[diesel(table_name = schema::discord_connections)] pub(crate) struct UpdateDiscordConnection { @@ -37,6 +52,10 @@ pub(crate) struct UpdateDiscordConnection { pub(crate) scope: Option, } +/// A builder for creating new or updating existing Discord connections. +/// See [`DiscordConnection`] for field definitions. +/// +/// This struct provides a builder pattern for constructing [`NewDiscordConnection`] and [`UpdateDiscordConnection`] instances. #[derive(Default)] pub(crate) struct DiscordConnectionBuilder { uid: Option, @@ -49,59 +68,58 @@ pub(crate) struct DiscordConnectionBuilder { } impl NewDiscordConnection { + /// Creates a new [`DiscordConnectionBuilder`] instance. pub(crate) fn build() -> DiscordConnectionBuilder { DiscordConnectionBuilder::default() } } impl UpdateDiscordConnection { + /// Creates a new [`DiscordConnectionBuilder`] instance. pub(crate) fn build() -> DiscordConnectionBuilder { DiscordConnectionBuilder::default() } } impl DiscordConnectionBuilder { + /// Creates a new [`DiscordConnectionBuilder`] instance. pub(crate) fn new() -> Self { Self::default() } - pub(crate) fn uid(mut self, uid: S) -> Self - where - S: Into, - { - self.uid = Some(uid.into()); + /// Sets the unique identifier for the connection. + /// This is the same as the user's Discord ID. + pub(crate) fn uid(mut self, uid: String) -> Self { + self.uid = Some(uid); self } - pub(crate) fn access_token(mut self, access_token: S) -> Self - where - S: Into, - { - self.access_token = Some(access_token.into()); + /// Sets the user's access token. + pub(crate) fn access_token(mut self, access_token: String) -> Self +where { + self.access_token = Some(access_token); self } + /// Sets the time at which the access token expires. pub(crate) fn expires_at(mut self, expires_at: chrono::NaiveDateTime) -> Self { self.expires_at = Some(expires_at); self } - pub(crate) fn refresh_token(mut self, refresh_token: S) -> Self - where - S: Into, - { - self.refresh_token = Some(refresh_token.into()); + /// Sets the user's refresh token. + pub(crate) fn refresh_token(mut self, refresh_token: String) -> Self { + self.refresh_token = Some(refresh_token); self } - pub(crate) fn scope(mut self, scope: S) -> Self - where - S: Into, - { - self.scope = Some(scope.into()); + /// Sets the scopes granted by the user. + pub(crate) fn scope(mut self, scope: String) -> Self { + self.scope = Some(scope); self } + /// Encrypts the access and refresh tokens. fn encrypt_tokens(&mut self) -> Result<(), String> { // Encrypt the access token if it exists if let Some(access_token) = &self.access_token { @@ -122,6 +140,7 @@ impl DiscordConnectionBuilder { Ok(()) } + /// Builds a new [`NewDiscordConnection`] instance. pub(crate) fn build(mut self) -> Result { self.encrypt_tokens()?; @@ -140,6 +159,7 @@ impl DiscordConnectionBuilder { }) } + /// Builds a new [`UpdateDiscordConnection`] instance. pub(crate) fn build_update(mut self) -> Result { self.encrypt_tokens()?; @@ -170,11 +190,11 @@ mod tests { let now = chrono::Utc::now().naive_utc(); let conn = DiscordConnectionBuilder::new() - .uid(constants::test::UID) - .access_token(constants::test::ACCESS_TOKEN) + .uid(constants::test::UID.to_string()) + .access_token(constants::test::ACCESS_TOKEN.to_string()) .expires_at(now) - .refresh_token(constants::test::REFRESH_TOKEN) - .scope(constants::test::SCOPE) + .refresh_token(constants::test::REFRESH_TOKEN.to_string()) + .scope(constants::test::SCOPE.to_string()) .build() .unwrap(); @@ -199,10 +219,10 @@ mod tests { let now = chrono::Utc::now().naive_utc(); let conn = DiscordConnectionBuilder::new() - .access_token(constants::test::ACCESS_TOKEN) + .access_token(constants::test::ACCESS_TOKEN.to_string()) .expires_at(now) - .refresh_token(constants::test::REFRESH_TOKEN) - .scope(constants::test::SCOPE) + .refresh_token(constants::test::REFRESH_TOKEN.to_string()) + .scope(constants::test::SCOPE.to_string()) .build_update() .unwrap(); diff --git a/src/database/wrappers/discord_connections/schema.rs b/src/database/wrappers/discord_connections/schema.rs index 6c43b7e..6dbda21 100644 --- a/src/database/wrappers/discord_connections/schema.rs +++ b/src/database/wrappers/discord_connections/schema.rs @@ -1,11 +1,22 @@ diesel::table! { + /// Represents an authorized Discord connection. discord_connections (uid) { + /// The unique identifier for the connection. + /// This is the same as the user's Discord ID. + /// + /// Primary key. uid -> Text, + /// The user's access token. access_token -> Text, + /// The nonce that was used to encrypt the access token. access_token_nonce -> Text, + /// The user's refresh token. refresh_token -> Text, + /// The nonce that was used to encrypt the refresh token. refresh_token_nonce -> Text, + /// The time at which the access token expires. expires_at -> Timestamp, + /// The scopes granted by the user. scope -> Text, } } diff --git a/src/database/wrappers/mod.rs b/src/database/wrappers/mod.rs index a55880f..f397409 100644 --- a/src/database/wrappers/mod.rs +++ b/src/database/wrappers/mod.rs @@ -1,3 +1,6 @@ pub(crate) mod account_links; pub(crate) mod discord_connections; pub(crate) mod sessions; + +/// The error type for operations on the database. +type Error = diesel::result::Error; diff --git a/src/database/wrappers/sessions/mod.rs b/src/database/wrappers/sessions/mod.rs index 6ed595a..c3fed42 100644 --- a/src/database/wrappers/sessions/mod.rs +++ b/src/database/wrappers/sessions/mod.rs @@ -4,20 +4,27 @@ mod schema; use self::models::{NewSession, Session}; use crate::database::wrappers::sessions::models::UpdateSession; use diesel::prelude::*; -use std::borrow::Borrow; +/// A collection of methods for interacting with the `sessions` table. pub(crate) struct SessionsDb; impl SessionsDb { - pub(crate) fn insert_one( - new_session: V, + /// Inserts a new session into the database. + /// + /// # Parameters + /// + /// - `new_session` - The new session to insert. + /// - `conn` - The database connection to use. + /// + /// # Returns + /// + /// The newly inserted [`Session`], if successful, [`None`] if the session already exists, + /// or an error if the operation failed. + pub(crate) fn insert_one( + new_session: NewSession, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> - where - V: Borrow, - { + ) -> Result, super::Error> { use self::schema::sessions; - let new_session = new_session.borrow(); // Already exists if Self::find_one(&new_session.session_id, conn).is_some() { @@ -31,12 +38,18 @@ impl SessionsDb { .map(Some) } - pub(crate) fn find_one(pk_session_id: PK, conn: &mut PgConnection) -> Option - where - PK: AsRef, - { - use schema::sessions::dsl::*; - let pk_session_id = pk_session_id.as_ref(); + /// Finds a session in the database by its primary key. + /// + /// # Parameters + /// + /// - `pk_session_id` - The primary key of the session to find. + /// - `conn` - The database connection to use. + /// + /// # Returns + /// + /// The [`Session`], if found, or [`None`] if no session with the primary key exists. + pub(crate) fn find_one(pk_session_id: &str, conn: &mut PgConnection) -> Option { + use self::schema::sessions::dsl::{session_id, sessions}; sessions .filter(session_id.eq(pk_session_id)) @@ -44,37 +57,46 @@ impl SessionsDb { .ok() } - pub(crate) fn delete_one(pk_session_id: PK, conn: &mut PgConnection) -> bool - where - PK: AsRef, - { - use schema::sessions::dsl::*; - let pk_session_id = pk_session_id.as_ref(); + /// Deletes a session from the database by its primary key. + /// + /// # Parameters + /// + /// - `pk_session_id` - The primary key of the session to delete. + /// - `conn` - The database connection to use. + /// + /// # Returns + /// + /// `true` if the session was successfully deleted, `false` if the session did not exist. + pub(crate) fn delete_one(pk_session_id: &str, conn: &mut PgConnection) -> bool { + use self::schema::sessions::dsl::{session_id, sessions}; diesel::delete(sessions.filter(session_id.eq(pk_session_id))) .execute(conn) .is_ok() } - pub(crate) fn update_one( - pk_session_id: PK, - updated_session: V, + /// Updates a session in the database by its primary key. + /// + /// # Parameters + /// + /// - `pk_session_id` - The primary key of the session to update. + /// - `updated_session` - The updated session data. + /// - `conn` - The database connection to use. + /// + /// # Returns + /// + /// The updated session, if successful, or [`None`] if the session did not exist. + pub(crate) fn update_one( + pk_session_id: &str, + updated_session: UpdateSession, conn: &mut PgConnection, - ) -> Option - where - PK: AsRef, - V: Borrow, - { - use schema::sessions::dsl::*; - - let pk_session_id = pk_session_id.as_ref(); - let updated_session = updated_session.borrow(); + ) -> Result { + use self::schema::sessions::dsl::{session_id, sessions}; diesel::update(sessions) .filter(session_id.eq(pk_session_id)) .set(updated_session) .returning(Session::as_returning()) .get_result(conn) - .ok() } } diff --git a/src/database/wrappers/sessions/models.rs b/src/database/wrappers/sessions/models.rs index 3cf3760..854d98b 100644 --- a/src/database/wrappers/sessions/models.rs +++ b/src/database/wrappers/sessions/models.rs @@ -1,15 +1,25 @@ use super::schema; use diesel::prelude::*; +/// Represents a session in the database. #[derive(Queryable, Selectable, Debug)] #[diesel(table_name = schema::sessions)] #[diesel(check_for_backend(diesel::pg::Pg))] pub(crate) struct Session { + /// The unique identifier for the session. + /// + /// Primary key. pub(crate) session_id: String, + /// The Discord user ID associated with the session. + /// + /// Foreign key to the `discord_connections` table. pub(crate) discord_uid: String, + /// The expiration time of the session. pub(crate) expires_at: chrono::NaiveDateTime, } +/// Represents a new session to be inserted into the database. +/// See [`Session`] for field definitions. #[derive(Insertable, Debug)] #[diesel(table_name = schema::sessions)] pub(crate) struct NewSession { @@ -18,6 +28,8 @@ pub(crate) struct NewSession { pub(crate) expires_at: chrono::NaiveDateTime, } +/// Represents an update to an existing session in the database. +/// See [`Session`] for field definitions. #[derive(AsChangeset, Debug)] #[diesel(table_name = schema::sessions)] pub(crate) struct UpdateSession { @@ -26,6 +38,10 @@ pub(crate) struct UpdateSession { pub(crate) expires_at: Option, } +/// A builder for creating new or updating existing session records. +/// See [`Session`] for field definitions. +/// +/// This struct provides a builder pattern for constructing [`NewSession`] and [`UpdateSession`] instances. #[derive(Debug, Default)] pub(crate) struct SessionBuilder { session_id: Option, @@ -34,43 +50,49 @@ pub(crate) struct SessionBuilder { } impl NewSession { + /// Creates a new [`SessionBuilder`] instance. pub(crate) fn build() -> SessionBuilder { SessionBuilder::default() } } impl UpdateSession { + /// Creates a new [`SessionBuilder`] instance. pub(crate) fn build() -> SessionBuilder { SessionBuilder::default() } } impl SessionBuilder { + /// Creates a new [`SessionBuilder`] instance. pub(crate) fn new() -> Self { Self::default() } - pub(crate) fn discord_uid(mut self, discord_uid: S) -> Self - where - S: Into, - { - self.discord_uid = Some(discord_uid.into()); + /// Sets the Discord user ID for the session. + pub(crate) fn discord_uid(mut self, discord_uid: String) -> Self { + self.discord_uid = Some(discord_uid); self } - pub(crate) fn session_id(mut self, session_id: S) -> Self - where - S: Into, - { - self.session_id = Some(session_id.into()); + /// Sets the session ID for the session. + pub(crate) fn session_id(mut self, session_id: String) -> Self { + self.session_id = Some(session_id); self } + /// Sets the expiration time for the session. pub(crate) fn expires_at(mut self, expires_at: chrono::NaiveDateTime) -> Self { self.expires_at = Some(expires_at); self } + /// Builds the [`NewSession`] instance. + /// + /// # Panics + /// + /// Panics if any of the required fields are not set. + /// See [`NewSession`] for more information. pub(crate) fn build(self) -> NewSession { NewSession { discord_uid: self.discord_uid.expect("discord_uid not set"), @@ -79,6 +101,7 @@ impl SessionBuilder { } } + /// Builds the [`UpdateSession`] instance. pub(crate) fn build_update(self) -> UpdateSession { UpdateSession { discord_uid: self.discord_uid, diff --git a/src/database/wrappers/sessions/schema.rs b/src/database/wrappers/sessions/schema.rs index 7a0dc8e..b463e55 100644 --- a/src/database/wrappers/sessions/schema.rs +++ b/src/database/wrappers/sessions/schema.rs @@ -1,7 +1,15 @@ diesel::table! { + /// Represents a session in the database. sessions (session_id) { + /// The unique identifier for the session. + /// + /// Primary key. session_id -> Text, + /// The Discord user ID associated with the session. + /// + /// Foreign key to the `discord_connections` table. discord_uid -> Text, + /// The expiration time of the session. expires_at -> Timestamp, } } diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index 19df39b..c04eee8 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -175,14 +175,14 @@ pub(super) async fn discord_refresh_token( // Fetch the Discord UID from the session token let discord_uid = conn - .run(|conn| SessionsDb::find_one(session_id, conn)) + .run(move |conn| SessionsDb::find_one(&session_id, conn)) .await .map(|session| session.discord_uid) .ok_or(ApiError::message(Status::Unauthorized, "Session not found"))?; // Fetch the Discord connection from the database let discord_connection = conn - .run(|conn| DiscordConnectionsDb::find_one(discord_uid, conn)) + .run(move |conn| DiscordConnectionsDb::find_one(&discord_uid, conn)) .await .ok_or(ApiError::message( Status::InternalServerError, @@ -220,8 +220,8 @@ pub(super) async fn discord_refresh_token( // Update the Discord connection in the database let success = conn - .run(|conn| { - DiscordConnectionsDb::update_one(discord_connection.uid, new_discord_connection, conn) + .run(move |conn| { + DiscordConnectionsDb::update_one(&discord_connection.uid, new_discord_connection, conn) .is_some() }) .await; diff --git a/src/oauth/routes/mod.rs b/src/oauth/routes/mod.rs index 2241092..312f8c6 100644 --- a/src/oauth/routes/mod.rs +++ b/src/oauth/routes/mod.rs @@ -43,15 +43,9 @@ async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionI .build_update(); // Update the session in the database - conn.run(move |conn| { - let Some(session) = SessionsDb::update_one(session_id.as_str(), updated_session, conn) - else { - return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); - }; - diesel::result::QueryResult::Ok(session) - }) - .await - .map_err(|_| ApiError::message(Status::InternalServerError, "Failed to update session"))?; + conn.run(move |conn| SessionsDb::update_one(session_id.as_str(), updated_session, conn)) + .await + .map_err(|_| ApiError::message(Status::InternalServerError, "Failed to update session"))?; // Update the session cookie jar.add_private(session.cookie); diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs index a87c98e..409cbed 100644 --- a/src/oauth/utils/mod.rs +++ b/src/oauth/utils/mod.rs @@ -18,7 +18,7 @@ pub(super) fn generate_state() -> String { /// Generates a random session ID. /// -/// The session ID is a random base64 URL-safe string with no padding and a length between 43 and 86 bytes. +/// The session ID is a random base64 URL-safe string with no padding and a length between 32 and 64 bytes. fn generate_session_id() -> String { let number_of_bytes = thread_rng().gen_range(32..=64); generate_random_base64_url_safe_no_pad_string(number_of_bytes) @@ -26,7 +26,7 @@ fn generate_session_id() -> String { /// Generates a new session. /// -/// - The session ID is a random base64 URL-safe string with no padding and a length between 43 and 86 bytes. +/// - The session ID is a random base64 URL-safe string with no padding and a length between 32 and 64 bytes. /// - The session cookie has the session ID as the value and a max age of 30 days. /// /// # Returns diff --git a/src/oauth/utils/pixy.rs b/src/oauth/utils/pixy.rs index e512de1..1467053 100644 --- a/src/oauth/utils/pixy.rs +++ b/src/oauth/utils/pixy.rs @@ -15,6 +15,8 @@ pub(crate) fn generate_random_base64_url_safe_no_pad_string(number_of_bytes: usi base64::encode_config(&random_bytes, base64::URL_SAFE_NO_PAD) } +/// Generates a random base64 URL-safe string with no padding and a length between 32 and 96 bytes. +/// The resulting string is wrapped in a [`SecretString`]. fn generate_verifier() -> SecretString { let number_of_bytes = thread_rng().gen_range(32..=96); @@ -23,6 +25,17 @@ fn generate_verifier() -> SecretString { )) } +/// Calculates the challenge from the verifier. +/// The challenge is the SHA-256 hash of the verifier. +/// The challenge is then base64 encoded with URL-safe characters and no padding. +/// +/// # Arguments +/// +/// - `verifier` - The verifier to calculate the challenge from. +/// +/// # Returns +/// +/// The challenge as a base64 URL-safe string with no padding. fn calculate_challenge_from_verifier(verifier: &SecretString) -> SecretString { let verifier_hash = Sha256::digest(verifier.expose_secret().as_bytes()); @@ -32,12 +45,16 @@ fn calculate_challenge_from_verifier(verifier: &SecretString) -> SecretString { )) } +/// A struct that represents the Pixy PKCE method. pub(crate) struct Pixy { + /// The challenge generated by the Pixy method. challenge: SecretString, + /// The verifier generated by the Pixy method. verifier: SecretString, } impl Pixy { + /// Creates a new Pixy instance. pub(crate) fn new() -> Self { let verifier = generate_verifier(); let challenge = calculate_challenge_from_verifier(&verifier); @@ -48,10 +65,12 @@ impl Pixy { } } + /// Exposes the verifier as a string slice pub(crate) fn expose_verifier(&self) -> &str { self.verifier.expose_secret() } + /// Reference getter for the challenge pub(crate) fn get_challenge(&self) -> &SecretString { &self.challenge } From 969cca735aa7ee24d48a2ddc5256a610026554d9 Mon Sep 17 00:00:00 2001 From: maefall Date: Sat, 7 Dec 2024 18:40:30 +0100 Subject: [PATCH 13/14] Roblox OAuth (#12) * stuff * Roblox OAuth * Remove an unnecessary dependency * Flake: delete_dev_db on exit * refactor(dependencies): Uninstall `secrecy` crate * refactor: Add comments * chore(cfg): Add example config * refactor(fmt): Run formatter * refactor(session-request-guard): Streamline implementation of the session request guard by verifying that the session ID is stored in the database and returning the database entry * fix(discord-api): Correct endpoints * refactor(env-example): Add example for `ROBLOX_CLIENT_SECRET` * chore(routes): Add roblox oauth2 routes to route method * refactor(roblox-connections): Make roblox_uid a foreign key * fix(roblox-token-request-body): Use roblox client ID instead of Discord client ID * fix(db): Add timestamp to migration * refactor(db-wrappers): Create upsert methods * docs(roblox-oauth): Document routes * refactor(discord-callback): Redirect user to success page on Ok * refactor(fmt): Run formatter --------- Co-authored-by: nick <59822256+Archasion@users.noreply.github.com> --- .example.config.toml | 13 + .example.env | 3 + Cargo.toml | 1 - README.md | 4 +- flake.nix | 6 +- .../down.sql | 1 + .../up.sql | 11 + .../up.sql | 3 +- src/constants.rs | 10 +- src/database/wrappers/account_links/mod.rs | 122 +++++---- src/database/wrappers/account_links/models.rs | 30 ++- src/database/wrappers/account_links/schema.rs | 2 +- .../wrappers/discord_connections/mod.rs | 39 ++- .../wrappers/discord_connections/models.rs | 16 +- src/database/wrappers/mod.rs | 1 + .../wrappers/roblox_connections/mod.rs | 120 +++++++++ .../wrappers/roblox_connections/models.rs | 243 ++++++++++++++++++ .../wrappers/roblox_connections/schema.rs | 21 ++ src/database/wrappers/sessions/mod.rs | 56 ++++ src/oauth/routes/discord.rs | 46 ++-- src/oauth/routes/mod.rs | 37 +-- src/oauth/routes/roblox.rs | 131 +++++++--- src/oauth/types/discord.rs | 8 +- src/oauth/types/mod.rs | 42 +-- src/oauth/types/roblox.rs | 80 +++++- src/oauth/utils/discord.rs | 2 +- src/oauth/utils/mod.rs | 3 +- src/oauth/utils/pixy.rs | 32 +-- src/oauth/utils/roblox.rs | 73 +++++- 29 files changed, 922 insertions(+), 234 deletions(-) create mode 100644 .example.config.toml create mode 100644 migrations/2024-11-21-101457_create_roblox_connections/down.sql create mode 100644 migrations/2024-11-21-101457_create_roblox_connections/up.sql create mode 100644 src/database/wrappers/roblox_connections/mod.rs create mode 100644 src/database/wrappers/roblox_connections/models.rs create mode 100644 src/database/wrappers/roblox_connections/schema.rs diff --git a/.example.config.toml b/.example.config.toml new file mode 100644 index 0000000..0ea4b14 --- /dev/null +++ b/.example.config.toml @@ -0,0 +1,13 @@ +[cors] +allowed_origins_exact = ["http://localhost:8000"] +allowed_origins_regex = ["^http://localhost:\\d{4}$"] +allowed_methods = ["GET", "POST"] +allow_credentials = true + +[oauth.discord] +client_id = "DISCORD_CLIENT_ID" +redirect_uri = "http://localhost:8000/v1/oauth2/discord/callback" + +[oauth.roblox] +client_id = "ROBLOX_CLIENT_ID" +redirect_uri = "http://localhost:8000/v1/oauth2/roblox/callback" \ No newline at end of file diff --git a/.example.env b/.example.env index adcae2d..d4b253a 100644 --- a/.example.env +++ b/.example.env @@ -1,6 +1,9 @@ # Discord OAuth2 client secret, you can get this from the Discord Developer Portal DISCORD_CLIENT_SECRET="your_discord_client_secret" +# Roblox OAuth2 client secret, you can get this from the Roblox Creator Portal +ROBLOX_CLIENT_SECRET="your_roblox_client_secret" + # A 32 byte encryption key, you can generate one with `openssl rand -base64 24` ENCRYPTION_KEY="your_encryption_key" diff --git a/Cargo.toml b/Cargo.toml index 4222b7a..ce0efc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ rocket_sync_db_pools = { version = "0.1.0", default-features = false, features = base64-compat = { version = "1.0.0", default-features = false } rand = { version = "0.8.0", default-features = false } sha2 = { version = "0.10.7", default-features = false } -secrecy = { version = "0.10.3", default-features = false } minreq = { version = "2.12.0", default-features = false, features = ["json-using-serde", "https"] } chrono = { version = "0.4.38", features = ["serde"] } diesel = { version = "2.1.4", features = ["postgres", "serde_json", "chrono"] } diff --git a/README.md b/README.md index 9e073db..3c60be5 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,12 @@ sensitive data should be stored in the environment variables. # Client ID (required) client_id = "CLIENT_ID" # Callback URL (required) -redirect_uri = "http://localhost:8000/v1/oauth2/callback/discord" +redirect_uri = "http://localhost:8000/v1/oauth2/discord/callback" # Roblox OAuth2 settings (required) [oauth.roblox] # Client ID (required) client_id = "CLIENT_ID" # Callback URL (required) -redirect_uri = "http://localhost:8000/v1/oauth2/callback/roblox" +redirect_uri = "http://localhost:8000/v1/oauth2/roblox/callback" ``` \ No newline at end of file diff --git a/flake.nix b/flake.nix index 3eb903c..0fc1577 100644 --- a/flake.nix +++ b/flake.nix @@ -54,6 +54,8 @@ export DATABASE_URL="postgres://$PGUSER:$PGPASSWORD@$PGHOST:$PGPORT/$db_name" export ROCKET_DATABASES="{roops={url=\"$DATABASE_URL\"}}" + + alias dblens="npx dblens $DATABASE_URL" function start_dev_db { pg_ctl -D "$PGDATA" -l "$PGDATA/logfile" start @@ -84,7 +86,9 @@ stop_dev_db } - ''; + + trap "delete_dev_db" EXIT + ''; buildInputs = runtimeDeps; nativeBuildInputs = buildDeps ++ devDeps ++ [ rustc ]; }; diff --git a/migrations/2024-11-21-101457_create_roblox_connections/down.sql b/migrations/2024-11-21-101457_create_roblox_connections/down.sql new file mode 100644 index 0000000..e6737d6 --- /dev/null +++ b/migrations/2024-11-21-101457_create_roblox_connections/down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS roblox_connections; diff --git a/migrations/2024-11-21-101457_create_roblox_connections/up.sql b/migrations/2024-11-21-101457_create_roblox_connections/up.sql new file mode 100644 index 0000000..cd8bb04 --- /dev/null +++ b/migrations/2024-11-21-101457_create_roblox_connections/up.sql @@ -0,0 +1,11 @@ +-- create +CREATE TABLE IF NOT EXISTS roblox_connections +( + uid TEXT PRIMARY KEY, + access_token TEXT UNIQUE NOT NULL, + access_token_nonce TEXT UNIQUE NOT NULL, + refresh_token TEXT UNIQUE NOT NULL, + refresh_token_nonce TEXT UNIQUE NOT NULL, + expires_at TIMESTAMP NOT NULL, + scope TEXT NOT NULL +); diff --git a/migrations/2024-11-21-101617_create_account_links/up.sql b/migrations/2024-11-21-101617_create_account_links/up.sql index 88f60c3..ba367e3 100644 --- a/migrations/2024-11-21-101617_create_account_links/up.sql +++ b/migrations/2024-11-21-101617_create_account_links/up.sql @@ -3,7 +3,8 @@ CREATE TABLE IF NOT EXISTS account_links ( discord_uid TEXT REFERENCES discord_connections (uid) ON DELETE CASCADE, - roblox_uid BIGINT, + roblox_uid TEXT REFERENCES roblox_connections (uid) + ON DELETE CASCADE, is_primary BOOLEAN NOT NULL, PRIMARY KEY (discord_uid, roblox_uid) diff --git a/src/constants.rs b/src/constants.rs index 09a7878..a68c158 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -5,15 +5,19 @@ pub(crate) const API_VERSION: &str = "v1"; pub(crate) mod roblox_api { /// The URL for authorizing with the Roblox API. pub(crate) const AUTHORIZE_URL: &str = "https://apis.roblox.com/oauth/v1/authorize"; + /// The URL for obtaining tokens from the Roblox API. + pub(crate) const TOKEN_URL: &str = "https://apis.roblox.com/oauth/v1/token"; + /// The URL for fetching the current user's information from the Roblox API. + pub(crate) const USER_URL: &str = "https://apis.roblox.com/oauth/v1/userinfo"; } pub(crate) mod discord_api { /// The URL for authorizing with the Discord API. - pub(crate) const AUTHORIZE_URL: &str = "https://discord.com/authorize"; + pub(crate) const AUTHORIZE_URL: &str = "https://discord.com/api/v10/oauth2/authorize"; /// The URL for obtaining tokens from the Discord API. pub(crate) const TOKEN_URL: &str = "https://discord.com/api/v10/oauth2/token"; /// The URL for fetching the current user's information from the Discord API. - pub(crate) const USER_URL: &str = "https://discord.com/api/v10/users/@me"; + pub(crate) const USER_URL: &str = "https://discord.com/api/v10/oauth2/@me"; } pub(crate) mod cookie { @@ -32,6 +36,8 @@ pub(crate) mod env { pub(crate) const ENCRYPTION_KEY_LENGTH: usize = 32; /// The environment variable name for the Discord client secret. pub(crate) const DISCORD_CLIENT_SECRET: &str = "DISCORD_CLIENT_SECRET"; + /// The environment variable name for the Roblox client secret. + pub(crate) const ROBLOX_CLIENT_SECRET: &str = "ROBLOX_CLIENT_SECRET"; } #[cfg(test)] diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs index da1d9ec..e35854c 100644 --- a/src/database/wrappers/account_links/mod.rs +++ b/src/database/wrappers/account_links/mod.rs @@ -2,6 +2,7 @@ pub(crate) mod models; mod schema; use self::models::{AccountLink, NewAccountLink}; +use crate::database::wrappers::account_links::models::UpdateAccountLink; use diesel::prelude::*; use std::marker::PhantomData; @@ -18,46 +19,53 @@ impl AccountLinksDb { /// /// # Returns /// - /// The newly inserted [`AccountLink`], if successful, [`None`] if the account link already exists, - /// or an error if the operation failed. + /// The newly inserted [`AccountLink`], if successful, or an error if the operation failed. pub(crate) fn insert_one( new_account_link: NewAccountLink, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> { + ) -> Result { use self::schema::account_links::dsl::{account_links, discord_uid, is_primary}; - let pk = ( - AccountLinkUserId::new(new_account_link.discord_uid.as_str()), - AccountLinkUserId::new(new_account_link.roblox_uid), - ); - - // Already exists - if Self::find_one(pk, conn).is_some() { - return Ok(None); - } - - // Unset the primary flag for all other account links - // if the new account link is primary - if new_account_link.is_primary { - conn.transaction(|conn| { + conn.transaction(|conn| { + // Unset the primary flag for all other account links + // if the new account link is primary + if new_account_link.is_primary { diesel::update(account_links) .filter(discord_uid.eq(&new_account_link.discord_uid)) .set(is_primary.eq(false)) .execute(conn)?; + } - diesel::insert_into(account_links) - .values(new_account_link) - .returning(AccountLink::as_returning()) - .get_result(conn) - }) - .map(Some) - } else { diesel::insert_into(account_links) .values(new_account_link) .returning(AccountLink::as_returning()) .get_result(conn) - .map(Some) - } + }) + } + + pub(crate) fn upsert_one( + new_account_link: NewAccountLink, + conn: &mut PgConnection, + ) -> Result { + conn.transaction(|conn| { + let pk = ( + AccountLinkUserId::new(new_account_link.discord_uid.clone()), + AccountLinkUserId::new(new_account_link.roblox_uid.clone()), + ); + + // Update if already exists + if Self::find_one(pk, conn).is_some() { + let pk = ( + AccountLinkUserId::new(new_account_link.discord_uid.clone()), + AccountLinkUserId::new(new_account_link.roblox_uid.clone()), + ); + let update_account_link = UpdateAccountLink::from(new_account_link); + + Self::update_one(pk, update_account_link, conn) + } else { + Self::insert_one(new_account_link, conn) + } + }) } /// Finds an account link in the database by its primary key. @@ -73,7 +81,7 @@ impl AccountLinksDb { pub(crate) fn find_one( pk: ( AccountLinkUserId, - AccountLinkUserId<'static, RobloxMarker>, + AccountLinkUserId, ), conn: &mut PgConnection, ) -> Option { @@ -87,7 +95,7 @@ impl AccountLinksDb { .ok() } - /// Updates the primary account link in the database. + /// Updates the account link in the database by its primary key. /// /// # Arguments /// @@ -96,36 +104,48 @@ impl AccountLinksDb { /// /// # Returns /// - /// `true` if the account link was successfully updated, `false` if the account link did not exist. - pub(crate) fn set_primary( + /// The updated [`AccountLink`], if successful, or an error if the operation failed. + pub(crate) fn update_one( pk: ( AccountLinkUserId, - AccountLinkUserId<'static, RobloxMarker>, + AccountLinkUserId, ), + new_account_link: UpdateAccountLink, conn: &mut PgConnection, - ) -> bool { + ) -> Result { use self::schema::account_links::dsl::{ account_links, discord_uid, is_primary, roblox_uid, }; let (pk_discord_uid, pk_roblox_uid) = pk; conn.transaction(|conn| { - // Set all other account links to not primary - diesel::update(account_links) - .filter(discord_uid.eq(&pk_discord_uid.id)) - .set(is_primary.eq(false)) - .execute(conn)?; + if let Some(primary) = new_account_link.is_primary { + // Set all other account links to not primary + // if the new account link is primary + if primary { + diesel::update(account_links) + .filter(discord_uid.eq(&pk_discord_uid.id)) + .set(is_primary.eq(false)) + .execute(conn)?; + } + } // Set the new primary account link diesel::update(account_links) .filter(discord_uid.eq(pk_discord_uid.id)) .filter(roblox_uid.eq(pk_roblox_uid.id)) - .set(is_primary.eq(true)) - .execute(conn)?; - - diesel::result::QueryResult::Ok(()) + .set(new_account_link) + .returning(AccountLink::as_returning()) + .get_result(conn) }) - .is_ok() + } +} + +impl From for UpdateAccountLink { + fn from(new_account_link: NewAccountLink) -> Self { + Self { + is_primary: Some(new_account_link.is_primary), + } } } @@ -135,28 +155,28 @@ pub(crate) struct DiscordMarker; pub(crate) struct RobloxMarker; /// Marker trait for user types. -pub(crate) trait UserMarker<'a> { +pub(crate) trait UserMarker { /// The ID type for the user. type Id; } -impl<'a> UserMarker<'a> for DiscordMarker { - type Id = &'a str; +impl UserMarker for DiscordMarker { + type Id = String; } -impl UserMarker<'static> for RobloxMarker { - type Id = i64; +impl UserMarker for RobloxMarker { + type Id = String; } /// Wrapper for account link user IDs. -pub(crate) struct AccountLinkUserId<'a, Marker: UserMarker<'a>> { +pub(crate) struct AccountLinkUserId { /// Marker for the user type. marker: PhantomData, /// The user ID. id: Marker::Id, } -impl<'a, Marker: UserMarker<'a>> AccountLinkUserId<'a, Marker> { +impl AccountLinkUserId { /// Create a new [`AccountLinkUserId`] with the given ID. pub(crate) fn new(id: Marker::Id) -> Self { Self { @@ -166,7 +186,7 @@ impl<'a, Marker: UserMarker<'a>> AccountLinkUserId<'a, Marker> { } } -impl AccountLinkUserId<'_, DiscordMarker> { +impl AccountLinkUserId { /// Find the primary account link by **Discord ID**. /// /// # Arguments @@ -223,7 +243,7 @@ impl AccountLinkUserId<'_, DiscordMarker> { } } -impl AccountLinkUserId<'static, RobloxMarker> { +impl AccountLinkUserId { /// Find all account links associated with the **Roblox ID**. /// /// # Arguments diff --git a/src/database/wrappers/account_links/models.rs b/src/database/wrappers/account_links/models.rs index f6fb8a2..f0e5043 100644 --- a/src/database/wrappers/account_links/models.rs +++ b/src/database/wrappers/account_links/models.rs @@ -9,7 +9,7 @@ pub(crate) struct AccountLink { /// The user's Roblox ID. /// /// Composite primary key with `discord_uid`. - pub(crate) roblox_uid: i64, + pub(crate) roblox_uid: String, /// The user's Discord ID. /// /// Composite primary key with `roblox_uid`. @@ -23,16 +23,24 @@ pub(crate) struct AccountLink { #[derive(Insertable, Debug)] #[diesel(table_name = schema::account_links)] pub(crate) struct NewAccountLink { - pub(crate) roblox_uid: i64, + pub(crate) roblox_uid: String, pub(crate) discord_uid: String, pub(crate) is_primary: bool, } +/// Represents an update to an account link in the database. +/// See [`AccountLink`] for field definitions. +#[derive(AsChangeset, Debug)] +#[diesel(table_name = schema::account_links)] +pub(crate) struct UpdateAccountLink { + pub(crate) is_primary: Option, +} + /// Represents an update to an account link in the database. /// See [`AccountLink`] for field definitions. #[derive(Debug, Default)] pub(crate) struct AccountLinkBuilder { - roblox_uid: Option, + roblox_uid: Option, discord_uid: Option, is_primary: Option, } @@ -44,6 +52,13 @@ impl NewAccountLink { } } +impl UpdateAccountLink { + /// Creates a new [`AccountLinkBuilder`] instance. + pub(crate) fn build() -> AccountLinkBuilder { + AccountLinkBuilder::default() + } +} + impl AccountLinkBuilder { /// Creates a new [`AccountLinkBuilder`] instance. pub(crate) fn new() -> Self { @@ -51,7 +66,7 @@ impl AccountLinkBuilder { } /// Sets the Roblox UID for the account link. - pub(crate) fn roblox_uid(mut self, roblox_uid: i64) -> Self { + pub(crate) fn roblox_uid(mut self, roblox_uid: String) -> Self { self.roblox_uid = Some(roblox_uid); self } @@ -77,4 +92,11 @@ impl AccountLinkBuilder { is_primary: self.is_primary.expect("is_primary is required"), } } + + /// Builds the [`UpdateAccountLink`] instance. + pub(crate) fn build_update(self) -> UpdateAccountLink { + UpdateAccountLink { + is_primary: self.is_primary, + } + } } diff --git a/src/database/wrappers/account_links/schema.rs b/src/database/wrappers/account_links/schema.rs index 09cd06f..92e3779 100644 --- a/src/database/wrappers/account_links/schema.rs +++ b/src/database/wrappers/account_links/schema.rs @@ -4,7 +4,7 @@ diesel::table! { /// The user's Roblox ID. /// /// Composite primary key with `discord_uid`. - roblox_uid -> BigInt, + roblox_uid -> Text, /// The user's Discord ID. /// /// Composite primary key with `roblox_uid`. diff --git a/src/database/wrappers/discord_connections/mod.rs b/src/database/wrappers/discord_connections/mod.rs index 08c8e5e..04f99db 100644 --- a/src/database/wrappers/discord_connections/mod.rs +++ b/src/database/wrappers/discord_connections/mod.rs @@ -18,23 +18,41 @@ impl DiscordConnectionsDb { /// /// # Returns /// - /// The newly inserted [`DiscordConnection`], if successful, [`None`] if the connection already exists, + /// The newly inserted [`DiscordConnection`], if successful, or an error if the operation failed. pub(crate) fn insert_one( new_conn: NewDiscordConnection, conn: &mut PgConnection, - ) -> Result, diesel::result::Error> { + ) -> Result { use self::schema::discord_connections; - // Already exists - if Self::find_one(&new_conn.uid, conn).is_some() { - return Ok(None); - } - diesel::insert_into(discord_connections::table) .values(new_conn) .returning(DiscordConnection::as_returning()) .get_result(conn) - .map(Some) + } + + /// Upserts (inserts or updates) a discord connection in the database. + /// + /// # Arguments + /// + /// * `new_conn` - The new discord connection to upsert. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The upserted [`DiscordConnection`], if successful, or an error if the operation failed. + pub(crate) fn upsert_one( + new_conn: NewDiscordConnection, + conn: &mut PgConnection, + ) -> Result { + // Update the connection if it already exists, otherwise insert a new one. + if Self::find_one(&new_conn.uid, conn).is_some() { + let uid = new_conn.uid.clone(); + let update_conn = UpdateDiscordConnection::from(new_conn); + Self::update_one(&uid, update_conn, conn) + } else { + Self::insert_one(new_conn, conn) + } } /// Finds a discord connection in the database by its primary key. @@ -85,12 +103,12 @@ impl DiscordConnectionsDb { /// /// # Returns /// - /// The updated [`DiscordConnection`], if successful, or [`None`] if the connection did not exist. + /// The updated [`DiscordConnection`], if successful, or an error if the operation failed. pub(crate) fn update_one( pk_uid: &str, new_conn: UpdateDiscordConnection, conn: &mut PgConnection, - ) -> Option { + ) -> Result { use self::schema::discord_connections::dsl::{discord_connections, uid}; diesel::update(discord_connections) @@ -98,6 +116,5 @@ impl DiscordConnectionsDb { .set(new_conn) .returning(DiscordConnection::as_returning()) .get_result(conn) - .ok() } } diff --git a/src/database/wrappers/discord_connections/models.rs b/src/database/wrappers/discord_connections/models.rs index d8b6ea9..bee1353 100644 --- a/src/database/wrappers/discord_connections/models.rs +++ b/src/database/wrappers/discord_connections/models.rs @@ -95,8 +95,7 @@ impl DiscordConnectionBuilder { } /// Sets the user's access token. - pub(crate) fn access_token(mut self, access_token: String) -> Self -where { + pub(crate) fn access_token(mut self, access_token: String) -> Self { self.access_token = Some(access_token); self } @@ -174,6 +173,19 @@ where { } } +impl From for UpdateDiscordConnection { + fn from(new_conn: NewDiscordConnection) -> Self { + UpdateDiscordConnection { + access_token: Some(new_conn.access_token), + access_token_nonce: Some(new_conn.access_token_nonce), + expires_at: Some(new_conn.expires_at), + refresh_token: Some(new_conn.refresh_token), + refresh_token_nonce: Some(new_conn.refresh_token_nonce), + scope: Some(new_conn.scope), + } + } +} + #[cfg(test)] mod tests { use super::DiscordConnectionBuilder; diff --git a/src/database/wrappers/mod.rs b/src/database/wrappers/mod.rs index f397409..29c3671 100644 --- a/src/database/wrappers/mod.rs +++ b/src/database/wrappers/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod account_links; pub(crate) mod discord_connections; +pub(crate) mod roblox_connections; pub(crate) mod sessions; /// The error type for operations on the database. diff --git a/src/database/wrappers/roblox_connections/mod.rs b/src/database/wrappers/roblox_connections/mod.rs new file mode 100644 index 0000000..7dc6bcb --- /dev/null +++ b/src/database/wrappers/roblox_connections/mod.rs @@ -0,0 +1,120 @@ +pub(crate) mod models; +mod schema; + +use self::models::{NewRobloxConnection, RobloxConnection}; +use crate::database::wrappers::roblox_connections::models::UpdateRobloxConnection; +use diesel::prelude::*; + +/// A collection of methods for interacting with the `roblox_connections` table. +pub(crate) struct RobloxConnectionsDb; + +impl RobloxConnectionsDb { + /// Inserts a new roblox connection into the database. + /// + /// # Arguments + /// + /// * `new_conn` - The new roblox connection to insert. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The newly inserted [`RobloxConnection`], if successful, or an error if the operation failed. + pub(crate) fn insert_one( + new_conn: NewRobloxConnection, + conn: &mut PgConnection, + ) -> Result { + use self::schema::roblox_connections; + + // Already exists + diesel::insert_into(roblox_connections::table) + .values(new_conn) + .returning(RobloxConnection::as_returning()) + .get_result(conn) + } + + /// Upserts (inserts or updates) a roblox connection into the database. + /// + /// # Arguments + /// + /// * `new_conn` - The new roblox connection to insert. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The newly inserted (or updated) [`RobloxConnection`], if successful, or an error if the operation failed. + pub(crate) fn upsert_one( + new_conn: NewRobloxConnection, + conn: &mut PgConnection, + ) -> Result { + // Update if already exists + if Self::find_one(&new_conn.uid, conn).is_some() { + let uid = new_conn.uid.clone(); + let update_conn = UpdateRobloxConnection::from(new_conn); + Self::update_one(&uid, update_conn, conn) + } else { + Self::insert_one(new_conn, conn) + } + } + + /// Finds a roblox connection in the database by its primary key. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the roblox connection to find. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// The [`RobloxConnection`], if found, or [`None`] if no connection with the primary key exists. + pub(crate) fn find_one(pk_uid: &str, conn: &mut PgConnection) -> Option { + use schema::roblox_connections::dsl::*; + + roblox_connections + .filter(uid.eq(pk_uid)) + .first::(conn) + .ok() + } + + /// Deletes a roblox connection from the database by its primary key. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the roblox connection to delete. + /// * `conn` - The database connection to use. + /// + /// # Returns + /// + /// `true` if the operation was successful, `false` otherwise. + pub(crate) fn delete_one(pk_uid: &str, conn: &mut PgConnection) -> bool { + use schema::roblox_connections::dsl::*; + + diesel::delete(roblox_connections) + .filter(uid.eq(pk_uid)) + .execute(conn) + .is_ok() + } + + /// Updates a roblox connection in the database. + /// + /// # Arguments + /// + /// * `pk_uid` - The primary key of the roblox connection to update. + /// * `new_conn` - The new roblox connection to update. + /// + /// # Returns + /// + /// The updated [`RobloxConnection`], if successful, or an error if the operation failed. + pub(crate) fn update_one( + pk_uid: &str, + new_conn: UpdateRobloxConnection, + conn: &mut PgConnection, + ) -> Result { + use schema::roblox_connections::dsl::*; + + diesel::update(roblox_connections) + .filter(uid.eq(pk_uid)) + .set(new_conn) + .returning(RobloxConnection::as_returning()) + .get_result(conn) + } +} diff --git a/src/database/wrappers/roblox_connections/models.rs b/src/database/wrappers/roblox_connections/models.rs new file mode 100644 index 0000000..2eb93a7 --- /dev/null +++ b/src/database/wrappers/roblox_connections/models.rs @@ -0,0 +1,243 @@ +use super::schema; +use diesel::prelude::*; + +/// Represents an authorized Roblox connection in the database. +#[derive(Queryable, Selectable, Debug)] +#[diesel(table_name = schema::roblox_connections)] +#[diesel(check_for_backend(diesel::pg::Pg))] +pub(crate) struct RobloxConnection { + /// The unique identifier for the connection. + /// This is the same as the user's Roblox ID. + /// + /// Primary key. + pub(crate) uid: String, + /// The user's access token. + pub(crate) access_token: String, + /// The nonce that was used to encrypt the access token. + pub(crate) access_token_nonce: String, + /// The user's refresh token. + pub(crate) refresh_token: String, + /// The nonce that was used to encrypt the refresh token. + pub(crate) refresh_token_nonce: String, + /// The time at which the access token expires. + pub(crate) expires_at: chrono::NaiveDateTime, + /// The scopes granted by the user. + pub(crate) scope: String, +} + +/// Represents a new Roblox connection to be inserted into the database. +/// See [`RobloxConnection`] for field definitions. +#[derive(Insertable, Debug)] +#[diesel(table_name = schema::roblox_connections)] +pub(crate) struct NewRobloxConnection { + pub(crate) uid: String, + pub(crate) access_token: String, + access_token_nonce: String, + pub(crate) refresh_token: String, + refresh_token_nonce: String, + pub(crate) expires_at: chrono::NaiveDateTime, + pub(crate) scope: String, +} + +/// Represents an update to a Roblox connection in the database. +/// See [`RobloxConnection`] for field definitions. +#[derive(AsChangeset, Debug)] +#[diesel(table_name = schema::roblox_connections)] +pub(crate) struct UpdateRobloxConnection { + pub(crate) access_token: Option, + access_token_nonce: Option, + pub(crate) refresh_token: Option, + refresh_token_nonce: Option, + pub(crate) expires_at: Option, + pub(crate) scope: Option, +} + +/// A builder for creating new or updating existing Roblox connections. +/// See [`RobloxConnection`] for field definitions. +#[derive(Default)] +pub(crate) struct RobloxConnectionBuilder { + uid: Option, + access_token: Option, + access_token_nonce: Option, + refresh_token: Option, + refresh_token_nonce: Option, + expires_at: Option, + scope: Option, +} + +impl NewRobloxConnection { + /// Creates a new [`RobloxConnectionBuilder`] instance. + pub(crate) fn build() -> RobloxConnectionBuilder { + RobloxConnectionBuilder::default() + } +} + +impl RobloxConnectionBuilder { + /// Creates a new [`RobloxConnectionBuilder`] instance. + pub(crate) fn new() -> Self { + Self::default() + } + + /// Sets the unique identifier for the connection. + /// This is the same as the user's Roblox ID. + pub(crate) fn uid(mut self, uid: String) -> Self { + self.uid = Some(uid); + self + } + + /// Sets the user's access token. + pub(crate) fn access_token(mut self, access_token: String) -> Self { + self.access_token = Some(access_token); + self + } + + /// Sets the time at which the access token expires. + pub(crate) fn expires_at(mut self, expires_at: chrono::NaiveDateTime) -> Self { + self.expires_at = Some(expires_at); + self + } + + /// Sets the user's refresh token. + pub(crate) fn refresh_token(mut self, refresh_token: String) -> Self { + self.refresh_token = Some(refresh_token); + self + } + + /// Sets the scopes granted by the user. + pub(crate) fn scope(mut self, scope: String) -> Self { + self.scope = Some(scope); + self + } + + /// Encrypts the access and refresh tokens. + fn encrypt_tokens(&mut self) -> Result<(), String> { + // Encrypt the access token if it exists + if let Some(access_token) = &self.access_token { + let data = crate::cipher::encrypt(access_token.as_bytes())?; + + self.access_token_nonce = Some(data.nonce); + self.access_token = Some(data.data); + } + + // Encrypt the refresh token if it exists + if let Some(refresh_token) = &self.refresh_token { + let data = crate::cipher::encrypt(refresh_token.as_bytes())?; + + self.refresh_token_nonce = Some(data.nonce); + self.refresh_token = Some(data.data); + } + + Ok(()) + } + + /// Builds the [`NewRobloxConnection`] instance. + pub(crate) fn build(mut self) -> Result { + self.encrypt_tokens()?; + + Ok(NewRobloxConnection { + uid: self.uid.expect("uid is required"), + access_token: self.access_token.expect("access_token is required"), + access_token_nonce: self + .access_token_nonce + .expect("access_token_nonce is required"), + expires_at: self.expires_at.expect("expires_at is required"), + refresh_token: self.refresh_token.expect("refresh_token is required"), + refresh_token_nonce: self + .refresh_token_nonce + .expect("refresh_token_nonce is required"), + scope: self.scope.expect("scope is required"), + }) + } + + /// Builds the [`UpdateRobloxConnection`] instance. + pub(crate) fn build_update(mut self) -> Result { + self.encrypt_tokens()?; + + Ok(UpdateRobloxConnection { + access_token: self.access_token, + access_token_nonce: self.access_token_nonce, + expires_at: self.expires_at, + refresh_token: self.refresh_token, + refresh_token_nonce: self.refresh_token_nonce, + scope: self.scope, + }) + } +} + +impl From for UpdateRobloxConnection { + fn from(new_conn: NewRobloxConnection) -> Self { + UpdateRobloxConnection { + access_token: Some(new_conn.access_token), + access_token_nonce: Some(new_conn.access_token_nonce), + expires_at: Some(new_conn.expires_at), + refresh_token: Some(new_conn.refresh_token), + refresh_token_nonce: Some(new_conn.refresh_token_nonce), + scope: Some(new_conn.scope), + } + } +} + +#[cfg(test)] +mod tests { + use super::RobloxConnectionBuilder; + use crate::constants; + use serial_test::file_serial; + + #[test] + #[file_serial(env)] + fn test_build() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + + let now = chrono::Utc::now().naive_utc(); + let conn = RobloxConnectionBuilder::new() + .uid(constants::test::UID.to_string()) + .access_token(constants::test::ACCESS_TOKEN.to_string()) + .expires_at(now) + .refresh_token(constants::test::REFRESH_TOKEN.to_string()) + .scope(constants::test::SCOPE.to_string()) + .build() + .unwrap(); + + // The access token is encrypted, so we can't compare it directly + assert_ne!(conn.access_token, constants::test::ACCESS_TOKEN); + // The refresh token is encrypted, so we can't compare it directly + assert_ne!(conn.refresh_token, constants::test::REFRESH_TOKEN); + assert_eq!(conn.uid, constants::test::UID); + assert_eq!(conn.expires_at, now); + assert_eq!(conn.scope, constants::test::SCOPE); + + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } + + #[test] + #[file_serial(env)] + fn test_build_update() { + std::env::set_var( + constants::env::ENCRYPTION_KEY, + constants::test::ENCRYPTION_KEY, + ); + + let now = chrono::Utc::now().naive_utc(); + let conn = RobloxConnectionBuilder::new() + .access_token(constants::test::ACCESS_TOKEN.to_string()) + .expires_at(now) + .refresh_token(constants::test::REFRESH_TOKEN.to_string()) + .scope(constants::test::SCOPE.to_string()) + .build_update() + .unwrap(); + + // The access token is encrypted, so we can't compare it directly + assert_ne!(conn.access_token.unwrap(), constants::test::ACCESS_TOKEN); + assert!(conn.access_token_nonce.is_some()); + // The refresh token is encrypted, so we can't compare it directly + assert_ne!(conn.refresh_token.unwrap(), constants::test::REFRESH_TOKEN); + assert!(conn.refresh_token_nonce.is_some()); + assert_eq!(conn.expires_at.unwrap(), now); + assert_eq!(conn.scope.unwrap(), constants::test::SCOPE); + + std::env::remove_var(constants::env::ENCRYPTION_KEY); + } +} diff --git a/src/database/wrappers/roblox_connections/schema.rs b/src/database/wrappers/roblox_connections/schema.rs new file mode 100644 index 0000000..36d5d51 --- /dev/null +++ b/src/database/wrappers/roblox_connections/schema.rs @@ -0,0 +1,21 @@ +diesel::table! { + roblox_connections (uid) { + /// The unique identifier for the connection. + /// This is the same as the user's Roblox ID. + /// + /// Primary key. + uid -> Text, + /// The user's access token. + access_token -> Text, + /// The nonce that was used to encrypt the access token. + access_token_nonce -> Text, + /// The user's refresh token. + refresh_token -> Text, + /// The nonce that was used to encrypt the refresh token. + refresh_token_nonce -> Text, + /// The time at which the access token expires. + expires_at -> Timestamp, + /// The scopes granted by the user. + scope -> Text, + } +} diff --git a/src/database/wrappers/sessions/mod.rs b/src/database/wrappers/sessions/mod.rs index c3fed42..d8fa241 100644 --- a/src/database/wrappers/sessions/mod.rs +++ b/src/database/wrappers/sessions/mod.rs @@ -3,7 +3,12 @@ mod schema; use self::models::{NewSession, Session}; use crate::database::wrappers::sessions::models::UpdateSession; +use crate::response::ApiError; +use crate::{constants, DbConn}; use diesel::prelude::*; +use rocket::http::Status; +use rocket::request::{FromRequest, Outcome}; +use rocket::Request; /// A collection of methods for interacting with the `sessions` table. pub(crate) struct SessionsDb; @@ -100,3 +105,54 @@ impl SessionsDb { .get_result(conn) } } + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for Session { + type Error = ApiError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + // Extract the session ID from the session cookie. + let session_id = request + .cookies() + .get_pending(constants::cookie::SESSION_ID) + .map(|cookie| cookie.value().to_string()); + + // Ensure the session ID cookie exists. + let Some(session_id) = session_id else { + return Outcome::Error(( + Status::Unauthorized, + ApiError::message( + Status::Unauthorized, + format!("Missing {} cookie", constants::cookie::SESSION_ID), + ), + )); + }; + + // Get the database connection from the request guard. + let conn = request.guard::().await; + // Ensure the database connection was successfully retrieved. + let Outcome::Success(conn) = conn else { + return Outcome::Error(( + Status::InternalServerError, + ApiError::message( + Status::InternalServerError, + "Failed to access database request guard", + ), + )); + }; + + // Find the session in the database. + let session = conn + .run(move |conn| SessionsDb::find_one(&session_id, conn)) + .await; + + // Ensure the session was successfully retrieved. + match session { + Some(session) => Outcome::Success(session), + None => Outcome::Error(( + Status::Unauthorized, + ApiError::message(Status::Unauthorized, "Invalid session ID"), + )), + } + } +} diff --git a/src/oauth/routes/discord.rs b/src/oauth/routes/discord.rs index c04eee8..55efeca 100644 --- a/src/oauth/routes/discord.rs +++ b/src/oauth/routes/discord.rs @@ -5,9 +5,8 @@ use crate::database::wrappers::discord_connections::models::{ NewDiscordConnection, UpdateDiscordConnection, }; use crate::database::wrappers::discord_connections::DiscordConnectionsDb; -use crate::database::wrappers::sessions::models::NewSession; +use crate::database::wrappers::sessions::models::{NewSession, Session}; use crate::database::wrappers::sessions::SessionsDb; -use crate::oauth::routes::SessionId; use crate::oauth::types::discord::{DiscordOAuthScopeSet, DiscordOAuthScopes}; use crate::oauth::types::OAuthCallback; use crate::oauth::utils::discord::{ @@ -32,19 +31,19 @@ use rocket::State; /// - `400 Bad Request` if the scope set is invalid. #[get("/discord/initiate?")] pub(super) fn discord_oauth_initiate( - scope_set: String, + scope_set: &str, jar: &CookieJar<'_>, cfg: &State, ) -> ApiResult { - let scopes = DiscordOAuthScopeSet::try_from(scope_set.as_str()) + let scopes = DiscordOAuthScopeSet::try_from(scope_set) .map_err(|e| ApiError::message(Status::BadRequest, e))?; - let scopes = DiscordOAuthScopes::from(&scopes); + let scopes = DiscordOAuthScopes::from(scopes); let state = generate_state(); - let redirect_uri = construct_discord_oauth_url(&scopes, &state, cfg); + let redirect_uri = construct_discord_oauth_url(scopes, &state, cfg); // Build a state cookie that will be used to verify the callback let auth_cookie = Cookie::build((constants::cookie::STATE, state)) - .path(construct_api_route("/oauth2/callback/discord")) + .path(construct_api_route("/oauth2/discord/callback")) .same_site(SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(auth_cookie); @@ -69,7 +68,7 @@ pub(super) fn discord_oauth_initiate( /// - If the code exchange response cannot be parsed /// - If the authenticated user is missing the 'identify' scope. /// - If the access token or refresh token cannot be encrypted. -/// - If the session or Discord connection cannot be inserted into the database. +/// - If the session cannot be inserted or the Discord connection cannot be upserted into the database. /// - `502 Bad Gateway` if the code exchange fails. #[get("/discord/callback?")] pub(super) async fn discord_oauth_callback( @@ -128,9 +127,8 @@ pub(super) async fn discord_oauth_callback( // Insert the session and Discord connection into the database conn.run(|conn| { conn.transaction(|conn| { - DiscordConnectionsDb::insert_one(discord_connection, conn)?; + DiscordConnectionsDb::upsert_one(discord_connection, conn)?; SessionsDb::insert_one(new_session, conn)?; - diesel::result::QueryResult::Ok(()) }) }) @@ -138,7 +136,7 @@ pub(super) async fn discord_oauth_callback( .map_err(|_| { ApiError::message( Status::InternalServerError, - "Failed to insert session or Discord connection into database", + "Failed to insert session or upsert Discord connection into database", ) })?; @@ -147,7 +145,7 @@ pub(super) async fn discord_oauth_callback( jar.add_private(session.cookie); // Redirect back to the main page - Ok(Redirect::to(uri!("/"))) + Ok(Redirect::to(uri!("/connections/discord/success"))) } /// Refreshes the Discord token by decrypting the refresh token, refreshing the token, and updating the database. @@ -168,21 +166,12 @@ pub(super) async fn discord_oauth_callback( #[post("/discord/refresh-token")] pub(super) async fn discord_refresh_token( conn: DbConn, - session_id: SessionId, + session: Session, cfg: &State, ) -> ApiResult { - let session_id = session_id.into_inner(); - - // Fetch the Discord UID from the session token - let discord_uid = conn - .run(move |conn| SessionsDb::find_one(&session_id, conn)) - .await - .map(|session| session.discord_uid) - .ok_or(ApiError::message(Status::Unauthorized, "Session not found"))?; - // Fetch the Discord connection from the database let discord_connection = conn - .run(move |conn| DiscordConnectionsDb::find_one(&discord_uid, conn)) + .run(move |conn| DiscordConnectionsDb::find_one(&session.discord_uid, conn)) .await .ok_or(ApiError::message( Status::InternalServerError, @@ -221,10 +210,15 @@ pub(super) async fn discord_refresh_token( // Update the Discord connection in the database let success = conn .run(move |conn| { - DiscordConnectionsDb::update_one(&discord_connection.uid, new_discord_connection, conn) - .is_some() + DiscordConnectionsDb::update_one( + &discord_connection.uid, + new_discord_connection, + conn, + )?; + diesel::result::QueryResult::Ok(()) }) - .await; + .await + .is_ok(); if success { Ok(ApiResponse::status(Status::NoContent)) diff --git a/src/oauth/routes/mod.rs b/src/oauth/routes/mod.rs index 312f8c6..f7fbb5c 100644 --- a/src/oauth/routes/mod.rs +++ b/src/oauth/routes/mod.rs @@ -1,11 +1,9 @@ -use crate::database::wrappers::sessions::models::UpdateSession; +use crate::database::wrappers::sessions::models::{Session, UpdateSession}; use crate::database::wrappers::sessions::SessionsDb; -use crate::oauth::types::SessionId; use crate::oauth::utils::generate_session; use crate::response::{ApiError, ApiResponse, ApiResult}; use crate::DbConn; use rocket::http::{CookieJar, Status}; -use std::sync::Arc; mod discord; mod roblox; @@ -20,35 +18,22 @@ mod roblox; /// - If the session is not found in the database. /// - `500 Internal Server Error` if the session cannot be updated in the database. #[post("/refresh-session")] -async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session_id: SessionId) -> ApiResult { - // Wrap the ID in an Arc to avoid cloning the value - let session_id = Arc::new(session_id.into_inner()); - let session_id_find = Arc::clone(&session_id); - - // Check if the session exists - conn.run(move |conn| { - let Some(_) = SessionsDb::find_one(session_id_find.as_str(), conn) else { - return diesel::result::QueryResult::Err(diesel::result::Error::NotFound); - }; - diesel::result::QueryResult::Ok(()) - }) - .await - .map_err(|_| ApiError::message(Status::Unauthorized, "Session not found"))?; - - // Generate a new session ID - let session = generate_session(); - let updated_session = UpdateSession::build() - .session_id(session.session_id.clone()) - .expires_at(session.expires_at) +async fn refresh_session(jar: &CookieJar<'_>, conn: DbConn, session: Session) -> ApiResult { + // Generate a new session + let new_session = generate_session(); + let new_session_cookie = new_session.cookie; + let new_session = UpdateSession::build() + .session_id(new_session.session_id.clone()) + .expires_at(new_session.expires_at) .build_update(); // Update the session in the database - conn.run(move |conn| SessionsDb::update_one(session_id.as_str(), updated_session, conn)) + conn.run(move |conn| SessionsDb::update_one(&session.session_id, new_session, conn)) .await .map_err(|_| ApiError::message(Status::InternalServerError, "Failed to update session"))?; // Update the session cookie - jar.add_private(session.cookie); + jar.add_private(new_session_cookie); Ok(ApiResponse::status(Status::NoContent)) } @@ -61,6 +46,8 @@ pub(crate) fn routes() -> Vec { discord::discord_oauth_initiate, discord::discord_oauth_callback, discord::discord_refresh_token, + roblox::roblox_oauth_initiate, + roblox::roblox_oauth_callback, refresh_session, ] } diff --git a/src/oauth/routes/roblox.rs b/src/oauth/routes/roblox.rs index d8dadc6..5fb9a27 100644 --- a/src/oauth/routes/roblox.rs +++ b/src/oauth/routes/roblox.rs @@ -1,21 +1,53 @@ use crate::config::Config; use crate::constants; +use crate::database::wrappers::account_links::models::NewAccountLink; +use crate::database::wrappers::account_links::AccountLinksDb; +use crate::database::wrappers::roblox_connections::models::NewRobloxConnection; +use crate::database::wrappers::roblox_connections::RobloxConnectionsDb; +use crate::database::wrappers::sessions::models::Session; use crate::oauth::types::roblox::{RobloxOAuthScopeSet, RobloxOAuthScopes}; use crate::oauth::types::OAuthCallback; use crate::oauth::utils::generate_state; use crate::oauth::utils::pixy::Pixy; -use crate::oauth::utils::roblox::construct_roblox_oauth_url; +use crate::oauth::utils::roblox::{construct_roblox_oauth_url, exchange_code, get_authorized_user}; use crate::response::{ApiError, ApiResult}; use crate::utils::construct_api_route; -use rocket::http::{Cookie, CookieJar, Status}; +use crate::DbConn; +use diesel::Connection; +use rocket::http::{Cookie, CookieJar, SameSite, Status}; use rocket::response::Redirect; use rocket::time::Duration; use rocket::State; +/// Handles the Roblox OAuth callback by verifying the state and code verifier, +/// and exchanging the code for a token. +/// +/// - Verifies the state against the one saved in the cookie. +/// - Verifies the code verifier against the one saved in the cookie. +/// - Exchanges the code for a token. +/// - Fetches the authorized user. +/// - Inserts the account link and Roblox connection into the database. +/// +/// # Possible Responses +/// +/// - `303 See Other` with a redirect to the success page. +/// - `400 Bad Request` if the state is invalid. +/// - `401 Unauthorized` +/// - If the session cookie is missing. +/// - If the session is not found in the database. +/// - `500 Internal Server Error` +/// - If the code exchange response cannot be parsed +/// - If the authenticated user is missing the 'identify' scope. +/// - If the access token or refresh token cannot be encrypted. +/// - If the account link or Roblox connection cannot be upserted into the database. +/// - `502 Bad Gateway` if the code exchange fails. #[get("/roblox/callback?")] -pub(super) fn roblox_oauth_callback( +pub(super) async fn roblox_oauth_callback( callback: OAuthCallback, jar: &CookieJar<'_>, + conn: DbConn, + session: Session, + cfg: &State, ) -> ApiResult { // Verify the state against the one that was saved in the cookie let is_valid_state = jar @@ -26,48 +58,87 @@ pub(super) fn roblox_oauth_callback( return Err(ApiError::message(Status::BadRequest, "Invalid state")); } - // Use verifier to obtain token - let Some(_verifier_cookie) = jar.get_pending(constants::cookie::OAUTH_CODE_VERIFIER) else { - return Err(ApiError::message( - Status::InternalServerError, - "Missing verifier cookie", - )); - }; + let verifier_cookie = jar + .get_pending(constants::cookie::OAUTH_CODE_VERIFIER) + .ok_or_else(|| ApiError::message(Status::BadRequest, "Missing code verifier cookie"))?; + + let response = exchange_code(&callback.code, verifier_cookie.value(), cfg)?; + let roblox_uid = get_authorized_user(&response.access_token)?.id; - // TODO: Obtain token through https://apis.roblox.com/oauth/v1/token - // and store it in the database, associating it with the session cookie - // (which should already be in the cookie jar from discord auth) + let token_expires_at = + chrono::Utc::now().naive_utc() + chrono::Duration::seconds(response.expires_in); - // Redirect back to the main page - Ok(Redirect::to(uri!("/"))) + let roblox_connection = NewRobloxConnection::build() + .uid(roblox_uid.clone()) + .access_token(response.access_token) + .expires_at(token_expires_at) + .refresh_token(response.refresh_token) + .scope(response.scope) + .build() + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to encrypt access and/or refresh token", + ) + })?; + + let account_link = NewAccountLink::build() + .discord_uid(session.discord_uid) + .roblox_uid(roblox_uid.clone()) + .is_primary(true) + .build(); + + conn.run(|conn| { + conn.transaction(|conn| { + RobloxConnectionsDb::upsert_one(roblox_connection, conn)?; + AccountLinksDb::upsert_one(account_link, conn)?; + diesel::result::QueryResult::Ok(()) + }) + }) + .await + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to upsert account link or roblox connection into database", + ) + })?; + + Ok(Redirect::to(uri!("/connections/roblox/success"))) } +/// Initiates the Roblox OAuth flow by saving a randomly-generated state and code verifier in a cookie, +/// and redirecting the user to the Roblox OAuth page. +/// +/// # Possible Responses +/// +/// - `303 See Other` with a redirect to the Roblox OAuth page. +/// - `400 Bad Request` if the scope set is invalid. +/// - `401 Unauthorized` +/// - If the session cookie is missing. +/// - If the session is not found in the database. #[get("/roblox/initiate?")] -pub(super) fn roblox_oauth_initiate( - scope_set: String, +pub(super) async fn roblox_oauth_initiate( + scope_set: &str, jar: &CookieJar<'_>, + _session: Session, cfg: &State, ) -> ApiResult { - let scope_set = RobloxOAuthScopeSet::try_from(scope_set.as_str()) + let scopes = RobloxOAuthScopeSet::try_from(scope_set) .map_err(|e| ApiError::message(Status::BadRequest, e))?; - - let pixy = Pixy::new(); - let challenge = pixy.get_challenge(); - let verifier = pixy.expose_verifier().to_string(); + let scopes = RobloxOAuthScopes::from(scopes); let state = generate_state(); - - let redirect_uri = - construct_roblox_oauth_url(challenge, &RobloxOAuthScopes::from(&scope_set), &state, cfg); + let pixy = Pixy::new(); + let redirect_uri = construct_roblox_oauth_url(&pixy.challenge, scopes, &state, cfg); let auth_cookie = Cookie::build((constants::cookie::STATE, state)) - .path(construct_api_route("/oauth2/callback/roblox")) - .same_site(rocket::http::SameSite::Lax) + .path(construct_api_route("/oauth2/roblox/callback")) + .same_site(SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(auth_cookie); - let verifier_cookie = Cookie::build((constants::cookie::OAUTH_CODE_VERIFIER, verifier)) - .path(construct_api_route("/oauth2/callback/roblox")) - .same_site(rocket::http::SameSite::Lax) + let verifier_cookie = Cookie::build((constants::cookie::OAUTH_CODE_VERIFIER, pixy.verifier)) + .path(construct_api_route("/oauth2/roblox/callback")) + .same_site(SameSite::Lax) .max_age(Duration::minutes(5)); jar.add_private(verifier_cookie); diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index 0b00b09..c380c0f 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -243,8 +243,8 @@ impl Display for DiscordOAuthScopes { } } -impl From<&DiscordOAuthScopeSet> for DiscordOAuthScopes { - fn from(value: &DiscordOAuthScopeSet) -> Self { +impl From for DiscordOAuthScopes { + fn from(value: DiscordOAuthScopeSet) -> Self { match value { DiscordOAuthScopeSet::Verification => DiscordOAuthScopes(vec![ DiscordOAuthScope::Identify, @@ -268,6 +268,7 @@ impl TryFrom<&str> for DiscordOAuthScopeSet { } impl<'a> DiscordAuthorizationCodeRequestBody<'a> { + /// Create a new [`DiscordAuthorizationCodeRequestBody`] with the given authorization code and configuration. pub(crate) fn new(code: &'a str, cfg: &'a Config) -> Self { Self { client_id: &cfg.oauth.discord.client_id, @@ -280,6 +281,7 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { } } + /// Converts the request body into a query parameter string. pub(crate) fn as_query_params(&self) -> String { url!([ ("client_id", self.client_id), @@ -292,6 +294,7 @@ impl<'a> DiscordAuthorizationCodeRequestBody<'a> { } impl<'a> DiscordRefreshTokenBody<'a> { + /// Create a new [`DiscordRefreshTokenBody`] with the given refresh token and configuration. pub(crate) fn new(refresh_token: &'a str, cfg: &'a Config) -> Self { Self { grant_type: "refresh_token", @@ -303,6 +306,7 @@ impl<'a> DiscordRefreshTokenBody<'a> { } } + /// Converts the request body into a query parameter string. pub(crate) fn as_query_params(&self) -> String { url!([ ("grant_type", self.grant_type), diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index 6e0012a..fbbf34e 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -1,8 +1,4 @@ -use crate::constants; -use crate::response::ApiError; -use rocket::http::{Cookie, Status}; -use rocket::request::{FromRequest, Outcome}; -use rocket::Request; +use rocket::http::Cookie; pub(super) mod discord; pub(super) mod roblox; @@ -23,39 +19,3 @@ pub(super) struct GeneratedSession<'a> { /// The expiration date of the session. pub(super) expires_at: chrono::NaiveDateTime, } - -/// A rocket request guard that extracts the session ID from the session cookie. -pub(super) struct SessionId(String); - -impl SessionId { - /// Consumes the [`SessionId`] and returns the inner session ID. - pub(super) fn into_inner(self) -> String { - self.0 - } -} - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for SessionId { - type Error = ApiError; - - async fn from_request(request: &'r Request<'_>) -> Outcome { - // Extract the session ID from the session cookie. - let session_id = request - .cookies() - .get_pending(constants::cookie::SESSION_ID) - .map(|cookie| cookie.value().to_string()); - - // Return an error if the session ID is missing. - // Otherwise, return the session ID. - match session_id { - None => { - let error = ApiError::message( - Status::Unauthorized, - format!("Missing {} cookie", constants::cookie::SESSION_ID), - ); - Outcome::Error((Status::Unauthorized, error)) - } - Some(session_id) => Outcome::Success(SessionId(session_id)), - } - } -} diff --git a/src/oauth/types/roblox.rs b/src/oauth/types/roblox.rs index a019894..38d3cd3 100644 --- a/src/oauth/types/roblox.rs +++ b/src/oauth/types/roblox.rs @@ -1,3 +1,6 @@ +use crate::config::Config; +use crate::{constants, url}; +use rocket::serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result as FmtResult}; /// A set of scopes that will be requested by the web app. @@ -45,6 +48,46 @@ pub(crate) enum RobloxOAuthScope { UserUserNotificationWrite, } +/// The request body for the Roblox OAuth token endpoint. +#[derive(Serialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct RobloxAuthorizedCodeRequestBody<'a> { + /// The authorization code received from the OAuth2 provider. + code: &'a str, + /// The code verifier used to generate the authorization code. + code_verifier: &'a str, + /// The grant type for the OAuth2 request. + grant_type: &'a str, + /// The client ID for the OAuth2 application. + client_id: &'a str, + /// The client secret for the OAuth2 application. + /// Must be an owned string as it is retrieved from the environment + client_secret: String, +} + +/// The response body for the Roblox OAuth token endpoint. +#[derive(Debug, Deserialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct RobloxAccessTokenResponse { + /// The access token for the user. + pub(crate) access_token: String, + /// The refresh token for the user. + pub(crate) refresh_token: String, + /// The time in seconds until the access token expires. + pub(crate) expires_in: i64, + /// The scopes that the user has authorized. + pub(crate) scope: String, +} + +/// The response body for the Roblox OAuth user info endpoint. +#[derive(Debug, Deserialize)] +#[serde(crate = "rocket::serde")] +pub(crate) struct RobloxUserInfoResponse { + /// The user's Roblox ID. + #[serde(rename = "sub")] + pub(crate) id: String, +} + impl From<&RobloxOAuthScope> for &str { fn from(value: &RobloxOAuthScope) -> Self { match value { @@ -108,8 +151,8 @@ impl Display for RobloxOAuthScopes { } } -impl From<&RobloxOAuthScopeSet> for RobloxOAuthScopes { - fn from(value: &RobloxOAuthScopeSet) -> Self { +impl From for RobloxOAuthScopes { + fn from(value: RobloxOAuthScopeSet) -> Self { match value { RobloxOAuthScopeSet::Verification => RobloxOAuthScopes(vec![ RobloxOAuthScope::OpenID, @@ -132,3 +175,36 @@ impl TryFrom<&str> for RobloxOAuthScopeSet { } } } + +// impl From<&RobloxOAuthScopeSet> for String { +// fn from(value: &RobloxOAuthScopeSet) -> Self { +// let scopes: RobloxOAuthScopes = value.into(); +// +// scopes.to_string() +// } +// } + +impl<'a> RobloxAuthorizedCodeRequestBody<'a> { + /// Creates a new [`RobloxAuthorizedCodeRequestBody`] with the given code, code verifier, and configuration. + pub(crate) fn new(code: &'a str, code_verifier: &'a str, cfg: &'a Config) -> Self { + Self { + code, + code_verifier, + grant_type: "authorization_code", + client_id: &cfg.oauth.roblox.client_id, + client_secret: std::env::var(constants::env::ROBLOX_CLIENT_SECRET) + .unwrap_or_else(|_| panic!("{} must be set", constants::env::ROBLOX_CLIENT_SECRET)), + } + } + + /// Converts the request body to a query parameter string. + pub(crate) fn as_query_params(&self) -> String { + url!([ + ("code", self.code), + ("code_verifier", self.code_verifier), + ("grant_type", self.grant_type), + ("client_id", self.client_id), + ("client_secret", self.client_secret) + ]) + } +} diff --git a/src/oauth/utils/discord.rs b/src/oauth/utils/discord.rs index 25b5b8f..3157e57 100644 --- a/src/oauth/utils/discord.rs +++ b/src/oauth/utils/discord.rs @@ -20,7 +20,7 @@ use rocket::http::Status; /// /// The constructed Discord OAuth URL. pub(crate) fn construct_discord_oauth_url( - scopes: &DiscordOAuthScopes, + scopes: DiscordOAuthScopes, state: &str, cfg: &Config, ) -> String { diff --git a/src/oauth/utils/mod.rs b/src/oauth/utils/mod.rs index 409cbed..54e3cfe 100644 --- a/src/oauth/utils/mod.rs +++ b/src/oauth/utils/mod.rs @@ -6,7 +6,7 @@ use self::pixy::generate_random_base64_url_safe_no_pad_string; use crate::constants; use crate::oauth::types::GeneratedSession; use rand::{thread_rng, Rng}; -use rocket::http::Cookie; +use rocket::http::{Cookie, SameSite}; /// Generates a random state. /// @@ -39,6 +39,7 @@ pub(super) fn generate_session() -> GeneratedSession<'static> { GeneratedSession { cookie: Cookie::build((constants::cookie::SESSION_ID, session_id.clone())) .max_age(rocket::time::Duration::days(30)) + .same_site(SameSite::Lax) .build(), session_id, expires_at: expires_at.naive_utc(), diff --git a/src/oauth/utils/pixy.rs b/src/oauth/utils/pixy.rs index 1467053..b3482a3 100644 --- a/src/oauth/utils/pixy.rs +++ b/src/oauth/utils/pixy.rs @@ -1,5 +1,4 @@ use rand::{thread_rng, Rng}; -use secrecy::{ExposeSecret, SecretString}; use sha2::{Digest, Sha256}; /// Generates a random base64 URL-safe string with no padding. @@ -17,12 +16,10 @@ pub(crate) fn generate_random_base64_url_safe_no_pad_string(number_of_bytes: usi /// Generates a random base64 URL-safe string with no padding and a length between 32 and 96 bytes. /// The resulting string is wrapped in a [`SecretString`]. -fn generate_verifier() -> SecretString { +fn generate_verifier() -> String { let number_of_bytes = thread_rng().gen_range(32..=96); - SecretString::from(generate_random_base64_url_safe_no_pad_string( - number_of_bytes, - )) + generate_random_base64_url_safe_no_pad_string(number_of_bytes) } /// Calculates the challenge from the verifier. @@ -36,21 +33,18 @@ fn generate_verifier() -> SecretString { /// # Returns /// /// The challenge as a base64 URL-safe string with no padding. -fn calculate_challenge_from_verifier(verifier: &SecretString) -> SecretString { - let verifier_hash = Sha256::digest(verifier.expose_secret().as_bytes()); +fn calculate_challenge_from_verifier(verifier: &str) -> String { + let verifier_hash = Sha256::digest(verifier.as_bytes()); - SecretString::from(base64::encode_config( - &verifier_hash, - base64::URL_SAFE_NO_PAD, - )) + base64::encode_config(&verifier_hash, base64::URL_SAFE_NO_PAD) } /// A struct that represents the Pixy PKCE method. pub(crate) struct Pixy { /// The challenge generated by the Pixy method. - challenge: SecretString, + pub(crate) challenge: String, /// The verifier generated by the Pixy method. - verifier: SecretString, + pub(crate) verifier: String, } impl Pixy { @@ -60,18 +54,8 @@ impl Pixy { let challenge = calculate_challenge_from_verifier(&verifier); Pixy { - verifier, challenge, + verifier, } } - - /// Exposes the verifier as a string slice - pub(crate) fn expose_verifier(&self) -> &str { - self.verifier.expose_secret() - } - - /// Reference getter for the challenge - pub(crate) fn get_challenge(&self) -> &SecretString { - &self.challenge - } } diff --git a/src/oauth/utils/roblox.rs b/src/oauth/utils/roblox.rs index 4b9d1ee..49e8251 100644 --- a/src/oauth/utils/roblox.rs +++ b/src/oauth/utils/roblox.rs @@ -1,8 +1,12 @@ use crate::config::Config; use crate::constants; -use crate::oauth::types::roblox::RobloxOAuthScopes; +use crate::oauth::types::roblox::{ + RobloxAccessTokenResponse, RobloxAuthorizedCodeRequestBody, RobloxOAuthScopes, + RobloxUserInfoResponse, +}; +use crate::response::ApiError; use crate::url; -use secrecy::{ExposeSecret, SecretString}; +use rocket::http::Status; /// Constructs the Roblox OAuth URL with the given scopes and state. /// The scopes are joined by an encoded space (`%20`). @@ -16,13 +20,11 @@ use secrecy::{ExposeSecret, SecretString}; /// /// The constructed Roblox OAuth URL. pub(crate) fn construct_roblox_oauth_url( - code_challenge_secret: &SecretString, - scopes: &RobloxOAuthScopes, + code_challenge: &str, + scopes: RobloxOAuthScopes, state: &str, cfg: &Config, ) -> String { - let code_challenge = code_challenge_secret.expose_secret(); - url!( constants::roblox_api::AUTHORIZE_URL, [ @@ -36,3 +38,62 @@ pub(crate) fn construct_roblox_oauth_url( ] ) } + +/// Exchanges the given code for an access token. +/// +/// # Arguments +/// +/// * `code` - The code to exchange for an access token. +/// * `code_verifier` - The code verifier used to generate the code challenge. +/// * `cfg` - The application configuration. +/// +/// # Returns +/// +/// The [`RobloxAccessTokenResponse`] struct if the exchange was successful, an [`ApiError`] otherwise. +pub(crate) fn exchange_code( + code: &str, + code_verifier: &str, + cfg: &Config, +) -> Result { + let body = RobloxAuthorizedCodeRequestBody::new(code, code_verifier, cfg); + let response = minreq::post(constants::roblox_api::TOKEN_URL) + .with_header("Content-Type", "application/x-www-form-urlencoded") + .with_body(body.as_query_params()) + .send() + .map_err(|_| ApiError::message(Status::BadGateway, "Failed to exchange code for token"))?; + + response.json::().map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse token response", + ) + }) +} + +/// Gets the authorized user from the Roblox API. +/// +/// # Arguments +/// +/// * `access_token` - The access token to use. +/// +/// # Returns +/// +/// The [`RobloxUserInfoResponse`] struct if the request was successful, an [`ApiError`] otherwise. +pub(crate) fn get_authorized_user(access_token: &str) -> Result { + minreq::get(constants::roblox_api::USER_URL) + .with_header("Authorization", format!("Bearer {access_token}")) + .send() + .map_err(|_| { + ApiError::message( + Status::BadGateway, + "Failed to get authorized user information", + ) + })? + .json::() + .map_err(|_| { + ApiError::message( + Status::InternalServerError, + "Failed to parse authorized user information", + ) + }) +} From 109600fe746770aa5e4e0ece04b2da73ad3bc31f Mon Sep 17 00:00:00 2001 From: nick <59822256+Archasion@users.noreply.github.com> Date: Sat, 7 Dec 2024 20:08:30 +0000 Subject: [PATCH 14/14] refactor: Add more tests and traits --- src/database/wrappers/account_links/mod.rs | 1 + src/oauth/types/discord.rs | 44 +++++++++++++++++ src/oauth/types/mod.rs | 1 + src/oauth/types/roblox.rs | 56 ++++++++++++++++++---- 4 files changed, 93 insertions(+), 9 deletions(-) diff --git a/src/database/wrappers/account_links/mod.rs b/src/database/wrappers/account_links/mod.rs index e35854c..3d8ee71 100644 --- a/src/database/wrappers/account_links/mod.rs +++ b/src/database/wrappers/account_links/mod.rs @@ -169,6 +169,7 @@ impl UserMarker for RobloxMarker { } /// Wrapper for account link user IDs. +#[derive(Debug)] pub(crate) struct AccountLinkUserId { /// Marker for the user type. marker: PhantomData, diff --git a/src/oauth/types/discord.rs b/src/oauth/types/discord.rs index c380c0f..408552d 100644 --- a/src/oauth/types/discord.rs +++ b/src/oauth/types/discord.rs @@ -8,17 +8,20 @@ use std::fmt::{Display, Formatter, Result as FmtResult}; /// /// The internal API should parse these scope sets into individual scopes /// that will be requested from the OAuth2 provider +#[derive(PartialEq, Debug)] pub(crate) enum DiscordOAuthScopeSet { /// The scopes required for verifying a user's Discord account Verification, } /// A set of scopes that will be requested from the OAuth2 provider. +#[derive(PartialEq, Debug)] pub(crate) struct DiscordOAuthScopes(pub(crate) Vec); /// A scope that can be requested from the OAuth2 provider. /// /// [Reference](https://discord.com/developers/docs/topics/oauth2#shared-resources-oauth2-scopes) +#[derive(PartialEq, Debug)] pub(crate) enum DiscordOAuthScope { Identify, Guilds, @@ -316,3 +319,44 @@ impl<'a> DiscordRefreshTokenBody<'a> { ]) } } + +#[cfg(test)] +mod tests { + use super::{DiscordOAuthScope, DiscordOAuthScopeSet, DiscordOAuthScopes}; + + #[test] + fn test_discord_oauth_scopes_display() { + let scopes = DiscordOAuthScopes(vec![ + DiscordOAuthScope::Identify, + DiscordOAuthScope::GuildsMembersRead, + DiscordOAuthScope::Guilds, + DiscordOAuthScope::Connections, + ]); + + assert_eq!( + format!("{}", scopes), + "identify+guilds.members.read+guilds+connections" + ); + } + + #[test] + fn test_discord_oauth_scope_set_try_from() { + assert_eq!( + DiscordOAuthScopeSet::try_from("verification").unwrap(), + DiscordOAuthScopeSet::Verification + ); + } + + #[test] + fn test_discord_oauth_scopes_from() { + assert_eq!( + DiscordOAuthScopes::from(DiscordOAuthScopeSet::Verification), + DiscordOAuthScopes(vec![ + DiscordOAuthScope::Identify, + DiscordOAuthScope::GuildsMembersRead, + DiscordOAuthScope::Guilds, + DiscordOAuthScope::Connections, + ]) + ); + } +} diff --git a/src/oauth/types/mod.rs b/src/oauth/types/mod.rs index fbbf34e..ca8f6a6 100644 --- a/src/oauth/types/mod.rs +++ b/src/oauth/types/mod.rs @@ -11,6 +11,7 @@ pub(super) struct OAuthCallback { } /// Result of the [`generate_session`](crate::oauth::utils::generate_session) function. +#[derive(Debug)] pub(super) struct GeneratedSession<'a> { /// The session cookie. pub(super) cookie: Cookie<'a>, diff --git a/src/oauth/types/roblox.rs b/src/oauth/types/roblox.rs index 38d3cd3..ad11b76 100644 --- a/src/oauth/types/roblox.rs +++ b/src/oauth/types/roblox.rs @@ -7,14 +7,17 @@ use std::fmt::{Display, Formatter, Result as FmtResult}; /// /// The internal API should parse these scope sets into individual scopes /// that will be requested from the OAuth2 provider +#[derive(PartialEq, Debug)] pub(crate) enum RobloxOAuthScopeSet { Verification, } /// A set of scopes that will be requested from the OAuth2 provider. +#[derive(PartialEq, Debug)] pub(crate) struct RobloxOAuthScopes(pub(crate) Vec); /// A scope that will be requested from the OAuth2 provider. +#[derive(PartialEq, Debug)] pub(crate) enum RobloxOAuthScope { OpenID, Profile, @@ -66,7 +69,7 @@ pub(crate) struct RobloxAuthorizedCodeRequestBody<'a> { } /// The response body for the Roblox OAuth token endpoint. -#[derive(Debug, Deserialize)] +#[derive(Deserialize)] #[serde(crate = "rocket::serde")] pub(crate) struct RobloxAccessTokenResponse { /// The access token for the user. @@ -176,14 +179,6 @@ impl TryFrom<&str> for RobloxOAuthScopeSet { } } -// impl From<&RobloxOAuthScopeSet> for String { -// fn from(value: &RobloxOAuthScopeSet) -> Self { -// let scopes: RobloxOAuthScopes = value.into(); -// -// scopes.to_string() -// } -// } - impl<'a> RobloxAuthorizedCodeRequestBody<'a> { /// Creates a new [`RobloxAuthorizedCodeRequestBody`] with the given code, code verifier, and configuration. pub(crate) fn new(code: &'a str, code_verifier: &'a str, cfg: &'a Config) -> Self { @@ -208,3 +203,46 @@ impl<'a> RobloxAuthorizedCodeRequestBody<'a> { ]) } } + +#[cfg(test)] +mod tests { + use super::{RobloxOAuthScope, RobloxOAuthScopeSet, RobloxOAuthScopes}; + + #[test] + fn test_roblox_oauth_scopes_display() { + let scopes = RobloxOAuthScopes(vec![ + RobloxOAuthScope::OpenID, + RobloxOAuthScope::Profile, + RobloxOAuthScope::GroupRead, + RobloxOAuthScope::UserInventoryItemRead, + RobloxOAuthScope::UserAdvancedRead, + ]); + + assert_eq!( + format!("{}", scopes), + "openid%20profile%20group:read%20user.inventory-item:read%20user.advanced:read" + ); + } + + #[test] + fn test_roblox_oauth_scope_set_try_from() { + assert_eq!( + RobloxOAuthScopeSet::try_from("verification").unwrap(), + RobloxOAuthScopeSet::Verification + ); + } + + #[test] + fn test_roblox_oauth_scopes_from() { + assert_eq!( + RobloxOAuthScopes::from(RobloxOAuthScopeSet::Verification), + RobloxOAuthScopes(vec![ + RobloxOAuthScope::OpenID, + RobloxOAuthScope::Profile, + RobloxOAuthScope::GroupRead, + RobloxOAuthScope::UserInventoryItemRead, + RobloxOAuthScope::UserAdvancedRead, + ]) + ); + } +}