@@ -87,16 +87,7 @@ def expand_boolean_indices(
87
87
# nonzero returns shape [N, dims], we need to extract dim i
88
88
if len (indices ) == 1 :
89
89
# x[mask] — 1D mask
90
- squeeze_layer = ctx .net .add_shuffle (nonzero_indices )
91
- squeeze_layer .reshape_dims = (- 1 ,)
92
- set_layer_name (
93
- squeeze_layer ,
94
- target ,
95
- name + f"_bool_nonzero_squeeze_{ i } " ,
96
- source_ir ,
97
- )
98
- squeezed_index = squeeze_layer .get_output (0 )
99
- new_indices .append (squeezed_index )
90
+ to_squeeze = nonzero_indices
100
91
else :
101
92
# Advanced multi-axis mask: extract index i from shape [N, D]
102
93
gather_axis = 1 # dim index
@@ -108,11 +99,17 @@ def expand_boolean_indices(
108
99
set_layer_name (
109
100
gather_layer , target , name + f"_bool_nonzero_extract_{ i } " , source_ir
110
101
)
111
- extracted_index = gather_layer .get_output (0 )
112
- squeeze_layer = ctx .net .add_shuffle (extracted_index )
113
- squeeze_layer .reshape_dims = (- 1 ,)
114
- squeezed_index = squeeze_layer .get_output (0 )
115
- new_indices .append (squeezed_index )
102
+ to_squeeze = gather_layer .get_output (0 )
103
+ squeeze_layer = ctx .net .add_shuffle (to_squeeze )
104
+ squeeze_layer .reshape_dims = (- 1 ,)
105
+ set_layer_name (
106
+ squeeze_layer ,
107
+ target ,
108
+ name + f"_bool_mask_squeeze_{ i } " ,
109
+ source_ir ,
110
+ )
111
+ squeezed_index = squeeze_layer .get_output (0 )
112
+ new_indices .append (squeezed_index )
116
113
else :
117
114
new_indices .append (ind )
118
115
return new_indices
0 commit comments