diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 54c2bb6..bf1e606 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Generator, Hashable, Iterable from typing import Any import numpy as np @@ -50,52 +50,16 @@ def __repr__(self) -> str: return f"Psc(gdims={self.gdims}, length={self.length}, corner={self.corner})" -def get_field_to_component(species_names: Iterable[str]) -> dict[str, dict[str, int]]: - field_to_component: dict[str, dict[str, int]] = {} - field_to_component["jeh"] = { - "jx_ec": 0, - "jy_ec": 1, - "jz_ec": 2, - "ex_ec": 3, - "ey_ec": 4, - "ez_ec": 5, - "hx_fc": 6, - "hy_fc": 7, - "hz_fc": 8, - } - field_to_component["dive"] = {"dive": 0} - field_to_component["rho"] = {"rho": 0} - field_to_component["d_rho"] = {"d_rho": 0} - field_to_component["dt_divj"] = {"dt_divj": 0} - - # keeping 'all_1st' for backwards compatibility - field_to_component["all_1st"] = {} - field_to_component["all_1st_cc"] = {} - moments = [ - "rho", - "jx", - "jy", - "jz", - "px", - "py", - "pz", - "txx", - "tyy", - "tzz", - "txy", - "tyz", - "tzx", - ] - for species_idx, species_name in enumerate(species_names): - for moment_idx, moment in enumerate(moments): - field_to_component["all_1st"][f"{moment}_{species_name}"] = ( - moment_idx + 13 * species_idx - ) - field_to_component["all_1st_cc"][f"{moment}_{species_name}"] = ( - moment_idx + 13 * species_idx - ) - - return field_to_component +def iter_components(field: Hashable, species_names: Iterable[str]) -> Generator[str]: + if field == "jeh": + yield from ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] # fmt: off + elif field in ["dive", "rho", "d_rho", "dt_divj"]: + yield str(field) + elif field in ["all_1st", "all_1st_cc"]: + moments = ["rho", "jx", "jy", "jz", "px", "py", "pz", "txx", "tyy", "tzz", "txy", "tyz", "tzx"] # fmt: off + for species_name in species_names: + for moment in moments: + yield f"{moment}_{species_name}" def decode_psc( @@ -118,15 +82,13 @@ def decode_psc( } ) ds = ds.squeeze("step") - field_to_component = get_field_to_component(species_names) - data_vars = {} for var_name in ds: - if var_name in field_to_component: - for field, component_idx in field_to_component[var_name].items(): # type: ignore[index] - data_vars[field] = ds[var_name][component_idx, :, :, :] + for component_idx, component in enumerate( + iter_components(var_name, species_names) + ): + ds = ds.assign({component: ds[var_name][component_idx, :, :, :]}) ds = ds.drop_vars([var_name]) - ds = ds.assign(data_vars) run_info = RunInfo(ds, length=length, corner=corner) coords = {