Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions elements/space.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ml_dtypes
import numpy as np


Expand All @@ -13,7 +14,7 @@ def __init__(self, dtype, shape=(), low=None, high=None):
self._high = self._infer_high(dtype, shape, low, high)
self._shape = self._infer_shape(dtype, shape, self._low, self._high)
self._discrete = (
np.issubdtype(self.dtype, np.integer) or self.dtype == bool)
issubdtype(self.dtype, np.integer) or self.dtype == bool)
self._random = np.random.RandomState()

@property
Expand Down Expand Up @@ -66,8 +67,8 @@ def __repr__(self):

def __contains__(self, value):
value = np.asarray(value)
if np.issubdtype(self.dtype, str):
return np.issubdtype(value.dtype, str)
if issubdtype(self.dtype, str):
return issubdtype(value.dtype, str)
if value.shape != self.shape:
return False
if (value > self.high).any():
Expand All @@ -80,43 +81,43 @@ def __contains__(self, value):

def sample(self):
low, high = self.low, self.high
if np.issubdtype(self.dtype, np.floating):
low = np.maximum(np.ones(self.shape) * np.finfo(self.dtype).min, low)
high = np.minimum(np.ones(self.shape) * np.finfo(self.dtype).max, high)
if issubdtype(self.dtype, np.floating):
low = np.maximum(np.ones(self.shape) * ml_dtypes.finfo(self.dtype).min, low)
high = np.minimum(np.ones(self.shape) * ml_dtypes.finfo(self.dtype).max, high)
return self._random.uniform(low, high, self.shape).astype(self.dtype)

def _infer_low(self, dtype, shape, low, high):
if np.issubdtype(dtype, str):
if issubdtype(dtype, str):
assert low is None, low
return None
if low is not None:
try:
return np.broadcast_to(low, shape)
except ValueError:
raise ValueError(f'Cannot broadcast {low} to shape {shape}')
elif np.issubdtype(dtype, np.floating):
elif issubdtype(dtype, np.floating):
return -np.inf * np.ones(shape)
elif np.issubdtype(dtype, np.integer):
elif issubdtype(dtype, np.integer):
return np.iinfo(dtype).min * np.ones(shape, dtype)
elif np.issubdtype(dtype, bool):
elif issubdtype(dtype, bool):
return np.zeros(shape, bool)
else:
raise ValueError('Cannot infer low bound from shape and dtype.')

def _infer_high(self, dtype, shape, low, high):
if np.issubdtype(dtype, str):
if issubdtype(dtype, str):
assert high is None, high
return None
if high is not None:
try:
return np.broadcast_to(high, shape)
except ValueError:
raise ValueError(f'Cannot broadcast {high} to shape {shape}')
elif np.issubdtype(dtype, np.floating):
elif issubdtype(dtype, np.floating):
return np.inf * np.ones(shape)
elif np.issubdtype(dtype, np.integer):
elif issubdtype(dtype, np.integer):
return np.iinfo(dtype).max * np.ones(shape, dtype)
elif np.issubdtype(dtype, bool):
elif issubdtype(dtype, bool):
return np.ones(shape, bool)
else:
raise ValueError('Cannot infer high bound from shape and dtype.')
Expand All @@ -130,3 +131,14 @@ def _infer_shape(self, dtype, shape, low, high):
shape = (shape,)
assert all(dim and dim > 0 for dim in shape), shape
return tuple(shape)


def issubdtype(a, b):
custom_float_dtypes = [getattr(ml_dtypes, name) for name in dir(ml_dtypes) if 'float' in name]
if a in custom_float_dtypes:
return b in {a, np.floating, np.inexact, np.number, np.generic}
if a in [ml_dtypes.int2, ml_dtypes.int4]:
return b in {a, np.signedinteger, np.integer, np.number, np.generic}
if a in [ml_dtypes.uint2, ml_dtypes.uint4]:
return b in {a, np.unsignedinteger, np.integer, np.number, np.generic}
return bool(np.issubdtype(a, b))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ requires-python = ">=3.10"
dependencies = [
"google-auth",
"google-cloud-storage",
"ml_dtypes",
"numpy",
"psutil",
"ruamel.yaml",
Expand Down