33from django .core .exceptions import ValidationError as InternalValidationError
44from django .core .exceptions import ImproperlyConfigured
55from django .db .models import Q , Prefetch , Manager
6+ from django .db .models .expressions import RawSQL , OrderBy
67import six
78from rest_framework import serializers
89from rest_framework .exceptions import ValidationError
9- from rest_framework .fields import BooleanField , NullBooleanField
10+ from rest_framework .fields import BooleanField , NullBooleanField , JSONField
1011from rest_framework .filters import BaseFilterBackend , OrderingFilter
11-
1212from dynamic_rest .utils import is_truthy
1313from dynamic_rest .conf import settings
1414from dynamic_rest .datastructures import TreeMap
@@ -127,6 +127,15 @@ def generate_query_key(self, serializer):
127127
128128 # Recurse into nested field
129129 s = getattr (field , 'serializer' , None )
130+ if isinstance (field , JSONField ):
131+ # If a json field is found, append any terms following
132+ j = i + 1
133+ while j < len (self .field ):
134+ rewritten .append (self .field [j ])
135+ j += 1
136+ if self .operator :
137+ rewritten .append (self .operator )
138+ return ('__' .join (rewritten ), field )
130139 if isinstance (s , serializers .ListSerializer ):
131140 s = s .child
132141 if not s :
@@ -294,33 +303,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
294303 q: Q() object (optional)
295304
296305 Returns:
297- Q() instance or None if no inclusion or exclusion filters
298- were specified.
306+ Tuple of:
307+ * Q() instance or None if no inclusion or exclusion filters
308+ were specified.
309+ * dictionary of {(field,): (operator, value)} for any json fields
299310 """
300311
301312 def rewrite_filters (filters , serializer ):
302313 out = {}
314+ json_out = {}
303315 for k , node in six .iteritems (filters ):
304316 filter_key , field = node .generate_query_key (serializer )
305317 if isinstance (field , (BooleanField , NullBooleanField )):
306318 node .value = is_truthy (node .value )
307- out [filter_key ] = node .value
308319
309- return out
320+ if isinstance (field , JSONField ):
321+ json_out [tuple (node .field )] = (node .operator , node .value )
322+ else :
323+ out [filter_key ] = node .value
324+ return out , json_out
310325
311326 q = q or Q ()
312327
328+ json_extras = None
329+
313330 if not includes and not excludes :
314- return None
331+ return None , None
315332
316333 if includes :
317- includes = rewrite_filters (includes , serializer )
334+ includes , json_extras = rewrite_filters (includes , serializer )
318335 q &= Q (** includes )
319336 if excludes :
320- excludes = rewrite_filters (excludes , serializer )
337+ excludes , json_extras = rewrite_filters (excludes , serializer )
321338 for k , v in six .iteritems (excludes ):
322339 q &= ~ Q (** {k : v })
323- return q
340+ return q , json_extras
324341
325342 def _create_prefetch (self , source , queryset ):
326343 return Prefetch (source , queryset = queryset )
@@ -569,7 +586,7 @@ def _build_queryset(
569586 queryset = queryset .only (* only )
570587
571588 # add request filters
572- query = self ._filters_to_query (
589+ query , json_extras = self ._filters_to_query (
573590 includes = filters .get ('_include' ),
574591 excludes = filters .get ('_exclude' ),
575592 serializer = serializer
@@ -579,12 +596,16 @@ def _build_queryset(
579596 if extra_filters :
580597 query = extra_filters if not query else extra_filters & query
581598
582- if query :
599+ if query or json_extras :
583600 # Convert internal django ValidationError to
584601 # APIException-based one in order to resolve validation error
585602 # from 500 status code to 400.
586603 try :
587604 queryset = queryset .filter (query )
605+
606+ if json_extras :
607+ extra_queries = self ._get_json_queries (json_extras )
608+ queryset = queryset .extra (where = extra_queries )
588609 except InternalValidationError as e :
589610 raise ValidationError (
590611 dict (e ) if hasattr (e , 'error_dict' ) else list (e )
@@ -620,6 +641,52 @@ def _build_queryset(
620641 queryset ._using_prefetches = prefetches
621642 return queryset
622643
644+ def _get_json_queries (self , json_extras ):
645+ extra_queries = []
646+
647+ for json_field_names , (operator , value ) in six .iteritems (json_extras ):
648+ if not operator :
649+ query_operator = '='
650+ value = "'{}'" .format (value )
651+ elif operator in ('startswith' , 'istartswith' ):
652+ query_operator = 'ILIKE' if operator [0 ] == 'i' else 'LIKE'
653+ value = "'{}%%'" .format (value )
654+ elif operator in ('endswith' , 'iendswith' ):
655+ query_operator = 'ILIKE' if operator [0 ] == 'i' else 'LIKE'
656+ value = "'%%{}'" .format (value )
657+ elif operator in ('contains' , 'icontains' ):
658+ query_operator = 'ILIKE' if operator [0 ] == 'i' else 'LIKE'
659+ value = "'%%{}%%'" .format (value )
660+
661+ else :
662+ raise InternalValidationError (
663+ f"""Unsupported filter operation for nested JSON fields:
664+ { operator } """
665+ )
666+
667+ extra_query = []
668+
669+ for idx , k in enumerate (json_field_names ):
670+ if idx == 0 :
671+ extra_query .append (k )
672+ else :
673+ extra_query .append ("'{}'" .format (k ))
674+
675+ if idx == len (json_field_names ) - 1 :
676+ continue
677+ # the ->> operator returns a raw value
678+ elif idx == len (json_field_names ) - 2 :
679+ extra_query .append ('->>' )
680+ # the -> operator returns JSON
681+ else :
682+ extra_query .append ('->' )
683+
684+ extra_query .append (query_operator )
685+ extra_query .append (value )
686+ extra_queries .append (' ' .join (extra_query ))
687+
688+ return extra_queries
689+
623690
624691class FastDynamicFilterBackend (DynamicFilterBackend ):
625692 def _create_prefetch (self , source , queryset ):
@@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view):
665732 """
666733 self .ordering_param = view .SORT
667734
668- ordering = self .get_ordering (request , queryset , view )
735+ ordering , nested = self .get_ordering (request , queryset , view )
736+ if ordering and nested :
737+ ordering_str = '' .join (ordering )
738+ if ordering_str .startswith ('-' ):
739+ return queryset .order_by (
740+ OrderBy (RawSQL ('LOWER( %s )' % (ordering_str [1 :]), nested ),
741+ descending = True ))
742+ return queryset .order_by (
743+ OrderBy (RawSQL ('LOWER(%s)' % (ordering_str ), nested ),
744+ descending = False ))
669745 if ordering :
670746 queryset = queryset .order_by (* ordering )
671747 if any (['__' in o for o in ordering ]):
@@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view):
681757 This method overwrites the DRF default so it can parse the array.
682758 """
683759 params = view .get_request_feature (view .SORT )
760+ nested = []
684761 if params :
685762 fields = [param .strip () for param in params ]
686- valid_ordering , invalid_ordering = self .remove_invalid_fields (
687- queryset , fields , view
688- )
763+ valid_ordering , invalid_ordering , nested = \
764+ self .remove_invalid_fields (
765+ queryset , fields , view
766+ )
689767
690768 # if any of the sort fields are invalid, throw an error.
691769 # else return the ordering
@@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view):
694772 "Invalid filter field: %s" % invalid_ordering
695773 )
696774 else :
697- return valid_ordering
775+ return valid_ordering , nested
698776
699777 # No sorting was included
700- return self .get_default_ordering (view )
778+ return self .get_default_ordering (view ), nested
701779
702780 def remove_invalid_fields (self , queryset , fields , view ):
703781 """Remove invalid fields from an ordering.
@@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view):
715793 stripped_term = term .lstrip ('-' )
716794 # add back the '-' add the end if necessary
717795 reverse_sort_term = '' if len (stripped_term ) is len (term ) else '-'
718- ordering = self .ordering_for (stripped_term , view )
796+ ordering , nested = self .ordering_for (stripped_term , view )
719797
720798 if ordering :
721799 valid_orderings .append (reverse_sort_term + ordering )
722800 else :
723801 invalid_orderings .append (term )
724802
725- return valid_orderings , invalid_orderings
803+ return valid_orderings , invalid_orderings , nested
726804
727805 def ordering_for (self , term , view ):
728806 """
@@ -732,7 +810,7 @@ def ordering_for(self, term, view):
732810 Raise ImproperlyConfigured if serializer_class not set on view
733811 """
734812 if not self ._is_allowed_term (term , view ):
735- return None
813+ return None , None
736814
737815 serializer = self ._get_serializer_class (view )()
738816 serializer_chain = term .split ('.' )
@@ -742,9 +820,27 @@ def ordering_for(self, term, view):
742820 for segment in serializer_chain [:- 1 ]:
743821 field = serializer .get_all_fields ().get (segment )
744822
823+ # If its a JSONField, construct a RawSQL command in the form
824+ # of 'jsonField->{}'.format('nestedField')' or
825+ # 'jsonField->>{}->{}'.format('nested','doubleNested')
826+ if field and isinstance (field , JSONField ):
827+ json_chain_start = str (segment )
828+ json_chain = ''
829+ nested = []
830+ first = True
831+ for nterm in serializer_chain [1 :]:
832+ if first :
833+ json_chain += '->>%s'
834+ first = False
835+ else :
836+ json_chain = '->%s' + json_chain
837+ nested .append (nterm )
838+ json_chain = json_chain_start + json_chain
839+ return json_chain , nested
840+
745841 if not (field and field .source != '*' and
746842 isinstance (field , DynamicRelationField )):
747- return None
843+ return None , None
748844
749845 model_chain .append (field .source or segment )
750846
@@ -754,11 +850,11 @@ def ordering_for(self, term, view):
754850 last_field = serializer .get_all_fields ().get (last_segment )
755851
756852 if not last_field or last_field .source == '*' :
757- return None
853+ return None , None
758854
759855 model_chain .append (last_field .source or last_segment )
760856
761- return '__' .join (model_chain )
857+ return '__' .join (model_chain ), None
762858
763859 def _is_allowed_term (self , term , view ):
764860 valid_fields = getattr (view , 'ordering_fields' , self .ordering_fields )
0 commit comments