diff --git a/examples/quantization_w4a4_fp4/gpt_oss_example.py b/examples/quantization_w4a4_fp4/gpt_oss_example.py new file mode 100644 index 000000000..77de89b0d --- /dev/null +++ b/examples/quantization_w4a4_fp4/gpt_oss_example.py @@ -0,0 +1,76 @@ +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch.utils.data import DataLoader +import os +import torch + +from llmcompressor import oneshot +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.utils import dispatch_for_generation + +model_id = "openai/gpt-oss-20b" + +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True, +) +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + +# ----------------------------- +# Create calibration dataloader +# ----------------------------- +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 32 +MAX_SEQUENCE_LENGTH = 2048 + +ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") +ds = ds.shuffle(seed=42) + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + +ds = ds.map(preprocess) + +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + +ds = ds.map(tokenize, remove_columns=ds.column_names) +dataloader = DataLoader(ds, batch_size=1, shuffle=False) + +# ----------------------------- +# Quantization recipe +# ----------------------------- +recipe = QuantizationModifier( + targets="Linear", + scheme="NVFP4", + ignore=["lm_head"], +) + +SAVE_DIR = f"{model_id.split('/')[-1]}-NVFP4" + +oneshot( + model=model, + tokenizer=tokenizer, + recipe=recipe, + dataset=ds, + trust_remote_code_model=True, + output_dir=SAVE_DIR, +) + +# Save compressed +model.save_pretrained(SAVE_DIR, save_compressed=True) +tokenizer.save_pretrained(SAVE_DIR) \ No newline at end of file