minitorch系列记录——3. 自动求导

参考:

只使用标量构建minitorch的第一个版本,关键就是自动求导。
标量的自动求导

1. 基本功能完成

1.0 前置步骤

第一步:

将module0中实现的内容转移到module1中(注意,测试不一定能检测出所有问题,很可能会在这章有些问题,结果是module0造成的。)

# 这两个是替换
minitorch/operators.py 
minitorch/module.py 

# 下面三个是新增
tests/test_module.py 
tests/test_operators.py 

project/run_manual.py

第二步:

安装必要的包,切换到repo里面,即minitorch-module-1-CastleDream目录中。

python -m pip install -r requirements.txt
python -m pip install -r requirements.extra.txt
python -m pip install -Ue .
# 这步会把当前的minitorch-moduleX重新安装成一个新的包(基于本地的repo文件,所以repo文件不能随便换位置或者删除)
conda install llvmlite

Task 1.1 Numerical Derivatives

完成任务1,数值求导(需要有基本的高等数学求导基础)

PS:如果有数据公式无法正确显示,直接找个支持markdown语法的编辑器,放到公式编辑格式里,就可以看到了,如下:

f i ′ ( x 0 , … , x n − 1 ) f'_i(x_0, \ldots, x_{n-1}) fi(x0,,xn1)

完成之后,可以pytest -m task1_1只测试这一模块的内容。

@pytest.mark.task1_1
def test_central_diff():
    d = central_difference(operators.id, 5, arg=0)
    assert_close(d, 1.0)

    d = central_difference(operators.add, 5, 10, arg=0)
    assert_close(d, 1.0)

    d = central_difference(operators.mul, 5, 10, arg=0)
    assert_close(d, 10.0)

    d = central_difference(operators.mul, 5, 10, arg=1)
    assert_close(d, 5.0)

    d = central_difference(operators.exp, 2, arg=0)
    assert_close(d, operators.exp(2.0))

在这里插入图片描述
验证成功,显示如上图


Task 1.2 Scalars

这部分就是实现梯度传播前的基础,主要是标量的计算和前向传播。

  • 要实现这部分,需要对minitorch.Scalar这个类很熟悉,可以先去看一下Tracking Variables这个追踪变量的引导手册,可能会刷新你对Python numerical overrides的认知
  • 实现重载的数学函数需要minitorch.Scalar类,每种方法都需要将内部Python操作符连接到正确的minitorch.Function.forward()调用。

注意,是先完成这部分代码的填空:

class Mul(ScalarFunction):
   "Multiplication function"

   @staticmethod
   def forward(ctx, a, b):
       # TODO: Implement for Task 1.2.
       raise NotImplementedError('Need to implement for Task 1.2')

   @staticmethod
   def backward(ctx, d_output):
       # TODO: Implement for Task 1.4.
       raise NotImplementedError('Need to implement for Task 1.4')

然后再去补充这部分代码:

def __lt__(self, b):
    # TODO: Implement for Task 1.2.
    raise NotImplementedError('Need to implement for Task 1.2')

基类ScalarFunction是这样:

class ScalarFunction(FunctionBase):
    """
	处理和生成标量Variables的数学函数的包装器
	是一个从来不会被实例化的静态类,使用这个类把 `forward` 和`backward` 代码组合到一起
    """

    @staticmethod
    def forward(ctx, *inputs):
        r"""
        前向调用,主要就是计算math:`f(x_0 \ldots x_{n-1})`.

        Args:
            ctx (:class:`Context`): 
            一个存储反向调用过程中可能用到的所有信息的容器对象
            *inputs (list of floats): n-float values :math:`x_0 \ldots x_{n-1}`.

        Should return float the computation of the function :math:`f`.
        """
        pass  # pragma: no cover

    @staticmethod
    def backward(ctx, d_out):
        r"""
        反向调用, 计算math:`f'_{x_i}(x_0 \ldots x_{n-1}) \times d_{out}`.

        Args:
            ctx (Context): A container object holding any information saved during in the corresponding `forward` call.
            d_out (float): :math:`d_out` term in the chain rule.

        Should return the computation of the derivative function
        :math:`f'_{x_i}` for each input :math:`x_i` times `d_out`.

        """
        pass  # pragma: no cover

    # Checks.
    variable = Scalar
    data_type = float

    @staticmethod
    def data(a):
        return a

简单来说,

  • 正向是计算 f ( x 0 … x n − 1 ) f(x_0 \ldots x_{n-1}) f(x0xn1),返回计算结果(一般是一个浮点数)
  • 反向是计算 f x i ′ ( x 0 … x n − 1 ) × d o u t f'_{x_i}(x_0 \ldots x_{n-1}) \times d_{out} fxi(x0xn1)×dout,返回每个参数的导数(用到几个参数,返回对应于参数个数个导数)

填完之后,测试pytest -m task1_2
在这里插入图片描述

Task 1.3 Chain Rule

  • 这个任务比较困难,请确保你了解链式规则( chain rule)、变量(Variables)和 函数(Functions)。
  • 请先仔细阅读自动微分的先导手册Autodifferentiation ,同时读一些其他ScalarFunctions的代码
  • 对于含有任意个参数的函数,基于FunctionBase这个类实现chain_rule函数。这个函数应该具有以下功能:
    • 可以通过传递一个 c o n t e x t context context(上下文变量)和 d o u t d_{out} dout来收集局部梯度,进而达到反向处理函数的目的
    • 它需要将梯度和对应的变量正确配对,并返回。也是在这个函数中,过滤掉传给前向的常数,但是不需要导数。

代码完成后,测试使用pytest -m task1_3
在这里插入图片描述

Task 1.4 Backpropagation

完成这部分任务之前先看反向传播引导文档

pytest -m task1_4

Task 1.5 Training

2. 检查并上传

# 检查全部
black .
# 也可以只检查特定的几个文件
black minitorch/ tests/ project/

flake8
# 也可以只检查特定的几个文件
flake8 minitorch/ tests/ project/

猜你喜欢

转载自blog.csdn.net/Castlehe/article/details/121291776