Skip to content

Commit 294d811

Browse files
author
Ryan P Kilby
committed
Merge branch 'master' into fix-filter-complexity
Conflicts: rest_framework_filters/filters.py rest_framework_filters/filterset.py rest_framework_filters/tests.py
2 parents 8258761 + 28728fd commit 294d811

File tree

5 files changed

+147
-72
lines changed

5 files changed

+147
-72
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ env:
1212
- DJANGO="Django>=1.8,<1.9"
1313
install:
1414
- travis_retry pip install -q $DJANGO
15-
- pip install py-dateutil
1615
- python setup.py install
1716
script: python manage.py test rest_framework_filters
1817

rest_framework_filters/fields.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from django import forms
22

3+
4+
# https://code.djangoproject.com/ticket/19917
5+
class Django14TimeField(forms.TimeField):
6+
input_formats = ['%H:%M:%S', '%H:%M:%S.%f', '%H:%M']
7+
8+
39
class ArrayDecimalField(forms.DecimalField):
410
def clean(self, value):
511
if value is None:
@@ -9,3 +15,14 @@ def clean(self, value):
915
for val in value.split(','):
1016
out.append(super(ArrayDecimalField, self).clean(val))
1117
return out
18+
19+
20+
class ArrayCharField(forms.CharField):
21+
def clean(self, value):
22+
if value is None:
23+
return None
24+
25+
out = []
26+
for val in value.split(','):
27+
out.append(super(ArrayCharField, self).clean(val))
28+
return out

rest_framework_filters/filters.py

Lines changed: 13 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from collections import OrderedDict
55
from django.utils import six
66

7-
from rest_framework.settings import api_settings
8-
import django_filters
7+
import django
98
from django_filters.filters import *
109

1110
from . import fields
@@ -17,35 +16,6 @@ def _import_class(path):
1716
return getattr(module, class_name)
1817

1918

20-
def subsitute_iso8601(date_type):
21-
from rest_framework import ISO_8601
22-
23-
if date_type == 'datetime':
24-
strptime_iso8601 = '%Y-%m-%dT%H:%M:%S.%f'
25-
formats = api_settings.DATETIME_INPUT_FORMATS
26-
elif date_type == 'date':
27-
strptime_iso8601 = '%Y-%m-%d'
28-
formats = api_settings.DATE_INPUT_FORMATS
29-
elif date_type == 'time':
30-
strptime_iso8601 = '%H:%M:%S.%f'
31-
formats = api_settings.TIME_INPUT_FORMATS
32-
33-
new_formats = []
34-
for f in formats:
35-
if f == ISO_8601:
36-
new_formats.append(strptime_iso8601)
37-
else:
38-
new_formats.append(f)
39-
return new_formats
40-
41-
42-
# In order to support ISO-8601 -- which is the default output for
43-
# DRF -- we need to set up custom date/time input formats.
44-
TIME_INPUT_FORMATS = subsitute_iso8601('time')
45-
DATE_INPUT_FORMATS = subsitute_iso8601('date')
46-
DATETIME_INPUT_FORMATS = subsitute_iso8601('datetime')
47-
48-
4919
class RelatedFilter(ModelChoiceFilter):
5020
def __init__(self, filterset, *args, **kwargs):
5121
self.filterset = filterset
@@ -97,27 +67,12 @@ class AllLookupsFilter(Filter):
9767
# Fixed-up versions of some of the default filters
9868
###################################################
9969

100-
class DateFilter(django_filters.DateFilter):
101-
def __init__(self, *args, **kwargs):
102-
super(DateFilter, self).__init__(*args, **kwargs)
103-
self.extra.update({'input_formats': DATE_INPUT_FORMATS})
104-
105-
106-
class DateTimeFilter(django_filters.DateTimeFilter):
107-
def __init__(self, *args, **kwargs):
108-
super(DateTimeFilter, self).__init__(*args, **kwargs)
109-
self.extra.update({'input_formats': DATETIME_INPUT_FORMATS})
110-
70+
class TimeFilter(TimeFilter):
71+
if django.VERSION < (1, 6):
72+
field_class = fields.Django14TimeField
11173

112-
class TimeFilter(django_filters.DateTimeFilter):
113-
def __init__(self, *args, **kwargs):
114-
super(TimeFilter, self).__init__(*args, **kwargs)
115-
self.extra.update({'input_formats': TIME_INPUT_FORMATS})
116-
117-
118-
class InSetNumberFilter(NumberFilter):
119-
field_class = fields.ArrayDecimalField
12074

75+
class InSetFilterBase(object):
12176
def filter(self, qs, value):
12277
if value in ([], (), {}, None, ''):
12378
return qs
@@ -126,3 +81,11 @@ def filter(self, qs, value):
12681
if self.distinct:
12782
qs = qs.distinct()
12883
return qs
84+
85+
86+
class InSetNumberFilter(InSetFilterBase, NumberFilter):
87+
field_class = fields.ArrayDecimalField
88+
89+
90+
class InSetCharFilter(InSetFilterBase, NumberFilter):
91+
field_class = fields.ArrayCharField

rest_framework_filters/filterset.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,16 @@ def __new__(cls, name, bases, attrs):
5454

5555

5656
class FilterSet(six.with_metaclass(FilterSetMetaclass, filterset.FilterSet)):
57-
# In order to support ISO-8601 -- which is the default output for
58-
# DRF -- we need to set up custom date/time input formats.
5957
filter_overrides = {
58+
59+
# In order to support ISO-8601 -- which is the default output for
60+
# DRF -- we need to use django-filter's IsoDateTimeFilter
6061
models.DateTimeField: {
61-
'filter_class': filters.DateTimeFilter,
62-
},
63-
models.DateField: {
64-
'filter_class': filters.DateFilter,
65-
},
62+
'filter_class': filters.IsoDateTimeFilter,
63+
},
64+
65+
# Django < 1.6 time input formats did not account for microseconds
66+
# https://code.djangoproject.com/ticket/19917
6667
models.TimeField: {
6768
'filter_class': filters.TimeFilter,
6869
},
@@ -183,6 +184,8 @@ def fix_filter_field(cls, f):
183184
lookup_type = f.lookup_type
184185
if lookup_type == 'isnull':
185186
return filters.BooleanFilter(name=("%s%sisnull" % (f.name, LOOKUP_SEP)))
186-
if lookup_type == 'in' and type(f) in [filters.NumberFilter]:
187+
if lookup_type == 'in' and type(f) == filters.NumberFilter:
187188
return filters.InSetNumberFilter(name=("%s%sin" % (f.name, LOOKUP_SEP)))
189+
if lookup_type == 'in' and type(f) == filters.CharFilter:
190+
return filters.InSetCharFilter(name=("%s%sin" % (f.name, LOOKUP_SEP)))
188191
return f

rest_framework_filters/tests.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
import time
66
import datetime
77

8-
from dateutil.parser import parse as date_parse
8+
from django.utils.dateparse import parse_time, parse_datetime
99

10+
import django
1011
from django.db import models
1112
from django.test import TestCase
1213
from django.contrib.auth.models import User
@@ -16,6 +17,12 @@
1617
from .filterset import FilterSet
1718
from .backends import DjangoFilterBackend
1819

20+
try:
21+
from django.test import override_settings
22+
except ImportError:
23+
# TODO: Remove this once Django 1.6 is EOL.
24+
from django.test.utils import override_settings
25+
1926

2027
class Note(models.Model):
2128
title = models.CharField(max_length=100)
@@ -231,14 +238,34 @@ class ExplicitLookupsPersonDateFilter(FilterSet):
231238
class Meta:
232239
model = Person
233240

234-
class InSetLookupPersonFilter(FilterSet):
241+
242+
class InSetLookupPersonIDFilter(FilterSet):
235243
pk = AllLookupsFilter('id')
236244

237245
class Meta:
238246
model = Person
239247

248+
249+
class InSetLookupPersonNameFilter(FilterSet):
250+
name = AllLookupsFilter('name')
251+
252+
class Meta:
253+
model = Person
254+
255+
240256
class TestFilterSets(TestCase):
241-
def setUp(self):
257+
258+
if django.VERSION >= (1, 8):
259+
@classmethod
260+
def setUpTestData(cls):
261+
cls.generateTestData()
262+
263+
else:
264+
def setUp(self):
265+
self.generateTestData()
266+
267+
@classmethod
268+
def generateTestData(cls):
242269
#######################
243270
# Create users
244271
#######################
@@ -262,7 +289,7 @@ def setUp(self):
262289
n.save()
263290

264291
#######################
265-
# Create notes
292+
# Create notes
266293
#######################
267294
n = Note(
268295
title="Test 2",
@@ -286,7 +313,7 @@ def setUp(self):
286313
n.save()
287314

288315
#######################
289-
# Create posts
316+
# Create posts
290317
#######################
291318
post = Post(
292319
note=Note.objects.get(title="Test 1"),
@@ -360,7 +387,7 @@ def setUp(self):
360387
)
361388
blogpost.save()
362389
blogpost.tags = [Tag.objects.get(name="house")]
363-
390+
364391
################################
365392
# Recursive relations
366393
################################
@@ -600,10 +627,10 @@ class Meta:
600627
date_str = JSONRenderer().render(data['date_joined']).decode('utf-8').strip('"')
601628

602629
# Adjust for imprecise rendering of time
603-
datetime_str = JSONRenderer().render(date_parse(data['datetime_joined']) + datetime.timedelta(seconds=0.6)).decode('utf-8').strip('"')
630+
datetime_str = JSONRenderer().render(parse_datetime(data['datetime_joined']) + datetime.timedelta(seconds=0.6)).decode('utf-8').strip('"')
604631

605632
# Adjust for imprecise rendering of time
606-
dt = datetime.datetime.combine(datetime.date.today(), date_parse(data['time_joined']).time()) + datetime.timedelta(seconds=0.6)
633+
dt = datetime.datetime.combine(datetime.date.today(), parse_time(data['time_joined'])) + datetime.timedelta(seconds=0.6)
607634
time_str = JSONRenderer().render(dt.time()).decode('utf-8').strip('"')
608635

609636
# DateField
@@ -632,14 +659,43 @@ class Meta:
632659
p = list(f)[0]
633660
self.assertEqual(p.name, "John")
634661

635-
def test_inset_filter(self):
662+
@override_settings(USE_TZ=True)
663+
def test_datetime_timezone_awareness(self):
664+
# Addresses issue #24 - ensure that datetime strings terminating
665+
# in 'Z' are correctly handled.
666+
from rest_framework import serializers
667+
from rest_framework.renderers import JSONRenderer
668+
669+
class PersonSerializer(serializers.ModelSerializer):
670+
class Meta:
671+
model = Person
672+
673+
# Figure out what the date strings should look like based on the
674+
# serializer output.
675+
john = Person.objects.get(name="John")
676+
data = PersonSerializer(john).data
677+
datetime_str = JSONRenderer().render(parse_datetime(data['datetime_joined']) + datetime.timedelta(seconds=0.6)).decode('utf-8').strip('"')
678+
679+
# This is more for documentation - DRF appends a 'Z' to timezone aware UTC datetimes when rendering:
680+
# https://github.com/tomchristie/django-rest-framework/blob/3.2.0/rest_framework/fields.py#L1002-L1006
681+
self.assertTrue(datetime_str.endswith('Z'))
682+
683+
GET = {
684+
'datetime_joined__lte': datetime_str,
685+
}
686+
f = AllLookupsPersonDateFilter(GET, queryset=Person.objects.all())
687+
self.assertEqual(len(list(f)), 1)
688+
p = list(f)[0]
689+
self.assertEqual(p.name, "John")
690+
691+
def test_inset_number_filter(self):
636692
p1 = Person.objects.get(name="John").pk
637693
p2 = Person.objects.get(name="Mark").pk
638694

639695
ALL_GET = {
640696
'pk__in': '{:d},{:d}'.format(p1, p2),
641697
}
642-
f = InSetLookupPersonFilter(ALL_GET, queryset=Person.objects.all())
698+
f = InSetLookupPersonIDFilter(ALL_GET, queryset=Person.objects.all())
643699
f = [x.pk for x in f]
644700
self.assertEqual(len(f), 2)
645701
self.assertIn(p1, f)
@@ -649,13 +705,13 @@ def test_inset_filter(self):
649705
INVALID_GET = {
650706
'pk__in': '{:d},c{:d}'.format(p1, p2)
651707
}
652-
f = InSetLookupPersonFilter(INVALID_GET, queryset=Person.objects.all())
708+
f = InSetLookupPersonIDFilter(INVALID_GET, queryset=Person.objects.all())
653709
self.assertEqual(len(list(f)), 0)
654710

655711
EXTRA_GET = {
656712
'pk__in': '{:d},{:d},{:d}'.format(p1, p2, p1*p2)
657713
}
658-
f = InSetLookupPersonFilter(EXTRA_GET, queryset=Person.objects.all())
714+
f = InSetLookupPersonIDFilter(EXTRA_GET, queryset=Person.objects.all())
659715
f = [x.pk for x in f]
660716
self.assertEqual(len(f), 2)
661717
self.assertIn(p1, f)
@@ -664,7 +720,7 @@ def test_inset_filter(self):
664720
DISORDERED_GET = {
665721
'pk__in': '{:d},{:d},{:d}'.format(p2, p2*p1, p1)
666722
}
667-
f = InSetLookupPersonFilter(DISORDERED_GET, queryset=Person.objects.all())
723+
f = InSetLookupPersonIDFilter(DISORDERED_GET, queryset=Person.objects.all())
668724
f = [x.pk for x in f]
669725
self.assertEqual(len(f), 2)
670726
self.assertIn(p1, f)
@@ -680,3 +736,40 @@ def test_get_filterset_subset(self):
680736
# ensure that the FilterSet subset only contains the requested fields
681737
self.assertIn('email', filterset_class.base_filters)
682738
self.assertEqual(len(filterset_class.base_filters), 1)
739+
740+
def test_inset_char_filter(self):
741+
p1 = Person.objects.get(name="John").name
742+
p2 = Person.objects.get(name="Mark").name
743+
744+
ALL_GET = {
745+
'name__in': '{},{}'.format(p1, p2),
746+
}
747+
f = InSetLookupPersonNameFilter(ALL_GET, queryset=Person.objects.all())
748+
f = [x.name for x in f]
749+
self.assertEqual(len(f), 2)
750+
self.assertIn(p1, f)
751+
self.assertIn(p2, f)
752+
753+
NONEXISTENT_GET = {
754+
'name__in': '{},Foo{}'.format(p1, p2)
755+
}
756+
f = InSetLookupPersonNameFilter(NONEXISTENT_GET, queryset=Person.objects.all())
757+
self.assertEqual(len(list(f)), 1)
758+
759+
EXTRA_GET = {
760+
'name__in': '{},{},{}'.format(p1, p2, p1+p2)
761+
}
762+
f = InSetLookupPersonNameFilter(EXTRA_GET, queryset=Person.objects.all())
763+
f = [x.name for x in f]
764+
self.assertEqual(len(f), 2)
765+
self.assertIn(p1, f)
766+
self.assertIn(p2, f)
767+
768+
DISORDERED_GET = {
769+
'name__in': '{},{},{}'.format(p2, p2+p1, p1)
770+
}
771+
f = InSetLookupPersonNameFilter(DISORDERED_GET, queryset=Person.objects.all())
772+
f = [x.name for x in f]
773+
self.assertEqual(len(f), 2)
774+
self.assertIn(p1, f)
775+
self.assertIn(p2, f)

0 commit comments

Comments
 (0)