pytorch 使用

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_16949707/article/details/72571509

1 DataParallel

from torch.nn import DataParallel
net = DataParallel(net)

可以实现模块级别(?好处具体是啥不大懂)的并行计算,可以将一个模块forward部分分到各个gpu去计算,然后backwards时,合并gradients 到original module。

    >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
    >>> output = net(input_var)

2 DataLoader

其实这里trainset已经包含数据集了,dataloader只是定义输入网络的一些参数,入batch_size等等。

这里写图片描述

3 Transform

对数据集进行的操作
这里写图片描述
compose函数会将多个transforms包在一起。

参考:
http://www.jianshu.com/p/8da9b24b2fb6

猜你喜欢

转载自blog.csdn.net/qq_16949707/article/details/72571509