A PyTorch-based implementation for training smaller student models to mimic larger teacher models for code generation tasks, specifically focused on Python code. This project implements knowledge distillation techniques to create efficient models capable of speculative decoding.
This project aims to create a lightweight student model that can perform speculative decoding by learning from a larger teacher model (CodeLlama). The core idea is to train a smaller, faster model that can generate draft sequences which are then validated by the larger teacher model, potentially speeding up inference.
- Student Model: Custom transformer architecture with multi-headed attention
- Teacher Model: CodeLlama-7b-Python-hf (quantized for efficiency)
- Knowledge Distillation: Custom loss combining cross-entropy and KL divergence
- Dataset: Python code from
DaniilOr/humanized_cleaned_code
├── distilation_model/
│ └── studentmodel.py # Custom transformer student model
├── tutor_model/
│ └── codellama.py # CodeLlama teacher model wrapper
├── loss/
│ └── customloss.py # Knowledge distillation loss function
├── dataset/
│ └── dataset.py # Python code dataset handler
├── metric_manager/
│ └── metric_manager.py # Training metrics tracking
├── research/ # Experimental notebooks and scripts
└── training.py # Main training loop
The student model (StudentModel) implements:
- Multi-headed attention mechanism with parallel computation
- Positional embeddings using sinusoidal encoding
- Custom transformer layers optimized for code generation
- Significantly smaller parameter count compared to CodeLlama
- Model:
codellama/CodeLlama-7b-Python-hf - Quantization: 4-bit quantization for memory efficiency
- Frozen parameters: Used only for generating target logits
- Custom generation: Implements temperature-controlled generation
The training uses a weighted combination of:
- Cross-entropy loss: Traditional supervised learning on ground truth
- KL divergence loss: Knowledge transfer from teacher to student
- Temperature scaling: Softens probability distributions for better knowledge transfer
loss = α × CrossEntropy(student_logits, targets) + (1-α) × KL(student_soft, teacher_soft)The project tracks:
- Total loss: Combined distillation loss
- Cross-entropy loss: Task-specific loss
- KL divergence loss: Knowledge transfer loss
- Perplexity: Language modeling performance
- Sample outputs: Input/output comparisons via TensorBoard
The research/ directory contains experimental work:
- Student model architecture with multi-headed attention
- Teacher model integration with CodeLlama
- Knowledge distillation loss implementation
- Dataset pipeline for Python code
- Training loop with checkpointing
- Metrics tracking and logging
- Memory-efficient training with quantization
- Hyperparameter tuning and optimization
- Speculative decoding inference pipeline
- Implement speculative decoding inference: Core functionality for using student model as draft generator
- Hyperparameter optimization: Learning rate scheduling, better batch sizes
- Dataset diversity: Include more programming languages or code types
- Inference optimization: Model quantization for student model
- Benchmarking suite: Compare against other code generation models
- Configuration management: YAML/JSON config files instead of hardcoded values
- Distributed training: Multi-GPU support
- Model serving: REST API for inference
- Documentation: API documentation and tutorials
- Memory constraints: Current batch size is limited to 4 due to GPU memory
- Path dependencies: Hardcoded paths need to be made configurable
- Limited metrics: Only basic loss metrics, missing code-specific evaluations
- Onyxia: Providing GPU infrastructure for model training and experimentation
This project implements knowledge distillation for code generation models with the goal of enabling speculative decoding for faster inference.