I train ResNet34 on CIFAR dataset. For a certain reason, I need to convert the dataset into TensorDataset
.
My solution is based on this: https://stackoverflow.com/a/44475689/15072863 with some differences (maybe they are critical, but I don't see why).
It looks I'm not doing this correctly.
Train loader:
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
train_ds = torchvision.datasets.CIFAR10('/files/', train=True, transform=transform_train, download=True)
xs, ys = [], []
for x, y in train_ds:
xs.append(x)
ys.append(y)
# 1) Standard Version
# cifar_train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers)
# 2) TensorDataset version, seems to be incorrect
cifar_tensor_ds = TensorDataset(torch.stack(xs), torch.tensor(ys, dtype=torch.long))
cifar_train_loader = DataLoader(cifar_tensor_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers)
I don't think it matters, but test loader is defined as usual:
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
cifar_test_loader = DataLoader(
torchvision.datasets.CIFAR10('/files/', train=False, transform=transform_test, download=True),
batch_size=batch_size_test, shuffle=False, num_workers=num_workers)
I know that something is wrong with how I use TensorDataset
, since;
- With
TensorDataset
I achieve 100% train accuracy, 80% test accuracy
- With standard Dataset I achieve 99% train accuracy (never 100%), 90% test accuracy.
So, what am I doing wrong?
P.S.: My final goal is to split the dataset into 10 datasets based on their class. Is there a better way to do this? Of course, I can define my subclass of DataSet, but manually splitting it and creating TensorDataset
's seemed to be simpler.
question from:
https://stackoverflow.com/questions/65925371/pytorch-convert-cifar-dataset-to-tensordataset