算法导论第四章4.2代码实现

最简单的矩阵乘法实现


public static int[][] squareMatrixMultiply(int[][] arr1,int [][] arr2) {
		int n = arr1.length;
		
		int[][] arr3 = new int [n][n];
		
		for(int i = 0;i < n;i++) {
			for(int j = 0;j < n;j++) {
				for(int k = 0;k < n;k++) {
					arr3[i][j] += arr1[i][k] * arr2[k][j];
				}
			}
		}
		
		return arr3;
	}

这种方法的实现是最简单的,时间复杂度也是最长的,是Θ(n^3).


递归实现矩阵乘法

矩阵划分代码

public static void divideMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
		int n = arr.length;
		int mid = n/2;
		for(int i =0;i < mid;i++) {
			for(int j = 0;j < mid;j++) {
				arr11[i][j] = arr[i][j];
				arr12[i][j] = arr[i][j+mid];
				arr21[i][j] = arr[i+mid][j];
				arr22[i][j] = arr[i+mid][j+mid];
			}
		}
	}

其实算法导论书上说矩阵划分的时间复杂度是Θ(1),但是我没想出怎么实现。。

矩阵相加减代码

public static int[][] addOrSubtractionMatrix(int[][] arr1,int[][] arr2,int flag) {
		int n = arr1.length;
		int[][] addResultArr = new int[n][n];
		for(int i = 0 ;i < n;i++) {
			for (int j = 0; j < n; j++) {
				if(flag == 1) {
					addResultArr[i][j] = arr1[i][j] + arr2 [i][j];					
				}else {
					addResultArr[i][j] = arr1[i][j] - arr2 [i][j];		
				}
			}
		}
		return addResultArr;
	}


矩阵结合代码

public static void combineMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
		int n = arr.length;
		int mid = n/2;
		for(int i =0;i < mid;i++) {
			for(int j = 0;j < mid;j++) {
				arr[i][j] = arr11[i][j];
				arr[i][j+mid] = arr12[i][j];
				arr[i+mid][j] = arr21[i][j];
				arr[i+mid][j+mid] = arr22[i][j];
			}
		}
	}

最终代码

	public static int[][] squareMatrixMultiply2(int[][] arr1,int [][] arr2){
		int n = arr1.length;


		int[][] arr3 = new int [n][n];
				
		if(n==1) {
			arr3[0][0]=arr1[0][0]*arr2[0][0];
		}else {
			int m = n/2;
			
			int[][] arrA11 = new int[m][m];
			int[][] arrA12 = new int[m][m];
			int[][] arrA21 = new int[m][m];
			int[][] arrA22 = new int[m][m];
			int[][] arrB11 = new int[m][m];
			int[][] arrB12 = new int[m][m];
			int[][] arrB21 = new int[m][m];
			int[][] arrB22 = new int[m][m];
			int[][] arrC11 = new int[m][m];
			int[][] arrC12 = new int[m][m];
			int[][] arrC21 = new int[m][m];
			int[][] arrC22 = new int[m][m];
			
			divideMatrix(arr1,arrA11,arrA12,arrA21,arrA22);
			divideMatrix(arr2,arrB11,arrB12,arrB21,arrB22);
			divideMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
			
			arrC11 = addMatrix(squareMatrixMultiply2(arrA11,arrB11),squareMatrixMultiply2(arrA12,arrB21));
			arrC12 = addMatrix(squareMatrixMultiply2(arrA11,arrB12),squareMatrixMultiply2(arrA12,arrB22));
			arrC21 = addMatrix(squareMatrixMultiply2(arrA21,arrB11),squareMatrixMultiply2(arrA22,arrB21));
			arrC22 = addMatrix(squareMatrixMultiply2(arrA21,arrB12),squareMatrixMultiply2(arrA22,arrB22));
			combineMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
		}
		
		return arr3;
		
	}

递归实现的思路还是很明确的,采用分块矩阵的乘法方式,将一个大的矩阵经过一次次的划分再分别进行分块矩阵乘法

这种方法并没有比原来的那种方式快导数

原来n*n矩阵相乘的时间复杂度是T(n)

现在n/2*n/2矩阵相乘的时间复杂度是T(n/2)

又经历的8次乘法和4次加减法

所以递归的时间复杂度是8T(n/2)+Θ(n^2)


 Strassen’s 矩阵乘法

public static int[][] strassenA(int[][] arrA,int[][] arrB) {
		int n = arrA.length;
		
		int[][] arr3 = new int [n][n];
		
		if(n==1) {
			arr3[0][0] = arrA[0][0]*arrB[0][0];
			return arr3;
		}
		int m = n/2;
		
		
		int[][] arrA11 = new int[m][m];
		int[][] arrA12 = new int[m][m];
		int[][] arrA21 = new int[m][m];
		int[][] arrA22 = new int[m][m];
		int[][] arrB11 = new int[m][m];
		int[][] arrB12 = new int[m][m];
		int[][] arrB21 = new int[m][m];
		int[][] arrB22 = new int[m][m];
		
		
		divideMatrix(arrA,arrA11,arrA12,arrA21,arrA22);
		divideMatrix(arrB,arrB11,arrB12,arrB21,arrB22);
		
		int[][] p1 =  addOrSubtractionMatrix(strassenA(arrA11,arrB12),strassenA(arrA11,arrB22),0);
		int[][] p2 =  addOrSubtractionMatrix(strassenA(arrA11,arrB22),strassenA(arrA12,arrB22),1);
		int[][] p3 =  addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA22,arrB11),1);
		int[][] p4 =  addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB11),0);
		int[][] p5 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB11),strassenA(arrA22,arrB22),1),1);
		int[][] p6 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA12,arrB21),strassenA(arrA12,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB22),1),0);
		int[][] p7 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB12),1),addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA21,arrB12),1),0);
		
		
		int[][] arrC11 = addOrSubtractionMatrix(addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p4,1),p6,1),p2,0);
		int[][] arrC12 = addOrSubtractionMatrix(p1,p2,1);
		int[][] arrC21 = addOrSubtractionMatrix(p3,p4,1);;
		int[][] arrC22 = addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p1,1),addOrSubtractionMatrix(p3,p7,1),0);
		
		combineMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
		
		return arr3;
		
	}

减少了1次矩阵乘法,那除去的一次乘法由多次矩阵加法替代。

最终的时间复杂度为7T(n/2)+Θ(n^2),由此可见该方法是复杂度最低的



猜你喜欢

转载自blog.csdn.net/sscout/article/details/81016217
今日推荐