矩阵的乘法通用模板(C++/Java)

矩阵的乘法

0x00 C++版本

#include <bits/stdc++.h>

using namespace std;
const int N = 25;
const double INF = 1e19;
const double EPS = 1e-6;

struct Matrix
{
   int rowCount,colCount;
   double mat[N][N];
} e;
Matrix a,b;

// 矩阵乘法
Matrix mat_mult(Matrix a,Matrix b)
{
   Matrix ans;
   for(int i = 0; i < a.rowCount; i++)
   {
       for(int j = 0; j < b.colCount; j++)
       {
           double tmp = 0;
           for(int k = 0; k < a.colCount; k++)
           {
               tmp += a.mat[i][k] * b.mat[k][j];
           }
           ans.mat[i][j] = tmp;
       }
   }
   ans.rowCount = a.rowCount;
   ans.colCount = b.colCount;
   return ans;
}

//矩阵转置
Matrix mat_transpose(Matrix a)
{
   Matrix res;
   for(int i = 0; i < a.rowCount; i++)
   {
       for(int j = 0; j < a.colCount; j++)
       {
           res.mat[j][i] = a.mat[i][j];
       }
   }
   res.rowCount = a.colCount;
   res.colCount = a.rowCount;
   return res;
}

//读入矩阵
void read_mat(Matrix &a,int m,int n)
{
   a.rowCount = m;
   a.colCount = n;
   for(int i = 0; i < a.rowCount; i++)
   {
       for(int j = 0; j < a.colCount; j++)
       {
           cin>>a.mat[i][j];
       }
   }
}

//矩阵打印
void print_mat(Matrix res)
{
   cout<<"-------------------"<<endl;
   for(int i = 0; i < res.rowCount; i++)
   {
       for(int j = 0; j < res.colCount; j++)
       {
           cout<<res.mat[i][j]<<" ";
       }
       cout<<endl;
   }
}

//计算偏差
double calc_diff(Matrix Y)
{
   double diff = 0;
   for(int i = 0 ; i < Y.rowCount; i++)
   {
       for(int j = 0; j < Y.colCount; j++)
       {
           diff += fabs(Y.mat[i][j] - b.mat[i][j])/b.mat[i][j];
       }
   }
   return diff/Y.rowCount;
}
double diff = INF;
Matrix X,ans;

void dfs(Matrix X,int n,int cnt,double sum)
{
   //可行性剪枝
   if(sum > 1.0 || (cnt == n&&sum != 1))
       return;

   if(cnt == n)
   {
      Matrix res = mat_mult(a,X);
      double now_diff = calc_diff(res);
      if(diff - now_diff > EPS)
      {
          diff = now_diff;
          ans = X;
      }
      return;

   }
   for (int j = 0; j <= 100; ++j)
   {
       X.mat[cnt][0] = j / 100.0;
       dfs(X,n,cnt+1,sum+j/100.0);
   }
}

int main()
{
   int m,n,x,y;

   cin>>m>>n;
   read_mat(a,m,n);
   //转置处理
   a = mat_transpose(a);

   cin>>x>>y;
   read_mat(b,x,y);
   //转置处理
   b = mat_transpose(b);

   X.rowCount = a.colCount;
   X.colCount = 1;
   dfs(X,a.colCount,0,0);

   cout<<"TargetX:"<<endl;
   print_mat(ans);

   cout<<"diff:"<<diff<<endl;

   cout<<"Y:"<<endl;
   print_mat(mat_transpose(mat_mult(a,ans)));
   return 0;
}
/*
3 3 2

1 2 -1
-1 3 4
1 1 1

5 6
-5 -6
6 0
*/
/*
res:
-11 -6
4 -24
6 0
*/

/*
4 2
2.89 0.376
3.18 0.272
2.49 0.393
3.28 0.308

1 2
3.06 0.322
*/

0x01 java版本

package Demo;


class Matrix{
	static int N = 25;
	int rowCount;
	int colCount;
	double [][] mat = new double[N][N];
	Matrix() {
		
	}
	
	Matrix(int rows,int cols,double[][] m) {
		rowCount = rows;
		colCount = cols;
		for(int i = 0; i < rowCount; i++) {
			for(int j = 0; j < colCount; j++){
				mat[i][j] = m[i][j];
			}
		}
	}
	Matrix(Matrix X) {
		rowCount = X.rowCount;
		colCount = X.colCount;
		for(int i = 0; i < rowCount; i++) {
			for(int j = 0; j < colCount; j++){
				mat[i][j] = X.mat[i][j];
			}
		}
	}
} 

public class BlendAlgorithm {
	Matrix a = new Matrix();
	Matrix b = new Matrix();
	Matrix X = new Matrix();
	
	//ans:各批的比例,Y:最佳勾兑后的结果
	Matrix ans = new Matrix();
	Matrix Y = new Matrix();
	 
	static double INF = 1e19;
	static double EPS = 1e-6;
	double diff = INF;
	
	// 矩阵乘法
	Matrix matrixMultiply(Matrix a,Matrix b)
	{
	    Matrix ans = new Matrix();
	    for(int i = 0; i < a.rowCount; i++)
	    {
	        for(int j = 0; j < b.colCount; j++)
	        {
	            double tmp = 0;
	            for(int k = 0; k < a.colCount; k++)
	            {
	                tmp += a.mat[i][k] * b.mat[k][j];
	            }
	            ans.mat[i][j] = tmp;
	        }
	    }
	    ans.rowCount = a.rowCount;
	    ans.colCount = b.colCount;
	    return ans;
	}

	//矩阵转置
	Matrix matrixTranspose(Matrix a)
	{
	    Matrix res = new Matrix();
	    for(int i = 0; i < a.rowCount; i++)
	    {
	        for(int j = 0; j < a.colCount; j++)
	        {
	            res.mat[j][i] = a.mat[i][j];
	        }
	    }
	    res.rowCount = a.colCount;
	    res.colCount = a.rowCount;
	    return res;
	}

	//矩阵打印
	void printMatrix(Matrix res)
	{
	    System.out.println("------------------------");
	    System.out.printf("rowCont=%d colCount=%d\n",res.rowCount,res.colCount);
	    for(int i = 0; i < res.rowCount; i++)
	    {
	        for(int j = 0; j < res.colCount; j++)
	        {
	            System.out.printf("%f ",res.mat[i][j]);
	   
	        }
	        System.out.println();
	    }
	}
	
	//计算偏差
	double calcDiff(Matrix Y)
	{
		//printMatrix(Y);
	    double diff = 0;
	    for(int i = 0 ; i < Y.rowCount; i++)
	    {
	        for(int j = 0; j < Y.colCount; j++)
	        {
	            diff += Math.abs(Y.mat[i][j] - b.mat[i][j])/b.mat[i][j];
	        }
	    }
	    return diff/Y.rowCount;
	}
	
	//勾兑算法
	public void blend(Matrix X,int n,int cnt,double sum) {
		
		//可行性剪枝
	    if(sum > 1.0 || (cnt == n && sum != 1.0)) {
	    	return;
	    }
	         
	    if(cnt == n){
	       Matrix res = matrixMultiply(a,X);
	       double now_diff = calcDiff(res);
	       if(diff - now_diff > EPS) {
	           diff = now_diff;
	           //ans = X是引用,这里要用重载的构造函数!!!
	           ans = new Matrix(X);
	       }
	       return;

	    }
	    for (int j = 0; j <= 100; ++j){
	        X.mat[cnt][0] = j / 100.0;
	        blend(X,n,cnt+1,sum+j/100.0);
	    }
	}
    
	//打印最终X结果矩阵
    void printAnswer() {
    	System.out.println("*******************");
    	System.out.println("TargetX:");
		printMatrix(ans);	
		System.out.println("*******************");
		
		System.out.println("diff:\n------------------------");
		System.out.println(diff);
		
		System.out.println("*******************");
		System.out.println("Y:");
		Y = matrixTranspose(matrixMultiply(a,ans));
		printMatrix(Y);
    }
	
	//测试
	public static void main(String[] args) {
		BlendAlgorithm f = new BlendAlgorithm();
		
		//4,2
		double[][] a = new double[][]{
			{2.89,0.376},
			{3.18,0.272},
			{2.49,0.393},
			{3.28,0.308}
		};
		
		f.a = new Matrix(4,2,a);
		//转置
		f.a = f.matrixTranspose(f.a);
		//f.printMatrix(f.a);
		
		//1,2
		double [][] b = new double[][] {
			{3.06,0.322}
		};
		f.b = new Matrix(1,2,b);
		//转置
		f.b = f.matrixTranspose(f.b);
		//f.printMatrix(f.b);
		
		f.X.rowCount = f.a.colCount;
	    f.X.colCount = 1;
		f.blend(f.X,f.a.colCount,0,0);
	
        f.printAnswer();
	}
}
 
发布了301 篇原创文章 · 获赞 38 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/tb_youth/article/details/103912638
今日推荐