Skip to content

Commit f7371d3

Browse files
committed
Merge pull request #23 from manuelnaranjo/master
__in filter
2 parents d3a6585 + 652f8fe commit f7371d3

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

rest_framework_filters/fields.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from django import forms
2+
3+
class ArrayDecimalField(forms.DecimalField):
4+
def clean(self, value):
5+
if value is None:
6+
return None
7+
8+
out = []
9+
for val in value.split(','):
10+
out.append(super(ArrayDecimalField, self).clean(val))
11+
return out

rest_framework_filters/filters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import django_filters
99
from django_filters.filters import *
1010

11+
from . import fields
1112

1213
def subsitute_iso8601(date_type):
1314
from rest_framework import ISO_8601
@@ -80,3 +81,16 @@ class TimeFilter(django_filters.DateTimeFilter):
8081
def __init__(self, *args, **kwargs):
8182
super(TimeFilter, self).__init__(*args, **kwargs)
8283
self.extra.update({'input_formats': TIME_INPUT_FORMATS})
84+
85+
86+
class InSetNumberFilter(NumberFilter):
87+
field_class = fields.ArrayDecimalField
88+
89+
def filter(self, qs, value):
90+
if value in ([], (), {}, None, ''):
91+
return qs
92+
method = qs.exclude if self.exclude else qs.filter
93+
qs = method(**{self.name: value})
94+
if self.distinct:
95+
qs = qs.distinct()
96+
return qs

rest_framework_filters/filterset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def fix_filter_field(self, f):
7373
lookup_type = f.lookup_type
7474
if lookup_type == 'isnull':
7575
return filters.BooleanFilter(name=("%s%sisnull" % (f.name, LOOKUP_SEP)))
76+
if lookup_type == 'in' and type(f) in [filters.NumberFilter]:
77+
return filters.InSetNumberFilter(name=("%s%sin" % (f.name, LOOKUP_SEP)))
7678
return f
7779

7880
def populate_from_filterset(self, filterset, filter_, name):

rest_framework_filters/tests.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ class ExplicitLookupsPersonDateFilter(FilterSet):
231231
class Meta:
232232
model = Person
233233

234+
class InSetLookupPersonFilter(FilterSet):
235+
pk = AllLookupsFilter('id')
236+
237+
class Meta:
238+
model = Person
234239

235240
class TestFilterSets(TestCase):
236241
def setUp(self):
@@ -626,3 +631,41 @@ class Meta:
626631
self.assertEqual(len(list(f)), 1)
627632
p = list(f)[0]
628633
self.assertEqual(p.name, "John")
634+
635+
def test_inset_filter(self):
636+
p1 = Person.objects.get(name="John").pk
637+
p2 = Person.objects.get(name="Mark").pk
638+
639+
ALL_GET = {
640+
'pk__in': '{:d},{:d}'.format(p1, p2),
641+
}
642+
f = InSetLookupPersonFilter(ALL_GET, queryset=Person.objects.all())
643+
f = [x.pk for x in f]
644+
self.assertEqual(len(f), 2)
645+
self.assertIn(p1, f)
646+
self.assertIn(p2, f)
647+
648+
649+
INVALID_GET = {
650+
'pk__in': '{:d},c{:d}'.format(p1, p2)
651+
}
652+
f = InSetLookupPersonFilter(INVALID_GET, queryset=Person.objects.all())
653+
self.assertEqual(len(list(f)), 0)
654+
655+
EXTRA_GET = {
656+
'pk__in': '{:d},{:d},{:d}'.format(p1, p2, p1*p2)
657+
}
658+
f = InSetLookupPersonFilter(EXTRA_GET, queryset=Person.objects.all())
659+
f = [x.pk for x in f]
660+
self.assertEqual(len(f), 2)
661+
self.assertIn(p1, f)
662+
self.assertIn(p2, f)
663+
664+
DISORDERED_GET = {
665+
'pk__in': '{:d},{:d},{:d}'.format(p2, p2*p1, p1)
666+
}
667+
f = InSetLookupPersonFilter(DISORDERED_GET, queryset=Person.objects.all())
668+
f = [x.pk for x in f]
669+
self.assertEqual(len(f), 2)
670+
self.assertIn(p1, f)
671+
self.assertIn(p2, f)

0 commit comments

Comments
 (0)