-
Notifications
You must be signed in to change notification settings - Fork 97
Open
Description
Hi, I was getting the followed erro when I executing this code:
from torch.utils.data import Dataset
from sklearn.datasets import fetch_openml
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
class SimpleDataset(Dataset):
def __init__(self, X, y):
super(SimpleDataset, self).__init__()
self.X = X
self.y = y
def __getitem__(self, index):
inputs = torch.tensor(self.X[index, :], dtype=torch.float32)
targets = torch.tensor(int(self.y[index]), dtype=torch.int64)
return inputs, targets
def __len__(self):
return self.X.shape[0]
dataset = SimpleDataset(X, y)
example, label = dataset[0]InvalidIndexError: (tensor(0), slice(None, None, None))The same was fixed when I change the code of the fetch_openml to:
X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)The problem was that whithout the as_frame, scikit will import the data as a DataFrame, not as numpy anymore.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels