Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
287 views
in Technique[技术] by (71.8m points)

python - Data Augmentation in PyTorch

I am a little bit confused about the data augmentation performed in PyTorch. Now, as far as I know, when we are performing data augmentation, we are KEEPING our original dataset, and then adding other versions of it (Flipping, Cropping...etc). But that doesn't seem like happening in PyTorch. As far as I understood from the references, when we use data.transforms in PyTorch, then it applies them one by one. So for example:

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

Here , for the training, we are first randomly cropping the image and resizing it to shape (224,224). Then we are taking these (224,224) images and horizontally flipping them. Therefore, our dataset is now containing ONLY the horizontally flipped images, so our original images are lost in this case.

Am I right? Is this understanding correct? If not, then where do we tell PyTorch in this code above (taken from Official Documentation) to keep the original images and resize them to the expected shape (224,224)?

Thanks

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

I assume you are asking whether these data augmentation transforms (e.g. RandomHorizontalFlip) actually increase the size of the dataset as well, or are they applied on each item in the dataset one by one and not adding to the size of the dataset.

Running the following simple code snippet we could observe that the latter is true, i.e. if you have a dataset of 8 images, and create a PyTorch dataset object for this dataset when you iterate through the dataset, the transformations are called on each data point, and the transformed data point is returned. So for example if you have random flipping, some of the data points are returned as original, some are returned as flipped (e.g. 4 flipped and 4 original). In other words, by one iteration through the dataset items, you get 8 data points(some flipped and some not). [Which is at odds with the conventional understanding of augmenting the dataset(e.g. in this case having 16 data points in the augmented dataset)]

class experimental_dataset(Dataset):

    def __init__(self, data, transform):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data.shape[0])

    def __getitem__(self, idx):
        item = self.data[idx]
        item = self.transform(item)
        return item

    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

x = torch.rand(8, 1, 2, 2)
print(x)

dataset = experimental_dataset(x,transform)

for item in dataset:
    print(item)

Results: (The little differences in floating points are caused by transforming to pil image and back)

Original dummy dataset:

tensor([[[[0.1872, 0.5518],
          [0.5733, 0.6593]]],


    [[[0.6570, 0.6487],
      [0.4415, 0.5883]]],


    [[[0.5682, 0.3294],
      [0.9346, 0.1243]]],


    [[[0.1829, 0.5607],
      [0.3661, 0.6277]]],


    [[[0.1201, 0.1574],
      [0.4224, 0.6146]]],


    [[[0.9301, 0.3369],
      [0.9210, 0.9616]]],


    [[[0.8567, 0.2297],
      [0.1789, 0.8954]]],


    [[[0.0068, 0.8932],
      [0.9971, 0.3548]]]])

transformed dataset:

tensor([[[0.1843, 0.5490],
     [0.5725, 0.6588]]])
tensor([[[0.6549, 0.6471],
     [0.4392, 0.5882]]])
tensor([[[0.5647, 0.3255],
         [0.9333, 0.1216]]])
tensor([[[0.5569, 0.1804],
         [0.6275, 0.3647]]])
tensor([[[0.1569, 0.1176],
         [0.6118, 0.4196]]])
tensor([[[0.9294, 0.3333],
         [0.9176, 0.9608]]])
tensor([[[0.8549, 0.2275],
         [0.1765, 0.8941]]])
tensor([[[0.8902, 0.0039],
         [0.3529, 0.9961]]])

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...