|
15 | 15 | from . import aggregate_flox, aggregate_npg, xrutils
|
16 | 16 | from . import xrdtypes as dtypes
|
17 | 17 | from .lib import dask_array_type, sparse_array_type
|
| 18 | +from .multiarray import MultiArray |
| 19 | +from .xrutils import notnull |
18 | 20 |
|
19 | 21 | if TYPE_CHECKING:
|
20 | 22 | FuncTuple = tuple[Callable | str, ...]
|
@@ -161,8 +163,8 @@ def __init__(
|
161 | 163 | self,
|
162 | 164 | name: str,
|
163 | 165 | *,
|
164 |
| - numpy: str | None = None, |
165 |
| - chunk: str | FuncTuple | None, |
| 166 | + numpy: partial | str | None = None, |
| 167 | + chunk: partial | str | FuncTuple | None, |
166 | 168 | combine: str | FuncTuple | None,
|
167 | 169 | preprocess: Callable | None = None,
|
168 | 170 | finalize: Callable | None = None,
|
@@ -343,57 +345,183 @@ def _mean_finalize(sum_, count):
|
343 | 345 | )
|
344 | 346 |
|
345 | 347 |
|
346 |
| -# TODO: fix this for complex numbers |
347 |
| -def _var_finalize(sumsq, sum_, count, ddof=0): |
| 348 | +def var_chunk( |
| 349 | + group_idx, array, *, skipna: bool, engine: str, axis=-1, size=None, fill_value=None, dtype=None |
| 350 | +): |
| 351 | + # Calculate length and sum - important for the adjustment terms to sum squared deviations |
| 352 | + array_lens = generic_aggregate( |
| 353 | + group_idx, |
| 354 | + array, |
| 355 | + func="nanlen", |
| 356 | + engine=engine, |
| 357 | + axis=axis, |
| 358 | + size=size, |
| 359 | + fill_value=0, # Unpack fill value bc it's currently defined for multiarray |
| 360 | + dtype=dtype, |
| 361 | + ) |
| 362 | + |
| 363 | + array_sums = generic_aggregate( |
| 364 | + group_idx, |
| 365 | + array, |
| 366 | + func="nansum" if skipna else "sum", |
| 367 | + engine=engine, |
| 368 | + axis=axis, |
| 369 | + size=size, |
| 370 | + fill_value=0, # Unpack fill value bc it's currently defined for multiarray |
| 371 | + dtype=dtype, |
| 372 | + ) |
| 373 | + |
| 374 | + # Calculate sum squared deviations - the main part of variance sum |
348 | 375 | with np.errstate(invalid="ignore", divide="ignore"):
|
349 |
| - result = (sumsq - (sum_**2 / count)) / (count - ddof) |
350 |
| - result[count <= ddof] = np.nan |
351 |
| - return result |
| 376 | + array_means = array_sums / array_lens |
| 377 | + |
| 378 | + sum_squared_deviations = generic_aggregate( |
| 379 | + group_idx, |
| 380 | + (array - array_means[..., group_idx]) ** 2, |
| 381 | + func="nansum" if skipna else "sum", |
| 382 | + engine=engine, |
| 383 | + axis=axis, |
| 384 | + size=size, |
| 385 | + fill_value=0, # Unpack fill value bc it's currently defined for multiarray |
| 386 | + dtype=dtype, |
| 387 | + ) |
| 388 | + |
| 389 | + return MultiArray((sum_squared_deviations, array_sums, array_lens)) |
| 390 | + |
| 391 | + |
| 392 | +def _var_combine(array, axis, keepdims=True): |
| 393 | + def clip_last(array, ax, n=1): |
| 394 | + """Return array except the last element along axis |
| 395 | + Purely included to tidy up the adj_terms line |
| 396 | + """ |
| 397 | + assert n > 0, "Clipping nothing off the end isn't implemented" |
| 398 | + not_last = [slice(None, None) for i in range(array.ndim)] |
| 399 | + not_last[ax] = slice(None, -n) |
| 400 | + return array[*not_last] |
| 401 | + |
| 402 | + def clip_first(array, ax, n=1): |
| 403 | + """Return array except the first element along axis |
| 404 | + Purely included to tidy up the adj_terms line |
| 405 | + """ |
| 406 | + not_first = [slice(None, None) for i in range(array.ndim)] |
| 407 | + not_first[ax] = slice(n, None) |
| 408 | + return array[*not_first] |
| 409 | + |
| 410 | + for ax in axis: |
| 411 | + if array.shape[ax] == 1: |
| 412 | + continue |
| 413 | + |
| 414 | + sum_deviations, sum_X, sum_len = array.arrays |
| 415 | + |
| 416 | + # Calculate parts needed for cascading combination |
| 417 | + cumsum_X = np.cumsum(sum_X, axis=ax) |
| 418 | + cumsum_len = np.cumsum(sum_len, axis=ax) |
| 419 | + |
| 420 | + # There will be instances in which one or both chunks being merged are empty |
| 421 | + # In which case, the adjustment term should be zero, but will throw a divide-by-zero error |
| 422 | + # We're going to add a constant to the bottom of the adjustment term equation on those instances |
| 423 | + # and count on the zeros on the top making our adjustment term still zero |
| 424 | + zero_denominator = (clip_last(cumsum_len, ax) == 0) | (clip_first(sum_len, ax) == 0) |
| 425 | + |
| 426 | + # Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean |
| 427 | + with np.errstate(invalid="ignore", divide="ignore"): |
| 428 | + adj_terms = ( |
| 429 | + clip_last(cumsum_len, ax) * clip_first(sum_X, ax) |
| 430 | + - clip_first(sum_len, ax) * clip_last(cumsum_X, ax) |
| 431 | + ) ** 2 / ( |
| 432 | + clip_last(cumsum_len, ax) |
| 433 | + * clip_first(sum_len, ax) |
| 434 | + * (clip_last(cumsum_len, ax) + clip_first(sum_len, ax)) |
| 435 | + + zero_denominator.astype(int) |
| 436 | + ) |
| 437 | + |
| 438 | + check = adj_terms * zero_denominator |
| 439 | + assert np.all(check[notnull(check)] == 0), ( |
| 440 | + "Instances where we add something to the denominator must come out to zero" |
| 441 | + ) |
| 442 | + |
| 443 | + array = MultiArray( |
| 444 | + ( |
| 445 | + np.sum(sum_deviations, axis=ax, keepdims=keepdims) |
| 446 | + + np.sum(adj_terms, axis=ax, keepdims=keepdims), # sum of squared deviations |
| 447 | + np.sum(sum_X, axis=ax, keepdims=keepdims), # sum of array items |
| 448 | + np.sum(sum_len, axis=ax, keepdims=keepdims), # sum of array lengths |
| 449 | + ) |
| 450 | + ) |
| 451 | + return array |
| 452 | + |
| 453 | + |
| 454 | +def is_var_chunk_reduction(agg: Callable) -> bool: |
| 455 | + if isinstance(agg, partial): |
| 456 | + agg = agg.func |
| 457 | + return agg is blockwise_or_numpy_var or agg is var_chunk |
| 458 | + |
| 459 | + |
| 460 | +def _var_finalize(multiarray, ddof=0): |
| 461 | + den = multiarray.arrays[2] |
| 462 | + den -= ddof |
| 463 | + # preserve nans for groups with 0 obs; so these values are -ddof |
| 464 | + with np.errstate(invalid="ignore", divide="ignore"): |
| 465 | + ret = multiarray.arrays[0] |
| 466 | + ret /= den |
| 467 | + ret[den < 0] = np.nan |
| 468 | + return ret |
352 | 469 |
|
353 | 470 |
|
354 |
| -def _std_finalize(sumsq, sum_, count, ddof=0): |
355 |
| - return np.sqrt(_var_finalize(sumsq, sum_, count, ddof)) |
| 471 | +def _std_finalize(multiarray, ddof=0): |
| 472 | + return np.sqrt(_var_finalize(multiarray, ddof)) |
| 473 | + |
| 474 | + |
| 475 | +def blockwise_or_numpy_var(*args, skipna: bool, ddof=0, std=False, **kwargs): |
| 476 | + res = _var_finalize(var_chunk(*args, skipna=skipna, **kwargs), ddof) |
| 477 | + return np.sqrt(res) if std else res |
356 | 478 |
|
357 | 479 |
|
358 | 480 | # var, std always promote to float, so we set nan
|
359 | 481 | var = Aggregation(
|
360 | 482 | "var",
|
361 |
| - chunk=("sum_of_squares", "sum", "nanlen"), |
362 |
| - combine=("sum", "sum", "sum"), |
| 483 | + chunk=partial(var_chunk, skipna=False), |
| 484 | + numpy=partial(blockwise_or_numpy_var, skipna=False), |
| 485 | + combine=(_var_combine,), |
363 | 486 | finalize=_var_finalize,
|
364 |
| - fill_value=0, |
| 487 | + fill_value=((0, 0, 0),), |
365 | 488 | final_fill_value=np.nan,
|
366 |
| - dtypes=(None, None, np.intp), |
| 489 | + dtypes=(None,), |
367 | 490 | final_dtype=np.floating,
|
368 | 491 | )
|
| 492 | + |
369 | 493 | nanvar = Aggregation(
|
370 | 494 | "nanvar",
|
371 |
| - chunk=("nansum_of_squares", "nansum", "nanlen"), |
372 |
| - combine=("sum", "sum", "sum"), |
| 495 | + chunk=partial(var_chunk, skipna=True), |
| 496 | + numpy=partial(blockwise_or_numpy_var, skipna=True), |
| 497 | + combine=(_var_combine,), |
373 | 498 | finalize=_var_finalize,
|
374 |
| - fill_value=0, |
| 499 | + fill_value=((0, 0, 0),), |
375 | 500 | final_fill_value=np.nan,
|
376 |
| - dtypes=(None, None, np.intp), |
| 501 | + dtypes=(None,), |
377 | 502 | final_dtype=np.floating,
|
378 | 503 | )
|
| 504 | + |
379 | 505 | std = Aggregation(
|
380 | 506 | "std",
|
381 |
| - chunk=("sum_of_squares", "sum", "nanlen"), |
382 |
| - combine=("sum", "sum", "sum"), |
| 507 | + chunk=partial(var_chunk, skipna=False), |
| 508 | + numpy=partial(blockwise_or_numpy_var, skipna=False, std=True), |
| 509 | + combine=(_var_combine,), |
383 | 510 | finalize=_std_finalize,
|
384 |
| - fill_value=0, |
| 511 | + fill_value=((0, 0, 0),), |
385 | 512 | final_fill_value=np.nan,
|
386 |
| - dtypes=(None, None, np.intp), |
| 513 | + dtypes=(None,), |
387 | 514 | final_dtype=np.floating,
|
388 | 515 | )
|
389 | 516 | nanstd = Aggregation(
|
390 | 517 | "nanstd",
|
391 |
| - chunk=("nansum_of_squares", "nansum", "nanlen"), |
392 |
| - combine=("sum", "sum", "sum"), |
| 518 | + chunk=partial(var_chunk, skipna=True), |
| 519 | + numpy=partial(blockwise_or_numpy_var, skipna=True, std=True), |
| 520 | + combine=(_var_combine,), |
393 | 521 | finalize=_std_finalize,
|
394 |
| - fill_value=0, |
| 522 | + fill_value=((0, 0, 0),), |
395 | 523 | final_fill_value=np.nan,
|
396 |
| - dtypes=(None, None, np.intp), |
| 524 | + dtypes=(None,), |
397 | 525 | final_dtype=np.floating,
|
398 | 526 | )
|
399 | 527 |
|
|
0 commit comments