Skip to content

Commit 81ee731

Browse files
committed
Reimplement advanced prefix control
Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent a635030 commit 81ee731

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

src/guidellm/data/deserializers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
SyntheticTextDatasetConfig,
2626
SyntheticTextDatasetDeserializer,
2727
SyntheticTextGenerator,
28+
SyntheticTextPrefixBucketConfig,
2829
)
2930

3031
__all__ = [
@@ -46,6 +47,7 @@
4647
"SyntheticTextDatasetConfig",
4748
"SyntheticTextDatasetDeserializer",
4849
"SyntheticTextGenerator",
50+
"SyntheticTextPrefixBucketConfig",
4951
"TarFileDatasetDeserializer",
5052
"TextFileDatasetDeserializer",
5153
]

src/guidellm/data/deserializers/synthetic.py

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterator
4+
from math import gcd
45
from pathlib import Path
5-
from typing import Any, Callable
6+
from random import Random
7+
from typing import Any, Callable, ClassVar
68

79
import yaml
810
from datasets import Features, IterableDataset, Value
@@ -21,15 +23,33 @@
2123
"SyntheticTextDatasetConfig",
2224
"SyntheticTextDatasetDeserializer",
2325
"SyntheticTextGenerator",
26+
"SyntheticTextPrefixBucketConfig",
2427
]
2528

2629

27-
class SyntheticTextDatasetConfig(StandardBaseModel):
30+
class SyntheticTextPrefixBucketConfig(StandardBaseModel):
31+
bucket_weight: int = Field(
32+
description="Weight of this bucket in the overall distribution.",
33+
gt=0,
34+
default=100,
35+
)
36+
prefix_count: int = Field(
37+
description="The number of unique prefixs to generate for this bucket.",
38+
ge=1,
39+
default=1,
40+
)
2841
prefix_tokens: int = Field(
29-
description="The number of shared prefix tokens to prepend to each prompt.",
42+
description="The number of prefix tokens per-prompt for this bucket.",
3043
ge=0,
3144
default=0,
3245
)
46+
47+
48+
class SyntheticTextDatasetConfig(StandardBaseModel):
49+
prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
50+
description="Buckets for the prefix tokens distribution.",
51+
default=None,
52+
)
3353
prompt_tokens: int = Field(
3454
description="The average number of text tokens generated for prompts.",
3555
gt=0,
@@ -75,6 +95,8 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
7595

7696

7797
class SyntheticTextGenerator:
98+
PREFIX_DISTRIBUTION_PRECISION: ClassVar[int] = 1000
99+
78100
def __init__(
79101
self,
80102
config: SyntheticTextDatasetConfig,
@@ -110,17 +132,15 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
110132
)
111133

112134
# Create a shared prefix if specified
113-
if self.config.prefix_tokens > 0:
114-
prefix = self._create_prompt(self.config.prefix_tokens, faker)
115-
else:
116-
prefix = "" # Always have a prefix key for consistency
135+
rand = Random(self.random_seed + 3)
136+
prefix_iter = self._create_prefix_iter(faker, rand)
117137

118138
while True:
119139
prompt_tokens_count = next(prompt_tokens_sampler)
120140
output_tokens_count = next(output_tokens_sampler)
121141

122142
yield {
123-
"prefix": prefix,
143+
"prefix": next(prefix_iter),
124144
"prompt": self._create_prompt(
125145
prompt_tokens_count, faker, f"{samples_generated} "
126146
),
@@ -149,6 +169,43 @@ def _create_prompt(
149169
prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
150170
)
151171

172+
def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
173+
if not self.config.prefix_buckets:
174+
while True:
175+
yield ""
176+
177+
total_weight = sum(
178+
bucket.bucket_weight for bucket in self.config.prefix_buckets
179+
)
180+
if total_weight <= 0:
181+
raise ValueError("Total weight of prefix buckets must be greater than 0.")
182+
183+
# Calculate the divisor needed to achieve the minimum
184+
# number of prompts given the weight ratios
185+
percents = [
186+
int(
187+
self.PREFIX_DISTRIBUTION_PRECISION
188+
* bucket.bucket_weight
189+
/ bucket.prefix_count
190+
/ total_weight
191+
)
192+
for bucket in self.config.prefix_buckets
193+
]
194+
common_divisor = gcd(*percents)
195+
196+
# Create prefix list maintaining the correct distribution
197+
prefixes = []
198+
for bucket, percent in zip(self.config.prefix_buckets, percents):
199+
bucket_prefixes = [
200+
self._create_prompt(bucket.prefix_tokens, faker)
201+
for _ in range(bucket.prefix_count)
202+
]
203+
sample_count = percent // common_divisor
204+
prefixes.extend([bucket_prefixes] * sample_count)
205+
206+
while True:
207+
yield rand.choice(prefixes)
208+
152209

153210
@DatasetDeserializerFactory.register("synthetic_text")
154211
class SyntheticTextDatasetDeserializer(DatasetDeserializer):

0 commit comments

Comments
 (0)