算法设计与分析:Strassen矩阵乘法

参考以及转载链接:

https://blog.csdn.net/dawn_after_dark/article/details/78686488

题目描述:

代码上菜:

package cn.htu.test;

import java.util.Scanner;

public class StrassenMutipleMatrix {
	//矩阵减法操作(int[][]给的是地址)
    public static void matrixSub(int[][] matrixA, int[][] matrixB, int[][] result){
        for(int i = 0; i < matrixA.length; i++)//前提是A和B的行列相等
            for(int j = 0; j < matrixA.length; j++)
                result[i][j] = matrixA[i][j] - matrixB[i][j];
    }
    //矩阵加法操作
    public static void matrixAdd(int[][] matrixA, int[][] matrixB, int[][] result){
        for(int i = 0; i < matrixA.length; i++)
            for(int j = 0; j < matrixA.length; j++)
                result[i][j] = matrixA[i][j] + matrixB[i][j];
    }
    //Strassen的前提是行列式2的幂,而且行列相等
    public static void Strassen(int N, int[][] matrixA, int[][] matrixB, int[][] result){
        //N为行列数
    	if(N == 1){
            result[0][0] = matrixA[0][0] * matrixB[0][0];
            return;
        }
        int halfSize = N / 2;
        int[][] A11 = new int[halfSize][halfSize];
        int[][] A12 = new int[halfSize][halfSize];
        int[][] A21 = new int[halfSize][halfSize];
        int[][] A22 = new int[halfSize][halfSize];
        int[][] B11 = new int[halfSize][halfSize];
        int[][] B12 = new int[halfSize][halfSize];
        int[][] B21 = new int[halfSize][halfSize];
        int[][] B22 = new int[halfSize][halfSize];
        //定义结果体C1~C4
        int[][] C1 = new int[halfSize][halfSize];
        int[][] C2 = new int[halfSize][halfSize];
        int[][] C3 = new int[halfSize][halfSize];
        int[][] C4 = new int[halfSize][halfSize];

        //M就middle中间体
        int[][] M1 = new int[halfSize][halfSize];
        int[][] M2 = new int[halfSize][halfSize];
        int[][] M3 = new int[halfSize][halfSize];
        int[][] M4 = new int[halfSize][halfSize];
        int[][] M5 = new int[halfSize][halfSize];
        int[][] M6 = new int[halfSize][halfSize];
        int[][] M7 = new int[halfSize][halfSize];

        int[][] tempA = new int[halfSize][halfSize];
        int[][] tempB = new int[halfSize][halfSize];
        //开始对矩阵A和矩阵B进行"田"字切割
        for(int i = 0; i < halfSize; i++)
            for(int j = 0; j < halfSize; j++){
                A11[i][j] = matrixA[i][j];
                A12[i][j] = matrixA[i][halfSize + j];
                A21[i][j] = matrixA[i + halfSize][j];
                A22[i][j] = matrixA[i + halfSize][j + halfSize];

                B11[i][j] = matrixB[i][j];
                B12[i][j] = matrixB[i][halfSize + j];
                B21[i][j] = matrixB[i + halfSize][j];
                B22[i][j] = matrixB[i + halfSize][j + halfSize];
            }
        //M1作为result中间结果存储矩阵代入参数
        //temp临时矩阵只是用来存储非主要算法的矩阵相加和相减
        matrixSub(B12,B22,tempB);
        Strassen(halfSize,A11,tempB,M1);

        matrixAdd(A11,A12,tempA);
        Strassen(halfSize,tempA,B22,M2);

        matrixAdd(A21,A22,tempA);
        Strassen(halfSize,tempA,B11,M3);

        matrixSub(B21,B11,tempB);
        Strassen(halfSize,A22,tempB,M4);

        matrixAdd(A11,A22,tempA);
        matrixAdd(B11,B22,tempB);
        Strassen(halfSize,tempA,tempB,M5);

        matrixSub(A12,A22,tempA);
        matrixAdd(B21,B22,tempB);
        Strassen(halfSize,tempA,tempB,M6);

        matrixSub(A11,A21,tempA);
        matrixAdd(B11,B12,tempB);
        Strassen(halfSize,tempA,tempB,M7);

        matrixAdd(M5,M4,C1);
        matrixSub(C1,M2,C1);
        matrixAdd(C1,M6,C1);

        matrixAdd(M1,M2,C2);

        matrixAdd(M3,M4,C3);

        matrixAdd(M5,M1,C4);
        matrixSub(C4,M3,C4);
        matrixSub(C4,M7,C4);
        

        //将C1到C4进行"田"字合并
        for(int i = 0; i < halfSize; i++)
            for(int j = 0; j < halfSize; j++){
                result[i][j] = C1[i][j];
                result[i][j + halfSize] = C2[i][j];
                result[i + halfSize][j] = C3[i][j];
                result[i + halfSize][j + halfSize] = C4[i][j];
            }
    }
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        Scanner input = new Scanner(System.in);
        while(input.hasNext()){
            int n = input.nextInt();//手动输入规模N 
            int[][] matrixA = new int[n][n];
            int[][] matrixB = new int[n][n];
            int[][] result = new int[n][n];
            //汲取第一个矩阵
            for(int i = 0; i < n; i++)
                for(int j = 0; j < n; j++)
                    matrixA[i][j] = input.nextInt();
            //汲取第二个矩阵
            for(int i = 0; i < n; i++)
                for(int j = 0; j < n; j++)
                    matrixB[i][j] = input.nextInt();
            //乘法运算
            Strassen(n,matrixA,matrixB,result);
            //结果输出
            for(int i = 0; i < n; i++)
                for(int j = 0; j < n; j++){
                    if(j != n - 1) System.out.print(result[i][j] + " ");//若不是最后一行
                    else           System.out.println(result[i][j]);
                }
        }
    }

}

样例输入输出:

2
2 1
4 3
1 2
1 0
3 4
7 8

猜你喜欢

转载自www.cnblogs.com/shallow920/p/12909906.html