解决Indexerror: dimension out of range (expected to be in range of [-1, 0], but got 1)

问题描述

在复现代码时,把batch_size调整为1,结果softmax报以下错误:

# Super Sketch Network links a RNN and CNN together with an attention layer in the last layer.
class SSN(nn.Module):
    
    def __init__(self, cnn_model_name,rnn_model_name, d_frozen = True,num_classes=40):
        pass
                
        
    def forward(self, images,strokes):
        cnn_output,cnn_f = self.cnn(images)
        rnn_output,rnn_f = self.rnn(strokes,None)
        
        #Attention Layer linking RNN and CNN together.
        output = torch.stack([cnn_output,rnn_output],dim = 1)
        
        #Get the center feature
        ssn_feat = torch.cat((cnn_f,rnn_f),dim = 1)
        att_score = torch.matmul(output, self.attention).squeeze()
        att_score = F.softmax(att_score,dim = 1).view(output.size(0), output.size(1), 1)
        score = output * att_score

        score = torch.sum(score, dim=1)
        
        return score,ssn_feat

Indexerror: dimension out of range (expected to be in range of [-1, 0], but got 1)

解决方案

在使用softmax函数时需要保证矩阵是二维的,但是当batch_size=1时,整个output矩阵的维度为 

[batch_size =1,maxlen,1],

如果不指定output.squeeze的维度就会得到[maxlen]的维度导致报错。

解决方案是手动给张量升维。在PyTorch中,可以使用torch.unsqueeze函数增加张量的维度。该函数会在指定位置(默认为最后一个维度)增加一个维度,使得张量的维度增加1。例如,对于一个形状为(3, 4)的张量,使用torch.unsqueeze(input, dim=0)可以得到一个形状为(1, 3, 4)的张量。具体用法如下:

import torch

# 假设输入的张量为tensor,shape为(n,)
tensor = torch.tensor([1, 2, 3])

# 增加一个维度,变成(1, n)
tensor = torch.unsqueeze(tensor, dim=0)

应用在本段代码就是:

# Super Sketch Network links a RNN and CNN together with an attention layer in the last layer.
class SSN(nn.Module):
    
    def __init__(self, cnn_model_name,rnn_model_name, d_frozen = True,num_classes=40):
        pass
                
        
    def forward(self, images,strokes):
        cnn_output,cnn_f = self.cnn(images)
        rnn_output,rnn_f = self.rnn(strokes,None)
        
        #Attention Layer linking RNN and CNN together.
        output = torch.stack([cnn_output,rnn_output],dim = 1)
        
        #Get the center feature
        ssn_feat = torch.cat((cnn_f,rnn_f),dim = 1)
        att_score = torch.matmul(output, self.attention).squeeze()
        att_score = torch.unsqueeze(att_score,dim=0)
        att_score = F.softmax(att_score,dim = 1).view(output.size(0), output.size(1), 1)
        score = output * att_score

        score = torch.sum(score, dim=1)
        
        return score,ssn_feat

猜你喜欢

转载自blog.csdn.net/qq_54708219/article/details/130232096