From 9b98005bf20a049287fb8e3942838525808c7ae0 Mon Sep 17 00:00:00 2001 From: Dalen Hughes Date: Mon, 27 Aug 2018 14:44:08 -0700 Subject: [PATCH 1/3] allow for sorting/filtering of JSON objects when using PostgreSQL #247 --- dynamic_rest/filters.py | 69 +++++++++++++++++++++++++++++++---------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index f1cd59b0..50bfe55e 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -3,10 +3,11 @@ from django.core.exceptions import ValidationError as InternalValidationError from django.core.exceptions import ImproperlyConfigured from django.db.models import Q, Prefetch, Manager +from django.db.models.expressions import RawSQL, OrderBy from django.utils import six from rest_framework import serializers from rest_framework.exceptions import ValidationError -from rest_framework.fields import BooleanField, NullBooleanField +from rest_framework.fields import BooleanField, NullBooleanField, JSONField from rest_framework.filters import BaseFilterBackend, OrderingFilter from dynamic_rest.utils import is_truthy @@ -127,6 +128,15 @@ def generate_query_key(self, serializer): # Recurse into nested field s = getattr(field, 'serializer', None) + if isinstance(field, JSONField): + # If a json field is found, append any terms following + j = i+1 + while j < len(self.field): + rewritten.append(self.field[j]) + j += 1 + if self.operator: + rewritten.append(self.operator) + return ('__'.join(rewritten), self.field) if isinstance(s, serializers.ListSerializer): s = s.child if not s: @@ -192,14 +202,12 @@ def filter_queryset(self, request, queryset, view): # after this is called may not behave as expected extra_filters = self.view.get_extra_filters(request) - disable_prefetches = self.view.is_update() - self.DEBUG = settings.DEBUG return self._build_queryset( queryset=queryset, extra_filters=extra_filters, - disable_prefetches=disable_prefetches, + disable_prefetches=False, ) """ @@ -643,7 +651,16 @@ def filter_queryset(self, request, queryset, view): """ self.ordering_param = view.SORT - ordering = self.get_ordering(request, queryset, view) + ordering, nested = self.get_ordering(request, queryset, view) + if ordering and nested: + ordering_str = ''.join(ordering) + if ordering_str.startswith('-'): + return queryset.order_by( + OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested), + descending=True)) + return queryset.order_by( + OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested), + descending=False)) if ordering: return queryset.order_by(*ordering) @@ -656,11 +673,13 @@ def get_ordering(self, request, queryset, view): This method overwrites the DRF default so it can parse the array. """ params = view.get_request_feature(view.SORT) + nested = [] if params: fields = [param.strip() for param in params] - valid_ordering, invalid_ordering = self.remove_invalid_fields( - queryset, fields, view - ) + valid_ordering, invalid_ordering, nested = \ + self.remove_invalid_fields( + queryset, fields, view + ) # if any of the sort fields are invalid, throw an error. # else return the ordering @@ -669,10 +688,10 @@ def get_ordering(self, request, queryset, view): "Invalid filter field: %s" % invalid_ordering ) else: - return valid_ordering + return valid_ordering, nested # No sorting was included - return self.get_default_ordering(view) + return self.get_default_ordering(view), nested def remove_invalid_fields(self, queryset, fields, view): """Remove invalid fields from an ordering. @@ -690,14 +709,14 @@ def remove_invalid_fields(self, queryset, fields, view): stripped_term = term.lstrip('-') # add back the '-' add the end if necessary reverse_sort_term = '' if len(stripped_term) is len(term) else '-' - ordering = self.ordering_for(stripped_term, view) + ordering, nested = self.ordering_for(stripped_term, view) if ordering: valid_orderings.append(reverse_sort_term + ordering) else: invalid_orderings.append(term) - return valid_orderings, invalid_orderings + return valid_orderings, invalid_orderings, nested def ordering_for(self, term, view): """ @@ -707,7 +726,7 @@ def ordering_for(self, term, view): Raise ImproperlyConfigured if serializer_class not set on view """ if not self._is_allowed_term(term, view): - return None + return None, None serializer = self._get_serializer_class(view)() serializer_chain = term.split('.') @@ -717,9 +736,27 @@ def ordering_for(self, term, view): for segment in serializer_chain[:-1]: field = serializer.get_all_fields().get(segment) + # If its a JSONField, construct a RawSQL command in the form + # of 'jsonField->{}'.format('nestedField')' or + # 'jsonField->{}->>{}'.format('nested','doubleNested') + if field and isinstance(field, JSONField): + json_chain_start = str(segment) + json_chain = '' + nested = [] + first = True + for nterm in serializer_chain[1:]: + if first: + json_chain += '->>%s' + first = False + else: + json_chain = '->%s'+json_chain + nested.append(nterm) + json_chain = json_chain_start + json_chain + return json_chain, nested + if not (field and field.source != '*' and isinstance(field, DynamicRelationField)): - return None + return None, None model_chain.append(field.source or segment) @@ -729,11 +766,11 @@ def ordering_for(self, term, view): last_field = serializer.get_all_fields().get(last_segment) if not last_field or last_field.source == '*': - return None + return None, None model_chain.append(last_field.source or last_segment) - return '__'.join(model_chain) + return '__'.join(model_chain), None def _is_allowed_term(self, term, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) From 3c8f98191efa50532cb7dcbd98e8488755f63dc8 Mon Sep 17 00:00:00 2001 From: Sagar Pandya Date: Mon, 22 Oct 2018 13:57:38 -0700 Subject: [PATCH 2/3] Fixed by where filtering by nested JSONField failed if the value was not a string --- dynamic_rest/filters.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 50bfe55e..722a9e6f 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -1,5 +1,6 @@ """This module contains custom filter backends.""" +import json from django.core.exceptions import ValidationError as InternalValidationError from django.core.exceptions import ImproperlyConfigured from django.db.models import Q, Prefetch, Manager @@ -136,7 +137,7 @@ def generate_query_key(self, serializer): j += 1 if self.operator: rewritten.append(self.operator) - return ('__'.join(rewritten), self.field) + return ('__'.join(rewritten), field) if isinstance(s, serializers.ListSerializer): s = s.child if not s: @@ -304,15 +305,24 @@ def _filters_to_query(self, includes, excludes, serializer, q=None): Q() instance or None if no inclusion or exclusion filters were specified. """ - def rewrite_filters(filters, serializer): out = {} for k, node in six.iteritems(filters): filter_key, field = node.generate_query_key(serializer) if isinstance(field, (BooleanField, NullBooleanField)): node.value = is_truthy(node.value) - out[filter_key] = node.value + # Who knows what the type of node.value is if it's JSON? + # it'll always come to us as `unicode` type. Therefore, let's try to json parse it: + if isinstance(field, JSONField): + try: + node.value = json.loads(node.value) + # it's a numeric type! json.loads will return the proper type + except ValueError: + # it's a string (json.loads('some string') will fail + # just leave it as is (that is, a unicode string) + pass + out[filter_key] = node.value return out q = q or Q() From 81451fd7a2a274313f9fcc957e1d9e15e3401522 Mon Sep 17 00:00:00 2001 From: Sagar Pandya Date: Mon, 22 Oct 2018 19:08:51 -0700 Subject: [PATCH 3/3] More robust fix for JSON filtering. Also supports a subset of operators (contains, startswith, etc) --- dynamic_rest/filters.py | 65 +++++++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/dynamic_rest/filters.py b/dynamic_rest/filters.py index 722a9e6f..df70d4c4 100644 --- a/dynamic_rest/filters.py +++ b/dynamic_rest/filters.py @@ -1,6 +1,4 @@ """This module contains custom filter backends.""" - -import json from django.core.exceptions import ValidationError as InternalValidationError from django.core.exceptions import ImproperlyConfigured from django.db.models import Q, Prefetch, Manager @@ -237,7 +235,6 @@ def _get_requested_filters(self, **kwargs): out = TreeMap() for spec, value in six.iteritems(filters_map): - # Inclusion or exclusion? if spec[0] == '-': spec = spec[1:] @@ -302,42 +299,40 @@ def _filters_to_query(self, includes, excludes, serializer, q=None): q: Q() object (optional) Returns: - Q() instance or None if no inclusion or exclusion filters - were specified. + Tuple of: + * Q() instance or None if no inclusion or exclusion filters + were specified. + * dictionary of {(field,): (operator, value)} for any json fields """ def rewrite_filters(filters, serializer): out = {} + json_out = {} for k, node in six.iteritems(filters): filter_key, field = node.generate_query_key(serializer) if isinstance(field, (BooleanField, NullBooleanField)): node.value = is_truthy(node.value) - # Who knows what the type of node.value is if it's JSON? - # it'll always come to us as `unicode` type. Therefore, let's try to json parse it: if isinstance(field, JSONField): - try: - node.value = json.loads(node.value) - # it's a numeric type! json.loads will return the proper type - except ValueError: - # it's a string (json.loads('some string') will fail - # just leave it as is (that is, a unicode string) - pass - out[filter_key] = node.value - return out + json_out[tuple(node.field)] = (node.operator, node.value) + else: + out[filter_key] = node.value + return out, json_out q = q or Q() + json_extras = None + if not includes and not excludes: - return None + return None, None if includes: - includes = rewrite_filters(includes, serializer) + includes, json_extras = rewrite_filters(includes, serializer) q &= Q(**includes) if excludes: - excludes = rewrite_filters(excludes, serializer) + excludes, json_extras = rewrite_filters(excludes, serializer) for k, v in six.iteritems(excludes): q &= ~Q(**{k: v}) - return q + return q, json_extras def _create_prefetch(self, source, queryset): return Prefetch(source, queryset=queryset) @@ -571,7 +566,7 @@ def _build_queryset( queryset = queryset.only(*only) # add request filters - query = self._filters_to_query( + query, json_extras = self._filters_to_query( includes=filters.get('_include'), excludes=filters.get('_exclude'), serializer=serializer @@ -581,12 +576,38 @@ def _build_queryset( if extra_filters: query = extra_filters if not query else extra_filters & query - if query: + if query or json_extras: # Convert internal django ValidationError to # APIException-based one in order to resolve validation error # from 500 status code to 400. try: queryset = queryset.filter(query) + + if json_extras: + extra_queries = [] + for json_field_names, (operator, value) in six.iteritems(json_extras): + if not operator: + query_operator = '=' + value = "'{}'".format(value) + elif operator in ('startswith', 'istartswith'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'{}%%'".format(value) + elif operator in ('endswith', 'iendswith'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'%%{}'".format(value) + elif operator in ('contains', 'icontains'): + query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE' + value = "'%%{}%%'".format(value) + else: + raise InternalValidationError('Unsupported filter operation for nested JSON fields: {}'.format(operator)) + + extra_query = [] + extra_query.append(json_field_names[0] + '->>') + extra_query.append('->>'.join(["'{}'".format(k) for k in json_field_names[1:]])) + extra_query.append(query_operator) + extra_query.append(value) + extra_queries.append(' '.join(extra_query)) + queryset = queryset.extra(where=extra_queries) except InternalValidationError as e: raise ValidationError( dict(e) if hasattr(e, 'error_dict') else list(e)