Skip to content

Commit b5ccd8a

Browse files
author
Ryan P Kilby
authored
Merge pull request #161 from rpkilby/fix-request-filtering
Fix request-based filtering
2 parents dd1237f + 72f2853 commit b5ccd8a

File tree

4 files changed

+66
-25
lines changed

4 files changed

+66
-25
lines changed

rest_framework_filters/backends.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11

2-
from django.template import Template, TemplateDoesNotExist, loader
3-
from rest_framework import compat
2+
from contextlib import contextmanager
43
from django_filters.rest_framework import backends
54

65
from .filterset import FilterSet
@@ -9,30 +8,27 @@
98
class DjangoFilterBackend(backends.DjangoFilterBackend):
109
default_filter_set = FilterSet
1110

12-
def filter_queryset(self, request, queryset, view):
13-
filter_class = self.get_filter_class(view, queryset)
14-
15-
if filter_class:
16-
if hasattr(filter_class, 'get_subset'):
17-
filter_class = filter_class.get_subset(request.query_params)
18-
return filter_class(request.query_params, queryset=queryset).qs
11+
@contextmanager
12+
def patched_filter_class(self, request):
13+
"""
14+
Patch `get_filter_class()` to get the subset based on the request params
15+
"""
16+
original = self.get_filter_class
1917

20-
return queryset
18+
def get_subset_class(view, queryset=None):
19+
filter_class = original(view, queryset)
2120

22-
def to_html(self, request, queryset, view):
23-
filter_class = self.get_filter_class(view, queryset)
24-
if not filter_class:
25-
return None
26-
filter_instance = filter_class(request.query_params, queryset=queryset)
21+
if filter_class and hasattr(filter_class, 'get_subset'):
22+
filter_class = filter_class.get_subset(request.query_params)
2723

28-
# forces `form` evaluation before `qs` is called. This prevents an empty form from being cached.
29-
filter_instance.form
24+
return filter_class
3025

31-
try:
32-
template = loader.get_template(self.template)
33-
except TemplateDoesNotExist:
34-
template = Template(backends.template_default)
26+
self.get_filter_class = get_subset_class
27+
yield
28+
self.get_filter_class = original
3529

36-
return compat.template_render(template, context={
37-
'filter': filter_instance
38-
})
30+
def filter_queryset(self, request, queryset, view):
31+
# patching the behavior of `get_filter_class()` in this method allows
32+
# us to avoid maintenance issues with code duplication.
33+
with self.patched_filter_class(request):
34+
return super(DjangoFilterBackend, self).filter_queryset(request, queryset, view)

rest_framework_filters/filterset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def expand_filters(self):
123123
if isinstance(f, filters.RelatedFilter) and filter_name in related_data:
124124
subset_data = related_data[filter_name]
125125
subset_class = f.filterset.get_subset(subset_data)
126-
filterset = subset_class(data=subset_data)
126+
filterset = subset_class(data=subset_data, request=self.request)
127127

128128
# modify filter names to account for relationship
129129
for related_name, related_f in six.iteritems(filterset.expand_filters()):

tests/test_backends.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
from rest_framework.test import APITestCase, APIRequestFactory
3+
from rest_framework_filters import FilterSet
34

45
from .testapp import models, views
56

@@ -58,3 +59,25 @@ class SimpleViewSet(views.FilterFieldsUserViewSet):
5859
<button type="submit" class="btn btn-primary">Submit</button>
5960
</form>
6061
""")
62+
63+
def test_request_obj_is_passed(self):
64+
"""
65+
Ensure that the request object is passed from the backend to the filterset.
66+
See: https://github.com/philipn/django-rest-framework-filters/issues/149
67+
"""
68+
class RequestCheck(FilterSet):
69+
def __init__(self, *args, **kwargs):
70+
super(RequestCheck, self).__init__(*args, **kwargs)
71+
assert self.request is not None
72+
73+
class Meta:
74+
model = models.User
75+
fields = ['username']
76+
77+
class ViewSet(views.FilterFieldsUserViewSet):
78+
filter_class = RequestCheck
79+
80+
view = ViewSet(action_map={})
81+
backend = view.filter_backends[0]
82+
request = view.initialize_request(factory.get('/'))
83+
backend().filter_queryset(request, view.get_queryset(), view)

tests/test_filtering.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,28 @@ class Meta:
349349
msg = str(excinfo.exception)
350350
self.assertEqual("Expected `.get_queryset()` to return a `QuerySet`, but got `None`.", msg)
351351

352+
def test_relatedfilter_request_is_passed(self):
353+
class RequestCheck(FilterSet):
354+
def __init__(self, *args, **kwargs):
355+
super(RequestCheck, self).__init__(*args, **kwargs)
356+
assert self.request is not None
357+
358+
class Meta:
359+
model = User
360+
fields = ['username']
361+
362+
class NoteFilter(FilterSet):
363+
author = filters.RelatedFilter(RequestCheck, name='author')
364+
365+
class Meta:
366+
model = Note
367+
fields = []
368+
369+
GET = {'author__username': 'user2'}
370+
371+
# should pass
372+
NoteFilter(GET, queryset=Note.objects.all(), request=object()).qs
373+
352374

353375
class MiscTests(TestCase):
354376
def test_multiwidget_incompatibility(self):

0 commit comments

Comments
 (0)