Conversation
This commit enhances the training pipeline with more flexible hardware
support and a test run feature, addressing your feedback for seamless
operation across different setups and easier debugging.
Key changes:
1. **Flexible Fabric Configuration:**
* `Fabric` initialization in `picolm.py` now adapts to the
number of devices specified:
* `devices=0`: Uses CPU (`accelerator="cpu"`, `strategy="auto"`).
* `devices=1`: Uses a single GPU (`accelerator="cuda"`, `strategy="auto"`).
* `devices>1`: Uses multiple GPUs with FSDP
(`strategy=FSDPStrategy(state_dict_type="full")`).
* Precision and `torch._dynamo.config.optimize_ddp` settings are
applied conditionally based on the configuration.
2. **Test Run Feature (`--test-run`):**
* A `--test-run` CLI flag has been added to `picolm.py`.
* When enabled, the training uses a minimal, in-memory
`IterableDataset` and a highly restricted `TrainingMeta`
(e.g., few steps, small batch size) for a quick pipeline check.
* W&B logging is disabled (`mode="disabled"`) during test runs.
3. **DataLoader Adaption for Test Runs:**
* In `src/pico/train.py`, `DataLoader` `num_workers` are
adjusted for test runs: 0 for CPU, 1 for CUDA, to ensure
compatibility and stability with in-memory datasets.
4. **Serialization Consistency:**
* The existing `pico.serialization.load` remains in use for the
inference path (`run_command`), compatible with checkpoints
saved by `fabric.save` due to consolidated saving.
These changes allow the training pipeline to run on CPU, single GPU, or
multi-GPU (FSDP) setups seamlessly and provide a convenient way to
perform quick, local tests of the training logic.
There was a problem hiding this comment.
Pull Request Overview
This PR introduces flexible test-run functionality and updates the training pipeline to use Fabric for distributed training, replacing the legacy DataParallel/GradScaler approach.
- Replace device-specific model setup with Fabric-based setup.
- Introduce a test_run mode that adjusts DataLoader settings and training metadata.
- Update dependency management and training command configuration in picolm.py.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| src/pico/train.py | Refactored training loop to utilize Fabric and removed legacy device handling using DataParallel and GradScaler; adjusted DataLoader num_workers logic based on test run and Fabric accelerator. |
| pyproject.toml | Added the lightning dependency required for Fabric integration. |
| picolm.py | Updated training command to configure Fabric based on device count and implemented test run overrides with minimal dataset and training metadata. |
| y = data["y"] | ||
|
|
||
| with torch.autocast(device.type, dtype=torch.bfloat16): | ||
| with torch.autocast("cuda", dtype=torch.bfloat16): |
There was a problem hiding this comment.
The autocast context is hard-coded to 'cuda', which may cause issues when running on CPU. Consider using fabric.accelerator to dynamically set the device context to ensure compatibility with both CPU and GPU runs.
| with torch.autocast("cuda", dtype=torch.bfloat16): | |
| fabric = Fabric() | |
| with torch.autocast(fabric.device.type, dtype=torch.bfloat16): |
| train_dataset, | ||
| validation_dataset=validation_dataset, | ||
| training_meta=training_meta, | ||
| devices=devices, # This argument is no longer used by pico.train.train but is kept for compatibility with fabric init |
There was a problem hiding this comment.
[nitpick] Since the 'devices' argument is not used in the pico.train.train function, consider removing it from the call to reduce potential confusion and improve maintainability.
| devices=devices, # This argument is no longer used by pico.train.train but is kept for compatibility with fabric init |
No description provided.