Skip to content

Custom Dataset class on Chapter_1 #6

@LucianoBatista

Description

@LucianoBatista

Hi, I was getting the followed erro when I executing this code:

from torch.utils.data import Dataset
from sklearn.datasets import fetch_openml

X, y = fetch_openml("mnist_784", version=1, return_X_y=True)

class SimpleDataset(Dataset):
    def __init__(self, X, y):
        super(SimpleDataset, self).__init__()
        self.X = X
        self.y = y
    
    def __getitem__(self, index):
        inputs = torch.tensor(self.X[index, :], dtype=torch.float32)
        targets = torch.tensor(int(self.y[index]), dtype=torch.int64)
        return inputs, targets

    def __len__(self):
        return self.X.shape[0]

dataset = SimpleDataset(X, y)
example, label = dataset[0]
InvalidIndexError: (tensor(0), slice(None, None, None))

The same was fixed when I change the code of the fetch_openml to:

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

The problem was that whithout the as_frame, scikit will import the data as a DataFrame, not as numpy anymore.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions