diff --git a/ami/users/management/__init__.py b/ami/users/management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/users/management/commands/__init__.py b/ami/users/management/commands/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ami/users/management/commands/update_roles.py b/ami/users/management/commands/update_roles.py new file mode 100644 index 000000000..ebbf2360d --- /dev/null +++ b/ami/users/management/commands/update_roles.py @@ -0,0 +1,88 @@ +import logging + +from django.core.management.base import BaseCommand +from django.db import transaction + +from ami.main.models import Project +from ami.users.models import RoleSchemaVersion +from ami.users.roles import create_roles_for_project + +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Update roles and permissions for all projects or a specific project" + + def add_arguments(self, parser): + parser.add_argument( + "--project-id", + type=int, + help="Update roles for a specific project by ID", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Preview changes without applying them", + ) + parser.add_argument( + "--force", + action="store_true", + default=True, + help="Force update even if groups already exist (default: True)", + ) + + def handle(self, *args, **options): + project_id = options.get("project_id") + dry_run = options.get("dry_run", False) + force = options.get("force", True) + + if dry_run: + self.stdout.write(self.style.WARNING("DRY RUN MODE - No changes will be made")) + + # Get projects to update + if project_id: + try: + projects = [Project.objects.get(pk=project_id)] + self.stdout.write(f"Updating roles for project {project_id}") + except Project.DoesNotExist: + self.stderr.write(self.style.ERROR(f"Project with ID {project_id} does not exist")) + return + else: + projects = Project.objects.all() + project_count = projects.count() + self.stdout.write(f"Updating roles for {project_count} projects") + + success = 0 + failed = 0 + + for project in projects: + try: + if dry_run: + self.stdout.write(f" Would update roles for project {project.pk} ({project.name})") + else: + with transaction.atomic(): + create_roles_for_project(project, force_update=force) + self.stdout.write( + self.style.SUCCESS(f" ✓ Updated roles for project {project.pk} ({project.name})") + ) + success += 1 + except Exception as e: + self.stderr.write(self.style.ERROR(f" ✗ Failed to update project {project.pk} ({project.name}): {e}")) + failed += 1 + logger.exception(f"Error updating roles for project {project.pk}") + + # Summary + self.stdout.write("\n" + "=" * 50) + if dry_run: + self.stdout.write(self.style.WARNING(f"DRY RUN COMPLETE: Would update {success} projects")) + else: + self.stdout.write(self.style.SUCCESS(f"Successfully updated: {success} projects")) + + if failed > 0: + self.stdout.write(self.style.ERROR(f"Failed: {failed} projects")) + + # Update schema version if successful and not dry run + if success > 0 and not project_id: + RoleSchemaVersion.mark_updated(description="Manual update via management command") + current_version = RoleSchemaVersion.get_current_version() + self.stdout.write(f"Schema version updated to: {current_version}") diff --git a/ami/users/migrations/0004_roleschemaversion.py b/ami/users/migrations/0004_roleschemaversion.py new file mode 100644 index 000000000..e595682c3 --- /dev/null +++ b/ami/users/migrations/0004_roleschemaversion.py @@ -0,0 +1,24 @@ +# Generated by Django 4.2.10 on 2026-01-22 03:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("users", "0003_lowercase_existing_emails"), + ] + + operations = [ + migrations.CreateModel( + name="RoleSchemaVersion", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("version", models.CharField(max_length=100, unique=True)), + ("description", models.TextField()), + ("updated_at", models.DateTimeField(auto_now=True)), + ], + options={ + "ordering": ["-updated_at"], + }, + ), + ] diff --git a/ami/users/models.py b/ami/users/models.py index 1b321a66c..4eff004ee 100644 --- a/ami/users/models.py +++ b/ami/users/models.py @@ -42,3 +42,52 @@ def get_absolute_url(self) -> str: """ # @TODO return frontend URL, not API URL return reverse("api:user-detail", kwargs={"id": self.pk}) + + +class RoleSchemaVersion(models.Model): + """ + Tracks the current role/permission schema version. + Updated when Role classes or Project.Permissions change. + """ + + version = models.CharField(max_length=100, unique=True) + description = models.TextField() + updated_at = models.DateTimeField(auto_now=True) + + class Meta: + ordering = ["-updated_at"] + + def __str__(self): + return f"RoleSchemaVersion {self.version}" + + @classmethod + def get_current_version(cls): + """Get the current schema version from code.""" + import hashlib + + from ami.users.roles import Role + + role_data = [] + for role_class in sorted(Role.__subclasses__(), key=lambda r: r.__name__): + perms = sorted(role_class.permissions) + role_data.append(f"{role_class.__name__}:{','.join(perms)}") + + schema_str = "|".join(role_data) + return hashlib.md5(schema_str.encode()).hexdigest()[:16] + + @classmethod + def needs_update(cls): + """Check if roles need updating based on schema version.""" + current = cls.get_current_version() + try: + latest = cls.objects.first() + return latest is None or latest.version != current + except Exception: + # Table doesn't exist yet (first migration) + return False + + @classmethod + def mark_updated(cls, description="Schema updated"): + """Mark schema as updated to current version.""" + current = cls.get_current_version() + cls.objects.create(version=current, description=description) diff --git a/ami/users/roles.py b/ami/users/roles.py index 69862229c..ac4aeef53 100644 --- a/ami/users/roles.py +++ b/ami/users/roles.py @@ -2,7 +2,6 @@ from django.contrib.auth.models import Group, Permission from django.contrib.contenttypes.models import ContentType -from guardian.shortcuts import assign_perm, get_perms, remove_perm from ami.main.models import Project @@ -194,29 +193,68 @@ class ProjectManager(Role): ) -def create_roles_for_project(project): - """Creates role-based permission groups for a given project.""" +def create_roles_for_project(project, force_update=False): + """ + Creates role-based permission groups for a given project. + + Args: + project: The project to create roles for + force_update: If False, skip updates for existing groups (default: False) + If True, always update permissions even if group exists + """ + from guardian.models import GroupObjectPermission + project_ct = ContentType.objects.get_for_model(Project) + # Pre-fetch all permissions we might need (single query) + all_perm_codenames = set() + for role_class in Role.__subclasses__(): + all_perm_codenames.update(role_class.permissions) + + existing_perms = { + perm.codename: perm + for perm in Permission.objects.filter(codename__in=all_perm_codenames, content_type=project_ct) + } + + # Create any missing permissions + missing_perms = [] + for codename in all_perm_codenames: + if codename not in existing_perms: + missing_perms.append( + Permission(codename=codename, content_type=project_ct, name=f"Can {codename.replace('_', ' ')}") + ) + + if missing_perms: + Permission.objects.bulk_create(missing_perms, ignore_conflicts=True) + # Refresh existing_perms dict after bulk create + existing_perms = { + perm.codename: perm + for perm in Permission.objects.filter(codename__in=all_perm_codenames, content_type=project_ct) + } + for role_class in Role.__subclasses__(): role_name = f"{project.pk}_{project.name}_{role_class.__name__}" permissions = role_class.permissions group, created = Group.objects.get_or_create(name=role_name) + if created: logger.debug(f"Role created {role_class} for project {project}") - else: - # Reset permissions to make sure permissions are updated - # every time we call this function - group.permissions.clear() - assigned_perms = get_perms(group, project) - for perm_codename in assigned_perms: - remove_perm(perm_codename, group, project) - for perm_codename in permissions: - permission, perm_created = Permission.objects.get_or_create( - codename=perm_codename, - content_type=project_ct, - defaults={"name": f"Can {perm_codename.replace('_', ' ')}"}, - ) + elif not force_update: + # Skip updates for existing groups unless force_update=True + continue - group.permissions.add(permission) # Assign the permission group to the project - assign_perm(perm_codename, group, project) + # Use set() instead of clear() + add() loop (single query) + role_perm_objects = [existing_perms[codename] for codename in permissions] + group.permissions.set(role_perm_objects) + + # Bulk update Guardian object permissions + # Remove all existing, then bulk create new ones + GroupObjectPermission.objects.filter(group=group, content_type=project_ct, object_pk=project.pk).delete() + + group_obj_perms = [ + GroupObjectPermission( + group=group, permission=existing_perms[codename], content_type=project_ct, object_pk=project.pk + ) + for codename in permissions + ] + GroupObjectPermission.objects.bulk_create(group_obj_perms) diff --git a/ami/users/signals.py b/ami/users/signals.py index 6ce8998eb..c90f1860f 100644 --- a/ami/users/signals.py +++ b/ami/users/signals.py @@ -12,20 +12,41 @@ def create_roles(sender, **kwargs): - """Creates predefined roles with specific permissions .""" + """ + Creates predefined roles with specific permissions. + Only runs when role schema version has changed. + """ + from ami.users.models import RoleSchemaVersion + + # Quick check - does schema need updating? + if not RoleSchemaVersion.needs_update(): + logger.debug("Role schema is up to date, skipping role creation") + return + + logger.info("Role schema version changed - updating roles for all projects") + project_count = Project.objects.count() + + if project_count > 100: + logger.warning( + f"Updating roles for {project_count} projects. " + f"This may take a while. Consider running 'python manage.py update_roles' " + f"separately for better control." + ) - logger.info("Creating roles for all projects") try: for project in Project.objects.all(): try: - create_roles_for_project(project) + create_roles_for_project(project, force_update=True) except Exception as e: logger.warning(f"Failed to create roles for project {project.pk} ({project.name}): {e}") continue + + # Mark schema as updated + RoleSchemaVersion.mark_updated(description="Post-migration role update") + logger.info(f"Successfully updated roles for {project_count} projects") + except Exception as e: - logger.warning( - f"Failed to create roles during migration: {e}. This can be run manually via management command." - ) + logger.error(f"Failed to create roles during migration: {e}") @receiver(m2m_changed, sender=Group.user_set.through)