A Transformer-based model for Genomic Prediction (GP) of maize grain yield, developed for the Genomes to Fields (G2F) Initiative prediction competition. The model predicts hybrid maize yield across diverse field environments by jointly learning from high-dimensional genotypic markers and environmental covariates.
The G2F competition challenges participants to predict grain yield for maize hybrids grown across North American field trials spanning multiple years (2014–2024). The key difficulty is Genotype × Environment (GxE) interaction: the same hybrid performs differently across environments due to soil, weather, management, and their interactions with genetic background. Models are evaluated on macro-averaged per-environment Pearson correlation coefficient (PCC), which rewards consistent within-environment ranking of hybrids rather than absolute yield prediction.
| Split | Years | Purpose |
|---|---|---|
| Train | 2014–2022 | Model fitting |
| Validation | 2023 (or LEO holdout) | Early stopping & hyperparameter selection |
| Test | 2024 | Final evaluation |
Each sample consists of:
- Genotype (G): 2,024 SNP marker tokens (encoded as 0 / 0.5 / 1) or GRM-standardized continuous features
- Environment (E): 705 numeric covariates (weather, soil, management) + optional categorical features (irrigation, treatment, previous crop)
- Target: Grain yield (mg/Ha)
Data files are expected under data/ following the maize_data_* directory structure with X_train.csv, y_train.csv, X_test.csv, y_test.csv.
| Model | Class | Description |
|---|---|---|
| FullTransformer | FullTransformer |
Primary architecture. Tokenizes all inputs (CLS + SNP markers + env features) into a single sequence, applies sinusoidal positional encoding, and processes through Pre-LN transformer blocks. A CLS token pools the final representation for yield regression. |
| FullTransformerResidual | FullTransformerResidual |
Extends FullTransformer with a residual decomposition head: predicts env-mean yield + genotype-specific residual separately. |
| GxE_Transformer | GxE_Transformer |
Three-prong encoder architecture with independent G, E, and LD encoders fused via learned weighted gating, then processed by a GxE interaction module (transformer, MLP, or CNN). |
| GxE_ResidualTransformer | GxE_ResidualTransformer |
Adds residual yield decomposition to the three-prong architecture. |
| Component | Description |
|---|---|
| G_Encoder | Transformer encoder for genotype tokens with CLS pooling. Supports dense or MoE blocks. |
| E_Encoder | Multi-layer MLP with residual connections for environmental covariates. |
| LD_Encoder | 1D ResNet CNN processing one-hot encoded markers for linkage-disequilibrium patterns. |
| MoE Layer | Mixture-of-Experts with Top-K gating, optional shared expert, and load-balancing loss. |
| FlashAttention | Bidirectional self-attention via F.scaled_dot_product_attention. |
| Stochastic Depth | Linearly increasing drop-path rate across transformer blocks. |
The three-prong model supports three GxE fusion strategies (set via GXE_ENC):
tf— Transformer blocks over the concatenated G+E token sequence (default)mlp— Residual MLP blocks over mean-pooled representationscnn— 1D ResNet blocks over the concatenated sequence
Training uses PyTorch DDP across multiple AMD MI250X GPUs on ORNL Frontier. Key settings are configured as environment variables in train.slurm.
| Parameter | Variable | Default | Description |
|---|---|---|---|
| Global batch size | GBS |
1024 | Divided evenly across GPUs |
| Learning rate | LR |
1e-4 | AdamW with cosine schedule |
| Weight decay | WEIGHT_DECAY |
1e-5 | AdamW regularization |
| Embedding dim | EMB_SIZE |
256 | Transformer hidden dimension |
| Attention heads | HEADS |
4 | Multi-head attention |
| GxE layers | GXE_LAYERS |
1 | Transformer depth for GxE interaction |
| Dropout | DROPOUT |
0.15 | Applied throughout |
| Early stopping | EARLY_STOP |
200 | Patience (epochs) |
| Max epochs | NUM_EPOCHS |
3000 | Upper bound for training |
Composite losses are specified via LOSS (e.g., "envpcc", "mse+envpcc") with per-component weights in LOSS_WEIGHTS.
| Loss | Description |
|---|---|
envpcc |
1 − macro-averaged per-environment Pearson correlation (with Fisher z-transform) |
pcc |
1 − global Pearson correlation |
mse |
Mean squared error |
envspearman |
1 − macro-averaged per-environment Spearman correlation (differentiable via soft ranks) |
spearman |
1 − global differentiable Spearman correlation |
envmse |
Macro-averaged per-environment MSE |
ktau |
Differentiable Kendall-tau ranking loss |
xi |
Chatterjee's Xi coefficient loss |
Auxiliary contrastive objectives align learned representations with known genetic/environmental similarity structures.
| Variable | Options | Description |
|---|---|---|
CONTRASTIVE_MODE |
none, g, e, g+e |
Which contrastive heads to activate |
CONTRASTIVE_SIM_TYPE |
grm, ibs |
Genetic similarity metric |
CONTRASTIVE_LOSS_TYPE |
mse, cosine, kl |
How to match embedding vs. target similarity |
| Strategy | Variable | Description |
|---|---|---|
| Year-based | default | Train ≤ 2022, validate on 2023 |
| LEO | LEO_VAL=True |
Leave-Environment-Out: hold out 15% of environments (all years) |
| Env-stratified batching | ENV_STRATIFIED=True |
Ensures each batch contains ≥ MIN_SAMPLES_PER_ENV samples per environment for stable envwise loss computation |
sbatch train.slurmEdit the environment variables at the top of train.slurm to configure the model and hyperparameters. Training logs, checkpoints, and evaluation results are saved to logs/, checkpoints/, and data/results/ respectively. Metrics are tracked via Weights & Biases.
# Architecture selection
FULL_TRANSFORMER=True # Use FullTransformer (vs. three-prong GxE)
G_ENC=True # Enable genotype encoder (three-prong only)
E_ENC=True # Enable environment encoder (three-prong only)
LD_ENC=True # Enable LD encoder (three-prong only)
GXE_ENC=tf # GxE backend: tf, mlp, cnn
WG=True # Weighted gating for three-prong fusion
# Genotype input
G_INPUT_TYPE=tokens # "tokens" (discrete SNPs) or "grm" (continuous)
# MoE settings
G_ENCODER_TYPE=moe # "dense" or "moe"
MOE_NUM_EXPERTS=8
MOE_TOP_K=2
MOE_SHARED_EXPERT=TrueGP-Transformer/
├── models/
│ ├── config.py # Config dataclass (all hyperparameters)
│ ├── model.py # FullTransformer, GxE_Transformer, and residual variants
│ ├── transformer.py # Attention, MLP, TransformerBlock, MoE blocks, G_Encoder
│ ├── moe.py # Mixture-of-Experts layer (Top-K gating, shared expert)
│ ├── mlp.py # MLP blocks, E_Encoder
│ └── cnn.py # 1D ResNet blocks, LD_Encoder
├── scripts/
│ ├── train.py # DDP training loop with checkpointing
│ ├── eval.py # Checkpoint evaluation, metrics, and result export
│ ├── train_residual.py # Training script for residual model variants
│ └── train_rolling.py # Rolling-window temporal training
├── utils/
│ ├── dataset.py # GxE_Dataset, data loading, env-stratified sampling
│ ├── loss.py # All loss functions (envpcc, contrastive, ranking, etc.)
│ ├── utils.py # CLI parsing, run naming, samplers, helpers
│ └── get_lr.py # Cosine learning rate schedule
├── data/ # Competition datasets (not tracked)
├── train.slurm # Primary SLURM job script
├── best_train.slurm # Best-config training script
└── notebooks/ # Analysis notebooks
- Python 3.9+
- PyTorch 2.0+ (with ROCm support for AMD GPUs)
wandb,scipy,scikit-learn,python-dotenv,tqdm- Optional:
torchsort(for differentiable Spearman loss)
Developed and tested on ORNL Frontier — AMD MI250X GPUs with ROCm 6.3.1, NCCL/RCCL backend for distributed training.