Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,23 @@ Similarly, pre-V1.0 output formatting can be re-estated by using
``polymorphic_showfield_old_format = True``.


Creating Subclass Objects from Existing Superclass Objects
------------------------------------------------------------

You can create an instance of a subclass from an existing instance of a superclass using the
:meth:`~polymorphic.managers.PolymorphicManager.create_from_super` method
of the subclass's manager. For example:

.. code-block:: python
super_instance = ModelA.objects.get(id=1)
sub_instance = ModelB.objects.create_from_super(super_instance, field2='value2')
The restriction is that ``super_instance`` must be an instance of the direct superclass of
``ModelB``, and any required fields of ``ModelB`` must be provided as keyword arguments. If multiple
levels of subclassing are involved, you must call this method multiple times to "promote" each
level.

.. _restrictions:

Restrictions & Caveats
Expand Down
45 changes: 44 additions & 1 deletion src/polymorphic/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
The manager class for use in the models.
"""

from django.db import models
from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS, models

from polymorphic.query import PolymorphicQuerySet

Expand Down Expand Up @@ -49,3 +50,45 @@ def not_instance_of(self, *args):

def get_real_instances(self, base_result_objects=None):
return self.all().get_real_instances(base_result_objects=base_result_objects)

def create_from_super(self, obj, **kwargs):
"""
Create an instance of this manager's model class from the given instance of a
parent class.

This is useful when "promoting" an instance down the inheritance chain.

:param obj: An instance of a parent class of the manager's model class.
:param kwargs: Additional fields to set on the new instance.
:return: The newly created instance.
"""
from .models import PolymorphicModel

# ensure we have the most derived real instance
if isinstance(obj, PolymorphicModel):
obj = obj.get_real_instance()

parent_ptr = self.model._meta.parents.get(type(obj), None)

if not parent_ptr:
raise TypeError(
f"{obj.__class__.__name__} is not a direct parent of {self.model.__name__}"
)
kwargs[parent_ptr.get_attname()] = obj.pk

# create the new base class with only fields that apply to it.
ctype = ContentType.objects.db_manager(
using=(obj._state.db or DEFAULT_DB_ALIAS)
).get_for_model(self.model)
nobj = self.model(**kwargs, polymorphic_ctype=ctype)
nobj.save_base(raw=True, using=obj._state.db or DEFAULT_DB_ALIAS, force_insert=True)
# force update the content type, but first we need to
# retrieve a clean copy from the db to fill in the null
# fields otherwise they would be overwritten.
if isinstance(obj, PolymorphicModel):
parent = obj.__class__.objects.using(obj._state.db or DEFAULT_DB_ALIAS).get(pk=obj.pk)
parent.polymorphic_ctype = ctype
parent.save()

nobj.refresh_from_db() # cast to cls
return nobj
4 changes: 2 additions & 2 deletions src/polymorphic/tests/migrations/0001_initial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 4.2 on 2025-12-13 10:57
# Generated by Django 4.2 on 2025-12-13 22:56

from django.conf import settings
from django.db import migrations, models
Expand All @@ -13,8 +13,8 @@ class Migration(migrations.Migration):
initial = True

dependencies = [
('auth', '0012_alter_user_first_name_max_length'),
('contenttypes', '0002_remove_content_type_name'),
('auth', '0012_alter_user_first_name_max_length'),
]

operations = [
Expand Down
139 changes: 139 additions & 0 deletions src/polymorphic/tests/test_multidb.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,142 @@ def run():

# Ensure no queries are made using the default database.
self.assertNumQueries(0, run)

def test_create_from_super(self):
# run create test 3 times because initial implementation
# would fail after first success.
from polymorphic.tests.models import (
NormalBase,
NormalExtension,
PolyExtension,
PolyExtChild,
)

nb = NormalBase.objects.db_manager("secondary").create(nb_field=1)
ne = NormalExtension.objects.db_manager("secondary").create(nb_field=2, ne_field="ne2")

with self.assertRaises(TypeError):
PolyExtension.objects.db_manager("secondary").create_from_super(nb, poly_ext_field=3)

pe = PolyExtension.objects.db_manager("secondary").create_from_super(ne, poly_ext_field=3)

ne.refresh_from_db()
self.assertEqual(type(ne), NormalExtension)
self.assertEqual(type(pe), PolyExtension)
self.assertEqual(pe.pk, ne.pk)

self.assertEqual(pe.nb_field, 2)
self.assertEqual(pe.ne_field, "ne2")
self.assertEqual(pe.poly_ext_field, 3)
pe.refresh_from_db()
self.assertEqual(pe.nb_field, 2)
self.assertEqual(pe.ne_field, "ne2")
self.assertEqual(pe.poly_ext_field, 3)

pc = PolyExtChild.objects.db_manager("secondary").create_from_super(
pe, poly_child_field="pcf6"
)

pe.refresh_from_db()
ne.refresh_from_db()
self.assertEqual(type(ne), NormalExtension)
self.assertEqual(type(pe), PolyExtension)
self.assertEqual(pe.pk, ne.pk)
self.assertEqual(pe.pk, pc.pk)

self.assertEqual(pc.nb_field, 2)
self.assertEqual(pc.ne_field, "ne2")
self.assertEqual(pc.poly_ext_field, 3)
pc.refresh_from_db()
self.assertEqual(pc.nb_field, 2)
self.assertEqual(pc.ne_field, "ne2")
self.assertEqual(pc.poly_ext_field, 3)
self.assertEqual(pc.poly_child_field, "pcf6")

self.assertEqual(
pe.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild),
)
self.assertEqual(
pc.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild),
)

self.assertEqual(set(PolyExtension.objects.db_manager("secondary").all()), {pc})

a1 = Model2A.objects.db_manager("secondary").create(field1="A1a")
a2 = Model2A.objects.db_manager("secondary").create(field1="A1b")

b1 = Model2B.objects.db_manager("secondary").create(field1="B1a", field2="B2a")
b2 = Model2B.objects.db_manager("secondary").create(field1="B1b", field2="B2b")

c1 = Model2C.objects.db_manager("secondary").create(
field1="C1a", field2="C2a", field3="C3a"
)
c2 = Model2C.objects.db_manager("secondary").create(
field1="C1b", field2="C2b", field3="C3b"
)

d1 = Model2D.objects.db_manager("secondary").create(
field1="D1a", field2="D2a", field3="D3a", field4="D4a"
)
d2 = Model2D.objects.db_manager("secondary").create(
field1="D1b", field2="D2b", field3="D3b", field4="D4b"
)

with self.assertRaises(TypeError):
Model2D.objects.db_manager("secondary").create_from_super(
b1, field3="D3x", field4="D4x"
)

b1_of_c = Model2B.objects.db_manager("secondary").non_polymorphic().get(pk=c1.pk)
with self.assertRaises(TypeError):
Model2C.objects.db_manager("secondary").create_from_super(b1_of_c, field3="C3x")

self.assertEqual(
c1.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
)
dfs1 = Model2D.objects.db_manager("secondary").create_from_super(b1_of_c, field4="D4x")
self.assertEqual(type(dfs1), Model2D)
self.assertEqual(dfs1.pk, c1.pk)
self.assertEqual(dfs1.field1, "C1a")
self.assertEqual(dfs1.field2, "C2a")
self.assertEqual(dfs1.field3, "C3a")
self.assertEqual(dfs1.field4, "D4x")
self.assertEqual(
dfs1.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2D),
)
c1.refresh_from_db()
self.assertEqual(
c1.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2D),
)

self.assertEqual(
b2.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2B),
)
cfs1 = Model2C.objects.db_manager("secondary").create_from_super(b2, field3="C3y")
self.assertEqual(type(cfs1), Model2C)
self.assertEqual(cfs1.pk, b2.pk)
self.assertEqual(cfs1.field1, "B1b")
self.assertEqual(cfs1.field2, "B2b")
self.assertEqual(cfs1.field3, "C3y")
b2.refresh_from_db()
self.assertEqual(
b2.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
)
self.assertEqual(
cfs1.polymorphic_ctype,
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
)

self.assertEqual(
set(Model2A.objects.db_manager("secondary").all()),
{a1, a2, b1, dfs1, cfs1, c2, d1, d2},
)

self.assertEqual(Model2A.objects.count(), 0)
99 changes: 99 additions & 0 deletions src/polymorphic/tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CustomPkBase,
CustomPkInherit,
Enhance_Base,
Enhance_Plain,
Enhance_Inherit,
InlineParent,
InlineModelA,
Expand Down Expand Up @@ -1769,3 +1770,101 @@ def test_manytomany_with_through_field(self):
self.assertEqual(lake.ducks.count(), 2)
self.assertIsInstance(lake.ducks.all()[0], RubberDuck)
self.assertIsInstance(lake.ducks.all()[1], RedheadDuck)

def test_create_from_super(self):
# run create test 3 times because initial implementation
# would fail after first success.
from polymorphic.tests.models import (
NormalBase,
NormalExtension,
PolyExtension,
PolyExtChild,
)

nb = NormalBase.objects.create(nb_field=1)
ne = NormalExtension.objects.create(nb_field=2, ne_field="ne2")

with self.assertRaises(TypeError):
PolyExtension.objects.create_from_super(nb, poly_ext_field=3)

pe = PolyExtension.objects.create_from_super(ne, poly_ext_field=3)

ne.refresh_from_db()
self.assertEqual(type(ne), NormalExtension)
self.assertEqual(type(pe), PolyExtension)
self.assertEqual(pe.pk, ne.pk)

self.assertEqual(pe.nb_field, 2)
self.assertEqual(pe.ne_field, "ne2")
self.assertEqual(pe.poly_ext_field, 3)
pe.refresh_from_db()
self.assertEqual(pe.nb_field, 2)
self.assertEqual(pe.ne_field, "ne2")
self.assertEqual(pe.poly_ext_field, 3)

pc = PolyExtChild.objects.create_from_super(pe, poly_child_field="pcf6")

pe.refresh_from_db()
ne.refresh_from_db()
self.assertEqual(type(ne), NormalExtension)
self.assertEqual(type(pe), PolyExtension)
self.assertEqual(pe.pk, ne.pk)
self.assertEqual(pe.pk, pc.pk)

self.assertEqual(pc.nb_field, 2)
self.assertEqual(pc.ne_field, "ne2")
self.assertEqual(pc.poly_ext_field, 3)
pc.refresh_from_db()
self.assertEqual(pc.nb_field, 2)
self.assertEqual(pc.ne_field, "ne2")
self.assertEqual(pc.poly_ext_field, 3)
self.assertEqual(pc.poly_child_field, "pcf6")

self.assertEqual(pe.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild))
self.assertEqual(pc.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild))

self.assertEqual(set(PolyExtension.objects.all()), {pc})

a1 = Model2A.objects.create(field1="A1a")
a2 = Model2A.objects.create(field1="A1b")

b1 = Model2B.objects.create(field1="B1a", field2="B2a")
b2 = Model2B.objects.create(field1="B1b", field2="B2b")

c1 = Model2C.objects.create(field1="C1a", field2="C2a", field3="C3a")
c2 = Model2C.objects.create(field1="C1b", field2="C2b", field3="C3b")

d1 = Model2D.objects.create(field1="D1a", field2="D2a", field3="D3a", field4="D4a")
d2 = Model2D.objects.create(field1="D1b", field2="D2b", field3="D3b", field4="D4b")

with self.assertRaises(TypeError):
Model2D.objects.create_from_super(b1, field3="D3x", field4="D4x")

b1_of_c = Model2B.objects.non_polymorphic().get(pk=c1.pk)
with self.assertRaises(TypeError):
Model2C.objects.create_from_super(b1_of_c, field3="C3x")

self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))
dfs1 = Model2D.objects.create_from_super(b1_of_c, field4="D4x")
self.assertEqual(type(dfs1), Model2D)
self.assertEqual(dfs1.pk, c1.pk)
self.assertEqual(dfs1.field1, "C1a")
self.assertEqual(dfs1.field2, "C2a")
self.assertEqual(dfs1.field3, "C3a")
self.assertEqual(dfs1.field4, "D4x")
self.assertEqual(dfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D))
c1.refresh_from_db()
self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D))

self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2B))
cfs1 = Model2C.objects.create_from_super(b2, field3="C3y")
self.assertEqual(type(cfs1), Model2C)
self.assertEqual(cfs1.pk, b2.pk)
self.assertEqual(cfs1.field1, "B1b")
self.assertEqual(cfs1.field2, "B2b")
self.assertEqual(cfs1.field3, "C3y")
b2.refresh_from_db()
self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))
self.assertEqual(cfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))

self.assertEqual(set(Model2A.objects.all()), {a1, a2, b1, dfs1, cfs1, c2, d1, d2})
Loading