diff --git a/docs/advanced.rst b/docs/advanced.rst index ee1319ca..fcfeb758 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -268,6 +268,23 @@ Similarly, pre-V1.0 output formatting can be re-estated by using ``polymorphic_showfield_old_format = True``. +Creating Subclass Objects from Existing Superclass Objects +------------------------------------------------------------ + +You can create an instance of a subclass from an existing instance of a superclass using the +:meth:`~polymorphic.managers.PolymorphicManager.create_from_super` method +of the subclass's manager. For example: + +.. code-block:: python + + super_instance = ModelA.objects.get(id=1) + sub_instance = ModelB.objects.create_from_super(super_instance, field2='value2') + +The restriction is that ``super_instance`` must be an instance of the direct superclass of +``ModelB``, and any required fields of ``ModelB`` must be provided as keyword arguments. If multiple +levels of subclassing are involved, you must call this method multiple times to "promote" each +level. + .. _restrictions: Restrictions & Caveats diff --git a/src/polymorphic/managers.py b/src/polymorphic/managers.py index 0f8fd71c..50bfc192 100644 --- a/src/polymorphic/managers.py +++ b/src/polymorphic/managers.py @@ -2,7 +2,8 @@ The manager class for use in the models. """ -from django.db import models +from django.contrib.contenttypes.models import ContentType +from django.db import DEFAULT_DB_ALIAS, models from polymorphic.query import PolymorphicQuerySet @@ -49,3 +50,45 @@ def not_instance_of(self, *args): def get_real_instances(self, base_result_objects=None): return self.all().get_real_instances(base_result_objects=base_result_objects) + + def create_from_super(self, obj, **kwargs): + """ + Create an instance of this manager's model class from the given instance of a + parent class. + + This is useful when "promoting" an instance down the inheritance chain. + + :param obj: An instance of a parent class of the manager's model class. + :param kwargs: Additional fields to set on the new instance. + :return: The newly created instance. + """ + from .models import PolymorphicModel + + # ensure we have the most derived real instance + if isinstance(obj, PolymorphicModel): + obj = obj.get_real_instance() + + parent_ptr = self.model._meta.parents.get(type(obj), None) + + if not parent_ptr: + raise TypeError( + f"{obj.__class__.__name__} is not a direct parent of {self.model.__name__}" + ) + kwargs[parent_ptr.get_attname()] = obj.pk + + # create the new base class with only fields that apply to it. + ctype = ContentType.objects.db_manager( + using=(obj._state.db or DEFAULT_DB_ALIAS) + ).get_for_model(self.model) + nobj = self.model(**kwargs, polymorphic_ctype=ctype) + nobj.save_base(raw=True, using=obj._state.db or DEFAULT_DB_ALIAS, force_insert=True) + # force update the content type, but first we need to + # retrieve a clean copy from the db to fill in the null + # fields otherwise they would be overwritten. + if isinstance(obj, PolymorphicModel): + parent = obj.__class__.objects.using(obj._state.db or DEFAULT_DB_ALIAS).get(pk=obj.pk) + parent.polymorphic_ctype = ctype + parent.save() + + nobj.refresh_from_db() # cast to cls + return nobj diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index f25287f6..081a3440 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-13 10:57 +# Generated by Django 4.2 on 2025-12-13 22:56 from django.conf import settings from django.db import migrations, models @@ -13,8 +13,8 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('auth', '0012_alter_user_first_name_max_length'), ('contenttypes', '0002_remove_content_type_name'), + ('auth', '0012_alter_user_first_name_max_length'), ] operations = [ diff --git a/src/polymorphic/tests/test_multidb.py b/src/polymorphic/tests/test_multidb.py index bd54a6f5..df754db0 100644 --- a/src/polymorphic/tests/test_multidb.py +++ b/src/polymorphic/tests/test_multidb.py @@ -137,3 +137,142 @@ def run(): # Ensure no queries are made using the default database. self.assertNumQueries(0, run) + + def test_create_from_super(self): + # run create test 3 times because initial implementation + # would fail after first success. + from polymorphic.tests.models import ( + NormalBase, + NormalExtension, + PolyExtension, + PolyExtChild, + ) + + nb = NormalBase.objects.db_manager("secondary").create(nb_field=1) + ne = NormalExtension.objects.db_manager("secondary").create(nb_field=2, ne_field="ne2") + + with self.assertRaises(TypeError): + PolyExtension.objects.db_manager("secondary").create_from_super(nb, poly_ext_field=3) + + pe = PolyExtension.objects.db_manager("secondary").create_from_super(ne, poly_ext_field=3) + + ne.refresh_from_db() + self.assertEqual(type(ne), NormalExtension) + self.assertEqual(type(pe), PolyExtension) + self.assertEqual(pe.pk, ne.pk) + + self.assertEqual(pe.nb_field, 2) + self.assertEqual(pe.ne_field, "ne2") + self.assertEqual(pe.poly_ext_field, 3) + pe.refresh_from_db() + self.assertEqual(pe.nb_field, 2) + self.assertEqual(pe.ne_field, "ne2") + self.assertEqual(pe.poly_ext_field, 3) + + pc = PolyExtChild.objects.db_manager("secondary").create_from_super( + pe, poly_child_field="pcf6" + ) + + pe.refresh_from_db() + ne.refresh_from_db() + self.assertEqual(type(ne), NormalExtension) + self.assertEqual(type(pe), PolyExtension) + self.assertEqual(pe.pk, ne.pk) + self.assertEqual(pe.pk, pc.pk) + + self.assertEqual(pc.nb_field, 2) + self.assertEqual(pc.ne_field, "ne2") + self.assertEqual(pc.poly_ext_field, 3) + pc.refresh_from_db() + self.assertEqual(pc.nb_field, 2) + self.assertEqual(pc.ne_field, "ne2") + self.assertEqual(pc.poly_ext_field, 3) + self.assertEqual(pc.poly_child_field, "pcf6") + + self.assertEqual( + pe.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild), + ) + self.assertEqual( + pc.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild), + ) + + self.assertEqual(set(PolyExtension.objects.db_manager("secondary").all()), {pc}) + + a1 = Model2A.objects.db_manager("secondary").create(field1="A1a") + a2 = Model2A.objects.db_manager("secondary").create(field1="A1b") + + b1 = Model2B.objects.db_manager("secondary").create(field1="B1a", field2="B2a") + b2 = Model2B.objects.db_manager("secondary").create(field1="B1b", field2="B2b") + + c1 = Model2C.objects.db_manager("secondary").create( + field1="C1a", field2="C2a", field3="C3a" + ) + c2 = Model2C.objects.db_manager("secondary").create( + field1="C1b", field2="C2b", field3="C3b" + ) + + d1 = Model2D.objects.db_manager("secondary").create( + field1="D1a", field2="D2a", field3="D3a", field4="D4a" + ) + d2 = Model2D.objects.db_manager("secondary").create( + field1="D1b", field2="D2b", field3="D3b", field4="D4b" + ) + + with self.assertRaises(TypeError): + Model2D.objects.db_manager("secondary").create_from_super( + b1, field3="D3x", field4="D4x" + ) + + b1_of_c = Model2B.objects.db_manager("secondary").non_polymorphic().get(pk=c1.pk) + with self.assertRaises(TypeError): + Model2C.objects.db_manager("secondary").create_from_super(b1_of_c, field3="C3x") + + self.assertEqual( + c1.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2C), + ) + dfs1 = Model2D.objects.db_manager("secondary").create_from_super(b1_of_c, field4="D4x") + self.assertEqual(type(dfs1), Model2D) + self.assertEqual(dfs1.pk, c1.pk) + self.assertEqual(dfs1.field1, "C1a") + self.assertEqual(dfs1.field2, "C2a") + self.assertEqual(dfs1.field3, "C3a") + self.assertEqual(dfs1.field4, "D4x") + self.assertEqual( + dfs1.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2D), + ) + c1.refresh_from_db() + self.assertEqual( + c1.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2D), + ) + + self.assertEqual( + b2.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2B), + ) + cfs1 = Model2C.objects.db_manager("secondary").create_from_super(b2, field3="C3y") + self.assertEqual(type(cfs1), Model2C) + self.assertEqual(cfs1.pk, b2.pk) + self.assertEqual(cfs1.field1, "B1b") + self.assertEqual(cfs1.field2, "B2b") + self.assertEqual(cfs1.field3, "C3y") + b2.refresh_from_db() + self.assertEqual( + b2.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2C), + ) + self.assertEqual( + cfs1.polymorphic_ctype, + ContentType.objects.db_manager("secondary").get_for_model(Model2C), + ) + + self.assertEqual( + set(Model2A.objects.db_manager("secondary").all()), + {a1, a2, b1, dfs1, cfs1, c2, d1, d2}, + ) + + self.assertEqual(Model2A.objects.count(), 0) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index e0641237..6160ab75 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -24,6 +24,7 @@ CustomPkBase, CustomPkInherit, Enhance_Base, + Enhance_Plain, Enhance_Inherit, InlineParent, InlineModelA, @@ -1769,3 +1770,101 @@ def test_manytomany_with_through_field(self): self.assertEqual(lake.ducks.count(), 2) self.assertIsInstance(lake.ducks.all()[0], RubberDuck) self.assertIsInstance(lake.ducks.all()[1], RedheadDuck) + + def test_create_from_super(self): + # run create test 3 times because initial implementation + # would fail after first success. + from polymorphic.tests.models import ( + NormalBase, + NormalExtension, + PolyExtension, + PolyExtChild, + ) + + nb = NormalBase.objects.create(nb_field=1) + ne = NormalExtension.objects.create(nb_field=2, ne_field="ne2") + + with self.assertRaises(TypeError): + PolyExtension.objects.create_from_super(nb, poly_ext_field=3) + + pe = PolyExtension.objects.create_from_super(ne, poly_ext_field=3) + + ne.refresh_from_db() + self.assertEqual(type(ne), NormalExtension) + self.assertEqual(type(pe), PolyExtension) + self.assertEqual(pe.pk, ne.pk) + + self.assertEqual(pe.nb_field, 2) + self.assertEqual(pe.ne_field, "ne2") + self.assertEqual(pe.poly_ext_field, 3) + pe.refresh_from_db() + self.assertEqual(pe.nb_field, 2) + self.assertEqual(pe.ne_field, "ne2") + self.assertEqual(pe.poly_ext_field, 3) + + pc = PolyExtChild.objects.create_from_super(pe, poly_child_field="pcf6") + + pe.refresh_from_db() + ne.refresh_from_db() + self.assertEqual(type(ne), NormalExtension) + self.assertEqual(type(pe), PolyExtension) + self.assertEqual(pe.pk, ne.pk) + self.assertEqual(pe.pk, pc.pk) + + self.assertEqual(pc.nb_field, 2) + self.assertEqual(pc.ne_field, "ne2") + self.assertEqual(pc.poly_ext_field, 3) + pc.refresh_from_db() + self.assertEqual(pc.nb_field, 2) + self.assertEqual(pc.ne_field, "ne2") + self.assertEqual(pc.poly_ext_field, 3) + self.assertEqual(pc.poly_child_field, "pcf6") + + self.assertEqual(pe.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild)) + self.assertEqual(pc.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild)) + + self.assertEqual(set(PolyExtension.objects.all()), {pc}) + + a1 = Model2A.objects.create(field1="A1a") + a2 = Model2A.objects.create(field1="A1b") + + b1 = Model2B.objects.create(field1="B1a", field2="B2a") + b2 = Model2B.objects.create(field1="B1b", field2="B2b") + + c1 = Model2C.objects.create(field1="C1a", field2="C2a", field3="C3a") + c2 = Model2C.objects.create(field1="C1b", field2="C2b", field3="C3b") + + d1 = Model2D.objects.create(field1="D1a", field2="D2a", field3="D3a", field4="D4a") + d2 = Model2D.objects.create(field1="D1b", field2="D2b", field3="D3b", field4="D4b") + + with self.assertRaises(TypeError): + Model2D.objects.create_from_super(b1, field3="D3x", field4="D4x") + + b1_of_c = Model2B.objects.non_polymorphic().get(pk=c1.pk) + with self.assertRaises(TypeError): + Model2C.objects.create_from_super(b1_of_c, field3="C3x") + + self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C)) + dfs1 = Model2D.objects.create_from_super(b1_of_c, field4="D4x") + self.assertEqual(type(dfs1), Model2D) + self.assertEqual(dfs1.pk, c1.pk) + self.assertEqual(dfs1.field1, "C1a") + self.assertEqual(dfs1.field2, "C2a") + self.assertEqual(dfs1.field3, "C3a") + self.assertEqual(dfs1.field4, "D4x") + self.assertEqual(dfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D)) + c1.refresh_from_db() + self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D)) + + self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2B)) + cfs1 = Model2C.objects.create_from_super(b2, field3="C3y") + self.assertEqual(type(cfs1), Model2C) + self.assertEqual(cfs1.pk, b2.pk) + self.assertEqual(cfs1.field1, "B1b") + self.assertEqual(cfs1.field2, "B2b") + self.assertEqual(cfs1.field3, "C3y") + b2.refresh_from_db() + self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2C)) + self.assertEqual(cfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C)) + + self.assertEqual(set(Model2A.objects.all()), {a1, a2, b1, dfs1, cfs1, c2, d1, d2})