PyTorch一:简介

几乎所有的深度学习框架都是基于计算图计算的,而计算图又可分为静态图和动态图,静态图先定义再运行,一次定义多次运行,而动态图是在运行过程中被定义的,在运行的时候构建,可以多次构建多次运行。PyTorch和TensorFlow都是基于计算图的深度学习框架,PyTorch使用的是动态图,而TensorFlow使用的是静态图。

静态图一定创建就不能修改,而且静态图定义的时候,使用了特殊的语法,就像新学一门语言。这意味着你无法使用if、while等常用的Python语句。因此静态图框架不得不为这些操作专门设计语法,同时在构件图的时候必须把所有可能出现的情况都包含进去,这也导致了静态图过于庞大,可能占用过高的显存。动态图框架就没有这个问题,它可以使用Python的if、while等条件语句,最终创建的计算图取决于你执行的条件分支。

我们来看看if条件语句在TensorFlow和Pytorch中的两种实现方式,第一个利用PyTorch动态图的方式实现。

import torch as t
from torch.autograd import Variable

N,D,H = 3,4,5
x = Variable(t.randn(N,D))
w1 = Variable(t.randn(D,H))
w2 = Variable(t.randn(D,H))

z = 10
if z>0:
    y = x.mm(w1)
else:
    y = x.mm(w2)

第二个利用TensorFlow静态图的方式实现。

import tensorflow as tf
import numpy as np

N,D,H = 3,4,5
x = tf.placeholder(tf.float32,shape=(N,D))
w1 = tf.placeholder(tf.float32,shape=(D,H))
w2 = tf.placeholder(tf.float32,shape=(D,H))
z = tf.placeholder(tf.float32,shape=None)

def f1():
    return tf.matmul(x,w1)
def f2():
    return tf.matmul(x,w2)

y = tf.cond(z>0,f1,f2)

sess = tf.Session()
values = {x:np.random.randn(N,D),z:10,w1:np.random.randn(D,H),w2:np.random.randn(D,H)}
y_val = sess.run(y,feed_dict=values)

参考:深度学习框架PyTorch入门与实战

猜你喜欢

转载自blog.csdn.net/qq_24946843/article/details/88875721