1. Initializing hidden states
In your source code you are using init_hidden_encoder
and init_hidden_decoder
functions to zero hidden states of both recurrent units in every forward pass.
In PyTorch you don't have to do that, if no initial hidden state is passed to RNN-cell (be it LSTM, GRU or RNN from the ones currently available by default in PyTorch), it is implicitly fed with zeroes.
So, to obtain the same code as your initial solution (which simplifies next parts), I will scrap unneeded parts, which leaves us with the model seen below:
class LSTM(nn.Module):
def __init__(self, input_dim, latent_dim, num_layers):
super(LSTM, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.num_layers = num_layers
self.encoder = nn.LSTM(self.input_dim, self.latent_dim, self.num_layers)
self.decoder = nn.LSTM(self.latent_dim, self.input_dim, self.num_layers)
def forward(self, input):
# Encode
_, (last_hidden, _) = self.encoder(input)
encoded = last_hidden.repeat(5, 1, 1)
# Decode
y, _ = self.decoder(encoded)
return torch.squeeze(y)
Addition of torch.squeeze
We don't need any superfluous dimensions (like the 1 in [5,1,1]).
Actually, it's the clue to your results equal to 0.2
Furthermore, I left input reshape out of the network (in my opinion, network should be fed with input ready to be processed), to separate strictly both tasks (input preparation and model itself).
This approach gives us the following setup code and training loop:
model = LSTM(input_dim=1, latent_dim=20, num_layers=1)
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
y = torch.Tensor([[0.0], [0.1], [0.2], [0.3], [0.4]])
# Sequence x batch x dimension
x = y.view(len(y), 1, -1)
while True:
y_pred = model(x)
optimizer.zero_grad()
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
print(y_pred)
Whole network is identical to yours (for now), except it is more succinct and readable.
2. What we want, describing network changes
As your provided Keras code indicates, what we want to do (and actually you are doing it correctly) is to obtain last hiddden state from the encoder (it encodes our entire sequence) and decode the sequence from this state to obtain the original one.
BTW. this approach is called sequence to sequence or seq2seq for short (often used in tasks like language translation). Well, maybe a variation of that approach, but I would classify it as that anyway.
PyTorch provides us the last hidden state as a separate return variable from RNNs family.
I would advise against yours encoded[-1]
. The reason for it would be bidirectional and multilayered approach. Say, you wanted to sum bidirectional output, it would mean a code along those lines
# batch_size and hidden_size should be inferred cluttering the code further
encoded[-1].view(batch_size, 2, hidden_size).sum(dim=1)
And that's why the line _, (last_hidden, _) = self.encoder(input)
was used.
3. Why does the output converge to 0.2?
Actually, it was a mistake on your side and only in the last part.
Output shapes of your predictions and targets:
# Your output
torch.Size([5, 1, 1])
# Your target
torch.Size([5, 1])
If those shapes are provided, MSELoss, by default, uses argument size_average=True
. And yes, it averages your targets and your output, which essentially calculates loss for the average of your tensor (around 2.5 at the beginning) and average of your target which is 0.2.
So the network converges correctly, but your targets are wrong.
3.1 First and wrong solution
Provide MSELoss with argument reduction="sum", though it's really temporary and works accidentally.
Network, at first, will try to get all of the outputs to be equal to sum (0 + 0.1 + 0.2 + 0.3 + 0.4 = 1.0), at first with semi-random outputs, after a while it will converge to what you want, but not for the reasons you want!.
Identity function is the easiest choice here, even for summation (as your input data is really simple).
3.2 Second and correct solution.
Just pass appropriate shapes to loss function, e.g. batch x outputs
, in your case, the final part would look like this:
model = LSTM(input_dim=1, latent_dim=20, num_layers=1)
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
y = torch.Tensor([0.0, 0.1, 0.2, 0.3, 0.4])
x = y.view(len(y), 1, -1)
while True:
y_pred = model(x)
optimizer.zero_grad()
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
print(y_pred)
Your target is one dimensional (as batch is of size 1) and so is your output (after squeezing unnecessary dimensions).
I changed Adam's parameters to defaults as it converges faster that way.
4. Final working code
For brevity, here is the code and results:
import torch
import torch.nn as nn
import torch.optim as optim
class LSTM(nn.Module):
def __init__(self, input_dim, latent_dim, num_layers):
super(LSTM, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.num_layers = num_layers
self.encoder = nn.LSTM(self.input_dim, self.latent_dim, self.num_layers)
self.decoder = nn.LSTM(self.latent_dim, self.input_dim, self.num_layers)
def forward(self, input):
# Encode
_, (last_hidden, _) = self.encoder(input)
# It is way more general that way
encoded = last_hidden.repeat(input.shape)
# Decode
y, _ = self.decoder(encoded)
return torch.squeeze(y)
model = LSTM(input_dim=1, latent_dim=20, num_layers=1)
loss_function = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
y = torch.Tensor([0.0, 0.1, 0.2, 0.3, 0.4])
x = y.view(len(y), 1, -1)
while True:
y_pred = model(x)
optimizer.zero_grad()
loss = loss_function(y_pred, y)
loss.backward()
optimizer.step()
print(y_pred)
And here are the results after ~60k steps (it is stuck after ~20k steps actually, you may want to improve your optimization and play around with hidden size for better results):
step=59682
tensor([0.0260, 0.0886, 0.1976, 0.3079, 0.3962], grad_fn=<SqueezeBackward0>)
Additionally, L1Loss (a.k.a Mean Absolute Error) may get better results in this case:
step=10645
tensor([0.0405, 0.1049, 0.1986, 0.3098, 0.4027], grad_fn=<SqueezeBackward0>)
Tuning and correct batching of this network is left for you, hope you'll have some fun now and you get the idea. :)
PS. I repeat entire shape of input sequence, as it's more general approach and should work with batches and more dimensions out of the box.