Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit c5b2a0c

Browse files
authored
Optimize dseries.rolling.mean() (#611)
* Optimize dseries.rolling.mean()
1 parent d917971 commit c5b2a0c

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

sdc/datatypes/hpat_pandas_series_rolling_functions.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,6 @@ def apply_minp(arr, ddof, minp):
293293
gen_hpat_pandas_series_rolling_impl(arr_kurt))
294294
hpat_pandas_rolling_series_max_impl = register_jitable(
295295
gen_hpat_pandas_series_rolling_impl(arr_max))
296-
hpat_pandas_rolling_series_mean_impl = register_jitable(
297-
gen_hpat_pandas_series_rolling_impl(arr_mean))
298296
hpat_pandas_rolling_series_median_impl = register_jitable(
299297
gen_hpat_pandas_series_rolling_impl(arr_median))
300298
hpat_pandas_rolling_series_min_impl = register_jitable(
@@ -336,7 +334,17 @@ def result_or_nan(nfinite, minp, result):
336334
return result
337335

338336

339-
def gen_sdc_pandas_series_rolling_impl(pop, put, init_result=numpy.nan):
337+
@sdc_register_jitable
338+
def mean_result_or_nan(nfinite, minp, result):
339+
"""Get result mean taking into account min periods."""
340+
if nfinite == 0 or nfinite < minp:
341+
return numpy.nan
342+
343+
return result / nfinite
344+
345+
346+
def gen_sdc_pandas_series_rolling_impl(pop, put, get_result=result_or_nan,
347+
init_result=numpy.nan):
340348
"""Generate series rolling methods implementations based on pop/put funcs"""
341349
def impl(self):
342350
win = self._window
@@ -366,22 +374,24 @@ def impl(self):
366374
for idx in range(interlude_start, interlude_stop):
367375
value = input_arr[idx]
368376
nfinite, result = put(value, nfinite, result)
369-
output_arr[idx] = result_or_nan(nfinite, minp, result)
377+
output_arr[idx] = get_result(nfinite, minp, result)
370378

371379
for idx in range(interlude_stop, chunk.stop):
372380
put_value = input_arr[idx]
373381
pop_value = input_arr[idx - win]
374382
nfinite, result = put(put_value, nfinite, result)
375383
nfinite, result = pop(pop_value, nfinite, result)
376-
output_arr[idx] = result_or_nan(nfinite, minp, result)
384+
output_arr[idx] = get_result(nfinite, minp, result)
377385

378386
return pandas.Series(output_arr, input_series._index,
379387
name=input_series._name)
380388
return impl
381389

382390

383-
sdc_pandas_series_rolling_sum_impl = register_jitable(
384-
gen_sdc_pandas_series_rolling_impl(pop_sum, put_sum, init_result=0.))
391+
sdc_pandas_series_rolling_mean_impl = gen_sdc_pandas_series_rolling_impl(
392+
pop_sum, put_sum, get_result=mean_result_or_nan, init_result=0.)
393+
sdc_pandas_series_rolling_sum_impl = gen_sdc_pandas_series_rolling_impl(
394+
pop_sum, put_sum, init_result=0.)
385395

386396

387397
@sdc_rolling_overload(SeriesRollingType, 'apply')
@@ -552,7 +562,30 @@ def _impl(self, other=None, pairwise=None, ddof=1):
552562
bias_adj = count / (count - ddof)
553563

554564
def mean(series):
555-
return series.rolling(win, min_periods=minp).mean()
565+
# cannot call return series.rolling(win, min_periods=minp).mean()
566+
# due to different float rounding in new and old implementations
567+
# TODO: fix this during optimizing of covariance
568+
input_arr = series._data
569+
length = len(input_arr)
570+
output_arr = numpy.empty(length, dtype=float64)
571+
572+
def apply_minp(arr, minp):
573+
finite_arr = arr[numpy.isfinite(arr)]
574+
if len(finite_arr) < minp:
575+
return numpy.nan
576+
else:
577+
return arr_mean(finite_arr)
578+
579+
boundary = min(win, length)
580+
for i in prange(boundary):
581+
arr_range = input_arr[:i + 1]
582+
output_arr[i] = apply_minp(arr_range, minp)
583+
584+
for i in prange(boundary, length):
585+
arr_range = input_arr[i + 1 - win:i + 1]
586+
output_arr[i] = apply_minp(arr_range, minp)
587+
588+
return pandas.Series(output_arr, series._index, name=series._name)
556589

557590
return (mean(main_aligned * other_aligned) - mean(main_aligned) * mean(other_aligned)) * bias_adj
558591

@@ -593,13 +626,13 @@ def hpat_pandas_series_rolling_max(self):
593626
return hpat_pandas_rolling_series_max_impl
594627

595628

596-
@sdc_rolling_overload(SeriesRollingType, 'mean')
629+
@sdc_overload_method(SeriesRollingType, 'mean')
597630
def hpat_pandas_series_rolling_mean(self):
598631

599632
ty_checker = TypeChecker('Method rolling.mean().')
600633
ty_checker.check(self, SeriesRollingType)
601634

602-
return hpat_pandas_rolling_series_mean_impl
635+
return sdc_pandas_series_rolling_mean_impl
603636

604637

605638
@sdc_rolling_overload(SeriesRollingType, 'median')

sdc/tests/test_rolling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -715,8 +715,8 @@ def test_impl(obj, window, min_periods):
715715
hpat_func = self.jit(test_impl)
716716
assert_equal = self._get_assert_equal(obj)
717717

718-
for window in range(0, len(obj) + 3, 2):
719-
for min_periods in range(0, window + 1, 2):
718+
for window in range(len(obj) + 2):
719+
for min_periods in range(window):
720720
with self.subTest(obj=obj, window=window,
721721
min_periods=min_periods):
722722
jit_result = hpat_func(obj, window, min_periods)

sdc/tests/tests_perf/test_perf_series_rolling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class TestSeriesRollingMethods(TestBase):
8585
def setUpClass(cls):
8686
super().setUpClass()
8787
cls.map_ncalls_dlength = {
88+
'mean': (100, [8 * 10 ** 5]),
8889
'sum': (100, [8 * 10 ** 5]),
8990
}
9091

@@ -124,6 +125,9 @@ def _test_series_rolling_method(self, name, rolling_params=None,
124125
data_num += len(extra_usecase_params.split(', '))
125126
self._test_case(usecase, name, total_data_length, data_num=data_num)
126127

128+
def test_series_rolling_mean(self):
129+
self._test_series_rolling_method('mean')
130+
127131
def test_series_rolling_sum(self):
128132
self._test_series_rolling_method('sum')
129133

@@ -135,7 +139,6 @@ def test_series_rolling_sum(self):
135139
TC(name='cov', size=[10 ** 7]),
136140
TC(name='kurt', size=[10 ** 7]),
137141
TC(name='max', size=[10 ** 7]),
138-
TC(name='mean', size=[10 ** 7]),
139142
TC(name='median', size=[10 ** 7]),
140143
TC(name='min', size=[10 ** 7]),
141144
TC(name='quantile', size=[10 ** 7], params='0.2'),

0 commit comments

Comments
 (0)