A multi-task learning model for predicting various material properties.
The FlexibleMultiTaskModel is designed with a modular and extensible architecture. At its core, it features:
- A Foundation Encoder that processes input features (formula-based, and optionally structure-based) to generate shared representations. This encoder includes mechanisms for multi-modal fusion if structural data is provided.
- A Tanh Activation that is uniformly applied to latent representations at the model level, providing bounded outputs to task heads.
- A collection of Task-specific Heads that take Tanh-activated latent representations from the foundation encoder to make predictions for various tasks, such as:
- Regression (e.g., predicting band gap)
- Classification (e.g., predicting material stability)
- Sequence Prediction (e.g., predicting density of states curves)
Below is a high-level overview of the architecture:
graph TD
%% ---------- Inputs (同一级) ----------
subgraph InputsLayer["Inputs"]
direction TB
GeneralInputs["Formula / Structure<br/>(x_formula, x_structure*)<br/>*optional"]
SequenceDataInputs["Sequence Data<br/>(task_sequence_* data)<br/>*optional"]
end
%% ---------- Foundation encoder ----------
FE["Foundation Encoder<br/>(Shared MLP, Fusion*, Deposit)<br/>*optional"]
%% ---------- Task heads ----------
NonSeqHeads["Regression / Classification Heads"]
SeqHeads["Sequence Heads"]
%% ---------- Edges ----------
GeneralInputs --> FE
FE -- "h_task (for Reg/Class)" --> NonSeqHeads
FE -- "h_task (for Seq)" --> SeqHeads
SequenceDataInputs --> SeqHeads
NonSeqHeads --> Outputs["Outputs (Dictionary)"]
SeqHeads --> Outputs
%% ---------- Styles ----------
classDef io fill:#E0EFFF,stroke:#5C9DFF,stroke-width:2px,color:#000;
classDef main fill:#DFF0D8,stroke:#77B55A,stroke-width:2px,color:#000;
classDef heads fill:#FCF8E3,stroke:#F0AD4E,stroke-width:2px,color:#000;
%% ---------- Class assignments ----------
class GeneralInputs,SequenceDataInputs io
class FE main
class NonSeqHeads,SeqHeads heads
class Outputs io
For a more detailed diagram and in-depth explanation of each component, data flow, and dimensionality, please refer to the Model Architecture Documentation (ARCHITECTURE.md).
- Clone the repository:
git clone https://github.com/yourusername/foundation_model.git
cd foundation_model- Install the package using uv:
uv sync --frozen --all-groupsThis will install all dependencies as defined in the pyproject.toml and uv.lock files, including both production and development dependencies, and ensure exact version matching. This method is preferred for reproducible installations.
If you need to add additional dependencies, use:
uv add <package_name>
# or for development dependencies
uv add --dev <package_name>The primary way to use this model is through the train.py script, which leverages PyTorch Lightning's CLI. This allows for flexible configuration via YAML files and command-line overrides.
To train the model, you will typically use a command like:
# From the project root directory
python -m foundation_model.scripts.train --config path/to/your/config.yaml [OTHER_CLI_OVERRIDES]Or, if you are in src/foundation_model/scripts/:
python train.py --config path/to/your/config.yaml [OTHER_CLI_OVERRIDES]- Replace
path/to/your/config.yamlwith the path to your experiment's configuration file. [OTHER_CLI_OVERRIDES]can be used to override specific parameters within your YAML file (e.g.,--trainer.max_epochs=50).
Model configuration is primarily handled through YAML files. These files define the model architecture (FlexibleMultiTaskModel), data loading (CompoundDataModule), PyTorch Lightning trainer settings, and any callbacks.
You can find examples of configuration files in the samples/generated_configs/ directory (e.g., generated_model_config.yaml) and more specific model component configurations in configs/model_configs/ (e.g., base_model.yaml).
For detailed examples of different configurations (such as pre-training, fine-tuning, using specific model components like different sequence heads) and how to effectively use command-line overrides, please refer to the ## Quick Examples section below.
- Multi‑task learning for material property prediction
- Dual‑modality support: formula descriptors + optional structure descriptors
- Pre‑training & downstream in one model
- Pre‑train losses: contrastive, cross‑reconstruction, masked‑feature, property supervision
--pretrainflag toggles extra losses; same architecture used for fine‑tune
- Flexible sequence heads:
rnn,vec,transformer,tcn,hybrid(Flash‑Attention inside) - Encoder control:
--freeze_encoderto lock shared layers - Handles missing values via masking & modality dropout
- Comprehensive logging and visualization tools
- Configurable data splitting strategies
- Early stopping and model checkpointing
To train the FlexibleMultiTaskModel on supervised tasks with different loss scales, we rely on a learnable uncertainty term inspired by Kendall, Gal, and Cipolla (CVPR 2018):
-
Task heads produce raw losses. Each supervised task
$t$ supplies the head-specific loss$\mathcal{L}_t$ (e.g., MSE or cross-entropy). -
Per-task static scaling. Each task configuration exposes
loss_weight(default1.0) to scale that task’s raw loss before further combination. -
Optional learnable uncertainty. When
enable_learnable_loss_balancerisTrue, the model maintains a per-task parameter $\log \sigma_tand scales the contribution as $\mathcal{L}'_{t} = \tfrac{1}{2}\,\texttt{loss\_weight}_t\,\exp(-2 \log \sigma_t)\,\mathcal{L}_t + \log \sigma_t. This lets the model down-weight noisier objectives while respecting explicit task priorities. -
Fallback when disabled. If the balancer is disabled or a task does not expose $\log \sigma_t
, the contribution becomes $\mathcal{L}'_{t} = \texttt{loss\_weight}_t \cdot \mathcal{L}_t. - Total loss. The overall objective is the sum of all task contributions.
See ARCHITECTURE.md for a deeper walk-through of the loss pipeline and implementation hooks.
- Supports multiple material properties
- Handles missing values through masking
- Configurable data splitting ratios
- Property-specific sampling fractions
When providing data through the attributes_source (typically an attributes.csv file or a Pandas DataFrame), it is essential to configure your tasks to point to the correct data columns. This is done via the data_column field in each task's configuration and, for sequence tasks, the optional steps_column.
1. Primary Data Column (data_column):
- Applies to:
RegressionTaskConfig,ClassificationTaskConfig,SequenceTaskConfig. - Field in Config:
data_column: str - Purpose: Specifies the exact name of the column in your
attributes_sourcefile/DataFrame that contains the primary data for the task.- For regression and classification tasks: This column holds the target values.
- For sequence tasks: This column holds the main sequence data (e.g., the y-values of a spectrum or time series).
- Example:
# In your task configuration list: # - name: "band_gap_prediction" # type: REGRESSION # data_column: "actual_band_gap_values" # Points to 'actual_band_gap_values' column in attributes.csv # ... other task parameters # # - name: "xrd_pattern_analysis" # type: SEQUENCE # data_column: "xrd_intensity_series" # Points to 'xrd_intensity_series' column for sequence y-values # steps_column: "xrd_two_theta_angles" # See below # ... other task parameters
2. Sequence Steps Column (steps_column):
- Applies to:
SequenceTaskConfigonly. - Field in Config:
steps_column: str(optional, defaults to"") - Purpose: Specifies the exact name of the column in your
attributes_sourcethat contains the steps or x-axis values corresponding to the sequence data (e.g., temperature points, time steps, 2-theta angles). - Behavior:
- If specified and the column exists: This data will be loaded and passed to the sequence task head (available in
temps_dictwithinCompoundDataset). - If specified but the column does not exist in
attributes_source: AValueErrorwill be raised during data loading, as this explicitly requested data is missing. - If left as an empty string (default): No specific steps column is loaded.
CompoundDatasetwill provide a placeholder (e.g., zeros) for the steps data. The sequence model head should be prepared to handle this (e.g., by assuming a default step interval liketorch.arange(sequence_length)).
- If specified and the column exists: This data will be loaded and passed to the sequence task head (available in
- Example:
# - name: "temperature_dependent_property" # type: SEQUENCE # data_column: "property_vs_temp_series" # steps_column: "temperature_points" # Points to 'temperature_points' column for x-axis of the sequence # ...
Important Considerations:
- Exact Column Names: The values provided for
data_columnandsteps_columnmust exactly match the column headers in yourattributes_sourcedata file. - Data Format in CSV: If your
attributes_sourceis a CSV file and a column contains list-like data (e.g., for sequence series, multi-dimensional regression targets, or sequence steps), these should be stored as strings that Python'sast.literal_evalcan parse (e.g.,"[1.0, 2.5, 3.0]"). - Missing
data_columnData: If adata_columnis specified in a task config but the column is not found inattributes_source, or if the column exists but contains many NaNs, the corresponding samples for that task will be masked out (i.e., not used for training or loss calculation for that specific task). Placeholders (e.g., zeros or -1 for classification) will be used for the target values iny_dict. attributes_sourceisNone:- If
attributes_sourceis not provided toCompoundDataModule(typically for prediction scenarios where only input features likeformula_desc_sourceare given):- Any task specifying a
data_columnwill have its targets treated as placeholders byCompoundDataset. - If a
SequenceTaskConfigspecifies a non-emptysteps_column,CompoundDataModulewill raise aValueErrorbecauseattributes_sourceis required to load this essential steps data, even for prediction.
- Any task specifying a
- If
This explicit column mapping approach provides clarity and flexibility in defining how your task configurations link to your data.
The train.py script utilizes PyTorch Lightning's CLI (see official documentation). This allows for comprehensive configuration of the model (FlexibleMultiTaskModel) and data module (CompoundDataModule) through YAML files, with parameters passed directly to their __init__ methods via an init_args block. You can also override these YAML settings using command-line arguments.
You can also adjust tasks programmatically. For example, to swap in two new heads after loading a checkpoint:
model.remove_tasks("old_regression")
model.add_task(new_reg_cfg, new_cls_cfg) # accepts multiple configs in one callIt's recommended to start with a base YAML configuration (e.g., samples/generated_configs/generated_model_config.yaml or configs/model_configs/base_model.yaml adapted to the init_args structure) and then customize it.
Command-Line Overrides: To override a parameter, you specify its full path. For example:
--model.init_args.shared_block_optimizer.freeze_parameters=True--trainer.max_epochs=50
Note: Low-Rank Adaptation (LoRA) support has been removed from the codebase. Any legacy configuration keys such as lora_rank or lora_enabled are currently ignored by the model.
This example runs standard supervised training.
python -m foundation_model.scripts.train --config path/to/your/config.yaml \
--trainer.max_epochs 60Corresponding YAML snippet (config.yaml):
model:
class_path: foundation_model.models.FlexibleMultiTaskModel
init_args:
# ... other shared_block_dims ...
task_configs:
- name: example_task_1
type: REGRESSION
dims: [128, 64, 1]
data_column: my_property
loss_weight: 0.8 # Optional per-task scaling (defaults to 1.0)
# - name: another_task
# ...
# loss_weight: 1.0
trainer:
max_epochs: 60This example demonstrates fine-tuning where the main encoder is frozen. This is achieved by setting freeze_parameters: true in the shared_block_optimizer configuration. A sequence task (e.g., 'temp_curve') uses an RNN head.
# Assumes config.yaml is set for fine-tuning and includes a sequence task configured with subtype "rnn".
python -m foundation_model.scripts.train --config path/to/your/config.yaml \
--model.init_args.shared_block_optimizer.freeze_parameters=TrueYAML snippet (config.yaml):
# In your config.yaml
# ...
model:
class_path: foundation_model.models.FlexibleMultiTaskModel
init_args:
# ...
shared_block_optimizer:
# ...
freeze_parameters: true # This freezes the shared encoder
task_configs:
- name: "temp_curve" # Example sequence task
type: "SEQUENCE"
subtype: "rnn"
# ... other settings for temp_curve ...
# ... other tasks ...
# ...Full fine-tune: encoder is not frozen (freeze_parameters: false). A sequence task uses a Transformer head, configured in YAML.
# Assumes config.yaml is set for fine-tuning.
# The relevant sequence task should be configured with subtype "transformer" in YAML.
python -m foundation_model.scripts.train --config path/to/your/transformer_encoder.yaml \
--model.init_args.shared_block_optimizer.freeze_parameters=FalseYAML snippet (transformer_encoder.yaml):
# In transformer_encoder.yaml
# ...
model:
class_path: foundation_model.models.FlexibleMultiTaskModel
init_args:
# ...
shared_block_dims: [128, 256] # Input dimension -> fallback latent dimension
encoder_config:
type: transformer
d_model: 256
num_layers: 4
nhead: 4
dropout: 0.1
use_cls_token: true
apply_layer_norm: true
shared_block_optimizer:
# ...
freeze_parameters: false # Encoder is trainable
task_configs:
- name: "temp_dos_transformer" # Example sequence task
type: "SEQUENCE"
subtype: "transformer" # Key: Use Transformer head
d_in: 256 # Input dimension (Tanh-activated latent from encoder)
d_model: 256 # Transformer d_model for the head
nhead: 4 # Transformer nhead
# ... other transformer parameters (num_encoder_layers, dim_feedforward, etc.)
# ... other settings for this task ...
# ... other tasks ...
# ...ℹ️ How the Transformer encoder trains tokens
- With
use_cls_token: truethe task heads consume the contextualised[CLS]embedding. Even though the other feature tokens are not pooled explicitly, they still receive gradients through the attention connections to the classifier query because their keys and values inform every[CLS]update.- Setting
use_cls_token: falseswitches to mean pooling so every token is exposed directly to the supervised loss without relying on masked pre-training; gradients are distributed evenly across the sequence length.- Both aggregation modes therefore keep all feature tokens in play for supervised objectives, and you can choose the variant that best matches your task assumptions.
Similar to full fine-tune (encoder trainable). A sequence task uses a 'vector' head, configured in YAML.
# Assumes config.yaml is set for fine-tuning.
# The relevant sequence task should be configured with subtype "vec" in YAML.
python -m foundation_model.scripts.train --config path/to/your/vec_head_config.yaml \
--model.init_args.shared_block_optimizer.freeze_parameters=FalseYAML snippet (vec_head_config.yaml):
# In vec_head_config.yaml
# ...
model:
class_path: foundation_model.models.FlexibleMultiTaskModel
init_args:
# ...
shared_block_optimizer:
# ...
freeze_parameters: false # Encoder is trainable
task_configs:
- name: "temp_dos_vector" # Example sequence task
type: "SEQUENCE"
subtype: "vec" # Key: Use fixed vector output head
d_in: 512 # Input dimension (Tanh-activated latent from encoder)
seq_len: 256 # Desired output sequence length for the vector
# ... other vec head parameters ...
# ... other settings for this task ...
# ...These examples should provide a more accurate reflection of how to use train.py with your LightningCLI setup.
This section demonstrates how to train the FlexibleMultiTaskModel using local data files (CSV) and a YAML configuration, highlighting how to explore scaling laws by adjusting data availability for specific tasks using CompoundDataModule's task_masking_ratios.
1. Prepare Dummy Data Files:
Create the following CSV files in your project, for example, under an examples/data/ directory:
-
examples/data/dummy_formula_descriptors.csv:id,comp_feat_1,comp_feat_2 mat_1,0.1,0.5 mat_2,0.2,0.6 mat_3,0.3,0.7 mat_4,0.4,0.8 mat_5,0.5,0.9 mat_6,0.15,0.55 mat_7,0.25,0.65 mat_8,0.35,0.75 mat_9,0.45,0.85 mat_10,0.55,0.95
-
examples/data/dummy_attributes.csv: This file defines the tasks, their target values, and the train/validation/test split. Column names are generic; the mapping to tasks is done in the YAML configuration.id,target_A,series_B_y,series_B_x,split mat_1,1.0,"[0.1,0.2,0.3]","[10,20,30]",train mat_2,2.0,"[0.4,0.5,0.6]","[10,20,30]",train mat_3,3.0,"[0.7,0.8,0.9]","[10,20,30]",train mat_4,1.5,"[0.15,0.25,0.35]","[10,20,30]",train mat_5,2.5,"[0.45,0.55,0.65]","[10,20,30]",train mat_6,3.5,"[0.75,0.85,0.95]","[10,20,30]",train mat_7,4.0,"[0.9,1.0,1.1]","[10,20,30]",val mat_8,4.5,"[1.1,1.2,1.3]","[10,20,30]",val mat_9,5.0,"[1.2,1.3,1.4]","[10,20,30]",test mat_10,5.5,"[1.3,1.4,1.5]","[10,20,30]",test
Note: For sequence tasks, series data and x-axis (steps) data are represented as strings of lists.
CompoundDatasetwill parse these.
2. Create YAML Configuration File:
Create a YAML file, for example, examples/configs/demo_scaling_law.yaml. This example uses the init_args structure expected by LightningCLI.
# examples/configs/demo_scaling_law.yaml
experiment_name: "scaling_law_demo" # Can be overridden by CLI
seed_everything: 42 # For reproducibility
# --- Model Configuration (for FlexibleMultiTaskModel) ---
model:
class_path: foundation_model.models.FlexibleMultiTaskModel
init_args:
shared_block_dims: [2, 128, 256] # Input (dummy_formula_descriptors.csv has 2 features) -> hidden -> latent
task_configs:
- name: "task_A"
type: "REGRESSION"
data_column: "target_A" # Maps to 'target_A' column in dummy_attributes.csv
dims: [256, 64, 1] # Tanh-activated latent_dim -> hidden -> output
optimizer: { lr: 0.001, scheduler_type: "None" }
- name: "task_B"
type: "SEQUENCE"
subtype: "rnn"
data_column: "series_B_y" # Maps to 'series_B_y' for y-values
steps_column: "series_B_x" # Maps to 'series_B_x' for x-values (steps)
d_in: 256 # Should match the model's latent_dim for sequence heads
hidden: 64
cell: "gru"
optimizer: { lr: 0.001, scheduler_type: "None" }
# Add other model.init_args as needed, e.g.:
# encoder_config:
# type: mlp
# hidden_dims: [128, 256]
# norm: true
# residual: false
shared_block_optimizer: { lr: 0.001, scheduler_type: "None", freeze_parameters: false }
# --- Data Module Configuration (for CompoundDataModule) ---
data: # Renamed from datamodule for consistency with LightningCLI v2.0+ common practice
class_path: foundation_model.data.CompoundDataModule
init_args:
formula_desc_source: "examples/data/dummy_formula_descriptors.csv"
attributes_source: "examples/data/dummy_attributes.csv"
task_configs: ${model.init_args.task_configs} # Dynamically uses task_configs from model
batch_size: 2
num_workers: 0
# train_ratio, val_ratio, test_split are used if 'split' column is NOT in attributes_source
# val_split: 0.1
# test_split: 0.1
# random_seed: 42
# task_masking_ratios: 0.9 # Or provide {"task_A": 0.9, "task_B": 0.5} for per-task control
task_masking_ratios:
task_A: 1.0 # Experiment with this: 1.0, 0.5, 0.25 etc. for task_A
# task_B: 1.0 # Can also apply to other tasks
# --- Trainer Configuration (PyTorch Lightning) ---
trainer:
default_root_dir: "results/logs/${experiment_name}" # Organizes logs by experiment name
max_epochs: 20 # Adjust as needed for a meaningful demo
accelerator: "cpu"
devices: 1
logger:
- class_path: lightning.pytorch.loggers.CSVLogger
init_args:
save_dir: "${trainer.default_root_dir}"
name: "" # Logs will be in ${trainer.default_root_dir}/version_X
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: "${trainer.default_root_dir}"
name: ""
# callbacks:
# - class_path: lightning.pytorch.callbacks.ModelCheckpoint
# init_args:
# monitor: "val_total_loss" # Or a specific task's validation loss
# mode: "min"
# - class_path: lightning.pytorch.callbacks.EarlyStopping
# init_args:
# monitor: "val_total_loss"
# patience: 5
# mode: "min"The train.py script, using LightningCLI, will parse this YAML. Ensure train.py is set up to correctly pass init_args to the respective classes.
3. Run Training:
Assuming you have a training script (e.g., train_flexible.py) that uses PyTorch Lightning's CLI or a similar mechanism to parse the YAML and CLI arguments:
python src/foundation_model/scripts/train_flexible.py --config examples/configs/demo_scaling_law.yaml(If using the existing train.py, it would require significant modification to load FlexibleMultiTaskModel and CompoundDataModule from such a YAML configuration.)
4. Demonstrating Scaling Law with task_masking_ratios:
The task_masking_ratios parameter in CompoundDataModule (set via the YAML) controls the fraction of valid (non-NaN) samples used for each specified task during training. A ratio of 1.0 uses all valid samples, 0.5 uses 50%, and so on. This allows you to simulate different dataset sizes for specific tasks.
To observe a scaling law for task_A:
- Run 1 (Full Data for task_A):
In
demo_scaling_law.yaml, ensuretask_masking_ratios: { task_A: 1.0 }. Train the model and note the final validation loss fortask_A. - Run 2 (Reduced Data for task_A):
Modify
demo_scaling_law.yamltotask_masking_ratios: { task_A: 0.5 }(using 50% oftask_A's valid training data). Retrain the model (preferably from scratch or ensure fair comparison) and note the final validation loss fortask_A. - Run 3 (Further Reduced Data for task_A):
Modify to
task_masking_ratios: { task_A: 0.2 }(using 20% oftask_A's valid training data). Retrain and note the loss.
Expected Observation: Generally, as the task_masking_ratios for task_A decreases (less data used), the final validation loss for task_A is expected to be higher, demonstrating the scaling law principle that model performance often improves with more data. Plotting these losses against the data fraction (1.0, 0.5, 0.2) can visualize this relationship.
This setup provides a controlled way to study the impact of data quantity on individual task performance within a multi-task learning framework.
Update history has been moved to CHANGES.md.