- [2026/03] π¬ We have created a KDFlow WeChat group! Welcome to join us for discussion and communication!
- [2026/03] π KDFlow v0.1.1 released! Now supports vision-language (multimodal) models and Qwen3.5 series (as the teacher model).
- π₯ News
- β¨ Key Features
- π Quick Start
- βοΈ Configuration Reference
- π§© Extending KDFlow
- π Design Highlights
- π Acknowledgement
- π Citation
- π License
- π¬ WeChat Group
- β Star History
- Decoupled Infrastructure - Using SGLang & FSDP2 for teacher inference and student training respectively.
- Off-Policy Knowledge Distillation β Distill from pre-collected teacher hidden states on static datasets.
- On-Policy Knowledge Distillation β Student-generated rollout responses are used for teacher forward and distillation training in a closed loop.
- Cross-Tokenizer Distillation β Native support for distilling between models with different tokenizers (e.g., Llama β Qwen).
- SFT Training (Black-box KD) β Supervised fine-tuning on collected dataset.
- MultiModal Support β Support distillation with vision-language models (e.g., Qwen3-VL).
- Colocate Mode β Teacher and student models share the same GPUs via sleep/wakeup mechanism, maximizing GPU utilization.
- Teacher on SGLang β Teacher inference is powered by SGLang Engine, enabling high-throughput prefilling and flexible parallel strategies.
- Pluggable KD Algorithms β Built-in support for Vanilla KD and DSKD (Dual-Space Knowledge Distillation), with easy registration of custom algorithms.
- Multiple Loss Functions β Torch compiled KL divergence, Reverse KL divergence, JS divergence, Adaptive KL (AKL), etc.
- LoRA Support β Optional LoRA fine-tuning for the student model.
- Wand&b Integration β Built-in wand&b logging for experiment tracking.
- High Training Efficiency β Achieves 1.4x to 6x faster distillation compared to mainstream knowledge distillation frameworks.
git clone https://github.com/songmzhang/KDFlow.git
cd KDFlow
pip install -e ./Since SGLang 0.5.9 does not support transformers v5, please use transformers v4.57.1 to ensure correct teacher computation.
LLMs:
bash ./examples/off_policy_kd/run_qwen3_30b_a3b_to_4b.shVLMs:
bash ./examples/off_policy_kd/run_qwen3_vl_30b_a3b_to_4b.shLLMs:
bash ./examples/on_policy_kd/run_qwen3_30b_a3b_to_4b.shVLMs:
bash ./examples/on_policy_kd/run_qwen3_vl_30b_a3b_to_4b.shUse SimpleCrossTokenizerKD (suggested):
bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_offpolicy_simple_ctkd.shor DSKD:
bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_offpolicy.shUse SimpleCrossTokenizerKD (suggested):
bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_onpolicy_simple_ctkd.shor DSKD:
bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_onpolicy.shbash ./examples/sft/run_qwen3_4b.sh| Argument | Default | Description |
|---|---|---|
--student_name_or_path |
None |
Student model name or path |
--teacher_name_or_path |
None |
Teacher model name or path |
--attn_implementation |
flash_attention_2 |
Attention implementation |
--use_liger_kernel |
False |
Use Liger Kernel for student model |
--lora_rank |
0 |
LoRA rank (0 = disabled) |
--lora_alpha |
16 |
LoRA alpha |
--target_modules |
all-linear |
LoRA target modules |
--lora_dropout |
0.0 |
LoRA dropout |
--ring_attn_size |
1 |
Ring attention group size for context parallelism |
--enable_thinking |
False |
Enable thinking mode |
--disable_fast_tokenizer |
False |
Disable fast tokenizer |
| Argument | Default | Description |
|---|---|---|
--num_nodes |
1 |
Number of training nodes |
--num_gpus_per_node |
8 |
GPUs per node |
--num_epochs |
1 |
Number of training epochs |
--train_batch_size |
128 |
Global training batch size |
--micro_train_batch_size |
1 |
Per-GPU micro batch size |
--learning_rate |
1e-6 |
Learning rate |
--lr_scheduler |
cosine_with_min_lr |
LR scheduler type |
--lr_warmup_ratio |
0.05 |
Warmup ratio |
--min_lr |
1e-8 |
Minimum learning rate |
--max_norm |
1.0 |
Gradient clipping max norm |
--weight_decay |
0.0 |
Weight decay |
--adam_betas |
(0.9, 0.98) |
Adam optimizer betas |
--backend |
fsdp2 |
Training backend |
--gradient_checkpointing |
False |
Enable gradient checkpointing |
--enable_sleep |
False |
Enable sleep mode for all components (student, teacher, rollout) |
--eval_steps |
-1 |
Evaluate every N steps (-1 = disabled) |
--save_steps |
-1 |
Save checkpoint every N steps (-1 = disabled) |
--save_path |
./ckpt/ |
Model save path |
--ckpt_path |
./ckpt/checkpoints_distill |
Checkpoint save path |
--seed |
42 |
Random seed |
--bf16 |
False |
Enable bfloat16 training |
| Argument | Default | Description |
|---|---|---|
--fsdp_size |
-1 |
FSDP shard size for HSDP (-1 = full sharding) |
--cpu_offload |
False |
Offload Adam optimizer states to CPU |
| Argument | Default | Description |
|---|---|---|
--kd_ratio |
0.5 |
KD loss weight: loss = (1 - kd_ratio) * CE + kd_ratio * KD |
--kd_temperature |
1.0 |
Temperature for softmax in KD |
--kd_algorithm |
vanilla_kd |
KD algorithm (vanilla_kd / dskd) |
--kd_loss_fn |
kl |
Divergence function (kl / rkl / jsd / akl) |
--teacher_tp_size |
8 |
Teacher tensor parallel size |
--teacher_ep_size |
1 |
Teacher expert parallel size (MoE models) |
--teacher_pp_size |
1 |
Teacher pipeline parallel size |
--teacher_dp_size |
1 |
Teacher data parallel size |
--teacher_forward_n_batches |
1 |
Teacher forward N batches at once |
--teacher_mem_fraction_static |
0.4 |
SGLang static memory fraction for teacher |
--teacher_offload_tags |
all |
Offload tags for SGLang |
--teacher_quantization |
None |
Teacher model quantization |
--dskd_token_align |
eta |
Token alignment strategy for DSKD (eta / cma) |
--dskd_topk_vocab |
-1 |
Top-k vocab tokens for DSKD projector init (-1 = all) |
--dskd_projector_lr |
1e-4 |
Learning rate for DSKD projectors |
--jsd_beta |
0.5 |
Beta for Jensen-Shannon Divergence |
--skew_lambda |
0.1 |
Lambda for Skewed KL/RKL |
--adaptive_alpha |
0.5 |
Alpha for Adaptive KL Divergence |
--hrl_topk |
5 |
Top-k for Hierarchical Ranking Loss |
| Argument | Default | Description |
|---|---|---|
--rollout_num_engines |
0 |
Number of SGLang rollout engines (0 = off-policy) |
--rollout_tp_size |
1 |
Tensor parallel per rollout engine |
--rollout_batch_size |
32 |
Prompts per rollout iteration |
--n_samples_per_prompt |
1 |
Number of responses per prompt |
--generate_max_len |
2048 |
Max generation length |
--temperature |
1.0 |
Sampling temperature |
--top_p |
1.0 |
Top-p sampling |
--rollout_mem_fraction_static |
0.6 |
GPU memory utilization per rollout engine |
--print_rollout_sample |
False |
Print a rollout sample after each rollout |
| Argument | Default | Description |
|---|---|---|
--train_dataset_path |
None |
Training dataset path |
--train_dataset_probs |
None |
Sampling probabilities for multiple datasets |
--train_split |
train |
Train split name |
--eval_dataset_path |
None |
Evaluation dataset path |
--eval_split |
eval |
Eval split name |
--input_key |
messages |
Dataset input key |
--output_key |
None |
Dataset output key |
--image_key |
None |
Image key for multimodal datasets |
--teacher_input_key |
None |
Input key for teacher prompt (for self-distillation/context distillation) |
--label_key |
None |
Label key in dataset |
--apply_chat_template |
True |
Apply tokenizer chat template |
--max_len |
4096 |
Max sequence length |
--prompt_max_len |
2048 |
Max prompt length |
--max_samples |
1e8 |
Max number of samples to load |
--packing_samples |
False |
Pack sequences for efficiency |
--preprocess_num_workers |
8 |
Number of workers for data preprocessing |
| Argument | Default | Description |
|---|---|---|
--logging_steps |
10 |
Log every N steps |
--use_wandb |
False |
Enable W&B logging |
--wandb_org |
None |
W&B organization name |
--wandb_project |
None |
W&B project name |
--wandb_group |
None |
W&B group name |
--wandb_run_name |
None |
W&B run name |
--wandb_mode |
online |
W&B mode (online / offline / disabled) |
--wandb_dir |
None |
Directory to store W&B offline logs |
Create a new file in kdflow/algorithms/ and register it:
import torch
from kdflow.loss import LOSS_DICT
from kdflow.algorithms import register_algorithm
@register_algorithm("my_custom_kd")
class MyCustomKD:
def __init__(self, strategy, student_model, teacher_lm_head, **kwargs):
self.strategy = strategy
self.student = student_model
self.teacher_lm_head = teacher_lm_head
self.loss_fn = LOSS_DICT[strategy.args.kd.loss_fn]
def training_step(self, micro_batch):
# Access student inputs
student_input_ids = micro_batch["stu_input_ids"]
student_attn_mask = micro_batch["stu_attn_mask"]
student_loss_mask = micro_batch["stu_loss_mask"].bool()
teacher_hiddens = micro_batch["teacher_hiddens"]
avg_token_num = micro_batch["avg_micro_batch_token_num"]
# Student forward
output = self.student(student_input_ids, attention_mask=student_attn_mask, return_output=True)
student_logits = output["logits"][student_loss_mask]
# Teacher logits from hidden states + lm_head
teacher_logits = self.teacher_lm_head(teacher_hiddens.to(self.teacher_lm_head.weight))
# Compute your custom loss
kd_loss = self.loss_fn(student_logits, teacher_logits, temperature=1.0)
kd_loss = kd_loss.sum() / avg_token_num
return {"loss": kd_loss, "kd_loss": kd_loss}Then use it with --kd_algorithm my_custom_kd.
Create a new file in kdflow/loss/ and register it:
import torch
import torch.nn.functional as F
from kdflow.loss import register_loss
@register_loss("my_custom_loss")
@torch.compile()
def compute_kl_div(
student_logits,
teacher_logits,
temperature=1.0,
reduction="none",
**kwargs
):
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
log_probs = torch.log_softmax(student_logits, -1, dtype=torch.float32)
target_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
kl_div = F.kl_div(log_probs, target_probs, reduction=reduction).sum(-1)
return kl_divThen use it with --kd_loss_fn my_custom_loss.
KDFlow enables teacher and student to share the same GPUs through a sleep/wakeup mechanism:
- Teacher phase: Teacher model weights are loaded on GPU, student optimizer states are offloaded to CPU.
- Student phase: Student optimizer states are reloaded to GPU, teacher model weights are offloaded to CPU.
This allows running large teacher models (e.g., 200B+ parameters) on the same hardware as the student without requiring separate GPU pools.
Hidden States Transfer via Shared Memory
Instead of transferring full teacher logits (which can be enormous for large vocabularies), KDFlow:
- Extracts hidden states from the teacher's last layer via SGLang.
- Transfers them to the student via shared memory (zero-copy).
- Computes teacher logits on the student side using only the teacher's
lm_headweights.
This dramatically reduces memory and communication overhead.
The TeacherActorGroup uses a greedy token-based load balancing strategy to distribute micro-batches across teacher actors, ensuring even workload distribution when sequence lengths vary.
KDFlow is built upon the shoulders of outstanding open-source projects. We sincerely thank:
- SGLang β We deeply appreciate its support for extracting hidden states from model inference and its exceptional inference efficiency, which are critical to KDFlow's teacher inference pipeline.
- OpenRLHF β We gratefully adopt its well-designed abstractions for model wrapping and distributed training strategy, which form the foundation of our training infrastructure.
- slime β We appreciate its elegant implementation of Ray placement group initialization and the weight update mechanism for SGLang, which greatly inspired our design of on-policy distillation.
If you find KDFlow useful in your research or work, please consider citing our paper:
@article{zhang2026kdflow,
title={KDFlow: A User-Friendly and Efficient Knowledge Distillation Framework for Large Language Models},
author={Songming Zhang and Xue Zhang and Tong Zhang and Bojie Hu and Yufeng Chen and Jinan Xu},
year={2026},
eprint={2603.01875},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2603.01875},
}This project is licensed under the MIT License.
Welcome to join our WeChat group for discussion and communication!



