pytorch产生loss的计算图

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(x.size()[0],-1)
        print(x)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
net=Net()
#params=list(net.parameters())
#for name,parameters in net.named_parameters():
#    print(name,':',parameters.size())
#print(len(params))
#print(net)
input=Variable(t.randn(1,1,32,32))
output=net(input)
#out.size()
target=Variable(t.arange(0,10))
criterion=nn.MSELoss()
loss=criterion(output,target)
loss.grad_fn

猜你喜欢

转载自blog.csdn.net/zouxiaolv/article/details/83033188
今日推荐