From 041d1498a1e2261c465aae9a71a8ff1cf7e35677 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 23 Apr 2025 00:23:30 +0100 Subject: [PATCH 1/5] Allow users to be deleted --- client/src/store/modules/user/user.js | 17 +++ client/src/store/store.js | 6 +- .../src/vue_components/config/ConfigUsers.vue | 16 +- .../versions/29471f7cf7d2_user_deletion.py | 75 +++++++++ server/controllers/api/auth.py | 143 ++++++++++++------ server/models/session.py | 4 +- server/models/user.py | 2 +- server/rbac/rbac.py | 3 + server/rbac/rbac_db.py | 66 +++++--- server/registry/named_locks.py | 36 +++++ 10 files changed, 295 insertions(+), 73 deletions(-) create mode 100644 server/alembic_config/versions/29471f7cf7d2_user_deletion.py create mode 100644 server/registry/named_locks.py diff --git a/client/src/store/modules/user/user.js b/client/src/store/modules/user/user.js index f5665e30..ec333ea2 100644 --- a/client/src/store/modules/user/user.js +++ b/client/src/store/modules/user/user.js @@ -52,12 +52,29 @@ export default { body: JSON.stringify(user), }); if (response.ok) { + await context.dispatch('GET_USERS'); Vue.$toast.success('User created!'); } else { log.error('Unable to create user'); Vue.$toast.error('Unable to create user'); } }, + async DELETE_USER(context, userId) { + const response = await fetch(`${makeURL('/api/v1/auth/delete')}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ id: userId }), + }); + if (response.ok) { + await context.dispatch('GET_USERS'); + Vue.$toast.success('User deleted!'); + } else { + log.error('Unable to delete user'); + Vue.$toast.error('Unable to delete user'); + } + }, async USER_LOGIN(context, user) { const response = await fetch(`${makeURL('/api/v1/auth/login')}`, { method: 'POST', diff --git a/client/src/store/store.js b/client/src/store/store.js index 05fadbc8..1b59a0fd 100644 --- a/client/src/store/store.js +++ b/client/src/store/store.js @@ -50,10 +50,8 @@ export default new Vuex.Store({ }, async SHOW_CHANGED(context) { if (context.rootGetters.CURRENT_USER != null) { - const response = await fetch(`${makeURL('/api/v1/auth/validate')}`); - if (response.status === 401) { - await context.dispatch('USER_LOGOUT'); - } + await context.dispatch('GET_CURRENT_USER'); + await context.dispatch('GET_CURRENT_RBAC'); } window.location.reload(); }, diff --git a/client/src/vue_components/config/ConfigUsers.vue b/client/src/vue_components/config/ConfigUsers.vue index 581fd096..4c7252b7 100644 --- a/client/src/vue_components/config/ConfigUsers.vue +++ b/client/src/vue_components/config/ConfigUsers.vue @@ -29,6 +29,13 @@ > RBAC + + Delete + @@ -89,7 +96,14 @@ export default { setEditUser(userId) { this.editUser = userId; }, - ...mapActions(['GET_USERS']), + async deleteUser(data) { + const msg = `Are you sure you want to delete ${data.item.username}?`; + const action = await this.$bvModal.msgBoxConfirm(msg, {}); + if (action === true) { + await this.DELETE_USER(data.item.id); + } + }, + ...mapActions(['GET_USERS', 'DELETE_USER']), }, computed: { ...mapGetters(['SHOW_USERS', 'CURRENT_SHOW']), diff --git a/server/alembic_config/versions/29471f7cf7d2_user_deletion.py b/server/alembic_config/versions/29471f7cf7d2_user_deletion.py new file mode 100644 index 00000000..0275c241 --- /dev/null +++ b/server/alembic_config/versions/29471f7cf7d2_user_deletion.py @@ -0,0 +1,75 @@ +"""User deletion + +Revision ID: 29471f7cf7d2 +Revises: be353176c064 +Create Date: 2025-04-23 00:01:32.182458 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "29471f7cf7d2" +down_revision: Union[str, None] = "be353176c064" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sessions", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("fk_sessions_user_id_user"), + "user", + ["user_id"], + ["id"], + ondelete="SET NULL", + ) + + with op.batch_alter_table("showsession", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("fk_showsession_user_id_user"), + "user", + ["user_id"], + ["id"], + ondelete="SET NULL", + ) + + with op.batch_alter_table("user_settings", schema=None) as batch_op: + batch_op.create_foreign_key( + batch_op.f("fk_user_settings_user_id_user"), + "user", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("user_settings", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_user_settings_user_id_user"), type_="foreignkey" + ) + batch_op.create_foreign_key( + "fk_user_settings_user_id_user", "user", ["user_id"], ["id"] + ) + + with op.batch_alter_table("showsession", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_showsession_user_id_user"), type_="foreignkey" + ) + batch_op.create_foreign_key(None, "user", ["user_id"], ["id"]) + + with op.batch_alter_table("sessions", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_sessions_user_id_user"), type_="foreignkey" + ) + batch_op.create_foreign_key(None, "user", ["user_id"], ["id"]) + + # ### end Alembic commands ### diff --git a/server/controllers/api/auth.py b/server/controllers/api/auth.py index 76d59284..30759c7e 100644 --- a/server/controllers/api/auth.py +++ b/server/controllers/api/auth.py @@ -1,15 +1,16 @@ from datetime import datetime import bcrypt -from tornado import escape, web +from tornado import escape, gen, web from tornado.ioloop import IOLoop from models.session import Session from models.user import User +from registry.named_locks import NamedLockRegistry from schemas.schemas import UserSchema from utils.web.base_controller import BaseAPIController from utils.web.route import ApiRoute, ApiVersion -from utils.web.web_decorators import require_admin, requires_show +from utils.web.web_decorators import no_live_session, require_admin, requires_show @ApiRoute("auth/create", ApiVersion.V1) @@ -69,6 +70,73 @@ async def post(self): await self.finish({"message": "Successfully created user"}) +@ApiRoute("auth/delete", ApiVersion.V1) +class UserDeleteController(BaseAPIController): + @web.authenticated + @require_admin + @no_live_session + async def post(self): + data = escape.json_decode(self.request.body) + user_to_delete = data.get("id", None) + if not user_to_delete: + self.set_status(400) + await self.finish({"message": "Id missing"}) + return + + with self.make_session() as session: + user_to_delete: User = session.query(User).get(int(user_to_delete)) + if not user_to_delete: + self.set_status(400) + await self.finish({"message": "Could not find user to delete"}) + return + + if user_to_delete.is_admin: + self.set_status(400) + await self.finish({"message": "Cannot delete admin user"}) + return + + async with NamedLockRegistry.acquire( + f"UserLock::{user_to_delete.username}" + ): + # First, log out all sessions for this user + await self.application.ws_send_to_user( + user_to_delete.id, "NOOP", "USER_LOGOUT", {} + ) + + # Then really make sure we have logged out the user for all sessions (basically, + # wait for the websocket ops to finish) + session_logout_attempts = 0 + user_sessions = ( + session.query(Session) + .filter(Session.user_id == user_to_delete.id) + .all() + ) + while user_sessions and session_logout_attempts < 5: + for user_session in user_sessions: + ws_session = self.application.get_ws(user_session.internal_id) + await ws_session.write_message( + {"OP": "NOOP", "DATA": "{}", "ACTION": "USER_LOGOUT"} + ) + await gen.sleep(0.2) + user_sessions = ( + session.query(Session) + .filter(Session.user_id == user_to_delete.id) + .all() + ) + session_logout_attempts += 1 + + # Delete all RBAC associations for this user + self.application.rbac.delete_actor(user_to_delete) + + # Then we can delete the user + session.delete(user_to_delete) + session.commit() + + self.set_status(200) + await self.application.ws_send_to_all("NOOP", "GET_USERS", {}) + await self.finish({"message": "Successfully deleted user"}) + + @ApiRoute("auth/login", ApiVersion.V1) class LoginHandler(BaseAPIController): async def post(self): @@ -87,34 +155,35 @@ async def post(self): return with self.make_session() as session: - user = session.query(User).filter(User.username == username).first() - if not user: - self.set_status(401) - await self.finish({"message": "Invalid username/password"}) - return - - password_equal = await IOLoop.current().run_in_executor( - None, - bcrypt.checkpw, - escape.utf8(password), - escape.utf8(user.password), - ) + async with NamedLockRegistry.acquire(f"UserLock::{username}"): + user = session.query(User).filter(User.username == username).first() + if not user: + self.set_status(401) + await self.finish({"message": "Invalid username/password"}) + return + + password_equal = await IOLoop.current().run_in_executor( + None, + bcrypt.checkpw, + escape.utf8(password), + escape.utf8(user.password), + ) - if password_equal: - session_id = data.get("session_id", "") - if session_id: - ws_session: Session = session.query(Session).get(session_id) - if ws_session: - ws_session.user = user - user.last_login = datetime.utcnow() - session.commit() + if password_equal: + session_id = data.get("session_id", "") + if session_id: + ws_session: Session = session.query(Session).get(session_id) + if ws_session: + ws_session.user = user + user.last_login = datetime.utcnow() + session.commit() - self.set_secure_cookie("digiscript_user_id", str(user.id)) - self.set_status(200) - await self.finish({"message": "Successful log in"}) - else: - self.set_status(401) - await self.finish({"message": "Invalid username/password"}) + self.set_secure_cookie("digiscript_user_id", str(user.id)) + self.set_status(200) + await self.finish({"message": "Successful log in"}) + else: + self.set_status(401) + await self.finish({"message": "Invalid username/password"}) @ApiRoute("auth/logout", ApiVersion.V1) @@ -140,24 +209,6 @@ async def post(self): await self.finish({"message": "No user logged in"}) -@ApiRoute("auth/validate", ApiVersion.V1) -class AuthValidationHandler(BaseAPIController): - @web.authenticated - async def get(self): - if self.current_user["is_admin"]: - self.set_status(200) - await self.finish({"message": "OK"}) - elif ( - self.current_show - and self.current_user["show_id"] == self.current_show["id"] - ): - self.set_status(200) - await self.finish({"message": "OK"}) - else: - self.set_status(401) - self.write({"message": "Not Authenticated"}) - - @ApiRoute("auth/users", ApiVersion.V1) class UsersHandler(BaseAPIController): @web.authenticated diff --git a/server/models/session.py b/server/models/session.py index c59414df..ff9a41f9 100644 --- a/server/models/session.py +++ b/server/models/session.py @@ -12,7 +12,7 @@ class Session(db.Model): last_ping = Column(Float()) last_pong = Column(Float()) is_editor = Column(Boolean(), default=False, index=True) - user_id = Column(Integer, ForeignKey("user.id"), index=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="SET NULL"), index=True) user = relationship( "User", uselist=False, backref=backref("sessions", uselist=True) @@ -27,7 +27,7 @@ class ShowSession(db.Model): start_date_time = Column(DateTime()) end_date_time = Column(DateTime()) - user_id = Column(Integer, ForeignKey("user.id"), index=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="SET NULL"), index=True) client_internal_id = Column(String(255), ForeignKey("sessions.internal_id")) last_client_internal_id = Column(String(255)) latest_line_ref = Column(String) diff --git a/server/models/user.py b/server/models/user.py index cd1eed7a..4f5f91bf 100644 --- a/server/models/user.py +++ b/server/models/user.py @@ -23,7 +23,7 @@ class UserSettings(db.Model): __tablename__ = "user_settings" id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey("user.id"), index=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), index=True) settings_type = Column(String, index=True) settings = Column(Text) diff --git a/server/rbac/rbac.py b/server/rbac/rbac.py index 6e762bb0..6d73f70a 100644 --- a/server/rbac/rbac.py +++ b/server/rbac/rbac.py @@ -31,6 +31,9 @@ def add_mapping( self._rbac_db.add_mapping(actor, resource) self._display_fields[resource] = [field.key for field in display_fields] + def delete_actor(self, actor: db.Model) -> None: + self._rbac_db.delete_actor(actor) + def give_role(self, actor: db.Model, resource: db.Model, role: Role) -> None: self._rbac_db.give_role(actor, resource, role) diff --git a/server/rbac/rbac_db.py b/server/rbac/rbac_db.py index 36ee566b..4f65363c 100644 --- a/server/rbac/rbac_db.py +++ b/server/rbac/rbac_db.py @@ -48,26 +48,30 @@ def process_result_value(self, value, dialect): return Role(value) -def _get_mapping_columns(actor: db.Model, resource: db.Model) -> dict: - actor_inspect = inspect(actor) - resource_inspect = inspect(resource) +def _get_mapping_columns( + actor: Optional[db.Model], resource: Optional[db.Model] +) -> dict: cols = {} - cols.update( - { - f"{actor_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( - actor, col.key - ) - for col in actor_inspect.mapper.primary_key - } - ) - cols.update( - { - f"{resource_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( - resource, col.key - ) - for col in resource_inspect.mapper.primary_key - } - ) + if actor: + actor_inspect = inspect(actor) + cols.update( + { + f"{actor_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( + actor, col.key + ) + for col in actor_inspect.mapper.primary_key + } + ) + if resource: + resource_inspect = inspect(resource) + cols.update( + { + f"{resource_inspect.mapper.mapped_table.fullname}_{col.key}": getattr( + resource, col.key + ) + for col in resource_inspect.mapper.primary_key + } + ) return cols @@ -202,6 +206,30 @@ def get_all_roles(self, actor: db.Model) -> Dict: ) return roles + def delete_actor(self, actor: db.Model) -> None: + actor_inspect = inspect(actor) + actor_cols = _get_mapping_columns(actor=actor, resource=None) + resource_mappings = self._resource_mappings.get( + actor_inspect.mapper.mapped_table.fullname, [] + ) + for resource in resource_mappings: + resource_inspect = inspect(resource) + table_name = ( + f"rbac_{actor_inspect.mapper.mapped_table.fullname}_" + f"{resource_inspect.mapper.mapped_table.fullname}" + ) + if table_name not in self._mappings: + RBACException("Could not get table for actor/resource") + with self._db.sessionmaker() as session: + rbac_assignments = ( + session.query(self._mappings[table_name]) + .filter_by(**actor_cols) + .all() + ) + for rbac_assignment in rbac_assignments: + session.delete(rbac_assignment) + session.commit() + @functools.lru_cache() def _has_link_to_show(self, table: Table): return self.__has_link_to_show(table) diff --git a/server/registry/named_locks.py b/server/registry/named_locks.py new file mode 100644 index 00000000..33582a63 --- /dev/null +++ b/server/registry/named_locks.py @@ -0,0 +1,36 @@ +import threading +from contextlib import asynccontextmanager +from typing import Dict + +import tornado.locks + + +class NamedLockRegistry: + _locks: Dict[str, tornado.locks.Lock] = {} + _registry_lock = threading.Lock() + + @classmethod + def get_lock(cls, name: str) -> tornado.locks.Lock: + if name in cls._locks: + return cls._locks[name] + + with cls._registry_lock: + if name not in cls._locks: + cls._locks[name] = tornado.locks.Lock() + return cls._locks[name] + + @classmethod + @asynccontextmanager + async def acquire(cls, name: str): + lock = cls.get_lock(name) + await lock.acquire() + try: + yield + finally: + lock.release() + + +@asynccontextmanager +async def acquire_lock(name: str): + async with NamedLockRegistry.acquire(name): + yield From 145825abe6a448409e9677ed00e8fbc1c50a682a Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 23 Apr 2025 00:41:51 +0100 Subject: [PATCH 2/5] Add unit tests for named lock registry --- server/.pylintrc | 2 +- server/test/test_named_lock_registry.py | 163 ++++++++++++++++++++++++ server/test_requirements.txt | 3 +- 3 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 server/test/test_named_lock_registry.py diff --git a/server/.pylintrc b/server/.pylintrc index 22bfce7c..b517ac6c 100644 --- a/server/.pylintrc +++ b/server/.pylintrc @@ -1,6 +1,6 @@ [MASTER] init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" -ignore-paths=^alembic_config/versions/.*$, +ignore-paths=^alembic_config/versions/.*$,^test/.*$ [MESSAGES CONTROL] disable= diff --git a/server/test/test_named_lock_registry.py b/server/test/test_named_lock_registry.py new file mode 100644 index 00000000..0ac08d6b --- /dev/null +++ b/server/test/test_named_lock_registry.py @@ -0,0 +1,163 @@ +import asyncio +import threading + +import pytest +import tornado.locks + +from registry.named_locks import NamedLockRegistry, acquire_lock + + +@pytest.fixture +def reset_registry(): + """Reset the NamedLockRegistry between tests""" + NamedLockRegistry._locks = {} + yield + NamedLockRegistry._locks = {} + + +@pytest.mark.asyncio +async def test_get_lock_returns_same_lock_for_same_name(reset_registry): + """Test that get_lock returns the same lock object for the same name""" + lock1 = NamedLockRegistry.get_lock("resource1") + lock2 = NamedLockRegistry.get_lock("resource1") + + assert lock1 is lock2 + assert isinstance(lock1, tornado.locks.Lock) + + +@pytest.mark.asyncio +async def test_get_lock_returns_different_locks_for_different_names(reset_registry): + """Test that get_lock returns different lock objects for different names""" + lock1 = NamedLockRegistry.get_lock("resource1") + lock2 = NamedLockRegistry.get_lock("resource2") + + assert lock1 is not lock2 + assert isinstance(lock1, tornado.locks.Lock) + assert isinstance(lock2, tornado.locks.Lock) + + +@pytest.mark.asyncio +async def test_acquire_lock_context_manager(reset_registry): + """Test that the acquire_lock context manager acquires and releases the lock""" + test_name = "test_resource" + + # Get the lock to manipulate it directly + lock = NamedLockRegistry.get_lock(test_name) + + # Verify it's unlocked initially + assert lock._block._value == 1 + + # Use the context manager + async with acquire_lock(test_name): + # Lock should be acquired within the context + assert lock._block._value == 0 + + # Lock should be released after the context + assert lock._block._value == 1 + + +@pytest.mark.asyncio +async def test_lock_prevents_concurrent_access(reset_registry): + """Test that the lock prevents concurrent access to a critical section""" + test_name = "concurrent_resource" + shared_counter = 0 + iterations = 100 + num_tasks = 10 + + async def increment_counter(): + nonlocal shared_counter + for _ in range(iterations): + async with acquire_lock(test_name): + # Store the current value + current = shared_counter + # Simulate some processing time that could lead to race conditions + await asyncio.sleep(0.001) + # Increment + shared_counter = current + 1 + + # Create and run multiple tasks that try to increment the counter + tasks = [asyncio.create_task(increment_counter()) for _ in range(num_tasks)] + await asyncio.gather(*tasks) + + # Without proper locking, we'd expect the counter to be less than iterations * num_tasks + # due to race conditions, but with locking it should be exactly iterations * num_tasks + assert shared_counter == iterations * num_tasks + + +@pytest.mark.asyncio +async def test_different_locks_dont_block_each_other(reset_registry): + """Test that different named locks don't block each other""" + resource1 = "resource1" + resource2 = "resource2" + + # We'll use these events to control and verify the execution order + resource1_acquired = asyncio.Event() + resource2_accessed = asyncio.Event() + + async def task1(): + async with acquire_lock(resource1): + # Signal that resource1 is locked + resource1_acquired.set() + # Wait for task2 to access resource2 + await resource2_accessed.wait() + # If we get here, it means task2 was able to access resource2 while resource1 was locked + + async def task2(): + # Wait until task1 has acquired the lock on resource1 + await resource1_acquired.wait() + # Try to acquire resource2 - this should not be blocked + async with acquire_lock(resource2): + # Signal that we've accessed resource2 + resource2_accessed.set() + + # Run both tasks concurrently and wait for them to complete + await asyncio.gather(task1(), task2()) + + # If we get here without deadlock, the test passed + assert resource1_acquired.is_set() and resource2_accessed.is_set() + + +@pytest.mark.asyncio +async def test_thread_safety_of_get_lock(reset_registry): + """Test that the get_lock method is thread-safe when creating new locks""" + test_name = "thread_safety_test" + NUM_THREADS = 20 + results = [] + + def get_lock_from_thread(): + # Get the lock from a separate thread + lock = NamedLockRegistry.get_lock(test_name) + results.append(lock) + + # Create and start multiple threads + threads = [ + threading.Thread(target=get_lock_from_thread) for _ in range(NUM_THREADS) + ] + for thread in threads: + thread.start() + + # Wait for all threads to finish + for thread in threads: + thread.join() + + # All threads should have got the same lock object + assert len(results) == NUM_THREADS + for lock in results: + assert lock is results[0] + + +@pytest.mark.asyncio +async def test_acquire_exceptions_release_lock(reset_registry): + """Test that the lock is released even if an exception occurs in the context""" + test_name = "exception_test" + lock = NamedLockRegistry.get_lock(test_name) + + try: + async with NamedLockRegistry.acquire(test_name): + assert lock._block._value == 0 + raise ValueError("Test exception") + except ValueError: + pass + + # Lock should be released even though an exception was raised + assert lock._block._value == 1 diff --git a/server/test_requirements.txt b/server/test_requirements.txt index 0d44c0b2..081b54f2 100644 --- a/server/test_requirements.txt +++ b/server/test_requirements.txt @@ -1 +1,2 @@ -pytest<8.4 \ No newline at end of file +pytest<8.4 +pytest-asyncio From 861f2598922bd73abbcb2980b2898f8ba892a566 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 23 Apr 2025 22:49:51 +0100 Subject: [PATCH 3/5] Add delete hook into RBAC database --- server/digi_server/app_server.py | 1 + server/rbac/rbac.py | 4 ++ server/rbac/rbac_db.py | 68 ++++++++++++++++++++++++++------ server/utils/database.py | 12 ++++++ 4 files changed, 72 insertions(+), 13 deletions(-) diff --git a/server/digi_server/app_server.py b/server/digi_server/app_server.py index 39867d32..3acfd61f 100644 --- a/server/digi_server/app_server.py +++ b/server/digi_server/app_server.py @@ -249,6 +249,7 @@ async def _configure_logging(self): ) def _configure_rbac(self): + self._db.register_delete_hook(self.rbac.rbac_db.check_object_deletion) self.rbac.add_mapping(User, Show, [Show.id, Show.name]) self.rbac.add_mapping(User, CueType, [CueType.id, CueType.prefix]) self.rbac.add_mapping(User, Script, [Script.id]) diff --git a/server/rbac/rbac.py b/server/rbac/rbac.py index 6d73f70a..2829ff43 100644 --- a/server/rbac/rbac.py +++ b/server/rbac/rbac.py @@ -17,6 +17,10 @@ def __init__(self, app: "DigiScriptServer"): self._rbac_db = RBACDatabase(app.get_db(), app) self._display_fields = {} + @property + def rbac_db(self): + return self._rbac_db + def add_mapping( self, actor: type, resource: type, display_fields: Optional[List] = None ) -> None: diff --git a/server/rbac/rbac_db.py b/server/rbac/rbac_db.py index 4f65363c..41794c37 100644 --- a/server/rbac/rbac_db.py +++ b/server/rbac/rbac_db.py @@ -1,7 +1,7 @@ import functools from collections import defaultdict from copy import deepcopy -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional from anytree import Node from sqlalchemy import Column, ForeignKey, Integer, Table, TypeDecorator, inspect @@ -13,7 +13,7 @@ from rbac.exceptions import RBACException from rbac.role import Role from utils import tree -from utils.database import DigiSQLAlchemy +from utils.database import DigiDBSession, DigiSQLAlchemy if TYPE_CHECKING: from digi_server.app_server import DigiScriptServer @@ -133,6 +133,35 @@ def add_mapping(self, actor: type, resource: type) -> None: self._resource_mappings[actor_inspect.mapped_table.fullname].append(resource) logger.info(f"Created RBAC mapping {table_name}") + @property + def mapped_resource_tables(self) -> List[str]: + mapped_tables = set() + for actor_table, resource_tables in self._resource_mappings.items(): + for resource_table in resource_tables: + resource_inspect = inspect(resource_table) + mapped_tables.add(resource_inspect.mapped_table.fullname) + return list(mapped_tables) + + @property + def resource_table_mappings(self) -> Dict[str, List[str]]: + output = defaultdict(set) + for actor_table, resource_tables in self._resource_mappings.items(): + for resource_table in resource_tables: + resource_inspect = inspect(resource_table) + output[resource_inspect.mapped_table.fullname].add(actor_table) + real_output = {} + for resource_table in output: + real_output[resource_table] = list(output[resource_table]) + return real_output + + def check_object_deletion(self, _session: DigiDBSession, delete_object: db.Model): + obj_inspect = inspect(delete_object) + table_name = obj_inspect.mapper.mapped_table.fullname + if table_name in self._resource_mappings: + self.delete_actor(delete_object) + if table_name in self.mapped_resource_tables: + self.delete_resource(delete_object) + def _validate_mapping(self, actor: db.Model, resource: db.Model) -> str: if not isinstance(actor, db.Model): raise RBACException("actor must be class instance, not object") @@ -206,6 +235,17 @@ def get_all_roles(self, actor: db.Model) -> Dict: ) return roles + def _delete_from_rbac_db(self, table_name: str, cols: Dict[str, Any]): + if table_name not in self._mappings: + RBACException("Could not get table for actor/resource") + with self._db.sessionmaker() as session: + rbac_assignments = ( + session.query(self._mappings[table_name]).filter_by(**cols).all() + ) + for rbac_assignment in rbac_assignments: + session.delete(rbac_assignment) + session.commit() + def delete_actor(self, actor: db.Model) -> None: actor_inspect = inspect(actor) actor_cols = _get_mapping_columns(actor=actor, resource=None) @@ -218,17 +258,19 @@ def delete_actor(self, actor: db.Model) -> None: f"rbac_{actor_inspect.mapper.mapped_table.fullname}_" f"{resource_inspect.mapper.mapped_table.fullname}" ) - if table_name not in self._mappings: - RBACException("Could not get table for actor/resource") - with self._db.sessionmaker() as session: - rbac_assignments = ( - session.query(self._mappings[table_name]) - .filter_by(**actor_cols) - .all() - ) - for rbac_assignment in rbac_assignments: - session.delete(rbac_assignment) - session.commit() + self._delete_from_rbac_db(table_name, actor_cols) + + def delete_resource(self, resource: db.Model): + resource_inspect = inspect(resource) + resource_cols = _get_mapping_columns(actor=None, resource=resource) + actor_mappings = self.resource_table_mappings.get( + resource_inspect.mapper.mapped_table.fullname, [] + ) + for actor in actor_mappings: + table_name = ( + f"rbac_{actor}_" f"{resource_inspect.mapper.mapped_table.fullname}" + ) + self._delete_from_rbac_db(table_name, resource_cols) @functools.lru_cache() def _has_link_to_show(self, table: Table): diff --git a/server/utils/database.py b/server/utils/database.py index 63f087df..bbe91633 100644 --- a/server/utils/database.py +++ b/server/utils/database.py @@ -1,4 +1,5 @@ import functools +from typing import Callable, List from sqlalchemy import MetaData, event from sqlalchemy.orm import declarative_base, sessionmaker @@ -16,6 +17,8 @@ def post_delete(self, session: "DigiDBSession"): class DigiDBSession(SessionEx): def _delete_impl(self, state, obj, head): + for hook in self.db.delete_hooks: + hook(self, obj) if isinstance(obj, DeleteMixin): obj.pre_delete(self) super()._delete_impl(state, obj, head) @@ -30,6 +33,7 @@ def __init__(self, url=None, binds=None, session_options=None, engine_options=No self.sessionmaker = None # Store the original create_engine method original_create_engine = self.create_engine + self._delete_hooks: List[Callable] = [] # Override create_engine to add event listener for SQLite def create_engine_with_fk_support(*args, **kwargs): @@ -75,3 +79,11 @@ def make_declarative_base(self): } metadata = MetaData(naming_convention=convention) return declarative_base(metaclass=BindMeta, metadata=metadata) + + @property + def delete_hooks(self): + return self._delete_hooks + + def register_delete_hook(self, hook: Callable): + if hook not in self._delete_hooks: + self._delete_hooks.append(hook) From e8adf20d490ab9aa1817d727f9b6cfdf2ff03a75 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 23 Apr 2025 22:52:05 +0100 Subject: [PATCH 4/5] Fix pylint --- server/rbac/rbac_db.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/rbac/rbac_db.py b/server/rbac/rbac_db.py index 41794c37..9fc2821c 100644 --- a/server/rbac/rbac_db.py +++ b/server/rbac/rbac_db.py @@ -136,7 +136,7 @@ def add_mapping(self, actor: type, resource: type) -> None: @property def mapped_resource_tables(self) -> List[str]: mapped_tables = set() - for actor_table, resource_tables in self._resource_mappings.items(): + for _actor_table, resource_tables in self._resource_mappings.items(): for resource_table in resource_tables: resource_inspect = inspect(resource_table) mapped_tables.add(resource_inspect.mapped_table.fullname) From 299a304f627fe9f8d0795bca5cb91ec60fa41135 Mon Sep 17 00:00:00 2001 From: Tim Bradgate Date: Wed, 23 Apr 2025 23:00:27 +0100 Subject: [PATCH 5/5] Bump version to 0.11.0 --- client/package-lock.json | 4 ++-- client/package.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client/package-lock.json b/client/package-lock.json index 68a5879a..6d071952 100644 --- a/client/package-lock.json +++ b/client/package-lock.json @@ -1,12 +1,12 @@ { "name": "client", - "version": "0.10.1", + "version": "0.11.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "client", - "version": "0.10.1", + "version": "0.11.0", "dependencies": { "bootstrap": "4.6.2", "bootstrap-vue": "2.23.1", diff --git a/client/package.json b/client/package.json index 06e685c4..75f60df4 100644 --- a/client/package.json +++ b/client/package.json @@ -1,6 +1,6 @@ { "name": "client", - "version": "0.10.1", + "version": "0.11.0", "private": true, "scripts": { "build": "vite build",