From 41e00376bec570487787403c21f3b087a4785e38 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 5 Nov 2025 17:09:13 +0000 Subject: [PATCH 1/2] added clever context manager xp decorator --- autoconf/xp_import.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 autoconf/xp_import.py diff --git a/autoconf/xp_import.py b/autoconf/xp_import.py new file mode 100644 index 0000000..ed68393 --- /dev/null +++ b/autoconf/xp_import.py @@ -0,0 +1,30 @@ +import contextlib +import numpy as np +import jax +import jax.numpy as jnp + + +def get_xp(*args, **kwargs): + for x in jax.tree_util.tree_leaves((args, kwargs)): + if isinstance(x, jax.core.Tracer): + return jnp + return np + +@contextlib.contextmanager +def _temporary_xp_binding(fn_globals, xp): + old_xp = fn_globals.get("xp", None) + fn_globals["xp"] = xp + try: + yield + finally: + if old_xp is None: + del fn_globals["xp"] + else: + fn_globals["xp"] = old_xp + +def auto_xp(fn): + def wrapped(*args, **kwargs): + xp = get_xp(args, kwargs) + with _temporary_xp_binding(fn.__globals__, xp): + return fn(*args, **kwargs) + return wrapped From 34d233c63b5e331c236252ae8942de853b332a3d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 10 Nov 2025 13:20:35 +0000 Subject: [PATCH 2/2] xp namespace complete --- autoconf/xp_import.py | 30 ------------------------------ 1 file changed, 30 deletions(-) delete mode 100644 autoconf/xp_import.py diff --git a/autoconf/xp_import.py b/autoconf/xp_import.py deleted file mode 100644 index ed68393..0000000 --- a/autoconf/xp_import.py +++ /dev/null @@ -1,30 +0,0 @@ -import contextlib -import numpy as np -import jax -import jax.numpy as jnp - - -def get_xp(*args, **kwargs): - for x in jax.tree_util.tree_leaves((args, kwargs)): - if isinstance(x, jax.core.Tracer): - return jnp - return np - -@contextlib.contextmanager -def _temporary_xp_binding(fn_globals, xp): - old_xp = fn_globals.get("xp", None) - fn_globals["xp"] = xp - try: - yield - finally: - if old_xp is None: - del fn_globals["xp"] - else: - fn_globals["xp"] = old_xp - -def auto_xp(fn): - def wrapped(*args, **kwargs): - xp = get_xp(args, kwargs) - with _temporary_xp_binding(fn.__globals__, xp): - return fn(*args, **kwargs) - return wrapped