算法导论4.2-3练习

public static void divideMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
		int n = arr.length;
		int mid = n/2;
		for(int i =0;i < mid;i++) {
			for(int j = 0;j < mid;j++) {
				arr11[i][j] = arr[i][j];
				arr12[i][j] = arr[i][j+mid];
				arr21[i][j] = arr[i+mid][j];
				arr22[i][j] = arr[i+mid][j+mid];
			}
		}
	}

public static int[][] addOrSubtractionMatrix(int[][] arr1,int[][] arr2,int flag) {
		int n = arr1.length;
		int[][] addResultArr = new int[n][n];
		for(int i = 0 ;i < n;i++) {
			for (int j = 0; j < n; j++) {
				if(flag == 1) {
					addResultArr[i][j] = arr1[i][j] + arr2 [i][j];					
				}else {
					addResultArr[i][j] = arr1[i][j] - arr2 [i][j];		
				}
			}
		}
		
		return addResultArr;
	}
	

public static void combineMatrix(int[][] arr,int[][] arr11,int[][] arr12,int[][] arr21,int[][] arr22){
		int n = arr.length;
		int mid = n/2;
		if(n%2 != 0) {
			mid = (n+1)/2;
		}
		for(int i =0;i < mid;i++) {
			for(int j = 0;j < mid;j++) {
				arr[i][j] = arr11[i][j];
				if((j+mid) <= n-1) {
					arr[i][j+mid] = arr12[i][j];
				}
				if((i+mid <= n-1)) {
					arr[i+mid][j] = arr21[i][j];
				}
				if((j+mid) <= n-1 && (i+mid) <= n-1) {
					arr[i+mid][j+mid] = arr22[i][j];
				}
			}
		}
	}

public static int[][] strassenA(int[][] arrA,int[][] arrB) {
		int n = arrA.length;
		
		
		int[][] arr3 = new int [n][n];
		
		if(n==1) {
			arr3[0][0] = arrA[0][0]*arrB[0][0];
			return arr3;
		}
		
		int m = n/2;
		if(n%2 != 0) {
			int[][] arrA2 = new int[n+1][n+1];
			int[][] arrB2 = new int[n+1][n+1];
			for(int i = 0;i < n;i++) {
				for(int j = 0;j < n;j++) {
					arrA2[i][j]=arrA[i][j];
					arrB2[i][j]=arrB[i][j];
				}
			}
			arrA = arrA2;
			arrB = arrB2;
			m = (n+1)/2;
		}
		
		
		int[][] arrA11 = new int[m][m];
		int[][] arrA12 = new int[m][m];
		int[][] arrA21 = new int[m][m];
		int[][] arrA22 = new int[m][m];
		int[][] arrB11 = new int[m][m];
		int[][] arrB12 = new int[m][m];
		int[][] arrB21 = new int[m][m];
		int[][] arrB22 = new int[m][m];
		
		
		divideMatrix(arrA,arrA11,arrA12,arrA21,arrA22);
		divideMatrix(arrB,arrB11,arrB12,arrB21,arrB22);
		
		int[][] p1 =  addOrSubtractionMatrix(strassenA(arrA11,arrB12),strassenA(arrA11,arrB22),0);
		int[][] p2 =  addOrSubtractionMatrix(strassenA(arrA11,arrB22),strassenA(arrA12,arrB22),1);
		int[][] p3 =  addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA22,arrB11),1);
		int[][] p4 =  addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB11),0);
		int[][] p5 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB11),strassenA(arrA22,arrB22),1),1);
		int[][] p6 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA12,arrB21),strassenA(arrA12,arrB22),1),addOrSubtractionMatrix(strassenA(arrA22,arrB21),strassenA(arrA22,arrB22),1),0);
		int[][] p7 =  addOrSubtractionMatrix(addOrSubtractionMatrix(strassenA(arrA11,arrB11),strassenA(arrA11,arrB12),1),addOrSubtractionMatrix(strassenA(arrA21,arrB11),strassenA(arrA21,arrB12),1),0);
		
		
		int[][] arrC11 = addOrSubtractionMatrix(addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p4,1),p6,1),p2,0);
		int[][] arrC12 = addOrSubtractionMatrix(p1,p2,1);
		int[][] arrC21 = addOrSubtractionMatrix(p3,p4,1);;
		int[][] arrC22 = addOrSubtractionMatrix(addOrSubtractionMatrix(p5,p1,1),addOrSubtractionMatrix(p3,p7,1),0);
		
		combineMatrix(arr3,arrC11,arrC12,arrC21,arrC22);
		
		return arr3;
		
	}

猜你喜欢

转载自blog.csdn.net/sscout/article/details/81022864