Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/nn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,17 @@ def init_embedding_variables(
return stacked_mapping


def _combiner_to_enum(combiner: str):
if combiner == "sum":
return embedding_spec_pb2.SUM
elif combiner == "mean":
return embedding_spec_pb2.MEAN
elif combiner == "sqrtn":
return embedding_spec_pb2.SQRTN
else:
return embedding_spec_pb2.COMBINER_UNSPECIFIED


def create_proto_from_feature_specs(
feature_specs: Nested[embedding_spec.FeatureSpec],
global_device_count: int | None,
Expand Down Expand Up @@ -1394,6 +1405,9 @@ def create_proto_from_feature_specs(
max_ids_per_partition=feature.table_spec.stacked_table_spec.max_ids_per_partition,
num_sparsecores=(num_sparsecore_per_device * global_device_count),
max_unique_ids_per_partition=feature.table_spec.stacked_table_spec.max_unique_ids_per_partition,
combiner=_combiner_to_enum(
feature.table_spec.stacked_table_spec.combiner
),
)
)
if current_table_name not in stack_to_table_specs[current_stack_name]:
Expand All @@ -1406,6 +1420,7 @@ def create_proto_from_feature_specs(
padded_embedding_dim=feature.table_spec.setting_in_stack.padded_embedding_dim,
row_offset_in_shard=feature.table_spec.setting_in_stack.row_offset_in_shard,
shard_rotation=feature.table_spec.setting_in_stack.shard_rotation,
combiner=_combiner_to_enum(feature.table_spec.combiner),
)
)
feature_spec = embedding_spec_pb2.FeatureSpecProto(
Expand Down
2 changes: 2 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pytype_strict_binary(
"//jax_tpu_embedding/sparsecore/lib/core:pybind_input_preprocessing",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",
pypi_requirement("absl:app"),
pypi_requirement("absl/flags"),
pypi_requirement("google_benchmark"),
Expand All @@ -71,6 +72,7 @@ pytype_strict_contrib_test(
"//jax_tpu_embedding/sparsecore/lib/core:pybind_input_preprocessing",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",
pypi_requirement("absl:app"),
pypi_requirement("absl/flags"),
pypi_requirement("google_benchmark"),
Expand Down
149 changes: 118 additions & 31 deletions jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,16 @@
from jax_tpu_embedding.sparsecore.lib.core import pybind_input_preprocessing
from jax_tpu_embedding.sparsecore.lib.nn import embedding
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
from jax_tpu_embedding.sparsecore.lib.proto import embedding_spec_pb2
import numpy as np
import portpicker


_EMBEDDING_SPEC_PATH = flags.DEFINE_string(
"embedding_spec_path",
None,
"Path to EmbeddingSpecProto to load table parameters from.",
)
_NUM_FEATURES = flags.DEFINE_integer("num_features", 100, "Number of features.")
_NUM_SAMPLES = flags.DEFINE_integer("num_samples", 16000, "Number of samples.")

Expand All @@ -71,6 +77,79 @@
_GLOBAL_RAGGED_DENSE_SHAPES = None


def _combiner_enum_to_string(combiner_enum):
if combiner_enum == embedding_spec_pb2.SUM:
return "sum"
elif combiner_enum == embedding_spec_pb2.MEAN:
return "mean"
elif combiner_enum == embedding_spec_pb2.SQRTN:
return "sqrtn"
else: # COMBINER_UNSPECIFIED or any other value
return "sum"


def load_feature_specs_from_proto(
path: str,
) -> list[embedding_spec.FeatureSpec]:
"""Loads feature specs from EmbeddingSpecProto."""
with open(path, "rb") as f:
proto = embedding_spec_pb2.EmbeddingSpecProto.FromString(f.read())
feature_specs = []
for stacked_table_pb in proto.stacked_table_specs:
stacked_table_spec = embedding_spec.StackedTableSpec(
stack_name=stacked_table_pb.stack_name,
stack_vocab_size=stacked_table_pb.stack_vocab_size,
stack_embedding_dim=stacked_table_pb.stack_embedding_dim,
optimizer=embedding_spec.SGDOptimizerSpec(
learning_rate=0.001,
),
combiner=_combiner_enum_to_string(stacked_table_pb.combiner),
total_sample_count=stacked_table_pb.total_sample_count,
max_ids_per_partition=stacked_table_pb.max_ids_per_partition,
max_unique_ids_per_partition=stacked_table_pb.max_unique_ids_per_partition,
)
for table_pb in stacked_table_pb.table_specs:
setting_in_stack = embedding_spec.TableSettingInStack(
stack_name=stacked_table_pb.stack_name,
padded_vocab_size=table_pb.padded_vocab_size,
padded_embedding_dim=table_pb.padded_embedding_dim,
row_offset_in_shard=table_pb.row_offset_in_shard,
shard_rotation=table_pb.shard_rotation,
)
table_spec = embedding_spec.TableSpec(
vocabulary_size=table_pb.vocab_size,
embedding_dim=table_pb.embedding_dim,
initializer=(
lambda sz=table_pb.vocab_size, dim=table_pb.embedding_dim: np.zeros(
(sz, dim), dtype=np.float32
)
),
optimizer=embedding_spec.SGDOptimizerSpec(
learning_rate=0.001,
),
combiner=_combiner_enum_to_string(table_pb.combiner),
name=table_pb.table_name,
max_ids_per_partition=stacked_table_pb.max_ids_per_partition,
max_unique_ids_per_partition=stacked_table_pb.max_unique_ids_per_partition,
_setting_in_stack=setting_in_stack,
_stacked_table_spec=stacked_table_spec,
)
for feature_pb in table_pb.feature_specs:
feature_spec = embedding_spec.FeatureSpec(
table_spec=table_spec,
input_shape=list(feature_pb.input_shape),
output_shape=list(feature_pb.output_shape),
name=feature_pb.feature_name,
_id_transformation=embedding_spec.FeatureIdTransformation(
row_offset=feature_pb.row_offset,
col_offset=feature_pb.col_offset,
col_shift=feature_pb.col_shift,
),
)
feature_specs.append(feature_spec)
return feature_specs


def generate_feature_specs(num_features: int, num_samples: int):
"""Generates feature specs for the given number of features."""
feature_specs = []
Expand Down Expand Up @@ -134,17 +213,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False):
all_features.append(features)
all_feature_weights.append(feature_weights)
else:
features = []
feature_weights = []
for _ in range(num_samples):
num_ids = np.random.randint(1, 32)
ids = np.random.randint(
table_spec.vocabulary_size,
size=(num_ids,),
dtype=np.int32,
)
features.append(ids)
feature_weights.append(np.ones((num_ids,), dtype=np.float32))
counts = np.random.randint(1, 32, size=num_samples)
total_ids = np.sum(counts)
ids_flat = np.random.randint(
table_spec.vocabulary_size,
size=(total_ids,),
dtype=np.int32,
)
split_indices = np.cumsum(counts)[:-1]
features = np.split(ids_flat, split_indices)
feature_weights = [np.ones((c,), dtype=np.float32) for c in counts]
all_features.append(np.array(features, dtype=object))
all_feature_weights.append(np.array(feature_weights, dtype=object))
return all_features, all_feature_weights
Expand All @@ -160,18 +238,19 @@ def generate_sparse_coo_inputs_for_feature_spec(

for feature_spec in feature_specs:
table_spec = feature_spec.table_spec
indices_tensors = []
values_tensors = []
for i in range(num_samples):
num_ids = np.random.randint(1, 32)
for j in range(num_ids):
indices_tensors.append([i, j])
for _ in range(num_ids):
values_tensors.append(np.random.randint(table_spec.vocabulary_size))
all_indices_tensors.append(np.array(indices_tensors, dtype=np.int64))
all_values_tensors.append(np.array(values_tensors, dtype=np.int32))
counts = np.random.randint(1, 32, size=num_samples)
total_ids = np.sum(counts)
values = np.random.randint(
table_spec.vocabulary_size, size=total_ids, dtype=np.int32
)
row_indices = np.repeat(np.arange(num_samples), counts)
col_indices = np.concatenate([np.arange(c) for c in counts])
indices = np.stack([row_indices, col_indices], axis=1)

all_indices_tensors.append(indices.astype(np.int64))
all_values_tensors.append(values)
all_dense_shape_tensors.append(
np.array([num_samples, vocab_size], dtype=np.int64)
np.array([num_samples, np.max(counts)], dtype=np.int64)
)
return all_indices_tensors, all_values_tensors, all_dense_shape_tensors

Expand Down Expand Up @@ -307,7 +386,9 @@ def worker(host_id: int, batch_number: int):
enable_minibatching=True,
all_reduce_interface=all_reduce_interfaces[host_id],
batch_number=batch_number,
allow_id_dropping=batch_number == 0,
# Loaded specs may drop ids even with minibatching.
allow_id_dropping=_EMBEDDING_SPEC_PATH.value is not None
or batch_number == 0,
)

with concurrent.futures.ThreadPoolExecutor(max_workers=num_hosts) as executor:
Expand All @@ -333,12 +414,14 @@ def worker(host_id: int, batch_number: int):
return
if batch_num == 0:
assert stats_cc is not None
apply_fdo_stats(stats_cc, fdo_headroom=0.5, buffer_size_headroom=1.0)
# Force minibatch creation, and provide enough buffer space.
apply_fdo_stats(stats_cc, fdo_headroom=0.5, buffer_size_headroom=3.0)
state.resume_timing()
else:
# Make sure we are benchmarking multiple minibatches (guranteed by
# initial seed).
assert num_minibatches >= 5, num_minibatches
if not _EMBEDDING_SPEC_PATH.value:
# Make sure we are benchmarking multiple minibatches (guranteed by
# initial seed).
assert num_minibatches >= 5, num_minibatches
batch_num += 1


Expand All @@ -352,11 +435,15 @@ def main(_):
global _GLOBAL_RAGGED_VALUES
global _GLOBAL_RAGGED_DENSE_SHAPES

# Total local batch size that is measured is 16000x100 = 1,600,000.
np.random.seed(0)
_GLOBAL_SPECS = generate_feature_specs(
_NUM_FEATURES.value, _NUM_SAMPLES.value
)
if _EMBEDDING_SPEC_PATH.value:
_GLOBAL_SPECS = load_feature_specs_from_proto(_EMBEDDING_SPEC_PATH.value)
else:
# Total local batch size that is measured is 16000 samples x 100 features =
# 1,600,000.
_GLOBAL_SPECS = generate_feature_specs(
_NUM_FEATURES.value, _NUM_SAMPLES.value
)
_GLOBAL_RAGGED_FEATURES, _GLOBAL_RAGGED_WEIGHTS = (
generate_samples_for_feature_spec(
_GLOBAL_SPECS, _NUM_SAMPLES.value, ragged=True
Expand Down
11 changes: 11 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/proto/embedding_spec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ message FeatureSpecProto {
int64 col_shift = 6;
}

enum Combiner {
COMBINER_UNSPECIFIED = 0;
SUM = 1;
MEAN = 2;
SQRTN = 3;
}

message TableSpecProto {
// The name of the table.
string table_name = 1;
Expand Down Expand Up @@ -66,6 +73,8 @@ message TableSpecProto {
int64 shard_rotation = 10;
// The list of features that point to this table.
repeated FeatureSpecProto feature_specs = 11;
// Combiner strategy for the table.
Combiner combiner = 12;
}

message StackedTableSpecProto {
Expand Down Expand Up @@ -93,6 +102,8 @@ message StackedTableSpecProto {
int64 num_sparsecores = 9;
// Specs for the table that are stacked.
repeated TableSpecProto table_specs = 10;
// Combiner strategy for the stack.
Combiner combiner = 11;
}

message EmbeddingSpecProto {
Expand Down
Loading