loss.backward()引发“RuntimeError: Found dtype Double but expected Float”

解决办法是将Tensor转化为想要的torch.float32类型。

具体为: 在将Tensor送入损失函数(mse_loss,cross_loss)之前,将Tensor转化为想要的torch.float32类型

age_batch = age_batch.view(-1, 1).to(torch.float32)
	    # 损失函数
	    mse_loss = torch.nn.MSELoss()
	    cross_loss = torch.nn.CrossEntropyLoss()

        # forward pass: compute predicted outputs by passing inputs to the model
        m_age_out_, m_gender_out_ = model(images_batch)
        age_batch = age_batch.view(-1, 1).to(torch.float32)
        gender_batch = gender_batch.long()

        # 计算损失
        l1 = mse_loss(m_age_out_, age_batch)
        l2 = cross_loss(m_gender_out_, gender_batch)
        loss = l1 + l2
        loss.backward()

Supongo que te gusta

Origin blog.csdn.net/thequitesunshine007/article/details/119223783
Recomendado
Clasificación