Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
434 views
in Technique[技术] by (71.8m points)

neural-network - pytorch优化器不更新网络权重时如何跟踪问题?(How can I trace the issue when a pytorch optimizer does not update the network weights?)

import torch.nn as nn
import torch

    class Classifier(nn.Module):
        def __init__(self, input_size, hidden_size, num_classes):
            super(Classifier, self).__init__()
            self.relu = nn.ReLU()
            self.soft = nn.Softmax(dim=0)
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.fc3 = nn.Linear(hidden_size, num_classes)

        def forward(self, x):
            out = self.fc1(x)
            out = self.relu(out)
            out = self.fc2(out)
            out = self.relu(out)
            out = self.fc3(out)
            out = self.soft(out)
            return out

    class Discriminator(nn.Module):
        def __init__(self, input_size, hidden_size):
            super(Discriminator, self).__init__()
            self.relu = nn.LeakyReLU(0.2, inplace=True)
            self.soft = nn.Softmax(dim=0)
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size)
            self.fc3 = nn.Linear(hidden_size, 1)

        def forward(self, x):
            out = self.fc1(x)
            out = self.relu(out)
            out = self.fc2(out)
            out = self.relu(out)
            out = self.fc3(out)
            out = self.soft(out)
            return out

    input_size = 15
    output_size = 9
    D = Discriminator(input_size=input_size+output_size, hidden_size=128)
    C = Classifier(input_size=input_size, hidden_size=128, num_classes=output_size)

    optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_C = torch.optim.Adam(C.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, X1, Y1 in enumerate(dataloader, 1):

            C1 = torch.cat((X1,Y1),dim=1)
            R1 = floatTensor(C1.shape[0], 1).fill_(1.0)

            optimizer_D.zero_grad()
            A1 = D(C1)

            loss = Loss(A1, R1)
            running_loss_D += loss.item()
            loss.backward()
            optimizer_D.step()


            optimizer_C.zero_grad()
            P1 = C(X1)

            loss = Loss(P1, Y1)
            running_loss_C += loss.item()
            loss.backward()
            optimizer_C.step()

I'm in the process of building a GAN architecture consisting of Generator (G), Discriminator (D), and Classifier (C).

(我正在构建由生成器(G),鉴别器(D)和分类器(C)组成的GAN架构。)

The generator is not implemented yet.

(生成器尚未实现。)

The classifier takes data samples (X1) as input and predicts labels (P1) as output.

(分类器将数据样本(X1)作为输入,并预测标签(P1)作为输出。)

The discriminator takes data-label pairs (C1) as input and predicts (A1) if they're real (part of the original dataset) or not (labelled by the classifier or generated by the generator).

(鉴别器将数据标签对(C1)作为输入,并预测(A1)它们是真实的(原始数据集的一部分)还是不真实的(由分类器标记或由生成器生成)。)

At the moment, the classifier is trained by predicting the labels of known data samples (X1 -> Y1), which works fine.

(目前,通过预测已知数据样本的标签(X1-> Y1)来训练分类器,效果很好。)

The discriminator is only given real data samples, so it should be fairly easy since it only has to predict 1s.

(仅向鉴别器提供真实数据样本,因此它应该相当容易,因为它仅需预测1s。)

However, during the training, the loss of the discriminator stays consistently between 4.0 and 4.1 and does not improve at all.

(但是,在训练过程中,鉴别器的损失始终保持在4.0到4.1之间,并且完全没有改善。)

I thought that it has something to do with the graphs of the combined data C1, but I read that "cat" data can be used as input.

(我以为这与组合数据C1的图形有关,但是我读到“猫”数据可以用作输入。)

I have to admit that PyTorch is still quite a black box for me, so I'd be grateful for any advise about how to find the source of the issue.

(我必须承认,PyTorch对我来说仍然是一个黑匣子,因此,对于任何如何找到问题根源的建议,我将不胜感激。)

  ask by hasl3r translate from so

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
等待大神答复

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

1.4m articles

1.4m replys

5 comments

56.8k users

...