I would like the following not to fail with PyTorch:
> >> import array_api_compat .torch as xp
> >> data = xp .linspace (0 , 1 , num = 5 , device = "mps" )
> >> xp .clip (data , 0.1 , 0.9 )
Traceback (most recent call last ):
Cell In [4 ], line 1
xp .clip (data , 0.1 , 0.9 )
File ~ / miniforge3 / envs / dev / lib / python3 .11 / site - packages / array_api_compat / _internal .py :28 in wrapped_f
return f (* args , xp = xp , ** kwargs )
File ~ / miniforge3 / envs / dev / lib / python3 .11 / site - packages / array_api_compat / common / _aliases .py :317 in clip
ia = (out < a ) | xp .isnan (a )
RuntimeError : Expected all tensors to be on the same device , but found at least two devices , mps :0 and cpu !
At the moment, we need to be overly verbose to use xp.clip with pytorch on non-cpu tensors:
> >> from array_api_compat import device
> >> device_ = device (data )
> >> xp .clip (data , xp .asarray (0.1 , device = device_ ), xp .asarray (0.9 , device = device_ ))
tensor ([0.1000 , 0.2500 , 0.5000 , 0.7500 , 0.9000 ], device = 'mps:0' )👍 React with 👍 1lucascolley