Skip to content

Commit e8efe11

Browse files
author
Ryan P Kilby
committed
Move caching logic into 'get_subset'
1 parent 2a0d0c6 commit e8efe11

File tree

2 files changed

+71
-46
lines changed

2 files changed

+71
-46
lines changed

rest_framework_filters/filterset.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class FilterSet(six.with_metaclass(FilterSetMetaclass, filterset.FilterSet)):
5858
'filter_class': filters.IsoDateTimeFilter,
5959
},
6060
}
61+
_subset_cache = {}
6162

6263
def __init__(self, *args, **kwargs):
6364
self._related_filterset_cache = kwargs.pop('cache', {})
@@ -116,21 +117,7 @@ def get_filters(self):
116117
if not isinstance(related_filter, filters.RelatedFilter):
117118
continue
118119

119-
# get known filter names
120-
filterset_class = related_filter.filterset
121-
filter_names = [filterset_class.get_filter_name(param) for param in rel_data.keys()]
122-
123-
# filter out empty values - indicates an unknown field (author__foobar__isnull)
124-
filter_names = [f for f in filter_names if f is not None]
125-
126-
# attempt to retrieve related filterset subset from the cache
127-
key = self.cache_key(filterset_class, filter_names)
128-
subset_class = self.cache_get(key)
129-
130-
# otherwise build and insert it into the cache
131-
if subset_class is None:
132-
subset_class = related_filter.filterset.get_subset(filter_names)
133-
self.cache_set(key, subset_class)
120+
subset_class = related_filter.filterset.get_subset(rel_data)
134121

135122
# initialize and copy filters
136123
filterset = subset_class(data=rel_data)
@@ -163,33 +150,56 @@ def get_filter_name(cls, param):
163150
return related_param
164151

165152
@classmethod
166-
def get_subset(cls, filter_names):
153+
def get_subset(cls, params):
167154
"""
168-
Returns a FilterSet subclass that contains the subset of filters
169-
specified in `filter_names`. This is useful for creating FilterSets
170-
used across relationships, as it minimizes the deepcopy overhead
171-
incurred when instantiating the FilterSet.
155+
Returns a FilterSubset class that contains the subset of filters
156+
specified in the requested `params`. This is useful for creating
157+
FilterSets that traverse relationships, as it helps to minimize
158+
the deepcopy overhead incurred when instantiating the FilterSet.
172159
"""
173-
class FilterSetSubset(cls):
160+
# Determine names of filters from query params and remove empty values.
161+
# param names that traverse relations are translated to just the local
162+
# filter names. eg, `author__username` => `author`. Empty values are
163+
# removed, as they indicate an unknown field eg, author__foobar__isnull
164+
filter_names = [cls.get_filter_name(param) for param in params]
165+
filter_names = [f for f in filter_names if f is not None]
166+
167+
# attempt to retrieve related filterset subset from the cache
168+
key = cls.cache_key(filter_names)
169+
subset_class = cls.cache_get(key)
170+
171+
# if no cached subset, then derive base_filters and create new subset
172+
if subset_class is not None:
173+
return subset_class
174+
175+
class FilterSubsetMetaclass(FilterSetMetaclass):
176+
def __new__(cls, name, bases, attrs):
177+
new_class = super(FilterSubsetMetaclass, cls).__new__(cls, name, bases, attrs)
178+
new_class.base_filters = OrderedDict([
179+
(name, f)
180+
for name, f in six.iteritems(new_class.base_filters)
181+
if name in filter_names
182+
])
183+
return new_class
184+
185+
class FilterSubset(six.with_metaclass(FilterSubsetMetaclass, cls)):
174186
pass
175187

176-
FilterSetSubset.__name__ = str('%sSubset' % (cls.__name__))
177-
FilterSetSubset.base_filters = OrderedDict([
178-
(name, f)
179-
for name, f in six.iteritems(cls.base_filters)
180-
if name in filter_names
181-
])
182-
183-
return FilterSetSubset
188+
FilterSubset.__name__ = str('%sSubset' % (cls.__name__, ))
189+
cls.cache_set(key, FilterSubset)
190+
return FilterSubset
184191

185-
def cache_key(self, filterset, filter_names):
186-
return '%sSubset-%s' % (filterset.__name__, '-'.join(sorted(filter_names)), )
192+
@classmethod
193+
def cache_key(cls, filter_names):
194+
return '%sSubset-%s' % (cls.__name__, '-'.join(sorted(filter_names)), )
187195

188-
def cache_get(self, key):
189-
return self._related_filterset_cache.get(key)
196+
@classmethod
197+
def cache_get(cls, key):
198+
return cls._subset_cache.get(key)
190199

191-
def cache_set(self, key, value):
192-
self._related_filterset_cache[key] = value
200+
@classmethod
201+
def cache_set(cls, key, value):
202+
cls._subset_cache[key] = value
193203

194204
@property
195205
def qs(self):

tests/test_filterset.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,6 @@ def test_m2m_relation(self):
304304
titles = set([p.title for p in f])
305305
self.assertEqual(titles, set(["First post", "Second post"]))
306306

307-
def test_get_subset(self):
308-
related_filter = NoteFilterWithRelated.base_filters['author']
309-
filterset_class = related_filter.filterset.get_subset(['email'])
310-
311-
# ensure that the class name is useful when debugging
312-
self.assertEqual(filterset_class.__name__, 'UserFilterSubset')
313-
314-
# ensure that the FilterSet subset only contains the requested fields
315-
self.assertIn('email', filterset_class.base_filters)
316-
self.assertEqual(len(filterset_class.base_filters), 1)
317-
318307
def test_nonexistent_related_field(self):
319308
"""
320309
Invalid filter keys (including those on related filters) are invalid
@@ -335,6 +324,32 @@ def test_nonexistent_related_field(self):
335324
self.assertEqual(len(list(f)), 4)
336325

337326

327+
class FilterSubsetTests(TestCase):
328+
329+
def test_get_subset(self):
330+
filterset_class = UserFilter.get_subset(['email'])
331+
332+
# ensure that the class name is useful when debugging
333+
self.assertEqual(filterset_class.__name__, 'UserFilterSubset')
334+
335+
# ensure that the FilterSet subset only contains the requested fields
336+
self.assertIn('email', filterset_class.base_filters)
337+
self.assertEqual(len(filterset_class.base_filters), 1)
338+
339+
def test_related_subset(self):
340+
# related filters should only return the local RelatedFilter
341+
filterset_class = NoteFilterWithRelated.get_subset(['title', 'author__email'])
342+
343+
self.assertIn('title', filterset_class.base_filters)
344+
self.assertIn('author', filterset_class.base_filters)
345+
self.assertEqual(len(filterset_class.base_filters), 2)
346+
347+
def test_non_filter_subset(self):
348+
# non-filter params should be ignored
349+
filterset_class = NoteFilterWithRelated.get_subset(['foobar'])
350+
self.assertEqual(len(filterset_class.base_filters), 0)
351+
352+
338353
class MethodFilterTests(TestCase):
339354

340355
@classmethod

0 commit comments

Comments
 (0)