@@ -134,17 +134,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False):
134134 all_features .append (features )
135135 all_feature_weights .append (feature_weights )
136136 else :
137- features = []
138- feature_weights = []
139- for _ in range (num_samples ):
140- num_ids = np .random .randint (1 , 32 )
141- ids = np .random .randint (
142- table_spec .vocabulary_size ,
143- size = (num_ids ,),
144- dtype = np .int32 ,
145- )
146- features .append (ids )
147- feature_weights .append (np .ones ((num_ids ,), dtype = np .float32 ))
137+ counts = np .random .randint (1 , 32 , size = num_samples )
138+ total_ids = np .sum (counts )
139+ ids_flat = np .random .randint (
140+ table_spec .vocabulary_size ,
141+ size = (total_ids ,),
142+ dtype = np .int32 ,
143+ )
144+ split_indices = np .cumsum (counts )[:- 1 ]
145+ features = np .split (ids_flat , split_indices )
146+ feature_weights = [np .ones ((c ,), dtype = np .float32 ) for c in counts ]
148147 all_features .append (np .array (features , dtype = object ))
149148 all_feature_weights .append (np .array (feature_weights , dtype = object ))
150149 return all_features , all_feature_weights
@@ -160,18 +159,19 @@ def generate_sparse_coo_inputs_for_feature_spec(
160159
161160 for feature_spec in feature_specs :
162161 table_spec = feature_spec .table_spec
163- indices_tensors = []
164- values_tensors = []
165- for i in range (num_samples ):
166- num_ids = np .random .randint (1 , 32 )
167- for j in range (num_ids ):
168- indices_tensors .append ([i , j ])
169- for _ in range (num_ids ):
170- values_tensors .append (np .random .randint (table_spec .vocabulary_size ))
171- all_indices_tensors .append (np .array (indices_tensors , dtype = np .int64 ))
172- all_values_tensors .append (np .array (values_tensors , dtype = np .int32 ))
162+ counts = np .random .randint (1 , 32 , size = num_samples )
163+ total_ids = np .sum (counts )
164+ values = np .random .randint (
165+ table_spec .vocabulary_size , size = total_ids , dtype = np .int32
166+ )
167+ row_indices = np .repeat (np .arange (num_samples ), counts )
168+ col_indices = np .concatenate ([np .arange (c ) for c in counts ])
169+ indices = np .stack ([row_indices , col_indices ], axis = 1 )
170+
171+ all_indices_tensors .append (indices .astype (np .int64 ))
172+ all_values_tensors .append (values )
173173 all_dense_shape_tensors .append (
174- np .array ([num_samples , vocab_size ], dtype = np .int64 )
174+ np .array ([num_samples , np . max ( counts ) ], dtype = np .int64 )
175175 )
176176 return all_indices_tensors , all_values_tensors , all_dense_shape_tensors
177177
0 commit comments