Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.6k views
in Technique[技术] by (71.8m points)

computer vision - ValueError: Target size (torch.Size([10, 1])) must be the same as input size (torch.Size([10, 2]))

A binary classification problem with Batch Size = 10. Trying to use torch.nn.BCEWithLogitsLoss().

~Anaconda3envs
otebooklibsite-packagesorch
nfunctional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2578 
   2579     if not (target.size() == input.size()):
-> 2580         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2581 
   2582     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)

ValueError: Target size (torch.Size([1, 10])) must be the same as input size (torch.Size([10, 2]))

Here is my training code:

def train(epochs):
    print('Starting training..')
    for e in range(0, epochs):
        exp_lr_scheduler.step()
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)
        train_loss = 0.
        val_loss = 0.
        resnet18.train() # set model to training phase
        for train_step, (images, labels) in enumerate(dl_train):
            optimizer.zero_grad()
            outputs = resnet18(images)
            outputs = outputs.float()
            loss = loss_fn(outputs, labels.unsqueeze(0))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % 20 == 0:
                print('Evaluating at step', train_step)
                accuracy = 0
                resnet18.eval() # set model to eval phase
                for val_step, (images, labels) in enumerate(dl_val):
                    outputs = resnet18(images)
                    outputs = outputs.float()
                    loss = loss_fn(outputs, labels.unsqueeze(0))
                    val_loss += loss.item()
                    _, preds = torch.max(outputs, 1)
                    accuracy += sum((preds == labels).numpy())
                val_loss /= (val_step + 1)
                accuracy = accuracy/len(val_dataset)
                print(f'Validation Loss: {val_loss:.4f}, Accuracy: {accuracy:.4f}')
                show_preds()
                resnet18.train() #set model to training phase
                if accuracy >= 0.95:
                    print('Performance condition satisfied, stopping..')
                    return
        train_loss /= (train_step + 1)
        print(f'Training Loss: {train_loss:.4f}')
    print('Training complete..')**

    

    
train(epochs=30)
question from:https://stackoverflow.com/questions/66053295/valueerror-target-size-torch-size10-1-must-be-the-same-as-input-size-to

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Target size (torch.Size([1, 10])) must be the same as input size (torch.Size([10, 2]))

Seems to me you have two issues:

  1. target size (a.k.a. ground truth tensor) should have the batch on the first axis: (1, 10).

  2. From what you've described you are dealing with a binary classification task not a multi-label (2-class) classification task. Therefore input size (a.k.a. model's output) should have a shape of (10, 1).


In a binary classification task you should only have a single logit coming out of your model, i.e. your last nn.Linear layer should have a single neuron. The output will define which class has been predicted. Since you are using nn.BCEWithLogitsLoss, the loss input should be the raw output (since it includes a Sigmoid layer, cf. documentation) and should have a shape matching (batch_size=10, 1). Similarly, the target tensor should have the same shape. Its content would be 0s and 1s in shape (batch_size=10, 1).


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...