Skip to content

Commit 64b9c10

Browse files
author
Ryan P Kilby
authored
Add additional lookups to RelatedFilter (#114)
1 parent 15cd1c1 commit 64b9c10

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

rest_framework_filters/filters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ def _import_class(path):
2121

2222

2323
class RelatedFilter(ModelChoiceFilter):
24-
def __init__(self, filterset, *args, **kwargs):
24+
def __init__(self, filterset, lookups=None, *args, **kwargs):
2525
self.filterset = filterset
26+
self.lookups = lookups
2627
return super(RelatedFilter, self).__init__(*args, **kwargs)
2728

2829
def filterset():
@@ -45,7 +46,7 @@ def field(self):
4546

4647

4748
class AllLookupsFilter(Filter):
48-
pass
49+
lookups = '__all__'
4950

5051

5152
###################################################

rest_framework_filters/filterset.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,18 @@ def __new__(cls, name, bases, attrs):
5454
# Populate our FilterSet fields with all the possible
5555
# filters for the AllLookupsFilter field.
5656
for name, filter_ in six.iteritems(new_class.base_filters.copy()):
57-
if isinstance(filter_, filters.AllLookupsFilter):
57+
if isinstance(filter_, (filters.AllLookupsFilter, filters.RelatedFilter)):
5858
field = filterset.get_model_field(opts.model, filter_.name)
5959

60-
for lookup_expr in utils.lookups_for_field(field):
60+
lookups = filter_.lookups or []
61+
if lookups == '__all__':
62+
lookups = utils.lookups_for_field(field)
63+
64+
for lookup_expr in lookups:
65+
if isinstance(filter_, filters.RelatedFilter) and lookup_expr == 'exact':
66+
# Don't replace the RelatedFilter
67+
continue
68+
6169
if isinstance(field, ForeignObjectRel):
6270
f = new_class.filter_for_reverse_field(field, filter_.name)
6371
else:

tests/test_filterset.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,29 @@ class Meta:
189189
self.assertIsInstance(F.base_filters['author'], filters.RelatedFilter)
190190
self.assertIsInstance(F.base_filters['author__in'], BaseInFilter)
191191

192+
def test_relatedfilter_lookups(self):
193+
# ensure that related filter is compatible with __all__ lookups.
194+
class F(FilterSet):
195+
author = filters.RelatedFilter(UserFilter, lookups='__all__')
196+
197+
class Meta:
198+
model = Note
199+
200+
self.assertIsInstance(F.base_filters['author'], filters.RelatedFilter)
201+
self.assertIsInstance(F.base_filters['author__in'], BaseInFilter)
202+
203+
def test_relatedfilter_lookups_list(self):
204+
# ensure that related filter is compatible with __all__ lookups.
205+
class F(FilterSet):
206+
author = filters.RelatedFilter(UserFilter, lookups=['in'])
207+
208+
class Meta:
209+
model = Note
210+
211+
self.assertEqual(len([f for f in F.base_filters if f.startswith('author')]), 2)
212+
self.assertIsInstance(F.base_filters['author'], filters.RelatedFilter)
213+
self.assertIsInstance(F.base_filters['author__in'], BaseInFilter)
214+
192215
def test_filter_persistence_with__all__(self):
193216
# ensure that __all__ does not overwrite declared filters.
194217
class F(FilterSet):

0 commit comments

Comments
 (0)