矩阵乘法问题

import numpy as np
# 矩阵乘法问题
# 问题描述:
# 假设A和B是两个n阶矩阵,求它们的乘积C,n=2*k
# 1、暴力法
# 根据矩阵相乘的性质,直接遍历求解
def matrix_multiply(a, b):
    n = a.shape[0]
    c = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            # 累加求和
            for k in range(n):
                c[i][j] += a[i][k] * b[k][j]
    return c
# 2、分治法
# 算法思想:
# 将矩阵的行和列从n/2处分开,化成4个子矩阵
# A = [[A11, A12] B = [[B11, B12],那么C = [[C11, C12]
#     [A21, A22]]     [B21, B22]]         [C21, C22]]
# C11 = A11 * B11 + A12 * B21,C12 = A11 * B12 + A12 * B22
# C21 = A21 * B11 + A22 * B21,C22 = A21 * B12 * A22 * B22
# 再继续划分成4个更小的子矩阵,直到化成1×1的矩阵为止
def divide_matrix(a, number):
    n = a.shape[0]
    rows, cols = n//2, n//2
    b = np.zeros((rows, cols))
    # 左上
    if number == 1:
        for i in range(rows):
            for j in range(cols):
                b[i][j] = a[i][j]
    # 右上
    elif number == 2:
        for i in range(rows):
            for j in range(cols):
                b[i][j] = a[i][j+cols]
    # 左下
    elif number == 3:
        for i in range(rows):
            for j in range(cols):
                b[i][j] = a[i+rows][j]
    # 右下
    else:
        for i in range(rows):
            for j in range(cols):
                b[i][j] = a[i+rows][j+cols]
    return b
def merge_matrix(c11, c12, c21, c22):
    n = c11.shape[0]
    c = np.zeros((2*n, 2*n), np.int32)
    for i in range(2*n):
        for j in range(2*n):
            if i < n:
                if j < n:
                    c[i][j] = c11[i][j]
                else:
                    c[i][j] = c12[i][j-n]
            else:
                if j < n:
                    c[i][j] = c21[i-n][j]
                else:
                    c[i][j] = c22[i-n][j-n]
    return c
def matrix_multiply(a, b):
    n = a.shape[0]
    if n == 1:
        c = a * b
        return c
    if n > 1:
        c = np.zeros((n, n))
        # 划分矩阵
        a11 = divide_matrix(a, 1)
        a12 = divide_matrix(a, 2)
        a21 = divide_matrix(a, 3)
        a22 = divide_matrix(a, 4)
        b11 = divide_matrix(b, 1)
        b12 = divide_matrix(b, 2)
        b21 = divide_matrix(b, 3)
        b22 = divide_matrix(b, 4)
        c11 = divide_matrix(c, 1)
        c12 = divide_matrix(c, 2)
        c21 = divide_matrix(c, 3)
        c22 = divide_matrix(c, 4)
        # 递归求解
        c11 = matrix_multiply(a11, b11) + matrix_multiply(a12, b21)
        c12 = matrix_multiply(a11, b12) + matrix_multiply(a12, b22)
        c21 = matrix_multiply(a21, b11) + matrix_multiply(a22, b21)
        c22 = matrix_multiply(a21, b12) + matrix_multiply(a22, b22)
        # 合并矩阵
        c = merge_matrix(c11, c12, c21, c22)
        return c

Guess you like

Origin blog.csdn.net/weixin_49346755/article/details/121481609
Recommended