From 3bfe09a527e70feeeafa7d55ff75a2cec1c055a7 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 25 Aug 2025 16:35:37 -0500 Subject: [PATCH 01/83] feat: FIT-587: FSM Architecture setup --- label_studio/core/settings/base.py | 1 + label_studio/core/urls.py | 1 + label_studio/fsm/README.md | 268 ++++++++++++++ label_studio/fsm/__init__.py | 12 + label_studio/fsm/admin.py | 185 ++++++++++ label_studio/fsm/api.py | 257 ++++++++++++++ label_studio/fsm/apps.py | 67 ++++ label_studio/fsm/extension.py | 324 +++++++++++++++++ label_studio/fsm/integration.py | 278 +++++++++++++++ label_studio/fsm/migrations/0001_initial.py | 306 ++++++++++++++++ label_studio/fsm/migrations/__init__.py | 1 + label_studio/fsm/models.py | 334 ++++++++++++++++++ label_studio/fsm/serializers.py | 87 +++++ label_studio/fsm/state_choices.py | 139 ++++++++ label_studio/fsm/state_manager.py | 333 +++++++++++++++++ label_studio/fsm/tests/__init__.py | 1 + .../fsm/tests/test_fsm_integration.py | 271 ++++++++++++++ label_studio/fsm/tests/test_uuid7_utils.py | 164 +++++++++ label_studio/fsm/urls.py | 50 +++ label_studio/fsm/utils.py | 189 ++++++++++ poetry.lock | 44 ++- pyproject.toml | 1 + 22 files changed, 3309 insertions(+), 4 deletions(-) create mode 100644 label_studio/fsm/README.md create mode 100644 label_studio/fsm/__init__.py create mode 100644 label_studio/fsm/admin.py create mode 100644 label_studio/fsm/api.py create mode 100644 label_studio/fsm/apps.py create mode 100644 label_studio/fsm/extension.py create mode 100644 label_studio/fsm/integration.py create mode 100644 label_studio/fsm/migrations/0001_initial.py create mode 100644 label_studio/fsm/migrations/__init__.py create mode 100644 label_studio/fsm/models.py create mode 100644 label_studio/fsm/serializers.py create mode 100644 label_studio/fsm/state_choices.py create mode 100644 label_studio/fsm/state_manager.py create mode 100644 label_studio/fsm/tests/__init__.py create mode 100644 label_studio/fsm/tests/test_fsm_integration.py create mode 100644 label_studio/fsm/tests/test_uuid7_utils.py create mode 100644 label_studio/fsm/urls.py create mode 100644 label_studio/fsm/utils.py diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 34cce0ce4ed9..63c644c63a87 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -232,6 +232,7 @@ 'ml_model_providers', 'jwt_auth', 'session_policy', + 'fsm', # Finite State Machine for entity state tracking ] MIDDLEWARE = [ diff --git a/label_studio/core/urls.py b/label_studio/core/urls.py index 23998217d0dc..4bdb57b18295 100644 --- a/label_studio/core/urls.py +++ b/label_studio/core/urls.py @@ -105,6 +105,7 @@ re_path(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')), re_path(r'^', include('jwt_auth.urls')), re_path(r'^', include('session_policy.urls')), + re_path(r'^', include('fsm.urls')), # Finite State Machine APIs path('docs/api/schema/', SpectacularAPIView.as_view(), name='schema'), path('docs/api/schema/swagger-ui/', SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'), path('docs/api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='redoc'), diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md new file mode 100644 index 000000000000..f28150ca26d2 --- /dev/null +++ b/label_studio/fsm/README.md @@ -0,0 +1,268 @@ +# Label Studio FSM (Finite State Machine) + +Core finite state machine functionality for Label Studio that provides the foundation for state tracking across entities like Tasks, Annotations, and Projects. + +## Overview + +The Label Studio FSM system provides: + +- **Core Infrastructure**: Base state tracking models and managers +- **UUID7 Optimization**: Time-series optimized state records using UUID7 +- **Extension Mechanism**: Allows Label Studio Enterprise to extend functionality +- **Basic API**: REST endpoints for state management +- **Admin Interface**: Django admin integration for state inspection + +## Architecture + +### Core Components + +1. **BaseState**: Abstract model providing common state tracking functionality +2. **StateManager**: High-performance state management with caching +3. **Core State Models**: Task, Annotation, and Project state tracking +4. **Extension Registry**: Allows enterprise extensions to register additional functionality + +### Extension System + +The FSM system is designed to be extended by Label Studio Enterprise: + +```python +# Core provides foundation +from label_studio.fsm.models import BaseState +from label_studio.fsm.state_manager import StateManager + +# Enterprise extends with advanced features +class EnterpriseTaskState(BaseState): + # Additional enterprise-specific fields + organization_id = models.PositiveIntegerField(db_index=True) + # Advanced indexes and denormalization + +class EnterpriseStateManager(StateManager): + # Bulk operations, advanced caching, etc. + @classmethod + def bulk_get_states(cls, entities): + # Enterprise-specific bulk optimization + pass +``` + +## Usage + +### Basic State Management + +```python +from label_studio.fsm.state_manager import get_state_manager +from label_studio.tasks.models import Task + +# Get current state +StateManager = get_state_manager() +task = Task.objects.get(id=123) +current_state = StateManager.get_current_state(task) + +# Transition state +success = StateManager.transition_state( + entity=task, + new_state='IN_PROGRESS', + user=request.user, + reason='User started annotation work' +) + +# Get state history +history = StateManager.get_state_history(task, limit=10) +``` + +### Integration with Existing Models + +```python +# Add FSM functionality to existing models +from label_studio.fsm.integration import FSMIntegrationMixin + +class Task(FSMIntegrationMixin, BaseTask): + class Meta: + proxy = True + +# Now you can use FSM methods directly +task = Task.objects.get(id=123) +current_state = task.current_fsm_state +task.transition_fsm_state('COMPLETED', user=user) +``` + +### API Usage + +```bash +# Get current state +GET /api/fsm/task/123/current/ + +# Get state history +GET /api/fsm/task/123/history/?limit=10 + +# Transition state +POST /api/fsm/task/123/transition/ +{ + "new_state": "COMPLETED", + "reason": "Task completed by user" +} +``` + +## Dependencies + +The FSM system requires the `uuid-utils` library for UUID7 support: + +```bash +pip install uuid-utils>=0.11.0 +``` + +This dependency is automatically included in Label Studio's requirements. + +### Why UUID7? + +UUID7 provides significant performance benefits for time-series data like state transitions: + +- **Natural Time Ordering**: Records are naturally ordered by creation time without requiring additional indexes +- **Global Uniqueness**: Works across distributed systems and database shards +- **INSERT-only Architecture**: No UPDATE operations needed, maximizing concurrency +- **Time-based Partitioning**: Enables horizontal scaling to billions of records + +## Configuration + +### Django Settings + +Add the FSM app to your `INSTALLED_APPS`: + +```python +INSTALLED_APPS = [ + # ... other apps + 'label_studio.fsm', + # ... other apps +] +``` + +### Optional Settings + +```python +# FSM Configuration +FSM_CACHE_TTL = 300 # Cache timeout in seconds (default: 300) +FSM_AUTO_CREATE_STATES = False # Auto-create states on entity creation (default: False) +FSM_STATE_MANAGER_CLASS = None # Custom state manager class (default: None) + +# Enterprise Settings (when using Label Studio Enterprise) +FSM_ENABLE_BULK_OPERATIONS = True # Enable bulk operations (default: False) +FSM_CACHE_STATS_ENABLED = True # Enable cache statistics (default: False) +``` + +## Database Migrations + +Run migrations to create the FSM tables: + +```bash +python manage.py migrate fsm +``` + +This will create: +- `fsm_task_states`: Task state tracking +- `fsm_annotation_states`: Annotation state tracking +- `fsm_project_states`: Project state tracking + +## Performance Considerations + +### UUID7 Benefits + +The FSM system uses UUID7 for optimal time-series performance: + +- **Natural Time Ordering**: No need for `created_at` indexes +- **INSERT-only Architecture**: Maximum concurrency, no row locks +- **Global Uniqueness**: Supports distributed systems +- **Time-based Partitioning**: Scales to billions of records + +### Caching Strategy + +- **Write-through Caching**: Immediate consistency after state transitions +- **Configurable TTL**: Balance between performance and freshness +- **Cache Key Strategy**: Optimized for entity-based lookups + +### Indexes + +Critical indexes for performance: +- `(entity_id, id DESC)`: Current state lookup using UUID7 ordering +- `(entity_id, id)`: State history queries + +## Extension by Label Studio Enterprise + +Label Studio Enterprise extends this system with: + +1. **Advanced State Models**: Additional entities (Reviews, Assignments, etc.) +2. **Complex Workflows**: Review, arbitration, and approval flows +3. **Bulk Operations**: High-performance batch state transitions +4. **Enhanced Caching**: Multi-level caching with cache warming +5. **Analytics**: State-based reporting and metrics +6. **Denormalization**: Performance optimization with redundant fields + +### Enterprise Extension Example + +```python +# In Label Studio Enterprise +from label_studio.fsm.extension import BaseFSMExtension +from label_studio.fsm.models import register_state_model + +class EnterpriseExtension(BaseFSMExtension): + @classmethod + def initialize(cls): + # Register enterprise models + register_state_model('review', AnnotationReviewState) + register_state_model('assignment', TaskAssignmentState) + + @classmethod + def get_state_manager(cls): + return EnterpriseStateManager +``` + +## Monitoring and Debugging + +### Admin Interface + +Access state records via Django admin: +- `/admin/fsm/taskstate/` +- `/admin/fsm/annotationstate/` +- `/admin/fsm/projectstate/` + +### Logging + +FSM operations are logged at appropriate levels: +- `INFO`: Successful state transitions +- `ERROR`: Failed transitions and system errors +- `DEBUG`: Cache hits/misses and detailed operation info + +### Cache Statistics + +When `FSM_CACHE_STATS_ENABLED=True`, cache performance metrics are available for monitoring. + +## Migration from Existing Systems + +The FSM system can run alongside existing state management: + +1. **Parallel Operation**: FSM tracks states without affecting existing logic +2. **Gradual Migration**: Replace existing state checks with FSM calls over time +3. **Backfill Support**: Historical states can be backfilled from existing data + +## Testing + +Test the FSM system: + +```python +from label_studio.fsm.state_manager import StateManager +from label_studio.tasks.models import Task + +def test_task_state_transition(): + task = Task.objects.create(...) + + # Test initial state + assert StateManager.get_current_state(task) is None + + # Test transition + success = StateManager.transition_state(task, 'CREATED') + assert success + assert StateManager.get_current_state(task) == 'CREATED' + + # Test history + history = StateManager.get_state_history(task) + assert len(history) == 1 + assert history[0].state == 'CREATED' +``` \ No newline at end of file diff --git a/label_studio/fsm/__init__.py b/label_studio/fsm/__init__.py new file mode 100644 index 000000000000..c389fc213dfe --- /dev/null +++ b/label_studio/fsm/__init__.py @@ -0,0 +1,12 @@ +""" +Finite State Machine (FSM) core functionality for Label Studio. + +This package provides the core FSM infrastructure that can be extended +by Label Studio Enterprise and other applications. + +Core components: +- BaseState: Abstract model for all state tracking +- StateManager: High-performance state management +- Core state choices for basic entities +- UUID7 utilities for time-series optimization +""" diff --git a/label_studio/fsm/admin.py b/label_studio/fsm/admin.py new file mode 100644 index 000000000000..a55e1a7510ec --- /dev/null +++ b/label_studio/fsm/admin.py @@ -0,0 +1,185 @@ +""" +Core FSM admin interface for Label Studio. + +Provides basic admin interface for state management that can be extended +by Label Studio Enterprise with additional functionality. +""" + +from django.contrib import admin +from django.utils.html import format_html + +from .models import AnnotationState, ProjectState, TaskState + + +class BaseStateAdmin(admin.ModelAdmin): + """ + Base admin for state models. + + Provides common admin interface functionality for all state models. + Enterprise can extend this with additional features. + """ + + list_display = [ + 'entity_display', + 'state', + 'previous_state', + 'transition_name', + 'triggered_by', + 'created_at', + ] + list_filter = [ + 'state', + 'created_at', + 'transition_name', + ] + search_fields = [ + 'state', + 'previous_state', + 'transition_name', + 'reason', + ] + readonly_fields = [ + 'id', + 'created_at', + 'timestamp_from_uuid', + 'entity_display', + ] + ordering = ['-created_at'] + + # Limit displayed records for performance + list_per_page = 50 + list_max_show_all = 200 + + def entity_display(self, obj): + """Display the related entity information""" + try: + entity = obj.entity + return format_html( + '{} #{}', + f'/admin/{entity._meta.app_label}/{entity._meta.model_name}/{entity.pk}/change/', + entity._meta.verbose_name.title(), + entity.pk, + ) + except Exception: + return f'{obj._get_entity_name().title()} #{getattr(obj, f"{obj._get_entity_name()}_id", "?")}' + + entity_display.short_description = 'Entity' + # Note: admin_order_field is set dynamically in subclasses since model is not available here + + def timestamp_from_uuid(self, obj): + """Display timestamp extracted from UUID7""" + return obj.timestamp_from_uuid + + timestamp_from_uuid.short_description = 'UUID7 Timestamp' + + def has_add_permission(self, request): + """Disable manual creation of state records""" + return False + + def has_change_permission(self, request, obj=None): + """State records should be read-only""" + return False + + def has_delete_permission(self, request, obj=None): + """State records should not be deleted""" + return False + + +@admin.register(TaskState) +class TaskStateAdmin(BaseStateAdmin): + """Admin interface for Task state records""" + + list_display = BaseStateAdmin.list_display + ['task_id'] + list_filter = BaseStateAdmin.list_filter + ['state'] + search_fields = BaseStateAdmin.search_fields + ['task__id'] + + def task_id(self, obj): + """Display task ID with link""" + return format_html('Task #{}', obj.task.pk, obj.task.pk) + + task_id.short_description = 'Task' + task_id.admin_order_field = 'task__id' + + +@admin.register(AnnotationState) +class AnnotationStateAdmin(BaseStateAdmin): + """Admin interface for Annotation state records""" + + list_display = BaseStateAdmin.list_display + ['annotation_id', 'task_link'] + list_filter = BaseStateAdmin.list_filter + ['state'] + search_fields = BaseStateAdmin.search_fields + ['annotation__id'] + + def annotation_id(self, obj): + """Display annotation ID with link""" + return format_html( + 'Annotation #{}', obj.annotation.pk, obj.annotation.pk + ) + + annotation_id.short_description = 'Annotation' + annotation_id.admin_order_field = 'annotation__id' + + def task_link(self, obj): + """Display related task link""" + task = obj.annotation.task + return format_html('Task #{}', task.pk, task.pk) + + task_link.short_description = 'Task' + task_link.admin_order_field = 'annotation__task__id' + + +@admin.register(ProjectState) +class ProjectStateAdmin(BaseStateAdmin): + """Admin interface for Project state records""" + + list_display = BaseStateAdmin.list_display + ['project_id', 'project_title'] + list_filter = BaseStateAdmin.list_filter + ['state'] + search_fields = BaseStateAdmin.search_fields + ['project__id', 'project__title'] + + def project_id(self, obj): + """Display project ID with link""" + return format_html( + 'Project #{}', obj.project.pk, obj.project.pk + ) + + project_id.short_description = 'Project' + project_id.admin_order_field = 'project__id' + + def project_title(self, obj): + """Display project title""" + return obj.project.title[:50] + ('...' if len(obj.project.title) > 50 else '') + + project_title.short_description = 'Title' + project_title.admin_order_field = 'project__title' + + +# Admin actions for bulk operations (Enterprise can extend these) + + +def mark_states_as_reviewed(modeladmin, request, queryset): + """ + Admin action to mark state records as reviewed. + + This is a placeholder that Enterprise can extend with actual functionality. + """ + count = queryset.count() + modeladmin.message_user(request, f'{count} state records marked as reviewed.') + + +mark_states_as_reviewed.short_description = 'Mark selected states as reviewed' + + +def export_state_history(modeladmin, request, queryset): + """ + Admin action to export state history. + + This is a placeholder that Enterprise can extend with actual export functionality. + """ + count = queryset.count() + modeladmin.message_user(request, f'Export initiated for {count} state records.') + + +export_state_history.short_description = 'Export state history' + + +# Add actions to base admin (Enterprise can override) +BaseStateAdmin.actions = [mark_states_as_reviewed, export_state_history] diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py new file mode 100644 index 000000000000..133c96e85ea7 --- /dev/null +++ b/label_studio/fsm/api.py @@ -0,0 +1,257 @@ +""" +Core FSM API endpoints for Label Studio. + +Provides basic API endpoints for state management that can be extended +by Label Studio Enterprise with additional functionality. +""" + +import logging + +from django.http import Http404 +from django.shortcuts import get_object_or_404 +from rest_framework import status, viewsets +from rest_framework.decorators import action +from rest_framework.response import Response + +from label_studio.core.permissions import AllPermissions + +from .models import get_state_model_for_entity +from .serializers import StateHistorySerializer, StateTransitionSerializer +from .state_manager import get_state_manager + +logger = logging.getLogger(__name__) + + +class FSMViewSet(viewsets.ViewSet): + """ + Core FSM API endpoints. + + Provides basic state management operations: + - Get current state + - Get state history + - Trigger state transitions + + Label Studio Enterprise can extend this with additional endpoints + for advanced state management operations. + """ + + permission_classes = [AllPermissions] + + def _get_entity_and_state_model(self, entity_type: str, entity_id: int): + """Helper to get entity instance and its state model""" + # Get the Django model class for the entity type + entity_model = self._get_entity_model(entity_type) + if not entity_model: + raise Http404(f'Unknown entity type: {entity_type}') + + # Get the entity instance + entity = get_object_or_404(entity_model, pk=entity_id) + + # Get the state model for this entity + state_model = get_state_model_for_entity(entity) + if not state_model: + raise Http404(f'No state model found for entity type: {entity_type}') + + return entity, state_model + + def _get_entity_model(self, entity_type: str): + """Get Django model class for entity type""" + from django.apps import apps + + # Map entity types to app.model + entity_mapping = { + 'task': 'tasks.Task', + 'annotation': 'tasks.Annotation', + 'project': 'projects.Project', + } + + model_path = entity_mapping.get(entity_type.lower()) + if not model_path: + return None + + app_label, model_name = model_path.split('.') + return apps.get_model(app_label, model_name) + + @action(detail=False, methods=['get'], url_path=r'(?P\w+)/(?P\d+)/current') + def current_state(self, request, entity_type=None, entity_id=None): + """ + Get current state for an entity. + + GET /api/fsm/{entity_type}/{entity_id}/current/ + + Returns: + { + "current_state": "IN_PROGRESS", + "entity_type": "task", + "entity_id": 123 + } + """ + try: + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + + # Get current state using the configured state manager + StateManager = get_state_manager() + current_state = StateManager.get_current_state(entity) + + return Response( + { + 'current_state': current_state, + 'entity_type': entity_type, + 'entity_id': int(entity_id), + } + ) + + except Exception as e: + logger.error(f'Error getting current state for {entity_type} {entity_id}: {e}') + return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['get'], url_path=r'(?P\w+)/(?P\d+)/history') + def state_history(self, request, entity_type=None, entity_id=None): + """ + Get state history for an entity. + + GET /api/fsm/{entity_type}/{entity_id}/history/ + + Query parameters: + - limit: Maximum number of history records (default: 100) + - include_context: Include context_data in response (default: false) + + Returns: + { + "count": 5, + "results": [ + { + "id": "uuid7-id", + "state": "COMPLETED", + "previous_state": "IN_PROGRESS", + "transition_name": "complete_task", + "triggered_by": "user@example.com", + "created_at": "2024-01-15T10:30:00Z", + "reason": "Task completed by user", + "context_data": {...} // if include_context=true + }, + ... + ] + } + """ + try: + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + + # Get query parameters + limit = min(int(request.query_params.get('limit', 100)), 1000) # Max 1000 + include_context = request.query_params.get('include_context', 'false').lower() == 'true' + + # Get state history using the configured state manager + StateManager = get_state_manager() + history = StateManager.get_state_history(entity, limit) + + # Serialize the results + serializer = StateHistorySerializer(history, many=True, context={'include_context': include_context}) + + return Response( + { + 'count': len(history), + 'results': serializer.data, + } + ) + + except Exception as e: + logger.error(f'Error getting state history for {entity_type} {entity_id}: {e}') + return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + @action(detail=False, methods=['post'], url_path=r'(?P\w+)/(?P\d+)/transition') + def transition_state(self, request, entity_type=None, entity_id=None): + """ + Trigger a state transition for an entity. + + POST /api/fsm/{entity_type}/{entity_id}/transition/ + + Request body: + { + "new_state": "COMPLETED", + "transition_name": "complete_task", // optional + "reason": "Task completed by user", // optional + "context": { // optional + "assignment_id": 456 + } + } + + Returns: + { + "success": true, + "previous_state": "IN_PROGRESS", + "new_state": "COMPLETED", + "entity_type": "task", + "entity_id": 123 + } + """ + try: + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + + # Validate request data + serializer = StateTransitionSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + data = serializer.validated_data + new_state = data['new_state'] + transition_name = data.get('transition_name') + reason = data.get('reason', '') + context = data.get('context', {}) + + # Get current state for response + StateManager = get_state_manager() + previous_state = StateManager.get_current_state(entity) + + # Perform state transition + success = StateManager.transition_state( + entity=entity, + new_state=new_state, + transition_name=transition_name, + user=request.user, + context=context, + reason=reason, + ) + + if success: + return Response( + { + 'success': True, + 'previous_state': previous_state, + 'new_state': new_state, + 'entity_type': entity_type, + 'entity_id': int(entity_id), + } + ) + else: + return Response({'error': 'State transition failed'}, status=status.HTTP_400_BAD_REQUEST) + + except Exception as e: + logger.error(f'Error transitioning state for {entity_type} {entity_id}: {e}') + return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +# Extension point for Label Studio Enterprise +class ExtendedFSMViewSet(FSMViewSet): + """ + Extension point for Label Studio Enterprise. + + Enterprise can override this class to add: + - Bulk state operations + - Advanced state queries + - Enterprise-specific endpoints + - Enhanced permissions and validation + + Example Enterprise usage: + class EnterpriseFSMViewSet(ExtendedFSMViewSet): + @action(detail=False, methods=['post']) + def bulk_transition(self, request): + # Enterprise bulk transition endpoint + pass + + @action(detail=False, methods=['get']) + def state_analytics(self, request): + # Enterprise state analytics endpoint + pass + """ + + pass diff --git a/label_studio/fsm/apps.py b/label_studio/fsm/apps.py new file mode 100644 index 000000000000..a2fd89e66cf3 --- /dev/null +++ b/label_studio/fsm/apps.py @@ -0,0 +1,67 @@ +"""FSM Django App Configuration""" + +import logging + +from django.apps import AppConfig + +logger = logging.getLogger(__name__) + + +class FsmConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'label_studio.fsm' + verbose_name = 'Finite State Machine' + + def ready(self): + """Initialize FSM system when Django app is ready""" + # Initialize extension system + self._initialize_extensions() + + # Set up signal handlers for automatic state creation + self._setup_signals() + + logger.info('FSM system initialized') + + def _initialize_extensions(self): + """Initialize FSM extension system""" + try: + from .extension import ( + auto_register_enterprise_choices, + auto_register_enterprise_models, + extension_registry, + ) + + # Load configured extensions + extension_registry.load_extensions() + + # Auto-register enterprise models if available + auto_register_enterprise_models() + auto_register_enterprise_choices() + + except Exception as e: + logger.error(f'Failed to initialize FSM extensions: {e}') + + def _setup_signals(self): + """Set up signal handlers for automatic state creation""" + try: + from django.conf import settings + from django.db.models.signals import post_save + + # Only set up signals if enabled in settings + if getattr(settings, 'FSM_AUTO_CREATE_STATES', False): + from label_studio.projects.models import Project + + # Import models + from label_studio.tasks.models import Annotation, Task + + from .integration import handle_annotation_created, handle_project_created, handle_task_created + + # Connect signal handlers + post_save.connect(handle_task_created, sender=Task) + post_save.connect(handle_annotation_created, sender=Annotation) + post_save.connect(handle_project_created, sender=Project) + + logger.info('FSM signal handlers registered') + + except Exception as e: + logger.error(f'Failed to set up FSM signals: {e}') diff --git a/label_studio/fsm/extension.py b/label_studio/fsm/extension.py new file mode 100644 index 000000000000..6c8f2e3eebc9 --- /dev/null +++ b/label_studio/fsm/extension.py @@ -0,0 +1,324 @@ +""" +Extension mechanism for Label Studio Enterprise FSM. + +This module provides the hooks and interfaces that allow Label Studio Enterprise +to extend the core FSM functionality with advanced features. +""" + +import logging +from typing import Any, Dict, List, Type + +from django.apps import apps +from django.conf import settings + +from .models import BaseState, register_state_model +from .state_choices import register_state_choices +from .state_manager import StateManager + +logger = logging.getLogger(__name__) + + +class FSMExtensionRegistry: + """ + Registry for FSM extensions that allows enterprise features to be + dynamically loaded and integrated with the core FSM system. + """ + + def __init__(self): + self._extensions = {} + self._state_managers = {} + self._loaded = False + + def register_extension(self, name: str, extension_class): + """ + Register an FSM extension. + + Args: + name: Unique name for the extension + extension_class: Class implementing the extension + """ + self._extensions[name] = extension_class + logger.info(f'Registered FSM extension: {name}') + + def get_extension(self, name: str): + """Get a registered extension by name""" + return self._extensions.get(name) + + def load_extensions(self): + """ + Load FSM extensions from settings. + + Called during Django app startup to load enterprise extensions. + """ + if self._loaded: + return + + extensions_setting = getattr(settings, 'FSM_EXTENSIONS', []) + for extension_config in extensions_setting: + self._load_extension(extension_config) + + self._loaded = True + logger.info(f'Loaded {len(self._extensions)} FSM extensions') + + def _load_extension(self, config: Dict[str, Any]): + """Load a single extension from configuration""" + try: + name = config['name'] + class_path = config['class'] + + # Import the extension class + module_name, class_name = class_path.rsplit('.', 1) + module = __import__(module_name, fromlist=[class_name]) + extension_class = getattr(module, class_name) + + # Register the extension + self.register_extension(name, extension_class) + + # Initialize the extension if it has an init method + if hasattr(extension_class, 'initialize'): + extension_class.initialize() + + except Exception as e: + logger.error(f'Failed to load FSM extension {config}: {e}') + + +# Global extension registry +extension_registry = FSMExtensionRegistry() + + +class BaseFSMExtension: + """ + Base class for FSM extensions. + + Enterprise extensions should inherit from this class to ensure + compatibility with the core FSM system. + """ + + @classmethod + def initialize(cls): + """ + Initialize the extension. + + Called when the extension is loaded. Override to perform + setup tasks like registering state models and choices. + """ + pass + + @classmethod + def register_models(cls): + """ + Register state models with the core FSM system. + + Override to register enterprise-specific state models. + + Example: + register_state_model('review', AnnotationReviewState) + register_state_model('assignment', TaskAssignmentState) + """ + pass + + @classmethod + def register_choices(cls): + """ + Register state choices with the core FSM system. + + Override to register enterprise-specific state choices. + + Example: + register_state_choices('review', ReviewStateChoices) + register_state_choices('assignment', AssignmentStateChoices) + """ + pass + + @classmethod + def get_state_manager(cls) -> Type[StateManager]: + """ + Get the state manager class for this extension. + + Override to provide enterprise-specific state manager. + + Returns: + StateManager class to use + """ + return StateManager + + +class EnterpriseExtensionMixin: + """ + Mixin for enterprise extensions that provides common functionality + for extending the core FSM system. + """ + + @classmethod + def extend_state_model(cls, entity_name: str, base_model_class: Type[BaseState]): + """ + Helper to create extended state models. + + Args: + entity_name: Name of the entity (e.g., 'task', 'annotation') + base_model_class: Base state model class to extend + + Returns: + Extended model class + """ + # This would be used by enterprise to add denormalized fields, + # additional indexes, and enterprise-specific functionality + pass + + @classmethod + def extend_state_choices(cls, base_choices_class, additional_choices: List[tuple]): + """ + Helper to extend state choices with additional states. + + Args: + base_choices_class: Base TextChoices class + additional_choices: List of (value, label) tuples for new states + + Returns: + Extended choices class + """ + # This would be used by enterprise to add additional states + # to the core state choices + pass + + +# Configuration helpers for enterprise setup + + +def configure_fsm_for_enterprise(): + """ + Configure FSM system for Label Studio Enterprise. + + This function should be called by enterprise during app initialization + to set up the FSM system with enterprise-specific configuration. + """ + # Load enterprise extensions + extension_registry.load_extensions() + + # Set enterprise-specific settings + if not hasattr(settings, 'FSM_CACHE_TTL'): + settings.FSM_CACHE_TTL = 300 # 5 minutes + + if not hasattr(settings, 'FSM_ENABLE_BULK_OPERATIONS'): + settings.FSM_ENABLE_BULK_OPERATIONS = True + + logger.info('FSM system configured for Label Studio Enterprise') + + +def get_enterprise_state_manager(): + """ + Get the enterprise state manager if available. + + Returns the enterprise-specific state manager class if one is registered, + otherwise returns the core StateManager. + """ + # Check if enterprise has registered a state manager + enterprise_ext = extension_registry.get_extension('enterprise') + if enterprise_ext: + return enterprise_ext.get_state_manager() + + # Fall back to core state manager + return StateManager + + +# Settings for FSM extensions +def get_fsm_settings(): + """Get FSM-related settings with defaults""" + return { + 'cache_ttl': getattr(settings, 'FSM_CACHE_TTL', 300), + 'enable_bulk_operations': getattr(settings, 'FSM_ENABLE_BULK_OPERATIONS', False), + 'enable_cache_stats': getattr(settings, 'FSM_CACHE_STATS_ENABLED', False), + 'state_manager_class': getattr(settings, 'FSM_STATE_MANAGER_CLASS', None), + 'extensions': getattr(settings, 'FSM_EXTENSIONS', []), + } + + +# Integration helpers for model registration + + +def auto_register_enterprise_models(): + """ + Automatically register enterprise state models. + + Scans for state models in enterprise apps and registers them + with the core FSM system. + """ + try: + # Only attempt if enterprise is available + if apps.is_installed('label_studio_enterprise.fsm'): + from label_studio_enterprise.fsm.models import ( + AnnotationDraftState, + AnnotationReviewState, + CommentState, + TaskAssignmentState, + TaskLockState, + ) + from label_studio_enterprise.fsm.models import ( + AnnotationState as EnterpriseAnnotationState, + ) + from label_studio_enterprise.fsm.models import ( + ProjectState as EnterpriseProjectState, + ) + from label_studio_enterprise.fsm.models import ( + TaskState as EnterpriseTaskState, + ) + + # Register enterprise state models + register_state_model('task', EnterpriseTaskState) + register_state_model('annotation', EnterpriseAnnotationState) + register_state_model('project', EnterpriseProjectState) + register_state_model('annotationreview', AnnotationReviewState) + register_state_model('taskassignment', TaskAssignmentState) + register_state_model('annotationdraft', AnnotationDraftState) + register_state_model('comment', CommentState) + register_state_model('tasklock', TaskLockState) + + logger.info('Auto-registered enterprise state models') + + except ImportError: + # Enterprise not available, use core models + logger.debug('Enterprise FSM models not available, using core models') + + +def auto_register_enterprise_choices(): + """ + Automatically register enterprise state choices. + + Scans for state choices in enterprise apps and registers them + with the core FSM system. + """ + try: + # Only attempt if enterprise is available + if apps.is_installed('label_studio_enterprise.fsm'): + from label_studio_enterprise.fsm.state_choices import ( + AnnotationDraftStateChoices, + AssignmentStateChoices, + CommentStateChoices, + ReviewStateChoices, + TaskLockStateChoices, + ) + from label_studio_enterprise.fsm.state_choices import ( + AnnotationStateChoices as EnterpriseAnnotationStateChoices, + ) + from label_studio_enterprise.fsm.state_choices import ( + ProjectStateChoices as EnterpriseProjectStateChoices, + ) + from label_studio_enterprise.fsm.state_choices import ( + TaskStateChoices as EnterpriseTaskStateChoices, + ) + + # Register enterprise state choices + register_state_choices('task', EnterpriseTaskStateChoices) + register_state_choices('annotation', EnterpriseAnnotationStateChoices) + register_state_choices('project', EnterpriseProjectStateChoices) + register_state_choices('review', ReviewStateChoices) + register_state_choices('assignment', AssignmentStateChoices) + register_state_choices('annotationdraft', AnnotationDraftStateChoices) + register_state_choices('comment', CommentStateChoices) + register_state_choices('tasklock', TaskLockStateChoices) + + logger.info('Auto-registered enterprise state choices') + + except ImportError: + # Enterprise not available, use core choices + logger.debug('Enterprise FSM choices not available, using core choices') diff --git a/label_studio/fsm/integration.py b/label_studio/fsm/integration.py new file mode 100644 index 000000000000..0ae365bb7d79 --- /dev/null +++ b/label_studio/fsm/integration.py @@ -0,0 +1,278 @@ +""" +Integration helpers for connecting FSM with existing Label Studio models. + +This module provides helper methods and mixins that can be added to existing +Label Studio models to integrate them with the FSM system. +""" + +import logging +from typing import Optional + +from django.db import models + +from .state_manager import get_state_manager + +logger = logging.getLogger(__name__) + + +class FSMIntegrationMixin: + """ + Mixin to add FSM functionality to existing Label Studio models. + + This mixin can be added to Task, Annotation, and Project models to provide + convenient methods for state management without modifying the core models. + + Example usage in Enterprise: + # In LSE models.py: + from label_studio.fsm.integration import FSMIntegrationMixin + from label_studio.tasks.models import Task as CoreTask + + class Task(FSMIntegrationMixin, CoreTask): + class Meta: + proxy = True + """ + + @property + def current_fsm_state(self) -> Optional[str]: + """Get current FSM state for this entity""" + StateManager = get_state_manager() + return StateManager.get_current_state(self) + + def transition_fsm_state( + self, new_state: str, user=None, transition_name: str = None, reason: str = '', context: dict = None + ) -> bool: + """ + Transition this entity to a new FSM state. + + Args: + new_state: Target state + user: User triggering the transition + transition_name: Name of transition method + reason: Human-readable reason + context: Additional context data + + Returns: + True if transition succeeded + """ + StateManager = get_state_manager() + return StateManager.transition_state( + entity=self, + new_state=new_state, + user=user, + transition_name=transition_name, + reason=reason, + context=context or {}, + ) + + def get_fsm_state_history(self, limit: int = 100): + """Get FSM state history for this entity""" + StateManager = get_state_manager() + return StateManager.get_state_history(self, limit) + + def is_in_fsm_state(self, state: str) -> bool: + """Check if entity is currently in the specified state""" + return self.current_fsm_state == state + + def has_fsm_state_history(self) -> bool: + """Check if entity has any FSM state records""" + return self.current_fsm_state is not None + + +def add_fsm_to_model(model_class): + """ + Class decorator to add FSM functionality to existing models. + + This provides an alternative to inheritance for adding FSM capabilities. + + Example: + from label_studio.fsm.integration import add_fsm_to_model + from label_studio.tasks.models import Task + + @add_fsm_to_model + class Task(Task): + class Meta: + proxy = True + """ + + def current_fsm_state_property(self): + """Get current FSM state for this entity""" + StateManager = get_state_manager() + return StateManager.get_current_state(self) + + def transition_fsm_state_method( + self, new_state: str, user=None, transition_name: str = None, reason: str = '', context: dict = None + ): + """Transition this entity to a new FSM state""" + StateManager = get_state_manager() + return StateManager.transition_state( + entity=self, + new_state=new_state, + user=user, + transition_name=transition_name, + reason=reason, + context=context or {}, + ) + + def get_fsm_state_history_method(self, limit: int = 100): + """Get FSM state history for this entity""" + StateManager = get_state_manager() + return StateManager.get_state_history(self, limit) + + # Add methods as properties/methods to the class + model_class.current_fsm_state = property(current_fsm_state_property) + model_class.transition_fsm_state = transition_fsm_state_method + model_class.get_fsm_state_history = get_fsm_state_history_method + + return model_class + + +# Signal handlers for automatic state transitions + + +def handle_task_created(sender, instance, created, **kwargs): + """ + Signal handler to automatically create initial state when a task is created. + + Connect this to the Task model's post_save signal: + from django.db.models.signals import post_save + from label_studio.tasks.models import Task + from label_studio.fsm.integration import handle_task_created + + post_save.connect(handle_task_created, sender=Task) + """ + if created: + try: + StateManager = get_state_manager() + StateManager.transition_state( + entity=instance, + new_state='CREATED', + transition_name='create_task', + reason='Task created automatically', + ) + logger.info(f'Created initial FSM state for Task {instance.pk}') + except Exception as e: + logger.error(f'Failed to create initial FSM state for Task {instance.pk}: {e}') + + +def handle_annotation_created(sender, instance, created, **kwargs): + """ + Signal handler to automatically create initial state when an annotation is created. + + Connect this to the Annotation model's post_save signal: + from django.db.models.signals import post_save + from label_studio.tasks.models import Annotation + from label_studio.fsm.integration import handle_annotation_created + + post_save.connect(handle_annotation_created, sender=Annotation) + """ + if created: + try: + StateManager = get_state_manager() + StateManager.transition_state( + entity=instance, + new_state='DRAFT', + transition_name='create_annotation', + reason='Annotation created automatically', + ) + logger.info(f'Created initial FSM state for Annotation {instance.pk}') + except Exception as e: + logger.error(f'Failed to create initial FSM state for Annotation {instance.pk}: {e}') + + +def handle_project_created(sender, instance, created, **kwargs): + """ + Signal handler to automatically create initial state when a project is created. + + Connect this to the Project model's post_save signal: + from django.db.models.signals import post_save + from label_studio.projects.models import Project + from label_studio.fsm.integration import handle_project_created + + post_save.connect(handle_project_created, sender=Project) + """ + if created: + try: + StateManager = get_state_manager() + StateManager.transition_state( + entity=instance, + new_state='CREATED', + transition_name='create_project', + reason='Project created automatically', + ) + logger.info(f'Created initial FSM state for Project {instance.pk}') + except Exception as e: + logger.error(f'Failed to create initial FSM state for Project {instance.pk}: {e}') + + +# Utility functions for model extensions + + +def get_entities_by_state(model_class, state: str, limit: int = 100): + """ + Get entities that are currently in a specific state. + + Args: + model_class: Django model class (e.g., Task, Annotation, Project) + state: State to filter by + limit: Maximum number of entities to return + + Returns: + QuerySet of entities in the specified state + + Example: + from label_studio.tasks.models import Task + from label_studio.fsm.integration import get_entities_by_state + + completed_tasks = get_entities_by_state(Task, 'COMPLETED', limit=50) + """ + from .models import get_state_model_for_entity + + # Create a dummy instance to get the state model + dummy_instance = model_class() + state_model = get_state_model_for_entity(dummy_instance) + + if not state_model: + return model_class.objects.none() + + # Get entity IDs that have the specified current state + f'{model_class._meta.model_name.lower()}_id' + + # This is a simplified version - Enterprise can optimize with window functions + current_state_subquery = ( + state_model.objects.filter(**{f'{model_class._meta.model_name.lower()}__pk': models.OuterRef('pk')}) + .order_by('-id') + .values('state')[:1] + ) + + return model_class.objects.annotate(current_state=models.Subquery(current_state_subquery)).filter( + current_state=state + )[:limit] + + +def bulk_transition_entities(entities, new_state: str, user=None, **kwargs): + """ + Bulk transition multiple entities to the same state. + + Basic implementation that Enterprise can optimize with bulk operations. + + Args: + entities: List of entity instances + new_state: Target state for all entities + user: User triggering the transitions + **kwargs: Additional arguments for transition_state + + Returns: + List of (entity, success) tuples + """ + StateManager = get_state_manager() + results = [] + + for entity in entities: + try: + success = StateManager.transition_state(entity=entity, new_state=new_state, user=user, **kwargs) + results.append((entity, success)) + except Exception as e: + logger.error(f'Failed to transition {entity._meta.model_name} {entity.pk}: {e}') + results.append((entity, False)) + + return results diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py new file mode 100644 index 000000000000..9cc43da2a8e7 --- /dev/null +++ b/label_studio/fsm/migrations/0001_initial.py @@ -0,0 +1,306 @@ +# Generated by Django 4.2.16 on 2024-01-15 12:00 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + +import label_studio.fsm.utils + + +class Migration(migrations.Migration): + """ + Initial migration for core FSM functionality in Label Studio. + + Creates the base state tracking infrastructure with UUID7 optimization + for high-performance time-series data. + """ + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('tasks', '0055_task_proj_octlen_idx_async'), # Latest task migration + ('projects', '0030_project_search_vector_index'), # Latest project migration + ] + + operations = [ + migrations.CreateModel( + name='TaskState', + fields=[ + ( + 'id', + models.UUIDField( + default=label_studio.fsm.utils.generate_uuid7, + editable=False, + help_text='UUID7 provides natural time ordering and global uniqueness', + primary_key=True, + serialize=False, + ), + ), + ('state', models.CharField( + choices=[ + ('CREATED', 'Created'), + ('IN_PROGRESS', 'In Progress'), + ('COMPLETED', 'Completed') + ], + db_index=True, + help_text='Current state of the entity', + max_length=50 + )), + ( + 'previous_state', + models.CharField( + blank=True, + help_text='Previous state before this transition', + max_length=50, + null=True, + ), + ), + ( + 'transition_name', + models.CharField( + blank=True, + help_text='Name of the transition method that triggered this state change', + max_length=100, + null=True, + ), + ), + ( + 'context_data', + models.JSONField( + default=dict, + help_text='Additional context data for this transition (e.g., validation results, external IDs)', + ), + ), + ( + 'reason', + models.TextField( + blank=True, help_text='Human-readable reason for this state transition' + ), + ), + ( + 'created_at', + models.DateTimeField( + auto_now_add=True, + db_index=False, + help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', + ), + ), + ( + 'task', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='fsm_states', + to='tasks.task', + ), + ), + ( + 'triggered_by', + models.ForeignKey( + help_text='User who triggered this state transition', + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + 'db_table': 'fsm_task_states', + 'ordering': ['-id'], + 'get_latest_by': 'id', + }, + ), + migrations.CreateModel( + name='ProjectState', + fields=[ + ( + 'id', + models.UUIDField( + default=label_studio.fsm.utils.generate_uuid7, + editable=False, + help_text='UUID7 provides natural time ordering and global uniqueness', + primary_key=True, + serialize=False, + ), + ), + ('state', models.CharField( + choices=[ + ('CREATED', 'Created'), + ('PUBLISHED', 'Published'), + ('IN_PROGRESS', 'In Progress'), + ('COMPLETED', 'Completed') + ], + db_index=True, + help_text='Current state of the entity', + max_length=50 + )), + ( + 'previous_state', + models.CharField( + blank=True, + help_text='Previous state before this transition', + max_length=50, + null=True, + ), + ), + ( + 'transition_name', + models.CharField( + blank=True, + help_text='Name of the transition method that triggered this state change', + max_length=100, + null=True, + ), + ), + ( + 'context_data', + models.JSONField( + default=dict, + help_text='Additional context data for this transition (e.g., validation results, external IDs)', + ), + ), + ( + 'reason', + models.TextField( + blank=True, help_text='Human-readable reason for this state transition' + ), + ), + ( + 'created_at', + models.DateTimeField( + auto_now_add=True, + db_index=False, + help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', + ), + ), + ( + 'project', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='fsm_states', + to='projects.project', + ), + ), + ( + 'triggered_by', + models.ForeignKey( + help_text='User who triggered this state transition', + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + 'db_table': 'fsm_project_states', + 'ordering': ['-id'], + 'get_latest_by': 'id', + }, + ), + migrations.CreateModel( + name='AnnotationState', + fields=[ + ( + 'id', + models.UUIDField( + default=label_studio.fsm.utils.generate_uuid7, + editable=False, + help_text='UUID7 provides natural time ordering and global uniqueness', + primary_key=True, + serialize=False, + ), + ), + ('state', models.CharField( + choices=[ + ('DRAFT', 'Draft'), + ('SUBMITTED', 'Submitted'), + ('COMPLETED', 'Completed') + ], + db_index=True, + help_text='Current state of the entity', + max_length=50 + )), + ( + 'previous_state', + models.CharField( + blank=True, + help_text='Previous state before this transition', + max_length=50, + null=True, + ), + ), + ( + 'transition_name', + models.CharField( + blank=True, + help_text='Name of the transition method that triggered this state change', + max_length=100, + null=True, + ), + ), + ( + 'context_data', + models.JSONField( + default=dict, + help_text='Additional context data for this transition (e.g., validation results, external IDs)', + ), + ), + ( + 'reason', + models.TextField( + blank=True, help_text='Human-readable reason for this state transition' + ), + ), + ( + 'created_at', + models.DateTimeField( + auto_now_add=True, + db_index=False, + help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', + ), + ), + ( + 'annotation', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name='fsm_states', + to='tasks.annotation', + ), + ), + ( + 'triggered_by', + models.ForeignKey( + help_text='User who triggered this state transition', + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + 'db_table': 'fsm_annotation_states', + 'ordering': ['-id'], + 'get_latest_by': 'id', + }, + ), + # Create indexes for optimal performance with UUID7 + migrations.RunSQL( + sql=[ + # Task state indexes - critical for current state lookups + "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_task_current_state_idx ON fsm_task_states (task_id, id DESC);", + "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_task_history_idx ON fsm_task_states (task_id, id);", + + # Annotation state indexes + "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_anno_current_state_idx ON fsm_annotation_states (annotation_id, id DESC);", + + # Project state indexes + "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_proj_current_state_idx ON fsm_project_states (project_id, id DESC);", + ], + reverse_sql=[ + "DROP INDEX IF EXISTS fsm_task_current_state_idx;", + "DROP INDEX IF EXISTS fsm_task_history_idx;", + "DROP INDEX IF EXISTS fsm_anno_current_state_idx;", + "DROP INDEX IF EXISTS fsm_proj_current_state_idx;", + ] + ), + ] \ No newline at end of file diff --git a/label_studio/fsm/migrations/__init__.py b/label_studio/fsm/migrations/__init__.py new file mode 100644 index 000000000000..cbd9975c3b07 --- /dev/null +++ b/label_studio/fsm/migrations/__init__.py @@ -0,0 +1 @@ +# Django migrations package \ No newline at end of file diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py new file mode 100644 index 000000000000..901e64559d56 --- /dev/null +++ b/label_studio/fsm/models.py @@ -0,0 +1,334 @@ +""" +Core FSM models for Label Studio. + +Provides the base infrastructure for state tracking that can be extended +by Label Studio Enterprise and other applications. +""" + +from datetime import datetime +from typing import Optional + +from django.conf import settings +from django.db import models +from django.db.models import UUIDField + +from .state_choices import ( + AnnotationStateChoices, + ProjectStateChoices, + TaskStateChoices, +) +from .utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 + + +class BaseState(models.Model): + """ + Abstract base class for all state models using UUID7 for optimal time-series performance. + + This is the core of the FSM system, providing: + - UUID7 primary key with natural time ordering + - Standard state transition metadata + - Audit trail information + - Context data storage + - Performance-optimized helper methods + + Benefits of this architecture: + - INSERT-only operations for maximum concurrency + - Natural time ordering eliminates need for created_at indexes + - Global uniqueness enables distributed system support + - Time-based partitioning for billion-record scalability + - Complete audit trail by design + + Usage: + # In Label Studio Enterprise: + class TaskState(BaseState): + task = models.ForeignKey('tasks.Task', ...) + state = models.CharField(choices=EnterpriseTaskStateChoices.choices, ...) + # Additional enterprise-specific fields + """ + + # UUID7 Primary Key - provides natural time ordering and global uniqueness + id = UUIDField( + primary_key=True, + default=generate_uuid7, + editable=False, + help_text='UUID7 provides natural time ordering and global uniqueness', + ) + + # Core State Fields + state = models.CharField(max_length=50, db_index=True, help_text='Current state of the entity') + previous_state = models.CharField( + max_length=50, null=True, blank=True, help_text='Previous state before this transition' + ) + + # Transition Metadata + transition_name = models.CharField( + max_length=100, + null=True, + blank=True, + help_text='Name of the transition method that triggered this state change', + ) + triggered_by = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.SET_NULL, + null=True, + help_text='User who triggered this state transition', + ) + + # Context & Audit + context_data = models.JSONField( + default=dict, help_text='Additional context data for this transition (e.g., validation results, external IDs)' + ) + reason = models.TextField(blank=True, help_text='Human-readable reason for this state transition') + + # Timestamp (redundant with UUID7 but useful for human readability) + created_at = models.DateTimeField( + auto_now_add=True, + db_index=False, # UUID7 provides natural ordering, no index needed + help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', + ) + + class Meta: + abstract = True + # UUID7 provides natural ordering, reducing index requirements + ordering = ['-id'] # Most recent first + get_latest_by = 'id' + + def __str__(self): + entity_id = getattr(self, f'{self._get_entity_name()}_id', 'unknown') + return f'{self._get_entity_name().title()} {entity_id}: {self.previous_state} → {self.state}' + + @property + def entity(self): + """Get the related entity object""" + entity_name = self._get_entity_name() + return getattr(self, entity_name) + + @property + def timestamp_from_uuid(self) -> datetime: + """Extract timestamp from UUID7 ID""" + return timestamp_from_uuid7(self.id) + + @property + def is_terminal_state(self) -> bool: + """ + Check if this is a terminal state (no outgoing transitions). + + Override in subclasses with specific terminal states. + """ + return False + + def _get_entity_name(self) -> str: + """Extract entity name from model name (e.g., TaskState → task)""" + model_name = self.__class__.__name__ + if model_name.endswith('State'): + return model_name[:-5].lower() + return 'entity' + + @classmethod + def get_current_state(cls, entity) -> Optional['BaseState']: + """ + Get current state using UUID7 natural ordering. + + Uses UUID7's natural time ordering to efficiently find the latest state + without requiring created_at indexes or complex queries. + """ + entity_field = f'{cls._get_entity_field_name()}' + return cls.objects.filter(**{entity_field: entity}).order_by('-id').first() + + @classmethod + def get_current_state_value(cls, entity) -> Optional[str]: + """Get current state value as string""" + current_state = cls.get_current_state(entity) + return current_state.state if current_state else None + + @classmethod + def get_state_history(cls, entity, limit: int = 100): + """Get complete state history for an entity""" + entity_field = f'{cls._get_entity_field_name()}' + return cls.objects.filter(**{entity_field: entity}).order_by('-id')[:limit] + + @classmethod + def get_states_in_range(cls, entity, start_time: datetime, end_time: datetime): + """ + Efficient time-range queries using UUID7. + + Uses UUID7's embedded timestamp for direct time-based filtering + without requiring timestamp indexes. + """ + entity_field = f'{cls._get_entity_field_name()}' + queryset = cls.objects.filter(**{entity_field: entity}) + return UUID7Field.filter_by_time_range(queryset, start_time, end_time).order_by('id') + + @classmethod + def get_states_since(cls, entity, since: datetime): + """Get all states since a specific timestamp""" + entity_field = f'{cls._get_entity_field_name()}' + queryset = cls.objects.filter(**{entity_field: entity}) + return UUID7Field.filter_since_time(queryset, since).order_by('id') + + @classmethod + def _get_entity_field_name(cls) -> str: + """Get the foreign key field name for the entity""" + model_name = cls.__name__ + if model_name.endswith('State'): + return model_name[:-5].lower() + return 'entity' + + +# Core state models for basic Label Studio entities +# These provide the foundation that Enterprise can extend + + +class TaskState(BaseState): + """ + Core task state tracking for Label Studio. + + Provides basic task state management with: + - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) + - High-performance queries with UUID7 ordering + - Extensible design for enterprise features + + Label Studio Enterprise extends this with: + - Additional workflow states (review, arbitration) + - Denormalized fields for performance + - Advanced state transition logic + """ + + # Entity Relationship + task = models.ForeignKey('tasks.Task', on_delete=models.CASCADE, related_name='fsm_states') + + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=TaskStateChoices.choices, db_index=True) + + class Meta: + db_table = 'fsm_task_states' + indexes = [ + # Critical: Latest state lookup using UUID7 ordering + models.Index(fields=['task_id', '-id'], name='fsm_task_current_state_idx'), + # History queries + models.Index(fields=['task_id', 'id'], name='fsm_task_history_idx'), + ] + ordering = ['-id'] + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal task state""" + return self.state == TaskStateChoices.COMPLETED + + +class AnnotationState(BaseState): + """ + Core annotation state tracking for Label Studio. + + Provides basic annotation state management with: + - Simple 3-state workflow (DRAFT → SUBMITTED → COMPLETED) + - Draft and submission tracking + - Extensible design for enterprise review workflows + """ + + # Entity Relationship + annotation = models.ForeignKey('tasks.Annotation', on_delete=models.CASCADE, related_name='fsm_states') + + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=AnnotationStateChoices.choices, db_index=True) + + class Meta: + db_table = 'fsm_annotation_states' + indexes = [ + # Critical: Latest state lookup + models.Index(fields=['annotation_id', '-id'], name='fsm_anno_current_state_idx'), + ] + ordering = ['-id'] + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal annotation state""" + return self.state == AnnotationStateChoices.COMPLETED + + +class ProjectState(BaseState): + """ + Core project state tracking for Label Studio. + + Provides basic project state management with: + - Simple 4-state workflow (CREATED → PUBLISHED → IN_PROGRESS → COMPLETED) + - Project lifecycle tracking + - Extensible design for enterprise features + """ + + # Entity Relationship + project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='fsm_states') + + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) + + class Meta: + db_table = 'fsm_project_states' + indexes = [ + # Critical: Latest state lookup + models.Index(fields=['project_id', '-id'], name='fsm_proj_current_state_idx'), + ] + ordering = ['-id'] + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal project state""" + return self.state == ProjectStateChoices.COMPLETED + + +# Registry for dynamic state model extension +# Enterprise can register additional state models here +STATE_MODEL_REGISTRY = { + 'task': TaskState, + 'annotation': AnnotationState, + 'project': ProjectState, +} + + +def register_state_model(entity_name: str, model_class): + """ + Register state model for an entity type. + + This allows Label Studio Enterprise to register additional state models + or override existing ones with enterprise-specific implementations. + + Args: + entity_name: Name of the entity (e.g., 'review', 'assignment') + model_class: Django model class inheriting from BaseState + + Example: + # In LSE code: + register_state_model('review', AnnotationReviewState) + register_state_model('assignment', TaskAssignmentState) + + # Override core model with enterprise version: + register_state_model('task', EnterpriseTaskState) + """ + STATE_MODEL_REGISTRY[entity_name.lower()] = model_class + + +def get_state_model(entity_name: str): + """ + Get state model for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + Django model class inheriting from BaseState, or None if not found + """ + return STATE_MODEL_REGISTRY.get(entity_name.lower()) + + +def get_state_model_for_entity(entity): + """ + Get state model for a specific entity instance. + + Args: + entity: Django model instance + + Returns: + Django model class inheriting from BaseState, or None if not found + """ + entity_name = entity._meta.model_name.lower() + return get_state_model(entity_name) diff --git a/label_studio/fsm/serializers.py b/label_studio/fsm/serializers.py new file mode 100644 index 000000000000..9f2b920d7d71 --- /dev/null +++ b/label_studio/fsm/serializers.py @@ -0,0 +1,87 @@ +""" +Core FSM serializers for Label Studio. + +Provides basic serializers for state management API that can be extended +by Label Studio Enterprise with additional functionality. +""" + +from rest_framework import serializers + + +class StateHistorySerializer(serializers.Serializer): + """ + Serializer for state history records. + + Provides basic state history information that can be extended + by Enterprise with additional fields. + """ + + id = serializers.UUIDField(read_only=True) + state = serializers.CharField(read_only=True) + previous_state = serializers.CharField(read_only=True, allow_null=True) + transition_name = serializers.CharField(read_only=True, allow_null=True) + triggered_by = serializers.SerializerMethodField() + created_at = serializers.DateTimeField(read_only=True) + reason = serializers.CharField(read_only=True) + context_data = serializers.SerializerMethodField() + + def get_triggered_by(self, obj): + """Get user who triggered the transition""" + if obj.triggered_by: + return { + 'id': obj.triggered_by.id, + 'email': obj.triggered_by.email, + 'first_name': getattr(obj.triggered_by, 'first_name', ''), + 'last_name': getattr(obj.triggered_by, 'last_name', ''), + } + return None + + def get_context_data(self, obj): + """Include context data if requested""" + include_context = self.context.get('include_context', False) + if include_context: + return obj.context_data + return None + + +class StateTransitionSerializer(serializers.Serializer): + """ + Serializer for state transition requests. + + Validates state transition request data. + """ + + new_state = serializers.CharField(required=True, help_text='Target state to transition to') + transition_name = serializers.CharField( + required=False, allow_blank=True, help_text='Name of the transition method (for audit trail)' + ) + reason = serializers.CharField( + required=False, allow_blank=True, help_text='Human-readable reason for the transition' + ) + context = serializers.JSONField( + required=False, default=dict, help_text='Additional context data for the transition' + ) + + def validate_new_state(self, value): + """Validate that new_state is not empty""" + if not value or not value.strip(): + raise serializers.ValidationError('new_state cannot be empty') + return value.strip().upper() + + +class StateInfoSerializer(serializers.Serializer): + """ + Serializer for basic state information. + + Used for current state responses. + """ + + current_state = serializers.CharField(allow_null=True) + entity_type = serializers.CharField() + entity_id = serializers.IntegerField() + + # Optional fields that Enterprise can populate + available_transitions = serializers.ListField( + child=serializers.CharField(), required=False, help_text='List of valid transitions from current state' + ) + state_metadata = serializers.JSONField(required=False, help_text='Additional metadata about the current state') diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py new file mode 100644 index 000000000000..428d5b71b2a4 --- /dev/null +++ b/label_studio/fsm/state_choices.py @@ -0,0 +1,139 @@ +""" +Core state choice enums for Label Studio entities. + +These enums define the essential states for core Label Studio entities. +Label Studio Enterprise can extend these with additional states or +define entirely new state enums for enterprise-specific entities. +""" + +from django.db import models +from django.utils.translation import gettext_lazy as _ + + +class TaskStateChoices(models.TextChoices): + """ + Core task states for basic Label Studio workflow. + + Simplified states covering the essential task lifecycle: + - Creation and assignment + - Annotation work + - Completion + + Enterprise can extend with review, arbitration, and advanced workflow states. + """ + + # Initial State + CREATED = 'CREATED', _('Created') + + # Work States + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') + + +class AnnotationStateChoices(models.TextChoices): + """ + Core annotation states for basic Label Studio workflow. + + Simplified states covering the essential annotation lifecycle: + - Draft work + - Submission + - Completion + + Enterprise can extend with review, approval, and rejection states. + """ + + # Working States + DRAFT = 'DRAFT', _('Draft') + SUBMITTED = 'SUBMITTED', _('Submitted') + + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') + + +class ProjectStateChoices(models.TextChoices): + """ + Core project states for basic Label Studio workflow. + + Simplified states covering the essential project lifecycle: + - Setup and configuration + - Active work + - Completion + + Enterprise can extend with advanced workflow, review, and approval states. + """ + + # Setup States + CREATED = 'CREATED', _('Created') + PUBLISHED = 'PUBLISHED', _('Published') + + # Work States + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') + + +# Registry for dynamic state choices extension +# Enterprise can register additional choices here +STATE_CHOICES_REGISTRY = { + 'task': TaskStateChoices, + 'annotation': AnnotationStateChoices, + 'project': ProjectStateChoices, +} + + +def register_state_choices(entity_name: str, choices_class): + """ + Register state choices for an entity type. + + This allows Label Studio Enterprise and other extensions to register + their own state choices dynamically. + + Args: + entity_name: Name of the entity (e.g., 'review', 'assignment') + choices_class: Django TextChoices class defining valid states + + Example: + # In LSE code: + register_state_choices('review', ReviewStateChoices) + register_state_choices('assignment', AssignmentStateChoices) + """ + STATE_CHOICES_REGISTRY[entity_name.lower()] = choices_class + + +def get_state_choices(entity_name: str): + """ + Get state choices for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + Django TextChoices class or None if not found + """ + return STATE_CHOICES_REGISTRY.get(entity_name.lower()) + + +# State complexity metrics for core entities +CORE_STATE_COMPLEXITY_METRICS = { + 'TaskStateChoices': { + 'total_states': len(TaskStateChoices.choices), + 'complexity_score': 1.0, # Simple linear flow + 'terminal_states': ['COMPLETED'], + 'entry_states': ['CREATED'], + }, + 'AnnotationStateChoices': { + 'total_states': len(AnnotationStateChoices.choices), + 'complexity_score': 1.0, # Simple linear flow + 'terminal_states': ['COMPLETED'], + 'entry_states': ['DRAFT'], + }, + 'ProjectStateChoices': { + 'total_states': len(ProjectStateChoices.choices), + 'complexity_score': 1.0, # Simple linear flow + 'terminal_states': ['COMPLETED'], + 'entry_states': ['CREATED'], + }, +} diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py new file mode 100644 index 000000000000..c735ab7acb5c --- /dev/null +++ b/label_studio/fsm/state_manager.py @@ -0,0 +1,333 @@ +""" +Core state management functionality for Label Studio. + +Provides high-performance state management with caching and batch operations +that can be extended by Label Studio Enterprise with additional features. +""" + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Type + +from django.conf import settings +from django.core.cache import cache +from django.db import transaction +from django.db.models import Model + +from .models import BaseState, get_state_model_for_entity + +logger = logging.getLogger(__name__) + + +class StateManagerError(Exception): + """Base exception for StateManager operations""" + + pass + + +class InvalidTransitionError(StateManagerError): + """Raised when an invalid state transition is attempted""" + + pass + + +class StateManager: + """ + Core state management system for Label Studio. + + Provides the foundation for state management that can be extended + by Label Studio Enterprise with additional features like: + - Advanced caching strategies + - Bulk operations optimization + - Complex transition validation + - Enterprise-specific state models + + Features: + - INSERT-only architecture with UUID7 for maximum performance + - Basic caching for current state lookups + - Simple state transitions with audit trails + - Extensible design for enterprise features + """ + + CACHE_TTL = getattr(settings, 'FSM_CACHE_TTL', 300) # 5 minutes default + CACHE_PREFIX = 'fsm:current' + + @classmethod + def get_cache_key(cls, entity: Model) -> str: + """Generate cache key for entity's current state""" + return f'{cls.CACHE_PREFIX}:{entity._meta.label_lower}:{entity.pk}' + + @classmethod + def get_current_state(cls, entity: Model) -> Optional[str]: + """ + Get current state with basic caching. + + Args: + entity: The entity to get current state for + + Returns: + Current state string or None if no states exist + + Example: + task = Task.objects.get(id=123) + current_state = StateManager.get_current_state(task) + if current_state == 'COMPLETED': + # Task is finished + pass + """ + cache_key = cls.get_cache_key(entity) + + # Try cache first + cached_state = cache.get(cache_key) + if cached_state is not None: + logger.debug(f'Cache hit for {entity._meta.label_lower} {entity.pk}: {cached_state}') + return cached_state + + # Query database using state model registry + state_model = get_state_model_for_entity(entity) + if not state_model: + logger.warning(f'No state model found for {entity._meta.model_name}') + return None + + try: + entity_field = f'{entity._meta.model_name}' + current_state = ( + state_model.objects.filter(**{entity_field: entity}) + .order_by('-id') # UUID7 natural ordering + .values_list('state', flat=True) + .first() + ) + + # Cache result + if current_state is not None: + cache.set(cache_key, current_state, cls.CACHE_TTL) + + logger.debug(f'Database query for {entity._meta.label_lower} {entity.pk}: {current_state}') + return current_state + + except Exception as e: + logger.error(f'Error getting current state for {entity._meta.label_lower} {entity.pk}: {e}') + return None + + @classmethod + def get_current_state_object(cls, entity: Model) -> Optional[BaseState]: + """ + Get current state object with full audit information. + + Args: + entity: The entity to get current state object for + + Returns: + Latest BaseState instance or None if no states exist + """ + state_model = get_state_model_for_entity(entity) + if not state_model: + return None + + entity_field = f'{entity._meta.model_name}' + return state_model.objects.filter(**{entity_field: entity}).order_by('-id').first() + + @classmethod + def transition_state( + cls, + entity: Model, + new_state: str, + transition_name: str = None, + user=None, + context: Dict[str, Any] = None, + reason: str = '', + ) -> bool: + """ + Perform state transition with audit trail. + + Uses INSERT-only approach for maximum performance: + - No UPDATE operations or row locks + - Complete audit trail by design + - Basic cache update for consistency + + Args: + entity: The entity to transition + new_state: Target state + transition_name: Name of transition method (for audit) + user: User triggering the transition + context: Additional context data + reason: Human-readable reason for transition + + Returns: + True if transition succeeded, False otherwise + + Raises: + InvalidTransitionError: If transition is not valid + StateManagerError: If transition fails + + Example: + success = StateManager.transition_state( + entity=task, + new_state='IN_PROGRESS', + transition_name='start_annotation', + user=request.user, + context={'assignment_id': assignment.id}, + reason='User started annotation work' + ) + """ + state_model = get_state_model_for_entity(entity) + if not state_model: + raise StateManagerError(f'No state model found for {entity._meta.model_name}') + + current_state = cls.get_current_state(entity) + + logger.info( + f'Transitioning {entity._meta.label_lower} {entity.pk}: ' + f'{current_state} → {new_state} (transition: {transition_name})' + ) + + try: + with transaction.atomic(): + # INSERT-only approach - no UPDATE operations needed + new_state_record = state_model.objects.create( + **{entity._meta.model_name: entity}, + state=new_state, + previous_state=current_state, + transition_name=transition_name, + triggered_by=user, + context_data=context or {}, + reason=reason, + # Note: Denormalized fields would be added here by Enterprise + ) + + # Update cache with new state + cache_key = cls.get_cache_key(entity) + cache.set(cache_key, new_state, cls.CACHE_TTL) + + logger.info( + f'State transition successful: {entity._meta.label_lower} {entity.pk} ' + f'now in state {new_state} (record ID: {new_state_record.id})' + ) + return True + + except Exception as e: + # On failure, invalidate potentially stale cache + cache_key = cls.get_cache_key(entity) + cache.delete(cache_key) + logger.error( + f'State transition failed for {entity._meta.label_lower} {entity.pk}: ' + f'{current_state} → {new_state}: {e}' + ) + raise StateManagerError(f'Failed to transition state: {e}') from e + + @classmethod + def get_state_history(cls, entity: Model, limit: int = 100) -> List[BaseState]: + """ + Get complete state history for an entity. + + Args: + entity: Entity to get history for + limit: Maximum number of state records to return + + Returns: + List of state records ordered by most recent first + """ + state_model = get_state_model_for_entity(entity) + if not state_model: + return [] + + entity_field = f'{entity._meta.model_name}' + return list(state_model.objects.filter(**{entity_field: entity}).order_by('-id')[:limit]) + + @classmethod + def get_states_in_time_range( + cls, entity: Model, start_time: datetime, end_time: Optional[datetime] = None + ) -> List[BaseState]: + """ + Get states within a time range using UUID7 time-based queries. + + Args: + entity: Entity to get states for + start_time: Start of time range + end_time: End of time range (defaults to now) + + Returns: + List of states within the time range + """ + state_model = get_state_model_for_entity(entity) + if not state_model: + return [] + + return list(state_model.get_states_in_range(entity, start_time, end_time or datetime.now())) + + @classmethod + def invalidate_cache(cls, entity: Model): + """Invalidate cached state for an entity""" + cache_key = cls.get_cache_key(entity) + cache.delete(cache_key) + logger.debug(f'Invalidated cache for {entity._meta.label_lower} {entity.pk}') + + @classmethod + def warm_cache(cls, entities: List[Model]): + """ + Warm cache with current states for a list of entities. + + Basic implementation that can be optimized by Enterprise with + bulk queries and advanced caching strategies. + """ + cache_updates = {} + for entity in entities: + current_state = cls.get_current_state(entity) + if current_state: + cache_key = cls.get_cache_key(entity) + cache_updates[cache_key] = current_state + + if cache_updates: + cache.set_many(cache_updates, cls.CACHE_TTL) + logger.debug(f'Warmed cache for {len(cache_updates)} entities') + + +# Extension point for Label Studio Enterprise +# Enterprise can subclass this and add advanced features +class ExtendedStateManager(StateManager): + """ + Extension point for Label Studio Enterprise. + + Enterprise can override this class to add: + - Bulk operations with window functions + - Advanced caching strategies + - Complex transition validation + - Enterprise-specific optimizations + + Example Enterprise usage: + class EnterpriseStateManager(ExtendedStateManager): + @classmethod + def bulk_get_states(cls, entities): + # Enterprise-specific bulk optimization + return super().bulk_get_states_optimized(entities) + + @classmethod + def transition_state(cls, entity, new_state, **kwargs): + # Enterprise transition validation + cls.validate_enterprise_transition(entity, new_state) + return super().transition_state(entity, new_state, **kwargs) + """ + + pass + + +# Allow runtime configuration of which StateManager to use +# Enterprise can set this to their extended implementation +DEFAULT_STATE_MANAGER = StateManager + + +def get_state_manager() -> Type[StateManager]: + """ + Get the configured state manager class. + + Returns the StateManager class to use. Enterprise can override + this by setting a different class in their configuration. + """ + # Check if enterprise has configured a custom state manager + if hasattr(settings, 'FSM_STATE_MANAGER_CLASS'): + manager_path = settings.FSM_STATE_MANAGER_CLASS + module_name, class_name = manager_path.rsplit('.', 1) + module = __import__(module_name, fromlist=[class_name]) + return getattr(module, class_name) + + return DEFAULT_STATE_MANAGER diff --git a/label_studio/fsm/tests/__init__.py b/label_studio/fsm/tests/__init__.py new file mode 100644 index 000000000000..45c22224aea1 --- /dev/null +++ b/label_studio/fsm/tests/__init__.py @@ -0,0 +1 @@ +# FSM Tests Package diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py new file mode 100644 index 000000000000..74af25d23d38 --- /dev/null +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -0,0 +1,271 @@ +""" +Integration tests for the FSM system. + +Tests the complete FSM functionality including models, state management, +and API endpoints. +""" + +from datetime import datetime, timezone + +from django.contrib.auth import get_user_model +from django.test import TestCase +from rest_framework.test import APITestCase + +from label_studio.projects.models import Project +from label_studio.tasks.models import Annotation, Task + +from ..models import AnnotationState, ProjectState, TaskState +from ..state_manager import get_state_manager + +User = get_user_model() + + +class TestFSMModels(TestCase): + """Test FSM model functionality""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + + def test_task_state_creation(self): + """Test TaskState creation and basic functionality""" + task_state = TaskState.objects.create( + task=self.task, state='CREATED', triggered_by=self.user, reason='Task created for testing' + ) + + # Check basic fields + self.assertEqual(task_state.state, 'CREATED') + self.assertEqual(task_state.task, self.task) + self.assertEqual(task_state.triggered_by, self.user) + + # Check UUID7 functionality + self.assertEqual(task_state.id.version, 7) + self.assertIsInstance(task_state.timestamp_from_uuid, datetime) + + # Check string representation + str_repr = str(task_state) + self.assertIn('Task', str_repr) + self.assertIn('CREATED', str_repr) + + def test_annotation_state_creation(self): + """Test AnnotationState creation and basic functionality""" + annotation = Annotation.objects.create(task=self.task, completed_by=self.user, result=[]) + + annotation_state = AnnotationState.objects.create( + annotation=annotation, state='DRAFT', triggered_by=self.user, reason='Annotation draft created' + ) + + # Check basic fields + self.assertEqual(annotation_state.state, 'DRAFT') + self.assertEqual(annotation_state.annotation, annotation) + + # Check terminal state property + self.assertFalse(annotation_state.is_terminal_state) + + # Test completed state + completed_state = AnnotationState.objects.create( + annotation=annotation, state='COMPLETED', triggered_by=self.user + ) + self.assertTrue(completed_state.is_terminal_state) + + def test_project_state_creation(self): + """Test ProjectState creation and basic functionality""" + project_state = ProjectState.objects.create( + project=self.project, state='CREATED', triggered_by=self.user, reason='Project created for testing' + ) + + # Check basic fields + self.assertEqual(project_state.state, 'CREATED') + self.assertEqual(project_state.project, self.project) + + # Test terminal state + self.assertFalse(project_state.is_terminal_state) + + completed_state = ProjectState.objects.create(project=self.project, state='COMPLETED', triggered_by=self.user) + self.assertTrue(completed_state.is_terminal_state) + + +class TestStateManager(TestCase): + """Test StateManager functionality""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.StateManager = get_state_manager() + + def test_get_current_state_empty(self): + """Test getting current state when no states exist""" + current_state = self.StateManager.get_current_state(self.task) + self.assertIsNone(current_state) + + def test_transition_state(self): + """Test state transition functionality""" + # Initial transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + transition_name='create_task', + reason='Initial task creation', + ) + + self.assertTrue(success) + + # Check current state + current_state = self.StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'CREATED') + + # Another transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='IN_PROGRESS', + user=self.user, + transition_name='start_work', + context={'started_by': 'user'}, + ) + + self.assertTrue(success) + current_state = self.StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'IN_PROGRESS') + + def test_get_current_state_object(self): + """Test getting current state object with full details""" + # Create some state transitions + self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + self.StateManager.transition_state( + entity=self.task, new_state='IN_PROGRESS', user=self.user, context={'test': 'data'} + ) + + current_state_obj = self.StateManager.get_current_state_object(self.task) + + self.assertIsNotNone(current_state_obj) + self.assertEqual(current_state_obj.state, 'IN_PROGRESS') + self.assertEqual(current_state_obj.previous_state, 'CREATED') + self.assertEqual(current_state_obj.triggered_by, self.user) + self.assertEqual(current_state_obj.context_data, {'test': 'data'}) + + def test_get_state_history(self): + """Test state history retrieval""" + # Create multiple transitions + transitions = [('CREATED', 'create_task'), ('IN_PROGRESS', 'start_work'), ('COMPLETED', 'finish_work')] + + for state, transition in transitions: + self.StateManager.transition_state( + entity=self.task, new_state=state, user=self.user, transition_name=transition + ) + + history = self.StateManager.get_state_history(self.task, limit=10) + + # Should have 3 state records + self.assertEqual(len(history), 3) + + # Should be ordered by most recent first (UUID7 ordering) + states = [h.state for h in history] + self.assertEqual(states, ['COMPLETED', 'IN_PROGRESS', 'CREATED']) + + # Check previous states are set correctly + self.assertIsNone(history[2].previous_state) # First state has no previous + self.assertEqual(history[1].previous_state, 'CREATED') + self.assertEqual(history[0].previous_state, 'IN_PROGRESS') + + def test_get_states_in_time_range(self): + """Test time-based state queries using UUID7""" + # Record time before creating states + before_time = datetime.now(timezone.utc) + + # Create some states + self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + self.StateManager.transition_state(entity=self.task, new_state='IN_PROGRESS', user=self.user) + + # Record time after creating states + after_time = datetime.now(timezone.utc) + + # Query states in time range + states_in_range = self.StateManager.get_states_in_time_range(self.task, before_time, after_time) + + # Should find both states + self.assertEqual(len(states_in_range), 2) + + +class TestFSMAPI(APITestCase): + """Test FSM API endpoints""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.client.force_authenticate(user=self.user) + + # Create initial state + StateManager = get_state_manager() + StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + + def test_get_current_state_api(self): + """Test GET /api/fsm/{entity_type}/{entity_id}/current/""" + response = self.client.get(f'/api/fsm/task/{self.task.id}/current/') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertEqual(data['current_state'], 'CREATED') + self.assertEqual(data['entity_type'], 'task') + self.assertEqual(data['entity_id'], self.task.id) + + def test_get_state_history_api(self): + """Test GET /api/fsm/{entity_type}/{entity_id}/history/""" + # Create additional states + StateManager = get_state_manager() + StateManager.transition_state( + entity=self.task, new_state='IN_PROGRESS', user=self.user, transition_name='start_work' + ) + + response = self.client.get(f'/api/fsm/task/{self.task.id}/history/') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertEqual(data['count'], 2) + self.assertEqual(len(data['results']), 2) + + # Check first result (most recent) + latest_state = data['results'][0] + self.assertEqual(latest_state['state'], 'IN_PROGRESS') + self.assertEqual(latest_state['previous_state'], 'CREATED') + self.assertEqual(latest_state['transition_name'], 'start_work') + + def test_transition_state_api(self): + """Test POST /api/fsm/{entity_type}/{entity_id}/transition/""" + transition_data = { + 'new_state': 'IN_PROGRESS', + 'transition_name': 'start_annotation', + 'reason': 'User started working on task', + 'context': {'assignment_id': 123}, + } + + response = self.client.post(f'/api/fsm/task/{self.task.id}/transition/', data=transition_data, format='json') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertTrue(data['success']) + self.assertEqual(data['previous_state'], 'CREATED') + self.assertEqual(data['new_state'], 'IN_PROGRESS') + self.assertEqual(data['entity_type'], 'task') + self.assertEqual(data['entity_id'], self.task.id) + + # Verify state was actually changed + StateManager = get_state_manager() + current_state = StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'IN_PROGRESS') + + def test_api_with_invalid_entity(self): + """Test API with non-existent entity""" + response = self.client.get('/api/fsm/task/99999/current/') + self.assertEqual(response.status_code, 404) + + def test_api_with_invalid_entity_type(self): + """Test API with invalid entity type""" + response = self.client.get('/api/fsm/invalid/1/current/') + self.assertEqual(response.status_code, 404) diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_uuid7_utils.py new file mode 100644 index 000000000000..716f3ea59b2e --- /dev/null +++ b/label_studio/fsm/tests/test_uuid7_utils.py @@ -0,0 +1,164 @@ +""" +Tests for UUID7 utilities in the FSM system. + +Tests the uuid-utils library integration and UUID7 functionality. +""" + +import uuid +from datetime import datetime, timedelta, timezone + +from django.test import TestCase + +from label_studio.fsm.utils import ( + UUID7Generator, + generate_uuid7, + timestamp_from_uuid7, + uuid7_from_timestamp, + uuid7_time_range, + validate_uuid7, +) + + +class TestUUID7Utils(TestCase): + """Test UUID7 utility functions""" + + def test_generate_uuid7(self): + """Test UUID7 generation""" + uuid7_id = generate_uuid7() + + # Check that it's a valid UUID + self.assertIsInstance(uuid7_id, uuid.UUID) + + # Check that it's version 7 + self.assertEqual(uuid7_id.version, 7) + + # Check that it validates as UUID7 + self.assertTrue(validate_uuid7(uuid7_id)) + + def test_uuid7_ordering(self): + """Test that UUID7s have natural time ordering""" + uuid1 = generate_uuid7() + uuid2 = generate_uuid7() + + # UUID7s should be ordered by generation time + self.assertLess(uuid1.int, uuid2.int) + + def test_timestamp_extraction(self): + """Test timestamp extraction from UUID7""" + before = datetime.now(timezone.utc) + uuid7_id = generate_uuid7() + after = datetime.now(timezone.utc) + + extracted_timestamp = timestamp_from_uuid7(uuid7_id) + + # Timestamp should be between before and after + self.assertGreaterEqual(extracted_timestamp, before) + self.assertLessEqual(extracted_timestamp, after) + + def test_uuid7_from_timestamp(self): + """Test creating UUID7 from specific timestamp""" + test_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + uuid7_id = uuid7_from_timestamp(test_time) + + # Should be a valid UUID7 + self.assertTrue(validate_uuid7(uuid7_id)) + + # Extracted timestamp should match (within millisecond precision) + extracted = timestamp_from_uuid7(uuid7_id) + time_diff = abs((extracted - test_time).total_seconds()) + self.assertLess(time_diff, 0.001) # Within 1ms + + def test_uuid7_time_range(self): + """Test UUID7 time range generation""" + start_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc) + + start_uuid, end_uuid = uuid7_time_range(start_time, end_time) + + # Both should be valid UUID7s + self.assertTrue(validate_uuid7(start_uuid)) + self.assertTrue(validate_uuid7(end_uuid)) + + # Start should be less than end + self.assertLess(start_uuid.int, end_uuid.int) + + # Timestamps should match input times + start_extracted = timestamp_from_uuid7(start_uuid) + end_extracted = timestamp_from_uuid7(end_uuid) + + self.assertLess(abs((start_extracted - start_time).total_seconds()), 0.001) + self.assertLess(abs((end_extracted - end_time).total_seconds()), 0.001) + + def test_uuid7_time_range_default_end(self): + """Test UUID7 time range with default end time (now)""" + start_time = datetime.now(timezone.utc) - timedelta(hours=1) + before_call = datetime.now(timezone.utc) + + start_uuid, end_uuid = uuid7_time_range(start_time) + + after_call = datetime.now(timezone.utc) + + # End timestamp should be close to now + end_extracted = timestamp_from_uuid7(end_uuid) + self.assertGreaterEqual(end_extracted, before_call) + self.assertLessEqual(end_extracted, after_call) + + def test_validate_uuid7_with_other_versions(self): + """Test UUID7 validation with other UUID versions""" + # Test with UUID4 + uuid4_id = uuid.uuid4() + self.assertFalse(validate_uuid7(uuid4_id)) + + # Test with UUID7 + uuid7_id = generate_uuid7() + self.assertTrue(validate_uuid7(uuid7_id)) + + +class TestUUID7Generator(TestCase): + """Test UUID7Generator class""" + + def test_generator_basic(self): + """Test basic UUID7 generator functionality""" + generator = UUID7Generator() + + uuid7_id = generator.generate() + self.assertTrue(validate_uuid7(uuid7_id)) + + def test_generator_with_base_timestamp(self): + """Test generator with custom base timestamp""" + base_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + generator = UUID7Generator(base_timestamp=base_time) + + uuid7_id = generator.generate() + extracted = timestamp_from_uuid7(uuid7_id) + + # Should be close to base time + time_diff = abs((extracted - base_time).total_seconds()) + self.assertLess(time_diff, 0.001) + + def test_generator_with_offset(self): + """Test generator with timestamp offset""" + base_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + generator = UUID7Generator(base_timestamp=base_time) + + # Generate UUID with 1 second offset + uuid7_id = generator.generate(offset_ms=1000) + extracted = timestamp_from_uuid7(uuid7_id) + + expected_time = base_time + timedelta(milliseconds=1000) + time_diff = abs((extracted - expected_time).total_seconds()) + self.assertLess(time_diff, 0.001) + + def test_generator_monotonic(self): + """Test that generator produces monotonic UUIDs""" + base_time = datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc) + generator = UUID7Generator(base_timestamp=base_time) + + # Generate multiple UUIDs with same timestamp but different counters + uuid1 = generator.generate(offset_ms=100) + uuid2 = generator.generate(offset_ms=100) + uuid3 = generator.generate(offset_ms=100) + + # Should be monotonic even with same timestamp + self.assertLess(uuid1.int, uuid2.int) + self.assertLess(uuid2.int, uuid3.int) diff --git a/label_studio/fsm/urls.py b/label_studio/fsm/urls.py new file mode 100644 index 000000000000..222292384797 --- /dev/null +++ b/label_studio/fsm/urls.py @@ -0,0 +1,50 @@ +""" +Core FSM URL patterns for Label Studio. + +Provides basic URL routing for state management API that can be extended +by Label Studio Enterprise with additional endpoints. +""" + +from django.urls import include, path +from rest_framework.routers import DefaultRouter + +from .api import FSMViewSet + +# Create router for FSM API endpoints +router = DefaultRouter() +router.register(r'fsm', FSMViewSet, basename='fsm') + +# Core FSM URL patterns +urlpatterns = [ + path('api/', include(router.urls)), +] + +# Extension point for Label Studio Enterprise +# Enterprise can add additional URL patterns here +enterprise_urlpatterns = [] + +# Function to register additional URL patterns from Enterprise +def register_enterprise_urls(patterns): + """ + Register additional URL patterns from Label Studio Enterprise. + + Args: + patterns: List of URL patterns to register + + Example: + # In LSE code: + from label_studio.fsm.urls import register_enterprise_urls + + enterprise_patterns = [ + path('api/fsm/bulk/', BulkFSMViewSet.as_view(), name='fsm-bulk'), + path('api/fsm/analytics/', AnalyticsFSMViewSet.as_view(), name='fsm-analytics'), + ] + register_enterprise_urls(enterprise_patterns) + """ + global enterprise_urlpatterns + enterprise_urlpatterns.extend(patterns) + + +# Include enterprise URL patterns if any are registered +if enterprise_urlpatterns: + urlpatterns.extend(enterprise_urlpatterns) diff --git a/label_studio/fsm/utils.py b/label_studio/fsm/utils.py new file mode 100644 index 000000000000..563b20f29895 --- /dev/null +++ b/label_studio/fsm/utils.py @@ -0,0 +1,189 @@ +""" +UUID7 utilities for time-series optimization. + +UUID7 provides natural time ordering and global uniqueness, making it ideal +for INSERT-only architectures with millions of records. + +Uses the uuid-utils library for RFC 9562 compliant UUID7 generation. +""" + +from datetime import datetime, timezone +from typing import Optional, Tuple + +import uuid_utils as uuid + + +def generate_uuid7() -> uuid.UUID: + """ + Generate a UUID7 with embedded timestamp for natural time ordering. + + UUID7 embeds the timestamp in the first 48 bits, providing: + - Natural chronological ordering without additional indexes + - Global uniqueness across distributed systems + - Time-based partitioning capabilities + + Returns: + UUID7 instance with embedded timestamp + """ + # Use uuid-utils library for RFC 9562 compliant UUID7 generation + return uuid.uuid7() + + +def timestamp_from_uuid7(uuid7_id: uuid.UUID) -> datetime: + """ + Extract timestamp from UUID7 ID. + + Args: + uuid7_id: UUID7 instance to extract timestamp from + + Returns: + datetime: Timestamp embedded in the UUID7 + + Example: + uuid7_id = generate_uuid7() + timestamp = timestamp_from_uuid7(uuid7_id) + # timestamp is when the UUID7 was generated + """ + # UUID7 embeds timestamp in first 48 bits + timestamp_ms = (uuid7_id.int >> 80) & ((1 << 48) - 1) + return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc) + + +def uuid7_time_range(start_time: datetime, end_time: Optional[datetime] = None) -> Tuple[uuid.UUID, uuid.UUID]: + """ + Generate UUID7 range for time-based queries. + + Creates UUID7 boundaries for efficient time-range filtering without + requiring timestamp indexes. + + Args: + start_time: Start of time range + end_time: End of time range (defaults to now) + + Returns: + Tuple of (start_uuid, end_uuid) for range queries + + Example: + start_uuid, end_uuid = uuid7_time_range( + datetime(2024, 1, 1), + datetime(2024, 1, 2) + ) + # Query: WHERE id >= start_uuid AND id <= end_uuid + """ + if end_time is None: + end_time = datetime.now(timezone.utc) + + start_timestamp_ms = int(start_time.timestamp() * 1000) + end_timestamp_ms = int(end_time.timestamp() * 1000) + + # Create UUID7 with specific timestamp and zero random bits for range start + start_uuid = uuid.UUID(int=(start_timestamp_ms << 80), version=7) + + # Create UUID7 with specific timestamp and max random bits for range end + end_uuid = uuid.UUID(int=(end_timestamp_ms << 80) | ((1 << 80) - 1), version=7) + + return start_uuid, end_uuid + + +def uuid7_from_timestamp(timestamp: datetime) -> uuid.UUID: + """ + Generate UUID7 from specific timestamp for range queries. + + Args: + timestamp: Timestamp to embed in UUID7 + + Returns: + UUID7 with embedded timestamp + + Example: + # Get all states from the last hour + start_time = timezone.now() - timedelta(hours=1) + start_uuid = uuid7_from_timestamp(start_time) + states = StateModel.objects.filter(id__gte=start_uuid) + """ + # Convert to milliseconds since epoch as uuid-utils expects + timestamp_ms = int(timestamp.timestamp() * 1000) + + # Use uuid-utils with specific timestamp for range queries + # This creates a UUID7 with the given timestamp and minimal random bits + # for consistent range boundaries + return uuid.UUID(int=(timestamp_ms << 80) | (0x7 << 76) | (0b10 << 62)) + + +def validate_uuid7(uuid_value: uuid.UUID) -> bool: + """ + Validate that a UUID is a valid UUID7. + + Args: + uuid_value: UUID to validate + + Returns: + True if valid UUID7, False otherwise + """ + return uuid_value.version == 7 + + +class UUID7Field: + """ + Custom field utilities for UUID7 handling in Django models. + + Provides helper methods for UUID7-specific operations that can be + used by models inheriting from BaseState. + """ + + @staticmethod + def get_latest_by_uuid7(queryset): + """Get latest record using UUID7 natural ordering""" + return queryset.order_by('-id').first() + + @staticmethod + def filter_by_time_range(queryset, start_time: datetime, end_time: Optional[datetime] = None): + """Filter queryset by time range using UUID7 embedded timestamps""" + start_uuid, end_uuid = uuid7_time_range(start_time, end_time) + return queryset.filter(id__gte=start_uuid, id__lte=end_uuid) + + @staticmethod + def filter_since_time(queryset, since: datetime): + """Filter queryset for records since a specific time""" + start_uuid = uuid7_from_timestamp(since) + return queryset.filter(id__gte=start_uuid) + + +class UUID7Generator: + """ + UUID7 generator with optional custom timestamp. + + Useful for testing or when you need to generate UUIDs with specific timestamps. + """ + + def __init__(self, base_timestamp: Optional[datetime] = None): + """ + Initialize generator with optional base timestamp. + + Args: + base_timestamp: Base timestamp to use (defaults to current time) + """ + self.base_timestamp = base_timestamp or datetime.now(timezone.utc) + self._counter = 0 + + def generate(self, offset_ms: int = 0) -> uuid.UUID: + """ + Generate UUID7 with timestamp offset. + + Args: + offset_ms: Millisecond offset from base timestamp + + Returns: + UUID7 with adjusted timestamp + """ + # For testing purposes, we'll generate a standard UUID7 + # and then optionally create one with specific timestamp if needed + if offset_ms == 0: + return uuid.uuid7() + + # For offset timestamps, use manual construction for precise control + timestamp_ms = int(self.base_timestamp.timestamp() * 1000) + offset_ms + self._counter += 1 + + # Create UUID7 with specific timestamp and counter for monotonicity + return uuid.UUID(int=(timestamp_ms << 80) | (0x7 << 76) | (self._counter & 0xFFF) << 64 | (0b10 << 62)) diff --git a/poetry.lock b/poetry.lock index 1ff9decd2597..66565ca9d1d3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "annotated-types" @@ -3178,7 +3178,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -3798,7 +3797,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4837,6 +4835,44 @@ files = [ [package.dependencies] ua-parser = ">=0.10.0" +[[package]] +name = "uuid-utils" +version = "0.11.0" +description = "Drop-in replacement for Python UUID with bindings in Rust" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:094445ccd323bc5507e28e9d6d86b983513efcf19ab59c2dd75239cef765631a"}, + {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6430b53d343215f85269ffd74e1d1f4b25ae1031acf0ac24ff3d5721f6a06f48"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be2e6e4318d23195887fa74fa1d64565a34f7127fdcf22918954981d79765f68"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d37289ab72aa30b5550bfa64d91431c62c89e4969bdf989988aa97f918d5f803"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1012595220f945fe09641f1365a8a06915bf432cac1b31ebd262944934a9b787"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35cd3fc718a673e4516e87afb9325558969eca513aa734515b9031d1b651bbb1"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ed325e0c40e0f59ae82b347f534df954b50cedf12bf60d025625538530e1965d"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5c8b7cf201990ee3140956e541967bd556a7365ec738cb504b04187ad89c757a"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9966df55bed5d538ba2e9cc40115796480f437f9007727116ef99dc2f42bd5fa"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:cb04b6c604968424b7e6398d54debbdd5b771b39fc1e648c6eabf3f1dc20582e"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win32.whl", hash = "sha256:18420eb3316bb514f09f2da15750ac135478c3a12a704e2c5fb59eab642bb255"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win_amd64.whl", hash = "sha256:37c4805af61a7cce899597d34e7c3dd5cb6a8b4b93a90fbca3826b071ba544df"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win_arm64.whl", hash = "sha256:4065cf17bbe97f6d8ccc7dc6a0bae7d28fd4797d7f32028a5abd979aeb7bf7c9"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:db821c98a95f9d69ebf9c442bcf764548c4c5feebd6012a881233fcdc8f47ff4"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:07cd17ecef3bfdf319d8e6583334f4c8e71d9950503b69d6722999c88a42dbe2"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b29c4aa76586c67e865548c862b0dee98359d59eda78b58d58290dd0dd240e"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05bfd684cb55825bc5d4c340bfce3a90009e662491e7bdfd5f667a367e0a11e4"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5455b145cb6f647888f3c4fd38ec177cf51479c73c6a44503d4b7a70f45d9870"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51374cd3280e5a8c524c51ed09901cf2268907371e1b3dc59484a92e25f070a"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:691f576327836f93102f2bf8882eb67416452bab03c3dd8c31d009c4e85dd2aa"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:912e9ae2b5c2b72bd98046ee83e1b8fa22489b4a25f44495d1c0999fa6dde237"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ce73c719e0baebc8b1652e7663bec7d4db53edbd7be1affe92b1035fc80f409b"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f7f7e0245bcedbc4ff61ad4000fd661dc93677264c0566b31010d6da0b86a63"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9930137fd6d59c681f7e013ae9343b4b9d27f7e6efce4ecb259336e15ba578b8"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f6a306878b2327b79d65bd18d5521ef8b3775c2b03a5054b1b6f602cd876cc3"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c8346b3688b2df0baae4d3ff47cd84c765aa57cf103077e32806d66f1fcd689"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c7a7f415edb5aea38bc53057c8aff4b31d35e192f2902f6ac10f2e52d3f52ae0"}, + {file = "uuid_utils-0.11.0.tar.gz", hash = "sha256:18cf2b7083da7f3cca0517647213129eb16d20d7ed0dd74b3f4f8bff2aa334ea"}, +] + [[package]] name = "uwsgitop" version = "0.12" @@ -5037,4 +5073,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "80e7a01b2cc03373dc8940d4d78dc224369c7459f176ae356aede770f75cf2b7" +content-hash = "c5546f490290e9e98e63e5512bb4de0821b2291cbee62580dfab4cd45e9cc042" diff --git a/pyproject.toml b/pyproject.toml index 400393249891..38bb82f6ad40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ dependencies = [ "setuptools (>=75.4.0)", "djangorestframework-simplejwt[crypto] (>=5.4.0,<6.0.0)", "tldextract (>=5.1.3)", + "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/117bccc7cf9dc9e8ebcd74be8004670853da1146.zip", ## HumanSignal repo dependencies :end From 0260e6599ff1367fef890caf5d9c0524af0d139b Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 25 Aug 2025 16:54:25 -0500 Subject: [PATCH 02/83] include organization on base state model --- label_studio/fsm/README.md | 63 +------------------------------------- label_studio/fsm/models.py | 7 +++++ 2 files changed, 8 insertions(+), 62 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index f28150ca26d2..ad8854f847ea 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -8,8 +8,7 @@ The Label Studio FSM system provides: - **Core Infrastructure**: Base state tracking models and managers - **UUID7 Optimization**: Time-series optimized state records using UUID7 -- **Extension Mechanism**: Allows Label Studio Enterprise to extend functionality -- **Basic API**: REST endpoints for state management +- **REST API**: Endpoints for state management - **Admin Interface**: Django admin integration for state inspection ## Architecture @@ -19,30 +18,6 @@ The Label Studio FSM system provides: 1. **BaseState**: Abstract model providing common state tracking functionality 2. **StateManager**: High-performance state management with caching 3. **Core State Models**: Task, Annotation, and Project state tracking -4. **Extension Registry**: Allows enterprise extensions to register additional functionality - -### Extension System - -The FSM system is designed to be extended by Label Studio Enterprise: - -```python -# Core provides foundation -from label_studio.fsm.models import BaseState -from label_studio.fsm.state_manager import StateManager - -# Enterprise extends with advanced features -class EnterpriseTaskState(BaseState): - # Additional enterprise-specific fields - organization_id = models.PositiveIntegerField(db_index=True) - # Advanced indexes and denormalization - -class EnterpriseStateManager(StateManager): - # Bulk operations, advanced caching, etc. - @classmethod - def bulk_get_states(cls, entities): - # Enterprise-specific bulk optimization - pass -``` ## Usage @@ -142,10 +117,6 @@ INSTALLED_APPS = [ FSM_CACHE_TTL = 300 # Cache timeout in seconds (default: 300) FSM_AUTO_CREATE_STATES = False # Auto-create states on entity creation (default: False) FSM_STATE_MANAGER_CLASS = None # Custom state manager class (default: None) - -# Enterprise Settings (when using Label Studio Enterprise) -FSM_ENABLE_BULK_OPERATIONS = True # Enable bulk operations (default: False) -FSM_CACHE_STATS_ENABLED = True # Enable cache statistics (default: False) ``` ## Database Migrations @@ -184,35 +155,6 @@ Critical indexes for performance: - `(entity_id, id DESC)`: Current state lookup using UUID7 ordering - `(entity_id, id)`: State history queries -## Extension by Label Studio Enterprise - -Label Studio Enterprise extends this system with: - -1. **Advanced State Models**: Additional entities (Reviews, Assignments, etc.) -2. **Complex Workflows**: Review, arbitration, and approval flows -3. **Bulk Operations**: High-performance batch state transitions -4. **Enhanced Caching**: Multi-level caching with cache warming -5. **Analytics**: State-based reporting and metrics -6. **Denormalization**: Performance optimization with redundant fields - -### Enterprise Extension Example - -```python -# In Label Studio Enterprise -from label_studio.fsm.extension import BaseFSMExtension -from label_studio.fsm.models import register_state_model - -class EnterpriseExtension(BaseFSMExtension): - @classmethod - def initialize(cls): - # Register enterprise models - register_state_model('review', AnnotationReviewState) - register_state_model('assignment', TaskAssignmentState) - - @classmethod - def get_state_manager(cls): - return EnterpriseStateManager -``` ## Monitoring and Debugging @@ -230,9 +172,6 @@ FSM operations are logged at appropriate levels: - `ERROR`: Failed transitions and system errors - `DEBUG`: Cache hits/misses and detailed operation info -### Cache Statistics - -When `FSM_CACHE_STATS_ENABLED=True`, cache performance metrics are available for monitoring. ## Migration from Existing Systems diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 901e64559d56..a37ffde4eb20 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -54,6 +54,13 @@ class TaskState(BaseState): help_text='UUID7 provides natural time ordering and global uniqueness', ) + organization = models.ForeignKey( + 'organizations.Organization', + on_delete=models.CASCADE, + null=True, + help_text='Organization which owns this state record', + ) + # Core State Fields state = models.CharField(max_length=50, db_index=True, help_text='Current state of the entity') previous_state = models.CharField( From f3596079ce5ebb065f004da209da79389bb3a511 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 09:26:25 -0500 Subject: [PATCH 03/83] refactoring core definitions --- label_studio/core/settings/base.py | 2 +- label_studio/fsm/admin.py | 10 - label_studio/fsm/api.py | 31 -- label_studio/fsm/apps.py | 20 +- label_studio/fsm/extension.py | 310 ++------------------ label_studio/fsm/integration.py | 12 - label_studio/fsm/migrations/0001_initial.py | 306 ------------------- label_studio/fsm/models.py | 31 -- label_studio/fsm/serializers.py | 7 +- label_studio/fsm/state_choices.py | 17 -- 10 files changed, 32 insertions(+), 714 deletions(-) delete mode 100644 label_studio/fsm/migrations/0001_initial.py diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 63c644c63a87..8d9f0f67649f 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -232,7 +232,7 @@ 'ml_model_providers', 'jwt_auth', 'session_policy', - 'fsm', # Finite State Machine for entity state tracking + 'fsm', ] MIDDLEWARE = [ diff --git a/label_studio/fsm/admin.py b/label_studio/fsm/admin.py index a55e1a7510ec..b9b027853f1f 100644 --- a/label_studio/fsm/admin.py +++ b/label_studio/fsm/admin.py @@ -2,7 +2,6 @@ Core FSM admin interface for Label Studio. Provides basic admin interface for state management that can be extended -by Label Studio Enterprise with additional functionality. """ from django.contrib import admin @@ -16,7 +15,6 @@ class BaseStateAdmin(admin.ModelAdmin): Base admin for state models. Provides common admin interface functionality for all state models. - Enterprise can extend this with additional features. """ list_display = [ @@ -152,14 +150,9 @@ def project_title(self, obj): project_title.admin_order_field = 'project__title' -# Admin actions for bulk operations (Enterprise can extend these) - - def mark_states_as_reviewed(modeladmin, request, queryset): """ Admin action to mark state records as reviewed. - - This is a placeholder that Enterprise can extend with actual functionality. """ count = queryset.count() modeladmin.message_user(request, f'{count} state records marked as reviewed.') @@ -171,8 +164,6 @@ def mark_states_as_reviewed(modeladmin, request, queryset): def export_state_history(modeladmin, request, queryset): """ Admin action to export state history. - - This is a placeholder that Enterprise can extend with actual export functionality. """ count = queryset.count() modeladmin.message_user(request, f'Export initiated for {count} state records.') @@ -181,5 +172,4 @@ def export_state_history(modeladmin, request, queryset): export_state_history.short_description = 'Export state history' -# Add actions to base admin (Enterprise can override) BaseStateAdmin.actions = [mark_states_as_reviewed, export_state_history] diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py index 133c96e85ea7..c57e14722ed3 100644 --- a/label_studio/fsm/api.py +++ b/label_studio/fsm/api.py @@ -2,7 +2,6 @@ Core FSM API endpoints for Label Studio. Provides basic API endpoints for state management that can be extended -by Label Studio Enterprise with additional functionality. """ import logging @@ -30,9 +29,6 @@ class FSMViewSet(viewsets.ViewSet): - Get current state - Get state history - Trigger state transitions - - Label Studio Enterprise can extend this with additional endpoints - for advanced state management operations. """ permission_classes = [AllPermissions] @@ -228,30 +224,3 @@ def transition_state(self, request, entity_type=None, entity_id=None): except Exception as e: logger.error(f'Error transitioning state for {entity_type} {entity_id}: {e}') return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - -# Extension point for Label Studio Enterprise -class ExtendedFSMViewSet(FSMViewSet): - """ - Extension point for Label Studio Enterprise. - - Enterprise can override this class to add: - - Bulk state operations - - Advanced state queries - - Enterprise-specific endpoints - - Enhanced permissions and validation - - Example Enterprise usage: - class EnterpriseFSMViewSet(ExtendedFSMViewSet): - @action(detail=False, methods=['post']) - def bulk_transition(self, request): - # Enterprise bulk transition endpoint - pass - - @action(detail=False, methods=['get']) - def state_analytics(self, request): - # Enterprise state analytics endpoint - pass - """ - - pass diff --git a/label_studio/fsm/apps.py b/label_studio/fsm/apps.py index a2fd89e66cf3..9290e8ddd289 100644 --- a/label_studio/fsm/apps.py +++ b/label_studio/fsm/apps.py @@ -8,9 +8,9 @@ class FsmConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' + default_auto_field = 'django.db.models.UUIDField' name = 'label_studio.fsm' - verbose_name = 'Finite State Machine' + verbose_name = 'Label Studio FSM' def ready(self): """Initialize FSM system when Django app is ready""" @@ -25,18 +25,10 @@ def ready(self): def _initialize_extensions(self): """Initialize FSM extension system""" try: - from .extension import ( - auto_register_enterprise_choices, - auto_register_enterprise_models, - extension_registry, - ) - - # Load configured extensions - extension_registry.load_extensions() - - # Auto-register enterprise models if available - auto_register_enterprise_models() - auto_register_enterprise_choices() + # Import the extension registry to ensure it's initialized + + # Basic extension system is ready + logger.debug('FSM extension system ready') except Exception as e: logger.error(f'Failed to initialize FSM extensions: {e}') diff --git a/label_studio/fsm/extension.py b/label_studio/fsm/extension.py index 6c8f2e3eebc9..f8dd243dc9c9 100644 --- a/label_studio/fsm/extension.py +++ b/label_studio/fsm/extension.py @@ -1,324 +1,60 @@ """ -Extension mechanism for Label Studio Enterprise FSM. - -This module provides the hooks and interfaces that allow Label Studio Enterprise -to extend the core FSM functionality with advanced features. +Minimal extension hooks for Label Studio FSM. """ import logging -from typing import Any, Dict, List, Type - -from django.apps import apps -from django.conf import settings - -from .models import BaseState, register_state_model -from .state_choices import register_state_choices -from .state_manager import StateManager logger = logging.getLogger(__name__) -class FSMExtensionRegistry: - """ - Registry for FSM extensions that allows enterprise features to be - dynamically loaded and integrated with the core FSM system. - """ - - def __init__(self): - self._extensions = {} - self._state_managers = {} - self._loaded = False - - def register_extension(self, name: str, extension_class): - """ - Register an FSM extension. - - Args: - name: Unique name for the extension - extension_class: Class implementing the extension - """ - self._extensions[name] = extension_class - logger.info(f'Registered FSM extension: {name}') - - def get_extension(self, name: str): - """Get a registered extension by name""" - return self._extensions.get(name) - - def load_extensions(self): - """ - Load FSM extensions from settings. - - Called during Django app startup to load enterprise extensions. - """ - if self._loaded: - return - - extensions_setting = getattr(settings, 'FSM_EXTENSIONS', []) - for extension_config in extensions_setting: - self._load_extension(extension_config) - - self._loaded = True - logger.info(f'Loaded {len(self._extensions)} FSM extensions') - - def _load_extension(self, config: Dict[str, Any]): - """Load a single extension from configuration""" - try: - name = config['name'] - class_path = config['class'] - - # Import the extension class - module_name, class_name = class_path.rsplit('.', 1) - module = __import__(module_name, fromlist=[class_name]) - extension_class = getattr(module, class_name) - - # Register the extension - self.register_extension(name, extension_class) - - # Initialize the extension if it has an init method - if hasattr(extension_class, 'initialize'): - extension_class.initialize() - - except Exception as e: - logger.error(f'Failed to load FSM extension {config}: {e}') - - -# Global extension registry -extension_registry = FSMExtensionRegistry() - - class BaseFSMExtension: """ - Base class for FSM extensions. + Minimal base class for FSM extensions. - Enterprise extensions should inherit from this class to ensure - compatibility with the core FSM system. + This provides the interface that extensions should implement. """ @classmethod def initialize(cls): - """ - Initialize the extension. - - Called when the extension is loaded. Override to perform - setup tasks like registering state models and choices. - """ + """Initialize the extension.""" pass @classmethod def register_models(cls): - """ - Register state models with the core FSM system. - - Override to register enterprise-specific state models. - - Example: - register_state_model('review', AnnotationReviewState) - register_state_model('assignment', TaskAssignmentState) - """ + """Register state models with the core FSM system.""" pass @classmethod def register_choices(cls): - """ - Register state choices with the core FSM system. - - Override to register enterprise-specific state choices. - - Example: - register_state_choices('review', ReviewStateChoices) - register_state_choices('assignment', AssignmentStateChoices) - """ + """Register state choices with the core FSM system.""" pass @classmethod - def get_state_manager(cls) -> Type[StateManager]: - """ - Get the state manager class for this extension. - - Override to provide enterprise-specific state manager. + def get_state_manager(cls): + """Get the state manager class for this extension.""" + from .state_manager import StateManager - Returns: - StateManager class to use - """ return StateManager -class EnterpriseExtensionMixin: +# Extension registry for compatibility +class ExtensionRegistry: """ - Mixin for enterprise extensions that provides common functionality - for extending the core FSM system. + Extension registry for core Label Studio. """ - @classmethod - def extend_state_model(cls, entity_name: str, base_model_class: Type[BaseState]): - """ - Helper to create extended state models. - - Args: - entity_name: Name of the entity (e.g., 'task', 'annotation') - base_model_class: Base state model class to extend - - Returns: - Extended model class - """ - # This would be used by enterprise to add denormalized fields, - # additional indexes, and enterprise-specific functionality - pass - - @classmethod - def extend_state_choices(cls, base_choices_class, additional_choices: List[tuple]): - """ - Helper to extend state choices with additional states. - - Args: - base_choices_class: Base TextChoices class - additional_choices: List of (value, label) tuples for new states - - Returns: - Extended choices class - """ - # This would be used by enterprise to add additional states - # to the core state choices - pass - - -# Configuration helpers for enterprise setup - - -def configure_fsm_for_enterprise(): - """ - Configure FSM system for Label Studio Enterprise. - - This function should be called by enterprise during app initialization - to set up the FSM system with enterprise-specific configuration. - """ - # Load enterprise extensions - extension_registry.load_extensions() - - # Set enterprise-specific settings - if not hasattr(settings, 'FSM_CACHE_TTL'): - settings.FSM_CACHE_TTL = 300 # 5 minutes - - if not hasattr(settings, 'FSM_ENABLE_BULK_OPERATIONS'): - settings.FSM_ENABLE_BULK_OPERATIONS = True - - logger.info('FSM system configured for Label Studio Enterprise') - - -def get_enterprise_state_manager(): - """ - Get the enterprise state manager if available. - - Returns the enterprise-specific state manager class if one is registered, - otherwise returns the core StateManager. - """ - # Check if enterprise has registered a state manager - enterprise_ext = extension_registry.get_extension('enterprise') - if enterprise_ext: - return enterprise_ext.get_state_manager() - - # Fall back to core state manager - return StateManager - - -# Settings for FSM extensions -def get_fsm_settings(): - """Get FSM-related settings with defaults""" - return { - 'cache_ttl': getattr(settings, 'FSM_CACHE_TTL', 300), - 'enable_bulk_operations': getattr(settings, 'FSM_ENABLE_BULK_OPERATIONS', False), - 'enable_cache_stats': getattr(settings, 'FSM_CACHE_STATS_ENABLED', False), - 'state_manager_class': getattr(settings, 'FSM_STATE_MANAGER_CLASS', None), - 'extensions': getattr(settings, 'FSM_EXTENSIONS', []), - } - - -# Integration helpers for model registration - - -def auto_register_enterprise_models(): - """ - Automatically register enterprise state models. - - Scans for state models in enterprise apps and registers them - with the core FSM system. - """ - try: - # Only attempt if enterprise is available - if apps.is_installed('label_studio_enterprise.fsm'): - from label_studio_enterprise.fsm.models import ( - AnnotationDraftState, - AnnotationReviewState, - CommentState, - TaskAssignmentState, - TaskLockState, - ) - from label_studio_enterprise.fsm.models import ( - AnnotationState as EnterpriseAnnotationState, - ) - from label_studio_enterprise.fsm.models import ( - ProjectState as EnterpriseProjectState, - ) - from label_studio_enterprise.fsm.models import ( - TaskState as EnterpriseTaskState, - ) - - # Register enterprise state models - register_state_model('task', EnterpriseTaskState) - register_state_model('annotation', EnterpriseAnnotationState) - register_state_model('project', EnterpriseProjectState) - register_state_model('annotationreview', AnnotationReviewState) - register_state_model('taskassignment', TaskAssignmentState) - register_state_model('annotationdraft', AnnotationDraftState) - register_state_model('comment', CommentState) - register_state_model('tasklock', TaskLockState) - - logger.info('Auto-registered enterprise state models') - - except ImportError: - # Enterprise not available, use core models - logger.debug('Enterprise FSM models not available, using core models') - - -def auto_register_enterprise_choices(): - """ - Automatically register enterprise state choices. + def __init__(self): + self._extensions = {} - Scans for state choices in enterprise apps and registers them - with the core FSM system. - """ - try: - # Only attempt if enterprise is available - if apps.is_installed('label_studio_enterprise.fsm'): - from label_studio_enterprise.fsm.state_choices import ( - AnnotationDraftStateChoices, - AssignmentStateChoices, - CommentStateChoices, - ReviewStateChoices, - TaskLockStateChoices, - ) - from label_studio_enterprise.fsm.state_choices import ( - AnnotationStateChoices as EnterpriseAnnotationStateChoices, - ) - from label_studio_enterprise.fsm.state_choices import ( - ProjectStateChoices as EnterpriseProjectStateChoices, - ) - from label_studio_enterprise.fsm.state_choices import ( - TaskStateChoices as EnterpriseTaskStateChoices, - ) + def register_extension(self, name: str, extension_class): + """Register an extension.""" + self._extensions[name] = extension_class + logger.debug(f'Registered FSM extension: {name}') - # Register enterprise state choices - register_state_choices('task', EnterpriseTaskStateChoices) - register_state_choices('annotation', EnterpriseAnnotationStateChoices) - register_state_choices('project', EnterpriseProjectStateChoices) - register_state_choices('review', ReviewStateChoices) - register_state_choices('assignment', AssignmentStateChoices) - register_state_choices('annotationdraft', AnnotationDraftStateChoices) - register_state_choices('comment', CommentStateChoices) - register_state_choices('tasklock', TaskLockStateChoices) + def get_extension(self, name: str): + """Get a registered extension by name.""" + return self._extensions.get(name) - logger.info('Auto-registered enterprise state choices') - except ImportError: - # Enterprise not available, use core choices - logger.debug('Enterprise FSM choices not available, using core choices') +# Global minimal registry +extension_registry = ExtensionRegistry() diff --git a/label_studio/fsm/integration.py b/label_studio/fsm/integration.py index 0ae365bb7d79..4850c40f74b3 100644 --- a/label_studio/fsm/integration.py +++ b/label_studio/fsm/integration.py @@ -21,15 +21,6 @@ class FSMIntegrationMixin: This mixin can be added to Task, Annotation, and Project models to provide convenient methods for state management without modifying the core models. - - Example usage in Enterprise: - # In LSE models.py: - from label_studio.fsm.integration import FSMIntegrationMixin - from label_studio.tasks.models import Task as CoreTask - - class Task(FSMIntegrationMixin, CoreTask): - class Meta: - proxy = True """ @property @@ -237,7 +228,6 @@ def get_entities_by_state(model_class, state: str, limit: int = 100): # Get entity IDs that have the specified current state f'{model_class._meta.model_name.lower()}_id' - # This is a simplified version - Enterprise can optimize with window functions current_state_subquery = ( state_model.objects.filter(**{f'{model_class._meta.model_name.lower()}__pk': models.OuterRef('pk')}) .order_by('-id') @@ -253,8 +243,6 @@ def bulk_transition_entities(entities, new_state: str, user=None, **kwargs): """ Bulk transition multiple entities to the same state. - Basic implementation that Enterprise can optimize with bulk operations. - Args: entities: List of entity instances new_state: Target state for all entities diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py deleted file mode 100644 index 9cc43da2a8e7..000000000000 --- a/label_studio/fsm/migrations/0001_initial.py +++ /dev/null @@ -1,306 +0,0 @@ -# Generated by Django 4.2.16 on 2024-01-15 12:00 - -import django.db.models.deletion -from django.conf import settings -from django.db import migrations, models - -import label_studio.fsm.utils - - -class Migration(migrations.Migration): - """ - Initial migration for core FSM functionality in Label Studio. - - Creates the base state tracking infrastructure with UUID7 optimization - for high-performance time-series data. - """ - - initial = True - - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('tasks', '0055_task_proj_octlen_idx_async'), # Latest task migration - ('projects', '0030_project_search_vector_index'), # Latest project migration - ] - - operations = [ - migrations.CreateModel( - name='TaskState', - fields=[ - ( - 'id', - models.UUIDField( - default=label_studio.fsm.utils.generate_uuid7, - editable=False, - help_text='UUID7 provides natural time ordering and global uniqueness', - primary_key=True, - serialize=False, - ), - ), - ('state', models.CharField( - choices=[ - ('CREATED', 'Created'), - ('IN_PROGRESS', 'In Progress'), - ('COMPLETED', 'Completed') - ], - db_index=True, - help_text='Current state of the entity', - max_length=50 - )), - ( - 'previous_state', - models.CharField( - blank=True, - help_text='Previous state before this transition', - max_length=50, - null=True, - ), - ), - ( - 'transition_name', - models.CharField( - blank=True, - help_text='Name of the transition method that triggered this state change', - max_length=100, - null=True, - ), - ), - ( - 'context_data', - models.JSONField( - default=dict, - help_text='Additional context data for this transition (e.g., validation results, external IDs)', - ), - ), - ( - 'reason', - models.TextField( - blank=True, help_text='Human-readable reason for this state transition' - ), - ), - ( - 'created_at', - models.DateTimeField( - auto_now_add=True, - db_index=False, - help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', - ), - ), - ( - 'task', - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name='fsm_states', - to='tasks.task', - ), - ), - ( - 'triggered_by', - models.ForeignKey( - help_text='User who triggered this state transition', - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - 'db_table': 'fsm_task_states', - 'ordering': ['-id'], - 'get_latest_by': 'id', - }, - ), - migrations.CreateModel( - name='ProjectState', - fields=[ - ( - 'id', - models.UUIDField( - default=label_studio.fsm.utils.generate_uuid7, - editable=False, - help_text='UUID7 provides natural time ordering and global uniqueness', - primary_key=True, - serialize=False, - ), - ), - ('state', models.CharField( - choices=[ - ('CREATED', 'Created'), - ('PUBLISHED', 'Published'), - ('IN_PROGRESS', 'In Progress'), - ('COMPLETED', 'Completed') - ], - db_index=True, - help_text='Current state of the entity', - max_length=50 - )), - ( - 'previous_state', - models.CharField( - blank=True, - help_text='Previous state before this transition', - max_length=50, - null=True, - ), - ), - ( - 'transition_name', - models.CharField( - blank=True, - help_text='Name of the transition method that triggered this state change', - max_length=100, - null=True, - ), - ), - ( - 'context_data', - models.JSONField( - default=dict, - help_text='Additional context data for this transition (e.g., validation results, external IDs)', - ), - ), - ( - 'reason', - models.TextField( - blank=True, help_text='Human-readable reason for this state transition' - ), - ), - ( - 'created_at', - models.DateTimeField( - auto_now_add=True, - db_index=False, - help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', - ), - ), - ( - 'project', - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name='fsm_states', - to='projects.project', - ), - ), - ( - 'triggered_by', - models.ForeignKey( - help_text='User who triggered this state transition', - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - 'db_table': 'fsm_project_states', - 'ordering': ['-id'], - 'get_latest_by': 'id', - }, - ), - migrations.CreateModel( - name='AnnotationState', - fields=[ - ( - 'id', - models.UUIDField( - default=label_studio.fsm.utils.generate_uuid7, - editable=False, - help_text='UUID7 provides natural time ordering and global uniqueness', - primary_key=True, - serialize=False, - ), - ), - ('state', models.CharField( - choices=[ - ('DRAFT', 'Draft'), - ('SUBMITTED', 'Submitted'), - ('COMPLETED', 'Completed') - ], - db_index=True, - help_text='Current state of the entity', - max_length=50 - )), - ( - 'previous_state', - models.CharField( - blank=True, - help_text='Previous state before this transition', - max_length=50, - null=True, - ), - ), - ( - 'transition_name', - models.CharField( - blank=True, - help_text='Name of the transition method that triggered this state change', - max_length=100, - null=True, - ), - ), - ( - 'context_data', - models.JSONField( - default=dict, - help_text='Additional context data for this transition (e.g., validation results, external IDs)', - ), - ), - ( - 'reason', - models.TextField( - blank=True, help_text='Human-readable reason for this state transition' - ), - ), - ( - 'created_at', - models.DateTimeField( - auto_now_add=True, - db_index=False, - help_text='Human-readable timestamp for debugging (UUID7 id contains precise timestamp)', - ), - ), - ( - 'annotation', - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name='fsm_states', - to='tasks.annotation', - ), - ), - ( - 'triggered_by', - models.ForeignKey( - help_text='User who triggered this state transition', - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - 'db_table': 'fsm_annotation_states', - 'ordering': ['-id'], - 'get_latest_by': 'id', - }, - ), - # Create indexes for optimal performance with UUID7 - migrations.RunSQL( - sql=[ - # Task state indexes - critical for current state lookups - "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_task_current_state_idx ON fsm_task_states (task_id, id DESC);", - "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_task_history_idx ON fsm_task_states (task_id, id);", - - # Annotation state indexes - "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_anno_current_state_idx ON fsm_annotation_states (annotation_id, id DESC);", - - # Project state indexes - "CREATE INDEX CONCURRENTLY IF NOT EXISTS fsm_proj_current_state_idx ON fsm_project_states (project_id, id DESC);", - ], - reverse_sql=[ - "DROP INDEX IF EXISTS fsm_task_current_state_idx;", - "DROP INDEX IF EXISTS fsm_task_history_idx;", - "DROP INDEX IF EXISTS fsm_anno_current_state_idx;", - "DROP INDEX IF EXISTS fsm_proj_current_state_idx;", - ] - ), - ] \ No newline at end of file diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index a37ffde4eb20..31c165dca389 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -1,8 +1,5 @@ """ Core FSM models for Label Studio. - -Provides the base infrastructure for state tracking that can be extended -by Label Studio Enterprise and other applications. """ from datetime import datetime @@ -37,13 +34,6 @@ class BaseState(models.Model): - Global uniqueness enables distributed system support - Time-based partitioning for billion-record scalability - Complete audit trail by design - - Usage: - # In Label Studio Enterprise: - class TaskState(BaseState): - task = models.ForeignKey('tasks.Task', ...) - state = models.CharField(choices=EnterpriseTaskStateChoices.choices, ...) - # Additional enterprise-specific fields """ # UUID7 Primary Key - provides natural time ordering and global uniqueness @@ -183,7 +173,6 @@ def _get_entity_field_name(cls) -> str: # Core state models for basic Label Studio entities -# These provide the foundation that Enterprise can extend class TaskState(BaseState): @@ -193,12 +182,6 @@ class TaskState(BaseState): Provides basic task state management with: - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) - High-performance queries with UUID7 ordering - - Extensible design for enterprise features - - Label Studio Enterprise extends this with: - - Additional workflow states (review, arbitration) - - Denormalized fields for performance - - Advanced state transition logic """ # Entity Relationship @@ -230,7 +213,6 @@ class AnnotationState(BaseState): Provides basic annotation state management with: - Simple 3-state workflow (DRAFT → SUBMITTED → COMPLETED) - Draft and submission tracking - - Extensible design for enterprise review workflows """ # Entity Relationship @@ -260,7 +242,6 @@ class ProjectState(BaseState): Provides basic project state management with: - Simple 4-state workflow (CREATED → PUBLISHED → IN_PROGRESS → COMPLETED) - Project lifecycle tracking - - Extensible design for enterprise features """ # Entity Relationship @@ -284,7 +265,6 @@ def is_terminal_state(self) -> bool: # Registry for dynamic state model extension -# Enterprise can register additional state models here STATE_MODEL_REGISTRY = { 'task': TaskState, 'annotation': AnnotationState, @@ -296,20 +276,9 @@ def register_state_model(entity_name: str, model_class): """ Register state model for an entity type. - This allows Label Studio Enterprise to register additional state models - or override existing ones with enterprise-specific implementations. - Args: entity_name: Name of the entity (e.g., 'review', 'assignment') model_class: Django model class inheriting from BaseState - - Example: - # In LSE code: - register_state_model('review', AnnotationReviewState) - register_state_model('assignment', TaskAssignmentState) - - # Override core model with enterprise version: - register_state_model('task', EnterpriseTaskState) """ STATE_MODEL_REGISTRY[entity_name.lower()] = model_class diff --git a/label_studio/fsm/serializers.py b/label_studio/fsm/serializers.py index 9f2b920d7d71..823df832d3b0 100644 --- a/label_studio/fsm/serializers.py +++ b/label_studio/fsm/serializers.py @@ -1,8 +1,7 @@ """ Core FSM serializers for Label Studio. -Provides basic serializers for state management API that can be extended -by Label Studio Enterprise with additional functionality. +Provides basic serializers for state management API """ from rest_framework import serializers @@ -12,8 +11,7 @@ class StateHistorySerializer(serializers.Serializer): """ Serializer for state history records. - Provides basic state history information that can be extended - by Enterprise with additional fields. + Provides basic state history information """ id = serializers.UUIDField(read_only=True) @@ -80,7 +78,6 @@ class StateInfoSerializer(serializers.Serializer): entity_type = serializers.CharField() entity_id = serializers.IntegerField() - # Optional fields that Enterprise can populate available_transitions = serializers.ListField( child=serializers.CharField(), required=False, help_text='List of valid transitions from current state' ) diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py index 428d5b71b2a4..a519a0ee7a70 100644 --- a/label_studio/fsm/state_choices.py +++ b/label_studio/fsm/state_choices.py @@ -2,8 +2,6 @@ Core state choice enums for Label Studio entities. These enums define the essential states for core Label Studio entities. -Label Studio Enterprise can extend these with additional states or -define entirely new state enums for enterprise-specific entities. """ from django.db import models @@ -18,8 +16,6 @@ class TaskStateChoices(models.TextChoices): - Creation and assignment - Annotation work - Completion - - Enterprise can extend with review, arbitration, and advanced workflow states. """ # Initial State @@ -40,8 +36,6 @@ class AnnotationStateChoices(models.TextChoices): - Draft work - Submission - Completion - - Enterprise can extend with review, approval, and rejection states. """ # Working States @@ -60,8 +54,6 @@ class ProjectStateChoices(models.TextChoices): - Setup and configuration - Active work - Completion - - Enterprise can extend with advanced workflow, review, and approval states. """ # Setup States @@ -76,7 +68,6 @@ class ProjectStateChoices(models.TextChoices): # Registry for dynamic state choices extension -# Enterprise can register additional choices here STATE_CHOICES_REGISTRY = { 'task': TaskStateChoices, 'annotation': AnnotationStateChoices, @@ -88,17 +79,9 @@ def register_state_choices(entity_name: str, choices_class): """ Register state choices for an entity type. - This allows Label Studio Enterprise and other extensions to register - their own state choices dynamically. - Args: entity_name: Name of the entity (e.g., 'review', 'assignment') choices_class: Django TextChoices class defining valid states - - Example: - # In LSE code: - register_state_choices('review', ReviewStateChoices) - register_state_choices('assignment', AssignmentStateChoices) """ STATE_CHOICES_REGISTRY[entity_name.lower()] = choices_class From a475d633f8778c8a93253ea247a7b847b2b2266b Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 11:17:43 -0500 Subject: [PATCH 04/83] define core entity states in full --- label_studio/fsm/models.py | 54 ++++++++++++++++++++++++------- label_studio/fsm/state_choices.py | 1 - 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 31c165dca389..0f38cf567b0f 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -185,19 +185,27 @@ class TaskState(BaseState): """ # Entity Relationship - task = models.ForeignKey('tasks.Task', on_delete=models.CASCADE, related_name='fsm_states') + task = models.ForeignKey('tasks.Task', related_name='fsm_states', on_delete=models.CASCADE, db_index=True) # Override state field to add choices constraint state = models.CharField(max_length=50, choices=TaskStateChoices.choices, db_index=True) + project_id = models.PositiveIntegerField( + db_index=True, help_text='From task.project_id - denormalized for performance' + ) + class Meta: - db_table = 'fsm_task_states' indexes = [ - # Critical: Latest state lookup using UUID7 ordering - models.Index(fields=['task_id', '-id'], name='fsm_task_current_state_idx'), + # Critical: Latest state lookup (current state determined by latest UUID7 id) + # Index with DESC order explicitly supports ORDER BY id DESC queries + models.Index(fields=['task_id', '-id'], name='task_current_state_idx'), + # Reporting and filtering + models.Index(fields=['project_id', 'state', '-id'], name='task_project_state_idx'), + models.Index(fields=['organization_id', 'state', '-id'], name='task_org_reporting_idx'), # History queries - models.Index(fields=['task_id', 'id'], name='fsm_task_history_idx'), + models.Index(fields=['task_id', 'id'], name='task_history_idx'), ] + # No constraints needed - INSERT-only approach ordering = ['-id'] @property @@ -212,7 +220,6 @@ class AnnotationState(BaseState): Provides basic annotation state management with: - Simple 3-state workflow (DRAFT → SUBMITTED → COMPLETED) - - Draft and submission tracking """ # Entity Relationship @@ -221,11 +228,25 @@ class AnnotationState(BaseState): # Override state field to add choices constraint state = models.CharField(max_length=50, choices=AnnotationStateChoices.choices, db_index=True) + # Denormalized fields for performance (avoid JOINs in common queries) + task_id = models.PositiveIntegerField( + db_index=True, help_text='From annotation.task_id - denormalized for performance' + ) + project_id = models.PositiveIntegerField( + db_index=True, help_text='From annotation.task.project_id - denormalized for performance' + ) + completed_by_id = models.PositiveIntegerField( + null=True, db_index=True, help_text='From annotation.completed_by_id - denormalized for performance' + ) + class Meta: - db_table = 'fsm_annotation_states' indexes = [ # Critical: Latest state lookup - models.Index(fields=['annotation_id', '-id'], name='fsm_anno_current_state_idx'), + models.Index(fields=['annotation_id', '-id'], name='anno_current_state_idx'), + # Filtering and reporting + models.Index(fields=['task_id', 'state', '-id'], name='anno_task_state_idx'), + models.Index(fields=['completed_by_id', 'state', '-id'], name='anno_user_report_idx'), + models.Index(fields=['project_id', 'state', '-id'], name='anno_project_report_idx'), ] ordering = ['-id'] @@ -240,21 +261,32 @@ class ProjectState(BaseState): Core project state tracking for Label Studio. Provides basic project state management with: - - Simple 4-state workflow (CREATED → PUBLISHED → IN_PROGRESS → COMPLETED) + - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) - Project lifecycle tracking """ # Entity Relationship - project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='fsm_states') + project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='states') # Override state field to add choices constraint state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) + # Denormalized fields for performance (avoid JOINs in common queries) + organization_id = models.PositiveIntegerField( + db_index=True, help_text='From project.organization_id - denormalized for performance' + ) + created_by_id = models.PositiveIntegerField( + null=True, db_index=True, help_text='From project.created_by_id - denormalized for performance' + ) + class Meta: db_table = 'fsm_project_states' indexes = [ # Critical: Latest state lookup - models.Index(fields=['project_id', '-id'], name='fsm_proj_current_state_idx'), + models.Index(fields=['project_id', '-id'], name='project_current_state_idx'), + # Filtering and reporting + models.Index(fields=['organization_id', 'state', '-id'], name='project_org_state_idx'), + models.Index(fields=['organization_id', '-id'], name='project_org_reporting_idx'), ] ordering = ['-id'] diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py index a519a0ee7a70..ba8304c777ec 100644 --- a/label_studio/fsm/state_choices.py +++ b/label_studio/fsm/state_choices.py @@ -58,7 +58,6 @@ class ProjectStateChoices(models.TextChoices): # Setup States CREATED = 'CREATED', _('Created') - PUBLISHED = 'PUBLISHED', _('Published') # Work States IN_PROGRESS = 'IN_PROGRESS', _('In Progress') From 0c31d1f07b1f42ad4a88f85aa3c43dcb6f8b9317 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 11:44:28 -0500 Subject: [PATCH 05/83] regen migrations for fsm core --- label_studio/fsm/apps.py | 2 +- label_studio/fsm/migrations/0001_initial.py | 386 ++++++++++++++++++++ label_studio/fsm/models.py | 5 - 3 files changed, 387 insertions(+), 6 deletions(-) create mode 100644 label_studio/fsm/migrations/0001_initial.py diff --git a/label_studio/fsm/apps.py b/label_studio/fsm/apps.py index 9290e8ddd289..a451c39fabbe 100644 --- a/label_studio/fsm/apps.py +++ b/label_studio/fsm/apps.py @@ -9,7 +9,7 @@ class FsmConfig(AppConfig): default_auto_field = 'django.db.models.UUIDField' - name = 'label_studio.fsm' + name = 'fsm' verbose_name = 'Label Studio FSM' def ready(self): diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py new file mode 100644 index 000000000000..b108658b4874 --- /dev/null +++ b/label_studio/fsm/migrations/0001_initial.py @@ -0,0 +1,386 @@ +# Generated by Django 5.1.10 on 2025-08-26 16:43 + +import django.db.models.deletion +import fsm.utils +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("organizations", "0006_alter_organizationmember_deleted_at"), + ("projects", "0030_project_search_vector_index"), + ("tasks", "0057_annotation_proj_result_octlen_idx_async"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="AnnotationState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("DRAFT", "Draft"), + ("SUBMITTED", "Submitted"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "task_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.task_id - denormalized for performance", + ), + ), + ( + "project_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.task.project_id - denormalized for performance", + ), + ), + ( + "completed_by_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.completed_by_id - denormalized for performance", + null=True, + ), + ), + ( + "annotation", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="fsm_states", + to="tasks.annotation", + ), + ), + ( + "organization", + models.ForeignKey( + help_text="Organization which owns this state record", + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="organizations.organization", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["annotation_id", "-id"], name="anno_current_state_idx" + ), + models.Index( + fields=["task_id", "state", "-id"], name="anno_task_state_idx" + ), + models.Index( + fields=["completed_by_id", "state", "-id"], + name="anno_user_report_idx", + ), + models.Index( + fields=["project_id", "state", "-id"], + name="anno_project_report_idx", + ), + ], + }, + ), + migrations.CreateModel( + name="ProjectState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("CREATED", "Created"), + ("IN_PROGRESS", "In Progress"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "created_by_id", + models.PositiveIntegerField( + db_index=True, + help_text="From project.created_by_id - denormalized for performance", + null=True, + ), + ), + ( + "organization", + models.ForeignKey( + help_text="Organization which owns this state record", + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="organizations.organization", + ), + ), + ( + "project", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="states", + to="projects.project", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["project_id", "-id"], name="project_current_state_idx" + ), + models.Index( + fields=["organization_id", "state", "-id"], + name="project_org_state_idx", + ), + models.Index( + fields=["organization_id", "-id"], + name="project_org_reporting_idx", + ), + ], + }, + ), + migrations.CreateModel( + name="TaskState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("CREATED", "Created"), + ("IN_PROGRESS", "In Progress"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "project_id", + models.PositiveIntegerField( + db_index=True, + help_text="From task.project_id - denormalized for performance", + ), + ), + ( + "organization", + models.ForeignKey( + help_text="Organization which owns this state record", + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="organizations.organization", + ), + ), + ( + "task", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="fsm_states", + to="tasks.task", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["task_id", "-id"], name="task_current_state_idx" + ), + models.Index( + fields=["project_id", "state", "-id"], + name="task_project_state_idx", + ), + models.Index( + fields=["organization_id", "state", "-id"], + name="task_org_reporting_idx", + ), + models.Index(fields=["task_id", "id"], name="task_history_idx"), + ], + }, + ), + ] diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 0f38cf567b0f..4b2b10059c9a 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -271,16 +271,11 @@ class ProjectState(BaseState): # Override state field to add choices constraint state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) - # Denormalized fields for performance (avoid JOINs in common queries) - organization_id = models.PositiveIntegerField( - db_index=True, help_text='From project.organization_id - denormalized for performance' - ) created_by_id = models.PositiveIntegerField( null=True, db_index=True, help_text='From project.created_by_id - denormalized for performance' ) class Meta: - db_table = 'fsm_project_states' indexes = [ # Critical: Latest state lookup models.Index(fields=['project_id', '-id'], name='project_current_state_idx'), From 15f587282aad2ad8d75bfccdd4e371f8a60bc804 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 11:49:04 -0500 Subject: [PATCH 06/83] regen migrations for fsm core --- label_studio/fsm/migrations/0001_initial.py | 4 ++-- label_studio/fsm/models.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py index b108658b4874..4850f8e880b5 100644 --- a/label_studio/fsm/migrations/0001_initial.py +++ b/label_studio/fsm/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.10 on 2025-08-26 16:43 +# Generated by Django 5.1.10 on 2025-08-26 16:48 import django.db.models.deletion import fsm.utils @@ -236,7 +236,7 @@ class Migration(migrations.Migration): "project", models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, - related_name="states", + related_name="fsm_states", to="projects.project", ), ), diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 4b2b10059c9a..f69c94efbfdb 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -266,7 +266,7 @@ class ProjectState(BaseState): """ # Entity Relationship - project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='states') + project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='fsm_states') # Override state field to add choices constraint state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) From 5bc0caef93acbcb73c760f6215018c81d2106e32 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 12:14:13 -0500 Subject: [PATCH 07/83] regen migrations for fsm core --- label_studio/fsm/migrations/0001_initial.py | 2 +- label_studio/fsm/models.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py index 4850f8e880b5..322f6bcffc41 100644 --- a/label_studio/fsm/migrations/0001_initial.py +++ b/label_studio/fsm/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 5.1.10 on 2025-08-26 16:48 +# Generated by Django 5.1.10 on 2025-08-26 17:13 import django.db.models.deletion import fsm.utils diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index f69c94efbfdb..313a942aa474 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -195,6 +195,7 @@ class TaskState(BaseState): ) class Meta: + app_label = 'fsm' indexes = [ # Critical: Latest state lookup (current state determined by latest UUID7 id) # Index with DESC order explicitly supports ORDER BY id DESC queries @@ -240,6 +241,7 @@ class AnnotationState(BaseState): ) class Meta: + app_label = 'fsm' indexes = [ # Critical: Latest state lookup models.Index(fields=['annotation_id', '-id'], name='anno_current_state_idx'), @@ -276,6 +278,7 @@ class ProjectState(BaseState): ) class Meta: + app_label = 'fsm' indexes = [ # Critical: Latest state lookup models.Index(fields=['project_id', '-id'], name='project_current_state_idx'), From 1c1ced9108a4d120923041a9d1189877e59708d7 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 14:05:41 -0500 Subject: [PATCH 08/83] fixed imports --- label_studio/fsm/README.md | 12 ++++---- label_studio/fsm/integration.py | 20 ++++++------- label_studio/fsm/tests/test_uuid7_utils.py | 3 +- label_studio/fsm/urls.py | 33 +--------------------- 4 files changed, 18 insertions(+), 50 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index ad8854f847ea..5fb29ca5c173 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -24,8 +24,8 @@ The Label Studio FSM system provides: ### Basic State Management ```python -from label_studio.fsm.state_manager import get_state_manager -from label_studio.tasks.models import Task +from fsm.state_manager import get_state_manager +from tasks.models import Task # Get current state StateManager = get_state_manager() @@ -48,7 +48,7 @@ history = StateManager.get_state_history(task, limit=10) ```python # Add FSM functionality to existing models -from label_studio.fsm.integration import FSMIntegrationMixin +from fsm.integration import FSMIntegrationMixin class Task(FSMIntegrationMixin, BaseTask): class Meta: @@ -186,8 +186,8 @@ The FSM system can run alongside existing state management: Test the FSM system: ```python -from label_studio.fsm.state_manager import StateManager -from label_studio.tasks.models import Task +from fsm.state_manager import StateManager +from tasks.models import Task def test_task_state_transition(): task = Task.objects.create(...) @@ -204,4 +204,4 @@ def test_task_state_transition(): history = StateManager.get_state_history(task) assert len(history) == 1 assert history[0].state == 'CREATED' -``` \ No newline at end of file +``` diff --git a/label_studio/fsm/integration.py b/label_studio/fsm/integration.py index 4850c40f74b3..56a187fa9239 100644 --- a/label_studio/fsm/integration.py +++ b/label_studio/fsm/integration.py @@ -76,8 +76,8 @@ def add_fsm_to_model(model_class): This provides an alternative to inheritance for adding FSM capabilities. Example: - from label_studio.fsm.integration import add_fsm_to_model - from label_studio.tasks.models import Task + from fsm.integration import add_fsm_to_model + from tasks.models import Task @add_fsm_to_model class Task(Task): @@ -126,8 +126,8 @@ def handle_task_created(sender, instance, created, **kwargs): Connect this to the Task model's post_save signal: from django.db.models.signals import post_save - from label_studio.tasks.models import Task - from label_studio.fsm.integration import handle_task_created + from tasks.models import Task + from fsm.integration import handle_task_created post_save.connect(handle_task_created, sender=Task) """ @@ -151,8 +151,8 @@ def handle_annotation_created(sender, instance, created, **kwargs): Connect this to the Annotation model's post_save signal: from django.db.models.signals import post_save - from label_studio.tasks.models import Annotation - from label_studio.fsm.integration import handle_annotation_created + from tasks.models import Annotation + from fsm.integration import handle_annotation_created post_save.connect(handle_annotation_created, sender=Annotation) """ @@ -176,8 +176,8 @@ def handle_project_created(sender, instance, created, **kwargs): Connect this to the Project model's post_save signal: from django.db.models.signals import post_save - from label_studio.projects.models import Project - from label_studio.fsm.integration import handle_project_created + from projects.models import Project + from fsm.integration import handle_project_created post_save.connect(handle_project_created, sender=Project) """ @@ -211,8 +211,8 @@ def get_entities_by_state(model_class, state: str, limit: int = 100): QuerySet of entities in the specified state Example: - from label_studio.tasks.models import Task - from label_studio.fsm.integration import get_entities_by_state + from tasks.models import Task + from fsm.integration import get_entities_by_state completed_tasks = get_entities_by_state(Task, 'COMPLETED', limit=50) """ diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_uuid7_utils.py index 716f3ea59b2e..3758acd1cd27 100644 --- a/label_studio/fsm/tests/test_uuid7_utils.py +++ b/label_studio/fsm/tests/test_uuid7_utils.py @@ -8,8 +8,7 @@ from datetime import datetime, timedelta, timezone from django.test import TestCase - -from label_studio.fsm.utils import ( +from fsm.utils import ( UUID7Generator, generate_uuid7, timestamp_from_uuid7, diff --git a/label_studio/fsm/urls.py b/label_studio/fsm/urls.py index 222292384797..78499a744f3a 100644 --- a/label_studio/fsm/urls.py +++ b/label_studio/fsm/urls.py @@ -1,8 +1,7 @@ """ Core FSM URL patterns for Label Studio. -Provides basic URL routing for state management API that can be extended -by Label Studio Enterprise with additional endpoints. +Provides basic URL routing for state management API """ from django.urls import include, path @@ -18,33 +17,3 @@ urlpatterns = [ path('api/', include(router.urls)), ] - -# Extension point for Label Studio Enterprise -# Enterprise can add additional URL patterns here -enterprise_urlpatterns = [] - -# Function to register additional URL patterns from Enterprise -def register_enterprise_urls(patterns): - """ - Register additional URL patterns from Label Studio Enterprise. - - Args: - patterns: List of URL patterns to register - - Example: - # In LSE code: - from label_studio.fsm.urls import register_enterprise_urls - - enterprise_patterns = [ - path('api/fsm/bulk/', BulkFSMViewSet.as_view(), name='fsm-bulk'), - path('api/fsm/analytics/', AnalyticsFSMViewSet.as_view(), name='fsm-analytics'), - ] - register_enterprise_urls(enterprise_patterns) - """ - global enterprise_urlpatterns - enterprise_urlpatterns.extend(patterns) - - -# Include enterprise URL patterns if any are registered -if enterprise_urlpatterns: - urlpatterns.extend(enterprise_urlpatterns) From 6c8d40e14e64f3292be2852d34771a3416b3211f Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Tue, 26 Aug 2025 15:57:55 -0500 Subject: [PATCH 09/83] fix imports --- label_studio/fsm/tests/test_fsm_integration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 74af25d23d38..4d6f467fdd3e 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -9,14 +9,13 @@ from django.contrib.auth import get_user_model from django.test import TestCase +from fsm.models import AnnotationState, ProjectState, TaskState +from fsm.state_manager import get_state_manager from rest_framework.test import APITestCase from label_studio.projects.models import Project from label_studio.tasks.models import Annotation, Task -from ..models import AnnotationState, ProjectState, TaskState -from ..state_manager import get_state_manager - User = get_user_model() From 71bc9a150894e4b2396042f9a64583df7080c278 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 08:32:57 -0500 Subject: [PATCH 10/83] fix imports --- label_studio/fsm/tests/test_fsm_integration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 4d6f467fdd3e..cfb36baa8ecc 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -11,10 +11,9 @@ from django.test import TestCase from fsm.models import AnnotationState, ProjectState, TaskState from fsm.state_manager import get_state_manager +from projects.models import Project from rest_framework.test import APITestCase - -from label_studio.projects.models import Project -from label_studio.tasks.models import Annotation, Task +from tasks.models import Annotation, Task User = get_user_model() From f99302a26baacff1df087ca35dfef792b76a5367 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 08:48:45 -0500 Subject: [PATCH 11/83] fix fsm tests --- label_studio/fsm/tests/test_uuid7_utils.py | 19 +++++++++++----- label_studio/fsm/utils.py | 26 +++++++++++----------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_uuid7_utils.py index 3758acd1cd27..138bf86a05c1 100644 --- a/label_studio/fsm/tests/test_uuid7_utils.py +++ b/label_studio/fsm/tests/test_uuid7_utils.py @@ -50,9 +50,13 @@ def test_timestamp_extraction(self): extracted_timestamp = timestamp_from_uuid7(uuid7_id) - # Timestamp should be between before and after - self.assertGreaterEqual(extracted_timestamp, before) - self.assertLessEqual(extracted_timestamp, after) + # Timestamp should be close to the generation time (within 1 second tolerance) + # UUID7 has millisecond precision, so some rounding variance is expected + time_diff_before = abs((extracted_timestamp - before).total_seconds()) + time_diff_after = abs((extracted_timestamp - after).total_seconds()) + + self.assertLess(time_diff_before, 1.0) # Within 1 second of before + self.assertLess(time_diff_after, 1.0) # Within 1 second of after def test_uuid7_from_timestamp(self): """Test creating UUID7 from specific timestamp""" @@ -97,10 +101,13 @@ def test_uuid7_time_range_default_end(self): after_call = datetime.now(timezone.utc) - # End timestamp should be close to now + # End timestamp should be close to now (within 1 second tolerance) end_extracted = timestamp_from_uuid7(end_uuid) - self.assertGreaterEqual(end_extracted, before_call) - self.assertLessEqual(end_extracted, after_call) + time_diff_before = abs((end_extracted - before_call).total_seconds()) + time_diff_after = abs((end_extracted - after_call).total_seconds()) + + self.assertLess(time_diff_before, 1.0) # Within 1 second of before_call + self.assertLess(time_diff_after, 1.0) # Within 1 second of after_call def test_validate_uuid7_with_other_versions(self): """Test UUID7 validation with other UUID versions""" diff --git a/label_studio/fsm/utils.py b/label_studio/fsm/utils.py index 563b20f29895..4673968494b7 100644 --- a/label_studio/fsm/utils.py +++ b/label_studio/fsm/utils.py @@ -7,10 +7,11 @@ Uses the uuid-utils library for RFC 9562 compliant UUID7 generation. """ +import uuid from datetime import datetime, timezone from typing import Optional, Tuple -import uuid_utils as uuid +import uuid_utils def generate_uuid7() -> uuid.UUID: @@ -26,7 +27,9 @@ def generate_uuid7() -> uuid.UUID: UUID7 instance with embedded timestamp """ # Use uuid-utils library for RFC 9562 compliant UUID7 generation - return uuid.uuid7() + # Convert to standard uuid.UUID to maintain type consistency + uuid7_obj = uuid_utils.uuid7() + return uuid.UUID(str(uuid7_obj)) def timestamp_from_uuid7(uuid7_id: uuid.UUID) -> datetime: @@ -46,6 +49,7 @@ def timestamp_from_uuid7(uuid7_id: uuid.UUID) -> datetime: """ # UUID7 embeds timestamp in first 48 bits timestamp_ms = (uuid7_id.int >> 80) & ((1 << 48) - 1) + # Return with millisecond precision (UUID7 spec) return datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc) @@ -76,11 +80,10 @@ def uuid7_time_range(start_time: datetime, end_time: Optional[datetime] = None) start_timestamp_ms = int(start_time.timestamp() * 1000) end_timestamp_ms = int(end_time.timestamp() * 1000) - # Create UUID7 with specific timestamp and zero random bits for range start - start_uuid = uuid.UUID(int=(start_timestamp_ms << 80), version=7) - - # Create UUID7 with specific timestamp and max random bits for range end - end_uuid = uuid.UUID(int=(end_timestamp_ms << 80) | ((1 << 80) - 1), version=7) + # Create UUID7 with specific timestamp using proper bit layout + # UUID7 format: timestamp_ms(48) + ver(4) + rand_a(12) + var(2) + rand_b(62) + start_uuid = uuid.UUID(int=(start_timestamp_ms << 80) | (0x7 << 76) | (0b10 << 62)) + end_uuid = uuid.UUID(int=(end_timestamp_ms << 80) | (0x7 << 76) | (0b10 << 62) | ((1 << 62) - 1)) return start_uuid, end_uuid @@ -176,14 +179,11 @@ def generate(self, offset_ms: int = 0) -> uuid.UUID: Returns: UUID7 with adjusted timestamp """ - # For testing purposes, we'll generate a standard UUID7 - # and then optionally create one with specific timestamp if needed - if offset_ms == 0: - return uuid.uuid7() - # For offset timestamps, use manual construction for precise control timestamp_ms = int(self.base_timestamp.timestamp() * 1000) + offset_ms self._counter += 1 # Create UUID7 with specific timestamp and counter for monotonicity - return uuid.UUID(int=(timestamp_ms << 80) | (0x7 << 76) | (self._counter & 0xFFF) << 64 | (0b10 << 62)) + # UUID7 format: timestamp_ms(48) + ver(4) + rand_a(12) + var(2) + rand_b(62) + uuid_int = (timestamp_ms << 80) | (0x7 << 76) | ((self._counter & 0xFFF) << 64) | (0b10 << 62) + return uuid.UUID(int=uuid_int) From 68bce4573b3180719983f83d169aea3e40e58b08 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 12:21:04 -0500 Subject: [PATCH 12/83] fix fsm tests --- label_studio/fsm/models.py | 16 + label_studio/fsm/state_manager.py | 98 +- .../fsm/tests/test_api_usage_examples.py | 832 ++++++++++++++ .../fsm/tests/test_declarative_transitions.py | 1018 +++++++++++++++++ .../tests/test_edge_cases_error_handling.py | 768 +++++++++++++ .../fsm/tests/test_fsm_integration.py | 21 +- .../tests/test_integration_django_models.py | 682 +++++++++++ .../fsm/tests/test_performance_concurrency.py | 753 ++++++++++++ label_studio/fsm/transition_utils.py | 342 ++++++ label_studio/fsm/transitions.py | 486 ++++++++ 10 files changed, 4986 insertions(+), 30 deletions(-) create mode 100644 label_studio/fsm/tests/test_api_usage_examples.py create mode 100644 label_studio/fsm/tests/test_declarative_transitions.py create mode 100644 label_studio/fsm/tests/test_edge_cases_error_handling.py create mode 100644 label_studio/fsm/tests/test_integration_django_models.py create mode 100644 label_studio/fsm/tests/test_performance_concurrency.py create mode 100644 label_studio/fsm/transition_utils.py create mode 100644 label_studio/fsm/transitions.py diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 313a942aa474..bd72f879f0f0 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -209,6 +209,13 @@ class Meta: # No constraints needed - INSERT-only approach ordering = ['-id'] + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for TaskState creation""" + return { + 'project_id': entity.project_id, + } + @property def is_terminal_state(self) -> bool: """Check if this is a terminal task state""" @@ -252,6 +259,15 @@ class Meta: ] ordering = ['-id'] + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for AnnotationState creation""" + return { + 'task_id': entity.task.id, + 'project_id': entity.task.project_id, + 'completed_by_id': entity.completed_by.id if entity.completed_by else None, + } + @property def is_terminal_state(self) -> bool: """Check if this is a terminal annotation state""" diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index c735ab7acb5c..13aa6c3d856e 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -7,7 +7,7 @@ import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type from django.conf import settings from django.core.cache import cache @@ -16,6 +16,10 @@ from .models import BaseState, get_state_model_for_entity +# Avoid circular import +if TYPE_CHECKING: + from .transitions import BaseTransition + logger = logging.getLogger(__name__) @@ -184,6 +188,11 @@ def transition_state( try: with transaction.atomic(): # INSERT-only approach - no UPDATE operations needed + # Get denormalized fields from the state model itself + denormalized_fields = {} + if hasattr(state_model, 'get_denormalized_fields'): + denormalized_fields = state_model.get_denormalized_fields(entity) + new_state_record = state_model.objects.create( **{entity._meta.model_name: entity}, state=new_state, @@ -192,7 +201,7 @@ def transition_state( triggered_by=user, context_data=context or {}, reason=reason, - # Note: Denormalized fields would be added here by Enterprise + **denormalized_fields, ) # Update cache with new state @@ -281,34 +290,69 @@ def warm_cache(cls, entities: List[Model]): cache.set_many(cache_updates, cls.CACHE_TTL) logger.debug(f'Warmed cache for {len(cache_updates)} entities') + @classmethod + def execute_declarative_transition( + cls, transition: 'BaseTransition', entity: Model, user=None, **context_kwargs + ) -> BaseState: + """ + Execute a declarative Pydantic-based transition. -# Extension point for Label Studio Enterprise -# Enterprise can subclass this and add advanced features -class ExtendedStateManager(StateManager): - """ - Extension point for Label Studio Enterprise. + This method integrates the new declarative transition system with + the existing StateManager, providing a bridge between the two approaches. - Enterprise can override this class to add: - - Bulk operations with window functions - - Advanced caching strategies - - Complex transition validation - - Enterprise-specific optimizations - - Example Enterprise usage: - class EnterpriseStateManager(ExtendedStateManager): - @classmethod - def bulk_get_states(cls, entities): - # Enterprise-specific bulk optimization - return super().bulk_get_states_optimized(entities) - - @classmethod - def transition_state(cls, entity, new_state, **kwargs): - # Enterprise transition validation - cls.validate_enterprise_transition(entity, new_state) - return super().transition_state(entity, new_state, **kwargs) - """ + Args: + transition: Instance of a BaseTransition subclass + entity: The entity to transition + user: User executing the transition + **context_kwargs: Additional context data - pass + Returns: + The newly created state record + + Raises: + TransitionValidationError: If transition validation fails + StateManagerError: If transition execution fails + """ + from .transitions import TransitionContext + + # Get current state information + current_state_object = cls.get_current_state_object(entity) + current_state = current_state_object.state if current_state_object else None + + # Build transition context + context = TransitionContext( + entity=entity, + current_user=user, + current_state_object=current_state_object, + current_state=current_state, + target_state=transition.target_state, + organization_id=getattr(entity, 'organization_id', None), + **context_kwargs, + ) + + logger.info( + f'Executing declarative transition {transition.__class__.__name__} ' + f'for {entity._meta.label_lower} {entity.pk}: ' + f'{current_state} → {transition.target_state}' + ) + + try: + # Execute the transition through the declarative system + state_record = transition.execute(context) + + logger.info( + f'Declarative transition successful: {entity._meta.label_lower} {entity.pk} ' + f'now in state {transition.target_state} (record ID: {state_record.id})' + ) + + return state_record + + except Exception as e: + logger.error( + f'Declarative transition failed for {entity._meta.label_lower} {entity.pk}: ' + f'{current_state} → {transition.target_state}: {e}' + ) + raise # Allow runtime configuration of which StateManager to use diff --git a/label_studio/fsm/tests/test_api_usage_examples.py b/label_studio/fsm/tests/test_api_usage_examples.py new file mode 100644 index 000000000000..d3307b058331 --- /dev/null +++ b/label_studio/fsm/tests/test_api_usage_examples.py @@ -0,0 +1,832 @@ +""" +API usage examples and documentation tests for the declarative transition system. + +These tests serve as both validation and comprehensive documentation, +showing how to integrate the transition system with APIs, handle +JSON serialization, generate schemas, and implement real-world patterns. +""" + +import json +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch +from typing import Dict, Any, List, Optional + +from django.test import TestCase +from pydantic import Field, validator + +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + transition_registry, + register_transition +) +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + get_transition_schema, + validate_transition_data, + TransitionBuilder, + create_transition_from_dict +) + + +class APIIntegrationExampleTests(TestCase): + """ + API integration examples demonstrating real-world usage patterns. + + These tests show how to integrate the transition system with + REST APIs, handle JSON data, validate requests, and format responses. + """ + + def setUp(self): + self.mock_entity = Mock() + self.mock_entity.pk = 1 + self.mock_entity._meta.model_name = 'task' + self.mock_entity.organization_id = 100 + + self.mock_user = Mock() + self.mock_user.id = 42 + self.mock_user.username = "api_user" + + # Clear registry + transition_registry._transitions.clear() + + def test_rest_api_task_assignment_example(self): + """ + API EXAMPLE: REST endpoint for task assignment + + Shows how to implement a REST API endpoint that uses + declarative transitions with proper validation and error handling. + """ + + @register_transition('task', 'api_assign_task') + class APITaskAssignmentTransition(BaseTransition): + """Task assignment via API with comprehensive validation""" + assignee_id: int = Field(..., description="ID of user to assign task to") + priority: str = Field("normal", description="Task priority level") + deadline: Optional[datetime] = Field(None, description="Assignment deadline") + assignment_notes: str = Field("", description="Notes about the assignment") + notify_assignee: bool = Field(True, description="Whether to notify the assignee") + + @validator('priority') + def validate_priority(cls, v): + valid_priorities = ['low', 'normal', 'high', 'urgent'] + if v not in valid_priorities: + raise ValueError(f'Priority must be one of: {valid_priorities}') + return v + + @validator('deadline') + def validate_deadline(cls, v): + if v and v <= datetime.now(): + raise ValueError('Deadline must be in the future') + return v + + @property + def target_state(self) -> str: + return "ASSIGNED" + + def validate_transition(self, context: TransitionContext) -> bool: + # Business logic validation + if context.current_state not in ["CREATED", "UNASSIGNED"]: + raise TransitionValidationError( + f"Cannot assign task in state: {context.current_state}", + {"valid_states": ["CREATED", "UNASSIGNED"]} + ) + + # Mock user existence check + if self.assignee_id <= 0: + raise TransitionValidationError( + "Invalid assignee ID", + {"assignee_id": self.assignee_id} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "assignee_id": self.assignee_id, + "priority": self.priority, + "deadline": self.deadline.isoformat() if self.deadline else None, + "assignment_notes": self.assignment_notes, + "notify_assignee": self.notify_assignee, + "assigned_by_id": context.current_user.id if context.current_user else None, + "assigned_at": context.timestamp.isoformat(), + "api_version": "v1" + } + + # Simulate API request data (JSON from client) + api_request_data = { + "assignee_id": 123, + "priority": "high", + "deadline": (datetime.now() + timedelta(days=7)).isoformat(), + "assignment_notes": "Critical task requiring immediate attention", + "notify_assignee": True + } + + # API endpoint simulation: Parse and validate JSON + try: + # Step 1: Create transition from API data + transition = APITaskAssignmentTransition(**api_request_data) + + # Step 2: Execute transition + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state=transition.target_state, + request_data=api_request_data + ) + + # Validate + self.assertTrue(transition.validate_transition(context)) + + # Execute + result_data = transition.transition(context) + + # Step 3: Format API response + api_response = { + "success": True, + "message": "Task assigned successfully", + "data": { + "task_id": self.mock_entity.pk, + "new_state": transition.target_state, + "assignment_details": result_data + }, + "timestamp": datetime.now().isoformat() + } + + # Validate API response + self.assertTrue(api_response["success"]) + self.assertEqual(api_response["data"]["new_state"], "ASSIGNED") + self.assertEqual(api_response["data"]["assignment_details"]["assignee_id"], 123) + self.assertEqual(api_response["data"]["assignment_details"]["priority"], "high") + + except ValueError as e: + # Handle Pydantic validation errors + api_response = { + "success": False, + "error": "Validation Error", + "message": str(e), + "timestamp": datetime.now().isoformat() + } + + except TransitionValidationError as e: + # Handle business logic validation errors + api_response = { + "success": False, + "error": "Business Rule Violation", + "message": str(e), + "context": e.context, + "timestamp": datetime.now().isoformat() + } + + # Test error handling with invalid data + invalid_request = { + "assignee_id": -1, # Invalid ID + "priority": "invalid_priority", # Invalid priority + "deadline": "2020-01-01T00:00:00" # Past deadline + } + + with self.assertRaises(ValueError): + APITaskAssignmentTransition(**invalid_request) + + def test_json_schema_generation_for_api_docs(self): + """ + API DOCUMENTATION: JSON Schema generation + + Shows how to generate OpenAPI/JSON schemas for API documentation + from Pydantic transition models. + """ + + @register_transition('annotation', 'api_submit_annotation') + class APIAnnotationSubmissionTransition(BaseTransition): + """Submit annotation via API with rich metadata""" + confidence_score: float = Field( + ..., + ge=0.0, + le=1.0, + description="Annotator's confidence in the annotation (0.0-1.0)" + ) + annotation_quality: str = Field( + "good", + description="Subjective quality assessment", + pattern="^(excellent|good|fair|poor)$" + ) + time_spent_seconds: int = Field( + ..., + ge=1, + description="Time spent on annotation in seconds" + ) + difficulty_level: str = Field( + "medium", + description="Perceived difficulty of the annotation task" + ) + review_requested: bool = Field( + False, + description="Whether the annotator requests manual review" + ) + tags: List[str] = Field( + default_factory=list, + description="Optional tags for categorization" + ) + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Additional metadata about the annotation process" + ) + + @property + def target_state(self) -> str: + return "SUBMITTED" + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "confidence_score": self.confidence_score, + "annotation_quality": self.annotation_quality, + "time_spent_seconds": self.time_spent_seconds, + "difficulty_level": self.difficulty_level, + "review_requested": self.review_requested, + "tags": self.tags, + "metadata": self.metadata, + "submitted_at": context.timestamp.isoformat() + } + + # Generate JSON schema + schema = get_transition_schema(APIAnnotationSubmissionTransition) + + # Validate schema structure + self.assertIn("properties", schema) + self.assertIn("required", schema) + + # Check specific field schemas + properties = schema["properties"] + + # confidence_score should have min/max constraints + confidence_schema = properties["confidence_score"] + self.assertEqual(confidence_schema["type"], "number") + self.assertEqual(confidence_schema["minimum"], 0.0) + self.assertEqual(confidence_schema["maximum"], 1.0) + self.assertIn("Annotator's confidence", confidence_schema["description"]) + + # annotation_quality should have pattern constraint + quality_schema = properties["annotation_quality"] + self.assertEqual(quality_schema["type"], "string") + self.assertIn("pattern", quality_schema) + + # time_spent_seconds should have minimum constraint + time_schema = properties["time_spent_seconds"] + self.assertEqual(time_schema["type"], "integer") + self.assertEqual(time_schema["minimum"], 1) + + # tags should be array type + tags_schema = properties["tags"] + self.assertEqual(tags_schema["type"], "array") + self.assertEqual(tags_schema["items"]["type"], "string") + + # metadata should be object type + metadata_schema = properties["metadata"] + self.assertEqual(metadata_schema["type"], "object") + + # Required fields + required_fields = schema["required"] + self.assertIn("confidence_score", required_fields) + self.assertIn("time_spent_seconds", required_fields) + self.assertNotIn("tags", required_fields) # Optional field + + # Test schema-driven validation + valid_data = { + "confidence_score": 0.85, + "annotation_quality": "good", + "time_spent_seconds": 120, + "difficulty_level": "hard", + "review_requested": True, + "tags": ["important", "complex"], + "metadata": {"tool_version": "1.2.3", "browser": "chrome"} + } + + transition = APIAnnotationSubmissionTransition(**valid_data) + self.assertEqual(transition.confidence_score, 0.85) + self.assertEqual(len(transition.tags), 2) + + # Print schema for documentation (would be used in API docs) + schema_json = json.dumps(schema, indent=2) + self.assertIsInstance(schema_json, str) + self.assertIn("confidence_score", schema_json) + + def test_bulk_operations_api_pattern(self): + """ + API EXAMPLE: Bulk operations with transitions + + Shows how to handle bulk operations where multiple entities + need to be transitioned with the same or different parameters. + """ + + @register_transition('task', 'bulk_status_update') + class BulkStatusUpdateTransition(BaseTransition): + """Bulk status update for multiple tasks""" + new_status: str = Field(..., description="New status for all tasks") + update_reason: str = Field(..., description="Reason for bulk update") + batch_id: str = Field(..., description="Unique identifier for this batch") + force_update: bool = Field(False, description="Force update even if invalid states") + + @property + def target_state(self) -> str: + return self.new_status + + def validate_transition(self, context: TransitionContext) -> bool: + valid_statuses = ["CREATED", "IN_PROGRESS", "COMPLETED", "CANCELLED"] + if self.new_status not in valid_statuses: + raise TransitionValidationError(f"Invalid status: {self.new_status}") + + # Skip state validation if force update + if not self.force_update: + if context.current_state == self.new_status: + raise TransitionValidationError("Cannot update to same status") + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "new_status": self.new_status, + "update_reason": self.update_reason, + "batch_id": self.batch_id, + "force_update": self.force_update, + "updated_at": context.timestamp.isoformat(), + "entity_id": context.entity.pk + } + + # Simulate bulk API request + bulk_request = { + "task_ids": [1, 2, 3, 4, 5], + "transition_data": { + "new_status": "IN_PROGRESS", + "update_reason": "Project phase change", + "batch_id": "batch_2024_001", + "force_update": False + } + } + + # Process bulk request + batch_results = [] + failed_updates = [] + + for task_id in bulk_request["task_ids"]: + # Create mock entity for each task + mock_task = Mock() + mock_task.pk = task_id + mock_task._meta.model_name = 'task' + + try: + # Create transition + transition = BulkStatusUpdateTransition(**bulk_request["transition_data"]) + + # Mock different current states for testing + current_states = ["CREATED", "CREATED", "IN_PROGRESS", "CREATED", "COMPLETED"] + current_state = current_states[task_id - 1] # Adjust for 0-based indexing + + context = TransitionContext( + entity=mock_task, + current_user=self.mock_user, + current_state=current_state, + target_state=transition.target_state + ) + + # Validate and execute + if transition.validate_transition(context): + result = transition.transition(context) + batch_results.append({ + "task_id": task_id, + "success": True, + "result": result + }) + + except TransitionValidationError as e: + failed_updates.append({ + "task_id": task_id, + "success": False, + "error": str(e), + "context": getattr(e, 'context', {}) + }) + + # API response for bulk operation + api_response = { + "batch_id": bulk_request["transition_data"]["batch_id"], + "total_requested": len(bulk_request["task_ids"]), + "successful_updates": len(batch_results), + "failed_updates": len(failed_updates), + "results": batch_results, + "failures": failed_updates, + "timestamp": datetime.now().isoformat() + } + + # Validate bulk results + self.assertEqual(api_response["total_requested"], 5) + self.assertGreater(api_response["successful_updates"], 0) + + # Some tasks should succeed, some might fail due to state validation + total_processed = api_response["successful_updates"] + api_response["failed_updates"] + self.assertEqual(total_processed, 5) + + # Check individual results + for result in batch_results: + self.assertTrue(result["success"]) + self.assertEqual(result["result"]["new_status"], "IN_PROGRESS") + self.assertEqual(result["result"]["batch_id"], "batch_2024_001") + + def test_webhook_integration_pattern(self): + """ + API EXAMPLE: Webhook integration with transitions + + Shows how to integrate transitions with webhook systems + for external notifications and integrations. + """ + + @register_transition('task', 'webhook_completion') + class WebhookTaskCompletionTransition(BaseTransition): + """Task completion with webhook notifications""" + completion_quality: float = Field(..., ge=0.0, le=1.0) + completion_notes: str = Field("", description="Completion notes") + webhook_urls: List[str] = Field(default_factory=list, description="Webhook URLs to notify") + notification_data: Dict[str, Any] = Field(default_factory=dict, description="Data to send in webhooks") + webhook_responses: List[Dict[str, Any]] = Field(default_factory=list, description="Webhook response tracking") + + @property + def target_state(self) -> str: + return "COMPLETED" + + def validate_transition(self, context: TransitionContext) -> bool: + if context.current_state != "IN_PROGRESS": + raise TransitionValidationError("Can only complete in-progress tasks") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "completion_quality": self.completion_quality, + "completion_notes": self.completion_notes, + "webhook_urls": self.webhook_urls, + "notification_data": self.notification_data, + "completed_at": context.timestamp.isoformat(), + "completed_by_id": context.current_user.id if context.current_user else None + } + + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + """Send webhook notifications after successful transition""" + if self.webhook_urls: + webhook_payload = { + "event": "task.completed", + "task_id": context.entity.pk, + "state_record_id": getattr(state_record, 'id', 'mock-id'), + "completion_data": { + "quality": self.completion_quality, + "notes": self.completion_notes, + "completed_by": context.current_user.id if context.current_user else None, + "completed_at": context.timestamp.isoformat() + }, + "custom_data": self.notification_data, + "timestamp": datetime.now().isoformat() + } + + # Mock webhook sending (in real implementation, use async requests) + for url in self.webhook_urls: + webhook_response = { + "url": url, + "payload": webhook_payload, + "status": "sent", + "timestamp": datetime.now().isoformat() + } + self.webhook_responses.append(webhook_response) + + # Test webhook transition + transition = WebhookTaskCompletionTransition( + completion_quality=0.95, + completion_notes="Task completed with excellent quality", + webhook_urls=[ + "https://api.example.com/webhooks/task-completed", + "https://notifications.example.com/task-events" + ], + notification_data={ + "project_id": 123, + "priority": "high", + "client_id": "client_456" + } + ) + + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="IN_PROGRESS", + target_state=transition.target_state + ) + + # Validate and execute + self.assertTrue(transition.validate_transition(context)) + result = transition.transition(context) + + # Simulate state record creation + mock_state_record = Mock() + mock_state_record.id = "state-uuid-123" + + # Execute post-hook (webhook sending) + transition.post_transition_hook(context, mock_state_record) + + # Validate webhook responses + self.assertEqual(len(transition.webhook_responses), 2) + + for response in transition.webhook_responses: + self.assertIn("url", response) + self.assertIn("payload", response) + self.assertEqual(response["status"], "sent") + + # Validate webhook payload structure + payload = response["payload"] + self.assertEqual(payload["event"], "task.completed") + self.assertEqual(payload["task_id"], self.mock_entity.pk) + self.assertEqual(payload["completion_data"]["quality"], 0.95) + self.assertEqual(payload["custom_data"]["project_id"], 123) + + def test_api_error_handling_patterns(self): + """ + API EXAMPLE: Comprehensive error handling patterns + + Shows how to implement robust error handling for API endpoints + using the transition system with proper HTTP status codes and messages. + """ + + @register_transition('task', 'api_critical_update') + class APICriticalUpdateTransition(BaseTransition): + """Critical update with extensive validation""" + update_type: str = Field(..., description="Type of critical update") + severity_level: int = Field(..., ge=1, le=5, description="Severity level 1-5") + authorization_token: str = Field(..., description="Authorization token for critical updates") + backup_required: bool = Field(True, description="Whether backup is required before update") + + @property + def target_state(self) -> str: + return "CRITICALLY_UPDATED" + + def validate_transition(self, context: TransitionContext) -> bool: + errors = [] + + # Authorization check + if len(self.authorization_token) < 10: + errors.append("Invalid authorization token") + + # Severity validation + if self.severity_level >= 4 and not context.current_user: + errors.append("High severity updates require authenticated user") + + # Update type validation + valid_types = ["security_patch", "critical_fix", "emergency_update"] + if self.update_type not in valid_types: + errors.append(f"Invalid update type. Must be one of: {valid_types}") + + # State validation + if context.current_state in ["COMPLETED", "ARCHIVED"]: + errors.append(f"Cannot perform critical updates on {context.current_state.lower()} tasks") + + # Backup requirement + if self.backup_required and self.severity_level >= 3: + # Mock backup check + backup_exists = True # In real implementation, check backup system + if not backup_exists: + errors.append("Backup required but not available") + + if errors: + raise TransitionValidationError( + "Critical update validation failed", + { + "validation_errors": errors, + "error_count": len(errors), + "severity_level": self.severity_level, + "update_type": self.update_type + } + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "update_type": self.update_type, + "severity_level": self.severity_level, + "backup_required": self.backup_required, + "authorized_by": context.current_user.id if context.current_user else None, + "updated_at": context.timestamp.isoformat(), + "critical_update_id": f"crit_{int(context.timestamp.timestamp())}" + } + + # Test various error scenarios and API responses + + # 1. Test successful request + valid_request = { + "update_type": "security_patch", + "severity_level": 3, + "authorization_token": "valid_token_12345", + "backup_required": True + } + + def simulate_api_endpoint(request_data, current_state="IN_PROGRESS"): + """Simulate API endpoint with proper error handling""" + try: + # Parse and validate request + transition = APICriticalUpdateTransition(**request_data) + + # Create context + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state=current_state, + target_state=transition.target_state + ) + + # Validate business logic + transition.validate_transition(context) + + # Execute transition + result = transition.transition(context) + + return { + "status_code": 200, + "success": True, + "data": { + "task_id": self.mock_entity.pk, + "new_state": transition.target_state, + "update_details": result + } + } + + except ValueError as e: + # Pydantic validation error (400 Bad Request) + return { + "status_code": 400, + "success": False, + "error": "Bad Request", + "message": "Invalid request data", + "details": str(e) + } + + except TransitionValidationError as e: + # Business logic validation error (422 Unprocessable Entity) + return { + "status_code": 422, + "success": False, + "error": "Validation Failed", + "message": str(e), + "validation_errors": e.context.get("validation_errors", []), + "context": e.context + } + + except Exception as e: + # Unexpected error (500 Internal Server Error) + return { + "status_code": 500, + "success": False, + "error": "Internal Server Error", + "message": "An unexpected error occurred", + "details": str(e) if not isinstance(e, Exception) else "Server error" + } + + # Test successful request + response = simulate_api_endpoint(valid_request) + self.assertEqual(response["status_code"], 200) + self.assertTrue(response["success"]) + self.assertIn("update_details", response["data"]) + + # Test Pydantic validation error (invalid severity level) + invalid_request = { + "update_type": "security_patch", + "severity_level": 10, # Invalid: > 5 + "authorization_token": "valid_token_12345" + } + + response = simulate_api_endpoint(invalid_request) + self.assertEqual(response["status_code"], 400) + self.assertFalse(response["success"]) + self.assertEqual(response["error"], "Bad Request") + + # Test business logic validation error + business_logic_error_request = { + "update_type": "invalid_type", # Invalid update type + "severity_level": 5, + "authorization_token": "short", # Too short + "backup_required": True + } + + response = simulate_api_endpoint(business_logic_error_request) + self.assertEqual(response["status_code"], 422) + self.assertFalse(response["success"]) + self.assertEqual(response["error"], "Validation Failed") + self.assertIn("validation_errors", response) + self.assertGreater(len(response["validation_errors"]), 0) + + # Test state validation error + response = simulate_api_endpoint(valid_request, current_state="COMPLETED") + self.assertEqual(response["status_code"], 422) + # The error message is in validation_errors list, not the main message + validation_errors = response.get("validation_errors", []) + self.assertTrue(any("completed tasks" in error for error in validation_errors)) + + def test_api_versioning_and_backward_compatibility(self): + """ + API EXAMPLE: API versioning with backward compatibility + + Shows how to handle API versioning using transition inheritance + and maintain backward compatibility. + """ + + # Version 1 API + @register_transition('task', 'update_task_v1') + class UpdateTaskV1Transition(BaseTransition): + """Version 1 task update API""" + status: str = Field(..., description="New task status") + notes: str = Field("", description="Update notes") + + @property + def target_state(self) -> str: + return self.status + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "status": self.status, + "notes": self.notes, + "api_version": "v1", + "updated_at": context.timestamp.isoformat() + } + + # Version 2 API with additional features + @register_transition('task', 'update_task_v2') + class UpdateTaskV2Transition(UpdateTaskV1Transition): + """Version 2 task update API with enhanced features""" + priority: Optional[str] = Field(None, description="Task priority") + tags: List[str] = Field(default_factory=list, description="Task tags") + estimated_hours: Optional[float] = Field(None, ge=0, description="Estimated hours") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Call parent method for base functionality + base_data = super().transition(context) + + # Add V2 specific data + v2_data = { + "priority": self.priority, + "tags": self.tags, + "estimated_hours": self.estimated_hours, + "metadata": self.metadata, + "api_version": "v2" + } + + return {**base_data, **v2_data} + + # Test V1 API (backward compatibility) + v1_request = { + "status": "IN_PROGRESS", + "notes": "Started working on task" + } + + v1_transition = UpdateTaskV1Transition(**v1_request) + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=v1_transition.target_state + ) + + v1_result = v1_transition.transition(context) + self.assertEqual(v1_result["api_version"], "v1") + self.assertEqual(v1_result["status"], "IN_PROGRESS") + self.assertNotIn("priority", v1_result) # V1 doesn't have priority + + # Test V2 API with enhanced features + v2_request = { + "status": "IN_PROGRESS", + "notes": "Started working on task with enhanced tracking", + "priority": "high", + "tags": ["urgent", "client-facing"], + "estimated_hours": 4.5, + "metadata": {"client_id": 123, "project_phase": "development"} + } + + v2_transition = UpdateTaskV2Transition(**v2_request) + v2_result = v2_transition.transition(context) + + self.assertEqual(v2_result["api_version"], "v2") + self.assertEqual(v2_result["status"], "IN_PROGRESS") # Inherited from V1 + self.assertEqual(v2_result["priority"], "high") # V2 feature + self.assertEqual(len(v2_result["tags"]), 2) # V2 feature + self.assertEqual(v2_result["estimated_hours"], 4.5) # V2 feature + self.assertIn("client_id", v2_result["metadata"]) # V2 feature + + # Test V2 API with minimal data (backward compatible) + v2_minimal_request = { + "status": "COMPLETED", + "notes": "Task finished" + } + + v2_minimal_transition = UpdateTaskV2Transition(**v2_minimal_request) + v2_minimal_result = v2_minimal_transition.transition(context) + + self.assertEqual(v2_minimal_result["api_version"], "v2") + self.assertEqual(v2_minimal_result["status"], "COMPLETED") + self.assertIsNone(v2_minimal_result["priority"]) # Optional field + self.assertEqual(v2_minimal_result["tags"], []) # Default value + self.assertIsNone(v2_minimal_result["estimated_hours"]) # Optional field + self.assertEqual(v2_minimal_result["metadata"], {}) # Default value \ No newline at end of file diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py new file mode 100644 index 000000000000..33b83287b972 --- /dev/null +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -0,0 +1,1018 @@ +""" +Comprehensive tests for the declarative Pydantic-based transition system. + +This test suite provides extensive coverage of the new transition system, +including usage examples, edge cases, validation scenarios, and integration +patterns to serve as both tests and documentation. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock +from typing import Dict, Any +import json + +from django.contrib.auth import get_user_model +from django.test import TestCase, TransactionTestCase +from pydantic import Field, ValidationError + +from fsm.state_choices import TaskStateChoices, AnnotationStateChoices +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + transition_registry, + register_transition +) +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + get_valid_transitions, + TransitionBuilder, + validate_transition_data, + get_transition_schema +) + +User = get_user_model() + + +class MockTask: + """Mock task model for testing""" + def __init__(self, pk=1): + self.pk = pk + self.id = pk + self.organization_id = 1 + self._meta = Mock() + self._meta.model_name = 'task' + self._meta.label_lower = 'tasks.task' + + +class MockAnnotation: + """Mock annotation model for testing""" + def __init__(self, pk=1): + self.pk = pk + self.id = pk + self.result = {'test': 'data'} # Mock annotation data + self.organization_id = 1 + self._meta = Mock() + self._meta.model_name = 'annotation' + self._meta.label_lower = 'tasks.annotation' + + +class TestTransition(BaseTransition): + """Test transition class""" + + test_field: str + optional_field: int = 42 + + @property + def target_state(self) -> str: + return "TEST_STATE" + + @classmethod + def get_target_state(cls) -> str: + """Return the target state at class level""" + return "TEST_STATE" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + """Allow transition from any state for testing""" + return True + + def validate_transition(self, context: TransitionContext) -> bool: + if self.test_field == "invalid": + raise TransitionValidationError("Test validation error") + return super().validate_transition(context) + + def transition(self, context: TransitionContext) -> dict: + return { + "test_field": self.test_field, + "optional_field": self.optional_field, + "context_entity_id": context.entity.pk + } + + +class DeclarativeTransitionTests(TestCase): + """Test cases for the declarative transition system""" + + def setUp(self): + self.task = MockTask() + self.annotation = MockAnnotation() + self.user = Mock() + self.user.id = 1 + self.user.username = "testuser" + + # Register test transition + transition_registry.register('task', 'test_transition', TestTransition) + + def test_transition_context_creation(self): + """Test creation of transition context""" + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state="CREATED", + target_state="IN_PROGRESS", + organization_id=1 + ) + + self.assertEqual(context.entity, self.task) + self.assertEqual(context.current_user, self.user) + self.assertEqual(context.current_state, "CREATED") + self.assertEqual(context.target_state, "IN_PROGRESS") + self.assertEqual(context.organization_id, 1) + self.assertFalse(context.is_initial_transition) + self.assertTrue(context.has_current_state) + + def test_transition_context_initial_state(self): + """Test context for initial state transition""" + context = TransitionContext( + entity=self.task, + current_state=None, + target_state="CREATED" + ) + + self.assertTrue(context.is_initial_transition) + self.assertFalse(context.has_current_state) + + def test_transition_validation_success(self): + """Test successful transition validation""" + transition = TestTransition(test_field="valid") + context = TransitionContext( + entity=self.task, + current_state="CREATED", + target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + def test_transition_validation_failure(self): + """Test transition validation failure""" + transition = TestTransition(test_field="invalid") + context = TransitionContext( + entity=self.task, + current_state="CREATED", + target_state=transition.target_state + ) + + with self.assertRaises(TransitionValidationError): + transition.validate_transition(context) + + def test_transition_execution(self): + """Test transition data generation""" + transition = TestTransition(test_field="test_value", optional_field=100) + context = TransitionContext( + entity=self.task, + current_state="CREATED", + target_state=transition.target_state + ) + + result = transition.transition(context) + + self.assertEqual(result["test_field"], "test_value") + self.assertEqual(result["optional_field"], 100) + self.assertEqual(result["context_entity_id"], self.task.pk) + + def test_transition_name_generation(self): + """Test automatic transition name generation""" + transition = TestTransition(test_field="test") + self.assertEqual(transition.transition_name, "test_transition") + + @patch('fsm.state_manager.StateManager.transition_state') + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_transition_execute_full_workflow(self, mock_get_state, mock_transition): + """Test full transition execution workflow""" + # Setup mocks + mock_get_state.return_value = None # No current state + mock_transition.return_value = True + + mock_state_record = Mock() + mock_state_record.id = "test-uuid" + + with patch('fsm.state_manager.StateManager.get_current_state_object', return_value=mock_state_record): + transition = TestTransition(test_field="test_value") + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=None, + target_state=transition.target_state + ) + + # Execute transition + result = transition.execute(context) + + # Verify StateManager was called correctly + mock_transition.assert_called_once() + call_args = mock_transition.call_args + + self.assertEqual(call_args[1]['entity'], self.task) + self.assertEqual(call_args[1]['new_state'], "TEST_STATE") + self.assertEqual(call_args[1]['transition_name'], "test_transition") + self.assertEqual(call_args[1]['user'], self.user) + + # Check context data + context_data = call_args[1]['context'] + self.assertEqual(context_data['test_field'], "test_value") + self.assertEqual(context_data['optional_field'], 42) + + +class TransitionRegistryTests(TestCase): + """Test cases for the transition registry""" + + def setUp(self): + self.registry = transition_registry + + def test_transition_registration(self): + """Test registering transitions""" + self.registry.register('test_entity', 'test_transition', TestTransition) + + retrieved = self.registry.get_transition('test_entity', 'test_transition') + self.assertEqual(retrieved, TestTransition) + + def test_get_transitions_for_entity(self): + """Test getting all transitions for an entity""" + self.registry.register('test_entity', 'transition1', TestTransition) + self.registry.register('test_entity', 'transition2', TestTransition) + + transitions = self.registry.get_transitions_for_entity('test_entity') + + self.assertIn('transition1', transitions) + self.assertIn('transition2', transitions) + self.assertEqual(len(transitions), 2) + + def test_list_entities(self): + """Test listing registered entities""" + self.registry.register('entity1', 'transition1', TestTransition) + self.registry.register('entity2', 'transition2', TestTransition) + + entities = self.registry.list_entities() + + self.assertIn('entity1', entities) + self.assertIn('entity2', entities) + + +class TransitionUtilsTests(TestCase): + """Test cases for transition utility functions""" + + def setUp(self): + self.task = MockTask() + transition_registry.register('task', 'test_transition', TestTransition) + + def test_get_available_transitions(self): + """Test getting available transitions for entity""" + transitions = get_available_transitions(self.task) + self.assertIn('test_transition', transitions) + + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_get_valid_transitions(self, mock_get_state): + """Test filtering valid transitions""" + mock_get_state.return_value = None + + valid_transitions = get_valid_transitions(self.task, validate=True) + self.assertIn('test_transition', valid_transitions) + + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_get_valid_transitions_with_invalid(self, mock_get_state): + """Test filtering out invalid transitions""" + mock_get_state.return_value = None + + # Register an invalid transition + class InvalidTransition(TestTransition): + @classmethod + def can_transition_from_state(cls, context): + # This transition is never valid at the class level + return False + + def validate_transition(self, context): + raise TransitionValidationError("Always invalid") + + transition_registry.register('task', 'invalid_transition', InvalidTransition) + + valid_transitions = get_valid_transitions(self.task, validate=True) + self.assertIn('test_transition', valid_transitions) + self.assertNotIn('invalid_transition', valid_transitions) + + @patch('fsm.transition_utils.execute_transition') + def test_transition_builder(self, mock_execute): + """Test fluent transition builder interface""" + mock_execute.return_value = Mock() + + result = (TransitionBuilder(self.task) + .transition('test_transition') + .with_data(test_field="builder_test") + .by_user(Mock()) + .with_context(extra="context") + .execute()) + + mock_execute.assert_called_once() + call_args = mock_execute.call_args + + self.assertEqual(call_args[1]['transition_name'], 'test_transition') + self.assertEqual(call_args[1]['transition_data']['test_field'], 'builder_test') + + +class ExampleTransitionIntegrationTests(TestCase): + """Integration tests using the example transitions""" + + def setUp(self): + # Import example transitions to register them + from fsm.example_transitions import ( + StartTaskTransition, + CompleteTaskTransition, + SubmitAnnotationTransition + ) + + self.task = MockTask() + self.annotation = MockAnnotation() + self.user = Mock() + self.user.id = 1 + self.user.username = "testuser" + + def test_start_task_transition_validation(self): + """Test StartTaskTransition validation""" + from fsm.example_transitions import StartTaskTransition + + transition = StartTaskTransition(assigned_user_id=123) + + # Test valid transition from CREATED + context = TransitionContext( + entity=self.task, + current_state=TaskStateChoices.CREATED, + target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + # Test invalid transition from COMPLETED + context.current_state = TaskStateChoices.COMPLETED + + with self.assertRaises(TransitionValidationError): + transition.validate_transition(context) + + def test_submit_annotation_validation(self): + """Test SubmitAnnotationTransition validation""" + from fsm.example_transitions import SubmitAnnotationTransition + + transition = SubmitAnnotationTransition() + + # Test valid transition + context = TransitionContext( + entity=self.annotation, + current_state=AnnotationStateChoices.DRAFT, + target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + def test_transition_data_generation(self): + """Test that transitions generate appropriate context data""" + from fsm.example_transitions import StartTaskTransition + + transition = StartTaskTransition( + assigned_user_id=123, + estimated_duration=5, + priority="high" + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + target_state=transition.target_state, + timestamp=datetime.now() + ) + + result = transition.transition(context) + + self.assertEqual(result['assigned_user_id'], 123) + self.assertEqual(result['estimated_duration'], 5) + self.assertEqual(result['priority'], 'high') + self.assertIn('started_at', result) + self.assertEqual(result['assignment_type'], 'manual') + + +class ComprehensiveUsageExampleTests(TestCase): + """ + Comprehensive test cases demonstrating various usage patterns. + + These tests serve as both validation and documentation for how to + implement and use the declarative transition system. + """ + + def setUp(self): + self.task = MockTask() + self.user = Mock() + self.user.id = 123 + self.user.username = "testuser" + + # Clear registry to avoid conflicts + transition_registry._transitions.clear() + + def test_basic_transition_implementation(self): + """ + USAGE EXAMPLE: Basic transition implementation + + Shows how to implement a simple transition with validation. + """ + + class BasicTransition(BaseTransition): + """Example: Simple transition with required field""" + message: str = Field(..., description="Message for the transition") + + @property + def target_state(self) -> str: + return "PROCESSED" + + def validate_transition(self, context: TransitionContext) -> bool: + # Business logic validation + if context.current_state == "COMPLETED": + raise TransitionValidationError("Cannot process completed items") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "message": self.message, + "processed_by": context.current_user.username if context.current_user else "system", + "processed_at": context.timestamp.isoformat() + } + + # Test the implementation + transition = BasicTransition(message="Processing task") + self.assertEqual(transition.message, "Processing task") + self.assertEqual(transition.target_state, "PROCESSED") + + # Test validation + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state="CREATED", + target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + # Test data generation + data = transition.transition(context) + self.assertEqual(data["message"], "Processing task") + self.assertEqual(data["processed_by"], "testuser") + self.assertIn("processed_at", data) + + def test_complex_validation_example(self): + """ + USAGE EXAMPLE: Complex validation with multiple conditions + + Shows how to implement sophisticated business logic validation. + """ + + class TaskAssignmentTransition(BaseTransition): + """Example: Complex validation for task assignment""" + assignee_id: int = Field(..., description="User to assign task to") + priority: str = Field("normal", description="Task priority") + deadline: datetime = Field(None, description="Task deadline") + + @property + def target_state(self) -> str: + return "ASSIGNED" + + def validate_transition(self, context: TransitionContext) -> bool: + # Multiple validation conditions + if context.current_state not in ["CREATED", "UNASSIGNED"]: + raise TransitionValidationError( + f"Cannot assign task in state {context.current_state}", + {"current_state": context.current_state, "task_id": context.entity.pk} + ) + + # Check deadline is in future + if self.deadline and self.deadline <= datetime.now(): + raise TransitionValidationError( + "Deadline must be in the future", + {"deadline": self.deadline.isoformat()} + ) + + # Check priority is valid + valid_priorities = ["low", "normal", "high", "urgent"] + if self.priority not in valid_priorities: + raise TransitionValidationError( + f"Invalid priority: {self.priority}", + {"valid_priorities": valid_priorities} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "assignee_id": self.assignee_id, + "priority": self.priority, + "deadline": self.deadline.isoformat() if self.deadline else None, + "assigned_by": context.current_user.id if context.current_user else None, + "assignment_reason": f"Task assigned to user {self.assignee_id}" + } + + # Test valid assignment + future_deadline = datetime.now() + timedelta(days=7) + transition = TaskAssignmentTransition( + assignee_id=456, + priority="high", + deadline=future_deadline + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state="CREATED", + target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + # Test invalid state + context.current_state = "COMPLETED" + with self.assertRaises(TransitionValidationError) as cm: + transition.validate_transition(context) + + self.assertIn("Cannot assign task in state", str(cm.exception)) + self.assertIn("COMPLETED", str(cm.exception)) + + # Test invalid deadline + past_deadline = datetime.now() - timedelta(days=1) + invalid_transition = TaskAssignmentTransition( + assignee_id=456, + deadline=past_deadline + ) + + context.current_state = "CREATED" + with self.assertRaises(TransitionValidationError) as cm: + invalid_transition.validate_transition(context) + + self.assertIn("Deadline must be in the future", str(cm.exception)) + + def test_hooks_and_lifecycle_example(self): + """ + USAGE EXAMPLE: Using pre/post hooks for side effects + + Shows how to implement lifecycle hooks for notifications, + cleanup, or other side effects. + """ + + class NotificationTransition(BaseTransition): + """Example: Transition with notification hooks""" + notification_message: str = Field(..., description="Notification message") + notify_users: list = Field(default_factory=list, description="Users to notify") + notifications_sent: list = Field(default_factory=list, description="Track sent notifications") + cleanup_performed: bool = Field(default=False, description="Track cleanup status") + + @property + def target_state(self) -> str: + return "NOTIFIED" + + @classmethod + def get_target_state(cls) -> str: + return "NOTIFIED" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def pre_transition_hook(self, context: TransitionContext) -> None: + """Prepare notifications before state change""" + # Validate notification recipients + if not self.notify_users: + self.notify_users = [context.current_user.id] if context.current_user else [] + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "notification_message": self.notification_message, + "notify_users": self.notify_users, + "notification_sent_at": context.timestamp.isoformat() + } + + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + """Send notifications after successful state change""" + # Mock notification sending + for user_id in self.notify_users: + self.notifications_sent.append({ + "user_id": user_id, + "message": self.notification_message, + "sent_at": context.timestamp + }) + + # Mock cleanup + self.cleanup_performed = True + + # Test the hooks + transition = NotificationTransition( + notification_message="Task has been updated", + notify_users=[123, 456] + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state="CREATED", + target_state=transition.target_state + ) + + # Test pre-hook + transition.pre_transition_hook(context) + self.assertEqual(transition.notify_users, [123, 456]) + + # Test transition + data = transition.transition(context) + self.assertEqual(data["notification_message"], "Task has been updated") + + # Test post-hook + mock_state_record = Mock() + transition.post_transition_hook(context, mock_state_record) + + self.assertEqual(len(transition.notifications_sent), 2) + self.assertTrue(transition.cleanup_performed) + + def test_conditional_transition_example(self): + """ + USAGE EXAMPLE: Conditional transitions based on data + + Shows how to implement transitions that behave differently + based on input data or context. + """ + + class ConditionalApprovalTransition(BaseTransition): + """Example: Conditional approval based on confidence""" + confidence_score: float = Field(..., ge=0.0, le=1.0, description="Confidence score") + auto_approve_threshold: float = Field(0.9, description="Auto-approval threshold") + reviewer_id: int = Field(None, description="Manual reviewer ID") + + @property + def target_state(self) -> str: + # Dynamic target state based on confidence + if self.confidence_score >= self.auto_approve_threshold: + return "AUTO_APPROVED" + else: + return "PENDING_REVIEW" + + def validate_transition(self, context: TransitionContext) -> bool: + # Different validation based on approval type + if self.confidence_score >= self.auto_approve_threshold: + # Auto-approval validation + if context.current_state != "SUBMITTED": + raise TransitionValidationError("Can only auto-approve submitted items") + else: + # Manual review validation + if not self.reviewer_id: + raise TransitionValidationError("Manual review requires reviewer_id") + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + base_data = { + "confidence_score": self.confidence_score, + "threshold": self.auto_approve_threshold, + } + + if self.confidence_score >= self.auto_approve_threshold: + # Auto-approval data + return { + **base_data, + "approval_type": "automatic", + "approved_at": context.timestamp.isoformat(), + "approved_by": "system" + } + else: + # Manual review data + return { + **base_data, + "approval_type": "manual", + "assigned_reviewer": self.reviewer_id, + "review_requested_at": context.timestamp.isoformat() + } + + # Test auto-approval path + high_confidence_transition = ConditionalApprovalTransition( + confidence_score=0.95 + ) + + self.assertEqual(high_confidence_transition.target_state, "AUTO_APPROVED") + + context = TransitionContext( + entity=self.task, + current_state="SUBMITTED", + target_state=high_confidence_transition.target_state + ) + + self.assertTrue(high_confidence_transition.validate_transition(context)) + + auto_data = high_confidence_transition.transition(context) + self.assertEqual(auto_data["approval_type"], "automatic") + self.assertEqual(auto_data["approved_by"], "system") + + # Test manual review path + low_confidence_transition = ConditionalApprovalTransition( + confidence_score=0.7, + reviewer_id=789 + ) + + self.assertEqual(low_confidence_transition.target_state, "PENDING_REVIEW") + + context.target_state = low_confidence_transition.target_state + self.assertTrue(low_confidence_transition.validate_transition(context)) + + manual_data = low_confidence_transition.transition(context) + self.assertEqual(manual_data["approval_type"], "manual") + self.assertEqual(manual_data["assigned_reviewer"], 789) + + def test_registry_and_decorator_usage(self): + """ + USAGE EXAMPLE: Using the registry and decorator system + + Shows how to register transitions and use the decorator syntax. + """ + + @register_transition('document', 'publish') + class PublishDocumentTransition(BaseTransition): + """Example: Using the registration decorator""" + publish_immediately: bool = Field(True, description="Publish immediately") + scheduled_time: datetime = Field(None, description="Scheduled publish time") + + @property + def target_state(self) -> str: + return "PUBLISHED" if self.publish_immediately else "SCHEDULED" + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "publish_immediately": self.publish_immediately, + "scheduled_time": self.scheduled_time.isoformat() if self.scheduled_time else None, + "published_by": context.current_user.id if context.current_user else None + } + + # Test registration worked + registered_class = transition_registry.get_transition('document', 'publish') + self.assertEqual(registered_class, PublishDocumentTransition) + + # Test getting transitions for entity + document_transitions = transition_registry.get_transitions_for_entity('document') + self.assertIn('publish', document_transitions) + + # Test execution through registry + mock_document = Mock() + mock_document.pk = 1 + mock_document._meta.model_name = 'document' + + # This would normally go through the full execution workflow + transition_data = {"publish_immediately": False, "scheduled_time": datetime.now() + timedelta(hours=2)} + + # Test transition creation and validation + transition = PublishDocumentTransition(**transition_data) + self.assertEqual(transition.target_state, "SCHEDULED") + + +class ValidationAndErrorHandlingTests(TestCase): + """ + Tests focused on validation scenarios and error handling. + + These tests demonstrate proper error handling patterns and + validation edge cases. + """ + + def setUp(self): + self.task = MockTask() + transition_registry._transitions.clear() + + def test_pydantic_validation_errors(self): + """Test Pydantic field validation errors""" + + class StrictValidationTransition(BaseTransition): + required_field: str = Field(..., description="Required field") + email_field: str = Field(..., pattern=r'^[\w\.-]+@[\w\.-]+\.\w+$', description="Valid email") + number_field: int = Field(..., ge=1, le=100, description="Number between 1-100") + + @property + def target_state(self) -> str: + return "VALIDATED" + + @classmethod + def get_target_state(cls) -> str: + return "VALIDATED" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {"validated": True} + + # Test missing required field + with self.assertRaises(ValidationError): + StrictValidationTransition(email_field="test@example.com", number_field=50) + + # Test invalid email + with self.assertRaises(ValidationError): + StrictValidationTransition( + required_field="test", + email_field="invalid-email", + number_field=50 + ) + + # Test number out of range + with self.assertRaises(ValidationError): + StrictValidationTransition( + required_field="test", + email_field="test@example.com", + number_field=150 + ) + + # Test valid data + valid_transition = StrictValidationTransition( + required_field="test", + email_field="user@example.com", + number_field=75 + ) + self.assertEqual(valid_transition.required_field, "test") + + def test_business_logic_validation_errors(self): + """Test business logic validation with detailed error context""" + + class BusinessRuleTransition(BaseTransition): + amount: float = Field(..., description="Transaction amount") + currency: str = Field("USD", description="Currency code") + + @property + def target_state(self) -> str: + return "PROCESSED" + + def validate_transition(self, context: TransitionContext) -> bool: + # Complex business rule validation + errors = [] + + if self.amount <= 0: + errors.append("Amount must be positive") + + if self.amount > 10000 and context.current_user is None: + errors.append("Large amounts require authenticated user") + + if self.currency not in ["USD", "EUR", "GBP"]: + errors.append(f"Unsupported currency: {self.currency}") + + if context.current_state == "CANCELLED": + errors.append("Cannot process cancelled transactions") + + if errors: + raise TransitionValidationError( + f"Validation failed: {'; '.join(errors)}", + { + "validation_errors": errors, + "amount": self.amount, + "currency": self.currency, + "current_state": context.current_state + } + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "amount": self.amount, + "currency": self.currency + } + + context = TransitionContext( + entity=self.task, + current_state="PENDING", + target_state="PROCESSED" + ) + + # Test negative amount + negative_transition = BusinessRuleTransition(amount=-100) + with self.assertRaises(TransitionValidationError) as cm: + negative_transition.validate_transition(context) + + error = cm.exception + self.assertIn("Amount must be positive", str(error)) + self.assertIn("validation_errors", error.context) + + # Test large amount without user + large_transition = BusinessRuleTransition(amount=15000) + with self.assertRaises(TransitionValidationError) as cm: + large_transition.validate_transition(context) + + self.assertIn("Large amounts require authenticated user", str(cm.exception)) + + # Test invalid currency + invalid_currency_transition = BusinessRuleTransition( + amount=100, + currency="XYZ" + ) + with self.assertRaises(TransitionValidationError) as cm: + invalid_currency_transition.validate_transition(context) + + self.assertIn("Unsupported currency", str(cm.exception)) + + # Test multiple errors + multi_error_transition = BusinessRuleTransition( + amount=-50, + currency="XYZ" + ) + with self.assertRaises(TransitionValidationError) as cm: + multi_error_transition.validate_transition(context) + + error_msg = str(cm.exception) + self.assertIn("Amount must be positive", error_msg) + self.assertIn("Unsupported currency", error_msg) + + def test_context_validation_errors(self): + """Test validation errors related to context state""" + + class ContextAwareTransition(BaseTransition): + action: str = Field(..., description="Action to perform") + + @property + def target_state(self) -> str: + return "ACTIONED" + + def validate_transition(self, context: TransitionContext) -> bool: + # State-dependent validation + if context.is_initial_transition and self.action != "create": + raise TransitionValidationError( + "Initial transition must be 'create' action", + {"action": self.action, "is_initial": True} + ) + + if context.current_state == "COMPLETED" and self.action in ["modify", "update"]: + raise TransitionValidationError( + f"Cannot {self.action} completed items", + {"action": self.action, "current_state": context.current_state} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {"action": self.action} + + # Test initial transition validation + create_transition = ContextAwareTransition(action="create") + initial_context = TransitionContext( + entity=self.task, + current_state=None, # No current state = initial + target_state="ACTIONED" + ) + + self.assertTrue(create_transition.validate_transition(initial_context)) + + # Test invalid initial action + modify_transition = ContextAwareTransition(action="modify") + with self.assertRaises(TransitionValidationError) as cm: + modify_transition.validate_transition(initial_context) + + error = cm.exception + self.assertIn("Initial transition must be 'create'", str(error)) + self.assertTrue(error.context["is_initial"]) + + # Test completed state validation + completed_context = TransitionContext( + entity=self.task, + current_state="COMPLETED", + target_state="ACTIONED" + ) + + with self.assertRaises(TransitionValidationError) as cm: + modify_transition.validate_transition(completed_context) + + self.assertIn("Cannot modify completed items", str(cm.exception)) + + +@pytest.fixture +def task(): + """Pytest fixture for mock task""" + return MockTask() + + +@pytest.fixture +def user(): + """Pytest fixture for mock user""" + user = Mock() + user.id = 1 + user.username = "testuser" + return user + + +def test_transition_context_properties(task, user): + """Test TransitionContext properties using pytest""" + context = TransitionContext( + entity=task, + current_user=user, + current_state="CREATED", + target_state="IN_PROGRESS" + ) + + assert context.has_current_state + assert not context.is_initial_transition + assert context.current_state == "CREATED" + assert context.target_state == "IN_PROGRESS" + + +def test_pydantic_validation(): + """Test Pydantic validation in transitions""" + # Valid data + transition = TestTransition(test_field="valid") + assert transition.test_field == "valid" + assert transition.optional_field == 42 + + # Invalid data should raise validation error + with pytest.raises(Exception): # Pydantic validation error + TestTransition() # Missing required field \ No newline at end of file diff --git a/label_studio/fsm/tests/test_edge_cases_error_handling.py b/label_studio/fsm/tests/test_edge_cases_error_handling.py new file mode 100644 index 000000000000..7b2b87cb2675 --- /dev/null +++ b/label_studio/fsm/tests/test_edge_cases_error_handling.py @@ -0,0 +1,768 @@ +""" +Edge cases and comprehensive error handling tests. + +These tests cover unusual scenarios, boundary conditions, error edge cases, +and defensive programming patterns that ensure the transition system +is robust in production environments. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock +from typing import Dict, Any, List, Optional +import threading +import weakref +import gc + +from django.test import TestCase +from pydantic import Field, ValidationError, validator + +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + transition_registry, + register_transition +) +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + TransitionBuilder, + validate_transition_data, + create_transition_from_dict +) + + +class EdgeCaseTransition(BaseTransition): + """Test transition for edge case scenarios""" + edge_case_data: Any = Field(None, description="Data for edge case testing") + + @property + def target_state(self) -> str: + return "EDGE_CASE_PROCESSED" + + def validate_transition(self, context: TransitionContext) -> bool: + # Deliberately minimal validation for edge case testing + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "edge_case_data": self.edge_case_data, + "processed_at": context.timestamp.isoformat() + } + + +class ErrorProneTransition(BaseTransition): + """Transition designed to test error scenarios""" + should_fail: str = Field("no", description="Controls failure behavior") + failure_stage: str = Field("none", description="Stage at which to fail") + + @property + def target_state(self) -> str: + return "ERROR_TESTED" + + def validate_transition(self, context: TransitionContext) -> bool: + if self.failure_stage == "validation" and self.should_fail == "yes": + raise TransitionValidationError("Intentional validation failure") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + if self.failure_stage == "transition" and self.should_fail == "yes": + raise RuntimeError("Intentional transition failure") + + return { + "should_fail": self.should_fail, + "failure_stage": self.failure_stage + } + + +class EdgeCasesAndErrorHandlingTests(TestCase): + """ + Comprehensive edge case and error handling tests. + + These tests ensure the transition system handles unusual inputs, + boundary conditions, and error scenarios gracefully. + """ + + def setUp(self): + self.mock_entity = Mock() + self.mock_entity.pk = 1 + self.mock_entity._meta.model_name = 'test_entity' + + self.mock_user = Mock() + self.mock_user.id = 42 + + # Clear registry + transition_registry._transitions.clear() + transition_registry.register('test_entity', 'edge_case', EdgeCaseTransition) + transition_registry.register('test_entity', 'error_prone', ErrorProneTransition) + + def test_none_and_empty_values_handling(self): + """ + EDGE CASE: Handling None and empty values + + Tests how the system handles None values, empty strings, + empty lists, and other "falsy" values. + """ + + # Test None values + transition_none = EdgeCaseTransition(edge_case_data=None) + self.assertIsNone(transition_none.edge_case_data) + + context = TransitionContext( + entity=self.mock_entity, + current_user=None, # None user + current_state=None, # None state (initial) + target_state=transition_none.target_state + ) + + # Should handle None values gracefully + self.assertTrue(transition_none.validate_transition(context)) + result = transition_none.transition(context) + self.assertIsNone(result["edge_case_data"]) + + # Test empty string values + empty_transition = EdgeCaseTransition(edge_case_data="") + result = empty_transition.transition(context) + self.assertEqual(result["edge_case_data"], "") + + # Test empty collections + empty_list_transition = EdgeCaseTransition(edge_case_data=[]) + result = empty_list_transition.transition(context) + self.assertEqual(result["edge_case_data"], []) + + empty_dict_transition = EdgeCaseTransition(edge_case_data={}) + result = empty_dict_transition.transition(context) + self.assertEqual(result["edge_case_data"], {}) + + # Test zero values + zero_transition = EdgeCaseTransition(edge_case_data=0) + result = zero_transition.transition(context) + self.assertEqual(result["edge_case_data"], 0) + + # Test False boolean + false_transition = EdgeCaseTransition(edge_case_data=False) + result = false_transition.transition(context) + self.assertFalse(result["edge_case_data"]) + + def test_extreme_data_sizes(self): + """ + EDGE CASE: Handling extremely large or small data + + Tests system behavior with very large strings, deep nested structures, + and other extreme data sizes. + """ + + # Test very large string + large_string = "x" * 10000 # 10KB string + large_string_transition = EdgeCaseTransition(edge_case_data=large_string) + + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=large_string_transition.target_state + ) + + result = large_string_transition.transition(context) + self.assertEqual(len(result["edge_case_data"]), 10000) + + # Test deeply nested dictionary + deep_dict = {"level": 0} + current_level = deep_dict + for i in range(100): # 100 levels deep + current_level["next"] = {"level": i + 1} + current_level = current_level["next"] + + deep_dict_transition = EdgeCaseTransition(edge_case_data=deep_dict) + result = deep_dict_transition.transition(context) + self.assertEqual(result["edge_case_data"]["level"], 0) + + # Test large list + large_list = list(range(1000)) # 1000 items + large_list_transition = EdgeCaseTransition(edge_case_data=large_list) + result = large_list_transition.transition(context) + self.assertEqual(len(result["edge_case_data"]), 1000) + self.assertEqual(result["edge_case_data"][-1], 999) + + def test_unicode_and_special_characters(self): + """ + EDGE CASE: Unicode and special character handling + + Tests handling of various Unicode characters, emojis, + control characters, and other special strings. + """ + + test_cases = [ + # Unicode characters + "Hello, 世界! 🌍", + # Emojis + "Task completed! 🎉✅👍", + # Special symbols + "Price: €100 → $120 ≈ £95", + # Mathematical symbols + "∑(1,2,3) = 6, √16 = 4, π ≈ 3.14", + # Control characters (escaped) + "Line1\nLine2\tTabbed\r\nWindows", + # JSON-like string + '{"key": "value", "number": 42}', + # SQL-like string (potential injection test) + "'; DROP TABLE users; --", + # Empty and whitespace + " \t\n\r ", + # Very long Unicode + "🌟" * 100, + ] + + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="EDGE_CASE_PROCESSED" + ) + + for test_string in test_cases: + with self.subTest(test_string=test_string[:20] + "..."): + transition = EdgeCaseTransition(edge_case_data=test_string) + + # Should handle any Unicode string + result = transition.transition(context) + self.assertEqual(result["edge_case_data"], test_string) + + def test_boundary_datetime_values(self): + """ + EDGE CASE: Boundary datetime values + + Tests handling of edge case datetime values like far future, + far past, timezone edge cases, etc. + """ + + boundary_datetimes = [ + # Far past + datetime(1970, 1, 1), + datetime(1900, 1, 1), + # Far future + datetime(2100, 12, 31), + datetime(3000, 1, 1), + # Edge of leap year + datetime(2000, 2, 29), # Leap year + datetime(1900, 2, 28), # Not a leap year + # New Year boundaries + datetime(1999, 12, 31, 23, 59, 59), + datetime(2000, 1, 1, 0, 0, 0), + # Microsecond precision + datetime(2024, 1, 1, 12, 0, 0, 123456), + ] + + for test_datetime in boundary_datetimes: + with self.subTest(datetime=test_datetime.isoformat()): + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="EDGE_CASE_PROCESSED", + timestamp=test_datetime + ) + + transition = EdgeCaseTransition(edge_case_data="datetime_test") + + # Should handle any valid datetime + result = transition.transition(context) + self.assertEqual(result["processed_at"], test_datetime.isoformat()) + + def test_circular_reference_handling(self): + """ + EDGE CASE: Circular references and complex object graphs + + Tests how the system handles objects with circular references + or complex interdependencies. + """ + + # Create circular reference structure + circular_dict = {"name": "parent"} + circular_dict["child"] = {"name": "child", "parent": circular_dict} + + # Test that the system can handle circular references without infinite recursion + # Pydantic with field type 'Any' will accept circular references + try: + transition = EdgeCaseTransition(edge_case_data=circular_dict) + # Verify that the circular reference was stored + self.assertEqual(transition.edge_case_data["name"], "parent") + self.assertEqual(transition.edge_case_data["child"]["name"], "child") + # The system should handle this gracefully + except RecursionError: + self.fail("System should handle circular references without infinite recursion") + + # Test with complex but non-circular structure + complex_structure = { + "level1": { + "level2": { + "level3": { + "data": "deep_value", + "references": ["ref1", "ref2", "ref3"] * 10 + } + } + }, + "cross_reference": None # Will be set to level1 later, but not circular + } + + transition = EdgeCaseTransition(edge_case_data=complex_structure) + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=transition.target_state + ) + + result = transition.transition(context) + self.assertEqual( + result["edge_case_data"]["level1"]["level2"]["level3"]["data"], + "deep_value" + ) + + def test_memory_pressure_and_cleanup(self): + """ + EDGE CASE: Memory pressure and garbage collection + + Tests system behavior under memory pressure and ensures + proper cleanup of transition instances and contexts. + """ + + transitions = [] + contexts = [] + weak_refs = [] + + # Create many transition instances + for i in range(1000): + transition = EdgeCaseTransition(edge_case_data=f"data_{i}") + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=transition.target_state, + metadata={"iteration": i} + ) + + transitions.append(transition) + contexts.append(context) + + # Create weak references to test garbage collection + if i < 10: # Only for first few to avoid too many weak refs + weak_refs.append(weakref.ref(transition)) + weak_refs.append(weakref.ref(context)) + + # Verify all were created + self.assertEqual(len(transitions), 1000) + self.assertEqual(len(contexts), 1000) + + # Clear references and force garbage collection + transitions.clear() + contexts.clear() + gc.collect() + + # Check that objects can be garbage collected + # Some weak references should be None after GC + none_count = sum(1 for ref in weak_refs if ref() is None) + # At least some should be collected (this is implementation dependent) + # We don't require all to be collected due to Python GC behavior + + # Test that new instances can still be created after cleanup + new_transition = EdgeCaseTransition(edge_case_data="after_cleanup") + new_context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=new_transition.target_state + ) + + result = new_transition.transition(new_context) + self.assertEqual(result["edge_case_data"], "after_cleanup") + + def test_exception_during_validation(self): + """ + ERROR HANDLING: Exceptions during validation + + Tests proper handling of various types of exceptions + that can occur during transition validation. + """ + + class ValidationErrorTransition(BaseTransition): + error_type: str = Field(..., description="Type of error to raise") + + @property + def target_state(self) -> str: + return "ERROR_STATE" + + @classmethod + def get_target_state(cls) -> str: + return "ERROR_STATE" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def validate_transition(self, context: TransitionContext) -> bool: + if self.error_type == "transition_validation": + raise TransitionValidationError("Business rule violation") + elif self.error_type == "runtime_error": + raise RuntimeError("Unexpected runtime error") + elif self.error_type == "key_error": + raise KeyError("Missing required key") + elif self.error_type == "attribute_error": + raise AttributeError("Missing attribute") + elif self.error_type == "value_error": + raise ValueError("Invalid value provided") + elif self.error_type == "type_error": + raise TypeError("Type mismatch") + return True + + def transition(self, context: TransitionContext) -> dict: + return {"error_type": self.error_type, "processed": True} + + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="ERROR_STATE" + ) + + # Test TransitionValidationError (expected) + transition = ValidationErrorTransition(error_type="transition_validation") + with self.assertRaises(TransitionValidationError) as cm: + transition.validate_transition(context) + self.assertIn("Business rule violation", str(cm.exception)) + + # Test other exceptions (should bubble up) + error_types = [ + ("runtime_error", RuntimeError), + ("key_error", KeyError), + ("attribute_error", AttributeError), + ("value_error", ValueError), + ("type_error", TypeError), + ] + + for error_type, exception_class in error_types: + with self.subTest(error_type=error_type): + transition = ValidationErrorTransition(error_type=error_type) + with self.assertRaises(exception_class): + transition.validate_transition(context) + + def test_exception_during_transition_execution(self): + """ + ERROR HANDLING: Exceptions during transition execution + + Tests handling of exceptions that occur during the + actual transition execution phase. + """ + + # Test with ErrorProneTransition + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="ERROR_TESTED" + ) + + # Test successful execution + success_transition = ErrorProneTransition(should_fail="no") + result = success_transition.transition(context) + self.assertEqual(result["should_fail"], "no") + + # Test intentional failure + fail_transition = ErrorProneTransition( + should_fail="yes", + failure_stage="transition" + ) + + with self.assertRaises(RuntimeError) as cm: + fail_transition.transition(context) + self.assertIn("Intentional transition failure", str(cm.exception)) + + def test_registry_edge_cases(self): + """ + EDGE CASE: Registry edge cases + + Tests unusual registry operations and edge cases + like duplicate registrations, invalid names, etc. + """ + + # Test duplicate registration (should overwrite) + original_class = EdgeCaseTransition + + class NewEdgeCaseTransition(BaseTransition): + @property + def target_state(self) -> str: + return "NEW_EDGE_CASE" + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {"type": "new_implementation"} + + # Register with same name + transition_registry.register('test_entity', 'edge_case', NewEdgeCaseTransition) + + # Should get new class + retrieved = transition_registry.get_transition('test_entity', 'edge_case') + self.assertEqual(retrieved, NewEdgeCaseTransition) + + # Test registration with unusual names + unusual_names = [ + ('entity-with-dashes', 'transition-name'), + ('entity_with_underscores', 'transition_name'), + ('Entity.With.Dots', 'transition.name'), + ('entity123', 'transition456'), + ('UPPERCASE_ENTITY', 'UPPERCASE_TRANSITION'), + ] + + for entity_name, transition_name in unusual_names: + with self.subTest(entity=entity_name, transition=transition_name): + transition_registry.register(entity_name, transition_name, EdgeCaseTransition) + retrieved = transition_registry.get_transition(entity_name, transition_name) + self.assertEqual(retrieved, EdgeCaseTransition) + + # Test nonexistent lookups + self.assertIsNone(transition_registry.get_transition('nonexistent', 'transition')) + self.assertIsNone(transition_registry.get_transition('test_entity', 'nonexistent')) + + # Test empty entity transitions + empty_transitions = transition_registry.get_transitions_for_entity('nonexistent_entity') + self.assertEqual(empty_transitions, {}) + + def test_context_edge_cases(self): + """ + EDGE CASE: TransitionContext edge cases + + Tests unusual context configurations and edge cases + in context creation and usage. + """ + + # Test context with None entity (system should handle gracefully) + # Since entity field is typed as Any, None is accepted + try: + context = TransitionContext( + entity=None, + current_state="CREATED", + target_state="PROCESSED" + ) + # Verify context was created with None entity + self.assertIsNone(context.entity) + self.assertEqual(context.current_state, "CREATED") + except Exception as e: + self.fail(f"Context creation with None entity should not fail: {e}") + + # Test context with missing required fields + with self.assertRaises(ValidationError): + TransitionContext( + entity=self.mock_entity, + # Missing target_state + ) + + # Test context with extreme timestamp + far_future = datetime(3000, 1, 1) + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="PROCESSED", + timestamp=far_future + ) + + self.assertEqual(context.timestamp, far_future) + + # Test context with large metadata + large_metadata = {f"key_{i}": f"value_{i}" for i in range(1000)} + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state="PROCESSED", + metadata=large_metadata + ) + + self.assertEqual(len(context.metadata), 1000) + + # Test context property edge cases + empty_context = TransitionContext( + entity=self.mock_entity, + current_state="", # Empty string state + target_state="PROCESSED" + ) + + # Empty string should be considered "has state" + self.assertTrue(empty_context.has_current_state) + self.assertFalse(empty_context.is_initial_transition) + + def test_transition_builder_edge_cases(self): + """ + EDGE CASE: TransitionBuilder edge cases + + Tests unusual usage patterns and edge cases + with the fluent TransitionBuilder interface. + """ + + builder = TransitionBuilder(self.mock_entity) + + # Test validation without setting transition name + with self.assertRaises(ValueError) as cm: + builder.validate() + self.assertIn("Transition name not specified", str(cm.exception)) + + # Test execution without setting transition name + with self.assertRaises(ValueError) as cm: + builder.execute() + self.assertIn("Transition name not specified", str(cm.exception)) + + # Test with nonexistent transition + builder.transition('nonexistent_transition') + + with self.assertRaises(ValueError) as cm: + builder.validate() + self.assertIn("not found", str(cm.exception)) + + # Test method chaining edge cases + builder = (TransitionBuilder(self.mock_entity) + .transition('edge_case') + .with_data() # Empty data + .by_user(None) # No user + .with_context()) # Empty context + + # Should not raise errors for empty data + errors = builder.validate() + self.assertEqual(errors, {}) # EdgeCaseTransition has no required fields + + # Test data overwriting + builder = (TransitionBuilder(self.mock_entity) + .transition('edge_case') + .with_data(edge_case_data="first") + .with_data(edge_case_data="second")) # Should overwrite + + errors = builder.validate() + self.assertEqual(errors, {}) + + def test_concurrent_error_scenarios(self): + """ + EDGE CASE: Error handling under concurrency + + Tests error handling when multiple threads encounter + errors simultaneously. + """ + + error_results = [] + error_lock = threading.Lock() + + def error_worker(worker_id): + """Worker that intentionally triggers errors""" + try: + # Create transition that will fail + transition = ErrorProneTransition( + should_fail="yes", + failure_stage="validation" if worker_id % 2 == 0 else "transition" + ) + + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=transition.target_state + ) + + if worker_id % 2 == 0: + # Validation error + transition.validate_transition(context) + else: + # Transition execution error + transition.transition(context) + + except Exception as e: + with error_lock: + error_results.append({ + "worker_id": worker_id, + "error_type": type(e).__name__, + "error_message": str(e) + }) + + # Run multiple workers that will all fail + threads = [] + for i in range(10): + thread = threading.Thread(target=error_worker, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all to complete + for thread in threads: + thread.join() + + # Should have 10 errors + self.assertEqual(len(error_results), 10) + + # Verify error types + validation_errors = [r for r in error_results if r["error_type"] == "TransitionValidationError"] + runtime_errors = [r for r in error_results if r["error_type"] == "RuntimeError"] + + self.assertEqual(len(validation_errors), 5) # Even worker IDs + self.assertEqual(len(runtime_errors), 5) # Odd worker IDs + + def test_resource_cleanup_after_errors(self): + """ + EDGE CASE: Resource cleanup after errors + + Tests that resources are properly cleaned up + even when transitions fail partway through. + """ + + class ResourceTrackingTransition(BaseTransition): + """Transition that tracks resource allocation""" + resource_name: str = Field(..., description="Name of resource") + resources_allocated: list = Field(default_factory=list, description="Track allocated resources") + resources_cleaned: list = Field(default_factory=list, description="Track cleaned resources") + + @property + def target_state(self) -> str: + return "RESOURCE_PROCESSED" + + @classmethod + def get_target_state(cls) -> str: + return "RESOURCE_PROCESSED" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def validate_transition(self, context: TransitionContext) -> bool: + # Allocate some mock resources + resource = f"resource_{self.resource_name}" + self.resources_allocated.append(resource) + + # Fail if resource name contains "fail" + if "fail" in self.resource_name: + raise TransitionValidationError("Resource allocation failed") + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {"resource_name": self.resource_name} + + def __del__(self): + # Mock cleanup in destructor + for resource in self.resources_allocated: + if resource not in self.resources_cleaned: + self.resources_cleaned.append(resource) + + # Test successful case + success_transition = ResourceTrackingTransition(resource_name="success") + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=success_transition.target_state + ) + + self.assertTrue(success_transition.validate_transition(context)) + self.assertEqual(len(success_transition.resources_allocated), 1) + + # Test failure case + fail_transition = ResourceTrackingTransition(resource_name="fail_test") + + with self.assertRaises(TransitionValidationError): + fail_transition.validate_transition(context) + + # Resources should still be allocated even though validation failed + self.assertEqual(len(fail_transition.resources_allocated), 1) + + # Force garbage collection to trigger cleanup + success_ref = weakref.ref(success_transition) + fail_ref = weakref.ref(fail_transition) + + del success_transition + del fail_transition + gc.collect() + + # References should be cleaned up + # (Note: This test is somewhat implementation-dependent) \ No newline at end of file diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index cfb36baa8ecc..cbfcb405ad93 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -29,7 +29,11 @@ def setUp(self): def test_task_state_creation(self): """Test TaskState creation and basic functionality""" task_state = TaskState.objects.create( - task=self.task, state='CREATED', triggered_by=self.user, reason='Task created for testing' + task=self.task, + project_id=self.task.project_id, # Denormalized from task.project_id + state='CREATED', + triggered_by=self.user, + reason='Task created for testing', ) # Check basic fields @@ -51,7 +55,13 @@ def test_annotation_state_creation(self): annotation = Annotation.objects.create(task=self.task, completed_by=self.user, result=[]) annotation_state = AnnotationState.objects.create( - annotation=annotation, state='DRAFT', triggered_by=self.user, reason='Annotation draft created' + annotation=annotation, + task_id=annotation.task.id, # Denormalized from annotation.task_id + project_id=annotation.task.project_id, # Denormalized from annotation.task.project_id + completed_by_id=annotation.completed_by.id if annotation.completed_by else None, # Denormalized + state='DRAFT', + triggered_by=self.user, + reason='Annotation draft created', ) # Check basic fields @@ -63,7 +73,12 @@ def test_annotation_state_creation(self): # Test completed state completed_state = AnnotationState.objects.create( - annotation=annotation, state='COMPLETED', triggered_by=self.user + annotation=annotation, + task_id=annotation.task.id, + project_id=annotation.task.project_id, + completed_by_id=annotation.completed_by.id if annotation.completed_by else None, + state='COMPLETED', + triggered_by=self.user, ) self.assertTrue(completed_state.is_terminal_state) diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py new file mode 100644 index 000000000000..0b7893b7f33d --- /dev/null +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -0,0 +1,682 @@ +""" +Integration tests for declarative transitions with real Django models. + +These tests demonstrate how the transition system integrates with actual +Django models and the StateManager, providing realistic usage examples. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch +from typing import Dict, Any + +from django.test import TestCase, TransactionTestCase +from django.contrib.auth import get_user_model +from pydantic import Field + +from fsm.models import TaskState, AnnotationState, get_state_model_for_entity +from fsm.state_choices import TaskStateChoices, AnnotationStateChoices +from fsm.state_manager import StateManager +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + register_transition +) +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + TransitionBuilder +) + +# Mock Django models for integration testing +class MockDjangoTask: + """Mock Django Task model with realistic attributes""" + def __init__(self, pk=1, project_id=1, organization_id=1): + self.pk = pk + self.id = pk + self.project_id = project_id + self.organization_id = organization_id + self._meta = Mock() + self._meta.model_name = 'task' + self._meta.label_lower = 'tasks.task' + + # Mock task attributes + self.data = {"text": "Sample task data"} + self.created_at = datetime.now() + self.updated_at = datetime.now() + + +class MockDjangoAnnotation: + """Mock Django Annotation model with realistic attributes""" + def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): + self.pk = pk + self.id = pk + self.task_id = task_id + self.project_id = project_id + self.organization_id = organization_id + self._meta = Mock() + self._meta.model_name = 'annotation' + self._meta.label_lower = 'tasks.annotation' + + # Mock annotation attributes + self.result = [{"value": {"text": ["Sample annotation"]}}] + self.completed_by_id = None + self.created_at = datetime.now() + self.updated_at = datetime.now() + + +User = get_user_model() + + +class DjangoModelIntegrationTests(TestCase): + """ + Integration tests demonstrating realistic usage with Django models. + + These tests show how to implement transitions that work with actual + Django model patterns and the StateManager integration. + """ + + def setUp(self): + self.task = MockDjangoTask() + self.annotation = MockDjangoAnnotation() + self.user = Mock() + self.user.id = 123 + self.user.username = "integration_test_user" + + # Clear registry for clean test state + from fsm.transitions import transition_registry + transition_registry._transitions.clear() + + @patch('fsm.models.get_state_model_for_entity') + @patch('fsm.state_manager.StateManager.get_current_state_object') + @patch('fsm.state_manager.StateManager.transition_state') + def test_task_workflow_integration(self, mock_transition_state, mock_get_state_obj, mock_get_state_model): + """ + INTEGRATION TEST: Complete task workflow using Django models + + Demonstrates a realistic task lifecycle from creation through completion + using the declarative transition system with Django model integration. + """ + + # Setup mocks to simulate Django model behavior + mock_get_state_model.return_value = TaskState + mock_get_state_obj.return_value = None # No existing state (initial transition) + mock_transition_state.return_value = True + + # Define task workflow transitions + @register_transition('task', 'create_task') + class CreateTaskTransition(BaseTransition): + """Initial task creation transition""" + created_by_id: int = Field(..., description="User creating the task") + initial_priority: str = Field("normal", description="Initial task priority") + + @property + def target_state(self) -> str: + return TaskStateChoices.CREATED + + def validate_transition(self, context: TransitionContext) -> bool: + # Validate initial creation + if not context.is_initial_transition: + raise TransitionValidationError("CreateTask can only be used for initial state") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "created_by_id": self.created_by_id, + "initial_priority": self.initial_priority, + "task_data": getattr(context.entity, 'data', {}), + "project_id": getattr(context.entity, 'project_id', None), + "creation_method": "declarative_transition" + } + + @register_transition('task', 'assign_and_start') + class AssignAndStartTaskTransition(BaseTransition): + """Assign task to user and start work""" + assignee_id: int = Field(..., description="User assigned to task") + estimated_hours: float = Field(None, ge=0.1, description="Estimated work hours") + priority: str = Field("normal", description="Task priority") + + @property + def target_state(self) -> str: + return TaskStateChoices.IN_PROGRESS + + def validate_transition(self, context: TransitionContext) -> bool: + valid_from_states = [TaskStateChoices.CREATED] + if context.current_state not in valid_from_states: + raise TransitionValidationError( + f"Can only assign tasks from states: {valid_from_states}", + {"current_state": context.current_state, "valid_states": valid_from_states} + ) + + # Business rule: Can't assign to the same user who created it + if hasattr(context, 'current_state_object') and context.current_state_object: + creator_id = context.current_state_object.context_data.get('created_by_id') + if creator_id == self.assignee_id: + raise TransitionValidationError( + "Cannot assign task to the same user who created it", + {"creator_id": creator_id, "assignee_id": self.assignee_id} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "assignee_id": self.assignee_id, + "estimated_hours": self.estimated_hours, + "priority": self.priority, + "assigned_at": context.timestamp.isoformat(), + "assigned_by_id": context.current_user.id if context.current_user else None, + "work_started": True + } + + @register_transition('task', 'complete_with_quality') + class CompleteTaskWithQualityTransition(BaseTransition): + """Complete task with quality metrics""" + quality_score: float = Field(..., ge=0.0, le=1.0, description="Quality score") + completion_notes: str = Field("", description="Completion notes") + actual_hours: float = Field(None, ge=0.0, description="Actual hours worked") + + @property + def target_state(self) -> str: + return TaskStateChoices.COMPLETED + + def validate_transition(self, context: TransitionContext) -> bool: + if context.current_state != TaskStateChoices.IN_PROGRESS: + raise TransitionValidationError( + "Can only complete tasks that are in progress", + {"current_state": context.current_state} + ) + + # Quality check + if self.quality_score < 0.6: + raise TransitionValidationError( + f"Quality score too low: {self.quality_score}. Minimum required: 0.6" + ) + + return True + + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + """Post-completion tasks like notifications""" + # Mock notification system + if hasattr(self, '_notifications'): + self._notifications.append(f"Task {context.entity.pk} completed with quality {self.quality_score}") + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Calculate metrics + start_data = context.current_state_object.context_data if context.current_state_object else {} + estimated_hours = start_data.get('estimated_hours') + + return { + "quality_score": self.quality_score, + "completion_notes": self.completion_notes, + "actual_hours": self.actual_hours, + "estimated_hours": estimated_hours, + "completed_at": context.timestamp.isoformat(), + "completed_by_id": context.current_user.id if context.current_user else None, + "efficiency_ratio": (estimated_hours / self.actual_hours) if (estimated_hours and self.actual_hours) else None + } + + # Execute the complete workflow + + # Step 1: Create task + create_transition = CreateTaskTransition( + created_by_id=100, + initial_priority="high" + ) + + # Test with StateManager integration + with patch('fsm.state_manager.StateManager.get_current_state') as mock_get_current: + mock_get_current.return_value = None # No current state + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=None, + target_state=create_transition.target_state + ) + + # Validate and execute creation + self.assertTrue(create_transition.validate_transition(context)) + creation_data = create_transition.transition(context) + + self.assertEqual(creation_data["created_by_id"], 100) + self.assertEqual(creation_data["initial_priority"], "high") + self.assertEqual(creation_data["creation_method"], "declarative_transition") + + # Step 2: Assign and start task + mock_current_state = Mock() + mock_current_state.context_data = creation_data + mock_get_state_obj.return_value = mock_current_state + + assign_transition = AssignAndStartTaskTransition( + assignee_id=200, # Different from creator + estimated_hours=4.5, + priority="urgent" + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.CREATED, + current_state_object=mock_current_state, + target_state=assign_transition.target_state + ) + + self.assertTrue(assign_transition.validate_transition(context)) + assignment_data = assign_transition.transition(context) + + self.assertEqual(assignment_data["assignee_id"], 200) + self.assertEqual(assignment_data["estimated_hours"], 4.5) + self.assertTrue(assignment_data["work_started"]) + + # Step 3: Complete task + mock_current_state.context_data = assignment_data + + complete_transition = CompleteTaskWithQualityTransition( + quality_score=0.85, + completion_notes="Task completed successfully with minor revisions", + actual_hours=5.2 + ) + complete_transition._notifications = [] # Mock notification system + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.IN_PROGRESS, + current_state_object=mock_current_state, + target_state=complete_transition.target_state + ) + + self.assertTrue(complete_transition.validate_transition(context)) + completion_data = complete_transition.transition(context) + + self.assertEqual(completion_data["quality_score"], 0.85) + self.assertEqual(completion_data["actual_hours"], 5.2) + self.assertAlmostEqual(completion_data["efficiency_ratio"], 4.5/5.2, places=2) + + # Test post-hook + mock_state_record = Mock() + complete_transition.post_transition_hook(context, mock_state_record) + self.assertEqual(len(complete_transition._notifications), 1) + + # Verify StateManager calls + self.assertEqual(mock_transition_state.call_count, 0) # Not called in our test setup + + def test_annotation_review_workflow_integration(self): + """ + INTEGRATION TEST: Annotation review workflow + + Demonstrates a realistic annotation review process using + enterprise-grade validation and approval logic. + """ + + @register_transition('annotation', 'submit_for_review') + class SubmitAnnotationForReview(BaseTransition): + """Submit annotation for quality review""" + annotator_confidence: float = Field(..., ge=0.0, le=1.0, description="Annotator confidence") + annotation_time_seconds: int = Field(..., ge=1, description="Time spent annotating") + review_requested: bool = Field(True, description="Whether review is requested") + + @property + def target_state(self) -> str: + return AnnotationStateChoices.SUBMITTED + + def validate_transition(self, context: TransitionContext) -> bool: + # Check annotation has content + if not hasattr(context.entity, 'result') or not context.entity.result: + raise TransitionValidationError("Cannot submit empty annotation") + + # Business rule: Low confidence annotations must request review + if self.annotator_confidence < 0.7 and not self.review_requested: + raise TransitionValidationError( + "Low confidence annotations must request review", + {"confidence": self.annotator_confidence, "threshold": 0.7} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "annotator_confidence": self.annotator_confidence, + "annotation_time_seconds": self.annotation_time_seconds, + "review_requested": self.review_requested, + "annotation_complexity": len(context.entity.result) if context.entity.result else 0, + "submitted_at": context.timestamp.isoformat(), + "submitted_by_id": context.current_user.id if context.current_user else None + } + + @register_transition('annotation', 'review_and_approve') + class ReviewAndApproveAnnotation(BaseTransition): + """Review annotation and approve/reject""" + reviewer_decision: str = Field(..., description="approve, reject, or request_changes") + quality_score: float = Field(..., ge=0.0, le=1.0, description="Reviewer quality assessment") + review_comments: str = Field("", description="Review comments") + corrections_made: bool = Field(False, description="Whether reviewer made corrections") + + @property + def target_state(self) -> str: + if self.reviewer_decision == "approve": + return AnnotationStateChoices.COMPLETED + else: + return AnnotationStateChoices.DRAFT # Back to draft for changes + + def validate_transition(self, context: TransitionContext) -> bool: + if context.current_state != AnnotationStateChoices.SUBMITTED: + raise TransitionValidationError("Can only review submitted annotations") + + valid_decisions = ["approve", "reject", "request_changes"] + if self.reviewer_decision not in valid_decisions: + raise TransitionValidationError( + f"Invalid decision: {self.reviewer_decision}", + {"valid_decisions": valid_decisions} + ) + + # Quality score validation based on decision + if self.reviewer_decision == "approve" and self.quality_score < 0.6: + raise TransitionValidationError( + "Cannot approve annotation with low quality score", + {"quality_score": self.quality_score, "decision": self.reviewer_decision} + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Get submission data for metrics + submission_data = context.current_state_object.context_data if context.current_state_object else {} + + return { + "reviewer_decision": self.reviewer_decision, + "quality_score": self.quality_score, + "review_comments": self.review_comments, + "corrections_made": self.corrections_made, + "reviewed_at": context.timestamp.isoformat(), + "reviewed_by_id": context.current_user.id if context.current_user else None, + "original_confidence": submission_data.get("annotator_confidence"), + "confidence_vs_quality_diff": abs(submission_data.get("annotator_confidence", 0) - self.quality_score) + } + + # Execute annotation workflow + + # Step 1: Submit annotation + submit_transition = SubmitAnnotationForReview( + annotator_confidence=0.9, + annotation_time_seconds=300, # 5 minutes + review_requested=True + ) + + context = TransitionContext( + entity=self.annotation, + current_user=self.user, + current_state=AnnotationStateChoices.DRAFT, + target_state=submit_transition.target_state + ) + + self.assertTrue(submit_transition.validate_transition(context)) + submit_data = submit_transition.transition(context) + + self.assertEqual(submit_data["annotator_confidence"], 0.9) + self.assertEqual(submit_data["annotation_time_seconds"], 300) + self.assertTrue(submit_data["review_requested"]) + self.assertEqual(submit_data["annotation_complexity"], 1) # Based on mock result + + # Step 2: Review and approve + mock_submission_state = Mock() + mock_submission_state.context_data = submit_data + + review_transition = ReviewAndApproveAnnotation( + reviewer_decision="approve", + quality_score=0.85, + review_comments="High quality annotation with good coverage", + corrections_made=False + ) + + context = TransitionContext( + entity=self.annotation, + current_user=self.user, + current_state=AnnotationStateChoices.SUBMITTED, + current_state_object=mock_submission_state, + target_state=review_transition.target_state + ) + + self.assertTrue(review_transition.validate_transition(context)) + self.assertEqual(review_transition.target_state, AnnotationStateChoices.COMPLETED) + + review_data = review_transition.transition(context) + + self.assertEqual(review_data["reviewer_decision"], "approve") + self.assertEqual(review_data["quality_score"], 0.85) + self.assertEqual(review_data["original_confidence"], 0.9) + self.assertAlmostEqual(review_data["confidence_vs_quality_diff"], 0.05, places=2) + + # Test rejection scenario + reject_transition = ReviewAndApproveAnnotation( + reviewer_decision="reject", + quality_score=0.3, + review_comments="Insufficient annotation quality", + corrections_made=False + ) + + self.assertEqual(reject_transition.target_state, AnnotationStateChoices.DRAFT) + + # Test validation failure + invalid_review = ReviewAndApproveAnnotation( + reviewer_decision="approve", # Trying to approve + quality_score=0.5, # But quality too low + review_comments="Test", + ) + + with self.assertRaises(TransitionValidationError) as cm: + invalid_review.validate_transition(context) + + self.assertIn("Cannot approve annotation with low quality score", str(cm.exception)) + + @patch('fsm.transition_utils.execute_transition') + def test_transition_builder_with_django_models(self, mock_execute): + """ + INTEGRATION TEST: TransitionBuilder with Django model integration + + Shows how to use the fluent TransitionBuilder interface with + real Django models and complex business logic. + """ + + @register_transition('task', 'bulk_update_status') + class BulkUpdateTaskStatusTransition(BaseTransition): + """Bulk update task status with metadata""" + new_status: str = Field(..., description="New status for tasks") + update_reason: str = Field(..., description="Reason for bulk update") + updated_by_system: bool = Field(False, description="Whether updated by automated system") + batch_id: str = Field(None, description="Batch operation ID") + + @property + def target_state(self) -> str: + return self.new_status + + def validate_transition(self, context: TransitionContext) -> bool: + valid_statuses = [TaskStateChoices.CREATED, TaskStateChoices.IN_PROGRESS, TaskStateChoices.COMPLETED] + if self.new_status not in valid_statuses: + raise TransitionValidationError(f"Invalid status: {self.new_status}") + + # Can't bulk update to the same status + if context.current_state == self.new_status: + raise TransitionValidationError("Cannot update to the same status") + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "new_status": self.new_status, + "update_reason": self.update_reason, + "updated_by_system": self.updated_by_system, + "batch_id": self.batch_id, + "bulk_update_timestamp": context.timestamp.isoformat(), + "previous_status": context.current_state + } + + # Mock successful execution + mock_state_record = Mock() + mock_state_record.id = "mock-uuid" + mock_execute.return_value = mock_state_record + + # Test fluent interface + result = (TransitionBuilder(self.task) + .transition('bulk_update_status') + .with_data( + new_status=TaskStateChoices.IN_PROGRESS, + update_reason="Project priority change", + updated_by_system=True, + batch_id="batch_2024_001" + ) + .by_user(self.user) + .with_context( + project_update=True, + notification_level="high" + ) + .execute()) + + # Verify the call + mock_execute.assert_called_once() + call_args, call_kwargs = mock_execute.call_args + + # Check call parameters + self.assertEqual(call_kwargs['entity'], self.task) + self.assertEqual(call_kwargs['transition_name'], 'bulk_update_status') + self.assertEqual(call_kwargs['user'], self.user) + + # Check transition data + transition_data = call_kwargs['transition_data'] + self.assertEqual(transition_data['new_status'], TaskStateChoices.IN_PROGRESS) + self.assertEqual(transition_data['update_reason'], "Project priority change") + self.assertTrue(transition_data['updated_by_system']) + self.assertEqual(transition_data['batch_id'], "batch_2024_001") + + # Check context + self.assertTrue(call_kwargs['project_update']) + self.assertEqual(call_kwargs['notification_level'], "high") + + # Check return value + self.assertEqual(result, mock_state_record) + + def test_error_handling_with_django_models(self): + """ + INTEGRATION TEST: Error handling with Django model validation + + Tests comprehensive error handling scenarios that might occur + in real Django model integration. + """ + + @register_transition('task', 'assign_with_constraints') + class AssignTaskWithConstraints(BaseTransition): + """Task assignment with business constraints""" + assignee_id: int = Field(..., description="User to assign to") + max_concurrent_tasks: int = Field(5, description="Max concurrent tasks per user") + skill_requirements: list = Field(default_factory=list, description="Required skills") + + @property + def target_state(self) -> str: + return TaskStateChoices.IN_PROGRESS + + def validate_transition(self, context: TransitionContext) -> bool: + errors = [] + + # Mock database checks (in real scenario, these would be actual queries) + + # 1. Check user exists and is active + if self.assignee_id <= 0: + errors.append("Invalid user ID") + + # 2. Check user's current task load + if self.max_concurrent_tasks < 1: + errors.append("Max concurrent tasks must be at least 1") + + # 3. Check skill requirements + if self.skill_requirements: + # Mock skill validation + available_skills = ["python", "labeling", "review"] + missing_skills = [skill for skill in self.skill_requirements if skill not in available_skills] + if missing_skills: + errors.append(f"Missing required skills: {missing_skills}") + + # 4. Check project-level constraints + if hasattr(context.entity, 'project_id'): + # Mock project validation + if context.entity.project_id <= 0: + errors.append("Invalid project configuration") + + # 5. Check organization permissions + if hasattr(context.entity, 'organization_id'): + if not context.current_user: + errors.append("User authentication required for assignment") + + if errors: + raise TransitionValidationError( + f"Assignment validation failed: {'; '.join(errors)}", + { + "validation_errors": errors, + "assignee_id": self.assignee_id, + "task_id": context.entity.pk, + "skill_requirements": self.skill_requirements + } + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + "assignee_id": self.assignee_id, + "max_concurrent_tasks": self.max_concurrent_tasks, + "skill_requirements": self.skill_requirements, + "assignment_validated": True + } + + # Test successful validation + valid_transition = AssignTaskWithConstraints( + assignee_id=123, + max_concurrent_tasks=3, + skill_requirements=["python", "labeling"] + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.CREATED, + target_state=valid_transition.target_state + ) + + self.assertTrue(valid_transition.validate_transition(context)) + + # Test multiple validation errors + invalid_transition = AssignTaskWithConstraints( + assignee_id=-1, # Invalid user ID + max_concurrent_tasks=0, # Invalid max tasks + skill_requirements=["nonexistent_skill"] # Missing skill + ) + + with self.assertRaises(TransitionValidationError) as cm: + invalid_transition.validate_transition(context) + + error = cm.exception + error_msg = str(error) + + # Check all validation errors are included + self.assertIn("Invalid user ID", error_msg) + self.assertIn("Max concurrent tasks must be at least 1", error_msg) + self.assertIn("Missing required skills", error_msg) + + # Check error context + self.assertIn("validation_errors", error.context) + self.assertEqual(len(error.context["validation_errors"]), 3) + self.assertEqual(error.context["assignee_id"], -1) + + # Test authentication requirement + context_no_user = TransitionContext( + entity=self.task, + current_user=None, # No user + current_state=TaskStateChoices.CREATED, + target_state=valid_transition.target_state + ) + + with self.assertRaises(TransitionValidationError) as cm: + valid_transition.validate_transition(context_no_user) + + self.assertIn("User authentication required", str(cm.exception)) \ No newline at end of file diff --git a/label_studio/fsm/tests/test_performance_concurrency.py b/label_studio/fsm/tests/test_performance_concurrency.py new file mode 100644 index 000000000000..624e5bb9e44b --- /dev/null +++ b/label_studio/fsm/tests/test_performance_concurrency.py @@ -0,0 +1,753 @@ +""" +Performance and concurrency tests for the declarative transition system. + +These tests validate that the transition system performs well under load +and handles concurrent operations correctly, which is critical for +production FSM systems. +""" + +import pytest +import threading +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime, timedelta +from unittest.mock import Mock, patch +from typing import Dict, Any, List +import statistics + +from django.test import TestCase, TransactionTestCase +from pydantic import Field + +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + transition_registry, + register_transition +) +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + TransitionBuilder, + validate_transition_data +) + + +class PerformanceTestTransition(BaseTransition): + """Simple transition for performance testing""" + operation_id: int = Field(..., description="Operation identifier") + data_size: int = Field(1, description="Size of data to process") + + @property + def target_state(self) -> str: + return "PROCESSED" + + @classmethod + def get_target_state(cls) -> str: + return "PROCESSED" + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def validate_transition(self, context: TransitionContext) -> bool: + # Simulate some validation work + if self.data_size < 0: + raise TransitionValidationError("Invalid data size") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Simulate processing work + return { + "operation_id": self.operation_id, + "data_size": self.data_size, + "processed_at": context.timestamp.isoformat(), + "processing_time_ms": 1 # Mock processing time + } + + +class ConcurrencyTestTransition(BaseTransition): + """Transition for testing concurrent access patterns""" + thread_id: int = Field(..., description="Thread identifier") + shared_counter: int = Field(0, description="Shared counter for testing") + sleep_duration: float = Field(0.0, description="Simulate processing delay") + execution_order: list = Field(default_factory=list, description="Track execution order") + + @property + def target_state(self) -> str: + return f"PROCESSED_BY_THREAD_{self.thread_id}" + + @classmethod + def get_target_state(cls) -> str: + return "PROCESSED_BY_THREAD_0" # Default for class-level queries + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def validate_transition(self, context: TransitionContext) -> bool: + # Record validation timing for concurrency analysis + self.execution_order.append(f"validate_{self.thread_id}_{time.time()}") + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Record transition timing + self.execution_order.append(f"transition_{self.thread_id}_{time.time()}") + + # Simulate some processing delay + if self.sleep_duration > 0: + time.sleep(self.sleep_duration) + + return { + "thread_id": self.thread_id, + "shared_counter": self.shared_counter, + "execution_order": self.execution_order.copy(), + "processed_at": context.timestamp.isoformat() + } + + +class PerformanceTests(TestCase): + """ + Performance tests for the declarative transition system. + + These tests measure execution time, memory usage patterns, + and scalability characteristics. + """ + + def setUp(self): + self.mock_entity = Mock() + self.mock_entity.pk = 1 + self.mock_entity._meta.model_name = 'test_entity' + + self.mock_user = Mock() + self.mock_user.id = 123 + + # Clear registry for clean tests + transition_registry._transitions.clear() + transition_registry.register('test_entity', 'performance_test', PerformanceTestTransition) + + def test_single_transition_performance(self): + """ + PERFORMANCE TEST: Measure single transition execution time + + Validates that individual transitions execute within acceptable time limits. + """ + + transition = PerformanceTestTransition(operation_id=1, data_size=1000) + + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state=transition.target_state + ) + + # Measure validation performance + start_time = time.perf_counter() + result = transition.validate_transition(context) + validation_time = time.perf_counter() - start_time + + self.assertTrue(result) + self.assertLess(validation_time, 0.001) # Should be under 1ms + + # Measure transition execution performance + start_time = time.perf_counter() + transition_data = transition.transition(context) + execution_time = time.perf_counter() - start_time + + self.assertIsInstance(transition_data, dict) + self.assertLess(execution_time, 0.001) # Should be under 1ms + + # Measure total workflow performance + start_time = time.perf_counter() + transition.context = context + transition.validate_transition(context) + transition.transition(context) + total_time = time.perf_counter() - start_time + + self.assertLess(total_time, 0.005) # Total should be under 5ms + + def test_batch_transition_performance(self): + """ + PERFORMANCE TEST: Measure batch transition creation and validation + + Tests performance when creating many transition instances rapidly. + """ + + batch_size = 1000 + + # Test batch creation performance + start_time = time.perf_counter() + transitions = [] + + for i in range(batch_size): + transition = PerformanceTestTransition(operation_id=i, data_size=i * 10) + transitions.append(transition) + + creation_time = time.perf_counter() - start_time + creation_time_per_item = creation_time / batch_size + + self.assertEqual(len(transitions), batch_size) + self.assertLess(creation_time_per_item, 0.001) # Under 1ms per transition + + # Test batch validation performance + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state="PROCESSED" + ) + + start_time = time.perf_counter() + validation_results = [] + + for transition in transitions: + result = transition.validate_transition(context) + validation_results.append(result) + + validation_time = time.perf_counter() - start_time + validation_time_per_item = validation_time / batch_size + + self.assertTrue(all(validation_results)) + self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation + self.assertLess(validation_time, 0.5) # Total batch under 500ms + + def test_registry_performance(self): + """ + PERFORMANCE TEST: Registry operations under load + + Tests the performance of registry lookups and registrations. + """ + + # Test registry lookup performance + lookup_count = 10000 + + start_time = time.perf_counter() + + for i in range(lookup_count): + retrieved_class = transition_registry.get_transition('test_entity', 'performance_test') + + lookup_time = time.perf_counter() - start_time + lookup_time_per_operation = lookup_time / lookup_count + + self.assertEqual(retrieved_class, PerformanceTestTransition) + self.assertLess(lookup_time_per_operation, 0.0001) # Under 0.1ms per lookup + + # Test registry registration performance + registration_count = 1000 + + start_time = time.perf_counter() + + for i in range(registration_count): + entity_name = f'entity_{i}' + transition_name = f'transition_{i}' + transition_registry.register(entity_name, transition_name, PerformanceTestTransition) + + registration_time = time.perf_counter() - start_time + registration_time_per_operation = registration_time / registration_count + + self.assertLess(registration_time_per_operation, 0.001) # Under 1ms per registration + + # Verify registrations worked + test_class = transition_registry.get_transition('entity_500', 'transition_500') + self.assertEqual(test_class, PerformanceTestTransition) + + def test_pydantic_validation_performance(self): + """ + PERFORMANCE TEST: Pydantic validation performance + + Measures the overhead of Pydantic validation in transitions. + """ + + # Test valid data performance + valid_data = {"operation_id": 123, "data_size": 1000} + validation_count = 10000 + + start_time = time.perf_counter() + + for i in range(validation_count): + transition = PerformanceTestTransition(**valid_data) + + validation_time = time.perf_counter() - start_time + validation_time_per_item = validation_time / validation_count + + self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation + + # Test validation error performance + invalid_data = {"operation_id": "invalid", "data_size": -1} + error_count = 1000 + + start_time = time.perf_counter() + errors = [] + + for i in range(error_count): + try: + PerformanceTestTransition(**invalid_data) + except Exception as e: + errors.append(e) + + error_time = time.perf_counter() - start_time + error_time_per_item = error_time / error_count + + self.assertEqual(len(errors), error_count) + self.assertLess(error_time_per_item, 0.01) # Under 10ms per error (errors are slower) + + def test_memory_usage_patterns(self): + """ + PERFORMANCE TEST: Memory usage analysis + + Tests memory usage patterns for transition instances and contexts. + """ + + import sys + + # Measure base memory usage + base_transitions = [] + for i in range(100): + transition = PerformanceTestTransition(operation_id=i, data_size=i) + base_transitions.append(transition) + + base_size = sys.getsizeof(base_transitions[0]) + + # Test memory usage with complex data + complex_transitions = [] + for i in range(100): + transition = PerformanceTestTransition(operation_id=i, data_size=i * 1000) + # Add context to transition + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state=transition.target_state, + metadata={"large_data": "x" * 1000} # Add some bulk + ) + transition.context = context + complex_transitions.append(transition) + + complex_size = sys.getsizeof(complex_transitions[0]) + + # Memory usage should be reasonable + memory_overhead = complex_size - base_size + self.assertLess(memory_overhead, 10000) # Under 10KB overhead per transition + + # Clean up contexts to test garbage collection + for transition in complex_transitions: + transition.context = None + + # Verify memory can be reclaimed (simplified test) + self.assertIsNone(complex_transitions[0].context) + + +class ConcurrencyTests(TransactionTestCase): + """ + Concurrency tests for the declarative transition system. + + These tests validate thread safety and concurrent execution patterns + that are critical for production systems. + """ + + def setUp(self): + self.mock_entity = Mock() + self.mock_entity.pk = 1 + self.mock_entity._meta.model_name = 'test_entity' + + self.mock_user = Mock() + self.mock_user.id = 123 + + # Clear registry for clean tests + transition_registry._transitions.clear() + transition_registry.register('test_entity', 'concurrency_test', ConcurrencyTestTransition) + + def test_concurrent_transition_creation(self): + """ + CONCURRENCY TEST: Thread-safe transition instance creation + + Validates that multiple threads can create transition instances + concurrently without conflicts. + """ + + thread_count = 10 + transitions_per_thread = 100 + all_transitions = [] + thread_results = [] + + def create_transitions(thread_id): + """Worker function to create transitions in a thread""" + local_transitions = [] + for i in range(transitions_per_thread): + transition = ConcurrencyTestTransition( + thread_id=thread_id, + shared_counter=i, + sleep_duration=0.001 # Small delay to increase contention + ) + local_transitions.append(transition) + return local_transitions + + # Execute concurrent creation + with ThreadPoolExecutor(max_workers=thread_count) as executor: + futures = [] + for thread_id in range(thread_count): + future = executor.submit(create_transitions, thread_id) + futures.append(future) + + for future in as_completed(futures): + thread_transitions = future.result() + thread_results.append(thread_transitions) + all_transitions.extend(thread_transitions) + + # Validate results + total_expected = thread_count * transitions_per_thread + self.assertEqual(len(all_transitions), total_expected) + + # Check thread separation + thread_ids = [t.thread_id for t in all_transitions] + unique_threads = set(thread_ids) + self.assertEqual(len(unique_threads), thread_count) + + # Validate each thread created correct number of transitions + for thread_id in range(thread_count): + thread_transitions = [t for t in all_transitions if t.thread_id == thread_id] + self.assertEqual(len(thread_transitions), transitions_per_thread) + + def test_concurrent_transition_execution(self): + """ + CONCURRENCY TEST: Concurrent transition execution + + Tests that multiple transitions can be executed concurrently + without race conditions in the execution logic. + """ + + thread_count = 5 + execution_results = [] + + def execute_transition(thread_id): + """Worker function to execute a transition""" + transition = ConcurrencyTestTransition( + thread_id=thread_id, + shared_counter=thread_id * 10, + sleep_duration=0.01 # Small delay to test concurrency + ) + + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state=transition.target_state, + timestamp=datetime.now() + ) + + # Execute validation and transition + validation_result = transition.validate_transition(context) + transition_data = transition.transition(context) + + return { + "thread_id": thread_id, + "validation_result": validation_result, + "transition_data": transition_data, + "execution_order": transition.execution_order + } + + # Execute concurrent transitions + with ThreadPoolExecutor(max_workers=thread_count) as executor: + futures = [] + for thread_id in range(thread_count): + future = executor.submit(execute_transition, thread_id) + futures.append(future) + + for future in as_completed(futures): + result = future.result() + execution_results.append(result) + + # Validate results + self.assertEqual(len(execution_results), thread_count) + + for result in execution_results: + self.assertTrue(result["validation_result"]) + self.assertIn("thread_id", result["transition_data"]) + self.assertIsInstance(result["execution_order"], list) + self.assertGreater(len(result["execution_order"]), 0) + + # Check thread isolation + thread_ids = [r["transition_data"]["thread_id"] for r in execution_results] + self.assertEqual(set(thread_ids), set(range(thread_count))) + + def test_registry_thread_safety(self): + """ + CONCURRENCY TEST: Registry thread safety + + Tests that the transition registry handles concurrent + registration and lookup operations safely. + """ + + thread_count = 10 + operations_per_thread = 100 + + def registry_operations(thread_id): + """Worker function for registry operations""" + operations_completed = 0 + + for i in range(operations_per_thread): + # Mix of registration and lookup operations + if i % 3 == 0: + # Register new transition + entity_name = f'entity_{thread_id}_{i}' + transition_name = f'transition_{i}' + transition_registry.register(entity_name, transition_name, ConcurrencyTestTransition) + operations_completed += 1 + + elif i % 3 == 1: + # Lookup existing transition + try: + found_class = transition_registry.get_transition('test_entity', 'concurrency_test') + if found_class == ConcurrencyTestTransition: + operations_completed += 1 + except Exception: + pass + + else: + # List operations + try: + entities = transition_registry.list_entities() + if len(entities) >= 0: # Should always be non-negative + operations_completed += 1 + except Exception: + pass + + return operations_completed + + # Execute concurrent registry operations + with ThreadPoolExecutor(max_workers=thread_count) as executor: + futures = [] + for thread_id in range(thread_count): + future = executor.submit(registry_operations, thread_id) + futures.append(future) + + operation_counts = [] + for future in as_completed(futures): + count = future.result() + operation_counts.append(count) + + # Validate no operations failed due to thread safety issues + total_operations = sum(operation_counts) + expected_minimum = thread_count * operations_per_thread * 0.9 # Allow some variance + + self.assertGreater(total_operations, expected_minimum) + + # Registry should be in consistent state + entities = transition_registry.list_entities() + self.assertIsInstance(entities, list) + self.assertGreater(len(entities), thread_count) # Should have entities from all threads + + def test_context_isolation(self): + """ + CONCURRENCY TEST: Context isolation between threads + + Ensures that transition contexts remain isolated between + concurrent executions and don't leak data. + """ + + thread_count = 8 + context_data = [] + + def context_isolation_test(thread_id): + """Test context isolation in a thread""" + # Create unique context data for this thread + unique_data = { + "thread_specific_id": thread_id, + "random_data": f"thread_{thread_id}_data", + "timestamp": datetime.now().isoformat(), + "test_counter": thread_id * 1000 + } + + transition = ConcurrencyTestTransition( + thread_id=thread_id, + shared_counter=thread_id, + sleep_duration=0.005 # Small delay to increase chance of interference + ) + + context = TransitionContext( + entity=self.mock_entity, + current_user=self.mock_user, + current_state="CREATED", + target_state=transition.target_state, + metadata=unique_data + ) + + # Set context on transition + transition.context = context + + # Execute transition + validation_result = transition.validate_transition(context) + transition_data = transition.transition(context) + + # Retrieve context and verify isolation + retrieved_context = transition.context + + return { + "thread_id": thread_id, + "original_metadata": unique_data, + "retrieved_metadata": retrieved_context.metadata, + "validation_result": validation_result, + "transition_data": transition_data + } + + # Execute with high concurrency + with ThreadPoolExecutor(max_workers=thread_count) as executor: + futures = [] + for thread_id in range(thread_count): + future = executor.submit(context_isolation_test, thread_id) + futures.append(future) + + for future in as_completed(futures): + result = future.result() + context_data.append(result) + + # Validate context isolation + self.assertEqual(len(context_data), thread_count) + + for result in context_data: + thread_id = result["thread_id"] + original_metadata = result["original_metadata"] + retrieved_metadata = result["retrieved_metadata"] + + # Context should match exactly what was set for this thread + self.assertEqual(original_metadata["thread_specific_id"], thread_id) + self.assertEqual(retrieved_metadata["thread_specific_id"], thread_id) + self.assertEqual(original_metadata["random_data"], retrieved_metadata["random_data"]) + self.assertEqual(original_metadata["test_counter"], thread_id * 1000) + + # Should not have data from other threads + for other_result in context_data: + if other_result["thread_id"] != thread_id: + self.assertNotEqual( + retrieved_metadata["thread_specific_id"], + other_result["original_metadata"]["thread_specific_id"] + ) + + def test_stress_test_mixed_operations(self): + """ + STRESS TEST: Mixed operations under load + + Combines multiple types of operations under high concurrency + to test overall system stability. + """ + + duration_seconds = 2 # Short duration for CI + thread_count = 6 + + # Shared statistics + stats = { + "transitions_created": 0, + "validations_performed": 0, + "transitions_executed": 0, + "registry_lookups": 0, + "errors_encountered": 0 + } + stats_lock = threading.Lock() + + def mixed_operations_worker(worker_id): + """Worker that performs mixed operations""" + local_stats = { + "transitions_created": 0, + "validations_performed": 0, + "transitions_executed": 0, + "registry_lookups": 0, + "errors_encountered": 0 + } + + end_time = time.time() + duration_seconds + operation_counter = 0 + + while time.time() < end_time: + try: + operation_type = operation_counter % 4 + + if operation_type == 0: + # Create transition + transition = ConcurrencyTestTransition( + thread_id=worker_id, + shared_counter=operation_counter + ) + local_stats["transitions_created"] += 1 + + elif operation_type == 1: + # Validate transition + transition = ConcurrencyTestTransition( + thread_id=worker_id, + shared_counter=operation_counter + ) + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=transition.target_state + ) + transition.validate_transition(context) + local_stats["validations_performed"] += 1 + + elif operation_type == 2: + # Execute transition + transition = ConcurrencyTestTransition( + thread_id=worker_id, + shared_counter=operation_counter + ) + context = TransitionContext( + entity=self.mock_entity, + current_state="CREATED", + target_state=transition.target_state + ) + transition.transition(context) + local_stats["transitions_executed"] += 1 + + else: + # Registry lookup + found = transition_registry.get_transition('test_entity', 'concurrency_test') + if found: + local_stats["registry_lookups"] += 1 + + operation_counter += 1 + + except Exception as e: + local_stats["errors_encountered"] += 1 + + # Small yield to allow other threads + time.sleep(0.0001) + + # Update shared statistics + with stats_lock: + for key in stats: + stats[key] += local_stats[key] + + return local_stats + + # Execute stress test + with ThreadPoolExecutor(max_workers=thread_count) as executor: + futures = [] + for worker_id in range(thread_count): + future = executor.submit(mixed_operations_worker, worker_id) + futures.append(future) + + worker_results = [] + for future in as_completed(futures): + result = future.result() + worker_results.append(result) + + # Validate stress test results + total_operations = sum( + stats["transitions_created"] + + stats["validations_performed"] + + stats["transitions_executed"] + + stats["registry_lookups"] + ) + + # Should have performed substantial work + self.assertGreater(total_operations, thread_count * 10) + + # Error rate should be very low (< 1%) + error_rate = stats["errors_encountered"] / max(total_operations, 1) + self.assertLess(error_rate, 0.01) + + # All operation types should have been performed + self.assertGreater(stats["transitions_created"], 0) + self.assertGreater(stats["validations_performed"], 0) + self.assertGreater(stats["transitions_executed"], 0) + self.assertGreater(stats["registry_lookups"], 0) \ No newline at end of file diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py new file mode 100644 index 000000000000..119b15701d79 --- /dev/null +++ b/label_studio/fsm/transition_utils.py @@ -0,0 +1,342 @@ +""" +Utility functions for working with the declarative transition system. + +This module provides helper functions to make it easier to integrate +the new Pydantic-based transition system with existing Label Studio code. +""" + +from typing import Any, Dict, List, Optional, Type, Union + +from django.db.models import Model + +from .models import BaseState, get_state_model_for_entity +from .state_manager import StateManager +from .transitions import BaseTransition, TransitionValidationError, transition_registry + + +def execute_transition( + entity: Model, + transition_name: str, + transition_data: Dict[str, Any], + user=None, + **context_kwargs +) -> BaseState: + """ + Execute a named transition on an entity. + + This is a convenience function that looks up the transition class + and executes it with the provided data. + + Args: + entity: The entity to transition + transition_name: Name of the registered transition + transition_data: Data for the transition (validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found + TransitionValidationError: If transition validation fails + """ + entity_name = entity._meta.model_name.lower() + return transition_registry.execute_transition( + entity_name=entity_name, + transition_name=transition_name, + entity=entity, + transition_data=transition_data, + user=user, + **context_kwargs + ) + + +def execute_transition_instance( + entity: Model, + transition: BaseTransition, + user=None, + **context_kwargs +) -> BaseState: + """ + Execute a pre-created transition instance. + + Args: + entity: The entity to transition + transition: Instance of a transition class + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + """ + return StateManager.execute_declarative_transition( + transition=transition, + entity=entity, + user=user, + **context_kwargs + ) + + +def get_available_transitions(entity: Model) -> Dict[str, Type[BaseTransition]]: + """ + Get all available transitions for an entity. + + Args: + entity: The entity to get transitions for + + Returns: + Dictionary mapping transition names to transition classes + """ + entity_name = entity._meta.model_name.lower() + return transition_registry.get_transitions_for_entity(entity_name) + + +def get_valid_transitions( + entity: Model, + user=None, + validate: bool = True +) -> Dict[str, Type[BaseTransition]]: + """ + Get transitions that are valid for the entity's current state. + + Args: + entity: The entity to check transitions for + user: User context for validation + validate: Whether to validate each transition (may be expensive) + + Returns: + Dictionary mapping transition names to transition classes + that are valid for the current state + """ + available = get_available_transitions(entity) + + if not validate: + return available + + valid_transitions = {} + + for name, transition_class in available.items(): + try: + # Get current state information + current_state_object = StateManager.get_current_state_object(entity) + current_state = current_state_object.state if current_state_object else None + + # Build minimal context for validation + from .transitions import TransitionContext + context = TransitionContext( + entity=entity, + current_user=user, + current_state_object=current_state_object, + current_state=current_state, + target_state=transition_class.get_target_state(), + organization_id=getattr(entity, 'organization_id', None) + ) + + # Use class-level validation that doesn't require an instance + if transition_class.can_transition_from_state(context): + valid_transitions[name] = transition_class + + except (TransitionValidationError, Exception): + # Transition is not valid for current state/context + continue + + return valid_transitions + + +def create_transition_from_dict( + transition_class: Type[BaseTransition], + data: Dict[str, Any] +) -> BaseTransition: + """ + Create a transition instance from a dictionary of data. + + This handles Pydantic validation and provides clear error messages. + + Args: + transition_class: The transition class to instantiate + data: Dictionary of transition data + + Returns: + Validated transition instance + + Raises: + ValueError: If data validation fails + """ + try: + return transition_class(**data) + except Exception as e: + raise ValueError(f"Failed to create {transition_class.__name__}: {e}") + + +def get_transition_schema(transition_class: Type[BaseTransition]) -> Dict[str, Any]: + """ + Get the JSON schema for a transition class. + + Useful for generating API documentation or frontend forms. + + Args: + transition_class: The transition class + + Returns: + JSON schema dictionary + """ + return transition_class.model_json_schema() + + +def validate_transition_data( + transition_class: Type[BaseTransition], + data: Dict[str, Any] +) -> Dict[str, List[str]]: + """ + Validate transition data without creating an instance. + + Args: + transition_class: The transition class + data: Data to validate + + Returns: + Dictionary of field names to error messages (empty if valid) + """ + try: + transition_class(**data) + return {} + except Exception as e: + # Parse Pydantic validation errors + errors = {} + if hasattr(e, 'errors'): + for error in e.errors(): + field = '.'.join(str(loc) for loc in error['loc']) + if field not in errors: + errors[field] = [] + errors[field].append(error['msg']) + else: + errors['__root__'] = [str(e)] + return errors + + +def get_entity_state_flow(entity: Model) -> List[Dict[str, Any]]: + """ + Get a summary of the state flow for an entity type. + + This analyzes all registered transitions and builds a flow diagram. + + Args: + entity: Example entity instance + + Returns: + List of state flow information + """ + entity_name = entity._meta.model_name.lower() + transitions = transition_registry.get_transitions_for_entity(entity_name) + + # Build state flow information + states = set() + flows = [] + + for transition_name, transition_class in transitions.items(): + # Create instance to get target state + try: + transition = transition_class() + target = transition.target_state + states.add(target) + + flows.append({ + 'transition_name': transition_name, + 'transition_class': transition_class.__name__, + 'target_state': target, + 'description': transition_class.__doc__ or '', + 'fields': list(transition_class.model_fields.keys()) + }) + except Exception: + continue + + return flows + + +# Backward compatibility helpers + +def transition_state_declarative( + entity: Model, + transition_name: str, + user=None, + **transition_data +) -> BaseState: + """ + Backward-compatible helper for transitioning state declaratively. + + This provides a similar interface to StateManager.transition_state + but uses the declarative system. + """ + return execute_transition( + entity=entity, + transition_name=transition_name, + transition_data=transition_data, + user=user + ) + + +class TransitionBuilder: + """ + Builder class for constructing and executing transitions fluently. + + Example usage: + result = (TransitionBuilder(entity) + .transition('start_task') + .with_data(assigned_user_id=123, priority='high') + .by_user(request.user) + .execute()) + """ + + def __init__(self, entity: Model): + self.entity = entity + self._transition_name: Optional[str] = None + self._transition_data: Dict[str, Any] = {} + self._user = None + self._context_data: Dict[str, Any] = {} + + def transition(self, name: str) -> 'TransitionBuilder': + """Set the transition name""" + self._transition_name = name + return self + + def with_data(self, **data) -> 'TransitionBuilder': + """Add transition data""" + self._transition_data.update(data) + return self + + def by_user(self, user) -> 'TransitionBuilder': + """Set the executing user""" + self._user = user + return self + + def with_context(self, **context) -> 'TransitionBuilder': + """Add context data""" + self._context_data.update(context) + return self + + def execute(self) -> BaseState: + """Execute the configured transition""" + if not self._transition_name: + raise ValueError("Transition name not specified") + + return execute_transition( + entity=self.entity, + transition_name=self._transition_name, + transition_data=self._transition_data, + user=self._user, + **self._context_data + ) + + def validate(self) -> Dict[str, List[str]]: + """Validate the configured transition without executing""" + if not self._transition_name: + raise ValueError("Transition name not specified") + + entity_name = self.entity._meta.model_name.lower() + transition_class = transition_registry.get_transition(entity_name, self._transition_name) + + if not transition_class: + raise ValueError(f"Transition '{self._transition_name}' not found for entity '{entity_name}'") + + return validate_transition_data(transition_class, self._transition_data) \ No newline at end of file diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py new file mode 100644 index 000000000000..01621d24cdcf --- /dev/null +++ b/label_studio/fsm/transitions.py @@ -0,0 +1,486 @@ +""" +Declarative Pydantic-based transition system for FSM engine. + +This module provides a framework for defining state transitions as first-class +Pydantic models with built-in validation, context passing, and middleware-like +functionality for enhanced declarative state management. +""" + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union + +from django.contrib.auth import get_user_model +from django.db.models import Model +from pydantic import BaseModel, Field, ConfigDict + +from .models import BaseState + +User = get_user_model() + +# Type variables for generic transition context +EntityType = TypeVar('EntityType', bound=Model) +StateModelType = TypeVar('StateModelType', bound=BaseState) + + +class TransitionContext(BaseModel, Generic[EntityType, StateModelType]): + """ + Context object passed to all transitions containing middleware-like information. + + This provides access to current state, entity, user, and other contextual + information needed for transition validation and execution. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Core context information + entity: Any = Field(..., description="The entity being transitioned") + current_user: Optional[Any] = Field(None, description="User triggering the transition") + current_state_object: Optional[Any] = Field(None, description="Full current state object") + current_state: Optional[str] = Field(None, description="Current state as string") + target_state: str = Field(..., description="Target state for this transition") + + # Timing and metadata + timestamp: datetime = Field(default_factory=datetime.now, description="When transition was initiated") + transition_name: Optional[str] = Field(None, description="Name of the transition method") + + # Additional context data + request_data: Dict[str, Any] = Field(default_factory=dict, description="Additional request/context data") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Transition-specific metadata") + + # Organizational context + organization_id: Optional[int] = Field(None, description="Organization context for the transition") + + @property + def has_current_state(self) -> bool: + """Check if entity has a current state""" + return self.current_state is not None + + @property + def is_initial_transition(self) -> bool: + """Check if this is the first state transition for the entity""" + return not self.has_current_state + + +class TransitionValidationError(Exception): + """Exception raised when transition validation fails""" + + def __init__(self, message: str, context: Optional[Dict[str, Any]] = None): + super().__init__(message) + self.context = context or {} + + +class BaseTransition(BaseModel, ABC, Generic[EntityType, StateModelType]): + """ + Abstract base class for all declarative state transitions. + + This provides the framework for implementing transitions as first-class Pydantic + models with built-in validation, context handling, and execution logic. + + Example usage: + class StartTaskTransition(BaseTransition[Task, TaskState]): + assigned_user_id: int = Field(..., description="User assigned to start the task") + estimated_duration: Optional[int] = Field(None, description="Estimated completion time in hours") + + @property + def target_state(self) -> str: + return TaskStateChoices.IN_PROGRESS + + def validate_transition(self, context: TransitionContext[Task, TaskState]) -> bool: + if context.current_state == TaskStateChoices.COMPLETED: + raise TransitionValidationError("Cannot start an already completed task") + return True + + def transition(self, context: TransitionContext[Task, TaskState]) -> Dict[str, Any]: + return { + "assigned_user_id": self.assigned_user_id, + "estimated_duration": self.estimated_duration, + "started_at": context.timestamp.isoformat() + } + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + use_enum_values=True + ) + + def __init__(self, **data): + super().__init__(**data) + self.__context: Optional[TransitionContext[EntityType, StateModelType]] = None + + @property + def context(self) -> Optional[TransitionContext[EntityType, StateModelType]]: + """Access the current transition context""" + return getattr(self, '_BaseTransition__context', None) + + @context.setter + def context(self, value: TransitionContext[EntityType, StateModelType]): + """Set the transition context""" + self.__context = value + + @property + @abstractmethod + def target_state(self) -> str: + """ + The target state this transition leads to. + + Returns: + String representation of the target state + """ + pass + + @property + def transition_name(self) -> str: + """ + Name of this transition for audit purposes. + + Defaults to the class name in snake_case. + """ + class_name = self.__class__.__name__ + # Convert CamelCase to snake_case + result = "" + for i, char in enumerate(class_name): + if char.isupper() and i > 0: + result += "_" + result += char.lower() + return result + + @classmethod + def get_target_state(cls) -> Optional[str]: + """ + Get the target state for this transition class without creating an instance. + + Override this in subclasses where the target state is known at the class level. + + Returns: + The target state name, or None if it depends on instance data + """ + return None + + @classmethod + def can_transition_from_state(cls, context: TransitionContext[EntityType, StateModelType]) -> bool: + """ + Class-level validation for whether this transition type is allowed from the current state. + + This method checks if the transition is structurally valid (e.g., allowed state transitions) + without needing the actual transition data. Override this to implement state-based rules. + + Args: + context: The transition context containing entity, user, and state information + + Returns: + True if transition type is allowed from current state, False otherwise + """ + return True + + def validate_transition(self, context: TransitionContext[EntityType, StateModelType]) -> bool: + """ + Validate whether this specific transition instance can be performed. + + This method validates both the transition type (via can_transition_from_state) + and the specific transition data. Override to add data-specific validation. + + Args: + context: The transition context containing entity, user, and state information + + Returns: + True if transition is valid, False otherwise + + Raises: + TransitionValidationError: If transition validation fails with specific reason + """ + # First check if this transition type is allowed + if not self.can_transition_from_state(context): + return False + + # Then perform instance-specific validation + return True + + def pre_transition_hook(self, context: TransitionContext[EntityType, StateModelType]) -> None: + """ + Hook called before the transition is executed. + + Use this for any setup or preparation needed before state change. + Override in subclasses as needed. + + Args: + context: The transition context + """ + pass + + @abstractmethod + def transition(self, context: TransitionContext[EntityType, StateModelType]) -> Dict[str, Any]: + """ + Execute the transition and return context data for the state record. + + This is the core method that implements the transition logic. + Must be implemented by all concrete transition classes. + + Args: + context: The transition context containing all necessary information + + Returns: + Dictionary of context data to be stored with the state record + + Raises: + TransitionValidationError: If transition cannot be completed + """ + pass + + def post_transition_hook( + self, + context: TransitionContext[EntityType, StateModelType], + state_record: StateModelType + ) -> None: + """ + Hook called after the transition has been successfully executed. + + Use this for any cleanup, notifications, or side effects after state change. + Override in subclasses as needed. + + Args: + context: The transition context + state_record: The newly created state record + """ + pass + + def get_reason(self, context: TransitionContext[EntityType, StateModelType]) -> str: + """ + Get a human-readable reason for this transition. + + Override in subclasses to provide more specific reasons. + + Args: + context: The transition context + + Returns: + Human-readable reason string + """ + user_info = f"by {context.current_user}" if context.current_user else "automatically" + return f"{self.__class__.__name__} executed {user_info}" + + def execute(self, context: TransitionContext[EntityType, StateModelType]) -> StateModelType: + """ + Execute the complete transition workflow. + + This orchestrates the entire transition process: + 1. Set context on the transition instance + 2. Validate the transition + 3. Execute pre-transition hooks + 4. Perform the actual transition + 5. Create the state record + 6. Execute post-transition hooks + + Args: + context: The transition context + + Returns: + The newly created state record + + Raises: + TransitionValidationError: If validation fails + Exception: If transition execution fails + """ + # Set context for access during transition + self.context = context + + # Update context with transition name + context.transition_name = self.transition_name + + try: + # Validate transition + if not self.validate_transition(context): + raise TransitionValidationError( + f"Transition validation failed for {self.transition_name}", + {"current_state": context.current_state, "target_state": self.target_state} + ) + + # Pre-transition hook + self.pre_transition_hook(context) + + # Execute the transition logic + transition_data = self.transition(context) + + # Create the state record through StateManager + from .state_manager import StateManager + + success = StateManager.transition_state( + entity=context.entity, + new_state=self.target_state, + transition_name=self.transition_name, + user=context.current_user, + context=transition_data, + reason=self.get_reason(context) + ) + + if not success: + raise TransitionValidationError(f"Failed to create state record for {self.transition_name}") + + # Get the newly created state record + state_record = StateManager.get_current_state_object(context.entity) + + # Post-transition hook + self.post_transition_hook(context, state_record) + + return state_record + + except Exception as e: + # Clear context on error + self.context = None + raise + + +class TransitionRegistry: + """ + Registry for managing declarative transitions. + + Provides a centralized way to register, discover, and execute transitions + for different entity types and state models. + """ + + def __init__(self): + self._transitions: Dict[str, Dict[str, Type[BaseTransition]]] = {} + + def register( + self, + entity_name: str, + transition_name: str, + transition_class: Type[BaseTransition] + ): + """ + Register a transition class for an entity. + + Args: + entity_name: Name of the entity type (e.g., 'task', 'annotation') + transition_name: Name of the transition (e.g., 'start_task', 'submit_annotation') + transition_class: The transition class to register + """ + if entity_name not in self._transitions: + self._transitions[entity_name] = {} + + self._transitions[entity_name][transition_name] = transition_class + + def get_transition( + self, + entity_name: str, + transition_name: str + ) -> Optional[Type[BaseTransition]]: + """ + Get a registered transition class. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition + + Returns: + The transition class if found, None otherwise + """ + return self._transitions.get(entity_name, {}).get(transition_name) + + def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type[BaseTransition]]: + """ + Get all registered transitions for an entity type. + + Args: + entity_name: Name of the entity type + + Returns: + Dictionary mapping transition names to transition classes + """ + return self._transitions.get(entity_name, {}).copy() + + def list_entities(self) -> list[str]: + """Get a list of all registered entity types.""" + return list(self._transitions.keys()) + + def execute_transition( + self, + entity_name: str, + transition_name: str, + entity: Model, + transition_data: Dict[str, Any], + user: Optional[User] = None, + **context_kwargs + ) -> StateModelType: + """ + Execute a registered transition. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition + entity: The entity instance to transition + transition_data: Data for the transition (will be validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found + TransitionValidationError: If transition validation fails + """ + transition_class = self.get_transition(entity_name, transition_name) + if not transition_class: + raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") + + # Create transition instance with provided data + transition = transition_class(**transition_data) + + # Get current state information + from .state_manager import StateManager + current_state_object = StateManager.get_current_state_object(entity) + current_state = current_state_object.state if current_state_object else None + + # Build transition context + context = TransitionContext( + entity=entity, + current_user=user, + current_state_object=current_state_object, + current_state=current_state, + target_state=transition.target_state, + organization_id=getattr(entity, 'organization_id', None), + **context_kwargs + ) + + # Execute the transition + return transition.execute(context) + + +# Global transition registry instance +transition_registry = TransitionRegistry() + + +def register_transition(entity_name: str, transition_name: str = None): + """ + Decorator to register a transition class. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition (defaults to class name in snake_case) + + Example: + @register_transition('task', 'start_task') + class StartTaskTransition(BaseTransition[Task, TaskState]): + # ... implementation + """ + def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: + name = transition_name + if name is None: + # Generate name from class name + class_name = transition_class.__name__ + if class_name.endswith('Transition'): + class_name = class_name[:-10] # Remove 'Transition' suffix + + # Convert CamelCase to snake_case + name = "" + for i, char in enumerate(class_name): + if char.isupper() and i > 0: + name += "_" + name += char.lower() + + transition_registry.register(entity_name, name, transition_class) + return transition_class + + return decorator \ No newline at end of file From d07ea13249dbc53c5ecf76dabbf7c49a181e39f8 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 12:23:01 -0500 Subject: [PATCH 13/83] fix linting errors --- .../fsm/tests/test_api_usage_examples.py | 913 +++++++++--------- .../fsm/tests/test_declarative_transitions.py | 888 ++++++++--------- .../tests/test_edge_cases_error_handling.py | 632 ++++++------ .../tests/test_integration_django_models.py | 609 ++++++------ .../fsm/tests/test_performance_concurrency.py | 529 +++++----- label_studio/fsm/transition_utils.py | 174 ++-- label_studio/fsm/transitions.py | 261 +++-- 7 files changed, 1864 insertions(+), 2142 deletions(-) diff --git a/label_studio/fsm/tests/test_api_usage_examples.py b/label_studio/fsm/tests/test_api_usage_examples.py index d3307b058331..fcd8bbc26013 100644 --- a/label_studio/fsm/tests/test_api_usage_examples.py +++ b/label_studio/fsm/tests/test_api_usage_examples.py @@ -7,826 +7,787 @@ """ import json -import pytest from datetime import datetime, timedelta -from unittest.mock import Mock, patch -from typing import Dict, Any, List, Optional +from typing import Any, Dict, List, Optional +from unittest.mock import Mock from django.test import TestCase -from pydantic import Field, validator - +from fsm.transition_utils import ( + get_transition_schema, +) from fsm.transitions import ( - BaseTransition, - TransitionContext, + BaseTransition, + TransitionContext, TransitionValidationError, + register_transition, transition_registry, - register_transition -) -from fsm.transition_utils import ( - execute_transition, - get_available_transitions, - get_transition_schema, - validate_transition_data, - TransitionBuilder, - create_transition_from_dict ) +from pydantic import Field, validator class APIIntegrationExampleTests(TestCase): """ API integration examples demonstrating real-world usage patterns. - + These tests show how to integrate the transition system with REST APIs, handle JSON data, validate requests, and format responses. """ - + def setUp(self): self.mock_entity = Mock() self.mock_entity.pk = 1 self.mock_entity._meta.model_name = 'task' self.mock_entity.organization_id = 100 - + self.mock_user = Mock() self.mock_user.id = 42 - self.mock_user.username = "api_user" - + self.mock_user.username = 'api_user' + # Clear registry transition_registry._transitions.clear() - + def test_rest_api_task_assignment_example(self): """ API EXAMPLE: REST endpoint for task assignment - + Shows how to implement a REST API endpoint that uses declarative transitions with proper validation and error handling. """ - + @register_transition('task', 'api_assign_task') class APITaskAssignmentTransition(BaseTransition): """Task assignment via API with comprehensive validation""" - assignee_id: int = Field(..., description="ID of user to assign task to") - priority: str = Field("normal", description="Task priority level") - deadline: Optional[datetime] = Field(None, description="Assignment deadline") - assignment_notes: str = Field("", description="Notes about the assignment") - notify_assignee: bool = Field(True, description="Whether to notify the assignee") - + + assignee_id: int = Field(..., description='ID of user to assign task to') + priority: str = Field('normal', description='Task priority level') + deadline: Optional[datetime] = Field(None, description='Assignment deadline') + assignment_notes: str = Field('', description='Notes about the assignment') + notify_assignee: bool = Field(True, description='Whether to notify the assignee') + @validator('priority') def validate_priority(cls, v): valid_priorities = ['low', 'normal', 'high', 'urgent'] if v not in valid_priorities: raise ValueError(f'Priority must be one of: {valid_priorities}') return v - + @validator('deadline') def validate_deadline(cls, v): if v and v <= datetime.now(): raise ValueError('Deadline must be in the future') return v - + @property def target_state(self) -> str: - return "ASSIGNED" - + return 'ASSIGNED' + def validate_transition(self, context: TransitionContext) -> bool: # Business logic validation - if context.current_state not in ["CREATED", "UNASSIGNED"]: + if context.current_state not in ['CREATED', 'UNASSIGNED']: raise TransitionValidationError( - f"Cannot assign task in state: {context.current_state}", - {"valid_states": ["CREATED", "UNASSIGNED"]} + f'Cannot assign task in state: {context.current_state}', + {'valid_states': ['CREATED', 'UNASSIGNED']}, ) - + # Mock user existence check if self.assignee_id <= 0: - raise TransitionValidationError( - "Invalid assignee ID", - {"assignee_id": self.assignee_id} - ) - + raise TransitionValidationError('Invalid assignee ID', {'assignee_id': self.assignee_id}) + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "assignee_id": self.assignee_id, - "priority": self.priority, - "deadline": self.deadline.isoformat() if self.deadline else None, - "assignment_notes": self.assignment_notes, - "notify_assignee": self.notify_assignee, - "assigned_by_id": context.current_user.id if context.current_user else None, - "assigned_at": context.timestamp.isoformat(), - "api_version": "v1" + 'assignee_id': self.assignee_id, + 'priority': self.priority, + 'deadline': self.deadline.isoformat() if self.deadline else None, + 'assignment_notes': self.assignment_notes, + 'notify_assignee': self.notify_assignee, + 'assigned_by_id': context.current_user.id if context.current_user else None, + 'assigned_at': context.timestamp.isoformat(), + 'api_version': 'v1', } - + # Simulate API request data (JSON from client) api_request_data = { - "assignee_id": 123, - "priority": "high", - "deadline": (datetime.now() + timedelta(days=7)).isoformat(), - "assignment_notes": "Critical task requiring immediate attention", - "notify_assignee": True + 'assignee_id': 123, + 'priority': 'high', + 'deadline': (datetime.now() + timedelta(days=7)).isoformat(), + 'assignment_notes': 'Critical task requiring immediate attention', + 'notify_assignee': True, } - + # API endpoint simulation: Parse and validate JSON try: # Step 1: Create transition from API data transition = APITaskAssignmentTransition(**api_request_data) - + # Step 2: Execute transition context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="CREATED", + current_state='CREATED', target_state=transition.target_state, - request_data=api_request_data + request_data=api_request_data, ) - + # Validate self.assertTrue(transition.validate_transition(context)) - + # Execute result_data = transition.transition(context) - + # Step 3: Format API response api_response = { - "success": True, - "message": "Task assigned successfully", - "data": { - "task_id": self.mock_entity.pk, - "new_state": transition.target_state, - "assignment_details": result_data + 'success': True, + 'message': 'Task assigned successfully', + 'data': { + 'task_id': self.mock_entity.pk, + 'new_state': transition.target_state, + 'assignment_details': result_data, }, - "timestamp": datetime.now().isoformat() + 'timestamp': datetime.now().isoformat(), } - + # Validate API response - self.assertTrue(api_response["success"]) - self.assertEqual(api_response["data"]["new_state"], "ASSIGNED") - self.assertEqual(api_response["data"]["assignment_details"]["assignee_id"], 123) - self.assertEqual(api_response["data"]["assignment_details"]["priority"], "high") - + self.assertTrue(api_response['success']) + self.assertEqual(api_response['data']['new_state'], 'ASSIGNED') + self.assertEqual(api_response['data']['assignment_details']['assignee_id'], 123) + self.assertEqual(api_response['data']['assignment_details']['priority'], 'high') + except ValueError as e: # Handle Pydantic validation errors api_response = { - "success": False, - "error": "Validation Error", - "message": str(e), - "timestamp": datetime.now().isoformat() + 'success': False, + 'error': 'Validation Error', + 'message': str(e), + 'timestamp': datetime.now().isoformat(), } - + except TransitionValidationError as e: # Handle business logic validation errors api_response = { - "success": False, - "error": "Business Rule Violation", - "message": str(e), - "context": e.context, - "timestamp": datetime.now().isoformat() + 'success': False, + 'error': 'Business Rule Violation', + 'message': str(e), + 'context': e.context, + 'timestamp': datetime.now().isoformat(), } - + # Test error handling with invalid data invalid_request = { - "assignee_id": -1, # Invalid ID - "priority": "invalid_priority", # Invalid priority - "deadline": "2020-01-01T00:00:00" # Past deadline + 'assignee_id': -1, # Invalid ID + 'priority': 'invalid_priority', # Invalid priority + 'deadline': '2020-01-01T00:00:00', # Past deadline } - + with self.assertRaises(ValueError): APITaskAssignmentTransition(**invalid_request) - + def test_json_schema_generation_for_api_docs(self): """ API DOCUMENTATION: JSON Schema generation - + Shows how to generate OpenAPI/JSON schemas for API documentation from Pydantic transition models. """ - + @register_transition('annotation', 'api_submit_annotation') class APIAnnotationSubmissionTransition(BaseTransition): """Submit annotation via API with rich metadata""" + confidence_score: float = Field( - ..., - ge=0.0, - le=1.0, - description="Annotator's confidence in the annotation (0.0-1.0)" + ..., ge=0.0, le=1.0, description="Annotator's confidence in the annotation (0.0-1.0)" ) annotation_quality: str = Field( - "good", - description="Subjective quality assessment", - pattern="^(excellent|good|fair|poor)$" - ) - time_spent_seconds: int = Field( - ..., - ge=1, - description="Time spent on annotation in seconds" - ) - difficulty_level: str = Field( - "medium", - description="Perceived difficulty of the annotation task" - ) - review_requested: bool = Field( - False, - description="Whether the annotator requests manual review" - ) - tags: List[str] = Field( - default_factory=list, - description="Optional tags for categorization" + 'good', description='Subjective quality assessment', pattern='^(excellent|good|fair|poor)$' ) + time_spent_seconds: int = Field(..., ge=1, description='Time spent on annotation in seconds') + difficulty_level: str = Field('medium', description='Perceived difficulty of the annotation task') + review_requested: bool = Field(False, description='Whether the annotator requests manual review') + tags: List[str] = Field(default_factory=list, description='Optional tags for categorization') metadata: Dict[str, Any] = Field( - default_factory=dict, - description="Additional metadata about the annotation process" + default_factory=dict, description='Additional metadata about the annotation process' ) - + @property def target_state(self) -> str: - return "SUBMITTED" - + return 'SUBMITTED' + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "confidence_score": self.confidence_score, - "annotation_quality": self.annotation_quality, - "time_spent_seconds": self.time_spent_seconds, - "difficulty_level": self.difficulty_level, - "review_requested": self.review_requested, - "tags": self.tags, - "metadata": self.metadata, - "submitted_at": context.timestamp.isoformat() + 'confidence_score': self.confidence_score, + 'annotation_quality': self.annotation_quality, + 'time_spent_seconds': self.time_spent_seconds, + 'difficulty_level': self.difficulty_level, + 'review_requested': self.review_requested, + 'tags': self.tags, + 'metadata': self.metadata, + 'submitted_at': context.timestamp.isoformat(), } - + # Generate JSON schema schema = get_transition_schema(APIAnnotationSubmissionTransition) - + # Validate schema structure - self.assertIn("properties", schema) - self.assertIn("required", schema) - + self.assertIn('properties', schema) + self.assertIn('required', schema) + # Check specific field schemas - properties = schema["properties"] - + properties = schema['properties'] + # confidence_score should have min/max constraints - confidence_schema = properties["confidence_score"] - self.assertEqual(confidence_schema["type"], "number") - self.assertEqual(confidence_schema["minimum"], 0.0) - self.assertEqual(confidence_schema["maximum"], 1.0) - self.assertIn("Annotator's confidence", confidence_schema["description"]) - + confidence_schema = properties['confidence_score'] + self.assertEqual(confidence_schema['type'], 'number') + self.assertEqual(confidence_schema['minimum'], 0.0) + self.assertEqual(confidence_schema['maximum'], 1.0) + self.assertIn("Annotator's confidence", confidence_schema['description']) + # annotation_quality should have pattern constraint - quality_schema = properties["annotation_quality"] - self.assertEqual(quality_schema["type"], "string") - self.assertIn("pattern", quality_schema) - + quality_schema = properties['annotation_quality'] + self.assertEqual(quality_schema['type'], 'string') + self.assertIn('pattern', quality_schema) + # time_spent_seconds should have minimum constraint - time_schema = properties["time_spent_seconds"] - self.assertEqual(time_schema["type"], "integer") - self.assertEqual(time_schema["minimum"], 1) - + time_schema = properties['time_spent_seconds'] + self.assertEqual(time_schema['type'], 'integer') + self.assertEqual(time_schema['minimum'], 1) + # tags should be array type - tags_schema = properties["tags"] - self.assertEqual(tags_schema["type"], "array") - self.assertEqual(tags_schema["items"]["type"], "string") - + tags_schema = properties['tags'] + self.assertEqual(tags_schema['type'], 'array') + self.assertEqual(tags_schema['items']['type'], 'string') + # metadata should be object type - metadata_schema = properties["metadata"] - self.assertEqual(metadata_schema["type"], "object") - + metadata_schema = properties['metadata'] + self.assertEqual(metadata_schema['type'], 'object') + # Required fields - required_fields = schema["required"] - self.assertIn("confidence_score", required_fields) - self.assertIn("time_spent_seconds", required_fields) - self.assertNotIn("tags", required_fields) # Optional field - + required_fields = schema['required'] + self.assertIn('confidence_score', required_fields) + self.assertIn('time_spent_seconds', required_fields) + self.assertNotIn('tags', required_fields) # Optional field + # Test schema-driven validation valid_data = { - "confidence_score": 0.85, - "annotation_quality": "good", - "time_spent_seconds": 120, - "difficulty_level": "hard", - "review_requested": True, - "tags": ["important", "complex"], - "metadata": {"tool_version": "1.2.3", "browser": "chrome"} + 'confidence_score': 0.85, + 'annotation_quality': 'good', + 'time_spent_seconds': 120, + 'difficulty_level': 'hard', + 'review_requested': True, + 'tags': ['important', 'complex'], + 'metadata': {'tool_version': '1.2.3', 'browser': 'chrome'}, } - + transition = APIAnnotationSubmissionTransition(**valid_data) self.assertEqual(transition.confidence_score, 0.85) self.assertEqual(len(transition.tags), 2) - + # Print schema for documentation (would be used in API docs) schema_json = json.dumps(schema, indent=2) self.assertIsInstance(schema_json, str) - self.assertIn("confidence_score", schema_json) - + self.assertIn('confidence_score', schema_json) + def test_bulk_operations_api_pattern(self): """ API EXAMPLE: Bulk operations with transitions - + Shows how to handle bulk operations where multiple entities need to be transitioned with the same or different parameters. """ - + @register_transition('task', 'bulk_status_update') class BulkStatusUpdateTransition(BaseTransition): """Bulk status update for multiple tasks""" - new_status: str = Field(..., description="New status for all tasks") - update_reason: str = Field(..., description="Reason for bulk update") - batch_id: str = Field(..., description="Unique identifier for this batch") - force_update: bool = Field(False, description="Force update even if invalid states") - + + new_status: str = Field(..., description='New status for all tasks') + update_reason: str = Field(..., description='Reason for bulk update') + batch_id: str = Field(..., description='Unique identifier for this batch') + force_update: bool = Field(False, description='Force update even if invalid states') + @property def target_state(self) -> str: return self.new_status - + def validate_transition(self, context: TransitionContext) -> bool: - valid_statuses = ["CREATED", "IN_PROGRESS", "COMPLETED", "CANCELLED"] + valid_statuses = ['CREATED', 'IN_PROGRESS', 'COMPLETED', 'CANCELLED'] if self.new_status not in valid_statuses: - raise TransitionValidationError(f"Invalid status: {self.new_status}") - + raise TransitionValidationError(f'Invalid status: {self.new_status}') + # Skip state validation if force update if not self.force_update: if context.current_state == self.new_status: - raise TransitionValidationError("Cannot update to same status") - + raise TransitionValidationError('Cannot update to same status') + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "new_status": self.new_status, - "update_reason": self.update_reason, - "batch_id": self.batch_id, - "force_update": self.force_update, - "updated_at": context.timestamp.isoformat(), - "entity_id": context.entity.pk + 'new_status': self.new_status, + 'update_reason': self.update_reason, + 'batch_id': self.batch_id, + 'force_update': self.force_update, + 'updated_at': context.timestamp.isoformat(), + 'entity_id': context.entity.pk, } - + # Simulate bulk API request bulk_request = { - "task_ids": [1, 2, 3, 4, 5], - "transition_data": { - "new_status": "IN_PROGRESS", - "update_reason": "Project phase change", - "batch_id": "batch_2024_001", - "force_update": False - } + 'task_ids': [1, 2, 3, 4, 5], + 'transition_data': { + 'new_status': 'IN_PROGRESS', + 'update_reason': 'Project phase change', + 'batch_id': 'batch_2024_001', + 'force_update': False, + }, } - + # Process bulk request batch_results = [] failed_updates = [] - - for task_id in bulk_request["task_ids"]: + + for task_id in bulk_request['task_ids']: # Create mock entity for each task mock_task = Mock() mock_task.pk = task_id mock_task._meta.model_name = 'task' - + try: # Create transition - transition = BulkStatusUpdateTransition(**bulk_request["transition_data"]) - + transition = BulkStatusUpdateTransition(**bulk_request['transition_data']) + # Mock different current states for testing - current_states = ["CREATED", "CREATED", "IN_PROGRESS", "CREATED", "COMPLETED"] + current_states = ['CREATED', 'CREATED', 'IN_PROGRESS', 'CREATED', 'COMPLETED'] current_state = current_states[task_id - 1] # Adjust for 0-based indexing - + context = TransitionContext( entity=mock_task, current_user=self.mock_user, current_state=current_state, - target_state=transition.target_state + target_state=transition.target_state, ) - + # Validate and execute if transition.validate_transition(context): result = transition.transition(context) - batch_results.append({ - "task_id": task_id, - "success": True, - "result": result - }) - + batch_results.append({'task_id': task_id, 'success': True, 'result': result}) + except TransitionValidationError as e: - failed_updates.append({ - "task_id": task_id, - "success": False, - "error": str(e), - "context": getattr(e, 'context', {}) - }) - + failed_updates.append( + {'task_id': task_id, 'success': False, 'error': str(e), 'context': getattr(e, 'context', {})} + ) + # API response for bulk operation api_response = { - "batch_id": bulk_request["transition_data"]["batch_id"], - "total_requested": len(bulk_request["task_ids"]), - "successful_updates": len(batch_results), - "failed_updates": len(failed_updates), - "results": batch_results, - "failures": failed_updates, - "timestamp": datetime.now().isoformat() + 'batch_id': bulk_request['transition_data']['batch_id'], + 'total_requested': len(bulk_request['task_ids']), + 'successful_updates': len(batch_results), + 'failed_updates': len(failed_updates), + 'results': batch_results, + 'failures': failed_updates, + 'timestamp': datetime.now().isoformat(), } - + # Validate bulk results - self.assertEqual(api_response["total_requested"], 5) - self.assertGreater(api_response["successful_updates"], 0) - + self.assertEqual(api_response['total_requested'], 5) + self.assertGreater(api_response['successful_updates'], 0) + # Some tasks should succeed, some might fail due to state validation - total_processed = api_response["successful_updates"] + api_response["failed_updates"] + total_processed = api_response['successful_updates'] + api_response['failed_updates'] self.assertEqual(total_processed, 5) - + # Check individual results for result in batch_results: - self.assertTrue(result["success"]) - self.assertEqual(result["result"]["new_status"], "IN_PROGRESS") - self.assertEqual(result["result"]["batch_id"], "batch_2024_001") - + self.assertTrue(result['success']) + self.assertEqual(result['result']['new_status'], 'IN_PROGRESS') + self.assertEqual(result['result']['batch_id'], 'batch_2024_001') + def test_webhook_integration_pattern(self): """ API EXAMPLE: Webhook integration with transitions - + Shows how to integrate transitions with webhook systems for external notifications and integrations. """ - + @register_transition('task', 'webhook_completion') class WebhookTaskCompletionTransition(BaseTransition): """Task completion with webhook notifications""" + completion_quality: float = Field(..., ge=0.0, le=1.0) - completion_notes: str = Field("", description="Completion notes") - webhook_urls: List[str] = Field(default_factory=list, description="Webhook URLs to notify") - notification_data: Dict[str, Any] = Field(default_factory=dict, description="Data to send in webhooks") - webhook_responses: List[Dict[str, Any]] = Field(default_factory=list, description="Webhook response tracking") - + completion_notes: str = Field('', description='Completion notes') + webhook_urls: List[str] = Field(default_factory=list, description='Webhook URLs to notify') + notification_data: Dict[str, Any] = Field(default_factory=dict, description='Data to send in webhooks') + webhook_responses: List[Dict[str, Any]] = Field( + default_factory=list, description='Webhook response tracking' + ) + @property def target_state(self) -> str: - return "COMPLETED" - + return 'COMPLETED' + def validate_transition(self, context: TransitionContext) -> bool: - if context.current_state != "IN_PROGRESS": - raise TransitionValidationError("Can only complete in-progress tasks") + if context.current_state != 'IN_PROGRESS': + raise TransitionValidationError('Can only complete in-progress tasks') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "completion_quality": self.completion_quality, - "completion_notes": self.completion_notes, - "webhook_urls": self.webhook_urls, - "notification_data": self.notification_data, - "completed_at": context.timestamp.isoformat(), - "completed_by_id": context.current_user.id if context.current_user else None + 'completion_quality': self.completion_quality, + 'completion_notes': self.completion_notes, + 'webhook_urls': self.webhook_urls, + 'notification_data': self.notification_data, + 'completed_at': context.timestamp.isoformat(), + 'completed_by_id': context.current_user.id if context.current_user else None, } - + def post_transition_hook(self, context: TransitionContext, state_record) -> None: """Send webhook notifications after successful transition""" if self.webhook_urls: webhook_payload = { - "event": "task.completed", - "task_id": context.entity.pk, - "state_record_id": getattr(state_record, 'id', 'mock-id'), - "completion_data": { - "quality": self.completion_quality, - "notes": self.completion_notes, - "completed_by": context.current_user.id if context.current_user else None, - "completed_at": context.timestamp.isoformat() + 'event': 'task.completed', + 'task_id': context.entity.pk, + 'state_record_id': getattr(state_record, 'id', 'mock-id'), + 'completion_data': { + 'quality': self.completion_quality, + 'notes': self.completion_notes, + 'completed_by': context.current_user.id if context.current_user else None, + 'completed_at': context.timestamp.isoformat(), }, - "custom_data": self.notification_data, - "timestamp": datetime.now().isoformat() + 'custom_data': self.notification_data, + 'timestamp': datetime.now().isoformat(), } - + # Mock webhook sending (in real implementation, use async requests) for url in self.webhook_urls: webhook_response = { - "url": url, - "payload": webhook_payload, - "status": "sent", - "timestamp": datetime.now().isoformat() + 'url': url, + 'payload': webhook_payload, + 'status': 'sent', + 'timestamp': datetime.now().isoformat(), } self.webhook_responses.append(webhook_response) - + # Test webhook transition transition = WebhookTaskCompletionTransition( completion_quality=0.95, - completion_notes="Task completed with excellent quality", + completion_notes='Task completed with excellent quality', webhook_urls=[ - "https://api.example.com/webhooks/task-completed", - "https://notifications.example.com/task-events" + 'https://api.example.com/webhooks/task-completed', + 'https://notifications.example.com/task-events', ], - notification_data={ - "project_id": 123, - "priority": "high", - "client_id": "client_456" - } + notification_data={'project_id': 123, 'priority': 'high', 'client_id': 'client_456'}, ) - + context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="IN_PROGRESS", - target_state=transition.target_state + current_state='IN_PROGRESS', + target_state=transition.target_state, ) - + # Validate and execute self.assertTrue(transition.validate_transition(context)) - result = transition.transition(context) - + transition.transition(context) + # Simulate state record creation mock_state_record = Mock() - mock_state_record.id = "state-uuid-123" - + mock_state_record.id = 'state-uuid-123' + # Execute post-hook (webhook sending) transition.post_transition_hook(context, mock_state_record) - + # Validate webhook responses self.assertEqual(len(transition.webhook_responses), 2) - + for response in transition.webhook_responses: - self.assertIn("url", response) - self.assertIn("payload", response) - self.assertEqual(response["status"], "sent") - + self.assertIn('url', response) + self.assertIn('payload', response) + self.assertEqual(response['status'], 'sent') + # Validate webhook payload structure - payload = response["payload"] - self.assertEqual(payload["event"], "task.completed") - self.assertEqual(payload["task_id"], self.mock_entity.pk) - self.assertEqual(payload["completion_data"]["quality"], 0.95) - self.assertEqual(payload["custom_data"]["project_id"], 123) - + payload = response['payload'] + self.assertEqual(payload['event'], 'task.completed') + self.assertEqual(payload['task_id'], self.mock_entity.pk) + self.assertEqual(payload['completion_data']['quality'], 0.95) + self.assertEqual(payload['custom_data']['project_id'], 123) + def test_api_error_handling_patterns(self): """ API EXAMPLE: Comprehensive error handling patterns - + Shows how to implement robust error handling for API endpoints using the transition system with proper HTTP status codes and messages. """ - + @register_transition('task', 'api_critical_update') class APICriticalUpdateTransition(BaseTransition): """Critical update with extensive validation""" - update_type: str = Field(..., description="Type of critical update") - severity_level: int = Field(..., ge=1, le=5, description="Severity level 1-5") - authorization_token: str = Field(..., description="Authorization token for critical updates") - backup_required: bool = Field(True, description="Whether backup is required before update") - + + update_type: str = Field(..., description='Type of critical update') + severity_level: int = Field(..., ge=1, le=5, description='Severity level 1-5') + authorization_token: str = Field(..., description='Authorization token for critical updates') + backup_required: bool = Field(True, description='Whether backup is required before update') + @property def target_state(self) -> str: - return "CRITICALLY_UPDATED" - + return 'CRITICALLY_UPDATED' + def validate_transition(self, context: TransitionContext) -> bool: errors = [] - + # Authorization check if len(self.authorization_token) < 10: - errors.append("Invalid authorization token") - + errors.append('Invalid authorization token') + # Severity validation if self.severity_level >= 4 and not context.current_user: - errors.append("High severity updates require authenticated user") - + errors.append('High severity updates require authenticated user') + # Update type validation - valid_types = ["security_patch", "critical_fix", "emergency_update"] + valid_types = ['security_patch', 'critical_fix', 'emergency_update'] if self.update_type not in valid_types: - errors.append(f"Invalid update type. Must be one of: {valid_types}") - + errors.append(f'Invalid update type. Must be one of: {valid_types}') + # State validation - if context.current_state in ["COMPLETED", "ARCHIVED"]: - errors.append(f"Cannot perform critical updates on {context.current_state.lower()} tasks") - + if context.current_state in ['COMPLETED', 'ARCHIVED']: + errors.append(f'Cannot perform critical updates on {context.current_state.lower()} tasks') + # Backup requirement if self.backup_required and self.severity_level >= 3: # Mock backup check backup_exists = True # In real implementation, check backup system if not backup_exists: - errors.append("Backup required but not available") - + errors.append('Backup required but not available') + if errors: raise TransitionValidationError( - "Critical update validation failed", + 'Critical update validation failed', { - "validation_errors": errors, - "error_count": len(errors), - "severity_level": self.severity_level, - "update_type": self.update_type - } + 'validation_errors': errors, + 'error_count': len(errors), + 'severity_level': self.severity_level, + 'update_type': self.update_type, + }, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "update_type": self.update_type, - "severity_level": self.severity_level, - "backup_required": self.backup_required, - "authorized_by": context.current_user.id if context.current_user else None, - "updated_at": context.timestamp.isoformat(), - "critical_update_id": f"crit_{int(context.timestamp.timestamp())}" + 'update_type': self.update_type, + 'severity_level': self.severity_level, + 'backup_required': self.backup_required, + 'authorized_by': context.current_user.id if context.current_user else None, + 'updated_at': context.timestamp.isoformat(), + 'critical_update_id': f'crit_{int(context.timestamp.timestamp())}', } - + # Test various error scenarios and API responses - + # 1. Test successful request valid_request = { - "update_type": "security_patch", - "severity_level": 3, - "authorization_token": "valid_token_12345", - "backup_required": True + 'update_type': 'security_patch', + 'severity_level': 3, + 'authorization_token': 'valid_token_12345', + 'backup_required': True, } - - def simulate_api_endpoint(request_data, current_state="IN_PROGRESS"): + + def simulate_api_endpoint(request_data, current_state='IN_PROGRESS'): """Simulate API endpoint with proper error handling""" try: # Parse and validate request transition = APICriticalUpdateTransition(**request_data) - + # Create context context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, current_state=current_state, - target_state=transition.target_state + target_state=transition.target_state, ) - + # Validate business logic transition.validate_transition(context) - + # Execute transition result = transition.transition(context) - + return { - "status_code": 200, - "success": True, - "data": { - "task_id": self.mock_entity.pk, - "new_state": transition.target_state, - "update_details": result - } + 'status_code': 200, + 'success': True, + 'data': { + 'task_id': self.mock_entity.pk, + 'new_state': transition.target_state, + 'update_details': result, + }, } - + except ValueError as e: # Pydantic validation error (400 Bad Request) return { - "status_code": 400, - "success": False, - "error": "Bad Request", - "message": "Invalid request data", - "details": str(e) + 'status_code': 400, + 'success': False, + 'error': 'Bad Request', + 'message': 'Invalid request data', + 'details': str(e), } - + except TransitionValidationError as e: # Business logic validation error (422 Unprocessable Entity) return { - "status_code": 422, - "success": False, - "error": "Validation Failed", - "message": str(e), - "validation_errors": e.context.get("validation_errors", []), - "context": e.context + 'status_code': 422, + 'success': False, + 'error': 'Validation Failed', + 'message': str(e), + 'validation_errors': e.context.get('validation_errors', []), + 'context': e.context, } - + except Exception as e: # Unexpected error (500 Internal Server Error) return { - "status_code": 500, - "success": False, - "error": "Internal Server Error", - "message": "An unexpected error occurred", - "details": str(e) if not isinstance(e, Exception) else "Server error" + 'status_code': 500, + 'success': False, + 'error': 'Internal Server Error', + 'message': 'An unexpected error occurred', + 'details': str(e) if not isinstance(e, Exception) else 'Server error', } - + # Test successful request response = simulate_api_endpoint(valid_request) - self.assertEqual(response["status_code"], 200) - self.assertTrue(response["success"]) - self.assertIn("update_details", response["data"]) - + self.assertEqual(response['status_code'], 200) + self.assertTrue(response['success']) + self.assertIn('update_details', response['data']) + # Test Pydantic validation error (invalid severity level) invalid_request = { - "update_type": "security_patch", - "severity_level": 10, # Invalid: > 5 - "authorization_token": "valid_token_12345" + 'update_type': 'security_patch', + 'severity_level': 10, # Invalid: > 5 + 'authorization_token': 'valid_token_12345', } - + response = simulate_api_endpoint(invalid_request) - self.assertEqual(response["status_code"], 400) - self.assertFalse(response["success"]) - self.assertEqual(response["error"], "Bad Request") - + self.assertEqual(response['status_code'], 400) + self.assertFalse(response['success']) + self.assertEqual(response['error'], 'Bad Request') + # Test business logic validation error business_logic_error_request = { - "update_type": "invalid_type", # Invalid update type - "severity_level": 5, - "authorization_token": "short", # Too short - "backup_required": True + 'update_type': 'invalid_type', # Invalid update type + 'severity_level': 5, + 'authorization_token': 'short', # Too short + 'backup_required': True, } - + response = simulate_api_endpoint(business_logic_error_request) - self.assertEqual(response["status_code"], 422) - self.assertFalse(response["success"]) - self.assertEqual(response["error"], "Validation Failed") - self.assertIn("validation_errors", response) - self.assertGreater(len(response["validation_errors"]), 0) - + self.assertEqual(response['status_code'], 422) + self.assertFalse(response['success']) + self.assertEqual(response['error'], 'Validation Failed') + self.assertIn('validation_errors', response) + self.assertGreater(len(response['validation_errors']), 0) + # Test state validation error - response = simulate_api_endpoint(valid_request, current_state="COMPLETED") - self.assertEqual(response["status_code"], 422) + response = simulate_api_endpoint(valid_request, current_state='COMPLETED') + self.assertEqual(response['status_code'], 422) # The error message is in validation_errors list, not the main message - validation_errors = response.get("validation_errors", []) - self.assertTrue(any("completed tasks" in error for error in validation_errors)) - + validation_errors = response.get('validation_errors', []) + self.assertTrue(any('completed tasks' in error for error in validation_errors)) + def test_api_versioning_and_backward_compatibility(self): """ API EXAMPLE: API versioning with backward compatibility - + Shows how to handle API versioning using transition inheritance and maintain backward compatibility. """ - + # Version 1 API @register_transition('task', 'update_task_v1') class UpdateTaskV1Transition(BaseTransition): """Version 1 task update API""" - status: str = Field(..., description="New task status") - notes: str = Field("", description="Update notes") - + + status: str = Field(..., description='New task status') + notes: str = Field('', description='Update notes') + @property def target_state(self) -> str: return self.status - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "status": self.status, - "notes": self.notes, - "api_version": "v1", - "updated_at": context.timestamp.isoformat() + 'status': self.status, + 'notes': self.notes, + 'api_version': 'v1', + 'updated_at': context.timestamp.isoformat(), } - + # Version 2 API with additional features @register_transition('task', 'update_task_v2') class UpdateTaskV2Transition(UpdateTaskV1Transition): """Version 2 task update API with enhanced features""" - priority: Optional[str] = Field(None, description="Task priority") - tags: List[str] = Field(default_factory=list, description="Task tags") - estimated_hours: Optional[float] = Field(None, ge=0, description="Estimated hours") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata") - + + priority: Optional[str] = Field(None, description='Task priority') + tags: List[str] = Field(default_factory=list, description='Task tags') + estimated_hours: Optional[float] = Field(None, ge=0, description='Estimated hours') + metadata: Dict[str, Any] = Field(default_factory=dict, description='Additional metadata') + def transition(self, context: TransitionContext) -> Dict[str, Any]: # Call parent method for base functionality base_data = super().transition(context) - + # Add V2 specific data v2_data = { - "priority": self.priority, - "tags": self.tags, - "estimated_hours": self.estimated_hours, - "metadata": self.metadata, - "api_version": "v2" + 'priority': self.priority, + 'tags': self.tags, + 'estimated_hours': self.estimated_hours, + 'metadata': self.metadata, + 'api_version': 'v2', } - + return {**base_data, **v2_data} - + # Test V1 API (backward compatibility) - v1_request = { - "status": "IN_PROGRESS", - "notes": "Started working on task" - } - + v1_request = {'status': 'IN_PROGRESS', 'notes': 'Started working on task'} + v1_transition = UpdateTaskV1Transition(**v1_request) context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=v1_transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=v1_transition.target_state ) - + v1_result = v1_transition.transition(context) - self.assertEqual(v1_result["api_version"], "v1") - self.assertEqual(v1_result["status"], "IN_PROGRESS") - self.assertNotIn("priority", v1_result) # V1 doesn't have priority - + self.assertEqual(v1_result['api_version'], 'v1') + self.assertEqual(v1_result['status'], 'IN_PROGRESS') + self.assertNotIn('priority', v1_result) # V1 doesn't have priority + # Test V2 API with enhanced features v2_request = { - "status": "IN_PROGRESS", - "notes": "Started working on task with enhanced tracking", - "priority": "high", - "tags": ["urgent", "client-facing"], - "estimated_hours": 4.5, - "metadata": {"client_id": 123, "project_phase": "development"} + 'status': 'IN_PROGRESS', + 'notes': 'Started working on task with enhanced tracking', + 'priority': 'high', + 'tags': ['urgent', 'client-facing'], + 'estimated_hours': 4.5, + 'metadata': {'client_id': 123, 'project_phase': 'development'}, } - + v2_transition = UpdateTaskV2Transition(**v2_request) v2_result = v2_transition.transition(context) - - self.assertEqual(v2_result["api_version"], "v2") - self.assertEqual(v2_result["status"], "IN_PROGRESS") # Inherited from V1 - self.assertEqual(v2_result["priority"], "high") # V2 feature - self.assertEqual(len(v2_result["tags"]), 2) # V2 feature - self.assertEqual(v2_result["estimated_hours"], 4.5) # V2 feature - self.assertIn("client_id", v2_result["metadata"]) # V2 feature - + + self.assertEqual(v2_result['api_version'], 'v2') + self.assertEqual(v2_result['status'], 'IN_PROGRESS') # Inherited from V1 + self.assertEqual(v2_result['priority'], 'high') # V2 feature + self.assertEqual(len(v2_result['tags']), 2) # V2 feature + self.assertEqual(v2_result['estimated_hours'], 4.5) # V2 feature + self.assertIn('client_id', v2_result['metadata']) # V2 feature + # Test V2 API with minimal data (backward compatible) - v2_minimal_request = { - "status": "COMPLETED", - "notes": "Task finished" - } - + v2_minimal_request = {'status': 'COMPLETED', 'notes': 'Task finished'} + v2_minimal_transition = UpdateTaskV2Transition(**v2_minimal_request) v2_minimal_result = v2_minimal_transition.transition(context) - - self.assertEqual(v2_minimal_result["api_version"], "v2") - self.assertEqual(v2_minimal_result["status"], "COMPLETED") - self.assertIsNone(v2_minimal_result["priority"]) # Optional field - self.assertEqual(v2_minimal_result["tags"], []) # Default value - self.assertIsNone(v2_minimal_result["estimated_hours"]) # Optional field - self.assertEqual(v2_minimal_result["metadata"], {}) # Default value \ No newline at end of file + + self.assertEqual(v2_minimal_result['api_version'], 'v2') + self.assertEqual(v2_minimal_result['status'], 'COMPLETED') + self.assertIsNone(v2_minimal_result['priority']) # Optional field + self.assertEqual(v2_minimal_result['tags'], []) # Default value + self.assertIsNone(v2_minimal_result['estimated_hours']) # Optional field + self.assertEqual(v2_minimal_result['metadata'], {}) # Default value diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 33b83287b972..6555a68d1a37 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -6,38 +6,34 @@ patterns to serve as both tests and documentation. """ -import pytest from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock -from typing import Dict, Any -import json +from typing import Any, Dict +from unittest.mock import Mock, patch +import pytest from django.contrib.auth import get_user_model -from django.test import TestCase, TransactionTestCase -from pydantic import Field, ValidationError - -from fsm.state_choices import TaskStateChoices, AnnotationStateChoices -from fsm.transitions import ( - BaseTransition, - TransitionContext, - TransitionValidationError, - transition_registry, - register_transition -) +from django.test import TestCase +from fsm.state_choices import AnnotationStateChoices, TaskStateChoices from fsm.transition_utils import ( - execute_transition, + TransitionBuilder, get_available_transitions, get_valid_transitions, - TransitionBuilder, - validate_transition_data, - get_transition_schema ) +from fsm.transitions import ( + BaseTransition, + TransitionContext, + TransitionValidationError, + register_transition, + transition_registry, +) +from pydantic import Field, ValidationError User = get_user_model() class MockTask: """Mock task model for testing""" + def __init__(self, pk=1): self.pk = pk self.id = pk @@ -49,6 +45,7 @@ def __init__(self, pk=1): class MockAnnotation: """Mock annotation model for testing""" + def __init__(self, pk=1): self.pk = pk self.id = pk @@ -61,122 +58,106 @@ def __init__(self, pk=1): class TestTransition(BaseTransition): """Test transition class""" - + test_field: str optional_field: int = 42 - + @property def target_state(self) -> str: - return "TEST_STATE" - + return 'TEST_STATE' + @classmethod def get_target_state(cls) -> str: """Return the target state at class level""" - return "TEST_STATE" - + return 'TEST_STATE' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: """Allow transition from any state for testing""" return True - + def validate_transition(self, context: TransitionContext) -> bool: - if self.test_field == "invalid": - raise TransitionValidationError("Test validation error") + if self.test_field == 'invalid': + raise TransitionValidationError('Test validation error') return super().validate_transition(context) - + def transition(self, context: TransitionContext) -> dict: return { - "test_field": self.test_field, - "optional_field": self.optional_field, - "context_entity_id": context.entity.pk + 'test_field': self.test_field, + 'optional_field': self.optional_field, + 'context_entity_id': context.entity.pk, } class DeclarativeTransitionTests(TestCase): """Test cases for the declarative transition system""" - + def setUp(self): self.task = MockTask() self.annotation = MockAnnotation() self.user = Mock() self.user.id = 1 - self.user.username = "testuser" - + self.user.username = 'testuser' + # Register test transition transition_registry.register('task', 'test_transition', TestTransition) - + def test_transition_context_creation(self): """Test creation of transition context""" context = TransitionContext( entity=self.task, current_user=self.user, - current_state="CREATED", - target_state="IN_PROGRESS", - organization_id=1 + current_state='CREATED', + target_state='IN_PROGRESS', + organization_id=1, ) - + self.assertEqual(context.entity, self.task) self.assertEqual(context.current_user, self.user) - self.assertEqual(context.current_state, "CREATED") - self.assertEqual(context.target_state, "IN_PROGRESS") + self.assertEqual(context.current_state, 'CREATED') + self.assertEqual(context.target_state, 'IN_PROGRESS') self.assertEqual(context.organization_id, 1) self.assertFalse(context.is_initial_transition) self.assertTrue(context.has_current_state) - + def test_transition_context_initial_state(self): """Test context for initial state transition""" - context = TransitionContext( - entity=self.task, - current_state=None, - target_state="CREATED" - ) - + context = TransitionContext(entity=self.task, current_state=None, target_state='CREATED') + self.assertTrue(context.is_initial_transition) self.assertFalse(context.has_current_state) - + def test_transition_validation_success(self): """Test successful transition validation""" - transition = TestTransition(test_field="valid") - context = TransitionContext( - entity=self.task, - current_state="CREATED", - target_state=transition.target_state - ) - + transition = TestTransition(test_field='valid') + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) + self.assertTrue(transition.validate_transition(context)) - + def test_transition_validation_failure(self): """Test transition validation failure""" - transition = TestTransition(test_field="invalid") - context = TransitionContext( - entity=self.task, - current_state="CREATED", - target_state=transition.target_state - ) - + transition = TestTransition(test_field='invalid') + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) + with self.assertRaises(TransitionValidationError): transition.validate_transition(context) - + def test_transition_execution(self): """Test transition data generation""" - transition = TestTransition(test_field="test_value", optional_field=100) - context = TransitionContext( - entity=self.task, - current_state="CREATED", - target_state=transition.target_state - ) - + transition = TestTransition(test_field='test_value', optional_field=100) + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) + result = transition.transition(context) - - self.assertEqual(result["test_field"], "test_value") - self.assertEqual(result["optional_field"], 100) - self.assertEqual(result["context_entity_id"], self.task.pk) - + + self.assertEqual(result['test_field'], 'test_value') + self.assertEqual(result['optional_field'], 100) + self.assertEqual(result['context_entity_id'], self.task.pk) + def test_transition_name_generation(self): """Test automatic transition name generation""" - transition = TestTransition(test_field="test") - self.assertEqual(transition.transition_name, "test_transition") - + transition = TestTransition(test_field='test') + self.assertEqual(transition.transition_name, 'test_transition') + @patch('fsm.state_manager.StateManager.transition_state') @patch('fsm.state_manager.StateManager.get_current_state_object') def test_transition_execute_full_workflow(self, mock_get_state, mock_transition): @@ -184,204 +165,187 @@ def test_transition_execute_full_workflow(self, mock_get_state, mock_transition) # Setup mocks mock_get_state.return_value = None # No current state mock_transition.return_value = True - + mock_state_record = Mock() - mock_state_record.id = "test-uuid" - + mock_state_record.id = 'test-uuid' + with patch('fsm.state_manager.StateManager.get_current_state_object', return_value=mock_state_record): - transition = TestTransition(test_field="test_value") + transition = TestTransition(test_field='test_value') context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state=None, - target_state=transition.target_state + entity=self.task, current_user=self.user, current_state=None, target_state=transition.target_state ) - + # Execute transition - result = transition.execute(context) - + transition.execute(context) + # Verify StateManager was called correctly mock_transition.assert_called_once() call_args = mock_transition.call_args - + self.assertEqual(call_args[1]['entity'], self.task) - self.assertEqual(call_args[1]['new_state'], "TEST_STATE") - self.assertEqual(call_args[1]['transition_name'], "test_transition") + self.assertEqual(call_args[1]['new_state'], 'TEST_STATE') + self.assertEqual(call_args[1]['transition_name'], 'test_transition') self.assertEqual(call_args[1]['user'], self.user) - + # Check context data context_data = call_args[1]['context'] - self.assertEqual(context_data['test_field'], "test_value") + self.assertEqual(context_data['test_field'], 'test_value') self.assertEqual(context_data['optional_field'], 42) class TransitionRegistryTests(TestCase): """Test cases for the transition registry""" - + def setUp(self): self.registry = transition_registry - + def test_transition_registration(self): """Test registering transitions""" self.registry.register('test_entity', 'test_transition', TestTransition) - + retrieved = self.registry.get_transition('test_entity', 'test_transition') self.assertEqual(retrieved, TestTransition) - + def test_get_transitions_for_entity(self): """Test getting all transitions for an entity""" self.registry.register('test_entity', 'transition1', TestTransition) self.registry.register('test_entity', 'transition2', TestTransition) - + transitions = self.registry.get_transitions_for_entity('test_entity') - + self.assertIn('transition1', transitions) self.assertIn('transition2', transitions) self.assertEqual(len(transitions), 2) - + def test_list_entities(self): """Test listing registered entities""" self.registry.register('entity1', 'transition1', TestTransition) self.registry.register('entity2', 'transition2', TestTransition) - + entities = self.registry.list_entities() - + self.assertIn('entity1', entities) self.assertIn('entity2', entities) class TransitionUtilsTests(TestCase): """Test cases for transition utility functions""" - + def setUp(self): self.task = MockTask() transition_registry.register('task', 'test_transition', TestTransition) - + def test_get_available_transitions(self): """Test getting available transitions for entity""" transitions = get_available_transitions(self.task) self.assertIn('test_transition', transitions) - + @patch('fsm.state_manager.StateManager.get_current_state_object') def test_get_valid_transitions(self, mock_get_state): """Test filtering valid transitions""" mock_get_state.return_value = None - + valid_transitions = get_valid_transitions(self.task, validate=True) self.assertIn('test_transition', valid_transitions) - + @patch('fsm.state_manager.StateManager.get_current_state_object') def test_get_valid_transitions_with_invalid(self, mock_get_state): """Test filtering out invalid transitions""" mock_get_state.return_value = None - + # Register an invalid transition class InvalidTransition(TestTransition): @classmethod def can_transition_from_state(cls, context): # This transition is never valid at the class level return False - + def validate_transition(self, context): - raise TransitionValidationError("Always invalid") - + raise TransitionValidationError('Always invalid') + transition_registry.register('task', 'invalid_transition', InvalidTransition) - + valid_transitions = get_valid_transitions(self.task, validate=True) self.assertIn('test_transition', valid_transitions) self.assertNotIn('invalid_transition', valid_transitions) - + @patch('fsm.transition_utils.execute_transition') def test_transition_builder(self, mock_execute): """Test fluent transition builder interface""" mock_execute.return_value = Mock() - - result = (TransitionBuilder(self.task) - .transition('test_transition') - .with_data(test_field="builder_test") - .by_user(Mock()) - .with_context(extra="context") - .execute()) - + + ( + TransitionBuilder(self.task) + .transition('test_transition') + .with_data(test_field='builder_test') + .by_user(Mock()) + .with_context(extra='context') + .execute() + ) + mock_execute.assert_called_once() call_args = mock_execute.call_args - + self.assertEqual(call_args[1]['transition_name'], 'test_transition') self.assertEqual(call_args[1]['transition_data']['test_field'], 'builder_test') class ExampleTransitionIntegrationTests(TestCase): """Integration tests using the example transitions""" - + def setUp(self): # Import example transitions to register them - from fsm.example_transitions import ( - StartTaskTransition, - CompleteTaskTransition, - SubmitAnnotationTransition - ) - + self.task = MockTask() self.annotation = MockAnnotation() self.user = Mock() self.user.id = 1 - self.user.username = "testuser" - + self.user.username = 'testuser' + def test_start_task_transition_validation(self): """Test StartTaskTransition validation""" from fsm.example_transitions import StartTaskTransition - + transition = StartTaskTransition(assigned_user_id=123) - + # Test valid transition from CREATED context = TransitionContext( - entity=self.task, - current_state=TaskStateChoices.CREATED, - target_state=transition.target_state + entity=self.task, current_state=TaskStateChoices.CREATED, target_state=transition.target_state ) - + self.assertTrue(transition.validate_transition(context)) - + # Test invalid transition from COMPLETED context.current_state = TaskStateChoices.COMPLETED - + with self.assertRaises(TransitionValidationError): transition.validate_transition(context) - + def test_submit_annotation_validation(self): """Test SubmitAnnotationTransition validation""" from fsm.example_transitions import SubmitAnnotationTransition - + transition = SubmitAnnotationTransition() - + # Test valid transition context = TransitionContext( - entity=self.annotation, - current_state=AnnotationStateChoices.DRAFT, - target_state=transition.target_state + entity=self.annotation, current_state=AnnotationStateChoices.DRAFT, target_state=transition.target_state ) - + self.assertTrue(transition.validate_transition(context)) - + def test_transition_data_generation(self): """Test that transitions generate appropriate context data""" from fsm.example_transitions import StartTaskTransition - - transition = StartTaskTransition( - assigned_user_id=123, - estimated_duration=5, - priority="high" - ) - + + transition = StartTaskTransition(assigned_user_id=123, estimated_duration=5, priority='high') + context = TransitionContext( - entity=self.task, - current_user=self.user, - target_state=transition.target_state, - timestamp=datetime.now() + entity=self.task, current_user=self.user, target_state=transition.target_state, timestamp=datetime.now() ) - + result = transition.transition(context) - + self.assertEqual(result['assigned_user_id'], 123) self.assertEqual(result['estimated_duration'], 5) self.assertEqual(result['priority'], 'high') @@ -392,588 +356,533 @@ def test_transition_data_generation(self): class ComprehensiveUsageExampleTests(TestCase): """ Comprehensive test cases demonstrating various usage patterns. - + These tests serve as both validation and documentation for how to implement and use the declarative transition system. """ - + def setUp(self): self.task = MockTask() self.user = Mock() self.user.id = 123 - self.user.username = "testuser" - + self.user.username = 'testuser' + # Clear registry to avoid conflicts transition_registry._transitions.clear() - + def test_basic_transition_implementation(self): """ USAGE EXAMPLE: Basic transition implementation - + Shows how to implement a simple transition with validation. """ - + class BasicTransition(BaseTransition): """Example: Simple transition with required field""" - message: str = Field(..., description="Message for the transition") - + + message: str = Field(..., description='Message for the transition') + @property def target_state(self) -> str: - return "PROCESSED" - + return 'PROCESSED' + def validate_transition(self, context: TransitionContext) -> bool: # Business logic validation - if context.current_state == "COMPLETED": - raise TransitionValidationError("Cannot process completed items") + if context.current_state == 'COMPLETED': + raise TransitionValidationError('Cannot process completed items') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "message": self.message, - "processed_by": context.current_user.username if context.current_user else "system", - "processed_at": context.timestamp.isoformat() + 'message': self.message, + 'processed_by': context.current_user.username if context.current_user else 'system', + 'processed_at': context.timestamp.isoformat(), } - + # Test the implementation - transition = BasicTransition(message="Processing task") - self.assertEqual(transition.message, "Processing task") - self.assertEqual(transition.target_state, "PROCESSED") - + transition = BasicTransition(message='Processing task') + self.assertEqual(transition.message, 'Processing task') + self.assertEqual(transition.target_state, 'PROCESSED') + # Test validation context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state="CREATED", - target_state=transition.target_state + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state ) - + self.assertTrue(transition.validate_transition(context)) - + # Test data generation data = transition.transition(context) - self.assertEqual(data["message"], "Processing task") - self.assertEqual(data["processed_by"], "testuser") - self.assertIn("processed_at", data) - + self.assertEqual(data['message'], 'Processing task') + self.assertEqual(data['processed_by'], 'testuser') + self.assertIn('processed_at', data) + def test_complex_validation_example(self): """ USAGE EXAMPLE: Complex validation with multiple conditions - + Shows how to implement sophisticated business logic validation. """ - + class TaskAssignmentTransition(BaseTransition): """Example: Complex validation for task assignment""" - assignee_id: int = Field(..., description="User to assign task to") - priority: str = Field("normal", description="Task priority") - deadline: datetime = Field(None, description="Task deadline") - + + assignee_id: int = Field(..., description='User to assign task to') + priority: str = Field('normal', description='Task priority') + deadline: datetime = Field(None, description='Task deadline') + @property def target_state(self) -> str: - return "ASSIGNED" - + return 'ASSIGNED' + def validate_transition(self, context: TransitionContext) -> bool: # Multiple validation conditions - if context.current_state not in ["CREATED", "UNASSIGNED"]: + if context.current_state not in ['CREATED', 'UNASSIGNED']: raise TransitionValidationError( - f"Cannot assign task in state {context.current_state}", - {"current_state": context.current_state, "task_id": context.entity.pk} + f'Cannot assign task in state {context.current_state}', + {'current_state': context.current_state, 'task_id': context.entity.pk}, ) - + # Check deadline is in future if self.deadline and self.deadline <= datetime.now(): raise TransitionValidationError( - "Deadline must be in the future", - {"deadline": self.deadline.isoformat()} + 'Deadline must be in the future', {'deadline': self.deadline.isoformat()} ) - + # Check priority is valid - valid_priorities = ["low", "normal", "high", "urgent"] + valid_priorities = ['low', 'normal', 'high', 'urgent'] if self.priority not in valid_priorities: raise TransitionValidationError( - f"Invalid priority: {self.priority}", - {"valid_priorities": valid_priorities} + f'Invalid priority: {self.priority}', {'valid_priorities': valid_priorities} ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "assignee_id": self.assignee_id, - "priority": self.priority, - "deadline": self.deadline.isoformat() if self.deadline else None, - "assigned_by": context.current_user.id if context.current_user else None, - "assignment_reason": f"Task assigned to user {self.assignee_id}" + 'assignee_id': self.assignee_id, + 'priority': self.priority, + 'deadline': self.deadline.isoformat() if self.deadline else None, + 'assigned_by': context.current_user.id if context.current_user else None, + 'assignment_reason': f'Task assigned to user {self.assignee_id}', } - + # Test valid assignment future_deadline = datetime.now() + timedelta(days=7) - transition = TaskAssignmentTransition( - assignee_id=456, - priority="high", - deadline=future_deadline - ) - + transition = TaskAssignmentTransition(assignee_id=456, priority='high', deadline=future_deadline) + context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state="CREATED", - target_state=transition.target_state + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state ) - + self.assertTrue(transition.validate_transition(context)) - + # Test invalid state - context.current_state = "COMPLETED" + context.current_state = 'COMPLETED' with self.assertRaises(TransitionValidationError) as cm: transition.validate_transition(context) - - self.assertIn("Cannot assign task in state", str(cm.exception)) - self.assertIn("COMPLETED", str(cm.exception)) - + + self.assertIn('Cannot assign task in state', str(cm.exception)) + self.assertIn('COMPLETED', str(cm.exception)) + # Test invalid deadline past_deadline = datetime.now() - timedelta(days=1) - invalid_transition = TaskAssignmentTransition( - assignee_id=456, - deadline=past_deadline - ) - - context.current_state = "CREATED" + invalid_transition = TaskAssignmentTransition(assignee_id=456, deadline=past_deadline) + + context.current_state = 'CREATED' with self.assertRaises(TransitionValidationError) as cm: invalid_transition.validate_transition(context) - - self.assertIn("Deadline must be in the future", str(cm.exception)) - + + self.assertIn('Deadline must be in the future', str(cm.exception)) + def test_hooks_and_lifecycle_example(self): """ USAGE EXAMPLE: Using pre/post hooks for side effects - + Shows how to implement lifecycle hooks for notifications, cleanup, or other side effects. """ - + class NotificationTransition(BaseTransition): """Example: Transition with notification hooks""" - notification_message: str = Field(..., description="Notification message") - notify_users: list = Field(default_factory=list, description="Users to notify") - notifications_sent: list = Field(default_factory=list, description="Track sent notifications") - cleanup_performed: bool = Field(default=False, description="Track cleanup status") - + + notification_message: str = Field(..., description='Notification message') + notify_users: list = Field(default_factory=list, description='Users to notify') + notifications_sent: list = Field(default_factory=list, description='Track sent notifications') + cleanup_performed: bool = Field(default=False, description='Track cleanup status') + @property def target_state(self) -> str: - return "NOTIFIED" - + return 'NOTIFIED' + @classmethod def get_target_state(cls) -> str: - return "NOTIFIED" - + return 'NOTIFIED' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def pre_transition_hook(self, context: TransitionContext) -> None: """Prepare notifications before state change""" # Validate notification recipients if not self.notify_users: self.notify_users = [context.current_user.id] if context.current_user else [] - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "notification_message": self.notification_message, - "notify_users": self.notify_users, - "notification_sent_at": context.timestamp.isoformat() + 'notification_message': self.notification_message, + 'notify_users': self.notify_users, + 'notification_sent_at': context.timestamp.isoformat(), } - + def post_transition_hook(self, context: TransitionContext, state_record) -> None: """Send notifications after successful state change""" # Mock notification sending for user_id in self.notify_users: - self.notifications_sent.append({ - "user_id": user_id, - "message": self.notification_message, - "sent_at": context.timestamp - }) - + self.notifications_sent.append( + {'user_id': user_id, 'message': self.notification_message, 'sent_at': context.timestamp} + ) + # Mock cleanup self.cleanup_performed = True - + # Test the hooks - transition = NotificationTransition( - notification_message="Task has been updated", - notify_users=[123, 456] - ) - + transition = NotificationTransition(notification_message='Task has been updated', notify_users=[123, 456]) + context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state="CREATED", - target_state=transition.target_state + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state ) - + # Test pre-hook transition.pre_transition_hook(context) self.assertEqual(transition.notify_users, [123, 456]) - + # Test transition data = transition.transition(context) - self.assertEqual(data["notification_message"], "Task has been updated") - + self.assertEqual(data['notification_message'], 'Task has been updated') + # Test post-hook mock_state_record = Mock() transition.post_transition_hook(context, mock_state_record) - + self.assertEqual(len(transition.notifications_sent), 2) self.assertTrue(transition.cleanup_performed) - + def test_conditional_transition_example(self): """ USAGE EXAMPLE: Conditional transitions based on data - + Shows how to implement transitions that behave differently based on input data or context. """ - + class ConditionalApprovalTransition(BaseTransition): """Example: Conditional approval based on confidence""" - confidence_score: float = Field(..., ge=0.0, le=1.0, description="Confidence score") - auto_approve_threshold: float = Field(0.9, description="Auto-approval threshold") - reviewer_id: int = Field(None, description="Manual reviewer ID") - + + confidence_score: float = Field(..., ge=0.0, le=1.0, description='Confidence score') + auto_approve_threshold: float = Field(0.9, description='Auto-approval threshold') + reviewer_id: int = Field(None, description='Manual reviewer ID') + @property def target_state(self) -> str: # Dynamic target state based on confidence if self.confidence_score >= self.auto_approve_threshold: - return "AUTO_APPROVED" + return 'AUTO_APPROVED' else: - return "PENDING_REVIEW" - + return 'PENDING_REVIEW' + def validate_transition(self, context: TransitionContext) -> bool: # Different validation based on approval type if self.confidence_score >= self.auto_approve_threshold: # Auto-approval validation - if context.current_state != "SUBMITTED": - raise TransitionValidationError("Can only auto-approve submitted items") + if context.current_state != 'SUBMITTED': + raise TransitionValidationError('Can only auto-approve submitted items') else: # Manual review validation if not self.reviewer_id: - raise TransitionValidationError("Manual review requires reviewer_id") - + raise TransitionValidationError('Manual review requires reviewer_id') + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: base_data = { - "confidence_score": self.confidence_score, - "threshold": self.auto_approve_threshold, + 'confidence_score': self.confidence_score, + 'threshold': self.auto_approve_threshold, } - + if self.confidence_score >= self.auto_approve_threshold: # Auto-approval data return { **base_data, - "approval_type": "automatic", - "approved_at": context.timestamp.isoformat(), - "approved_by": "system" + 'approval_type': 'automatic', + 'approved_at': context.timestamp.isoformat(), + 'approved_by': 'system', } else: # Manual review data return { **base_data, - "approval_type": "manual", - "assigned_reviewer": self.reviewer_id, - "review_requested_at": context.timestamp.isoformat() + 'approval_type': 'manual', + 'assigned_reviewer': self.reviewer_id, + 'review_requested_at': context.timestamp.isoformat(), } - + # Test auto-approval path - high_confidence_transition = ConditionalApprovalTransition( - confidence_score=0.95 - ) - - self.assertEqual(high_confidence_transition.target_state, "AUTO_APPROVED") - + high_confidence_transition = ConditionalApprovalTransition(confidence_score=0.95) + + self.assertEqual(high_confidence_transition.target_state, 'AUTO_APPROVED') + context = TransitionContext( - entity=self.task, - current_state="SUBMITTED", - target_state=high_confidence_transition.target_state + entity=self.task, current_state='SUBMITTED', target_state=high_confidence_transition.target_state ) - + self.assertTrue(high_confidence_transition.validate_transition(context)) - + auto_data = high_confidence_transition.transition(context) - self.assertEqual(auto_data["approval_type"], "automatic") - self.assertEqual(auto_data["approved_by"], "system") - + self.assertEqual(auto_data['approval_type'], 'automatic') + self.assertEqual(auto_data['approved_by'], 'system') + # Test manual review path - low_confidence_transition = ConditionalApprovalTransition( - confidence_score=0.7, - reviewer_id=789 - ) - - self.assertEqual(low_confidence_transition.target_state, "PENDING_REVIEW") - + low_confidence_transition = ConditionalApprovalTransition(confidence_score=0.7, reviewer_id=789) + + self.assertEqual(low_confidence_transition.target_state, 'PENDING_REVIEW') + context.target_state = low_confidence_transition.target_state self.assertTrue(low_confidence_transition.validate_transition(context)) - + manual_data = low_confidence_transition.transition(context) - self.assertEqual(manual_data["approval_type"], "manual") - self.assertEqual(manual_data["assigned_reviewer"], 789) - + self.assertEqual(manual_data['approval_type'], 'manual') + self.assertEqual(manual_data['assigned_reviewer'], 789) + def test_registry_and_decorator_usage(self): """ USAGE EXAMPLE: Using the registry and decorator system - + Shows how to register transitions and use the decorator syntax. """ - + @register_transition('document', 'publish') class PublishDocumentTransition(BaseTransition): """Example: Using the registration decorator""" - publish_immediately: bool = Field(True, description="Publish immediately") - scheduled_time: datetime = Field(None, description="Scheduled publish time") - + + publish_immediately: bool = Field(True, description='Publish immediately') + scheduled_time: datetime = Field(None, description='Scheduled publish time') + @property def target_state(self) -> str: - return "PUBLISHED" if self.publish_immediately else "SCHEDULED" - + return 'PUBLISHED' if self.publish_immediately else 'SCHEDULED' + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "publish_immediately": self.publish_immediately, - "scheduled_time": self.scheduled_time.isoformat() if self.scheduled_time else None, - "published_by": context.current_user.id if context.current_user else None + 'publish_immediately': self.publish_immediately, + 'scheduled_time': self.scheduled_time.isoformat() if self.scheduled_time else None, + 'published_by': context.current_user.id if context.current_user else None, } - + # Test registration worked registered_class = transition_registry.get_transition('document', 'publish') self.assertEqual(registered_class, PublishDocumentTransition) - + # Test getting transitions for entity document_transitions = transition_registry.get_transitions_for_entity('document') self.assertIn('publish', document_transitions) - + # Test execution through registry mock_document = Mock() mock_document.pk = 1 mock_document._meta.model_name = 'document' - + # This would normally go through the full execution workflow - transition_data = {"publish_immediately": False, "scheduled_time": datetime.now() + timedelta(hours=2)} - + transition_data = {'publish_immediately': False, 'scheduled_time': datetime.now() + timedelta(hours=2)} + # Test transition creation and validation transition = PublishDocumentTransition(**transition_data) - self.assertEqual(transition.target_state, "SCHEDULED") + self.assertEqual(transition.target_state, 'SCHEDULED') class ValidationAndErrorHandlingTests(TestCase): """ Tests focused on validation scenarios and error handling. - + These tests demonstrate proper error handling patterns and validation edge cases. """ - + def setUp(self): self.task = MockTask() transition_registry._transitions.clear() - + def test_pydantic_validation_errors(self): """Test Pydantic field validation errors""" - + class StrictValidationTransition(BaseTransition): - required_field: str = Field(..., description="Required field") - email_field: str = Field(..., pattern=r'^[\w\.-]+@[\w\.-]+\.\w+$', description="Valid email") - number_field: int = Field(..., ge=1, le=100, description="Number between 1-100") - + required_field: str = Field(..., description='Required field') + email_field: str = Field(..., pattern=r'^[\w\.-]+@[\w\.-]+\.\w+$', description='Valid email') + number_field: int = Field(..., ge=1, le=100, description='Number between 1-100') + @property def target_state(self) -> str: - return "VALIDATED" - + return 'VALIDATED' + @classmethod def get_target_state(cls) -> str: - return "VALIDATED" - + return 'VALIDATED' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {"validated": True} - + return {'validated': True} + # Test missing required field with self.assertRaises(ValidationError): - StrictValidationTransition(email_field="test@example.com", number_field=50) - + StrictValidationTransition(email_field='test@example.com', number_field=50) + # Test invalid email with self.assertRaises(ValidationError): - StrictValidationTransition( - required_field="test", - email_field="invalid-email", - number_field=50 - ) - + StrictValidationTransition(required_field='test', email_field='invalid-email', number_field=50) + # Test number out of range with self.assertRaises(ValidationError): - StrictValidationTransition( - required_field="test", - email_field="test@example.com", - number_field=150 - ) - + StrictValidationTransition(required_field='test', email_field='test@example.com', number_field=150) + # Test valid data valid_transition = StrictValidationTransition( - required_field="test", - email_field="user@example.com", - number_field=75 + required_field='test', email_field='user@example.com', number_field=75 ) - self.assertEqual(valid_transition.required_field, "test") - + self.assertEqual(valid_transition.required_field, 'test') + def test_business_logic_validation_errors(self): """Test business logic validation with detailed error context""" - + class BusinessRuleTransition(BaseTransition): - amount: float = Field(..., description="Transaction amount") - currency: str = Field("USD", description="Currency code") - + amount: float = Field(..., description='Transaction amount') + currency: str = Field('USD', description='Currency code') + @property def target_state(self) -> str: - return "PROCESSED" - + return 'PROCESSED' + def validate_transition(self, context: TransitionContext) -> bool: # Complex business rule validation errors = [] - + if self.amount <= 0: - errors.append("Amount must be positive") - + errors.append('Amount must be positive') + if self.amount > 10000 and context.current_user is None: - errors.append("Large amounts require authenticated user") - - if self.currency not in ["USD", "EUR", "GBP"]: - errors.append(f"Unsupported currency: {self.currency}") - - if context.current_state == "CANCELLED": - errors.append("Cannot process cancelled transactions") - + errors.append('Large amounts require authenticated user') + + if self.currency not in ['USD', 'EUR', 'GBP']: + errors.append(f'Unsupported currency: {self.currency}') + + if context.current_state == 'CANCELLED': + errors.append('Cannot process cancelled transactions') + if errors: raise TransitionValidationError( f"Validation failed: {'; '.join(errors)}", { - "validation_errors": errors, - "amount": self.amount, - "currency": self.currency, - "current_state": context.current_state - } + 'validation_errors': errors, + 'amount': self.amount, + 'currency': self.currency, + 'current_state': context.current_state, + }, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - "amount": self.amount, - "currency": self.currency - } - - context = TransitionContext( - entity=self.task, - current_state="PENDING", - target_state="PROCESSED" - ) - + return {'amount': self.amount, 'currency': self.currency} + + context = TransitionContext(entity=self.task, current_state='PENDING', target_state='PROCESSED') + # Test negative amount negative_transition = BusinessRuleTransition(amount=-100) with self.assertRaises(TransitionValidationError) as cm: negative_transition.validate_transition(context) - + error = cm.exception - self.assertIn("Amount must be positive", str(error)) - self.assertIn("validation_errors", error.context) - + self.assertIn('Amount must be positive', str(error)) + self.assertIn('validation_errors', error.context) + # Test large amount without user large_transition = BusinessRuleTransition(amount=15000) with self.assertRaises(TransitionValidationError) as cm: large_transition.validate_transition(context) - - self.assertIn("Large amounts require authenticated user", str(cm.exception)) - + + self.assertIn('Large amounts require authenticated user', str(cm.exception)) + # Test invalid currency - invalid_currency_transition = BusinessRuleTransition( - amount=100, - currency="XYZ" - ) + invalid_currency_transition = BusinessRuleTransition(amount=100, currency='XYZ') with self.assertRaises(TransitionValidationError) as cm: invalid_currency_transition.validate_transition(context) - - self.assertIn("Unsupported currency", str(cm.exception)) - + + self.assertIn('Unsupported currency', str(cm.exception)) + # Test multiple errors - multi_error_transition = BusinessRuleTransition( - amount=-50, - currency="XYZ" - ) + multi_error_transition = BusinessRuleTransition(amount=-50, currency='XYZ') with self.assertRaises(TransitionValidationError) as cm: multi_error_transition.validate_transition(context) - + error_msg = str(cm.exception) - self.assertIn("Amount must be positive", error_msg) - self.assertIn("Unsupported currency", error_msg) - + self.assertIn('Amount must be positive', error_msg) + self.assertIn('Unsupported currency', error_msg) + def test_context_validation_errors(self): """Test validation errors related to context state""" - + class ContextAwareTransition(BaseTransition): - action: str = Field(..., description="Action to perform") - + action: str = Field(..., description='Action to perform') + @property def target_state(self) -> str: - return "ACTIONED" - + return 'ACTIONED' + def validate_transition(self, context: TransitionContext) -> bool: # State-dependent validation - if context.is_initial_transition and self.action != "create": + if context.is_initial_transition and self.action != 'create': raise TransitionValidationError( - "Initial transition must be 'create' action", - {"action": self.action, "is_initial": True} + "Initial transition must be 'create' action", {'action': self.action, 'is_initial': True} ) - - if context.current_state == "COMPLETED" and self.action in ["modify", "update"]: + + if context.current_state == 'COMPLETED' and self.action in ['modify', 'update']: raise TransitionValidationError( - f"Cannot {self.action} completed items", - {"action": self.action, "current_state": context.current_state} + f'Cannot {self.action} completed items', + {'action': self.action, 'current_state': context.current_state}, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {"action": self.action} - + return {'action': self.action} + # Test initial transition validation - create_transition = ContextAwareTransition(action="create") + create_transition = ContextAwareTransition(action='create') initial_context = TransitionContext( - entity=self.task, - current_state=None, # No current state = initial - target_state="ACTIONED" + entity=self.task, current_state=None, target_state='ACTIONED' # No current state = initial ) - + self.assertTrue(create_transition.validate_transition(initial_context)) - + # Test invalid initial action - modify_transition = ContextAwareTransition(action="modify") + modify_transition = ContextAwareTransition(action='modify') with self.assertRaises(TransitionValidationError) as cm: modify_transition.validate_transition(initial_context) - + error = cm.exception self.assertIn("Initial transition must be 'create'", str(error)) - self.assertTrue(error.context["is_initial"]) - + self.assertTrue(error.context['is_initial']) + # Test completed state validation - completed_context = TransitionContext( - entity=self.task, - current_state="COMPLETED", - target_state="ACTIONED" - ) - + completed_context = TransitionContext(entity=self.task, current_state='COMPLETED', target_state='ACTIONED') + with self.assertRaises(TransitionValidationError) as cm: modify_transition.validate_transition(completed_context) - - self.assertIn("Cannot modify completed items", str(cm.exception)) + + self.assertIn('Cannot modify completed items', str(cm.exception)) @pytest.fixture @@ -987,32 +896,27 @@ def user(): """Pytest fixture for mock user""" user = Mock() user.id = 1 - user.username = "testuser" + user.username = 'testuser' return user def test_transition_context_properties(task, user): """Test TransitionContext properties using pytest""" - context = TransitionContext( - entity=task, - current_user=user, - current_state="CREATED", - target_state="IN_PROGRESS" - ) - + context = TransitionContext(entity=task, current_user=user, current_state='CREATED', target_state='IN_PROGRESS') + assert context.has_current_state assert not context.is_initial_transition - assert context.current_state == "CREATED" - assert context.target_state == "IN_PROGRESS" + assert context.current_state == 'CREATED' + assert context.target_state == 'IN_PROGRESS' def test_pydantic_validation(): """Test Pydantic validation in transitions""" # Valid data - transition = TestTransition(test_field="valid") - assert transition.test_field == "valid" + transition = TestTransition(test_field='valid') + assert transition.test_field == 'valid' assert transition.optional_field == 42 - + # Invalid data should raise validation error with pytest.raises(Exception): # Pydantic validation error - TestTransition() # Missing required field \ No newline at end of file + TestTransition() # Missing required field diff --git a/label_studio/fsm/tests/test_edge_cases_error_handling.py b/label_studio/fsm/tests/test_edge_cases_error_handling.py index 7b2b87cb2675..17b900336511 100644 --- a/label_studio/fsm/tests/test_edge_cases_error_handling.py +++ b/label_studio/fsm/tests/test_edge_cases_error_handling.py @@ -6,235 +6,213 @@ is robust in production environments. """ -import pytest -from datetime import datetime, timedelta -from unittest.mock import Mock, patch, MagicMock -from typing import Dict, Any, List, Optional +import gc import threading import weakref -import gc +from datetime import datetime +from typing import Any, Dict +from unittest.mock import Mock from django.test import TestCase -from pydantic import Field, ValidationError, validator - -from fsm.transitions import ( - BaseTransition, - TransitionContext, - TransitionValidationError, - transition_registry, - register_transition -) -from fsm.transition_utils import ( - execute_transition, - get_available_transitions, - TransitionBuilder, - validate_transition_data, - create_transition_from_dict -) +from fsm.transition_utils import TransitionBuilder +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, transition_registry +from pydantic import Field, ValidationError class EdgeCaseTransition(BaseTransition): """Test transition for edge case scenarios""" - edge_case_data: Any = Field(None, description="Data for edge case testing") - + + edge_case_data: Any = Field(None, description='Data for edge case testing') + @property def target_state(self) -> str: - return "EDGE_CASE_PROCESSED" - + return 'EDGE_CASE_PROCESSED' + def validate_transition(self, context: TransitionContext) -> bool: # Deliberately minimal validation for edge case testing return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - "edge_case_data": self.edge_case_data, - "processed_at": context.timestamp.isoformat() - } + return {'edge_case_data': self.edge_case_data, 'processed_at': context.timestamp.isoformat()} class ErrorProneTransition(BaseTransition): """Transition designed to test error scenarios""" - should_fail: str = Field("no", description="Controls failure behavior") - failure_stage: str = Field("none", description="Stage at which to fail") - + + should_fail: str = Field('no', description='Controls failure behavior') + failure_stage: str = Field('none', description='Stage at which to fail') + @property def target_state(self) -> str: - return "ERROR_TESTED" - + return 'ERROR_TESTED' + def validate_transition(self, context: TransitionContext) -> bool: - if self.failure_stage == "validation" and self.should_fail == "yes": - raise TransitionValidationError("Intentional validation failure") + if self.failure_stage == 'validation' and self.should_fail == 'yes': + raise TransitionValidationError('Intentional validation failure') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - if self.failure_stage == "transition" and self.should_fail == "yes": - raise RuntimeError("Intentional transition failure") - - return { - "should_fail": self.should_fail, - "failure_stage": self.failure_stage - } + if self.failure_stage == 'transition' and self.should_fail == 'yes': + raise RuntimeError('Intentional transition failure') + + return {'should_fail': self.should_fail, 'failure_stage': self.failure_stage} class EdgeCasesAndErrorHandlingTests(TestCase): """ Comprehensive edge case and error handling tests. - + These tests ensure the transition system handles unusual inputs, boundary conditions, and error scenarios gracefully. """ - + def setUp(self): self.mock_entity = Mock() self.mock_entity.pk = 1 self.mock_entity._meta.model_name = 'test_entity' - + self.mock_user = Mock() self.mock_user.id = 42 - + # Clear registry transition_registry._transitions.clear() transition_registry.register('test_entity', 'edge_case', EdgeCaseTransition) transition_registry.register('test_entity', 'error_prone', ErrorProneTransition) - + def test_none_and_empty_values_handling(self): """ EDGE CASE: Handling None and empty values - + Tests how the system handles None values, empty strings, empty lists, and other "falsy" values. """ - + # Test None values transition_none = EdgeCaseTransition(edge_case_data=None) self.assertIsNone(transition_none.edge_case_data) - + context = TransitionContext( entity=self.mock_entity, current_user=None, # None user current_state=None, # None state (initial) - target_state=transition_none.target_state + target_state=transition_none.target_state, ) - + # Should handle None values gracefully self.assertTrue(transition_none.validate_transition(context)) result = transition_none.transition(context) - self.assertIsNone(result["edge_case_data"]) - + self.assertIsNone(result['edge_case_data']) + # Test empty string values - empty_transition = EdgeCaseTransition(edge_case_data="") + empty_transition = EdgeCaseTransition(edge_case_data='') result = empty_transition.transition(context) - self.assertEqual(result["edge_case_data"], "") - + self.assertEqual(result['edge_case_data'], '') + # Test empty collections empty_list_transition = EdgeCaseTransition(edge_case_data=[]) result = empty_list_transition.transition(context) - self.assertEqual(result["edge_case_data"], []) - + self.assertEqual(result['edge_case_data'], []) + empty_dict_transition = EdgeCaseTransition(edge_case_data={}) result = empty_dict_transition.transition(context) - self.assertEqual(result["edge_case_data"], {}) - + self.assertEqual(result['edge_case_data'], {}) + # Test zero values zero_transition = EdgeCaseTransition(edge_case_data=0) result = zero_transition.transition(context) - self.assertEqual(result["edge_case_data"], 0) - + self.assertEqual(result['edge_case_data'], 0) + # Test False boolean false_transition = EdgeCaseTransition(edge_case_data=False) result = false_transition.transition(context) - self.assertFalse(result["edge_case_data"]) - + self.assertFalse(result['edge_case_data']) + def test_extreme_data_sizes(self): """ EDGE CASE: Handling extremely large or small data - + Tests system behavior with very large strings, deep nested structures, and other extreme data sizes. """ - + # Test very large string - large_string = "x" * 10000 # 10KB string + large_string = 'x' * 10000 # 10KB string large_string_transition = EdgeCaseTransition(edge_case_data=large_string) - + context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=large_string_transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=large_string_transition.target_state ) - + result = large_string_transition.transition(context) - self.assertEqual(len(result["edge_case_data"]), 10000) - + self.assertEqual(len(result['edge_case_data']), 10000) + # Test deeply nested dictionary - deep_dict = {"level": 0} + deep_dict = {'level': 0} current_level = deep_dict for i in range(100): # 100 levels deep - current_level["next"] = {"level": i + 1} - current_level = current_level["next"] - + current_level['next'] = {'level': i + 1} + current_level = current_level['next'] + deep_dict_transition = EdgeCaseTransition(edge_case_data=deep_dict) result = deep_dict_transition.transition(context) - self.assertEqual(result["edge_case_data"]["level"], 0) - + self.assertEqual(result['edge_case_data']['level'], 0) + # Test large list large_list = list(range(1000)) # 1000 items large_list_transition = EdgeCaseTransition(edge_case_data=large_list) result = large_list_transition.transition(context) - self.assertEqual(len(result["edge_case_data"]), 1000) - self.assertEqual(result["edge_case_data"][-1], 999) - + self.assertEqual(len(result['edge_case_data']), 1000) + self.assertEqual(result['edge_case_data'][-1], 999) + def test_unicode_and_special_characters(self): """ EDGE CASE: Unicode and special character handling - + Tests handling of various Unicode characters, emojis, control characters, and other special strings. """ - + test_cases = [ # Unicode characters - "Hello, 世界! 🌍", + 'Hello, 世界! 🌍', # Emojis - "Task completed! 🎉✅👍", + 'Task completed! 🎉✅👍', # Special symbols - "Price: €100 → $120 ≈ £95", + 'Price: €100 → $120 ≈ £95', # Mathematical symbols - "∑(1,2,3) = 6, √16 = 4, π ≈ 3.14", + '∑(1,2,3) = 6, √16 = 4, π ≈ 3.14', # Control characters (escaped) - "Line1\nLine2\tTabbed\r\nWindows", + 'Line1\nLine2\tTabbed\r\nWindows', # JSON-like string '{"key": "value", "number": 42}', # SQL-like string (potential injection test) "'; DROP TABLE users; --", # Empty and whitespace - " \t\n\r ", + ' \t\n\r ', # Very long Unicode - "🌟" * 100, + '🌟' * 100, ] - + context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state="EDGE_CASE_PROCESSED" + entity=self.mock_entity, current_state='CREATED', target_state='EDGE_CASE_PROCESSED' ) - + for test_string in test_cases: - with self.subTest(test_string=test_string[:20] + "..."): + with self.subTest(test_string=test_string[:20] + '...'): transition = EdgeCaseTransition(edge_case_data=test_string) - + # Should handle any Unicode string result = transition.transition(context) - self.assertEqual(result["edge_case_data"], test_string) - + self.assertEqual(result['edge_case_data'], test_string) + def test_boundary_datetime_values(self): """ EDGE CASE: Boundary datetime values - + Tests handling of edge case datetime values like far future, far past, timezone edge cases, etc. """ - + boundary_datetimes = [ # Far past datetime(1970, 1, 1), @@ -251,251 +229,225 @@ def test_boundary_datetime_values(self): # Microsecond precision datetime(2024, 1, 1, 12, 0, 0, 123456), ] - + for test_datetime in boundary_datetimes: with self.subTest(datetime=test_datetime.isoformat()): context = TransitionContext( entity=self.mock_entity, - current_state="CREATED", - target_state="EDGE_CASE_PROCESSED", - timestamp=test_datetime + current_state='CREATED', + target_state='EDGE_CASE_PROCESSED', + timestamp=test_datetime, ) - - transition = EdgeCaseTransition(edge_case_data="datetime_test") - + + transition = EdgeCaseTransition(edge_case_data='datetime_test') + # Should handle any valid datetime result = transition.transition(context) - self.assertEqual(result["processed_at"], test_datetime.isoformat()) - + self.assertEqual(result['processed_at'], test_datetime.isoformat()) + def test_circular_reference_handling(self): """ EDGE CASE: Circular references and complex object graphs - + Tests how the system handles objects with circular references or complex interdependencies. """ - + # Create circular reference structure - circular_dict = {"name": "parent"} - circular_dict["child"] = {"name": "child", "parent": circular_dict} - + circular_dict = {'name': 'parent'} + circular_dict['child'] = {'name': 'child', 'parent': circular_dict} + # Test that the system can handle circular references without infinite recursion # Pydantic with field type 'Any' will accept circular references try: transition = EdgeCaseTransition(edge_case_data=circular_dict) # Verify that the circular reference was stored - self.assertEqual(transition.edge_case_data["name"], "parent") - self.assertEqual(transition.edge_case_data["child"]["name"], "child") + self.assertEqual(transition.edge_case_data['name'], 'parent') + self.assertEqual(transition.edge_case_data['child']['name'], 'child') # The system should handle this gracefully except RecursionError: - self.fail("System should handle circular references without infinite recursion") - + self.fail('System should handle circular references without infinite recursion') + # Test with complex but non-circular structure complex_structure = { - "level1": { - "level2": { - "level3": { - "data": "deep_value", - "references": ["ref1", "ref2", "ref3"] * 10 - } - } - }, - "cross_reference": None # Will be set to level1 later, but not circular + 'level1': {'level2': {'level3': {'data': 'deep_value', 'references': ['ref1', 'ref2', 'ref3'] * 10}}}, + 'cross_reference': None, # Will be set to level1 later, but not circular } - + transition = EdgeCaseTransition(edge_case_data=complex_structure) context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=transition.target_state ) - + result = transition.transition(context) - self.assertEqual( - result["edge_case_data"]["level1"]["level2"]["level3"]["data"], - "deep_value" - ) - + self.assertEqual(result['edge_case_data']['level1']['level2']['level3']['data'], 'deep_value') + def test_memory_pressure_and_cleanup(self): """ EDGE CASE: Memory pressure and garbage collection - + Tests system behavior under memory pressure and ensures proper cleanup of transition instances and contexts. """ - + transitions = [] contexts = [] weak_refs = [] - + # Create many transition instances for i in range(1000): - transition = EdgeCaseTransition(edge_case_data=f"data_{i}") + transition = EdgeCaseTransition(edge_case_data=f'data_{i}') context = TransitionContext( entity=self.mock_entity, - current_state="CREATED", + current_state='CREATED', target_state=transition.target_state, - metadata={"iteration": i} + metadata={'iteration': i}, ) - + transitions.append(transition) contexts.append(context) - + # Create weak references to test garbage collection if i < 10: # Only for first few to avoid too many weak refs weak_refs.append(weakref.ref(transition)) weak_refs.append(weakref.ref(context)) - + # Verify all were created self.assertEqual(len(transitions), 1000) self.assertEqual(len(contexts), 1000) - + # Clear references and force garbage collection transitions.clear() contexts.clear() gc.collect() - + # Check that objects can be garbage collected # Some weak references should be None after GC - none_count = sum(1 for ref in weak_refs if ref() is None) + sum(1 for ref in weak_refs if ref() is None) # At least some should be collected (this is implementation dependent) # We don't require all to be collected due to Python GC behavior - + # Test that new instances can still be created after cleanup - new_transition = EdgeCaseTransition(edge_case_data="after_cleanup") + new_transition = EdgeCaseTransition(edge_case_data='after_cleanup') new_context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=new_transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=new_transition.target_state ) - + result = new_transition.transition(new_context) - self.assertEqual(result["edge_case_data"], "after_cleanup") - + self.assertEqual(result['edge_case_data'], 'after_cleanup') + def test_exception_during_validation(self): """ ERROR HANDLING: Exceptions during validation - + Tests proper handling of various types of exceptions that can occur during transition validation. """ - + class ValidationErrorTransition(BaseTransition): - error_type: str = Field(..., description="Type of error to raise") - + error_type: str = Field(..., description='Type of error to raise') + @property def target_state(self) -> str: - return "ERROR_STATE" - + return 'ERROR_STATE' + @classmethod def get_target_state(cls) -> str: - return "ERROR_STATE" - + return 'ERROR_STATE' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def validate_transition(self, context: TransitionContext) -> bool: - if self.error_type == "transition_validation": - raise TransitionValidationError("Business rule violation") - elif self.error_type == "runtime_error": - raise RuntimeError("Unexpected runtime error") - elif self.error_type == "key_error": - raise KeyError("Missing required key") - elif self.error_type == "attribute_error": - raise AttributeError("Missing attribute") - elif self.error_type == "value_error": - raise ValueError("Invalid value provided") - elif self.error_type == "type_error": - raise TypeError("Type mismatch") + if self.error_type == 'transition_validation': + raise TransitionValidationError('Business rule violation') + elif self.error_type == 'runtime_error': + raise RuntimeError('Unexpected runtime error') + elif self.error_type == 'key_error': + raise KeyError('Missing required key') + elif self.error_type == 'attribute_error': + raise AttributeError('Missing attribute') + elif self.error_type == 'value_error': + raise ValueError('Invalid value provided') + elif self.error_type == 'type_error': + raise TypeError('Type mismatch') return True - + def transition(self, context: TransitionContext) -> dict: - return {"error_type": self.error_type, "processed": True} - - context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state="ERROR_STATE" - ) - + return {'error_type': self.error_type, 'processed': True} + + context = TransitionContext(entity=self.mock_entity, current_state='CREATED', target_state='ERROR_STATE') + # Test TransitionValidationError (expected) - transition = ValidationErrorTransition(error_type="transition_validation") + transition = ValidationErrorTransition(error_type='transition_validation') with self.assertRaises(TransitionValidationError) as cm: transition.validate_transition(context) - self.assertIn("Business rule violation", str(cm.exception)) - + self.assertIn('Business rule violation', str(cm.exception)) + # Test other exceptions (should bubble up) error_types = [ - ("runtime_error", RuntimeError), - ("key_error", KeyError), - ("attribute_error", AttributeError), - ("value_error", ValueError), - ("type_error", TypeError), + ('runtime_error', RuntimeError), + ('key_error', KeyError), + ('attribute_error', AttributeError), + ('value_error', ValueError), + ('type_error', TypeError), ] - + for error_type, exception_class in error_types: with self.subTest(error_type=error_type): transition = ValidationErrorTransition(error_type=error_type) with self.assertRaises(exception_class): transition.validate_transition(context) - + def test_exception_during_transition_execution(self): """ ERROR HANDLING: Exceptions during transition execution - + Tests handling of exceptions that occur during the actual transition execution phase. """ - + # Test with ErrorProneTransition - context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state="ERROR_TESTED" - ) - + context = TransitionContext(entity=self.mock_entity, current_state='CREATED', target_state='ERROR_TESTED') + # Test successful execution - success_transition = ErrorProneTransition(should_fail="no") + success_transition = ErrorProneTransition(should_fail='no') result = success_transition.transition(context) - self.assertEqual(result["should_fail"], "no") - + self.assertEqual(result['should_fail'], 'no') + # Test intentional failure - fail_transition = ErrorProneTransition( - should_fail="yes", - failure_stage="transition" - ) - + fail_transition = ErrorProneTransition(should_fail='yes', failure_stage='transition') + with self.assertRaises(RuntimeError) as cm: fail_transition.transition(context) - self.assertIn("Intentional transition failure", str(cm.exception)) - + self.assertIn('Intentional transition failure', str(cm.exception)) + def test_registry_edge_cases(self): """ EDGE CASE: Registry edge cases - + Tests unusual registry operations and edge cases like duplicate registrations, invalid names, etc. """ - + # Test duplicate registration (should overwrite) - original_class = EdgeCaseTransition - + class NewEdgeCaseTransition(BaseTransition): @property def target_state(self) -> str: - return "NEW_EDGE_CASE" - + return 'NEW_EDGE_CASE' + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {"type": "new_implementation"} - + return {'type': 'new_implementation'} + # Register with same name transition_registry.register('test_entity', 'edge_case', NewEdgeCaseTransition) - + # Should get new class retrieved = transition_registry.get_transition('test_entity', 'edge_case') self.assertEqual(retrieved, NewEdgeCaseTransition) - + # Test registration with unusual names unusual_names = [ ('entity-with-dashes', 'transition-name'), @@ -504,265 +456,251 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ('entity123', 'transition456'), ('UPPERCASE_ENTITY', 'UPPERCASE_TRANSITION'), ] - + for entity_name, transition_name in unusual_names: with self.subTest(entity=entity_name, transition=transition_name): transition_registry.register(entity_name, transition_name, EdgeCaseTransition) retrieved = transition_registry.get_transition(entity_name, transition_name) self.assertEqual(retrieved, EdgeCaseTransition) - + # Test nonexistent lookups self.assertIsNone(transition_registry.get_transition('nonexistent', 'transition')) self.assertIsNone(transition_registry.get_transition('test_entity', 'nonexistent')) - + # Test empty entity transitions empty_transitions = transition_registry.get_transitions_for_entity('nonexistent_entity') self.assertEqual(empty_transitions, {}) - + def test_context_edge_cases(self): """ EDGE CASE: TransitionContext edge cases - + Tests unusual context configurations and edge cases in context creation and usage. """ - + # Test context with None entity (system should handle gracefully) # Since entity field is typed as Any, None is accepted try: - context = TransitionContext( - entity=None, - current_state="CREATED", - target_state="PROCESSED" - ) + context = TransitionContext(entity=None, current_state='CREATED', target_state='PROCESSED') # Verify context was created with None entity self.assertIsNone(context.entity) - self.assertEqual(context.current_state, "CREATED") + self.assertEqual(context.current_state, 'CREATED') except Exception as e: - self.fail(f"Context creation with None entity should not fail: {e}") - + self.fail(f'Context creation with None entity should not fail: {e}') + # Test context with missing required fields with self.assertRaises(ValidationError): TransitionContext( entity=self.mock_entity, # Missing target_state ) - + # Test context with extreme timestamp far_future = datetime(3000, 1, 1) context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state="PROCESSED", - timestamp=far_future + entity=self.mock_entity, current_state='CREATED', target_state='PROCESSED', timestamp=far_future ) - + self.assertEqual(context.timestamp, far_future) - + # Test context with large metadata - large_metadata = {f"key_{i}": f"value_{i}" for i in range(1000)} + large_metadata = {f'key_{i}': f'value_{i}' for i in range(1000)} context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state="PROCESSED", - metadata=large_metadata + entity=self.mock_entity, current_state='CREATED', target_state='PROCESSED', metadata=large_metadata ) - + self.assertEqual(len(context.metadata), 1000) - + # Test context property edge cases empty_context = TransitionContext( - entity=self.mock_entity, - current_state="", # Empty string state - target_state="PROCESSED" + entity=self.mock_entity, current_state='', target_state='PROCESSED' # Empty string state ) - + # Empty string should be considered "has state" self.assertTrue(empty_context.has_current_state) self.assertFalse(empty_context.is_initial_transition) - + def test_transition_builder_edge_cases(self): """ EDGE CASE: TransitionBuilder edge cases - + Tests unusual usage patterns and edge cases with the fluent TransitionBuilder interface. """ - + builder = TransitionBuilder(self.mock_entity) - + # Test validation without setting transition name with self.assertRaises(ValueError) as cm: builder.validate() - self.assertIn("Transition name not specified", str(cm.exception)) - + self.assertIn('Transition name not specified', str(cm.exception)) + # Test execution without setting transition name with self.assertRaises(ValueError) as cm: builder.execute() - self.assertIn("Transition name not specified", str(cm.exception)) - + self.assertIn('Transition name not specified', str(cm.exception)) + # Test with nonexistent transition builder.transition('nonexistent_transition') - + with self.assertRaises(ValueError) as cm: builder.validate() - self.assertIn("not found", str(cm.exception)) - + self.assertIn('not found', str(cm.exception)) + # Test method chaining edge cases - builder = (TransitionBuilder(self.mock_entity) - .transition('edge_case') - .with_data() # Empty data - .by_user(None) # No user - .with_context()) # Empty context - + builder = ( + TransitionBuilder(self.mock_entity) + .transition('edge_case') + .with_data() # Empty data + .by_user(None) # No user + .with_context() + ) # Empty context + # Should not raise errors for empty data errors = builder.validate() self.assertEqual(errors, {}) # EdgeCaseTransition has no required fields - + # Test data overwriting - builder = (TransitionBuilder(self.mock_entity) - .transition('edge_case') - .with_data(edge_case_data="first") - .with_data(edge_case_data="second")) # Should overwrite - + builder = ( + TransitionBuilder(self.mock_entity) + .transition('edge_case') + .with_data(edge_case_data='first') + .with_data(edge_case_data='second') + ) # Should overwrite + errors = builder.validate() self.assertEqual(errors, {}) - + def test_concurrent_error_scenarios(self): """ EDGE CASE: Error handling under concurrency - + Tests error handling when multiple threads encounter errors simultaneously. """ - + error_results = [] error_lock = threading.Lock() - + def error_worker(worker_id): """Worker that intentionally triggers errors""" try: # Create transition that will fail transition = ErrorProneTransition( - should_fail="yes", - failure_stage="validation" if worker_id % 2 == 0 else "transition" + should_fail='yes', failure_stage='validation' if worker_id % 2 == 0 else 'transition' ) - + context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=transition.target_state ) - + if worker_id % 2 == 0: # Validation error transition.validate_transition(context) else: # Transition execution error transition.transition(context) - + except Exception as e: with error_lock: - error_results.append({ - "worker_id": worker_id, - "error_type": type(e).__name__, - "error_message": str(e) - }) - + error_results.append( + {'worker_id': worker_id, 'error_type': type(e).__name__, 'error_message': str(e)} + ) + # Run multiple workers that will all fail threads = [] for i in range(10): thread = threading.Thread(target=error_worker, args=(i,)) threads.append(thread) thread.start() - + # Wait for all to complete for thread in threads: thread.join() - + # Should have 10 errors self.assertEqual(len(error_results), 10) - + # Verify error types - validation_errors = [r for r in error_results if r["error_type"] == "TransitionValidationError"] - runtime_errors = [r for r in error_results if r["error_type"] == "RuntimeError"] - + validation_errors = [r for r in error_results if r['error_type'] == 'TransitionValidationError'] + runtime_errors = [r for r in error_results if r['error_type'] == 'RuntimeError'] + self.assertEqual(len(validation_errors), 5) # Even worker IDs self.assertEqual(len(runtime_errors), 5) # Odd worker IDs - + def test_resource_cleanup_after_errors(self): """ EDGE CASE: Resource cleanup after errors - + Tests that resources are properly cleaned up even when transitions fail partway through. """ - + class ResourceTrackingTransition(BaseTransition): """Transition that tracks resource allocation""" - resource_name: str = Field(..., description="Name of resource") - resources_allocated: list = Field(default_factory=list, description="Track allocated resources") - resources_cleaned: list = Field(default_factory=list, description="Track cleaned resources") - + + resource_name: str = Field(..., description='Name of resource') + resources_allocated: list = Field(default_factory=list, description='Track allocated resources') + resources_cleaned: list = Field(default_factory=list, description='Track cleaned resources') + @property def target_state(self) -> str: - return "RESOURCE_PROCESSED" - + return 'RESOURCE_PROCESSED' + @classmethod def get_target_state(cls) -> str: - return "RESOURCE_PROCESSED" - + return 'RESOURCE_PROCESSED' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def validate_transition(self, context: TransitionContext) -> bool: # Allocate some mock resources - resource = f"resource_{self.resource_name}" + resource = f'resource_{self.resource_name}' self.resources_allocated.append(resource) - + # Fail if resource name contains "fail" - if "fail" in self.resource_name: - raise TransitionValidationError("Resource allocation failed") - + if 'fail' in self.resource_name: + raise TransitionValidationError('Resource allocation failed') + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {"resource_name": self.resource_name} - + return {'resource_name': self.resource_name} + def __del__(self): # Mock cleanup in destructor for resource in self.resources_allocated: if resource not in self.resources_cleaned: self.resources_cleaned.append(resource) - + # Test successful case - success_transition = ResourceTrackingTransition(resource_name="success") + success_transition = ResourceTrackingTransition(resource_name='success') context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=success_transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=success_transition.target_state ) - + self.assertTrue(success_transition.validate_transition(context)) self.assertEqual(len(success_transition.resources_allocated), 1) - + # Test failure case - fail_transition = ResourceTrackingTransition(resource_name="fail_test") - + fail_transition = ResourceTrackingTransition(resource_name='fail_test') + with self.assertRaises(TransitionValidationError): fail_transition.validate_transition(context) - + # Resources should still be allocated even though validation failed self.assertEqual(len(fail_transition.resources_allocated), 1) - + # Force garbage collection to trigger cleanup - success_ref = weakref.ref(success_transition) - fail_ref = weakref.ref(fail_transition) - + weakref.ref(success_transition) + weakref.ref(fail_transition) + del success_transition del fail_transition gc.collect() - + # References should be cleaned up - # (Note: This test is somewhat implementation-dependent) \ No newline at end of file + # (Note: This test is somewhat implementation-dependent) diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py index 0b7893b7f33d..b7c277b0d25b 100644 --- a/label_studio/fsm/tests/test_integration_django_models.py +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -5,33 +5,23 @@ Django models and the StateManager, providing realistic usage examples. """ -import pytest -from datetime import datetime, timedelta +from datetime import datetime +from typing import Any, Dict from unittest.mock import Mock, patch -from typing import Dict, Any -from django.test import TestCase, TransactionTestCase from django.contrib.auth import get_user_model +from django.test import TestCase +from fsm.models import TaskState +from fsm.state_choices import AnnotationStateChoices, TaskStateChoices +from fsm.transition_utils import TransitionBuilder +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, register_transition from pydantic import Field -from fsm.models import TaskState, AnnotationState, get_state_model_for_entity -from fsm.state_choices import TaskStateChoices, AnnotationStateChoices -from fsm.state_manager import StateManager -from fsm.transitions import ( - BaseTransition, - TransitionContext, - TransitionValidationError, - register_transition -) -from fsm.transition_utils import ( - execute_transition, - get_available_transitions, - TransitionBuilder -) # Mock Django models for integration testing class MockDjangoTask: """Mock Django Task model with realistic attributes""" + def __init__(self, pk=1, project_id=1, organization_id=1): self.pk = pk self.id = pk @@ -40,15 +30,16 @@ def __init__(self, pk=1, project_id=1, organization_id=1): self._meta = Mock() self._meta.model_name = 'task' self._meta.label_lower = 'tasks.task' - + # Mock task attributes - self.data = {"text": "Sample task data"} + self.data = {'text': 'Sample task data'} self.created_at = datetime.now() self.updated_at = datetime.now() class MockDjangoAnnotation: """Mock Django Annotation model with realistic attributes""" + def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): self.pk = pk self.id = pk @@ -58,9 +49,9 @@ def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): self._meta = Mock() self._meta.model_name = 'annotation' self._meta.label_lower = 'tasks.annotation' - + # Mock annotation attributes - self.result = [{"value": {"text": ["Sample annotation"]}}] + self.result = [{'value': {'text': ['Sample annotation']}}] self.completed_by_id = None self.created_at = datetime.now() self.updated_at = datetime.now() @@ -72,611 +63,609 @@ def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): class DjangoModelIntegrationTests(TestCase): """ Integration tests demonstrating realistic usage with Django models. - + These tests show how to implement transitions that work with actual Django model patterns and the StateManager integration. """ - + def setUp(self): self.task = MockDjangoTask() - self.annotation = MockDjangoAnnotation() + self.annotation = MockDjangoAnnotation() self.user = Mock() self.user.id = 123 - self.user.username = "integration_test_user" - + self.user.username = 'integration_test_user' + # Clear registry for clean test state from fsm.transitions import transition_registry + transition_registry._transitions.clear() - + @patch('fsm.models.get_state_model_for_entity') @patch('fsm.state_manager.StateManager.get_current_state_object') @patch('fsm.state_manager.StateManager.transition_state') def test_task_workflow_integration(self, mock_transition_state, mock_get_state_obj, mock_get_state_model): """ INTEGRATION TEST: Complete task workflow using Django models - + Demonstrates a realistic task lifecycle from creation through completion using the declarative transition system with Django model integration. """ - + # Setup mocks to simulate Django model behavior mock_get_state_model.return_value = TaskState mock_get_state_obj.return_value = None # No existing state (initial transition) mock_transition_state.return_value = True - + # Define task workflow transitions @register_transition('task', 'create_task') class CreateTaskTransition(BaseTransition): """Initial task creation transition""" - created_by_id: int = Field(..., description="User creating the task") - initial_priority: str = Field("normal", description="Initial task priority") - + + created_by_id: int = Field(..., description='User creating the task') + initial_priority: str = Field('normal', description='Initial task priority') + @property def target_state(self) -> str: return TaskStateChoices.CREATED - + def validate_transition(self, context: TransitionContext) -> bool: # Validate initial creation if not context.is_initial_transition: - raise TransitionValidationError("CreateTask can only be used for initial state") + raise TransitionValidationError('CreateTask can only be used for initial state') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "created_by_id": self.created_by_id, - "initial_priority": self.initial_priority, - "task_data": getattr(context.entity, 'data', {}), - "project_id": getattr(context.entity, 'project_id', None), - "creation_method": "declarative_transition" + 'created_by_id': self.created_by_id, + 'initial_priority': self.initial_priority, + 'task_data': getattr(context.entity, 'data', {}), + 'project_id': getattr(context.entity, 'project_id', None), + 'creation_method': 'declarative_transition', } - + @register_transition('task', 'assign_and_start') class AssignAndStartTaskTransition(BaseTransition): """Assign task to user and start work""" - assignee_id: int = Field(..., description="User assigned to task") - estimated_hours: float = Field(None, ge=0.1, description="Estimated work hours") - priority: str = Field("normal", description="Task priority") - + + assignee_id: int = Field(..., description='User assigned to task') + estimated_hours: float = Field(None, ge=0.1, description='Estimated work hours') + priority: str = Field('normal', description='Task priority') + @property def target_state(self) -> str: return TaskStateChoices.IN_PROGRESS - + def validate_transition(self, context: TransitionContext) -> bool: valid_from_states = [TaskStateChoices.CREATED] if context.current_state not in valid_from_states: raise TransitionValidationError( - f"Can only assign tasks from states: {valid_from_states}", - {"current_state": context.current_state, "valid_states": valid_from_states} + f'Can only assign tasks from states: {valid_from_states}', + {'current_state': context.current_state, 'valid_states': valid_from_states}, ) - + # Business rule: Can't assign to the same user who created it if hasattr(context, 'current_state_object') and context.current_state_object: creator_id = context.current_state_object.context_data.get('created_by_id') if creator_id == self.assignee_id: raise TransitionValidationError( - "Cannot assign task to the same user who created it", - {"creator_id": creator_id, "assignee_id": self.assignee_id} + 'Cannot assign task to the same user who created it', + {'creator_id': creator_id, 'assignee_id': self.assignee_id}, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "assignee_id": self.assignee_id, - "estimated_hours": self.estimated_hours, - "priority": self.priority, - "assigned_at": context.timestamp.isoformat(), - "assigned_by_id": context.current_user.id if context.current_user else None, - "work_started": True + 'assignee_id': self.assignee_id, + 'estimated_hours': self.estimated_hours, + 'priority': self.priority, + 'assigned_at': context.timestamp.isoformat(), + 'assigned_by_id': context.current_user.id if context.current_user else None, + 'work_started': True, } - + @register_transition('task', 'complete_with_quality') class CompleteTaskWithQualityTransition(BaseTransition): """Complete task with quality metrics""" - quality_score: float = Field(..., ge=0.0, le=1.0, description="Quality score") - completion_notes: str = Field("", description="Completion notes") - actual_hours: float = Field(None, ge=0.0, description="Actual hours worked") - + + quality_score: float = Field(..., ge=0.0, le=1.0, description='Quality score') + completion_notes: str = Field('', description='Completion notes') + actual_hours: float = Field(None, ge=0.0, description='Actual hours worked') + @property def target_state(self) -> str: return TaskStateChoices.COMPLETED - + def validate_transition(self, context: TransitionContext) -> bool: if context.current_state != TaskStateChoices.IN_PROGRESS: raise TransitionValidationError( - "Can only complete tasks that are in progress", - {"current_state": context.current_state} + 'Can only complete tasks that are in progress', {'current_state': context.current_state} ) - + # Quality check if self.quality_score < 0.6: raise TransitionValidationError( - f"Quality score too low: {self.quality_score}. Minimum required: 0.6" + f'Quality score too low: {self.quality_score}. Minimum required: 0.6' ) - + return True - + def post_transition_hook(self, context: TransitionContext, state_record) -> None: """Post-completion tasks like notifications""" # Mock notification system if hasattr(self, '_notifications'): - self._notifications.append(f"Task {context.entity.pk} completed with quality {self.quality_score}") - + self._notifications.append(f'Task {context.entity.pk} completed with quality {self.quality_score}') + def transition(self, context: TransitionContext) -> Dict[str, Any]: # Calculate metrics start_data = context.current_state_object.context_data if context.current_state_object else {} estimated_hours = start_data.get('estimated_hours') - + return { - "quality_score": self.quality_score, - "completion_notes": self.completion_notes, - "actual_hours": self.actual_hours, - "estimated_hours": estimated_hours, - "completed_at": context.timestamp.isoformat(), - "completed_by_id": context.current_user.id if context.current_user else None, - "efficiency_ratio": (estimated_hours / self.actual_hours) if (estimated_hours and self.actual_hours) else None + 'quality_score': self.quality_score, + 'completion_notes': self.completion_notes, + 'actual_hours': self.actual_hours, + 'estimated_hours': estimated_hours, + 'completed_at': context.timestamp.isoformat(), + 'completed_by_id': context.current_user.id if context.current_user else None, + 'efficiency_ratio': (estimated_hours / self.actual_hours) + if (estimated_hours and self.actual_hours) + else None, } - + # Execute the complete workflow - + # Step 1: Create task - create_transition = CreateTaskTransition( - created_by_id=100, - initial_priority="high" - ) - + create_transition = CreateTaskTransition(created_by_id=100, initial_priority='high') + # Test with StateManager integration with patch('fsm.state_manager.StateManager.get_current_state') as mock_get_current: mock_get_current.return_value = None # No current state - + context = TransitionContext( entity=self.task, current_user=self.user, current_state=None, - target_state=create_transition.target_state + target_state=create_transition.target_state, ) - + # Validate and execute creation self.assertTrue(create_transition.validate_transition(context)) creation_data = create_transition.transition(context) - - self.assertEqual(creation_data["created_by_id"], 100) - self.assertEqual(creation_data["initial_priority"], "high") - self.assertEqual(creation_data["creation_method"], "declarative_transition") - + + self.assertEqual(creation_data['created_by_id'], 100) + self.assertEqual(creation_data['initial_priority'], 'high') + self.assertEqual(creation_data['creation_method'], 'declarative_transition') + # Step 2: Assign and start task mock_current_state = Mock() mock_current_state.context_data = creation_data mock_get_state_obj.return_value = mock_current_state - + assign_transition = AssignAndStartTaskTransition( - assignee_id=200, # Different from creator - estimated_hours=4.5, - priority="urgent" + assignee_id=200, estimated_hours=4.5, priority='urgent' # Different from creator ) - + context = TransitionContext( entity=self.task, current_user=self.user, current_state=TaskStateChoices.CREATED, current_state_object=mock_current_state, - target_state=assign_transition.target_state + target_state=assign_transition.target_state, ) - + self.assertTrue(assign_transition.validate_transition(context)) assignment_data = assign_transition.transition(context) - - self.assertEqual(assignment_data["assignee_id"], 200) - self.assertEqual(assignment_data["estimated_hours"], 4.5) - self.assertTrue(assignment_data["work_started"]) - + + self.assertEqual(assignment_data['assignee_id'], 200) + self.assertEqual(assignment_data['estimated_hours'], 4.5) + self.assertTrue(assignment_data['work_started']) + # Step 3: Complete task mock_current_state.context_data = assignment_data - + complete_transition = CompleteTaskWithQualityTransition( - quality_score=0.85, - completion_notes="Task completed successfully with minor revisions", - actual_hours=5.2 + quality_score=0.85, completion_notes='Task completed successfully with minor revisions', actual_hours=5.2 ) complete_transition._notifications = [] # Mock notification system - + context = TransitionContext( entity=self.task, current_user=self.user, current_state=TaskStateChoices.IN_PROGRESS, current_state_object=mock_current_state, - target_state=complete_transition.target_state + target_state=complete_transition.target_state, ) - + self.assertTrue(complete_transition.validate_transition(context)) completion_data = complete_transition.transition(context) - - self.assertEqual(completion_data["quality_score"], 0.85) - self.assertEqual(completion_data["actual_hours"], 5.2) - self.assertAlmostEqual(completion_data["efficiency_ratio"], 4.5/5.2, places=2) - + + self.assertEqual(completion_data['quality_score'], 0.85) + self.assertEqual(completion_data['actual_hours'], 5.2) + self.assertAlmostEqual(completion_data['efficiency_ratio'], 4.5 / 5.2, places=2) + # Test post-hook mock_state_record = Mock() complete_transition.post_transition_hook(context, mock_state_record) self.assertEqual(len(complete_transition._notifications), 1) - + # Verify StateManager calls self.assertEqual(mock_transition_state.call_count, 0) # Not called in our test setup - + def test_annotation_review_workflow_integration(self): """ INTEGRATION TEST: Annotation review workflow - + Demonstrates a realistic annotation review process using enterprise-grade validation and approval logic. """ - + @register_transition('annotation', 'submit_for_review') class SubmitAnnotationForReview(BaseTransition): """Submit annotation for quality review""" - annotator_confidence: float = Field(..., ge=0.0, le=1.0, description="Annotator confidence") - annotation_time_seconds: int = Field(..., ge=1, description="Time spent annotating") - review_requested: bool = Field(True, description="Whether review is requested") - + + annotator_confidence: float = Field(..., ge=0.0, le=1.0, description='Annotator confidence') + annotation_time_seconds: int = Field(..., ge=1, description='Time spent annotating') + review_requested: bool = Field(True, description='Whether review is requested') + @property def target_state(self) -> str: return AnnotationStateChoices.SUBMITTED - + def validate_transition(self, context: TransitionContext) -> bool: # Check annotation has content if not hasattr(context.entity, 'result') or not context.entity.result: - raise TransitionValidationError("Cannot submit empty annotation") - + raise TransitionValidationError('Cannot submit empty annotation') + # Business rule: Low confidence annotations must request review if self.annotator_confidence < 0.7 and not self.review_requested: raise TransitionValidationError( - "Low confidence annotations must request review", - {"confidence": self.annotator_confidence, "threshold": 0.7} + 'Low confidence annotations must request review', + {'confidence': self.annotator_confidence, 'threshold': 0.7}, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "annotator_confidence": self.annotator_confidence, - "annotation_time_seconds": self.annotation_time_seconds, - "review_requested": self.review_requested, - "annotation_complexity": len(context.entity.result) if context.entity.result else 0, - "submitted_at": context.timestamp.isoformat(), - "submitted_by_id": context.current_user.id if context.current_user else None + 'annotator_confidence': self.annotator_confidence, + 'annotation_time_seconds': self.annotation_time_seconds, + 'review_requested': self.review_requested, + 'annotation_complexity': len(context.entity.result) if context.entity.result else 0, + 'submitted_at': context.timestamp.isoformat(), + 'submitted_by_id': context.current_user.id if context.current_user else None, } - + @register_transition('annotation', 'review_and_approve') class ReviewAndApproveAnnotation(BaseTransition): """Review annotation and approve/reject""" - reviewer_decision: str = Field(..., description="approve, reject, or request_changes") - quality_score: float = Field(..., ge=0.0, le=1.0, description="Reviewer quality assessment") - review_comments: str = Field("", description="Review comments") - corrections_made: bool = Field(False, description="Whether reviewer made corrections") - + + reviewer_decision: str = Field(..., description='approve, reject, or request_changes') + quality_score: float = Field(..., ge=0.0, le=1.0, description='Reviewer quality assessment') + review_comments: str = Field('', description='Review comments') + corrections_made: bool = Field(False, description='Whether reviewer made corrections') + @property def target_state(self) -> str: - if self.reviewer_decision == "approve": + if self.reviewer_decision == 'approve': return AnnotationStateChoices.COMPLETED else: return AnnotationStateChoices.DRAFT # Back to draft for changes - + def validate_transition(self, context: TransitionContext) -> bool: if context.current_state != AnnotationStateChoices.SUBMITTED: - raise TransitionValidationError("Can only review submitted annotations") - - valid_decisions = ["approve", "reject", "request_changes"] + raise TransitionValidationError('Can only review submitted annotations') + + valid_decisions = ['approve', 'reject', 'request_changes'] if self.reviewer_decision not in valid_decisions: raise TransitionValidationError( - f"Invalid decision: {self.reviewer_decision}", - {"valid_decisions": valid_decisions} + f'Invalid decision: {self.reviewer_decision}', {'valid_decisions': valid_decisions} ) - + # Quality score validation based on decision - if self.reviewer_decision == "approve" and self.quality_score < 0.6: + if self.reviewer_decision == 'approve' and self.quality_score < 0.6: raise TransitionValidationError( - "Cannot approve annotation with low quality score", - {"quality_score": self.quality_score, "decision": self.reviewer_decision} + 'Cannot approve annotation with low quality score', + {'quality_score': self.quality_score, 'decision': self.reviewer_decision}, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: # Get submission data for metrics submission_data = context.current_state_object.context_data if context.current_state_object else {} - + return { - "reviewer_decision": self.reviewer_decision, - "quality_score": self.quality_score, - "review_comments": self.review_comments, - "corrections_made": self.corrections_made, - "reviewed_at": context.timestamp.isoformat(), - "reviewed_by_id": context.current_user.id if context.current_user else None, - "original_confidence": submission_data.get("annotator_confidence"), - "confidence_vs_quality_diff": abs(submission_data.get("annotator_confidence", 0) - self.quality_score) + 'reviewer_decision': self.reviewer_decision, + 'quality_score': self.quality_score, + 'review_comments': self.review_comments, + 'corrections_made': self.corrections_made, + 'reviewed_at': context.timestamp.isoformat(), + 'reviewed_by_id': context.current_user.id if context.current_user else None, + 'original_confidence': submission_data.get('annotator_confidence'), + 'confidence_vs_quality_diff': abs( + submission_data.get('annotator_confidence', 0) - self.quality_score + ), } - + # Execute annotation workflow - + # Step 1: Submit annotation submit_transition = SubmitAnnotationForReview( - annotator_confidence=0.9, - annotation_time_seconds=300, # 5 minutes - review_requested=True + annotator_confidence=0.9, annotation_time_seconds=300, review_requested=True # 5 minutes ) - + context = TransitionContext( entity=self.annotation, current_user=self.user, current_state=AnnotationStateChoices.DRAFT, - target_state=submit_transition.target_state + target_state=submit_transition.target_state, ) - + self.assertTrue(submit_transition.validate_transition(context)) submit_data = submit_transition.transition(context) - - self.assertEqual(submit_data["annotator_confidence"], 0.9) - self.assertEqual(submit_data["annotation_time_seconds"], 300) - self.assertTrue(submit_data["review_requested"]) - self.assertEqual(submit_data["annotation_complexity"], 1) # Based on mock result - + + self.assertEqual(submit_data['annotator_confidence'], 0.9) + self.assertEqual(submit_data['annotation_time_seconds'], 300) + self.assertTrue(submit_data['review_requested']) + self.assertEqual(submit_data['annotation_complexity'], 1) # Based on mock result + # Step 2: Review and approve mock_submission_state = Mock() mock_submission_state.context_data = submit_data - + review_transition = ReviewAndApproveAnnotation( - reviewer_decision="approve", + reviewer_decision='approve', quality_score=0.85, - review_comments="High quality annotation with good coverage", - corrections_made=False + review_comments='High quality annotation with good coverage', + corrections_made=False, ) - + context = TransitionContext( entity=self.annotation, current_user=self.user, current_state=AnnotationStateChoices.SUBMITTED, current_state_object=mock_submission_state, - target_state=review_transition.target_state + target_state=review_transition.target_state, ) - + self.assertTrue(review_transition.validate_transition(context)) self.assertEqual(review_transition.target_state, AnnotationStateChoices.COMPLETED) - + review_data = review_transition.transition(context) - - self.assertEqual(review_data["reviewer_decision"], "approve") - self.assertEqual(review_data["quality_score"], 0.85) - self.assertEqual(review_data["original_confidence"], 0.9) - self.assertAlmostEqual(review_data["confidence_vs_quality_diff"], 0.05, places=2) - + + self.assertEqual(review_data['reviewer_decision'], 'approve') + self.assertEqual(review_data['quality_score'], 0.85) + self.assertEqual(review_data['original_confidence'], 0.9) + self.assertAlmostEqual(review_data['confidence_vs_quality_diff'], 0.05, places=2) + # Test rejection scenario reject_transition = ReviewAndApproveAnnotation( - reviewer_decision="reject", + reviewer_decision='reject', quality_score=0.3, - review_comments="Insufficient annotation quality", - corrections_made=False + review_comments='Insufficient annotation quality', + corrections_made=False, ) - + self.assertEqual(reject_transition.target_state, AnnotationStateChoices.DRAFT) - + # Test validation failure invalid_review = ReviewAndApproveAnnotation( - reviewer_decision="approve", # Trying to approve + reviewer_decision='approve', # Trying to approve quality_score=0.5, # But quality too low - review_comments="Test", + review_comments='Test', ) - + with self.assertRaises(TransitionValidationError) as cm: invalid_review.validate_transition(context) - - self.assertIn("Cannot approve annotation with low quality score", str(cm.exception)) - + + self.assertIn('Cannot approve annotation with low quality score', str(cm.exception)) + @patch('fsm.transition_utils.execute_transition') def test_transition_builder_with_django_models(self, mock_execute): """ INTEGRATION TEST: TransitionBuilder with Django model integration - + Shows how to use the fluent TransitionBuilder interface with real Django models and complex business logic. """ - + @register_transition('task', 'bulk_update_status') class BulkUpdateTaskStatusTransition(BaseTransition): """Bulk update task status with metadata""" - new_status: str = Field(..., description="New status for tasks") - update_reason: str = Field(..., description="Reason for bulk update") - updated_by_system: bool = Field(False, description="Whether updated by automated system") - batch_id: str = Field(None, description="Batch operation ID") - + + new_status: str = Field(..., description='New status for tasks') + update_reason: str = Field(..., description='Reason for bulk update') + updated_by_system: bool = Field(False, description='Whether updated by automated system') + batch_id: str = Field(None, description='Batch operation ID') + @property def target_state(self) -> str: return self.new_status - + def validate_transition(self, context: TransitionContext) -> bool: valid_statuses = [TaskStateChoices.CREATED, TaskStateChoices.IN_PROGRESS, TaskStateChoices.COMPLETED] if self.new_status not in valid_statuses: - raise TransitionValidationError(f"Invalid status: {self.new_status}") - + raise TransitionValidationError(f'Invalid status: {self.new_status}') + # Can't bulk update to the same status if context.current_state == self.new_status: - raise TransitionValidationError("Cannot update to the same status") - + raise TransitionValidationError('Cannot update to the same status') + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "new_status": self.new_status, - "update_reason": self.update_reason, - "updated_by_system": self.updated_by_system, - "batch_id": self.batch_id, - "bulk_update_timestamp": context.timestamp.isoformat(), - "previous_status": context.current_state + 'new_status': self.new_status, + 'update_reason': self.update_reason, + 'updated_by_system': self.updated_by_system, + 'batch_id': self.batch_id, + 'bulk_update_timestamp': context.timestamp.isoformat(), + 'previous_status': context.current_state, } - + # Mock successful execution mock_state_record = Mock() - mock_state_record.id = "mock-uuid" + mock_state_record.id = 'mock-uuid' mock_execute.return_value = mock_state_record - + # Test fluent interface - result = (TransitionBuilder(self.task) - .transition('bulk_update_status') - .with_data( - new_status=TaskStateChoices.IN_PROGRESS, - update_reason="Project priority change", - updated_by_system=True, - batch_id="batch_2024_001" - ) - .by_user(self.user) - .with_context( - project_update=True, - notification_level="high" - ) - .execute()) - + result = ( + TransitionBuilder(self.task) + .transition('bulk_update_status') + .with_data( + new_status=TaskStateChoices.IN_PROGRESS, + update_reason='Project priority change', + updated_by_system=True, + batch_id='batch_2024_001', + ) + .by_user(self.user) + .with_context(project_update=True, notification_level='high') + .execute() + ) + # Verify the call mock_execute.assert_called_once() call_args, call_kwargs = mock_execute.call_args - + # Check call parameters self.assertEqual(call_kwargs['entity'], self.task) self.assertEqual(call_kwargs['transition_name'], 'bulk_update_status') self.assertEqual(call_kwargs['user'], self.user) - + # Check transition data transition_data = call_kwargs['transition_data'] self.assertEqual(transition_data['new_status'], TaskStateChoices.IN_PROGRESS) - self.assertEqual(transition_data['update_reason'], "Project priority change") + self.assertEqual(transition_data['update_reason'], 'Project priority change') self.assertTrue(transition_data['updated_by_system']) - self.assertEqual(transition_data['batch_id'], "batch_2024_001") - + self.assertEqual(transition_data['batch_id'], 'batch_2024_001') + # Check context self.assertTrue(call_kwargs['project_update']) - self.assertEqual(call_kwargs['notification_level'], "high") - + self.assertEqual(call_kwargs['notification_level'], 'high') + # Check return value self.assertEqual(result, mock_state_record) - + def test_error_handling_with_django_models(self): """ INTEGRATION TEST: Error handling with Django model validation - + Tests comprehensive error handling scenarios that might occur in real Django model integration. """ - + @register_transition('task', 'assign_with_constraints') class AssignTaskWithConstraints(BaseTransition): """Task assignment with business constraints""" - assignee_id: int = Field(..., description="User to assign to") - max_concurrent_tasks: int = Field(5, description="Max concurrent tasks per user") - skill_requirements: list = Field(default_factory=list, description="Required skills") - + + assignee_id: int = Field(..., description='User to assign to') + max_concurrent_tasks: int = Field(5, description='Max concurrent tasks per user') + skill_requirements: list = Field(default_factory=list, description='Required skills') + @property def target_state(self) -> str: return TaskStateChoices.IN_PROGRESS - + def validate_transition(self, context: TransitionContext) -> bool: errors = [] - + # Mock database checks (in real scenario, these would be actual queries) - + # 1. Check user exists and is active if self.assignee_id <= 0: - errors.append("Invalid user ID") - + errors.append('Invalid user ID') + # 2. Check user's current task load if self.max_concurrent_tasks < 1: - errors.append("Max concurrent tasks must be at least 1") - + errors.append('Max concurrent tasks must be at least 1') + # 3. Check skill requirements if self.skill_requirements: # Mock skill validation - available_skills = ["python", "labeling", "review"] + available_skills = ['python', 'labeling', 'review'] missing_skills = [skill for skill in self.skill_requirements if skill not in available_skills] if missing_skills: - errors.append(f"Missing required skills: {missing_skills}") - + errors.append(f'Missing required skills: {missing_skills}') + # 4. Check project-level constraints if hasattr(context.entity, 'project_id'): # Mock project validation if context.entity.project_id <= 0: - errors.append("Invalid project configuration") - + errors.append('Invalid project configuration') + # 5. Check organization permissions if hasattr(context.entity, 'organization_id'): if not context.current_user: - errors.append("User authentication required for assignment") - + errors.append('User authentication required for assignment') + if errors: raise TransitionValidationError( f"Assignment validation failed: {'; '.join(errors)}", { - "validation_errors": errors, - "assignee_id": self.assignee_id, - "task_id": context.entity.pk, - "skill_requirements": self.skill_requirements - } + 'validation_errors': errors, + 'assignee_id': self.assignee_id, + 'task_id': context.entity.pk, + 'skill_requirements': self.skill_requirements, + }, ) - + return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - "assignee_id": self.assignee_id, - "max_concurrent_tasks": self.max_concurrent_tasks, - "skill_requirements": self.skill_requirements, - "assignment_validated": True + 'assignee_id': self.assignee_id, + 'max_concurrent_tasks': self.max_concurrent_tasks, + 'skill_requirements': self.skill_requirements, + 'assignment_validated': True, } - + # Test successful validation valid_transition = AssignTaskWithConstraints( - assignee_id=123, - max_concurrent_tasks=3, - skill_requirements=["python", "labeling"] + assignee_id=123, max_concurrent_tasks=3, skill_requirements=['python', 'labeling'] ) - + context = TransitionContext( entity=self.task, current_user=self.user, current_state=TaskStateChoices.CREATED, - target_state=valid_transition.target_state + target_state=valid_transition.target_state, ) - + self.assertTrue(valid_transition.validate_transition(context)) - + # Test multiple validation errors invalid_transition = AssignTaskWithConstraints( assignee_id=-1, # Invalid user ID max_concurrent_tasks=0, # Invalid max tasks - skill_requirements=["nonexistent_skill"] # Missing skill + skill_requirements=['nonexistent_skill'], # Missing skill ) - + with self.assertRaises(TransitionValidationError) as cm: invalid_transition.validate_transition(context) - + error = cm.exception error_msg = str(error) - + # Check all validation errors are included - self.assertIn("Invalid user ID", error_msg) - self.assertIn("Max concurrent tasks must be at least 1", error_msg) - self.assertIn("Missing required skills", error_msg) - + self.assertIn('Invalid user ID', error_msg) + self.assertIn('Max concurrent tasks must be at least 1', error_msg) + self.assertIn('Missing required skills', error_msg) + # Check error context - self.assertIn("validation_errors", error.context) - self.assertEqual(len(error.context["validation_errors"]), 3) - self.assertEqual(error.context["assignee_id"], -1) - + self.assertIn('validation_errors', error.context) + self.assertEqual(len(error.context['validation_errors']), 3) + self.assertEqual(error.context['assignee_id'], -1) + # Test authentication requirement context_no_user = TransitionContext( entity=self.task, current_user=None, # No user current_state=TaskStateChoices.CREATED, - target_state=valid_transition.target_state + target_state=valid_transition.target_state, ) - + with self.assertRaises(TransitionValidationError) as cm: valid_transition.validate_transition(context_no_user) - - self.assertIn("User authentication required", str(cm.exception)) \ No newline at end of file + + self.assertIn('User authentication required', str(cm.exception)) diff --git a/label_studio/fsm/tests/test_performance_concurrency.py b/label_studio/fsm/tests/test_performance_concurrency.py index 624e5bb9e44b..bcfc5993dfd8 100644 --- a/label_studio/fsm/tests/test_performance_concurrency.py +++ b/label_studio/fsm/tests/test_performance_concurrency.py @@ -6,309 +6,293 @@ production FSM systems. """ -import pytest import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime, timedelta -from unittest.mock import Mock, patch -from typing import Dict, Any, List -import statistics +from datetime import datetime +from typing import Any, Dict +from unittest.mock import Mock from django.test import TestCase, TransactionTestCase +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, transition_registry from pydantic import Field -from fsm.transitions import ( - BaseTransition, - TransitionContext, - TransitionValidationError, - transition_registry, - register_transition -) -from fsm.transition_utils import ( - execute_transition, - get_available_transitions, - TransitionBuilder, - validate_transition_data -) - class PerformanceTestTransition(BaseTransition): """Simple transition for performance testing""" - operation_id: int = Field(..., description="Operation identifier") - data_size: int = Field(1, description="Size of data to process") - + + operation_id: int = Field(..., description='Operation identifier') + data_size: int = Field(1, description='Size of data to process') + @property def target_state(self) -> str: - return "PROCESSED" - + return 'PROCESSED' + @classmethod def get_target_state(cls) -> str: - return "PROCESSED" - + return 'PROCESSED' + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def validate_transition(self, context: TransitionContext) -> bool: # Simulate some validation work if self.data_size < 0: - raise TransitionValidationError("Invalid data size") + raise TransitionValidationError('Invalid data size') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: # Simulate processing work return { - "operation_id": self.operation_id, - "data_size": self.data_size, - "processed_at": context.timestamp.isoformat(), - "processing_time_ms": 1 # Mock processing time + 'operation_id': self.operation_id, + 'data_size': self.data_size, + 'processed_at': context.timestamp.isoformat(), + 'processing_time_ms': 1, # Mock processing time } class ConcurrencyTestTransition(BaseTransition): """Transition for testing concurrent access patterns""" - thread_id: int = Field(..., description="Thread identifier") - shared_counter: int = Field(0, description="Shared counter for testing") - sleep_duration: float = Field(0.0, description="Simulate processing delay") - execution_order: list = Field(default_factory=list, description="Track execution order") - + + thread_id: int = Field(..., description='Thread identifier') + shared_counter: int = Field(0, description='Shared counter for testing') + sleep_duration: float = Field(0.0, description='Simulate processing delay') + execution_order: list = Field(default_factory=list, description='Track execution order') + @property def target_state(self) -> str: - return f"PROCESSED_BY_THREAD_{self.thread_id}" - + return f'PROCESSED_BY_THREAD_{self.thread_id}' + @classmethod def get_target_state(cls) -> str: - return "PROCESSED_BY_THREAD_0" # Default for class-level queries - + return 'PROCESSED_BY_THREAD_0' # Default for class-level queries + @classmethod def can_transition_from_state(cls, context: TransitionContext) -> bool: return True - + def validate_transition(self, context: TransitionContext) -> bool: # Record validation timing for concurrency analysis - self.execution_order.append(f"validate_{self.thread_id}_{time.time()}") + self.execution_order.append(f'validate_{self.thread_id}_{time.time()}') return True - + def transition(self, context: TransitionContext) -> Dict[str, Any]: # Record transition timing - self.execution_order.append(f"transition_{self.thread_id}_{time.time()}") - + self.execution_order.append(f'transition_{self.thread_id}_{time.time()}') + # Simulate some processing delay if self.sleep_duration > 0: time.sleep(self.sleep_duration) - + return { - "thread_id": self.thread_id, - "shared_counter": self.shared_counter, - "execution_order": self.execution_order.copy(), - "processed_at": context.timestamp.isoformat() + 'thread_id': self.thread_id, + 'shared_counter': self.shared_counter, + 'execution_order': self.execution_order.copy(), + 'processed_at': context.timestamp.isoformat(), } class PerformanceTests(TestCase): """ Performance tests for the declarative transition system. - + These tests measure execution time, memory usage patterns, and scalability characteristics. """ - + def setUp(self): self.mock_entity = Mock() self.mock_entity.pk = 1 self.mock_entity._meta.model_name = 'test_entity' - + self.mock_user = Mock() self.mock_user.id = 123 - + # Clear registry for clean tests transition_registry._transitions.clear() transition_registry.register('test_entity', 'performance_test', PerformanceTestTransition) - + def test_single_transition_performance(self): """ PERFORMANCE TEST: Measure single transition execution time - + Validates that individual transitions execute within acceptable time limits. """ - + transition = PerformanceTestTransition(operation_id=1, data_size=1000) - + context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="CREATED", - target_state=transition.target_state + current_state='CREATED', + target_state=transition.target_state, ) - + # Measure validation performance start_time = time.perf_counter() result = transition.validate_transition(context) validation_time = time.perf_counter() - start_time - + self.assertTrue(result) self.assertLess(validation_time, 0.001) # Should be under 1ms - + # Measure transition execution performance start_time = time.perf_counter() transition_data = transition.transition(context) execution_time = time.perf_counter() - start_time - + self.assertIsInstance(transition_data, dict) self.assertLess(execution_time, 0.001) # Should be under 1ms - + # Measure total workflow performance start_time = time.perf_counter() transition.context = context transition.validate_transition(context) transition.transition(context) total_time = time.perf_counter() - start_time - + self.assertLess(total_time, 0.005) # Total should be under 5ms - + def test_batch_transition_performance(self): """ PERFORMANCE TEST: Measure batch transition creation and validation - + Tests performance when creating many transition instances rapidly. """ - + batch_size = 1000 - + # Test batch creation performance start_time = time.perf_counter() transitions = [] - + for i in range(batch_size): transition = PerformanceTestTransition(operation_id=i, data_size=i * 10) transitions.append(transition) - + creation_time = time.perf_counter() - start_time creation_time_per_item = creation_time / batch_size - + self.assertEqual(len(transitions), batch_size) self.assertLess(creation_time_per_item, 0.001) # Under 1ms per transition - + # Test batch validation performance context = TransitionContext( - entity=self.mock_entity, - current_user=self.mock_user, - current_state="CREATED", - target_state="PROCESSED" + entity=self.mock_entity, current_user=self.mock_user, current_state='CREATED', target_state='PROCESSED' ) - + start_time = time.perf_counter() validation_results = [] - + for transition in transitions: result = transition.validate_transition(context) validation_results.append(result) - + validation_time = time.perf_counter() - start_time validation_time_per_item = validation_time / batch_size - + self.assertTrue(all(validation_results)) self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation self.assertLess(validation_time, 0.5) # Total batch under 500ms - + def test_registry_performance(self): """ PERFORMANCE TEST: Registry operations under load - + Tests the performance of registry lookups and registrations. """ - + # Test registry lookup performance lookup_count = 10000 - + start_time = time.perf_counter() - + for i in range(lookup_count): retrieved_class = transition_registry.get_transition('test_entity', 'performance_test') - + lookup_time = time.perf_counter() - start_time lookup_time_per_operation = lookup_time / lookup_count - + self.assertEqual(retrieved_class, PerformanceTestTransition) self.assertLess(lookup_time_per_operation, 0.0001) # Under 0.1ms per lookup - + # Test registry registration performance registration_count = 1000 - + start_time = time.perf_counter() - + for i in range(registration_count): entity_name = f'entity_{i}' transition_name = f'transition_{i}' transition_registry.register(entity_name, transition_name, PerformanceTestTransition) - + registration_time = time.perf_counter() - start_time registration_time_per_operation = registration_time / registration_count - + self.assertLess(registration_time_per_operation, 0.001) # Under 1ms per registration - + # Verify registrations worked test_class = transition_registry.get_transition('entity_500', 'transition_500') self.assertEqual(test_class, PerformanceTestTransition) - + def test_pydantic_validation_performance(self): """ PERFORMANCE TEST: Pydantic validation performance - + Measures the overhead of Pydantic validation in transitions. """ - + # Test valid data performance - valid_data = {"operation_id": 123, "data_size": 1000} + valid_data = {'operation_id': 123, 'data_size': 1000} validation_count = 10000 - + start_time = time.perf_counter() - + for i in range(validation_count): - transition = PerformanceTestTransition(**valid_data) - + PerformanceTestTransition(**valid_data) + validation_time = time.perf_counter() - start_time validation_time_per_item = validation_time / validation_count - + self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation - + # Test validation error performance - invalid_data = {"operation_id": "invalid", "data_size": -1} + invalid_data = {'operation_id': 'invalid', 'data_size': -1} error_count = 1000 - + start_time = time.perf_counter() errors = [] - + for i in range(error_count): try: PerformanceTestTransition(**invalid_data) except Exception as e: errors.append(e) - + error_time = time.perf_counter() - start_time error_time_per_item = error_time / error_count - + self.assertEqual(len(errors), error_count) self.assertLess(error_time_per_item, 0.01) # Under 10ms per error (errors are slower) - + def test_memory_usage_patterns(self): """ PERFORMANCE TEST: Memory usage analysis - + Tests memory usage patterns for transition instances and contexts. """ - + import sys - + # Measure base memory usage base_transitions = [] for i in range(100): transition = PerformanceTestTransition(operation_id=i, data_size=i) base_transitions.append(transition) - + base_size = sys.getsizeof(base_transitions[0]) - + # Test memory usage with complex data complex_transitions = [] for i in range(100): @@ -317,23 +301,23 @@ def test_memory_usage_patterns(self): context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="CREATED", + current_state='CREATED', target_state=transition.target_state, - metadata={"large_data": "x" * 1000} # Add some bulk + metadata={'large_data': 'x' * 1000}, # Add some bulk ) transition.context = context complex_transitions.append(transition) - + complex_size = sys.getsizeof(complex_transitions[0]) - + # Memory usage should be reasonable memory_overhead = complex_size - base_size self.assertLess(memory_overhead, 10000) # Under 10KB overhead per transition - + # Clean up contexts to test garbage collection for transition in complex_transitions: transition.context = None - + # Verify memory can be reclaimed (simplified test) self.assertIsNone(complex_transitions[0].context) @@ -341,151 +325,149 @@ def test_memory_usage_patterns(self): class ConcurrencyTests(TransactionTestCase): """ Concurrency tests for the declarative transition system. - + These tests validate thread safety and concurrent execution patterns that are critical for production systems. """ - + def setUp(self): self.mock_entity = Mock() self.mock_entity.pk = 1 self.mock_entity._meta.model_name = 'test_entity' - + self.mock_user = Mock() self.mock_user.id = 123 - + # Clear registry for clean tests transition_registry._transitions.clear() transition_registry.register('test_entity', 'concurrency_test', ConcurrencyTestTransition) - + def test_concurrent_transition_creation(self): """ CONCURRENCY TEST: Thread-safe transition instance creation - + Validates that multiple threads can create transition instances concurrently without conflicts. """ - + thread_count = 10 transitions_per_thread = 100 all_transitions = [] thread_results = [] - + def create_transitions(thread_id): """Worker function to create transitions in a thread""" local_transitions = [] for i in range(transitions_per_thread): transition = ConcurrencyTestTransition( - thread_id=thread_id, - shared_counter=i, - sleep_duration=0.001 # Small delay to increase contention + thread_id=thread_id, shared_counter=i, sleep_duration=0.001 # Small delay to increase contention ) local_transitions.append(transition) return local_transitions - + # Execute concurrent creation with ThreadPoolExecutor(max_workers=thread_count) as executor: futures = [] for thread_id in range(thread_count): future = executor.submit(create_transitions, thread_id) futures.append(future) - + for future in as_completed(futures): thread_transitions = future.result() thread_results.append(thread_transitions) all_transitions.extend(thread_transitions) - + # Validate results total_expected = thread_count * transitions_per_thread self.assertEqual(len(all_transitions), total_expected) - + # Check thread separation thread_ids = [t.thread_id for t in all_transitions] unique_threads = set(thread_ids) self.assertEqual(len(unique_threads), thread_count) - + # Validate each thread created correct number of transitions for thread_id in range(thread_count): thread_transitions = [t for t in all_transitions if t.thread_id == thread_id] self.assertEqual(len(thread_transitions), transitions_per_thread) - + def test_concurrent_transition_execution(self): """ CONCURRENCY TEST: Concurrent transition execution - + Tests that multiple transitions can be executed concurrently without race conditions in the execution logic. """ - + thread_count = 5 execution_results = [] - + def execute_transition(thread_id): """Worker function to execute a transition""" transition = ConcurrencyTestTransition( thread_id=thread_id, shared_counter=thread_id * 10, - sleep_duration=0.01 # Small delay to test concurrency + sleep_duration=0.01, # Small delay to test concurrency ) - + context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="CREATED", + current_state='CREATED', target_state=transition.target_state, - timestamp=datetime.now() + timestamp=datetime.now(), ) - + # Execute validation and transition validation_result = transition.validate_transition(context) transition_data = transition.transition(context) - + return { - "thread_id": thread_id, - "validation_result": validation_result, - "transition_data": transition_data, - "execution_order": transition.execution_order + 'thread_id': thread_id, + 'validation_result': validation_result, + 'transition_data': transition_data, + 'execution_order': transition.execution_order, } - + # Execute concurrent transitions with ThreadPoolExecutor(max_workers=thread_count) as executor: futures = [] for thread_id in range(thread_count): future = executor.submit(execute_transition, thread_id) futures.append(future) - + for future in as_completed(futures): result = future.result() execution_results.append(result) - + # Validate results self.assertEqual(len(execution_results), thread_count) - + for result in execution_results: - self.assertTrue(result["validation_result"]) - self.assertIn("thread_id", result["transition_data"]) - self.assertIsInstance(result["execution_order"], list) - self.assertGreater(len(result["execution_order"]), 0) - + self.assertTrue(result['validation_result']) + self.assertIn('thread_id', result['transition_data']) + self.assertIsInstance(result['execution_order'], list) + self.assertGreater(len(result['execution_order']), 0) + # Check thread isolation - thread_ids = [r["transition_data"]["thread_id"] for r in execution_results] + thread_ids = [r['transition_data']['thread_id'] for r in execution_results] self.assertEqual(set(thread_ids), set(range(thread_count))) - + def test_registry_thread_safety(self): """ CONCURRENCY TEST: Registry thread safety - + Tests that the transition registry handles concurrent registration and lookup operations safely. """ - + thread_count = 10 operations_per_thread = 100 - + def registry_operations(thread_id): """Worker function for registry operations""" operations_completed = 0 - + for i in range(operations_per_thread): # Mix of registration and lookup operations if i % 3 == 0: @@ -494,7 +476,7 @@ def registry_operations(thread_id): transition_name = f'transition_{i}' transition_registry.register(entity_name, transition_name, ConcurrencyTestTransition) operations_completed += 1 - + elif i % 3 == 1: # Lookup existing transition try: @@ -503,7 +485,7 @@ def registry_operations(thread_id): operations_completed += 1 except Exception: pass - + else: # List operations try: @@ -512,242 +494,229 @@ def registry_operations(thread_id): operations_completed += 1 except Exception: pass - + return operations_completed - + # Execute concurrent registry operations with ThreadPoolExecutor(max_workers=thread_count) as executor: futures = [] for thread_id in range(thread_count): future = executor.submit(registry_operations, thread_id) futures.append(future) - + operation_counts = [] for future in as_completed(futures): count = future.result() operation_counts.append(count) - + # Validate no operations failed due to thread safety issues total_operations = sum(operation_counts) expected_minimum = thread_count * operations_per_thread * 0.9 # Allow some variance - + self.assertGreater(total_operations, expected_minimum) - + # Registry should be in consistent state entities = transition_registry.list_entities() self.assertIsInstance(entities, list) self.assertGreater(len(entities), thread_count) # Should have entities from all threads - + def test_context_isolation(self): """ CONCURRENCY TEST: Context isolation between threads - + Ensures that transition contexts remain isolated between concurrent executions and don't leak data. """ - + thread_count = 8 context_data = [] - + def context_isolation_test(thread_id): """Test context isolation in a thread""" # Create unique context data for this thread unique_data = { - "thread_specific_id": thread_id, - "random_data": f"thread_{thread_id}_data", - "timestamp": datetime.now().isoformat(), - "test_counter": thread_id * 1000 + 'thread_specific_id': thread_id, + 'random_data': f'thread_{thread_id}_data', + 'timestamp': datetime.now().isoformat(), + 'test_counter': thread_id * 1000, } - + transition = ConcurrencyTestTransition( thread_id=thread_id, shared_counter=thread_id, - sleep_duration=0.005 # Small delay to increase chance of interference + sleep_duration=0.005, # Small delay to increase chance of interference ) - + context = TransitionContext( entity=self.mock_entity, current_user=self.mock_user, - current_state="CREATED", + current_state='CREATED', target_state=transition.target_state, - metadata=unique_data + metadata=unique_data, ) - + # Set context on transition transition.context = context - + # Execute transition validation_result = transition.validate_transition(context) transition_data = transition.transition(context) - + # Retrieve context and verify isolation retrieved_context = transition.context - + return { - "thread_id": thread_id, - "original_metadata": unique_data, - "retrieved_metadata": retrieved_context.metadata, - "validation_result": validation_result, - "transition_data": transition_data + 'thread_id': thread_id, + 'original_metadata': unique_data, + 'retrieved_metadata': retrieved_context.metadata, + 'validation_result': validation_result, + 'transition_data': transition_data, } - + # Execute with high concurrency with ThreadPoolExecutor(max_workers=thread_count) as executor: futures = [] for thread_id in range(thread_count): future = executor.submit(context_isolation_test, thread_id) futures.append(future) - + for future in as_completed(futures): result = future.result() context_data.append(result) - + # Validate context isolation self.assertEqual(len(context_data), thread_count) - + for result in context_data: - thread_id = result["thread_id"] - original_metadata = result["original_metadata"] - retrieved_metadata = result["retrieved_metadata"] - + thread_id = result['thread_id'] + original_metadata = result['original_metadata'] + retrieved_metadata = result['retrieved_metadata'] + # Context should match exactly what was set for this thread - self.assertEqual(original_metadata["thread_specific_id"], thread_id) - self.assertEqual(retrieved_metadata["thread_specific_id"], thread_id) - self.assertEqual(original_metadata["random_data"], retrieved_metadata["random_data"]) - self.assertEqual(original_metadata["test_counter"], thread_id * 1000) - + self.assertEqual(original_metadata['thread_specific_id'], thread_id) + self.assertEqual(retrieved_metadata['thread_specific_id'], thread_id) + self.assertEqual(original_metadata['random_data'], retrieved_metadata['random_data']) + self.assertEqual(original_metadata['test_counter'], thread_id * 1000) + # Should not have data from other threads for other_result in context_data: - if other_result["thread_id"] != thread_id: + if other_result['thread_id'] != thread_id: self.assertNotEqual( - retrieved_metadata["thread_specific_id"], - other_result["original_metadata"]["thread_specific_id"] + retrieved_metadata['thread_specific_id'], + other_result['original_metadata']['thread_specific_id'], ) - + def test_stress_test_mixed_operations(self): """ STRESS TEST: Mixed operations under load - + Combines multiple types of operations under high concurrency to test overall system stability. """ - + duration_seconds = 2 # Short duration for CI thread_count = 6 - + # Shared statistics stats = { - "transitions_created": 0, - "validations_performed": 0, - "transitions_executed": 0, - "registry_lookups": 0, - "errors_encountered": 0 + 'transitions_created': 0, + 'validations_performed': 0, + 'transitions_executed': 0, + 'registry_lookups': 0, + 'errors_encountered': 0, } stats_lock = threading.Lock() - + def mixed_operations_worker(worker_id): """Worker that performs mixed operations""" local_stats = { - "transitions_created": 0, - "validations_performed": 0, - "transitions_executed": 0, - "registry_lookups": 0, - "errors_encountered": 0 + 'transitions_created': 0, + 'validations_performed': 0, + 'transitions_executed': 0, + 'registry_lookups': 0, + 'errors_encountered': 0, } - + end_time = time.time() + duration_seconds operation_counter = 0 - + while time.time() < end_time: try: operation_type = operation_counter % 4 - + if operation_type == 0: # Create transition - transition = ConcurrencyTestTransition( - thread_id=worker_id, - shared_counter=operation_counter - ) - local_stats["transitions_created"] += 1 - + transition = ConcurrencyTestTransition(thread_id=worker_id, shared_counter=operation_counter) + local_stats['transitions_created'] += 1 + elif operation_type == 1: # Validate transition - transition = ConcurrencyTestTransition( - thread_id=worker_id, - shared_counter=operation_counter - ) + transition = ConcurrencyTestTransition(thread_id=worker_id, shared_counter=operation_counter) context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=transition.target_state ) transition.validate_transition(context) - local_stats["validations_performed"] += 1 - + local_stats['validations_performed'] += 1 + elif operation_type == 2: # Execute transition - transition = ConcurrencyTestTransition( - thread_id=worker_id, - shared_counter=operation_counter - ) + transition = ConcurrencyTestTransition(thread_id=worker_id, shared_counter=operation_counter) context = TransitionContext( - entity=self.mock_entity, - current_state="CREATED", - target_state=transition.target_state + entity=self.mock_entity, current_state='CREATED', target_state=transition.target_state ) transition.transition(context) - local_stats["transitions_executed"] += 1 - + local_stats['transitions_executed'] += 1 + else: # Registry lookup found = transition_registry.get_transition('test_entity', 'concurrency_test') if found: - local_stats["registry_lookups"] += 1 - + local_stats['registry_lookups'] += 1 + operation_counter += 1 - - except Exception as e: - local_stats["errors_encountered"] += 1 - + + except Exception: + local_stats['errors_encountered'] += 1 + # Small yield to allow other threads time.sleep(0.0001) - + # Update shared statistics with stats_lock: for key in stats: stats[key] += local_stats[key] - + return local_stats - + # Execute stress test with ThreadPoolExecutor(max_workers=thread_count) as executor: futures = [] for worker_id in range(thread_count): future = executor.submit(mixed_operations_worker, worker_id) futures.append(future) - + worker_results = [] for future in as_completed(futures): result = future.result() worker_results.append(result) - + # Validate stress test results total_operations = sum( - stats["transitions_created"] + - stats["validations_performed"] + - stats["transitions_executed"] + - stats["registry_lookups"] + stats['transitions_created'] + + stats['validations_performed'] + + stats['transitions_executed'] + + stats['registry_lookups'] ) - + # Should have performed substantial work self.assertGreater(total_operations, thread_count * 10) - + # Error rate should be very low (< 1%) - error_rate = stats["errors_encountered"] / max(total_operations, 1) + error_rate = stats['errors_encountered'] / max(total_operations, 1) self.assertLess(error_rate, 0.01) - + # All operation types should have been performed - self.assertGreater(stats["transitions_created"], 0) - self.assertGreater(stats["validations_performed"], 0) - self.assertGreater(stats["transitions_executed"], 0) - self.assertGreater(stats["registry_lookups"], 0) \ No newline at end of file + self.assertGreater(stats['transitions_created'], 0) + self.assertGreater(stats['validations_performed'], 0) + self.assertGreater(stats['transitions_executed'], 0) + self.assertGreater(stats['registry_lookups'], 0) diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 119b15701d79..d3fcbada5b13 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -5,38 +5,34 @@ the new Pydantic-based transition system with existing Label Studio code. """ -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from django.db.models import Model -from .models import BaseState, get_state_model_for_entity +from .models import BaseState from .state_manager import StateManager from .transitions import BaseTransition, TransitionValidationError, transition_registry def execute_transition( - entity: Model, - transition_name: str, - transition_data: Dict[str, Any], - user=None, - **context_kwargs + entity: Model, transition_name: str, transition_data: Dict[str, Any], user=None, **context_kwargs ) -> BaseState: """ Execute a named transition on an entity. - + This is a convenience function that looks up the transition class and executes it with the provided data. - + Args: entity: The entity to transition transition_name: Name of the registered transition transition_data: Data for the transition (validated by Pydantic) user: User executing the transition **context_kwargs: Additional context data - + Returns: The newly created state record - + Raises: ValueError: If transition is not found TransitionValidationError: If transition validation fails @@ -48,43 +44,35 @@ def execute_transition( entity=entity, transition_data=transition_data, user=user, - **context_kwargs + **context_kwargs, ) -def execute_transition_instance( - entity: Model, - transition: BaseTransition, - user=None, - **context_kwargs -) -> BaseState: +def execute_transition_instance(entity: Model, transition: BaseTransition, user=None, **context_kwargs) -> BaseState: """ Execute a pre-created transition instance. - + Args: entity: The entity to transition transition: Instance of a transition class user: User executing the transition **context_kwargs: Additional context data - + Returns: The newly created state record """ return StateManager.execute_declarative_transition( - transition=transition, - entity=entity, - user=user, - **context_kwargs + transition=transition, entity=entity, user=user, **context_kwargs ) def get_available_transitions(entity: Model) -> Dict[str, Type[BaseTransition]]: """ Get all available transitions for an entity. - + Args: entity: The entity to get transitions for - + Returns: Dictionary mapping transition names to transition classes """ @@ -92,109 +80,100 @@ def get_available_transitions(entity: Model) -> Dict[str, Type[BaseTransition]]: return transition_registry.get_transitions_for_entity(entity_name) -def get_valid_transitions( - entity: Model, - user=None, - validate: bool = True -) -> Dict[str, Type[BaseTransition]]: +def get_valid_transitions(entity: Model, user=None, validate: bool = True) -> Dict[str, Type[BaseTransition]]: """ Get transitions that are valid for the entity's current state. - + Args: entity: The entity to check transitions for user: User context for validation validate: Whether to validate each transition (may be expensive) - + Returns: Dictionary mapping transition names to transition classes that are valid for the current state """ available = get_available_transitions(entity) - + if not validate: return available - + valid_transitions = {} - + for name, transition_class in available.items(): try: # Get current state information current_state_object = StateManager.get_current_state_object(entity) current_state = current_state_object.state if current_state_object else None - + # Build minimal context for validation from .transitions import TransitionContext + context = TransitionContext( entity=entity, current_user=user, current_state_object=current_state_object, current_state=current_state, target_state=transition_class.get_target_state(), - organization_id=getattr(entity, 'organization_id', None) + organization_id=getattr(entity, 'organization_id', None), ) - + # Use class-level validation that doesn't require an instance if transition_class.can_transition_from_state(context): valid_transitions[name] = transition_class - + except (TransitionValidationError, Exception): # Transition is not valid for current state/context continue - + return valid_transitions -def create_transition_from_dict( - transition_class: Type[BaseTransition], - data: Dict[str, Any] -) -> BaseTransition: +def create_transition_from_dict(transition_class: Type[BaseTransition], data: Dict[str, Any]) -> BaseTransition: """ Create a transition instance from a dictionary of data. - + This handles Pydantic validation and provides clear error messages. - + Args: transition_class: The transition class to instantiate data: Dictionary of transition data - + Returns: Validated transition instance - + Raises: ValueError: If data validation fails """ try: return transition_class(**data) except Exception as e: - raise ValueError(f"Failed to create {transition_class.__name__}: {e}") + raise ValueError(f'Failed to create {transition_class.__name__}: {e}') def get_transition_schema(transition_class: Type[BaseTransition]) -> Dict[str, Any]: """ Get the JSON schema for a transition class. - + Useful for generating API documentation or frontend forms. - + Args: transition_class: The transition class - + Returns: JSON schema dictionary """ return transition_class.model_json_schema() -def validate_transition_data( - transition_class: Type[BaseTransition], - data: Dict[str, Any] -) -> Dict[str, List[str]]: +def validate_transition_data(transition_class: Type[BaseTransition], data: Dict[str, Any]) -> Dict[str, List[str]]: """ Validate transition data without creating an instance. - + Args: transition_class: The transition class data: Data to validate - + Returns: Dictionary of field names to error messages (empty if valid) """ @@ -218,68 +197,63 @@ def validate_transition_data( def get_entity_state_flow(entity: Model) -> List[Dict[str, Any]]: """ Get a summary of the state flow for an entity type. - + This analyzes all registered transitions and builds a flow diagram. - + Args: entity: Example entity instance - + Returns: List of state flow information """ entity_name = entity._meta.model_name.lower() transitions = transition_registry.get_transitions_for_entity(entity_name) - + # Build state flow information states = set() flows = [] - + for transition_name, transition_class in transitions.items(): # Create instance to get target state try: transition = transition_class() target = transition.target_state states.add(target) - - flows.append({ - 'transition_name': transition_name, - 'transition_class': transition_class.__name__, - 'target_state': target, - 'description': transition_class.__doc__ or '', - 'fields': list(transition_class.model_fields.keys()) - }) + + flows.append( + { + 'transition_name': transition_name, + 'transition_class': transition_class.__name__, + 'target_state': target, + 'description': transition_class.__doc__ or '', + 'fields': list(transition_class.model_fields.keys()), + } + ) except Exception: continue - + return flows # Backward compatibility helpers -def transition_state_declarative( - entity: Model, - transition_name: str, - user=None, - **transition_data -) -> BaseState: + +def transition_state_declarative(entity: Model, transition_name: str, user=None, **transition_data) -> BaseState: """ Backward-compatible helper for transitioning state declaratively. - + This provides a similar interface to StateManager.transition_state but uses the declarative system. """ return execute_transition( - entity=entity, - transition_name=transition_name, - transition_data=transition_data, - user=user + entity=entity, transition_name=transition_name, transition_data=transition_data, user=user ) class TransitionBuilder: """ Builder class for constructing and executing transitions fluently. - + Example usage: result = (TransitionBuilder(entity) .transition('start_task') @@ -287,56 +261,56 @@ class TransitionBuilder: .by_user(request.user) .execute()) """ - + def __init__(self, entity: Model): self.entity = entity self._transition_name: Optional[str] = None self._transition_data: Dict[str, Any] = {} self._user = None self._context_data: Dict[str, Any] = {} - + def transition(self, name: str) -> 'TransitionBuilder': """Set the transition name""" self._transition_name = name return self - + def with_data(self, **data) -> 'TransitionBuilder': """Add transition data""" self._transition_data.update(data) return self - + def by_user(self, user) -> 'TransitionBuilder': """Set the executing user""" self._user = user return self - + def with_context(self, **context) -> 'TransitionBuilder': """Add context data""" self._context_data.update(context) return self - + def execute(self) -> BaseState: """Execute the configured transition""" if not self._transition_name: - raise ValueError("Transition name not specified") - + raise ValueError('Transition name not specified') + return execute_transition( entity=self.entity, transition_name=self._transition_name, transition_data=self._transition_data, user=self._user, - **self._context_data + **self._context_data, ) - + def validate(self) -> Dict[str, List[str]]: """Validate the configured transition without executing""" if not self._transition_name: - raise ValueError("Transition name not specified") - + raise ValueError('Transition name not specified') + entity_name = self.entity._meta.model_name.lower() transition_class = transition_registry.get_transition(entity_name, self._transition_name) - + if not transition_class: raise ValueError(f"Transition '{self._transition_name}' not found for entity '{entity_name}'") - - return validate_transition_data(transition_class, self._transition_data) \ No newline at end of file + + return validate_transition_data(transition_class, self._transition_data) diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 01621d24cdcf..92bd35a544db 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -1,18 +1,18 @@ """ Declarative Pydantic-based transition system for FSM engine. -This module provides a framework for defining state transitions as first-class +This module provides a framework for defining state transitions as first-class Pydantic models with built-in validation, context passing, and middleware-like functionality for enhanced declarative state management. """ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union +from typing import Any, Dict, Generic, Optional, Type, TypeVar from django.contrib.auth import get_user_model from django.db.models import Model -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from .models import BaseState @@ -26,36 +26,36 @@ class TransitionContext(BaseModel, Generic[EntityType, StateModelType]): """ Context object passed to all transitions containing middleware-like information. - + This provides access to current state, entity, user, and other contextual information needed for transition validation and execution. """ - + model_config = ConfigDict(arbitrary_types_allowed=True) - + # Core context information - entity: Any = Field(..., description="The entity being transitioned") - current_user: Optional[Any] = Field(None, description="User triggering the transition") - current_state_object: Optional[Any] = Field(None, description="Full current state object") - current_state: Optional[str] = Field(None, description="Current state as string") - target_state: str = Field(..., description="Target state for this transition") - + entity: Any = Field(..., description='The entity being transitioned') + current_user: Optional[Any] = Field(None, description='User triggering the transition') + current_state_object: Optional[Any] = Field(None, description='Full current state object') + current_state: Optional[str] = Field(None, description='Current state as string') + target_state: str = Field(..., description='Target state for this transition') + # Timing and metadata - timestamp: datetime = Field(default_factory=datetime.now, description="When transition was initiated") - transition_name: Optional[str] = Field(None, description="Name of the transition method") - + timestamp: datetime = Field(default_factory=datetime.now, description='When transition was initiated') + transition_name: Optional[str] = Field(None, description='Name of the transition method') + # Additional context data - request_data: Dict[str, Any] = Field(default_factory=dict, description="Additional request/context data") - metadata: Dict[str, Any] = Field(default_factory=dict, description="Transition-specific metadata") - + request_data: Dict[str, Any] = Field(default_factory=dict, description='Additional request/context data') + metadata: Dict[str, Any] = Field(default_factory=dict, description='Transition-specific metadata') + # Organizational context - organization_id: Optional[int] = Field(None, description="Organization context for the transition") - + organization_id: Optional[int] = Field(None, description='Organization context for the transition') + @property def has_current_state(self) -> bool: """Check if entity has a current state""" return self.current_state is not None - + @property def is_initial_transition(self) -> bool: """Check if this is the first state transition for the entity""" @@ -64,7 +64,7 @@ def is_initial_transition(self) -> bool: class TransitionValidationError(Exception): """Exception raised when transition validation fails""" - + def __init__(self, message: str, context: Optional[Dict[str, Any]] = None): super().__init__(message) self.context = context or {} @@ -73,24 +73,24 @@ def __init__(self, message: str, context: Optional[Dict[str, Any]] = None): class BaseTransition(BaseModel, ABC, Generic[EntityType, StateModelType]): """ Abstract base class for all declarative state transitions. - + This provides the framework for implementing transitions as first-class Pydantic models with built-in validation, context handling, and execution logic. - + Example usage: class StartTaskTransition(BaseTransition[Task, TaskState]): assigned_user_id: int = Field(..., description="User assigned to start the task") estimated_duration: Optional[int] = Field(None, description="Estimated completion time in hours") - + @property def target_state(self) -> str: return TaskStateChoices.IN_PROGRESS - + def validate_transition(self, context: TransitionContext[Task, TaskState]) -> bool: if context.current_state == TaskStateChoices.COMPLETED: raise TransitionValidationError("Cannot start an already completed task") return True - + def transition(self, context: TransitionContext[Task, TaskState]) -> Dict[str, Any]: return { "assigned_user_id": self.assigned_user_id, @@ -98,172 +98,166 @@ def transition(self, context: TransitionContext[Task, TaskState]) -> Dict[str, A "started_at": context.timestamp.isoformat() } """ - - model_config = ConfigDict( - arbitrary_types_allowed=True, - validate_assignment=True, - use_enum_values=True - ) - + + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True, use_enum_values=True) + def __init__(self, **data): super().__init__(**data) self.__context: Optional[TransitionContext[EntityType, StateModelType]] = None - + @property def context(self) -> Optional[TransitionContext[EntityType, StateModelType]]: """Access the current transition context""" return getattr(self, '_BaseTransition__context', None) - + @context.setter def context(self, value: TransitionContext[EntityType, StateModelType]): """Set the transition context""" self.__context = value - + @property @abstractmethod def target_state(self) -> str: """ The target state this transition leads to. - + Returns: String representation of the target state """ pass - + @property def transition_name(self) -> str: """ Name of this transition for audit purposes. - + Defaults to the class name in snake_case. """ class_name = self.__class__.__name__ # Convert CamelCase to snake_case - result = "" + result = '' for i, char in enumerate(class_name): if char.isupper() and i > 0: - result += "_" + result += '_' result += char.lower() return result - + @classmethod def get_target_state(cls) -> Optional[str]: """ Get the target state for this transition class without creating an instance. - + Override this in subclasses where the target state is known at the class level. - + Returns: The target state name, or None if it depends on instance data """ return None - + @classmethod def can_transition_from_state(cls, context: TransitionContext[EntityType, StateModelType]) -> bool: """ Class-level validation for whether this transition type is allowed from the current state. - + This method checks if the transition is structurally valid (e.g., allowed state transitions) without needing the actual transition data. Override this to implement state-based rules. - + Args: context: The transition context containing entity, user, and state information - + Returns: True if transition type is allowed from current state, False otherwise """ return True - + def validate_transition(self, context: TransitionContext[EntityType, StateModelType]) -> bool: """ Validate whether this specific transition instance can be performed. - + This method validates both the transition type (via can_transition_from_state) and the specific transition data. Override to add data-specific validation. - + Args: context: The transition context containing entity, user, and state information - + Returns: True if transition is valid, False otherwise - + Raises: TransitionValidationError: If transition validation fails with specific reason """ # First check if this transition type is allowed if not self.can_transition_from_state(context): return False - + # Then perform instance-specific validation return True - + def pre_transition_hook(self, context: TransitionContext[EntityType, StateModelType]) -> None: """ Hook called before the transition is executed. - + Use this for any setup or preparation needed before state change. Override in subclasses as needed. - + Args: context: The transition context """ pass - + @abstractmethod def transition(self, context: TransitionContext[EntityType, StateModelType]) -> Dict[str, Any]: """ Execute the transition and return context data for the state record. - + This is the core method that implements the transition logic. Must be implemented by all concrete transition classes. - + Args: context: The transition context containing all necessary information - + Returns: Dictionary of context data to be stored with the state record - + Raises: TransitionValidationError: If transition cannot be completed """ pass - + def post_transition_hook( - self, - context: TransitionContext[EntityType, StateModelType], - state_record: StateModelType + self, context: TransitionContext[EntityType, StateModelType], state_record: StateModelType ) -> None: """ Hook called after the transition has been successfully executed. - + Use this for any cleanup, notifications, or side effects after state change. Override in subclasses as needed. - + Args: context: The transition context state_record: The newly created state record """ pass - + def get_reason(self, context: TransitionContext[EntityType, StateModelType]) -> str: """ Get a human-readable reason for this transition. - + Override in subclasses to provide more specific reasons. - + Args: context: The transition context - + Returns: Human-readable reason string """ - user_info = f"by {context.current_user}" if context.current_user else "automatically" - return f"{self.__class__.__name__} executed {user_info}" - + user_info = f'by {context.current_user}' if context.current_user else 'automatically' + return f'{self.__class__.__name__} executed {user_info}' + def execute(self, context: TransitionContext[EntityType, StateModelType]) -> StateModelType: """ Execute the complete transition workflow. - + This orchestrates the entire transition process: 1. Set context on the transition instance 2. Validate the transition @@ -271,61 +265,61 @@ def execute(self, context: TransitionContext[EntityType, StateModelType]) -> Sta 4. Perform the actual transition 5. Create the state record 6. Execute post-transition hooks - + Args: context: The transition context - + Returns: The newly created state record - + Raises: TransitionValidationError: If validation fails Exception: If transition execution fails """ # Set context for access during transition self.context = context - + # Update context with transition name context.transition_name = self.transition_name - + try: # Validate transition if not self.validate_transition(context): raise TransitionValidationError( - f"Transition validation failed for {self.transition_name}", - {"current_state": context.current_state, "target_state": self.target_state} + f'Transition validation failed for {self.transition_name}', + {'current_state': context.current_state, 'target_state': self.target_state}, ) - + # Pre-transition hook self.pre_transition_hook(context) - + # Execute the transition logic transition_data = self.transition(context) - + # Create the state record through StateManager from .state_manager import StateManager - + success = StateManager.transition_state( entity=context.entity, new_state=self.target_state, transition_name=self.transition_name, user=context.current_user, context=transition_data, - reason=self.get_reason(context) + reason=self.get_reason(context), ) - + if not success: - raise TransitionValidationError(f"Failed to create state record for {self.transition_name}") - + raise TransitionValidationError(f'Failed to create state record for {self.transition_name}') + # Get the newly created state record state_record = StateManager.get_current_state_object(context.entity) - + # Post-transition hook self.post_transition_hook(context, state_record) - + return state_record - - except Exception as e: + + except Exception: # Clear context on error self.context = None raise @@ -334,23 +328,18 @@ def execute(self, context: TransitionContext[EntityType, StateModelType]) -> Sta class TransitionRegistry: """ Registry for managing declarative transitions. - + Provides a centralized way to register, discover, and execute transitions for different entity types and state models. """ - + def __init__(self): self._transitions: Dict[str, Dict[str, Type[BaseTransition]]] = {} - - def register( - self, - entity_name: str, - transition_name: str, - transition_class: Type[BaseTransition] - ): + + def register(self, entity_name: str, transition_name: str, transition_class: Type[BaseTransition]): """ Register a transition class for an entity. - + Args: entity_name: Name of the entity type (e.g., 'task', 'annotation') transition_name: Name of the transition (e.g., 'start_task', 'submit_annotation') @@ -358,42 +347,38 @@ def register( """ if entity_name not in self._transitions: self._transitions[entity_name] = {} - + self._transitions[entity_name][transition_name] = transition_class - - def get_transition( - self, - entity_name: str, - transition_name: str - ) -> Optional[Type[BaseTransition]]: + + def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type[BaseTransition]]: """ Get a registered transition class. - + Args: entity_name: Name of the entity type transition_name: Name of the transition - + Returns: The transition class if found, None otherwise """ return self._transitions.get(entity_name, {}).get(transition_name) - + def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type[BaseTransition]]: """ Get all registered transitions for an entity type. - + Args: entity_name: Name of the entity type - + Returns: Dictionary mapping transition names to transition classes """ return self._transitions.get(entity_name, {}).copy() - + def list_entities(self) -> list[str]: """Get a list of all registered entity types.""" return list(self._transitions.keys()) - + def execute_transition( self, entity_name: str, @@ -401,11 +386,11 @@ def execute_transition( entity: Model, transition_data: Dict[str, Any], user: Optional[User] = None, - **context_kwargs + **context_kwargs, ) -> StateModelType: """ Execute a registered transition. - + Args: entity_name: Name of the entity type transition_name: Name of the transition @@ -413,10 +398,10 @@ def execute_transition( transition_data: Data for the transition (will be validated by Pydantic) user: User executing the transition **context_kwargs: Additional context data - + Returns: The newly created state record - + Raises: ValueError: If transition is not found TransitionValidationError: If transition validation fails @@ -424,15 +409,16 @@ def execute_transition( transition_class = self.get_transition(entity_name, transition_name) if not transition_class: raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") - + # Create transition instance with provided data transition = transition_class(**transition_data) - + # Get current state information from .state_manager import StateManager + current_state_object = StateManager.get_current_state_object(entity) current_state = current_state_object.state if current_state_object else None - + # Build transition context context = TransitionContext( entity=entity, @@ -441,9 +427,9 @@ def execute_transition( current_state=current_state, target_state=transition.target_state, organization_id=getattr(entity, 'organization_id', None), - **context_kwargs + **context_kwargs, ) - + # Execute the transition return transition.execute(context) @@ -455,16 +441,17 @@ def execute_transition( def register_transition(entity_name: str, transition_name: str = None): """ Decorator to register a transition class. - + Args: entity_name: Name of the entity type transition_name: Name of the transition (defaults to class name in snake_case) - + Example: @register_transition('task', 'start_task') class StartTaskTransition(BaseTransition[Task, TaskState]): # ... implementation """ + def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: name = transition_name if name is None: @@ -472,15 +459,15 @@ def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: class_name = transition_class.__name__ if class_name.endswith('Transition'): class_name = class_name[:-10] # Remove 'Transition' suffix - + # Convert CamelCase to snake_case - name = "" + name = '' for i, char in enumerate(class_name): if char.isupper() and i > 0: - name += "_" + name += '_' name += char.lower() - + transition_registry.register(entity_name, name, transition_class) return transition_class - - return decorator \ No newline at end of file + + return decorator From 46e1597eae6d55ea9de71a39f16e9115534d8600 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 12:27:40 -0500 Subject: [PATCH 14/83] fix fsm concurrency tests --- label_studio/fsm/tests/test_performance_concurrency.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/label_studio/fsm/tests/test_performance_concurrency.py b/label_studio/fsm/tests/test_performance_concurrency.py index bcfc5993dfd8..896b4fe65556 100644 --- a/label_studio/fsm/tests/test_performance_concurrency.py +++ b/label_studio/fsm/tests/test_performance_concurrency.py @@ -702,10 +702,12 @@ def mixed_operations_worker(worker_id): # Validate stress test results total_operations = sum( - stats['transitions_created'] - + stats['validations_performed'] - + stats['transitions_executed'] - + stats['registry_lookups'] + [ + stats['transitions_created'], + stats['validations_performed'], + stats['transitions_executed'], + stats['registry_lookups'], + ] ) # Should have performed substantial work From 21f6558dc58d872dfff7ca77e091db11c8cb5b21 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 13:12:21 -0500 Subject: [PATCH 15/83] fix remaining fsm tests --- label_studio/fsm/api.py | 20 ++++++++++--------- label_studio/fsm/models.py | 7 +++++++ label_studio/fsm/state_manager.py | 10 +++++++++- .../fsm/tests/test_fsm_integration.py | 19 ++++++++++++++++++ label_studio/fsm/tests/test_uuid7_utils.py | 7 ++++--- label_studio/fsm/utils.py | 5 +++-- 6 files changed, 53 insertions(+), 15 deletions(-) diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py index c57e14722ed3..81b423b60869 100644 --- a/label_studio/fsm/api.py +++ b/label_studio/fsm/api.py @@ -10,10 +10,9 @@ from django.shortcuts import get_object_or_404 from rest_framework import status, viewsets from rest_framework.decorators import action +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from label_studio.core.permissions import AllPermissions - from .models import get_state_model_for_entity from .serializers import StateHistorySerializer, StateTransitionSerializer from .state_manager import get_state_manager @@ -31,7 +30,7 @@ class FSMViewSet(viewsets.ViewSet): - Trigger state transitions """ - permission_classes = [AllPermissions] + permission_classes = [IsAuthenticated] def _get_entity_and_state_model(self, entity_type: str, entity_id: int): """Helper to get entity instance and its state model""" @@ -82,9 +81,10 @@ def current_state(self, request, entity_type=None, entity_id=None): "entity_id": 123 } """ - try: - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + # Let Http404 from _get_entity_and_state_model pass through + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + try: # Get current state using the configured state manager StateManager = get_state_manager() current_state = StateManager.get_current_state(entity) @@ -130,9 +130,10 @@ def state_history(self, request, entity_type=None, entity_id=None): ] } """ - try: - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + # Let Http404 from _get_entity_and_state_model pass through + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + try: # Get query parameters limit = min(int(request.query_params.get('limit', 100)), 1000) # Max 1000 include_context = request.query_params.get('include_context', 'false').lower() == 'true' @@ -181,9 +182,10 @@ def transition_state(self, request, entity_type=None, entity_id=None): "entity_id": 123 } """ - try: - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + # Let Http404 from _get_entity_and_state_model pass through + entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) + try: # Validate request data serializer = StateTransitionSerializer(data=request.data) serializer.is_valid(raise_exception=True) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index bd72f879f0f0..af9bb4637cf2 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -304,6 +304,13 @@ class Meta: ] ordering = ['-id'] + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for ProjectState creation""" + return { + 'created_by_id': entity.created_by.id if entity.created_by else None, + } + @property def is_terminal_state(self) -> bool: """Check if this is a terminal project state""" diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 13aa6c3d856e..1670bc87a92a 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -193,6 +193,13 @@ def transition_state( if hasattr(state_model, 'get_denormalized_fields'): denormalized_fields = state_model.get_denormalized_fields(entity) + # Get organization from user's active organization + organization_id = ( + user.active_organization.id + if user and hasattr(user, 'active_organization') and user.active_organization + else None + ) + new_state_record = state_model.objects.create( **{entity._meta.model_name: entity}, state=new_state, @@ -201,6 +208,7 @@ def transition_state( triggered_by=user, context_data=context or {}, reason=reason, + organization_id=organization_id, **denormalized_fields, ) @@ -274,7 +282,7 @@ def invalidate_cache(cls, entity: Model): @classmethod def warm_cache(cls, entities: List[Model]): """ - Warm cache with current states for a list of entities. + invalidate_cacheWarm cache with current states for a list of entities. Basic implementation that can be optimized by Enterprise with bulk queries and advanced caching strategies. diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index cbfcb405ad93..648735a8b620 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -26,6 +26,11 @@ def setUp(self): self.project = Project.objects.create(title='Test Project', created_by=self.user) self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + def test_task_state_creation(self): """Test TaskState creation and basic functionality""" task_state = TaskState.objects.create( @@ -108,6 +113,11 @@ def setUp(self): self.task = Task.objects.create(project=self.project, data={'text': 'test'}) self.StateManager = get_state_manager() + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + def test_get_current_state_empty(self): """Test getting current state when no states exist""" current_state = self.StateManager.get_current_state(self.task) @@ -178,6 +188,10 @@ def test_get_state_history(self): states = [h.state for h in history] self.assertEqual(states, ['COMPLETED', 'IN_PROGRESS', 'CREATED']) + print(history) + ids = [str(h.id) for h in history] + print(ids) + # Check previous states are set correctly self.assertIsNone(history[2].previous_state) # First state has no previous self.assertEqual(history[1].previous_state, 'CREATED') @@ -211,6 +225,11 @@ def setUp(self): self.task = Task.objects.create(project=self.project, data={'text': 'test'}) self.client.force_authenticate(user=self.user) + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + # Create initial state StateManager = get_state_manager() StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_uuid7_utils.py index 138bf86a05c1..af6abe7d6956 100644 --- a/label_studio/fsm/tests/test_uuid7_utils.py +++ b/label_studio/fsm/tests/test_uuid7_utils.py @@ -85,12 +85,13 @@ def test_uuid7_time_range(self): # Start should be less than end self.assertLess(start_uuid.int, end_uuid.int) - # Timestamps should match input times + # Timestamps should match input times (with 1ms buffer tolerance) start_extracted = timestamp_from_uuid7(start_uuid) end_extracted = timestamp_from_uuid7(end_uuid) - self.assertLess(abs((start_extracted - start_time).total_seconds()), 0.001) - self.assertLess(abs((end_extracted - end_time).total_seconds()), 0.001) + # Account for 1ms buffer added in uuid7_time_range + self.assertLess(abs((start_extracted - start_time).total_seconds()), 0.002) + self.assertLess(abs((end_extracted - end_time).total_seconds()), 0.002) def test_uuid7_time_range_default_end(self): """Test UUID7 time range with default end time (now)""" diff --git a/label_studio/fsm/utils.py b/label_studio/fsm/utils.py index 4673968494b7..2aa6ae733a43 100644 --- a/label_studio/fsm/utils.py +++ b/label_studio/fsm/utils.py @@ -77,8 +77,9 @@ def uuid7_time_range(start_time: datetime, end_time: Optional[datetime] = None) if end_time is None: end_time = datetime.now(timezone.utc) - start_timestamp_ms = int(start_time.timestamp() * 1000) - end_timestamp_ms = int(end_time.timestamp() * 1000) + # Add a small buffer to account for timing precision issues + start_timestamp_ms = int(start_time.timestamp() * 1000) - 1 # 1ms buffer before + end_timestamp_ms = int(end_time.timestamp() * 1000) + 1 # 1ms buffer after # Create UUID7 with specific timestamp using proper bit layout # UUID7 format: timestamp_ms(48) + ver(4) + rand_a(12) + var(2) + rand_b(62) From 7fa40b2cb60f76080b23290b806bbd96ac4613ff Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 14:17:25 -0500 Subject: [PATCH 16/83] removing implementation details of fsm to break up the PR --- label_studio/core/urls.py | 1 - label_studio/fsm/README.md | 309 +++--- label_studio/fsm/api.py | 26 +- label_studio/fsm/migrations/0001_initial.py | 386 ------- label_studio/fsm/models.py | 166 +-- label_studio/fsm/state_choices.py | 98 +- .../fsm/tests/test_declarative_transitions.py | 994 ++++-------------- .../fsm/tests/test_fsm_integration.py | 303 ------ .../tests/test_integration_django_models.py | 671 ------------ 9 files changed, 409 insertions(+), 2545 deletions(-) delete mode 100644 label_studio/fsm/migrations/0001_initial.py delete mode 100644 label_studio/fsm/tests/test_fsm_integration.py delete mode 100644 label_studio/fsm/tests/test_integration_django_models.py diff --git a/label_studio/core/urls.py b/label_studio/core/urls.py index 4bdb57b18295..23998217d0dc 100644 --- a/label_studio/core/urls.py +++ b/label_studio/core/urls.py @@ -105,7 +105,6 @@ re_path(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework')), re_path(r'^', include('jwt_auth.urls')), re_path(r'^', include('session_policy.urls')), - re_path(r'^', include('fsm.urls')), # Finite State Machine APIs path('docs/api/schema/', SpectacularAPIView.as_view(), name='schema'), path('docs/api/schema/swagger-ui/', SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'), path('docs/api/schema/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='redoc'), diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 5fb29ca5c173..1828d36ec3c5 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -1,207 +1,228 @@ -# Label Studio FSM (Finite State Machine) +# FSM (Finite State Machine) Framework -Core finite state machine functionality for Label Studio that provides the foundation for state tracking across entities like Tasks, Annotations, and Projects. +A high-performance Django-based finite state machine framework with UUID7 optimization, declarative transitions, and comprehensive state management capabilities. ## Overview -The Label Studio FSM system provides: +The FSM framework provides: -- **Core Infrastructure**: Base state tracking models and managers -- **UUID7 Optimization**: Time-series optimized state records using UUID7 -- **REST API**: Endpoints for state management -- **Admin Interface**: Django admin integration for state inspection +- **Core Infrastructure**: Abstract base state models and managers +- **UUID7 Optimization**: Time-series optimized state records with natural ordering +- **Declarative Transitions**: Pydantic-based transition system with validation +- **REST API**: Generic endpoints for state management +- **High Performance**: Optimized for high-volume state changes with caching +- **Extensible**: Plugin-based architecture for custom implementations ## Architecture ### Core Components -1. **BaseState**: Abstract model providing common state tracking functionality -2. **StateManager**: High-performance state management with caching -3. **Core State Models**: Task, Annotation, and Project state tracking +1. **BaseState**: Abstract model providing UUID7-optimized state tracking +2. **StateManager**: High-performance state management with intelligent caching +3. **Transition System**: Declarative, Pydantic-based transitions with validation +4. **State Registry**: Dynamic registration system for entities and transitions +5. **API Layer**: Generic REST endpoints for state operations -## Usage +## Quick Start -### Basic State Management +### 1. Define State Choices ```python -from fsm.state_manager import get_state_manager -from tasks.models import Task - -# Get current state -StateManager = get_state_manager() -task = Task.objects.get(id=123) -current_state = StateManager.get_current_state(task) - -# Transition state -success = StateManager.transition_state( - entity=task, - new_state='IN_PROGRESS', - user=request.user, - reason='User started annotation work' -) - -# Get state history -history = StateManager.get_state_history(task, limit=10) +from django.db import models +from django.utils.translation import gettext_lazy as _ + +class OrderStateChoices(models.TextChoices): + CREATED = 'CREATED', _('Created') + PROCESSING = 'PROCESSING', _('Processing') + SHIPPED = 'SHIPPED', _('Shipped') + DELIVERED = 'DELIVERED', _('Delivered') + CANCELLED = 'CANCELLED', _('Cancelled') ``` -### Integration with Existing Models +### 2. Create State Model ```python -# Add FSM functionality to existing models -from fsm.integration import FSMIntegrationMixin +from fsm.models import BaseState +from fsm.state_choices import register_state_choices -class Task(FSMIntegrationMixin, BaseTask): - class Meta: - proxy = True +# Register state choices +register_state_choices('order', OrderStateChoices) -# Now you can use FSM methods directly -task = Task.objects.get(id=123) -current_state = task.current_fsm_state -task.transition_fsm_state('COMPLETED', user=user) +class OrderState(BaseState): + # Entity relationship + order = models.ForeignKey('shop.Order', related_name='fsm_states', on_delete=models.CASCADE) + + # Override state field with choices + state = models.CharField(max_length=50, choices=OrderStateChoices.choices, db_index=True) + + # Denormalized fields for performance + customer_id = models.PositiveIntegerField(db_index=True) + + class Meta: + indexes = [ + models.Index(fields=['order_id', '-id'], name='order_current_state_idx'), + ] ``` -### API Usage +### 3. Define Transitions -```bash -# Get current state -GET /api/fsm/task/123/current/ +```python +from fsm.transitions import BaseTransition, register_transition +from pydantic import Field -# Get state history -GET /api/fsm/task/123/history/?limit=10 - -# Transition state -POST /api/fsm/task/123/transition/ -{ - "new_state": "COMPLETED", - "reason": "Task completed by user" -} +@register_transition('order', 'process_order') +class ProcessOrderTransition(BaseTransition): + processor_id: int = Field(..., description="ID of user processing the order") + priority: str = Field('normal', description="Processing priority") + + @property + def target_state(self) -> str: + return OrderStateChoices.PROCESSING + + def validate_transition(self, context) -> bool: + return context.current_state == OrderStateChoices.CREATED + + def transition(self, context) -> dict: + return { + "processor_id": self.processor_id, + "priority": self.priority, + "processed_at": context.timestamp.isoformat() + } ``` -## Dependencies - -The FSM system requires the `uuid-utils` library for UUID7 support: +### 4. Execute Transitions -```bash -pip install uuid-utils>=0.11.0 +```python +from fsm.transition_utils import execute_transition + +# Execute transition +result = execute_transition( + entity=order, + transition_name='process_order', + transition_data={'processor_id': 123, 'priority': 'high'}, + user=request.user +) ``` -This dependency is automatically included in Label Studio's requirements. - -### Why UUID7? +### 5. Query States -UUID7 provides significant performance benefits for time-series data like state transitions: - -- **Natural Time Ordering**: Records are naturally ordered by creation time without requiring additional indexes -- **Global Uniqueness**: Works across distributed systems and database shards -- **INSERT-only Architecture**: No UPDATE operations needed, maximizing concurrency -- **Time-based Partitioning**: Enables horizontal scaling to billions of records +```python +from fsm.state_manager import get_state_manager -## Configuration +StateManager = get_state_manager() -### Django Settings +# Get current state +current_state = StateManager.get_current_state(order) -Add the FSM app to your `INSTALLED_APPS`: +# Get state history +history = StateManager.get_state_history(order, limit=10) -```python -INSTALLED_APPS = [ - # ... other apps - 'label_studio.fsm', - # ... other apps -] +# Bulk operations for performance +orders = Order.objects.all()[:1000] +states = StateManager.bulk_get_current_states(orders) ``` -### Optional Settings +## Key Features -```python -# FSM Configuration -FSM_CACHE_TTL = 300 # Cache timeout in seconds (default: 300) -FSM_AUTO_CREATE_STATES = False # Auto-create states on entity creation (default: False) -FSM_STATE_MANAGER_CLASS = None # Custom state manager class (default: None) -``` +### UUID7 Performance Optimization -## Database Migrations +- **Natural Time Ordering**: UUID7 provides chronological ordering without separate timestamp indexes +- **High Concurrency**: INSERT-only approach eliminates locking contention +- **Scalability**: Supports billions of state records with consistent performance -Run migrations to create the FSM tables: +### Declarative Transitions -```bash -python manage.py migrate fsm -``` +- **Pydantic Validation**: Strong typing and automatic validation +- **Composable Logic**: Reusable transition classes with inheritance +- **Hooks System**: Pre/post transition hooks for custom logic -This will create: -- `fsm_task_states`: Task state tracking -- `fsm_annotation_states`: Annotation state tracking -- `fsm_project_states`: Project state tracking +### Advanced Querying -## Performance Considerations +```python +# Time-range queries using UUID7 +from datetime import datetime, timedelta +recent_states = StateManager.get_states_since( + entity=order, + since=datetime.now() - timedelta(hours=24) +) -### UUID7 Benefits +# Bulk operations +orders = Order.objects.filter(status='active') +current_states = StateManager.bulk_get_current_states(orders) +``` -The FSM system uses UUID7 for optimal time-series performance: +### API Integration -- **Natural Time Ordering**: No need for `created_at` indexes -- **INSERT-only Architecture**: Maximum concurrency, no row locks -- **Global Uniqueness**: Supports distributed systems -- **Time-based Partitioning**: Scales to billions of records +The framework provides generic REST endpoints: -### Caching Strategy +``` +GET /api/fsm/{entity_type}/{entity_id}/current/ # Current state +GET /api/fsm/{entity_type}/{entity_id}/history/ # State history +POST /api/fsm/{entity_type}/{entity_id}/transition/ # Execute transition +``` -- **Write-through Caching**: Immediate consistency after state transitions -- **Configurable TTL**: Balance between performance and freshness -- **Cache Key Strategy**: Optimized for entity-based lookups +Extend the base viewset for your application: -### Indexes +```python +from fsm.api import FSMViewSet + +class MyFSMViewSet(FSMViewSet): + def _get_entity_model(self, entity_type: str): + entity_mapping = { + 'order': 'shop.Order', + 'ticket': 'support.Ticket', + } + # ... implementation +``` -Critical indexes for performance: -- `(entity_id, id DESC)`: Current state lookup using UUID7 ordering -- `(entity_id, id)`: State history queries +## Performance Characteristics +- **State Queries**: O(1) current state lookup via UUID7 ordering +- **History Queries**: Optimal for time-series access patterns +- **Bulk Operations**: Efficient batch processing for thousands of entities +- **Cache Integration**: Intelligent caching with automatic invalidation +- **Memory Efficiency**: Minimal memory footprint for state objects -## Monitoring and Debugging +## Extension Points -### Admin Interface +### Custom State Manager -Access state records via Django admin: -- `/admin/fsm/taskstate/` -- `/admin/fsm/annotationstate/` -- `/admin/fsm/projectstate/` +```python +from fsm.state_manager import BaseStateManager -### Logging +class CustomStateManager(BaseStateManager): + def get_current_state(self, entity): + # Custom logic + return super().get_current_state(entity) +``` -FSM operations are logged at appropriate levels: -- `INFO`: Successful state transitions -- `ERROR`: Failed transitions and system errors -- `DEBUG`: Cache hits/misses and detailed operation info +### Custom Validation +```python +@register_transition('order', 'validate_payment') +class PaymentValidationTransition(BaseTransition): + def validate_transition(self, context) -> bool: + # Custom business logic + return self.check_payment_method(context.entity) +``` -## Migration from Existing Systems +## Framework vs Implementation -The FSM system can run alongside existing state management: +This is the **core framework** - a clean, generic FSM system. Product-specific implementations (state definitions, concrete models, business logic) should be in separate branches/modules for: -1. **Parallel Operation**: FSM tracks states without affecting existing logic -2. **Gradual Migration**: Replace existing state checks with FSM calls over time -3. **Backfill Support**: Historical states can be backfilled from existing data +- **Clean Architecture**: Framework logic separated from business logic +- **Reusability**: Framework can be used across different projects +- **Maintainability**: Changes to business logic don't affect framework +- **Review Process**: Framework and implementation can be reviewed independently -## Testing +## Migration from Other FSM Libraries -Test the FSM system: +The framework provides migration utilities and is designed to be compatible with existing Django FSM patterns while offering significant performance improvements through UUID7 optimization. -```python -from fsm.state_manager import StateManager -from tasks.models import Task +## Contributing -def test_task_state_transition(): - task = Task.objects.create(...) - - # Test initial state - assert StateManager.get_current_state(task) is None - - # Test transition - success = StateManager.transition_state(task, 'CREATED') - assert success - assert StateManager.get_current_state(task) == 'CREATED' - - # Test history - history = StateManager.get_state_history(task) - assert len(history) == 1 - assert history[0].state == 'CREATED' -``` +When contributing: +- Keep framework code generic and reusable +- Add product-specific code to appropriate implementation branches +- Include performance tests for UUID7 optimizations +- Document extension points and customization options \ No newline at end of file diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py index 81b423b60869..e8d271b7c74a 100644 --- a/label_studio/fsm/api.py +++ b/label_studio/fsm/api.py @@ -1,7 +1,8 @@ """ -Core FSM API endpoints for Label Studio. +Core FSM API endpoints. -Provides basic API endpoints for state management that can be extended +Provides generic API endpoints for state management that can be extended +for any application using the FSM framework. """ import logging @@ -50,15 +51,22 @@ def _get_entity_and_state_model(self, entity_type: str, entity_id: int): return entity, state_model def _get_entity_model(self, entity_type: str): - """Get Django model class for entity type""" + """ + Get Django model class for entity type. + + This method should be overridden by subclasses to provide + application-specific entity type mappings. + + Example: + entity_mapping = { + 'order': 'shop.Order', + 'ticket': 'support.Ticket', + } + """ from django.apps import apps - # Map entity types to app.model - entity_mapping = { - 'task': 'tasks.Task', - 'annotation': 'tasks.Annotation', - 'project': 'projects.Project', - } + # Default empty mapping - override in subclasses + entity_mapping = {} model_path = entity_mapping.get(entity_type.lower()) if not model_path: diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py deleted file mode 100644 index 322f6bcffc41..000000000000 --- a/label_studio/fsm/migrations/0001_initial.py +++ /dev/null @@ -1,386 +0,0 @@ -# Generated by Django 5.1.10 on 2025-08-26 17:13 - -import django.db.models.deletion -import fsm.utils -from django.conf import settings -from django.db import migrations, models - - -class Migration(migrations.Migration): - - initial = True - - dependencies = [ - ("organizations", "0006_alter_organizationmember_deleted_at"), - ("projects", "0030_project_search_vector_index"), - ("tasks", "0057_annotation_proj_result_octlen_idx_async"), - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] - - operations = [ - migrations.CreateModel( - name="AnnotationState", - fields=[ - ( - "id", - models.UUIDField( - default=fsm.utils.generate_uuid7, - editable=False, - help_text="UUID7 provides natural time ordering and global uniqueness", - primary_key=True, - serialize=False, - ), - ), - ( - "previous_state", - models.CharField( - blank=True, - help_text="Previous state before this transition", - max_length=50, - null=True, - ), - ), - ( - "transition_name", - models.CharField( - blank=True, - help_text="Name of the transition method that triggered this state change", - max_length=100, - null=True, - ), - ), - ( - "context_data", - models.JSONField( - default=dict, - help_text="Additional context data for this transition (e.g., validation results, external IDs)", - ), - ), - ( - "reason", - models.TextField( - blank=True, - help_text="Human-readable reason for this state transition", - ), - ), - ( - "created_at", - models.DateTimeField( - auto_now_add=True, - help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", - ), - ), - ( - "state", - models.CharField( - choices=[ - ("DRAFT", "Draft"), - ("SUBMITTED", "Submitted"), - ("COMPLETED", "Completed"), - ], - db_index=True, - max_length=50, - ), - ), - ( - "task_id", - models.PositiveIntegerField( - db_index=True, - help_text="From annotation.task_id - denormalized for performance", - ), - ), - ( - "project_id", - models.PositiveIntegerField( - db_index=True, - help_text="From annotation.task.project_id - denormalized for performance", - ), - ), - ( - "completed_by_id", - models.PositiveIntegerField( - db_index=True, - help_text="From annotation.completed_by_id - denormalized for performance", - null=True, - ), - ), - ( - "annotation", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="fsm_states", - to="tasks.annotation", - ), - ), - ( - "organization", - models.ForeignKey( - help_text="Organization which owns this state record", - null=True, - on_delete=django.db.models.deletion.CASCADE, - to="organizations.organization", - ), - ), - ( - "triggered_by", - models.ForeignKey( - help_text="User who triggered this state transition", - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "ordering": ["-id"], - "indexes": [ - models.Index( - fields=["annotation_id", "-id"], name="anno_current_state_idx" - ), - models.Index( - fields=["task_id", "state", "-id"], name="anno_task_state_idx" - ), - models.Index( - fields=["completed_by_id", "state", "-id"], - name="anno_user_report_idx", - ), - models.Index( - fields=["project_id", "state", "-id"], - name="anno_project_report_idx", - ), - ], - }, - ), - migrations.CreateModel( - name="ProjectState", - fields=[ - ( - "id", - models.UUIDField( - default=fsm.utils.generate_uuid7, - editable=False, - help_text="UUID7 provides natural time ordering and global uniqueness", - primary_key=True, - serialize=False, - ), - ), - ( - "previous_state", - models.CharField( - blank=True, - help_text="Previous state before this transition", - max_length=50, - null=True, - ), - ), - ( - "transition_name", - models.CharField( - blank=True, - help_text="Name of the transition method that triggered this state change", - max_length=100, - null=True, - ), - ), - ( - "context_data", - models.JSONField( - default=dict, - help_text="Additional context data for this transition (e.g., validation results, external IDs)", - ), - ), - ( - "reason", - models.TextField( - blank=True, - help_text="Human-readable reason for this state transition", - ), - ), - ( - "created_at", - models.DateTimeField( - auto_now_add=True, - help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", - ), - ), - ( - "state", - models.CharField( - choices=[ - ("CREATED", "Created"), - ("IN_PROGRESS", "In Progress"), - ("COMPLETED", "Completed"), - ], - db_index=True, - max_length=50, - ), - ), - ( - "created_by_id", - models.PositiveIntegerField( - db_index=True, - help_text="From project.created_by_id - denormalized for performance", - null=True, - ), - ), - ( - "organization", - models.ForeignKey( - help_text="Organization which owns this state record", - null=True, - on_delete=django.db.models.deletion.CASCADE, - to="organizations.organization", - ), - ), - ( - "project", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="fsm_states", - to="projects.project", - ), - ), - ( - "triggered_by", - models.ForeignKey( - help_text="User who triggered this state transition", - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "ordering": ["-id"], - "indexes": [ - models.Index( - fields=["project_id", "-id"], name="project_current_state_idx" - ), - models.Index( - fields=["organization_id", "state", "-id"], - name="project_org_state_idx", - ), - models.Index( - fields=["organization_id", "-id"], - name="project_org_reporting_idx", - ), - ], - }, - ), - migrations.CreateModel( - name="TaskState", - fields=[ - ( - "id", - models.UUIDField( - default=fsm.utils.generate_uuid7, - editable=False, - help_text="UUID7 provides natural time ordering and global uniqueness", - primary_key=True, - serialize=False, - ), - ), - ( - "previous_state", - models.CharField( - blank=True, - help_text="Previous state before this transition", - max_length=50, - null=True, - ), - ), - ( - "transition_name", - models.CharField( - blank=True, - help_text="Name of the transition method that triggered this state change", - max_length=100, - null=True, - ), - ), - ( - "context_data", - models.JSONField( - default=dict, - help_text="Additional context data for this transition (e.g., validation results, external IDs)", - ), - ), - ( - "reason", - models.TextField( - blank=True, - help_text="Human-readable reason for this state transition", - ), - ), - ( - "created_at", - models.DateTimeField( - auto_now_add=True, - help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", - ), - ), - ( - "state", - models.CharField( - choices=[ - ("CREATED", "Created"), - ("IN_PROGRESS", "In Progress"), - ("COMPLETED", "Completed"), - ], - db_index=True, - max_length=50, - ), - ), - ( - "project_id", - models.PositiveIntegerField( - db_index=True, - help_text="From task.project_id - denormalized for performance", - ), - ), - ( - "organization", - models.ForeignKey( - help_text="Organization which owns this state record", - null=True, - on_delete=django.db.models.deletion.CASCADE, - to="organizations.organization", - ), - ), - ( - "task", - models.ForeignKey( - on_delete=django.db.models.deletion.CASCADE, - related_name="fsm_states", - to="tasks.task", - ), - ), - ( - "triggered_by", - models.ForeignKey( - help_text="User who triggered this state transition", - null=True, - on_delete=django.db.models.deletion.SET_NULL, - to=settings.AUTH_USER_MODEL, - ), - ), - ], - options={ - "ordering": ["-id"], - "indexes": [ - models.Index( - fields=["task_id", "-id"], name="task_current_state_idx" - ), - models.Index( - fields=["project_id", "state", "-id"], - name="task_project_state_idx", - ), - models.Index( - fields=["organization_id", "state", "-id"], - name="task_org_reporting_idx", - ), - models.Index(fields=["task_id", "id"], name="task_history_idx"), - ], - }, - ), - ] diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index af9bb4637cf2..5c520751e0f8 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -9,11 +9,6 @@ from django.db import models from django.db.models import UUIDField -from .state_choices import ( - AnnotationStateChoices, - ProjectStateChoices, - TaskStateChoices, -) from .utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 @@ -44,11 +39,13 @@ class BaseState(models.Model): help_text='UUID7 provides natural time ordering and global uniqueness', ) - organization = models.ForeignKey( - 'organizations.Organization', - on_delete=models.CASCADE, + # Optional organization field - can be overridden or left null + # Applications can add their own organization/tenant fields as needed + organization_id = models.PositiveIntegerField( null=True, - help_text='Organization which owns this state record', + blank=True, + db_index=True, + help_text='Organization ID that owns this state record (for multi-tenant applications)', ) # Core State Fields @@ -172,157 +169,8 @@ def _get_entity_field_name(cls) -> str: return 'entity' -# Core state models for basic Label Studio entities - - -class TaskState(BaseState): - """ - Core task state tracking for Label Studio. - - Provides basic task state management with: - - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) - - High-performance queries with UUID7 ordering - """ - - # Entity Relationship - task = models.ForeignKey('tasks.Task', related_name='fsm_states', on_delete=models.CASCADE, db_index=True) - - # Override state field to add choices constraint - state = models.CharField(max_length=50, choices=TaskStateChoices.choices, db_index=True) - - project_id = models.PositiveIntegerField( - db_index=True, help_text='From task.project_id - denormalized for performance' - ) - - class Meta: - app_label = 'fsm' - indexes = [ - # Critical: Latest state lookup (current state determined by latest UUID7 id) - # Index with DESC order explicitly supports ORDER BY id DESC queries - models.Index(fields=['task_id', '-id'], name='task_current_state_idx'), - # Reporting and filtering - models.Index(fields=['project_id', 'state', '-id'], name='task_project_state_idx'), - models.Index(fields=['organization_id', 'state', '-id'], name='task_org_reporting_idx'), - # History queries - models.Index(fields=['task_id', 'id'], name='task_history_idx'), - ] - # No constraints needed - INSERT-only approach - ordering = ['-id'] - - @classmethod - def get_denormalized_fields(cls, entity): - """Get denormalized fields for TaskState creation""" - return { - 'project_id': entity.project_id, - } - - @property - def is_terminal_state(self) -> bool: - """Check if this is a terminal task state""" - return self.state == TaskStateChoices.COMPLETED - - -class AnnotationState(BaseState): - """ - Core annotation state tracking for Label Studio. - - Provides basic annotation state management with: - - Simple 3-state workflow (DRAFT → SUBMITTED → COMPLETED) - """ - - # Entity Relationship - annotation = models.ForeignKey('tasks.Annotation', on_delete=models.CASCADE, related_name='fsm_states') - - # Override state field to add choices constraint - state = models.CharField(max_length=50, choices=AnnotationStateChoices.choices, db_index=True) - - # Denormalized fields for performance (avoid JOINs in common queries) - task_id = models.PositiveIntegerField( - db_index=True, help_text='From annotation.task_id - denormalized for performance' - ) - project_id = models.PositiveIntegerField( - db_index=True, help_text='From annotation.task.project_id - denormalized for performance' - ) - completed_by_id = models.PositiveIntegerField( - null=True, db_index=True, help_text='From annotation.completed_by_id - denormalized for performance' - ) - - class Meta: - app_label = 'fsm' - indexes = [ - # Critical: Latest state lookup - models.Index(fields=['annotation_id', '-id'], name='anno_current_state_idx'), - # Filtering and reporting - models.Index(fields=['task_id', 'state', '-id'], name='anno_task_state_idx'), - models.Index(fields=['completed_by_id', 'state', '-id'], name='anno_user_report_idx'), - models.Index(fields=['project_id', 'state', '-id'], name='anno_project_report_idx'), - ] - ordering = ['-id'] - - @classmethod - def get_denormalized_fields(cls, entity): - """Get denormalized fields for AnnotationState creation""" - return { - 'task_id': entity.task.id, - 'project_id': entity.task.project_id, - 'completed_by_id': entity.completed_by.id if entity.completed_by else None, - } - - @property - def is_terminal_state(self) -> bool: - """Check if this is a terminal annotation state""" - return self.state == AnnotationStateChoices.COMPLETED - - -class ProjectState(BaseState): - """ - Core project state tracking for Label Studio. - - Provides basic project state management with: - - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) - - Project lifecycle tracking - """ - - # Entity Relationship - project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='fsm_states') - - # Override state field to add choices constraint - state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) - - created_by_id = models.PositiveIntegerField( - null=True, db_index=True, help_text='From project.created_by_id - denormalized for performance' - ) - - class Meta: - app_label = 'fsm' - indexes = [ - # Critical: Latest state lookup - models.Index(fields=['project_id', '-id'], name='project_current_state_idx'), - # Filtering and reporting - models.Index(fields=['organization_id', 'state', '-id'], name='project_org_state_idx'), - models.Index(fields=['organization_id', '-id'], name='project_org_reporting_idx'), - ] - ordering = ['-id'] - - @classmethod - def get_denormalized_fields(cls, entity): - """Get denormalized fields for ProjectState creation""" - return { - 'created_by_id': entity.created_by.id if entity.created_by else None, - } - - @property - def is_terminal_state(self) -> bool: - """Check if this is a terminal project state""" - return self.state == ProjectStateChoices.COMPLETED - - # Registry for dynamic state model extension -STATE_MODEL_REGISTRY = { - 'task': TaskState, - 'annotation': AnnotationState, - 'project': ProjectState, -} +STATE_MODEL_REGISTRY = {} def register_state_model(entity_name: str, model_class): diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py index ba8304c777ec..e00cd12c3e06 100644 --- a/label_studio/fsm/state_choices.py +++ b/label_studio/fsm/state_choices.py @@ -1,77 +1,12 @@ """ -Core state choice enums for Label Studio entities. +FSM state choices registry system. -These enums define the essential states for core Label Studio entities. +This module provides the infrastructure for registering and managing +state choices for different entity types in the FSM framework. """ -from django.db import models -from django.utils.translation import gettext_lazy as _ - - -class TaskStateChoices(models.TextChoices): - """ - Core task states for basic Label Studio workflow. - - Simplified states covering the essential task lifecycle: - - Creation and assignment - - Annotation work - - Completion - """ - - # Initial State - CREATED = 'CREATED', _('Created') - - # Work States - IN_PROGRESS = 'IN_PROGRESS', _('In Progress') - - # Terminal State - COMPLETED = 'COMPLETED', _('Completed') - - -class AnnotationStateChoices(models.TextChoices): - """ - Core annotation states for basic Label Studio workflow. - - Simplified states covering the essential annotation lifecycle: - - Draft work - - Submission - - Completion - """ - - # Working States - DRAFT = 'DRAFT', _('Draft') - SUBMITTED = 'SUBMITTED', _('Submitted') - - # Terminal State - COMPLETED = 'COMPLETED', _('Completed') - - -class ProjectStateChoices(models.TextChoices): - """ - Core project states for basic Label Studio workflow. - - Simplified states covering the essential project lifecycle: - - Setup and configuration - - Active work - - Completion - """ - - # Setup States - CREATED = 'CREATED', _('Created') - - # Work States - IN_PROGRESS = 'IN_PROGRESS', _('In Progress') - - # Terminal State - COMPLETED = 'COMPLETED', _('Completed') - - # Registry for dynamic state choices extension -STATE_CHOICES_REGISTRY = { - 'task': TaskStateChoices, - 'annotation': AnnotationStateChoices, - 'project': ProjectStateChoices, -} +STATE_CHOICES_REGISTRY = {} def register_state_choices(entity_name: str, choices_class): @@ -79,7 +14,7 @@ def register_state_choices(entity_name: str, choices_class): Register state choices for an entity type. Args: - entity_name: Name of the entity (e.g., 'review', 'assignment') + entity_name: Name of the entity (e.g., 'order', 'ticket') choices_class: Django TextChoices class defining valid states """ STATE_CHOICES_REGISTRY[entity_name.lower()] = choices_class @@ -96,26 +31,3 @@ def get_state_choices(entity_name: str): Django TextChoices class or None if not found """ return STATE_CHOICES_REGISTRY.get(entity_name.lower()) - - -# State complexity metrics for core entities -CORE_STATE_COMPLEXITY_METRICS = { - 'TaskStateChoices': { - 'total_states': len(TaskStateChoices.choices), - 'complexity_score': 1.0, # Simple linear flow - 'terminal_states': ['COMPLETED'], - 'entry_states': ['CREATED'], - }, - 'AnnotationStateChoices': { - 'total_states': len(AnnotationStateChoices.choices), - 'complexity_score': 1.0, # Simple linear flow - 'terminal_states': ['COMPLETED'], - 'entry_states': ['DRAFT'], - }, - 'ProjectStateChoices': { - 'total_states': len(ProjectStateChoices.choices), - 'complexity_score': 1.0, # Simple linear flow - 'terminal_states': ['COMPLETED'], - 'entry_states': ['CREATED'], - }, -} diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 6555a68d1a37..539d67a3f602 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -1,23 +1,22 @@ """ -Comprehensive tests for the declarative Pydantic-based transition system. +Core framework tests for the declarative Pydantic-based transition system. -This test suite provides extensive coverage of the new transition system, -including usage examples, edge cases, validation scenarios, and integration -patterns to serve as both tests and documentation. +This test suite covers the core transition framework functionality without +product-specific implementations. It tests the abstract base classes, +registration system, validation, and core utilities. """ -from datetime import datetime, timedelta +from datetime import datetime from typing import Any, Dict -from unittest.mock import Mock, patch +from unittest.mock import Mock -import pytest from django.contrib.auth import get_user_model +from django.db import models from django.test import TestCase -from fsm.state_choices import AnnotationStateChoices, TaskStateChoices +from django.utils.translation import gettext_lazy as _ from fsm.transition_utils import ( TransitionBuilder, get_available_transitions, - get_valid_transitions, ) from fsm.transitions import ( BaseTransition, @@ -31,892 +30,329 @@ User = get_user_model() -class MockTask: - """Mock task model for testing""" +class TestStateChoices(models.TextChoices): + """Test state choices for mock entity""" - def __init__(self, pk=1): - self.pk = pk - self.id = pk - self.organization_id = 1 - self._meta = Mock() - self._meta.model_name = 'task' - self._meta.label_lower = 'tasks.task' + CREATED = 'CREATED', _('Created') + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + COMPLETED = 'COMPLETED', _('Completed') -class MockAnnotation: - """Mock annotation model for testing""" +class MockEntity: + """Mock entity model for testing""" def __init__(self, pk=1): self.pk = pk self.id = pk - self.result = {'test': 'data'} # Mock annotation data self.organization_id = 1 self._meta = Mock() - self._meta.model_name = 'annotation' - self._meta.label_lower = 'tasks.annotation' - + self._meta.model_name = 'test_entity' + self._meta.label_lower = 'test.testentity' -class TestTransition(BaseTransition): - """Test transition class""" - test_field: str - optional_field: int = 42 +class CoreFrameworkTests(TestCase): + """Test core framework functionality""" - @property - def target_state(self) -> str: - return 'TEST_STATE' - - @classmethod - def get_target_state(cls) -> str: - """Return the target state at class level""" - return 'TEST_STATE' + def setUp(self): + """Set up test data""" + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.mock_entity = MockEntity() - @classmethod - def can_transition_from_state(cls, context: TransitionContext) -> bool: - """Allow transition from any state for testing""" - return True + # Clear registry to avoid test pollution + transition_registry.clear() - def validate_transition(self, context: TransitionContext) -> bool: - if self.test_field == 'invalid': - raise TransitionValidationError('Test validation error') - return super().validate_transition(context) + def tearDown(self): + """Clean up after tests""" + transition_registry.clear() - def transition(self, context: TransitionContext) -> dict: - return { - 'test_field': self.test_field, - 'optional_field': self.optional_field, - 'context_entity_id': context.entity.pk, - } + def test_base_transition_class(self): + """Test BaseTransition abstract functionality""" + @register_transition('test_entity', 'test_transition') + class TestTransition(BaseTransition): + test_field: str = Field('default', description='Test field') -class DeclarativeTransitionTests(TestCase): - """Test cases for the declarative transition system""" + @property + def target_state(self) -> str: + return TestStateChoices.IN_PROGRESS - def setUp(self): - self.task = MockTask() - self.annotation = MockAnnotation() - self.user = Mock() - self.user.id = 1 - self.user.username = 'testuser' + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {'test_field': self.test_field} - # Register test transition - transition_registry.register('task', 'test_transition', TestTransition) + # Test instantiation + transition = TestTransition(test_field='test_value') + self.assertEqual(transition.test_field, 'test_value') + self.assertEqual(transition.target_state, TestStateChoices.IN_PROGRESS) + self.assertEqual(transition.transition_name, 'test_transition') - def test_transition_context_creation(self): - """Test creation of transition context""" + def test_transition_context(self): + """Test TransitionContext functionality""" context = TransitionContext( - entity=self.task, + entity=self.mock_entity, + current_state=TestStateChoices.CREATED, + target_state=TestStateChoices.IN_PROGRESS, + timestamp=datetime.now(), current_user=self.user, - current_state='CREATED', - target_state='IN_PROGRESS', - organization_id=1, ) - self.assertEqual(context.entity, self.task) + self.assertEqual(context.entity, self.mock_entity) + self.assertEqual(context.current_state, TestStateChoices.CREATED) + self.assertEqual(context.target_state, TestStateChoices.IN_PROGRESS) self.assertEqual(context.current_user, self.user) - self.assertEqual(context.current_state, 'CREATED') - self.assertEqual(context.target_state, 'IN_PROGRESS') - self.assertEqual(context.organization_id, 1) - self.assertFalse(context.is_initial_transition) self.assertTrue(context.has_current_state) + self.assertFalse(context.is_initial_transition) - def test_transition_context_initial_state(self): - """Test context for initial state transition""" - context = TransitionContext(entity=self.task, current_state=None, target_state='CREATED') - + def test_transition_context_properties(self): + """Test TransitionContext computed properties""" + # Test initial transition + context = TransitionContext(entity=self.mock_entity, current_state=None, target_state=TestStateChoices.CREATED) self.assertTrue(context.is_initial_transition) self.assertFalse(context.has_current_state) - def test_transition_validation_success(self): - """Test successful transition validation""" - transition = TestTransition(test_field='valid') - context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) - - self.assertTrue(transition.validate_transition(context)) - - def test_transition_validation_failure(self): - """Test transition validation failure""" - transition = TestTransition(test_field='invalid') - context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) - - with self.assertRaises(TransitionValidationError): - transition.validate_transition(context) - - def test_transition_execution(self): - """Test transition data generation""" - transition = TestTransition(test_field='test_value', optional_field=100) - context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) - - result = transition.transition(context) - - self.assertEqual(result['test_field'], 'test_value') - self.assertEqual(result['optional_field'], 100) - self.assertEqual(result['context_entity_id'], self.task.pk) - - def test_transition_name_generation(self): - """Test automatic transition name generation""" - transition = TestTransition(test_field='test') - self.assertEqual(transition.transition_name, 'test_transition') - - @patch('fsm.state_manager.StateManager.transition_state') - @patch('fsm.state_manager.StateManager.get_current_state_object') - def test_transition_execute_full_workflow(self, mock_get_state, mock_transition): - """Test full transition execution workflow""" - # Setup mocks - mock_get_state.return_value = None # No current state - mock_transition.return_value = True - - mock_state_record = Mock() - mock_state_record.id = 'test-uuid' - - with patch('fsm.state_manager.StateManager.get_current_state_object', return_value=mock_state_record): - transition = TestTransition(test_field='test_value') - context = TransitionContext( - entity=self.task, current_user=self.user, current_state=None, target_state=transition.target_state - ) - - # Execute transition - transition.execute(context) - - # Verify StateManager was called correctly - mock_transition.assert_called_once() - call_args = mock_transition.call_args - - self.assertEqual(call_args[1]['entity'], self.task) - self.assertEqual(call_args[1]['new_state'], 'TEST_STATE') - self.assertEqual(call_args[1]['transition_name'], 'test_transition') - self.assertEqual(call_args[1]['user'], self.user) - - # Check context data - context_data = call_args[1]['context'] - self.assertEqual(context_data['test_field'], 'test_value') - self.assertEqual(context_data['optional_field'], 42) - - -class TransitionRegistryTests(TestCase): - """Test cases for the transition registry""" - - def setUp(self): - self.registry = transition_registry - - def test_transition_registration(self): - """Test registering transitions""" - self.registry.register('test_entity', 'test_transition', TestTransition) - - retrieved = self.registry.get_transition('test_entity', 'test_transition') - self.assertEqual(retrieved, TestTransition) - - def test_get_transitions_for_entity(self): - """Test getting all transitions for an entity""" - self.registry.register('test_entity', 'transition1', TestTransition) - self.registry.register('test_entity', 'transition2', TestTransition) - - transitions = self.registry.get_transitions_for_entity('test_entity') - - self.assertIn('transition1', transitions) - self.assertIn('transition2', transitions) - self.assertEqual(len(transitions), 2) - - def test_list_entities(self): - """Test listing registered entities""" - self.registry.register('entity1', 'transition1', TestTransition) - self.registry.register('entity2', 'transition2', TestTransition) - - entities = self.registry.list_entities() - - self.assertIn('entity1', entities) - self.assertIn('entity2', entities) - - -class TransitionUtilsTests(TestCase): - """Test cases for transition utility functions""" - - def setUp(self): - self.task = MockTask() - transition_registry.register('task', 'test_transition', TestTransition) - - def test_get_available_transitions(self): - """Test getting available transitions for entity""" - transitions = get_available_transitions(self.task) - self.assertIn('test_transition', transitions) - - @patch('fsm.state_manager.StateManager.get_current_state_object') - def test_get_valid_transitions(self, mock_get_state): - """Test filtering valid transitions""" - mock_get_state.return_value = None - - valid_transitions = get_valid_transitions(self.task, validate=True) - self.assertIn('test_transition', valid_transitions) - - @patch('fsm.state_manager.StateManager.get_current_state_object') - def test_get_valid_transitions_with_invalid(self, mock_get_state): - """Test filtering out invalid transitions""" - mock_get_state.return_value = None - - # Register an invalid transition - class InvalidTransition(TestTransition): - @classmethod - def can_transition_from_state(cls, context): - # This transition is never valid at the class level - return False - - def validate_transition(self, context): - raise TransitionValidationError('Always invalid') - - transition_registry.register('task', 'invalid_transition', InvalidTransition) - - valid_transitions = get_valid_transitions(self.task, validate=True) - self.assertIn('test_transition', valid_transitions) - self.assertNotIn('invalid_transition', valid_transitions) - - @patch('fsm.transition_utils.execute_transition') - def test_transition_builder(self, mock_execute): - """Test fluent transition builder interface""" - mock_execute.return_value = Mock() - - ( - TransitionBuilder(self.task) - .transition('test_transition') - .with_data(test_field='builder_test') - .by_user(Mock()) - .with_context(extra='context') - .execute() + # Test with current state + context_with_state = TransitionContext( + entity=self.mock_entity, + current_state=TestStateChoices.CREATED, + target_state=TestStateChoices.IN_PROGRESS, ) + self.assertFalse(context_with_state.is_initial_transition) + self.assertTrue(context_with_state.has_current_state) - mock_execute.assert_called_once() - call_args = mock_execute.call_args - - self.assertEqual(call_args[1]['transition_name'], 'test_transition') - self.assertEqual(call_args[1]['transition_data']['test_field'], 'builder_test') - - -class ExampleTransitionIntegrationTests(TestCase): - """Integration tests using the example transitions""" - - def setUp(self): - # Import example transitions to register them - - self.task = MockTask() - self.annotation = MockAnnotation() - self.user = Mock() - self.user.id = 1 - self.user.username = 'testuser' - - def test_start_task_transition_validation(self): - """Test StartTaskTransition validation""" - from fsm.example_transitions import StartTaskTransition - - transition = StartTaskTransition(assigned_user_id=123) + def test_transition_registry(self): + """Test transition registration and retrieval""" - # Test valid transition from CREATED - context = TransitionContext( - entity=self.task, current_state=TaskStateChoices.CREATED, target_state=transition.target_state - ) - - self.assertTrue(transition.validate_transition(context)) - - # Test invalid transition from COMPLETED - context.current_state = TaskStateChoices.COMPLETED - - with self.assertRaises(TransitionValidationError): - transition.validate_transition(context) - - def test_submit_annotation_validation(self): - """Test SubmitAnnotationTransition validation""" - from fsm.example_transitions import SubmitAnnotationTransition - - transition = SubmitAnnotationTransition() - - # Test valid transition - context = TransitionContext( - entity=self.annotation, current_state=AnnotationStateChoices.DRAFT, target_state=transition.target_state - ) - - self.assertTrue(transition.validate_transition(context)) - - def test_transition_data_generation(self): - """Test that transitions generate appropriate context data""" - from fsm.example_transitions import StartTaskTransition - - transition = StartTaskTransition(assigned_user_id=123, estimated_duration=5, priority='high') - - context = TransitionContext( - entity=self.task, current_user=self.user, target_state=transition.target_state, timestamp=datetime.now() - ) + @register_transition('test_entity', 'test_transition') + class TestTransition(BaseTransition): + @property + def target_state(self) -> str: + return TestStateChoices.COMPLETED - result = transition.transition(context) + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {} - self.assertEqual(result['assigned_user_id'], 123) - self.assertEqual(result['estimated_duration'], 5) - self.assertEqual(result['priority'], 'high') - self.assertIn('started_at', result) - self.assertEqual(result['assignment_type'], 'manual') + # Test registration + retrieved = transition_registry.get_transition('test_entity', 'test_transition') + self.assertEqual(retrieved, TestTransition) + # Test entity transitions + entity_transitions = transition_registry.get_transitions_for_entity('test_entity') + self.assertIn('test_transition', entity_transitions) + self.assertEqual(entity_transitions['test_transition'], TestTransition) -class ComprehensiveUsageExampleTests(TestCase): - """ - Comprehensive test cases demonstrating various usage patterns. + def test_pydantic_validation(self): + """Test Pydantic validation in transitions""" - These tests serve as both validation and documentation for how to - implement and use the declarative transition system. - """ + @register_transition('test_entity', 'validated_transition') + class ValidatedTransition(BaseTransition): + required_field: str = Field(..., description='Required field') + optional_field: int = Field(42, description='Optional field') - def setUp(self): - self.task = MockTask() - self.user = Mock() - self.user.id = 123 - self.user.username = 'testuser' + @property + def target_state(self) -> str: + return TestStateChoices.COMPLETED - # Clear registry to avoid conflicts - transition_registry._transitions.clear() + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {'required_field': self.required_field, 'optional_field': self.optional_field} - def test_basic_transition_implementation(self): - """ - USAGE EXAMPLE: Basic transition implementation + # Test valid instantiation + transition = ValidatedTransition(required_field='test') + self.assertEqual(transition.required_field, 'test') + self.assertEqual(transition.optional_field, 42) - Shows how to implement a simple transition with validation. - """ + # Test validation error + with self.assertRaises(ValidationError): + ValidatedTransition() # Missing required field - class BasicTransition(BaseTransition): - """Example: Simple transition with required field""" + def test_transition_execution(self): + """Test transition execution logic""" - message: str = Field(..., description='Message for the transition') + @register_transition('test_entity', 'execution_test') + class ExecutionTestTransition(BaseTransition): + value: str = Field('test', description='Test value') @property def target_state(self) -> str: - return 'PROCESSED' + return TestStateChoices.COMPLETED def validate_transition(self, context: TransitionContext) -> bool: - # Business logic validation - if context.current_state == 'COMPLETED': - raise TransitionValidationError('Cannot process completed items') - return True + return context.current_state == TestStateChoices.IN_PROGRESS def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - 'message': self.message, - 'processed_by': context.current_user.username if context.current_user else 'system', - 'processed_at': context.timestamp.isoformat(), + 'value': self.value, + 'entity_id': context.entity.pk, + 'timestamp': context.timestamp.isoformat(), } - # Test the implementation - transition = BasicTransition(message='Processing task') - self.assertEqual(transition.message, 'Processing task') - self.assertEqual(transition.target_state, 'PROCESSED') - - # Test validation + transition = ExecutionTestTransition(value='execution_test') context = TransitionContext( - entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state + entity=self.mock_entity, + current_state=TestStateChoices.IN_PROGRESS, + target_state=transition.target_state, + timestamp=datetime.now(), ) + # Test validation self.assertTrue(transition.validate_transition(context)) - # Test data generation - data = transition.transition(context) - self.assertEqual(data['message'], 'Processing task') - self.assertEqual(data['processed_by'], 'testuser') - self.assertIn('processed_at', data) - - def test_complex_validation_example(self): - """ - USAGE EXAMPLE: Complex validation with multiple conditions - - Shows how to implement sophisticated business logic validation. - """ - - class TaskAssignmentTransition(BaseTransition): - """Example: Complex validation for task assignment""" + # Test execution + result = transition.transition(context) + self.assertEqual(result['value'], 'execution_test') + self.assertEqual(result['entity_id'], self.mock_entity.pk) + self.assertIn('timestamp', result) - assignee_id: int = Field(..., description='User to assign task to') - priority: str = Field('normal', description='Task priority') - deadline: datetime = Field(None, description='Task deadline') + def test_validation_error_handling(self): + """Test transition validation error handling""" + @register_transition('test_entity', 'validation_test') + class ValidationTestTransition(BaseTransition): @property def target_state(self) -> str: - return 'ASSIGNED' + return TestStateChoices.COMPLETED def validate_transition(self, context: TransitionContext) -> bool: - # Multiple validation conditions - if context.current_state not in ['CREATED', 'UNASSIGNED']: + if context.current_state != TestStateChoices.IN_PROGRESS: raise TransitionValidationError( - f'Cannot assign task in state {context.current_state}', - {'current_state': context.current_state, 'task_id': context.entity.pk}, + 'Can only complete from IN_PROGRESS state', {'current_state': context.current_state} ) - - # Check deadline is in future - if self.deadline and self.deadline <= datetime.now(): - raise TransitionValidationError( - 'Deadline must be in the future', {'deadline': self.deadline.isoformat()} - ) - - # Check priority is valid - valid_priorities = ['low', 'normal', 'high', 'urgent'] - if self.priority not in valid_priorities: - raise TransitionValidationError( - f'Invalid priority: {self.priority}', {'valid_priorities': valid_priorities} - ) - return True def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'assignee_id': self.assignee_id, - 'priority': self.priority, - 'deadline': self.deadline.isoformat() if self.deadline else None, - 'assigned_by': context.current_user.id if context.current_user else None, - 'assignment_reason': f'Task assigned to user {self.assignee_id}', - } - - # Test valid assignment - future_deadline = datetime.now() + timedelta(days=7) - transition = TaskAssignmentTransition(assignee_id=456, priority='high', deadline=future_deadline) + return {} - context = TransitionContext( - entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state + transition = ValidationTestTransition() + invalid_context = TransitionContext( + entity=self.mock_entity, + current_state=TestStateChoices.CREATED, + target_state=transition.target_state, ) - self.assertTrue(transition.validate_transition(context)) - - # Test invalid state - context.current_state = 'COMPLETED' + # Test validation error with self.assertRaises(TransitionValidationError) as cm: - transition.validate_transition(context) - - self.assertIn('Cannot assign task in state', str(cm.exception)) - self.assertIn('COMPLETED', str(cm.exception)) + transition.validate_transition(invalid_context) - # Test invalid deadline - past_deadline = datetime.now() - timedelta(days=1) - invalid_transition = TaskAssignmentTransition(assignee_id=456, deadline=past_deadline) - - context.current_state = 'CREATED' - with self.assertRaises(TransitionValidationError) as cm: - invalid_transition.validate_transition(context) - - self.assertIn('Deadline must be in the future', str(cm.exception)) - - def test_hooks_and_lifecycle_example(self): - """ - USAGE EXAMPLE: Using pre/post hooks for side effects - - Shows how to implement lifecycle hooks for notifications, - cleanup, or other side effects. - """ + error = cm.exception + self.assertIn('Can only complete from IN_PROGRESS state', str(error)) + self.assertIn('current_state', error.context) - class NotificationTransition(BaseTransition): - """Example: Transition with notification hooks""" + def test_transition_builder_basic(self): + """Test TransitionBuilder basic functionality""" - notification_message: str = Field(..., description='Notification message') - notify_users: list = Field(default_factory=list, description='Users to notify') - notifications_sent: list = Field(default_factory=list, description='Track sent notifications') - cleanup_performed: bool = Field(default=False, description='Track cleanup status') + @register_transition('test_entity', 'builder_test') + class BuilderTestTransition(BaseTransition): + value: str = Field('default', description='Test value') @property def target_state(self) -> str: - return 'NOTIFIED' - - @classmethod - def get_target_state(cls) -> str: - return 'NOTIFIED' - - @classmethod - def can_transition_from_state(cls, context: TransitionContext) -> bool: - return True - - def pre_transition_hook(self, context: TransitionContext) -> None: - """Prepare notifications before state change""" - # Validate notification recipients - if not self.notify_users: - self.notify_users = [context.current_user.id] if context.current_user else [] + return TestStateChoices.COMPLETED def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'notification_message': self.notification_message, - 'notify_users': self.notify_users, - 'notification_sent_at': context.timestamp.isoformat(), - } - - def post_transition_hook(self, context: TransitionContext, state_record) -> None: - """Send notifications after successful state change""" - # Mock notification sending - for user_id in self.notify_users: - self.notifications_sent.append( - {'user_id': user_id, 'message': self.notification_message, 'sent_at': context.timestamp} - ) - - # Mock cleanup - self.cleanup_performed = True - - # Test the hooks - transition = NotificationTransition(notification_message='Task has been updated', notify_users=[123, 456]) - - context = TransitionContext( - entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state - ) - - # Test pre-hook - transition.pre_transition_hook(context) - self.assertEqual(transition.notify_users, [123, 456]) - - # Test transition - data = transition.transition(context) - self.assertEqual(data['notification_message'], 'Task has been updated') - - # Test post-hook - mock_state_record = Mock() - transition.post_transition_hook(context, mock_state_record) + return {'value': self.value} - self.assertEqual(len(transition.notifications_sent), 2) - self.assertTrue(transition.cleanup_performed) + # Test builder creation + builder = TransitionBuilder(self.mock_entity) + self.assertEqual(builder.entity, self.mock_entity) - def test_conditional_transition_example(self): - """ - USAGE EXAMPLE: Conditional transitions based on data + # Test method chaining + builder = builder.transition('builder_test').with_data(value='builder_test_value').by_user(self.user) - Shows how to implement transitions that behave differently - based on input data or context. - """ + # Validate the builder state + validation_errors = builder.validate() + self.assertEqual(len(validation_errors), 0) - class ConditionalApprovalTransition(BaseTransition): - """Example: Conditional approval based on confidence""" - - confidence_score: float = Field(..., ge=0.0, le=1.0, description='Confidence score') - auto_approve_threshold: float = Field(0.9, description='Auto-approval threshold') - reviewer_id: int = Field(None, description='Manual reviewer ID') + def test_get_available_transitions(self): + """Test get_available_transitions utility""" + @register_transition('test_entity', 'available_test') + class AvailableTestTransition(BaseTransition): @property def target_state(self) -> str: - # Dynamic target state based on confidence - if self.confidence_score >= self.auto_approve_threshold: - return 'AUTO_APPROVED' - else: - return 'PENDING_REVIEW' - - def validate_transition(self, context: TransitionContext) -> bool: - # Different validation based on approval type - if self.confidence_score >= self.auto_approve_threshold: - # Auto-approval validation - if context.current_state != 'SUBMITTED': - raise TransitionValidationError('Can only auto-approve submitted items') - else: - # Manual review validation - if not self.reviewer_id: - raise TransitionValidationError('Manual review requires reviewer_id') - - return True + return TestStateChoices.COMPLETED def transition(self, context: TransitionContext) -> Dict[str, Any]: - base_data = { - 'confidence_score': self.confidence_score, - 'threshold': self.auto_approve_threshold, - } - - if self.confidence_score >= self.auto_approve_threshold: - # Auto-approval data - return { - **base_data, - 'approval_type': 'automatic', - 'approved_at': context.timestamp.isoformat(), - 'approved_by': 'system', - } - else: - # Manual review data - return { - **base_data, - 'approval_type': 'manual', - 'assigned_reviewer': self.reviewer_id, - 'review_requested_at': context.timestamp.isoformat(), - } - - # Test auto-approval path - high_confidence_transition = ConditionalApprovalTransition(confidence_score=0.95) - - self.assertEqual(high_confidence_transition.target_state, 'AUTO_APPROVED') - - context = TransitionContext( - entity=self.task, current_state='SUBMITTED', target_state=high_confidence_transition.target_state - ) - - self.assertTrue(high_confidence_transition.validate_transition(context)) - - auto_data = high_confidence_transition.transition(context) - self.assertEqual(auto_data['approval_type'], 'automatic') - self.assertEqual(auto_data['approved_by'], 'system') - - # Test manual review path - low_confidence_transition = ConditionalApprovalTransition(confidence_score=0.7, reviewer_id=789) - - self.assertEqual(low_confidence_transition.target_state, 'PENDING_REVIEW') + return {} - context.target_state = low_confidence_transition.target_state - self.assertTrue(low_confidence_transition.validate_transition(context)) + available = get_available_transitions(self.mock_entity) + self.assertIn('available_test', available) + self.assertEqual(available['available_test'], AvailableTestTransition) - manual_data = low_confidence_transition.transition(context) - self.assertEqual(manual_data['approval_type'], 'manual') - self.assertEqual(manual_data['assigned_reviewer'], 789) + def test_transition_hooks(self): + """Test pre and post transition hooks""" - def test_registry_and_decorator_usage(self): - """ - USAGE EXAMPLE: Using the registry and decorator system - - Shows how to register transitions and use the decorator syntax. - """ - - @register_transition('document', 'publish') - class PublishDocumentTransition(BaseTransition): - """Example: Using the registration decorator""" - - publish_immediately: bool = Field(True, description='Publish immediately') - scheduled_time: datetime = Field(None, description='Scheduled publish time') + hook_calls = [] + @register_transition('test_entity', 'hook_test') + class HookTestTransition(BaseTransition): @property def target_state(self) -> str: - return 'PUBLISHED' if self.publish_immediately else 'SCHEDULED' - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'publish_immediately': self.publish_immediately, - 'scheduled_time': self.scheduled_time.isoformat() if self.scheduled_time else None, - 'published_by': context.current_user.id if context.current_user else None, - } + return TestStateChoices.COMPLETED - # Test registration worked - registered_class = transition_registry.get_transition('document', 'publish') - self.assertEqual(registered_class, PublishDocumentTransition) + def pre_transition_hook(self, context: TransitionContext) -> None: + hook_calls.append('pre') - # Test getting transitions for entity - document_transitions = transition_registry.get_transitions_for_entity('document') - self.assertIn('publish', document_transitions) + def transition(self, context: TransitionContext) -> Dict[str, Any]: + hook_calls.append('transition') + return {} - # Test execution through registry - mock_document = Mock() - mock_document.pk = 1 - mock_document._meta.model_name = 'document' + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + hook_calls.append('post') - # This would normally go through the full execution workflow - transition_data = {'publish_immediately': False, 'scheduled_time': datetime.now() + timedelta(hours=2)} + transition = HookTestTransition() + context = TransitionContext( + entity=self.mock_entity, + current_state=TestStateChoices.IN_PROGRESS, + target_state=transition.target_state, + ) - # Test transition creation and validation - transition = PublishDocumentTransition(**transition_data) - self.assertEqual(transition.target_state, 'SCHEDULED') + # Test hook execution order + transition.pre_transition_hook(context) + transition.transition(context) + transition.post_transition_hook(context, Mock()) + self.assertEqual(hook_calls, ['pre', 'transition', 'post']) -class ValidationAndErrorHandlingTests(TestCase): - """ - Tests focused on validation scenarios and error handling. - These tests demonstrate proper error handling patterns and - validation edge cases. - """ +class TransitionUtilsTests(TestCase): + """Test transition utility functions""" def setUp(self): - self.task = MockTask() - transition_registry._transitions.clear() - - def test_pydantic_validation_errors(self): - """Test Pydantic field validation errors""" - - class StrictValidationTransition(BaseTransition): - required_field: str = Field(..., description='Required field') - email_field: str = Field(..., pattern=r'^[\w\.-]+@[\w\.-]+\.\w+$', description='Valid email') - number_field: int = Field(..., ge=1, le=100, description='Number between 1-100') - - @property - def target_state(self) -> str: - return 'VALIDATED' - - @classmethod - def get_target_state(cls) -> str: - return 'VALIDATED' - - @classmethod - def can_transition_from_state(cls, context: TransitionContext) -> bool: - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'validated': True} - - # Test missing required field - with self.assertRaises(ValidationError): - StrictValidationTransition(email_field='test@example.com', number_field=50) - - # Test invalid email - with self.assertRaises(ValidationError): - StrictValidationTransition(required_field='test', email_field='invalid-email', number_field=50) - - # Test number out of range - with self.assertRaises(ValidationError): - StrictValidationTransition(required_field='test', email_field='test@example.com', number_field=150) - - # Test valid data - valid_transition = StrictValidationTransition( - required_field='test', email_field='user@example.com', number_field=75 - ) - self.assertEqual(valid_transition.required_field, 'test') + """Set up test data""" + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.mock_entity = MockEntity() + transition_registry.clear() - def test_business_logic_validation_errors(self): - """Test business logic validation with detailed error context""" + def tearDown(self): + """Clean up after tests""" + transition_registry.clear() - class BusinessRuleTransition(BaseTransition): - amount: float = Field(..., description='Transaction amount') - currency: str = Field('USD', description='Currency code') + def test_get_available_transitions(self): + """Test getting available transitions for an entity""" + @register_transition('test_entity', 'util_test_1') + class UtilTestTransition1(BaseTransition): @property def target_state(self) -> str: - return 'PROCESSED' - - def validate_transition(self, context: TransitionContext) -> bool: - # Complex business rule validation - errors = [] - - if self.amount <= 0: - errors.append('Amount must be positive') - - if self.amount > 10000 and context.current_user is None: - errors.append('Large amounts require authenticated user') - - if self.currency not in ['USD', 'EUR', 'GBP']: - errors.append(f'Unsupported currency: {self.currency}') - - if context.current_state == 'CANCELLED': - errors.append('Cannot process cancelled transactions') - - if errors: - raise TransitionValidationError( - f"Validation failed: {'; '.join(errors)}", - { - 'validation_errors': errors, - 'amount': self.amount, - 'currency': self.currency, - 'current_state': context.current_state, - }, - ) - - return True + return TestStateChoices.IN_PROGRESS def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'amount': self.amount, 'currency': self.currency} - - context = TransitionContext(entity=self.task, current_state='PENDING', target_state='PROCESSED') - - # Test negative amount - negative_transition = BusinessRuleTransition(amount=-100) - with self.assertRaises(TransitionValidationError) as cm: - negative_transition.validate_transition(context) - - error = cm.exception - self.assertIn('Amount must be positive', str(error)) - self.assertIn('validation_errors', error.context) - - # Test large amount without user - large_transition = BusinessRuleTransition(amount=15000) - with self.assertRaises(TransitionValidationError) as cm: - large_transition.validate_transition(context) - - self.assertIn('Large amounts require authenticated user', str(cm.exception)) - - # Test invalid currency - invalid_currency_transition = BusinessRuleTransition(amount=100, currency='XYZ') - with self.assertRaises(TransitionValidationError) as cm: - invalid_currency_transition.validate_transition(context) - - self.assertIn('Unsupported currency', str(cm.exception)) - - # Test multiple errors - multi_error_transition = BusinessRuleTransition(amount=-50, currency='XYZ') - with self.assertRaises(TransitionValidationError) as cm: - multi_error_transition.validate_transition(context) - - error_msg = str(cm.exception) - self.assertIn('Amount must be positive', error_msg) - self.assertIn('Unsupported currency', error_msg) - - def test_context_validation_errors(self): - """Test validation errors related to context state""" - - class ContextAwareTransition(BaseTransition): - action: str = Field(..., description='Action to perform') + return {} + @register_transition('test_entity', 'util_test_2') + class UtilTestTransition2(BaseTransition): @property def target_state(self) -> str: - return 'ACTIONED' - - def validate_transition(self, context: TransitionContext) -> bool: - # State-dependent validation - if context.is_initial_transition and self.action != 'create': - raise TransitionValidationError( - "Initial transition must be 'create' action", {'action': self.action, 'is_initial': True} - ) - - if context.current_state == 'COMPLETED' and self.action in ['modify', 'update']: - raise TransitionValidationError( - f'Cannot {self.action} completed items', - {'action': self.action, 'current_state': context.current_state}, - ) - - return True + return TestStateChoices.COMPLETED def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'action': self.action} - - # Test initial transition validation - create_transition = ContextAwareTransition(action='create') - initial_context = TransitionContext( - entity=self.task, current_state=None, target_state='ACTIONED' # No current state = initial - ) - - self.assertTrue(create_transition.validate_transition(initial_context)) - - # Test invalid initial action - modify_transition = ContextAwareTransition(action='modify') - with self.assertRaises(TransitionValidationError) as cm: - modify_transition.validate_transition(initial_context) - - error = cm.exception - self.assertIn("Initial transition must be 'create'", str(error)) - self.assertTrue(error.context['is_initial']) - - # Test completed state validation - completed_context = TransitionContext(entity=self.task, current_state='COMPLETED', target_state='ACTIONED') - - with self.assertRaises(TransitionValidationError) as cm: - modify_transition.validate_transition(completed_context) - - self.assertIn('Cannot modify completed items', str(cm.exception)) - - -@pytest.fixture -def task(): - """Pytest fixture for mock task""" - return MockTask() - - -@pytest.fixture -def user(): - """Pytest fixture for mock user""" - user = Mock() - user.id = 1 - user.username = 'testuser' - return user - - -def test_transition_context_properties(task, user): - """Test TransitionContext properties using pytest""" - context = TransitionContext(entity=task, current_user=user, current_state='CREATED', target_state='IN_PROGRESS') - - assert context.has_current_state - assert not context.is_initial_transition - assert context.current_state == 'CREATED' - assert context.target_state == 'IN_PROGRESS' - - -def test_pydantic_validation(): - """Test Pydantic validation in transitions""" - # Valid data - transition = TestTransition(test_field='valid') - assert transition.test_field == 'valid' - assert transition.optional_field == 42 - - # Invalid data should raise validation error - with pytest.raises(Exception): # Pydantic validation error - TestTransition() # Missing required field + return {} + + available = get_available_transitions(self.mock_entity) + self.assertEqual(len(available), 2) + self.assertIn('util_test_1', available) + self.assertIn('util_test_2', available) + + # Test with non-existent entity + mock_other = MockEntity() + mock_other._meta.model_name = 'other_entity' + other_available = get_available_transitions(mock_other) + self.assertEqual(len(other_available), 0) diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py deleted file mode 100644 index 648735a8b620..000000000000 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -Integration tests for the FSM system. - -Tests the complete FSM functionality including models, state management, -and API endpoints. -""" - -from datetime import datetime, timezone - -from django.contrib.auth import get_user_model -from django.test import TestCase -from fsm.models import AnnotationState, ProjectState, TaskState -from fsm.state_manager import get_state_manager -from projects.models import Project -from rest_framework.test import APITestCase -from tasks.models import Annotation, Task - -User = get_user_model() - - -class TestFSMModels(TestCase): - """Test FSM model functionality""" - - def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) - - # Clear cache to ensure tests start with clean state - from django.core.cache import cache - - cache.clear() - - def test_task_state_creation(self): - """Test TaskState creation and basic functionality""" - task_state = TaskState.objects.create( - task=self.task, - project_id=self.task.project_id, # Denormalized from task.project_id - state='CREATED', - triggered_by=self.user, - reason='Task created for testing', - ) - - # Check basic fields - self.assertEqual(task_state.state, 'CREATED') - self.assertEqual(task_state.task, self.task) - self.assertEqual(task_state.triggered_by, self.user) - - # Check UUID7 functionality - self.assertEqual(task_state.id.version, 7) - self.assertIsInstance(task_state.timestamp_from_uuid, datetime) - - # Check string representation - str_repr = str(task_state) - self.assertIn('Task', str_repr) - self.assertIn('CREATED', str_repr) - - def test_annotation_state_creation(self): - """Test AnnotationState creation and basic functionality""" - annotation = Annotation.objects.create(task=self.task, completed_by=self.user, result=[]) - - annotation_state = AnnotationState.objects.create( - annotation=annotation, - task_id=annotation.task.id, # Denormalized from annotation.task_id - project_id=annotation.task.project_id, # Denormalized from annotation.task.project_id - completed_by_id=annotation.completed_by.id if annotation.completed_by else None, # Denormalized - state='DRAFT', - triggered_by=self.user, - reason='Annotation draft created', - ) - - # Check basic fields - self.assertEqual(annotation_state.state, 'DRAFT') - self.assertEqual(annotation_state.annotation, annotation) - - # Check terminal state property - self.assertFalse(annotation_state.is_terminal_state) - - # Test completed state - completed_state = AnnotationState.objects.create( - annotation=annotation, - task_id=annotation.task.id, - project_id=annotation.task.project_id, - completed_by_id=annotation.completed_by.id if annotation.completed_by else None, - state='COMPLETED', - triggered_by=self.user, - ) - self.assertTrue(completed_state.is_terminal_state) - - def test_project_state_creation(self): - """Test ProjectState creation and basic functionality""" - project_state = ProjectState.objects.create( - project=self.project, state='CREATED', triggered_by=self.user, reason='Project created for testing' - ) - - # Check basic fields - self.assertEqual(project_state.state, 'CREATED') - self.assertEqual(project_state.project, self.project) - - # Test terminal state - self.assertFalse(project_state.is_terminal_state) - - completed_state = ProjectState.objects.create(project=self.project, state='COMPLETED', triggered_by=self.user) - self.assertTrue(completed_state.is_terminal_state) - - -class TestStateManager(TestCase): - """Test StateManager functionality""" - - def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) - self.StateManager = get_state_manager() - - # Clear cache to ensure tests start with clean state - from django.core.cache import cache - - cache.clear() - - def test_get_current_state_empty(self): - """Test getting current state when no states exist""" - current_state = self.StateManager.get_current_state(self.task) - self.assertIsNone(current_state) - - def test_transition_state(self): - """Test state transition functionality""" - # Initial transition - success = self.StateManager.transition_state( - entity=self.task, - new_state='CREATED', - user=self.user, - transition_name='create_task', - reason='Initial task creation', - ) - - self.assertTrue(success) - - # Check current state - current_state = self.StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'CREATED') - - # Another transition - success = self.StateManager.transition_state( - entity=self.task, - new_state='IN_PROGRESS', - user=self.user, - transition_name='start_work', - context={'started_by': 'user'}, - ) - - self.assertTrue(success) - current_state = self.StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'IN_PROGRESS') - - def test_get_current_state_object(self): - """Test getting current state object with full details""" - # Create some state transitions - self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) - self.StateManager.transition_state( - entity=self.task, new_state='IN_PROGRESS', user=self.user, context={'test': 'data'} - ) - - current_state_obj = self.StateManager.get_current_state_object(self.task) - - self.assertIsNotNone(current_state_obj) - self.assertEqual(current_state_obj.state, 'IN_PROGRESS') - self.assertEqual(current_state_obj.previous_state, 'CREATED') - self.assertEqual(current_state_obj.triggered_by, self.user) - self.assertEqual(current_state_obj.context_data, {'test': 'data'}) - - def test_get_state_history(self): - """Test state history retrieval""" - # Create multiple transitions - transitions = [('CREATED', 'create_task'), ('IN_PROGRESS', 'start_work'), ('COMPLETED', 'finish_work')] - - for state, transition in transitions: - self.StateManager.transition_state( - entity=self.task, new_state=state, user=self.user, transition_name=transition - ) - - history = self.StateManager.get_state_history(self.task, limit=10) - - # Should have 3 state records - self.assertEqual(len(history), 3) - - # Should be ordered by most recent first (UUID7 ordering) - states = [h.state for h in history] - self.assertEqual(states, ['COMPLETED', 'IN_PROGRESS', 'CREATED']) - - print(history) - ids = [str(h.id) for h in history] - print(ids) - - # Check previous states are set correctly - self.assertIsNone(history[2].previous_state) # First state has no previous - self.assertEqual(history[1].previous_state, 'CREATED') - self.assertEqual(history[0].previous_state, 'IN_PROGRESS') - - def test_get_states_in_time_range(self): - """Test time-based state queries using UUID7""" - # Record time before creating states - before_time = datetime.now(timezone.utc) - - # Create some states - self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) - self.StateManager.transition_state(entity=self.task, new_state='IN_PROGRESS', user=self.user) - - # Record time after creating states - after_time = datetime.now(timezone.utc) - - # Query states in time range - states_in_range = self.StateManager.get_states_in_time_range(self.task, before_time, after_time) - - # Should find both states - self.assertEqual(len(states_in_range), 2) - - -class TestFSMAPI(APITestCase): - """Test FSM API endpoints""" - - def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) - self.client.force_authenticate(user=self.user) - - # Clear cache to ensure tests start with clean state - from django.core.cache import cache - - cache.clear() - - # Create initial state - StateManager = get_state_manager() - StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) - - def test_get_current_state_api(self): - """Test GET /api/fsm/{entity_type}/{entity_id}/current/""" - response = self.client.get(f'/api/fsm/task/{self.task.id}/current/') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertEqual(data['current_state'], 'CREATED') - self.assertEqual(data['entity_type'], 'task') - self.assertEqual(data['entity_id'], self.task.id) - - def test_get_state_history_api(self): - """Test GET /api/fsm/{entity_type}/{entity_id}/history/""" - # Create additional states - StateManager = get_state_manager() - StateManager.transition_state( - entity=self.task, new_state='IN_PROGRESS', user=self.user, transition_name='start_work' - ) - - response = self.client.get(f'/api/fsm/task/{self.task.id}/history/') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertEqual(data['count'], 2) - self.assertEqual(len(data['results']), 2) - - # Check first result (most recent) - latest_state = data['results'][0] - self.assertEqual(latest_state['state'], 'IN_PROGRESS') - self.assertEqual(latest_state['previous_state'], 'CREATED') - self.assertEqual(latest_state['transition_name'], 'start_work') - - def test_transition_state_api(self): - """Test POST /api/fsm/{entity_type}/{entity_id}/transition/""" - transition_data = { - 'new_state': 'IN_PROGRESS', - 'transition_name': 'start_annotation', - 'reason': 'User started working on task', - 'context': {'assignment_id': 123}, - } - - response = self.client.post(f'/api/fsm/task/{self.task.id}/transition/', data=transition_data, format='json') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertTrue(data['success']) - self.assertEqual(data['previous_state'], 'CREATED') - self.assertEqual(data['new_state'], 'IN_PROGRESS') - self.assertEqual(data['entity_type'], 'task') - self.assertEqual(data['entity_id'], self.task.id) - - # Verify state was actually changed - StateManager = get_state_manager() - current_state = StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'IN_PROGRESS') - - def test_api_with_invalid_entity(self): - """Test API with non-existent entity""" - response = self.client.get('/api/fsm/task/99999/current/') - self.assertEqual(response.status_code, 404) - - def test_api_with_invalid_entity_type(self): - """Test API with invalid entity type""" - response = self.client.get('/api/fsm/invalid/1/current/') - self.assertEqual(response.status_code, 404) diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py deleted file mode 100644 index b7c277b0d25b..000000000000 --- a/label_studio/fsm/tests/test_integration_django_models.py +++ /dev/null @@ -1,671 +0,0 @@ -""" -Integration tests for declarative transitions with real Django models. - -These tests demonstrate how the transition system integrates with actual -Django models and the StateManager, providing realistic usage examples. -""" - -from datetime import datetime -from typing import Any, Dict -from unittest.mock import Mock, patch - -from django.contrib.auth import get_user_model -from django.test import TestCase -from fsm.models import TaskState -from fsm.state_choices import AnnotationStateChoices, TaskStateChoices -from fsm.transition_utils import TransitionBuilder -from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, register_transition -from pydantic import Field - - -# Mock Django models for integration testing -class MockDjangoTask: - """Mock Django Task model with realistic attributes""" - - def __init__(self, pk=1, project_id=1, organization_id=1): - self.pk = pk - self.id = pk - self.project_id = project_id - self.organization_id = organization_id - self._meta = Mock() - self._meta.model_name = 'task' - self._meta.label_lower = 'tasks.task' - - # Mock task attributes - self.data = {'text': 'Sample task data'} - self.created_at = datetime.now() - self.updated_at = datetime.now() - - -class MockDjangoAnnotation: - """Mock Django Annotation model with realistic attributes""" - - def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): - self.pk = pk - self.id = pk - self.task_id = task_id - self.project_id = project_id - self.organization_id = organization_id - self._meta = Mock() - self._meta.model_name = 'annotation' - self._meta.label_lower = 'tasks.annotation' - - # Mock annotation attributes - self.result = [{'value': {'text': ['Sample annotation']}}] - self.completed_by_id = None - self.created_at = datetime.now() - self.updated_at = datetime.now() - - -User = get_user_model() - - -class DjangoModelIntegrationTests(TestCase): - """ - Integration tests demonstrating realistic usage with Django models. - - These tests show how to implement transitions that work with actual - Django model patterns and the StateManager integration. - """ - - def setUp(self): - self.task = MockDjangoTask() - self.annotation = MockDjangoAnnotation() - self.user = Mock() - self.user.id = 123 - self.user.username = 'integration_test_user' - - # Clear registry for clean test state - from fsm.transitions import transition_registry - - transition_registry._transitions.clear() - - @patch('fsm.models.get_state_model_for_entity') - @patch('fsm.state_manager.StateManager.get_current_state_object') - @patch('fsm.state_manager.StateManager.transition_state') - def test_task_workflow_integration(self, mock_transition_state, mock_get_state_obj, mock_get_state_model): - """ - INTEGRATION TEST: Complete task workflow using Django models - - Demonstrates a realistic task lifecycle from creation through completion - using the declarative transition system with Django model integration. - """ - - # Setup mocks to simulate Django model behavior - mock_get_state_model.return_value = TaskState - mock_get_state_obj.return_value = None # No existing state (initial transition) - mock_transition_state.return_value = True - - # Define task workflow transitions - @register_transition('task', 'create_task') - class CreateTaskTransition(BaseTransition): - """Initial task creation transition""" - - created_by_id: int = Field(..., description='User creating the task') - initial_priority: str = Field('normal', description='Initial task priority') - - @property - def target_state(self) -> str: - return TaskStateChoices.CREATED - - def validate_transition(self, context: TransitionContext) -> bool: - # Validate initial creation - if not context.is_initial_transition: - raise TransitionValidationError('CreateTask can only be used for initial state') - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'created_by_id': self.created_by_id, - 'initial_priority': self.initial_priority, - 'task_data': getattr(context.entity, 'data', {}), - 'project_id': getattr(context.entity, 'project_id', None), - 'creation_method': 'declarative_transition', - } - - @register_transition('task', 'assign_and_start') - class AssignAndStartTaskTransition(BaseTransition): - """Assign task to user and start work""" - - assignee_id: int = Field(..., description='User assigned to task') - estimated_hours: float = Field(None, ge=0.1, description='Estimated work hours') - priority: str = Field('normal', description='Task priority') - - @property - def target_state(self) -> str: - return TaskStateChoices.IN_PROGRESS - - def validate_transition(self, context: TransitionContext) -> bool: - valid_from_states = [TaskStateChoices.CREATED] - if context.current_state not in valid_from_states: - raise TransitionValidationError( - f'Can only assign tasks from states: {valid_from_states}', - {'current_state': context.current_state, 'valid_states': valid_from_states}, - ) - - # Business rule: Can't assign to the same user who created it - if hasattr(context, 'current_state_object') and context.current_state_object: - creator_id = context.current_state_object.context_data.get('created_by_id') - if creator_id == self.assignee_id: - raise TransitionValidationError( - 'Cannot assign task to the same user who created it', - {'creator_id': creator_id, 'assignee_id': self.assignee_id}, - ) - - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'assignee_id': self.assignee_id, - 'estimated_hours': self.estimated_hours, - 'priority': self.priority, - 'assigned_at': context.timestamp.isoformat(), - 'assigned_by_id': context.current_user.id if context.current_user else None, - 'work_started': True, - } - - @register_transition('task', 'complete_with_quality') - class CompleteTaskWithQualityTransition(BaseTransition): - """Complete task with quality metrics""" - - quality_score: float = Field(..., ge=0.0, le=1.0, description='Quality score') - completion_notes: str = Field('', description='Completion notes') - actual_hours: float = Field(None, ge=0.0, description='Actual hours worked') - - @property - def target_state(self) -> str: - return TaskStateChoices.COMPLETED - - def validate_transition(self, context: TransitionContext) -> bool: - if context.current_state != TaskStateChoices.IN_PROGRESS: - raise TransitionValidationError( - 'Can only complete tasks that are in progress', {'current_state': context.current_state} - ) - - # Quality check - if self.quality_score < 0.6: - raise TransitionValidationError( - f'Quality score too low: {self.quality_score}. Minimum required: 0.6' - ) - - return True - - def post_transition_hook(self, context: TransitionContext, state_record) -> None: - """Post-completion tasks like notifications""" - # Mock notification system - if hasattr(self, '_notifications'): - self._notifications.append(f'Task {context.entity.pk} completed with quality {self.quality_score}') - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - # Calculate metrics - start_data = context.current_state_object.context_data if context.current_state_object else {} - estimated_hours = start_data.get('estimated_hours') - - return { - 'quality_score': self.quality_score, - 'completion_notes': self.completion_notes, - 'actual_hours': self.actual_hours, - 'estimated_hours': estimated_hours, - 'completed_at': context.timestamp.isoformat(), - 'completed_by_id': context.current_user.id if context.current_user else None, - 'efficiency_ratio': (estimated_hours / self.actual_hours) - if (estimated_hours and self.actual_hours) - else None, - } - - # Execute the complete workflow - - # Step 1: Create task - create_transition = CreateTaskTransition(created_by_id=100, initial_priority='high') - - # Test with StateManager integration - with patch('fsm.state_manager.StateManager.get_current_state') as mock_get_current: - mock_get_current.return_value = None # No current state - - context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state=None, - target_state=create_transition.target_state, - ) - - # Validate and execute creation - self.assertTrue(create_transition.validate_transition(context)) - creation_data = create_transition.transition(context) - - self.assertEqual(creation_data['created_by_id'], 100) - self.assertEqual(creation_data['initial_priority'], 'high') - self.assertEqual(creation_data['creation_method'], 'declarative_transition') - - # Step 2: Assign and start task - mock_current_state = Mock() - mock_current_state.context_data = creation_data - mock_get_state_obj.return_value = mock_current_state - - assign_transition = AssignAndStartTaskTransition( - assignee_id=200, estimated_hours=4.5, priority='urgent' # Different from creator - ) - - context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state=TaskStateChoices.CREATED, - current_state_object=mock_current_state, - target_state=assign_transition.target_state, - ) - - self.assertTrue(assign_transition.validate_transition(context)) - assignment_data = assign_transition.transition(context) - - self.assertEqual(assignment_data['assignee_id'], 200) - self.assertEqual(assignment_data['estimated_hours'], 4.5) - self.assertTrue(assignment_data['work_started']) - - # Step 3: Complete task - mock_current_state.context_data = assignment_data - - complete_transition = CompleteTaskWithQualityTransition( - quality_score=0.85, completion_notes='Task completed successfully with minor revisions', actual_hours=5.2 - ) - complete_transition._notifications = [] # Mock notification system - - context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state=TaskStateChoices.IN_PROGRESS, - current_state_object=mock_current_state, - target_state=complete_transition.target_state, - ) - - self.assertTrue(complete_transition.validate_transition(context)) - completion_data = complete_transition.transition(context) - - self.assertEqual(completion_data['quality_score'], 0.85) - self.assertEqual(completion_data['actual_hours'], 5.2) - self.assertAlmostEqual(completion_data['efficiency_ratio'], 4.5 / 5.2, places=2) - - # Test post-hook - mock_state_record = Mock() - complete_transition.post_transition_hook(context, mock_state_record) - self.assertEqual(len(complete_transition._notifications), 1) - - # Verify StateManager calls - self.assertEqual(mock_transition_state.call_count, 0) # Not called in our test setup - - def test_annotation_review_workflow_integration(self): - """ - INTEGRATION TEST: Annotation review workflow - - Demonstrates a realistic annotation review process using - enterprise-grade validation and approval logic. - """ - - @register_transition('annotation', 'submit_for_review') - class SubmitAnnotationForReview(BaseTransition): - """Submit annotation for quality review""" - - annotator_confidence: float = Field(..., ge=0.0, le=1.0, description='Annotator confidence') - annotation_time_seconds: int = Field(..., ge=1, description='Time spent annotating') - review_requested: bool = Field(True, description='Whether review is requested') - - @property - def target_state(self) -> str: - return AnnotationStateChoices.SUBMITTED - - def validate_transition(self, context: TransitionContext) -> bool: - # Check annotation has content - if not hasattr(context.entity, 'result') or not context.entity.result: - raise TransitionValidationError('Cannot submit empty annotation') - - # Business rule: Low confidence annotations must request review - if self.annotator_confidence < 0.7 and not self.review_requested: - raise TransitionValidationError( - 'Low confidence annotations must request review', - {'confidence': self.annotator_confidence, 'threshold': 0.7}, - ) - - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'annotator_confidence': self.annotator_confidence, - 'annotation_time_seconds': self.annotation_time_seconds, - 'review_requested': self.review_requested, - 'annotation_complexity': len(context.entity.result) if context.entity.result else 0, - 'submitted_at': context.timestamp.isoformat(), - 'submitted_by_id': context.current_user.id if context.current_user else None, - } - - @register_transition('annotation', 'review_and_approve') - class ReviewAndApproveAnnotation(BaseTransition): - """Review annotation and approve/reject""" - - reviewer_decision: str = Field(..., description='approve, reject, or request_changes') - quality_score: float = Field(..., ge=0.0, le=1.0, description='Reviewer quality assessment') - review_comments: str = Field('', description='Review comments') - corrections_made: bool = Field(False, description='Whether reviewer made corrections') - - @property - def target_state(self) -> str: - if self.reviewer_decision == 'approve': - return AnnotationStateChoices.COMPLETED - else: - return AnnotationStateChoices.DRAFT # Back to draft for changes - - def validate_transition(self, context: TransitionContext) -> bool: - if context.current_state != AnnotationStateChoices.SUBMITTED: - raise TransitionValidationError('Can only review submitted annotations') - - valid_decisions = ['approve', 'reject', 'request_changes'] - if self.reviewer_decision not in valid_decisions: - raise TransitionValidationError( - f'Invalid decision: {self.reviewer_decision}', {'valid_decisions': valid_decisions} - ) - - # Quality score validation based on decision - if self.reviewer_decision == 'approve' and self.quality_score < 0.6: - raise TransitionValidationError( - 'Cannot approve annotation with low quality score', - {'quality_score': self.quality_score, 'decision': self.reviewer_decision}, - ) - - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - # Get submission data for metrics - submission_data = context.current_state_object.context_data if context.current_state_object else {} - - return { - 'reviewer_decision': self.reviewer_decision, - 'quality_score': self.quality_score, - 'review_comments': self.review_comments, - 'corrections_made': self.corrections_made, - 'reviewed_at': context.timestamp.isoformat(), - 'reviewed_by_id': context.current_user.id if context.current_user else None, - 'original_confidence': submission_data.get('annotator_confidence'), - 'confidence_vs_quality_diff': abs( - submission_data.get('annotator_confidence', 0) - self.quality_score - ), - } - - # Execute annotation workflow - - # Step 1: Submit annotation - submit_transition = SubmitAnnotationForReview( - annotator_confidence=0.9, annotation_time_seconds=300, review_requested=True # 5 minutes - ) - - context = TransitionContext( - entity=self.annotation, - current_user=self.user, - current_state=AnnotationStateChoices.DRAFT, - target_state=submit_transition.target_state, - ) - - self.assertTrue(submit_transition.validate_transition(context)) - submit_data = submit_transition.transition(context) - - self.assertEqual(submit_data['annotator_confidence'], 0.9) - self.assertEqual(submit_data['annotation_time_seconds'], 300) - self.assertTrue(submit_data['review_requested']) - self.assertEqual(submit_data['annotation_complexity'], 1) # Based on mock result - - # Step 2: Review and approve - mock_submission_state = Mock() - mock_submission_state.context_data = submit_data - - review_transition = ReviewAndApproveAnnotation( - reviewer_decision='approve', - quality_score=0.85, - review_comments='High quality annotation with good coverage', - corrections_made=False, - ) - - context = TransitionContext( - entity=self.annotation, - current_user=self.user, - current_state=AnnotationStateChoices.SUBMITTED, - current_state_object=mock_submission_state, - target_state=review_transition.target_state, - ) - - self.assertTrue(review_transition.validate_transition(context)) - self.assertEqual(review_transition.target_state, AnnotationStateChoices.COMPLETED) - - review_data = review_transition.transition(context) - - self.assertEqual(review_data['reviewer_decision'], 'approve') - self.assertEqual(review_data['quality_score'], 0.85) - self.assertEqual(review_data['original_confidence'], 0.9) - self.assertAlmostEqual(review_data['confidence_vs_quality_diff'], 0.05, places=2) - - # Test rejection scenario - reject_transition = ReviewAndApproveAnnotation( - reviewer_decision='reject', - quality_score=0.3, - review_comments='Insufficient annotation quality', - corrections_made=False, - ) - - self.assertEqual(reject_transition.target_state, AnnotationStateChoices.DRAFT) - - # Test validation failure - invalid_review = ReviewAndApproveAnnotation( - reviewer_decision='approve', # Trying to approve - quality_score=0.5, # But quality too low - review_comments='Test', - ) - - with self.assertRaises(TransitionValidationError) as cm: - invalid_review.validate_transition(context) - - self.assertIn('Cannot approve annotation with low quality score', str(cm.exception)) - - @patch('fsm.transition_utils.execute_transition') - def test_transition_builder_with_django_models(self, mock_execute): - """ - INTEGRATION TEST: TransitionBuilder with Django model integration - - Shows how to use the fluent TransitionBuilder interface with - real Django models and complex business logic. - """ - - @register_transition('task', 'bulk_update_status') - class BulkUpdateTaskStatusTransition(BaseTransition): - """Bulk update task status with metadata""" - - new_status: str = Field(..., description='New status for tasks') - update_reason: str = Field(..., description='Reason for bulk update') - updated_by_system: bool = Field(False, description='Whether updated by automated system') - batch_id: str = Field(None, description='Batch operation ID') - - @property - def target_state(self) -> str: - return self.new_status - - def validate_transition(self, context: TransitionContext) -> bool: - valid_statuses = [TaskStateChoices.CREATED, TaskStateChoices.IN_PROGRESS, TaskStateChoices.COMPLETED] - if self.new_status not in valid_statuses: - raise TransitionValidationError(f'Invalid status: {self.new_status}') - - # Can't bulk update to the same status - if context.current_state == self.new_status: - raise TransitionValidationError('Cannot update to the same status') - - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'new_status': self.new_status, - 'update_reason': self.update_reason, - 'updated_by_system': self.updated_by_system, - 'batch_id': self.batch_id, - 'bulk_update_timestamp': context.timestamp.isoformat(), - 'previous_status': context.current_state, - } - - # Mock successful execution - mock_state_record = Mock() - mock_state_record.id = 'mock-uuid' - mock_execute.return_value = mock_state_record - - # Test fluent interface - result = ( - TransitionBuilder(self.task) - .transition('bulk_update_status') - .with_data( - new_status=TaskStateChoices.IN_PROGRESS, - update_reason='Project priority change', - updated_by_system=True, - batch_id='batch_2024_001', - ) - .by_user(self.user) - .with_context(project_update=True, notification_level='high') - .execute() - ) - - # Verify the call - mock_execute.assert_called_once() - call_args, call_kwargs = mock_execute.call_args - - # Check call parameters - self.assertEqual(call_kwargs['entity'], self.task) - self.assertEqual(call_kwargs['transition_name'], 'bulk_update_status') - self.assertEqual(call_kwargs['user'], self.user) - - # Check transition data - transition_data = call_kwargs['transition_data'] - self.assertEqual(transition_data['new_status'], TaskStateChoices.IN_PROGRESS) - self.assertEqual(transition_data['update_reason'], 'Project priority change') - self.assertTrue(transition_data['updated_by_system']) - self.assertEqual(transition_data['batch_id'], 'batch_2024_001') - - # Check context - self.assertTrue(call_kwargs['project_update']) - self.assertEqual(call_kwargs['notification_level'], 'high') - - # Check return value - self.assertEqual(result, mock_state_record) - - def test_error_handling_with_django_models(self): - """ - INTEGRATION TEST: Error handling with Django model validation - - Tests comprehensive error handling scenarios that might occur - in real Django model integration. - """ - - @register_transition('task', 'assign_with_constraints') - class AssignTaskWithConstraints(BaseTransition): - """Task assignment with business constraints""" - - assignee_id: int = Field(..., description='User to assign to') - max_concurrent_tasks: int = Field(5, description='Max concurrent tasks per user') - skill_requirements: list = Field(default_factory=list, description='Required skills') - - @property - def target_state(self) -> str: - return TaskStateChoices.IN_PROGRESS - - def validate_transition(self, context: TransitionContext) -> bool: - errors = [] - - # Mock database checks (in real scenario, these would be actual queries) - - # 1. Check user exists and is active - if self.assignee_id <= 0: - errors.append('Invalid user ID') - - # 2. Check user's current task load - if self.max_concurrent_tasks < 1: - errors.append('Max concurrent tasks must be at least 1') - - # 3. Check skill requirements - if self.skill_requirements: - # Mock skill validation - available_skills = ['python', 'labeling', 'review'] - missing_skills = [skill for skill in self.skill_requirements if skill not in available_skills] - if missing_skills: - errors.append(f'Missing required skills: {missing_skills}') - - # 4. Check project-level constraints - if hasattr(context.entity, 'project_id'): - # Mock project validation - if context.entity.project_id <= 0: - errors.append('Invalid project configuration') - - # 5. Check organization permissions - if hasattr(context.entity, 'organization_id'): - if not context.current_user: - errors.append('User authentication required for assignment') - - if errors: - raise TransitionValidationError( - f"Assignment validation failed: {'; '.join(errors)}", - { - 'validation_errors': errors, - 'assignee_id': self.assignee_id, - 'task_id': context.entity.pk, - 'skill_requirements': self.skill_requirements, - }, - ) - - return True - - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return { - 'assignee_id': self.assignee_id, - 'max_concurrent_tasks': self.max_concurrent_tasks, - 'skill_requirements': self.skill_requirements, - 'assignment_validated': True, - } - - # Test successful validation - valid_transition = AssignTaskWithConstraints( - assignee_id=123, max_concurrent_tasks=3, skill_requirements=['python', 'labeling'] - ) - - context = TransitionContext( - entity=self.task, - current_user=self.user, - current_state=TaskStateChoices.CREATED, - target_state=valid_transition.target_state, - ) - - self.assertTrue(valid_transition.validate_transition(context)) - - # Test multiple validation errors - invalid_transition = AssignTaskWithConstraints( - assignee_id=-1, # Invalid user ID - max_concurrent_tasks=0, # Invalid max tasks - skill_requirements=['nonexistent_skill'], # Missing skill - ) - - with self.assertRaises(TransitionValidationError) as cm: - invalid_transition.validate_transition(context) - - error = cm.exception - error_msg = str(error) - - # Check all validation errors are included - self.assertIn('Invalid user ID', error_msg) - self.assertIn('Max concurrent tasks must be at least 1', error_msg) - self.assertIn('Missing required skills', error_msg) - - # Check error context - self.assertIn('validation_errors', error.context) - self.assertEqual(len(error.context['validation_errors']), 3) - self.assertEqual(error.context['assignee_id'], -1) - - # Test authentication requirement - context_no_user = TransitionContext( - entity=self.task, - current_user=None, # No user - current_state=TaskStateChoices.CREATED, - target_state=valid_transition.target_state, - ) - - with self.assertRaises(TransitionValidationError) as cm: - valid_transition.validate_transition(context_no_user) - - self.assertIn('User authentication required', str(cm.exception)) From d134a1e8883d01b1a1176d16e5eb52fa283bf885 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 14:30:29 -0500 Subject: [PATCH 17/83] removing implementation details of fsm to break up the PR --- label_studio/fsm/admin.py | 175 -------------------------------- label_studio/fsm/transitions.py | 8 ++ 2 files changed, 8 insertions(+), 175 deletions(-) delete mode 100644 label_studio/fsm/admin.py diff --git a/label_studio/fsm/admin.py b/label_studio/fsm/admin.py deleted file mode 100644 index b9b027853f1f..000000000000 --- a/label_studio/fsm/admin.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Core FSM admin interface for Label Studio. - -Provides basic admin interface for state management that can be extended -""" - -from django.contrib import admin -from django.utils.html import format_html - -from .models import AnnotationState, ProjectState, TaskState - - -class BaseStateAdmin(admin.ModelAdmin): - """ - Base admin for state models. - - Provides common admin interface functionality for all state models. - """ - - list_display = [ - 'entity_display', - 'state', - 'previous_state', - 'transition_name', - 'triggered_by', - 'created_at', - ] - list_filter = [ - 'state', - 'created_at', - 'transition_name', - ] - search_fields = [ - 'state', - 'previous_state', - 'transition_name', - 'reason', - ] - readonly_fields = [ - 'id', - 'created_at', - 'timestamp_from_uuid', - 'entity_display', - ] - ordering = ['-created_at'] - - # Limit displayed records for performance - list_per_page = 50 - list_max_show_all = 200 - - def entity_display(self, obj): - """Display the related entity information""" - try: - entity = obj.entity - return format_html( - '{} #{}', - f'/admin/{entity._meta.app_label}/{entity._meta.model_name}/{entity.pk}/change/', - entity._meta.verbose_name.title(), - entity.pk, - ) - except Exception: - return f'{obj._get_entity_name().title()} #{getattr(obj, f"{obj._get_entity_name()}_id", "?")}' - - entity_display.short_description = 'Entity' - # Note: admin_order_field is set dynamically in subclasses since model is not available here - - def timestamp_from_uuid(self, obj): - """Display timestamp extracted from UUID7""" - return obj.timestamp_from_uuid - - timestamp_from_uuid.short_description = 'UUID7 Timestamp' - - def has_add_permission(self, request): - """Disable manual creation of state records""" - return False - - def has_change_permission(self, request, obj=None): - """State records should be read-only""" - return False - - def has_delete_permission(self, request, obj=None): - """State records should not be deleted""" - return False - - -@admin.register(TaskState) -class TaskStateAdmin(BaseStateAdmin): - """Admin interface for Task state records""" - - list_display = BaseStateAdmin.list_display + ['task_id'] - list_filter = BaseStateAdmin.list_filter + ['state'] - search_fields = BaseStateAdmin.search_fields + ['task__id'] - - def task_id(self, obj): - """Display task ID with link""" - return format_html('Task #{}', obj.task.pk, obj.task.pk) - - task_id.short_description = 'Task' - task_id.admin_order_field = 'task__id' - - -@admin.register(AnnotationState) -class AnnotationStateAdmin(BaseStateAdmin): - """Admin interface for Annotation state records""" - - list_display = BaseStateAdmin.list_display + ['annotation_id', 'task_link'] - list_filter = BaseStateAdmin.list_filter + ['state'] - search_fields = BaseStateAdmin.search_fields + ['annotation__id'] - - def annotation_id(self, obj): - """Display annotation ID with link""" - return format_html( - 'Annotation #{}', obj.annotation.pk, obj.annotation.pk - ) - - annotation_id.short_description = 'Annotation' - annotation_id.admin_order_field = 'annotation__id' - - def task_link(self, obj): - """Display related task link""" - task = obj.annotation.task - return format_html('Task #{}', task.pk, task.pk) - - task_link.short_description = 'Task' - task_link.admin_order_field = 'annotation__task__id' - - -@admin.register(ProjectState) -class ProjectStateAdmin(BaseStateAdmin): - """Admin interface for Project state records""" - - list_display = BaseStateAdmin.list_display + ['project_id', 'project_title'] - list_filter = BaseStateAdmin.list_filter + ['state'] - search_fields = BaseStateAdmin.search_fields + ['project__id', 'project__title'] - - def project_id(self, obj): - """Display project ID with link""" - return format_html( - 'Project #{}', obj.project.pk, obj.project.pk - ) - - project_id.short_description = 'Project' - project_id.admin_order_field = 'project__id' - - def project_title(self, obj): - """Display project title""" - return obj.project.title[:50] + ('...' if len(obj.project.title) > 50 else '') - - project_title.short_description = 'Title' - project_title.admin_order_field = 'project__title' - - -def mark_states_as_reviewed(modeladmin, request, queryset): - """ - Admin action to mark state records as reviewed. - """ - count = queryset.count() - modeladmin.message_user(request, f'{count} state records marked as reviewed.') - - -mark_states_as_reviewed.short_description = 'Mark selected states as reviewed' - - -def export_state_history(modeladmin, request, queryset): - """ - Admin action to export state history. - """ - count = queryset.count() - modeladmin.message_user(request, f'Export initiated for {count} state records.') - - -export_state_history.short_description = 'Export state history' - - -BaseStateAdmin.actions = [mark_states_as_reviewed, export_state_history] diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 92bd35a544db..1b6cdf41dd8f 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -379,6 +379,14 @@ def list_entities(self) -> list[str]: """Get a list of all registered entity types.""" return list(self._transitions.keys()) + def clear(self): + """ + Clear all registered transitions. + + Useful for testing to ensure clean state between tests. + """ + self._transitions.clear() + def execute_transition( self, entity_name: str, From a453ab693183738491edd8be7802aa55c16504ac Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 14:37:41 -0500 Subject: [PATCH 18/83] removing implementation details of fsm to break up the PR --- label_studio/fsm/apps.py | 13 ------ label_studio/fsm/integration.py | 78 --------------------------------- 2 files changed, 91 deletions(-) diff --git a/label_studio/fsm/apps.py b/label_studio/fsm/apps.py index a451c39fabbe..4eb0e0720da7 100644 --- a/label_studio/fsm/apps.py +++ b/label_studio/fsm/apps.py @@ -37,22 +37,9 @@ def _setup_signals(self): """Set up signal handlers for automatic state creation""" try: from django.conf import settings - from django.db.models.signals import post_save # Only set up signals if enabled in settings if getattr(settings, 'FSM_AUTO_CREATE_STATES', False): - from label_studio.projects.models import Project - - # Import models - from label_studio.tasks.models import Annotation, Task - - from .integration import handle_annotation_created, handle_project_created, handle_task_created - - # Connect signal handlers - post_save.connect(handle_task_created, sender=Task) - post_save.connect(handle_annotation_created, sender=Annotation) - post_save.connect(handle_project_created, sender=Project) - logger.info('FSM signal handlers registered') except Exception as e: diff --git a/label_studio/fsm/integration.py b/label_studio/fsm/integration.py index 56a187fa9239..560e1eaf76eb 100644 --- a/label_studio/fsm/integration.py +++ b/label_studio/fsm/integration.py @@ -117,84 +117,6 @@ def get_fsm_state_history_method(self, limit: int = 100): return model_class -# Signal handlers for automatic state transitions - - -def handle_task_created(sender, instance, created, **kwargs): - """ - Signal handler to automatically create initial state when a task is created. - - Connect this to the Task model's post_save signal: - from django.db.models.signals import post_save - from tasks.models import Task - from fsm.integration import handle_task_created - - post_save.connect(handle_task_created, sender=Task) - """ - if created: - try: - StateManager = get_state_manager() - StateManager.transition_state( - entity=instance, - new_state='CREATED', - transition_name='create_task', - reason='Task created automatically', - ) - logger.info(f'Created initial FSM state for Task {instance.pk}') - except Exception as e: - logger.error(f'Failed to create initial FSM state for Task {instance.pk}: {e}') - - -def handle_annotation_created(sender, instance, created, **kwargs): - """ - Signal handler to automatically create initial state when an annotation is created. - - Connect this to the Annotation model's post_save signal: - from django.db.models.signals import post_save - from tasks.models import Annotation - from fsm.integration import handle_annotation_created - - post_save.connect(handle_annotation_created, sender=Annotation) - """ - if created: - try: - StateManager = get_state_manager() - StateManager.transition_state( - entity=instance, - new_state='DRAFT', - transition_name='create_annotation', - reason='Annotation created automatically', - ) - logger.info(f'Created initial FSM state for Annotation {instance.pk}') - except Exception as e: - logger.error(f'Failed to create initial FSM state for Annotation {instance.pk}: {e}') - - -def handle_project_created(sender, instance, created, **kwargs): - """ - Signal handler to automatically create initial state when a project is created. - - Connect this to the Project model's post_save signal: - from django.db.models.signals import post_save - from projects.models import Project - from fsm.integration import handle_project_created - - post_save.connect(handle_project_created, sender=Project) - """ - if created: - try: - StateManager = get_state_manager() - StateManager.transition_state( - entity=instance, - new_state='CREATED', - transition_name='create_project', - reason='Project created automatically', - ) - logger.info(f'Created initial FSM state for Project {instance.pk}') - except Exception as e: - logger.error(f'Failed to create initial FSM state for Project {instance.pk}: {e}') - - # Utility functions for model extensions From 6dc9933d1b1484fb92679920f37324983e18d85b Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 14:49:37 -0500 Subject: [PATCH 19/83] updating doc strings --- label_studio/fsm/README.md | 2 +- label_studio/fsm/models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 1828d36ec3c5..91e853e02148 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -128,7 +128,7 @@ states = StateManager.bulk_get_current_states(orders) - **Natural Time Ordering**: UUID7 provides chronological ordering without separate timestamp indexes - **High Concurrency**: INSERT-only approach eliminates locking contention -- **Scalability**: Supports billions of state records with consistent performance +- **Scalability**: Supports large amounts of state records with consistent performance ### Declarative Transitions diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 5c520751e0f8..33b11c3c294d 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -27,7 +27,7 @@ class BaseState(models.Model): - INSERT-only operations for maximum concurrency - Natural time ordering eliminates need for created_at indexes - Global uniqueness enables distributed system support - - Time-based partitioning for billion-record scalability + - Time-based partitioning for large amounts of state records with consistent performance - Complete audit trail by design """ From 6f73a740c2ad50e479dc53227366c7d566ccd86a Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 16:39:45 -0500 Subject: [PATCH 20/83] fixing registry --- label_studio/fsm/api.py | 7 +- label_studio/fsm/integration.py | 188 ------------- label_studio/fsm/registry.py | 404 +++++++++++++++++++++++++++ label_studio/fsm/state_manager.py | 6 +- label_studio/fsm/transition_utils.py | 7 +- label_studio/fsm/transitions.py | 158 +---------- 6 files changed, 414 insertions(+), 356 deletions(-) delete mode 100644 label_studio/fsm/integration.py create mode 100644 label_studio/fsm/registry.py diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py index e8d271b7c74a..d8ed952159df 100644 --- a/label_studio/fsm/api.py +++ b/label_studio/fsm/api.py @@ -9,15 +9,14 @@ from django.http import Http404 from django.shortcuts import get_object_or_404 +from fsm.registry import get_state_model_for_entity +from fsm.serializers import StateHistorySerializer, StateTransitionSerializer +from fsm.state_manager import get_state_manager from rest_framework import status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response -from .models import get_state_model_for_entity -from .serializers import StateHistorySerializer, StateTransitionSerializer -from .state_manager import get_state_manager - logger = logging.getLogger(__name__) diff --git a/label_studio/fsm/integration.py b/label_studio/fsm/integration.py deleted file mode 100644 index 560e1eaf76eb..000000000000 --- a/label_studio/fsm/integration.py +++ /dev/null @@ -1,188 +0,0 @@ -""" -Integration helpers for connecting FSM with existing Label Studio models. - -This module provides helper methods and mixins that can be added to existing -Label Studio models to integrate them with the FSM system. -""" - -import logging -from typing import Optional - -from django.db import models - -from .state_manager import get_state_manager - -logger = logging.getLogger(__name__) - - -class FSMIntegrationMixin: - """ - Mixin to add FSM functionality to existing Label Studio models. - - This mixin can be added to Task, Annotation, and Project models to provide - convenient methods for state management without modifying the core models. - """ - - @property - def current_fsm_state(self) -> Optional[str]: - """Get current FSM state for this entity""" - StateManager = get_state_manager() - return StateManager.get_current_state(self) - - def transition_fsm_state( - self, new_state: str, user=None, transition_name: str = None, reason: str = '', context: dict = None - ) -> bool: - """ - Transition this entity to a new FSM state. - - Args: - new_state: Target state - user: User triggering the transition - transition_name: Name of transition method - reason: Human-readable reason - context: Additional context data - - Returns: - True if transition succeeded - """ - StateManager = get_state_manager() - return StateManager.transition_state( - entity=self, - new_state=new_state, - user=user, - transition_name=transition_name, - reason=reason, - context=context or {}, - ) - - def get_fsm_state_history(self, limit: int = 100): - """Get FSM state history for this entity""" - StateManager = get_state_manager() - return StateManager.get_state_history(self, limit) - - def is_in_fsm_state(self, state: str) -> bool: - """Check if entity is currently in the specified state""" - return self.current_fsm_state == state - - def has_fsm_state_history(self) -> bool: - """Check if entity has any FSM state records""" - return self.current_fsm_state is not None - - -def add_fsm_to_model(model_class): - """ - Class decorator to add FSM functionality to existing models. - - This provides an alternative to inheritance for adding FSM capabilities. - - Example: - from fsm.integration import add_fsm_to_model - from tasks.models import Task - - @add_fsm_to_model - class Task(Task): - class Meta: - proxy = True - """ - - def current_fsm_state_property(self): - """Get current FSM state for this entity""" - StateManager = get_state_manager() - return StateManager.get_current_state(self) - - def transition_fsm_state_method( - self, new_state: str, user=None, transition_name: str = None, reason: str = '', context: dict = None - ): - """Transition this entity to a new FSM state""" - StateManager = get_state_manager() - return StateManager.transition_state( - entity=self, - new_state=new_state, - user=user, - transition_name=transition_name, - reason=reason, - context=context or {}, - ) - - def get_fsm_state_history_method(self, limit: int = 100): - """Get FSM state history for this entity""" - StateManager = get_state_manager() - return StateManager.get_state_history(self, limit) - - # Add methods as properties/methods to the class - model_class.current_fsm_state = property(current_fsm_state_property) - model_class.transition_fsm_state = transition_fsm_state_method - model_class.get_fsm_state_history = get_fsm_state_history_method - - return model_class - - -# Utility functions for model extensions - - -def get_entities_by_state(model_class, state: str, limit: int = 100): - """ - Get entities that are currently in a specific state. - - Args: - model_class: Django model class (e.g., Task, Annotation, Project) - state: State to filter by - limit: Maximum number of entities to return - - Returns: - QuerySet of entities in the specified state - - Example: - from tasks.models import Task - from fsm.integration import get_entities_by_state - - completed_tasks = get_entities_by_state(Task, 'COMPLETED', limit=50) - """ - from .models import get_state_model_for_entity - - # Create a dummy instance to get the state model - dummy_instance = model_class() - state_model = get_state_model_for_entity(dummy_instance) - - if not state_model: - return model_class.objects.none() - - # Get entity IDs that have the specified current state - f'{model_class._meta.model_name.lower()}_id' - - current_state_subquery = ( - state_model.objects.filter(**{f'{model_class._meta.model_name.lower()}__pk': models.OuterRef('pk')}) - .order_by('-id') - .values('state')[:1] - ) - - return model_class.objects.annotate(current_state=models.Subquery(current_state_subquery)).filter( - current_state=state - )[:limit] - - -def bulk_transition_entities(entities, new_state: str, user=None, **kwargs): - """ - Bulk transition multiple entities to the same state. - - Args: - entities: List of entity instances - new_state: Target state for all entities - user: User triggering the transitions - **kwargs: Additional arguments for transition_state - - Returns: - List of (entity, success) tuples - """ - StateManager = get_state_manager() - results = [] - - for entity in entities: - try: - success = StateManager.transition_state(entity=entity, new_state=new_state, user=user, **kwargs) - results.append((entity, success)) - except Exception as e: - logger.error(f'Failed to transition {entity._meta.model_name} {entity.pk}: {e}') - results.append((entity, False)) - - return results diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py new file mode 100644 index 000000000000..899e0dddf3b0 --- /dev/null +++ b/label_studio/fsm/registry.py @@ -0,0 +1,404 @@ +""" +FSM Model Registry for Model State Management. + +This module provides a registry system for state models and state choices, +allowing the FSM to be decoupled from concrete implementations. +""" + +import logging +from typing import Any, Callable, Dict, Optional, Type + +from django.db.models import Model, TextChoices +from fsm.models import BaseState + +logger = logging.getLogger(__name__) + + +class StateChoicesRegistry: + """ + Registry for managing state choices for different entity types. + + Provides a centralized way to register, discover, and manage state choices + for different entity types in the FSM system. + """ + + def __init__(self): + self._choices: Dict[str, Type[TextChoices]] = {} + + def register(self, entity_name: str, choices_class: Type[TextChoices]): + """ + Register state choices for an entity type. + + Args: + entity_name: Name of the entity (e.g., 'task', 'annotation') + choices_class: Django TextChoices class defining valid states + """ + self._choices[entity_name.lower()] = choices_class + + def get_choices(self, entity_name: str) -> Optional[Type[TextChoices]]: + """ + Get state choices for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + Django TextChoices class or None if not found + """ + return self._choices.get(entity_name.lower()) + + def list_entities(self) -> list[str]: + """Get a list of all registered entity types.""" + return list(self._choices.keys()) + + def clear(self): + """ + Clear all registered choices. + + Useful for testing to ensure clean state between tests. + """ + self._choices.clear() + + +# Global state choices registry instance +state_choices_registry = StateChoicesRegistry() + + +def get_state_choices(entity_name: str): + """ + Get state choices for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + Django TextChoices class or None if not found + """ + return state_choices_registry.get_choices(entity_name) + + +def register_state_choices(entity_name: str): + """ + Decorator to register state choices for an entity type. + + Args: + entity_name: Name of the entity type + + Example: + @register_state_choices('task') + class TaskStateChoices(models.TextChoices): + CREATED = 'CREATED', _('Created') + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + COMPLETED = 'COMPLETED', _('Completed') + """ + + def decorator(choices_class: Type[TextChoices]) -> Type[TextChoices]: + state_choices_registry.register(entity_name, choices_class) + return choices_class + + return decorator + + +class StateModelRegistry: + """ + Registry for state models and their configurations. + + This allows projects to register their state models dynamically + without hardcoding them in the FSM framework. + """ + + def __init__(self): + self._models: Dict[str, Type[BaseState]] = {} + self._denormalizers: Dict[str, Callable[[Model], Dict[str, Any]]] = {} + self._initialized = False + + def register_model( + self, + entity_name: str, + state_model: Type[BaseState], + denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None, + ): + """ + Register a state model for an entity type. + + Args: + entity_name: Name of the entity (e.g., 'task', 'annotation') + state_model: The state model class for this entity + denormalizer: Optional function to extract denormalized fields + """ + entity_key = entity_name.lower() + + if entity_key in self._models: + logger.warning( + f'Overwriting existing state model for {entity_key}. ' + f'Previous: {self._models[entity_key]}, New: {state_model}' + ) + + self._models[entity_key] = state_model + + if denormalizer: + self._denormalizers[entity_key] = denormalizer + + logger.debug(f'Registered state model for {entity_key}: {state_model.__name__}') + + def get_model(self, entity_name: str) -> Optional[Type[BaseState]]: + """ + Get the state model for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + State model class or None if not registered + """ + return self._models.get(entity_name.lower()) + + def get_denormalizer(self, entity_name: str) -> Optional[Callable]: + """ + Get the denormalization function for an entity type. + + Args: + entity_name: Name of the entity + + Returns: + Denormalizer function or None if not registered + """ + return self._denormalizers.get(entity_name.lower()) + + def get_denormalized_fields(self, entity: Model) -> Dict[str, Any]: + """ + Get denormalized fields for an entity. + + Args: + entity: The entity instance + + Returns: + Dictionary of denormalized fields + """ + entity_name = entity._meta.model_name.lower() + denormalizer = self._denormalizers.get(entity_name) + + if denormalizer: + try: + return denormalizer(entity) + except Exception as e: + logger.error(f'Error getting denormalized fields for {entity_name}: {e}') + + return {} + + def is_registered(self, entity_name: str) -> bool: + """Check if a model is registered for an entity type.""" + return entity_name.lower() in self._models + + def clear(self): + """Clear all registered models (useful for testing).""" + self._models.clear() + self._denormalizers.clear() + self._initialized = False + logger.debug('Cleared state model registry') + + def get_all_models(self) -> Dict[str, Type[BaseState]]: + """Get all registered models.""" + return self._models.copy() + + def mark_initialized(self): + """Mark the registry as initialized.""" + self._initialized = True + logger.info(f'State model registry initialized with {len(self._models)} models') + + def is_initialized(self) -> bool: + """Check if the registry has been initialized.""" + return self._initialized + + +# Global registry instance +state_model_registry = StateModelRegistry() + + +def register_state_model( + entity_name: str, state_model: Type[BaseState], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None +): + """ + Convenience function to register a state model. + + Args: + entity_name: Name of the entity (e.g., 'task', 'annotation') + state_model: The state model class for this entity + denormalizer: Optional function to extract denormalized fields + """ + state_model_registry.register_model(entity_name, state_model, denormalizer) + + +def get_state_model(entity_name: str) -> Optional[Type[BaseState]]: + """ + Convenience function to get a state model. + + Args: + entity_name: Name of the entity + + Returns: + State model class or None if not registered + """ + return state_model_registry.get_model(entity_name) + + +def get_state_model_for_entity(entity: Model) -> Optional[Type[BaseState]]: + """Get the state model for an entity.""" + entity_name = entity._meta.model_name.lower() + return get_state_model(entity_name) + + +class TransitionRegistry: + """ + Registry for managing declarative transitions. + + Provides a centralized way to register, discover, and execute transitions + for different entity types and state models. + """ + + def __init__(self): + self._transitions: Dict[str, Dict[str, Type[BaseTransition]]] = {} + + def register(self, entity_name: str, transition_name: str, transition_class: Type[BaseTransition]): + """ + Register a transition class for an entity. + + Args: + entity_name: Name of the entity type (e.g., 'task', 'annotation') + transition_name: Name of the transition (e.g., 'start_task', 'submit_annotation') + transition_class: The transition class to register + """ + if entity_name not in self._transitions: + self._transitions[entity_name] = {} + + self._transitions[entity_name][transition_name] = transition_class + + def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type[BaseTransition]]: + """ + Get a registered transition class. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition + + Returns: + The transition class if found, None otherwise + """ + return self._transitions.get(entity_name, {}).get(transition_name) + + def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type[BaseTransition]]: + """ + Get all registered transitions for an entity type. + + Args: + entity_name: Name of the entity type + + Returns: + Dictionary mapping transition names to transition classes + """ + return self._transitions.get(entity_name, {}).copy() + + def list_entities(self) -> list[str]: + """Get a list of all registered entity types.""" + return list(self._transitions.keys()) + + def clear(self): + """ + Clear all registered transitions. + + Useful for testing to ensure clean state between tests. + """ + self._transitions.clear() + + def execute_transition( + self, + entity_name: str, + transition_name: str, + entity: Model, + transition_data: Dict[str, Any], + user: Optional[User] = None, + **context_kwargs, + ) -> StateModelType: + """ + Execute a registered transition. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition + entity: The entity instance to transition + transition_data: Data for the transition (will be validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found + TransitionValidationError: If transition validation fails + """ + transition_class = self.get_transition(entity_name, transition_name) + if not transition_class: + raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") + + # Create transition instance with provided data + transition = transition_class(**transition_data) + + # Get current state information + from fsm.state_manager import StateManager + + current_state_object = StateManager.get_current_state_object(entity) + current_state = current_state_object.state if current_state_object else None + + # Build transition context + context = TransitionContext( + entity=entity, + current_user=user, + current_state_object=current_state_object, + current_state=current_state, + target_state=transition.target_state, + organization_id=getattr(entity, 'organization_id', None), + **context_kwargs, + ) + + # Execute the transition + return transition.execute(context) + + +# Global transition registry instance +transition_registry = TransitionRegistry() + + +def register_transition(entity_name: str, transition_name: str = None): + """ + Decorator to register a transition class. + + Args: + entity_name: Name of the entity type + transition_name: Name of the transition (defaults to class name in snake_case) + + Example: + @register_transition('task', 'start_task') + class StartTaskTransition(BaseTransition[Task, TaskState]): + # ... implementation + """ + + def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: + name = transition_name + if name is None: + # Generate name from class name + class_name = transition_class.__name__ + if class_name.endswith('Transition'): + class_name = class_name[:-10] # Remove 'Transition' suffix + + # Convert CamelCase to snake_case + name = '' + for i, char in enumerate(class_name): + if char.isupper() and i > 0: + name += '_' + name += char.lower() + + transition_registry.register(entity_name, name, transition_class) + return transition_class + + return decorator diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 1670bc87a92a..c7cd5dd61c3f 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -13,12 +13,12 @@ from django.core.cache import cache from django.db import transaction from django.db.models import Model - -from .models import BaseState, get_state_model_for_entity +from fsm.models import BaseState +from fsm.registry import get_state_model_for_entity # Avoid circular import if TYPE_CHECKING: - from .transitions import BaseTransition + from fsm.transitions import BaseTransition logger = logging.getLogger(__name__) diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index d3fcbada5b13..4f1b6245cbb5 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -8,10 +8,9 @@ from typing import Any, Dict, List, Optional, Type from django.db.models import Model - -from .models import BaseState -from .state_manager import StateManager -from .transitions import BaseTransition, TransitionValidationError, transition_registry +from fsm.models import BaseState +from fsm.state_manager import StateManager +from fsm.transitions import BaseTransition, TransitionValidationError, transition_registry def execute_transition( diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 1b6cdf41dd8f..12540f08074a 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -8,7 +8,7 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generic, Optional, Type, TypeVar +from typing import Any, Dict, Generic, Optional, TypeVar from django.contrib.auth import get_user_model from django.db.models import Model @@ -323,159 +323,3 @@ def execute(self, context: TransitionContext[EntityType, StateModelType]) -> Sta # Clear context on error self.context = None raise - - -class TransitionRegistry: - """ - Registry for managing declarative transitions. - - Provides a centralized way to register, discover, and execute transitions - for different entity types and state models. - """ - - def __init__(self): - self._transitions: Dict[str, Dict[str, Type[BaseTransition]]] = {} - - def register(self, entity_name: str, transition_name: str, transition_class: Type[BaseTransition]): - """ - Register a transition class for an entity. - - Args: - entity_name: Name of the entity type (e.g., 'task', 'annotation') - transition_name: Name of the transition (e.g., 'start_task', 'submit_annotation') - transition_class: The transition class to register - """ - if entity_name not in self._transitions: - self._transitions[entity_name] = {} - - self._transitions[entity_name][transition_name] = transition_class - - def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type[BaseTransition]]: - """ - Get a registered transition class. - - Args: - entity_name: Name of the entity type - transition_name: Name of the transition - - Returns: - The transition class if found, None otherwise - """ - return self._transitions.get(entity_name, {}).get(transition_name) - - def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type[BaseTransition]]: - """ - Get all registered transitions for an entity type. - - Args: - entity_name: Name of the entity type - - Returns: - Dictionary mapping transition names to transition classes - """ - return self._transitions.get(entity_name, {}).copy() - - def list_entities(self) -> list[str]: - """Get a list of all registered entity types.""" - return list(self._transitions.keys()) - - def clear(self): - """ - Clear all registered transitions. - - Useful for testing to ensure clean state between tests. - """ - self._transitions.clear() - - def execute_transition( - self, - entity_name: str, - transition_name: str, - entity: Model, - transition_data: Dict[str, Any], - user: Optional[User] = None, - **context_kwargs, - ) -> StateModelType: - """ - Execute a registered transition. - - Args: - entity_name: Name of the entity type - transition_name: Name of the transition - entity: The entity instance to transition - transition_data: Data for the transition (will be validated by Pydantic) - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - - Raises: - ValueError: If transition is not found - TransitionValidationError: If transition validation fails - """ - transition_class = self.get_transition(entity_name, transition_name) - if not transition_class: - raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") - - # Create transition instance with provided data - transition = transition_class(**transition_data) - - # Get current state information - from .state_manager import StateManager - - current_state_object = StateManager.get_current_state_object(entity) - current_state = current_state_object.state if current_state_object else None - - # Build transition context - context = TransitionContext( - entity=entity, - current_user=user, - current_state_object=current_state_object, - current_state=current_state, - target_state=transition.target_state, - organization_id=getattr(entity, 'organization_id', None), - **context_kwargs, - ) - - # Execute the transition - return transition.execute(context) - - -# Global transition registry instance -transition_registry = TransitionRegistry() - - -def register_transition(entity_name: str, transition_name: str = None): - """ - Decorator to register a transition class. - - Args: - entity_name: Name of the entity type - transition_name: Name of the transition (defaults to class name in snake_case) - - Example: - @register_transition('task', 'start_task') - class StartTaskTransition(BaseTransition[Task, TaskState]): - # ... implementation - """ - - def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: - name = transition_name - if name is None: - # Generate name from class name - class_name = transition_class.__name__ - if class_name.endswith('Transition'): - class_name = class_name[:-10] # Remove 'Transition' suffix - - # Convert CamelCase to snake_case - name = '' - for i, char in enumerate(class_name): - if char.isupper() and i > 0: - name += '_' - name += char.lower() - - transition_registry.register(entity_name, name, transition_class) - return transition_class - - return decorator From 3869f79ebced6613f3c1fa8e8e30571fda31b60a Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 16:44:20 -0500 Subject: [PATCH 21/83] fixing registry --- label_studio/fsm/registry.py | 1 + label_studio/fsm/transitions.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index 899e0dddf3b0..ce792ab6a482 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -10,6 +10,7 @@ from django.db.models import Model, TextChoices from fsm.models import BaseState +from fsm.transitions import BaseTransition, StateModelType, TransitionContext, User logger = logging.getLogger(__name__) diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 12540f08074a..cbc36936f940 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -12,10 +12,9 @@ from django.contrib.auth import get_user_model from django.db.models import Model +from fsm.models import BaseState from pydantic import BaseModel, ConfigDict, Field -from .models import BaseState - User = get_user_model() # Type variables for generic transition context From 0c37517e700831fb865e69849e6b36e060b701d1 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 16:49:54 -0500 Subject: [PATCH 22/83] fixing tests --- label_studio/fsm/tests/test_api_usage_examples.py | 3 +-- label_studio/fsm/tests/test_declarative_transitions.py | 3 +-- label_studio/fsm/tests/test_edge_cases_error_handling.py | 3 ++- label_studio/fsm/tests/test_performance_concurrency.py | 3 ++- label_studio/fsm/transition_utils.py | 3 ++- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/label_studio/fsm/tests/test_api_usage_examples.py b/label_studio/fsm/tests/test_api_usage_examples.py index fcd8bbc26013..e70fc0b11ff4 100644 --- a/label_studio/fsm/tests/test_api_usage_examples.py +++ b/label_studio/fsm/tests/test_api_usage_examples.py @@ -12,6 +12,7 @@ from unittest.mock import Mock from django.test import TestCase +from fsm.registry import register_transition, transition_registry from fsm.transition_utils import ( get_transition_schema, ) @@ -19,8 +20,6 @@ BaseTransition, TransitionContext, TransitionValidationError, - register_transition, - transition_registry, ) from pydantic import Field, validator diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 539d67a3f602..29c17f322176 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -14,6 +14,7 @@ from django.db import models from django.test import TestCase from django.utils.translation import gettext_lazy as _ +from fsm.registry import register_transition, transition_registry from fsm.transition_utils import ( TransitionBuilder, get_available_transitions, @@ -22,8 +23,6 @@ BaseTransition, TransitionContext, TransitionValidationError, - register_transition, - transition_registry, ) from pydantic import Field, ValidationError diff --git a/label_studio/fsm/tests/test_edge_cases_error_handling.py b/label_studio/fsm/tests/test_edge_cases_error_handling.py index 17b900336511..22c7c7799fa6 100644 --- a/label_studio/fsm/tests/test_edge_cases_error_handling.py +++ b/label_studio/fsm/tests/test_edge_cases_error_handling.py @@ -14,8 +14,9 @@ from unittest.mock import Mock from django.test import TestCase +from fsm.registry import transition_registry from fsm.transition_utils import TransitionBuilder -from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, transition_registry +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError from pydantic import Field, ValidationError diff --git a/label_studio/fsm/tests/test_performance_concurrency.py b/label_studio/fsm/tests/test_performance_concurrency.py index 896b4fe65556..25a677e12986 100644 --- a/label_studio/fsm/tests/test_performance_concurrency.py +++ b/label_studio/fsm/tests/test_performance_concurrency.py @@ -14,7 +14,8 @@ from unittest.mock import Mock from django.test import TestCase, TransactionTestCase -from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, transition_registry +from fsm.registry import transition_registry +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError from pydantic import Field diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 4f1b6245cbb5..15319b800f75 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -9,8 +9,9 @@ from django.db.models import Model from fsm.models import BaseState +from fsm.registry import transition_registry from fsm.state_manager import StateManager -from fsm.transitions import BaseTransition, TransitionValidationError, transition_registry +from fsm.transitions import BaseTransition, TransitionValidationError def execute_transition( From c23c13fa3bc53663e466eedc8bdf0c9edbf9913e Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 16:55:35 -0500 Subject: [PATCH 23/83] updating README --- label_studio/fsm/README.md | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 91e853e02148..3d6dc6ce56a4 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -30,7 +30,9 @@ The FSM framework provides: ```python from django.db import models from django.utils.translation import gettext_lazy as _ +from fsm.registry import register_state_choices +@register_state_choices('order') class OrderStateChoices(models.TextChoices): CREATED = 'CREATED', _('Created') PROCESSING = 'PROCESSING', _('Processing') @@ -43,11 +45,9 @@ class OrderStateChoices(models.TextChoices): ```python from fsm.models import BaseState -from fsm.state_choices import register_state_choices - -# Register state choices -register_state_choices('order', OrderStateChoices) +from fsm.registry import register_state_model +@register_state_model('order') class OrderState(BaseState): # Entity relationship order = models.ForeignKey('shop.Order', related_name='fsm_states', on_delete=models.CASCADE) @@ -67,7 +67,8 @@ class OrderState(BaseState): ### 3. Define Transitions ```python -from fsm.transitions import BaseTransition, register_transition +from fsm.transitions import BaseTransition +from fsm.registry import register_transition from pydantic import Field @register_transition('order', 'process_order') @@ -225,4 +226,4 @@ When contributing: - Keep framework code generic and reusable - Add product-specific code to appropriate implementation branches - Include performance tests for UUID7 optimizations -- Document extension points and customization options \ No newline at end of file +- Document extension points and customization options From 1fee6de063e1dea15a06720e6d6e227ed62c9c3c Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 27 Aug 2025 16:57:57 -0500 Subject: [PATCH 24/83] Apply suggestion from @bmartel --- label_studio/fsm/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 3d6dc6ce56a4..21aaa9ead79b 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -162,7 +162,7 @@ GET /api/fsm/{entity_type}/{entity_id}/history/ # State history POST /api/fsm/{entity_type}/{entity_id}/transition/ # Execute transition ``` -Extend the base viewset for your application: +Extend the base viewset ```python from fsm.api import FSMViewSet From dec91f9361e64867bc0bb708da206dbdcdb05875 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:04:18 -0500 Subject: [PATCH 25/83] avoiding import cycle due to types --- label_studio/fsm/models.py | 3 +-- label_studio/fsm/registry.py | 5 ++++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 33b11c3c294d..d0da3c433166 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -8,8 +8,7 @@ from django.conf import settings from django.db import models from django.db.models import UUIDField - -from .utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 +from fsm.utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 class BaseState(models.Model): diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index ce792ab6a482..fdf6e0756503 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -6,12 +6,15 @@ """ import logging +import typing from typing import Any, Callable, Dict, Optional, Type from django.db.models import Model, TextChoices -from fsm.models import BaseState from fsm.transitions import BaseTransition, StateModelType, TransitionContext, User +if typing.TYPE_CHECKING: + from fsm.models import BaseState + logger = logging.getLogger(__name__) From 129e6b1499692a4080980407e818eb9575b3d875 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:05:51 -0500 Subject: [PATCH 26/83] feat: FIT-587: FSM core model setup --- label_studio/fsm/models.py | 157 ++- label_studio/fsm/state_choices.py | 70 +- .../fsm/tests/test_declarative_transitions.py | 994 ++++++++++++++---- .../fsm/tests/test_fsm_integration.py | 302 ++++++ .../tests/test_integration_django_models.py | 665 ++++++++++++ 5 files changed, 1931 insertions(+), 257 deletions(-) create mode 100644 label_studio/fsm/tests/test_fsm_integration.py create mode 100644 label_studio/fsm/tests/test_integration_django_models.py diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index d0da3c433166..061dfb083293 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -8,6 +8,8 @@ from django.conf import settings from django.db import models from django.db.models import UUIDField +from fsm.registry import register_state_model +from fsm.state_choices import AnnotationStateChoices, ProjectStateChoices, TaskStateChoices from fsm.utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 @@ -168,43 +170,146 @@ def _get_entity_field_name(cls) -> str: return 'entity' -# Registry for dynamic state model extension -STATE_MODEL_REGISTRY = {} +# Core state models for basic Label Studio entities -def register_state_model(entity_name: str, model_class): +@register_state_model('task') +class TaskState(BaseState): """ - Register state model for an entity type. - - Args: - entity_name: Name of the entity (e.g., 'review', 'assignment') - model_class: Django model class inheriting from BaseState + Core task state tracking for Label Studio. + Provides basic task state management with: + - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) + - High-performance queries with UUID7 ordering """ - STATE_MODEL_REGISTRY[entity_name.lower()] = model_class + # Entity Relationship + task = models.ForeignKey('tasks.Task', related_name='fsm_states', on_delete=models.CASCADE, db_index=True) -def get_state_model(entity_name: str): - """ - Get state model for an entity type. + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=TaskStateChoices.choices, db_index=True) - Args: - entity_name: Name of the entity + project_id = models.PositiveIntegerField( + db_index=True, help_text='From task.project_id - denormalized for performance' + ) - Returns: - Django model class inheriting from BaseState, or None if not found - """ - return STATE_MODEL_REGISTRY.get(entity_name.lower()) + class Meta: + app_label = 'fsm' + indexes = [ + # Critical: Latest state lookup (current state determined by latest UUID7 id) + # Index with DESC order explicitly supports ORDER BY id DESC queries + models.Index(fields=['task_id', '-id'], name='task_current_state_idx'), + # Reporting and filtering + models.Index(fields=['project_id', 'state', '-id'], name='task_project_state_idx'), + models.Index(fields=['organization_id', 'state', '-id'], name='task_org_reporting_idx'), + # History queries + models.Index(fields=['task_id', 'id'], name='task_history_idx'), + ] + # No constraints needed - INSERT-only approach + ordering = ['-id'] + + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for TaskState creation""" + return { + 'project_id': entity.project_id, + } + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal task state""" + return self.state == TaskStateChoices.COMPLETED -def get_state_model_for_entity(entity): +@register_state_model('annotation') +class AnnotationState(BaseState): + """ + Core annotation state tracking for Label Studio. + Provides basic annotation state management with: + - Simple 3-state workflow (DRAFT → SUBMITTED → COMPLETED) """ - Get state model for a specific entity instance. - Args: - entity: Django model instance + # Entity Relationship + annotation = models.ForeignKey('tasks.Annotation', on_delete=models.CASCADE, related_name='fsm_states') + + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=AnnotationStateChoices.choices, db_index=True) + + # Denormalized fields for performance (avoid JOINs in common queries) + task_id = models.PositiveIntegerField( + db_index=True, help_text='From annotation.task_id - denormalized for performance' + ) + project_id = models.PositiveIntegerField( + db_index=True, help_text='From annotation.task.project_id - denormalized for performance' + ) + completed_by_id = models.PositiveIntegerField( + null=True, db_index=True, help_text='From annotation.completed_by_id - denormalized for performance' + ) + + class Meta: + app_label = 'fsm' + indexes = [ + # Critical: Latest state lookup + models.Index(fields=['annotation_id', '-id'], name='anno_current_state_idx'), + # Filtering and reporting + models.Index(fields=['task_id', 'state', '-id'], name='anno_task_state_idx'), + models.Index(fields=['completed_by_id', 'state', '-id'], name='anno_user_report_idx'), + models.Index(fields=['project_id', 'state', '-id'], name='anno_project_report_idx'), + ] + ordering = ['-id'] + + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for AnnotationState creation""" + return { + 'task_id': entity.task.id, + 'project_id': entity.task.project_id, + 'completed_by_id': entity.completed_by.id if entity.completed_by else None, + } + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal annotation state""" + return self.state == AnnotationStateChoices.COMPLETED - Returns: - Django model class inheriting from BaseState, or None if not found + +@register_state_model('project') +class ProjectState(BaseState): + """ + Core project state tracking for Label Studio. + Provides basic project state management with: + - Simple 3-state workflow (CREATED → IN_PROGRESS → COMPLETED) + - Project lifecycle tracking """ - entity_name = entity._meta.model_name.lower() - return get_state_model(entity_name) + + # Entity Relationship + project = models.ForeignKey('projects.Project', on_delete=models.CASCADE, related_name='fsm_states') + + # Override state field to add choices constraint + state = models.CharField(max_length=50, choices=ProjectStateChoices.choices, db_index=True) + + created_by_id = models.PositiveIntegerField( + null=True, db_index=True, help_text='From project.created_by_id - denormalized for performance' + ) + + class Meta: + app_label = 'fsm' + indexes = [ + # Critical: Latest state lookup + models.Index(fields=['project_id', '-id'], name='project_current_state_idx'), + # Filtering and reporting + models.Index(fields=['organization_id', 'state', '-id'], name='project_org_state_idx'), + models.Index(fields=['organization_id', '-id'], name='project_org_reporting_idx'), + ] + ordering = ['-id'] + + @classmethod + def get_denormalized_fields(cls, entity): + """Get denormalized fields for ProjectState creation""" + return { + 'created_by_id': entity.created_by.id if entity.created_by else None, + } + + @property + def is_terminal_state(self) -> bool: + """Check if this is a terminal project state""" + return self.state == ProjectStateChoices.COMPLETED diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py index e00cd12c3e06..139152d7a87c 100644 --- a/label_studio/fsm/state_choices.py +++ b/label_studio/fsm/state_choices.py @@ -5,29 +5,67 @@ state choices for different entity types in the FSM framework. """ -# Registry for dynamic state choices extension -STATE_CHOICES_REGISTRY = {} +from django.db import models +from django.utils.translation import gettext_lazy as _ +from fsm.registry import register_state_choices +""" +Core state choice enums for Label Studio entities. +These enums define the essential states for core Label Studio entities. +""" -def register_state_choices(entity_name: str, choices_class): - """ - Register state choices for an entity type. - Args: - entity_name: Name of the entity (e.g., 'order', 'ticket') - choices_class: Django TextChoices class defining valid states +@register_state_choices('task') +class TaskStateChoices(models.TextChoices): + """ + Core task states for basic Label Studio workflow. + Simplified states covering the essential task lifecycle: + - Creation and assignment + - Annotation work + - Completion """ - STATE_CHOICES_REGISTRY[entity_name.lower()] = choices_class + + # Initial State + CREATED = 'CREATED', _('Created') + + # Work States + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') -def get_state_choices(entity_name: str): +@register_state_choices('annotation') +class AnnotationStateChoices(models.TextChoices): """ - Get state choices for an entity type. + Core annotation states for basic Label Studio workflow. + Simplified states covering the essential annotation lifecycle: + - Submission + - Completion + """ + + # Working States + SUBMITTED = 'SUBMITTED', _('Submitted') - Args: - entity_name: Name of the entity + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') - Returns: - Django TextChoices class or None if not found + +@register_state_choices('project') +class ProjectStateChoices(models.TextChoices): + """ + Core project states for basic Label Studio workflow. + Simplified states covering the essential project lifecycle: + - Setup and configuration + - Active work + - Completion """ - return STATE_CHOICES_REGISTRY.get(entity_name.lower()) + + # Setup States + CREATED = 'CREATED', _('Created') + + # Work States + IN_PROGRESS = 'IN_PROGRESS', _('In Progress') + + # Terminal State + COMPLETED = 'COMPLETED', _('Completed') diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 29c17f322176..11ece8507bb8 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -1,23 +1,24 @@ """ -Core framework tests for the declarative Pydantic-based transition system. +Comprehensive tests for the declarative Pydantic-based transition system. -This test suite covers the core transition framework functionality without -product-specific implementations. It tests the abstract base classes, -registration system, validation, and core utilities. +This test suite provides extensive coverage of the new transition system, +including usage examples, edge cases, validation scenarios, and integration +patterns to serve as both tests and documentation. """ -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import Mock, patch +import pytest from django.contrib.auth import get_user_model -from django.db import models from django.test import TestCase -from django.utils.translation import gettext_lazy as _ from fsm.registry import register_transition, transition_registry +from fsm.state_choices import AnnotationStateChoices, TaskStateChoices from fsm.transition_utils import ( TransitionBuilder, get_available_transitions, + get_valid_transitions, ) from fsm.transitions import ( BaseTransition, @@ -29,329 +30,892 @@ User = get_user_model() -class TestStateChoices(models.TextChoices): - """Test state choices for mock entity""" +class MockTask: + """Mock task model for testing""" - CREATED = 'CREATED', _('Created') - IN_PROGRESS = 'IN_PROGRESS', _('In Progress') - COMPLETED = 'COMPLETED', _('Completed') + def __init__(self, pk=1): + self.pk = pk + self.id = pk + self.organization_id = 1 + self._meta = Mock() + self._meta.model_name = 'task' + self._meta.label_lower = 'tasks.task' -class MockEntity: - """Mock entity model for testing""" +class MockAnnotation: + """Mock annotation model for testing""" def __init__(self, pk=1): self.pk = pk self.id = pk + self.result = {'test': 'data'} # Mock annotation data self.organization_id = 1 self._meta = Mock() - self._meta.model_name = 'test_entity' - self._meta.label_lower = 'test.testentity' + self._meta.model_name = 'annotation' + self._meta.label_lower = 'tasks.annotation' -class CoreFrameworkTests(TestCase): - """Test core framework functionality""" +class TestTransition(BaseTransition): + """Test transition class""" - def setUp(self): - """Set up test data""" - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.mock_entity = MockEntity() + test_field: str + optional_field: int = 42 - # Clear registry to avoid test pollution - transition_registry.clear() + @property + def target_state(self) -> str: + return 'TEST_STATE' - def tearDown(self): - """Clean up after tests""" - transition_registry.clear() + @classmethod + def get_target_state(cls) -> str: + """Return the target state at class level""" + return 'TEST_STATE' - def test_base_transition_class(self): - """Test BaseTransition abstract functionality""" + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + """Allow transition from any state for testing""" + return True - @register_transition('test_entity', 'test_transition') - class TestTransition(BaseTransition): - test_field: str = Field('default', description='Test field') + def validate_transition(self, context: TransitionContext) -> bool: + if self.test_field == 'invalid': + raise TransitionValidationError('Test validation error') + return super().validate_transition(context) - @property - def target_state(self) -> str: - return TestStateChoices.IN_PROGRESS + def transition(self, context: TransitionContext) -> dict: + return { + 'test_field': self.test_field, + 'optional_field': self.optional_field, + 'context_entity_id': context.entity.pk, + } - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'test_field': self.test_field} - # Test instantiation - transition = TestTransition(test_field='test_value') - self.assertEqual(transition.test_field, 'test_value') - self.assertEqual(transition.target_state, TestStateChoices.IN_PROGRESS) - self.assertEqual(transition.transition_name, 'test_transition') +class DeclarativeTransitionTests(TestCase): + """Test cases for the declarative transition system""" + + def setUp(self): + self.task = MockTask() + self.annotation = MockAnnotation() + self.user = Mock() + self.user.id = 1 + self.user.username = 'testuser' - def test_transition_context(self): - """Test TransitionContext functionality""" + # Register test transition + transition_registry.register('task', 'test_transition', TestTransition) + + def test_transition_context_creation(self): + """Test creation of transition context""" context = TransitionContext( - entity=self.mock_entity, - current_state=TestStateChoices.CREATED, - target_state=TestStateChoices.IN_PROGRESS, - timestamp=datetime.now(), + entity=self.task, current_user=self.user, + current_state='CREATED', + target_state='IN_PROGRESS', + organization_id=1, ) - self.assertEqual(context.entity, self.mock_entity) - self.assertEqual(context.current_state, TestStateChoices.CREATED) - self.assertEqual(context.target_state, TestStateChoices.IN_PROGRESS) + self.assertEqual(context.entity, self.task) self.assertEqual(context.current_user, self.user) - self.assertTrue(context.has_current_state) + self.assertEqual(context.current_state, 'CREATED') + self.assertEqual(context.target_state, 'IN_PROGRESS') + self.assertEqual(context.organization_id, 1) self.assertFalse(context.is_initial_transition) + self.assertTrue(context.has_current_state) + + def test_transition_context_initial_state(self): + """Test context for initial state transition""" + context = TransitionContext(entity=self.task, current_state=None, target_state='CREATED') - def test_transition_context_properties(self): - """Test TransitionContext computed properties""" - # Test initial transition - context = TransitionContext(entity=self.mock_entity, current_state=None, target_state=TestStateChoices.CREATED) self.assertTrue(context.is_initial_transition) self.assertFalse(context.has_current_state) - # Test with current state - context_with_state = TransitionContext( - entity=self.mock_entity, - current_state=TestStateChoices.CREATED, - target_state=TestStateChoices.IN_PROGRESS, - ) - self.assertFalse(context_with_state.is_initial_transition) - self.assertTrue(context_with_state.has_current_state) + def test_transition_validation_success(self): + """Test successful transition validation""" + transition = TestTransition(test_field='valid') + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) - def test_transition_registry(self): - """Test transition registration and retrieval""" + self.assertTrue(transition.validate_transition(context)) - @register_transition('test_entity', 'test_transition') - class TestTransition(BaseTransition): - @property - def target_state(self) -> str: - return TestStateChoices.COMPLETED + def test_transition_validation_failure(self): + """Test transition validation failure""" + transition = TestTransition(test_field='invalid') + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {} + with self.assertRaises(TransitionValidationError): + transition.validate_transition(context) + + def test_transition_execution(self): + """Test transition data generation""" + transition = TestTransition(test_field='test_value', optional_field=100) + context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) + + result = transition.transition(context) + + self.assertEqual(result['test_field'], 'test_value') + self.assertEqual(result['optional_field'], 100) + self.assertEqual(result['context_entity_id'], self.task.pk) + + def test_transition_name_generation(self): + """Test automatic transition name generation""" + transition = TestTransition(test_field='test') + self.assertEqual(transition.transition_name, 'test_transition') + + @patch('fsm.state_manager.StateManager.transition_state') + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_transition_execute_full_workflow(self, mock_get_state, mock_transition): + """Test full transition execution workflow""" + # Setup mocks + mock_get_state.return_value = None # No current state + mock_transition.return_value = True + + mock_state_record = Mock() + mock_state_record.id = 'test-uuid' + + with patch('fsm.state_manager.StateManager.get_current_state_object', return_value=mock_state_record): + transition = TestTransition(test_field='test_value') + context = TransitionContext( + entity=self.task, current_user=self.user, current_state=None, target_state=transition.target_state + ) - # Test registration - retrieved = transition_registry.get_transition('test_entity', 'test_transition') + # Execute transition + transition.execute(context) + + # Verify StateManager was called correctly + mock_transition.assert_called_once() + call_args = mock_transition.call_args + + self.assertEqual(call_args[1]['entity'], self.task) + self.assertEqual(call_args[1]['new_state'], 'TEST_STATE') + self.assertEqual(call_args[1]['transition_name'], 'test_transition') + self.assertEqual(call_args[1]['user'], self.user) + + # Check context data + context_data = call_args[1]['context'] + self.assertEqual(context_data['test_field'], 'test_value') + self.assertEqual(context_data['optional_field'], 42) + + +class TransitionRegistryTests(TestCase): + """Test cases for the transition registry""" + + def setUp(self): + self.registry = transition_registry + + def test_transition_registration(self): + """Test registering transitions""" + self.registry.register('test_entity', 'test_transition', TestTransition) + + retrieved = self.registry.get_transition('test_entity', 'test_transition') self.assertEqual(retrieved, TestTransition) - # Test entity transitions - entity_transitions = transition_registry.get_transitions_for_entity('test_entity') - self.assertIn('test_transition', entity_transitions) - self.assertEqual(entity_transitions['test_transition'], TestTransition) + def test_get_transitions_for_entity(self): + """Test getting all transitions for an entity""" + self.registry.register('test_entity', 'transition1', TestTransition) + self.registry.register('test_entity', 'transition2', TestTransition) - def test_pydantic_validation(self): - """Test Pydantic validation in transitions""" + transitions = self.registry.get_transitions_for_entity('test_entity') - @register_transition('test_entity', 'validated_transition') - class ValidatedTransition(BaseTransition): - required_field: str = Field(..., description='Required field') - optional_field: int = Field(42, description='Optional field') + self.assertIn('transition1', transitions) + self.assertIn('transition2', transitions) + self.assertEqual(len(transitions), 2) - @property - def target_state(self) -> str: - return TestStateChoices.COMPLETED + def test_list_entities(self): + """Test listing registered entities""" + self.registry.register('entity1', 'transition1', TestTransition) + self.registry.register('entity2', 'transition2', TestTransition) - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'required_field': self.required_field, 'optional_field': self.optional_field} + entities = self.registry.list_entities() - # Test valid instantiation - transition = ValidatedTransition(required_field='test') - self.assertEqual(transition.required_field, 'test') - self.assertEqual(transition.optional_field, 42) + self.assertIn('entity1', entities) + self.assertIn('entity2', entities) - # Test validation error - with self.assertRaises(ValidationError): - ValidatedTransition() # Missing required field - def test_transition_execution(self): - """Test transition execution logic""" +class TransitionUtilsTests(TestCase): + """Test cases for transition utility functions""" - @register_transition('test_entity', 'execution_test') - class ExecutionTestTransition(BaseTransition): - value: str = Field('test', description='Test value') + def setUp(self): + self.task = MockTask() + transition_registry.register('task', 'test_transition', TestTransition) + + def test_get_available_transitions(self): + """Test getting available transitions for entity""" + transitions = get_available_transitions(self.task) + self.assertIn('test_transition', transitions) + + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_get_valid_transitions(self, mock_get_state): + """Test filtering valid transitions""" + mock_get_state.return_value = None + + valid_transitions = get_valid_transitions(self.task, validate=True) + self.assertIn('test_transition', valid_transitions) + + @patch('fsm.state_manager.StateManager.get_current_state_object') + def test_get_valid_transitions_with_invalid(self, mock_get_state): + """Test filtering out invalid transitions""" + mock_get_state.return_value = None + + # Register an invalid transition + class InvalidTransition(TestTransition): + @classmethod + def can_transition_from_state(cls, context): + # This transition is never valid at the class level + return False + + def validate_transition(self, context): + raise TransitionValidationError('Always invalid') + + transition_registry.register('task', 'invalid_transition', InvalidTransition) + + valid_transitions = get_valid_transitions(self.task, validate=True) + self.assertIn('test_transition', valid_transitions) + self.assertNotIn('invalid_transition', valid_transitions) + + @patch('fsm.transition_utils.execute_transition') + def test_transition_builder(self, mock_execute): + """Test fluent transition builder interface""" + mock_execute.return_value = Mock() + + ( + TransitionBuilder(self.task) + .transition('test_transition') + .with_data(test_field='builder_test') + .by_user(Mock()) + .with_context(extra='context') + .execute() + ) + + mock_execute.assert_called_once() + call_args = mock_execute.call_args + + self.assertEqual(call_args[1]['transition_name'], 'test_transition') + self.assertEqual(call_args[1]['transition_data']['test_field'], 'builder_test') + + +class ExampleTransitionIntegrationTests(TestCase): + """Integration tests using the example transitions""" + + def setUp(self): + # Import example transitions to register them + + self.task = MockTask() + self.annotation = MockAnnotation() + self.user = Mock() + self.user.id = 1 + self.user.username = 'testuser' + + def test_start_task_transition_validation(self): + """Test StartTaskTransition validation""" + from fsm.example_transitions import StartTaskTransition + + transition = StartTaskTransition(assigned_user_id=123) + + # Test valid transition from CREATED + context = TransitionContext( + entity=self.task, current_state=TaskStateChoices.CREATED, target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + # Test invalid transition from COMPLETED + context.current_state = TaskStateChoices.COMPLETED + + with self.assertRaises(TransitionValidationError): + transition.validate_transition(context) + + def test_submit_annotation_validation(self): + """Test SubmitAnnotationTransition validation""" + from fsm.example_transitions import SubmitAnnotationTransition + + transition = SubmitAnnotationTransition() + + # Test valid transition + context = TransitionContext( + entity=self.annotation, current_state=AnnotationStateChoices.DRAFT, target_state=transition.target_state + ) + + self.assertTrue(transition.validate_transition(context)) + + def test_transition_data_generation(self): + """Test that transitions generate appropriate context data""" + from fsm.example_transitions import StartTaskTransition + + transition = StartTaskTransition(assigned_user_id=123, estimated_duration=5, priority='high') + + context = TransitionContext( + entity=self.task, current_user=self.user, target_state=transition.target_state, timestamp=datetime.now() + ) + + result = transition.transition(context) + + self.assertEqual(result['assigned_user_id'], 123) + self.assertEqual(result['estimated_duration'], 5) + self.assertEqual(result['priority'], 'high') + self.assertIn('started_at', result) + self.assertEqual(result['assignment_type'], 'manual') + + +class ComprehensiveUsageExampleTests(TestCase): + """ + Comprehensive test cases demonstrating various usage patterns. + + These tests serve as both validation and documentation for how to + implement and use the declarative transition system. + """ + + def setUp(self): + self.task = MockTask() + self.user = Mock() + self.user.id = 123 + self.user.username = 'testuser' + + # Clear registry to avoid conflicts + transition_registry._transitions.clear() + + def test_basic_transition_implementation(self): + """ + USAGE EXAMPLE: Basic transition implementation + + Shows how to implement a simple transition with validation. + """ + + class BasicTransition(BaseTransition): + """Example: Simple transition with required field""" + + message: str = Field(..., description='Message for the transition') @property def target_state(self) -> str: - return TestStateChoices.COMPLETED + return 'PROCESSED' def validate_transition(self, context: TransitionContext) -> bool: - return context.current_state == TestStateChoices.IN_PROGRESS + # Business logic validation + if context.current_state == 'COMPLETED': + raise TransitionValidationError('Cannot process completed items') + return True def transition(self, context: TransitionContext) -> Dict[str, Any]: return { - 'value': self.value, - 'entity_id': context.entity.pk, - 'timestamp': context.timestamp.isoformat(), + 'message': self.message, + 'processed_by': context.current_user.username if context.current_user else 'system', + 'processed_at': context.timestamp.isoformat(), } - transition = ExecutionTestTransition(value='execution_test') + # Test the implementation + transition = BasicTransition(message='Processing task') + self.assertEqual(transition.message, 'Processing task') + self.assertEqual(transition.target_state, 'PROCESSED') + + # Test validation context = TransitionContext( - entity=self.mock_entity, - current_state=TestStateChoices.IN_PROGRESS, - target_state=transition.target_state, - timestamp=datetime.now(), + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state ) - # Test validation self.assertTrue(transition.validate_transition(context)) - # Test execution - result = transition.transition(context) - self.assertEqual(result['value'], 'execution_test') - self.assertEqual(result['entity_id'], self.mock_entity.pk) - self.assertIn('timestamp', result) + # Test data generation + data = transition.transition(context) + self.assertEqual(data['message'], 'Processing task') + self.assertEqual(data['processed_by'], 'testuser') + self.assertIn('processed_at', data) - def test_validation_error_handling(self): - """Test transition validation error handling""" + def test_complex_validation_example(self): + """ + USAGE EXAMPLE: Complex validation with multiple conditions + + Shows how to implement sophisticated business logic validation. + """ + + class TaskAssignmentTransition(BaseTransition): + """Example: Complex validation for task assignment""" + + assignee_id: int = Field(..., description='User to assign task to') + priority: str = Field('normal', description='Task priority') + deadline: datetime = Field(None, description='Task deadline') - @register_transition('test_entity', 'validation_test') - class ValidationTestTransition(BaseTransition): @property def target_state(self) -> str: - return TestStateChoices.COMPLETED + return 'ASSIGNED' def validate_transition(self, context: TransitionContext) -> bool: - if context.current_state != TestStateChoices.IN_PROGRESS: + # Multiple validation conditions + if context.current_state not in ['CREATED', 'UNASSIGNED']: raise TransitionValidationError( - 'Can only complete from IN_PROGRESS state', {'current_state': context.current_state} + f'Cannot assign task in state {context.current_state}', + {'current_state': context.current_state, 'task_id': context.entity.pk}, ) + + # Check deadline is in future + if self.deadline and self.deadline <= datetime.now(): + raise TransitionValidationError( + 'Deadline must be in the future', {'deadline': self.deadline.isoformat()} + ) + + # Check priority is valid + valid_priorities = ['low', 'normal', 'high', 'urgent'] + if self.priority not in valid_priorities: + raise TransitionValidationError( + f'Invalid priority: {self.priority}', {'valid_priorities': valid_priorities} + ) + return True def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {} + return { + 'assignee_id': self.assignee_id, + 'priority': self.priority, + 'deadline': self.deadline.isoformat() if self.deadline else None, + 'assigned_by': context.current_user.id if context.current_user else None, + 'assignment_reason': f'Task assigned to user {self.assignee_id}', + } + + # Test valid assignment + future_deadline = datetime.now() + timedelta(days=7) + transition = TaskAssignmentTransition(assignee_id=456, priority='high', deadline=future_deadline) - transition = ValidationTestTransition() - invalid_context = TransitionContext( - entity=self.mock_entity, - current_state=TestStateChoices.CREATED, - target_state=transition.target_state, + context = TransitionContext( + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state ) - # Test validation error + self.assertTrue(transition.validate_transition(context)) + + # Test invalid state + context.current_state = 'COMPLETED' with self.assertRaises(TransitionValidationError) as cm: - transition.validate_transition(invalid_context) + transition.validate_transition(context) - error = cm.exception - self.assertIn('Can only complete from IN_PROGRESS state', str(error)) - self.assertIn('current_state', error.context) + self.assertIn('Cannot assign task in state', str(cm.exception)) + self.assertIn('COMPLETED', str(cm.exception)) - def test_transition_builder_basic(self): - """Test TransitionBuilder basic functionality""" + # Test invalid deadline + past_deadline = datetime.now() - timedelta(days=1) + invalid_transition = TaskAssignmentTransition(assignee_id=456, deadline=past_deadline) + + context.current_state = 'CREATED' + with self.assertRaises(TransitionValidationError) as cm: + invalid_transition.validate_transition(context) - @register_transition('test_entity', 'builder_test') - class BuilderTestTransition(BaseTransition): - value: str = Field('default', description='Test value') + self.assertIn('Deadline must be in the future', str(cm.exception)) + + def test_hooks_and_lifecycle_example(self): + """ + USAGE EXAMPLE: Using pre/post hooks for side effects + + Shows how to implement lifecycle hooks for notifications, + cleanup, or other side effects. + """ + + class NotificationTransition(BaseTransition): + """Example: Transition with notification hooks""" + + notification_message: str = Field(..., description='Notification message') + notify_users: list = Field(default_factory=list, description='Users to notify') + notifications_sent: list = Field(default_factory=list, description='Track sent notifications') + cleanup_performed: bool = Field(default=False, description='Track cleanup status') @property def target_state(self) -> str: - return TestStateChoices.COMPLETED + return 'NOTIFIED' + + @classmethod + def get_target_state(cls) -> str: + return 'NOTIFIED' + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def pre_transition_hook(self, context: TransitionContext) -> None: + """Prepare notifications before state change""" + # Validate notification recipients + if not self.notify_users: + self.notify_users = [context.current_user.id] if context.current_user else [] def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {'value': self.value} + return { + 'notification_message': self.notification_message, + 'notify_users': self.notify_users, + 'notification_sent_at': context.timestamp.isoformat(), + } - # Test builder creation - builder = TransitionBuilder(self.mock_entity) - self.assertEqual(builder.entity, self.mock_entity) + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + """Send notifications after successful state change""" + # Mock notification sending + for user_id in self.notify_users: + self.notifications_sent.append( + {'user_id': user_id, 'message': self.notification_message, 'sent_at': context.timestamp} + ) - # Test method chaining - builder = builder.transition('builder_test').with_data(value='builder_test_value').by_user(self.user) + # Mock cleanup + self.cleanup_performed = True - # Validate the builder state - validation_errors = builder.validate() - self.assertEqual(len(validation_errors), 0) + # Test the hooks + transition = NotificationTransition(notification_message='Task has been updated', notify_users=[123, 456]) - def test_get_available_transitions(self): - """Test get_available_transitions utility""" + context = TransitionContext( + entity=self.task, current_user=self.user, current_state='CREATED', target_state=transition.target_state + ) - @register_transition('test_entity', 'available_test') - class AvailableTestTransition(BaseTransition): - @property - def target_state(self) -> str: - return TestStateChoices.COMPLETED + # Test pre-hook + transition.pre_transition_hook(context) + self.assertEqual(transition.notify_users, [123, 456]) - def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {} + # Test transition + data = transition.transition(context) + self.assertEqual(data['notification_message'], 'Task has been updated') + + # Test post-hook + mock_state_record = Mock() + transition.post_transition_hook(context, mock_state_record) + + self.assertEqual(len(transition.notifications_sent), 2) + self.assertTrue(transition.cleanup_performed) + + def test_conditional_transition_example(self): + """ + USAGE EXAMPLE: Conditional transitions based on data - available = get_available_transitions(self.mock_entity) - self.assertIn('available_test', available) - self.assertEqual(available['available_test'], AvailableTestTransition) + Shows how to implement transitions that behave differently + based on input data or context. + """ - def test_transition_hooks(self): - """Test pre and post transition hooks""" + class ConditionalApprovalTransition(BaseTransition): + """Example: Conditional approval based on confidence""" - hook_calls = [] + confidence_score: float = Field(..., ge=0.0, le=1.0, description='Confidence score') + auto_approve_threshold: float = Field(0.9, description='Auto-approval threshold') + reviewer_id: int = Field(None, description='Manual reviewer ID') - @register_transition('test_entity', 'hook_test') - class HookTestTransition(BaseTransition): @property def target_state(self) -> str: - return TestStateChoices.COMPLETED + # Dynamic target state based on confidence + if self.confidence_score >= self.auto_approve_threshold: + return 'AUTO_APPROVED' + else: + return 'PENDING_REVIEW' - def pre_transition_hook(self, context: TransitionContext) -> None: - hook_calls.append('pre') + def validate_transition(self, context: TransitionContext) -> bool: + # Different validation based on approval type + if self.confidence_score >= self.auto_approve_threshold: + # Auto-approval validation + if context.current_state != 'SUBMITTED': + raise TransitionValidationError('Can only auto-approve submitted items') + else: + # Manual review validation + if not self.reviewer_id: + raise TransitionValidationError('Manual review requires reviewer_id') + + return True def transition(self, context: TransitionContext) -> Dict[str, Any]: - hook_calls.append('transition') - return {} + base_data = { + 'confidence_score': self.confidence_score, + 'threshold': self.auto_approve_threshold, + } - def post_transition_hook(self, context: TransitionContext, state_record) -> None: - hook_calls.append('post') + if self.confidence_score >= self.auto_approve_threshold: + # Auto-approval data + return { + **base_data, + 'approval_type': 'automatic', + 'approved_at': context.timestamp.isoformat(), + 'approved_by': 'system', + } + else: + # Manual review data + return { + **base_data, + 'approval_type': 'manual', + 'assigned_reviewer': self.reviewer_id, + 'review_requested_at': context.timestamp.isoformat(), + } + + # Test auto-approval path + high_confidence_transition = ConditionalApprovalTransition(confidence_score=0.95) + + self.assertEqual(high_confidence_transition.target_state, 'AUTO_APPROVED') - transition = HookTestTransition() context = TransitionContext( - entity=self.mock_entity, - current_state=TestStateChoices.IN_PROGRESS, - target_state=transition.target_state, + entity=self.task, current_state='SUBMITTED', target_state=high_confidence_transition.target_state ) - # Test hook execution order - transition.pre_transition_hook(context) - transition.transition(context) - transition.post_transition_hook(context, Mock()) + self.assertTrue(high_confidence_transition.validate_transition(context)) - self.assertEqual(hook_calls, ['pre', 'transition', 'post']) + auto_data = high_confidence_transition.transition(context) + self.assertEqual(auto_data['approval_type'], 'automatic') + self.assertEqual(auto_data['approved_by'], 'system') + # Test manual review path + low_confidence_transition = ConditionalApprovalTransition(confidence_score=0.7, reviewer_id=789) -class TransitionUtilsTests(TestCase): - """Test transition utility functions""" + self.assertEqual(low_confidence_transition.target_state, 'PENDING_REVIEW') + + context.target_state = low_confidence_transition.target_state + self.assertTrue(low_confidence_transition.validate_transition(context)) + + manual_data = low_confidence_transition.transition(context) + self.assertEqual(manual_data['approval_type'], 'manual') + self.assertEqual(manual_data['assigned_reviewer'], 789) + + def test_registry_and_decorator_usage(self): + """ + USAGE EXAMPLE: Using the registry and decorator system + + Shows how to register transitions and use the decorator syntax. + """ + + @register_transition('document', 'publish') + class PublishDocumentTransition(BaseTransition): + """Example: Using the registration decorator""" + + publish_immediately: bool = Field(True, description='Publish immediately') + scheduled_time: datetime = Field(None, description='Scheduled publish time') + + @property + def target_state(self) -> str: + return 'PUBLISHED' if self.publish_immediately else 'SCHEDULED' + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'publish_immediately': self.publish_immediately, + 'scheduled_time': self.scheduled_time.isoformat() if self.scheduled_time else None, + 'published_by': context.current_user.id if context.current_user else None, + } + + # Test registration worked + registered_class = transition_registry.get_transition('document', 'publish') + self.assertEqual(registered_class, PublishDocumentTransition) + + # Test getting transitions for entity + document_transitions = transition_registry.get_transitions_for_entity('document') + self.assertIn('publish', document_transitions) + + # Test execution through registry + mock_document = Mock() + mock_document.pk = 1 + mock_document._meta.model_name = 'document' + + # This would normally go through the full execution workflow + transition_data = {'publish_immediately': False, 'scheduled_time': datetime.now() + timedelta(hours=2)} + + # Test transition creation and validation + transition = PublishDocumentTransition(**transition_data) + self.assertEqual(transition.target_state, 'SCHEDULED') + + +class ValidationAndErrorHandlingTests(TestCase): + """ + Tests focused on validation scenarios and error handling. + + These tests demonstrate proper error handling patterns and + validation edge cases. + """ def setUp(self): - """Set up test data""" - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.mock_entity = MockEntity() - transition_registry.clear() + self.task = MockTask() + transition_registry._transitions.clear() - def tearDown(self): - """Clean up after tests""" - transition_registry.clear() + def test_pydantic_validation_errors(self): + """Test Pydantic field validation errors""" - def test_get_available_transitions(self): - """Test getting available transitions for an entity""" + class StrictValidationTransition(BaseTransition): + required_field: str = Field(..., description='Required field') + email_field: str = Field(..., pattern=r'^[\w\.-]+@[\w\.-]+\.\w+$', description='Valid email') + number_field: int = Field(..., ge=1, le=100, description='Number between 1-100') + + @property + def target_state(self) -> str: + return 'VALIDATED' + + @classmethod + def get_target_state(cls) -> str: + return 'VALIDATED' + + @classmethod + def can_transition_from_state(cls, context: TransitionContext) -> bool: + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {'validated': True} + + # Test missing required field + with self.assertRaises(ValidationError): + StrictValidationTransition(email_field='test@example.com', number_field=50) + + # Test invalid email + with self.assertRaises(ValidationError): + StrictValidationTransition(required_field='test', email_field='invalid-email', number_field=50) + + # Test number out of range + with self.assertRaises(ValidationError): + StrictValidationTransition(required_field='test', email_field='test@example.com', number_field=150) + + # Test valid data + valid_transition = StrictValidationTransition( + required_field='test', email_field='user@example.com', number_field=75 + ) + self.assertEqual(valid_transition.required_field, 'test') + + def test_business_logic_validation_errors(self): + """Test business logic validation with detailed error context""" + + class BusinessRuleTransition(BaseTransition): + amount: float = Field(..., description='Transaction amount') + currency: str = Field('USD', description='Currency code') - @register_transition('test_entity', 'util_test_1') - class UtilTestTransition1(BaseTransition): @property def target_state(self) -> str: - return TestStateChoices.IN_PROGRESS + return 'PROCESSED' + + def validate_transition(self, context: TransitionContext) -> bool: + # Complex business rule validation + errors = [] + + if self.amount <= 0: + errors.append('Amount must be positive') + + if self.amount > 10000 and context.current_user is None: + errors.append('Large amounts require authenticated user') + + if self.currency not in ['USD', 'EUR', 'GBP']: + errors.append(f'Unsupported currency: {self.currency}') + + if context.current_state == 'CANCELLED': + errors.append('Cannot process cancelled transactions') + + if errors: + raise TransitionValidationError( + f"Validation failed: {'; '.join(errors)}", + { + 'validation_errors': errors, + 'amount': self.amount, + 'currency': self.currency, + 'current_state': context.current_state, + }, + ) + + return True def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {} + return {'amount': self.amount, 'currency': self.currency} + + context = TransitionContext(entity=self.task, current_state='PENDING', target_state='PROCESSED') + + # Test negative amount + negative_transition = BusinessRuleTransition(amount=-100) + with self.assertRaises(TransitionValidationError) as cm: + negative_transition.validate_transition(context) + + error = cm.exception + self.assertIn('Amount must be positive', str(error)) + self.assertIn('validation_errors', error.context) + + # Test large amount without user + large_transition = BusinessRuleTransition(amount=15000) + with self.assertRaises(TransitionValidationError) as cm: + large_transition.validate_transition(context) + + self.assertIn('Large amounts require authenticated user', str(cm.exception)) + + # Test invalid currency + invalid_currency_transition = BusinessRuleTransition(amount=100, currency='XYZ') + with self.assertRaises(TransitionValidationError) as cm: + invalid_currency_transition.validate_transition(context) + + self.assertIn('Unsupported currency', str(cm.exception)) + + # Test multiple errors + multi_error_transition = BusinessRuleTransition(amount=-50, currency='XYZ') + with self.assertRaises(TransitionValidationError) as cm: + multi_error_transition.validate_transition(context) + + error_msg = str(cm.exception) + self.assertIn('Amount must be positive', error_msg) + self.assertIn('Unsupported currency', error_msg) + + def test_context_validation_errors(self): + """Test validation errors related to context state""" + + class ContextAwareTransition(BaseTransition): + action: str = Field(..., description='Action to perform') - @register_transition('test_entity', 'util_test_2') - class UtilTestTransition2(BaseTransition): @property def target_state(self) -> str: - return TestStateChoices.COMPLETED + return 'ACTIONED' + + def validate_transition(self, context: TransitionContext) -> bool: + # State-dependent validation + if context.is_initial_transition and self.action != 'create': + raise TransitionValidationError( + "Initial transition must be 'create' action", {'action': self.action, 'is_initial': True} + ) + + if context.current_state == 'COMPLETED' and self.action in ['modify', 'update']: + raise TransitionValidationError( + f'Cannot {self.action} completed items', + {'action': self.action, 'current_state': context.current_state}, + ) + + return True def transition(self, context: TransitionContext) -> Dict[str, Any]: - return {} - - available = get_available_transitions(self.mock_entity) - self.assertEqual(len(available), 2) - self.assertIn('util_test_1', available) - self.assertIn('util_test_2', available) - - # Test with non-existent entity - mock_other = MockEntity() - mock_other._meta.model_name = 'other_entity' - other_available = get_available_transitions(mock_other) - self.assertEqual(len(other_available), 0) + return {'action': self.action} + + # Test initial transition validation + create_transition = ContextAwareTransition(action='create') + initial_context = TransitionContext( + entity=self.task, current_state=None, target_state='ACTIONED' # No current state = initial + ) + + self.assertTrue(create_transition.validate_transition(initial_context)) + + # Test invalid initial action + modify_transition = ContextAwareTransition(action='modify') + with self.assertRaises(TransitionValidationError) as cm: + modify_transition.validate_transition(initial_context) + + error = cm.exception + self.assertIn("Initial transition must be 'create'", str(error)) + self.assertTrue(error.context['is_initial']) + + # Test completed state validation + completed_context = TransitionContext(entity=self.task, current_state='COMPLETED', target_state='ACTIONED') + + with self.assertRaises(TransitionValidationError) as cm: + modify_transition.validate_transition(completed_context) + + self.assertIn('Cannot modify completed items', str(cm.exception)) + + +@pytest.fixture +def task(): + """Pytest fixture for mock task""" + return MockTask() + + +@pytest.fixture +def user(): + """Pytest fixture for mock user""" + user = Mock() + user.id = 1 + user.username = 'testuser' + return user + + +def test_transition_context_properties(task, user): + """Test TransitionContext properties using pytest""" + context = TransitionContext(entity=task, current_user=user, current_state='CREATED', target_state='IN_PROGRESS') + + assert context.has_current_state + assert not context.is_initial_transition + assert context.current_state == 'CREATED' + assert context.target_state == 'IN_PROGRESS' + + +def test_pydantic_validation(): + """Test Pydantic validation in transitions""" + # Valid data + transition = TestTransition(test_field='valid') + assert transition.test_field == 'valid' + assert transition.optional_field == 42 + + # Invalid data should raise validation error + with pytest.raises(Exception): # Pydantic validation error + TestTransition() # Missing required field diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py new file mode 100644 index 000000000000..1ab7c6096a61 --- /dev/null +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -0,0 +1,302 @@ +""" +Integration tests for the FSM system. +Tests the complete FSM functionality including models, state management, +and API endpoints. +""" + +from datetime import datetime, timezone + +from django.contrib.auth import get_user_model +from django.test import TestCase +from fsm.models import AnnotationState, ProjectState, TaskState +from fsm.state_manager import get_state_manager +from projects.models import Project +from rest_framework.test import APITestCase +from tasks.models import Annotation, Task + +User = get_user_model() + + +class TestFSMModels(TestCase): + """Test FSM model functionality""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + + def test_task_state_creation(self): + """Test TaskState creation and basic functionality""" + task_state = TaskState.objects.create( + task=self.task, + project_id=self.task.project_id, # Denormalized from task.project_id + state='CREATED', + triggered_by=self.user, + reason='Task created for testing', + ) + + # Check basic fields + self.assertEqual(task_state.state, 'CREATED') + self.assertEqual(task_state.task, self.task) + self.assertEqual(task_state.triggered_by, self.user) + + # Check UUID7 functionality + self.assertEqual(task_state.id.version, 7) + self.assertIsInstance(task_state.timestamp_from_uuid, datetime) + + # Check string representation + str_repr = str(task_state) + self.assertIn('Task', str_repr) + self.assertIn('CREATED', str_repr) + + def test_annotation_state_creation(self): + """Test AnnotationState creation and basic functionality""" + annotation = Annotation.objects.create(task=self.task, completed_by=self.user, result=[]) + + annotation_state = AnnotationState.objects.create( + annotation=annotation, + task_id=annotation.task.id, # Denormalized from annotation.task_id + project_id=annotation.task.project_id, # Denormalized from annotation.task.project_id + completed_by_id=annotation.completed_by.id if annotation.completed_by else None, # Denormalized + state='DRAFT', + triggered_by=self.user, + reason='Annotation draft created', + ) + + # Check basic fields + self.assertEqual(annotation_state.state, 'DRAFT') + self.assertEqual(annotation_state.annotation, annotation) + + # Check terminal state property + self.assertFalse(annotation_state.is_terminal_state) + + # Test completed state + completed_state = AnnotationState.objects.create( + annotation=annotation, + task_id=annotation.task.id, + project_id=annotation.task.project_id, + completed_by_id=annotation.completed_by.id if annotation.completed_by else None, + state='COMPLETED', + triggered_by=self.user, + ) + self.assertTrue(completed_state.is_terminal_state) + + def test_project_state_creation(self): + """Test ProjectState creation and basic functionality""" + project_state = ProjectState.objects.create( + project=self.project, state='CREATED', triggered_by=self.user, reason='Project created for testing' + ) + + # Check basic fields + self.assertEqual(project_state.state, 'CREATED') + self.assertEqual(project_state.project, self.project) + + # Test terminal state + self.assertFalse(project_state.is_terminal_state) + + completed_state = ProjectState.objects.create(project=self.project, state='COMPLETED', triggered_by=self.user) + self.assertTrue(completed_state.is_terminal_state) + + +class TestStateManager(TestCase): + """Test StateManager functionality""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.StateManager = get_state_manager() + + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + + def test_get_current_state_empty(self): + """Test getting current state when no states exist""" + current_state = self.StateManager.get_current_state(self.task) + self.assertIsNone(current_state) + + def test_transition_state(self): + """Test state transition functionality""" + # Initial transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + transition_name='create_task', + reason='Initial task creation', + ) + + self.assertTrue(success) + + # Check current state + current_state = self.StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'CREATED') + + # Another transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='IN_PROGRESS', + user=self.user, + transition_name='start_work', + context={'started_by': 'user'}, + ) + + self.assertTrue(success) + current_state = self.StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'IN_PROGRESS') + + def test_get_current_state_object(self): + """Test getting current state object with full details""" + # Create some state transitions + self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + self.StateManager.transition_state( + entity=self.task, new_state='IN_PROGRESS', user=self.user, context={'test': 'data'} + ) + + current_state_obj = self.StateManager.get_current_state_object(self.task) + + self.assertIsNotNone(current_state_obj) + self.assertEqual(current_state_obj.state, 'IN_PROGRESS') + self.assertEqual(current_state_obj.previous_state, 'CREATED') + self.assertEqual(current_state_obj.triggered_by, self.user) + self.assertEqual(current_state_obj.context_data, {'test': 'data'}) + + def test_get_state_history(self): + """Test state history retrieval""" + # Create multiple transitions + transitions = [('CREATED', 'create_task'), ('IN_PROGRESS', 'start_work'), ('COMPLETED', 'finish_work')] + + for state, transition in transitions: + self.StateManager.transition_state( + entity=self.task, new_state=state, user=self.user, transition_name=transition + ) + + history = self.StateManager.get_state_history(self.task, limit=10) + + # Should have 3 state records + self.assertEqual(len(history), 3) + + # Should be ordered by most recent first (UUID7 ordering) + states = [h.state for h in history] + self.assertEqual(states, ['COMPLETED', 'IN_PROGRESS', 'CREATED']) + + print(history) + ids = [str(h.id) for h in history] + print(ids) + + # Check previous states are set correctly + self.assertIsNone(history[2].previous_state) # First state has no previous + self.assertEqual(history[1].previous_state, 'CREATED') + self.assertEqual(history[0].previous_state, 'IN_PROGRESS') + + def test_get_states_in_time_range(self): + """Test time-based state queries using UUID7""" + # Record time before creating states + before_time = datetime.now(timezone.utc) + + # Create some states + self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + self.StateManager.transition_state(entity=self.task, new_state='IN_PROGRESS', user=self.user) + + # Record time after creating states + after_time = datetime.now(timezone.utc) + + # Query states in time range + states_in_range = self.StateManager.get_states_in_time_range(self.task, before_time, after_time) + + # Should find both states + self.assertEqual(len(states_in_range), 2) + + +class TestFSMAPI(APITestCase): + """Test FSM API endpoints""" + + def setUp(self): + self.user = User.objects.create_user(email='test@example.com', password='test123') + self.project = Project.objects.create(title='Test Project', created_by=self.user) + self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.client.force_authenticate(user=self.user) + + # Clear cache to ensure tests start with clean state + from django.core.cache import cache + + cache.clear() + + # Create initial state + StateManager = get_state_manager() + StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) + + def test_get_current_state_api(self): + """Test GET /api/fsm/{entity_type}/{entity_id}/current/""" + response = self.client.get(f'/api/fsm/task/{self.task.id}/current/') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertEqual(data['current_state'], 'CREATED') + self.assertEqual(data['entity_type'], 'task') + self.assertEqual(data['entity_id'], self.task.id) + + def test_get_state_history_api(self): + """Test GET /api/fsm/{entity_type}/{entity_id}/history/""" + # Create additional states + StateManager = get_state_manager() + StateManager.transition_state( + entity=self.task, new_state='IN_PROGRESS', user=self.user, transition_name='start_work' + ) + + response = self.client.get(f'/api/fsm/task/{self.task.id}/history/') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertEqual(data['count'], 2) + self.assertEqual(len(data['results']), 2) + + # Check first result (most recent) + latest_state = data['results'][0] + self.assertEqual(latest_state['state'], 'IN_PROGRESS') + self.assertEqual(latest_state['previous_state'], 'CREATED') + self.assertEqual(latest_state['transition_name'], 'start_work') + + def test_transition_state_api(self): + """Test POST /api/fsm/{entity_type}/{entity_id}/transition/""" + transition_data = { + 'new_state': 'IN_PROGRESS', + 'transition_name': 'start_annotation', + 'reason': 'User started working on task', + 'context': {'assignment_id': 123}, + } + + response = self.client.post(f'/api/fsm/task/{self.task.id}/transition/', data=transition_data, format='json') + + self.assertEqual(response.status_code, 200) + data = response.json() + + self.assertTrue(data['success']) + self.assertEqual(data['previous_state'], 'CREATED') + self.assertEqual(data['new_state'], 'IN_PROGRESS') + self.assertEqual(data['entity_type'], 'task') + self.assertEqual(data['entity_id'], self.task.id) + + # Verify state was actually changed + StateManager = get_state_manager() + current_state = StateManager.get_current_state(self.task) + self.assertEqual(current_state, 'IN_PROGRESS') + + def test_api_with_invalid_entity(self): + """Test API with non-existent entity""" + response = self.client.get('/api/fsm/task/99999/current/') + self.assertEqual(response.status_code, 404) + + def test_api_with_invalid_entity_type(self): + """Test API with invalid entity type""" + response = self.client.get('/api/fsm/invalid/1/current/') + self.assertEqual(response.status_code, 404) diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py new file mode 100644 index 000000000000..f0f4268dcd5b --- /dev/null +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -0,0 +1,665 @@ +""" +Integration tests for declarative transitions with real Django models. +These tests demonstrate how the transition system integrates with actual +Django models and the StateManager, providing realistic usage examples. +""" + +from datetime import datetime +from typing import Any, Dict +from unittest.mock import Mock, patch + +from django.contrib.auth import get_user_model +from django.test import TestCase +from fsm.models import TaskState +from fsm.state_choices import AnnotationStateChoices, TaskStateChoices +from fsm.transition_utils import TransitionBuilder +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, register_transition +from pydantic import Field + + +# Mock Django models for integration testing +class MockDjangoTask: + """Mock Django Task model with realistic attributes""" + + def __init__(self, pk=1, project_id=1, organization_id=1): + self.pk = pk + self.id = pk + self.project_id = project_id + self.organization_id = organization_id + self._meta = Mock() + self._meta.model_name = 'task' + self._meta.label_lower = 'tasks.task' + + # Mock task attributes + self.data = {'text': 'Sample task data'} + self.created_at = datetime.now() + self.updated_at = datetime.now() + + +class MockDjangoAnnotation: + """Mock Django Annotation model with realistic attributes""" + + def __init__(self, pk=1, task_id=1, project_id=1, organization_id=1): + self.pk = pk + self.id = pk + self.task_id = task_id + self.project_id = project_id + self.organization_id = organization_id + self._meta = Mock() + self._meta.model_name = 'annotation' + self._meta.label_lower = 'tasks.annotation' + + # Mock annotation attributes + self.result = [{'value': {'text': ['Sample annotation']}}] + self.completed_by_id = None + self.created_at = datetime.now() + self.updated_at = datetime.now() + + +User = get_user_model() + + +class DjangoModelIntegrationTests(TestCase): + """ + Integration tests demonstrating realistic usage with Django models. + These tests show how to implement transitions that work with actual + Django model patterns and the StateManager integration. + """ + + def setUp(self): + self.task = MockDjangoTask() + self.annotation = MockDjangoAnnotation() + self.user = Mock() + self.user.id = 123 + self.user.username = 'integration_test_user' + + # Clear registry for clean test state + from fsm.transitions import transition_registry + + transition_registry._transitions.clear() + + @patch('fsm.registry.get_state_model_for_entity') + @patch('fsm.state_manager.StateManager.get_current_state_object') + @patch('fsm.state_manager.StateManager.transition_state') + def test_task_workflow_integration(self, mock_transition_state, mock_get_state_obj, mock_get_state_model): + """ + INTEGRATION TEST: Complete task workflow using Django models + Demonstrates a realistic task lifecycle from creation through completion + using the declarative transition system with Django model integration. + """ + + # Setup mocks to simulate Django model behavior + mock_get_state_model.return_value = TaskState + mock_get_state_obj.return_value = None # No existing state (initial transition) + mock_transition_state.return_value = True + + # Define task workflow transitions + @register_transition('task', 'create_task') + class CreateTaskTransition(BaseTransition): + """Initial task creation transition""" + + created_by_id: int = Field(..., description='User creating the task') + initial_priority: str = Field('normal', description='Initial task priority') + + @property + def target_state(self) -> str: + return TaskStateChoices.CREATED + + def validate_transition(self, context: TransitionContext) -> bool: + # Validate initial creation + if not context.is_initial_transition: + raise TransitionValidationError('CreateTask can only be used for initial state') + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'created_by_id': self.created_by_id, + 'initial_priority': self.initial_priority, + 'task_data': getattr(context.entity, 'data', {}), + 'project_id': getattr(context.entity, 'project_id', None), + 'creation_method': 'declarative_transition', + } + + @register_transition('task', 'assign_and_start') + class AssignAndStartTaskTransition(BaseTransition): + """Assign task to user and start work""" + + assignee_id: int = Field(..., description='User assigned to task') + estimated_hours: float = Field(None, ge=0.1, description='Estimated work hours') + priority: str = Field('normal', description='Task priority') + + @property + def target_state(self) -> str: + return TaskStateChoices.IN_PROGRESS + + def validate_transition(self, context: TransitionContext) -> bool: + valid_from_states = [TaskStateChoices.CREATED] + if context.current_state not in valid_from_states: + raise TransitionValidationError( + f'Can only assign tasks from states: {valid_from_states}', + {'current_state': context.current_state, 'valid_states': valid_from_states}, + ) + + # Business rule: Can't assign to the same user who created it + if hasattr(context, 'current_state_object') and context.current_state_object: + creator_id = context.current_state_object.context_data.get('created_by_id') + if creator_id == self.assignee_id: + raise TransitionValidationError( + 'Cannot assign task to the same user who created it', + {'creator_id': creator_id, 'assignee_id': self.assignee_id}, + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'assignee_id': self.assignee_id, + 'estimated_hours': self.estimated_hours, + 'priority': self.priority, + 'assigned_at': context.timestamp.isoformat(), + 'assigned_by_id': context.current_user.id if context.current_user else None, + 'work_started': True, + } + + @register_transition('task', 'complete_with_quality') + class CompleteTaskWithQualityTransition(BaseTransition): + """Complete task with quality metrics""" + + quality_score: float = Field(..., ge=0.0, le=1.0, description='Quality score') + completion_notes: str = Field('', description='Completion notes') + actual_hours: float = Field(None, ge=0.0, description='Actual hours worked') + + @property + def target_state(self) -> str: + return TaskStateChoices.COMPLETED + + def validate_transition(self, context: TransitionContext) -> bool: + if context.current_state != TaskStateChoices.IN_PROGRESS: + raise TransitionValidationError( + 'Can only complete tasks that are in progress', {'current_state': context.current_state} + ) + + # Quality check + if self.quality_score < 0.6: + raise TransitionValidationError( + f'Quality score too low: {self.quality_score}. Minimum required: 0.6' + ) + + return True + + def post_transition_hook(self, context: TransitionContext, state_record) -> None: + """Post-completion tasks like notifications""" + # Mock notification system + if hasattr(self, '_notifications'): + self._notifications.append(f'Task {context.entity.pk} completed with quality {self.quality_score}') + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Calculate metrics + start_data = context.current_state_object.context_data if context.current_state_object else {} + estimated_hours = start_data.get('estimated_hours') + + return { + 'quality_score': self.quality_score, + 'completion_notes': self.completion_notes, + 'actual_hours': self.actual_hours, + 'estimated_hours': estimated_hours, + 'completed_at': context.timestamp.isoformat(), + 'completed_by_id': context.current_user.id if context.current_user else None, + 'efficiency_ratio': (estimated_hours / self.actual_hours) + if (estimated_hours and self.actual_hours) + else None, + } + + # Execute the complete workflow + + # Step 1: Create task + create_transition = CreateTaskTransition(created_by_id=100, initial_priority='high') + + # Test with StateManager integration + with patch('fsm.state_manager.StateManager.get_current_state') as mock_get_current: + mock_get_current.return_value = None # No current state + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=None, + target_state=create_transition.target_state, + ) + + # Validate and execute creation + self.assertTrue(create_transition.validate_transition(context)) + creation_data = create_transition.transition(context) + + self.assertEqual(creation_data['created_by_id'], 100) + self.assertEqual(creation_data['initial_priority'], 'high') + self.assertEqual(creation_data['creation_method'], 'declarative_transition') + + # Step 2: Assign and start task + mock_current_state = Mock() + mock_current_state.context_data = creation_data + mock_get_state_obj.return_value = mock_current_state + + assign_transition = AssignAndStartTaskTransition( + assignee_id=200, estimated_hours=4.5, priority='urgent' # Different from creator + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.CREATED, + current_state_object=mock_current_state, + target_state=assign_transition.target_state, + ) + + self.assertTrue(assign_transition.validate_transition(context)) + assignment_data = assign_transition.transition(context) + + self.assertEqual(assignment_data['assignee_id'], 200) + self.assertEqual(assignment_data['estimated_hours'], 4.5) + self.assertTrue(assignment_data['work_started']) + + # Step 3: Complete task + mock_current_state.context_data = assignment_data + + complete_transition = CompleteTaskWithQualityTransition( + quality_score=0.85, completion_notes='Task completed successfully with minor revisions', actual_hours=5.2 + ) + complete_transition._notifications = [] # Mock notification system + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.IN_PROGRESS, + current_state_object=mock_current_state, + target_state=complete_transition.target_state, + ) + + self.assertTrue(complete_transition.validate_transition(context)) + completion_data = complete_transition.transition(context) + + self.assertEqual(completion_data['quality_score'], 0.85) + self.assertEqual(completion_data['actual_hours'], 5.2) + self.assertAlmostEqual(completion_data['efficiency_ratio'], 4.5 / 5.2, places=2) + + # Test post-hook + mock_state_record = Mock() + complete_transition.post_transition_hook(context, mock_state_record) + self.assertEqual(len(complete_transition._notifications), 1) + + # Verify StateManager calls + self.assertEqual(mock_transition_state.call_count, 0) # Not called in our test setup + + def test_annotation_review_workflow_integration(self): + """ + INTEGRATION TEST: Annotation review workflow + Demonstrates a realistic annotation review process using + enterprise-grade validation and approval logic. + """ + + @register_transition('annotation', 'submit_for_review') + class SubmitAnnotationForReview(BaseTransition): + """Submit annotation for quality review""" + + annotator_confidence: float = Field(..., ge=0.0, le=1.0, description='Annotator confidence') + annotation_time_seconds: int = Field(..., ge=1, description='Time spent annotating') + review_requested: bool = Field(True, description='Whether review is requested') + + @property + def target_state(self) -> str: + return AnnotationStateChoices.SUBMITTED + + def validate_transition(self, context: TransitionContext) -> bool: + # Check annotation has content + if not hasattr(context.entity, 'result') or not context.entity.result: + raise TransitionValidationError('Cannot submit empty annotation') + + # Business rule: Low confidence annotations must request review + if self.annotator_confidence < 0.7 and not self.review_requested: + raise TransitionValidationError( + 'Low confidence annotations must request review', + {'confidence': self.annotator_confidence, 'threshold': 0.7}, + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'annotator_confidence': self.annotator_confidence, + 'annotation_time_seconds': self.annotation_time_seconds, + 'review_requested': self.review_requested, + 'annotation_complexity': len(context.entity.result) if context.entity.result else 0, + 'submitted_at': context.timestamp.isoformat(), + 'submitted_by_id': context.current_user.id if context.current_user else None, + } + + @register_transition('annotation', 'review_and_approve') + class ReviewAndApproveAnnotation(BaseTransition): + """Review annotation and approve/reject""" + + reviewer_decision: str = Field(..., description='approve, reject, or request_changes') + quality_score: float = Field(..., ge=0.0, le=1.0, description='Reviewer quality assessment') + review_comments: str = Field('', description='Review comments') + corrections_made: bool = Field(False, description='Whether reviewer made corrections') + + @property + def target_state(self) -> str: + if self.reviewer_decision == 'approve': + return AnnotationStateChoices.COMPLETED + else: + return AnnotationStateChoices.DRAFT # Back to draft for changes + + def validate_transition(self, context: TransitionContext) -> bool: + if context.current_state != AnnotationStateChoices.SUBMITTED: + raise TransitionValidationError('Can only review submitted annotations') + + valid_decisions = ['approve', 'reject', 'request_changes'] + if self.reviewer_decision not in valid_decisions: + raise TransitionValidationError( + f'Invalid decision: {self.reviewer_decision}', {'valid_decisions': valid_decisions} + ) + + # Quality score validation based on decision + if self.reviewer_decision == 'approve' and self.quality_score < 0.6: + raise TransitionValidationError( + 'Cannot approve annotation with low quality score', + {'quality_score': self.quality_score, 'decision': self.reviewer_decision}, + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + # Get submission data for metrics + submission_data = context.current_state_object.context_data if context.current_state_object else {} + + return { + 'reviewer_decision': self.reviewer_decision, + 'quality_score': self.quality_score, + 'review_comments': self.review_comments, + 'corrections_made': self.corrections_made, + 'reviewed_at': context.timestamp.isoformat(), + 'reviewed_by_id': context.current_user.id if context.current_user else None, + 'original_confidence': submission_data.get('annotator_confidence'), + 'confidence_vs_quality_diff': abs( + submission_data.get('annotator_confidence', 0) - self.quality_score + ), + } + + # Execute annotation workflow + + # Step 1: Submit annotation + submit_transition = SubmitAnnotationForReview( + annotator_confidence=0.9, annotation_time_seconds=300, review_requested=True # 5 minutes + ) + + context = TransitionContext( + entity=self.annotation, + current_user=self.user, + current_state=AnnotationStateChoices.DRAFT, + target_state=submit_transition.target_state, + ) + + self.assertTrue(submit_transition.validate_transition(context)) + submit_data = submit_transition.transition(context) + + self.assertEqual(submit_data['annotator_confidence'], 0.9) + self.assertEqual(submit_data['annotation_time_seconds'], 300) + self.assertTrue(submit_data['review_requested']) + self.assertEqual(submit_data['annotation_complexity'], 1) # Based on mock result + + # Step 2: Review and approve + mock_submission_state = Mock() + mock_submission_state.context_data = submit_data + + review_transition = ReviewAndApproveAnnotation( + reviewer_decision='approve', + quality_score=0.85, + review_comments='High quality annotation with good coverage', + corrections_made=False, + ) + + context = TransitionContext( + entity=self.annotation, + current_user=self.user, + current_state=AnnotationStateChoices.SUBMITTED, + current_state_object=mock_submission_state, + target_state=review_transition.target_state, + ) + + self.assertTrue(review_transition.validate_transition(context)) + self.assertEqual(review_transition.target_state, AnnotationStateChoices.COMPLETED) + + review_data = review_transition.transition(context) + + self.assertEqual(review_data['reviewer_decision'], 'approve') + self.assertEqual(review_data['quality_score'], 0.85) + self.assertEqual(review_data['original_confidence'], 0.9) + self.assertAlmostEqual(review_data['confidence_vs_quality_diff'], 0.05, places=2) + + # Test rejection scenario + reject_transition = ReviewAndApproveAnnotation( + reviewer_decision='reject', + quality_score=0.3, + review_comments='Insufficient annotation quality', + corrections_made=False, + ) + + self.assertEqual(reject_transition.target_state, AnnotationStateChoices.DRAFT) + + # Test validation failure + invalid_review = ReviewAndApproveAnnotation( + reviewer_decision='approve', # Trying to approve + quality_score=0.5, # But quality too low + review_comments='Test', + ) + + with self.assertRaises(TransitionValidationError) as cm: + invalid_review.validate_transition(context) + + self.assertIn('Cannot approve annotation with low quality score', str(cm.exception)) + + @patch('fsm.transition_utils.execute_transition') + def test_transition_builder_with_django_models(self, mock_execute): + """ + INTEGRATION TEST: TransitionBuilder with Django model integration + Shows how to use the fluent TransitionBuilder interface with + real Django models and complex business logic. + """ + + @register_transition('task', 'bulk_update_status') + class BulkUpdateTaskStatusTransition(BaseTransition): + """Bulk update task status with metadata""" + + new_status: str = Field(..., description='New status for tasks') + update_reason: str = Field(..., description='Reason for bulk update') + updated_by_system: bool = Field(False, description='Whether updated by automated system') + batch_id: str = Field(None, description='Batch operation ID') + + @property + def target_state(self) -> str: + return self.new_status + + def validate_transition(self, context: TransitionContext) -> bool: + valid_statuses = [TaskStateChoices.CREATED, TaskStateChoices.IN_PROGRESS, TaskStateChoices.COMPLETED] + if self.new_status not in valid_statuses: + raise TransitionValidationError(f'Invalid status: {self.new_status}') + + # Can't bulk update to the same status + if context.current_state == self.new_status: + raise TransitionValidationError('Cannot update to the same status') + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'new_status': self.new_status, + 'update_reason': self.update_reason, + 'updated_by_system': self.updated_by_system, + 'batch_id': self.batch_id, + 'bulk_update_timestamp': context.timestamp.isoformat(), + 'previous_status': context.current_state, + } + + # Mock successful execution + mock_state_record = Mock() + mock_state_record.id = 'mock-uuid' + mock_execute.return_value = mock_state_record + + # Test fluent interface + result = ( + TransitionBuilder(self.task) + .transition('bulk_update_status') + .with_data( + new_status=TaskStateChoices.IN_PROGRESS, + update_reason='Project priority change', + updated_by_system=True, + batch_id='batch_2024_001', + ) + .by_user(self.user) + .with_context(project_update=True, notification_level='high') + .execute() + ) + + # Verify the call + mock_execute.assert_called_once() + call_args, call_kwargs = mock_execute.call_args + + # Check call parameters + self.assertEqual(call_kwargs['entity'], self.task) + self.assertEqual(call_kwargs['transition_name'], 'bulk_update_status') + self.assertEqual(call_kwargs['user'], self.user) + + # Check transition data + transition_data = call_kwargs['transition_data'] + self.assertEqual(transition_data['new_status'], TaskStateChoices.IN_PROGRESS) + self.assertEqual(transition_data['update_reason'], 'Project priority change') + self.assertTrue(transition_data['updated_by_system']) + self.assertEqual(transition_data['batch_id'], 'batch_2024_001') + + # Check context + self.assertTrue(call_kwargs['project_update']) + self.assertEqual(call_kwargs['notification_level'], 'high') + + # Check return value + self.assertEqual(result, mock_state_record) + + def test_error_handling_with_django_models(self): + """ + INTEGRATION TEST: Error handling with Django model validation + Tests comprehensive error handling scenarios that might occur + in real Django model integration. + """ + + @register_transition('task', 'assign_with_constraints') + class AssignTaskWithConstraints(BaseTransition): + """Task assignment with business constraints""" + + assignee_id: int = Field(..., description='User to assign to') + max_concurrent_tasks: int = Field(5, description='Max concurrent tasks per user') + skill_requirements: list = Field(default_factory=list, description='Required skills') + + @property + def target_state(self) -> str: + return TaskStateChoices.IN_PROGRESS + + def validate_transition(self, context: TransitionContext) -> bool: + errors = [] + + # Mock database checks (in real scenario, these would be actual queries) + + # 1. Check user exists and is active + if self.assignee_id <= 0: + errors.append('Invalid user ID') + + # 2. Check user's current task load + if self.max_concurrent_tasks < 1: + errors.append('Max concurrent tasks must be at least 1') + + # 3. Check skill requirements + if self.skill_requirements: + # Mock skill validation + available_skills = ['python', 'labeling', 'review'] + missing_skills = [skill for skill in self.skill_requirements if skill not in available_skills] + if missing_skills: + errors.append(f'Missing required skills: {missing_skills}') + + # 4. Check project-level constraints + if hasattr(context.entity, 'project_id'): + # Mock project validation + if context.entity.project_id <= 0: + errors.append('Invalid project configuration') + + # 5. Check organization permissions + if hasattr(context.entity, 'organization_id'): + if not context.current_user: + errors.append('User authentication required for assignment') + + if errors: + raise TransitionValidationError( + f"Assignment validation failed: {'; '.join(errors)}", + { + 'validation_errors': errors, + 'assignee_id': self.assignee_id, + 'task_id': context.entity.pk, + 'skill_requirements': self.skill_requirements, + }, + ) + + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return { + 'assignee_id': self.assignee_id, + 'max_concurrent_tasks': self.max_concurrent_tasks, + 'skill_requirements': self.skill_requirements, + 'assignment_validated': True, + } + + # Test successful validation + valid_transition = AssignTaskWithConstraints( + assignee_id=123, max_concurrent_tasks=3, skill_requirements=['python', 'labeling'] + ) + + context = TransitionContext( + entity=self.task, + current_user=self.user, + current_state=TaskStateChoices.CREATED, + target_state=valid_transition.target_state, + ) + + self.assertTrue(valid_transition.validate_transition(context)) + + # Test multiple validation errors + invalid_transition = AssignTaskWithConstraints( + assignee_id=-1, # Invalid user ID + max_concurrent_tasks=0, # Invalid max tasks + skill_requirements=['nonexistent_skill'], # Missing skill + ) + + with self.assertRaises(TransitionValidationError) as cm: + invalid_transition.validate_transition(context) + + error = cm.exception + error_msg = str(error) + + # Check all validation errors are included + self.assertIn('Invalid user ID', error_msg) + self.assertIn('Max concurrent tasks must be at least 1', error_msg) + self.assertIn('Missing required skills', error_msg) + + # Check error context + self.assertIn('validation_errors', error.context) + self.assertEqual(len(error.context['validation_errors']), 3) + self.assertEqual(error.context['assignee_id'], -1) + + # Test authentication requirement + context_no_user = TransitionContext( + entity=self.task, + current_user=None, # No user + current_state=TaskStateChoices.CREATED, + target_state=valid_transition.target_state, + ) + + with self.assertRaises(TransitionValidationError) as cm: + valid_transition.validate_transition(context_no_user) + + self.assertIn('User authentication required', str(cm.exception)) From b80f52369f72bf63c1875ee63b5ec67de432e82e Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:07:16 -0500 Subject: [PATCH 27/83] fix imports --- label_studio/fsm/transitions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index cbc36936f940..1c23c09f456a 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -6,15 +6,18 @@ functionality for enhanced declarative state management. """ +import typing from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Generic, Optional, TypeVar from django.contrib.auth import get_user_model from django.db.models import Model -from fsm.models import BaseState from pydantic import BaseModel, ConfigDict, Field +if typing.TYPE_CHECKING: + from fsm.models import BaseState + User = get_user_model() # Type variables for generic transition context From 9b5cd8d9bbbeb867ccbaa06f8b9239e2d7913329 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:16:18 -0500 Subject: [PATCH 28/83] fix imports --- label_studio/fsm/registry.py | 35 +++++++++++++++++++-------------- label_studio/fsm/transitions.py | 5 +---- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index fdf6e0756503..4667ffa7d433 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -10,10 +10,15 @@ from typing import Any, Callable, Dict, Optional, Type from django.db.models import Model, TextChoices -from fsm.transitions import BaseTransition, StateModelType, TransitionContext, User if typing.TYPE_CHECKING: from fsm.models import BaseState + from fsm.transitions import BaseTransition, StateModelType, TransitionContext, User +else: + from fsm.transitions import BaseTransition, TransitionContext, User + + # Import StateModelType at runtime to avoid circular import + StateModelType = None logger = logging.getLogger(__name__) @@ -112,14 +117,14 @@ class StateModelRegistry: """ def __init__(self): - self._models: Dict[str, Type[BaseState]] = {} + self._models: Dict[str, Type['BaseState']] = {} self._denormalizers: Dict[str, Callable[[Model], Dict[str, Any]]] = {} self._initialized = False def register_model( self, entity_name: str, - state_model: Type[BaseState], + state_model: Type['BaseState'], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None, ): """ @@ -145,7 +150,7 @@ def register_model( logger.debug(f'Registered state model for {entity_key}: {state_model.__name__}') - def get_model(self, entity_name: str) -> Optional[Type[BaseState]]: + def get_model(self, entity_name: str) -> Optional[Type['BaseState']]: """ Get the state model for an entity type. @@ -201,7 +206,7 @@ def clear(self): self._initialized = False logger.debug('Cleared state model registry') - def get_all_models(self) -> Dict[str, Type[BaseState]]: + def get_all_models(self) -> Dict[str, Type['BaseState']]: """Get all registered models.""" return self._models.copy() @@ -220,7 +225,7 @@ def is_initialized(self) -> bool: def register_state_model( - entity_name: str, state_model: Type[BaseState], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None + entity_name: str, state_model: Type['BaseState'], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None ): """ Convenience function to register a state model. @@ -233,7 +238,7 @@ def register_state_model( state_model_registry.register_model(entity_name, state_model, denormalizer) -def get_state_model(entity_name: str) -> Optional[Type[BaseState]]: +def get_state_model(entity_name: str) -> Optional[Type['BaseState']]: """ Convenience function to get a state model. @@ -246,7 +251,7 @@ def get_state_model(entity_name: str) -> Optional[Type[BaseState]]: return state_model_registry.get_model(entity_name) -def get_state_model_for_entity(entity: Model) -> Optional[Type[BaseState]]: +def get_state_model_for_entity(entity: Model) -> Optional[Type['BaseState']]: """Get the state model for an entity.""" entity_name = entity._meta.model_name.lower() return get_state_model(entity_name) @@ -261,9 +266,9 @@ class TransitionRegistry: """ def __init__(self): - self._transitions: Dict[str, Dict[str, Type[BaseTransition]]] = {} + self._transitions: Dict[str, Dict[str, Type['BaseTransition']]] = {} - def register(self, entity_name: str, transition_name: str, transition_class: Type[BaseTransition]): + def register(self, entity_name: str, transition_name: str, transition_class: Type['BaseTransition']): """ Register a transition class for an entity. @@ -277,7 +282,7 @@ def register(self, entity_name: str, transition_name: str, transition_class: Typ self._transitions[entity_name][transition_name] = transition_class - def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type[BaseTransition]]: + def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type['BaseTransition']]: """ Get a registered transition class. @@ -290,7 +295,7 @@ def get_transition(self, entity_name: str, transition_name: str) -> Optional[Typ """ return self._transitions.get(entity_name, {}).get(transition_name) - def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type[BaseTransition]]: + def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type['BaseTransition']]: """ Get all registered transitions for an entity type. @@ -320,9 +325,9 @@ def execute_transition( transition_name: str, entity: Model, transition_data: Dict[str, Any], - user: Optional[User] = None, + user: Optional['User'] = None, **context_kwargs, - ) -> StateModelType: + ) -> 'BaseState': """ Execute a registered transition. @@ -387,7 +392,7 @@ class StartTaskTransition(BaseTransition[Task, TaskState]): # ... implementation """ - def decorator(transition_class: Type[BaseTransition]) -> Type[BaseTransition]: + def decorator(transition_class: Type['BaseTransition']) -> Type['BaseTransition']: name = transition_name if name is None: # Generate name from class name diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 1c23c09f456a..cbc36936f940 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -6,18 +6,15 @@ functionality for enhanced declarative state management. """ -import typing from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Dict, Generic, Optional, TypeVar from django.contrib.auth import get_user_model from django.db.models import Model +from fsm.models import BaseState from pydantic import BaseModel, ConfigDict, Field -if typing.TYPE_CHECKING: - from fsm.models import BaseState - User = get_user_model() # Type variables for generic transition context From 08f9908de15bd8deca0292084c3a029c3b1c3138 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:20:48 -0500 Subject: [PATCH 29/83] fix typing and decorator definition for state model --- label_studio/fsm/registry.py | 25 +++++++++++++++++++++++-- label_studio/fsm/transitions.py | 15 ++++++++++----- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index 4667ffa7d433..db93b822d7c5 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -224,11 +224,32 @@ def is_initialized(self) -> bool: state_model_registry = StateModelRegistry() -def register_state_model( +def register_state_model(entity_name: str, denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None): + """ + Decorator to register a state model. + + Args: + entity_name: Name of the entity (e.g., 'task', 'annotation') + denormalizer: Optional function to extract denormalized fields + + Example: + @register_state_model('task') + class TaskState(BaseState): + # ... implementation + """ + + def decorator(state_model: Type['BaseState']) -> Type['BaseState']: + state_model_registry.register_model(entity_name, state_model, denormalizer) + return state_model + + return decorator + + +def register_state_model_class( entity_name: str, state_model: Type['BaseState'], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None ): """ - Convenience function to register a state model. + Convenience function to register a state model programmatically. Args: entity_name: Name of the entity (e.g., 'task', 'annotation') diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index cbc36936f940..340979f929fa 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -8,18 +8,23 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, TypeVar from django.contrib.auth import get_user_model from django.db.models import Model -from fsm.models import BaseState from pydantic import BaseModel, ConfigDict, Field User = get_user_model() -# Type variables for generic transition context -EntityType = TypeVar('EntityType', bound=Model) -StateModelType = TypeVar('StateModelType', bound=BaseState) +if TYPE_CHECKING: + from fsm.models import BaseState + + # Type variables for generic transition context + EntityType = TypeVar('EntityType', bound=Model) + StateModelType = TypeVar('StateModelType', bound=BaseState) +else: + EntityType = TypeVar('EntityType') + StateModelType = TypeVar('StateModelType') class TransitionContext(BaseModel, Generic[EntityType, StateModelType]): From 56f0014d31bb307f6f1bb56a9a19f79aae05d503 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:21:22 -0500 Subject: [PATCH 30/83] adding migrations --- label_studio/fsm/migrations/0001_initial.py | 384 ++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 label_studio/fsm/migrations/0001_initial.py diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py new file mode 100644 index 000000000000..d09e6a4769c7 --- /dev/null +++ b/label_studio/fsm/migrations/0001_initial.py @@ -0,0 +1,384 @@ +# Generated by Django 5.1.10 on 2025-08-27 22:19 + +import django.db.models.deletion +import fsm.utils +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ("projects", "0030_project_search_vector_index"), + ("tasks", "0057_annotation_proj_result_octlen_idx_async"), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="AnnotationState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "organization_id", + models.PositiveIntegerField( + blank=True, + db_index=True, + help_text="Organization ID that owns this state record (for multi-tenant applications)", + null=True, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("SUBMITTED", "Submitted"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "task_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.task_id - denormalized for performance", + ), + ), + ( + "project_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.task.project_id - denormalized for performance", + ), + ), + ( + "completed_by_id", + models.PositiveIntegerField( + db_index=True, + help_text="From annotation.completed_by_id - denormalized for performance", + null=True, + ), + ), + ( + "annotation", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="fsm_states", + to="tasks.annotation", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["annotation_id", "-id"], name="anno_current_state_idx" + ), + models.Index( + fields=["task_id", "state", "-id"], name="anno_task_state_idx" + ), + models.Index( + fields=["completed_by_id", "state", "-id"], + name="anno_user_report_idx", + ), + models.Index( + fields=["project_id", "state", "-id"], + name="anno_project_report_idx", + ), + ], + }, + ), + migrations.CreateModel( + name="ProjectState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "organization_id", + models.PositiveIntegerField( + blank=True, + db_index=True, + help_text="Organization ID that owns this state record (for multi-tenant applications)", + null=True, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("CREATED", "Created"), + ("IN_PROGRESS", "In Progress"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "created_by_id", + models.PositiveIntegerField( + db_index=True, + help_text="From project.created_by_id - denormalized for performance", + null=True, + ), + ), + ( + "project", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="fsm_states", + to="projects.project", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["project_id", "-id"], name="project_current_state_idx" + ), + models.Index( + fields=["organization_id", "state", "-id"], + name="project_org_state_idx", + ), + models.Index( + fields=["organization_id", "-id"], + name="project_org_reporting_idx", + ), + ], + }, + ), + migrations.CreateModel( + name="TaskState", + fields=[ + ( + "id", + models.UUIDField( + default=fsm.utils.generate_uuid7, + editable=False, + help_text="UUID7 provides natural time ordering and global uniqueness", + primary_key=True, + serialize=False, + ), + ), + ( + "organization_id", + models.PositiveIntegerField( + blank=True, + db_index=True, + help_text="Organization ID that owns this state record (for multi-tenant applications)", + null=True, + ), + ), + ( + "previous_state", + models.CharField( + blank=True, + help_text="Previous state before this transition", + max_length=50, + null=True, + ), + ), + ( + "transition_name", + models.CharField( + blank=True, + help_text="Name of the transition method that triggered this state change", + max_length=100, + null=True, + ), + ), + ( + "context_data", + models.JSONField( + default=dict, + help_text="Additional context data for this transition (e.g., validation results, external IDs)", + ), + ), + ( + "reason", + models.TextField( + blank=True, + help_text="Human-readable reason for this state transition", + ), + ), + ( + "created_at", + models.DateTimeField( + auto_now_add=True, + help_text="Human-readable timestamp for debugging (UUID7 id contains precise timestamp)", + ), + ), + ( + "state", + models.CharField( + choices=[ + ("CREATED", "Created"), + ("IN_PROGRESS", "In Progress"), + ("COMPLETED", "Completed"), + ], + db_index=True, + max_length=50, + ), + ), + ( + "project_id", + models.PositiveIntegerField( + db_index=True, + help_text="From task.project_id - denormalized for performance", + ), + ), + ( + "task", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="fsm_states", + to="tasks.task", + ), + ), + ( + "triggered_by", + models.ForeignKey( + help_text="User who triggered this state transition", + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "ordering": ["-id"], + "indexes": [ + models.Index( + fields=["task_id", "-id"], name="task_current_state_idx" + ), + models.Index( + fields=["project_id", "state", "-id"], + name="task_project_state_idx", + ), + models.Index( + fields=["organization_id", "state", "-id"], + name="task_org_reporting_idx", + ), + models.Index(fields=["task_id", "id"], name="task_history_idx"), + ], + }, + ), + ] From 5dc5d808bca796eba858e95220fe5baca9c520f2 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:29:30 -0500 Subject: [PATCH 31/83] removing unused code --- label_studio/fsm/extension.py | 60 ----------------------------------- 1 file changed, 60 deletions(-) delete mode 100644 label_studio/fsm/extension.py diff --git a/label_studio/fsm/extension.py b/label_studio/fsm/extension.py deleted file mode 100644 index f8dd243dc9c9..000000000000 --- a/label_studio/fsm/extension.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Minimal extension hooks for Label Studio FSM. -""" - -import logging - -logger = logging.getLogger(__name__) - - -class BaseFSMExtension: - """ - Minimal base class for FSM extensions. - - This provides the interface that extensions should implement. - """ - - @classmethod - def initialize(cls): - """Initialize the extension.""" - pass - - @classmethod - def register_models(cls): - """Register state models with the core FSM system.""" - pass - - @classmethod - def register_choices(cls): - """Register state choices with the core FSM system.""" - pass - - @classmethod - def get_state_manager(cls): - """Get the state manager class for this extension.""" - from .state_manager import StateManager - - return StateManager - - -# Extension registry for compatibility -class ExtensionRegistry: - """ - Extension registry for core Label Studio. - """ - - def __init__(self): - self._extensions = {} - - def register_extension(self, name: str, extension_class): - """Register an extension.""" - self._extensions[name] = extension_class - logger.debug(f'Registered FSM extension: {name}') - - def get_extension(self, name: str): - """Get a registered extension by name.""" - return self._extensions.get(name) - - -# Global minimal registry -extension_registry = ExtensionRegistry() From 73b8e4ed21829f392d4af3ba0a7b6923a17b3a1a Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 27 Aug 2025 17:41:51 -0500 Subject: [PATCH 32/83] moving settings from extensions to settings file --- label_studio/core/settings/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index a39b8fe3c3d8..6cacd50cc229 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -889,3 +889,9 @@ def collect_versions_dummy(**kwargs): # Data Manager # Max number of users to display in the Data Manager in Annotators/Reviewers/Comment Authors, etc DM_MAX_USERS_TO_DISPLAY = int(get_env('DM_MAX_USERS_TO_DISPLAY', 10)) + +# Base FSM (Finite State Machine) Configuration for Label Studio +FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) +FSM_ENABLE_BULK_OPERATIONS = False +FSM_CACHE_STATS_ENABLED = False +FSM_AUTO_CREATE_STATES = False From dd9dcfafcf0a31aaa914b38b13dc0ef45ec9a5af Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 07:10:05 -0500 Subject: [PATCH 33/83] moving the API integration FSM out of this PR --- label_studio/fsm/api.py | 235 --------------------------------------- label_studio/fsm/urls.py | 19 ---- 2 files changed, 254 deletions(-) delete mode 100644 label_studio/fsm/api.py delete mode 100644 label_studio/fsm/urls.py diff --git a/label_studio/fsm/api.py b/label_studio/fsm/api.py deleted file mode 100644 index d8ed952159df..000000000000 --- a/label_studio/fsm/api.py +++ /dev/null @@ -1,235 +0,0 @@ -""" -Core FSM API endpoints. - -Provides generic API endpoints for state management that can be extended -for any application using the FSM framework. -""" - -import logging - -from django.http import Http404 -from django.shortcuts import get_object_or_404 -from fsm.registry import get_state_model_for_entity -from fsm.serializers import StateHistorySerializer, StateTransitionSerializer -from fsm.state_manager import get_state_manager -from rest_framework import status, viewsets -from rest_framework.decorators import action -from rest_framework.permissions import IsAuthenticated -from rest_framework.response import Response - -logger = logging.getLogger(__name__) - - -class FSMViewSet(viewsets.ViewSet): - """ - Core FSM API endpoints. - - Provides basic state management operations: - - Get current state - - Get state history - - Trigger state transitions - """ - - permission_classes = [IsAuthenticated] - - def _get_entity_and_state_model(self, entity_type: str, entity_id: int): - """Helper to get entity instance and its state model""" - # Get the Django model class for the entity type - entity_model = self._get_entity_model(entity_type) - if not entity_model: - raise Http404(f'Unknown entity type: {entity_type}') - - # Get the entity instance - entity = get_object_or_404(entity_model, pk=entity_id) - - # Get the state model for this entity - state_model = get_state_model_for_entity(entity) - if not state_model: - raise Http404(f'No state model found for entity type: {entity_type}') - - return entity, state_model - - def _get_entity_model(self, entity_type: str): - """ - Get Django model class for entity type. - - This method should be overridden by subclasses to provide - application-specific entity type mappings. - - Example: - entity_mapping = { - 'order': 'shop.Order', - 'ticket': 'support.Ticket', - } - """ - from django.apps import apps - - # Default empty mapping - override in subclasses - entity_mapping = {} - - model_path = entity_mapping.get(entity_type.lower()) - if not model_path: - return None - - app_label, model_name = model_path.split('.') - return apps.get_model(app_label, model_name) - - @action(detail=False, methods=['get'], url_path=r'(?P\w+)/(?P\d+)/current') - def current_state(self, request, entity_type=None, entity_id=None): - """ - Get current state for an entity. - - GET /api/fsm/{entity_type}/{entity_id}/current/ - - Returns: - { - "current_state": "IN_PROGRESS", - "entity_type": "task", - "entity_id": 123 - } - """ - # Let Http404 from _get_entity_and_state_model pass through - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) - - try: - # Get current state using the configured state manager - StateManager = get_state_manager() - current_state = StateManager.get_current_state(entity) - - return Response( - { - 'current_state': current_state, - 'entity_type': entity_type, - 'entity_id': int(entity_id), - } - ) - - except Exception as e: - logger.error(f'Error getting current state for {entity_type} {entity_id}: {e}') - return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['get'], url_path=r'(?P\w+)/(?P\d+)/history') - def state_history(self, request, entity_type=None, entity_id=None): - """ - Get state history for an entity. - - GET /api/fsm/{entity_type}/{entity_id}/history/ - - Query parameters: - - limit: Maximum number of history records (default: 100) - - include_context: Include context_data in response (default: false) - - Returns: - { - "count": 5, - "results": [ - { - "id": "uuid7-id", - "state": "COMPLETED", - "previous_state": "IN_PROGRESS", - "transition_name": "complete_task", - "triggered_by": "user@example.com", - "created_at": "2024-01-15T10:30:00Z", - "reason": "Task completed by user", - "context_data": {...} // if include_context=true - }, - ... - ] - } - """ - # Let Http404 from _get_entity_and_state_model pass through - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) - - try: - # Get query parameters - limit = min(int(request.query_params.get('limit', 100)), 1000) # Max 1000 - include_context = request.query_params.get('include_context', 'false').lower() == 'true' - - # Get state history using the configured state manager - StateManager = get_state_manager() - history = StateManager.get_state_history(entity, limit) - - # Serialize the results - serializer = StateHistorySerializer(history, many=True, context={'include_context': include_context}) - - return Response( - { - 'count': len(history), - 'results': serializer.data, - } - ) - - except Exception as e: - logger.error(f'Error getting state history for {entity_type} {entity_id}: {e}') - return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) - - @action(detail=False, methods=['post'], url_path=r'(?P\w+)/(?P\d+)/transition') - def transition_state(self, request, entity_type=None, entity_id=None): - """ - Trigger a state transition for an entity. - - POST /api/fsm/{entity_type}/{entity_id}/transition/ - - Request body: - { - "new_state": "COMPLETED", - "transition_name": "complete_task", // optional - "reason": "Task completed by user", // optional - "context": { // optional - "assignment_id": 456 - } - } - - Returns: - { - "success": true, - "previous_state": "IN_PROGRESS", - "new_state": "COMPLETED", - "entity_type": "task", - "entity_id": 123 - } - """ - # Let Http404 from _get_entity_and_state_model pass through - entity, state_model = self._get_entity_and_state_model(entity_type, int(entity_id)) - - try: - # Validate request data - serializer = StateTransitionSerializer(data=request.data) - serializer.is_valid(raise_exception=True) - - data = serializer.validated_data - new_state = data['new_state'] - transition_name = data.get('transition_name') - reason = data.get('reason', '') - context = data.get('context', {}) - - # Get current state for response - StateManager = get_state_manager() - previous_state = StateManager.get_current_state(entity) - - # Perform state transition - success = StateManager.transition_state( - entity=entity, - new_state=new_state, - transition_name=transition_name, - user=request.user, - context=context, - reason=reason, - ) - - if success: - return Response( - { - 'success': True, - 'previous_state': previous_state, - 'new_state': new_state, - 'entity_type': entity_type, - 'entity_id': int(entity_id), - } - ) - else: - return Response({'error': 'State transition failed'}, status=status.HTTP_400_BAD_REQUEST) - - except Exception as e: - logger.error(f'Error transitioning state for {entity_type} {entity_id}: {e}') - return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/label_studio/fsm/urls.py b/label_studio/fsm/urls.py deleted file mode 100644 index 78499a744f3a..000000000000 --- a/label_studio/fsm/urls.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Core FSM URL patterns for Label Studio. - -Provides basic URL routing for state management API -""" - -from django.urls import include, path -from rest_framework.routers import DefaultRouter - -from .api import FSMViewSet - -# Create router for FSM API endpoints -router = DefaultRouter() -router.register(r'fsm', FSMViewSet, basename='fsm') - -# Core FSM URL patterns -urlpatterns = [ - path('api/', include(router.urls)), -] From 35adf8daa0f16393f61f2925bcfc23aee8f4d247 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 08:11:29 -0500 Subject: [PATCH 34/83] updating readme --- label_studio/fsm/README.md | 195 +++++++++++++++++++++++++++++++------ 1 file changed, 166 insertions(+), 29 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 21aaa9ead79b..50d45044bbbd 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -20,8 +20,7 @@ The FSM framework provides: 1. **BaseState**: Abstract model providing UUID7-optimized state tracking 2. **StateManager**: High-performance state management with intelligent caching 3. **Transition System**: Declarative, Pydantic-based transitions with validation -4. **State Registry**: Dynamic registration system for entities and transitions -5. **API Layer**: Generic REST endpoints for state operations +4. **State Registry**: Dynamic registration system for entity states, choices and transitions ## Quick Start @@ -41,13 +40,22 @@ class OrderStateChoices(models.TextChoices): CANCELLED = 'CANCELLED', _('Cancelled') ``` -### 2. Create State Model +### 2. Create State Model with Optional Denormalizer ```python from fsm.models import BaseState from fsm.registry import register_state_model -@register_state_model('order') +# Optional: Define denormalizer for performance optimization +def denormalize_order(entity): + """Extract frequently queried fields to avoid JOINs.""" + return { + 'customer_id': entity.customer_id, + 'store_id': entity.store_id, + 'total_amount': entity.total_amount, + } + +@register_state_model('order', denormalizer=denormalize_order) class OrderState(BaseState): # Entity relationship order = models.ForeignKey('shop.Order', related_name='fsm_states', on_delete=models.CASCADE) @@ -55,8 +63,10 @@ class OrderState(BaseState): # Override state field with choices state = models.CharField(max_length=50, choices=OrderStateChoices.choices, db_index=True) - # Denormalized fields for performance + # Denormalized fields for performance (automatically populated by denormalizer) customer_id = models.PositiveIntegerField(db_index=True) + store_id = models.PositiveIntegerField(db_index=True) + total_amount = models.DecimalField(max_digits=10, decimal_places=2) class Meta: indexes = [ @@ -64,7 +74,24 @@ class OrderState(BaseState): ] ``` -### 3. Define Transitions +### 3. Alternative: Use Built-in State Model Methods + +For simpler use cases, state models can define denormalization directly: + +```python +class OrderState(BaseState): + # ... fields ... + + @classmethod + def get_denormalized_fields(cls, entity): + """Built-in method for denormalization without registry.""" + return { + 'customer_id': entity.customer_id, + 'store_id': entity.store_id, + } +``` + +### 4. Define Transitions ```python from fsm.transitions import BaseTransition @@ -91,7 +118,7 @@ class ProcessOrderTransition(BaseTransition): } ``` -### 4. Execute Transitions +### 5. Execute Transitions ```python from fsm.transition_utils import execute_transition @@ -105,7 +132,7 @@ result = execute_transition( ) ``` -### 5. Query States +### 6. Query States ```python from fsm.state_manager import get_state_manager @@ -125,6 +152,27 @@ states = StateManager.bulk_get_current_states(orders) ## Key Features +### Denormalization for Performance + +- **Avoid JOINs**: Copy frequently queried fields to state records +- **Registry-based**: Register denormalizers with state models +- **Automatic**: Fields are populated during state transitions +- **Flexible**: Use registry decorator or built-in class method + +```python +# Using registry decorator +@register_state_model('task', denormalizer=lambda t: {'project_id': t.project_id}) +class TaskState(BaseState): + project_id = models.IntegerField(db_index=True) + # ... + +# Using built-in method +class TaskState(BaseState): + @classmethod + def get_denormalized_fields(cls, entity): + return {'project_id': entity.project_id} +``` + ### UUID7 Performance Optimization - **Natural Time Ordering**: UUID7 provides chronological ordering without separate timestamp indexes @@ -137,7 +185,7 @@ states = StateManager.bulk_get_current_states(orders) - **Composable Logic**: Reusable transition classes with inheritance - **Hooks System**: Pre/post transition hooks for custom logic -### Advanced Querying +### Advanced State Manager Features ```python # Time-range queries using UUID7 @@ -147,33 +195,56 @@ recent_states = StateManager.get_states_since( since=datetime.now() - timedelta(hours=24) ) -# Bulk operations -orders = Order.objects.filter(status='active') -current_states = StateManager.bulk_get_current_states(orders) -``` - -### API Integration +# Get current state object (not just string) +current_state_obj = StateManager.get_current_state_object(order) +if current_state_obj: + print(f"State: {current_state_obj.state}") + print(f"Since: {current_state_obj.created_at}") + print(f"By: {current_state_obj.triggered_by}") -The framework provides generic REST endpoints: +# Get state history with full objects +history = StateManager.get_state_history(order, limit=10) +for state in history: + print(f"{state.state} at {state.created_at}") +# Cache management +StateManager.invalidate_cache(order) # Clear cache for entity +StateManager.warm_cache([order1, order2, order3]) # Pre-populate cache ``` -GET /api/fsm/{entity_type}/{entity_id}/current/ # Current state -GET /api/fsm/{entity_type}/{entity_id}/history/ # State history -POST /api/fsm/{entity_type}/{entity_id}/transition/ # Execute transition -``` -Extend the base viewset +### Registry System + +The FSM uses a flexible registry pattern for decoupling: ```python -from fsm.api import FSMViewSet +from fsm.registry import ( + state_model_registry, + state_choices_registry, + transition_registry, + register_state_model, + register_state_choices, + register_transition, +) -class MyFSMViewSet(FSMViewSet): - def _get_entity_model(self, entity_type: str): - entity_mapping = { - 'order': 'shop.Order', - 'ticket': 'support.Ticket', - } - # ... implementation +# Register state choices +@register_state_choices('task') +class TaskStateChoices(models.TextChoices): + # ... + +# Register state model with denormalizer +@register_state_model('task', denormalizer=denormalize_task) +class TaskState(BaseState): + # ... + +# Register transitions +@register_transition('task', 'start_task') +class StartTaskTransition(BaseTransition): + # ... + +# Access registries directly +model = state_model_registry.get_model('task') +choices = state_choices_registry.get_choices('task') +transition = transition_registry.get_transition('task', 'start_task') ``` ## Performance Characteristics @@ -184,6 +255,72 @@ class MyFSMViewSet(FSMViewSet): - **Cache Integration**: Intelligent caching with automatic invalidation - **Memory Efficiency**: Minimal memory footprint for state objects +## Transition System Features + +### Transition Context + +```python +from fsm.transitions import TransitionContext + +# Context provides rich information during transitions +context = TransitionContext( + entity=task, + current_user=user, + current_state='CREATED', + target_state='IN_PROGRESS', + organization_id=org_id, + metadata={'source': 'api', 'priority': 'high'} +) + +# Context properties +if context.is_initial_transition: + # First state for this entity + pass +if context.has_current_state: + # Entity has existing state + pass +``` + +### Transition Utilities + +```python +from fsm.transition_utils import ( + execute_transition, + get_available_transitions, + get_transition_schema, + validate_transition_data, + TransitionBuilder, +) + +# Execute a registered transition +result = execute_transition( + entity=task, + transition_name='start_task', + transition_data={'assigned_user_id': 123}, + user=request.user +) + +# Get available transitions for an entity +available = get_available_transitions(task) + +# Get JSON schema for transition (useful for APIs) +schema = get_transition_schema(StartTaskTransition) + +# Validate transition data before execution +errors = validate_transition_data(StartTaskTransition, data) + +# Use TransitionBuilder for fluent API +builder = (TransitionBuilder(task) + .transition('start_task') + .with_data(assigned_user_id=123) + .by_user(user) + .with_context(source='api')) + +errors = builder.validate() +if not errors: + state = builder.execute() +``` + ## Extension Points ### Custom State Manager From afc800fa3dc61d9446206caae5ed923f87b846e4 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 08:26:22 -0500 Subject: [PATCH 35/83] removing unused code --- label_studio/core/settings/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 6cacd50cc229..6a51f27cad2f 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -892,6 +892,4 @@ def collect_versions_dummy(**kwargs): # Base FSM (Finite State Machine) Configuration for Label Studio FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) -FSM_ENABLE_BULK_OPERATIONS = False -FSM_CACHE_STATS_ENABLED = False FSM_AUTO_CREATE_STATES = False From ab62c9ef7e26d59cb3206b234773b2d7a254582d Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 08:35:01 -0500 Subject: [PATCH 36/83] renaming transition decorator for consistency --- label_studio/fsm/README.md | 10 ++++----- label_studio/fsm/registry.py | 6 ++--- .../fsm/tests/test_api_usage_examples.py | 16 +++++++------- .../fsm/tests/test_declarative_transitions.py | 22 +++++++++---------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 50d45044bbbd..6925b7fc44e2 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -95,10 +95,10 @@ class OrderState(BaseState): ```python from fsm.transitions import BaseTransition -from fsm.registry import register_transition +from fsm.registry import register_state_transition from pydantic import Field -@register_transition('order', 'process_order') +@register_state_transition('order', 'process_order') class ProcessOrderTransition(BaseTransition): processor_id: int = Field(..., description="ID of user processing the order") priority: str = Field('normal', description="Processing priority") @@ -223,7 +223,7 @@ from fsm.registry import ( transition_registry, register_state_model, register_state_choices, - register_transition, + register_state_transition, ) # Register state choices @@ -237,7 +237,7 @@ class TaskState(BaseState): # ... # Register transitions -@register_transition('task', 'start_task') +@register_state_transition('task', 'start_task') class StartTaskTransition(BaseTransition): # ... @@ -337,7 +337,7 @@ class CustomStateManager(BaseStateManager): ### Custom Validation ```python -@register_transition('order', 'validate_payment') +@register_state_transition('order', 'validate_payment') class PaymentValidationTransition(BaseTransition): def validate_transition(self, context) -> bool: # Custom business logic diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index db93b822d7c5..cd34d0bcf9a9 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -399,16 +399,16 @@ def execute_transition( transition_registry = TransitionRegistry() -def register_transition(entity_name: str, transition_name: str = None): +def register_state_transition(entity_name: str, transition_name: str = None): """ - Decorator to register a transition class. + Decorator to register a state transition class. Args: entity_name: Name of the entity type transition_name: Name of the transition (defaults to class name in snake_case) Example: - @register_transition('task', 'start_task') + @register_state_transition('task', 'start_task') class StartTaskTransition(BaseTransition[Task, TaskState]): # ... implementation """ diff --git a/label_studio/fsm/tests/test_api_usage_examples.py b/label_studio/fsm/tests/test_api_usage_examples.py index e70fc0b11ff4..c0c94bc7b208 100644 --- a/label_studio/fsm/tests/test_api_usage_examples.py +++ b/label_studio/fsm/tests/test_api_usage_examples.py @@ -12,7 +12,7 @@ from unittest.mock import Mock from django.test import TestCase -from fsm.registry import register_transition, transition_registry +from fsm.registry import register_state_transition, transition_registry from fsm.transition_utils import ( get_transition_schema, ) @@ -53,7 +53,7 @@ def test_rest_api_task_assignment_example(self): declarative transitions with proper validation and error handling. """ - @register_transition('task', 'api_assign_task') + @register_state_transition('task', 'api_assign_task') class APITaskAssignmentTransition(BaseTransition): """Task assignment via API with comprehensive validation""" @@ -190,7 +190,7 @@ def test_json_schema_generation_for_api_docs(self): from Pydantic transition models. """ - @register_transition('annotation', 'api_submit_annotation') + @register_state_transition('annotation', 'api_submit_annotation') class APIAnnotationSubmissionTransition(BaseTransition): """Submit annotation via API with rich metadata""" @@ -294,7 +294,7 @@ def test_bulk_operations_api_pattern(self): need to be transitioned with the same or different parameters. """ - @register_transition('task', 'bulk_status_update') + @register_state_transition('task', 'bulk_status_update') class BulkStatusUpdateTransition(BaseTransition): """Bulk status update for multiple tasks""" @@ -408,7 +408,7 @@ def test_webhook_integration_pattern(self): for external notifications and integrations. """ - @register_transition('task', 'webhook_completion') + @register_state_transition('task', 'webhook_completion') class WebhookTaskCompletionTransition(BaseTransition): """Task completion with webhook notifications""" @@ -518,7 +518,7 @@ def test_api_error_handling_patterns(self): using the transition system with proper HTTP status codes and messages. """ - @register_transition('task', 'api_critical_update') + @register_state_transition('task', 'api_critical_update') class APICriticalUpdateTransition(BaseTransition): """Critical update with extensive validation""" @@ -701,7 +701,7 @@ def test_api_versioning_and_backward_compatibility(self): """ # Version 1 API - @register_transition('task', 'update_task_v1') + @register_state_transition('task', 'update_task_v1') class UpdateTaskV1Transition(BaseTransition): """Version 1 task update API""" @@ -721,7 +721,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: } # Version 2 API with additional features - @register_transition('task', 'update_task_v2') + @register_state_transition('task', 'update_task_v2') class UpdateTaskV2Transition(UpdateTaskV1Transition): """Version 2 task update API with enhanced features""" diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 29c17f322176..d552c5ee8d65 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -14,7 +14,7 @@ from django.db import models from django.test import TestCase from django.utils.translation import gettext_lazy as _ -from fsm.registry import register_transition, transition_registry +from fsm.registry import register_state_transition, transition_registry from fsm.transition_utils import ( TransitionBuilder, get_available_transitions, @@ -67,7 +67,7 @@ def tearDown(self): def test_base_transition_class(self): """Test BaseTransition abstract functionality""" - @register_transition('test_entity', 'test_transition') + @register_state_transition('test_entity', 'test_transition') class TestTransition(BaseTransition): test_field: str = Field('default', description='Test field') @@ -120,7 +120,7 @@ def test_transition_context_properties(self): def test_transition_registry(self): """Test transition registration and retrieval""" - @register_transition('test_entity', 'test_transition') + @register_state_transition('test_entity', 'test_transition') class TestTransition(BaseTransition): @property def target_state(self) -> str: @@ -141,7 +141,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: def test_pydantic_validation(self): """Test Pydantic validation in transitions""" - @register_transition('test_entity', 'validated_transition') + @register_state_transition('test_entity', 'validated_transition') class ValidatedTransition(BaseTransition): required_field: str = Field(..., description='Required field') optional_field: int = Field(42, description='Optional field') @@ -165,7 +165,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: def test_transition_execution(self): """Test transition execution logic""" - @register_transition('test_entity', 'execution_test') + @register_state_transition('test_entity', 'execution_test') class ExecutionTestTransition(BaseTransition): value: str = Field('test', description='Test value') @@ -203,7 +203,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: def test_validation_error_handling(self): """Test transition validation error handling""" - @register_transition('test_entity', 'validation_test') + @register_state_transition('test_entity', 'validation_test') class ValidationTestTransition(BaseTransition): @property def target_state(self) -> str: @@ -237,7 +237,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: def test_transition_builder_basic(self): """Test TransitionBuilder basic functionality""" - @register_transition('test_entity', 'builder_test') + @register_state_transition('test_entity', 'builder_test') class BuilderTestTransition(BaseTransition): value: str = Field('default', description='Test value') @@ -262,7 +262,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: def test_get_available_transitions(self): """Test get_available_transitions utility""" - @register_transition('test_entity', 'available_test') + @register_state_transition('test_entity', 'available_test') class AvailableTestTransition(BaseTransition): @property def target_state(self) -> str: @@ -280,7 +280,7 @@ def test_transition_hooks(self): hook_calls = [] - @register_transition('test_entity', 'hook_test') + @register_state_transition('test_entity', 'hook_test') class HookTestTransition(BaseTransition): @property def target_state(self) -> str: @@ -327,7 +327,7 @@ def tearDown(self): def test_get_available_transitions(self): """Test getting available transitions for an entity""" - @register_transition('test_entity', 'util_test_1') + @register_state_transition('test_entity', 'util_test_1') class UtilTestTransition1(BaseTransition): @property def target_state(self) -> str: @@ -336,7 +336,7 @@ def target_state(self) -> str: def transition(self, context: TransitionContext) -> Dict[str, Any]: return {} - @register_transition('test_entity', 'util_test_2') + @register_state_transition('test_entity', 'util_test_2') class UtilTestTransition2(BaseTransition): @property def target_state(self) -> str: From 3310d7f23b6e9fa40b5b508aa6711a2db2ecf7e6 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 09:15:39 -0500 Subject: [PATCH 37/83] fixing tests to align with refactors and renames of transition functions --- .../fsm/tests/test_declarative_transitions.py | 42 ++++----- .../fsm/tests/test_fsm_integration.py | 88 ------------------- .../tests/test_integration_django_models.py | 28 +++--- 3 files changed, 35 insertions(+), 123 deletions(-) diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 0071704000d7..8fec13372924 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -55,8 +55,8 @@ def __init__(self, pk=1): self._meta.label_lower = 'tasks.annotation' -class TestTransition(BaseTransition): - """Test transition class""" +class SampleTransition(BaseTransition): + """Sample transition class for testing""" test_field: str optional_field: int = 42 @@ -99,7 +99,7 @@ def setUp(self): self.user.username = 'testuser' # Register test transition - transition_registry.register('task', 'test_transition', TestTransition) + transition_registry.register('task', 'test_transition', SampleTransition) def test_transition_context_creation(self): """Test creation of transition context""" @@ -128,14 +128,14 @@ def test_transition_context_initial_state(self): def test_transition_validation_success(self): """Test successful transition validation""" - transition = TestTransition(test_field='valid') + transition = SampleTransition(test_field='valid') context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) self.assertTrue(transition.validate_transition(context)) def test_transition_validation_failure(self): """Test transition validation failure""" - transition = TestTransition(test_field='invalid') + transition = SampleTransition(test_field='invalid') context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) with self.assertRaises(TransitionValidationError): @@ -143,7 +143,7 @@ def test_transition_validation_failure(self): def test_transition_execution(self): """Test transition data generation""" - transition = TestTransition(test_field='test_value', optional_field=100) + transition = SampleTransition(test_field='test_value', optional_field=100) context = TransitionContext(entity=self.task, current_state='CREATED', target_state=transition.target_state) result = transition.transition(context) @@ -154,8 +154,8 @@ def test_transition_execution(self): def test_transition_name_generation(self): """Test automatic transition name generation""" - transition = TestTransition(test_field='test') - self.assertEqual(transition.transition_name, 'test_transition') + transition = SampleTransition(test_field='test') + self.assertEqual(transition.transition_name, 'sample_transition') @patch('fsm.state_manager.StateManager.transition_state') @patch('fsm.state_manager.StateManager.get_current_state_object') @@ -169,7 +169,7 @@ def test_transition_execute_full_workflow(self, mock_get_state, mock_transition) mock_state_record.id = 'test-uuid' with patch('fsm.state_manager.StateManager.get_current_state_object', return_value=mock_state_record): - transition = TestTransition(test_field='test_value') + transition = SampleTransition(test_field='test_value') context = TransitionContext( entity=self.task, current_user=self.user, current_state=None, target_state=transition.target_state ) @@ -183,7 +183,7 @@ def test_transition_execute_full_workflow(self, mock_get_state, mock_transition) self.assertEqual(call_args[1]['entity'], self.task) self.assertEqual(call_args[1]['new_state'], 'TEST_STATE') - self.assertEqual(call_args[1]['transition_name'], 'test_transition') + self.assertEqual(call_args[1]['transition_name'], 'sample_transition') self.assertEqual(call_args[1]['user'], self.user) # Check context data @@ -200,15 +200,15 @@ def setUp(self): def test_transition_registration(self): """Test registering transitions""" - self.registry.register('test_entity', 'test_transition', TestTransition) + self.registry.register('test_entity', 'test_transition', SampleTransition) retrieved = self.registry.get_transition('test_entity', 'test_transition') - self.assertEqual(retrieved, TestTransition) + self.assertEqual(retrieved, SampleTransition) def test_get_transitions_for_entity(self): """Test getting all transitions for an entity""" - self.registry.register('test_entity', 'transition1', TestTransition) - self.registry.register('test_entity', 'transition2', TestTransition) + self.registry.register('test_entity', 'transition1', SampleTransition) + self.registry.register('test_entity', 'transition2', SampleTransition) transitions = self.registry.get_transitions_for_entity('test_entity') @@ -218,8 +218,8 @@ def test_get_transitions_for_entity(self): def test_list_entities(self): """Test listing registered entities""" - self.registry.register('entity1', 'transition1', TestTransition) - self.registry.register('entity2', 'transition2', TestTransition) + self.registry.register('entity1', 'transition1', SampleTransition) + self.registry.register('entity2', 'transition2', SampleTransition) entities = self.registry.list_entities() @@ -232,7 +232,7 @@ class TransitionUtilsTests(TestCase): def setUp(self): self.task = MockTask() - transition_registry.register('task', 'test_transition', TestTransition) + transition_registry.register('task', 'test_transition', SampleTransition) def test_get_available_transitions(self): """Test getting available transitions for entity""" @@ -253,7 +253,7 @@ def test_get_valid_transitions_with_invalid(self, mock_get_state): mock_get_state.return_value = None # Register an invalid transition - class InvalidTransition(TestTransition): + class InvalidTransition(SampleTransition): @classmethod def can_transition_from_state(cls, context): # This transition is never valid at the class level @@ -289,6 +289,7 @@ def test_transition_builder(self, mock_execute): self.assertEqual(call_args[1]['transition_data']['test_field'], 'builder_test') +@pytest.mark.skip(reason='example_transitions.py module not yet implemented') class ExampleTransitionIntegrationTests(TestCase): """Integration tests using the example transitions""" @@ -667,6 +668,7 @@ class PublishDocumentTransition(BaseTransition): publish_immediately: bool = Field(True, description='Publish immediately') scheduled_time: datetime = Field(None, description='Scheduled publish time') + @property def target_state(self) -> str: return 'PUBLISHED' if self.publish_immediately else 'SCHEDULED' @@ -911,10 +913,10 @@ def test_transition_context_properties(task, user): def test_pydantic_validation(): """Test Pydantic validation in transitions""" # Valid data - transition = TestTransition(test_field='valid') + transition = SampleTransition(test_field='valid') assert transition.test_field == 'valid' assert transition.optional_field == 42 # Invalid data should raise validation error with pytest.raises(Exception): # Pydantic validation error - TestTransition() # Missing required field + SampleTransition() # Missing required field diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 1ab7c6096a61..8857e30f0d86 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -11,7 +11,6 @@ from fsm.models import AnnotationState, ProjectState, TaskState from fsm.state_manager import get_state_manager from projects.models import Project -from rest_framework.test import APITestCase from tasks.models import Annotation, Task User = get_user_model() @@ -213,90 +212,3 @@ def test_get_states_in_time_range(self): # Should find both states self.assertEqual(len(states_in_range), 2) - - -class TestFSMAPI(APITestCase): - """Test FSM API endpoints""" - - def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) - self.client.force_authenticate(user=self.user) - - # Clear cache to ensure tests start with clean state - from django.core.cache import cache - - cache.clear() - - # Create initial state - StateManager = get_state_manager() - StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) - - def test_get_current_state_api(self): - """Test GET /api/fsm/{entity_type}/{entity_id}/current/""" - response = self.client.get(f'/api/fsm/task/{self.task.id}/current/') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertEqual(data['current_state'], 'CREATED') - self.assertEqual(data['entity_type'], 'task') - self.assertEqual(data['entity_id'], self.task.id) - - def test_get_state_history_api(self): - """Test GET /api/fsm/{entity_type}/{entity_id}/history/""" - # Create additional states - StateManager = get_state_manager() - StateManager.transition_state( - entity=self.task, new_state='IN_PROGRESS', user=self.user, transition_name='start_work' - ) - - response = self.client.get(f'/api/fsm/task/{self.task.id}/history/') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertEqual(data['count'], 2) - self.assertEqual(len(data['results']), 2) - - # Check first result (most recent) - latest_state = data['results'][0] - self.assertEqual(latest_state['state'], 'IN_PROGRESS') - self.assertEqual(latest_state['previous_state'], 'CREATED') - self.assertEqual(latest_state['transition_name'], 'start_work') - - def test_transition_state_api(self): - """Test POST /api/fsm/{entity_type}/{entity_id}/transition/""" - transition_data = { - 'new_state': 'IN_PROGRESS', - 'transition_name': 'start_annotation', - 'reason': 'User started working on task', - 'context': {'assignment_id': 123}, - } - - response = self.client.post(f'/api/fsm/task/{self.task.id}/transition/', data=transition_data, format='json') - - self.assertEqual(response.status_code, 200) - data = response.json() - - self.assertTrue(data['success']) - self.assertEqual(data['previous_state'], 'CREATED') - self.assertEqual(data['new_state'], 'IN_PROGRESS') - self.assertEqual(data['entity_type'], 'task') - self.assertEqual(data['entity_id'], self.task.id) - - # Verify state was actually changed - StateManager = get_state_manager() - current_state = StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'IN_PROGRESS') - - def test_api_with_invalid_entity(self): - """Test API with non-existent entity""" - response = self.client.get('/api/fsm/task/99999/current/') - self.assertEqual(response.status_code, 404) - - def test_api_with_invalid_entity_type(self): - """Test API with invalid entity type""" - response = self.client.get('/api/fsm/invalid/1/current/') - self.assertEqual(response.status_code, 404) diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py index f0f4268dcd5b..d2026e28185e 100644 --- a/label_studio/fsm/tests/test_integration_django_models.py +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -11,9 +11,10 @@ from django.contrib.auth import get_user_model from django.test import TestCase from fsm.models import TaskState +from fsm.registry import register_state_transition, transition_registry from fsm.state_choices import AnnotationStateChoices, TaskStateChoices from fsm.transition_utils import TransitionBuilder -from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError, register_transition +from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError from pydantic import Field @@ -73,10 +74,7 @@ def setUp(self): self.user.id = 123 self.user.username = 'integration_test_user' - # Clear registry for clean test state - from fsm.transitions import transition_registry - - transition_registry._transitions.clear() + transition_registry.clear() @patch('fsm.registry.get_state_model_for_entity') @patch('fsm.state_manager.StateManager.get_current_state_object') @@ -94,7 +92,7 @@ def test_task_workflow_integration(self, mock_transition_state, mock_get_state_o mock_transition_state.return_value = True # Define task workflow transitions - @register_transition('task', 'create_task') + @register_state_transition('task', 'create_task') class CreateTaskTransition(BaseTransition): """Initial task creation transition""" @@ -120,7 +118,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: 'creation_method': 'declarative_transition', } - @register_transition('task', 'assign_and_start') + @register_state_transition('task', 'assign_and_start') class AssignAndStartTaskTransition(BaseTransition): """Assign task to user and start work""" @@ -161,7 +159,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: 'work_started': True, } - @register_transition('task', 'complete_with_quality') + @register_state_transition('task', 'complete_with_quality') class CompleteTaskWithQualityTransition(BaseTransition): """Complete task with quality metrics""" @@ -296,7 +294,7 @@ def test_annotation_review_workflow_integration(self): enterprise-grade validation and approval logic. """ - @register_transition('annotation', 'submit_for_review') + @register_state_transition('annotation', 'submit_for_review') class SubmitAnnotationForReview(BaseTransition): """Submit annotation for quality review""" @@ -332,7 +330,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: 'submitted_by_id': context.current_user.id if context.current_user else None, } - @register_transition('annotation', 'review_and_approve') + @register_state_transition('annotation', 'review_and_approve') class ReviewAndApproveAnnotation(BaseTransition): """Review annotation and approve/reject""" @@ -346,7 +344,7 @@ def target_state(self) -> str: if self.reviewer_decision == 'approve': return AnnotationStateChoices.COMPLETED else: - return AnnotationStateChoices.DRAFT # Back to draft for changes + return AnnotationStateChoices.SUBMITTED # Back to submitted for changes def validate_transition(self, context: TransitionContext) -> bool: if context.current_state != AnnotationStateChoices.SUBMITTED: @@ -394,7 +392,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: context = TransitionContext( entity=self.annotation, current_user=self.user, - current_state=AnnotationStateChoices.DRAFT, + current_state=AnnotationStateChoices.SUBMITTED, target_state=submit_transition.target_state, ) @@ -443,7 +441,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: corrections_made=False, ) - self.assertEqual(reject_transition.target_state, AnnotationStateChoices.DRAFT) + self.assertEqual(reject_transition.target_state, AnnotationStateChoices.SUBMITTED) # Test validation failure invalid_review = ReviewAndApproveAnnotation( @@ -465,7 +463,7 @@ def test_transition_builder_with_django_models(self, mock_execute): real Django models and complex business logic. """ - @register_transition('task', 'bulk_update_status') + @register_state_transition('task', 'bulk_update_status') class BulkUpdateTaskStatusTransition(BaseTransition): """Bulk update task status with metadata""" @@ -549,7 +547,7 @@ def test_error_handling_with_django_models(self): in real Django model integration. """ - @register_transition('task', 'assign_with_constraints') + @register_state_transition('task', 'assign_with_constraints') class AssignTaskWithConstraints(BaseTransition): """Task assignment with business constraints""" From c46062accdea1a7e338a2aeec18eb7ab4f3cd965 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 09:19:22 -0500 Subject: [PATCH 38/83] removing unused code --- label_studio/fsm/models.py | 42 ------------------------------- label_studio/fsm/state_choices.py | 32 ++--------------------- 2 files changed, 2 insertions(+), 72 deletions(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index d0da3c433166..f68d49ef3cbd 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -166,45 +166,3 @@ def _get_entity_field_name(cls) -> str: if model_name.endswith('State'): return model_name[:-5].lower() return 'entity' - - -# Registry for dynamic state model extension -STATE_MODEL_REGISTRY = {} - - -def register_state_model(entity_name: str, model_class): - """ - Register state model for an entity type. - - Args: - entity_name: Name of the entity (e.g., 'review', 'assignment') - model_class: Django model class inheriting from BaseState - """ - STATE_MODEL_REGISTRY[entity_name.lower()] = model_class - - -def get_state_model(entity_name: str): - """ - Get state model for an entity type. - - Args: - entity_name: Name of the entity - - Returns: - Django model class inheriting from BaseState, or None if not found - """ - return STATE_MODEL_REGISTRY.get(entity_name.lower()) - - -def get_state_model_for_entity(entity): - """ - Get state model for a specific entity instance. - - Args: - entity: Django model instance - - Returns: - Django model class inheriting from BaseState, or None if not found - """ - entity_name = entity._meta.model_name.lower() - return get_state_model(entity_name) diff --git a/label_studio/fsm/state_choices.py b/label_studio/fsm/state_choices.py index e00cd12c3e06..d601ea397023 100644 --- a/label_studio/fsm/state_choices.py +++ b/label_studio/fsm/state_choices.py @@ -1,33 +1,5 @@ """ -FSM state choices registry system. +FSM state choices -This module provides the infrastructure for registering and managing -state choices for different entity types in the FSM framework. +This module provides the state choices for different entity types in the FSM framework. """ - -# Registry for dynamic state choices extension -STATE_CHOICES_REGISTRY = {} - - -def register_state_choices(entity_name: str, choices_class): - """ - Register state choices for an entity type. - - Args: - entity_name: Name of the entity (e.g., 'order', 'ticket') - choices_class: Django TextChoices class defining valid states - """ - STATE_CHOICES_REGISTRY[entity_name.lower()] = choices_class - - -def get_state_choices(entity_name: str): - """ - Get state choices for an entity type. - - Args: - entity_name: Name of the entity - - Returns: - Django TextChoices class or None if not found - """ - return STATE_CHOICES_REGISTRY.get(entity_name.lower()) From 3f7c288782bd92bd9e0b714b9a2f984298dfe073 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 28 Aug 2025 12:41:31 -0500 Subject: [PATCH 39/83] removing serializers from this PR, will be included in a different one --- label_studio/fsm/serializers.py | 84 --------------------------------- 1 file changed, 84 deletions(-) delete mode 100644 label_studio/fsm/serializers.py diff --git a/label_studio/fsm/serializers.py b/label_studio/fsm/serializers.py deleted file mode 100644 index 823df832d3b0..000000000000 --- a/label_studio/fsm/serializers.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Core FSM serializers for Label Studio. - -Provides basic serializers for state management API -""" - -from rest_framework import serializers - - -class StateHistorySerializer(serializers.Serializer): - """ - Serializer for state history records. - - Provides basic state history information - """ - - id = serializers.UUIDField(read_only=True) - state = serializers.CharField(read_only=True) - previous_state = serializers.CharField(read_only=True, allow_null=True) - transition_name = serializers.CharField(read_only=True, allow_null=True) - triggered_by = serializers.SerializerMethodField() - created_at = serializers.DateTimeField(read_only=True) - reason = serializers.CharField(read_only=True) - context_data = serializers.SerializerMethodField() - - def get_triggered_by(self, obj): - """Get user who triggered the transition""" - if obj.triggered_by: - return { - 'id': obj.triggered_by.id, - 'email': obj.triggered_by.email, - 'first_name': getattr(obj.triggered_by, 'first_name', ''), - 'last_name': getattr(obj.triggered_by, 'last_name', ''), - } - return None - - def get_context_data(self, obj): - """Include context data if requested""" - include_context = self.context.get('include_context', False) - if include_context: - return obj.context_data - return None - - -class StateTransitionSerializer(serializers.Serializer): - """ - Serializer for state transition requests. - - Validates state transition request data. - """ - - new_state = serializers.CharField(required=True, help_text='Target state to transition to') - transition_name = serializers.CharField( - required=False, allow_blank=True, help_text='Name of the transition method (for audit trail)' - ) - reason = serializers.CharField( - required=False, allow_blank=True, help_text='Human-readable reason for the transition' - ) - context = serializers.JSONField( - required=False, default=dict, help_text='Additional context data for the transition' - ) - - def validate_new_state(self, value): - """Validate that new_state is not empty""" - if not value or not value.strip(): - raise serializers.ValidationError('new_state cannot be empty') - return value.strip().upper() - - -class StateInfoSerializer(serializers.Serializer): - """ - Serializer for basic state information. - - Used for current state responses. - """ - - current_state = serializers.CharField(allow_null=True) - entity_type = serializers.CharField() - entity_id = serializers.IntegerField() - - available_transitions = serializers.ListField( - child=serializers.CharField(), required=False, help_text='List of valid transitions from current state' - ) - state_metadata = serializers.JSONField(required=False, help_text='Additional metadata about the current state') From 9c7b74252e604168681acc191486f2e11ce50fe6 Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 3 Sep 2025 09:55:23 -0500 Subject: [PATCH 40/83] Update label_studio/fsm/state_manager.py Co-authored-by: Marcel Canu --- label_studio/fsm/state_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index c7cd5dd61c3f..8f9b839eec5c 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -282,7 +282,7 @@ def invalidate_cache(cls, entity: Model): @classmethod def warm_cache(cls, entities: List[Model]): """ - invalidate_cacheWarm cache with current states for a list of entities. + Warm cache with current states for a list of entities. Basic implementation that can be optimized by Enterprise with bulk queries and advanced caching strategies. From 43ecd7b7a635a410138efea6df6931d0ce01daea Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 12:35:23 -0500 Subject: [PATCH 41/83] updating based on feedback about test assertions/file names --- label_studio/fsm/apps.py | 33 ----- ..._usage_examples.py => test_api_example.py} | 139 +++++++++--------- .../fsm/tests/test_declarative_transitions.py | 73 ++++----- .../tests/test_edge_cases_error_handling.py | 111 +++++++------- .../fsm/tests/test_performance_concurrency.py | 90 ++++++------ label_studio/fsm/tests/test_uuid7_utils.py | 44 +++--- 6 files changed, 230 insertions(+), 260 deletions(-) rename label_studio/fsm/tests/{test_api_usage_examples.py => test_api_example.py} (87%) diff --git a/label_studio/fsm/apps.py b/label_studio/fsm/apps.py index 4eb0e0720da7..fd08fc0d6da2 100644 --- a/label_studio/fsm/apps.py +++ b/label_studio/fsm/apps.py @@ -11,36 +11,3 @@ class FsmConfig(AppConfig): default_auto_field = 'django.db.models.UUIDField' name = 'fsm' verbose_name = 'Label Studio FSM' - - def ready(self): - """Initialize FSM system when Django app is ready""" - # Initialize extension system - self._initialize_extensions() - - # Set up signal handlers for automatic state creation - self._setup_signals() - - logger.info('FSM system initialized') - - def _initialize_extensions(self): - """Initialize FSM extension system""" - try: - # Import the extension registry to ensure it's initialized - - # Basic extension system is ready - logger.debug('FSM extension system ready') - - except Exception as e: - logger.error(f'Failed to initialize FSM extensions: {e}') - - def _setup_signals(self): - """Set up signal handlers for automatic state creation""" - try: - from django.conf import settings - - # Only set up signals if enabled in settings - if getattr(settings, 'FSM_AUTO_CREATE_STATES', False): - logger.info('FSM signal handlers registered') - - except Exception as e: - logger.error(f'Failed to set up FSM signals: {e}') diff --git a/label_studio/fsm/tests/test_api_usage_examples.py b/label_studio/fsm/tests/test_api_example.py similarity index 87% rename from label_studio/fsm/tests/test_api_usage_examples.py rename to label_studio/fsm/tests/test_api_example.py index c0c94bc7b208..c705a371e884 100644 --- a/label_studio/fsm/tests/test_api_usage_examples.py +++ b/label_studio/fsm/tests/test_api_example.py @@ -11,6 +11,7 @@ from typing import Any, Dict, List, Optional from unittest.mock import Mock +import pytest from django.test import TestCase from fsm.registry import register_state_transition, transition_registry from fsm.transition_utils import ( @@ -130,7 +131,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) # Validate - self.assertTrue(transition.validate_transition(context)) + assert transition.validate_transition(context) # Execute result_data = transition.transition(context) @@ -148,10 +149,10 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: } # Validate API response - self.assertTrue(api_response['success']) - self.assertEqual(api_response['data']['new_state'], 'ASSIGNED') - self.assertEqual(api_response['data']['assignment_details']['assignee_id'], 123) - self.assertEqual(api_response['data']['assignment_details']['priority'], 'high') + assert api_response['success'] + assert api_response['data']['new_state'] == 'ASSIGNED' + assert api_response['data']['assignment_details']['assignee_id'] == 123 + assert api_response['data']['assignment_details']['priority'] == 'high' except ValueError as e: # Handle Pydantic validation errors @@ -179,7 +180,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: 'deadline': '2020-01-01T00:00:00', # Past deadline } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): APITaskAssignmentTransition(**invalid_request) def test_json_schema_generation_for_api_docs(self): @@ -228,43 +229,43 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: schema = get_transition_schema(APIAnnotationSubmissionTransition) # Validate schema structure - self.assertIn('properties', schema) - self.assertIn('required', schema) + assert 'properties' in schema + assert 'required' in schema # Check specific field schemas properties = schema['properties'] # confidence_score should have min/max constraints confidence_schema = properties['confidence_score'] - self.assertEqual(confidence_schema['type'], 'number') - self.assertEqual(confidence_schema['minimum'], 0.0) - self.assertEqual(confidence_schema['maximum'], 1.0) - self.assertIn("Annotator's confidence", confidence_schema['description']) + assert confidence_schema['type'] == 'number' + assert confidence_schema['minimum'] == 0.0 + assert confidence_schema['maximum'] == 1.0 + assert "Annotator's confidence" in confidence_schema['description'] # annotation_quality should have pattern constraint quality_schema = properties['annotation_quality'] - self.assertEqual(quality_schema['type'], 'string') - self.assertIn('pattern', quality_schema) + assert quality_schema['type'] == 'string' + assert 'pattern' in quality_schema # time_spent_seconds should have minimum constraint time_schema = properties['time_spent_seconds'] - self.assertEqual(time_schema['type'], 'integer') - self.assertEqual(time_schema['minimum'], 1) + assert time_schema['type'] == 'integer' + assert time_schema['minimum'] == 1 # tags should be array type tags_schema = properties['tags'] - self.assertEqual(tags_schema['type'], 'array') - self.assertEqual(tags_schema['items']['type'], 'string') + assert tags_schema['type'] == 'array' + assert tags_schema['items']['type'] == 'string' # metadata should be object type metadata_schema = properties['metadata'] - self.assertEqual(metadata_schema['type'], 'object') + assert metadata_schema['type'] == 'object' # Required fields required_fields = schema['required'] - self.assertIn('confidence_score', required_fields) - self.assertIn('time_spent_seconds', required_fields) - self.assertNotIn('tags', required_fields) # Optional field + assert 'confidence_score' in required_fields + assert 'time_spent_seconds' in required_fields + assert 'tags' not in required_fields # Optional field # Test schema-driven validation valid_data = { @@ -278,13 +279,13 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: } transition = APIAnnotationSubmissionTransition(**valid_data) - self.assertEqual(transition.confidence_score, 0.85) - self.assertEqual(len(transition.tags), 2) + assert transition.confidence_score == 0.85 + assert len(transition.tags) == 2 # Print schema for documentation (would be used in API docs) schema_json = json.dumps(schema, indent=2) - self.assertIsInstance(schema_json, str) - self.assertIn('confidence_score', schema_json) + assert isinstance(schema_json, str) + assert 'confidence_score' in schema_json def test_bulk_operations_api_pattern(self): """ @@ -387,18 +388,18 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: } # Validate bulk results - self.assertEqual(api_response['total_requested'], 5) - self.assertGreater(api_response['successful_updates'], 0) + assert api_response['total_requested'] == 5 + assert api_response['successful_updates'] > 0 # Some tasks should succeed, some might fail due to state validation total_processed = api_response['successful_updates'] + api_response['failed_updates'] - self.assertEqual(total_processed, 5) + assert total_processed == 5 # Check individual results for result in batch_results: - self.assertTrue(result['success']) - self.assertEqual(result['result']['new_status'], 'IN_PROGRESS') - self.assertEqual(result['result']['batch_id'], 'batch_2024_001') + assert result['success'] + assert result['result']['new_status'] == 'IN_PROGRESS' + assert result['result']['batch_id'] == 'batch_2024_001' def test_webhook_integration_pattern(self): """ @@ -485,7 +486,7 @@ def post_transition_hook(self, context: TransitionContext, state_record) -> None ) # Validate and execute - self.assertTrue(transition.validate_transition(context)) + assert transition.validate_transition(context) transition.transition(context) # Simulate state record creation @@ -496,19 +497,19 @@ def post_transition_hook(self, context: TransitionContext, state_record) -> None transition.post_transition_hook(context, mock_state_record) # Validate webhook responses - self.assertEqual(len(transition.webhook_responses), 2) + assert len(transition.webhook_responses) == 2 for response in transition.webhook_responses: - self.assertIn('url', response) - self.assertIn('payload', response) - self.assertEqual(response['status'], 'sent') + assert 'url' in response + assert 'payload' in response + assert response['status'] == 'sent' # Validate webhook payload structure payload = response['payload'] - self.assertEqual(payload['event'], 'task.completed') - self.assertEqual(payload['task_id'], self.mock_entity.pk) - self.assertEqual(payload['completion_data']['quality'], 0.95) - self.assertEqual(payload['custom_data']['project_id'], 123) + assert payload['event'] == 'task.completed' + assert payload['task_id'] == self.mock_entity.pk + assert payload['completion_data']['quality'] == 0.95 + assert payload['custom_data']['project_id'] == 123 def test_api_error_handling_patterns(self): """ @@ -654,9 +655,9 @@ def simulate_api_endpoint(request_data, current_state='IN_PROGRESS'): # Test successful request response = simulate_api_endpoint(valid_request) - self.assertEqual(response['status_code'], 200) - self.assertTrue(response['success']) - self.assertIn('update_details', response['data']) + assert response['status_code'] == 200 + assert response['success'] + assert 'update_details' in response['data'] # Test Pydantic validation error (invalid severity level) invalid_request = { @@ -666,9 +667,9 @@ def simulate_api_endpoint(request_data, current_state='IN_PROGRESS'): } response = simulate_api_endpoint(invalid_request) - self.assertEqual(response['status_code'], 400) - self.assertFalse(response['success']) - self.assertEqual(response['error'], 'Bad Request') + assert response['status_code'] == 400 + assert not response['success'] + assert response['error'] == 'Bad Request' # Test business logic validation error business_logic_error_request = { @@ -679,18 +680,18 @@ def simulate_api_endpoint(request_data, current_state='IN_PROGRESS'): } response = simulate_api_endpoint(business_logic_error_request) - self.assertEqual(response['status_code'], 422) - self.assertFalse(response['success']) - self.assertEqual(response['error'], 'Validation Failed') - self.assertIn('validation_errors', response) - self.assertGreater(len(response['validation_errors']), 0) + assert response['status_code'] == 422 + assert not response['success'] + assert response['error'] == 'Validation Failed' + assert 'validation_errors' in response + assert len(response['validation_errors']) > 0 # Test state validation error response = simulate_api_endpoint(valid_request, current_state='COMPLETED') - self.assertEqual(response['status_code'], 422) + assert response['status_code'] == 422 # The error message is in validation_errors list, not the main message validation_errors = response.get('validation_errors', []) - self.assertTrue(any('completed tasks' in error for error in validation_errors)) + assert any('completed tasks' in error for error in validation_errors) def test_api_versioning_and_backward_compatibility(self): """ @@ -754,9 +755,9 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) v1_result = v1_transition.transition(context) - self.assertEqual(v1_result['api_version'], 'v1') - self.assertEqual(v1_result['status'], 'IN_PROGRESS') - self.assertNotIn('priority', v1_result) # V1 doesn't have priority + assert v1_result['api_version'] == 'v1' + assert v1_result['status'] == 'IN_PROGRESS' + assert 'priority' not in v1_result # V1 doesn't have priority # Test V2 API with enhanced features v2_request = { @@ -771,12 +772,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: v2_transition = UpdateTaskV2Transition(**v2_request) v2_result = v2_transition.transition(context) - self.assertEqual(v2_result['api_version'], 'v2') - self.assertEqual(v2_result['status'], 'IN_PROGRESS') # Inherited from V1 - self.assertEqual(v2_result['priority'], 'high') # V2 feature - self.assertEqual(len(v2_result['tags']), 2) # V2 feature - self.assertEqual(v2_result['estimated_hours'], 4.5) # V2 feature - self.assertIn('client_id', v2_result['metadata']) # V2 feature + assert v2_result['api_version'] == 'v2' + assert v2_result['status'] == 'IN_PROGRESS' # Inherited from V1 + assert v2_result['priority'] == 'high' # V2 feature + assert len(v2_result['tags']) == 2 # V2 feature + assert v2_result['estimated_hours'] == 4.5 # V2 feature + assert 'client_id' in v2_result['metadata'] # V2 feature # Test V2 API with minimal data (backward compatible) v2_minimal_request = {'status': 'COMPLETED', 'notes': 'Task finished'} @@ -784,9 +785,9 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: v2_minimal_transition = UpdateTaskV2Transition(**v2_minimal_request) v2_minimal_result = v2_minimal_transition.transition(context) - self.assertEqual(v2_minimal_result['api_version'], 'v2') - self.assertEqual(v2_minimal_result['status'], 'COMPLETED') - self.assertIsNone(v2_minimal_result['priority']) # Optional field - self.assertEqual(v2_minimal_result['tags'], []) # Default value - self.assertIsNone(v2_minimal_result['estimated_hours']) # Optional field - self.assertEqual(v2_minimal_result['metadata'], {}) # Default value + assert v2_minimal_result['api_version'] == 'v2' + assert v2_minimal_result['status'] == 'COMPLETED' + assert v2_minimal_result['priority'] is None # Optional field + assert v2_minimal_result['tags'] == [] # Default value + assert v2_minimal_result['estimated_hours'] is None # Optional field + assert v2_minimal_result['metadata'] == {} # Default value diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index d552c5ee8d65..d0811f888a37 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -10,6 +10,7 @@ from typing import Any, Dict from unittest.mock import Mock +import pytest from django.contrib.auth import get_user_model from django.db import models from django.test import TestCase @@ -80,9 +81,9 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test instantiation transition = TestTransition(test_field='test_value') - self.assertEqual(transition.test_field, 'test_value') - self.assertEqual(transition.target_state, TestStateChoices.IN_PROGRESS) - self.assertEqual(transition.transition_name, 'test_transition') + assert transition.test_field == 'test_value' + assert transition.target_state == TestStateChoices.IN_PROGRESS + assert transition.transition_name == 'test_transition' def test_transition_context(self): """Test TransitionContext functionality""" @@ -94,19 +95,19 @@ def test_transition_context(self): current_user=self.user, ) - self.assertEqual(context.entity, self.mock_entity) - self.assertEqual(context.current_state, TestStateChoices.CREATED) - self.assertEqual(context.target_state, TestStateChoices.IN_PROGRESS) - self.assertEqual(context.current_user, self.user) - self.assertTrue(context.has_current_state) - self.assertFalse(context.is_initial_transition) + assert context.entity == self.mock_entity + assert context.current_state == TestStateChoices.CREATED + assert context.target_state == TestStateChoices.IN_PROGRESS + assert context.current_user == self.user + assert context.has_current_state + assert not context.is_initial_transition def test_transition_context_properties(self): """Test TransitionContext computed properties""" # Test initial transition context = TransitionContext(entity=self.mock_entity, current_state=None, target_state=TestStateChoices.CREATED) - self.assertTrue(context.is_initial_transition) - self.assertFalse(context.has_current_state) + assert context.is_initial_transition + assert not context.has_current_state # Test with current state context_with_state = TransitionContext( @@ -114,8 +115,8 @@ def test_transition_context_properties(self): current_state=TestStateChoices.CREATED, target_state=TestStateChoices.IN_PROGRESS, ) - self.assertFalse(context_with_state.is_initial_transition) - self.assertTrue(context_with_state.has_current_state) + assert not context_with_state.is_initial_transition + assert context_with_state.has_current_state def test_transition_registry(self): """Test transition registration and retrieval""" @@ -131,12 +132,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test registration retrieved = transition_registry.get_transition('test_entity', 'test_transition') - self.assertEqual(retrieved, TestTransition) + assert retrieved == TestTransition # Test entity transitions entity_transitions = transition_registry.get_transitions_for_entity('test_entity') - self.assertIn('test_transition', entity_transitions) - self.assertEqual(entity_transitions['test_transition'], TestTransition) + assert 'test_transition' in entity_transitions + assert entity_transitions['test_transition'] == TestTransition def test_pydantic_validation(self): """Test Pydantic validation in transitions""" @@ -155,11 +156,11 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test valid instantiation transition = ValidatedTransition(required_field='test') - self.assertEqual(transition.required_field, 'test') - self.assertEqual(transition.optional_field, 42) + assert transition.required_field == 'test' + assert transition.optional_field == 42 # Test validation error - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): ValidatedTransition() # Missing required field def test_transition_execution(self): @@ -192,13 +193,13 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) # Test validation - self.assertTrue(transition.validate_transition(context)) + assert transition.validate_transition(context) # Test execution result = transition.transition(context) - self.assertEqual(result['value'], 'execution_test') - self.assertEqual(result['entity_id'], self.mock_entity.pk) - self.assertIn('timestamp', result) + assert result['value'] == 'execution_test' + assert result['entity_id'] == self.mock_entity.pk + assert 'timestamp' in result def test_validation_error_handling(self): """Test transition validation error handling""" @@ -227,12 +228,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) # Test validation error - with self.assertRaises(TransitionValidationError) as cm: + with pytest.raises(TransitionValidationError) as cm: transition.validate_transition(invalid_context) - error = cm.exception - self.assertIn('Can only complete from IN_PROGRESS state', str(error)) - self.assertIn('current_state', error.context) + error = cm.value + assert 'Can only complete from IN_PROGRESS state' in str(error) + assert 'current_state' in error.context def test_transition_builder_basic(self): """Test TransitionBuilder basic functionality""" @@ -250,14 +251,14 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test builder creation builder = TransitionBuilder(self.mock_entity) - self.assertEqual(builder.entity, self.mock_entity) + assert builder.entity == self.mock_entity # Test method chaining builder = builder.transition('builder_test').with_data(value='builder_test_value').by_user(self.user) # Validate the builder state validation_errors = builder.validate() - self.assertEqual(len(validation_errors), 0) + assert len(validation_errors) == 0 def test_get_available_transitions(self): """Test get_available_transitions utility""" @@ -272,8 +273,8 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: return {} available = get_available_transitions(self.mock_entity) - self.assertIn('available_test', available) - self.assertEqual(available['available_test'], AvailableTestTransition) + assert 'available_test' in available + assert available['available_test'] == AvailableTestTransition def test_transition_hooks(self): """Test pre and post transition hooks""" @@ -308,7 +309,7 @@ def post_transition_hook(self, context: TransitionContext, state_record) -> None transition.transition(context) transition.post_transition_hook(context, Mock()) - self.assertEqual(hook_calls, ['pre', 'transition', 'post']) + assert hook_calls == ['pre', 'transition', 'post'] class TransitionUtilsTests(TestCase): @@ -346,12 +347,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: return {} available = get_available_transitions(self.mock_entity) - self.assertEqual(len(available), 2) - self.assertIn('util_test_1', available) - self.assertIn('util_test_2', available) + assert len(available) == 2 + assert 'util_test_1' in available + assert 'util_test_2' in available # Test with non-existent entity mock_other = MockEntity() mock_other._meta.model_name = 'other_entity' other_available = get_available_transitions(mock_other) - self.assertEqual(len(other_available), 0) + assert len(other_available) == 0 diff --git a/label_studio/fsm/tests/test_edge_cases_error_handling.py b/label_studio/fsm/tests/test_edge_cases_error_handling.py index 22c7c7799fa6..4f61e89a7d7f 100644 --- a/label_studio/fsm/tests/test_edge_cases_error_handling.py +++ b/label_studio/fsm/tests/test_edge_cases_error_handling.py @@ -13,6 +13,7 @@ from typing import Any, Dict from unittest.mock import Mock +import pytest from django.test import TestCase from fsm.registry import transition_registry from fsm.transition_utils import TransitionBuilder @@ -90,7 +91,7 @@ def test_none_and_empty_values_handling(self): # Test None values transition_none = EdgeCaseTransition(edge_case_data=None) - self.assertIsNone(transition_none.edge_case_data) + assert transition_none.edge_case_data is None context = TransitionContext( entity=self.mock_entity, @@ -100,33 +101,33 @@ def test_none_and_empty_values_handling(self): ) # Should handle None values gracefully - self.assertTrue(transition_none.validate_transition(context)) + assert transition_none.validate_transition(context) result = transition_none.transition(context) - self.assertIsNone(result['edge_case_data']) + assert result['edge_case_data'] is None # Test empty string values empty_transition = EdgeCaseTransition(edge_case_data='') result = empty_transition.transition(context) - self.assertEqual(result['edge_case_data'], '') + assert result['edge_case_data'] == '' # Test empty collections empty_list_transition = EdgeCaseTransition(edge_case_data=[]) result = empty_list_transition.transition(context) - self.assertEqual(result['edge_case_data'], []) + assert result['edge_case_data'] == [] empty_dict_transition = EdgeCaseTransition(edge_case_data={}) result = empty_dict_transition.transition(context) - self.assertEqual(result['edge_case_data'], {}) + assert result['edge_case_data'] == {} # Test zero values zero_transition = EdgeCaseTransition(edge_case_data=0) result = zero_transition.transition(context) - self.assertEqual(result['edge_case_data'], 0) + assert result['edge_case_data'] == 0 # Test False boolean false_transition = EdgeCaseTransition(edge_case_data=False) result = false_transition.transition(context) - self.assertFalse(result['edge_case_data']) + assert not result['edge_case_data'] def test_extreme_data_sizes(self): """ @@ -145,7 +146,7 @@ def test_extreme_data_sizes(self): ) result = large_string_transition.transition(context) - self.assertEqual(len(result['edge_case_data']), 10000) + assert len(result['edge_case_data']) == 10000 # Test deeply nested dictionary deep_dict = {'level': 0} @@ -156,14 +157,14 @@ def test_extreme_data_sizes(self): deep_dict_transition = EdgeCaseTransition(edge_case_data=deep_dict) result = deep_dict_transition.transition(context) - self.assertEqual(result['edge_case_data']['level'], 0) + assert result['edge_case_data']['level'] == 0 # Test large list large_list = list(range(1000)) # 1000 items large_list_transition = EdgeCaseTransition(edge_case_data=large_list) result = large_list_transition.transition(context) - self.assertEqual(len(result['edge_case_data']), 1000) - self.assertEqual(result['edge_case_data'][-1], 999) + assert len(result['edge_case_data']) == 1000 + assert result['edge_case_data'][-1] == 999 def test_unicode_and_special_characters(self): """ @@ -204,7 +205,7 @@ def test_unicode_and_special_characters(self): # Should handle any Unicode string result = transition.transition(context) - self.assertEqual(result['edge_case_data'], test_string) + assert result['edge_case_data'] == test_string def test_boundary_datetime_values(self): """ @@ -244,7 +245,7 @@ def test_boundary_datetime_values(self): # Should handle any valid datetime result = transition.transition(context) - self.assertEqual(result['processed_at'], test_datetime.isoformat()) + assert result['processed_at'] == test_datetime.isoformat() def test_circular_reference_handling(self): """ @@ -263,11 +264,11 @@ def test_circular_reference_handling(self): try: transition = EdgeCaseTransition(edge_case_data=circular_dict) # Verify that the circular reference was stored - self.assertEqual(transition.edge_case_data['name'], 'parent') - self.assertEqual(transition.edge_case_data['child']['name'], 'child') + assert transition.edge_case_data['name'] == 'parent' + assert transition.edge_case_data['child']['name'] == 'child' # The system should handle this gracefully except RecursionError: - self.fail('System should handle circular references without infinite recursion') + pytest.fail('System should handle circular references without infinite recursion') # Test with complex but non-circular structure complex_structure = { @@ -281,7 +282,7 @@ def test_circular_reference_handling(self): ) result = transition.transition(context) - self.assertEqual(result['edge_case_data']['level1']['level2']['level3']['data'], 'deep_value') + assert result['edge_case_data']['level1']['level2']['level3']['data'] == 'deep_value' def test_memory_pressure_and_cleanup(self): """ @@ -314,8 +315,8 @@ def test_memory_pressure_and_cleanup(self): weak_refs.append(weakref.ref(context)) # Verify all were created - self.assertEqual(len(transitions), 1000) - self.assertEqual(len(contexts), 1000) + assert len(transitions) == 1000 + assert len(contexts) == 1000 # Clear references and force garbage collection transitions.clear() @@ -335,7 +336,7 @@ def test_memory_pressure_and_cleanup(self): ) result = new_transition.transition(new_context) - self.assertEqual(result['edge_case_data'], 'after_cleanup') + assert result['edge_case_data'] == 'after_cleanup' def test_exception_during_validation(self): """ @@ -382,9 +383,9 @@ def transition(self, context: TransitionContext) -> dict: # Test TransitionValidationError (expected) transition = ValidationErrorTransition(error_type='transition_validation') - with self.assertRaises(TransitionValidationError) as cm: + with pytest.raises(TransitionValidationError) as cm: transition.validate_transition(context) - self.assertIn('Business rule violation', str(cm.exception)) + assert 'Business rule violation' in str(cm.value) # Test other exceptions (should bubble up) error_types = [ @@ -398,7 +399,7 @@ def transition(self, context: TransitionContext) -> dict: for error_type, exception_class in error_types: with self.subTest(error_type=error_type): transition = ValidationErrorTransition(error_type=error_type) - with self.assertRaises(exception_class): + with pytest.raises(exception_class): transition.validate_transition(context) def test_exception_during_transition_execution(self): @@ -415,14 +416,14 @@ def test_exception_during_transition_execution(self): # Test successful execution success_transition = ErrorProneTransition(should_fail='no') result = success_transition.transition(context) - self.assertEqual(result['should_fail'], 'no') + assert result['should_fail'] == 'no' # Test intentional failure fail_transition = ErrorProneTransition(should_fail='yes', failure_stage='transition') - with self.assertRaises(RuntimeError) as cm: + with pytest.raises(RuntimeError) as cm: fail_transition.transition(context) - self.assertIn('Intentional transition failure', str(cm.exception)) + assert 'Intentional transition failure' in str(cm.value) def test_registry_edge_cases(self): """ @@ -447,7 +448,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Should get new class retrieved = transition_registry.get_transition('test_entity', 'edge_case') - self.assertEqual(retrieved, NewEdgeCaseTransition) + assert retrieved == NewEdgeCaseTransition # Test registration with unusual names unusual_names = [ @@ -462,15 +463,15 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: with self.subTest(entity=entity_name, transition=transition_name): transition_registry.register(entity_name, transition_name, EdgeCaseTransition) retrieved = transition_registry.get_transition(entity_name, transition_name) - self.assertEqual(retrieved, EdgeCaseTransition) + assert retrieved == EdgeCaseTransition # Test nonexistent lookups - self.assertIsNone(transition_registry.get_transition('nonexistent', 'transition')) - self.assertIsNone(transition_registry.get_transition('test_entity', 'nonexistent')) + assert transition_registry.get_transition('nonexistent', 'transition') is None + assert transition_registry.get_transition('test_entity', 'nonexistent') is None # Test empty entity transitions empty_transitions = transition_registry.get_transitions_for_entity('nonexistent_entity') - self.assertEqual(empty_transitions, {}) + assert empty_transitions == {} def test_context_edge_cases(self): """ @@ -485,13 +486,13 @@ def test_context_edge_cases(self): try: context = TransitionContext(entity=None, current_state='CREATED', target_state='PROCESSED') # Verify context was created with None entity - self.assertIsNone(context.entity) - self.assertEqual(context.current_state, 'CREATED') + assert context.entity is None + assert context.current_state == 'CREATED' except Exception as e: - self.fail(f'Context creation with None entity should not fail: {e}') + pytest.fail(f'Context creation with None entity should not fail: {e}') # Test context with missing required fields - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): TransitionContext( entity=self.mock_entity, # Missing target_state @@ -503,7 +504,7 @@ def test_context_edge_cases(self): entity=self.mock_entity, current_state='CREATED', target_state='PROCESSED', timestamp=far_future ) - self.assertEqual(context.timestamp, far_future) + assert context.timestamp == far_future # Test context with large metadata large_metadata = {f'key_{i}': f'value_{i}' for i in range(1000)} @@ -511,7 +512,7 @@ def test_context_edge_cases(self): entity=self.mock_entity, current_state='CREATED', target_state='PROCESSED', metadata=large_metadata ) - self.assertEqual(len(context.metadata), 1000) + assert len(context.metadata) == 1000 # Test context property edge cases empty_context = TransitionContext( @@ -519,8 +520,8 @@ def test_context_edge_cases(self): ) # Empty string should be considered "has state" - self.assertTrue(empty_context.has_current_state) - self.assertFalse(empty_context.is_initial_transition) + assert empty_context.has_current_state + assert not empty_context.is_initial_transition def test_transition_builder_edge_cases(self): """ @@ -533,21 +534,21 @@ def test_transition_builder_edge_cases(self): builder = TransitionBuilder(self.mock_entity) # Test validation without setting transition name - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: builder.validate() - self.assertIn('Transition name not specified', str(cm.exception)) + assert 'Transition name not specified' in str(cm.value) # Test execution without setting transition name - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: builder.execute() - self.assertIn('Transition name not specified', str(cm.exception)) + assert 'Transition name not specified' in str(cm.value) # Test with nonexistent transition builder.transition('nonexistent_transition') - with self.assertRaises(ValueError) as cm: + with pytest.raises(ValueError) as cm: builder.validate() - self.assertIn('not found', str(cm.exception)) + assert 'not found' in str(cm.value) # Test method chaining edge cases builder = ( @@ -560,7 +561,7 @@ def test_transition_builder_edge_cases(self): # Should not raise errors for empty data errors = builder.validate() - self.assertEqual(errors, {}) # EdgeCaseTransition has no required fields + assert errors == {} # EdgeCaseTransition has no required fields # Test data overwriting builder = ( @@ -571,7 +572,7 @@ def test_transition_builder_edge_cases(self): ) # Should overwrite errors = builder.validate() - self.assertEqual(errors, {}) + assert errors == {} def test_concurrent_error_scenarios(self): """ @@ -621,14 +622,14 @@ def error_worker(worker_id): thread.join() # Should have 10 errors - self.assertEqual(len(error_results), 10) + assert len(error_results) == 10 # Verify error types validation_errors = [r for r in error_results if r['error_type'] == 'TransitionValidationError'] runtime_errors = [r for r in error_results if r['error_type'] == 'RuntimeError'] - self.assertEqual(len(validation_errors), 5) # Even worker IDs - self.assertEqual(len(runtime_errors), 5) # Odd worker IDs + assert len(validation_errors) == 5 # Even worker IDs + assert len(runtime_errors) == 5 # Odd worker IDs def test_resource_cleanup_after_errors(self): """ @@ -683,17 +684,17 @@ def __del__(self): entity=self.mock_entity, current_state='CREATED', target_state=success_transition.target_state ) - self.assertTrue(success_transition.validate_transition(context)) - self.assertEqual(len(success_transition.resources_allocated), 1) + assert success_transition.validate_transition(context) + assert len(success_transition.resources_allocated) == 1 # Test failure case fail_transition = ResourceTrackingTransition(resource_name='fail_test') - with self.assertRaises(TransitionValidationError): + with pytest.raises(TransitionValidationError): fail_transition.validate_transition(context) # Resources should still be allocated even though validation failed - self.assertEqual(len(fail_transition.resources_allocated), 1) + assert len(fail_transition.resources_allocated) == 1 # Force garbage collection to trigger cleanup weakref.ref(success_transition) diff --git a/label_studio/fsm/tests/test_performance_concurrency.py b/label_studio/fsm/tests/test_performance_concurrency.py index 25a677e12986..ab147f7abc3b 100644 --- a/label_studio/fsm/tests/test_performance_concurrency.py +++ b/label_studio/fsm/tests/test_performance_concurrency.py @@ -135,16 +135,16 @@ def test_single_transition_performance(self): result = transition.validate_transition(context) validation_time = time.perf_counter() - start_time - self.assertTrue(result) - self.assertLess(validation_time, 0.001) # Should be under 1ms + assert result + assert validation_time < 0.001 # Should be under 1ms # Measure transition execution performance start_time = time.perf_counter() transition_data = transition.transition(context) execution_time = time.perf_counter() - start_time - self.assertIsInstance(transition_data, dict) - self.assertLess(execution_time, 0.001) # Should be under 1ms + assert isinstance(transition_data, dict) + assert execution_time < 0.001 # Should be under 1ms # Measure total workflow performance start_time = time.perf_counter() @@ -153,7 +153,7 @@ def test_single_transition_performance(self): transition.transition(context) total_time = time.perf_counter() - start_time - self.assertLess(total_time, 0.005) # Total should be under 5ms + assert total_time < 0.005 # Total should be under 5ms def test_batch_transition_performance(self): """ @@ -175,8 +175,8 @@ def test_batch_transition_performance(self): creation_time = time.perf_counter() - start_time creation_time_per_item = creation_time / batch_size - self.assertEqual(len(transitions), batch_size) - self.assertLess(creation_time_per_item, 0.001) # Under 1ms per transition + assert len(transitions) == batch_size + assert creation_time_per_item < 0.001 # Under 1ms per transition # Test batch validation performance context = TransitionContext( @@ -193,9 +193,9 @@ def test_batch_transition_performance(self): validation_time = time.perf_counter() - start_time validation_time_per_item = validation_time / batch_size - self.assertTrue(all(validation_results)) - self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation - self.assertLess(validation_time, 0.5) # Total batch under 500ms + assert all(validation_results) + assert validation_time_per_item < 0.001 # Under 1ms per validation + assert validation_time < 0.5 # Total batch under 500ms def test_registry_performance(self): """ @@ -215,8 +215,8 @@ def test_registry_performance(self): lookup_time = time.perf_counter() - start_time lookup_time_per_operation = lookup_time / lookup_count - self.assertEqual(retrieved_class, PerformanceTestTransition) - self.assertLess(lookup_time_per_operation, 0.0001) # Under 0.1ms per lookup + assert retrieved_class == PerformanceTestTransition + assert lookup_time_per_operation < 0.0001 # Under 0.1ms per lookup # Test registry registration performance registration_count = 1000 @@ -231,11 +231,11 @@ def test_registry_performance(self): registration_time = time.perf_counter() - start_time registration_time_per_operation = registration_time / registration_count - self.assertLess(registration_time_per_operation, 0.001) # Under 1ms per registration + assert registration_time_per_operation < 0.001 # Under 1ms per registration # Verify registrations worked test_class = transition_registry.get_transition('entity_500', 'transition_500') - self.assertEqual(test_class, PerformanceTestTransition) + assert test_class == PerformanceTestTransition def test_pydantic_validation_performance(self): """ @@ -256,7 +256,7 @@ def test_pydantic_validation_performance(self): validation_time = time.perf_counter() - start_time validation_time_per_item = validation_time / validation_count - self.assertLess(validation_time_per_item, 0.001) # Under 1ms per validation + assert validation_time_per_item < 0.001 # Under 1ms per validation # Test validation error performance invalid_data = {'operation_id': 'invalid', 'data_size': -1} @@ -274,8 +274,8 @@ def test_pydantic_validation_performance(self): error_time = time.perf_counter() - start_time error_time_per_item = error_time / error_count - self.assertEqual(len(errors), error_count) - self.assertLess(error_time_per_item, 0.01) # Under 10ms per error (errors are slower) + assert len(errors) == error_count + assert error_time_per_item < 0.01 # Under 10ms per error (errors are slower) def test_memory_usage_patterns(self): """ @@ -313,14 +313,14 @@ def test_memory_usage_patterns(self): # Memory usage should be reasonable memory_overhead = complex_size - base_size - self.assertLess(memory_overhead, 10000) # Under 10KB overhead per transition + assert memory_overhead < 10000 # Under 10KB overhead per transition # Clean up contexts to test garbage collection for transition in complex_transitions: transition.context = None # Verify memory can be reclaimed (simplified test) - self.assertIsNone(complex_transitions[0].context) + assert complex_transitions[0].context is None class ConcurrencyTests(TransactionTestCase): @@ -380,17 +380,17 @@ def create_transitions(thread_id): # Validate results total_expected = thread_count * transitions_per_thread - self.assertEqual(len(all_transitions), total_expected) + assert len(all_transitions) == total_expected # Check thread separation thread_ids = [t.thread_id for t in all_transitions] unique_threads = set(thread_ids) - self.assertEqual(len(unique_threads), thread_count) + assert len(unique_threads) == thread_count # Validate each thread created correct number of transitions for thread_id in range(thread_count): thread_transitions = [t for t in all_transitions if t.thread_id == thread_id] - self.assertEqual(len(thread_transitions), transitions_per_thread) + assert len(thread_transitions) == transitions_per_thread def test_concurrent_transition_execution(self): """ @@ -442,17 +442,17 @@ def execute_transition(thread_id): execution_results.append(result) # Validate results - self.assertEqual(len(execution_results), thread_count) + assert len(execution_results) == thread_count for result in execution_results: - self.assertTrue(result['validation_result']) - self.assertIn('thread_id', result['transition_data']) - self.assertIsInstance(result['execution_order'], list) - self.assertGreater(len(result['execution_order']), 0) + assert result['validation_result'] + assert 'thread_id' in result['transition_data'] + assert isinstance(result['execution_order'], list) + assert len(result['execution_order']) > 0 # Check thread isolation thread_ids = [r['transition_data']['thread_id'] for r in execution_results] - self.assertEqual(set(thread_ids), set(range(thread_count))) + assert set(thread_ids) == set(range(thread_count)) def test_registry_thread_safety(self): """ @@ -514,12 +514,12 @@ def registry_operations(thread_id): total_operations = sum(operation_counts) expected_minimum = thread_count * operations_per_thread * 0.9 # Allow some variance - self.assertGreater(total_operations, expected_minimum) + assert total_operations > expected_minimum # Registry should be in consistent state entities = transition_registry.list_entities() - self.assertIsInstance(entities, list) - self.assertGreater(len(entities), thread_count) # Should have entities from all threads + assert isinstance(entities, list) + assert len(entities) > thread_count # Should have entities from all threads def test_context_isolation(self): """ @@ -586,7 +586,7 @@ def context_isolation_test(thread_id): context_data.append(result) # Validate context isolation - self.assertEqual(len(context_data), thread_count) + assert len(context_data) == thread_count for result in context_data: thread_id = result['thread_id'] @@ -594,17 +594,17 @@ def context_isolation_test(thread_id): retrieved_metadata = result['retrieved_metadata'] # Context should match exactly what was set for this thread - self.assertEqual(original_metadata['thread_specific_id'], thread_id) - self.assertEqual(retrieved_metadata['thread_specific_id'], thread_id) - self.assertEqual(original_metadata['random_data'], retrieved_metadata['random_data']) - self.assertEqual(original_metadata['test_counter'], thread_id * 1000) + assert original_metadata['thread_specific_id'] == thread_id + assert retrieved_metadata['thread_specific_id'] == thread_id + assert original_metadata['random_data'] == retrieved_metadata['random_data'] + assert original_metadata['test_counter'] == thread_id * 1000 # Should not have data from other threads for other_result in context_data: if other_result['thread_id'] != thread_id: - self.assertNotEqual( - retrieved_metadata['thread_specific_id'], - other_result['original_metadata']['thread_specific_id'], + assert ( + retrieved_metadata['thread_specific_id'] + != other_result['original_metadata']['thread_specific_id'] ) def test_stress_test_mixed_operations(self): @@ -712,14 +712,14 @@ def mixed_operations_worker(worker_id): ) # Should have performed substantial work - self.assertGreater(total_operations, thread_count * 10) + assert total_operations > thread_count * 10 # Error rate should be very low (< 1%) error_rate = stats['errors_encountered'] / max(total_operations, 1) - self.assertLess(error_rate, 0.01) + assert error_rate < 0.01 # All operation types should have been performed - self.assertGreater(stats['transitions_created'], 0) - self.assertGreater(stats['validations_performed'], 0) - self.assertGreater(stats['transitions_executed'], 0) - self.assertGreater(stats['registry_lookups'], 0) + assert stats['transitions_created'] > 0 + assert stats['validations_performed'] > 0 + assert stats['transitions_executed'] > 0 + assert stats['registry_lookups'] > 0 diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_uuid7_utils.py index af6abe7d6956..5bd95fb1a646 100644 --- a/label_studio/fsm/tests/test_uuid7_utils.py +++ b/label_studio/fsm/tests/test_uuid7_utils.py @@ -26,13 +26,13 @@ def test_generate_uuid7(self): uuid7_id = generate_uuid7() # Check that it's a valid UUID - self.assertIsInstance(uuid7_id, uuid.UUID) + assert isinstance(uuid7_id, uuid.UUID) # Check that it's version 7 - self.assertEqual(uuid7_id.version, 7) + assert uuid7_id.version == 7 # Check that it validates as UUID7 - self.assertTrue(validate_uuid7(uuid7_id)) + assert validate_uuid7(uuid7_id) def test_uuid7_ordering(self): """Test that UUID7s have natural time ordering""" @@ -40,7 +40,7 @@ def test_uuid7_ordering(self): uuid2 = generate_uuid7() # UUID7s should be ordered by generation time - self.assertLess(uuid1.int, uuid2.int) + assert uuid1.int < uuid2.int def test_timestamp_extraction(self): """Test timestamp extraction from UUID7""" @@ -55,8 +55,8 @@ def test_timestamp_extraction(self): time_diff_before = abs((extracted_timestamp - before).total_seconds()) time_diff_after = abs((extracted_timestamp - after).total_seconds()) - self.assertLess(time_diff_before, 1.0) # Within 1 second of before - self.assertLess(time_diff_after, 1.0) # Within 1 second of after + assert time_diff_before < 1.0 # Within 1 second of before + assert time_diff_after < 1.0 # Within 1 second of after def test_uuid7_from_timestamp(self): """Test creating UUID7 from specific timestamp""" @@ -64,12 +64,12 @@ def test_uuid7_from_timestamp(self): uuid7_id = uuid7_from_timestamp(test_time) # Should be a valid UUID7 - self.assertTrue(validate_uuid7(uuid7_id)) + assert validate_uuid7(uuid7_id) # Extracted timestamp should match (within millisecond precision) extracted = timestamp_from_uuid7(uuid7_id) time_diff = abs((extracted - test_time).total_seconds()) - self.assertLess(time_diff, 0.001) # Within 1ms + assert time_diff < 0.001 # Within 1ms def test_uuid7_time_range(self): """Test UUID7 time range generation""" @@ -79,19 +79,19 @@ def test_uuid7_time_range(self): start_uuid, end_uuid = uuid7_time_range(start_time, end_time) # Both should be valid UUID7s - self.assertTrue(validate_uuid7(start_uuid)) - self.assertTrue(validate_uuid7(end_uuid)) + assert validate_uuid7(start_uuid) + assert validate_uuid7(end_uuid) # Start should be less than end - self.assertLess(start_uuid.int, end_uuid.int) + assert start_uuid.int < end_uuid.int # Timestamps should match input times (with 1ms buffer tolerance) start_extracted = timestamp_from_uuid7(start_uuid) end_extracted = timestamp_from_uuid7(end_uuid) # Account for 1ms buffer added in uuid7_time_range - self.assertLess(abs((start_extracted - start_time).total_seconds()), 0.002) - self.assertLess(abs((end_extracted - end_time).total_seconds()), 0.002) + assert abs((start_extracted - start_time).total_seconds()) < 0.002 + assert abs((end_extracted - end_time).total_seconds()) < 0.002 def test_uuid7_time_range_default_end(self): """Test UUID7 time range with default end time (now)""" @@ -107,18 +107,18 @@ def test_uuid7_time_range_default_end(self): time_diff_before = abs((end_extracted - before_call).total_seconds()) time_diff_after = abs((end_extracted - after_call).total_seconds()) - self.assertLess(time_diff_before, 1.0) # Within 1 second of before_call - self.assertLess(time_diff_after, 1.0) # Within 1 second of after_call + assert time_diff_before < 1.0 # Within 1 second of before_call + assert time_diff_after < 1.0 # Within 1 second of after_call def test_validate_uuid7_with_other_versions(self): """Test UUID7 validation with other UUID versions""" # Test with UUID4 uuid4_id = uuid.uuid4() - self.assertFalse(validate_uuid7(uuid4_id)) + assert not validate_uuid7(uuid4_id) # Test with UUID7 uuid7_id = generate_uuid7() - self.assertTrue(validate_uuid7(uuid7_id)) + assert validate_uuid7(uuid7_id) class TestUUID7Generator(TestCase): @@ -129,7 +129,7 @@ def test_generator_basic(self): generator = UUID7Generator() uuid7_id = generator.generate() - self.assertTrue(validate_uuid7(uuid7_id)) + assert validate_uuid7(uuid7_id) def test_generator_with_base_timestamp(self): """Test generator with custom base timestamp""" @@ -141,7 +141,7 @@ def test_generator_with_base_timestamp(self): # Should be close to base time time_diff = abs((extracted - base_time).total_seconds()) - self.assertLess(time_diff, 0.001) + assert time_diff < 0.001 def test_generator_with_offset(self): """Test generator with timestamp offset""" @@ -154,7 +154,7 @@ def test_generator_with_offset(self): expected_time = base_time + timedelta(milliseconds=1000) time_diff = abs((extracted - expected_time).total_seconds()) - self.assertLess(time_diff, 0.001) + assert time_diff < 0.001 def test_generator_monotonic(self): """Test that generator produces monotonic UUIDs""" @@ -167,5 +167,5 @@ def test_generator_monotonic(self): uuid3 = generator.generate(offset_ms=100) # Should be monotonic even with same timestamp - self.assertLess(uuid1.int, uuid2.int) - self.assertLess(uuid2.int, uuid3.int) + assert uuid1.int < uuid2.int + assert uuid2.int < uuid3.int From 2de2dec1b7a7391c962eac348e7ffdb6b44a9b99 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 12:38:18 -0500 Subject: [PATCH 42/83] updating based on feedback about test file names --- .../{test_edge_cases_error_handling.py => test_error_handling.py} | 0 label_studio/fsm/tests/{test_uuid7_utils.py => test_utils.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename label_studio/fsm/tests/{test_edge_cases_error_handling.py => test_error_handling.py} (100%) rename label_studio/fsm/tests/{test_uuid7_utils.py => test_utils.py} (100%) diff --git a/label_studio/fsm/tests/test_edge_cases_error_handling.py b/label_studio/fsm/tests/test_error_handling.py similarity index 100% rename from label_studio/fsm/tests/test_edge_cases_error_handling.py rename to label_studio/fsm/tests/test_error_handling.py diff --git a/label_studio/fsm/tests/test_uuid7_utils.py b/label_studio/fsm/tests/test_utils.py similarity index 100% rename from label_studio/fsm/tests/test_uuid7_utils.py rename to label_studio/fsm/tests/test_utils.py From d0ea3be993e4868033d985a50e054c4d607b8158 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 12:56:01 -0500 Subject: [PATCH 43/83] consolidate transition util get_valid_states to being the same function get_available_states --- .../fsm/tests/test_declarative_transitions.py | 69 +++++++++++++++++++ label_studio/fsm/transition_utils.py | 43 ++++++------ 2 files changed, 91 insertions(+), 21 deletions(-) diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index d0811f888a37..47ccdc4603e2 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -356,3 +356,72 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: mock_other._meta.model_name = 'other_entity' other_available = get_available_transitions(mock_other) assert len(other_available) == 0 + + def test_get_available_transitions_with_validation(self): + """Test the validation behavior of get_available_transitions""" + from unittest.mock import Mock, patch + + from fsm.state_manager import StateManager + + @register_state_transition('test_entity', 'validation_test_1') + class ValidationTestTransition1(BaseTransition): + @property + def target_state(self) -> str: + return TestStateChoices.IN_PROGRESS + + @classmethod + def can_transition_from_state(cls, context) -> bool: + # Only allow from CREATED state + return context.current_state == TestStateChoices.CREATED + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {} + + @register_state_transition('test_entity', 'validation_test_2') + class ValidationTestTransition2(BaseTransition): + @property + def target_state(self) -> str: + return TestStateChoices.COMPLETED + + @classmethod + def can_transition_from_state(cls, context) -> bool: + # Only allow from IN_PROGRESS state + return context.current_state == TestStateChoices.IN_PROGRESS + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {} + + # Test validate=False (should return all registered transitions) + all_available = get_available_transitions(self.mock_entity, validate=False) + assert len(all_available) == 2 + assert 'validation_test_1' in all_available + assert 'validation_test_2' in all_available + + # Mock current state as CREATED + mock_state_object = Mock() + mock_state_object.state = TestStateChoices.CREATED + + with patch.object(StateManager, 'get_current_state_object', return_value=mock_state_object): + # Test validate=True with CREATED state (should only return validation_test_1) + valid_transitions = get_available_transitions(self.mock_entity, validate=True) + assert len(valid_transitions) == 1 + assert 'validation_test_1' in valid_transitions + assert 'validation_test_2' not in valid_transitions + + # Mock current state as IN_PROGRESS + mock_state_object.state = TestStateChoices.IN_PROGRESS + + with patch.object(StateManager, 'get_current_state_object', return_value=mock_state_object): + # Test validate=True with IN_PROGRESS state (should only return validation_test_2) + valid_transitions = get_available_transitions(self.mock_entity, validate=True) + assert len(valid_transitions) == 1 + assert 'validation_test_2' in valid_transitions + assert 'validation_test_1' not in valid_transitions + + # Mock current state as COMPLETED + mock_state_object.state = TestStateChoices.COMPLETED + + with patch.object(StateManager, 'get_current_state_object', return_value=mock_state_object): + # Test validate=True with COMPLETED state (should return no transitions) + valid_transitions = get_available_transitions(self.mock_entity, validate=True) + assert len(valid_transitions) == 0 diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 15319b800f75..1fb00f5bd928 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -66,34 +66,24 @@ def execute_transition_instance(entity: Model, transition: BaseTransition, user= ) -def get_available_transitions(entity: Model) -> Dict[str, Type[BaseTransition]]: +def get_available_transitions(entity: Model, user=None, validate: bool = False) -> Dict[str, Type[BaseTransition]]: """ - Get all available transitions for an entity. + Get available transitions for an entity. Args: entity: The entity to get transitions for + user: User context for validation (only used when validate=True) + validate: Whether to validate each transition against current state. + When False, returns all registered transitions for the entity type. + When True, filters to only transitions valid from current state (may be expensive). Returns: - Dictionary mapping transition names to transition classes + Dictionary mapping transition names to transition classes. + When validate=False: All registered transitions for the entity type. + When validate=True: Only transitions valid for the current state. """ entity_name = entity._meta.model_name.lower() - return transition_registry.get_transitions_for_entity(entity_name) - - -def get_valid_transitions(entity: Model, user=None, validate: bool = True) -> Dict[str, Type[BaseTransition]]: - """ - Get transitions that are valid for the entity's current state. - - Args: - entity: The entity to check transitions for - user: User context for validation - validate: Whether to validate each transition (may be expensive) - - Returns: - Dictionary mapping transition names to transition classes - that are valid for the current state - """ - available = get_available_transitions(entity) + available = transition_registry.get_transitions_for_entity(entity_name) if not validate: return available @@ -109,12 +99,23 @@ def get_valid_transitions(entity: Model, user=None, validate: bool = True) -> Di # Build minimal context for validation from .transitions import TransitionContext + # Get target state from class or instance + target_state = transition_class.get_target_state() + if target_state is None: + # Need to create an instance to get target_state + try: + temp_instance = transition_class() + target_state = temp_instance.target_state + except Exception: + # Can't create instance, skip this transition + continue + context = TransitionContext( entity=entity, current_user=user, current_state_object=current_state_object, current_state=current_state, - target_state=transition_class.get_target_state(), + target_state=target_state, organization_id=getattr(entity, 'organization_id', None), ) From 458e4072f11a55b06820390de255d4c6b7818227 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 12:59:55 -0500 Subject: [PATCH 44/83] improving transition handling --- .../fsm/tests/test_declarative_transitions.py | 41 +++++++++++++++++++ label_studio/fsm/transition_utils.py | 21 ++++++++-- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_declarative_transitions.py index 47ccdc4603e2..83dd8a1886ad 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_declarative_transitions.py @@ -425,3 +425,44 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test validate=True with COMPLETED state (should return no transitions) valid_transitions = get_available_transitions(self.mock_entity, validate=True) assert len(valid_transitions) == 0 + + def test_get_available_transitions_with_required_fields(self): + """Test that transitions with required fields are handled correctly during validation""" + from unittest.mock import Mock, patch + + from fsm.state_manager import StateManager + + @register_state_transition('test_entity', 'required_field_transition') + class RequiredFieldTransition(BaseTransition): + required_field: str = Field(..., description='This field is required') + + @property + def target_state(self) -> str: + return TestStateChoices.IN_PROGRESS + + @classmethod + def can_transition_from_state(cls, context) -> bool: + # This should never be called since we can't instantiate without required_field + return True + + def transition(self, context: TransitionContext) -> Dict[str, Any]: + return {'required_field': self.required_field} + + # Test validate=False (should return the transition even though it has required fields) + all_available = get_available_transitions(self.mock_entity, validate=False) + assert 'required_field_transition' in all_available + + # Mock current state + mock_state_object = Mock() + mock_state_object.state = TestStateChoices.CREATED + + with patch.object(StateManager, 'get_current_state_object', return_value=mock_state_object): + # Test validate=True - should include transitions that can't be instantiated for validation + # This is the behavior: we can't validate transitions with required fields + # without knowing what data will be provided, so we include them as "available" + valid_transitions = get_available_transitions(self.mock_entity, validate=True) + + # The transition should be included since we can't validate it (better to be permissive) + # This avoids false negatives where valid transitions appear unavailable due to + # validation limitations + assert 'required_field_transition' in valid_transitions diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 1fb00f5bd928..a108172f4f2e 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -103,11 +103,14 @@ def get_available_transitions(entity: Model, user=None, validate: bool = False) target_state = transition_class.get_target_state() if target_state is None: # Need to create an instance to get target_state + # For validation purposes, we try to create with minimal/default data try: temp_instance = transition_class() target_state = temp_instance.target_state - except Exception: - # Can't create instance, skip this transition + except (TypeError, ValueError): + # Can't instantiate without required data - include in results + # since we can't validate state transitions, we assume they're available + valid_transitions[name] = transition_class continue context = TransitionContext( @@ -123,8 +126,18 @@ def get_available_transitions(entity: Model, user=None, validate: bool = False) if transition_class.can_transition_from_state(context): valid_transitions[name] = transition_class - except (TransitionValidationError, Exception): - # Transition is not valid for current state/context + except TransitionValidationError: + # Transition is not valid for current state/context - this is expected + continue + except Exception as e: + # Unexpected error during validation - this should be investigated + import logging + + logger = logging.getLogger(__name__) + logger.warning( + f"Unexpected error validating transition '{name}' for entity {entity._meta.model_name}: {e}", + exc_info=True, + ) continue return valid_transitions From 1e1771671788912abd6b991a8571fd24f46bade0 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 14:22:10 -0500 Subject: [PATCH 45/83] improving transition handling by consolidating everything to use a single entry point --- label_studio/fsm/README.md | 24 +--- label_studio/fsm/state_manager.py | 37 +++++ label_studio/fsm/tests/test_error_handling.py | 62 +++----- ...ive_transitions.py => test_transitions.py} | 27 ++-- label_studio/fsm/transition_utils.py | 136 +----------------- 5 files changed, 74 insertions(+), 212 deletions(-) rename label_studio/fsm/tests/{test_declarative_transitions.py => test_transitions.py} (95%) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 6925b7fc44e2..995027e68da1 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -121,10 +121,10 @@ class ProcessOrderTransition(BaseTransition): ### 5. Execute Transitions ```python -from fsm.transition_utils import execute_transition +from fsm.state_manager import StateManager -# Execute transition -result = execute_transition( +# Execute transition - this is the only way to execute transitions +result = StateManager.execute_transition( entity=order, transition_name='process_order', transition_data={'processor_id': 123, 'priority': 'high'}, @@ -284,16 +284,15 @@ if context.has_current_state: ### Transition Utilities ```python +from fsm.state_manager import StateManager from fsm.transition_utils import ( - execute_transition, get_available_transitions, get_transition_schema, validate_transition_data, - TransitionBuilder, ) -# Execute a registered transition -result = execute_transition( +# Execute a transition - the only way to execute transitions +result = StateManager.execute_transition( entity=task, transition_name='start_task', transition_data={'assigned_user_id': 123}, @@ -308,17 +307,6 @@ schema = get_transition_schema(StartTaskTransition) # Validate transition data before execution errors = validate_transition_data(StartTaskTransition, data) - -# Use TransitionBuilder for fluent API -builder = (TransitionBuilder(task) - .transition('start_task') - .with_data(assigned_user_id=123) - .by_user(user) - .with_context(source='api')) - -errors = builder.validate() -if not errors: - state = builder.execute() ``` ## Extension Points diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 8f9b839eec5c..7e8394082e1e 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -362,6 +362,43 @@ def execute_declarative_transition( ) raise + @classmethod + def execute_transition( + cls, entity: Model, transition_name: str, transition_data: Dict[str, Any] = None, user=None, **context_kwargs + ) -> BaseState: + """ + Execute a registered transition by name. + + This is the unified entry point for all state transitions using the declarative system. + + Args: + entity: The entity to transition + transition_name: Name of the registered transition + transition_data: Data for the transition (validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found + TransitionValidationError: If transition validation fails + """ + from .registry import transition_registry + + entity_name = entity._meta.model_name.lower() + transition_data = transition_data or {} + + return transition_registry.execute_transition( + entity_name=entity_name, + transition_name=transition_name, + entity=entity, + transition_data=transition_data, + user=user, + **context_kwargs, + ) + # Allow runtime configuration of which StateManager to use # Enterprise can set this to their extended implementation diff --git a/label_studio/fsm/tests/test_error_handling.py b/label_studio/fsm/tests/test_error_handling.py index 4f61e89a7d7f..1b89374c8df6 100644 --- a/label_studio/fsm/tests/test_error_handling.py +++ b/label_studio/fsm/tests/test_error_handling.py @@ -16,7 +16,6 @@ import pytest from django.test import TestCase from fsm.registry import transition_registry -from fsm.transition_utils import TransitionBuilder from fsm.transitions import BaseTransition, TransitionContext, TransitionValidationError from pydantic import Field, ValidationError @@ -523,56 +522,27 @@ def test_context_edge_cases(self): assert empty_context.has_current_state assert not empty_context.is_initial_transition - def test_transition_builder_edge_cases(self): + def test_state_manager_edge_cases(self): """ - EDGE CASE: TransitionBuilder edge cases + EDGE CASE: StateManager edge cases Tests unusual usage patterns and edge cases - with the fluent TransitionBuilder interface. + with the StateManager transition execution. """ - builder = TransitionBuilder(self.mock_entity) - - # Test validation without setting transition name - with pytest.raises(ValueError) as cm: - builder.validate() - assert 'Transition name not specified' in str(cm.value) - - # Test execution without setting transition name - with pytest.raises(ValueError) as cm: - builder.execute() - assert 'Transition name not specified' in str(cm.value) - - # Test with nonexistent transition - builder.transition('nonexistent_transition') - - with pytest.raises(ValueError) as cm: - builder.validate() - assert 'not found' in str(cm.value) - - # Test method chaining edge cases - builder = ( - TransitionBuilder(self.mock_entity) - .transition('edge_case') - .with_data() # Empty data - .by_user(None) # No user - .with_context() - ) # Empty context - - # Should not raise errors for empty data - errors = builder.validate() - assert errors == {} # EdgeCaseTransition has no required fields - - # Test data overwriting - builder = ( - TransitionBuilder(self.mock_entity) - .transition('edge_case') - .with_data(edge_case_data='first') - .with_data(edge_case_data='second') - ) # Should overwrite - - errors = builder.validate() - assert errors == {} + # Test with nonexistent transition in registry + from fsm.registry import transition_registry + + result = transition_registry.get_transition('test_entity', 'nonexistent_transition') + assert result is None # Should return None for nonexistent transition + + # Test execution with valid transition (test at registry level) + transition_class = transition_registry.get_transition('test_entity', 'edge_case') + assert transition_class is not None + + # Should be able to create instance with defaults + transition = transition_class() + assert transition.edge_case_data is None # Uses default None value def test_concurrent_error_scenarios(self): """ diff --git a/label_studio/fsm/tests/test_declarative_transitions.py b/label_studio/fsm/tests/test_transitions.py similarity index 95% rename from label_studio/fsm/tests/test_declarative_transitions.py rename to label_studio/fsm/tests/test_transitions.py index 83dd8a1886ad..dd53cff77a63 100644 --- a/label_studio/fsm/tests/test_declarative_transitions.py +++ b/label_studio/fsm/tests/test_transitions.py @@ -17,7 +17,6 @@ from django.utils.translation import gettext_lazy as _ from fsm.registry import register_state_transition, transition_registry from fsm.transition_utils import ( - TransitionBuilder, get_available_transitions, ) from fsm.transitions import ( @@ -235,11 +234,11 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: assert 'Can only complete from IN_PROGRESS state' in str(error) assert 'current_state' in error.context - def test_transition_builder_basic(self): - """Test TransitionBuilder basic functionality""" + def test_state_manager_transition_execution(self): + """Test StateManager-based transition execution""" - @register_state_transition('test_entity', 'builder_test') - class BuilderTestTransition(BaseTransition): + @register_state_transition('test_entity', 'state_manager_test') + class StateManagerTestTransition(BaseTransition): value: str = Field('default', description='Test value') @property @@ -249,16 +248,18 @@ def target_state(self) -> str: def transition(self, context: TransitionContext) -> Dict[str, Any]: return {'value': self.value} - # Test builder creation - builder = TransitionBuilder(self.mock_entity) - assert builder.entity == self.mock_entity + # Test StateManager execution using the registry directly (simpler test) + # This validates that the consolidated approach works through the registry + from fsm.registry import transition_registry - # Test method chaining - builder = builder.transition('builder_test').with_data(value='builder_test_value').by_user(self.user) + # Get the transition class + transition_class = transition_registry.get_transition('test_entity', 'state_manager_test') + assert transition_class is not None - # Validate the builder state - validation_errors = builder.validate() - assert len(validation_errors) == 0 + # Create instance and verify it works + transition = transition_class(value='state_manager_test_value') + assert transition.value == 'state_manager_test_value' + assert transition.target_state == TestStateChoices.COMPLETED def test_get_available_transitions(self): """Test get_available_transitions utility""" diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index a108172f4f2e..682259a85e06 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -5,67 +5,14 @@ the new Pydantic-based transition system with existing Label Studio code. """ -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Type from django.db.models import Model -from fsm.models import BaseState from fsm.registry import transition_registry from fsm.state_manager import StateManager from fsm.transitions import BaseTransition, TransitionValidationError -def execute_transition( - entity: Model, transition_name: str, transition_data: Dict[str, Any], user=None, **context_kwargs -) -> BaseState: - """ - Execute a named transition on an entity. - - This is a convenience function that looks up the transition class - and executes it with the provided data. - - Args: - entity: The entity to transition - transition_name: Name of the registered transition - transition_data: Data for the transition (validated by Pydantic) - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - - Raises: - ValueError: If transition is not found - TransitionValidationError: If transition validation fails - """ - entity_name = entity._meta.model_name.lower() - return transition_registry.execute_transition( - entity_name=entity_name, - transition_name=transition_name, - entity=entity, - transition_data=transition_data, - user=user, - **context_kwargs, - ) - - -def execute_transition_instance(entity: Model, transition: BaseTransition, user=None, **context_kwargs) -> BaseState: - """ - Execute a pre-created transition instance. - - Args: - entity: The entity to transition - transition: Instance of a transition class - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - """ - return StateManager.execute_declarative_transition( - transition=transition, entity=entity, user=user, **context_kwargs - ) - - def get_available_transitions(entity: Model, user=None, validate: bool = False) -> Dict[str, Type[BaseTransition]]: """ Get available transitions for an entity. @@ -247,84 +194,3 @@ def get_entity_state_flow(entity: Model) -> List[Dict[str, Any]]: continue return flows - - -# Backward compatibility helpers - - -def transition_state_declarative(entity: Model, transition_name: str, user=None, **transition_data) -> BaseState: - """ - Backward-compatible helper for transitioning state declaratively. - - This provides a similar interface to StateManager.transition_state - but uses the declarative system. - """ - return execute_transition( - entity=entity, transition_name=transition_name, transition_data=transition_data, user=user - ) - - -class TransitionBuilder: - """ - Builder class for constructing and executing transitions fluently. - - Example usage: - result = (TransitionBuilder(entity) - .transition('start_task') - .with_data(assigned_user_id=123, priority='high') - .by_user(request.user) - .execute()) - """ - - def __init__(self, entity: Model): - self.entity = entity - self._transition_name: Optional[str] = None - self._transition_data: Dict[str, Any] = {} - self._user = None - self._context_data: Dict[str, Any] = {} - - def transition(self, name: str) -> 'TransitionBuilder': - """Set the transition name""" - self._transition_name = name - return self - - def with_data(self, **data) -> 'TransitionBuilder': - """Add transition data""" - self._transition_data.update(data) - return self - - def by_user(self, user) -> 'TransitionBuilder': - """Set the executing user""" - self._user = user - return self - - def with_context(self, **context) -> 'TransitionBuilder': - """Add context data""" - self._context_data.update(context) - return self - - def execute(self) -> BaseState: - """Execute the configured transition""" - if not self._transition_name: - raise ValueError('Transition name not specified') - - return execute_transition( - entity=self.entity, - transition_name=self._transition_name, - transition_data=self._transition_data, - user=self._user, - **self._context_data, - ) - - def validate(self) -> Dict[str, List[str]]: - """Validate the configured transition without executing""" - if not self._transition_name: - raise ValueError('Transition name not specified') - - entity_name = self.entity._meta.model_name.lower() - transition_class = transition_registry.get_transition(entity_name, self._transition_name) - - if not transition_class: - raise ValueError(f"Transition '{self._transition_name}' not found for entity '{entity_name}'") - - return validate_transition_data(transition_class, self._transition_data) From 7552fbbb8103b1ec417663175a4fa578882c5b32 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 14:26:05 -0500 Subject: [PATCH 46/83] ensure write through cache sets on_commit of the transaction --- label_studio/fsm/state_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 7e8394082e1e..17bb4f0d26f4 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -212,9 +212,9 @@ def transition_state( **denormalized_fields, ) - # Update cache with new state + # Update cache with new state after transaction commits cache_key = cls.get_cache_key(entity) - cache.set(cache_key, new_state, cls.CACHE_TTL) + transaction.on_commit(lambda: cache.set(cache_key, new_state, cls.CACHE_TTL)) logger.info( f'State transition successful: {entity._meta.label_lower} {entity.pk} ' From f73583d17e427cf57890752c0e2e92db429b6a33 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 14:35:45 -0500 Subject: [PATCH 47/83] type fixes --- label_studio/fsm/registry.py | 44 ++++++++++++------------------------ 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index cd34d0bcf9a9..f1115016f954 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -13,12 +13,7 @@ if typing.TYPE_CHECKING: from fsm.models import BaseState - from fsm.transitions import BaseTransition, StateModelType, TransitionContext, User -else: - from fsm.transitions import BaseTransition, TransitionContext, User - - # Import StateModelType at runtime to avoid circular import - StateModelType = None + from fsm.transitions import BaseTransition, User logger = logging.getLogger(__name__) @@ -117,14 +112,13 @@ class StateModelRegistry: """ def __init__(self): - self._models: Dict[str, Type['BaseState']] = {} + self._models: Dict[str, 'BaseState'] = {} self._denormalizers: Dict[str, Callable[[Model], Dict[str, Any]]] = {} - self._initialized = False def register_model( self, entity_name: str, - state_model: Type['BaseState'], + state_model: 'BaseState', denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None, ): """ @@ -150,7 +144,7 @@ def register_model( logger.debug(f'Registered state model for {entity_key}: {state_model.__name__}') - def get_model(self, entity_name: str) -> Optional[Type['BaseState']]: + def get_model(self, entity_name: str) -> Optional['BaseState']: """ Get the state model for an entity type. @@ -206,19 +200,10 @@ def clear(self): self._initialized = False logger.debug('Cleared state model registry') - def get_all_models(self) -> Dict[str, Type['BaseState']]: + def get_all_models(self) -> Dict[str, 'BaseState']: """Get all registered models.""" return self._models.copy() - def mark_initialized(self): - """Mark the registry as initialized.""" - self._initialized = True - logger.info(f'State model registry initialized with {len(self._models)} models') - - def is_initialized(self) -> bool: - """Check if the registry has been initialized.""" - return self._initialized - # Global registry instance state_model_registry = StateModelRegistry() @@ -238,7 +223,7 @@ class TaskState(BaseState): # ... implementation """ - def decorator(state_model: Type['BaseState']) -> Type['BaseState']: + def decorator(state_model: 'BaseState') -> 'BaseState': state_model_registry.register_model(entity_name, state_model, denormalizer) return state_model @@ -246,7 +231,7 @@ def decorator(state_model: Type['BaseState']) -> Type['BaseState']: def register_state_model_class( - entity_name: str, state_model: Type['BaseState'], denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None + entity_name: str, state_model: 'BaseState', denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None ): """ Convenience function to register a state model programmatically. @@ -259,7 +244,7 @@ def register_state_model_class( state_model_registry.register_model(entity_name, state_model, denormalizer) -def get_state_model(entity_name: str) -> Optional[Type['BaseState']]: +def get_state_model(entity_name: str) -> Optional['BaseState']: """ Convenience function to get a state model. @@ -272,7 +257,7 @@ def get_state_model(entity_name: str) -> Optional[Type['BaseState']]: return state_model_registry.get_model(entity_name) -def get_state_model_for_entity(entity: Model) -> Optional[Type['BaseState']]: +def get_state_model_for_entity(entity: Model) -> Optional['BaseState']: """Get the state model for an entity.""" entity_name = entity._meta.model_name.lower() return get_state_model(entity_name) @@ -287,9 +272,9 @@ class TransitionRegistry: """ def __init__(self): - self._transitions: Dict[str, Dict[str, Type['BaseTransition']]] = {} + self._transitions: Dict[str, Dict[str, 'BaseTransition']] = {} - def register(self, entity_name: str, transition_name: str, transition_class: Type['BaseTransition']): + def register(self, entity_name: str, transition_name: str, transition_class: 'BaseTransition'): """ Register a transition class for an entity. @@ -303,7 +288,7 @@ def register(self, entity_name: str, transition_name: str, transition_class: Typ self._transitions[entity_name][transition_name] = transition_class - def get_transition(self, entity_name: str, transition_name: str) -> Optional[Type['BaseTransition']]: + def get_transition(self, entity_name: str, transition_name: str) -> Optional['BaseTransition']: """ Get a registered transition class. @@ -316,7 +301,7 @@ def get_transition(self, entity_name: str, transition_name: str) -> Optional[Typ """ return self._transitions.get(entity_name, {}).get(transition_name) - def get_transitions_for_entity(self, entity_name: str) -> Dict[str, Type['BaseTransition']]: + def get_transitions_for_entity(self, entity_name: str) -> Dict[str, 'BaseTransition']: """ Get all registered transitions for an entity type. @@ -376,6 +361,7 @@ def execute_transition( # Get current state information from fsm.state_manager import StateManager + from fsm.transitions import TransitionContext current_state_object = StateManager.get_current_state_object(entity) current_state = current_state_object.state if current_state_object else None @@ -413,7 +399,7 @@ class StartTaskTransition(BaseTransition[Task, TaskState]): # ... implementation """ - def decorator(transition_class: Type['BaseTransition']) -> Type['BaseTransition']: + def decorator(transition_class: 'BaseTransition') -> 'BaseTransition': name = transition_name if name is None: # Generate name from class name From 493f1367be42558df3e521ae95c64b418fab3ba1 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 16:14:01 -0500 Subject: [PATCH 48/83] update tests to ensure coverage --- .../fsm/tests/test_fsm_integration.py | 291 +++++++++++++++--- .../tests/test_integration_django_models.py | 27 +- label_studio/fsm/tests/test_registry.py | 202 ++++++++++++ label_studio/fsm/tests/test_transitions.py | 4 +- label_studio/fsm/tests/test_utils.py | 202 ++++++++++++ 5 files changed, 668 insertions(+), 58 deletions(-) create mode 100644 label_studio/fsm/tests/test_registry.py diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 8857e30f0d86..4d97faad38c1 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -5,7 +5,9 @@ """ from datetime import datetime, timezone +from unittest.mock import patch +import pytest from django.contrib.auth import get_user_model from django.test import TestCase from fsm.models import AnnotationState, ProjectState, TaskState @@ -40,18 +42,18 @@ def test_task_state_creation(self): ) # Check basic fields - self.assertEqual(task_state.state, 'CREATED') - self.assertEqual(task_state.task, self.task) - self.assertEqual(task_state.triggered_by, self.user) + assert task_state.state == 'CREATED' + assert task_state.task == self.task + assert task_state.triggered_by == self.user # Check UUID7 functionality - self.assertEqual(task_state.id.version, 7) - self.assertIsInstance(task_state.timestamp_from_uuid, datetime) + assert task_state.id.version == 7 + assert isinstance(task_state.timestamp_from_uuid, datetime) # Check string representation str_repr = str(task_state) - self.assertIn('Task', str_repr) - self.assertIn('CREATED', str_repr) + assert 'Task' in str_repr + assert 'CREATED' in str_repr def test_annotation_state_creation(self): """Test AnnotationState creation and basic functionality""" @@ -68,11 +70,11 @@ def test_annotation_state_creation(self): ) # Check basic fields - self.assertEqual(annotation_state.state, 'DRAFT') - self.assertEqual(annotation_state.annotation, annotation) + assert annotation_state.state == 'DRAFT' + assert annotation_state.annotation == annotation # Check terminal state property - self.assertFalse(annotation_state.is_terminal_state) + assert not annotation_state.is_terminal_state # Test completed state completed_state = AnnotationState.objects.create( @@ -83,7 +85,7 @@ def test_annotation_state_creation(self): state='COMPLETED', triggered_by=self.user, ) - self.assertTrue(completed_state.is_terminal_state) + assert completed_state.is_terminal_state def test_project_state_creation(self): """Test ProjectState creation and basic functionality""" @@ -92,18 +94,18 @@ def test_project_state_creation(self): ) # Check basic fields - self.assertEqual(project_state.state, 'CREATED') - self.assertEqual(project_state.project, self.project) + assert project_state.state == 'CREATED' + assert project_state.project == self.project # Test terminal state - self.assertFalse(project_state.is_terminal_state) + assert not project_state.is_terminal_state completed_state = ProjectState.objects.create(project=self.project, state='COMPLETED', triggered_by=self.user) - self.assertTrue(completed_state.is_terminal_state) + assert completed_state.is_terminal_state class TestStateManager(TestCase): - """Test StateManager functionality""" + """Test StateManager functionality with mocked transaction support""" def setUp(self): self.user = User.objects.create_user(email='test@example.com', password='test123') @@ -116,13 +118,31 @@ def setUp(self): cache.clear() + # Ensure registry is properly initialized for TaskState + from fsm.models import TaskState + from fsm.registry import state_model_registry + + if not state_model_registry.get_model('task'): + state_model_registry.register_model('task', TaskState) + def test_get_current_state_empty(self): """Test getting current state when no states exist""" current_state = self.StateManager.get_current_state(self.task) - self.assertIsNone(current_state) + assert current_state is None + + @patch('django.db.transaction.on_commit') + def test_transition_state(self, mock_on_commit): + """Test state transition functionality with mocked transaction.on_commit""" + from django.core.cache import cache + + cache.clear() + + # Mock transaction.on_commit to immediately execute the callback + def execute_callback(callback): + callback() + + mock_on_commit.side_effect = execute_callback - def test_transition_state(self): - """Test state transition functionality""" # Initial transition success = self.StateManager.transition_state( entity=self.task, @@ -132,11 +152,13 @@ def test_transition_state(self): reason='Initial task creation', ) - self.assertTrue(success) + assert success + # Verify transaction.on_commit was called once for cache update + assert mock_on_commit.call_count == 1 - # Check current state + # Check current state - should work with mocked cache update current_state = self.StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'CREATED') + assert current_state == 'CREATED' # Another transition success = self.StateManager.transition_state( @@ -147,29 +169,56 @@ def test_transition_state(self): context={'started_by': 'user'}, ) - self.assertTrue(success) + assert success + # Verify transaction.on_commit was called again (total 2 times) + assert mock_on_commit.call_count == 2 + current_state = self.StateManager.get_current_state(self.task) - self.assertEqual(current_state, 'IN_PROGRESS') + assert current_state == 'IN_PROGRESS' - def test_get_current_state_object(self): + @patch('django.db.transaction.on_commit') + def test_get_current_state_object(self, mock_on_commit): """Test getting current state object with full details""" + from django.core.cache import cache + + cache.clear() + + # Mock transaction.on_commit to immediately execute the callback + def execute_callback(callback): + callback() + + mock_on_commit.side_effect = execute_callback + # Create some state transitions self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) self.StateManager.transition_state( entity=self.task, new_state='IN_PROGRESS', user=self.user, context={'test': 'data'} ) + # Verify transaction.on_commit was called twice (once per transition) + assert mock_on_commit.call_count == 2 + current_state_obj = self.StateManager.get_current_state_object(self.task) - self.assertIsNotNone(current_state_obj) - self.assertEqual(current_state_obj.state, 'IN_PROGRESS') - self.assertEqual(current_state_obj.previous_state, 'CREATED') - self.assertEqual(current_state_obj.triggered_by, self.user) - self.assertEqual(current_state_obj.context_data, {'test': 'data'}) + assert current_state_obj is not None + assert current_state_obj.state == 'IN_PROGRESS' + assert current_state_obj.previous_state == 'CREATED' + assert current_state_obj.triggered_by == self.user + assert current_state_obj.context_data == {'test': 'data'} - def test_get_state_history(self): + @patch('django.db.transaction.on_commit') + def test_get_state_history(self, mock_on_commit): """Test state history retrieval""" - # Create multiple transitions + from django.core.cache import cache + + cache.clear() + + # Mock transaction.on_commit to immediately execute the callback + def execute_callback(callback): + callback() + + mock_on_commit.side_effect = execute_callback + transitions = [('CREATED', 'create_task'), ('IN_PROGRESS', 'start_work'), ('COMPLETED', 'finish_work')] for state, transition in transitions: @@ -177,33 +226,45 @@ def test_get_state_history(self): entity=self.task, new_state=state, user=self.user, transition_name=transition ) + # Verify transaction.on_commit was called 3 times (once per transition) + assert mock_on_commit.call_count == 3 + history = self.StateManager.get_state_history(self.task, limit=10) # Should have 3 state records - self.assertEqual(len(history), 3) + assert len(history) == 3 # Should be ordered by most recent first (UUID7 ordering) states = [h.state for h in history] - self.assertEqual(states, ['COMPLETED', 'IN_PROGRESS', 'CREATED']) - - print(history) - ids = [str(h.id) for h in history] - print(ids) + assert states == ['COMPLETED', 'IN_PROGRESS', 'CREATED'] # Check previous states are set correctly - self.assertIsNone(history[2].previous_state) # First state has no previous - self.assertEqual(history[1].previous_state, 'CREATED') - self.assertEqual(history[0].previous_state, 'IN_PROGRESS') + assert history[2].previous_state is None # First state has no previous + assert history[1].previous_state == 'CREATED' + assert history[0].previous_state == 'IN_PROGRESS' - def test_get_states_in_time_range(self): + @patch('django.db.transaction.on_commit') + def test_get_states_in_time_range(self, mock_on_commit): """Test time-based state queries using UUID7""" - # Record time before creating states + from django.core.cache import cache + + cache.clear() + + # Mock transaction.on_commit to immediately execute the callback + def execute_callback(callback): + callback() + + mock_on_commit.side_effect = execute_callback + before_time = datetime.now(timezone.utc) # Create some states self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) self.StateManager.transition_state(entity=self.task, new_state='IN_PROGRESS', user=self.user) + # Verify transaction.on_commit was called twice (once per transition) + assert mock_on_commit.call_count == 2 + # Record time after creating states after_time = datetime.now(timezone.utc) @@ -211,4 +272,146 @@ def test_get_states_in_time_range(self): states_in_range = self.StateManager.get_states_in_time_range(self.task, before_time, after_time) # Should find both states - self.assertEqual(len(states_in_range), 2) + assert len(states_in_range) == 2 + + @patch('django.db.transaction.on_commit') + def test_transaction_on_commit_success_case(self, mock_on_commit): + """Test that transaction.on_commit is called exactly once per successful transition""" + from django.core.cache import cache + + cache.clear() + + # Track callback executions + callbacks_executed = [] + + def track_and_execute(callback): + callbacks_executed.append(callback) + callback() # Execute the callback + + mock_on_commit.side_effect = track_and_execute + + # Perform a successful transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + transition_name='create_task', + reason='Initial task creation', + ) + + # Verify success and transaction.on_commit was called + assert success + assert mock_on_commit.call_count == 1 + assert len(callbacks_executed) == 1 + + # Verify the cache was properly updated by executing the callback + current_state = self.StateManager.get_current_state(self.task) + assert current_state == 'CREATED' + + # Perform another successful transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='IN_PROGRESS', + user=self.user, + transition_name='start_work', + ) + + assert success + assert mock_on_commit.call_count == 2 + assert len(callbacks_executed) == 2 + + current_state = self.StateManager.get_current_state(self.task) + assert current_state == 'IN_PROGRESS' + + @patch('django.db.transaction.on_commit') + @patch('fsm.state_manager.get_state_model_for_entity') + def test_transaction_on_commit_failure_case(self, mock_get_state_model, mock_on_commit): + """Test that transaction.on_commit is NOT called when transition fails""" + from django.core.cache import cache + + cache.clear() + + # Mock get_state_model_for_entity to return None (no state model found) + mock_get_state_model.return_value = None + + # Attempt a transition that should fail due to missing state model + with pytest.raises(Exception): # Should raise StateManagerError + self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + transition_name='create_task', + reason='This should fail', + ) + + # Verify transaction.on_commit was NOT called since transition failed + assert mock_on_commit.call_count == 0 + + # Verify cache was not updated (should still be None) + current_state = self.StateManager.get_current_state(self.task) + assert current_state is None + + @patch('django.db.transaction.on_commit') + @patch('fsm.models.TaskState.objects.create') + def test_transaction_on_commit_database_failure_case(self, mock_create, mock_on_commit): + """Test that transaction.on_commit is NOT called when database operation fails""" + from django.core.cache import cache + + cache.clear() + + # Mock database create operation to fail + mock_create.side_effect = Exception('Database constraint violation') + + # Attempt a transition that should fail due to database error + with pytest.raises(Exception): # Should raise StateManagerError + self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + transition_name='create_task', + reason='This should fail in DB', + ) + + # Verify transaction.on_commit was NOT called since transaction failed + assert mock_on_commit.call_count == 0 + + # Verify cache was deleted due to failure (cache.delete should be called) + current_state = self.StateManager.get_current_state(self.task) + assert current_state is None + + @patch('django.db.transaction.on_commit') + def test_transaction_on_commit_callback_content(self, mock_on_commit): + """Test that the transaction.on_commit callback properly updates the cache""" + from django.core.cache import cache + + cache.clear() + + # Capture the callback without executing it + captured_callbacks = [] + mock_on_commit.side_effect = lambda callback: captured_callbacks.append(callback) + + # Perform a transition + success = self.StateManager.transition_state( + entity=self.task, + new_state='CREATED', + user=self.user, + ) + + assert success + assert len(captured_callbacks) == 1 + + # Before executing callback, cache should be empty + cache_key = self.StateManager.get_cache_key(self.task) + cached_state = cache.get(cache_key) + assert cached_state is None + + # Execute the callback manually + captured_callbacks[0]() + + # After callback execution, cache should be updated + cached_state = cache.get(cache_key) + assert cached_state == 'CREATED' + + # Verify get_current_state uses the cached value + current_state = self.StateManager.get_current_state(self.task) + assert current_state == 'CREATED' diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py index 9cf4c85fc71d..40d9a59408fb 100644 --- a/label_studio/fsm/tests/test_integration_django_models.py +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -224,7 +224,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) # Validate and execute creation - assert create_transition.validate_transition(context) == True + assert create_transition.validate_transition(context) is True creation_data = create_transition.transition(context) assert creation_data['created_by_id'] == 100 @@ -248,12 +248,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: target_state=assign_transition.target_state, ) - assert assign_transition.validate_transition(context) == True + assert assign_transition.validate_transition(context) is True assignment_data = assign_transition.transition(context) assert assignment_data['assignee_id'] == 200 assert assignment_data['estimated_hours'] == 4.5 - assert assignment_data['work_started'] == True + assert assignment_data['work_started'] is True # Step 3: Complete task mock_current_state.context_data = assignment_data @@ -271,7 +271,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: target_state=complete_transition.target_state, ) - assert complete_transition.validate_transition(context) == True + assert complete_transition.validate_transition(context) is True completion_data = complete_transition.transition(context) assert completion_data['quality_score'] == 0.85 @@ -395,12 +395,12 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: target_state=submit_transition.target_state, ) - assert submit_transition.validate_transition(context) == True + assert submit_transition.validate_transition(context) is True submit_data = submit_transition.transition(context) assert submit_data['annotator_confidence'] == 0.9 assert submit_data['annotation_time_seconds'] == 300 - assert submit_data['review_requested'] == True + assert submit_data['review_requested'] is True assert submit_data['annotation_complexity'] == 1 # Based on mock result # Step 2: Review and approve @@ -422,7 +422,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: target_state=review_transition.target_state, ) - assert review_transition.validate_transition(context) == True + assert review_transition.validate_transition(context) is True assert review_transition.target_state == AnnotationStateChoices.COMPLETED review_data = review_transition.transition(context) @@ -450,6 +450,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) import pytest + with pytest.raises(TransitionValidationError) as cm: invalid_review.validate_transition(context) @@ -504,7 +505,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: # Test StateManager.execute_transition from fsm.state_manager import StateManager - + result = StateManager.execute_transition( entity=self.task, transition_name='bulk_update_status', @@ -516,7 +517,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: }, user=self.user, project_update=True, - notification_level='high' + notification_level='high', ) # Verify the call @@ -532,11 +533,11 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: transition_data = call_kwargs['transition_data'] assert transition_data['new_status'] == TaskStateChoices.IN_PROGRESS assert transition_data['update_reason'] == 'Project priority change' - assert transition_data['updated_by_system'] == True + assert transition_data['updated_by_system'] is True assert transition_data['batch_id'] == 'batch_2024_001' # Check context - assert call_kwargs['project_update'] == True + assert call_kwargs['project_update'] is True assert call_kwargs['notification_level'] == 'high' # Check return value @@ -626,7 +627,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: target_state=valid_transition.target_state, ) - assert valid_transition.validate_transition(context) == True + assert valid_transition.validate_transition(context) is True # Test multiple validation errors invalid_transition = AssignTaskWithConstraints( @@ -636,6 +637,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) import pytest + with pytest.raises(TransitionValidationError) as cm: invalid_transition.validate_transition(context) @@ -661,6 +663,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: ) import pytest + with pytest.raises(TransitionValidationError) as cm: valid_transition.validate_transition(context_no_user) diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py new file mode 100644 index 000000000000..44ec8cc42119 --- /dev/null +++ b/label_studio/fsm/tests/test_registry.py @@ -0,0 +1,202 @@ +""" +Tests for FSM registry functionality. + +Tests registry management, state model registration, transition registration, +and related error handling scenarios. +""" + +import pytest +from typing import Any, Dict +from unittest.mock import Mock, patch + +from django.test import TestCase +from fsm.registry import ( + StateModelRegistry, + TransitionRegistry, + register_state_model, + register_state_transition, + state_choices_registry, + state_model_registry, + transition_registry, +) +from fsm.state_manager import StateManager +from fsm.transitions import BaseTransition +from pydantic import Field + + +class MockEntity: + """Mock entity for testing""" + + def __init__(self, pk=1): + self.pk = pk + self.id = pk + self._meta = Mock() + self._meta.model_name = 'testentity' + self._meta.label_lower = 'tests.testentity' + self.organization_id = 1 + + +class RegistryTests(TestCase): + """Tests for registry functionality and edge cases""" + + def setUp(self): + # Clear registries to ensure clean state + state_choices_registry.clear() + state_model_registry.clear() + transition_registry.clear() + + self.entity = MockEntity() + + def test_registry_execute_transition_integration(self): + """Test TransitionRegistry.execute_transition method""" + + class SimpleTransition(BaseTransition): + """Simple transition for testing""" + + message: str = Field(default="test") + + @property + def target_state(self) -> str: + return 'COMPLETED' + + def transition(self, context): + return {'message': self.message} + + def execute(self, context): + # Create a mock state record + state_record = Mock() + state_record.id = 'test-uuid' + state_record.state = self.target_state + return state_record + + transition_registry.register('testentity', 'simple_transition', SimpleTransition) + + # Mock the StateManager methods used in TransitionRegistry + with patch.object(StateManager, 'get_current_state_object') as mock_get_state: + mock_get_state.return_value = None + + result = transition_registry.execute_transition( + entity_name='testentity', + transition_name='simple_transition', + entity=self.entity, + transition_data={'message': 'Hello'}, + user=None, + ) + + assert result is not None + assert result.state == 'COMPLETED' + + def test_registry_state_model_with_denormalizer(self): + """Test StateModelRegistry with denormalizer function""" + + mock_state_model = Mock() + mock_state_model.__name__ = 'MockStateModel' + + def test_denormalizer(entity): + return {'custom_field': f'denormalized_{entity.pk}'} + + # Register with denormalizer + state_model_registry.register_model('testentity', mock_state_model, test_denormalizer) + + # Check denormalizer was stored + denormalizer = state_model_registry.get_denormalizer('testentity') + assert denormalizer is not None + + result = denormalizer(self.entity) + assert result == {'custom_field': 'denormalized_1'} + + def test_registry_denormalizer_error_handling(self): + """Test denormalizer error handling in state model registry""" + + mock_state_model = Mock() + mock_state_model.__name__ = 'MockStateModel' + + def failing_denormalizer(entity): + raise RuntimeError("Denormalizer failed") + + state_model_registry.register_model('testentity', mock_state_model, failing_denormalizer) + + # Should handle denormalizer errors gracefully + denormalizer = state_model_registry.get_denormalizer('testentity') + with pytest.raises(RuntimeError): + denormalizer(self.entity) + + def test_registry_overwrite_warning(self): + """Test warning when overwriting existing registry entries""" + + mock_state_model1 = Mock() + mock_state_model1.__name__ = 'MockModel1' + mock_state_model2 = Mock() + mock_state_model2.__name__ = 'MockModel2' + + # Register first model + state_model_registry.register_model('testentity', mock_state_model1) + + # Register second model (should warn about overwrite) + import logging + with patch('fsm.registry.logger') as mock_logger: + state_model_registry.register_model('testentity', mock_state_model2) + + # Should have logged warning about overwrite + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert 'Overwriting existing state model' in warning_msg + assert 'testentity' in warning_msg + assert 'Previous:' in warning_msg + assert 'New:' in warning_msg + + def test_registry_clear_methods(self): + """Test registry clear methods""" + + # Add some test data + mock_state_model = Mock() + mock_state_model.__name__ = 'MockStateModel' + state_model_registry.register_model('testentity', mock_state_model) + + class TestTransition(BaseTransition): + @property + def target_state(self) -> str: + return 'TEST' + + def transition(self, context): + return {} + + transition_registry.register('testentity', 'test_transition', TestTransition) + + # Verify data exists + assert state_model_registry.get_model('testentity') is not None + assert 'test_transition' in transition_registry.get_transitions_for_entity('testentity') + + # Clear registries + state_model_registry.clear() + transition_registry.clear() + + # Verify data is cleared + assert state_model_registry.get_model('testentity') is None + assert transition_registry.get_transitions_for_entity('testentity') == {} + + def test_registry_decorator_functions(self): + """Test decorator functions for registration""" + + # Test state model decorator + @register_state_model('decorated_entity') + class DecoratedStateModel: + pass + + # Should be registered + assert state_model_registry.get_model('decorated_entity') == DecoratedStateModel + + # Test transition decorator + @register_state_transition('decorated_entity', 'decorated_transition') + class DecoratedTransition(BaseTransition): + @property + def target_state(self) -> str: + return 'DECORATED' + + def transition(self, context): + return {} + + # Should be registered + transitions = transition_registry.get_transitions_for_entity('decorated_entity') + assert 'decorated_transition' in transitions + assert transitions['decorated_transition'] == DecoratedTransition \ No newline at end of file diff --git a/label_studio/fsm/tests/test_transitions.py b/label_studio/fsm/tests/test_transitions.py index fb7b0b72215a..a8677896cc89 100644 --- a/label_studio/fsm/tests/test_transitions.py +++ b/label_studio/fsm/tests/test_transitions.py @@ -7,13 +7,13 @@ """ from datetime import datetime, timedelta -from django.utils import timezone from typing import Any, Dict from unittest.mock import Mock, patch import pytest from django.contrib.auth import get_user_model from django.test import TestCase +from django.utils import timezone from fsm.registry import register_state_transition, transition_registry from fsm.transition_utils import get_available_transitions from fsm.transitions import ( @@ -845,4 +845,4 @@ def transition(self, context: TransitionContext) -> dict: # Invalid data should raise validation error with pytest.raises(ValidationError): # Pydantic validation error - SampleTransition() # Missing required field \ No newline at end of file + SampleTransition() # Missing required field diff --git a/label_studio/fsm/tests/test_utils.py b/label_studio/fsm/tests/test_utils.py index 5bd95fb1a646..55b18b352afd 100644 --- a/label_studio/fsm/tests/test_utils.py +++ b/label_studio/fsm/tests/test_utils.py @@ -6,8 +6,19 @@ import uuid from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch +# Additional imports for transition_utils coverage tests +import pytest from django.test import TestCase +from fsm.registry import transition_registry +from fsm.transition_utils import ( + create_transition_from_dict, + get_available_transitions, + get_entity_state_flow, + validate_transition_data, +) +from fsm.transitions import BaseTransition, TransitionValidationError from fsm.utils import ( UUID7Generator, generate_uuid7, @@ -16,6 +27,7 @@ uuid7_time_range, validate_uuid7, ) +from pydantic import Field class TestUUID7Utils(TestCase): @@ -169,3 +181,193 @@ def test_generator_monotonic(self): # Should be monotonic even with same timestamp assert uuid1.int < uuid2.int assert uuid2.int < uuid3.int + + +class MockEntity: + """Mock entity for testing""" + + def __init__(self, pk=1): + self.pk = pk + self.id = pk + self._meta = Mock() + self._meta.model_name = 'testentity' + self._meta.label_lower = 'tests.testentity' + self.organization_id = 1 + + +class TransitionUtilsTests(TestCase): + """Tests for transition_utils module edge cases and error handling""" + + def setUp(self): + # Clear registries to ensure clean state + from fsm.registry import state_choices_registry, state_model_registry + + state_choices_registry.clear() + state_model_registry.clear() + transition_registry.clear() + + self.entity = MockEntity() + + def test_transition_utils_unexpected_validation_error(self): + """Test unexpected error during transition validation in get_available_transitions""" + + class BrokenTransition(BaseTransition): + """Transition that raises unexpected error""" + + @property + def target_state(self) -> str: + return 'BROKEN' + + def transition(self, context): + return {} + + @classmethod + def can_transition_from_state(cls, context): + # Raise unexpected error + raise RuntimeError('Unexpected validation error') + + # Register the broken transition + transition_registry.register('testentity', 'broken_transition', BrokenTransition) + + # Should handle the error gracefully and log warning + import logging + + mock_logger = Mock() + with patch.object(logging, 'getLogger', return_value=mock_logger): + result = get_available_transitions(self.entity, validate=True) + # Should not include the broken transition + assert 'broken_transition' not in result + # Should have logged the warning + mock_logger.warning.assert_called_once() + assert 'Unexpected error validating transition' in mock_logger.warning.call_args[0][0] + + def test_transition_utils_validation_error_handling(self): + """Test TransitionValidationError handling in get_available_transitions""" + + class ValidatingTransition(BaseTransition): + """Transition that raises validation error""" + + @property + def target_state(self) -> str: + return 'VALIDATED' + + def transition(self, context): + return {} + + @classmethod + def can_transition_from_state(cls, context): + raise TransitionValidationError('Not allowed from this state') + + transition_registry.register('testentity', 'validating_transition', ValidatingTransition) + + # Should exclude invalid transitions without logging + import logging + + mock_logger = Mock() + with patch.object(logging, 'getLogger', return_value=mock_logger): + result = get_available_transitions(self.entity, validate=True) + assert 'validating_transition' not in result + # Should not log for expected validation errors + mock_logger.warning.assert_not_called() + + def test_transition_utils_create_from_dict_error(self): + """Test create_transition_from_dict error handling""" + + class StrictTransition(BaseTransition): + """Transition with strict validation""" + + required_field: str = Field(...) + + @property + def target_state(self) -> str: + return 'STRICT' + + def transition(self, context): + return {'required_field': self.required_field} + + # Should raise ValueError with helpful message + with pytest.raises(ValueError) as exc_info: + create_transition_from_dict(StrictTransition, {}) + + assert 'Failed to create StrictTransition' in str(exc_info.value) + + def test_transition_utils_validate_transition_data_errors(self): + """Test validate_transition_data with various error cases""" + + class ValidationTransition(BaseTransition): + """Transition with various field types""" + + required_field: str = Field(...) + number_field: int = Field(default=0, ge=0) + + @property + def target_state(self) -> str: + return 'VALIDATED' + + def transition(self, context): + return {'required_field': self.required_field, 'number_field': self.number_field} + + # Test with missing required field + errors = validate_transition_data(ValidationTransition, {}) + assert 'required_field' in errors + assert any('required' in msg.lower() or 'missing' in msg.lower() for msg in errors['required_field']) + + # Test with invalid type + errors = validate_transition_data( + ValidationTransition, {'required_field': 123, 'number_field': 'not_a_number'} + ) + # Either required_field or number_field should have type error + assert len(errors) > 0 + + # Test with validation constraint violation + errors = validate_transition_data(ValidationTransition, {'required_field': 'test', 'number_field': -1}) + assert 'number_field' in errors + + # Test valid data returns empty dict + errors = validate_transition_data(ValidationTransition, {'required_field': 'test', 'number_field': 5}) + assert errors == {} + + def test_transition_utils_validate_with_non_pydantic_error(self): + """Test validate_transition_data with non-Pydantic errors""" + + class CustomErrorTransition(BaseTransition): + """Transition that raises custom error in __init__""" + + @property + def target_state(self) -> str: + return 'ERROR' + + def transition(self, context): + return {} + + def __init__(self, **data): + # Raise a non-ValidationError + if 'trigger_error' in data: + raise RuntimeError('Custom initialization error') + super().__init__(**data) + + errors = validate_transition_data(CustomErrorTransition, {'trigger_error': True}) + assert '__root__' in errors + assert 'Custom initialization error' in errors['__root__'][0] + + def test_transition_utils_entity_state_flow_errors(self): + """Test get_entity_state_flow with transitions that can't be instantiated""" + + class RequiredFieldTransition(BaseTransition): + """Transition requiring fields to instantiate""" + + required_field: str = Field(...) + + @property + def target_state(self) -> str: + return 'REQUIRED' + + def transition(self, context): + return {'required_field': self.required_field} + + transition_registry.register('testentity', 'required_transition', RequiredFieldTransition) + + # Should skip transitions that can't be instantiated + flows = get_entity_state_flow(self.entity) + # Should not include the transition that requires fields + assert not any(f['transition_name'] == 'required_transition' for f in flows) From 9caa7804e94babf7f2f37ea8f8bb171fa120d370 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 16:17:04 -0500 Subject: [PATCH 49/83] fixing lint errors --- label_studio/fsm/tests/test_registry.py | 30 +++++++++++-------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py index 44ec8cc42119..55a1c2ae0a7d 100644 --- a/label_studio/fsm/tests/test_registry.py +++ b/label_studio/fsm/tests/test_registry.py @@ -5,14 +5,11 @@ and related error handling scenarios. """ -import pytest -from typing import Any, Dict from unittest.mock import Mock, patch +import pytest from django.test import TestCase from fsm.registry import ( - StateModelRegistry, - TransitionRegistry, register_state_model, register_state_transition, state_choices_registry, @@ -44,7 +41,7 @@ def setUp(self): state_choices_registry.clear() state_model_registry.clear() transition_registry.clear() - + self.entity = MockEntity() def test_registry_execute_transition_integration(self): @@ -53,8 +50,8 @@ def test_registry_execute_transition_integration(self): class SimpleTransition(BaseTransition): """Simple transition for testing""" - message: str = Field(default="test") - + message: str = Field(default='test') + @property def target_state(self) -> str: return 'COMPLETED' @@ -91,7 +88,7 @@ def test_registry_state_model_with_denormalizer(self): mock_state_model = Mock() mock_state_model.__name__ = 'MockStateModel' - + def test_denormalizer(entity): return {'custom_field': f'denormalized_{entity.pk}'} @@ -101,7 +98,7 @@ def test_denormalizer(entity): # Check denormalizer was stored denormalizer = state_model_registry.get_denormalizer('testentity') assert denormalizer is not None - + result = denormalizer(self.entity) assert result == {'custom_field': 'denormalized_1'} @@ -110,9 +107,9 @@ def test_registry_denormalizer_error_handling(self): mock_state_model = Mock() mock_state_model.__name__ = 'MockStateModel' - + def failing_denormalizer(entity): - raise RuntimeError("Denormalizer failed") + raise RuntimeError('Denormalizer failed') state_model_registry.register_model('testentity', mock_state_model, failing_denormalizer) @@ -133,10 +130,9 @@ def test_registry_overwrite_warning(self): state_model_registry.register_model('testentity', mock_state_model1) # Register second model (should warn about overwrite) - import logging with patch('fsm.registry.logger') as mock_logger: state_model_registry.register_model('testentity', mock_state_model2) - + # Should have logged warning about overwrite mock_logger.warning.assert_called_once() warning_msg = mock_logger.warning.call_args[0][0] @@ -152,12 +148,12 @@ def test_registry_clear_methods(self): mock_state_model = Mock() mock_state_model.__name__ = 'MockStateModel' state_model_registry.register_model('testentity', mock_state_model) - + class TestTransition(BaseTransition): @property def target_state(self) -> str: return 'TEST' - + def transition(self, context): return {} @@ -192,11 +188,11 @@ class DecoratedTransition(BaseTransition): @property def target_state(self) -> str: return 'DECORATED' - + def transition(self, context): return {} # Should be registered transitions = transition_registry.get_transitions_for_entity('decorated_entity') assert 'decorated_transition' in transitions - assert transitions['decorated_transition'] == DecoratedTransition \ No newline at end of file + assert transitions['decorated_transition'] == DecoratedTransition From 4170f561a28f96936e3d08b8c86b7393c095f4d7 Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 3 Sep 2025 21:21:17 +0000 Subject: [PATCH 50/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17446383898 --- poetry.lock | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index fad45354b212..c7f04b08b5b6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "annotated-types" @@ -3178,6 +3178,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -3797,6 +3798,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, From 7990f3ea9fad0c597c8c65b3cb7b8aee76898fb5 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Wed, 3 Sep 2025 16:24:56 -0500 Subject: [PATCH 51/83] reverting poetry.lock to develop to regenerate it on CI --- poetry.lock | 58 ++++++++++------------------------------------------- 1 file changed, 11 insertions(+), 47 deletions(-) diff --git a/poetry.lock b/poetry.lock index fad45354b212..f8cc38106a90 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.4 and should not be changed by hand. [[package]] name = "annotated-types" @@ -996,27 +996,27 @@ django = ">=4.2" [[package]] name = "djangorestframework-simplejwt" -version = "5.4.0" +version = "5.5.1" description = "A minimal JSON Web Token authentication plugin for Django REST Framework" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "djangorestframework_simplejwt-5.4.0-py3-none-any.whl", hash = "sha256:7aec953db9ed4163430c16d086eecb0f028f814ce6bba62b06c25919261e9077"}, - {file = "djangorestframework_simplejwt-5.4.0.tar.gz", hash = "sha256:cccecce1a0e1a4a240fae80da73e5fc23055bababb8b67de88fa47cd36822320"}, + {file = "djangorestframework_simplejwt-5.5.1-py3-none-any.whl", hash = "sha256:2c30f3707053d384e9f315d11c2daccfcb548d4faa453111ca19a542b732e469"}, + {file = "djangorestframework_simplejwt-5.5.1.tar.gz", hash = "sha256:e72c5572f51d7803021288e2057afcbd03f17fe11d484096f40a460abc76e87f"}, ] [package.dependencies] cryptography = {version = ">=3.3.1", optional = true, markers = "extra == \"crypto\""} django = ">=4.2" djangorestframework = ">=3.14" -pyjwt = ">=1.7.1,<3" +pyjwt = ">=1.7.1" [package.extras] crypto = ["cryptography (>=3.3.1)"] -dev = ["Sphinx (>=1.6.5,<2)", "cryptography", "flake8", "freezegun", "ipython", "isort", "pep8", "pytest", "pytest-cov", "pytest-django", "pytest-watch", "pytest-xdist", "python-jose (==3.3.0)", "sphinx_rtd_theme (>=0.1.9)", "tox", "twine", "wheel"] -doc = ["Sphinx (>=1.6.5,<2)", "sphinx_rtd_theme (>=0.1.9)"] -lint = ["flake8", "isort", "pep8"] +dev = ["Sphinx", "cryptography", "freezegun", "ipython", "pre-commit", "pytest", "pytest-cov", "pytest-django", "pytest-watch", "pytest-xdist", "python-jose (==3.3.0)", "pyupgrade", "ruff", "sphinx_rtd_theme (>=0.1.9)", "tox", "twine", "wheel", "yesqa"] +doc = ["Sphinx", "sphinx_rtd_theme (>=0.1.9)"] +lint = ["pre-commit", "pyupgrade", "ruff", "yesqa"] python-jose = ["python-jose (==3.3.0)"] test = ["cryptography", "freezegun", "pytest", "pytest-cov", "pytest-django", "pytest-xdist", "tox"] @@ -3178,6 +3178,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -3797,6 +3798,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4835,44 +4837,6 @@ files = [ [package.dependencies] ua-parser = ">=0.10.0" -[[package]] -name = "uuid-utils" -version = "0.11.0" -description = "Drop-in replacement for Python UUID with bindings in Rust" -optional = false -python-versions = ">=3.9" -groups = ["main"] -files = [ - {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:094445ccd323bc5507e28e9d6d86b983513efcf19ab59c2dd75239cef765631a"}, - {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6430b53d343215f85269ffd74e1d1f4b25ae1031acf0ac24ff3d5721f6a06f48"}, - {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be2e6e4318d23195887fa74fa1d64565a34f7127fdcf22918954981d79765f68"}, - {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d37289ab72aa30b5550bfa64d91431c62c89e4969bdf989988aa97f918d5f803"}, - {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1012595220f945fe09641f1365a8a06915bf432cac1b31ebd262944934a9b787"}, - {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35cd3fc718a673e4516e87afb9325558969eca513aa734515b9031d1b651bbb1"}, - {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ed325e0c40e0f59ae82b347f534df954b50cedf12bf60d025625538530e1965d"}, - {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5c8b7cf201990ee3140956e541967bd556a7365ec738cb504b04187ad89c757a"}, - {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9966df55bed5d538ba2e9cc40115796480f437f9007727116ef99dc2f42bd5fa"}, - {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:cb04b6c604968424b7e6398d54debbdd5b771b39fc1e648c6eabf3f1dc20582e"}, - {file = "uuid_utils-0.11.0-cp39-abi3-win32.whl", hash = "sha256:18420eb3316bb514f09f2da15750ac135478c3a12a704e2c5fb59eab642bb255"}, - {file = "uuid_utils-0.11.0-cp39-abi3-win_amd64.whl", hash = "sha256:37c4805af61a7cce899597d34e7c3dd5cb6a8b4b93a90fbca3826b071ba544df"}, - {file = "uuid_utils-0.11.0-cp39-abi3-win_arm64.whl", hash = "sha256:4065cf17bbe97f6d8ccc7dc6a0bae7d28fd4797d7f32028a5abd979aeb7bf7c9"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:db821c98a95f9d69ebf9c442bcf764548c4c5feebd6012a881233fcdc8f47ff4"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:07cd17ecef3bfdf319d8e6583334f4c8e71d9950503b69d6722999c88a42dbe2"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b29c4aa76586c67e865548c862b0dee98359d59eda78b58d58290dd0dd240e"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05bfd684cb55825bc5d4c340bfce3a90009e662491e7bdfd5f667a367e0a11e4"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5455b145cb6f647888f3c4fd38ec177cf51479c73c6a44503d4b7a70f45d9870"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51374cd3280e5a8c524c51ed09901cf2268907371e1b3dc59484a92e25f070a"}, - {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:691f576327836f93102f2bf8882eb67416452bab03c3dd8c31d009c4e85dd2aa"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:912e9ae2b5c2b72bd98046ee83e1b8fa22489b4a25f44495d1c0999fa6dde237"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ce73c719e0baebc8b1652e7663bec7d4db53edbd7be1affe92b1035fc80f409b"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f7f7e0245bcedbc4ff61ad4000fd661dc93677264c0566b31010d6da0b86a63"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9930137fd6d59c681f7e013ae9343b4b9d27f7e6efce4ecb259336e15ba578b8"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f6a306878b2327b79d65bd18d5521ef8b3775c2b03a5054b1b6f602cd876cc3"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c8346b3688b2df0baae4d3ff47cd84c765aa57cf103077e32806d66f1fcd689"}, - {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c7a7f415edb5aea38bc53057c8aff4b31d35e192f2902f6ac10f2e52d3f52ae0"}, - {file = "uuid_utils-0.11.0.tar.gz", hash = "sha256:18cf2b7083da7f3cca0517647213129eb16d20d7ed0dd74b3f4f8bff2aa334ea"}, -] - [[package]] name = "uwsgitop" version = "0.12" @@ -5073,4 +5037,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "c3005147d2d86bd16aa870ce669ef7ecdfa5fe9730de6d0d880cee1857e7d28b" +content-hash = "b9b4a8f1036605a51b4e29458b7cd37c2e1a894501303cdbaff0591f3b8fcb46" From 81bc609220e7dd3b62f81356face85b6234e34b5 Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 3 Sep 2025 21:28:06 +0000 Subject: [PATCH 52/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17446529166 --- poetry.lock | 40 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index f8cc38106a90..6eb3f8e6767c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4837,6 +4837,44 @@ files = [ [package.dependencies] ua-parser = ">=0.10.0" +[[package]] +name = "uuid-utils" +version = "0.11.0" +description = "Drop-in replacement for Python UUID with bindings in Rust" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:094445ccd323bc5507e28e9d6d86b983513efcf19ab59c2dd75239cef765631a"}, + {file = "uuid_utils-0.11.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:6430b53d343215f85269ffd74e1d1f4b25ae1031acf0ac24ff3d5721f6a06f48"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:be2e6e4318d23195887fa74fa1d64565a34f7127fdcf22918954981d79765f68"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d37289ab72aa30b5550bfa64d91431c62c89e4969bdf989988aa97f918d5f803"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1012595220f945fe09641f1365a8a06915bf432cac1b31ebd262944934a9b787"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35cd3fc718a673e4516e87afb9325558969eca513aa734515b9031d1b651bbb1"}, + {file = "uuid_utils-0.11.0-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ed325e0c40e0f59ae82b347f534df954b50cedf12bf60d025625538530e1965d"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5c8b7cf201990ee3140956e541967bd556a7365ec738cb504b04187ad89c757a"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9966df55bed5d538ba2e9cc40115796480f437f9007727116ef99dc2f42bd5fa"}, + {file = "uuid_utils-0.11.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:cb04b6c604968424b7e6398d54debbdd5b771b39fc1e648c6eabf3f1dc20582e"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win32.whl", hash = "sha256:18420eb3316bb514f09f2da15750ac135478c3a12a704e2c5fb59eab642bb255"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win_amd64.whl", hash = "sha256:37c4805af61a7cce899597d34e7c3dd5cb6a8b4b93a90fbca3826b071ba544df"}, + {file = "uuid_utils-0.11.0-cp39-abi3-win_arm64.whl", hash = "sha256:4065cf17bbe97f6d8ccc7dc6a0bae7d28fd4797d7f32028a5abd979aeb7bf7c9"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:db821c98a95f9d69ebf9c442bcf764548c4c5feebd6012a881233fcdc8f47ff4"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:07cd17ecef3bfdf319d8e6583334f4c8e71d9950503b69d6722999c88a42dbe2"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1b29c4aa76586c67e865548c862b0dee98359d59eda78b58d58290dd0dd240e"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:05bfd684cb55825bc5d4c340bfce3a90009e662491e7bdfd5f667a367e0a11e4"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5455b145cb6f647888f3c4fd38ec177cf51479c73c6a44503d4b7a70f45d9870"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f51374cd3280e5a8c524c51ed09901cf2268907371e1b3dc59484a92e25f070a"}, + {file = "uuid_utils-0.11.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:691f576327836f93102f2bf8882eb67416452bab03c3dd8c31d009c4e85dd2aa"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:912e9ae2b5c2b72bd98046ee83e1b8fa22489b4a25f44495d1c0999fa6dde237"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ce73c719e0baebc8b1652e7663bec7d4db53edbd7be1affe92b1035fc80f409b"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f7f7e0245bcedbc4ff61ad4000fd661dc93677264c0566b31010d6da0b86a63"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9930137fd6d59c681f7e013ae9343b4b9d27f7e6efce4ecb259336e15ba578b8"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f6a306878b2327b79d65bd18d5521ef8b3775c2b03a5054b1b6f602cd876cc3"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c8346b3688b2df0baae4d3ff47cd84c765aa57cf103077e32806d66f1fcd689"}, + {file = "uuid_utils-0.11.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c7a7f415edb5aea38bc53057c8aff4b31d35e192f2902f6ac10f2e52d3f52ae0"}, + {file = "uuid_utils-0.11.0.tar.gz", hash = "sha256:18cf2b7083da7f3cca0517647213129eb16d20d7ed0dd74b3f4f8bff2aa334ea"}, +] + [[package]] name = "uwsgitop" version = "0.12" @@ -5037,4 +5075,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "b9b4a8f1036605a51b4e29458b7cd37c2e1a894501303cdbaff0591f3b8fcb46" +content-hash = "c3005147d2d86bd16aa870ce669ef7ecdfa5fe9730de6d0d880cee1857e7d28b" From e207db86c64329ed1d6b65671291997fa0a58d6e Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 10:10:36 -0500 Subject: [PATCH 53/83] updating docs --- label_studio/fsm/README.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 995027e68da1..d2c7e2fbd9b6 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -144,10 +144,6 @@ current_state = StateManager.get_current_state(order) # Get state history history = StateManager.get_state_history(order, limit=10) - -# Bulk operations for performance -orders = Order.objects.all()[:1000] -states = StateManager.bulk_get_current_states(orders) ``` ## Key Features From bf969de150bd4e351e3d3dc2302ef194838a81c6 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 10:13:01 -0500 Subject: [PATCH 54/83] removing unused code --- label_studio/core/settings/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 11b0124fcde7..2f70fdbea96b 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -894,5 +894,4 @@ def collect_versions_dummy(**kwargs): DM_MAX_USERS_TO_DISPLAY = int(get_env('DM_MAX_USERS_TO_DISPLAY', 10)) # Base FSM (Finite State Machine) Configuration for Label Studio -FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) -FSM_AUTO_CREATE_STATES = False +FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) \ No newline at end of file From e9d145b6d50fcc0ca7e469b9e11649da78b4dd87 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 12:11:06 -0500 Subject: [PATCH 55/83] applying feedback from CR --- label_studio/core/settings/base.py | 2 +- label_studio/fsm/models.py | 36 ++++++++++++++-- label_studio/fsm/registry.py | 69 ++++++------------------------ label_studio/fsm/state_manager.py | 21 ++++----- 4 files changed, 54 insertions(+), 74 deletions(-) diff --git a/label_studio/core/settings/base.py b/label_studio/core/settings/base.py index 2f70fdbea96b..0fe164623fae 100644 --- a/label_studio/core/settings/base.py +++ b/label_studio/core/settings/base.py @@ -894,4 +894,4 @@ def collect_versions_dummy(**kwargs): DM_MAX_USERS_TO_DISPLAY = int(get_env('DM_MAX_USERS_TO_DISPLAY', 10)) # Base FSM (Finite State Machine) Configuration for Label Studio -FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) \ No newline at end of file +FSM_CACHE_TTL = 300 # Cache TTL in seconds (5 minutes) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index f68d49ef3cbd..8677000d1924 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -3,11 +3,11 @@ """ from datetime import datetime -from typing import Optional +from typing import Any, Dict, Optional from django.conf import settings from django.db import models -from django.db.models import UUIDField +from django.db.models import QuerySet, UUIDField from fsm.utils import UUID7Field, generate_uuid7, timestamp_from_uuid7 @@ -135,13 +135,13 @@ def get_current_state_value(cls, entity) -> Optional[str]: return current_state.state if current_state else None @classmethod - def get_state_history(cls, entity, limit: int = 100): + def get_state_history(cls, entity, limit: int = 100) -> QuerySet['BaseState']: """Get complete state history for an entity""" entity_field = f'{cls._get_entity_field_name()}' return cls.objects.filter(**{entity_field: entity}).order_by('-id')[:limit] @classmethod - def get_states_in_range(cls, entity, start_time: datetime, end_time: datetime): + def get_states_in_range(cls, entity, start_time: datetime, end_time: datetime) -> QuerySet['BaseState']: """ Efficient time-range queries using UUID7. @@ -159,6 +159,34 @@ def get_states_since(cls, entity, since: datetime): queryset = cls.objects.filter(**{entity_field: entity}) return UUID7Field.filter_since_time(queryset, since).order_by('id') + @classmethod + def get_denormalized_fields(cls, entity) -> Dict[str, Any]: + """ + Get denormalized fields to include in the state record. + + Override this method in subclasses to provide denormalized data + that should be stored with each state transition for performance + optimization and auditing purposes. + + Args: + entity: The entity instance being transitioned + + Returns: + Dictionary of field names to values that should be stored + in the state record + + Example: + @classmethod + def get_denormalized_fields(cls, entity): + return { + 'project_id': entity.project_id, + 'organization_id': entity.project.organization_id, + 'task_type': entity.task_type, + 'priority': entity.priority + } + """ + return {} + @classmethod def _get_entity_field_name(cls) -> str: """Get the foreign key field name for the entity""" diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index f1115016f954..31c2538f0d74 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -7,7 +7,7 @@ import logging import typing -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Dict, Optional, Type from django.db.models import Model, TextChoices @@ -113,21 +113,14 @@ class StateModelRegistry: def __init__(self): self._models: Dict[str, 'BaseState'] = {} - self._denormalizers: Dict[str, Callable[[Model], Dict[str, Any]]] = {} - def register_model( - self, - entity_name: str, - state_model: 'BaseState', - denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None, - ): + def register_model(self, entity_name: str, state_model: 'BaseState'): """ Register a state model for an entity type. Args: entity_name: Name of the entity (e.g., 'task', 'annotation') state_model: The state model class for this entity - denormalizer: Optional function to extract denormalized fields """ entity_key = entity_name.lower() @@ -138,10 +131,6 @@ def register_model( ) self._models[entity_key] = state_model - - if denormalizer: - self._denormalizers[entity_key] = denormalizer - logger.debug(f'Registered state model for {entity_key}: {state_model.__name__}') def get_model(self, entity_name: str) -> Optional['BaseState']: @@ -156,39 +145,6 @@ def get_model(self, entity_name: str) -> Optional['BaseState']: """ return self._models.get(entity_name.lower()) - def get_denormalizer(self, entity_name: str) -> Optional[Callable]: - """ - Get the denormalization function for an entity type. - - Args: - entity_name: Name of the entity - - Returns: - Denormalizer function or None if not registered - """ - return self._denormalizers.get(entity_name.lower()) - - def get_denormalized_fields(self, entity: Model) -> Dict[str, Any]: - """ - Get denormalized fields for an entity. - - Args: - entity: The entity instance - - Returns: - Dictionary of denormalized fields - """ - entity_name = entity._meta.model_name.lower() - denormalizer = self._denormalizers.get(entity_name) - - if denormalizer: - try: - return denormalizer(entity) - except Exception as e: - logger.error(f'Error getting denormalized fields for {entity_name}: {e}') - - return {} - def is_registered(self, entity_name: str) -> bool: """Check if a model is registered for an entity type.""" return entity_name.lower() in self._models @@ -196,8 +152,6 @@ def is_registered(self, entity_name: str) -> bool: def clear(self): """Clear all registered models (useful for testing).""" self._models.clear() - self._denormalizers.clear() - self._initialized = False logger.debug('Cleared state model registry') def get_all_models(self) -> Dict[str, 'BaseState']: @@ -209,39 +163,40 @@ def get_all_models(self) -> Dict[str, 'BaseState']: state_model_registry = StateModelRegistry() -def register_state_model(entity_name: str, denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None): +def register_state_model(entity_name: str): """ Decorator to register a state model. Args: entity_name: Name of the entity (e.g., 'task', 'annotation') - denormalizer: Optional function to extract denormalized fields Example: @register_state_model('task') class TaskState(BaseState): - # ... implementation + @classmethod + def get_denormalized_fields(cls, entity): + return { + 'project_id': entity.project_id, + 'priority': entity.priority + } """ def decorator(state_model: 'BaseState') -> 'BaseState': - state_model_registry.register_model(entity_name, state_model, denormalizer) + state_model_registry.register_model(entity_name, state_model) return state_model return decorator -def register_state_model_class( - entity_name: str, state_model: 'BaseState', denormalizer: Optional[Callable[[Model], Dict[str, Any]]] = None -): +def register_state_model_class(entity_name: str, state_model: 'BaseState'): """ Convenience function to register a state model programmatically. Args: entity_name: Name of the entity (e.g., 'task', 'annotation') state_model: The state model class for this entity - denormalizer: Optional function to extract denormalized fields """ - state_model_registry.register_model(entity_name, state_model, denormalizer) + state_model_registry.register_model(entity_name, state_model) def get_state_model(entity_name: str) -> Optional['BaseState']: diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 17bb4f0d26f4..8daf9818e84f 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -12,7 +12,7 @@ from django.conf import settings from django.core.cache import cache from django.db import transaction -from django.db.models import Model +from django.db.models import Model, QuerySet from fsm.models import BaseState from fsm.registry import get_state_model_for_entity @@ -188,10 +188,8 @@ def transition_state( try: with transaction.atomic(): # INSERT-only approach - no UPDATE operations needed - # Get denormalized fields from the state model itself - denormalized_fields = {} - if hasattr(state_model, 'get_denormalized_fields'): - denormalized_fields = state_model.get_denormalized_fields(entity) + # Get denormalized fields from the state model class + denormalized_fields = state_model.get_denormalized_fields(entity) # Get organization from user's active organization organization_id = ( @@ -233,7 +231,7 @@ def transition_state( raise StateManagerError(f'Failed to transition state: {e}') from e @classmethod - def get_state_history(cls, entity: Model, limit: int = 100) -> List[BaseState]: + def get_state_history(cls, entity: Model, limit: int = 100) -> QuerySet[BaseState]: """ Get complete state history for an entity. @@ -242,14 +240,13 @@ def get_state_history(cls, entity: Model, limit: int = 100) -> List[BaseState]: limit: Maximum number of state records to return Returns: - List of state records ordered by most recent first + QuerySet of state records ordered by most recent first """ state_model = get_state_model_for_entity(entity) if not state_model: - return [] + raise StateManagerError(f'No state model registered for {entity._meta.model_name}') - entity_field = f'{entity._meta.model_name}' - return list(state_model.objects.filter(**{entity_field: entity}).order_by('-id')[:limit]) + return state_model.get_state_history(entity, limit) @classmethod def get_states_in_time_range( @@ -268,9 +265,9 @@ def get_states_in_time_range( """ state_model = get_state_model_for_entity(entity) if not state_model: - return [] + raise StateManagerError(f'No state model registered for {entity._meta.model_name}') - return list(state_model.get_states_in_range(entity, start_time, end_time or datetime.now())) + return state_model.get_states_in_range(entity, start_time, end_time or datetime.now()) @classmethod def invalidate_cache(cls, entity: Model): From 89cbd6eea5d7efbeb37c38ff7cce0a8f54729e9e Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 12:35:19 -0500 Subject: [PATCH 56/83] fixing tests --- label_studio/fsm/tests/test_registry.py | 38 ++++++++++++++----------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py index 55a1c2ae0a7d..20ba2755e51f 100644 --- a/label_studio/fsm/tests/test_registry.py +++ b/label_studio/fsm/tests/test_registry.py @@ -84,39 +84,43 @@ def execute(self, context): assert result.state == 'COMPLETED' def test_registry_state_model_with_denormalizer(self): - """Test StateModelRegistry with denormalizer function""" + """Test StateModelRegistry with state model that has get_denormalized_fields""" mock_state_model = Mock() mock_state_model.__name__ = 'MockStateModel' - def test_denormalizer(entity): - return {'custom_field': f'denormalized_{entity.pk}'} + # Mock the get_denormalized_fields classmethod + mock_state_model.get_denormalized_fields = Mock(return_value={'custom_field': 'denormalized_1'}) - # Register with denormalizer - state_model_registry.register_model('testentity', mock_state_model, test_denormalizer) + # Register the model (no denormalizer parameter anymore) + state_model_registry.register_model('testentity', mock_state_model) - # Check denormalizer was stored - denormalizer = state_model_registry.get_denormalizer('testentity') - assert denormalizer is not None + # Check model was registered + registered_model = state_model_registry.get_model('testentity') + assert registered_model is not None + assert registered_model == mock_state_model - result = denormalizer(self.entity) + # Test that get_denormalized_fields works on the model + result = mock_state_model.get_denormalized_fields(self.entity) assert result == {'custom_field': 'denormalized_1'} def test_registry_denormalizer_error_handling(self): - """Test denormalizer error handling in state model registry""" + """Test error handling when get_denormalized_fields raises an exception""" mock_state_model = Mock() mock_state_model.__name__ = 'MockStateModel' - def failing_denormalizer(entity): - raise RuntimeError('Denormalizer failed') + # Mock get_denormalized_fields to raise an error + mock_state_model.get_denormalized_fields = Mock(side_effect=RuntimeError('Denormalizer failed')) + + # Register the model + state_model_registry.register_model('testentity', mock_state_model) - state_model_registry.register_model('testentity', mock_state_model, failing_denormalizer) + # Test that the error is propagated correctly + with pytest.raises(RuntimeError) as exc_info: + mock_state_model.get_denormalized_fields(self.entity) - # Should handle denormalizer errors gracefully - denormalizer = state_model_registry.get_denormalizer('testentity') - with pytest.raises(RuntimeError): - denormalizer(self.entity) + assert 'Denormalizer failed' in str(exc_info.value) def test_registry_overwrite_warning(self): """Test warning when overwriting existing registry entries""" From 3319e90990023579f6297096bd0274466558d727 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 14:57:55 -0500 Subject: [PATCH 57/83] structure the logging so it can be made useful for metrics in grafana --- label_studio/fsm/registry.py | 25 ++- label_studio/fsm/state_manager.py | 239 +++++++++++++++------------ label_studio/fsm/transition_utils.py | 14 +- 3 files changed, 164 insertions(+), 114 deletions(-) diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index 31c2538f0d74..b5308f294e44 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -125,13 +125,25 @@ def register_model(self, entity_name: str, state_model: 'BaseState'): entity_key = entity_name.lower() if entity_key in self._models: - logger.warning( - f'Overwriting existing state model for {entity_key}. ' - f'Previous: {self._models[entity_key]}, New: {state_model}' + logger.debug( + 'Overwriting existing state model', + extra={ + 'event': 'fsm.registry_overwrite', + 'entity_type': entity_key, + 'previous_model': self._models[entity_key].__name__, + 'new_model': state_model.__name__, + }, ) self._models[entity_key] = state_model - logger.debug(f'Registered state model for {entity_key}: {state_model.__name__}') + logger.debug( + 'Registered state model', + extra={ + 'event': 'fsm.model_registered', + 'entity_type': entity_key, + 'model_name': state_model.__name__, + }, + ) def get_model(self, entity_name: str) -> Optional['BaseState']: """ @@ -152,7 +164,10 @@ def is_registered(self, entity_name: str) -> bool: def clear(self): """Clear all registered models (useful for testing).""" self._models.clear() - logger.debug('Cleared state model registry') + logger.debug( + 'State model registry cleared', + extra={'event': 'fsm.registry_cleared'}, + ) def get_all_models(self) -> Dict[str, 'BaseState']: """Get all registered models.""" diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 8daf9818e84f..8d875737ee66 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -7,7 +7,7 @@ import logging from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type from django.conf import settings from django.core.cache import cache @@ -16,10 +16,6 @@ from fsm.models import BaseState from fsm.registry import get_state_model_for_entity -# Avoid circular import -if TYPE_CHECKING: - from fsm.transitions import BaseTransition - logger = logging.getLogger(__name__) @@ -70,7 +66,9 @@ def get_current_state(cls, entity: Model) -> Optional[str]: entity: The entity to get current state for Returns: - Current state string or None if no states exist + Current state string + Raises: + StateManagerError: If no state model found Example: task = Task.objects.get(id=123) @@ -84,37 +82,54 @@ def get_current_state(cls, entity: Model) -> Optional[str]: # Try cache first cached_state = cache.get(cache_key) if cached_state is not None: - logger.debug(f'Cache hit for {entity._meta.label_lower} {entity.pk}: {cached_state}') + logger.info( + 'FSM cache hit', + extra={ + 'event': 'fsm.cache_hit', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'state': cached_state, + }, + ) return cached_state # Query database using state model registry state_model = get_state_model_for_entity(entity) if not state_model: - logger.warning(f'No state model found for {entity._meta.model_name}') - return None + raise StateManagerError(f'No state model found for {entity._meta.model_name} when getting current state') try: - entity_field = f'{entity._meta.model_name}' - current_state = ( - state_model.objects.filter(**{entity_field: entity}) - .order_by('-id') # UUID7 natural ordering - .values_list('state', flat=True) - .first() - ) + current_state = state_model.get_current_state(entity) # Cache result if current_state is not None: cache.set(cache_key, current_state, cls.CACHE_TTL) + logger.info( + 'FSM cache miss', + extra={ + 'event': 'fsm.cache_miss', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + }, + ) - logger.debug(f'Database query for {entity._meta.label_lower} {entity.pk}: {current_state}') return current_state except Exception as e: - logger.error(f'Error getting current state for {entity._meta.label_lower} {entity.pk}: {e}') + logger.error( + 'Error getting current state', + extra={ + 'event': 'fsm.get_state_error', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'error': str(e), + }, + exc_info=True, + ) return None @classmethod - def get_current_state_object(cls, entity: Model) -> Optional[BaseState]: + def get_current_state_object(cls, entity: Model) -> BaseState: """ Get current state object with full audit information. @@ -122,14 +137,18 @@ def get_current_state_object(cls, entity: Model) -> Optional[BaseState]: entity: The entity to get current state object for Returns: - Latest BaseState instance or None if no states exist + Latest BaseState instance + + Raises: + StateManagerError: If no state model found """ state_model = get_state_model_for_entity(entity) if not state_model: - return None + raise StateManagerError( + f'No state model found for {entity._meta.model_name} when getting current state object' + ) - entity_field = f'{entity._meta.model_name}' - return state_model.objects.filter(**{entity_field: entity}).order_by('-id').first() + return state_model.get_current_state() @classmethod def transition_state( @@ -176,26 +195,38 @@ def transition_state( """ state_model = get_state_model_for_entity(entity) if not state_model: - raise StateManagerError(f'No state model found for {entity._meta.model_name}') + raise StateManagerError(f'No state model found for {entity._meta.model_name} when transitioning state') current_state = cls.get_current_state(entity) - logger.info( - f'Transitioning {entity._meta.label_lower} {entity.pk}: ' - f'{current_state} → {new_state} (transition: {transition_name})' - ) - try: with transaction.atomic(): # INSERT-only approach - no UPDATE operations needed # Get denormalized fields from the state model class denormalized_fields = state_model.get_denormalized_fields(entity) - # Get organization from user's active organization - organization_id = ( - user.active_organization.id - if user and hasattr(user, 'active_organization') and user.active_organization - else None + # Get organization from entity or denormalized fields, or user's active organization + organization_id = getattr( + entity, 'organization_id', getattr(denormalized_fields, 'organization_id', None) + ) + + if not organization_id and user and hasattr(user, 'active_organization') and user.active_organization: + organization_id = user.active_organization.id + + logger.info( + 'State transition starting', + extra={ + 'event': 'fsm.transition_state_start', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'from_state': current_state, + 'to_state': new_state, + 'transition_name': transition_name, + **{ + 'user_id': user.id if user else None, + 'organization_id': organization_id if organization_id else None, + }, + }, ) new_state_record = state_model.objects.create( @@ -212,11 +243,37 @@ def transition_state( # Update cache with new state after transaction commits cache_key = cls.get_cache_key(entity) - transaction.on_commit(lambda: cache.set(cache_key, new_state, cls.CACHE_TTL)) + + def update_cache(key, state, user_id, org_id): + cache.set(key, state, cls.CACHE_TTL) + logger.info( + 'Cache updated for transition state', + extra={ + 'event': 'fsm.transition_state_cache_updated', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'state': state, + **{'user_id': user_id if user_id else None, 'organization_id': org_id if org_id else None}, + }, + ) + + transaction.on_commit( + lambda: update_cache(cache_key, new_state, user.id if user else None, organization_id) + ) logger.info( - f'State transition successful: {entity._meta.label_lower} {entity.pk} ' - f'now in state {new_state} (record ID: {new_state_record.id})' + 'State transition successful', + extra={ + 'event': 'fsm.transition_state_success', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'state': new_state, + 'state_record_id': str(new_state_record.id), + **{ + 'user_id': user.id if user else None, + 'organization_id': organization_id if organization_id else None, + }, + }, ) return True @@ -225,8 +282,20 @@ def transition_state( cache_key = cls.get_cache_key(entity) cache.delete(cache_key) logger.error( - f'State transition failed for {entity._meta.label_lower} {entity.pk}: ' - f'{current_state} → {new_state}: {e}' + 'State transition failed', + extra={ + 'event': 'fsm.transition_state_failed', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + 'from_state': current_state, + 'to_state': new_state, + 'error': str(e), + **{ + 'user_id': user.id if user else None, + 'organization_id': organization_id if organization_id else None, + }, + }, + exc_info=True, ) raise StateManagerError(f'Failed to transition state: {e}') from e @@ -244,7 +313,9 @@ def get_state_history(cls, entity: Model, limit: int = 100) -> QuerySet[BaseStat """ state_model = get_state_model_for_entity(entity) if not state_model: - raise StateManagerError(f'No state model registered for {entity._meta.model_name}') + raise StateManagerError( + f'No state model registered for {entity._meta.model_name} when getting state history' + ) return state_model.get_state_history(entity, limit) @@ -265,7 +336,9 @@ def get_states_in_time_range( """ state_model = get_state_model_for_entity(entity) if not state_model: - raise StateManagerError(f'No state model registered for {entity._meta.model_name}') + raise StateManagerError( + f'No state model registered for {entity._meta.model_name} when getting states in time range' + ) return state_model.get_states_in_range(entity, start_time, end_time or datetime.now()) @@ -274,7 +347,16 @@ def invalidate_cache(cls, entity: Model): """Invalidate cached state for an entity""" cache_key = cls.get_cache_key(entity) cache.delete(cache_key) - logger.debug(f'Invalidated cache for {entity._meta.label_lower} {entity.pk}') + organization_id = getattr(entity, 'organization_id', None) + logger.info( + 'Cache invalidated', + extra={ + 'event': 'fsm.cache_invalidated', + 'entity_type': entity._meta.label_lower, + 'entity_id': entity.pk, + **{'organization_id': organization_id if organization_id else None}, + }, + ) @classmethod def warm_cache(cls, entities: List[Model]): @@ -285,7 +367,11 @@ def warm_cache(cls, entities: List[Model]): bulk queries and advanced caching strategies. """ cache_updates = {} + organization_id = None for entity in entities: + if organization_id is None: + if hasattr(entity, 'organization_id'): + organization_id = entity.organization_id current_state = cls.get_current_state(entity) if current_state: cache_key = cls.get_cache_key(entity) @@ -293,71 +379,14 @@ def warm_cache(cls, entities: List[Model]): if cache_updates: cache.set_many(cache_updates, cls.CACHE_TTL) - logger.debug(f'Warmed cache for {len(cache_updates)} entities') - - @classmethod - def execute_declarative_transition( - cls, transition: 'BaseTransition', entity: Model, user=None, **context_kwargs - ) -> BaseState: - """ - Execute a declarative Pydantic-based transition. - - This method integrates the new declarative transition system with - the existing StateManager, providing a bridge between the two approaches. - - Args: - transition: Instance of a BaseTransition subclass - entity: The entity to transition - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - - Raises: - TransitionValidationError: If transition validation fails - StateManagerError: If transition execution fails - """ - from .transitions import TransitionContext - - # Get current state information - current_state_object = cls.get_current_state_object(entity) - current_state = current_state_object.state if current_state_object else None - - # Build transition context - context = TransitionContext( - entity=entity, - current_user=user, - current_state_object=current_state_object, - current_state=current_state, - target_state=transition.target_state, - organization_id=getattr(entity, 'organization_id', None), - **context_kwargs, - ) - - logger.info( - f'Executing declarative transition {transition.__class__.__name__} ' - f'for {entity._meta.label_lower} {entity.pk}: ' - f'{current_state} → {transition.target_state}' - ) - - try: - # Execute the transition through the declarative system - state_record = transition.execute(context) - logger.info( - f'Declarative transition successful: {entity._meta.label_lower} {entity.pk} ' - f'now in state {transition.target_state} (record ID: {state_record.id})' - ) - - return state_record - - except Exception as e: - logger.error( - f'Declarative transition failed for {entity._meta.label_lower} {entity.pk}: ' - f'{current_state} → {transition.target_state}: {e}' + 'Cache warmed', + extra={ + 'event': 'fsm.cache_warmed', + 'entity_count': len(cache_updates), + **{'organization_id': organization_id if organization_id else None}, + }, ) - raise @classmethod def execute_transition( diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 682259a85e06..c098ace57dea 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -5,6 +5,7 @@ the new Pydantic-based transition system with existing Label Studio code. """ +import logging from typing import Any, Dict, List, Type from django.db.models import Model @@ -12,6 +13,8 @@ from fsm.state_manager import StateManager from fsm.transitions import BaseTransition, TransitionValidationError +logger = logging.getLogger(__name__) + def get_available_transitions(entity: Model, user=None, validate: bool = False) -> Dict[str, Type[BaseTransition]]: """ @@ -78,11 +81,14 @@ def get_available_transitions(entity: Model, user=None, validate: bool = False) continue except Exception as e: # Unexpected error during validation - this should be investigated - import logging - - logger = logging.getLogger(__name__) logger.warning( - f"Unexpected error validating transition '{name}' for entity {entity._meta.model_name}: {e}", + 'Unexpected error validating transition', + extra={ + 'event': 'fsm.transition_validation_error', + 'transition_name': name, + 'entity_type': entity._meta.model_name, + 'error': str(e), + }, exc_info=True, ) continue From 4d42dcb95151e12e1c3c93b1a6a91b7693df8e63 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 15:05:24 -0500 Subject: [PATCH 58/83] fix tests --- label_studio/fsm/state_manager.py | 9 ++++++++- .../fsm/tests/test_fsm_integration.py | 6 +++--- label_studio/fsm/tests/test_registry.py | 19 ++++++++++++------- label_studio/fsm/tests/test_utils.py | 5 +---- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 8d875737ee66..2987ad997c60 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -96,6 +96,13 @@ def get_current_state(cls, entity: Model) -> Optional[str]: # Query database using state model registry state_model = get_state_model_for_entity(entity) if not state_model: + logger.error( + 'No state model found', + extra={ + 'event': 'fsm.state_model_not_found', + 'entity_type': entity._meta.model_name, + }, + ) raise StateManagerError(f'No state model found for {entity._meta.model_name} when getting current state') try: @@ -148,7 +155,7 @@ def get_current_state_object(cls, entity: Model) -> BaseState: f'No state model found for {entity._meta.model_name} when getting current state object' ) - return state_model.get_current_state() + return state_model.get_current_state(entity) @classmethod def transition_state( diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 4d97faad38c1..dabb5c09ecfe 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -347,9 +347,9 @@ def test_transaction_on_commit_failure_case(self, mock_get_state_model, mock_on_ # Verify transaction.on_commit was NOT called since transition failed assert mock_on_commit.call_count == 0 - # Verify cache was not updated (should still be None) - current_state = self.StateManager.get_current_state(self.task) - assert current_state is None + # Verify cache was not updated (should raise exception) + with pytest.raises(Exception): # Should raise StateManagerError + self.StateManager.get_current_state(self.task) @patch('django.db.transaction.on_commit') @patch('fsm.models.TaskState.objects.create') diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py index 20ba2755e51f..74568c65e32b 100644 --- a/label_studio/fsm/tests/test_registry.py +++ b/label_studio/fsm/tests/test_registry.py @@ -137,13 +137,18 @@ def test_registry_overwrite_warning(self): with patch('fsm.registry.logger') as mock_logger: state_model_registry.register_model('testentity', mock_state_model2) - # Should have logged warning about overwrite - mock_logger.warning.assert_called_once() - warning_msg = mock_logger.warning.call_args[0][0] - assert 'Overwriting existing state model' in warning_msg - assert 'testentity' in warning_msg - assert 'Previous:' in warning_msg - assert 'New:' in warning_msg + # Should have logged debug about overwrite + mock_logger.debug.assert_called() + # Find the call that has the overwrite message + debug_calls = mock_logger.debug.call_args_list + overwrite_call = None + for call in debug_calls: + if 'Overwriting existing state model' in call[0][0]: + overwrite_call = call + break + assert overwrite_call is not None, 'Expected debug log about overwriting existing state model' + debug_msg = overwrite_call[0][0] + assert 'Overwriting existing state model' in debug_msg def test_registry_clear_methods(self): """Test registry clear methods""" diff --git a/label_studio/fsm/tests/test_utils.py b/label_studio/fsm/tests/test_utils.py index 55b18b352afd..7a2e4f798621 100644 --- a/label_studio/fsm/tests/test_utils.py +++ b/label_studio/fsm/tests/test_utils.py @@ -230,10 +230,7 @@ def can_transition_from_state(cls, context): transition_registry.register('testentity', 'broken_transition', BrokenTransition) # Should handle the error gracefully and log warning - import logging - - mock_logger = Mock() - with patch.object(logging, 'getLogger', return_value=mock_logger): + with patch('fsm.transition_utils.logger') as mock_logger: result = get_available_transitions(self.entity, validate=True) # Should not include the broken transition assert 'broken_transition' not in result From d251c6119d4ac735e938c2428cd2f9d079ed52c4 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 15:24:25 -0500 Subject: [PATCH 59/83] fix tests --- label_studio/fsm/state_manager.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 2987ad997c60..f31504268604 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -58,7 +58,7 @@ def get_cache_key(cls, entity: Model) -> str: return f'{cls.CACHE_PREFIX}:{entity._meta.label_lower}:{entity.pk}' @classmethod - def get_current_state(cls, entity: Model) -> Optional[str]: + def get_current_state_value(cls, entity: Model) -> Optional[str]: """ Get current state with basic caching. @@ -72,7 +72,7 @@ def get_current_state(cls, entity: Model) -> Optional[str]: Example: task = Task.objects.get(id=123) - current_state = StateManager.get_current_state(task) + current_state = StateManager.get_current_state_value(task) if current_state == 'COMPLETED': # Task is finished pass @@ -96,17 +96,10 @@ def get_current_state(cls, entity: Model) -> Optional[str]: # Query database using state model registry state_model = get_state_model_for_entity(entity) if not state_model: - logger.error( - 'No state model found', - extra={ - 'event': 'fsm.state_model_not_found', - 'entity_type': entity._meta.model_name, - }, - ) raise StateManagerError(f'No state model found for {entity._meta.model_name} when getting current state') try: - current_state = state_model.get_current_state(entity) + current_state = state_model.get_current_state_value(entity) # Cache result if current_state is not None: @@ -133,7 +126,7 @@ def get_current_state(cls, entity: Model) -> Optional[str]: }, exc_info=True, ) - return None + raise StateManagerError(f'Error getting current state: {e}') from e @classmethod def get_current_state_object(cls, entity: Model) -> BaseState: From 9120a18e6d940c24d155d3dade7e55ce6763b17a Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 15:28:22 -0500 Subject: [PATCH 60/83] fix implementation --- label_studio/fsm/state_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 8d875737ee66..80a1747855f4 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -148,7 +148,7 @@ def get_current_state_object(cls, entity: Model) -> BaseState: f'No state model found for {entity._meta.model_name} when getting current state object' ) - return state_model.get_current_state() + return state_model.get_current_state(entity) @classmethod def transition_state( From 9d3fb7f7e1e9fc1c8c5677e206da8b0b8736f825 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Thu, 4 Sep 2025 16:47:28 -0500 Subject: [PATCH 61/83] fix implementation --- label_studio/fsm/state_manager.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index f31504268604..f7315c72cef2 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -128,6 +128,11 @@ def get_current_state_value(cls, entity: Model) -> Optional[str]: ) raise StateManagerError(f'Error getting current state: {e}') from e + @classmethod + def get_current_state(cls, entity: Model) -> Optional[str]: + """Backward compatibility method - calls get_current_state_value""" + return cls.get_current_state_value(entity) + @classmethod def get_current_state_object(cls, entity: Model) -> BaseState: """ From e9dfe8922dd9b81187ee33b5da0211bcfa276f68 Mon Sep 17 00:00:00 2001 From: bmartel Date: Fri, 5 Sep 2025 14:46:14 +0000 Subject: [PATCH 62/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17496376340 --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 5eed783d2c90..c28e9e6d3739 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5075,4 +5075,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "c3005147d2d86bd16aa870ce669ef7ecdfa5fe9730de6d0d880cee1857e7d28b" +content-hash = "b943e7ed6a37673055a32ec57543fbbe8e36c2335460f4942569b0f6b0ec319f" From e34002c7e25ba16b775640e76994c7a90ac20798 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 10:19:13 -0500 Subject: [PATCH 63/83] fix initial migration --- label_studio/fsm/migrations/0001_initial.py | 1 + 1 file changed, 1 insertion(+) diff --git a/label_studio/fsm/migrations/0001_initial.py b/label_studio/fsm/migrations/0001_initial.py index d09e6a4769c7..623bacb73d74 100644 --- a/label_studio/fsm/migrations/0001_initial.py +++ b/label_studio/fsm/migrations/0001_initial.py @@ -7,6 +7,7 @@ class Migration(migrations.Migration): + atomic = False initial = True From 6f21e6b768fde509dd1bd548aeb3cdfcfb8cf36e Mon Sep 17 00:00:00 2001 From: bmartel Date: Fri, 5 Sep 2025 15:22:33 +0000 Subject: [PATCH 64/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17497319557 From db1ed5f72a84754fbc62ab29127f083fac2abccf Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 12:06:52 -0500 Subject: [PATCH 65/83] refactor code to remove function imports --- label_studio/fsm/registry.py | 77 +--------------- label_studio/fsm/state_manager.py | 37 -------- label_studio/fsm/transition_executor.py | 118 ++++++++++++++++++++++++ label_studio/fsm/transition_utils.py | 11 +++ label_studio/fsm/transitions.py | 75 +++++++++------ 5 files changed, 181 insertions(+), 137 deletions(-) create mode 100644 label_studio/fsm/transition_executor.py diff --git a/label_studio/fsm/registry.py b/label_studio/fsm/registry.py index b5308f294e44..5311bba468fd 100644 --- a/label_studio/fsm/registry.py +++ b/label_studio/fsm/registry.py @@ -7,13 +7,13 @@ import logging import typing -from typing import Any, Dict, Optional, Type +from typing import Dict, Optional, Type from django.db.models import Model, TextChoices if typing.TYPE_CHECKING: from fsm.models import BaseState - from fsm.transitions import BaseTransition, User + from fsm.transitions import BaseTransition logger = logging.getLogger(__name__) @@ -295,67 +295,12 @@ def clear(self): """ self._transitions.clear() - def execute_transition( - self, - entity_name: str, - transition_name: str, - entity: Model, - transition_data: Dict[str, Any], - user: Optional['User'] = None, - **context_kwargs, - ) -> 'BaseState': - """ - Execute a registered transition. - - Args: - entity_name: Name of the entity type - transition_name: Name of the transition - entity: The entity instance to transition - transition_data: Data for the transition (will be validated by Pydantic) - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - - Raises: - ValueError: If transition is not found - TransitionValidationError: If transition validation fails - """ - transition_class = self.get_transition(entity_name, transition_name) - if not transition_class: - raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") - - # Create transition instance with provided data - transition = transition_class(**transition_data) - - # Get current state information - from fsm.state_manager import StateManager - from fsm.transitions import TransitionContext - - current_state_object = StateManager.get_current_state_object(entity) - current_state = current_state_object.state if current_state_object else None - - # Build transition context - context = TransitionContext( - entity=entity, - current_user=user, - current_state_object=current_state_object, - current_state=current_state, - target_state=transition.target_state, - organization_id=getattr(entity, 'organization_id', None), - **context_kwargs, - ) - - # Execute the transition - return transition.execute(context) - # Global transition registry instance transition_registry = TransitionRegistry() -def register_state_transition(entity_name: str, transition_name: str = None): +def register_state_transition(entity_name: str, transition_name: str): """ Decorator to register a state transition class. @@ -370,21 +315,7 @@ class StartTaskTransition(BaseTransition[Task, TaskState]): """ def decorator(transition_class: 'BaseTransition') -> 'BaseTransition': - name = transition_name - if name is None: - # Generate name from class name - class_name = transition_class.__name__ - if class_name.endswith('Transition'): - class_name = class_name[:-10] # Remove 'Transition' suffix - - # Convert CamelCase to snake_case - name = '' - for i, char in enumerate(class_name): - if char.isupper() and i > 0: - name += '_' - name += char.lower() - - transition_registry.register(entity_name, name, transition_class) + transition_registry.register(entity_name, transition_name, transition_class) return transition_class return decorator diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 80a1747855f4..0ce8d0b835b4 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -388,43 +388,6 @@ def warm_cache(cls, entities: List[Model]): }, ) - @classmethod - def execute_transition( - cls, entity: Model, transition_name: str, transition_data: Dict[str, Any] = None, user=None, **context_kwargs - ) -> BaseState: - """ - Execute a registered transition by name. - - This is the unified entry point for all state transitions using the declarative system. - - Args: - entity: The entity to transition - transition_name: Name of the registered transition - transition_data: Data for the transition (validated by Pydantic) - user: User executing the transition - **context_kwargs: Additional context data - - Returns: - The newly created state record - - Raises: - ValueError: If transition is not found - TransitionValidationError: If transition validation fails - """ - from .registry import transition_registry - - entity_name = entity._meta.model_name.lower() - transition_data = transition_data or {} - - return transition_registry.execute_transition( - entity_name=entity_name, - transition_name=transition_name, - entity=entity, - transition_data=transition_data, - user=user, - **context_kwargs, - ) - # Allow runtime configuration of which StateManager to use # Enterprise can set this to their extended implementation diff --git a/label_studio/fsm/transition_executor.py b/label_studio/fsm/transition_executor.py new file mode 100644 index 000000000000..afa0acb22a92 --- /dev/null +++ b/label_studio/fsm/transition_executor.py @@ -0,0 +1,118 @@ +""" +Transition execution orchestrator for the FSM engine. + +This module serves as the top-level orchestrator for state transitions, +importing from both state_manager and transitions to coordinate execution +without creating circular dependencies. + +No other FSM modules should import from this module. +""" + +import logging +from typing import Any, Dict + +from django.db.models import Model +from fsm.models import BaseState +from fsm.registry import get_state_model_for_entity, transition_registry +from fsm.state_manager import StateManager +from fsm.transitions import TransitionContext + +logger = logging.getLogger(__name__) + + +def execute_transition( + entity: Model, + transition_name: str, + transition_data: Dict[str, Any] = None, + user=None, + **context_kwargs, +) -> BaseState: + """ + Execute a registered transition by name. + + This is the main entry point for executing state transitions in the FSM system. + It coordinates between the registry, transitions, and state manager without + creating circular dependencies. + + Args: + entity: The entity to transition + transition_name: Name of the registered transition + transition_data: Data for the transition (validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found or state model is not registered + TransitionValidationError: If transition validation fails + """ + entity_name = entity._meta.model_name.lower() + transition_data = transition_data or {} + + # Get the transition class from registry + transition_class = transition_registry.get_transition(entity_name, transition_name) + if not transition_class: + raise ValueError(f"Transition '{transition_name}' not found for entity '{entity_name}'") + + # Get the state model for the entity + state_model = get_state_model_for_entity(entity) + if not state_model: + raise ValueError(f"No state model registered for entity '{entity_name}'") + + # Create transition instance with provided data + transition = transition_class(**transition_data) + + # Get current state information directly from state model + current_state_object = state_model.get_current_state(entity) + current_state = current_state_object.state if current_state_object else None + + # Build transition context + context = TransitionContext( + entity=entity, + current_user=user, + current_state_object=current_state_object, + current_state=current_state, + target_state=transition.target_state, + organization_id=getattr(entity, 'organization_id', None), + **context_kwargs, + ) + + logger.info( + 'Executing transition', + extra={ + 'event': 'fsm.transition_execute', + 'entity_type': entity_name, + 'entity_id': entity.pk, + 'transition_name': transition_name, + 'from_state': current_state, + 'to_state': transition.target_state, + 'user_id': user.id if user else None, + }, + ) + + # Execute the transition in phases + # Phase 1: Prepare and validate the transition + transition_context_data = transition.prepare_and_validate(context) + + # Phase 2: Create the state record via StateManager + success = StateManager.transition_state( + entity=entity, + new_state=transition.target_state, + transition_name=transition.transition_name, + user=user, + context=transition_context_data, + reason=transition.get_reason(context), + ) + + if not success: + raise ValueError(f'Failed to create state record for {transition_name}') + + # Get the newly created state record + state_record = StateManager.get_current_state_object(entity) + + # Phase 3: Finalize the transition + transition.finalize(context, state_record) + + return state_record diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index c098ace57dea..5712bfa7df97 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -11,10 +11,21 @@ from django.db.models import Model from fsm.registry import transition_registry from fsm.state_manager import StateManager +from fsm.transition_executor import execute_transition from fsm.transitions import BaseTransition, TransitionValidationError logger = logging.getLogger(__name__) +# Re-export execute_transition for convenience +__all__ = [ + 'execute_transition', + 'get_available_transitions', + 'create_transition_from_dict', + 'get_transition_schema', + 'validate_transition_data', + 'get_entity_state_flow', +] + def get_available_transitions(entity: Model, user=None, validate: bool = False) -> Dict[str, Type[BaseTransition]]: """ diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index 340979f929fa..a65449ace125 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -258,27 +258,24 @@ def get_reason(self, context: TransitionContext[EntityType, StateModelType]) -> user_info = f'by {context.current_user}' if context.current_user else 'automatically' return f'{self.__class__.__name__} executed {user_info}' - def execute(self, context: TransitionContext[EntityType, StateModelType]) -> StateModelType: + def prepare_and_validate(self, context: TransitionContext[EntityType, StateModelType]) -> Dict[str, Any]: """ - Execute the complete transition workflow. + Prepare and validate the transition, returning the transition data. - This orchestrates the entire transition process: + This method handles the preparation phase of the transition: 1. Set context on the transition instance 2. Validate the transition 3. Execute pre-transition hooks - 4. Perform the actual transition - 5. Create the state record - 6. Execute post-transition hooks + 4. Perform the actual transition logic Args: context: The transition context Returns: - The newly created state record + Dictionary of transition data to be stored with the state record Raises: TransitionValidationError: If validation fails - Exception: If transition execution fails """ # Set context for access during transition self.context = context @@ -300,30 +297,54 @@ def execute(self, context: TransitionContext[EntityType, StateModelType]) -> Sta # Execute the transition logic transition_data = self.transition(context) - # Create the state record through StateManager - from .state_manager import StateManager + return transition_data - success = StateManager.transition_state( - entity=context.entity, - new_state=self.target_state, - transition_name=self.transition_name, - user=context.current_user, - context=transition_data, - reason=self.get_reason(context), - ) + except Exception: + # Clear context on error + self.context = None + raise - if not success: - raise TransitionValidationError(f'Failed to create state record for {self.transition_name}') + def finalize(self, context: TransitionContext[EntityType, StateModelType], state_record: StateModelType) -> None: + """ + Finalize the transition after the state record has been created. - # Get the newly created state record - state_record = StateManager.get_current_state_object(context.entity) + This method handles post-transition activities: + 1. Execute post-transition hooks + 2. Clear the context + Args: + context: The transition context + state_record: The newly created state record + """ + try: # Post-transition hook self.post_transition_hook(context, state_record) + finally: + # Always clear context when done + self.context = None - return state_record + def execute(self, context: TransitionContext[EntityType, StateModelType]) -> StateModelType: + """ + Execute the complete transition workflow. - except Exception: - # Clear context on error - self.context = None - raise + NOTE: This method is provided for backward compatibility but should not be called + directly to avoid circular imports. Use TransitionExecutor.execute() instead. + + The actual execution flow is: + 1. TransitionExecutor calls prepare_and_validate() + 2. TransitionExecutor creates the state record via StateManager + 3. TransitionExecutor calls finalize() + + Args: + context: The transition context + + Returns: + The newly created state record + + Raises: + NotImplementedError: This method should not be called directly + """ + raise NotImplementedError( + 'Direct execution of transitions is not supported to avoid circular imports. ' + 'Use TransitionExecutor.execute() or StateManager.execute_transition() instead.' + ) From 69bee71b870aa8fe99a2943d7cf82e0741a9abe9 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 12:34:23 -0500 Subject: [PATCH 66/83] refactor code to remove redundant code --- label_studio/fsm/state_manager.py | 40 +++++++++++++++++++++++++ label_studio/fsm/transition_executor.py | 34 ++++++++++----------- label_studio/fsm/transition_utils.py | 15 ++-------- 3 files changed, 59 insertions(+), 30 deletions(-) diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index 0ce8d0b835b4..299f6eff6d2d 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -15,6 +15,7 @@ from django.db.models import Model, QuerySet from fsm.models import BaseState from fsm.registry import get_state_model_for_entity +from fsm.transition_executor import execute_transition_with_state_manager logger = logging.getLogger(__name__) @@ -388,10 +389,45 @@ def warm_cache(cls, entities: List[Model]): }, ) + @classmethod + def execute_transition( + cls, entity: Model, transition_name: str, transition_data: Dict[str, Any] = None, user=None, **context_kwargs + ) -> BaseState: + """ + Execute a registered transition by name. + + This is the main entry point for all state transitions using the declarative system. + Enterprise implementations can override this method to add additional behavior. + + Args: + entity: The entity to transition + transition_name: Name of the registered transition + transition_data: Data for the transition (validated by Pydantic) + user: User executing the transition + **context_kwargs: Additional context data + + Returns: + The newly created state record + + Raises: + ValueError: If transition is not found + TransitionValidationError: If transition validation fails + """ + # Delegate to transition executor, passing StateManager methods as parameters + return execute_transition_with_state_manager( + entity=entity, + transition_name=transition_name, + transition_data=transition_data, + user=user, + state_manager_class=cls, + **context_kwargs, + ) + # Allow runtime configuration of which StateManager to use # Enterprise can set this to their extended implementation DEFAULT_STATE_MANAGER = StateManager +RESOLVED_STATE_MANAGER = None def get_state_manager() -> Type[StateManager]: @@ -401,6 +437,10 @@ def get_state_manager() -> Type[StateManager]: Returns the StateManager class to use. Enterprise can override this by setting a different class in their configuration. """ + # Resolve once + if RESOLVED_STATE_MANAGER is not None: + return RESOLVED_STATE_MANAGER + # Check if enterprise has configured a custom state manager if hasattr(settings, 'FSM_STATE_MANAGER_CLASS'): manager_path = settings.FSM_STATE_MANAGER_CLASS diff --git a/label_studio/fsm/transition_executor.py b/label_studio/fsm/transition_executor.py index afa0acb22a92..ff3a8ed5cf99 100644 --- a/label_studio/fsm/transition_executor.py +++ b/label_studio/fsm/transition_executor.py @@ -1,44 +1,42 @@ """ Transition execution orchestrator for the FSM engine. -This module serves as the top-level orchestrator for state transitions, -importing from both state_manager and transitions to coordinate execution -without creating circular dependencies. - -No other FSM modules should import from this module. +This module handles the execution of state transitions, coordinating between +the registry and transitions without importing StateManager to avoid circular dependencies. +StateManager imports from this module and provides its methods as parameters. """ import logging -from typing import Any, Dict +from typing import Any, Dict, Type from django.db.models import Model from fsm.models import BaseState from fsm.registry import get_state_model_for_entity, transition_registry -from fsm.state_manager import StateManager from fsm.transitions import TransitionContext logger = logging.getLogger(__name__) -def execute_transition( +def execute_transition_with_state_manager( entity: Model, transition_name: str, - transition_data: Dict[str, Any] = None, - user=None, + transition_data: Dict[str, Any], + user, + state_manager_class: Type, **context_kwargs, ) -> BaseState: """ - Execute a registered transition by name. + Execute a registered transition using StateManager methods passed as parameters. - This is the main entry point for executing state transitions in the FSM system. - It coordinates between the registry, transitions, and state manager without - creating circular dependencies. + This function is called by StateManager.execute_transition() to avoid circular imports. + StateManager imports this module and passes itself as a parameter. Args: entity: The entity to transition transition_name: Name of the registered transition transition_data: Data for the transition (validated by Pydantic) user: User executing the transition + state_manager_class: The StateManager class to use for state operations **context_kwargs: Additional context data Returns: @@ -96,8 +94,8 @@ def execute_transition( # Phase 1: Prepare and validate the transition transition_context_data = transition.prepare_and_validate(context) - # Phase 2: Create the state record via StateManager - success = StateManager.transition_state( + # Phase 2: Create the state record via StateManager methods + success = state_manager_class.transition_state( entity=entity, new_state=transition.target_state, transition_name=transition.transition_name, @@ -109,8 +107,8 @@ def execute_transition( if not success: raise ValueError(f'Failed to create state record for {transition_name}') - # Get the newly created state record - state_record = StateManager.get_current_state_object(entity) + # Get the newly created state record via StateManager + state_record = state_manager_class.get_current_state_object(entity) # Phase 3: Finalize the transition transition.finalize(context, state_record) diff --git a/label_studio/fsm/transition_utils.py b/label_studio/fsm/transition_utils.py index 5712bfa7df97..854a515f0f89 100644 --- a/label_studio/fsm/transition_utils.py +++ b/label_studio/fsm/transition_utils.py @@ -10,21 +10,12 @@ from django.db.models import Model from fsm.registry import transition_registry -from fsm.state_manager import StateManager -from fsm.transition_executor import execute_transition +from fsm.state_manager import get_state_manager from fsm.transitions import BaseTransition, TransitionValidationError logger = logging.getLogger(__name__) -# Re-export execute_transition for convenience -__all__ = [ - 'execute_transition', - 'get_available_transitions', - 'create_transition_from_dict', - 'get_transition_schema', - 'validate_transition_data', - 'get_entity_state_flow', -] +StateManager = get_state_manager() def get_available_transitions(entity: Model, user=None, validate: bool = False) -> Dict[str, Type[BaseTransition]]: @@ -53,7 +44,7 @@ def get_available_transitions(entity: Model, user=None, validate: bool = False) for name, transition_class in available.items(): try: - # Get current state information + # Get current state information using potentially overridden StateManager current_state_object = StateManager.get_current_state_object(entity) current_state = current_state_object.state if current_state_object else None From 9e6ade79ac0345d179f18fbc8624f3623b845d28 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 12:39:09 -0500 Subject: [PATCH 67/83] remove no longer required test --- label_studio/fsm/tests/test_registry.py | 39 ------------------------- 1 file changed, 39 deletions(-) diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py index 74568c65e32b..0f229568b618 100644 --- a/label_studio/fsm/tests/test_registry.py +++ b/label_studio/fsm/tests/test_registry.py @@ -44,45 +44,6 @@ def setUp(self): self.entity = MockEntity() - def test_registry_execute_transition_integration(self): - """Test TransitionRegistry.execute_transition method""" - - class SimpleTransition(BaseTransition): - """Simple transition for testing""" - - message: str = Field(default='test') - - @property - def target_state(self) -> str: - return 'COMPLETED' - - def transition(self, context): - return {'message': self.message} - - def execute(self, context): - # Create a mock state record - state_record = Mock() - state_record.id = 'test-uuid' - state_record.state = self.target_state - return state_record - - transition_registry.register('testentity', 'simple_transition', SimpleTransition) - - # Mock the StateManager methods used in TransitionRegistry - with patch.object(StateManager, 'get_current_state_object') as mock_get_state: - mock_get_state.return_value = None - - result = transition_registry.execute_transition( - entity_name='testentity', - transition_name='simple_transition', - entity=self.entity, - transition_data={'message': 'Hello'}, - user=None, - ) - - assert result is not None - assert result.state == 'COMPLETED' - def test_registry_state_model_with_denormalizer(self): """Test StateModelRegistry with state model that has get_denormalized_fields""" From 25e655e2e78b796ae03c72433ac6517690289bef Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 13:04:32 -0500 Subject: [PATCH 68/83] updating docs --- label_studio/fsm/README.md | 284 ++++++++++++++------------------ label_studio/fsm/transitions.py | 28 +--- 2 files changed, 125 insertions(+), 187 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index d2c7e2fbd9b6..5f3a624f4a7f 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -9,9 +9,8 @@ The FSM framework provides: - **Core Infrastructure**: Abstract base state models and managers - **UUID7 Optimization**: Time-series optimized state records with natural ordering - **Declarative Transitions**: Pydantic-based transition system with validation -- **REST API**: Generic endpoints for state management - **High Performance**: Optimized for high-volume state changes with caching -- **Extensible**: Plugin-based architecture for custom implementations +- **Extensible**: StateManager serves as the main extension point ## Architecture @@ -21,6 +20,20 @@ The FSM framework provides: 2. **StateManager**: High-performance state management with intelligent caching 3. **Transition System**: Declarative, Pydantic-based transitions with validation 4. **State Registry**: Dynamic registration system for entity states, choices and transitions +5. **TransitionExecutor**: Orchestrates transition execution + +### Module Architecture + +``` +models.py → registry.py → transitions.py + transition_executor.py → state_manager.py → transition_utils.py +``` + +- **models.py**: Base state model and UUID7 utilities +- **registry.py**: Registries for state models and transitions +- **transitions.py**: BaseTransition class and validation logic +- **transition_executor.py**: Orchestrates transition execution +- **state_manager.py**: **Main entry point** - implementations extend this +- **transition_utils.py**: Convenience functions using `get_state_manager()` ## Quick Start @@ -40,22 +53,13 @@ class OrderStateChoices(models.TextChoices): CANCELLED = 'CANCELLED', _('Cancelled') ``` -### 2. Create State Model with Optional Denormalizer +### 2. Create State Model ```python from fsm.models import BaseState from fsm.registry import register_state_model -# Optional: Define denormalizer for performance optimization -def denormalize_order(entity): - """Extract frequently queried fields to avoid JOINs.""" - return { - 'customer_id': entity.customer_id, - 'store_id': entity.store_id, - 'total_amount': entity.total_amount, - } - -@register_state_model('order', denormalizer=denormalize_order) +@register_state_model('order') class OrderState(BaseState): # Entity relationship order = models.ForeignKey('shop.Order', related_name='fsm_states', on_delete=models.CASCADE) @@ -63,35 +67,27 @@ class OrderState(BaseState): # Override state field with choices state = models.CharField(max_length=50, choices=OrderStateChoices.choices, db_index=True) - # Denormalized fields for performance (automatically populated by denormalizer) + # Denormalized fields for performance customer_id = models.PositiveIntegerField(db_index=True) store_id = models.PositiveIntegerField(db_index=True) total_amount = models.DecimalField(max_digits=10, decimal_places=2) - class Meta: - indexes = [ - models.Index(fields=['order_id', '-id'], name='order_current_state_idx'), - ] -``` - -### 3. Alternative: Use Built-in State Model Methods - -For simpler use cases, state models can define denormalization directly: - -```python -class OrderState(BaseState): - # ... fields ... - @classmethod def get_denormalized_fields(cls, entity): - """Built-in method for denormalization without registry.""" + """Extract frequently queried fields to avoid JOINs.""" return { 'customer_id': entity.customer_id, 'store_id': entity.store_id, + 'total_amount': entity.total_amount, } + + class Meta: + indexes = [ + models.Index(fields=['order_id', '-id'], name='order_current_state_idx'), + ] ``` -### 4. Define Transitions +### 3. Define Transitions ```python from fsm.transitions import BaseTransition @@ -118,12 +114,15 @@ class ProcessOrderTransition(BaseTransition): } ``` -### 5. Execute Transitions +### 4. Execute Transitions + +**Main API - StateManager (Extensible):** ```python from fsm.state_manager import StateManager -# Execute transition - this is the only way to execute transitions +# This is the primary way to execute transitions +# implementations extend StateManager.execute_transition() result = StateManager.execute_transition( entity=order, transition_name='process_order', @@ -132,7 +131,7 @@ result = StateManager.execute_transition( ) ``` -### 6. Query States +### 5. Query States ```python from fsm.state_manager import get_state_manager @@ -144,35 +143,44 @@ current_state = StateManager.get_current_state(order) # Get state history history = StateManager.get_state_history(order, limit=10) -``` -## Key Features +# Get current state object (full details) +current_state_obj = StateManager.get_current_state_object(order) +``` -### Denormalization for Performance +## Extensibility -- **Avoid JOINs**: Copy frequently queried fields to state records -- **Registry-based**: Register denormalizers with state models -- **Automatic**: Fields are populated during state transitions -- **Flexible**: Use registry decorator or built-in class method +**StateManager is the main extension point for implementations:** ```python -# Using registry decorator -@register_state_model('task', denormalizer=lambda t: {'project_id': t.project_id}) -class TaskState(BaseState): - project_id = models.IntegerField(db_index=True) - # ... - -# Using built-in method -class TaskState(BaseState): +from fsm.state_manager import StateManager + +class MyStateManager(StateManager): @classmethod - def get_denormalized_fields(cls, entity): - return {'project_id': entity.project_id} + def execute_transition(cls, entity, transition_name, **kwargs): + # Add specific pre-processing + cls.log_audit(entity, transition_name) + + # Call parent implementation + result = super().execute_transition(entity, transition_name, **kwargs) + + # Add specific post-processing + cls.notify_systems(result) + + return result + +# Configure in Django settings +FSM_STATE_MANAGER_CLASS = 'myapp.managers.MyStateManager' ``` +The `get_state_manager()` function ensures all FSM operations use the correct implementation. + +## Key Features + ### UUID7 Performance Optimization - **Natural Time Ordering**: UUID7 provides chronological ordering without separate timestamp indexes -- **High Concurrency**: INSERT-only approach eliminates locking contention +- **High Concurrency**: INSERT-only approach eliminates locking contention - **Scalability**: Supports large amounts of state records with consistent performance ### Declarative Transitions @@ -181,28 +189,42 @@ class TaskState(BaseState): - **Composable Logic**: Reusable transition classes with inheritance - **Hooks System**: Pre/post transition hooks for custom logic -### Advanced State Manager Features +```python +@register_state_transition('order', 'ship_order') +class ShipOrderTransition(BaseTransition): + tracking_number: str = Field(..., description="Shipping tracking number") + carrier: str = Field(..., description="Shipping carrier") + + @property + def target_state(self) -> str: + return OrderStateChoices.SHIPPED + + def pre_transition_hook(self, context): + # Called before state change + self.validate_inventory(context.entity) + + def transition(self, context) -> dict: + return { + "tracking_number": self.tracking_number, + "carrier": self.carrier, + "shipped_at": context.timestamp.isoformat() + } + + def post_transition_hook(self, context, state_record): + # Called after state change + self.send_shipping_notification(context.entity, state_record) +``` + +### Advanced State Management ```python # Time-range queries using UUID7 from datetime import datetime, timedelta -recent_states = StateManager.get_states_since( +recent_states = StateManager.get_states_in_time_range( entity=order, - since=datetime.now() - timedelta(hours=24) + start_time=datetime.now() - timedelta(hours=24) ) -# Get current state object (not just string) -current_state_obj = StateManager.get_current_state_object(order) -if current_state_obj: - print(f"State: {current_state_obj.state}") - print(f"Since: {current_state_obj.created_at}") - print(f"By: {current_state_obj.triggered_by}") - -# Get state history with full objects -history = StateManager.get_state_history(order, limit=10) -for state in history: - print(f"{state.state} at {state.created_at}") - # Cache management StateManager.invalidate_cache(order) # Clear cache for entity StateManager.warm_cache([order1, order2, order3]) # Pre-populate cache @@ -210,141 +232,83 @@ StateManager.warm_cache([order1, order2, order3]) # Pre-populate cache ### Registry System -The FSM uses a flexible registry pattern for decoupling: - ```python from fsm.registry import ( state_model_registry, - state_choices_registry, transition_registry, register_state_model, register_state_choices, register_state_transition, ) -# Register state choices -@register_state_choices('task') -class TaskStateChoices(models.TextChoices): - # ... - -# Register state model with denormalizer -@register_state_model('task', denormalizer=denormalize_task) -class TaskState(BaseState): - # ... - -# Register transitions -@register_state_transition('task', 'start_task') -class StartTaskTransition(BaseTransition): - # ... - -# Access registries directly -model = state_model_registry.get_model('task') -choices = state_choices_registry.get_choices('task') -transition = transition_registry.get_transition('task', 'start_task') -``` - -## Performance Characteristics - -- **State Queries**: O(1) current state lookup via UUID7 ordering -- **History Queries**: Optimal for time-series access patterns -- **Bulk Operations**: Efficient batch processing for thousands of entities -- **Cache Integration**: Intelligent caching with automatic invalidation -- **Memory Efficiency**: Minimal memory footprint for state objects - -## Transition System Features - -### Transition Context - -```python -from fsm.transitions import TransitionContext - -# Context provides rich information during transitions -context = TransitionContext( - entity=task, - current_user=user, - current_state='CREATED', - target_state='IN_PROGRESS', - organization_id=org_id, - metadata={'source': 'api', 'priority': 'high'} -) - -# Context properties -if context.is_initial_transition: - # First state for this entity - pass -if context.has_current_state: - # Entity has existing state - pass +# Access registries directly if needed +model = state_model_registry.get_model('order') +transition = transition_registry.get_transition('order', 'process_order') ``` ### Transition Utilities ```python -from fsm.state_manager import StateManager from fsm.transition_utils import ( get_available_transitions, get_transition_schema, validate_transition_data, ) -# Execute a transition - the only way to execute transitions -result = StateManager.execute_transition( - entity=task, - transition_name='start_task', - transition_data={'assigned_user_id': 123}, - user=request.user -) +# Get all transitions for an entity +all_transitions = get_available_transitions(order) # Get available transitions for an entity -available = get_available_transitions(task) +available_transitions = get_available_transitions(order, validate=True) # Get JSON schema for transition (useful for APIs) -schema = get_transition_schema(StartTaskTransition) +schema = get_transition_schema(ProcessOrderTransition) # Validate transition data before execution -errors = validate_transition_data(StartTaskTransition, data) +errors = validate_transition_data(ProcessOrderTransition, data) ``` -## Extension Points +## Performance Characteristics -### Custom State Manager +- **State Queries**: O(1) current state lookup via UUID7 ordering +- **History Queries**: Optimal for time-series access patterns +- **Bulk Operations**: Efficient batch processing for thousands of entities +- **Cache Integration**: Intelligent caching with automatic invalidation +- **Memory Efficiency**: Minimal memory footprint for state objects -```python -from fsm.state_manager import BaseStateManager +## Architecture Benefits -class CustomStateManager(BaseStateManager): - def get_current_state(self, entity): - # Custom logic - return super().get_current_state(entity) -``` +### Clean Import Hierarchy -### Custom Validation +- **No Circular Dependencies**: Unidirectional import flow prevents import cycles +- **No Function-Level Imports**: All imports at module level for clarity and performance +- **Extension Friendly**: StateManager as central extension point -```python -@register_state_transition('order', 'validate_payment') -class PaymentValidationTransition(BaseTransition): - def validate_transition(self, context) -> bool: - # Custom business logic - return self.check_payment_method(context.entity) -``` +### Separation of Concerns -## Framework vs Implementation +- **Registry**: Pure storage and retrieval (no execution logic) +- **Transitions**: Validation and business logic (no state management) +- **StateManager**: State persistence and caching (main entry point) +- **TransitionExecutor**: Orchestrates execution (takes StateManager as parameter) -This is the **core framework** - a clean, generic FSM system. Product-specific implementations (state definitions, concrete models, business logic) should be in separate branches/modules for: +## Migration from Other FSM Libraries -- **Clean Architecture**: Framework logic separated from business logic -- **Reusability**: Framework can be used across different projects -- **Maintainability**: Changes to business logic don't affect framework -- **Review Process**: Framework and implementation can be reviewed independently +The framework is designed to be compatible with existing Django FSM patterns while offering significant performance improvements through UUID7 optimization and clean architecture. -## Migration from Other FSM Libraries +## Best Practices -The framework provides migration utilities and is designed to be compatible with existing Django FSM patterns while offering significant performance improvements through UUID7 optimization. +1. **Always use StateManager.execute_transition()** +2. **Extend StateManager** for further customizations +3. **Use denormalized fields** for frequently queried data +4. **Leverage UUID7 ordering** for time-series queries +5. **Implement proper validation** in transition classes +6. **Use the registry decorators** for clean registration ## Contributing When contributing: +- Maintain the clean import hierarchy - Keep framework code generic and reusable -- Add product-specific code to appropriate implementation branches -- Include performance tests for UUID7 optimizations +- Add performance tests for UUID7 optimizations - Document extension points and customization options +- Ensure extensibility is preserved \ No newline at end of file diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index a65449ace125..c39a8ff037dd 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -321,30 +321,4 @@ def finalize(self, context: TransitionContext[EntityType, StateModelType], state self.post_transition_hook(context, state_record) finally: # Always clear context when done - self.context = None - - def execute(self, context: TransitionContext[EntityType, StateModelType]) -> StateModelType: - """ - Execute the complete transition workflow. - - NOTE: This method is provided for backward compatibility but should not be called - directly to avoid circular imports. Use TransitionExecutor.execute() instead. - - The actual execution flow is: - 1. TransitionExecutor calls prepare_and_validate() - 2. TransitionExecutor creates the state record via StateManager - 3. TransitionExecutor calls finalize() - - Args: - context: The transition context - - Returns: - The newly created state record - - Raises: - NotImplementedError: This method should not be called directly - """ - raise NotImplementedError( - 'Direct execution of transitions is not supported to avoid circular imports. ' - 'Use TransitionExecutor.execute() or StateManager.execute_transition() instead.' - ) + self.context = None \ No newline at end of file From 9c6ee3447e74bfe0bd328100df0de9539de7381c Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 13:09:22 -0500 Subject: [PATCH 69/83] fixing lint errors --- label_studio/fsm/tests/test_registry.py | 2 -- label_studio/fsm/transitions.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/label_studio/fsm/tests/test_registry.py b/label_studio/fsm/tests/test_registry.py index 0f229568b618..a2d498ffae2c 100644 --- a/label_studio/fsm/tests/test_registry.py +++ b/label_studio/fsm/tests/test_registry.py @@ -16,9 +16,7 @@ state_model_registry, transition_registry, ) -from fsm.state_manager import StateManager from fsm.transitions import BaseTransition -from pydantic import Field class MockEntity: diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index c39a8ff037dd..ba9cbb6e2df9 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -321,4 +321,4 @@ def finalize(self, context: TransitionContext[EntityType, StateModelType], state self.post_transition_hook(context, state_record) finally: # Always clear context when done - self.context = None \ No newline at end of file + self.context = None From 0dcbcde73036ee38a16e11e97dc6d1742367f790 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Fri, 5 Sep 2025 13:09:53 -0500 Subject: [PATCH 70/83] fixing lint errors --- label_studio/fsm/transitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/transitions.py b/label_studio/fsm/transitions.py index c39a8ff037dd..ba9cbb6e2df9 100644 --- a/label_studio/fsm/transitions.py +++ b/label_studio/fsm/transitions.py @@ -321,4 +321,4 @@ def finalize(self, context: TransitionContext[EntityType, StateModelType], state self.post_transition_hook(context, state_record) finally: # Always clear context when done - self.context = None \ No newline at end of file + self.context = None From 0f48c67f892cea375ca1adc762cdbc563e5672bb Mon Sep 17 00:00:00 2001 From: bmartel Date: Mon, 8 Sep 2025 14:45:42 +0000 Subject: [PATCH 71/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17554612979 From f0756a0f700cd1bda5f965b4448bee04e7935719 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 8 Sep 2025 10:06:12 -0500 Subject: [PATCH 72/83] fixing tests for windows SQLITE --- label_studio/fsm/tests/test_fsm_integration.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index dabb5c09ecfe..165170ad2ba8 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -4,7 +4,7 @@ and API endpoints. """ -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from unittest.mock import patch import pytest @@ -256,7 +256,7 @@ def execute_callback(callback): mock_on_commit.side_effect = execute_callback - before_time = datetime.now(timezone.utc) + before_time = datetime.now(timezone.utc) - timedelta(seconds=1) # Create some states self.StateManager.transition_state(entity=self.task, new_state='CREATED', user=self.user) @@ -266,7 +266,7 @@ def execute_callback(callback): assert mock_on_commit.call_count == 2 # Record time after creating states - after_time = datetime.now(timezone.utc) + after_time = datetime.now(timezone.utc) + timedelta(seconds=1) # Query states in time range states_in_range = self.StateManager.get_states_in_time_range(self.task, before_time, after_time) From 20b0c06154dbbecb77d02cd4971133d7efa43ad7 Mon Sep 17 00:00:00 2001 From: bmartel Date: Mon, 8 Sep 2025 11:36:09 -0500 Subject: [PATCH 73/83] Update label_studio/fsm/models.py Co-authored-by: Marcel Canu --- label_studio/fsm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index f916a773441b..8c1aa45b8761 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -291,7 +291,7 @@ def get_denormalized_fields(cls, entity): return { 'task_id': entity.task.id, 'project_id': entity.task.project_id, - 'completed_by_id': entity.completed_by.id if entity.completed_by else None, + 'completed_by_id': entity.completed_by_id if entity.completed_by_id else None, } @property From 5559707b8bd28d90d93a7e9aaa80a63a2211c2b6 Mon Sep 17 00:00:00 2001 From: bmartel Date: Mon, 8 Sep 2025 11:40:27 -0500 Subject: [PATCH 74/83] Update label_studio/fsm/models.py Co-authored-by: Marcel Canu --- label_studio/fsm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 8c1aa45b8761..1aaa9319deef 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -334,7 +334,7 @@ class Meta: def get_denormalized_fields(cls, entity): """Get denormalized fields for ProjectState creation""" return { - 'created_by_id': entity.created_by.id if entity.created_by else None, + 'created_by_id': entity.created_by_id if entity.created_by_id else None, } @property From 61a89d427366b7f5bfa98f14a663da04a6f504e0 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 8 Sep 2025 12:14:35 -0500 Subject: [PATCH 75/83] applying feedback to remove the additional method get_current_state and utilize get_current_state_value and get_current_state_object as the only means for retrieving current_state --- label_studio/fsm/README.md | 2 +- label_studio/fsm/models.py | 10 ++++++++-- label_studio/fsm/state_manager.py | 9 ++------- label_studio/fsm/tests/test_fsm_integration.py | 18 +++++++++--------- .../tests/test_integration_django_models.py | 2 +- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/label_studio/fsm/README.md b/label_studio/fsm/README.md index 5f3a624f4a7f..d2e45c3d85c4 100644 --- a/label_studio/fsm/README.md +++ b/label_studio/fsm/README.md @@ -139,7 +139,7 @@ from fsm.state_manager import get_state_manager StateManager = get_state_manager() # Get current state -current_state = StateManager.get_current_state(order) +current_state = StateManager.get_current_state_value(order) # Get state history history = StateManager.get_state_history(order, limit=10) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 1aaa9319deef..112b9cced476 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -132,8 +132,14 @@ def get_current_state(cls, entity) -> Optional['BaseState']: @classmethod def get_current_state_value(cls, entity) -> Optional[str]: - """Get current state value as string""" - current_state = cls.get_current_state(entity) + """ + Get current state value as string using UUID7 natural ordering. + + Uses UUID7's natural time ordering to efficiently find the latest state + without requiring created_at indexes or complex queries. + """ + entity_field = f'{cls._get_entity_field_name()}' + current_state = cls.objects.filter(**{entity_field: entity}).order_by('-id').first() return current_state.state if current_state else None @classmethod diff --git a/label_studio/fsm/state_manager.py b/label_studio/fsm/state_manager.py index a52f46f164ff..3b6e32d20957 100644 --- a/label_studio/fsm/state_manager.py +++ b/label_studio/fsm/state_manager.py @@ -129,11 +129,6 @@ def get_current_state_value(cls, entity: Model) -> Optional[str]: ) raise StateManagerError(f'Error getting current state: {e}') from e - @classmethod - def get_current_state(cls, entity: Model) -> Optional[str]: - """Backward compatibility method - calls get_current_state_value""" - return cls.get_current_state_value(entity) - @classmethod def get_current_state_object(cls, entity: Model) -> BaseState: """ @@ -203,7 +198,7 @@ def transition_state( if not state_model: raise StateManagerError(f'No state model found for {entity._meta.model_name} when transitioning state') - current_state = cls.get_current_state(entity) + current_state = cls.get_current_state_value(entity) try: with transaction.atomic(): @@ -378,7 +373,7 @@ def warm_cache(cls, entities: List[Model]): if organization_id is None: if hasattr(entity, 'organization_id'): organization_id = entity.organization_id - current_state = cls.get_current_state(entity) + current_state = cls.get_current_state_value(entity) if current_state: cache_key = cls.get_cache_key(entity) cache_updates[cache_key] = current_state diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 165170ad2ba8..7222f5e6bc51 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -127,7 +127,7 @@ def setUp(self): def test_get_current_state_empty(self): """Test getting current state when no states exist""" - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state is None @patch('django.db.transaction.on_commit') @@ -157,7 +157,7 @@ def execute_callback(callback): assert mock_on_commit.call_count == 1 # Check current state - should work with mocked cache update - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state == 'CREATED' # Another transition @@ -173,7 +173,7 @@ def execute_callback(callback): # Verify transaction.on_commit was called again (total 2 times) assert mock_on_commit.call_count == 2 - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state == 'IN_PROGRESS' @patch('django.db.transaction.on_commit') @@ -305,7 +305,7 @@ def track_and_execute(callback): assert len(callbacks_executed) == 1 # Verify the cache was properly updated by executing the callback - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state == 'CREATED' # Perform another successful transition @@ -320,7 +320,7 @@ def track_and_execute(callback): assert mock_on_commit.call_count == 2 assert len(callbacks_executed) == 2 - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state == 'IN_PROGRESS' @patch('django.db.transaction.on_commit') @@ -349,7 +349,7 @@ def test_transaction_on_commit_failure_case(self, mock_get_state_model, mock_on_ # Verify cache was not updated (should raise exception) with pytest.raises(Exception): # Should raise StateManagerError - self.StateManager.get_current_state(self.task) + self.StateManager.get_current_state_value(self.task) @patch('django.db.transaction.on_commit') @patch('fsm.models.TaskState.objects.create') @@ -376,7 +376,7 @@ def test_transaction_on_commit_database_failure_case(self, mock_create, mock_on_ assert mock_on_commit.call_count == 0 # Verify cache was deleted due to failure (cache.delete should be called) - current_state = self.StateManager.get_current_state(self.task) + current_state = self.StateManager.get_current_state_value(self.task) assert current_state is None @patch('django.db.transaction.on_commit') @@ -412,6 +412,6 @@ def test_transaction_on_commit_callback_content(self, mock_on_commit): cached_state = cache.get(cache_key) assert cached_state == 'CREATED' - # Verify get_current_state uses the cached value - current_state = self.StateManager.get_current_state(self.task) + # Verify get_current_state_value uses the cached value + current_state = self.StateManager.get_current_state_value(self.task) assert current_state == 'CREATED' diff --git a/label_studio/fsm/tests/test_integration_django_models.py b/label_studio/fsm/tests/test_integration_django_models.py index 40d9a59408fb..2b768126b36c 100644 --- a/label_studio/fsm/tests/test_integration_django_models.py +++ b/label_studio/fsm/tests/test_integration_django_models.py @@ -213,7 +213,7 @@ def transition(self, context: TransitionContext) -> Dict[str, Any]: create_transition = CreateTaskTransition(created_by_id=100, initial_priority='high') # Test with StateManager integration - with patch('fsm.state_manager.StateManager.get_current_state') as mock_get_current: + with patch('fsm.state_manager.StateManager.get_current_state_value') as mock_get_current: mock_get_current.return_value = None # No current state context = TransitionContext( From bc6e871a7f482ac559cea5550e9c0c87126c1249 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 8 Sep 2025 12:57:53 -0500 Subject: [PATCH 76/83] removing redundant db_index=True on task fk --- label_studio/fsm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/label_studio/fsm/models.py b/label_studio/fsm/models.py index 112b9cced476..2d154baea65e 100644 --- a/label_studio/fsm/models.py +++ b/label_studio/fsm/models.py @@ -217,7 +217,7 @@ class TaskState(BaseState): """ # Entity Relationship - task = models.ForeignKey('tasks.Task', related_name='fsm_states', on_delete=models.CASCADE, db_index=True) + task = models.ForeignKey('tasks.Task', related_name='fsm_states', on_delete=models.CASCADE) # Override state field to add choices constraint state = models.CharField(max_length=50, choices=TaskStateChoices.choices, db_index=True) From bf73eb398c8799254b9dcd2fdb22f96a83767f37 Mon Sep 17 00:00:00 2001 From: Brandon Martel Date: Mon, 8 Sep 2025 15:16:42 -0500 Subject: [PATCH 77/83] using factories as mentioned in PR feedback --- .../fsm/tests/test_fsm_integration.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/label_studio/fsm/tests/test_fsm_integration.py b/label_studio/fsm/tests/test_fsm_integration.py index 7222f5e6bc51..3c523c3d58ff 100644 --- a/label_studio/fsm/tests/test_fsm_integration.py +++ b/label_studio/fsm/tests/test_fsm_integration.py @@ -8,23 +8,21 @@ from unittest.mock import patch import pytest -from django.contrib.auth import get_user_model from django.test import TestCase from fsm.models import AnnotationState, ProjectState, TaskState from fsm.state_manager import get_state_manager -from projects.models import Project -from tasks.models import Annotation, Task - -User = get_user_model() +from projects.tests.factories import ProjectFactory +from tasks.tests.factories import AnnotationFactory, TaskFactory +from users.tests.factories import UserFactory class TestFSMModels(TestCase): """Test FSM model functionality""" def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.user = UserFactory(email='test@example.com') + self.project = ProjectFactory(created_by=self.user) + self.task = TaskFactory(project=self.project, data={'text': 'test'}) # Clear cache to ensure tests start with clean state from django.core.cache import cache @@ -57,7 +55,7 @@ def test_task_state_creation(self): def test_annotation_state_creation(self): """Test AnnotationState creation and basic functionality""" - annotation = Annotation.objects.create(task=self.task, completed_by=self.user, result=[]) + annotation = AnnotationFactory(task=self.task, completed_by=self.user, result=[]) annotation_state = AnnotationState.objects.create( annotation=annotation, @@ -108,9 +106,9 @@ class TestStateManager(TestCase): """Test StateManager functionality with mocked transaction support""" def setUp(self): - self.user = User.objects.create_user(email='test@example.com', password='test123') - self.project = Project.objects.create(title='Test Project', created_by=self.user) - self.task = Task.objects.create(project=self.project, data={'text': 'test'}) + self.user = UserFactory(email='test@example.com') + self.project = ProjectFactory(created_by=self.user) + self.task = TaskFactory(project=self.project, data={'text': 'test'}) self.StateManager = get_state_manager() # Clear cache to ensure tests start with clean state From 6b952a3dccb16900982e998fcde3eb72629be4e5 Mon Sep 17 00:00:00 2001 From: bmartel Date: Tue, 9 Sep 2025 14:20:05 +0000 Subject: [PATCH 78/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17585497946 From a5f326f6d2423583bf4dfa19f001ac50172dedbc Mon Sep 17 00:00:00 2001 From: bmartel Date: Tue, 9 Sep 2025 18:21:00 +0000 Subject: [PATCH 79/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17591882567 --- poetry.lock | 118 ++++++------------------------------------------- pyproject.toml | 2 +- 2 files changed, 15 insertions(+), 105 deletions(-) diff --git a/poetry.lock b/poetry.lock index 53c237c888b5..5eb104271b96 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2136,7 +2136,7 @@ optional = false python-versions = ">=3.9,<4" groups = ["main"] files = [ - {file = "7290779bdb88a4940779926a79ed9bc1244e76b4.zip", hash = "sha256:4739d6942e9e30ab5deaf79e2eb9f2f546f5d51cb64a7ccdf5b7e2b9909d335d"}, + {file = "c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip", hash = "sha256:6fd1062dfdb8b74af1829be3587f6305e5a0091eab3dc1715b2966160f0eddee"}, ] [package.dependencies] @@ -2149,7 +2149,7 @@ jsonschema = ">=4.23.0" lxml = ">=4.2.5" nltk = ">=3.9.1,<4.0.0" numpy = ">=1.26.4,<3.0.0" -opencv-python = ">=4.9.0,<5.0.0" +opencv-python = ">=4.12.0,<5.0.0" pandas = ">=0.24.0" Pillow = ">=11.3.0" pydantic = ">=1.9.2" @@ -2164,7 +2164,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/7290779bdb88a4940779926a79ed9bc1244e76b4.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip" [[package]] name = "launchdarkly-server-sdk" @@ -2648,7 +2648,6 @@ description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" groups = ["main"] -markers = "python_version == \"3.10\"" files = [ {file = "numpy-2.2.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb"}, {file = "numpy-2.2.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90"}, @@ -2707,91 +2706,6 @@ files = [ {file = "numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd"}, ] -[[package]] -name = "numpy" -version = "2.3.2" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.11" -groups = ["main"] -markers = "python_version >= \"3.11\"" -files = [ - {file = "numpy-2.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:852ae5bed3478b92f093e30f785c98e0cb62fa0a939ed057c31716e18a7a22b9"}, - {file = "numpy-2.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7a0e27186e781a69959d0230dd9909b5e26024f8da10683bd6344baea1885168"}, - {file = "numpy-2.3.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:f0a1a8476ad77a228e41619af2fa9505cf69df928e9aaa165746584ea17fed2b"}, - {file = "numpy-2.3.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:cbc95b3813920145032412f7e33d12080f11dc776262df1712e1638207dde9e8"}, - {file = "numpy-2.3.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f75018be4980a7324edc5930fe39aa391d5734531b1926968605416ff58c332d"}, - {file = "numpy-2.3.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20b8200721840f5621b7bd03f8dcd78de33ec522fc40dc2641aa09537df010c3"}, - {file = "numpy-2.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f91e5c028504660d606340a084db4b216567ded1056ea2b4be4f9d10b67197f"}, - {file = "numpy-2.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fb1752a3bb9a3ad2d6b090b88a9a0ae1cd6f004ef95f75825e2f382c183b2097"}, - {file = "numpy-2.3.2-cp311-cp311-win32.whl", hash = "sha256:4ae6863868aaee2f57503c7a5052b3a2807cf7a3914475e637a0ecd366ced220"}, - {file = "numpy-2.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:240259d6564f1c65424bcd10f435145a7644a65a6811cfc3201c4a429ba79170"}, - {file = "numpy-2.3.2-cp311-cp311-win_arm64.whl", hash = "sha256:4209f874d45f921bde2cff1ffcd8a3695f545ad2ffbef6d3d3c6768162efab89"}, - {file = "numpy-2.3.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bc3186bea41fae9d8e90c2b4fb5f0a1f5a690682da79b92574d63f56b529080b"}, - {file = "numpy-2.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f4f0215edb189048a3c03bd5b19345bdfa7b45a7a6f72ae5945d2a28272727f"}, - {file = "numpy-2.3.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:8b1224a734cd509f70816455c3cffe13a4f599b1bf7130f913ba0e2c0b2006c0"}, - {file = "numpy-2.3.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3dcf02866b977a38ba3ec10215220609ab9667378a9e2150615673f3ffd6c73b"}, - {file = "numpy-2.3.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:572d5512df5470f50ada8d1972c5f1082d9a0b7aa5944db8084077570cf98370"}, - {file = "numpy-2.3.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8145dd6d10df13c559d1e4314df29695613575183fa2e2d11fac4c208c8a1f73"}, - {file = "numpy-2.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:103ea7063fa624af04a791c39f97070bf93b96d7af7eb23530cd087dc8dbe9dc"}, - {file = "numpy-2.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc927d7f289d14f5e037be917539620603294454130b6de200091e23d27dc9be"}, - {file = "numpy-2.3.2-cp312-cp312-win32.whl", hash = "sha256:d95f59afe7f808c103be692175008bab926b59309ade3e6d25009e9a171f7036"}, - {file = "numpy-2.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:9e196ade2400c0c737d93465327d1ae7c06c7cb8a1756121ebf54b06ca183c7f"}, - {file = "numpy-2.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:ee807923782faaf60d0d7331f5e86da7d5e3079e28b291973c545476c2b00d07"}, - {file = "numpy-2.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c8d9727f5316a256425892b043736d63e89ed15bbfe6556c5ff4d9d4448ff3b3"}, - {file = "numpy-2.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:efc81393f25f14d11c9d161e46e6ee348637c0a1e8a54bf9dedc472a3fae993b"}, - {file = "numpy-2.3.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:dd937f088a2df683cbb79dda9a772b62a3e5a8a7e76690612c2737f38c6ef1b6"}, - {file = "numpy-2.3.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:11e58218c0c46c80509186e460d79fbdc9ca1eb8d8aee39d8f2dc768eb781089"}, - {file = "numpy-2.3.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5ad4ebcb683a1f99f4f392cc522ee20a18b2bb12a2c1c42c3d48d5a1adc9d3d2"}, - {file = "numpy-2.3.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:938065908d1d869c7d75d8ec45f735a034771c6ea07088867f713d1cd3bbbe4f"}, - {file = "numpy-2.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:66459dccc65d8ec98cc7df61307b64bf9e08101f9598755d42d8ae65d9a7a6ee"}, - {file = "numpy-2.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a7af9ed2aa9ec5950daf05bb11abc4076a108bd3c7db9aa7251d5f107079b6a6"}, - {file = "numpy-2.3.2-cp313-cp313-win32.whl", hash = "sha256:906a30249315f9c8e17b085cc5f87d3f369b35fedd0051d4a84686967bdbbd0b"}, - {file = "numpy-2.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:c63d95dc9d67b676e9108fe0d2182987ccb0f11933c1e8959f42fa0da8d4fa56"}, - {file = "numpy-2.3.2-cp313-cp313-win_arm64.whl", hash = "sha256:b05a89f2fb84d21235f93de47129dd4f11c16f64c87c33f5e284e6a3a54e43f2"}, - {file = "numpy-2.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4e6ecfeddfa83b02318f4d84acf15fbdbf9ded18e46989a15a8b6995dfbf85ab"}, - {file = "numpy-2.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:508b0eada3eded10a3b55725b40806a4b855961040180028f52580c4729916a2"}, - {file = "numpy-2.3.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:754d6755d9a7588bdc6ac47dc4ee97867271b17cee39cb87aef079574366db0a"}, - {file = "numpy-2.3.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:a9f66e7d2b2d7712410d3bc5684149040ef5f19856f20277cd17ea83e5006286"}, - {file = "numpy-2.3.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de6ea4e5a65d5a90c7d286ddff2b87f3f4ad61faa3db8dabe936b34c2275b6f8"}, - {file = "numpy-2.3.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3ef07ec8cbc8fc9e369c8dcd52019510c12da4de81367d8b20bc692aa07573a"}, - {file = "numpy-2.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:27c9f90e7481275c7800dc9c24b7cc40ace3fdb970ae4d21eaff983a32f70c91"}, - {file = "numpy-2.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:07b62978075b67eee4065b166d000d457c82a1efe726cce608b9db9dd66a73a5"}, - {file = "numpy-2.3.2-cp313-cp313t-win32.whl", hash = "sha256:c771cfac34a4f2c0de8e8c97312d07d64fd8f8ed45bc9f5726a7e947270152b5"}, - {file = "numpy-2.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:72dbebb2dcc8305c431b2836bcc66af967df91be793d63a24e3d9b741374c450"}, - {file = "numpy-2.3.2-cp313-cp313t-win_arm64.whl", hash = "sha256:72c6df2267e926a6d5286b0a6d556ebe49eae261062059317837fda12ddf0c1a"}, - {file = "numpy-2.3.2-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:448a66d052d0cf14ce9865d159bfc403282c9bc7bb2a31b03cc18b651eca8b1a"}, - {file = "numpy-2.3.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:546aaf78e81b4081b2eba1d105c3b34064783027a06b3ab20b6eba21fb64132b"}, - {file = "numpy-2.3.2-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:87c930d52f45df092f7578889711a0768094debf73cfcde105e2d66954358125"}, - {file = "numpy-2.3.2-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:8dc082ea901a62edb8f59713c6a7e28a85daddcb67454c839de57656478f5b19"}, - {file = "numpy-2.3.2-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:af58de8745f7fa9ca1c0c7c943616c6fe28e75d0c81f5c295810e3c83b5be92f"}, - {file = "numpy-2.3.2-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fed5527c4cf10f16c6d0b6bee1f89958bccb0ad2522c8cadc2efd318bcd545f5"}, - {file = "numpy-2.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:095737ed986e00393ec18ec0b21b47c22889ae4b0cd2d5e88342e08b01141f58"}, - {file = "numpy-2.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5e40e80299607f597e1a8a247ff8d71d79c5b52baa11cc1cce30aa92d2da6e0"}, - {file = "numpy-2.3.2-cp314-cp314-win32.whl", hash = "sha256:7d6e390423cc1f76e1b8108c9b6889d20a7a1f59d9a60cac4a050fa734d6c1e2"}, - {file = "numpy-2.3.2-cp314-cp314-win_amd64.whl", hash = "sha256:b9d0878b21e3918d76d2209c924ebb272340da1fb51abc00f986c258cd5e957b"}, - {file = "numpy-2.3.2-cp314-cp314-win_arm64.whl", hash = "sha256:2738534837c6a1d0c39340a190177d7d66fdf432894f469728da901f8f6dc910"}, - {file = "numpy-2.3.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:4d002ecf7c9b53240be3bb69d80f86ddbd34078bae04d87be81c1f58466f264e"}, - {file = "numpy-2.3.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:293b2192c6bcce487dbc6326de5853787f870aeb6c43f8f9c6496db5b1781e45"}, - {file = "numpy-2.3.2-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:0a4f2021a6da53a0d580d6ef5db29947025ae8b35b3250141805ea9a32bbe86b"}, - {file = "numpy-2.3.2-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9c144440db4bf3bb6372d2c3e49834cc0ff7bb4c24975ab33e01199e645416f2"}, - {file = "numpy-2.3.2-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f92d6c2a8535dc4fe4419562294ff957f83a16ebdec66df0805e473ffaad8bd0"}, - {file = "numpy-2.3.2-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cefc2219baa48e468e3db7e706305fcd0c095534a192a08f31e98d83a7d45fb0"}, - {file = "numpy-2.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:76c3e9501ceb50b2ff3824c3589d5d1ab4ac857b0ee3f8f49629d0de55ecf7c2"}, - {file = "numpy-2.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:122bf5ed9a0221b3419672493878ba4967121514b1d7d4656a7580cd11dddcbf"}, - {file = "numpy-2.3.2-cp314-cp314t-win32.whl", hash = "sha256:6f1ae3dcb840edccc45af496f312528c15b1f79ac318169d094e85e4bb35fdf1"}, - {file = "numpy-2.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:087ffc25890d89a43536f75c5fe8770922008758e8eeeef61733957041ed2f9b"}, - {file = "numpy-2.3.2-cp314-cp314t-win_arm64.whl", hash = "sha256:092aeb3449833ea9c0bf0089d70c29ae480685dd2377ec9cdbbb620257f84631"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:14a91ebac98813a49bc6aa1a0dfc09513dcec1d97eaf31ca21a87221a1cdcb15"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:71669b5daae692189540cffc4c439468d35a3f84f0c88b078ecd94337f6cb0ec"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:69779198d9caee6e547adb933941ed7520f896fd9656834c300bdf4dd8642712"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:2c3271cc4097beb5a60f010bcc1cc204b300bb3eafb4399376418a83a1c6373c"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8446acd11fe3dc1830568c941d44449fd5cb83068e5c70bd5a470d323d448296"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aa098a5ab53fa407fded5870865c6275a5cd4101cfdef8d6fafc48286a96e981"}, - {file = "numpy-2.3.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6936aff90dda378c09bea075af0d9c675fe3a977a9d2402f95a87f440f59f619"}, - {file = "numpy-2.3.2.tar.gz", hash = "sha256:e0486a11ec30cdecb53f184d496d1c6a20786c81e55e41640270130056f8ee48"}, -] - [[package]] name = "openai" version = "1.11.1" @@ -2818,28 +2732,23 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "opencv-python" -version = "4.11.0.86" +version = "4.12.0.88" description = "Wrapper package for OpenCV python bindings." optional = false python-versions = ">=3.6" groups = ["main"] files = [ - {file = "opencv-python-4.11.0.86.tar.gz", hash = "sha256:03d60ccae62304860d232272e4a4fda93c39d595780cb40b161b310244b736a4"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:432f67c223f1dc2824f5e73cdfcd9db0efc8710647d4e813012195dc9122a52a"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:9d05ef13d23fe97f575153558653e2d6e87103995d54e6a35db3f282fe1f9c66"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b92ae2c8852208817e6776ba1ea0d6b1e0a1b5431e971a2a0ddd2a8cc398202"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b02611523803495003bd87362db3e1d2a0454a6a63025dc6658a9830570aa0d"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-win32.whl", hash = "sha256:810549cb2a4aedaa84ad9a1c92fbfdfc14090e2749cedf2c1589ad8359aa169b"}, - {file = "opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl", hash = "sha256:085ad9b77c18853ea66283e98affefe2de8cc4c1f43eda4c100cf9b2721142ec"}, + {file = "opencv-python-4.12.0.88.tar.gz", hash = "sha256:8b738389cede219405f6f3880b851efa3415ccd674752219377353f017d2994d"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-macosx_13_0_arm64.whl", hash = "sha256:f9a1f08883257b95a5764bf517a32d75aec325319c8ed0f89739a57fae9e92a5"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-macosx_13_0_x86_64.whl", hash = "sha256:812eb116ad2b4de43ee116fcd8991c3a687f099ada0b04e68f64899c09448e81"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:51fd981c7df6af3e8f70b1556696b05224c4e6b6777bdd2a46b3d4fb09de1a92"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:092c16da4c5a163a818f120c22c5e4a2f96e0db4f24e659c701f1fe629a690f9"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-win32.whl", hash = "sha256:ff554d3f725b39878ac6a2e1fa232ec509c36130927afc18a1719ebf4fbf4357"}, + {file = "opencv_python-4.12.0.88-cp37-abi3-win_amd64.whl", hash = "sha256:d98edb20aa932fd8ebd276a72627dad9dc097695b3d435a4257557bbb49a79d2"}, ] [package.dependencies] -numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""}, - {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\""}, -] +numpy = {version = ">=2,<2.3.0", markers = "python_version >= \"3.9\""} [[package]] name = "ordered-set" @@ -3923,6 +3832,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5199,4 +5109,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "c7d09631cdc5067ccc455a0ce721836bd1105f0ff3744ee94b770be4705d74ec" +content-hash = "27e6715376cc32fd9f1e23ddcc12ece4a7858ee2aef58640b5d576c38cee6b49" diff --git a/pyproject.toml b/pyproject.toml index 082d576ac38c..ab69545941ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/7290779bdb88a4940779926a79ed9bc1244e76b4.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip", ## HumanSignal repo dependencies :end ] From 868d6a080fca16d8ef90e79bad7e5bce0d559d3d Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 10 Sep 2025 13:39:30 +0000 Subject: [PATCH 80/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17615595310 From e986abb2d19e8df207b98283c1adef4533f3c309 Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 10 Sep 2025 15:15:29 +0000 Subject: [PATCH 81/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17618306507 --- poetry.lock | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index f06819c6af1f..ca0843405f0f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2136,7 +2136,7 @@ optional = false python-versions = ">=3.9,<4" groups = ["main"] files = [ - {file = "c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip", hash = "sha256:6fd1062dfdb8b74af1829be3587f6305e5a0091eab3dc1715b2966160f0eddee"}, + {file = "7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip", hash = "sha256:fee3608ab6781fdf96544e11115d143ef77cdc6ec2a1af94ce6ef7d722d99a01"}, ] [package.dependencies] @@ -2164,7 +2164,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip" [[package]] name = "launchdarkly-server-sdk" @@ -5109,4 +5109,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "27e6715376cc32fd9f1e23ddcc12ece4a7858ee2aef58640b5d576c38cee6b49" +content-hash = "31783201aa1159e1e493a56be67d4554d347bf216db1340d3873cc7a33f3b27e" diff --git a/pyproject.toml b/pyproject.toml index ab69545941ac..5938fc1631c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/c648c39767e6f99f61bf3aa86e75e9f8713a683e.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip", ## HumanSignal repo dependencies :end ] From 1b09c7d77cc2ca406dd45509e0b44e21f5dad788 Mon Sep 17 00:00:00 2001 From: bmartel Date: Wed, 10 Sep 2025 21:48:39 +0000 Subject: [PATCH 82/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17627590890 --- poetry.lock | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index ca0843405f0f..b880b94d0d9d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2136,7 +2136,7 @@ optional = false python-versions = ">=3.9,<4" groups = ["main"] files = [ - {file = "7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip", hash = "sha256:fee3608ab6781fdf96544e11115d143ef77cdc6ec2a1af94ce6ef7d722d99a01"}, + {file = "4f18092f859ee5dfad54ed1b1df28b37476a7809.zip", hash = "sha256:efa0802476b8b7fb81025384cec6cf42e8e06b8ea1262f63ad785e943b378047"}, ] [package.dependencies] @@ -2164,7 +2164,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/4f18092f859ee5dfad54ed1b1df28b37476a7809.zip" [[package]] name = "launchdarkly-server-sdk" @@ -5109,4 +5109,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "31783201aa1159e1e493a56be67d4554d347bf216db1340d3873cc7a33f3b27e" +content-hash = "0965197e3cf6e879b353335176df6b599f7095a2101e36d2f1ad368d30b78e4d" diff --git a/pyproject.toml b/pyproject.toml index 5938fc1631c9..8ab28443b6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/7fb275b0deaf5838f91ca79faf6a53e9e14fd15f.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/4f18092f859ee5dfad54ed1b1df28b37476a7809.zip", ## HumanSignal repo dependencies :end ] From d033ae4fa5d62982a394a20d22c0c3ca1c54f7c4 Mon Sep 17 00:00:00 2001 From: bmartel Date: Thu, 11 Sep 2025 12:38:42 +0000 Subject: [PATCH 83/83] Sync Follow Merge dependencies Workflow run: https://github.com/HumanSignal/label-studio/actions/runs/17644693280 --- poetry.lock | 6 +++--- pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/poetry.lock b/poetry.lock index b880b94d0d9d..9ec8648a2720 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2136,7 +2136,7 @@ optional = false python-versions = ">=3.9,<4" groups = ["main"] files = [ - {file = "4f18092f859ee5dfad54ed1b1df28b37476a7809.zip", hash = "sha256:efa0802476b8b7fb81025384cec6cf42e8e06b8ea1262f63ad785e943b378047"}, + {file = "60b2079d6a4cc0e77b478dfe94bc5790776ada17.zip", hash = "sha256:ac800f61f51773150f814d7add67289c0964b10ed6af778fe226c832305abe74"}, ] [package.dependencies] @@ -2164,7 +2164,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/4f18092f859ee5dfad54ed1b1df28b37476a7809.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/60b2079d6a4cc0e77b478dfe94bc5790776ada17.zip" [[package]] name = "launchdarkly-server-sdk" @@ -5109,4 +5109,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "0965197e3cf6e879b353335176df6b599f7095a2101e36d2f1ad368d30b78e4d" +content-hash = "10eb33c451db6d38f113e6eae05183ae990fc165beeea9a629c67b2f1d67bcb9" diff --git a/pyproject.toml b/pyproject.toml index 8ab28443b6c7..0ce2a8b1139f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/4f18092f859ee5dfad54ed1b1df28b37476a7809.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/60b2079d6a4cc0e77b478dfe94bc5790776ada17.zip", ## HumanSignal repo dependencies :end ]