Skip to content

Commit 34be8d6

Browse files
committed
Experimental support for custom JOIN conditions
1 parent badf827 commit 34be8d6

File tree

5 files changed

+130
-4
lines changed

5 files changed

+130
-4
lines changed

psqlextra/datastructures.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import List, Tuple, Any
2+
3+
from django.db.models.sql.datastructures import Join
4+
5+
6+
class ConditionalJoin(Join):
7+
"""A custom JOIN statement that allows attaching
8+
extra conditions."""
9+
10+
def __init__(self, *args, **kwargs):
11+
"""Initializes a new instance of :see:ConditionalJoin."""
12+
13+
super().__init__(*args, **kwargs)
14+
self.join_type = 'LEFT OUTER JOIN'
15+
self.extra_conditions = []
16+
17+
def add_condition(self, field, value: Any) -> None:
18+
"""Adds an extra condition to this join.
19+
20+
Arguments:
21+
field:
22+
The field that the condition will apply to.
23+
24+
value:
25+
The value to compare.
26+
"""
27+
28+
self.extra_conditions.append((field, value))
29+
30+
def as_sql(self, compiler, connection) -> Tuple[str, List[Any]]:
31+
"""Compiles this JOIN into a SQL string."""
32+
33+
sql, params = super().as_sql(compiler, connection)
34+
qn = compiler.quote_name_unless_alias
35+
36+
# generate the extra conditions
37+
extra_conditions = ' AND '.join([
38+
'{}.{} = %s'.format(
39+
qn(self.table_name),
40+
qn(field.column)
41+
)
42+
for field, value in self.extra_conditions
43+
])
44+
45+
# add to the existing params, so the connector will
46+
# actually nicely format the value for us
47+
for _, value in self.extra_conditions:
48+
params.append(value)
49+
50+
# rewrite the sql to include the extra conditions
51+
rewritten_sql = sql.replace(')', ' AND {})'.format(extra_conditions))
52+
return rewritten_sql, params
53+
54+
@classmethod
55+
def from_join(cls, join: Join) -> 'ConditionalJoin':
56+
"""Creates a new :see:ConditionalJoin from the
57+
specified :see:Join object.
58+
59+
Arguments:
60+
join:
61+
The :see:Join object to create the
62+
:see:ConditionalJoin object from.
63+
64+
Returns:
65+
A :see:ConditionalJoin object created from
66+
the :see:Join object.
67+
"""
68+
69+
return cls(
70+
join.table_name,
71+
join.parent_alias,
72+
join.table_alias,
73+
join.join_type,
74+
join.join_field,
75+
join.nullable
76+
)

psqlextra/expressions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def resolve_expression(self, *args, **kwargs) -> HStoreColumn:
8787
)
8888
return expression
8989

90+
9091
class NonGroupableFunc(expressions.Func):
9192
"""A version of Django's :see:Func expression that
9293
is _never_ included in the GROUP BY clause."""

psqlextra/manager.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ def __init__(self, model=None, query=None, using=None, hints=None):
2626
self.conflict_target = None
2727
self.conflict_action = None
2828

29+
def join(self, **conditions):
30+
"""Adds extra conditions to existing joins.
31+
32+
WARNING: This is an extremely experimental feature.
33+
DO NOT USE unless you know what you're doing.
34+
"""
35+
36+
self.query.add_join_conditions(conditions)
37+
return self
38+
2939
def update(self, **fields):
3040
"""Updates all rows that match the filter."""
3141

@@ -35,7 +45,7 @@ def update(self, **fields):
3545
query._annotations = None
3646
query.add_update_values(fields)
3747

38-
# build the compiler for form the query
48+
# build the compiler for for the query
3949
connection = django.db.connections[self.db]
4050
compiler = PostgresReturningUpdateCompiler(query, connection, self.db)
4151

psqlextra/query.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
from typing import List, Tuple, Optional
1+
from typing import List, Tuple, Optional, Dict, Any
22
from enum import Enum
33

44
from django.db import models
55
from django.db.models import sql
6-
76
from django.db.models.constants import LOOKUP_SEP
7+
from django.core.exceptions import SuspiciousOperation
88

99
from .fields import HStoreField
1010
from .expressions import HStoreColumn
11+
from .datastructures import ConditionalJoin
1112

1213

1314
class ConflictAction(Enum):
@@ -18,6 +19,43 @@ class ConflictAction(Enum):
1819

1920

2021
class PostgresQuery(sql.Query):
22+
def add_join_conditions(self, conditions: Dict[str, Any]) -> None:
23+
"""Adds an extra condition to an existing JOIN.
24+
25+
This allows you to for example do:
26+
27+
INNER JOIN othertable ON (mytable.id = othertable.other_id AND [extra conditions])
28+
29+
This does not work if nothing else in your query doesn't already generate the
30+
initial join in the first place.
31+
"""
32+
33+
alias = self.get_initial_alias()
34+
opts = self.get_meta()
35+
36+
for name, value in conditions.items():
37+
parts = name.split(LOOKUP_SEP)
38+
_, targets, _, joins, path = self.setup_joins(parts, opts, alias, allow_many=True)
39+
self.trim_joins(targets, joins, path)
40+
41+
target_table = joins[-1]
42+
field = targets[-1]
43+
join = self.alias_map.get(target_table)
44+
45+
if not join:
46+
raise SuspiciousOperation((
47+
'Cannot add an extra join condition for "%s", there\'s no'
48+
'existing join to add it to.'
49+
) % target_table)
50+
51+
# convert the Join object into a ConditionalJoin object, which
52+
# allows us to add the extra condition
53+
if not isinstance(join, ConditionalJoin):
54+
self.alias_map[target_table] = ConditionalJoin.from_join(join)
55+
join = self.alias_map[target_table]
56+
57+
join.add_condition(field, value)
58+
2159
def add_fields(self, field_names: List[str], allow_m2m: bool=True) -> bool:
2260
"""
2361
Adds the given (model) fields to the select set. The field names are

tests/test_hstore_field.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def test_values():
5050

5151
result = list(model.objects.values_list('title__en', 'title__ar'))[0]
5252

53+
5354
def test_annotate_ref():
5455
"""Tests whether annotating using a :see:HStoreRef expression
5556
works correctly.
@@ -65,7 +66,7 @@ def test_annotate_ref():
6566
})
6667

6768
fk = model_fk.objects.create(title={'en': 'english', 'ar': 'arabic'})
68-
obj = model.objects.create(fk=fk)
69+
model.objects.create(fk=fk)
6970

7071
queryset = (
7172
model.objects

0 commit comments

Comments
 (0)