Skip to content

songmzhang/KDFlow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

86 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

KDFlow Logo

A User-friendly and Efficient Framework for LLM Knowledge Distillation

Release License arXiv WeChat Stars


πŸ”₯ News

  • [2026/03] πŸ’¬ We have created a KDFlow WeChat group! Welcome to join us for discussion and communication!
  • [2026/03] πŸŽ‰ KDFlow v0.1.1 released! Now supports vision-language (multimodal) models and Qwen3.5 series (as the teacher model).

πŸ“‘ Table of Contents


✨ Key Features

  • Decoupled Infrastructure - Using SGLang & FSDP2 for teacher inference and student training respectively.
  • Off-Policy Knowledge Distillation β€” Distill from pre-collected teacher hidden states on static datasets.
  • On-Policy Knowledge Distillation β€” Student-generated rollout responses are used for teacher forward and distillation training in a closed loop.
  • Cross-Tokenizer Distillation β€” Native support for distilling between models with different tokenizers (e.g., Llama β†’ Qwen).
  • SFT Training (Black-box KD) β€” Supervised fine-tuning on collected dataset.
  • MultiModal Support β€” Support distillation with vision-language models (e.g., Qwen3-VL).
  • Colocate Mode β€” Teacher and student models share the same GPUs via sleep/wakeup mechanism, maximizing GPU utilization.
  • Teacher on SGLang β€” Teacher inference is powered by SGLang Engine, enabling high-throughput prefilling and flexible parallel strategies.
  • Pluggable KD Algorithms β€” Built-in support for Vanilla KD and DSKD (Dual-Space Knowledge Distillation), with easy registration of custom algorithms.
  • Multiple Loss Functions β€” Torch compiled KL divergence, Reverse KL divergence, JS divergence, Adaptive KL (AKL), etc.
  • LoRA Support β€” Optional LoRA fine-tuning for the student model.
  • Wand&b Integration β€” Built-in wand&b logging for experiment tracking.
  • High Training Efficiency β€” Achieves 1.4x to 6x faster distillation compared to mainstream knowledge distillation frameworks.

KDFlow Architecture


πŸš€ Quick Start

Installation

git clone https://github.com/songmzhang/KDFlow.git
cd KDFlow
pip install -e ./

Since SGLang 0.5.9 does not support transformers v5, please use transformers v4.57.1 to ensure correct teacher computation.

Off-Policy Knowledge Distillation

LLMs:

bash ./examples/off_policy_kd/run_qwen3_30b_a3b_to_4b.sh

VLMs:

bash ./examples/off_policy_kd/run_qwen3_vl_30b_a3b_to_4b.sh

On-Policy Knowledge Distillation

LLMs:

bash ./examples/on_policy_kd/run_qwen3_30b_a3b_to_4b.sh

VLMs:

bash ./examples/on_policy_kd/run_qwen3_vl_30b_a3b_to_4b.sh

Cross-Tokenizer Knowledge Distillation

Off-Policy

Use SimpleCrossTokenizerKD (suggested):

bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_offpolicy_simple_ctkd.sh

or DSKD:

bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_offpolicy.sh

Off-Policy

Use SimpleCrossTokenizerKD (suggested):

bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_onpolicy_simple_ctkd.sh

or DSKD:

bash ./examples/cross_tokenizer_kd/run_qwen3_30b_a3b_to_llama3_2_3b_onpolicy.sh

Supervised Fine-Tuning (SFT)

bash ./examples/sft/run_qwen3_4b.sh

βš™οΈ Arguments

Model Arguments

Argument Default Description
--student_name_or_path None Student model name or path
--teacher_name_or_path None Teacher model name or path
--attn_implementation flash_attention_2 Attention implementation
--use_liger_kernel False Use Liger Kernel for student model
--lora_rank 0 LoRA rank (0 = disabled)
--lora_alpha 16 LoRA alpha
--target_modules all-linear LoRA target modules
--lora_dropout 0.0 LoRA dropout
--ring_attn_size 1 Ring attention group size for context parallelism
--enable_thinking False Enable thinking mode
--disable_fast_tokenizer False Disable fast tokenizer

Training Arguments

Argument Default Description
--num_nodes 1 Number of training nodes
--num_gpus_per_node 8 GPUs per node
--num_epochs 1 Number of training epochs
--train_batch_size 128 Global training batch size
--micro_train_batch_size 1 Per-GPU micro batch size
--learning_rate 1e-6 Learning rate
--lr_scheduler cosine_with_min_lr LR scheduler type
--lr_warmup_ratio 0.05 Warmup ratio
--min_lr 1e-8 Minimum learning rate
--max_norm 1.0 Gradient clipping max norm
--weight_decay 0.0 Weight decay
--adam_betas (0.9, 0.98) Adam optimizer betas
--backend fsdp2 Training backend
--gradient_checkpointing False Enable gradient checkpointing
--enable_sleep False Enable sleep mode for all components (student, teacher, rollout)
--eval_steps -1 Evaluate every N steps (-1 = disabled)
--save_steps -1 Save checkpoint every N steps (-1 = disabled)
--save_path ./ckpt/ Model save path
--ckpt_path ./ckpt/checkpoints_distill Checkpoint save path
--seed 42 Random seed
--bf16 False Enable bfloat16 training

FSDP Arguments

Argument Default Description
--fsdp_size -1 FSDP shard size for HSDP (-1 = full sharding)
--cpu_offload False Offload Adam optimizer states to CPU

Distillation Arguments

Argument Default Description
--kd_ratio 0.5 KD loss weight: loss = (1 - kd_ratio) * CE + kd_ratio * KD
--kd_temperature 1.0 Temperature for softmax in KD
--kd_algorithm vanilla_kd KD algorithm (vanilla_kd / dskd)
--kd_loss_fn kl Divergence function (kl / rkl / jsd / akl)
--teacher_tp_size 8 Teacher tensor parallel size
--teacher_ep_size 1 Teacher expert parallel size (MoE models)
--teacher_pp_size 1 Teacher pipeline parallel size
--teacher_dp_size 1 Teacher data parallel size
--teacher_forward_n_batches 1 Teacher forward N batches at once
--teacher_mem_fraction_static 0.4 SGLang static memory fraction for teacher
--teacher_offload_tags all Offload tags for SGLang
--teacher_quantization None Teacher model quantization
--dskd_token_align eta Token alignment strategy for DSKD (eta / cma)
--dskd_topk_vocab -1 Top-k vocab tokens for DSKD projector init (-1 = all)
--dskd_projector_lr 1e-4 Learning rate for DSKD projectors
--jsd_beta 0.5 Beta for Jensen-Shannon Divergence
--skew_lambda 0.1 Lambda for Skewed KL/RKL
--adaptive_alpha 0.5 Alpha for Adaptive KL Divergence
--hrl_topk 5 Top-k for Hierarchical Ranking Loss

Rollout Arguments (On-Policy)

Argument Default Description
--rollout_num_engines 0 Number of SGLang rollout engines (0 = off-policy)
--rollout_tp_size 1 Tensor parallel per rollout engine
--rollout_batch_size 32 Prompts per rollout iteration
--n_samples_per_prompt 1 Number of responses per prompt
--generate_max_len 2048 Max generation length
--temperature 1.0 Sampling temperature
--top_p 1.0 Top-p sampling
--rollout_mem_fraction_static 0.6 GPU memory utilization per rollout engine
--print_rollout_sample False Print a rollout sample after each rollout

Data Arguments

Argument Default Description
--train_dataset_path None Training dataset path
--train_dataset_probs None Sampling probabilities for multiple datasets
--train_split train Train split name
--eval_dataset_path None Evaluation dataset path
--eval_split eval Eval split name
--input_key messages Dataset input key
--output_key None Dataset output key
--image_key None Image key for multimodal datasets
--teacher_input_key None Input key for teacher prompt (for self-distillation/context distillation)
--label_key None Label key in dataset
--apply_chat_template True Apply tokenizer chat template
--max_len 4096 Max sequence length
--prompt_max_len 2048 Max prompt length
--max_samples 1e8 Max number of samples to load
--packing_samples False Pack sequences for efficiency
--preprocess_num_workers 8 Number of workers for data preprocessing

Logging Arguments

Argument Default Description
--logging_steps 10 Log every N steps
--use_wandb False Enable W&B logging
--wandb_org None W&B organization name
--wandb_project None W&B project name
--wandb_group None W&B group name
--wandb_run_name None W&B run name
--wandb_mode online W&B mode (online / offline / disabled)
--wandb_dir None Directory to store W&B offline logs

🧩 Extending KDFlow

Adding a Custom KD Algorithm

Create a new file in kdflow/algorithms/ and register it:

import torch
from kdflow.loss import LOSS_DICT
from kdflow.algorithms import register_algorithm


@register_algorithm("my_custom_kd")
class MyCustomKD:
    def __init__(self, strategy, student_model, teacher_lm_head, **kwargs):
        self.strategy = strategy
        self.student = student_model
        self.teacher_lm_head = teacher_lm_head
        self.loss_fn = LOSS_DICT[strategy.args.kd.loss_fn]

    def training_step(self, micro_batch):
        # Access student inputs
        student_input_ids = micro_batch["stu_input_ids"]
        student_attn_mask = micro_batch["stu_attn_mask"]
        student_loss_mask = micro_batch["stu_loss_mask"].bool()
        teacher_hiddens = micro_batch["teacher_hiddens"]
        avg_token_num = micro_batch["avg_micro_batch_token_num"]

        # Student forward
        output = self.student(student_input_ids, attention_mask=student_attn_mask, return_output=True)
        student_logits = output["logits"][student_loss_mask]

        # Teacher logits from hidden states + lm_head
        teacher_logits = self.teacher_lm_head(teacher_hiddens.to(self.teacher_lm_head.weight))

        # Compute your custom loss
        kd_loss = self.loss_fn(student_logits, teacher_logits, temperature=1.0)
        kd_loss = kd_loss.sum() / avg_token_num

        return {"loss": kd_loss, "kd_loss": kd_loss}

Then use it with --kd_algorithm my_custom_kd.

Adding a Custom KD Loss

Create a new file in kdflow/loss/ and register it:

import torch
import torch.nn.functional as F 

from kdflow.loss import register_loss


@register_loss("my_custom_loss")
@torch.compile()
def compute_kl_div(
    student_logits,
    teacher_logits, 
    temperature=1.0,
    reduction="none",
    **kwargs
):
    student_logits = student_logits / temperature
    teacher_logits = teacher_logits / temperature
    log_probs = torch.log_softmax(student_logits, -1, dtype=torch.float32)
    target_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
    kl_div = F.kl_div(log_probs, target_probs, reduction=reduction).sum(-1)
    
    return kl_div

Then use it with --kd_loss_fn my_custom_loss.


πŸ”‘ Design Highlights

GPU Co-location via Sleep/Wakeup

KDFlow enables teacher and student to share the same GPUs through a sleep/wakeup mechanism:

  1. Teacher phase: Teacher model weights are loaded on GPU, student optimizer states are offloaded to CPU.
  2. Student phase: Student optimizer states are reloaded to GPU, teacher model weights are offloaded to CPU.

This allows running large teacher models (e.g., 200B+ parameters) on the same hardware as the student without requiring separate GPU pools.

Hidden States Transfer via Shared Memory

Knowledge transfer cost

Instead of transferring full teacher logits (which can be enormous for large vocabularies), KDFlow:

  1. Extracts hidden states from the teacher's last layer via SGLang.
  2. Transfers them to the student via shared memory (zero-copy).
  3. Computes teacher logits on the student side using only the teacher's lm_head weights.

This dramatically reduces memory and communication overhead.

Token-Based Teacher Load Balancing

The TeacherActorGroup uses a greedy token-based load balancing strategy to distribute micro-batches across teacher actors, ensuring even workload distribution when sequence lengths vary.


πŸ™ Acknowledgement

KDFlow is built upon the shoulders of outstanding open-source projects. We sincerely thank:

  • SGLang β€” We deeply appreciate its support for extracting hidden states from model inference and its exceptional inference efficiency, which are critical to KDFlow's teacher inference pipeline.
  • OpenRLHF β€” We gratefully adopt its well-designed abstractions for model wrapping and distributed training strategy, which form the foundation of our training infrastructure.
  • slime β€” We appreciate its elegant implementation of Ray placement group initialization and the weight update mechanism for SGLang, which greatly inspired our design of on-policy distillation.

πŸ“– Citation

If you find KDFlow useful in your research or work, please consider citing our paper:

@article{zhang2026kdflow,
      title={KDFlow: A User-Friendly and Efficient Knowledge Distillation Framework for Large Language Models}, 
      author={Songming Zhang and Xue Zhang and Tong Zhang and Bojie Hu and Yufeng Chen and Jinan Xu},
      year={2026},
      eprint={2603.01875},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2603.01875}, 
}

πŸ“„ License

This project is licensed under the MIT License.


πŸ’¬ WeChat Group

Welcome to join our WeChat group for discussion and communication!

WeChat Group QR Code


⭐ Star History

Star History Chart

About

A user-friendly & efficient knowledge distillation framework for LLMs, supporting off-policy, on-policy (OPD), cross-tokenizer, multimodal, and on-policy self-distillation.

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages