Python은 동적 for 루프를 구현합니다.

배경

들어오는 사진과 텐서는 차원이 다를 수 있으므로 동일한 기능을 사용하여 처리하려면 동적 for 루프를 사용하여 이를 달성할 수 있습니다.

성취하다

pytorch의 텐서 텐서를 데이터 컨테이너로 사용합니다.

여기서 텐서 내의 형태 등의 매개변수는 형태를 반환하는 데 사용되지 않고 직접 전달됩니다. 동적 for 루프의 성립 과정을 누구나 쉽게 이해할 수 있도록 하기 위함입니다.

합계 구현

그림의 모든 요소와 차원이 다른 텐서의 합계를 계산하는 동적 for 루프를 구현합니다.

import torch


def dynamicFor(data, data_size):
    '''
    动态for循环之实现求和
    :param data: 传入的待求数据
    :param data_size: 数据的形状
    :return: 数据data中各元素之和
    '''
    count = 0
    place = 0
    ndim = len(data_size)  # 算出维度
    sum_num = 1
    sum_list = []
    sumx = 1
    for i in range(ndim):
        sum_num *= data_size[i]  # 计算所要计算的总次数
        if i != ndim - 1:
            sumx *= data_size[-(i + 1)]  # 计算在每个维度下的除数,方便后面用求余数的方法得到具体位置
            sum_list.append(sumx)
    sum_list.reverse()  # 记得取反,方便后面索引
    # print(sum_list)
    while place < sum_num:
        position = []  # 存储位置
        current_place = place
        for i in range(ndim):
            if i != ndim - 1:
                position.append(current_place // sum_list[i])
                # 下面这步很重要,去掉多余的位置记录,防止后面位置的小除数除出来越界的位置
                current_place = current_place - current_place // sum_list[i] * sum_list[i]
            else:
                position.append(current_place % sum_list[i - 1])
        result = torch.Tensor(data)
        for position_num in position:  # 访问具体位置以方便进行你要进行的操作,比如这里是求和
            result = result[position_num]
        count += result  # 访问到了具体位置,进行你要的操作,求和
        place += 1
    return count

올바른지 확인하십시오.

dim3 = torch.ones(400).reshape([5, 8, 10])
print(dynamicFor(dim3, [5, 8, 10]))

dim4 = torch.ones(400).reshape([5, 5, 4, 4])
print(dynamicFor(dim4, [5, 5, 4, 4]))

dim5 = torch.ones(400).reshape([2, 4, 2, 5, 5])
print(dynamicFor(dim5, [2, 4, 2, 5, 5]))

결과는 다음과 같습니다.

tensor(400.)
tensor(400.)
tensor(400.)

규범을 실현하다

(이 동적 for 루프를 하는 이유는 네트워크를 구축할 때 네트워크 가중치의 L2 norm을 계산하기 위해 손실 함수를 입력하려고 했기 때문입니다. 차원이 다릅니다.)

서로 다른 차원의 규범을 찾기 위해 동적 for 루프를 구축하려면 마지막 두 차원은 위치 인덱스에 포함되지 않지만 위치 인덱스는 행렬의 위치를 ​​찾는 데 사용됩니다. 이전 위치 지수가 계산됩니다.

import torch


def dynamicFor4Norm(data, data_size):
    '''
    动态for循环之求矩阵范数。即最后两个维度中矩阵的范数
    :param data: 待求的数据
    :param data_size: 数据data的形状
    :return: 数据data中,所有最后两个维度的矩阵的范数平方之和
    '''
    count = 0
    place = 0
    ndim = len(data_size)  # 算出总维度
    if ndim < 2:  # 如果维度小于2,那么应该就是偏置,就跳过
        return 0
        # raise Exception("Dimension of input is less than 2!")
    elif ndim == 2:  # 维度为2,直接求
        return torch.tensor(data).norm()
    else:
        sum_num = 1
        sum_list = []
        sumx = 1
        for i in range(ndim - 2):  # 因为要保留最后两个的维度,所以去掉最后2个维度的总次数和各维度除数的计算
            sum_num *= data_size[i]  # 计算前面所要计算的总次数
            if i != ndim - 3:
                sumx *= data_size[-2 - (i + 1)]  # 计算前面每个维度下的除数,方便后面用求余数的方法得到具体位置
                sum_list.append(sumx)
        sum_list.reverse()  # 记得取反,方便后面索引
        # print(sum_list)
        while place < sum_num:
            position = []  # 存储矩阵位置,方便便利所有的矩阵
            current_place = place
            for i in range(ndim - 2):
                if i != ndim - 3:
                    position.append(current_place // sum_list[i])
                    # 下面这步很重要,去掉多余的位置记录,防止后面位置的小除数除出来越界的位置
                    current_place = current_place - current_place // sum_list[i] * sum_list[i]
                else:
                    position.append(current_place % sum_list[i - 1])
            result = torch.Tensor(data)
            for position_num in position:  # 访问具体位置以方便进行你要进行的操作,比如这里是求范数的平方
                result = result[position_num]
            count += result.norm()  # 访问到了具体位置,进行你要的操作,求范数
            place += 1
    return count

테스트가 정확합니다.

# 用和前面同样的dim4、dim5
dim4 = torch.ones(400).reshape([5, 5, 4, 4])
dim5 = torch.ones(400).reshape([2, 4, 2, 5, 5])

print(dynamicFor4Norm(dim4, [5, 5, 4, 4]))
print(dynamicFor4Norm(dim5, [2, 4, 2, 5, 5]))

결과:

tensor(100.)
tensor(80.)

소주

일반적으로 Loss 함수에서 가중치 L2 놈의 제곱은 다음과 같이 계산됩니다.

이때 위의 43번째 줄은 노름을 찾기 위해 다음과 같이 변경되어야 합니다.

count += result.norm().pow(2)  # 没错,就是加了个pow(2)求平方

이때 동일한 입력 결과는 다음과 같습니다.

tensor(400.)
tensor(400.)

Supongo que te gusta

Origin blog.csdn.net/m0_46948660/article/details/129407470
Recomendado
Clasificación