-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessing.py
More file actions
41 lines (30 loc) · 1.73 KB
/
preprocessing.py
File metadata and controls
41 lines (30 loc) · 1.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from torchvision.transforms.v2 import RandomAffine
def get_transform(aug_type='standard'):
if aug_type == 'augmented':
return transforms.Compose([transforms.RandomRotation(30),
RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
else:
return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5, ))])
def load_datasets(config, transform_type='standard'):
g = torch.Generator().manual_seed(42)
transform = get_transform(transform_type)
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
# Use random indices to avoid sampling bias from ordered datasets
train_indices = torch.randperm(len(train_dataset), generator=g)[:config['data']['subset_train']]
test_indices = torch.randperm(len(test_dataset,), generator=g)[:config['data']['subset_test']]
train_subset = Subset(train_dataset, train_indices)
test_subset = Subset(test_dataset, test_indices)
train_subset.transform_type = transform_type
test_subset.transform_type = transform_type
return train_subset, test_subset
def get_dataloaders(train_subset, test_subset, batch_size):
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)
return train_loader, test_loader