Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 97 additions & 29 deletions dynamic_rest/filters.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""This module contains custom filter backends."""

from django.core.exceptions import ValidationError as InternalValidationError
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Prefetch, Manager
from django.db.models.expressions import RawSQL, OrderBy
from django.utils import six
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, NullBooleanField
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
from rest_framework.filters import BaseFilterBackend, OrderingFilter

from dynamic_rest.utils import is_truthy
Expand Down Expand Up @@ -127,6 +127,15 @@ def generate_query_key(self, serializer):

# Recurse into nested field
s = getattr(field, 'serializer', None)
if isinstance(field, JSONField):
# If a json field is found, append any terms following
j = i+1
while j < len(self.field):
rewritten.append(self.field[j])
j += 1
if self.operator:
rewritten.append(self.operator)
return ('__'.join(rewritten), field)
if isinstance(s, serializers.ListSerializer):
s = s.child
if not s:
Expand Down Expand Up @@ -192,14 +201,12 @@ def filter_queryset(self, request, queryset, view):
# after this is called may not behave as expected
extra_filters = self.view.get_extra_filters(request)

disable_prefetches = self.view.is_update()

self.DEBUG = settings.DEBUG

return self._build_queryset(
queryset=queryset,
extra_filters=extra_filters,
disable_prefetches=disable_prefetches,
disable_prefetches=False,
)

"""
Expand Down Expand Up @@ -228,7 +235,6 @@ def _get_requested_filters(self, **kwargs):
out = TreeMap()

for spec, value in six.iteritems(filters_map):

# Inclusion or exclusion?
if spec[0] == '-':
spec = spec[1:]
Expand Down Expand Up @@ -293,33 +299,40 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
q: Q() object (optional)

Returns:
Q() instance or None if no inclusion or exclusion filters
were specified.
Tuple of:
* Q() instance or None if no inclusion or exclusion filters
were specified.
* dictionary of {(field,): (operator, value)} for any json fields
"""

def rewrite_filters(filters, serializer):
out = {}
json_out = {}
for k, node in six.iteritems(filters):
filter_key, field = node.generate_query_key(serializer)
if isinstance(field, (BooleanField, NullBooleanField)):
node.value = is_truthy(node.value)
out[filter_key] = node.value

return out
if isinstance(field, JSONField):
json_out[tuple(node.field)] = (node.operator, node.value)
else:
out[filter_key] = node.value
return out, json_out

q = q or Q()

json_extras = None

if not includes and not excludes:
return None
return None, None

if includes:
includes = rewrite_filters(includes, serializer)
includes, json_extras = rewrite_filters(includes, serializer)
q &= Q(**includes)
if excludes:
excludes = rewrite_filters(excludes, serializer)
excludes, json_extras = rewrite_filters(excludes, serializer)
for k, v in six.iteritems(excludes):
q &= ~Q(**{k: v})
return q
return q, json_extras

def _create_prefetch(self, source, queryset):
return Prefetch(source, queryset=queryset)
Expand Down Expand Up @@ -553,7 +566,7 @@ def _build_queryset(
queryset = queryset.only(*only)

# add request filters
query = self._filters_to_query(
query, json_extras = self._filters_to_query(
includes=filters.get('_include'),
excludes=filters.get('_exclude'),
serializer=serializer
Expand All @@ -563,12 +576,38 @@ def _build_queryset(
if extra_filters:
query = extra_filters if not query else extra_filters & query

if query:
if query or json_extras:
# Convert internal django ValidationError to
# APIException-based one in order to resolve validation error
# from 500 status code to 400.
try:
queryset = queryset.filter(query)

if json_extras:
extra_queries = []
for json_field_names, (operator, value) in six.iteritems(json_extras):
if not operator:
query_operator = '='
value = "'{}'".format(value)
elif operator in ('startswith', 'istartswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'{}%%'".format(value)
elif operator in ('endswith', 'iendswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}'".format(value)
elif operator in ('contains', 'icontains'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}%%'".format(value)
else:
raise InternalValidationError('Unsupported filter operation for nested JSON fields: {}'.format(operator))

extra_query = []
extra_query.append(json_field_names[0] + '->>')
extra_query.append('->>'.join(["'{}'".format(k) for k in json_field_names[1:]]))
extra_query.append(query_operator)
extra_query.append(value)
extra_queries.append(' '.join(extra_query))
queryset = queryset.extra(where=extra_queries)
except InternalValidationError as e:
raise ValidationError(
dict(e) if hasattr(e, 'error_dict') else list(e)
Expand Down Expand Up @@ -643,7 +682,16 @@ def filter_queryset(self, request, queryset, view):
"""
self.ordering_param = view.SORT

ordering = self.get_ordering(request, queryset, view)
ordering, nested = self.get_ordering(request, queryset, view)
if ordering and nested:
ordering_str = ''.join(ordering)
if ordering_str.startswith('-'):
return queryset.order_by(
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
descending=True))
return queryset.order_by(
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
descending=False))
if ordering:
return queryset.order_by(*ordering)

Expand All @@ -656,11 +704,13 @@ def get_ordering(self, request, queryset, view):
This method overwrites the DRF default so it can parse the array.
"""
params = view.get_request_feature(view.SORT)
nested = []
if params:
fields = [param.strip() for param in params]
valid_ordering, invalid_ordering = self.remove_invalid_fields(
queryset, fields, view
)
valid_ordering, invalid_ordering, nested = \
self.remove_invalid_fields(
queryset, fields, view
)

# if any of the sort fields are invalid, throw an error.
# else return the ordering
Expand All @@ -669,10 +719,10 @@ def get_ordering(self, request, queryset, view):
"Invalid filter field: %s" % invalid_ordering
)
else:
return valid_ordering
return valid_ordering, nested

# No sorting was included
return self.get_default_ordering(view)
return self.get_default_ordering(view), nested

def remove_invalid_fields(self, queryset, fields, view):
"""Remove invalid fields from an ordering.
Expand All @@ -690,14 +740,14 @@ def remove_invalid_fields(self, queryset, fields, view):
stripped_term = term.lstrip('-')
# add back the '-' add the end if necessary
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
ordering = self.ordering_for(stripped_term, view)
ordering, nested = self.ordering_for(stripped_term, view)

if ordering:
valid_orderings.append(reverse_sort_term + ordering)
else:
invalid_orderings.append(term)

return valid_orderings, invalid_orderings
return valid_orderings, invalid_orderings, nested

def ordering_for(self, term, view):
"""
Expand All @@ -707,7 +757,7 @@ def ordering_for(self, term, view):
Raise ImproperlyConfigured if serializer_class not set on view
"""
if not self._is_allowed_term(term, view):
return None
return None, None

serializer = self._get_serializer_class(view)()
serializer_chain = term.split('.')
Expand All @@ -717,9 +767,27 @@ def ordering_for(self, term, view):
for segment in serializer_chain[:-1]:
field = serializer.get_all_fields().get(segment)

# If its a JSONField, construct a RawSQL command in the form
# of 'jsonField->{}'.format('nestedField')' or
# 'jsonField->{}->>{}'.format('nested','doubleNested')
if field and isinstance(field, JSONField):
json_chain_start = str(segment)
json_chain = ''
nested = []
first = True
for nterm in serializer_chain[1:]:
if first:
json_chain += '->>%s'
first = False
else:
json_chain = '->%s'+json_chain
nested.append(nterm)
json_chain = json_chain_start + json_chain
return json_chain, nested

if not (field and field.source != '*' and
isinstance(field, DynamicRelationField)):
return None
return None, None

model_chain.append(field.source or segment)

Expand All @@ -729,11 +797,11 @@ def ordering_for(self, term, view):
last_field = serializer.get_all_fields().get(last_segment)

if not last_field or last_field.source == '*':
return None
return None, None

model_chain.append(last_field.source or last_segment)

return '__'.join(model_chain)
return '__'.join(model_chain), None

def _is_allowed_term(self, term, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
Expand Down