diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index 851a2d6103e..700d581d773 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -1017,11 +1017,11 @@ def _aten_bucketize(input, def _aten_conv2d( input, weight, - bias, - stride, - padding, - dilation, - groups, + bias=None, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1, ): return _aten_convolution( input,