diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index 39867d32..3acfd61f 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -249,6 +249,7 @@ async def _configure_logging(self): ) def _configure_rbac(self): + self._db.register_delete_hook(self.rbac.rbac_db.check_object_deletion) self.rbac.add_mapping(User, Show, [Show.id, Show.name]) self.rbac.add_mapping(User, CueType, [CueType.id, CueType.prefix]) self.rbac.add_mapping(User, Script, [Script.id]) diff --git a/server/rbac/rbac.py b/server/rbac/rbac.py index 6d73f70a..2829ff43 100644 --- a/server/rbac/rbac.py +++ b/server/rbac/rbac.py @@ -17,6 +17,10 @@ def __init__(self, app: "DigiScriptServer"): self._rbac_db = RBACDatabase(app.get_db(), app) self._display_fields = {} + @property + def rbac_db(self): + return self._rbac_db + def add_mapping( self, actor: type, resource: type, display_fields: Optional[List] = None ) -> None: diff --git a/server/rbac/rbac_db.py b/server/rbac/rbac_db.py index 4f65363c..9fc2821c 100644 --- a/server/rbac/rbac_db.py +++ b/server/rbac/rbac_db.py @@ -1,7 +1,7 @@ import functools from collections import defaultdict from copy import deepcopy -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from anytree import Node from sqlalchemy import Column, ForeignKey, Integer, Table, TypeDecorator, inspect @@ -13,7 +13,7 @@ from rbac.exceptions import RBACException from rbac.role import Role from utils import tree -from utils.database import DigiSQLAlchemy +from utils.database import DigiDBSession, DigiSQLAlchemy if TYPE_CHECKING: from digi_server.app_server import DigiScriptServer @@ -133,6 +133,35 @@ def add_mapping(self, actor: type, resource: type) -> None: self._resource_mappings[actor_inspect.mapped_table.fullname].append(resource) logger.info(f"Created RBAC mapping {table_name}") + @property + def mapped_resource_tables(self) -> List[str]: + mapped_tables = set() + for _actor_table, resource_tables in self._resource_mappings.items(): + for resource_table in resource_tables: + resource_inspect = inspect(resource_table) + mapped_tables.add(resource_inspect.mapped_table.fullname) + return list(mapped_tables) + + @property + def resource_table_mappings(self) -> Dict[str, List[str]]: + output = defaultdict(set) + for actor_table, resource_tables in self._resource_mappings.items(): + for resource_table in resource_tables: + resource_inspect = inspect(resource_table) + output[resource_inspect.mapped_table.fullname].add(actor_table) + real_output = {} + for resource_table in output: + real_output[resource_table] = list(output[resource_table]) + return real_output + + def check_object_deletion(self, _session: DigiDBSession, delete_object: db.Model): + obj_inspect = inspect(delete_object) + table_name = obj_inspect.mapper.mapped_table.fullname + if table_name in self._resource_mappings: + self.delete_actor(delete_object) + if table_name in self.mapped_resource_tables: + self.delete_resource(delete_object) + def _validate_mapping(self, actor: db.Model, resource: db.Model) -> str: if not isinstance(actor, db.Model): raise RBACException("actor must be class instance, not object") @@ -206,6 +235,17 @@ def get_all_roles(self, actor: db.Model) -> Dict: ) return roles + def _delete_from_rbac_db(self, table_name: str, cols: Dict[str, Any]): + if table_name not in self._mappings: + RBACException("Could not get table for actor/resource") + with self._db.sessionmaker() as session: + rbac_assignments = ( + session.query(self._mappings[table_name]).filter_by(**cols).all() + ) + for rbac_assignment in rbac_assignments: + session.delete(rbac_assignment) + session.commit() + def delete_actor(self, actor: db.Model) -> None: actor_inspect = inspect(actor) actor_cols = _get_mapping_columns(actor=actor, resource=None) @@ -218,17 +258,19 @@ def delete_actor(self, actor: db.Model) -> None: f"rbac_{actor_inspect.mapper.mapped_table.fullname}_" f"{resource_inspect.mapper.mapped_table.fullname}" ) - if table_name not in self._mappings: - RBACException("Could not get table for actor/resource") - with self._db.sessionmaker() as session: - rbac_assignments = ( - session.query(self._mappings[table_name]) - .filter_by(**actor_cols) - .all() - ) - for rbac_assignment in rbac_assignments: - session.delete(rbac_assignment) - session.commit() + self._delete_from_rbac_db(table_name, actor_cols) + + def delete_resource(self, resource: db.Model): + resource_inspect = inspect(resource) + resource_cols = _get_mapping_columns(actor=None, resource=resource) + actor_mappings = self.resource_table_mappings.get( + resource_inspect.mapper.mapped_table.fullname, [] + ) + for actor in actor_mappings: + table_name = ( + f"rbac_{actor}_" f"{resource_inspect.mapper.mapped_table.fullname}" + ) + self._delete_from_rbac_db(table_name, resource_cols) @functools.lru_cache() def _has_link_to_show(self, table: Table): diff --git a/server/utils/database.py b/server/utils/database.py index 63f087df..bbe91633 100644 --- a/server/utils/database.py +++ b/server/utils/database.py @@ -1,4 +1,5 @@ import functools +from typing import Callable, List from sqlalchemy import MetaData, event from sqlalchemy.orm import declarative_base, sessionmaker @@ -16,6 +17,8 @@ def post_delete(self, session: "DigiDBSession"): class DigiDBSession(SessionEx): def _delete_impl(self, state, obj, head): + for hook in self.db.delete_hooks: + hook(self, obj) if isinstance(obj, DeleteMixin): obj.pre_delete(self) super()._delete_impl(state, obj, head) @@ -30,6 +33,7 @@ def __init__(self, url=None, binds=None, session_options=None, engine_options=No self.sessionmaker = None # Store the original create_engine method original_create_engine = self.create_engine + self._delete_hooks: List[Callable] = [] # Override create_engine to add event listener for SQLite def create_engine_with_fk_support(*args, **kwargs): @@ -75,3 +79,11 @@ def make_declarative_base(self): } metadata = MetaData(naming_convention=convention) return declarative_base(metaclass=BindMeta, metadata=metadata) + + @property + def delete_hooks(self): + return self._delete_hooks + + def register_delete_hook(self, hook: Callable): + if hook not in self._delete_hooks: + self._delete_hooks.append(hook)