Skip to content

Commit 77a318a

Browse files
author
Ryan P Kilby
committed
Rework MethodFilter to works across relationships
1 parent 058cfbe commit 77a318a

File tree

5 files changed

+115
-1
lines changed

5 files changed

+115
-1
lines changed

rest_framework_filters/filters.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,47 @@ class InSetNumberFilter(InSetFilterBase, NumberFilter):
8787

8888
class InSetCharFilter(InSetFilterBase, NumberFilter):
8989
field_class = fields.ArrayCharField
90+
91+
92+
class MethodFilter(Filter):
93+
"""
94+
This filter will allow you to run a method that exists on the filterset class
95+
"""
96+
97+
def __init__(self, *args, **kwargs):
98+
self.action = kwargs.pop('action', '')
99+
super(MethodFilter, self).__init__(*args, **kwargs)
100+
101+
def resolve_action(self):
102+
"""
103+
This method provides a hook for the parent FilterSet to resolve the filter's
104+
action after initialization. This is necessary, as the filter name may change
105+
as it's expanded across related filtersets.
106+
107+
ie, `is_published` might become `post__is_published`.
108+
"""
109+
# noop if a function was provided as the action
110+
if callable(self.action):
111+
return
112+
113+
# otherwise, action is a string representing an action to be called on
114+
# the parent FilterSet.
115+
parent_action = self.action or 'filter_{0}'.format(self.name)
116+
117+
parent = getattr(self, 'parent', None)
118+
self.action = getattr(parent, parent_action, None)
119+
120+
assert callable(self.action), (
121+
'Expected parent FilterSet `%s.%s` to have a `.%s()` method.' %
122+
(parent.__class__.__module__, parent.__class__.__name__, parent_action)
123+
)
124+
125+
def filter(self, qs, value):
126+
"""
127+
This filter method will act as a proxy for the actual method we want to
128+
call.
129+
It will try to find the method on the parent filterset,
130+
if not it attempts to search for the method `field_{{attribute_name}}`.
131+
Otherwise it defaults to just returning the queryset.
132+
"""
133+
return self.action(self.name, qs, value)

rest_framework_filters/filterset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __init__(self, *args, **kwargs):
7474
if isnull not in self.filters:
7575
self.filters[isnull] = filters.BooleanFilter(name=isnull)
7676

77+
elif isinstance(filter_, filters.MethodFilter):
78+
filter_.resolve_action()
79+
7780
def get_filters(self):
7881
"""
7982
Build a set of filters based on the requested data. The resulting set

tests/filters.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
from rest_framework_filters import filters
33
from rest_framework_filters.filters import RelatedFilter, AllLookupsFilter
4-
from rest_framework_filters.filterset import FilterSet
4+
from rest_framework_filters.filterset import FilterSet, LOOKUP_SEP
55

66

77
from .models import (
@@ -63,6 +63,32 @@ class Meta:
6363
model = Post
6464

6565

66+
class PostFilterWithMethod(FilterSet):
67+
note = RelatedFilter(NoteFilterWithRelatedAll, name='note')
68+
is_published = filters.MethodFilter()
69+
70+
class Meta:
71+
model = Post
72+
73+
def filter_is_published(self, name, qs, value):
74+
null = value.lower() != 'true'
75+
76+
# 'post', 'is_published'
77+
name, _ = name.rsplit(LOOKUP_SEP, 1)
78+
79+
return qs.filter(**{
80+
LOOKUP_SEP.join([name, 'date_published__isnull']): null
81+
})
82+
83+
84+
class CoverFilterWithRelatedMethodFilter(FilterSet):
85+
comment = filters.CharFilter(name='comment')
86+
post = RelatedFilter(PostFilterWithMethod, name='post')
87+
88+
class Meta:
89+
model = Cover
90+
91+
6692
class CoverFilterWithRelated(FilterSet):
6793
comment = filters.CharFilter(name='comment')
6894
post = RelatedFilter(PostFilterWithRelated, name='post')

tests/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class Note(models.Model):
1212
class Post(models.Model):
1313
note = models.ForeignKey(Note)
1414
content = models.TextField()
15+
date_published = models.DateField(null=True)
1516

1617

1718
class Cover(models.Model):

tests/test_filterset.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
NoteFilterWithRelatedAll,
2323
NoteFilterWithRelatedAllDifferentFilterName,
2424
PostFilterWithRelated,
25+
# PostFilterWithMethod,
26+
CoverFilterWithRelatedMethodFilter,
2527
CoverFilterWithRelated,
2628
# PageFilterWithRelated,
2729
TagFilter,
@@ -324,6 +326,44 @@ def test_get_filterset_subset(self):
324326
self.assertEqual(len(filterset_class.base_filters), 1)
325327

326328

329+
class MethodFilterTests(TestCase):
330+
331+
if django.VERSION >= (1, 8):
332+
@classmethod
333+
def setUpTestData(cls):
334+
cls.generateTestData()
335+
336+
else:
337+
def setUp(self):
338+
self.generateTestData()
339+
340+
@classmethod
341+
def generateTestData(cls):
342+
user = User.objects.create(username="user1", email="user1@example.org")
343+
344+
note1 = Note.objects.create(title="Test 1", content="Test content 1", author=user)
345+
note2 = Note.objects.create(title="Test 2", content="Test content 2", author=user)
346+
347+
post1 = Post.objects.create(note=note1, content="Test content in post 1")
348+
post2 = Post.objects.create(note=note2, content="Test content in post 4", date_published=datetime.date.today())
349+
350+
Cover.objects.create(post=post1, comment="Cover 1")
351+
Cover.objects.create(post=post2, comment="Cover 2")
352+
353+
def test_related_method_filter(self):
354+
"""
355+
Missing MethodFilter filter methods are silently ignored, returning
356+
the unfiltered queryset.
357+
"""
358+
GET = {
359+
'post__is_published': 'true'
360+
}
361+
filterset = CoverFilterWithRelatedMethodFilter(GET, queryset=Cover.objects.all())
362+
results = list(filterset)
363+
self.assertEqual(len(results), 1)
364+
self.assertEqual(results[0].comment, "Cover 2")
365+
366+
327367
class DatetimeTests(TestCase):
328368

329369
if django.VERSION >= (1, 8):

0 commit comments

Comments
 (0)