矩阵乘法Strassen算法

这里主要是给出实现方法,至于算法的介绍,可以参考@
金戈大王的介绍。
下面算法是有bug的,虽然已经对一些非n*n的矩阵做出了处理,但是没有完善,当计算非n*n的矩阵是会出现数据越界的异常的。所以两个相乘的矩阵必须是n*n的。
不啰嗦,下面直接贴出Java的实现方法:

//调用入口:


         public static int[][] StrassenMulti(int a [][], int b[][])
	  {
		int acow = a.length, acol = a[0].length, bcow = b.length, bcol = b[0].length;
		if(acow!=acol || bcow != bcol || acow != bcow) return MatrixMulti(a,b);
		if((acow &(acow -1 )) != 0) return MatrixMulti(a,b);//不是2的幂
		//只有符合2的幂才满足Strassen算法使用条件
		if(acow == 2) return MatrixMulti(a,b);
		else 
		{
			int[][] A11 = new int[acow/2][acow/2];
			int[][] A12 = new int[acow/2][acow/2];
			int[][] A21 = new int[acow/2][acow/2];
			int[][] A22 = new int[acow/2][acow/2];
			nnMatrixSplitTo4Block(a,A11,A12,A21,A22);
			int[][] B11 = new int[acow/2][acow/2];
			int[][] B12 = new int[acow/2][acow/2];
			int[][] B21 = new int[acow/2][acow/2];
			int[][] B22 = new int[acow/2][acow/2];
			nnMatrixSplitTo4Block(b,B11,B12,B21,B22);
			
			int [][]S1 = MatrixNeg(B12 ,B22);
			int [][]S2 = MatrixPlus(A11 , A12);
			int [][]S3 = MatrixPlus(A21 , A22);
			int [][]S4 = MatrixNeg(B21 , B11);
			int [][]S5 = MatrixPlus(A11 , A22);
			int [][]S6 = MatrixPlus(B11 ,B22);
			int [][]S7 = MatrixNeg(A12 , A22);
			int [][]S8 = MatrixPlus(B21 , B22);
			int [][]S9 = MatrixNeg(A11 , A21);
			int [][]S10 = MatrixPlus(B11 , B12);
			
			int [][]P1 = StrassenMulti(A11 , S1);
			int [][]P2 = StrassenMulti(S2 , B22);
			int [][]P3 = StrassenMulti(S3 , B11);
			int [][]P4 = StrassenMulti(A22 , S4);
			int [][]P5 = StrassenMulti(S5 ,S6);
			int [][]P6 = StrassenMulti(S7 , S8);
			int [][]P7 = StrassenMulti( S9 , S10);
			
			int [][]C11 = MatrixPlus(MatrixNeg(MatrixPlus(P5, P4), P2) , P6);
			int [][]C12 = MatrixPlus(P1 , P2);
			int [][]C21 = MatrixPlus(P3,  P4);
			int [][]C22 = MatrixNeg(MatrixNeg(MatrixPlus(P5 , P1) , P3) , P7);
			return MatrixBlockPlus(C11,C12,C21,C22);

		}
	
		  
	  }
//将矩阵分为四个子矩阵
	  public static void nnMatrixSplitTo4Block(int [][]src, int [][]A11, int [][]A12, int [][]A21, int [][]A22)
	  {
		  int n = src.length;
		  for(int i = 0; i<A11.length+A21.length; i++)
			  for(int j = 0; j<A11[0].length+A12[0].length; j++)
			  {
				  if(i<A11.length)
				  {
					  if(j<A11[0].length)
					  {
						 A11[i][j] = src[i][j];
					  }
					  else
					  {
						  A12[i][j-A11[0].length] = src[i][j];
					  }
				  }
				  else
				  {
					  if(j<A21[0].length)
					  {
						  A21[i-A11.length][j] = src[i][j];
					  }
					  else
					  {
						  A22[i-A11.length][j-A12[0].length] = src[i][j];
					  }
				  } 
			  }
	  }

//将四个矩阵合并为一个矩阵
	  
	  public static int[][] MatrixBlockPlus(int [][]A11, int [][]A12, int [][]A21, int [][]A22)
	  {
		  if(A11[0].length+A12[0].length != A21[0].length+A22[0].length || A11.length+A21.length != A12.length+A22.length) return null;
		  
		  int result[][] = new int[A11.length+A21.length][A11[0].length+A12[0].length];
		  for(int i = 0; i<A11.length+A21.length; i++)
			  for(int j = 0; j<A11[0].length+A12[0].length; j++)
			  {
				  if(i<A11.length)
				  {
					  if(j<A11[0].length)
					  {
						  result[i][j] = A11[i][j];
					  }
					  else
					  {
						  result[i][j] = A12[i][j-A11[0].length];
					  }
				  }
				  else
				  {
					  if(j<A12[0].length)
					  {
						  result[i][j] = A21[i-A11.length][j];
					  }
					  else
					  {
						  result[i][j] = A22[i-A11.length][j-A12[0].length];
					  }
				  } 
			  }
		  return result;
		  
	  }
//矩阵减法
	  public static int[][] MatrixNeg(int a[][], int b[][])
	  {
		  int temp[][] = new int[b.length][b[0].length];
		  for(int i = 0 ;i<b.length; i++)
		  {
			  for(int j = 0; j <  b[0].length; j++)
			  {
				  temp[i][j] = (-1) * b[i][j];
			  }
		  }
		  return MatrixPlus(a,temp);
	  }
	  
//矩阵加法
	  public static int[][] MatrixPlus(int a[][], int b[][])
	  {
		  if(a[0].length  != b[0].length || a.length != b.length) return null;
		  int result[][] = new int[a.length][a[0].length];
		  for(int i = 0; i<a.length; i++)
			  for(int j = 0; j<b.length; j++)
			  {
				  result[i][j] = a[i][j]+ b[i][j];
			  }
		  return result;
	  }
	  
//矩阵乘法
	  public static int[][] MatrixMulti(int a[][], int b[][])
	  {
		  //     a的列数不等于 b行数                      //列数
		  int cow = a.length;//结果的行数
		  int col = b[0].length;//结果的列数
		  if(a[0].length  != b.length ) return null;
		  else
		  {
			  int result[][] = new int[a.length][b[0].length];
			  for(int i = 0; i<cow; i++)
				  for(int j = 0; j<col; j++)
				  {
					  for(int k = 0; k<a[0].length; k++)
					  {
						  result[i][j] +=a[i][k] * b[k][j]; 
					  }
				  }
			  return result;
		  }
	  }

猜你喜欢

转载自blog.csdn.net/qq_30268545/article/details/80598365
今日推荐