|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import math |
3 | 4 | from collections.abc import Iterator
|
4 |
| -from math import gcd |
5 | 5 | from pathlib import Path
|
6 | 6 | from random import Random
|
7 |
| -from typing import Any, Callable, ClassVar |
| 7 | +from typing import Any, Callable |
8 | 8 |
|
9 | 9 | import yaml
|
10 | 10 | from datasets import Features, IterableDataset, Value
|
@@ -95,8 +95,6 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
|
95 | 95 |
|
96 | 96 |
|
97 | 97 | class SyntheticTextGenerator:
|
98 |
| - PREFIX_DISTRIBUTION_PRECISION: ClassVar[int] = 1000 |
99 |
| - |
100 | 98 | def __init__(
|
101 | 99 | self,
|
102 | 100 | config: SyntheticTextDatasetConfig,
|
@@ -174,34 +172,26 @@ def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
|
174 | 172 | while True:
|
175 | 173 | yield ""
|
176 | 174 |
|
177 |
| - total_weight = sum( |
178 |
| - bucket.bucket_weight for bucket in self.config.prefix_buckets |
| 175 | + # Increase weights to ensure an integer number of samples per per-prefix |
| 176 | + least_common_prefix_count = math.lcm( |
| 177 | + *(bucket.prefix_count for bucket in self.config.prefix_buckets) |
179 | 178 | )
|
180 |
| - if total_weight <= 0: |
181 |
| - raise ValueError("Total weight of prefix buckets must be greater than 0.") |
182 |
| - |
183 |
| - # Calculate the divisor needed to achieve the minimum |
184 |
| - # number of prompts given the weight ratios |
185 |
| - percents = [ |
186 |
| - int( |
187 |
| - self.PREFIX_DISTRIBUTION_PRECISION |
188 |
| - * bucket.bucket_weight |
189 |
| - / bucket.prefix_count |
190 |
| - / total_weight |
191 |
| - ) |
| 179 | + unnorm_weights = [ |
| 180 | + least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count |
192 | 181 | for bucket in self.config.prefix_buckets
|
193 | 182 | ]
|
194 |
| - common_divisor = gcd(*percents) |
| 183 | + # Use GCD to reduce the weights to smallest integer ratio |
| 184 | + common_divisor = math.gcd(*unnorm_weights) |
195 | 185 |
|
196 | 186 | # Create prefix list maintaining the correct distribution
|
197 | 187 | prefixes = []
|
198 |
| - for bucket, percent in zip(self.config.prefix_buckets, percents): |
| 188 | + for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights): |
199 | 189 | bucket_prefixes = [
|
200 | 190 | self._create_prompt(bucket.prefix_tokens, faker)
|
201 | 191 | for _ in range(bucket.prefix_count)
|
202 | 192 | ]
|
203 |
| - sample_count = percent // common_divisor |
204 |
| - prefixes.extend([bucket_prefixes] * sample_count) |
| 193 | + sample_count = weight // common_divisor |
| 194 | + prefixes.extend(bucket_prefixes * sample_count) |
205 | 195 |
|
206 | 196 | while True:
|
207 | 197 | yield rand.choice(prefixes)
|
|
0 commit comments