Skip to content

Commit dd7a4b8

Browse files
authored
[GuideLLM Refactor] Advanced Prefix Cache Controls (#382)
## TODO - Docs - ~CSV arg string support~ CSV arg string now supports single bucket (see last example). Might leave it at that for now. - More validation ## Summary <!-- Include a short paragraph of the changes introduced in this PR. If this PR requires additional context or rationale, explain why the changes are necessary. --> This PR is a port of #287 to the v0.4.0 refactor branch. Adds controls for sharing one or more fixed prefixes between samples. See examples bellow. ## Details <!-- Provide a detailed list of all changes introduced in this pull request. --> Adds a `prefix_buckets` argument to the `SyntheticTextDatasetConfig`, each bucket consists of a prefix count, token count, and bucket weight. Prefix count sets the number of unique prefixes to generate for a given bucket, token count is the length of each prompt in the bucket, and bucket weight is used to calculate the proportion of requests the bucket applies to relative to the sum of all bucket weights. Here are a few examples: Here we have one bucket of 32 prefixes of length 2048. Since there are 1024 total samples each prefix will apply to 32 samples. If there is only one bucket than weight can be omitted as the bucket applies to 100% of samples. ```yaml data: prefix_buckets: - prefix_tokens: 2048 prefix_count: 32 prompt_tokens: 256 output_tokens: 256 samples: 1024 ``` In this modified version of the first example 16 of the prompts have 2048 tokens while the other 16 have 1024 tokens. ```yaml data: prefix_buckets: - prefix_tokens: 2048 prefix_count: 16 bucket_weight: 50 - prefix_tokens: 1024 prefix_count: 16 bucket_weight: 50 prompt_tokens: 256 output_tokens: 256 samples: 1024 ``` The prefix tokens of a bucket can also be 0 to disable prefixes for those samples. Here is an example where 40% of the samples have a prefix of 2048 tokens while the other 60% have no prefix. ```yaml data: prefix_buckets: - prefix_tokens: 2048 bucket_weight: 40 - prefix_tokens: 0 bucket_weight: 60 prompt_tokens: 256 output_tokens: 256 samples: 1000 ``` If only a single bucket is needed, it can be set at the top level. This make the changes backwards compatible with the previous interface and allows the CSV string format to work without parsing nested structures (at least for this use-case). ```yaml data: prefix_tokens: 128 prefix_count: 10 prompt_tokens: 256 output_tokens: 256 samples: 1000 ``` ## Test Plan <!-- List the steps needed to test this PR. --> - PR includes unit tests for all synthetic dataset changes (`pytest tests/unit/dataset`) - Scenearios in the Details section can be used against a model server with prefix caching and the cache rate can be confirmed by inspecting console output. ## Related Issues <!-- Link any relevant issues that this PR addresses. --> - Resolves #232 - Closes #287 --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [x] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`) --------- Signed-off-by: Samuel Monson <smonson@redhat.com>
1 parent 730eeb1 commit dd7a4b8

File tree

11 files changed

+2035
-2020
lines changed

11 files changed

+2035
-2020
lines changed

pylock.toml

Lines changed: 1331 additions & 1134 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ include = ["*"]
1313
[tool.pdm]
1414
distribution = true
1515

16+
[[tool.pdm.source]]
17+
name = "torch"
18+
type = "find_links"
19+
#url = "https://download.pytorch.org/whl/cpu/torch_stable.html"
20+
url = "https://download.pytorch.org/whl/cpu/torch/"
21+
include_packages = ["torch"]
22+
1623

1724
# ************************************************
1825
# ********** Project Metadata **********
@@ -64,6 +71,8 @@ dependencies = [
6471
"sanic",
6572
"transformers",
6673
"uvloop>=0.18",
74+
"librosa>=0.11.0",
75+
"torch>=2.8.0",
6776
]
6877

6978
[project.optional-dependencies]

src/guidellm/data/deserializers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
SyntheticTextDatasetConfig,
2626
SyntheticTextDatasetDeserializer,
2727
SyntheticTextGenerator,
28+
SyntheticTextPrefixBucketConfig,
2829
)
2930

3031
__all__ = [
@@ -46,6 +47,7 @@
4647
"SyntheticTextDatasetConfig",
4748
"SyntheticTextDatasetDeserializer",
4849
"SyntheticTextGenerator",
50+
"SyntheticTextPrefixBucketConfig",
4951
"TarFileDatasetDeserializer",
5052
"TextFileDatasetDeserializer",
5153
]

src/guidellm/data/deserializers/synthetic.py

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3+
import math
34
from collections.abc import Iterator
45
from pathlib import Path
5-
from typing import Any, Callable
6+
from random import Random
7+
from typing import Any, Callable, Self
68

79
import yaml
810
from datasets import Features, IterableDataset, Value
911
from faker import Faker
10-
from pydantic import Field
12+
from pydantic import ConfigDict, Field, model_validator
1113
from transformers import PreTrainedTokenizerBase
1214

1315
from guidellm.data.deserializers.deserializer import (
@@ -21,10 +23,37 @@
2123
"SyntheticTextDatasetConfig",
2224
"SyntheticTextDatasetDeserializer",
2325
"SyntheticTextGenerator",
26+
"SyntheticTextPrefixBucketConfig",
2427
]
2528

2629

30+
class SyntheticTextPrefixBucketConfig(StandardBaseModel):
31+
bucket_weight: int = Field(
32+
description="Weight of this bucket in the overall distribution.",
33+
gt=0,
34+
default=100,
35+
)
36+
prefix_count: int = Field(
37+
description="The number of unique prefixes to generate for this bucket.",
38+
ge=1,
39+
default=1,
40+
)
41+
prefix_tokens: int = Field(
42+
description="The number of prefix tokens per-prompt for this bucket.",
43+
ge=0,
44+
default=0,
45+
)
46+
47+
2748
class SyntheticTextDatasetConfig(StandardBaseModel):
49+
model_config = ConfigDict(
50+
extra="allow",
51+
)
52+
53+
prefix_buckets: list[SyntheticTextPrefixBucketConfig] | None = Field(
54+
description="Buckets for the prefix tokens distribution.",
55+
default=None,
56+
)
2857
prompt_tokens: int = Field(
2958
description="The average number of text tokens generated for prompts.",
3059
gt=0,
@@ -68,6 +97,26 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
6897
default="data:prideandprejudice.txt.gz",
6998
)
7099

100+
@model_validator(mode="after")
101+
def check_prefix_options(self) -> Self:
102+
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
103+
prefix_tokens = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
104+
if prefix_count is not None or prefix_tokens is not None:
105+
if self.prefix_buckets:
106+
raise ValueError(
107+
"prefix_buckets is mutually exclusive"
108+
" with prefix_count and prefix_tokens"
109+
)
110+
111+
self.prefix_buckets = [
112+
SyntheticTextPrefixBucketConfig(
113+
prefix_count=prefix_count or 1,
114+
prefix_tokens=prefix_tokens or 0,
115+
)
116+
]
117+
118+
return self
119+
71120

72121
class SyntheticTextGenerator:
73122
def __init__(
@@ -104,20 +153,27 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
104153
)
105154
)
106155

156+
# Create a shared prefix if specified
157+
rand = Random(self.random_seed + 3)
158+
prefix_iter = self._create_prefix_iter(faker, rand)
159+
107160
while True:
108161
prompt_tokens_count = next(prompt_tokens_sampler)
109162
output_tokens_count = next(output_tokens_sampler)
110163

111164
yield {
165+
"prefix": next(prefix_iter),
112166
"prompt": self._create_prompt(
113-
prompt_tokens_count, samples_generated, faker
167+
prompt_tokens_count, faker, f"{samples_generated} "
114168
),
115169
"prompt_tokens_count": prompt_tokens_count,
116170
"output_tokens_count": output_tokens_count,
117171
}
118172
samples_generated += 1
119173

120-
def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) -> str:
174+
def _create_prompt(
175+
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
176+
) -> str:
121177
prompt_token_ids = []
122178
avg_chars_per_token = 5
123179
margin_of_safety = 1.5
@@ -128,13 +184,42 @@ def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) ->
128184
num_chars = (
129185
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
130186
)
131-
text = f"{index} " + faker.text(max_nb_chars=num_chars)
187+
text = unique + faker.text(max_nb_chars=num_chars)
132188
prompt_token_ids = self.processor.encode(text)
133189

134190
return self.processor.decode(
135191
prompt_token_ids[:prompt_tokens_count], skip_special_tokens=True
136192
)
137193

194+
def _create_prefix_iter(self, faker: Faker, rand: Random) -> Iterator[str]:
195+
if not self.config.prefix_buckets:
196+
while True:
197+
yield ""
198+
199+
# Increase weights to ensure an integer number of samples per per-prefix
200+
least_common_prefix_count = math.lcm(
201+
*(bucket.prefix_count for bucket in self.config.prefix_buckets)
202+
)
203+
unnorm_weights = [
204+
least_common_prefix_count * bucket.bucket_weight // bucket.prefix_count
205+
for bucket in self.config.prefix_buckets
206+
]
207+
# Use GCD to reduce the weights to smallest integer ratio
208+
common_divisor = math.gcd(*unnorm_weights)
209+
210+
# Create prefix list maintaining the correct distribution
211+
prefixes = []
212+
for bucket, weight in zip(self.config.prefix_buckets, unnorm_weights):
213+
bucket_prefixes = [
214+
self._create_prompt(bucket.prefix_tokens, faker)
215+
for _ in range(bucket.prefix_count)
216+
]
217+
sample_count = weight // common_divisor
218+
prefixes.extend(bucket_prefixes * sample_count)
219+
220+
while True:
221+
yield rand.choice(prefixes)
222+
138223

139224
@DatasetDeserializerFactory.register("synthetic_text")
140225
class SyntheticTextDatasetDeserializer(DatasetDeserializer):
@@ -166,6 +251,7 @@ def __call__(
166251
),
167252
features=Features(
168253
{
254+
"prefix": Value("string"),
169255
"prompt": Value("string"),
170256
"prompt_tokens_count": Value("int32"),
171257
"output_tokens_count": Value("int32"),

src/guidellm/data/formatters/templates.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@ class JinjaTemplatesRegistry(RegistryMixin[Union[Template, str]]):
2222
textwrap.dedent("""
2323
{% set obj = {
2424
"json_body": {
25-
"prompt": (
26-
text_column[0]
27-
if text_column and text_column|length == 1
28-
else text_column
29-
)
25+
"prompt": prefix_column[0]|default("") + text_column[0]
3026
}
3127
} %}
3228
@@ -52,6 +48,10 @@ class JinjaTemplatesRegistry(RegistryMixin[Union[Template, str]]):
5248
{% set obj = {
5349
"json_body": {
5450
"messages": [
51+
{
52+
"role": "system",
53+
"content": prefix_column[0]|default("")
54+
},
5555
{
5656
"role": "user",
5757
"content": []
@@ -61,11 +61,11 @@ class JinjaTemplatesRegistry(RegistryMixin[Union[Template, str]]):
6161
} %}
6262
6363
{%- for item in text_column or [] %}
64-
{% do obj["json_body"].messages[0].content.append({"type": "text", "text": item}) %}
64+
{% do obj["json_body"].messages[1].content.append({"type": "text", "text": item}) %}
6565
{%- endfor %}
6666
6767
{%- for item in image_column or [] %}
68-
{% do obj["json_body"].messages[0].content.append({
68+
{% do obj["json_body"].messages[1].content.append({
6969
"type": "image_url",
7070
"image_url": encode_image(
7171
item,
@@ -78,7 +78,7 @@ class JinjaTemplatesRegistry(RegistryMixin[Union[Template, str]]):
7878
{%- endfor %}
7979
8080
{%- for item in video_column or [] %}
81-
{% do obj["json_body"].messages[0].content.append({
81+
{% do obj["json_body"].messages[1].content.append({
8282
"type": "video_url",
8383
"video_url": encode_video(
8484
item,

src/guidellm/data/objects.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
GenerativeDatasetColumnType = Literal[
3232
"prompt_tokens_count_column",
3333
"output_tokens_count_column",
34+
"prefix_column",
3435
"text_column",
3536
"image_column",
3637
"video_column",
@@ -195,6 +196,7 @@ class GenerativeDatasetArgs(StandardBaseDict):
195196
split: str | None = None
196197
prompt_tokens_count_column: str | None = None
197198
output_tokens_count_column: str | None = None
199+
prefix_column: str | None = None
198200
text_column: str | list[str] | None = None
199201
image_column: str | list[str] | None = None
200202
video_column: str | list[str] | None = None

src/guidellm/data/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@
8080
DEFAULT_COLUMN_NAMES: dict[str, list[str]] = {
8181
"prompt_tokens_count": ["prompt_tokens_count", "input_tokens_count"],
8282
"output_tokens_count": ["output_tokens_count", "completion_tokens_count"],
83+
"prefix_column": [
84+
"system_prompt",
85+
"system",
86+
"prefix",
87+
],
8388
"text_column": [
8489
"prompt",
8590
"instruction",
File renamed without changes.

tests/unit/data/deserializers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)