diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py index b7ac4e53..9ab0f3f8 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py @@ -134,17 +134,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False): all_features.append(features) all_feature_weights.append(feature_weights) else: - features = [] - feature_weights = [] - for _ in range(num_samples): - num_ids = np.random.randint(1, 32) - ids = np.random.randint( - table_spec.vocabulary_size, - size=(num_ids,), - dtype=np.int32, - ) - features.append(ids) - feature_weights.append(np.ones((num_ids,), dtype=np.float32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + ids_flat = np.random.randint( + table_spec.vocabulary_size, + size=(total_ids,), + dtype=np.int32, + ) + split_indices = np.cumsum(counts)[:-1] + features = np.split(ids_flat, split_indices) + feature_weights = [np.ones((c,), dtype=np.float32) for c in counts] all_features.append(np.array(features, dtype=object)) all_feature_weights.append(np.array(feature_weights, dtype=object)) return all_features, all_feature_weights @@ -160,16 +159,17 @@ def generate_sparse_coo_inputs_for_feature_spec( for feature_spec in feature_specs: table_spec = feature_spec.table_spec - indices_tensors = [] - values_tensors = [] - for i in range(num_samples): - num_ids = np.random.randint(1, 32) - for j in range(num_ids): - indices_tensors.append([i, j]) - for _ in range(num_ids): - values_tensors.append(np.random.randint(table_spec.vocabulary_size)) - all_indices_tensors.append(np.array(indices_tensors, dtype=np.int64)) - all_values_tensors.append(np.array(values_tensors, dtype=np.int32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + values = np.random.randint( + table_spec.vocabulary_size, size=total_ids, dtype=np.int32 + ) + row_indices = np.repeat(np.arange(num_samples), counts) + col_indices = np.concatenate([np.arange(c) for c in counts]) + indices = np.stack([row_indices, col_indices], axis=1) + + all_indices_tensors.append(indices.astype(np.int64)) + all_values_tensors.append(values) all_dense_shape_tensors.append( np.array([num_samples, vocab_size], dtype=np.int64) )