From 8a36af21cd26b29571a0b5e3e3fb1a581ea3447a Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 16 Jun 2025 10:55:58 -0400 Subject: [PATCH 1/4] psc: replace field_to_component with func --- src/pscpy/psc.py | 65 +++++++++++------------------------------------- 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 54c2bb6..c6ed33a 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -50,52 +50,17 @@ def __repr__(self) -> str: return f"Psc(gdims={self.gdims}, length={self.length}, corner={self.corner})" -def get_field_to_component(species_names: Iterable[str]) -> dict[str, dict[str, int]]: - field_to_component: dict[str, dict[str, int]] = {} - field_to_component["jeh"] = { - "jx_ec": 0, - "jy_ec": 1, - "jz_ec": 2, - "ex_ec": 3, - "ey_ec": 4, - "ez_ec": 5, - "hx_fc": 6, - "hy_fc": 7, - "hz_fc": 8, - } - field_to_component["dive"] = {"dive": 0} - field_to_component["rho"] = {"rho": 0} - field_to_component["d_rho"] = {"d_rho": 0} - field_to_component["dt_divj"] = {"dt_divj": 0} - - # keeping 'all_1st' for backwards compatibility - field_to_component["all_1st"] = {} - field_to_component["all_1st_cc"] = {} - moments = [ - "rho", - "jx", - "jy", - "jz", - "px", - "py", - "pz", - "txx", - "tyy", - "tzz", - "txy", - "tyz", - "tzx", - ] - for species_idx, species_name in enumerate(species_names): - for moment_idx, moment in enumerate(moments): - field_to_component["all_1st"][f"{moment}_{species_name}"] = ( - moment_idx + 13 * species_idx - ) - field_to_component["all_1st_cc"][f"{moment}_{species_name}"] = ( - moment_idx + 13 * species_idx - ) - - return field_to_component +def get_components(field: str, species_names: Iterable[str]) -> list[str] | None: + # fmt: off + if field == "jeh": + return ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] + elif field in ["dive", "rho", "d_rho", "dt_divj"]: + return [field] + elif field in ["all_1st", "all_1st_cc"]: + moments = ["rho", "jx", "jy", "jz", "px", "py", "pz", "txx", "tyy", "tzz", "txy", "tyz", "tzx"] + return [f"{moment}_{species_name}" for species_name in species_names for moment in moments] + return None + # fmt: on def decode_psc( @@ -118,13 +83,13 @@ def decode_psc( } ) ds = ds.squeeze("step") - field_to_component = get_field_to_component(species_names) data_vars = {} for var_name in ds: - if var_name in field_to_component: - for field, component_idx in field_to_component[var_name].items(): # type: ignore[index] - data_vars[field] = ds[var_name][component_idx, :, :, :] + components = get_components(str(var_name), species_names) + if components is not None: + for component_idx, component in enumerate(components): + data_vars[component] = ds[var_name][component_idx, :, :, :] ds = ds.drop_vars([var_name]) ds = ds.assign(data_vars) From e458357a1fb5ab6b9414094eac68c06387ed8753 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 16 Jun 2025 13:05:49 -0400 Subject: [PATCH 2/4] psc: yield components --- src/pscpy/psc.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index c6ed33a..66eb0fe 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable +from collections.abc import Generator, Hashable, Iterable from typing import Any import numpy as np @@ -50,16 +50,17 @@ def __repr__(self) -> str: return f"Psc(gdims={self.gdims}, length={self.length}, corner={self.corner})" -def get_components(field: str, species_names: Iterable[str]) -> list[str] | None: +def iter_components(field: Hashable, species_names: Iterable[str]) -> Generator[str]: # fmt: off if field == "jeh": - return ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] + yield from ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] elif field in ["dive", "rho", "d_rho", "dt_divj"]: - return [field] + yield str(field) elif field in ["all_1st", "all_1st_cc"]: moments = ["rho", "jx", "jy", "jz", "px", "py", "pz", "txx", "tyy", "tzz", "txy", "tyz", "tzx"] - return [f"{moment}_{species_name}" for species_name in species_names for moment in moments] - return None + for species_name in species_names: + for moment in moments: + yield f"{moment}_{species_name}" # fmt: on @@ -86,10 +87,10 @@ def decode_psc( data_vars = {} for var_name in ds: - components = get_components(str(var_name), species_names) - if components is not None: - for component_idx, component in enumerate(components): - data_vars[component] = ds[var_name][component_idx, :, :, :] + for component_idx, component in enumerate( + iter_components(var_name, species_names) + ): + data_vars[component] = ds[var_name][component_idx, :, :, :] ds = ds.drop_vars([var_name]) ds = ds.assign(data_vars) From 891dd61ca8d910afe57165de962aa7359d99eace Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 16 Jun 2025 13:12:34 -0400 Subject: [PATCH 3/4] psc: don't batch assignment I'm not sure if this significantly affects performance one way or another --- src/pscpy/psc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 66eb0fe..6299358 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -85,14 +85,12 @@ def decode_psc( ) ds = ds.squeeze("step") - data_vars = {} for var_name in ds: for component_idx, component in enumerate( iter_components(var_name, species_names) ): - data_vars[component] = ds[var_name][component_idx, :, :, :] + ds = ds.assign({component: ds[var_name][component_idx, :, :, :]}) ds = ds.drop_vars([var_name]) - ds = ds.assign(data_vars) run_info = RunInfo(ds, length=length, corner=corner) coords = { From abf7815c6376b9c1469ec9ff12a2754f2bb4a0c1 Mon Sep 17 00:00:00 2001 From: James McClung Date: Mon, 16 Jun 2025 13:14:44 -0400 Subject: [PATCH 4/4] psc: turn off fmt for specific lines --- src/pscpy/psc.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index 6299358..bf1e606 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -51,17 +51,15 @@ def __repr__(self) -> str: def iter_components(field: Hashable, species_names: Iterable[str]) -> Generator[str]: - # fmt: off if field == "jeh": - yield from ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] + yield from ["jx_ec", "jy_ec", "jz_ec", "ex_ec", "ey_ec", "ez_ec", "hx_fc", "hy_fc", "hz_fc"] # fmt: off elif field in ["dive", "rho", "d_rho", "dt_divj"]: yield str(field) elif field in ["all_1st", "all_1st_cc"]: - moments = ["rho", "jx", "jy", "jz", "px", "py", "pz", "txx", "tyy", "tzz", "txy", "tyz", "tzx"] + moments = ["rho", "jx", "jy", "jz", "px", "py", "pz", "txx", "tyy", "tzz", "txy", "tyz", "tzx"] # fmt: off for species_name in species_names: for moment in moments: yield f"{moment}_{species_name}" - # fmt: on def decode_psc(