Skip to content

Commit 04a9832

Browse files
author
sa-mmendivil
committed
allow for sorting/filtering of JSON objects when using PostgreSQL #247
1 parent cdf804f commit 04a9832

File tree

8 files changed

+276
-30
lines changed

8 files changed

+276
-30
lines changed

dynamic_rest/filters.py

Lines changed: 120 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from django.core.exceptions import ValidationError as InternalValidationError
44
from django.core.exceptions import ImproperlyConfigured
55
from django.db.models import Q, Prefetch, Manager
6+
from django.db.models.expressions import RawSQL, OrderBy
67
import six
78
from rest_framework import serializers
89
from rest_framework.exceptions import ValidationError
9-
from rest_framework.fields import BooleanField, NullBooleanField
10+
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
1011
from rest_framework.filters import BaseFilterBackend, OrderingFilter
11-
1212
from dynamic_rest.utils import is_truthy
1313
from dynamic_rest.conf import settings
1414
from dynamic_rest.datastructures import TreeMap
@@ -127,6 +127,15 @@ def generate_query_key(self, serializer):
127127

128128
# Recurse into nested field
129129
s = getattr(field, 'serializer', None)
130+
if isinstance(field, JSONField):
131+
# If a json field is found, append any terms following
132+
j = i+1
133+
while j < len(self.field):
134+
rewritten.append(self.field[j])
135+
j += 1
136+
if self.operator:
137+
rewritten.append(self.operator)
138+
return ('__'.join(rewritten), field)
130139
if isinstance(s, serializers.ListSerializer):
131140
s = s.child
132141
if not s:
@@ -294,33 +303,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
294303
q: Q() object (optional)
295304
296305
Returns:
297-
Q() instance or None if no inclusion or exclusion filters
298-
were specified.
306+
Tuple of:
307+
* Q() instance or None if no inclusion or exclusion filters
308+
were specified.
309+
* dictionary of {(field,): (operator, value)} for any json fields
299310
"""
300311

301312
def rewrite_filters(filters, serializer):
302313
out = {}
314+
json_out = {}
303315
for k, node in six.iteritems(filters):
304316
filter_key, field = node.generate_query_key(serializer)
305317
if isinstance(field, (BooleanField, NullBooleanField)):
306318
node.value = is_truthy(node.value)
307-
out[filter_key] = node.value
308319

309-
return out
320+
if isinstance(field, JSONField):
321+
json_out[tuple(node.field)] = (node.operator, node.value)
322+
else:
323+
out[filter_key] = node.value
324+
return out, json_out
310325

311326
q = q or Q()
312327

328+
json_extras = None
329+
313330
if not includes and not excludes:
314-
return None
331+
return None, None
315332

316333
if includes:
317-
includes = rewrite_filters(includes, serializer)
334+
includes, json_extras = rewrite_filters(includes, serializer)
318335
q &= Q(**includes)
319336
if excludes:
320-
excludes = rewrite_filters(excludes, serializer)
337+
excludes, json_extras = rewrite_filters(excludes, serializer)
321338
for k, v in six.iteritems(excludes):
322339
q &= ~Q(**{k: v})
323-
return q
340+
return q, json_extras
324341

325342
def _create_prefetch(self, source, queryset):
326343
return Prefetch(source, queryset=queryset)
@@ -569,7 +586,7 @@ def _build_queryset(
569586
queryset = queryset.only(*only)
570587

571588
# add request filters
572-
query = self._filters_to_query(
589+
query, json_extras = self._filters_to_query(
573590
includes=filters.get('_include'),
574591
excludes=filters.get('_exclude'),
575592
serializer=serializer
@@ -579,12 +596,16 @@ def _build_queryset(
579596
if extra_filters:
580597
query = extra_filters if not query else extra_filters & query
581598

582-
if query:
599+
if query or json_extras:
583600
# Convert internal django ValidationError to
584601
# APIException-based one in order to resolve validation error
585602
# from 500 status code to 400.
586603
try:
587604
queryset = queryset.filter(query)
605+
606+
if json_extras:
607+
extra_queries = self._get_json_queries(json_extras)
608+
queryset = queryset.extra(where=extra_queries)
588609
except InternalValidationError as e:
589610
raise ValidationError(
590611
dict(e) if hasattr(e, 'error_dict') else list(e)
@@ -620,6 +641,52 @@ def _build_queryset(
620641
queryset._using_prefetches = prefetches
621642
return queryset
622643

644+
def _get_json_queries(self, json_extras):
645+
extra_queries = []
646+
647+
for json_field_names, (operator, value) in six.iteritems(json_extras):
648+
if not operator:
649+
query_operator = '='
650+
value = "'{}'".format(value)
651+
elif operator in ('startswith', 'istartswith'):
652+
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
653+
value = "'{}%%'".format(value)
654+
elif operator in ('endswith', 'iendswith'):
655+
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
656+
value = "'%%{}'".format(value)
657+
elif operator in ('contains', 'icontains'):
658+
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
659+
value = "'%%{}%%'".format(value)
660+
661+
else:
662+
raise InternalValidationError(
663+
f"""Unsupported filter operation for nested JSON fields:
664+
{operator}"""
665+
)
666+
667+
extra_query = []
668+
669+
for idx, k in enumerate(json_field_names):
670+
if idx == 0:
671+
extra_query.append(k)
672+
else:
673+
extra_query.append("'{}'".format(k))
674+
675+
if idx == len(json_field_names) - 1:
676+
continue
677+
# the ->> operator returns a raw value
678+
elif idx == len(json_field_names) - 2:
679+
extra_query.append('->>')
680+
# the -> operator returns JSON
681+
else:
682+
extra_query.append('->')
683+
684+
extra_query.append(query_operator)
685+
extra_query.append(value)
686+
extra_queries.append(' '.join(extra_query))
687+
688+
return extra_queries
689+
623690

624691
class FastDynamicFilterBackend(DynamicFilterBackend):
625692
def _create_prefetch(self, source, queryset):
@@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view):
665732
"""
666733
self.ordering_param = view.SORT
667734

668-
ordering = self.get_ordering(request, queryset, view)
735+
ordering, nested = self.get_ordering(request, queryset, view)
736+
if ordering and nested:
737+
ordering_str = ''.join(ordering)
738+
if ordering_str.startswith('-'):
739+
return queryset.order_by(
740+
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
741+
descending=True))
742+
return queryset.order_by(
743+
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
744+
descending=False))
669745
if ordering:
670746
queryset = queryset.order_by(*ordering)
671747
if any(['__' in o for o in ordering]):
@@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view):
681757
This method overwrites the DRF default so it can parse the array.
682758
"""
683759
params = view.get_request_feature(view.SORT)
760+
nested = []
684761
if params:
685762
fields = [param.strip() for param in params]
686-
valid_ordering, invalid_ordering = self.remove_invalid_fields(
687-
queryset, fields, view
688-
)
763+
valid_ordering, invalid_ordering, nested = \
764+
self.remove_invalid_fields(
765+
queryset, fields, view
766+
)
689767

690768
# if any of the sort fields are invalid, throw an error.
691769
# else return the ordering
@@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view):
694772
"Invalid filter field: %s" % invalid_ordering
695773
)
696774
else:
697-
return valid_ordering
775+
return valid_ordering, nested
698776

699777
# No sorting was included
700-
return self.get_default_ordering(view)
778+
return self.get_default_ordering(view), nested
701779

702780
def remove_invalid_fields(self, queryset, fields, view):
703781
"""Remove invalid fields from an ordering.
@@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view):
715793
stripped_term = term.lstrip('-')
716794
# add back the '-' add the end if necessary
717795
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
718-
ordering = self.ordering_for(stripped_term, view)
796+
ordering, nested = self.ordering_for(stripped_term, view)
719797

720798
if ordering:
721799
valid_orderings.append(reverse_sort_term + ordering)
722800
else:
723801
invalid_orderings.append(term)
724802

725-
return valid_orderings, invalid_orderings
803+
return valid_orderings, invalid_orderings, nested
726804

727805
def ordering_for(self, term, view):
728806
"""
@@ -732,7 +810,7 @@ def ordering_for(self, term, view):
732810
Raise ImproperlyConfigured if serializer_class not set on view
733811
"""
734812
if not self._is_allowed_term(term, view):
735-
return None
813+
return None, None
736814

737815
serializer = self._get_serializer_class(view)()
738816
serializer_chain = term.split('.')
@@ -742,9 +820,27 @@ def ordering_for(self, term, view):
742820
for segment in serializer_chain[:-1]:
743821
field = serializer.get_all_fields().get(segment)
744822

823+
# If its a JSONField, construct a RawSQL command in the form
824+
# of 'jsonField->{}'.format('nestedField')' or
825+
# 'jsonField->>{}->{}'.format('nested','doubleNested')
826+
if field and isinstance(field, JSONField):
827+
json_chain_start = str(segment)
828+
json_chain = ''
829+
nested = []
830+
first = True
831+
for nterm in serializer_chain[1:]:
832+
if first:
833+
json_chain += '->>%s'
834+
first = False
835+
else:
836+
json_chain = '->%s' + json_chain
837+
nested.append(nterm)
838+
json_chain = json_chain_start + json_chain
839+
return json_chain, nested
840+
745841
if not (field and field.source != '*' and
746842
isinstance(field, DynamicRelationField)):
747-
return None
843+
return None, None
748844

749845
model_chain.append(field.source or segment)
750846

@@ -754,11 +850,11 @@ def ordering_for(self, term, view):
754850
last_field = serializer.get_all_fields().get(last_segment)
755851

756852
if not last_field or last_field.source == '*':
757-
return None
853+
return None, None
758854

759855
model_chain.append(last_field.source or last_segment)
760856

761-
return '__'.join(model_chain)
857+
return '__'.join(model_chain), None
762858

763859
def _is_allowed_term(self, term, view):
764860
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import unicode_literals
3+
4+
from django.db import migrations, models
5+
from django.contrib.postgres.fields import JSONField
6+
7+
8+
class Migration(migrations.Migration):
9+
10+
dependencies = [
11+
('tests', '0006_auto_20210921_1026'),
12+
]
13+
14+
operations = [
15+
migrations.CreateModel(
16+
name='recipe',
17+
fields=[
18+
('name', models.CharField(max_length=60)),
19+
('ingredients', JSONField(null=True))
20+
]
21+
),
22+
]

tests/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.contrib.contenttypes.fields import GenericForeignKey
22
from django.contrib.contenttypes.models import ContentType
3+
from django.contrib.postgres.fields import JSONField
34
from django.db import models
45

56

@@ -137,3 +138,8 @@ class Part(models.Model):
137138
car = models.ForeignKey(Car, on_delete=models.CASCADE)
138139
name = models.CharField(max_length=60)
139140
country = models.ForeignKey(Country, on_delete=models.CASCADE)
141+
142+
143+
class Recipe(models.Model):
144+
name = models.CharField(max_length=60)
145+
ingredients = JSONField(null=True)

tests/serializers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Part,
2323
Permission,
2424
Profile,
25+
Recipe,
2526
User,
2627
Zebra,
2728
)
@@ -323,3 +324,9 @@ class Meta:
323324
model = Car
324325
fields = ('id', 'name', 'country', 'parts')
325326
deferred_fields = ('name', 'country', 'parts')
327+
328+
329+
class RecipeSerializer(DynamicModelSerializer):
330+
class Meta:
331+
model = Recipe
332+
fields = ('name', 'ingredients')

0 commit comments

Comments
 (0)