Skip to content

Commit 6b5bae7

Browse files
cantoniosThe sparsecore Authors
authored andcommitted
Refactor CPU vs TPU tests.
TPU tests are marked by the tag `requires-tpu`. CPU tests are anything else. CPU tests need to be configured to disable all devices except CPU, otherwise the test may hang when searching for TPUs. For TPU tests, we add the required LIBTPU flags in .bazelrc. PiperOrigin-RevId: 715081513
1 parent 8fd28a4 commit 6b5bae7

File tree

10 files changed

+31
-11
lines changed

10 files changed

+31
-11
lines changed

.bazelrc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ build:clang --copt=-Qunused-arguments
101101
##############################################################################
102102
# Test configurations.
103103
##############################################################################
104-
test:cpu --test_env=JAX_PLATFORMS=cpu --test_tag_filters=cpu
104+
# Configure TPU for SparseCore usage.
105+
test --test_env="LIBTPU_INIT_ARGS=--2a886c8_chip_config_name=megachip_tccontrol"
106+
# Show output for failing tests.
107+
test --test_output=errors
105108

106109
#############################################################################
107110
# Some configs to make getting some forms of debug builds. In general, the

.github/workflows/build_and_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,4 @@ jobs:
7070
7171
- name: Run CPU tests
7272
run: |
73-
bazel test --config=cpu --test_output=errors --keep_going //...
73+
bazel test --test_tag_filters=-requires-tpu --test_output=errors --keep_going //...

jax_tpu_embedding/sparsecore/examples/shakespeare/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pytype_strict_contrib_test(
3434
main = "jax_sc_shakespeare_jit.py",
3535
tags = [
3636
"exclusive-if-local",
37+
"requires-tpu",
3738
],
3839
deps = [
3940
"//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset",
@@ -105,6 +106,7 @@ pytype_strict_contrib_test(
105106
main = "jax_sc_shakespeare_pmap.py",
106107
tags = [
107108
"exclusive-if-local",
109+
"requires-tpu",
108110
],
109111
deps = [
110112
"//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset",

jax_tpu_embedding/sparsecore/lib/checkpointing/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pytype_strict_library(
3131
pytype_strict_contrib_test(
3232
name = "checkpoint_utils_test",
3333
srcs = ["checkpoint_utils_test.py"],
34-
tags = ["cpu"],
34+
env = {"JAX_PLATFORMS": "cpu"},
3535
deps = [
3636
":checkpoint_utils",
3737
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ cc_library(
3636
cc_test(
3737
name = "input_preprocessing_threads_test",
3838
srcs = ["input_preprocessing_threads_test.cc"],
39-
tags = ["cpu"],
39+
env = {"JAX_PLATFORMS": "cpu"},
4040
deps = [
4141
":input_preprocessing_threads",
4242
"@com_google_googletest//:gtest_main",
@@ -65,7 +65,7 @@ cc_library(
6565
cc_test(
6666
name = "input_preprocessing_util_test",
6767
srcs = ["input_preprocessing_util_test.cc"],
68-
tags = ["cpu"],
68+
env = {"JAX_PLATFORMS": "cpu"},
6969
deps = [
7070
":input_preprocessing_util",
7171
"@com_google_googletest//:gtest_main",
@@ -118,9 +118,9 @@ pytype_strict_contrib_test(
118118
"input_preprocessing_test.py",
119119
],
120120
env = {
121+
"JAX_PLATFORMS": "cpu",
121122
"XLA_FLAGS": "--xla_dump_to=sponge",
122123
},
123-
tags = ["cpu"],
124124
deps = [
125125
":constants",
126126
":input_preprocessing",
@@ -136,9 +136,9 @@ pytype_strict_contrib_test(
136136
"input_preprocessing_cc_test.py",
137137
],
138138
env = {
139+
"JAX_PLATFORMS": "cpu",
139140
"XLA_FLAGS": "--xla_dump_to=sponge",
140141
},
141-
tags = ["cpu"],
142142
deps = [
143143
":constants",
144144
":input_preprocessing_cc",

jax_tpu_embedding/sparsecore/lib/core/primitives/tests/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pytype_strict_contrib_test(
3333
},
3434
tags = [
3535
"exclusive-if-local",
36+
"requires-tpu",
3637
],
3738
deps = [
3839
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",
@@ -59,6 +60,7 @@ pytype_strict_contrib_test(
5960
},
6061
tags = [
6162
"exclusive-if-local",
63+
"requires-tpu",
6264
],
6365
deps = [
6466
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",
@@ -85,6 +87,7 @@ pytype_strict_contrib_test(
8587
},
8688
tags = [
8789
"exclusive-if-local",
90+
"requires-tpu",
8891
],
8992
deps = [
9093
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_sgd_with_mini_batching",
@@ -111,6 +114,7 @@ pytype_strict_contrib_test(
111114
},
112115
tags = [
113116
"exclusive-if-local",
117+
"requires-tpu",
114118
],
115119
deps = [
116120
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adagrad_with_mini_batching",
@@ -138,6 +142,7 @@ pytype_strict_contrib_test(
138142
},
139143
tags = [
140144
"exclusive-if-local",
145+
"requires-tpu",
141146
],
142147
deps = [
143148
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",
@@ -164,6 +169,7 @@ pytype_strict_contrib_test(
164169
},
165170
tags = [
166171
"exclusive-if-local",
172+
"requires-tpu",
167173
],
168174
deps = [
169175
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",
@@ -188,6 +194,7 @@ pytype_strict_contrib_test(
188194
],
189195
tags = [
190196
"exclusive-if-local",
197+
"requires-tpu",
191198
],
192199
deps = [
193200
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",

jax_tpu_embedding/sparsecore/lib/fdo/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pytype_strict_library(
3939
pytype_strict_contrib_test(
4040
name = "file_fdo_client_test",
4141
srcs = ["file_fdo_client_test.py"],
42-
tags = ["cpu"],
42+
env = {"JAX_PLATFORMS": "cpu"},
4343
deps = [
4444
":file_fdo_client",
4545
pypi_requirement("absl/testing:absltest"),

jax_tpu_embedding/sparsecore/lib/flax/tests/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pytype_strict_contrib_test(
3030
],
3131
tags = [
3232
"exclusive-if-local",
33+
"requires-tpu",
3334
],
3435
deps = [
3536
"//jax_tpu_embedding/sparsecore/lib/flax:embed",
@@ -61,6 +62,7 @@ pytype_strict_contrib_test(
6162
},
6263
tags = [
6364
"exclusive-if-local",
65+
"requires-tpu",
6466
],
6567
deps = [
6668
"//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset",

jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pytype_strict_contrib_test(
2525
srcs = [
2626
"preprocess_sparse_dense_matmul_input_test.py",
2727
],
28-
tags = ["cpu"],
28+
env = {"JAX_PLATFORMS": "cpu"},
2929
deps = [
3030
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing",
3131
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
@@ -64,7 +64,7 @@ pytype_strict_library(
6464
pytype_strict_contrib_test(
6565
name = "test_utils_test",
6666
srcs = ["test_utils_test.py"],
67-
tags = ["cpu"],
67+
env = {"JAX_PLATFORMS": "cpu"},
6868
deps = [
6969
":test_utils",
7070
pypi_requirement("absl/logging"),
@@ -89,6 +89,7 @@ pytype_strict_contrib_test(
8989
},
9090
tags = [
9191
"exclusive-if-local",
92+
"requires-tpu",
9293
],
9394
deps = [
9495
":test_utils",
@@ -118,6 +119,7 @@ pytype_strict_contrib_test(
118119
},
119120
tags = [
120121
"exclusive-if-local",
122+
"requires-tpu",
121123
],
122124
deps = [
123125
":test_utils",
@@ -143,6 +145,7 @@ pytype_strict_contrib_test(
143145
],
144146
tags = [
145147
"exclusive-if-local",
148+
"requires-tpu",
146149
],
147150
deps = [
148151
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
@@ -168,6 +171,7 @@ pytype_strict_contrib_test(
168171
],
169172
tags = [
170173
"exclusive-if-local",
174+
"requires-tpu",
171175
],
172176
deps = [
173177
":test_utils",
@@ -185,7 +189,7 @@ pytype_strict_contrib_test(
185189
pytype_strict_contrib_test(
186190
name = "embedding_spec_test",
187191
srcs = ["embedding_spec_test.py"],
188-
tags = ["cpu"],
192+
env = {"JAX_PLATFORMS": "cpu"},
189193
deps = [
190194
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
191195
pypi_requirement("absl/testing:absltest"),

jax_tpu_embedding/sparsecore/tests/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pytype_strict_contrib_test(
3131
},
3232
tags = [
3333
"exclusive-if-local",
34+
"requires-tpu",
3435
],
3536
deps = [
3637
"//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset",
@@ -60,6 +61,7 @@ pytype_strict_contrib_test(
6061
},
6162
tags = [
6263
"exclusive-if-local",
64+
"requires-tpu",
6365
],
6466
deps = [
6567
"//jax_tpu_embedding/sparsecore/examples/models/shakespeare:dataset",

0 commit comments

Comments
 (0)