Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ic_python_db/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def init(
if cls._instance:
raise RuntimeError("Database instance already exists")
cls._instance = cls(audit_enabled, db_storage, db_audit)

# Flush any Entity subclasses that were defined before Database existed
from .entity import Entity

Entity._flush_deferred_types()

return cls._instance

def __init__(
Expand Down
27 changes: 27 additions & 0 deletions ic_python_db/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,37 @@ class AdminUser(Entity):

_entity_type = None # To be defined in subclasses
_context: Set["Entity"] = set() # Set of entities in current context
_deferred_types: List[Type["Entity"]] = [] # Types defined before DB exists
_do_not_save = False
__version__ = 1 # Default schema version
__namespace__: Optional[str] = None # Optional namespace for entity type

def __init_subclass__(cls, **kwargs):
"""Auto-register Entity subclasses with the Database at class definition time."""
super().__init_subclass__(**kwargs)
db = Database._instance
if db is not None:
db.register_entity_type(cls, cls.get_full_type_name())
else:
# Database not initialized yet — defer registration
Entity._deferred_types.append(cls)

@classmethod
def _flush_deferred_types(cls):
"""Register any Entity subclasses that were defined before Database existed."""
if not cls._deferred_types:
return
try:
db = Database.get_instance()
except Exception:
return
for deferred_cls in cls._deferred_types:
try:
db.register_entity_type(deferred_cls, deferred_cls.get_full_type_name())
except Exception:
pass
cls._deferred_types.clear()

def __init__(self, **kwargs):
"""Initialize a new entity.

Expand Down
2 changes: 1 addition & 1 deletion tests/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cd src

exit_code=0

TEST_IDS=("example_1" "example_2" "entity" "mixins" "properties" "alias_and_properties" "relationships" "enhanced_relations" "serialization" "namespaces" "migrations" "database" "audit")
TEST_IDS=("example_1" "example_2" "entity" "mixins" "properties" "alias_and_properties" "relationships" "enhanced_relations" "serialization" "namespaces" "migrations" "database" "audit" "auto_register")

# Check if a specific test ID is provided as an argument
if [ "$1" ]; then
Expand Down
129 changes: 129 additions & 0 deletions tests/src/tests/test_auto_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Tests for automatic Entity type registration at class definition time.

Verifies that Entity subclasses are registered in Database._entity_types
when the class is defined, not just when instances are created.

See: https://github.com/smart-social-contracts/ic-python-db/issues/6
"""

from tester import Tester

from ic_python_db import Database, Entity, Integer, String


class TestAutoRegister:
def setUp(self):
"""Clear database before each test."""
Database.get_instance().clear()

def test_type_registered_at_definition_time(self):
"""Entity subclass should be in _entity_types immediately after class definition."""
db = Database.get_instance()

class Dog(Entity):
__alias__ = "name"
name = String(max_length=50)

assert (
"Dog" in db._entity_types
), f"Dog not in _entity_types: {list(db._entity_types.keys())}"
assert db._entity_types["Dog"] is Dog

def test_multiple_types_registered(self):
"""Multiple Entity subclasses should all be registered."""
db = Database.get_instance()

class Cat(Entity):
name = String(max_length=50)

class Fish(Entity):
name = String(max_length=50)

assert "Cat" in db._entity_types
assert "Fish" in db._entity_types

def test_type_registered_without_creating_instances(self):
"""Type should be registered even if no instances exist."""
db = Database.get_instance()

class Bird(Entity):
__alias__ = "name"
name = String(max_length=50)

assert "Bird" in db._entity_types
assert Bird.count() == 0

def test_type_still_works_after_instance_creation(self):
"""Creating instances should not break or duplicate registration."""
db = Database.get_instance()

class Horse(Entity):
__alias__ = "name"
name = String(max_length=50)
legs = Integer(default=4)

assert "Horse" in db._entity_types
horse = Horse(name="Spirit", legs=4)
assert "Horse" in db._entity_types
assert db._entity_types["Horse"] is Horse
loaded = Horse[horse._id]
assert loaded.name == "Spirit"

def test_namespaced_entity_registered(self):
"""Entity with __namespace__ should register under full type name."""
db = Database.get_instance()

class MyExtEntity(Entity):
__namespace__ = "ext_test"
name = String(max_length=50)

full_name = MyExtEntity.get_full_type_name()
assert full_name == "ext_test::MyExtEntity"
assert (
full_name in db._entity_types
), f"{full_name} not in _entity_types: {list(db._entity_types.keys())}"

def test_clear_preserves_type_registration(self):
"""Database.clear() should not lose entity type registrations."""
db = Database.get_instance()

class Lizard(Entity):
name = String(max_length=50)

assert "Lizard" in db._entity_types
db.clear()
assert "Lizard" in db._entity_types

def test_flush_deferred_types_is_idempotent(self):
"""Calling _flush_deferred_types when list is empty is a no-op."""
Entity._flush_deferred_types()
assert Entity._deferred_types == []

def test_deferred_types_flushed_on_db_init(self):
"""Types defined before Database exists should register when DB is created."""
old_instance = Database._instance
Database._instance = None

class Frog(Entity):
name = String(max_length=50)

assert Frog in Entity._deferred_types

Database._instance = None
db = Database.init(audit_enabled=False)
assert "Frog" in db._entity_types
assert Entity._deferred_types == []

# Restore original DB instance for other tests
Database._instance = None
Database._instance = old_instance


def run(test_name: str = None, test_var: str = None):
tester = Tester(TestAutoRegister)
return tester.run_tests()


if __name__ == "__main__":
Database.get_instance().clear()
exit(run())
Loading