My variational autoencoder seems to work for MNIST, but fails on slightly "harder" data.
By "fails" I mean there are at least two apparent problems:
- Very poor reconstruction, for example sample reconstructions from the last epoch on validation set
without any regularization at all.
The last reported losses from console are val_loss=9.57e-5, train_loss=9.83e-5
which I thought would imply exact reconstructions.
- validation loss is low (which does not seem to reflect the reconstruction), and always lower than training loss which is very suspicious.
For MNIST everything looks fine (with less layers!).
I will give as much nformation as I can, since I am not sure what I should provide to help anyone help me.
Firstly, here is the full code
You will notice loss calculation and logging is very simple and straight forward and I can't seem to find what's wrong.
import torch
from torch import nn
import torch.nn.functional as F
from typing import List, Optional, Any
from pytorch_lightning.core.lightning import LightningModule
from Testing.Research.config.ConfigProvider import ConfigProvider
from pytorch_lightning import Trainer, seed_everything
from torch import optim
import os
from pytorch_lightning.loggers import TensorBoardLogger
# import tfmpl
import matplotlib.pyplot as plt
import matplotlib
from Testing.Research.data_modules.MyDataModule import MyDataModule
from Testing.Research.data_modules.MNISTDataModule import MNISTDataModule
from Testing.Research.data_modules.CaseDataModule import CaseDataModule
import torchvision
from Testing.Research.config.paths import tb_logs_folder
from Testing.Research.config.paths import vae_checkpoints_path
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
class VAEFC(LightningModule):
# see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
# for possible upgrades, see https://arxiv.org/pdf/1602.02282.pdf
# https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
# -auto-encoder
def __init__(self, encoder_layer_sizes: List, decoder_layer_sizes: List, config):
super(VAEFC, self).__init__()
self._config = config
self.logger: Optional[TensorBoardLogger] = None
self.save_hyperparameters()
assert len(encoder_layer_sizes) >= 3, "must have at least 3 layers (2 hidden)"
# encoder layers
self._encoder_layers = nn.ModuleList()
for i in range(1, len(encoder_layer_sizes) - 1):
enc_layer = nn.Linear(encoder_layer_sizes[i - 1], encoder_layer_sizes[i])
self._encoder_layers.append(enc_layer)
# predict mean and covariance vectors
self._mean_layer = nn.Linear(encoder_layer_sizes[
len(encoder_layer_sizes) - 2],
encoder_layer_sizes[len(encoder_layer_sizes) - 1])
self._logvar_layer = nn.Linear(encoder_layer_sizes[
len(encoder_layer_sizes) - 2],
encoder_layer_sizes[len(encoder_layer_sizes) - 1])
# decoder layers
self._decoder_layers = nn.ModuleList()
for i in range(1, len(decoder_layer_sizes)):
dec_layer = nn.Linear(decoder_layer_sizes[i - 1], decoder_layer_sizes[i])
self._decoder_layers.append(dec_layer)
self._recon_function = nn.MSELoss(reduction='mean')
self._last_val_batch = {}
def _encode(self, x):
for i in range(len(self._encoder_layers)):
layer = self._encoder_layers[i]
x = F.relu(layer(x))
mean_output = self._mean_layer(x)
logvar_output = self._logvar_layer(x)
return mean_output, logvar_output
def _reparametrize(self, mu, logvar):
if not self.training:
return mu
std = logvar.mul(0.5).exp_()
if std.is_cuda:
eps = torch.FloatTensor(std.size()).cuda().normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
reparameterized = eps.mul(std).add_(mu)
return reparameterized
def _decode(self, z):
for i in range(len(self._decoder_layers) - 1):
layer = self._decoder_layers[i]
z = F.relu((layer(z)))
decoded = self._decoder_layers[len(self._decoder_layers) - 1](z)
# decoded = F.sigmoid(self._decoder_layers[len(self._decoder_layers)-1](z))
return decoded
def _loss_function(self, recon_x, x, mu, logvar, reconstruction_function):
"""
recon_x: generating images
x: origin images
mu: latent mean
logvar: latent log variance
"""
binary_cross_entropy = reconstruction_function(recon_x, x) # mse loss TODO see if mse or cross entropy
# loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kld_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
kld = torch.sum(kld_element).mul_(-0.5)
# KL divergence Kullback–Leibler divergence, regularization term for VAE
# It is a measure of how different two probability distributions are different from each other.
# We are trying to force the distributions closer while keeping the reconstruction loss low.
# see https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73
# read on weighting the regularization term here:
# https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational
# -auto-encoder
return binary_cross_entropy + kld * self._config.regularization_factor
def _parse_batch_by_dataset(self, batch, batch_index):
if self._config.dataset == "toy":
(orig_batch, noisy_batch), label_batch = batch
# TODO put in the noise here and not in the dataset?
elif self._config.dataset == "mnist":
orig_batch, label_batch = batch
orig_batch = orig_batch.reshape(-1, 28 * 28)
noisy_batch = orig_batch
elif self._config.dataset == "case":
orig_batch, label_batch = batch
orig_batch = orig_batch.float().reshape(
-1,
len(self._config.case.feature_list) * self._config.case.frames_per_pd_sample
)
noisy_batch = orig_batch
else:
raise ValueError("invalid dataset")
noisy_batch = noisy_batch.view(noisy_batch.size(0), -1)
return orig_batch, noisy_batch, label_batch
def training_step(self, batch, batch_idx):
orig_batch, noisy_batch, label_batch = self._parse_batch_by_dataset(batch, batch_idx)
recon_batch, mu, logvar = self.forward(noisy_batch)
loss = self._loss_function(
recon_batch,
orig_batch, mu, logvar,
reconstruction_function=self._recon_function
)
# self.logger.experiment.add_scalars("losses", {"train_loss": loss})
tb = self.logger.experiment
tb.add_scalars("losses", {"train_loss": loss}, global_step=self.current_epoch)
# self.logger.experiment.add_scalar("train_loss", loss, self.current_epoch)
if batch_idx == len(self.train_dataloader()) - 2:
# https://pytorch.org/docs/stable/_modules/torch/utils/tensorboard/writer.html#SummaryWriter.add_embedding
# noisy_batch = noisy_batch.detach()
# recon_batch = recon_batch.detach()
# last_batch_plt = matplotlib.figure.Figure() # read https://github.com/wookayin/tensorflow-plot
# ax = last_batch_plt.add_subplot(1, 1, 1)
# ax.scatter(orig_batch[:, 0], orig_batch[:, 1], label="original")
# ax.scatter(noisy_batch[:, 0], noisy_batch[:, 1], label="noisy")
# ax.scatter(recon_batch[:, 0], recon_batch[:, 1], label="reconstructed")
# ax.legend(loc="upper left")
# self.logger.experiment.add_figure(f"original last batch, epoch {self.current_epoch}", last_batch_plt)
# tb.add_embedding(orig_batch, global_step=self.current_epoch, metadata=label_batch)
pass
self.logger.experiment.flush()
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def _plot_batches(self, orig_batch, noisy_batch, label_batch, batch_idx, recon_batch, mu, logvar):
# orig_batch_view = orig_batch.reshape(-1, self._config.case.frames_per_pd_sample,
# len(self._config.case.feature_list))
#
# plt.figure()
# plt.plot(orig_batch_view[11, :, 0].detach().cpu().numpy(), label="feature 0")
# plt.legend(loc="upper left")
# plt.show()
tb = self.logger.experiment
if self._config.dataset == "mnist":
orig_batch -= orig_batch.min()
orig_batch /= orig_batch.max()
recon_batch -= recon_batch.min()
recon_batch /= recon_batch.max()
orig_grid = torchvision.utils.make_grid(orig_batch.view(-1, 1, 28, 28))
val_recon_grid = torchvision.utils.make_grid(recon_batch.view(-1, 1, 28, 28))
tb.add_image("original_val", orig_grid, global_step=self.current_epoch)
tb.add_image("reconstruction_val", val_recon_grid, global_step=self.current_epoch)
label_img = orig_batch.view(-1, 1, 28, 28)
pass
elif self._config.dataset == "case":
orig_batch_view = orig_batch.reshape(-1, self._config.case.frames_per_pd_sample,
len(self._config.case.feature_list)).transpose(1, 2)
recon_batch_view = recon_batch.reshape(-1, self._config.case.frames_per_pd_sample,
len(self._config.case.feature_list)).transpose(1, 2)
# plt.figure()
# plt.plot(orig_batch_view[11, 0, :].detach().cpu().numpy())
# plt.show()
# pass
n_samples = orig_batch_view.shape[0]
n_plots = min(n_samples, 4)
first_sample_idx = 0
# TODO either plotting or data problem
fig, axs = plt.subplots(n_plots, 1)
for sample_idx in range(n_plots):
for feature_idx, (orig_feature, recon_feature) in enumerate(