-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample.py
More file actions
66 lines (50 loc) · 2.04 KB
/
example.py
File metadata and controls
66 lines (50 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pytorch_lightning as pl
import torchunique
import torch
from torch import distributed as dist
import os
import torch
from torch.utils.data import DataLoader, random_split, Dataset
from pytorch_lightning import Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.size = size
self.length = length
def __len__(self):
return self.length
def __getitem__(self, idx):
x = torch.randn(self.size)
y = torch.randint(0, 2, (1,)).float()
return x, y
#####################################################
# Example Usage Beginning
class TestLightningModule(pl.LightningModule):
def __init__(self):
super().__init__()
self.cache_dict: dict = torchunique.Unique(dict())
def predict_step(self, batch, batch_idx, dataloader_idx=0):
self.cache_dict.setdefault("fname", []).append(f"rank{dist.get_rank()} batch_idx{batch_idx:02}")
def on_predict_epoch_end(self):
self.cache_dict.barrier() # Force ddp synchronization using barrie. This is to prevent data inconsistencies over the dist ranks.
cache_dict = self.cache_dict.to_here()["fname"] # Use to_here to get internal objects.
print(sorted(cache_dict), "epoch_end", dist.get_rank()) # The results on each rank are the same.
######################################################
def prepare_data():
dataset = RandomDataset(size=32, length=32)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=1, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=1, num_workers=0)
return train_loader, val_loader
if __name__ == "__main__":
train_loader, val_loader = prepare_data()
model = TestLightningModule()
trainer = Trainer(
accelerator="gpu",
devices=[0, 1],
strategy="ddp",
max_epochs=1,
log_every_n_steps=10,
)
trainer.predict(model, train_loader)