python 激活函数图像代码

import matplotlib.pylab as plt
import torch

def xyplot(x,y,name):
    # plt.rcParams['figure.figsize'] = (5, 5)
    plt.plot(x.detach().numpy(), y.detach().numpy())
    plt.xlabel('x')
    plt.ylabel(name + '(x)')
    plt.show()

x=torch.arange(-8,8,0.1,requires_grad=True)
y=x.relu()
xyplot(x,y,'relu')

y.sum().backward()
xyplot(x,x.grad,'grad of relu')

y = x.sigmoid()
xyplot(x, y, 'sigmoid')

x.grad.zero_()
y.sum().backward()
xyplot(x, x.grad, 'grad of sigmoid')

y = x.tanh()
xyplot(x, y, 'tanh')

x.grad.zero_()
y.sum().backward()
xyplot(x, x.grad, 'grad of tanh')

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/131396752
今日推荐