Debugging network models [PyTorch version]

Test whether the network process is smooth

if __name__ == '__main__':
    x=torch.randn(2,3,256,256)
    net=Union_Seg_1_v1()
    print(net(x).shape)

其中,torch.randn(batch_size , channel , size[0] , size[1] )

batch_size: The amount of data input in one run

channel: number of input channels

size: the size of the input image (length and width)

First, define the input data format

Then, define the network

Input data into the network and print out the format of the data

Print network structure

from torchsummary import summary

# 需要使用device来指定网络在GPU还是CPU运行
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
net=DPNet_v1()
model = net.to(device)
# input_size=(channel,size,size)
summary(model, input_size=(3,256,256))

Requires the torchsummary package. pip install torchsummary or conda install -c ravelbio torchsummary

Guess you like

Origin blog.csdn.net/qq_41704436/article/details/131147158