diff --git a/elements/space.py b/elements/space.py index ec74cea..8e7e12a 100644 --- a/elements/space.py +++ b/elements/space.py @@ -1,3 +1,4 @@ +import ml_dtypes import numpy as np @@ -13,7 +14,7 @@ def __init__(self, dtype, shape=(), low=None, high=None): self._high = self._infer_high(dtype, shape, low, high) self._shape = self._infer_shape(dtype, shape, self._low, self._high) self._discrete = ( - np.issubdtype(self.dtype, np.integer) or self.dtype == bool) + issubdtype(self.dtype, np.integer) or self.dtype == bool) self._random = np.random.RandomState() @property @@ -66,8 +67,8 @@ def __repr__(self): def __contains__(self, value): value = np.asarray(value) - if np.issubdtype(self.dtype, str): - return np.issubdtype(value.dtype, str) + if issubdtype(self.dtype, str): + return issubdtype(value.dtype, str) if value.shape != self.shape: return False if (value > self.high).any(): @@ -80,13 +81,13 @@ def __contains__(self, value): def sample(self): low, high = self.low, self.high - if np.issubdtype(self.dtype, np.floating): - low = np.maximum(np.ones(self.shape) * np.finfo(self.dtype).min, low) - high = np.minimum(np.ones(self.shape) * np.finfo(self.dtype).max, high) + if issubdtype(self.dtype, np.floating): + low = np.maximum(np.ones(self.shape) * ml_dtypes.finfo(self.dtype).min, low) + high = np.minimum(np.ones(self.shape) * ml_dtypes.finfo(self.dtype).max, high) return self._random.uniform(low, high, self.shape).astype(self.dtype) def _infer_low(self, dtype, shape, low, high): - if np.issubdtype(dtype, str): + if issubdtype(dtype, str): assert low is None, low return None if low is not None: @@ -94,17 +95,17 @@ def _infer_low(self, dtype, shape, low, high): return np.broadcast_to(low, shape) except ValueError: raise ValueError(f'Cannot broadcast {low} to shape {shape}') - elif np.issubdtype(dtype, np.floating): + elif issubdtype(dtype, np.floating): return -np.inf * np.ones(shape) - elif np.issubdtype(dtype, np.integer): + elif issubdtype(dtype, np.integer): return np.iinfo(dtype).min * np.ones(shape, dtype) - elif np.issubdtype(dtype, bool): + elif issubdtype(dtype, bool): return np.zeros(shape, bool) else: raise ValueError('Cannot infer low bound from shape and dtype.') def _infer_high(self, dtype, shape, low, high): - if np.issubdtype(dtype, str): + if issubdtype(dtype, str): assert high is None, high return None if high is not None: @@ -112,11 +113,11 @@ def _infer_high(self, dtype, shape, low, high): return np.broadcast_to(high, shape) except ValueError: raise ValueError(f'Cannot broadcast {high} to shape {shape}') - elif np.issubdtype(dtype, np.floating): + elif issubdtype(dtype, np.floating): return np.inf * np.ones(shape) - elif np.issubdtype(dtype, np.integer): + elif issubdtype(dtype, np.integer): return np.iinfo(dtype).max * np.ones(shape, dtype) - elif np.issubdtype(dtype, bool): + elif issubdtype(dtype, bool): return np.ones(shape, bool) else: raise ValueError('Cannot infer high bound from shape and dtype.') @@ -130,3 +131,14 @@ def _infer_shape(self, dtype, shape, low, high): shape = (shape,) assert all(dim and dim > 0 for dim in shape), shape return tuple(shape) + + +def issubdtype(a, b): + custom_float_dtypes = [getattr(ml_dtypes, name) for name in dir(ml_dtypes) if 'float' in name] + if a in custom_float_dtypes: + return b in {a, np.floating, np.inexact, np.number, np.generic} + if a in [ml_dtypes.int2, ml_dtypes.int4]: + return b in {a, np.signedinteger, np.integer, np.number, np.generic} + if a in [ml_dtypes.uint2, ml_dtypes.uint4]: + return b in {a, np.unsignedinteger, np.integer, np.number, np.generic} + return bool(np.issubdtype(a, b)) diff --git a/pyproject.toml b/pyproject.toml index 7f2a75b..a25ae0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ requires-python = ">=3.10" dependencies = [ "google-auth", "google-cloud-storage", + "ml_dtypes", "numpy", "psutil", "ruamel.yaml",