Skip to content

Commit 26def0f

Browse files
authored
Misc fixes: (#9491)
1 parent 55b7d02 commit 26def0f

File tree

4 files changed

+27
-6
lines changed

4 files changed

+27
-6
lines changed

torchax/test/test_interop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torchax import interop, jax_device
66
import torchax
77
import jax
8+
import jax.numpy as jnp
89

910

1011
def is_tpu_available():
@@ -171,6 +172,17 @@ def test_to_jax_device(self):
171172
self.assertEqual(d.jax_device.platform, "cpu")
172173
self.assertEqual(d.device.type, "jax")
173174

175+
def test_torch_jax_view_dtype(self):
176+
dtype = torch.float32
177+
self.assertEqual(interop.jax_view(dtype), jnp.float32.dtype)
178+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
179+
dtype = torch.bfloat16
180+
self.assertEqual(interop.jax_view(dtype), jnp.bfloat16.dtype)
181+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
182+
dtype = torch.int32
183+
self.assertEqual(interop.jax_view(dtype), jnp.int32.dtype)
184+
self.assertEqual(interop.torch_view(interop.jax_view(dtype)), dtype)
185+
174186

175187
if __name__ == '__main__':
176188
unittest.main()

torchax/torchax/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,11 @@ def extract_jax(mod: torch.nn.Module, env=None):
4949
states = env.t2j_copy(states)
5050

5151
#@jax.jit
52-
def jax_func(states, inputs):
53-
(states, inputs) = env.j2t_iso((states, inputs))
52+
def jax_func(states, args, kwargs=None):
53+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
5454
with env:
55-
res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
55+
res = torch.func.functional_call(
56+
mod, states, args, kwargs, tie_weights=False)
5657
return env.t2j_iso(res)
5758

5859
return states, jax_func

torchax/torchax/device_module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import torch
2+
3+
14
def _is_in_bad_fork():
25
return False
36

@@ -24,3 +27,7 @@ def is_available():
2427

2528
def current_device():
2629
return 0
30+
31+
32+
def get_amp_supported_dtype():
33+
return [torch.float16, torch.bfloat16]

torchax/torchax/interop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from jax.experimental.shard_map import shard_map
1212
from torchax import tensor
1313
from torchax import util
14+
from torchax.ops import mappings
1415
import torchax
1516

1617
from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
@@ -183,8 +184,8 @@ def _torch_view(t: JaxValue) -> TorchValue:
183184
if isinstance(t, jax.Array):
184185
# TODO
185186
return tensor.Tensor(t, torchax.default_env())
186-
if isinstance(t, type(jnp.int32)):
187-
return tensor.t2j_type(t)
187+
if isinstance(t, jnp.dtype):
188+
return mappings.j2t_dtype(t)
188189
if callable(t): # t is a JaxCallable
189190
return functools.partial(call_jax, t)
190191
# regular types are not changed
@@ -201,7 +202,7 @@ def _jax_view(t: TorchValue) -> JaxValue:
201202
assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
202203
return t.jax()
203204
if isinstance(t, type(torch.int32)):
204-
return tensor.t2j_dtype(t)
205+
return mappings.t2j_dtype(t)
205206

206207
# torch.nn.Module needs special handling
207208
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable

0 commit comments

Comments
 (0)