Skip to content

Commit d08aea0

Browse files
[JAX SC] Vectorize ragged input array generation (~16x speedup).
This is particularly bad with lot of tables and larges batch size wasting upto ~63% of the benchmark time in generation, this vectorization brings it close to 10% now. PiperOrigin-RevId: 836946285
1 parent 3ddc44d commit d08aea0

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False):
134134
all_features.append(features)
135135
all_feature_weights.append(feature_weights)
136136
else:
137-
features = []
138-
feature_weights = []
139-
for _ in range(num_samples):
140-
num_ids = np.random.randint(1, 32)
141-
ids = np.random.randint(
142-
table_spec.vocabulary_size,
143-
size=(num_ids,),
144-
dtype=np.int32,
145-
)
146-
features.append(ids)
147-
feature_weights.append(np.ones((num_ids,), dtype=np.float32))
137+
counts = np.random.randint(1, 32, size=num_samples)
138+
total_ids = np.sum(counts)
139+
ids_flat = np.random.randint(
140+
table_spec.vocabulary_size,
141+
size=(total_ids,),
142+
dtype=np.int32,
143+
)
144+
split_indices = np.cumsum(counts)[:-1]
145+
features = np.split(ids_flat, split_indices)
146+
feature_weights = [np.ones((c,), dtype=np.float32) for c in counts]
148147
all_features.append(np.array(features, dtype=object))
149148
all_feature_weights.append(np.array(feature_weights, dtype=object))
150149
return all_features, all_feature_weights
@@ -160,18 +159,19 @@ def generate_sparse_coo_inputs_for_feature_spec(
160159

161160
for feature_spec in feature_specs:
162161
table_spec = feature_spec.table_spec
163-
indices_tensors = []
164-
values_tensors = []
165-
for i in range(num_samples):
166-
num_ids = np.random.randint(1, 32)
167-
for j in range(num_ids):
168-
indices_tensors.append([i, j])
169-
for _ in range(num_ids):
170-
values_tensors.append(np.random.randint(table_spec.vocabulary_size))
171-
all_indices_tensors.append(np.array(indices_tensors, dtype=np.int64))
172-
all_values_tensors.append(np.array(values_tensors, dtype=np.int32))
162+
counts = np.random.randint(1, 32, size=num_samples)
163+
total_ids = np.sum(counts)
164+
values = np.random.randint(
165+
table_spec.vocabulary_size, size=total_ids, dtype=np.int32
166+
)
167+
row_indices = np.repeat(np.arange(num_samples), counts)
168+
col_indices = np.concatenate([np.arange(c) for c in counts])
169+
indices = np.stack([row_indices, col_indices], axis=1)
170+
171+
all_indices_tensors.append(indices.astype(np.int64))
172+
all_values_tensors.append(values)
173173
all_dense_shape_tensors.append(
174-
np.array([num_samples, vocab_size], dtype=np.int64)
174+
np.array([num_samples, np.max(counts)], dtype=np.int64)
175175
)
176176
return all_indices_tensors, all_values_tensors, all_dense_shape_tensors
177177

0 commit comments

Comments
 (0)