From a294e31aac200e95f9c3a2d24252e4043612bd86 Mon Sep 17 00:00:00 2001 From: Oded BD Date: Thu, 24 Feb 2022 16:20:18 +0200 Subject: [PATCH] add token validation to enforcer and local routes --- horizon/enforcer/api.py | 6 ++- horizon/local/api.py | 5 ++- horizon/token_utils.py | 87 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- 4 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 horizon/token_utils.py diff --git a/horizon/enforcer/api.py b/horizon/enforcer/api.py index 8d67a567..fa49c5f8 100644 --- a/horizon/enforcer/api.py +++ b/horizon/enforcer/api.py @@ -1,11 +1,12 @@ import json from typing import Optional, Dict -from fastapi import APIRouter, status, Response +from fastapi import APIRouter, Depends, status, Response from opal_client.policy_store import BasePolicyStoreClient, DEFAULT_POLICY_STORE_GETTER from opal_client.policy_store.opa_client import fail_silently from opal_client.logger import logger from horizon.config import sidecar_config +from horizon.token_utils import JWTBearer from horizon.enforcer.schemas import AuthorizationQuery, AuthorizationResult @@ -66,7 +67,8 @@ def log_query_and_result(query: AuthorizationQuery, response: Response): ) - @router.post("/allowed", response_model=AuthorizationResult, status_code=status.HTTP_200_OK, response_model_exclude_none=True) + @router.post("/allowed", response_model=AuthorizationResult, status_code=status.HTTP_200_OK, response_model_exclude_none=True, + dependencies=[Depends(JWTBearer())]) async def is_allowed(query: AuthorizationQuery): async def _is_allowed(): return await policy_store.get_data_with_input(path="rbac", input=query) diff --git a/horizon/local/api.py b/horizon/local/api.py index 343cbce5..1fddde32 100644 --- a/horizon/local/api.py +++ b/horizon/local/api.py @@ -1,13 +1,14 @@ from typing import Dict, Any, List, Optional -from fastapi import APIRouter, status, HTTPException +from fastapi import APIRouter, Depends, status, HTTPException from opal_client.policy_store import BasePolicyStoreClient, DEFAULT_POLICY_STORE_GETTER from horizon.local.schemas import Message, SyncedRole, SyncedUser +from horizon.token_utils import JWTBearer def init_local_cache_api_router(policy_store:BasePolicyStoreClient=None): policy_store = policy_store or DEFAULT_POLICY_STORE_GETTER() - router = APIRouter() + router = APIRouter(dependencies=[Depends(JWTBearer())]) def error_message(msg: str): return { diff --git a/horizon/token_utils.py b/horizon/token_utils.py new file mode 100644 index 00000000..84a60621 --- /dev/null +++ b/horizon/token_utils.py @@ -0,0 +1,87 @@ +import requests +from fastapi import Request, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jose import jwt, JWTError + + +AUTH0_DOMAIN = "acalla-dev.us.auth0.com" +API_AUDIENCE = f"https://api.acalla.com/v1/" +ALGORITHMS = ["RS256"] +JWK_TOKENS = requests.get(f"https://{AUTH0_DOMAIN}/.well-known/jwks.json").json() + +class JWTBearer(HTTPBearer): + def __init__(self, auto_error: bool = True): + super(JWTBearer, self).__init__(auto_error=auto_error) + async def __call__(self, request: Request): + credentials: HTTPAuthorizationCredentials = await super( + JWTBearer, self + ).__call__(request) + if credentials: + if not credentials.scheme == "Bearer": + raise HTTPException( + status_code=403, detail="Invalid authentication scheme." + ) + if not type(self)._verify_jwt(credentials.credentials): + raise HTTPException( + status_code=403, detail="Invalid token or expired token." + ) + return credentials.credentials + else: + raise HTTPException(status_code=403, detail="Invalid authorization code.") + + @classmethod + def _verify_jwt(cls, jwtoken: str) -> bool: + # is_token_valid: bool = False + + payload = cls._decode_jwt(jwtoken) + return payload + # except HTTPException: + # payload = None + # if payload: + # is_token_valid = True + # return is_token_valid + + @classmethod + def _decode_jwt(cls, jwtoken: str) -> dict: + rsa_key = cls._get_rsa_key(jwtoken) + try: + return jwt.decode( + jwtoken, + rsa_key, + algorithms=ALGORITHMS, + audience=API_AUDIENCE, + issuer=f"https://{AUTH0_DOMAIN}/", + ) + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, + detail="token is expired", + ) + except jwt.JWTClaimsError: + raise HTTPException( + status_code=401, + detail="incorrect claims, please check the audience and issuer", + ) + except Exception as e: + print(e) + raise HTTPException( + status_code=401, + detail="Unable to parse authentication token." + str(e), + ) + + @staticmethod + def _get_rsa_key(jwtoken: str) -> dict: + unverified_header = jwt.get_unverified_header(jwtoken) + for key in JWK_TOKENS["keys"]: + if key["kid"] == unverified_header["kid"]: + return { + "kty": key["kty"], + "kid": key["kid"], + "use": key["use"], + "n": key["n"], + "e": key["e"], + } + raise HTTPException( + status_code=401, + detail="Unable to find appropriate key", + ) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 22876e16..525cae72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ tenacity==6.3.1 Jinja2==3.0.3 logzio-python-handler rook -ddtrace \ No newline at end of file +ddtrace +jose \ No newline at end of file