diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index dce35b7b83..7f9b7cfabe 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -1,55 +1,69 @@ import requests -import torch from PIL import Image -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import ( + AutoProcessor, + DataCollatorWithPadding, + Gemma3ForConditionalGeneration, +) from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.utils import dispatch_for_generation # Load model. model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) +collator = DataCollatorWithPadding(processor.tokenizer) # Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 512 +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +# Define a oneshot data collator for multimodal processors +# remove extra dim added by vision processor +def data_collator(features: list[dict[str, object]]): + features = [{key: feature[key][0] for key in feature} for feature in features] + return collator(features) # Recipe recipe = [ - GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=[ - "lm_head", - r"re:model\.vision_tower.*", - r"re:model\.multi_modal_projector.*", - ], - ), + # GPTQModifier( + # targets="Linear", + # scheme="W4A16", + # ignore=[ + # "lm_head", + # r"re:model\.vision_tower.*", + # r"re:model\.multi_modal_projector.*", + # ], + # ), ] -# Perform oneshot -oneshot( - model=model, - tokenizer=model_id, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, - data_collator=data_collator, -) +from pttp import TensorProfiler + +with TensorProfiler() as prof: + # Perform oneshot + oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + batch_size=BATCH_SIZE, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + data_collator=data_collator, + trust_remote_code_model=True, + pipeline="sequential", + ) +import torch +del prof._memory.timeline[torch.device("cpu")] +prof.save_memory_timeline("with_disable.png") +exit(0) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index 2fdaeb1a4a..a3c75722d6 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -1,5 +1,4 @@ import requests -import torch from datasets import load_dataset from PIL import Image from transformers import AutoProcessor, Idefics3ForConditionalGeneration @@ -14,16 +13,11 @@ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) # Oneshot arguments -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here - - -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +MAX_SEQUENCE_LENGTH = 4096 +BATCH_SIZE = 512 +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]" # Recipe @@ -69,7 +63,7 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - return processor( + features = processor( text=sample["text"], images=sample["images"], padding=False, @@ -77,6 +71,9 @@ def tokenize(sample): truncation=True, ) + # remove extra dim added by vision processor + return [{key: feature[key][0] for key in feature} for feature in features] + # avoid errors with writer_batch_size ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) @@ -86,10 +83,9 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["LlamaDecoderLayer"], ) diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index b03aacee35..e970d41eb8 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -7,6 +7,7 @@ # Select model and load it. model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +#model_id = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -16,8 +17,9 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 32 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 16 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") @@ -58,9 +60,12 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="sequential", ) +exit(0) # Confirm generations of the quantized model look sane. print("\n\n") diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 2618b90197..b66fcfa275 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,9 +8,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Callable - -from transformers import DefaultDataCollator +from typing import Callable, Optional @dataclass @@ -69,9 +67,25 @@ class CustomDatasetArguments(DVCDatasetArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + data_collator: Optional[Callable] = field( + default=None, + metadata={ + "help": ( + "The function to used to form a batch from the dataset. Defaults to " + "`DataCollatorWithPadding(processor)`." + ) + }, + ) + + batch_size: int = field( + default=1, + metadata={ + "help": ( + "Calibration batch size. During calibration, LLM Compressor disables " + "lm_head output computations to reduce memory usage from large " + "calibration matches" + ) + }, ) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 2b80b1ed9a..dc11088cd4 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -7,6 +7,7 @@ one-shot calibration workflows. """ +import math import multiprocessing import re from typing import Any, Callable @@ -14,8 +15,8 @@ import torch from datasets import Dataset from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from torch.utils.data import DataLoader, SequentialSampler +from transformers.data import DataCollatorWithPadding from llmcompressor.args import DatasetArguments from llmcompressor.transformers.data import TextGenerationDataset @@ -115,44 +116,53 @@ def get_calibration_dataloader( ) calibration_dataset = datasets.get("calibration") + tokenizer = getattr(processor, "tokenizer", processor) + collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) + if tokenizer.pad_token is None or tokenizer.pad_token_id < 0: + logger.debug("Could not find padding token. Setting PAD token to EOS token") + tokenizer.pad_token = tokenizer.eos_token return format_calibration_data( tokenized_dataset=calibration_dataset, + collate_fn=collate_fn, + batch_size=dataset_args.batch_size, num_calibration_samples=dataset_args.num_calibration_samples, do_shuffle=dataset_args.shuffle_calibration_samples, - collate_fn=dataset_args.data_collator, ) def format_calibration_data( tokenized_dataset: Dataset, + collate_fn: Callable, + batch_size: int = 1, num_calibration_samples: int | None = None, do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, ) -> list[torch.Tensor]: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert + :param num_calibration_samples: number of batches to convert :param do_shuffle: whether to shuffle the dataset before selecting calibration samples, true by default :param collate_fn: optional custom collate function, or use default :return: list of trimmed calibration data tensors """ - safe_calibration_samples = len(tokenized_dataset) + # (1) shuffle dataset + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + + # (2) truncate dataset if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: + if num_calibration_samples > len(tokenized_dataset): logger.warning( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " + f"Requested {num_calibration_samples} calibration samples but the " + f"provided dataset only has {len(tokenized_dataset)} samples." ) + num_calibration_samples = len(tokenized_dataset) + tokenized_dataset = tokenized_dataset.select(range(num_calibration_samples)) - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - + # (3) infer number of workers MAX_DATALOADER_WORKERS = 8 try: num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2) @@ -161,19 +171,16 @@ def format_calibration_data( "Could not determine number of CPUs, defaulting to 0 dataloader workers." ) num_workers = 0 + + # (4) create dataloader dataloader_params = { - "batch_size": 1, - "sampler": RandomSampler(tokenized_calibration) - if do_shuffle - else SequentialSampler(tokenized_calibration), + "batch_size": batch_size, + "sampler": SequentialSampler(tokenized_dataset), "collate_fn": collate_fn, - "pin_memory": True, + "pin_memory": False, "num_workers": num_workers, } - - calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - - return calibration_dataloader + return DataLoader(tokenized_dataset, **dataloader_params) def make_dataset_splits( diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index c2b29aa97c..fba7dc6688 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from loguru import logger from torch.utils.data import DataLoader @@ -248,6 +248,8 @@ def oneshot( dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, + batch_size: int = 1, + data_collator: Optional[Callable] = None, num_calibration_samples: int = 512, shuffle_calibration_samples: bool = True, max_seq_length: int = 384, diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index fddcafd829..ecd63d4239 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1073,7 +1073,7 @@ def disable_lm_head(model: torch.nn.Module): does not untie parameters and restores the model proper loading upon exit """ _, lm_head = get_embeddings(model) - if lm_head is not None: + if lm_head is None: logger.warning( f"Attempted to disable lm_head of instance {model.__class__.__name__}, " "but was unable to to find lm_head. This may lead to unexpected OOM."