diff --git a/pathwaysutils/__init__.py b/pathwaysutils/__init__.py index 0be743f..5546147 100644 --- a/pathwaysutils/__init__.py +++ b/pathwaysutils/__init__.py @@ -67,6 +67,31 @@ def _is_persistence_enabled() -> bool: return False +def _is_colocated_python_enabled() -> bool: + """Returns whether colocated python checkpointing is enabled. + + This function checks the environment variable + ENABLE_COLOCATED_PYTHON_CHECKPOINTING to determine whether colocated python + checkpointing is enabled. If the variable is set to "1", it is enabled. If the + variable is set to "0" or unset, it is disabled. + + Returns: + True if colocated python checkpointing is enabled, False otherwise. + """ + if "ENABLE_COLOCATED_PYTHON_CHECKPOINTING" in os.environ: + if os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"] == "1": + return True + if os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"] == "0": + return False + else: + raise ValueError( + "ENABLE_COLOCATED_PYTHON_CHECKPOINTING must be set to 1/0 or" + " unset, got: " + + os.environ["ENABLE_COLOCATED_PYTHON_CHECKPOINTING"] + ) + return False + + def initialize() -> None: """Initializes pathwaysutils. @@ -93,8 +118,16 @@ def initialize() -> None: proxy_backend.register_backend_factory() profiling.monkey_patch_jax() # TODO: b/365549911 - Remove when OCDBT-compatible - if _is_persistence_enabled(): - orbax_handler.register_pathways_handlers(datetime.timedelta(hours=1)) + if _is_persistence_enabled() ^ _is_colocated_python_enabled(): + orbax_handler.register_pathways_handlers( + datetime.timedelta(hours=1), + use_colocated_python=_is_colocated_python_enabled(), + ) + elif _is_persistence_enabled() and _is_colocated_python_enabled(): + raise ValueError( + "Invalid configuration: ENABLE_PATHWAYS_PERSISTENCE and" + " ENABLE_COLOCATED_PYTHON_CHECKPOINTING cannot both be enabled." + ) # Turn off JAX compilation cache because Pathways handles its own # compilation cache. @@ -103,4 +136,4 @@ def initialize() -> None: else: _logger.debug( "Did not detect Pathways-on-Cloud backend. No changes applied." - ) + ) \ No newline at end of file diff --git a/pathwaysutils/persistence/orbax_handler.py b/pathwaysutils/persistence/orbax_handler.py index c0e72e8..e4e1261 100644 --- a/pathwaysutils/persistence/orbax_handler.py +++ b/pathwaysutils/persistence/orbax_handler.py @@ -22,11 +22,14 @@ import typing import jax +from orbax.checkpoint import experimental from orbax.checkpoint import future from orbax.checkpoint import type_handlers from pathwaysutils.persistence import helper +ColocatedPythonArrayHandler = experimental.ColocatedPythonArrayHandler + logger = logging.getLogger(__name__) ParamInfo = type_handlers.ParamInfo @@ -192,15 +195,24 @@ async def deserialize( def register_pathways_handlers( read_timeout: datetime.timedelta | None = None, + use_colocated_python: bool = False, ): """Function that must be called before saving or restoring with Pathways.""" - logger.debug( - "Registering CloudPathwaysArrayHandler (Pathways Persistence API)." - ) - type_handlers.register_type_handler( - jax.Array, - CloudPathwaysArrayHandler( - read_timeout=read_timeout, - ), - override=True, - ) + if use_colocated_python: + logger.debug("Registering ColocatedPythonArrayHandler.") + type_handlers.register_type_handler( + jax.Array, + ColocatedPythonArrayHandler(), + override=True, + ) + else: + logger.debug( + "Registering CloudPathwaysArrayHandler (Pathways Persistence API)." + ) + type_handlers.register_type_handler( + jax.Array, + CloudPathwaysArrayHandler( + read_timeout=read_timeout, + ), + override=True, + ) diff --git a/pathwaysutils/test/pathwaysutils_test.py b/pathwaysutils/test/pathwaysutils_test.py index a2fedd9..7f78339 100644 --- a/pathwaysutils/test/pathwaysutils_test.py +++ b/pathwaysutils/test/pathwaysutils_test.py @@ -24,20 +24,44 @@ class PathwaysutilsTest(parameterized.TestCase): - def test_first_initialize(self): + @parameterized.named_parameters( + ("persistence", "ENABLE_PATHWAYS_PERSISTENCE"), + ("colocated_python", "ENABLE_COLOCATED_PYTHON_CHECKPOINTING"), + ) + def test_first_initialize(self, flag): jax.config.update("jax_platforms", "proxy") pathwaysutils._initialization_count = 0 - with self.assertLogs(pathwaysutils._logger, level="DEBUG") as logs: - pathwaysutils.initialize() + with mock.patch.dict(os.environ, {flag: "1"}, clear=True): + with self.assertLogs("pathwaysutils", level="DEBUG") as logs: + pathwaysutils.initialize() - self.assertLen(logs.output, 2) - self.assertIn( - "Starting initialize.", logs.output[0] - ) + self.assertLen(logs.output, 3) + self.assertIn("Starting initialize.", logs.output[0]) self.assertIn( "Detected Pathways-on-Cloud backend. Applying changes.", logs.output[1] ) + if flag == "ENABLE_PATHWAYS_PERSISTENCE": + self.assertIn( + "Registering CloudPathwaysArrayHandler", logs.output[2] + ) + else: + self.assertIn("Registering ColocatedPythonArrayHandler", logs.output[2]) + + def test_initialize_with_both_enabled_raises_error(self): + jax.config.update("jax_platforms", "proxy") + pathwaysutils._initialization_count = 0 + + with mock.patch.dict( + os.environ, + { + "ENABLE_PATHWAYS_PERSISTENCE": "1", + "ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "1", + }, + clear=True, + ): + with self.assertRaises(ValueError): + pathwaysutils.initialize() @parameterized.named_parameters( ("initialization_count 1", 1), @@ -78,17 +102,42 @@ def test_is_pathways_backend_used(self, platform: str): self.assertTrue(pathwaysutils.is_pathways_backend_used()) def test_persistence_enabled(self): - os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "1" - self.assertTrue(pathwaysutils._is_persistence_enabled()) - - os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "0" - self.assertFalse(pathwaysutils._is_persistence_enabled()) - - os.environ["ENABLE_PATHWAYS_PERSISTENCE"] = "" - self.assertRaises(ValueError, pathwaysutils._is_persistence_enabled) - - del os.environ["ENABLE_PATHWAYS_PERSISTENCE"] - self.assertFalse(pathwaysutils._is_persistence_enabled()) + with mock.patch.dict( + os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": "1"}, clear=True + ): + self.assertTrue(pathwaysutils._is_persistence_enabled()) + + with mock.patch.dict( + os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": "0"}, clear=True + ): + self.assertFalse(pathwaysutils._is_persistence_enabled()) + + with mock.patch.dict( + os.environ, {"ENABLE_PATHWAYS_PERSISTENCE": ""}, clear=True + ): + self.assertRaises(ValueError, pathwaysutils._is_persistence_enabled) + + with mock.patch.dict(os.environ, {}, clear=True): + self.assertFalse(pathwaysutils._is_persistence_enabled()) + + def test_colocated_python_enabled(self): + with mock.patch.dict( + os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "1"}, clear=True + ): + self.assertTrue(pathwaysutils._is_colocated_python_enabled()) + + with mock.patch.dict( + os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": "0"}, clear=True + ): + self.assertFalse(pathwaysutils._is_colocated_python_enabled()) + + with mock.patch.dict( + os.environ, {"ENABLE_COLOCATED_PYTHON_CHECKPOINTING": ""}, clear=True + ): + self.assertRaises(ValueError, pathwaysutils._is_colocated_python_enabled) + + with mock.patch.dict(os.environ, {}, clear=True): + self.assertFalse(pathwaysutils._is_colocated_python_enabled()) if __name__ == "__main__": diff --git a/pathwaysutils/test/proxy_backend_test.py b/pathwaysutils/test/proxy_backend_test.py index 08988de..d102fbd 100644 --- a/pathwaysutils/test/proxy_backend_test.py +++ b/pathwaysutils/test/proxy_backend_test.py @@ -20,6 +20,7 @@ from jax.lib.xla_extension import ifrt_proxy from pathwaysutils import proxy_backend + from absl.testing import absltest