Skip to content

Commit efe6d16

Browse files
committed
Fix support for default values when upserting
Fixes #20
1 parent a963b75 commit efe6d16

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

psqlextra/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from django.db import models, transaction
77
from django.db.models.sql import UpdateQuery
88
from django.db.models.sql.constants import CURSOR
9+
from django.db.models.fields import NOT_PROVIDED
910

1011
from . import signals
1112
from .compiler import (PostgresReturningUpdateCompiler,
@@ -323,10 +324,14 @@ def _get_upsert_fields(self, kwargs):
323324
update_fields = []
324325

325326
for field in model_instance._meta.local_concrete_fields:
326-
if field.name in kwargs or field.column in kwargs:
327+
has_default = field.default != NOT_PROVIDED
328+
if (field.name in kwargs or field.column in kwargs):
327329
insert_fields.append(field)
328330
update_fields.append(field)
329331
continue
332+
elif has_default:
333+
insert_fields.append(field)
334+
continue
330335

331336
# special handling for 'pk' which always refers to
332337
# the primary key, so if we the user specifies `pk`

tests/test_on_conflict.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,57 @@ def test_on_conflict_pk_conflict_target(conflict_action):
333333
assert obj1.id == obj2.id
334334
assert obj1.id == 0
335335
assert obj2.id == 0
336+
337+
338+
def test_on_conflict_default_value():
339+
"""Tests whether setting a default for a field and
340+
not specifying it explicitely when upserting properly
341+
causes the default value to be used."""
342+
343+
model = get_fake_model({
344+
'title': models.CharField(max_length=255, default='great')
345+
})
346+
347+
obj1 = (
348+
model.objects
349+
.on_conflict(['id'], ConflictAction.UPDATE)
350+
.insert_and_get(id=0)
351+
)
352+
353+
assert obj1.title == 'great'
354+
355+
obj2 = (
356+
model.objects
357+
.on_conflict(['id'], ConflictAction.UPDATE)
358+
.insert_and_get(id=0)
359+
)
360+
361+
assert obj1.id == obj2.id
362+
assert obj2.title == 'great'
363+
364+
365+
def test_on_conflict_default_value_no_overwrite():
366+
"""Tests whether setting a default for a field, inserting
367+
a non-default value and then trying to update it without
368+
specifying that field doesn't result in it being overwritten."""
369+
370+
model = get_fake_model({
371+
'title': models.CharField(max_length=255, default='great')
372+
})
373+
374+
obj1 = (
375+
model.objects
376+
.on_conflict(['id'], ConflictAction.UPDATE)
377+
.insert_and_get(id=0, title='mytitle')
378+
)
379+
380+
assert obj1.title == 'mytitle'
381+
382+
obj2 = (
383+
model.objects
384+
.on_conflict(['id'], ConflictAction.UPDATE)
385+
.insert_and_get(id=0)
386+
)
387+
388+
assert obj1.id == obj2.id
389+
assert obj2.title == 'mytitle'

0 commit comments

Comments
 (0)