Skip to content

Commit 07da84f

Browse files
committed
Smarter weighing
Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent 81ee731 commit 07da84f

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

src/guidellm/data/deserializers/synthetic.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3+
import math
34
from collections.abc import Iterator
4-
from math import gcd
55
from pathlib import Path
66
from random import Random
7-
from typing import Any, Callable, ClassVar
7+
from typing import Any, Callable
88

99
import yaml
1010
from datasets import Features, IterableDataset, Value
@@ -95,8 +95,6 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
9595

9696

9797
class SyntheticTextGenerator:
98-
PREFIX_DISTRIBUTION_PRECISION: ClassVar[int] = 1000
99-
10098
def __init__(
10199
self,
102100
config: SyntheticTextDatasetConfig,
@@ -174,34 +172,26 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
174172
while True:
175173
yield ""
176174

177-
total_weight = sum(
178-
bucket.bucket_weight for bucket in self.config.prefix_buckets
175+
# Increase weights to ensure an integer number of samples per per-prefix
176+
least_common_prefix_count = math.lcm(
177+
*(bucket.prefix_count for bucket in self.config.prefix_buckets)
179178
)
180-
if total_weight <= 0:
181-
raise ValueError("Total weight of prefix buckets must be greater than 0.")
182-
183-
# Calculate the divisor needed to achieve the minimum
184-
# number of prompts given the weight ratios
185-
percents = [
186-
int(
187-
self.PREFIX_DISTRIBUTION_PRECISION
188-
* bucket.bucket_weight
189-
/ bucket.prefix_count
190-
/ total_weight
191-
)
179+
unnorm_weights = [
180+
least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
192181
for bucket in self.config.prefix_buckets
193182
]
194-
common_divisor = gcd(*percents)
183+
# Use GCD to reduce the weights to smallest integer ratio
184+
common_divisor = math.gcd(*unnorm_weights)
195185

196186
# Create prefix list maintaining the correct distribution
197187
prefixes = []
198-
for bucket, percent in zip(self.config.prefix_buckets, percents):
188+
for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights):
199189
bucket_prefixes = [
200190
self._create_prompt(bucket.prefix_tokens, faker)
201191
for _ in range(bucket.prefix_count)
202192
]
203-
sample_count = percent // common_divisor
204-
prefixes.extend([bucket_prefixes] * sample_count)
193+
sample_count = weight // common_divisor
194+
prefixes.extend(bucket_prefixes * sample_count)
205195

206196
while True:
207197
yield rand.choice(prefixes)

0 commit comments

Comments
 (0)