JAX Compute SeLU function

1. SeLU (scaled exponential linear units) activation function calculation formula

selu ( x ) = λ { x  if  x > 0 α e x − α  if  x ⩽ 0. \text{selu}(x)= \lambda \begin{cases} x& \text{ if } x>0 \\ \alpha e^x-\alpha & \text{ if } x\leqslant 0. \end{cases} comb ( x )=l{ xa exa if x>0 if x0.

where λ = 1.0507009873554804934193349852946 \lambda=1.0507009873554804934193349852946l=1.0507009873554804934193349852946α = 1.6732632423543772848170429916717. \alpha=1.6732632423543772848170429916717.a=1.6732632423543772848170429916717.

2. JAX code implementation

#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
@Time        : 2022/7/20 13:51
@Author      : Albert Darren
@Contact     : [email protected]
@File        : Program1.1.py
@Version     : Version 1.0.0
@Description : TODO 利用jax计算selu函数,详见P12
@Created By  : PyCharm
"""
import jax.numpy as jnp  # 导入numpy计算包
from jax import random  # 导入random随机数包


def selu(x, alpha=1.6732632423543772848170429916717, lmbda=1.0507009873554804934193349852946):
    """
    实现selu激活函数
    :param x: 输入张量
    :param alpha: 预定义参数alpha
    :param lmbda: 预定义参数lambda,此处变量名故意拼写错误,避免与关键字lambda命名冲突
    :return: selu函数值
    """
    return lmbda * jnp.where(x > 0, x, alpha * (jnp.exp(x) - 1))


# 产生一个固定数字17作为key
key = random.PRNGKey(17)
# 随机生成一个大小为[1,5]的矩阵
x = random.normal(key, (5,))
print(selu(x))
# [-1.2497659   0.4546819   1.5760192  -0.81573856  0.27510932]

3. References

Guess you like

Origin blog.csdn.net/m0_46223009/article/details/125891430