A research framework for pretraining and evaluating ECG neural networks. Supports multiple architectures, training objectives, and data representations with distributed training out of the box. Prepare datasets with ecg_preprocess before use.
Status: Beta.
We use torch 2.9 with cuda 12.8 and primarily use H100 GPUs.
git clone https://github.com/ELM-Research/ecg_encoder.git
cd ecg_encoder && uv syncFor BPE symbolic representation with ECG-Byte, compile the Rust tokenizer:
cd src/dataloaders/data_representation/bpe
maturin develop --releaseIf Rust is not installed: curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- --default-toolchain=1.82.0 -y
Set DATA_DIR in src/configs/constants.py to your preprocessed data directory which contains the one or more of the following subdirectories: mimic_iv, ptb_xl, code15, cpsc, csn. These can be preprocessed using the ecg_preprocess repository.
| Dataset | Key |
|---|---|
| PTB-XL | ptb_xl |
| MIMIC-IV | mimic_iv |
| CODE 15 | code15 |
| CPSC | cpsc |
| CSN | csn |
--data_representation |
Description |
|---|---|
signal |
Raw ECG matrix |
bpe_symbolic |
BPE-tokenized symbolic sequence |
| Model | --neural_network |
--objective |
Representation |
|---|---|---|---|
| DiT | trans_continuous_dit |
ddpm, rectified_flow |
signal |
| NEPA | trans_continuous_nepa |
autoregressive |
signal |
| Decoder Transformer | trans_discrete_decoder |
autoregressive |
bpe_symbolic |
| MAE ViT | mae_vit |
mae |
signal |
| MERL | merl |
merl |
signal |
| MLAE | mlae |
mlae |
signal |
| MTAE | mtae |
mtae |
signal |
| ST-MEM | st_mem |
st_mem |
signal |
DiT with rectified flow on 8 GPUs:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
uv run torchrun --standalone --nproc_per_node=8 \
src/pretrain_encoder.py \
--data mimic_iv \
--data_representation signal \
--objective rectified_flow \
--neural_network trans_continuous_dit \
--task pretrain \
--batch_size 64 --distributed --emaDecoder transformer with BPE tokens:
CUDA_VISIBLE_DEVICES=0,1,2,3 \
uv run torchrun --standalone --nproc_per_node=4 \
src/pretrain_encoder.py \
--data mimic_iv \
--data_representation bpe_symbolic \
--objective autoregressive \
--neural_network trans_discrete_decoder \
--task pretrain \
--batch_size 64 --distributedGeneration:
uv run src/eval_encoder.py \
--data mimic_iv \
--nn_ckpt path/to/checkpoint.pt \
--data_representation signal \
--objective rectified_flow \
--neural_network trans_continuous_dit \
--task generation --emaReconstruction:
uv run src/eval_encoder.py \
--data mimic_iv \
--nn_ckpt path/to/checkpoint.pt \
--data_representation signal \
--objective rectified_flow \
--neural_network trans_continuous_dit \
--task reconstruction --emaForecasting:
uv run src/eval_encoder.py \
--data mimic_iv \
--nn_ckpt path/to/checkpoint.pt \
--data_representation bpe_symbolic \
--objective autoregressive \
--neural_network trans_discrete_decoder \
--task forecasting --forecast_ratio 0.5# Lead-conditioned
... --condition lead --condition_lead 1
# Text-conditioned
... --condition text --text_feature_extractor Qwen/Qwen3-0.6BText-conditioned 12-lead ECG generation (DiT + DDPM):
ECG forecasting (decoder transformer + BPE tokens, 50% context):
| Flag | Description |
|---|---|
--distributed |
Multi-GPU DDP training |
--ema |
Exponential moving average |
--torch_compile |
torch.compile the model |
--wandb |
Log to Weights & Biases |
--augment |
ECG data augmentation |
--optimizer |
adam, adamw, muon |
--lr_schedule |
constant, cosine, inv_sqrt |
--condition |
text, lead, or omit for unconditional |
MIT, except st_mem.py, mlae.py, mtae.py which are CC BY-NC 4.0.

