Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 44 additions & 30 deletions examples/multimodal_vision/gemma3_example.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,69 @@
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from transformers import (
AutoProcessor,
DataCollatorWithPadding,
Gemma3ForConditionalGeneration,
)

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.utils import dispatch_for_generation

# Load model.
model_id = "google/gemma-3-4b-it"
model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
collator = DataCollatorWithPadding(processor.tokenizer)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048
BATCH_SIZE = 512
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
# Define a oneshot data collator for multimodal processors
# remove extra dim added by vision processor
def data_collator(features: list[dict[str, object]]):
features = [{key: feature[key][0] for key in feature} for feature in features]
return collator(features)


# Recipe
recipe = [
GPTQModifier(
targets="Linear",
scheme="W4A16",
ignore=[
"lm_head",
r"re:model\.vision_tower.*",
r"re:model\.multi_modal_projector.*",
],
),
# GPTQModifier(
# targets="Linear",
# scheme="W4A16",
# ignore=[
# "lm_head",
# r"re:model\.vision_tower.*",
# r"re:model\.multi_modal_projector.*",
# ],
# ),
]

# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
)
from pttp import TensorProfiler

with TensorProfiler() as prof:
# Perform oneshot
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
batch_size=BATCH_SIZE,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
data_collator=data_collator,
trust_remote_code_model=True,
pipeline="sequential",
)
import torch
del prof._memory.timeline[torch.device("cpu")]
prof.save_memory_timeline("with_disable.png")
exit(0)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
Expand Down
22 changes: 9 additions & 13 deletions examples/multimodal_vision/idefics3_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import requests
import torch
from datasets import load_dataset
from PIL import Image
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
Expand All @@ -14,16 +13,11 @@
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here


# Define a oneshot data collator for multimodal inputs.
def data_collator(batch):
assert len(batch) == 1
return {key: torch.tensor(value) for key, value in batch[0].items()}
MAX_SEQUENCE_LENGTH = 4096
BATCH_SIZE = 512
DATASET_ID = "lmms-lab/flickr30k"
DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]"


# Recipe
Expand Down Expand Up @@ -69,14 +63,17 @@ def preprocess(example):

# Tokenize inputs.
def tokenize(sample):
return processor(
features = processor(
text=sample["text"],
images=sample["images"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
)

# remove extra dim added by vision processor
return [{key: feature[key][0] for key in feature} for feature in features]


# avoid errors with writer_batch_size
ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names)
Expand All @@ -86,10 +83,9 @@ def tokenize(sample):
model=model,
dataset=ds,
recipe=recipe,
batch_size=BATCH_SIZE,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
data_collator=data_collator,
sequential_targets=["LlamaDecoderLayer"],
)

Expand Down
7 changes: 6 additions & 1 deletion examples/quantization_w4a16/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# Select model and load it.
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
#model_id = "meta-llama/Llama-3.2-1B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

Expand All @@ -16,8 +17,9 @@

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 512
NUM_CALIBRATION_SAMPLES = 32
MAX_SEQUENCE_LENGTH = 2048
BATCH_SIZE = 16

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
Expand Down Expand Up @@ -58,9 +60,12 @@ def tokenize(sample):
model=model,
dataset=ds,
recipe=recipe,
batch_size=BATCH_SIZE,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
pipeline="sequential",
)
exit(0)

# Confirm generations of the quantized model look sane.
print("\n\n")
Expand Down
26 changes: 20 additions & 6 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
"""

from dataclasses import dataclass, field
from typing import Any, Callable

from transformers import DefaultDataCollator
from typing import Callable, Optional


@dataclass
Expand Down Expand Up @@ -69,9 +67,25 @@ class CustomDatasetArguments(DVCDatasetArguments):
},
)

data_collator: Callable[[Any], Any] = field(
default_factory=lambda: DefaultDataCollator(),
metadata={"help": "The function to used to form a batch from the dataset"},
data_collator: Optional[Callable] = field(
default=None,
metadata={
"help": (
"The function to used to form a batch from the dataset. Defaults to "
"`DataCollatorWithPadding(processor)`."
)
},
)

batch_size: int = field(
default=1,
metadata={
"help": (
"Calibration batch size. During calibration, LLM Compressor disables "
"lm_head output computations to reduce memory usage from large "
"calibration matches"
)
},
)


Expand Down
53 changes: 30 additions & 23 deletions src/llmcompressor/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
one-shot calibration workflows.
"""

import math
import multiprocessing
import re
from typing import Any, Callable

import torch
from datasets import Dataset
from loguru import logger
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.data import default_data_collator
from torch.utils.data import DataLoader, SequentialSampler
from transformers.data import DataCollatorWithPadding

from llmcompressor.args import DatasetArguments
from llmcompressor.transformers.data import TextGenerationDataset
Expand Down Expand Up @@ -115,44 +116,53 @@ def get_calibration_dataloader(
)

calibration_dataset = datasets.get("calibration")
tokenizer = getattr(processor, "tokenizer", processor)
collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer)
if tokenizer.pad_token is None or tokenizer.pad_token_id < 0:
logger.debug("Could not find padding token. Setting PAD token to EOS token")
tokenizer.pad_token = tokenizer.eos_token

return format_calibration_data(
tokenized_dataset=calibration_dataset,
collate_fn=collate_fn,
batch_size=dataset_args.batch_size,
num_calibration_samples=dataset_args.num_calibration_samples,
do_shuffle=dataset_args.shuffle_calibration_samples,
collate_fn=dataset_args.data_collator,
)


def format_calibration_data(
tokenized_dataset: Dataset,
collate_fn: Callable,
batch_size: int = 1,
num_calibration_samples: int | None = None,
do_shuffle: bool = True,
collate_fn: Callable = default_data_collator,
) -> list[torch.Tensor]:
"""
Creates a dataloader out of the calibration dataset split, trimming it to
the desired number of calibration samples
:param tokenized_dataset: dataset to convert to dataloader
:param num_calibration_samples: number of data samples to convert
:param num_calibration_samples: number of batches to convert
:param do_shuffle: whether to shuffle the dataset before selecting calibration
samples, true by default
:param collate_fn: optional custom collate function, or use default
:return: list of trimmed calibration data tensors
"""
safe_calibration_samples = len(tokenized_dataset)
# (1) shuffle dataset
if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()

# (2) truncate dataset
if num_calibration_samples is not None:
safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples)
if safe_calibration_samples != num_calibration_samples:
if num_calibration_samples > len(tokenized_dataset):
logger.warning(
f"Requested {num_calibration_samples} calibration samples but "
f"the provided dataset only has {safe_calibration_samples}. "
f"Requested {num_calibration_samples} calibration samples but the "
f"provided dataset only has {len(tokenized_dataset)} samples."
)
num_calibration_samples = len(tokenized_dataset)
tokenized_dataset = tokenized_dataset.select(range(num_calibration_samples))

if do_shuffle:
tokenized_dataset = tokenized_dataset.shuffle()
tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples))

# (3) infer number of workers
MAX_DATALOADER_WORKERS = 8
try:
num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2)
Expand All @@ -161,19 +171,16 @@ def format_calibration_data(
"Could not determine number of CPUs, defaulting to 0 dataloader workers."
)
num_workers = 0

# (4) create dataloader
dataloader_params = {
"batch_size": 1,
"sampler": RandomSampler(tokenized_calibration)
if do_shuffle
else SequentialSampler(tokenized_calibration),
"batch_size": batch_size,
"sampler": SequentialSampler(tokenized_dataset),
"collate_fn": collate_fn,
"pin_memory": True,
"pin_memory": False,
"num_workers": num_workers,
}

calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params)

return calibration_dataloader
return DataLoader(tokenized_dataset, **dataloader_params)


def make_dataset_splits(
Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Optional

from loguru import logger
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -248,6 +248,8 @@ def oneshot(
dataset_config_name: str | None = None,
dataset_path: str | None = None,
splits: str | list[str] | dict[str, str] | None = None,
batch_size: int = 1,
data_collator: Optional[Callable] = None,
num_calibration_samples: int = 512,
shuffle_calibration_samples: bool = True,
max_seq_length: int = 384,
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,7 +1073,7 @@ def disable_lm_head(model: torch.nn.Module):
does not untie parameters and restores the model proper loading upon exit
"""
_, lm_head = get_embeddings(model)
if lm_head is not None:
if lm_head is None:
logger.warning(
f"Attempted to disable lm_head of instance {model.__class__.__name__}, "
"but was unable to to find lm_head. This may lead to unexpected OOM."
Expand Down