From f91aa9d23be53e2fc89e19cf7f0b19695201e5d7 Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Tue, 25 Nov 2025 13:47:05 -0800 Subject: [PATCH] [JAX SC] Add support for loading `EmbeddingSpecProto` for benchmarking. * Added combiner to proto to load as well. * Vectorized numpy ragged input array generation (~16x speedup) to avoid wasting time (63%) in input generation for benchmark (to 10%). PiperOrigin-RevId: 836798575 --- .../sparsecore/lib/nn/embedding.py | 15 ++ .../sparsecore/lib/nn/tests/BUILD | 2 + .../nn/tests/preprocess_input_benchmarks.py | 149 ++++++++++++++---- .../sparsecore/lib/proto/embedding_spec.proto | 11 ++ 4 files changed, 146 insertions(+), 31 deletions(-) diff --git a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py index 7e3fc75a..51ea197e 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/embedding.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/embedding.py @@ -1353,6 +1353,17 @@ def init_embedding_variables( return stacked_mapping +def _combiner_to_enum(combiner: str): + if combiner == "sum": + return embedding_spec_pb2.SUM + elif combiner == "mean": + return embedding_spec_pb2.MEAN + elif combiner == "sqrtn": + return embedding_spec_pb2.SQRTN + else: + return embedding_spec_pb2.COMBINER_UNSPECIFIED + + def create_proto_from_feature_specs( feature_specs: Nested[embedding_spec.FeatureSpec], global_device_count: int | None, @@ -1394,6 +1405,9 @@ def create_proto_from_feature_specs( max_ids_per_partition=feature.table_spec.stacked_table_spec.max_ids_per_partition, num_sparsecores=(num_sparsecore_per_device * global_device_count), max_unique_ids_per_partition=feature.table_spec.stacked_table_spec.max_unique_ids_per_partition, + combiner=_combiner_to_enum( + feature.table_spec.stacked_table_spec.combiner + ), ) ) if current_table_name not in stack_to_table_specs[current_stack_name]: @@ -1406,6 +1420,7 @@ def create_proto_from_feature_specs( padded_embedding_dim=feature.table_spec.setting_in_stack.padded_embedding_dim, row_offset_in_shard=feature.table_spec.setting_in_stack.row_offset_in_shard, shard_rotation=feature.table_spec.setting_in_stack.shard_rotation, + combiner=_combiner_to_enum(feature.table_spec.combiner), ) ) feature_spec = embedding_spec_pb2.FeatureSpecProto( diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD b/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD index 2960a8b6..9d6f50d5 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD @@ -45,6 +45,7 @@ pytype_strict_binary( "//jax_tpu_embedding/sparsecore/lib/core:pybind_input_preprocessing", "//jax_tpu_embedding/sparsecore/lib/nn:embedding", "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", + "//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2", pypi_requirement("absl:app"), pypi_requirement("absl/flags"), pypi_requirement("google_benchmark"), @@ -71,6 +72,7 @@ pytype_strict_contrib_test( "//jax_tpu_embedding/sparsecore/lib/core:pybind_input_preprocessing", "//jax_tpu_embedding/sparsecore/lib/nn:embedding", "//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec", + "//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2", pypi_requirement("absl:app"), pypi_requirement("absl/flags"), pypi_requirement("google_benchmark"), diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py index b7ac4e53..a1143ef4 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py @@ -54,10 +54,16 @@ from jax_tpu_embedding.sparsecore.lib.core import pybind_input_preprocessing from jax_tpu_embedding.sparsecore.lib.nn import embedding from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec +from jax_tpu_embedding.sparsecore.lib.proto import embedding_spec_pb2 import numpy as np import portpicker +_EMBEDDING_SPEC_PATH = flags.DEFINE_string( + "embedding_spec_path", + None, + "Path to EmbeddingSpecProto to load table parameters from.", +) _NUM_FEATURES = flags.DEFINE_integer("num_features", 100, "Number of features.") _NUM_SAMPLES = flags.DEFINE_integer("num_samples", 16000, "Number of samples.") @@ -71,6 +77,79 @@ _GLOBAL_RAGGED_DENSE_SHAPES = None +def _combiner_enum_to_string(combiner_enum): + if combiner_enum == embedding_spec_pb2.SUM: + return "sum" + elif combiner_enum == embedding_spec_pb2.MEAN: + return "mean" + elif combiner_enum == embedding_spec_pb2.SQRTN: + return "sqrtn" + else: # COMBINER_UNSPECIFIED or any other value + return "sum" + + +def load_feature_specs_from_proto( + path: str, +) -> list[embedding_spec.FeatureSpec]: + """Loads feature specs from EmbeddingSpecProto.""" + with open(path, "rb") as f: + proto = embedding_spec_pb2.EmbeddingSpecProto.FromString(f.read()) + feature_specs = [] + for stacked_table_pb in proto.stacked_table_specs: + stacked_table_spec = embedding_spec.StackedTableSpec( + stack_name=stacked_table_pb.stack_name, + stack_vocab_size=stacked_table_pb.stack_vocab_size, + stack_embedding_dim=stacked_table_pb.stack_embedding_dim, + optimizer=embedding_spec.SGDOptimizerSpec( + learning_rate=0.001, + ), + combiner=_combiner_enum_to_string(stacked_table_pb.combiner), + total_sample_count=stacked_table_pb.total_sample_count, + max_ids_per_partition=stacked_table_pb.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table_pb.max_unique_ids_per_partition, + ) + for table_pb in stacked_table_pb.table_specs: + setting_in_stack = embedding_spec.TableSettingInStack( + stack_name=stacked_table_pb.stack_name, + padded_vocab_size=table_pb.padded_vocab_size, + padded_embedding_dim=table_pb.padded_embedding_dim, + row_offset_in_shard=table_pb.row_offset_in_shard, + shard_rotation=table_pb.shard_rotation, + ) + table_spec = embedding_spec.TableSpec( + vocabulary_size=table_pb.vocab_size, + embedding_dim=table_pb.embedding_dim, + initializer=( + lambda sz=table_pb.vocab_size, dim=table_pb.embedding_dim: np.zeros( + (sz, dim), dtype=np.float32 + ) + ), + optimizer=embedding_spec.SGDOptimizerSpec( + learning_rate=0.001, + ), + combiner=_combiner_enum_to_string(table_pb.combiner), + name=table_pb.table_name, + max_ids_per_partition=stacked_table_pb.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table_pb.max_unique_ids_per_partition, + _setting_in_stack=setting_in_stack, + _stacked_table_spec=stacked_table_spec, + ) + for feature_pb in table_pb.feature_specs: + feature_spec = embedding_spec.FeatureSpec( + table_spec=table_spec, + input_shape=list(feature_pb.input_shape), + output_shape=list(feature_pb.output_shape), + name=feature_pb.feature_name, + _id_transformation=embedding_spec.FeatureIdTransformation( + row_offset=feature_pb.row_offset, + col_offset=feature_pb.col_offset, + col_shift=feature_pb.col_shift, + ), + ) + feature_specs.append(feature_spec) + return feature_specs + + def generate_feature_specs(num_features: int, num_samples: int): """Generates feature specs for the given number of features.""" feature_specs = [] @@ -134,17 +213,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False): all_features.append(features) all_feature_weights.append(feature_weights) else: - features = [] - feature_weights = [] - for _ in range(num_samples): - num_ids = np.random.randint(1, 32) - ids = np.random.randint( - table_spec.vocabulary_size, - size=(num_ids,), - dtype=np.int32, - ) - features.append(ids) - feature_weights.append(np.ones((num_ids,), dtype=np.float32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + ids_flat = np.random.randint( + table_spec.vocabulary_size, + size=(total_ids,), + dtype=np.int32, + ) + split_indices = np.cumsum(counts)[:-1] + features = np.split(ids_flat, split_indices) + feature_weights = [np.ones((c,), dtype=np.float32) for c in counts] all_features.append(np.array(features, dtype=object)) all_feature_weights.append(np.array(feature_weights, dtype=object)) return all_features, all_feature_weights @@ -160,18 +238,19 @@ def generate_sparse_coo_inputs_for_feature_spec( for feature_spec in feature_specs: table_spec = feature_spec.table_spec - indices_tensors = [] - values_tensors = [] - for i in range(num_samples): - num_ids = np.random.randint(1, 32) - for j in range(num_ids): - indices_tensors.append([i, j]) - for _ in range(num_ids): - values_tensors.append(np.random.randint(table_spec.vocabulary_size)) - all_indices_tensors.append(np.array(indices_tensors, dtype=np.int64)) - all_values_tensors.append(np.array(values_tensors, dtype=np.int32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + values = np.random.randint( + table_spec.vocabulary_size, size=total_ids, dtype=np.int32 + ) + row_indices = np.repeat(np.arange(num_samples), counts) + col_indices = np.concatenate([np.arange(c) for c in counts]) + indices = np.stack([row_indices, col_indices], axis=1) + + all_indices_tensors.append(indices.astype(np.int64)) + all_values_tensors.append(values) all_dense_shape_tensors.append( - np.array([num_samples, vocab_size], dtype=np.int64) + np.array([num_samples, np.max(counts)], dtype=np.int64) ) return all_indices_tensors, all_values_tensors, all_dense_shape_tensors @@ -307,7 +386,9 @@ def worker(host_id: int, batch_number: int): enable_minibatching=True, all_reduce_interface=all_reduce_interfaces[host_id], batch_number=batch_number, - allow_id_dropping=batch_number == 0, + # Loaded specs may drop ids even with minibatching. + allow_id_dropping=_EMBEDDING_SPEC_PATH.value is not None + or batch_number == 0, ) with concurrent.futures.ThreadPoolExecutor(max_workers=num_hosts) as executor: @@ -333,12 +414,14 @@ def worker(host_id: int, batch_number: int): return if batch_num == 0: assert stats_cc is not None - apply_fdo_stats(stats_cc, fdo_headroom=0.5, buffer_size_headroom=1.0) + # Force minibatch creation, and provide enough buffer space. + apply_fdo_stats(stats_cc, fdo_headroom=0.5, buffer_size_headroom=3.0) state.resume_timing() else: - # Make sure we are benchmarking multiple minibatches (guranteed by - # initial seed). - assert num_minibatches >= 5, num_minibatches + if not _EMBEDDING_SPEC_PATH.value: + # Make sure we are benchmarking multiple minibatches (guranteed by + # initial seed). + assert num_minibatches >= 5, num_minibatches batch_num += 1 @@ -352,11 +435,15 @@ def main(_): global _GLOBAL_RAGGED_VALUES global _GLOBAL_RAGGED_DENSE_SHAPES - # Total local batch size that is measured is 16000x100 = 1,600,000. np.random.seed(0) - _GLOBAL_SPECS = generate_feature_specs( - _NUM_FEATURES.value, _NUM_SAMPLES.value - ) + if _EMBEDDING_SPEC_PATH.value: + _GLOBAL_SPECS = load_feature_specs_from_proto(_EMBEDDING_SPEC_PATH.value) + else: + # Total local batch size that is measured is 16000 samples x 100 features = + # 1,600,000. + _GLOBAL_SPECS = generate_feature_specs( + _NUM_FEATURES.value, _NUM_SAMPLES.value + ) _GLOBAL_RAGGED_FEATURES, _GLOBAL_RAGGED_WEIGHTS = ( generate_samples_for_feature_spec( _GLOBAL_SPECS, _NUM_SAMPLES.value, ragged=True diff --git a/jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto b/jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto index 5381e093..d1f0b5cd 100644 --- a/jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto +++ b/jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto @@ -35,6 +35,13 @@ message FeatureSpecProto { int64 col_shift = 6; } +enum Combiner { + COMBINER_UNSPECIFIED = 0; + SUM = 1; + MEAN = 2; + SQRTN = 3; +} + message TableSpecProto { // The name of the table. string table_name = 1; @@ -66,6 +73,8 @@ message TableSpecProto { int64 shard_rotation = 10; // The list of features that point to this table. repeated FeatureSpecProto feature_specs = 11; + // Combiner strategy for the table. + Combiner combiner = 12; } message StackedTableSpecProto { @@ -93,6 +102,8 @@ message StackedTableSpecProto { int64 num_sparsecores = 9; // Specs for the table that are stacked. repeated TableSpecProto table_specs = 10; + // Combiner strategy for the stack. + Combiner combiner = 11; } message EmbeddingSpecProto {