Skip to content

Commit 46adc8e

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Internal change
PiperOrigin-RevId: 532516183
1 parent ec67350 commit 46adc8e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

clu/deterministic_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def create_dataset(dataset_builder: DatasetBuilder,
430430
if isinstance(rng, tf.Tensor):
431431
rngs = [x.numpy() for x in tf.random.experimental.stateless_split(rng, 3)]
432432
else:
433-
rngs = list(jax.random.split(rng, 3))
433+
rngs = list(jax.random.key_data(jax.random.split(rng, 3)))
434434
else:
435435
rngs = 3 * [[None, None]]
436436

0 commit comments

Comments
 (0)