Skip to content

Commit 9a765ef

Browse files
jemmajeffreedcherianpre-commit-ci[bot]
authored
More stable algorithm for variance, standard deviation (#456)
Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <deepak@cherian.net>
1 parent f6c87fc commit 9a765ef

File tree

10 files changed

+483
-51
lines changed

10 files changed

+483
-51
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ jobs:
106106
cache-dependency-glob: "pyproject.toml"
107107
- name: Install xarray and dependencies
108108
run: |
109-
uv add --dev .[complete] pint>=0.22
109+
uv add --dev ".[complete]" "pint>=0.22"
110110
- name: Install upstream flox
111111
run: |
112112
uv add git+https://github.com/dcherian/flox.git@${{ github.ref }}

flox/aggregations.py

Lines changed: 153 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from . import aggregate_flox, aggregate_npg, xrutils
1616
from . import xrdtypes as dtypes
1717
from .lib import dask_array_type, sparse_array_type
18+
from .multiarray import MultiArray
19+
from .xrutils import notnull
1820

1921
if TYPE_CHECKING:
2022
FuncTuple = tuple[Callable | str, ...]
@@ -161,8 +163,8 @@ def __init__(
161163
self,
162164
name: str,
163165
*,
164-
numpy: str | None = None,
165-
chunk: str | FuncTuple | None,
166+
numpy: partial | str | None = None,
167+
chunk: partial | str | FuncTuple | None,
166168
combine: str | FuncTuple | None,
167169
preprocess: Callable | None = None,
168170
finalize: Callable | None = None,
@@ -343,57 +345,183 @@ def _mean_finalize(sum_, count):
343345
)
344346

345347

346-
# TODO: fix this for complex numbers
347-
def _var_finalize(sumsq, sum_, count, ddof=0):
348+
def var_chunk(
349+
group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None
350+
):
351+
# Calculate length and sum - important for the adjustment terms to sum squared deviations
352+
array_lens = generic_aggregate(
353+
group_idx,
354+
array,
355+
func="nanlen",
356+
engine=engine,
357+
axis=axis,
358+
size=size,
359+
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
360+
dtype=dtype,
361+
)
362+
363+
array_sums = generic_aggregate(
364+
group_idx,
365+
array,
366+
func="nansum" if skipna else "sum",
367+
engine=engine,
368+
axis=axis,
369+
size=size,
370+
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
371+
dtype=dtype,
372+
)
373+
374+
# Calculate sum squared deviations - the main part of variance sum
348375
with np.errstate(invalid="ignore", divide="ignore"):
349-
result = (sumsq - (sum_**2 / count)) / (count - ddof)
350-
result[count <= ddof] = np.nan
351-
return result
376+
array_means = array_sums / array_lens
377+
378+
sum_squared_deviations = generic_aggregate(
379+
group_idx,
380+
(array - array_means[..., group_idx]) ** 2,
381+
func="nansum" if skipna else "sum",
382+
engine=engine,
383+
axis=axis,
384+
size=size,
385+
fill_value=0, # Unpack fill value bc it's currently defined for multiarray
386+
dtype=dtype,
387+
)
388+
389+
return MultiArray((sum_squared_deviations, array_sums, array_lens))
390+
391+
392+
def _var_combine(array, axis, keepdims=True):
393+
def clip_last(array, ax, n=1):
394+
"""Return array except the last element along axis
395+
Purely included to tidy up the adj_terms line
396+
"""
397+
assert n > 0, "Clipping nothing off the end isn't implemented"
398+
not_last = [slice(None, None) for i in range(array.ndim)]
399+
not_last[ax] = slice(None, -n)
400+
return array[*not_last]
401+
402+
def clip_first(array, ax, n=1):
403+
"""Return array except the first element along axis
404+
Purely included to tidy up the adj_terms line
405+
"""
406+
not_first = [slice(None, None) for i in range(array.ndim)]
407+
not_first[ax] = slice(n, None)
408+
return array[*not_first]
409+
410+
for ax in axis:
411+
if array.shape[ax] == 1:
412+
continue
413+
414+
sum_deviations, sum_X, sum_len = array.arrays
415+
416+
# Calculate parts needed for cascading combination
417+
cumsum_X = np.cumsum(sum_X, axis=ax)
418+
cumsum_len = np.cumsum(sum_len, axis=ax)
419+
420+
# There will be instances in which one or both chunks being merged are empty
421+
# In which case, the adjustment term should be zero, but will throw a divide-by-zero error
422+
# We're going to add a constant to the bottom of the adjustment term equation on those instances
423+
# and count on the zeros on the top making our adjustment term still zero
424+
zero_denominator = (clip_last(cumsum_len, ax) == 0) | (clip_first(sum_len, ax) == 0)
425+
426+
# Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean
427+
with np.errstate(invalid="ignore", divide="ignore"):
428+
adj_terms = (
429+
clip_last(cumsum_len, ax) * clip_first(sum_X, ax)
430+
- clip_first(sum_len, ax) * clip_last(cumsum_X, ax)
431+
) ** 2 / (
432+
clip_last(cumsum_len, ax)
433+
* clip_first(sum_len, ax)
434+
* (clip_last(cumsum_len, ax) + clip_first(sum_len, ax))
435+
+ zero_denominator.astype(int)
436+
)
437+
438+
check = adj_terms * zero_denominator
439+
assert np.all(check[notnull(check)] == 0), (
440+
"Instances where we add something to the denominator must come out to zero"
441+
)
442+
443+
array = MultiArray(
444+
(
445+
np.sum(sum_deviations, axis=ax, keepdims=keepdims)
446+
+ np.sum(adj_terms, axis=ax, keepdims=keepdims), # sum of squared deviations
447+
np.sum(sum_X, axis=ax, keepdims=keepdims), # sum of array items
448+
np.sum(sum_len, axis=ax, keepdims=keepdims), # sum of array lengths
449+
)
450+
)
451+
return array
452+
453+
454+
def is_var_chunk_reduction(agg: Callable) -> bool:
455+
if isinstance(agg, partial):
456+
agg = agg.func
457+
return agg is blockwise_or_numpy_var or agg is var_chunk
458+
459+
460+
def _var_finalize(multiarray, ddof=0):
461+
den = multiarray.arrays[2]
462+
den -= ddof
463+
# preserve nans for groups with 0 obs; so these values are -ddof
464+
with np.errstate(invalid="ignore", divide="ignore"):
465+
ret = multiarray.arrays[0]
466+
ret /= den
467+
ret[den < 0] = np.nan
468+
return ret
352469

353470

354-
def _std_finalize(sumsq, sum_, count, ddof=0):
355-
return np.sqrt(_var_finalize(sumsq, sum_, count, ddof))
471+
def _std_finalize(multiarray, ddof=0):
472+
return np.sqrt(_var_finalize(multiarray, ddof))
473+
474+
475+
def blockwise_or_numpy_var(*args, skipna: bool, ddof=0, std=False, **kwargs):
476+
res = _var_finalize(var_chunk(*args, skipna=skipna, **kwargs), ddof)
477+
return np.sqrt(res) if std else res
356478

357479

358480
# var, std always promote to float, so we set nan
359481
var = Aggregation(
360482
"var",
361-
chunk=("sum_of_squares", "sum", "nanlen"),
362-
combine=("sum", "sum", "sum"),
483+
chunk=partial(var_chunk, skipna=False),
484+
numpy=partial(blockwise_or_numpy_var, skipna=False),
485+
combine=(_var_combine,),
363486
finalize=_var_finalize,
364-
fill_value=0,
487+
fill_value=((0, 0, 0),),
365488
final_fill_value=np.nan,
366-
dtypes=(None, None, np.intp),
489+
dtypes=(None,),
367490
final_dtype=np.floating,
368491
)
492+
369493
nanvar = Aggregation(
370494
"nanvar",
371-
chunk=("nansum_of_squares", "nansum", "nanlen"),
372-
combine=("sum", "sum", "sum"),
495+
chunk=partial(var_chunk, skipna=True),
496+
numpy=partial(blockwise_or_numpy_var, skipna=True),
497+
combine=(_var_combine,),
373498
finalize=_var_finalize,
374-
fill_value=0,
499+
fill_value=((0, 0, 0),),
375500
final_fill_value=np.nan,
376-
dtypes=(None, None, np.intp),
501+
dtypes=(None,),
377502
final_dtype=np.floating,
378503
)
504+
379505
std = Aggregation(
380506
"std",
381-
chunk=("sum_of_squares", "sum", "nanlen"),
382-
combine=("sum", "sum", "sum"),
507+
chunk=partial(var_chunk, skipna=False),
508+
numpy=partial(blockwise_or_numpy_var, skipna=False, std=True),
509+
combine=(_var_combine,),
383510
finalize=_std_finalize,
384-
fill_value=0,
511+
fill_value=((0, 0, 0),),
385512
final_fill_value=np.nan,
386-
dtypes=(None, None, np.intp),
513+
dtypes=(None,),
387514
final_dtype=np.floating,
388515
)
389516
nanstd = Aggregation(
390517
"nanstd",
391-
chunk=("nansum_of_squares", "nansum", "nanlen"),
392-
combine=("sum", "sum", "sum"),
518+
chunk=partial(var_chunk, skipna=True),
519+
numpy=partial(blockwise_or_numpy_var, skipna=True, std=True),
520+
combine=(_var_combine,),
393521
finalize=_std_finalize,
394-
fill_value=0,
522+
fill_value=((0, 0, 0),),
395523
final_fill_value=np.nan,
396-
dtypes=(None, None, np.intp),
524+
dtypes=(None,),
397525
final_dtype=np.floating,
398526
)
399527

flox/core.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_atleast_1d,
4545
_initialize_aggregation,
4646
generic_aggregate,
47+
is_var_chunk_reduction,
4748
quantile_new_dims_func,
4849
)
4950
from .cache import memoize
@@ -1289,7 +1290,8 @@ def chunk_reduce(
12891290
# optimize that out.
12901291
previous_reduction: T_Func = ""
12911292
for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes):
1292-
if empty:
1293+
# UGLY! but this is because the `var` breaks our design assumptions
1294+
if empty and not is_var_chunk_reduction(reduction):
12931295
result = np.full(shape=final_array_shape, fill_value=fv, like=array)
12941296
elif is_nanlen(reduction) and is_nanlen(previous_reduction):
12951297
result = results["intermediates"][-1]
@@ -1298,6 +1300,10 @@ def chunk_reduce(
12981300
kw_func = dict(size=size, dtype=dt, fill_value=fv)
12991301
kw_func.update(kw)
13001302

1303+
# UGLY! but this is because the `var` breaks our design assumptions
1304+
if is_var_chunk_reduction(reduction):
1305+
kw_func.update(engine=engine)
1306+
13011307
if callable(reduction):
13021308
# passing a custom reduction for npg to apply per-group is really slow!
13031309
# So this `reduction` has to do the groupby-aggregation
@@ -2785,6 +2791,7 @@ def groupby_reduce(
27852791
array = array.view(np.int64)
27862792
elif is_cftime:
27872793
offset = array.min()
2794+
assert offset is not None
27882795
array = datetime_to_numeric(array, offset, datetime_unit="us")
27892796

27902797
if nax == 1 and by_.ndim > 1 and expected_ is None:

flox/multiarray.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from collections.abc import Callable
2+
from typing import Self
3+
4+
import numpy as np
5+
6+
MULTIARRAY_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
7+
8+
9+
class MultiArray:
10+
arrays: tuple[np.ndarray, ...]
11+
12+
def __init__(self, arrays):
13+
self.arrays = arrays
14+
assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape"
15+
16+
def astype(self, dt, **kwargs) -> Self:
17+
return type(self)(tuple(array.astype(dt, **kwargs) for array in self.arrays))
18+
19+
def reshape(self, shape, **kwargs) -> Self:
20+
return type(self)(tuple(array.reshape(shape, **kwargs) for array in self.arrays))
21+
22+
def squeeze(self, axis=None) -> Self:
23+
return type(self)(tuple(array.squeeze(axis) for array in self.arrays))
24+
25+
def __setitem__(self, key, value) -> None:
26+
assert len(value) == len(self.arrays)
27+
for array, val in zip(self.arrays, value):
28+
array[key] = val
29+
30+
def __array_function__(self, func, types, args, kwargs):
31+
if func not in MULTIARRAY_HANDLED_FUNCTIONS:
32+
return NotImplemented
33+
# Note: this allows subclasses that don't override
34+
# __array_function__ to handle MyArray objects
35+
# if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in?
36+
# return NotImplemented
37+
return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs)
38+
39+
# Shape is needed, seems likely that the other two might be
40+
# Making some strong assumptions here that all the arrays are the same shape, and I don't really like this
41+
@property
42+
def dtype(self) -> np.dtype:
43+
return self.arrays[0].dtype
44+
45+
@property
46+
def shape(self) -> tuple[int, ...]:
47+
return self.arrays[0].shape
48+
49+
@property
50+
def ndim(self) -> int:
51+
return self.arrays[0].ndim
52+
53+
def __getitem__(self, key) -> Self:
54+
return type(self)([array[key] for array in self.arrays])
55+
56+
57+
def implements(numpy_function):
58+
"""Register an __array_function__ implementation for MyArray objects."""
59+
60+
def decorator(func):
61+
MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func
62+
return func
63+
64+
return decorator
65+
66+
67+
@implements(np.expand_dims)
68+
def expand_dims(multiarray, axis) -> MultiArray:
69+
return MultiArray(tuple(np.expand_dims(a, axis) for a in multiarray.arrays))
70+
71+
72+
@implements(np.concatenate)
73+
def concatenate(multiarrays, axis) -> MultiArray:
74+
n_arrays = len(multiarrays[0].arrays)
75+
for ma in multiarrays[1:]:
76+
assert len(ma.arrays) == n_arrays
77+
return MultiArray(
78+
tuple(np.concatenate(tuple(ma.arrays[i] for ma in multiarrays), axis) for i in range(n_arrays))
79+
)
80+
81+
82+
@implements(np.transpose)
83+
def transpose(multiarray, axes) -> MultiArray:
84+
return MultiArray(tuple(np.transpose(a, axes) for a in multiarray.arrays))
85+
86+
87+
@implements(np.squeeze)
88+
def squeeze(multiarray, axis) -> MultiArray:
89+
return MultiArray(tuple(np.squeeze(a, axis) for a in multiarray.arrays))
90+
91+
92+
@implements(np.full)
93+
def full(shape, fill_values, *args, **kwargs) -> MultiArray:
94+
"""All arguments except fill_value are shared by each array in the MultiArray.
95+
Iterate over fill_values to create arrays
96+
"""
97+
return MultiArray(tuple(np.full(shape, fv, *args, **kwargs) for fv in fill_values))

flox/xrutils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
147147

148148

149149
def notnull(data):
150+
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0):
151+
# boo: another special case for Var
152+
return True
150153
if not is_duck_array(data):
151154
data = np.asarray(data)
152155

@@ -164,6 +167,9 @@ def notnull(data):
164167

165168

166169
def isnull(data: Any):
170+
if isinstance(data, tuple) and len(data) == 3 and data == (0, 0, 0):
171+
# boo: another special case for Var
172+
return False
167173
if data is None:
168174
return False
169175
if not is_duck_array(data):

0 commit comments

Comments
 (0)