diff --git a/ic_python_db/db_engine.py b/ic_python_db/db_engine.py index 14f12d4..3f22c28 100644 --- a/ic_python_db/db_engine.py +++ b/ic_python_db/db_engine.py @@ -36,6 +36,12 @@ def init( if cls._instance: raise RuntimeError("Database instance already exists") cls._instance = cls(audit_enabled, db_storage, db_audit) + + # Flush any Entity subclasses that were defined before Database existed + from .entity import Entity + + Entity._flush_deferred_types() + return cls._instance def __init__( diff --git a/ic_python_db/entity.py b/ic_python_db/entity.py index 04e15bb..fa84a45 100644 --- a/ic_python_db/entity.py +++ b/ic_python_db/entity.py @@ -91,10 +91,37 @@ class AdminUser(Entity): _entity_type = None # To be defined in subclasses _context: Set["Entity"] = set() # Set of entities in current context + _deferred_types: List[Type["Entity"]] = [] # Types defined before DB exists _do_not_save = False __version__ = 1 # Default schema version __namespace__: Optional[str] = None # Optional namespace for entity type + def __init_subclass__(cls, **kwargs): + """Auto-register Entity subclasses with the Database at class definition time.""" + super().__init_subclass__(**kwargs) + db = Database._instance + if db is not None: + db.register_entity_type(cls, cls.get_full_type_name()) + else: + # Database not initialized yet — defer registration + Entity._deferred_types.append(cls) + + @classmethod + def _flush_deferred_types(cls): + """Register any Entity subclasses that were defined before Database existed.""" + if not cls._deferred_types: + return + try: + db = Database.get_instance() + except Exception: + return + for deferred_cls in cls._deferred_types: + try: + db.register_entity_type(deferred_cls, deferred_cls.get_full_type_name()) + except Exception: + pass + cls._deferred_types.clear() + def __init__(self, **kwargs): """Initialize a new entity. diff --git a/tests/run_test.sh b/tests/run_test.sh index 63b7c5d..023e5af 100755 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -8,7 +8,7 @@ cd src exit_code=0 -TEST_IDS=("example_1" "example_2" "entity" "mixins" "properties" "alias_and_properties" "relationships" "enhanced_relations" "serialization" "namespaces" "migrations" "database" "audit") +TEST_IDS=("example_1" "example_2" "entity" "mixins" "properties" "alias_and_properties" "relationships" "enhanced_relations" "serialization" "namespaces" "migrations" "database" "audit" "auto_register") # Check if a specific test ID is provided as an argument if [ "$1" ]; then diff --git a/tests/src/tests/test_auto_register.py b/tests/src/tests/test_auto_register.py new file mode 100644 index 0000000..f683d09 --- /dev/null +++ b/tests/src/tests/test_auto_register.py @@ -0,0 +1,129 @@ +"""Tests for automatic Entity type registration at class definition time. + +Verifies that Entity subclasses are registered in Database._entity_types +when the class is defined, not just when instances are created. + +See: https://github.com/smart-social-contracts/ic-python-db/issues/6 +""" + +from tester import Tester + +from ic_python_db import Database, Entity, Integer, String + + +class TestAutoRegister: + def setUp(self): + """Clear database before each test.""" + Database.get_instance().clear() + + def test_type_registered_at_definition_time(self): + """Entity subclass should be in _entity_types immediately after class definition.""" + db = Database.get_instance() + + class Dog(Entity): + __alias__ = "name" + name = String(max_length=50) + + assert ( + "Dog" in db._entity_types + ), f"Dog not in _entity_types: {list(db._entity_types.keys())}" + assert db._entity_types["Dog"] is Dog + + def test_multiple_types_registered(self): + """Multiple Entity subclasses should all be registered.""" + db = Database.get_instance() + + class Cat(Entity): + name = String(max_length=50) + + class Fish(Entity): + name = String(max_length=50) + + assert "Cat" in db._entity_types + assert "Fish" in db._entity_types + + def test_type_registered_without_creating_instances(self): + """Type should be registered even if no instances exist.""" + db = Database.get_instance() + + class Bird(Entity): + __alias__ = "name" + name = String(max_length=50) + + assert "Bird" in db._entity_types + assert Bird.count() == 0 + + def test_type_still_works_after_instance_creation(self): + """Creating instances should not break or duplicate registration.""" + db = Database.get_instance() + + class Horse(Entity): + __alias__ = "name" + name = String(max_length=50) + legs = Integer(default=4) + + assert "Horse" in db._entity_types + horse = Horse(name="Spirit", legs=4) + assert "Horse" in db._entity_types + assert db._entity_types["Horse"] is Horse + loaded = Horse[horse._id] + assert loaded.name == "Spirit" + + def test_namespaced_entity_registered(self): + """Entity with __namespace__ should register under full type name.""" + db = Database.get_instance() + + class MyExtEntity(Entity): + __namespace__ = "ext_test" + name = String(max_length=50) + + full_name = MyExtEntity.get_full_type_name() + assert full_name == "ext_test::MyExtEntity" + assert ( + full_name in db._entity_types + ), f"{full_name} not in _entity_types: {list(db._entity_types.keys())}" + + def test_clear_preserves_type_registration(self): + """Database.clear() should not lose entity type registrations.""" + db = Database.get_instance() + + class Lizard(Entity): + name = String(max_length=50) + + assert "Lizard" in db._entity_types + db.clear() + assert "Lizard" in db._entity_types + + def test_flush_deferred_types_is_idempotent(self): + """Calling _flush_deferred_types when list is empty is a no-op.""" + Entity._flush_deferred_types() + assert Entity._deferred_types == [] + + def test_deferred_types_flushed_on_db_init(self): + """Types defined before Database exists should register when DB is created.""" + old_instance = Database._instance + Database._instance = None + + class Frog(Entity): + name = String(max_length=50) + + assert Frog in Entity._deferred_types + + Database._instance = None + db = Database.init(audit_enabled=False) + assert "Frog" in db._entity_types + assert Entity._deferred_types == [] + + # Restore original DB instance for other tests + Database._instance = None + Database._instance = old_instance + + +def run(test_name: str = None, test_var: str = None): + tester = Tester(TestAutoRegister) + return tester.run_tests() + + +if __name__ == "__main__": + Database.get_instance().clear() + exit(run())