We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 26db621 commit baf057fCopy full SHA for baf057f
tpu_inference/envs.py
@@ -8,7 +8,6 @@
8
9
if TYPE_CHECKING:
10
JAX_PLATFORMS: str = ""
11
- JAX_RANDOM_WEIGHTS: bool = False
12
TPU_ACCELERATOR_TYPE: str | None = None
13
TPU_NAME: str | None = None
14
TPU_WORKER_ID: str | None = None
@@ -28,9 +27,6 @@
28
27
# JAX platform selection (e.g., "tpu", "cpu", "proxy")
29
"JAX_PLATFORMS":
30
lambda: os.getenv("JAX_PLATFORMS", ""),
31
- # Initialize model weights randomly instead of loading from checkpoint
32
- "JAX_RANDOM_WEIGHTS":
33
- lambda: bool(int(os.getenv("JAX_RANDOM_WEIGHTS", "0"))),
34
# TPU accelerator type (e.g., "v5litepod-16", "v4-8")
35
"TPU_ACCELERATOR_TYPE":
36
lambda: os.getenv("TPU_ACCELERATOR_TYPE", None),
0 commit comments