Skip to content

Commit 8f2304f

Browse files
authored
Merge pull request #29 from SectorLabs/make-conditional-indexes-usable
Make ConditionalUniqueIndexes actually usable.
2 parents 6260551 + cd3dcff commit 8f2304f

File tree

3 files changed

+117
-10
lines changed

3 files changed

+117
-10
lines changed

psqlextra/compiler.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,26 @@ def _rewrite_insert_update(self, sql, params, returning):
103103
# for conflicts
104104
conflict_target = self._build_conflict_target()
105105

106+
index_predicate = self.query.index_predicate
107+
108+
sql_template = (
109+
'{insert} ON CONFLICT {conflict_target} DO UPDATE '
110+
'SET {update_columns} RETURNING {returning}'
111+
)
112+
113+
if index_predicate:
114+
sql_template = (
115+
'{insert} ON CONFLICT {conflict_target} WHERE {index_predicate} DO UPDATE '
116+
'SET {update_columns} RETURNING {returning}'
117+
)
118+
106119
return (
107-
(
108-
'{insert} ON CONFLICT {conflict_target} DO UPDATE'
109-
' SET {update_columns} RETURNING {returning}'
110-
).format(
120+
sql_template.format(
111121
insert=sql,
112122
conflict_target=conflict_target,
113123
update_columns=update_columns,
114-
returning=returning
124+
returning=returning,
125+
index_predicate=index_predicate,
115126
),
116127
params
117128
)

psqlextra/manager/manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
2626

2727
self.conflict_target = None
2828
self.conflict_action = None
29+
self.index_predicate = None
2930

3031
def annotate(self, **annotations):
3132
"""Custom version of the standard annotate function
@@ -113,7 +114,7 @@ def update(self, **fields):
113114
# affected, let's do the same
114115
return len(rows)
115116

116-
def on_conflict(self, fields: List[Union[str, Tuple[str]]], action):
117+
def on_conflict(self, fields: List[Union[str, Tuple[str]]], action, index_predicate: str=None):
117118
"""Sets the action to take when conflicts arise when attempting
118119
to insert/create a new row.
119120
@@ -123,10 +124,16 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action):
123124
124125
action:
125126
The action to take when the conflict occurs.
127+
128+
index_predicate:
129+
The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking
130+
conflicts)
126131
"""
127132

128133
self.conflict_target = fields
129134
self.conflict_action = action
135+
self.index_predicate = index_predicate
136+
130137
return self
131138

132139
def bulk_insert(self, rows):
@@ -216,7 +223,7 @@ def insert_and_get(self, **fields):
216223

217224
return self.model(**model_init_fields)
218225

219-
def upsert(self, conflict_target: List, fields: Dict) -> int:
226+
def upsert(self, conflict_target: List, fields: Dict, index_predicate: str=None) -> int:
220227
"""Creates a new record or updates the existing one
221228
with the specified data.
222229
@@ -227,11 +234,15 @@ def upsert(self, conflict_target: List, fields: Dict) -> int:
227234
fields:
228235
Fields to insert/update.
229236
237+
index_predicate:
238+
The index predicate to satisfy an arbiter partial index (i.e. what partial index to use for checking
239+
conflicts)
240+
230241
Returns:
231242
The primary key of the row that was created/updated.
232243
"""
233244

234-
self.on_conflict(conflict_target, ConflictAction.UPDATE)
245+
self.on_conflict(conflict_target, ConflictAction.UPDATE, index_predicate)
235246
return self.insert(**fields)
236247

237248
def upsert_and_get(self, conflict_target: List, fields: Dict):
@@ -307,6 +318,7 @@ def _build_insert_compiler(self, rows: List[Dict]):
307318
query = PostgresInsertQuery(self.model)
308319
query.conflict_action = self.conflict_action
309320
query.conflict_target = self.conflict_target
321+
query.index_predicate = self.index_predicate
310322
query.values(objs, insert_fields, update_fields)
311323

312324
# use the postgresql insert query compiler to transform the insert
@@ -466,7 +478,7 @@ def on_conflict(self, fields: List[Union[str, Tuple[str]]], action):
466478
"""
467479
return self.get_queryset().on_conflict(fields, action)
468480

469-
def upsert(self, conflict_target: List, fields: Dict) -> int:
481+
def upsert(self, conflict_target: List, fields: Dict, index_predicate: str=None) -> int:
470482
"""Creates a new record or updates the existing one
471483
with the specified data.
472484
@@ -477,11 +489,14 @@ def upsert(self, conflict_target: List, fields: Dict) -> int:
477489
fields:
478490
Fields to insert/update.
479491
492+
index_predicate:
493+
The index predicate to satisfy an arbiter partial index.
494+
480495
Returns:
481496
The primary key of the row that was created/updated.
482497
"""
483498

484-
return self.get_queryset().upsert(conflict_target, fields)
499+
return self.get_queryset().upsert(conflict_target, fields, index_predicate)
485500

486501
def upsert_and_get(self, conflict_target: List, fields: Dict):
487502
"""Creates a new record or updates the existing one

tests/test_conditional_unique_index.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from django.db import models, IntegrityError, transaction
77
from django.db.migrations import AddIndex, CreateModel
88

9+
from .util import get_fake_model
10+
911

1012
def test_deconstruct():
1113
"""Tests whether the :see:HStoreField's deconstruct()
@@ -73,3 +75,82 @@ def test_migrations():
7375
Model.objects.create(id=1, name=None, other_name="other_name")
7476
with pytest.raises(IntegrityError):
7577
Model.objects.create(id=2, name=None, other_name="other_name")
78+
79+
80+
def test_upserting():
81+
"""Tests upserting respects the :see:ConditionalUniqueIndex rules"""
82+
model = get_fake_model(
83+
fields={
84+
'a': models.IntegerField(),
85+
'b': models.IntegerField(null=True),
86+
'c': models.IntegerField(),
87+
},
88+
meta_options={
89+
'indexes': [
90+
ConditionalUniqueIndex(
91+
fields=['a', 'b'],
92+
condition='"b" IS NOT NULL'
93+
),
94+
ConditionalUniqueIndex(
95+
fields=['a'],
96+
condition='"b" IS NULL'
97+
)
98+
]
99+
}
100+
)
101+
102+
model.objects.upsert(conflict_target=['a'], index_predicate='"b" IS NULL', fields=dict(a=1, c=1))
103+
assert model.objects.all().count() == 1
104+
assert model.objects.filter(a=1, c=1).count() == 1
105+
106+
model.objects.upsert(conflict_target=['a'], index_predicate='"b" IS NULL', fields=dict(a=1, c=2))
107+
assert model.objects.all().count() == 1
108+
assert model.objects.filter(a=1, c=1).count() == 0
109+
assert model.objects.filter(a=1, c=2).count() == 1
110+
111+
model.objects.upsert(conflict_target=['a', 'b'], index_predicate='"b" IS NOT NULL', fields=dict(a=1, b=1, c=1))
112+
assert model.objects.all().count() == 2
113+
assert model.objects.filter(a=1, c=2).count() == 1
114+
assert model.objects.filter(a=1, b=1, c=1).count() == 1
115+
116+
model.objects.upsert(conflict_target=['a', 'b'], index_predicate='"b" IS NOT NULL', fields=dict(a=1, b=1, c=2))
117+
assert model.objects.all().count() == 2
118+
assert model.objects.filter(a=1, c=1).count() == 0
119+
assert model.objects.filter(a=1, b=1, c=2).count() == 1
120+
121+
122+
def test_inserting():
123+
"""Tests inserting respects the :see:ConditionalUniqueIndex rules"""
124+
125+
model = get_fake_model(
126+
fields={
127+
'a': models.IntegerField(),
128+
'b': models.IntegerField(null=True),
129+
'c': models.IntegerField(),
130+
},
131+
meta_options={
132+
'indexes': [
133+
ConditionalUniqueIndex(
134+
fields=['a', 'b'],
135+
condition='"b" IS NOT NULL'
136+
),
137+
ConditionalUniqueIndex(
138+
fields=['a'],
139+
condition='"b" IS NULL'
140+
)
141+
]
142+
}
143+
)
144+
145+
model.objects.create(a=1, c=1)
146+
with transaction.atomic():
147+
with pytest.raises(IntegrityError):
148+
model.objects.create(a=1, c=2)
149+
model.objects.create(a=2, c=1)
150+
151+
model.objects.create(a=1, b=1, c=1)
152+
with transaction.atomic():
153+
with pytest.raises(IntegrityError):
154+
model.objects.create(a=1, b=1, c=2)
155+
156+
model.objects.create(a=1, b=2, c=1)

0 commit comments

Comments
 (0)