Java implementation of matrix multiplication and optimization methods

Java implementation of matrix multiplication and optimization methods

Traditional matrix multiplication implementation

  First of all, for two matrices to be multiplied, a premise must be met: the number of columns in the previous matrix is ​​equal to the number of rows in the next matrix.

  The sum of the product of the mth row of the first matrix and the nth column of the second matrix is ​​the value of the mth row and nth column of the product matrix. This process can be represented by the following image.

Matrix multiplication process display

C[1][1] = A[1][0] * B[0][1] + A[1][1] * B[1][1] + A[1][2] * B[2][1] + A[1][3] * B[3][1] + A[1][4] * B[4][1]

  The traditional way to implement this process in Java is to implement a triple loop according to this rule, accumulating all the multiplications:

public int[][] multiply(int[][] mat1, int[][] mat2){
    
    
	int m = mat1.length, n = mat2[0].length;
	int[][] mat = new int[m][n];
	for(int i = 0; i < m; i++){
    
    
		for(int j = 0; j < n; j++){
    
    
			for(int k = 0; k < mat1[0].length; k++){
    
    
				mat[i][j] += mat1[i][k] * mat2[k][j];
			}
		}
	}
	return mat;
}

  It can be seen that the time complexity of this method is O(n 3 ), and the program will easily time out when the matrix dimension is relatively large.

Optimization method (Strassen algorithm)

  The Strassen algorithm is the first matrix multiplication algorithm with a time complexity of less than O(n³) proposed by Volker Strassen in 1966. Its main idea is to realize the fast operation of matrix multiplication through divide and conquer. The calculation process is shown in the figure:
Split a matrix multiplication into a combination of multiple multiplications and additions
  Why is this method faster? We know that according to the traditional matrix multiplication:

C11 = A11 * B11 + A12 * B21
C12 = A11 * B12 + A12 * B22
C21 = A21 * B11 + A22 * B21
C22 = A21 * B12 + A22 * B22

  We need 8 matrix multiplications and 4 matrix additions. These 8 multiplications are the most time-consuming. The Strassen method only needs 7 matrix multiplications, although the cost is that the number of matrix additions becomes 18, but based on the order of magnitude considerations, 18 times Addition is still faster than 1 multiplication.

  Of course, the code implementation of the Strassen algorithm is also much more complicated than the traditional algorithm. Here is another java implementation written by the great god (original link: https://blog.csdn.net/wj310298/article/details/44857175 ):

public class Matrix {
    
    
	private final Matrix[] _matrixArray;
	private final int n;
	private int element;
	public Matrix(int n) {
    
    
		this.n = n;
		if (n != 1) {
    
    
			this._matrixArray = new Matrix[4];
			for (int i = 0; i < 4; i++) {
    
    
				this._matrixArray[i] = new Matrix(n / 2);
			}
		} else {
    
    
			this._matrixArray = null; 
		}
	}
	private Matrix(int n, boolean needInit) {
    
    
		this.n = n;
		if (n != 1) {
    
    
			this._matrixArray = new Matrix[4];
		} else {
    
    
			this._matrixArray = null; 
		}
	}
	public void set(int i, int j, int a) {
    
    
		if (n == 1) {
    
    
			element = a;
		} else {
    
    
			int size = n / 2;
			this._matrixArray[(i / size) * 2 + (j / size)].set(i % size, j % size, a);
		}
	}
	public Matrix multi(Matrix m) {
    
    
		Matrix result = null;
		if (n == 1) {
    
    
			result = new Matrix(1);
			result.set(0, 0, (element * m.element));
		} else {
    
    
			result = new Matrix(n, false);
			result._matrixArray[0] = P5(m).add(P4(m)).minus(P2(m)).add(P6(m));
			result._matrixArray[1] = P1(m).add(P2(m));
			result._matrixArray[2] = P3(m).add(P4(m));
			result._matrixArray[3] = P5(m).add(P1(m)).minus(P3(m)).minus(P7(m));
		}
		return result;
	}
	public Matrix add(Matrix m) {
    
    
		Matrix result = null;
		if (n == 1) {
    
    
			result = new Matrix(1);
			result.set(0, 0, (element + m.element));
		} else {
    
    
			result = new Matrix(n, false);
			result._matrixArray[0] = this._matrixArray[0].add(m._matrixArray[0]);
			result._matrixArray[1] = this._matrixArray[1].add(m._matrixArray[1]);
			result._matrixArray[2] = this._matrixArray[2].add(m._matrixArray[2]);
			result._matrixArray[3] = this._matrixArray[3].add(m._matrixArray[3]);;
		}
		return result;
	}
	public Matrix minus(Matrix m) {
    
    
		Matrix result = null;
		if (n == 1) {
    
    
			result = new Matrix(1);
			result.set(0, 0, (element - m.element));
		} else {
    
    
			result = new Matrix(n, false);
			result._matrixArray[0] = this._matrixArray[0].minus(m._matrixArray[0]);
			result._matrixArray[1] = this._matrixArray[1].minus(m._matrixArray[1]);
			result._matrixArray[2] = this._matrixArray[2].minus(m._matrixArray[2]);
			result._matrixArray[3] = this._matrixArray[3].minus(m._matrixArray[3]);;
		}
		return result;
	}
	protected Matrix P1(Matrix m) {
    
    
		return _matrixArray[0].multi(m._matrixArray[1]).minus(_matrixArray[0].multi(m._matrixArray[3]));
	}
	protected Matrix P2(Matrix m) {
    
    
		return _matrixArray[0].multi(m._matrixArray[3]).add(_matrixArray[1].multi(m._matrixArray[3]));
	}
	protected Matrix P3(Matrix m) {
    
    
		return _matrixArray[2].multi(m._matrixArray[0]).add(_matrixArray[3].multi(m._matrixArray[0]));
	}
	protected Matrix P4(Matrix m) {
    
    
		return _matrixArray[3].multi(m._matrixArray[2]).minus(_matrixArray[3].multi(m._matrixArray[0]));
	}
	protected Matrix P5(Matrix m) {
    
    
		return (_matrixArray[0].add(_matrixArray[3])).multi(m._matrixArray[0].add(m._matrixArray[3]));
	}
	protected Matrix P6(Matrix m) {
    
    
		return (_matrixArray[1].minus(_matrixArray[3])).multi(m._matrixArray[2].add(m._matrixArray[3]));
	}
	protected Matrix P7(Matrix m) {
    
    
		return (_matrixArray[0].minus(_matrixArray[2])).multi(m._matrixArray[0].add(m._matrixArray[1]));
	}
	public int get(int i, int j) {
    
    
		if (n == 1) {
    
    
			return element;
		} else {
    
    
			int size = n / 2;
			return this._matrixArray[(i / size) * 2 + (j / size)].get(i % size, j % size);
		}
	}
	public void display() {
    
    
		for (int i = 0; i < n; i++) {
    
    
			for (int j = 0; j < n; j++) {
    
    
				System.out.print(get(i, j));
				System.out.print(" ");
			}
			System.out.println();
		}
	}
	
	public static void main(String[] args) {
    
    
		Matrix m = new Matrix(2);
		Matrix n = new Matrix(2);
		m.set(0, 0, 1);
		m.set(0, 1, 3);
		m.set(1, 0, 5);
		m.set(1, 1, 7);
		n.set(0, 0, 8);
		n.set(0, 1, 4);
		n.set(1, 0, 6);
		n.set(1, 1, 2);
		Matrix res = m.multi(n);
		res.display();
	}
}

Guess you like

Origin blog.csdn.net/GGG_Yu/article/details/109693318