Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/digi_server/app_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ async def _configure_logging(self):
)

def _configure_rbac(self):
self._db.register_delete_hook(self.rbac.rbac_db.check_object_deletion)
self.rbac.add_mapping(User, Show, [Show.id, Show.name])
self.rbac.add_mapping(User, CueType, [CueType.id, CueType.prefix])
self.rbac.add_mapping(User, Script, [Script.id])
Expand Down
4 changes: 4 additions & 0 deletions server/rbac/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def __init__(self, app: "DigiScriptServer"):
self._rbac_db = RBACDatabase(app.get_db(), app)
self._display_fields = {}

@property
def rbac_db(self):
return self._rbac_db

def add_mapping(
self, actor: type, resource: type, display_fields: Optional[List] = None
) -> None:
Expand Down
68 changes: 55 additions & 13 deletions server/rbac/rbac_db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
from collections import defaultdict
from copy import deepcopy
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from anytree import Node
from sqlalchemy import Column, ForeignKey, Integer, Table, TypeDecorator, inspect
Expand All @@ -13,7 +13,7 @@
from rbac.exceptions import RBACException
from rbac.role import Role
from utils import tree
from utils.database import DigiSQLAlchemy
from utils.database import DigiDBSession, DigiSQLAlchemy

if TYPE_CHECKING:
from digi_server.app_server import DigiScriptServer
Expand Down Expand Up @@ -133,6 +133,35 @@ def add_mapping(self, actor: type, resource: type) -> None:
self._resource_mappings[actor_inspect.mapped_table.fullname].append(resource)
logger.info(f"Created RBAC mapping {table_name}")

@property
def mapped_resource_tables(self) -> List[str]:
mapped_tables = set()
for _actor_table, resource_tables in self._resource_mappings.items():
for resource_table in resource_tables:
resource_inspect = inspect(resource_table)
mapped_tables.add(resource_inspect.mapped_table.fullname)
return list(mapped_tables)

@property
def resource_table_mappings(self) -> Dict[str, List[str]]:
output = defaultdict(set)
for actor_table, resource_tables in self._resource_mappings.items():
for resource_table in resource_tables:
resource_inspect = inspect(resource_table)
output[resource_inspect.mapped_table.fullname].add(actor_table)
real_output = {}
for resource_table in output:
real_output[resource_table] = list(output[resource_table])
return real_output

def check_object_deletion(self, _session: DigiDBSession, delete_object: db.Model):
obj_inspect = inspect(delete_object)
table_name = obj_inspect.mapper.mapped_table.fullname
if table_name in self._resource_mappings:
self.delete_actor(delete_object)
if table_name in self.mapped_resource_tables:
self.delete_resource(delete_object)

def _validate_mapping(self, actor: db.Model, resource: db.Model) -> str:
if not isinstance(actor, db.Model):
raise RBACException("actor must be class instance, not object")
Expand Down Expand Up @@ -206,6 +235,17 @@ def get_all_roles(self, actor: db.Model) -> Dict:
)
return roles

def _delete_from_rbac_db(self, table_name: str, cols: Dict[str, Any]):
if table_name not in self._mappings:
RBACException("Could not get table for actor/resource")
with self._db.sessionmaker() as session:
rbac_assignments = (
session.query(self._mappings[table_name]).filter_by(**cols).all()
)
for rbac_assignment in rbac_assignments:
session.delete(rbac_assignment)
session.commit()

def delete_actor(self, actor: db.Model) -> None:
actor_inspect = inspect(actor)
actor_cols = _get_mapping_columns(actor=actor, resource=None)
Expand All @@ -218,17 +258,19 @@ def delete_actor(self, actor: db.Model) -> None:
f"rbac_{actor_inspect.mapper.mapped_table.fullname}_"
f"{resource_inspect.mapper.mapped_table.fullname}"
)
if table_name not in self._mappings:
RBACException("Could not get table for actor/resource")
with self._db.sessionmaker() as session:
rbac_assignments = (
session.query(self._mappings[table_name])
.filter_by(**actor_cols)
.all()
)
for rbac_assignment in rbac_assignments:
session.delete(rbac_assignment)
session.commit()
self._delete_from_rbac_db(table_name, actor_cols)

def delete_resource(self, resource: db.Model):
resource_inspect = inspect(resource)
resource_cols = _get_mapping_columns(actor=None, resource=resource)
actor_mappings = self.resource_table_mappings.get(
resource_inspect.mapper.mapped_table.fullname, []
)
for actor in actor_mappings:
table_name = (
f"rbac_{actor}_" f"{resource_inspect.mapper.mapped_table.fullname}"
)
self._delete_from_rbac_db(table_name, resource_cols)

@functools.lru_cache()
def _has_link_to_show(self, table: Table):
Expand Down
12 changes: 12 additions & 0 deletions server/utils/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
from typing import Callable, List

from sqlalchemy import MetaData, event
from sqlalchemy.orm import declarative_base, sessionmaker
Expand All @@ -16,6 +17,8 @@ def post_delete(self, session: "DigiDBSession"):
class DigiDBSession(SessionEx):

def _delete_impl(self, state, obj, head):
for hook in self.db.delete_hooks:
hook(self, obj)
if isinstance(obj, DeleteMixin):
obj.pre_delete(self)
super()._delete_impl(state, obj, head)
Expand All @@ -30,6 +33,7 @@ def __init__(self, url=None, binds=None, session_options=None, engine_options=No
self.sessionmaker = None
# Store the original create_engine method
original_create_engine = self.create_engine
self._delete_hooks: List[Callable] = []

# Override create_engine to add event listener for SQLite
def create_engine_with_fk_support(*args, **kwargs):
Expand Down Expand Up @@ -75,3 +79,11 @@ def make_declarative_base(self):
}
metadata = MetaData(naming_convention=convention)
return declarative_base(metaclass=BindMeta, metadata=metadata)

@property
def delete_hooks(self):
return self._delete_hooks

def register_delete_hook(self, hook: Callable):
if hook not in self._delete_hooks:
self._delete_hooks.append(hook)
Loading