pytorch基础——nn.Module模块

nn.Module模块

  • pytorch中所有的层结构和损失函数都来自于torch.nn
  • 所有的模型结构都是从nn.Module继承的
"""nn.Module模块定义一个计算图,并且这个结构可以复用多次"""
from torch import nn.Module


class net_name(nn.Module):
    # 继承
    def __init__(self, other_arguments):
        super(net_name, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels, kernel_size)
        # other network layer
        
    def forward(self, x):
        x = self.conv1
        return x
    
    

发布了165 篇原创文章 · 获赞 30 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_44478378/article/details/104296318
今日推荐