03计算图与动态图机制

一、计算图

1.1 计算图定义

定义: 计算图是用来描述运算的有向无环图

计算图有两个主要元素:

  • 结点(Node):表示数据,如向量,矩阵,张量
  • 边(Edge):表示运算,如加减乘除卷积等

示例:
在这里插入图片描述
用计算图表示:y = (x + w)*(w + 1)

  • a = x + w
  • b = w + 1
  • y = a * b

1.2 计算图与梯度求导

y=(x + w)*(w + 1)

  • a = x + w
  • b = w + 1
  • y = a * b

在这里插入图片描述
y w = y a a w + y b b w = b × 1 + a × 1 = b + a = ( w + 1 ) × ( x + w ) = 2 × w + x + 1 = 2 × 1 + 2 + 1 = 5 \frac{\partial y}{\partial w}=\frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ \quad = b\times1+a \times 1 \\ \quad = b + a \\ \quad = (w + 1) \times (x + w) \\ \quad =2\times w +x + 1 \\ \quad =2 \times 1 + 2 + 1 = 5

# -*- coding:utf-8 -*-
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)

在这里插入图片描述
在这里插入图片描述
叶子结点: 用户创建的结点,如图中的x和w

  • is_leaf:指示张量是否为叶子结点
  • grad_fn:记录创建该张量时所用的方法(函数)

注意: 反向传播过程中,Pytorch默认只记录叶子结点的梯度,如果需要记录非叶子结点的梯度,则使用retain_grad()方法

# -*- coding:utf-8 -*-
import torch

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)

a = torch.add(w, x)
# a.retain_grad()  # 保存非叶子结点的梯度
b = torch.add(w, 1)
y = torch.mul(a, b)

y.backward()
print(w.grad)

# 查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)

# 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)

# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)

在这里插入图片描述

二、Pytorch的动态图机制

根据计算图搭建方式,可以将计算图分为动态图和静态图

动态图:运算与搭建图同时进行,灵活,易调节
静态图:先搭建图,后运算,高效,但不灵活

#TensorFlow
import tensorflow as tf
first_counter = tf.constant(0)
second_counter = tf.constant(10)
def cond(first_counter,second_counter,*args):
    return first_counter < second_counter
def body(first_counter, second_counter):
    first_counter = tf.add(first_counter,2)
    second_counter = tf.add(second_counter,1)
    return first_counter,second_counter
c1,c2 = tf.while_loop(cond, body, [first_counter,second_counter])
with tf.Session() as sess:
    counter_1_res, counter_2_res = sess.run([c1,c2])
print("counter_1_res:",counter_1_res)
print("counter_2_res:",counter_2_res)

运行结果:
在这里插入图片描述

import torch
first_counter = torch.Tensor([0])
second_counter = torch.Tensor([10])
while (first_counter < second_counter)[0]:
    first_counter += 2
    second_counter += 1
print(first_counter)
print(second_counter)

输出:
在这里插入图片描述

发布了105 篇原创文章 · 获赞 9 · 访问量 7820

猜你喜欢

转载自blog.csdn.net/qq_36825778/article/details/104083023