Skip to content

Commit cd881dd

Browse files
made distributed structure proper
1 parent aaf6e20 commit cd881dd

File tree

6 files changed

+10
-5
lines changed

6 files changed

+10
-5
lines changed

keras/src/backend/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@
3636
# Import backend functions.
3737
if backend() == "tensorflow":
3838
from keras.src.backend.tensorflow import * # noqa: F403
39-
from keras.src.backend.tensorflow import distribution_lib
4039
from keras.src.backend.tensorflow.core import Variable as BackendVariable
4140
elif backend() == "jax":
4241
from keras.src.backend.jax import * # noqa: F403
43-
from keras.src.backend.jax import distribution_lib
4442
from keras.src.backend.jax.core import Variable as BackendVariable
4543
elif backend() == "torch":
4644
from keras.src.backend.torch import * # noqa: F403

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from keras.src.backend.config import is_nnx_enabled
22
from keras.src.backend.jax import core
3+
from keras.src.backend.jax import distribution_lib
34
from keras.src.backend.jax import image
45
from keras.src.backend.jax import linalg
56
from keras.src.backend.jax import math

keras/src/backend/jax/distribution_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ def num_processes():
193193
return jax.process_count()
194194

195195

196+
def process_id():
197+
"""Return the current process ID for the distribution setting."""
198+
return jax.process_index()
199+
200+
196201
def _to_backend_device(device_name):
197202
if isinstance(device_name, jax.Device):
198203
return device_name

keras/src/backend/numpy/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,3 @@
2424
from keras.src.backend.numpy.rnn import gru
2525
from keras.src.backend.numpy.rnn import lstm
2626
from keras.src.backend.numpy.rnn import rnn
27-
28-
# Numpy backend does not support distribution
29-
distribution_lib = None

keras/src/backend/tensorflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from keras.src.backend.tensorflow import core
2+
from keras.src.backend.tensorflow import distribution_lib
23
from keras.src.backend.tensorflow import image
34
from keras.src.backend.tensorflow import linalg
45
from keras.src.backend.tensorflow import math

keras/src/wrappers/sklearn_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def use_floatx(x):
120120
"not an issue in sklearn>=1.6"
121121
),
122122
"check_pipeline_consistency": "Neural networks are non-deterministic",
123+
"check_transformer_data_not_an_array": "Neural networks are "
124+
"non-deterministic",
125+
"check_transformer_general": "Neural networks are non-deterministic",
123126
},
124127
}
125128

0 commit comments

Comments
 (0)