Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions config_files/training/config_lorem_ipsum_long_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,22 @@ settings:
num_seen_samples: 0
last_step: -1

collate_fn:
component_key: collate_fn
variant_key: gpt_2_llm_collator
collator:
component_key: collator
variant_key: default_wrapping_collator
config:
sample_key: ${settings.referencing_keys.sample_key}
target_key: ${settings.referencing_keys.target_key}
input_keys:
- ${settings.referencing_keys.sample_key}
sample_keys:
- ${settings.referencing_keys.sample_key}
target_keys:
- ${settings.referencing_keys.target_key}
collate_fns:
- component_key: collate_fn
variant_key: autoregressive
config:
sample_key: ${settings.referencing_keys.sample_key}
target_key: ${settings.referencing_keys.target_key}

train_dataset:
component_key: dataset
Expand Down Expand Up @@ -95,8 +105,8 @@ train_dataloader:
seed: 42
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
collator:
instance_key: collator
pass_type: BY_REFERENCE

test_dataset:
Expand Down Expand Up @@ -134,8 +144,8 @@ test_dataloader:
dataset:
instance_key: test_dataset
pass_type: BY_REFERENCE
collate_fn:
instance_key: collate_fn
collator:
instance_key: collator
pass_type: BY_REFERENCE

eval_dataloaders:
Expand Down
89 changes: 71 additions & 18 deletions src/modalities/checkpointing/stateful/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from modalities.utils.logging import get_logger


class StatefulComponents(Enum):
MODEL = "model"
Expand All @@ -34,13 +36,19 @@ class AppState(Stateful):
https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
"""

def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None):
def __init__(
self,
model: Optional[nn.Module] = None,
optimizer: Optional[Optimizer] = None,
lr_scheduler: Optional[LRScheduler] = None,
):
"""Initializes the AppState object.

Args:
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
model (nn.Module, optional): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer, optional): The optimizer can be either a non-sharded optimizer,
FSDP1 or FSDP2 optimizer.
lr_scheduler (LRScheduler, optional): The lr scheduler used during training. Defaults to None.
"""
self._model = model
self._optimizer = optimizer
Expand All @@ -59,14 +67,41 @@ def is_loaded(self) -> bool:
def model(self) -> nn.Module:
return self._model

@model.setter
def model(self, model: nn.Module) -> None:
"""Sets the model in the AppState object.

Args:
model (nn.Module): The model to set in the AppState object.
"""
self._model = model

@property
def optimizer(self) -> Optimizer:
return self._optimizer

@optimizer.setter
def optimizer(self, optimizer: Optimizer) -> None:
"""Sets the optimizer in the AppState object.

Args:
optimizer (Optimizer): The optimizer to set in the AppState object.
"""
self._optimizer = optimizer

@property
def lr_scheduler(self) -> LRScheduler:
return self._lr_scheduler

@lr_scheduler.setter
def lr_scheduler(self, lr_scheduler: LRScheduler) -> None:
"""Sets the learning rate scheduler in the AppState object.

Args:
lr_scheduler (LRScheduler): The learning rate scheduler to set in the AppState object.
"""
self._lr_scheduler = lr_scheduler

def state_dict(self) -> dict[str, Any]:
"""Returns the state dict of the AppState object.

Expand All @@ -76,12 +111,13 @@ def state_dict(self) -> dict[str, Any]:
# this line automatically manages FSDP FQN's, as well as sets the default
# state dict type to FSDP.SHARDED_STATE_DICT
# model_state_dict, optimizer_state_dict = get_state_dict(self._model, self._optimizer)
sd = {
StatefulComponents.MODEL.value: ModelStateRetriever.get_state_dict(app_state=self),
StatefulComponents.OPTIMIZER.value: OptimizerStateRetriever.get_state_dict(
app_state=self,
),
}
sd = {}
if self._model is not None:
sd[StatefulComponents.MODEL.value] = ModelStateRetriever.get_state_dict(app_state=self)

if self._optimizer is not None:
sd[StatefulComponents.OPTIMIZER.value] = OptimizerStateRetriever.get_state_dict(app_state=self)

if self._lr_scheduler is not None:
sd[StatefulComponents.LR_SCHEDULER.value] = LRSchedulerStateRetriever.get_state_dict(app_state=self)
return sd
Expand All @@ -101,15 +137,32 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
)

ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])
OptimizerStateRetriever.load_state_dict_(
app_state=self,
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
)
if self._model is not None:
ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])

if self._optimizer is not None:
if StatefulComponents.OPTIMIZER.value in state_dict:
OptimizerStateRetriever.load_state_dict_(
app_state=self,
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
)
else:
get_logger(name="app_state").warning(
"Did not load optimizer checkpoint! "
f"Optimizer state dict not found in state_dict: {state_dict.keys()}."
)

if self._lr_scheduler is not None:
LRSchedulerStateRetriever.load_state_dict_(
app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value]
)
if StatefulComponents.LR_SCHEDULER.value in state_dict:
LRSchedulerStateRetriever.load_state_dict_(
app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value]
)
else:
get_logger(name="app_state").warning(
"Did not load lr scheduler checkpoint! "
f"LR scheduler state dict not found in state_dict: {state_dict.keys()}."
)

self._is_loaded = True


Expand Down
19 changes: 18 additions & 1 deletion src/modalities/checkpointing/stateful/app_state_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def get_raw_app_state(
def get_dcp_checkpointed_app_state_(
raw_app_state: AppState,
checkpoint_dir_path: Path,
load_model_checkpoint: bool = True,
load_optimizer_checkpoint: bool = True,
load_lr_scheduler_checkpoint: bool = True,
) -> AppState:
"""Loads the checkpointed state dict into the raw AppState object
(i.e., non-checkpoint loaded AppState) in-place.
Expand All @@ -54,5 +57,19 @@ def get_dcp_checkpointed_app_state_(
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
)
cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank())
cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path)

tmp_app_state = AppStateFactory.get_raw_app_state(
model=raw_app_state.model if load_model_checkpoint else None,
optimizer=raw_app_state.optimizer if load_optimizer_checkpoint else None,
lr_scheduler=raw_app_state.lr_scheduler if load_lr_scheduler_checkpoint else None,
)

cp_loading.load_checkpoint_(app_state=tmp_app_state, checkpoint_dir_path=checkpoint_dir_path)
raw_app_state.model = tmp_app_state.model if tmp_app_state.model is not None else raw_app_state.model
raw_app_state.optimizer = (
tmp_app_state.optimizer if tmp_app_state.optimizer is not None else raw_app_state.optimizer
)
raw_app_state.lr_scheduler = (
tmp_app_state.lr_scheduler if tmp_app_state.lr_scheduler is not None else raw_app_state.lr_scheduler
)
return raw_app_state
17 changes: 10 additions & 7 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
PydanticAppStateType,
PydanticCheckpointSavingExecutionIFType,
PydanticCheckpointSavingStrategyIFType,
PydanticCollateFnIFType,
PydanticCollatorIFType,
PydanticDatasetIFType,
PydanticDeviceMeshIFType,
PydanticFSDP1CheckpointLoadingIFType,
Expand Down Expand Up @@ -313,6 +313,9 @@ class RawAppStateConfig(BaseModel):
class DCPAppStateConfig(BaseModel):
raw_app_state: PydanticAppStateType
checkpoint_dir_path: Path
load_model_checkpoint: bool = True
load_optimizer_checkpoint: bool = True
load_lr_scheduler_checkpoint: bool = True


class PreTrainedHFTokenizerConfig(BaseModel):
Expand Down Expand Up @@ -366,6 +369,11 @@ class PackedMemMapDatasetContinuousConfig(BaseModel):
reuse_last_target: bool = Field(default=True)


class MemMapDatasetIterativeConfig(BaseModel):
raw_data_path: Path
sample_key: str


class PackedMemMapDatasetMegatronConfig(BaseModel):
raw_data_path: Path
block_size: Annotated[int, Field(strict=True, gt=1)]
Expand All @@ -382,16 +390,11 @@ class BatchSamplerConfig(BaseModel):
drop_last: Literal[True] = True


class GPT2LLMCollateFnConfig(BaseModel):
sample_key: str
target_key: str


class LLMDataLoaderConfig(BaseModel):
dataloader_tag: str
dataset: PydanticDatasetIFType
batch_sampler: PydanticSamplerIFType
collate_fn: Optional[PydanticCollateFnIFType] = None
collator: Optional[PydanticCollatorIFType] = None
num_workers: Annotated[int, Field(strict=True, ge=0)]
pin_memory: bool

Expand Down
3 changes: 2 additions & 1 deletion src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC
from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF
from modalities.checkpointing.stateful.app_state import AppState
from modalities.dataloader.collate_fns.collate_if import CollateFnIF
from modalities.dataloader.collate_fns.collate_if import CollateFnIF, CollatorIF
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.inference.text.inference_component import TextInferenceComponent
from modalities.logging_broker.subscriber import MessageSubscriberIF
Expand Down Expand Up @@ -67,6 +67,7 @@ def __get_pydantic_core_schema__(
PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)]
PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)]
PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)]
PydanticCollatorIFType = Annotated[CollatorIF, PydanticThirdPartyTypeIF(CollatorIF)]
PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)]
PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)]
Expand Down
1 change: 1 addition & 0 deletions src/modalities/dataloader/collate_fns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

39 changes: 39 additions & 0 deletions src/modalities/dataloader/collate_fns/autoregressive_collate_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from pydantic import BaseModel

from modalities.dataloader.collate_fns.collate_if import CollateFnIF


class AutoregressiveCollateFnConfig(BaseModel):
sample_key: str
target_key: str


class AutoregressiveCollateFn(CollateFnIF):
"""AutoregressiveCollateFn class to define a collate function for language modeling."""

def __init__(self, sample_key: str, target_key: str):
"""
Initializes the Collator object.

Args:
sample_key (str): The key for accessing the sample data.
target_key (str): The key for accessing the target data.
"""
self.sample_key = sample_key
self.target_key = target_key

def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Process a batch of data.

Args:
batch (dict[str, torch.Tensor]): A dictionary containing tensors of the batch.

Returns:
dict[str, torch.Tensor]: The processed batch with sample and target tensors.
"""
sample_tensor = batch[self.sample_key]
batch[self.sample_key] = sample_tensor[:, :-1]
batch[self.target_key] = sample_tensor[:, 1:]
return batch
19 changes: 18 additions & 1 deletion src/modalities/dataloader/collate_fns/collate_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,28 @@
from modalities.batch import DatasetBatch


class CollatorIF(ABC):
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
"""
Process a batch of data.

Args:
batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing 1-dim tensors.

Returns:
DatasetBatch: The processed batch of data.

Raises:
NotImplementedError: This abstract method should be implemented in a subclass.
"""
raise NotImplementedError


class CollateFnIF(ABC):
"""CollateFnIF class to define a collate function interface."""

@abstractmethod
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch:
def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Process a batch of data.

Expand Down
Loading