From e53b59174970571c147c00305d828bc486547a0a Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 18 Nov 2025 04:32:52 +0000 Subject: [PATCH 1/2] feat(crypto): Add SecretGateway core cryptographic utilities (Window 17) Implement comprehensive crypto service with: - Cryptographically secure random secret generation (bytes, hex, urlsafe) - Fernet symmetric encryption/decryption with TTL support - SHA-256 hashing with salt and constant-time verification - PBKDF2 key derivation from passwords (480k iterations) - Key rotation support for encrypted data - Comprehensive test suite (41 tests, 100% pass rate) Security features: - Uses secrets module for CSPRNG - Fernet provides authenticated encryption (AES-128-CBC + HMAC) - Constant-time hash comparison to prevent timing attacks - OWASP-recommended PBKDF2 iterations Dependencies: - Add cryptography==42.0.5 - Add .gitignore for Python cache files --- .gitignore | 3 + app/services/crypto/__init__.py | 12 + app/services/crypto/crypto_service.py | 354 ++++++++++++++++++ requirements.txt | 1 + tests/test_crypto_service.py | 497 ++++++++++++++++++++++++++ 5 files changed, 867 insertions(+) create mode 100644 .gitignore create mode 100644 app/services/crypto/__init__.py create mode 100644 app/services/crypto/crypto_service.py create mode 100644 tests/test_crypto_service.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..75c6182 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +.pytest_cache/ diff --git a/app/services/crypto/__init__.py b/app/services/crypto/__init__.py new file mode 100644 index 0000000..7c2c912 --- /dev/null +++ b/app/services/crypto/__init__.py @@ -0,0 +1,12 @@ +""" +SecretGateway Cryptographic Services Module + +Provides core cryptographic utilities for secure secret management: +- Secret generation +- Symmetric encryption/decryption +- Secure hashing with salt +""" + +from .crypto_service import CryptoService + +__all__ = ["CryptoService"] diff --git a/app/services/crypto/crypto_service.py b/app/services/crypto/crypto_service.py new file mode 100644 index 0000000..e3b774a --- /dev/null +++ b/app/services/crypto/crypto_service.py @@ -0,0 +1,354 @@ +""" +Core Cryptographic Service for SecretGateway + +Implements secure cryptographic operations: +- Random secret generation +- Fernet symmetric encryption/decryption +- SHA-256 hashing with salt support +""" + +import hashlib +import secrets +from typing import Optional, Tuple +from base64 import b64encode, b64decode + +from cryptography.fernet import Fernet, InvalidToken +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + +class CryptoError(Exception): + """Base exception for cryptographic operations""" + pass + + +class EncryptionError(CryptoError): + """Raised when encryption fails""" + pass + + +class DecryptionError(CryptoError): + """Raised when decryption fails""" + pass + + +class CryptoService: + """ + Cryptographic service providing secure operations for secret management. + + Features: + - Cryptographically secure random secret generation + - Fernet symmetric encryption (AES-128 in CBC mode with HMAC) + - SHA-256 hashing with optional salt + - Key derivation from passwords + """ + + DEFAULT_SECRET_LENGTH = 32 # 256 bits + DEFAULT_SALT_LENGTH = 16 # 128 bits + PBKDF2_ITERATIONS = 480000 # OWASP 2023 recommendation + + def __init__(self, master_key: Optional[bytes] = None): + """ + Initialize CryptoService. + + Args: + master_key: Optional pre-generated Fernet key (32 url-safe base64 bytes). + If not provided, a new key will be generated. + """ + if master_key: + try: + self._fernet = Fernet(master_key) + self._master_key = master_key + except Exception as e: + raise CryptoError(f"Invalid master key: {e}") + else: + self._master_key = Fernet.generate_key() + self._fernet = Fernet(self._master_key) + + @property + def master_key(self) -> bytes: + """Get the current master encryption key""" + return self._master_key + + def generate_secret(self, length: int = DEFAULT_SECRET_LENGTH) -> bytes: + """ + Generate cryptographically secure random bytes. + + Uses secrets module which is suitable for managing sensitive data + like authentication tokens, API keys, and secrets. + + Args: + length: Number of random bytes to generate (default: 32) + + Returns: + Random bytes of specified length + + Raises: + ValueError: If length is not positive + """ + if length <= 0: + raise ValueError("Secret length must be positive") + + return secrets.token_bytes(length) + + def generate_secret_hex(self, length: int = DEFAULT_SECRET_LENGTH) -> str: + """ + Generate cryptographically secure random hex string. + + Args: + length: Number of random bytes to generate (default: 32) + + Returns: + Hex string representation (2x length characters) + """ + return secrets.token_hex(length) + + def generate_secret_urlsafe(self, length: int = DEFAULT_SECRET_LENGTH) -> str: + """ + Generate cryptographically secure URL-safe random string. + + Args: + length: Number of random bytes to generate (default: 32) + + Returns: + URL-safe base64-encoded string + """ + return secrets.token_urlsafe(length) + + def encrypt(self, plaintext: bytes) -> bytes: + """ + Encrypt data using Fernet symmetric encryption. + + Fernet guarantees that encrypted data cannot be manipulated or read + without the key. It uses AES-128 in CBC mode with PKCS7 padding + and HMAC for authentication. + + Args: + plaintext: Data to encrypt + + Returns: + Encrypted data (includes timestamp and HMAC) + + Raises: + EncryptionError: If encryption fails + """ + try: + return self._fernet.encrypt(plaintext) + except Exception as e: + raise EncryptionError(f"Encryption failed: {e}") + + def encrypt_string(self, plaintext: str, encoding: str = "utf-8") -> bytes: + """ + Encrypt a string using Fernet symmetric encryption. + + Args: + plaintext: String to encrypt + encoding: Text encoding (default: utf-8) + + Returns: + Encrypted data + + Raises: + EncryptionError: If encryption fails + """ + try: + plaintext_bytes = plaintext.encode(encoding) + return self.encrypt(plaintext_bytes) + except UnicodeEncodeError as e: + raise EncryptionError(f"Failed to encode plaintext: {e}") + + def decrypt(self, ciphertext: bytes, ttl: Optional[int] = None) -> bytes: + """ + Decrypt data using Fernet symmetric encryption. + + Args: + ciphertext: Encrypted data to decrypt + ttl: Optional time-to-live in seconds. If provided, decryption + will fail if the token is older than ttl seconds. + + Returns: + Decrypted plaintext + + Raises: + DecryptionError: If decryption fails or token is invalid/expired + """ + try: + return self._fernet.decrypt(ciphertext, ttl=ttl) + except InvalidToken: + raise DecryptionError("Invalid or expired token") + except Exception as e: + raise DecryptionError(f"Decryption failed: {e}") + + def decrypt_to_string( + self, + ciphertext: bytes, + encoding: str = "utf-8", + ttl: Optional[int] = None + ) -> str: + """ + Decrypt data and return as string. + + Args: + ciphertext: Encrypted data to decrypt + encoding: Text encoding (default: utf-8) + ttl: Optional time-to-live in seconds + + Returns: + Decrypted string + + Raises: + DecryptionError: If decryption or decoding fails + """ + plaintext_bytes = self.decrypt(ciphertext, ttl=ttl) + try: + return plaintext_bytes.decode(encoding) + except UnicodeDecodeError as e: + raise DecryptionError(f"Failed to decode decrypted data: {e}") + + def hash_data( + self, + data: bytes, + salt: Optional[bytes] = None, + return_salt: bool = False + ) -> bytes | Tuple[bytes, bytes]: + """ + Hash data using SHA-256 with optional salt. + + Args: + data: Data to hash + salt: Optional salt bytes. If not provided, will be generated. + return_salt: If True, return tuple of (hash, salt) + + Returns: + SHA-256 hash, or tuple of (hash, salt) if return_salt=True + """ + if salt is None: + salt = self.generate_secret(self.DEFAULT_SALT_LENGTH) + + hasher = hashlib.sha256() + hasher.update(salt) + hasher.update(data) + hash_result = hasher.digest() + + if return_salt: + return hash_result, salt + return hash_result + + def hash_string( + self, + data: str, + salt: Optional[bytes] = None, + encoding: str = "utf-8", + return_salt: bool = False + ) -> bytes | Tuple[bytes, bytes]: + """ + Hash a string using SHA-256 with optional salt. + + Args: + data: String to hash + salt: Optional salt bytes + encoding: Text encoding (default: utf-8) + return_salt: If True, return tuple of (hash, salt) + + Returns: + SHA-256 hash, or tuple of (hash, salt) if return_salt=True + """ + data_bytes = data.encode(encoding) + return self.hash_data(data_bytes, salt=salt, return_salt=return_salt) + + def verify_hash( + self, + data: bytes, + expected_hash: bytes, + salt: bytes + ) -> bool: + """ + Verify that data matches expected hash with given salt. + + Uses constant-time comparison to prevent timing attacks. + + Args: + data: Data to verify + expected_hash: Expected hash value + salt: Salt used in original hash + + Returns: + True if hash matches, False otherwise + """ + computed_hash = self.hash_data(data, salt=salt) + return secrets.compare_digest(computed_hash, expected_hash) + + def derive_key_from_password( + self, + password: str, + salt: Optional[bytes] = None, + return_salt: bool = False + ) -> bytes | Tuple[bytes, bytes]: + """ + Derive an encryption key from a password using PBKDF2-HMAC-SHA256. + + Args: + password: Password to derive key from + salt: Optional salt. If not provided, will be generated. + return_salt: If True, return tuple of (key, salt) + + Returns: + 32-byte encryption key suitable for Fernet, or tuple of (key, salt) + """ + if salt is None: + salt = self.generate_secret(self.DEFAULT_SALT_LENGTH) + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=self.PBKDF2_ITERATIONS, + ) + + key = b64encode(kdf.derive(password.encode('utf-8'))) + + if return_salt: + return key, salt + return key + + @staticmethod + def generate_fernet_key() -> bytes: + """ + Generate a new Fernet encryption key. + + Returns: + 32-byte URL-safe base64-encoded key suitable for Fernet + """ + return Fernet.generate_key() + + @staticmethod + def rotate_key(old_key: bytes, new_key: bytes, ciphertext: bytes) -> bytes: + """ + Re-encrypt data with a new key (for key rotation). + + Args: + old_key: Current encryption key + new_key: New encryption key + ciphertext: Data encrypted with old key + + Returns: + Data re-encrypted with new key + + Raises: + DecryptionError: If old key cannot decrypt data + EncryptionError: If re-encryption fails + """ + old_fernet = Fernet(old_key) + new_fernet = Fernet(new_key) + + try: + plaintext = old_fernet.decrypt(ciphertext) + except InvalidToken: + raise DecryptionError("Cannot decrypt with old key") + except Exception as e: + raise DecryptionError(f"Decryption with old key failed: {e}") + + try: + return new_fernet.encrypt(plaintext) + except Exception as e: + raise EncryptionError(f"Re-encryption with new key failed: {e}") diff --git a/requirements.txt b/requirements.txt index 465a9ae..48e93aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ structlog==23.2.0 python-dotenv==1.0.1 orjson==3.10.7 pytest==8.1.0 +cryptography==42.0.5 diff --git a/tests/test_crypto_service.py b/tests/test_crypto_service.py new file mode 100644 index 0000000..8513edf --- /dev/null +++ b/tests/test_crypto_service.py @@ -0,0 +1,497 @@ +""" +Comprehensive tests for CryptoService + +Tests cover: +- Secret generation (random bytes, hex, urlsafe) +- Symmetric encryption/decryption (Fernet) +- Hashing with SHA-256 and salt +- Key derivation from passwords +- Key rotation +- Error handling and edge cases +""" + +import pytest +import time +from cryptography.fernet import Fernet + +from app.services.crypto import CryptoService +from app.services.crypto.crypto_service import ( + CryptoError, + EncryptionError, + DecryptionError +) + + +class TestSecretGeneration: + """Test cryptographically secure random secret generation""" + + def test_generate_secret_default_length(self): + """Test secret generation with default length""" + crypto = CryptoService() + secret = crypto.generate_secret() + + assert isinstance(secret, bytes) + assert len(secret) == CryptoService.DEFAULT_SECRET_LENGTH + + def test_generate_secret_custom_length(self): + """Test secret generation with custom length""" + crypto = CryptoService() + lengths = [8, 16, 32, 64, 128] + + for length in lengths: + secret = crypto.generate_secret(length) + assert len(secret) == length + + def test_generate_secret_randomness(self): + """Test that generated secrets are different (highly probable)""" + crypto = CryptoService() + secrets = [crypto.generate_secret() for _ in range(10)] + + # All secrets should be unique (probability of collision is negligible) + assert len(set(secrets)) == len(secrets) + + def test_generate_secret_invalid_length(self): + """Test that invalid length raises ValueError""" + crypto = CryptoService() + + with pytest.raises(ValueError, match="Secret length must be positive"): + crypto.generate_secret(0) + + with pytest.raises(ValueError): + crypto.generate_secret(-1) + + def test_generate_secret_hex(self): + """Test hex secret generation""" + crypto = CryptoService() + secret = crypto.generate_secret_hex(16) + + assert isinstance(secret, str) + assert len(secret) == 32 # Hex is 2x the byte length + assert all(c in '0123456789abcdef' for c in secret) + + def test_generate_secret_urlsafe(self): + """Test URL-safe secret generation""" + crypto = CryptoService() + secret = crypto.generate_secret_urlsafe(32) + + assert isinstance(secret, str) + # URL-safe base64 uses only alphanumeric, -, and _ + assert all(c.isalnum() or c in '-_' for c in secret) + + +class TestSymmetricEncryption: + """Test Fernet symmetric encryption and decryption""" + + def test_encrypt_decrypt_bytes(self): + """Test basic encryption and decryption of bytes""" + crypto = CryptoService() + plaintext = b"Secret message for encryption" + + ciphertext = crypto.encrypt(plaintext) + decrypted = crypto.decrypt(ciphertext) + + assert ciphertext != plaintext + assert decrypted == plaintext + + def test_encrypt_decrypt_string(self): + """Test encryption and decryption of strings""" + crypto = CryptoService() + plaintext = "Secret string message" + + ciphertext = crypto.encrypt_string(plaintext) + decrypted = crypto.decrypt_to_string(ciphertext) + + assert isinstance(ciphertext, bytes) + assert decrypted == plaintext + + def test_encrypt_unicode(self): + """Test encryption of Unicode strings""" + crypto = CryptoService() + plaintext = "Secret: 你好世界 🔐 مرحبا" + + ciphertext = crypto.encrypt_string(plaintext) + decrypted = crypto.decrypt_to_string(ciphertext) + + assert decrypted == plaintext + + def test_ciphertext_different_each_time(self): + """Test that same plaintext produces different ciphertext (IV)""" + crypto = CryptoService() + plaintext = b"Same message" + + ciphertext1 = crypto.encrypt(plaintext) + ciphertext2 = crypto.encrypt(plaintext) + + # Ciphertexts should be different due to random IV + assert ciphertext1 != ciphertext2 + + # But both should decrypt to same plaintext + assert crypto.decrypt(ciphertext1) == plaintext + assert crypto.decrypt(ciphertext2) == plaintext + + def test_decrypt_with_wrong_key(self): + """Test that decryption fails with wrong key""" + crypto1 = CryptoService() + crypto2 = CryptoService() # Different key + + plaintext = b"Secret data" + ciphertext = crypto1.encrypt(plaintext) + + with pytest.raises(DecryptionError, match="Invalid or expired token"): + crypto2.decrypt(ciphertext) + + def test_decrypt_invalid_data(self): + """Test that decryption of invalid data raises error""" + crypto = CryptoService() + + with pytest.raises(DecryptionError): + crypto.decrypt(b"not-valid-ciphertext") + + with pytest.raises(DecryptionError): + crypto.decrypt(b"") + + def test_encrypt_with_master_key(self): + """Test initialization with existing master key""" + # Generate a key and use it for two instances + master_key = CryptoService.generate_fernet_key() + + crypto1 = CryptoService(master_key=master_key) + crypto2 = CryptoService(master_key=master_key) + + plaintext = b"Shared key test" + ciphertext = crypto1.encrypt(plaintext) + + # crypto2 should be able to decrypt since it has same key + decrypted = crypto2.decrypt(ciphertext) + assert decrypted == plaintext + + def test_invalid_master_key(self): + """Test that invalid master key raises error""" + with pytest.raises(CryptoError, match="Invalid master key"): + CryptoService(master_key=b"not-a-valid-key") + + def test_encryption_ttl(self): + """Test time-to-live parameter for decryption""" + crypto = CryptoService() + plaintext = b"Time-sensitive data" + + ciphertext = crypto.encrypt(plaintext) + + # Should decrypt successfully with generous TTL + decrypted = crypto.decrypt(ciphertext, ttl=60) + assert decrypted == plaintext + + # Encrypt and wait, then try with very short TTL + time.sleep(2) + ciphertext_old = crypto.encrypt(plaintext) + time.sleep(2) + + # Should fail with TTL of 1 second + with pytest.raises(DecryptionError, match="Invalid or expired token"): + crypto.decrypt(ciphertext_old, ttl=1) + + def test_master_key_property(self): + """Test master_key property access""" + crypto = CryptoService() + key = crypto.master_key + + assert isinstance(key, bytes) + assert len(key) > 0 + + # Key should be valid Fernet key + Fernet(key) # Should not raise + + +class TestHashing: + """Test SHA-256 hashing with salt""" + + def test_hash_data_basic(self): + """Test basic data hashing""" + crypto = CryptoService() + data = b"Data to hash" + + hash_result = crypto.hash_data(data) + + assert isinstance(hash_result, bytes) + assert len(hash_result) == 32 # SHA-256 produces 32 bytes + + def test_hash_with_salt(self): + """Test hashing with explicit salt""" + crypto = CryptoService() + data = b"Data to hash" + salt = crypto.generate_secret(16) + + hash1 = crypto.hash_data(data, salt=salt) + hash2 = crypto.hash_data(data, salt=salt) + + # Same data + salt should produce same hash + assert hash1 == hash2 + + def test_hash_different_salts(self): + """Test that different salts produce different hashes""" + crypto = CryptoService() + data = b"Same data" + + hash1, salt1 = crypto.hash_data(data, return_salt=True) + hash2, salt2 = crypto.hash_data(data, return_salt=True) + + # Different salts should produce different hashes + assert salt1 != salt2 + assert hash1 != hash2 + + def test_hash_string(self): + """Test string hashing""" + crypto = CryptoService() + data = "String to hash" + + hash_result = crypto.hash_string(data) + + assert isinstance(hash_result, bytes) + assert len(hash_result) == 32 + + def test_hash_unicode_string(self): + """Test Unicode string hashing""" + crypto = CryptoService() + data = "Unicode: 你好 🔐" + + hash_result = crypto.hash_string(data) + assert len(hash_result) == 32 + + def test_verify_hash_correct(self): + """Test hash verification with correct data""" + crypto = CryptoService() + data = b"Data to verify" + + hash_result, salt = crypto.hash_data(data, return_salt=True) + + # Verification should succeed + assert crypto.verify_hash(data, hash_result, salt) is True + + def test_verify_hash_incorrect(self): + """Test hash verification with incorrect data""" + crypto = CryptoService() + data = b"Original data" + wrong_data = b"Wrong data" + + hash_result, salt = crypto.hash_data(data, return_salt=True) + + # Verification should fail + assert crypto.verify_hash(wrong_data, hash_result, salt) is False + + def test_verify_hash_wrong_salt(self): + """Test hash verification with wrong salt""" + crypto = CryptoService() + data = b"Data to verify" + + hash_result, _ = crypto.hash_data(data, return_salt=True) + wrong_salt = crypto.generate_secret(16) + + # Verification should fail with wrong salt + assert crypto.verify_hash(data, hash_result, wrong_salt) is False + + def test_hash_deterministic_with_same_salt(self): + """Test that hashing is deterministic with same salt""" + crypto = CryptoService() + data = b"Consistent data" + salt = crypto.generate_secret(16) + + hashes = [crypto.hash_data(data, salt=salt) for _ in range(5)] + + # All hashes should be identical + assert len(set(hashes)) == 1 + + +class TestKeyDerivation: + """Test password-based key derivation""" + + def test_derive_key_from_password(self): + """Test basic key derivation from password""" + crypto = CryptoService() + password = "SecurePassword123!" + + key = crypto.derive_key_from_password(password) + + assert isinstance(key, bytes) + # Should be valid Fernet key + Fernet(key) + + def test_derive_key_deterministic(self): + """Test that same password+salt produces same key""" + crypto = CryptoService() + password = "MyPassword" + salt = crypto.generate_secret(16) + + key1 = crypto.derive_key_from_password(password, salt=salt) + key2 = crypto.derive_key_from_password(password, salt=salt) + + assert key1 == key2 + + def test_derive_key_different_salts(self): + """Test that different salts produce different keys""" + crypto = CryptoService() + password = "SamePassword" + + key1, salt1 = crypto.derive_key_from_password(password, return_salt=True) + key2, salt2 = crypto.derive_key_from_password(password, return_salt=True) + + assert salt1 != salt2 + assert key1 != key2 + + def test_derive_key_different_passwords(self): + """Test that different passwords produce different keys""" + crypto = CryptoService() + salt = crypto.generate_secret(16) + + key1 = crypto.derive_key_from_password("Password1", salt=salt) + key2 = crypto.derive_key_from_password("Password2", salt=salt) + + assert key1 != key2 + + def test_derived_key_works_for_encryption(self): + """Test that derived key can be used for encryption""" + crypto = CryptoService() + password = "EncryptionPassword" + + key, salt = crypto.derive_key_from_password(password, return_salt=True) + + # Create new CryptoService with derived key + crypto_derived = CryptoService(master_key=key) + + plaintext = b"Test message" + ciphertext = crypto_derived.encrypt(plaintext) + decrypted = crypto_derived.decrypt(ciphertext) + + assert decrypted == plaintext + + +class TestKeyRotation: + """Test key rotation functionality""" + + def test_rotate_key_basic(self): + """Test basic key rotation""" + old_key = CryptoService.generate_fernet_key() + new_key = CryptoService.generate_fernet_key() + + crypto_old = CryptoService(master_key=old_key) + plaintext = b"Data to rotate" + + # Encrypt with old key + ciphertext_old = crypto_old.encrypt(plaintext) + + # Rotate to new key + ciphertext_new = CryptoService.rotate_key(old_key, new_key, ciphertext_old) + + # Decrypt with new key + crypto_new = CryptoService(master_key=new_key) + decrypted = crypto_new.decrypt(ciphertext_new) + + assert decrypted == plaintext + + def test_rotate_key_old_key_invalid(self): + """Test rotation fails with wrong old key""" + old_key = CryptoService.generate_fernet_key() + wrong_key = CryptoService.generate_fernet_key() + new_key = CryptoService.generate_fernet_key() + + crypto = CryptoService(master_key=old_key) + ciphertext = crypto.encrypt(b"Test data") + + # Rotation should fail with wrong old key + with pytest.raises(DecryptionError, match="Cannot decrypt with old key"): + CryptoService.rotate_key(wrong_key, new_key, ciphertext) + + def test_rotate_key_invalid_ciphertext(self): + """Test rotation fails with invalid ciphertext""" + old_key = CryptoService.generate_fernet_key() + new_key = CryptoService.generate_fernet_key() + + with pytest.raises(DecryptionError): + CryptoService.rotate_key(old_key, new_key, b"invalid-data") + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_encrypt_empty_data(self): + """Test encryption of empty data""" + crypto = CryptoService() + + ciphertext = crypto.encrypt(b"") + decrypted = crypto.decrypt(ciphertext) + + assert decrypted == b"" + + def test_encrypt_large_data(self): + """Test encryption of large data""" + crypto = CryptoService() + large_data = b"X" * 1_000_000 # 1 MB + + ciphertext = crypto.encrypt(large_data) + decrypted = crypto.decrypt(ciphertext) + + assert decrypted == large_data + + def test_hash_empty_data(self): + """Test hashing of empty data""" + crypto = CryptoService() + + hash_result = crypto.hash_data(b"") + assert len(hash_result) == 32 + + def test_generate_fernet_key_static(self): + """Test static Fernet key generation""" + key1 = CryptoService.generate_fernet_key() + key2 = CryptoService.generate_fernet_key() + + assert key1 != key2 + assert isinstance(key1, bytes) + assert isinstance(key2, bytes) + + # Both should be valid Fernet keys + Fernet(key1) + Fernet(key2) + + def test_encoding_error_handling(self): + """Test handling of encoding errors""" + crypto = CryptoService() + + # Create invalid UTF-8 bytes + invalid_utf8 = b'\xff\xfe' + ciphertext = crypto.encrypt(invalid_utf8) + + # Should fail when trying to decode as UTF-8 + with pytest.raises(DecryptionError, match="Failed to decode"): + crypto.decrypt_to_string(ciphertext, encoding='utf-8') + + +class TestSecurityProperties: + """Test security properties and best practices""" + + def test_constant_time_comparison(self): + """Test that hash verification uses constant-time comparison""" + crypto = CryptoService() + data = b"Test data" + + hash1, salt = crypto.hash_data(data, return_salt=True) + hash2 = crypto.hash_data(b"Different data", salt=salt) + + # Both verifications should take similar time (constant-time) + # This is verified by using secrets.compare_digest internally + result1 = crypto.verify_hash(data, hash1, salt) + result2 = crypto.verify_hash(data, hash2, salt) + + assert result1 is True + assert result2 is False + + def test_pbkdf2_iterations(self): + """Test that PBKDF2 uses sufficient iterations""" + # OWASP recommends 480,000 iterations for PBKDF2-HMAC-SHA256 (2023) + assert CryptoService.PBKDF2_ITERATIONS >= 480000 + + def test_default_lengths_secure(self): + """Test that default lengths are cryptographically secure""" + # 256-bit secrets (32 bytes) + assert CryptoService.DEFAULT_SECRET_LENGTH >= 32 + + # 128-bit salt (16 bytes) + assert CryptoService.DEFAULT_SALT_LENGTH >= 16 From fb506905177f7c2070f5d5d0ad93b8250da45a0d Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 18 Nov 2025 04:46:15 +0000 Subject: [PATCH 2/2] feat(tokens): Add ephemeral token issuance service (Window 18) Implement comprehensive token management system with: - Short-lived token issuance with configurable TTL (1s-3600s) - In-memory token store with automatic expiry cleanup - Thread-safe token operations with RLock - Token validation with TTL remaining calculation - Token revocation support Token Models: - TokenScope: Resource and action-based permissions - Token: Ephemeral token with expiry tracking - Pydantic schemas for API requests/responses API Endpoints (/tokens): - POST /issue: Issue new ephemeral token - POST /validate: Validate token and check expiry - DELETE /{token_id}: Revoke token - GET /stats: Active token statistics - POST /cleanup: Manual expired token cleanup Features: - Automatic background cleanup of expired tokens (60s interval) - TTL validation (1-3600 seconds) - Unique token IDs using cryptographically secure random generation - Token metadata support for custom attributes - Comprehensive test suite (50 tests, 100% pass rate) Testing: - 33 service/store tests (issuance, validation, expiry, concurrency) - 17 API endpoint tests (full lifecycle workflows) - Total: 92 tests passed (including Window 17 crypto tests) Dependencies: - Add pydantic-settings==2.2.1 - Add httpx==0.27.0 for test client - Update config.py to use pydantic-settings --- app/api/token_router.py | 136 +++++++ app/config.py | 3 +- app/main.py | 2 + app/services/crypto/__init__.py | 23 +- app/services/crypto/token_models.py | 105 +++++ app/services/crypto/token_service.py | 243 ++++++++++++ app/services/crypto/token_store.py | 174 ++++++++ requirements.txt | 2 + tests/test_token_api.py | 418 +++++++++++++++++++ tests/test_token_service.py | 573 +++++++++++++++++++++++++++ 10 files changed, 1677 insertions(+), 2 deletions(-) create mode 100644 app/api/token_router.py create mode 100644 app/services/crypto/token_models.py create mode 100644 app/services/crypto/token_service.py create mode 100644 app/services/crypto/token_store.py create mode 100644 tests/test_token_api.py create mode 100644 tests/test_token_service.py diff --git a/app/api/token_router.py b/app/api/token_router.py new file mode 100644 index 0000000..53ab978 --- /dev/null +++ b/app/api/token_router.py @@ -0,0 +1,136 @@ +""" +API Router for SecretGateway Token Management +""" + +from fastapi import APIRouter, HTTPException, status +from typing import Dict +import logging + +from app.services.crypto import ( + TokenService, + TokenIssuanceRequest, + TokenIssuanceResponse, + TokenValidationRequest, + TokenValidationResponse +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/tokens", tags=["tokens"]) + +# Global token service instance +_token_service: TokenService = TokenService() + + +def get_token_service() -> TokenService: + """Get the global token service instance""" + return _token_service + + +@router.post( + "/issue", + response_model=TokenIssuanceResponse, + status_code=status.HTTP_201_CREATED, + summary="Issue ephemeral token", + description="Issue a short-lived token with specified scope and TTL" +) +async def issue_token(request: TokenIssuanceRequest) -> TokenIssuanceResponse: + """ + Issue a new ephemeral token + + - **scope**: Token scope with resource and actions + - **ttl_seconds**: Time-to-live (1-3600 seconds) + - **metadata**: Optional metadata dictionary + + Returns the token ID and expiration details. + """ + try: + response = _token_service.issue_token_from_request(request) + logger.info(f"Token issued: {response.token_id[:8]}... for {request.scope.resource}") + return response + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + logger.error(f"Token issuance failed: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to issue token" + ) + + +@router.post( + "/validate", + response_model=TokenValidationResponse, + summary="Validate token", + description="Validate a token and check if it's still valid" +) +async def validate_token(request: TokenValidationRequest) -> TokenValidationResponse: + """ + Validate a token + + - **token_id**: Token identifier to validate + + Returns validation result with token details if valid. + """ + response = _token_service.validate_token(request.token_id) + return response + + +@router.delete( + "/{token_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Revoke token", + description="Revoke (delete) a token before it expires" +) +async def revoke_token(token_id: str) -> None: + """ + Revoke a token + + - **token_id**: Token identifier to revoke + + Returns 204 No Content on success, 404 if token not found. + """ + revoked = _token_service.revoke_token(token_id) + if not revoked: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Token not found" + ) + + +@router.get( + "/stats", + response_model=Dict[str, int], + summary="Get token statistics", + description="Get statistics about active tokens" +) +async def get_token_stats() -> Dict[str, int]: + """ + Get token statistics + + Returns count of active tokens. + """ + return { + "active_tokens": _token_service.get_active_token_count() + } + + +@router.post( + "/cleanup", + response_model=Dict[str, int], + summary="Cleanup expired tokens", + description="Manually trigger cleanup of expired tokens" +) +async def cleanup_expired_tokens() -> Dict[str, int]: + """ + Manually trigger cleanup of expired tokens + + Returns count of tokens removed. + """ + count = _token_service.cleanup_expired_tokens() + return { + "removed_tokens": count + } diff --git a/app/config.py b/app/config.py index 5b0b955..b43f58a 100644 --- a/app/config.py +++ b/app/config.py @@ -1,4 +1,5 @@ -from pydantic import BaseSettings, AnyUrl +from pydantic import AnyUrl +from pydantic_settings import BaseSettings from functools import lru_cache class Settings(BaseSettings): diff --git a/app/main.py b/app/main.py index d151881..85c387b 100644 --- a/app/main.py +++ b/app/main.py @@ -2,12 +2,14 @@ from .config import get_settings from .logging import setup_logging from .api.router import router +from .api.token_router import router as token_router settings = get_settings() setup_logging(settings.LOG_JSON) app = FastAPI(title="AgentAddon EventBridge", version="0.1.0") app.include_router(router) +app.include_router(token_router) @app.get("/health") async def health(): diff --git a/app/services/crypto/__init__.py b/app/services/crypto/__init__.py index 7c2c912..1e058ee 100644 --- a/app/services/crypto/__init__.py +++ b/app/services/crypto/__init__.py @@ -5,8 +5,29 @@ - Secret generation - Symmetric encryption/decryption - Secure hashing with salt +- Ephemeral token issuance and validation """ from .crypto_service import CryptoService +from .token_service import TokenService +from .token_store import InMemoryTokenStore +from .token_models import ( + Token, + TokenScope, + TokenIssuanceRequest, + TokenIssuanceResponse, + TokenValidationRequest, + TokenValidationResponse +) -__all__ = ["CryptoService"] +__all__ = [ + "CryptoService", + "TokenService", + "InMemoryTokenStore", + "Token", + "TokenScope", + "TokenIssuanceRequest", + "TokenIssuanceResponse", + "TokenValidationRequest", + "TokenValidationResponse", +] diff --git a/app/services/crypto/token_models.py b/app/services/crypto/token_models.py new file mode 100644 index 0000000..bf64538 --- /dev/null +++ b/app/services/crypto/token_models.py @@ -0,0 +1,105 @@ +""" +Token models and schemas for ephemeral token management +""" + +from datetime import datetime, timedelta +from typing import Optional +from pydantic import BaseModel, Field, field_validator + + +class TokenScope(BaseModel): + """ + Scope definition for token permissions + """ + resource: str = Field(..., description="Resource identifier") + actions: list[str] = Field(default_factory=list, description="Allowed actions") + metadata: dict = Field(default_factory=dict, description="Additional scope metadata") + + +class TokenIssuanceRequest(BaseModel): + """ + Request to issue a new ephemeral token + """ + scope: TokenScope + ttl_seconds: int = Field( + default=300, + ge=1, + le=3600, + description="Time-to-live in seconds (1s to 1h)" + ) + metadata: dict = Field(default_factory=dict, description="Optional metadata") + + @field_validator('ttl_seconds') + @classmethod + def validate_ttl(cls, v: int) -> int: + if v < 1: + raise ValueError("TTL must be at least 1 second") + if v > 3600: + raise ValueError("TTL cannot exceed 3600 seconds (1 hour)") + return v + + +class Token(BaseModel): + """ + Ephemeral token representation + """ + token_id: str = Field(..., description="Unique token identifier") + scope: TokenScope + created_at: datetime = Field(default_factory=datetime.utcnow) + expires_at: datetime + metadata: dict = Field(default_factory=dict) + + @property + def is_expired(self) -> bool: + """Check if token has expired""" + return datetime.utcnow() >= self.expires_at + + @property + def ttl_remaining(self) -> float: + """Get remaining time-to-live in seconds""" + if self.is_expired: + return 0.0 + delta = self.expires_at - datetime.utcnow() + return delta.total_seconds() + + def to_dict(self) -> dict: + """Convert to dictionary""" + return { + "token_id": self.token_id, + "scope": self.scope.model_dump(), + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat(), + "metadata": self.metadata, + "is_expired": self.is_expired, + "ttl_remaining": self.ttl_remaining + } + + +class TokenIssuanceResponse(BaseModel): + """ + Response from token issuance + """ + token_id: str + scope: TokenScope + expires_at: datetime + ttl_seconds: int + message: str = "Token issued successfully" + + +class TokenValidationRequest(BaseModel): + """ + Request to validate a token + """ + token_id: str = Field(..., min_length=1) + + +class TokenValidationResponse(BaseModel): + """ + Response from token validation + """ + valid: bool + token_id: Optional[str] = None + scope: Optional[TokenScope] = None + expires_at: Optional[datetime] = None + ttl_remaining: Optional[float] = None + reason: Optional[str] = None diff --git a/app/services/crypto/token_service.py b/app/services/crypto/token_service.py new file mode 100644 index 0000000..3a0f871 --- /dev/null +++ b/app/services/crypto/token_service.py @@ -0,0 +1,243 @@ +""" +TokenService for ephemeral token issuance and validation +""" + +import logging +from datetime import datetime, timedelta +from typing import Optional + +from .crypto_service import CryptoService +from .token_store import InMemoryTokenStore +from .token_models import ( + Token, + TokenScope, + TokenIssuanceRequest, + TokenIssuanceResponse, + TokenValidationResponse +) + +logger = logging.getLogger(__name__) + + +class TokenServiceError(Exception): + """Base exception for token service errors""" + pass + + +class TokenIssuanceError(TokenServiceError): + """Raised when token issuance fails""" + pass + + +class TokenValidationError(TokenServiceError): + """Raised when token validation fails""" + pass + + +class TokenService: + """ + Service for managing ephemeral tokens + + Provides: + - Token issuance with configurable TTL + - Token validation with expiry checking + - Automatic cleanup of expired tokens + """ + + def __init__( + self, + crypto_service: Optional[CryptoService] = None, + token_store: Optional[InMemoryTokenStore] = None, + cleanup_interval: int = 60 + ): + """ + Initialize TokenService + + Args: + crypto_service: CryptoService instance for token ID generation + token_store: Token store instance (creates new if not provided) + cleanup_interval: Cleanup interval in seconds for expired tokens + """ + self._crypto = crypto_service or CryptoService() + self._store = token_store or InMemoryTokenStore(cleanup_interval_seconds=cleanup_interval) + + def issue_token( + self, + scope: TokenScope, + ttl_seconds: int = 300, + metadata: Optional[dict] = None + ) -> TokenIssuanceResponse: + """ + Issue a new ephemeral token + + Args: + scope: Token scope defining permissions + ttl_seconds: Time-to-live in seconds (default: 300, max: 3600) + metadata: Optional metadata to attach to token + + Returns: + TokenIssuanceResponse with token details + + Raises: + TokenIssuanceError: If token issuance fails + """ + try: + # Validate TTL + if ttl_seconds < 1: + raise ValueError("TTL must be at least 1 second") + if ttl_seconds > 3600: + raise ValueError("TTL cannot exceed 3600 seconds") + + # Generate secure token ID + token_id = self._crypto.generate_secret_urlsafe(32) + + # Calculate expiration + created_at = datetime.utcnow() + expires_at = created_at + timedelta(seconds=ttl_seconds) + + # Create token + token = Token( + token_id=token_id, + scope=scope, + created_at=created_at, + expires_at=expires_at, + metadata=metadata or {} + ) + + # Store token + self._store.store(token) + + logger.info( + f"Issued token {token_id[:8]}... for resource '{scope.resource}' " + f"with TTL {ttl_seconds}s" + ) + + return TokenIssuanceResponse( + token_id=token_id, + scope=scope, + expires_at=expires_at, + ttl_seconds=ttl_seconds + ) + + except ValueError as e: + raise TokenIssuanceError(f"Invalid token parameters: {e}") + except Exception as e: + logger.error(f"Failed to issue token: {e}") + raise TokenIssuanceError(f"Token issuance failed: {e}") + + def issue_token_from_request( + self, + request: TokenIssuanceRequest + ) -> TokenIssuanceResponse: + """ + Issue token from a TokenIssuanceRequest + + Args: + request: Token issuance request + + Returns: + TokenIssuanceResponse + """ + return self.issue_token( + scope=request.scope, + ttl_seconds=request.ttl_seconds, + metadata=request.metadata + ) + + def validate_token(self, token_id: str) -> TokenValidationResponse: + """ + Validate a token + + Args: + token_id: Token identifier to validate + + Returns: + TokenValidationResponse with validation result + """ + try: + token = self._store.get(token_id) + + if token is None: + logger.debug(f"Token validation failed: {token_id[:8]}... not found or expired") + return TokenValidationResponse( + valid=False, + token_id=token_id, + reason="Token not found or has expired" + ) + + # Token exists and is not expired (get() already checks expiry) + logger.debug(f"Token {token_id[:8]}... validated successfully") + return TokenValidationResponse( + valid=True, + token_id=token.token_id, + scope=token.scope, + expires_at=token.expires_at, + ttl_remaining=token.ttl_remaining + ) + + except Exception as e: + logger.error(f"Token validation error: {e}") + return TokenValidationResponse( + valid=False, + token_id=token_id, + reason=f"Validation error: {e}" + ) + + def revoke_token(self, token_id: str) -> bool: + """ + Revoke a token (remove from store) + + Args: + token_id: Token identifier to revoke + + Returns: + True if token was revoked, False if not found + """ + revoked = self._store.remove(token_id) + if revoked: + logger.info(f"Revoked token {token_id[:8]}...") + else: + logger.debug(f"Token {token_id[:8]}... not found for revocation") + return revoked + + def get_token_info(self, token_id: str) -> Optional[Token]: + """ + Get token information + + Args: + token_id: Token identifier + + Returns: + Token if found and valid, None otherwise + """ + return self._store.get(token_id) + + def cleanup_expired_tokens(self) -> int: + """ + Manually trigger cleanup of expired tokens + + Returns: + Number of tokens removed + """ + count = self._store.cleanup_expired() + logger.info(f"Manual cleanup removed {count} expired tokens") + return count + + def get_active_token_count(self) -> int: + """ + Get count of active (non-expired) tokens + + Returns: + Number of active tokens + """ + return self._store.count() + + def clear_all_tokens(self) -> None: + """Clear all tokens from the store""" + self._store.clear() + logger.warning("Cleared all tokens from store") + + def shutdown(self) -> None: + """Shutdown the token service""" + self._store.shutdown() + logger.info("Token service shutdown") diff --git a/app/services/crypto/token_store.py b/app/services/crypto/token_store.py new file mode 100644 index 0000000..ebbd5d6 --- /dev/null +++ b/app/services/crypto/token_store.py @@ -0,0 +1,174 @@ +""" +In-memory token store with automatic expiry and cleanup +""" + +import threading +from datetime import datetime, timedelta +from typing import Optional, Dict +import logging + +from .token_models import Token + +logger = logging.getLogger(__name__) + + +class InMemoryTokenStore: + """ + Thread-safe in-memory store for ephemeral tokens with automatic cleanup + """ + + def __init__(self, cleanup_interval_seconds: int = 60): + """ + Initialize token store + + Args: + cleanup_interval_seconds: Interval for automatic cleanup of expired tokens + """ + self._store: Dict[str, Token] = {} + self._lock = threading.RLock() + self._cleanup_interval = cleanup_interval_seconds + self._cleanup_timer: Optional[threading.Timer] = None + self._shutdown = False + + # Start automatic cleanup + self._schedule_cleanup() + + def store(self, token: Token) -> None: + """ + Store a token + + Args: + token: Token to store + """ + with self._lock: + self._store[token.token_id] = token + logger.debug(f"Stored token {token.token_id}, expires at {token.expires_at}") + + def get(self, token_id: str) -> Optional[Token]: + """ + Retrieve a token by ID + + Args: + token_id: Token identifier + + Returns: + Token if found and not expired, None otherwise + """ + with self._lock: + token = self._store.get(token_id) + + if token is None: + logger.debug(f"Token {token_id} not found") + return None + + if token.is_expired: + logger.debug(f"Token {token_id} has expired") + # Remove expired token + del self._store[token_id] + return None + + return token + + def remove(self, token_id: str) -> bool: + """ + Remove a token from the store + + Args: + token_id: Token identifier + + Returns: + True if token was removed, False if not found + """ + with self._lock: + if token_id in self._store: + del self._store[token_id] + logger.debug(f"Removed token {token_id}") + return True + return False + + def cleanup_expired(self) -> int: + """ + Remove all expired tokens from the store + + Returns: + Number of tokens removed + """ + with self._lock: + now = datetime.utcnow() + expired_ids = [ + token_id for token_id, token in self._store.items() + if token.expires_at <= now + ] + + for token_id in expired_ids: + del self._store[token_id] + + if expired_ids: + logger.info(f"Cleaned up {len(expired_ids)} expired tokens") + + return len(expired_ids) + + def count(self) -> int: + """ + Get count of active (non-expired) tokens + + Returns: + Number of active tokens + """ + with self._lock: + # Clean up expired tokens first + self.cleanup_expired() + return len(self._store) + + def count_all(self) -> int: + """ + Get count of all tokens (including expired) + + Returns: + Total number of tokens in store + """ + with self._lock: + return len(self._store) + + def clear(self) -> None: + """Clear all tokens from the store""" + with self._lock: + self._store.clear() + logger.info("Cleared all tokens from store") + + def _schedule_cleanup(self) -> None: + """Schedule the next automatic cleanup""" + if self._shutdown: + return + + self._cleanup_timer = threading.Timer( + self._cleanup_interval, + self._run_cleanup + ) + self._cleanup_timer.daemon = True + self._cleanup_timer.start() + + def _run_cleanup(self) -> None: + """Run cleanup and schedule next one""" + try: + self.cleanup_expired() + except Exception as e: + logger.error(f"Error during token cleanup: {e}") + finally: + if not self._shutdown: + self._schedule_cleanup() + + def shutdown(self) -> None: + """Shutdown the store and cancel cleanup timer""" + self._shutdown = True + if self._cleanup_timer: + self._cleanup_timer.cancel() + logger.info("Token store shutdown") + + def __len__(self) -> int: + """Get count of active tokens""" + return self.count() + + def __contains__(self, token_id: str) -> bool: + """Check if token exists and is not expired""" + return self.get(token_id) is not None diff --git a/requirements.txt b/requirements.txt index 48e93aa..445ef88 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ fastapi==0.115.0 uvicorn[standard]==0.30.0 pydantic==2.6.4 +pydantic-settings==2.2.1 redis==5.0.1 structlog==23.2.0 python-dotenv==1.0.1 orjson==3.10.7 pytest==8.1.0 cryptography==42.0.5 +httpx==0.27.0 diff --git a/tests/test_token_api.py b/tests/test_token_api.py new file mode 100644 index 0000000..e75e884 --- /dev/null +++ b/tests/test_token_api.py @@ -0,0 +1,418 @@ +""" +API endpoint tests for token management +""" + +import pytest +import time +from fastapi.testclient import TestClient + +from app.main import app +from app.api.token_router import get_token_service + + +@pytest.fixture +def client(): + """Create test client""" + return TestClient(app) + + +@pytest.fixture(autouse=True) +def clear_tokens(): + """Clear all tokens before each test""" + service = get_token_service() + service.clear_all_tokens() + yield + service.clear_all_tokens() + + +class TestTokenIssuanceAPI: + """Test token issuance API endpoints""" + + def test_issue_token_success(self, client): + """Test successful token issuance""" + response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/users", + "actions": ["read", "write"], + "metadata": {} + }, + "ttl_seconds": 300, + "metadata": {"client": "test"} + } + ) + + assert response.status_code == 201 + data = response.json() + + assert "token_id" in data + assert len(data["token_id"]) > 0 + assert data["scope"]["resource"] == "api/users" + assert data["scope"]["actions"] == ["read", "write"] + assert data["ttl_seconds"] == 300 + assert "expires_at" in data + + def test_issue_token_minimal_request(self, client): + """Test token issuance with minimal request""" + response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + } + } + ) + + assert response.status_code == 201 + data = response.json() + assert data["ttl_seconds"] == 300 # Default TTL + + def test_issue_token_custom_ttl(self, client): + """Test token issuance with custom TTL""" + ttls = [1, 60, 300, 600, 3600] + + for ttl in ttls: + response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": ttl + } + ) + + assert response.status_code == 201 + assert response.json()["ttl_seconds"] == ttl + + def test_issue_token_invalid_ttl_too_low(self, client): + """Test token issuance with TTL too low""" + response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 0 + } + ) + + assert response.status_code == 422 # Validation error + + def test_issue_token_invalid_ttl_too_high(self, client): + """Test token issuance with TTL too high""" + response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 5000 + } + ) + + assert response.status_code == 422 # Validation error + + def test_issue_token_missing_scope(self, client): + """Test token issuance without scope""" + response = client.post( + "/tokens/issue", + json={ + "ttl_seconds": 300 + } + ) + + assert response.status_code == 422 # Validation error + + +class TestTokenValidationAPI: + """Test token validation API endpoints""" + + def test_validate_valid_token(self, client): + """Test validation of valid token""" + # Issue token + issue_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 60 + } + ) + token_id = issue_response.json()["token_id"] + + # Validate token + validate_response = client.post( + "/tokens/validate", + json={"token_id": token_id} + ) + + assert validate_response.status_code == 200 + data = validate_response.json() + + assert data["valid"] is True + assert data["token_id"] == token_id + assert data["scope"]["resource"] == "api/test" + assert data["ttl_remaining"] > 0 + + def test_validate_nonexistent_token(self, client): + """Test validation of non-existent token""" + response = client.post( + "/tokens/validate", + json={"token_id": "nonexistent-token-id"} + ) + + assert response.status_code == 200 + data = response.json() + + assert data["valid"] is False + assert data["reason"] is not None + + def test_validate_expired_token(self, client): + """Test validation of expired token""" + # Issue token with 1 second TTL + issue_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 1 + } + ) + token_id = issue_response.json()["token_id"] + + # Wait for expiry + time.sleep(2) + + # Validate expired token + validate_response = client.post( + "/tokens/validate", + json={"token_id": token_id} + ) + + assert validate_response.status_code == 200 + data = validate_response.json() + + assert data["valid"] is False + + +class TestTokenRevocationAPI: + """Test token revocation API endpoints""" + + def test_revoke_valid_token(self, client): + """Test revoking a valid token""" + # Issue token + issue_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 300 + } + ) + token_id = issue_response.json()["token_id"] + + # Revoke token + revoke_response = client.delete(f"/tokens/{token_id}") + + assert revoke_response.status_code == 204 + + # Verify token is invalid + validate_response = client.post( + "/tokens/validate", + json={"token_id": token_id} + ) + + assert validate_response.json()["valid"] is False + + def test_revoke_nonexistent_token(self, client): + """Test revoking non-existent token""" + response = client.delete("/tokens/nonexistent-token") + + assert response.status_code == 404 + + def test_revoke_already_revoked_token(self, client): + """Test revoking an already revoked token""" + # Issue token + issue_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 300 + } + ) + token_id = issue_response.json()["token_id"] + + # Revoke token + client.delete(f"/tokens/{token_id}") + + # Try to revoke again + revoke_again_response = client.delete(f"/tokens/{token_id}") + + assert revoke_again_response.status_code == 404 + + +class TestTokenStatisticsAPI: + """Test token statistics API endpoints""" + + def test_get_stats_no_tokens(self, client): + """Test getting stats with no active tokens""" + response = client.get("/tokens/stats") + + assert response.status_code == 200 + data = response.json() + + assert "active_tokens" in data + assert data["active_tokens"] == 0 + + def test_get_stats_with_tokens(self, client): + """Test getting stats with active tokens""" + # Issue 3 tokens + for _ in range(3): + client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 300 + } + ) + + response = client.get("/tokens/stats") + + assert response.status_code == 200 + data = response.json() + + assert data["active_tokens"] == 3 + + def test_cleanup_endpoint(self, client): + """Test cleanup endpoint""" + # Issue short-lived tokens + for _ in range(3): + client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/test", + "actions": ["read"] + }, + "ttl_seconds": 1 + } + ) + + # Wait for expiry + time.sleep(2) + + # Trigger cleanup + response = client.post("/tokens/cleanup") + + assert response.status_code == 200 + data = response.json() + + assert "removed_tokens" in data + assert data["removed_tokens"] == 3 + + +class TestTokenWorkflow: + """Test complete token workflows""" + + def test_complete_token_lifecycle(self, client): + """Test complete token lifecycle: issue, validate, revoke""" + # 1. Issue token + issue_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/users", + "actions": ["read", "write", "delete"], + "metadata": {"region": "us-west"} + }, + "ttl_seconds": 600, + "metadata": {"user_id": "123"} + } + ) + + assert issue_response.status_code == 201 + token_id = issue_response.json()["token_id"] + + # 2. Validate token (should be valid) + validate_response = client.post( + "/tokens/validate", + json={"token_id": token_id} + ) + + assert validate_response.status_code == 200 + assert validate_response.json()["valid"] is True + + # 3. Check stats + stats_response = client.get("/tokens/stats") + assert stats_response.json()["active_tokens"] == 1 + + # 4. Revoke token + revoke_response = client.delete(f"/tokens/{token_id}") + assert revoke_response.status_code == 204 + + # 5. Validate token again (should be invalid) + validate_again = client.post( + "/tokens/validate", + json={"token_id": token_id} + ) + assert validate_again.json()["valid"] is False + + # 6. Check stats (should be 0) + stats_final = client.get("/tokens/stats") + assert stats_final.json()["active_tokens"] == 0 + + def test_multiple_tokens_independent(self, client): + """Test that multiple tokens are independent""" + # Issue two tokens + token1_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/resource1", + "actions": ["read"] + }, + "ttl_seconds": 300 + } + ) + token1_id = token1_response.json()["token_id"] + + token2_response = client.post( + "/tokens/issue", + json={ + "scope": { + "resource": "api/resource2", + "actions": ["write"] + }, + "ttl_seconds": 300 + } + ) + token2_id = token2_response.json()["token_id"] + + # Both should be valid + assert client.post("/tokens/validate", json={"token_id": token1_id}).json()["valid"] + assert client.post("/tokens/validate", json={"token_id": token2_id}).json()["valid"] + + # Revoke token1 + client.delete(f"/tokens/{token1_id}") + + # Token1 should be invalid, token2 should still be valid + assert not client.post("/tokens/validate", json={"token_id": token1_id}).json()["valid"] + assert client.post("/tokens/validate", json={"token_id": token2_id}).json()["valid"] diff --git a/tests/test_token_service.py b/tests/test_token_service.py new file mode 100644 index 0000000..7724fa5 --- /dev/null +++ b/tests/test_token_service.py @@ -0,0 +1,573 @@ +""" +Comprehensive tests for Token Service + +Tests cover: +- Token issuance with various TTLs +- Token validation (valid and invalid cases) +- Token expiry mechanisms +- In-memory store operations +- Cleanup of expired tokens +- API endpoints +""" + +import pytest +import time +from datetime import datetime, timedelta + +from app.services.crypto import ( + TokenService, + TokenScope, + TokenIssuanceRequest, + InMemoryTokenStore, + Token +) +from app.services.crypto.token_service import TokenIssuanceError + + +class TestTokenIssuance: + """Test token issuance functionality""" + + def test_issue_token_basic(self): + """Test basic token issuance""" + service = TokenService() + scope = TokenScope(resource="api/users", actions=["read", "write"]) + + response = service.issue_token(scope=scope, ttl_seconds=300) + + assert response.token_id is not None + assert len(response.token_id) > 0 + assert response.scope == scope + assert response.ttl_seconds == 300 + assert response.expires_at > datetime.utcnow() + + def test_issue_token_with_metadata(self): + """Test token issuance with metadata""" + service = TokenService() + scope = TokenScope(resource="api/data", actions=["read"]) + metadata = {"user_id": "123", "session": "abc"} + + response = service.issue_token(scope=scope, ttl_seconds=60, metadata=metadata) + + # Verify token was stored with metadata + token = service.get_token_info(response.token_id) + assert token is not None + assert token.metadata == metadata + + def test_issue_token_custom_ttl(self): + """Test token issuance with various TTL values""" + service = TokenService() + scope = TokenScope(resource="test", actions=["test"]) + + ttls = [1, 60, 300, 600, 3600] + for ttl in ttls: + response = service.issue_token(scope=scope, ttl_seconds=ttl) + assert response.ttl_seconds == ttl + + def test_issue_token_invalid_ttl_too_low(self): + """Test that TTL below 1 second raises error""" + service = TokenService() + scope = TokenScope(resource="test", actions=["test"]) + + with pytest.raises(TokenIssuanceError, match="at least 1 second"): + service.issue_token(scope=scope, ttl_seconds=0) + + with pytest.raises(TokenIssuanceError): + service.issue_token(scope=scope, ttl_seconds=-1) + + def test_issue_token_invalid_ttl_too_high(self): + """Test that TTL above 3600 seconds raises error""" + service = TokenService() + scope = TokenScope(resource="test", actions=["test"]) + + with pytest.raises(TokenIssuanceError, match="cannot exceed 3600"): + service.issue_token(scope=scope, ttl_seconds=3601) + + def test_issue_token_from_request(self): + """Test token issuance from request object""" + service = TokenService() + request = TokenIssuanceRequest( + scope=TokenScope(resource="api/posts", actions=["create"]), + ttl_seconds=120, + metadata={"client": "web"} + ) + + response = service.issue_token_from_request(request) + + assert response.token_id is not None + assert response.scope.resource == "api/posts" + assert response.ttl_seconds == 120 + + def test_issue_multiple_tokens_unique_ids(self): + """Test that multiple tokens have unique IDs""" + service = TokenService() + scope = TokenScope(resource="test", actions=["test"]) + + tokens = [service.issue_token(scope=scope) for _ in range(10)] + token_ids = [t.token_id for t in tokens] + + # All token IDs should be unique + assert len(set(token_ids)) == len(token_ids) + + +class TestTokenValidation: + """Test token validation functionality""" + + def test_validate_valid_token(self): + """Test validation of a valid token""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read"]) + + # Issue token + issued = service.issue_token(scope=scope, ttl_seconds=60) + + # Validate token + validation = service.validate_token(issued.token_id) + + assert validation.valid is True + assert validation.token_id == issued.token_id + assert validation.scope == scope + assert validation.ttl_remaining > 0 + assert validation.reason is None + + def test_validate_nonexistent_token(self): + """Test validation of non-existent token""" + service = TokenService() + + validation = service.validate_token("nonexistent-token-id") + + assert validation.valid is False + assert validation.reason is not None + assert "not found" in validation.reason.lower() + + def test_validate_expired_token(self): + """Test validation of expired token""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read"]) + + # Issue token with 1 second TTL + issued = service.issue_token(scope=scope, ttl_seconds=1) + + # Wait for token to expire + time.sleep(2) + + # Validate expired token + validation = service.validate_token(issued.token_id) + + assert validation.valid is False + assert "expired" in validation.reason.lower() or "not found" in validation.reason.lower() + + def test_token_ttl_remaining(self): + """Test TTL remaining calculation""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read"]) + + # Issue token with 10 second TTL + issued = service.issue_token(scope=scope, ttl_seconds=10) + + # Immediately validate + validation = service.validate_token(issued.token_id) + + assert validation.valid is True + assert validation.ttl_remaining is not None + # Should have close to 10 seconds remaining (allow some margin) + assert 9.5 <= validation.ttl_remaining <= 10.0 + + # Wait 2 seconds + time.sleep(2) + + # Validate again + validation2 = service.validate_token(issued.token_id) + assert validation2.valid is True + # Should have ~8 seconds remaining + assert 7.5 <= validation2.ttl_remaining <= 8.5 + + +class TestTokenRevocation: + """Test token revocation functionality""" + + def test_revoke_valid_token(self): + """Test revoking a valid token""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read"]) + + # Issue token + issued = service.issue_token(scope=scope, ttl_seconds=300) + + # Verify token is valid + validation = service.validate_token(issued.token_id) + assert validation.valid is True + + # Revoke token + revoked = service.revoke_token(issued.token_id) + assert revoked is True + + # Verify token is no longer valid + validation2 = service.validate_token(issued.token_id) + assert validation2.valid is False + + def test_revoke_nonexistent_token(self): + """Test revoking a non-existent token""" + service = TokenService() + + revoked = service.revoke_token("nonexistent-token") + assert revoked is False + + def test_revoke_already_revoked_token(self): + """Test revoking an already revoked token""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read"]) + + # Issue and revoke token + issued = service.issue_token(scope=scope, ttl_seconds=300) + service.revoke_token(issued.token_id) + + # Try to revoke again + revoked_again = service.revoke_token(issued.token_id) + assert revoked_again is False + + +class TestInMemoryTokenStore: + """Test in-memory token store""" + + def test_store_and_retrieve_token(self): + """Test storing and retrieving a token""" + store = InMemoryTokenStore() + token = Token( + token_id="test-token-123", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + + store.store(token) + retrieved = store.get("test-token-123") + + assert retrieved is not None + assert retrieved.token_id == "test-token-123" + + def test_get_nonexistent_token(self): + """Test retrieving non-existent token""" + store = InMemoryTokenStore() + + retrieved = store.get("nonexistent") + assert retrieved is None + + def test_get_expired_token(self): + """Test retrieving expired token returns None""" + store = InMemoryTokenStore() + token = Token( + token_id="expired-token", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() - timedelta(seconds=1) # Already expired + ) + + store.store(token) + retrieved = store.get("expired-token") + + # Should return None for expired token and remove it + assert retrieved is None + assert "expired-token" not in store + + def test_remove_token(self): + """Test removing a token""" + store = InMemoryTokenStore() + token = Token( + token_id="test-token", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + + store.store(token) + assert store.get("test-token") is not None + + removed = store.remove("test-token") + assert removed is True + assert store.get("test-token") is None + + def test_cleanup_expired_tokens(self): + """Test automatic cleanup of expired tokens""" + store = InMemoryTokenStore(cleanup_interval_seconds=999999) # Disable auto cleanup + + # Add valid token + valid_token = Token( + token_id="valid", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + store.store(valid_token) + + # Add expired tokens + for i in range(5): + expired = Token( + token_id=f"expired-{i}", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() - timedelta(seconds=1) + ) + store.store(expired) + + # Should have 6 tokens total (1 valid + 5 expired) + assert store.count_all() == 6 + + # Run cleanup + removed = store.cleanup_expired() + + # Should have removed 5 expired tokens + assert removed == 5 + assert store.count() == 1 + assert store.get("valid") is not None + + def test_count_tokens(self): + """Test counting tokens""" + store = InMemoryTokenStore(cleanup_interval_seconds=999999) + + assert store.count() == 0 + + # Add tokens + for i in range(3): + token = Token( + token_id=f"token-{i}", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + store.store(token) + + assert store.count() == 3 + assert len(store) == 3 + + def test_clear_all_tokens(self): + """Test clearing all tokens""" + store = InMemoryTokenStore() + + # Add tokens + for i in range(5): + token = Token( + token_id=f"token-{i}", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + store.store(token) + + assert store.count() == 5 + + store.clear() + assert store.count() == 0 + + def test_contains_operator(self): + """Test 'in' operator for token store""" + store = InMemoryTokenStore() + token = Token( + token_id="test-token", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + + store.store(token) + + assert "test-token" in store + assert "nonexistent" not in store + + def test_shutdown(self): + """Test store shutdown""" + store = InMemoryTokenStore() + store.shutdown() + # Should not raise error + + +class TestTokenServiceStatistics: + """Test token service statistics and utility methods""" + + def test_get_active_token_count(self): + """Test getting count of active tokens""" + service = TokenService() + scope = TokenScope(resource="test", actions=["read"]) + + assert service.get_active_token_count() == 0 + + # Issue tokens + for _ in range(3): + service.issue_token(scope=scope, ttl_seconds=60) + + assert service.get_active_token_count() == 3 + + def test_cleanup_expired_tokens(self): + """Test manual cleanup of expired tokens""" + service = TokenService() + scope = TokenScope(resource="test", actions=["read"]) + + # Issue short-lived tokens + for _ in range(3): + service.issue_token(scope=scope, ttl_seconds=1) + + # Wait for expiry + time.sleep(2) + + # Manual cleanup + removed = service.cleanup_expired_tokens() + assert removed == 3 + + def test_clear_all_tokens(self): + """Test clearing all tokens""" + service = TokenService() + scope = TokenScope(resource="test", actions=["read"]) + + # Issue tokens + for _ in range(5): + service.issue_token(scope=scope, ttl_seconds=300) + + assert service.get_active_token_count() == 5 + + service.clear_all_tokens() + assert service.get_active_token_count() == 0 + + def test_get_token_info(self): + """Test getting token information""" + service = TokenService() + scope = TokenScope(resource="api/test", actions=["read", "write"]) + + issued = service.issue_token(scope=scope, ttl_seconds=120) + + info = service.get_token_info(issued.token_id) + assert info is not None + assert info.token_id == issued.token_id + assert info.scope.resource == "api/test" + assert info.scope.actions == ["read", "write"] + + +class TestTokenModels: + """Test token model classes""" + + def test_token_is_expired_property(self): + """Test token is_expired property""" + # Valid token + token = Token( + token_id="test", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + assert token.is_expired is False + + # Expired token + expired_token = Token( + token_id="test", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() - timedelta(seconds=1) + ) + assert expired_token.is_expired is True + + def test_token_ttl_remaining_property(self): + """Test token ttl_remaining property""" + token = Token( + token_id="test", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + + # Should have close to 60 seconds remaining + assert 59.5 <= token.ttl_remaining <= 60.0 + + # Expired token should have 0 TTL + expired_token = Token( + token_id="test", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() - timedelta(seconds=1) + ) + assert expired_token.ttl_remaining == 0.0 + + def test_token_to_dict(self): + """Test token to_dict method""" + token = Token( + token_id="test-123", + scope=TokenScope( + resource="api/users", + actions=["read"], + metadata={"region": "us-west"} + ), + expires_at=datetime.utcnow() + timedelta(seconds=60), + metadata={"user": "alice"} + ) + + token_dict = token.to_dict() + + assert token_dict["token_id"] == "test-123" + assert token_dict["scope"]["resource"] == "api/users" + assert token_dict["scope"]["actions"] == ["read"] + assert token_dict["metadata"]["user"] == "alice" + assert "created_at" in token_dict + assert "expires_at" in token_dict + assert "is_expired" in token_dict + assert "ttl_remaining" in token_dict + + def test_token_scope_model(self): + """Test TokenScope model""" + scope = TokenScope( + resource="api/data", + actions=["read", "write", "delete"], + metadata={"owner": "admin"} + ) + + assert scope.resource == "api/data" + assert scope.actions == ["read", "write", "delete"] + assert scope.metadata["owner"] == "admin" + + def test_token_issuance_request_validation(self): + """Test TokenIssuanceRequest validation""" + # Valid request + request = TokenIssuanceRequest( + scope=TokenScope(resource="test", actions=["read"]), + ttl_seconds=300 + ) + assert request.ttl_seconds == 300 + + # Test default TTL + request_default = TokenIssuanceRequest( + scope=TokenScope(resource="test", actions=["read"]) + ) + assert request_default.ttl_seconds == 300 # Default + + # Test TTL validation + with pytest.raises(ValueError): + TokenIssuanceRequest( + scope=TokenScope(resource="test", actions=["read"]), + ttl_seconds=0 + ) + + with pytest.raises(ValueError): + TokenIssuanceRequest( + scope=TokenScope(resource="test", actions=["read"]), + ttl_seconds=5000 + ) + + +class TestConcurrency: + """Test thread-safety of token store""" + + def test_concurrent_token_operations(self): + """Test concurrent token storage and retrieval""" + import threading + + store = InMemoryTokenStore() + errors = [] + + def store_tokens(): + try: + for i in range(10): + token = Token( + token_id=f"token-{threading.current_thread().name}-{i}", + scope=TokenScope(resource="test", actions=["read"]), + expires_at=datetime.utcnow() + timedelta(seconds=60) + ) + store.store(token) + except Exception as e: + errors.append(e) + + # Run multiple threads + threads = [threading.Thread(target=store_tokens) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Should have no errors + assert len(errors) == 0 + + # Should have 50 tokens (5 threads * 10 tokens each) + assert store.count() == 50