@@ -1030,9 +1030,9 @@ def quantized_conv2d_nhwc_meta(
10301030 get_conv1d_output_size (
10311031 in_size ,
10321032 out_channels ,
1033- stride [1 ],
1034- padding [1 ],
1035- dilation [1 ],
1033+ stride [- 1 ],
1034+ padding [- 1 ],
1035+ dilation [- 1 ],
10361036 kernel_size [0 ],
10371037 True ,
10381038 )
@@ -1074,9 +1074,9 @@ def quantized_conv2d_nchw_meta(
10741074 get_conv1d_output_size (
10751075 in_size ,
10761076 out_channels ,
1077- stride [1 ],
1078- padding [1 ],
1079- dilation [1 ],
1077+ stride [- 1 ],
1078+ padding [- 1 ],
1079+ dilation [- 1 ],
10801080 kernel_size [0 ],
10811081 False ,
10821082 )
@@ -1118,9 +1118,9 @@ def quantized_conv2d_nchw_per_tensor_meta(
11181118 get_conv1d_output_size (
11191119 in_size ,
11201120 out_channels ,
1121- stride [1 ],
1122- padding [1 ],
1123- dilation [1 ],
1121+ stride [- 1 ],
1122+ padding [- 1 ],
1123+ dilation [- 1 ],
11241124 kernel_size [0 ],
11251125 False ,
11261126 )
@@ -1162,9 +1162,9 @@ def quantized_conv2d_nhwc_per_tensor_meta(
11621162 get_conv1d_output_size (
11631163 in_size ,
11641164 out_channels ,
1165- stride [1 ],
1166- padding [1 ],
1167- dilation [1 ],
1165+ stride [- 1 ],
1166+ padding [- 1 ],
1167+ dilation [- 1 ],
11681168 kernel_size [0 ],
11691169 True ,
11701170 )
@@ -1211,9 +1211,9 @@ def quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
12111211 get_conv1d_output_size (
12121212 in_size ,
12131213 out_channels ,
1214- stride [1 ],
1215- padding [1 ],
1216- dilation [1 ],
1214+ stride [- 1 ],
1215+ padding [- 1 ],
1216+ dilation [- 1 ],
12171217 kernel_size [0 ],
12181218 False ,
12191219 )
@@ -1260,9 +1260,9 @@ def quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
12601260 get_conv1d_output_size (
12611261 in_size ,
12621262 out_channels ,
1263- stride [1 ],
1264- padding [1 ],
1265- dilation [1 ],
1263+ stride [- 1 ],
1264+ padding [- 1 ],
1265+ dilation [- 1 ],
12661266 kernel_size [0 ],
12671267 False ,
12681268 )
@@ -1309,9 +1309,9 @@ def quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
13091309 get_conv1d_output_size (
13101310 in_size ,
13111311 out_channels ,
1312- stride [1 ],
1313- padding [1 ],
1314- dilation [1 ],
1312+ stride [- 1 ],
1313+ padding [- 1 ],
1314+ dilation [- 1 ],
13151315 kernel_size [0 ],
13161316 True ,
13171317 )
@@ -1358,9 +1358,9 @@ def quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
13581358 get_conv1d_output_size (
13591359 in_size ,
13601360 out_channels ,
1361- stride [1 ],
1362- padding [1 ],
1363- dilation [1 ],
1361+ stride [- 1 ],
1362+ padding [- 1 ],
1363+ dilation [- 1 ],
13641364 kernel_size [0 ],
13651365 True ,
13661366 )
@@ -1407,9 +1407,9 @@ def quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
14071407 get_conv1d_output_size (
14081408 in_size ,
14091409 out_channels ,
1410- stride [1 ],
1411- padding [1 ],
1412- dilation [1 ],
1410+ stride [- 1 ],
1411+ padding [- 1 ],
1412+ dilation [- 1 ],
14131413 kernel_size [0 ],
14141414 False ,
14151415 )
@@ -1456,9 +1456,9 @@ def quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
14561456 get_conv1d_output_size (
14571457 in_size ,
14581458 out_channels ,
1459- stride [1 ],
1460- padding [1 ],
1461- dilation [1 ],
1459+ stride [- 1 ],
1460+ padding [- 1 ],
1461+ dilation [- 1 ],
14621462 kernel_size [0 ],
14631463 False ,
14641464 )
@@ -1505,9 +1505,9 @@ def quantized_conv2d_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
15051505 get_conv1d_output_size (
15061506 in_size ,
15071507 out_channels ,
1508- stride [1 ],
1509- padding [1 ],
1510- dilation [1 ],
1508+ stride [- 1 ],
1509+ padding [- 1 ],
1510+ dilation [- 1 ],
15111511 kernel_size [0 ],
15121512 True ,
15131513 )
@@ -1554,9 +1554,9 @@ def quantized_conv2d_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
15541554 get_conv1d_output_size (
15551555 in_size ,
15561556 out_channels ,
1557- stride [1 ],
1558- padding [1 ],
1559- dilation [1 ],
1557+ stride [- 1 ],
1558+ padding [- 1 ],
1559+ dilation [- 1 ],
15601560 kernel_size [0 ],
15611561 True ,
15621562 )
@@ -1605,9 +1605,9 @@ def quantized_conv2d_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
16051605 get_conv1d_output_size (
16061606 in_size ,
16071607 out_channels ,
1608- stride [1 ],
1609- padding [1 ],
1610- dilation [1 ],
1608+ stride [- 1 ],
1609+ padding [- 1 ],
1610+ dilation [- 1 ],
16111611 kernel_size [0 ],
16121612 False ,
16131613 )
@@ -1656,9 +1656,9 @@ def quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
16561656 get_conv1d_output_size (
16571657 in_size ,
16581658 out_channels ,
1659- stride [1 ],
1660- padding [1 ],
1661- dilation [1 ],
1659+ stride [- 1 ],
1660+ padding [- 1 ],
1661+ dilation [- 1 ],
16621662 kernel_size [0 ],
16631663 False ,
16641664 )
@@ -1707,9 +1707,9 @@ def quantized_conv2d_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
17071707 get_conv1d_output_size (
17081708 in_size ,
17091709 out_channels ,
1710- stride [1 ],
1711- padding [1 ],
1712- dilation [1 ],
1710+ stride [- 1 ],
1711+ padding [- 1 ],
1712+ dilation [- 1 ],
17131713 kernel_size [0 ],
17141714 True ,
17151715 )
@@ -1758,9 +1758,9 @@ def quantized_conv2d_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
17581758 get_conv1d_output_size (
17591759 in_size ,
17601760 out_channels ,
1761- stride [1 ],
1762- padding [1 ],
1763- dilation [1 ],
1761+ stride [- 1 ],
1762+ padding [- 1 ],
1763+ dilation [- 1 ],
17641764 kernel_size [0 ],
17651765 True ,
17661766 )
@@ -2178,15 +2178,31 @@ def conv1d_meta(
21782178 dilation : Tuple [int ],
21792179 groups : int ,
21802180) -> torch .Tensor :
2181+ # Validate tensor dimensions
2182+ assert len (input .shape ) == 3 , f"Conv1d expects 3D input, got { len (input .shape )} D"
2183+ assert len (weight .shape ) == 3 , f"Conv1d expects 3D weight, got { len (weight .shape )} D"
2184+
2185+ # Extract dimensions
2186+ batch_size , in_channels , length = input .shape
2187+ out_channels , weight_in_channels , kernel_size = weight .shape
2188+
2189+ # Validate groups parameter and channel consistency
2190+ assert groups > 0 , f"groups must be positive, got { groups } "
21812191 assert (
2182- len (weight .shape ) == 3
2183- ), f"Conv1d expects a 3D weight, got { len (weight .shape )} D"
2184- out_channels , _ , kernel_size = weight .shape
2185- in_size = input .shape
2186- assert len (in_size ) == 3 , f"conv1d expects 3D input, got { len (in_size )} D"
2192+ in_channels % groups == 0
2193+ ), f"in_channels ({ in_channels } ) must be divisible by groups ({ groups } )"
2194+ assert (
2195+ out_channels % groups == 0
2196+ ), f"out_channels ({ out_channels } ) must be divisible by groups ({ groups } )"
2197+
2198+ # Validate weight channels match input channels divided by groups
2199+ expected_weight_in_channels = in_channels // groups
2200+ assert (
2201+ weight_in_channels == expected_weight_in_channels
2202+ ), f"Expected weight to have { expected_weight_in_channels } input channels (in_channels/groups), but got { weight_in_channels } "
21872203
21882204 output_size = get_conv1d_output_size (
2189- in_size ,
2205+ input . shape ,
21902206 out_channels ,
21912207 stride [0 ],
21922208 padding [0 ],
0 commit comments