C++ 方矩阵乘法 + Strassen矩阵

  这几天看算法导论,看到矩阵一章,就实现了一下。

下面是普通的矩阵乘法,复杂度为:n^3。

template<unsigned M,unsigned N, unsigned Q>
void Square_matrix_multiply(int(&A)[M][N], int(&B)[N][Q], int(&C)[M][Q]) {                 
	for (size_t i = 0;i != M;++i) {
		for (size_t j = 0;j != Q;++j) {
			C[i][j] = 0;
			for (size_t n = 0;n != N;++n) {
				C[i][j] += A[i][n] * B[n][j];
			}
		}
	}
}

函数接受三个二维数组,A * B得到的矩阵赋值给C。

下面是分治策略的算法。

template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1)
		return C = A.get()*B.get();
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);    // 使用一个类MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);    // 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);    // 进行分割
		C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);  // Matrix::operator+;
		C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);  // MatrixRef::operator=;
		C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
		C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
	}
	return C;
}

矩阵实现了一个Matrix类(具体实现在最下面),有一个构造函数:接受两个size_t值l、r,生成l*r大小值全为0的矩阵。

Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
	data->resize(l*r);
}

其中hight为矩阵行高,width为列宽,data为shared_ptr,矩阵用vector实现。

A.rows()返回A的width长度(即方矩阵的边长),Matrix(n,n)创建一个矩阵。

size_t rows() const {
		return width;
	}

如果n==1,通过Matrix的get函数返回第一个元素,也就是唯一的一个元素。

int Matrix::get() const {
	return (*data)[0];
}

为了不复制矩阵元素(如果可以复制矩阵元素的话,会简单很多),另实现了一个MatrixRef,其含有:两个size_t数据成员(实现坐标点)、一个size_t数据成员(实现矩阵长度)、一个weak_ptr(指向vector<int>)。

MatrixRef含有两个构造函数:一个接受Matrix加两个size_t;一个接受MatrixRef加两个size_t。都是为了指明引用范围。

MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), 
      hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr),                 
      hight_startptr(mref.hight_startptr + line),
      width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }

wptr用data或wptr初始化,避免拷贝。length为rows()的返回值除以2,因为是分割为4个矩阵,行列各除以2。

要注意:接受MatrixRef的坐标要加上之前的坐标。

MatrixRef也有一个rows成员函数,为了递归调用。

size_t rows() const {
		return length;
	}

Square_max_matrix_multiply_recursive函数返回一个Matrix,Matrix实现了operator+,但是行列必须相等。

Matrix& Matrix::operator+=(const Matrix &rhs) {
	if (hight == rhs.hight && width == rhs.width) {
		for (size_t i = 0;i != size();++i)
			(*data)[i] += (*rhs.data)[i];
	}
	else
		throw std::logic_error("Not Matched");
	return *this;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}

MatrixRef实现了一个operator=。

MatrixRef& MatrixRef::operator=(const Matrix &rhs) {   
	for (size_t i = 0;i != length;++i) {
		for (size_t j = 0;j != length;++j) {
			(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = 
               rhs.get(i + 1, j + 1);  //注意:length*2  因为C也被分割了
		}
	}
	return *this;
}

其中(i + hight_startptr)*length * 2 + j + width_startptr)为当前下标(vector对应矩阵的下标,非矩阵行列)。此函数将分割的C进行“拼合”。注意:length * 2  ,因为C也被分割了,不乘以2为C_11及C_12的长度,乘以2才是C的行列长宽,才能给C的给定位置赋值。

下面是Strassen矩阵算法。

template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 为2的幂的情况下
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1) {
		return C = A.get()*B.get();
	}
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);        // 使用一个类MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);        // 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);        // 进行分割
		Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22,     //MatrixRef的加、减
			S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
		Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
			P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
			P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
		C_11 = P5 + P4 - P2 + P6;
		C_12 = P1 + P2;
		C_21 = P3 + P4;
		C_22 = P5 + P1 - P3 - P7;
	}
	return C;
}

此算法较之前多了一个MatrixRef::operator-、以及MatrixRef::operator+。

Matrix& Matrix::operator-() {
	for (auto &f : *data)
		f = -f;
	return *this;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml = -mr + ml;
}
Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}

operator-用Matrix的取负、以及Matrix的加法,同时最重要的还有Matrix(const MatrixRef &)。MatrixRef将此对象引用范围内的子矩阵创建一个局部Matrix对象。

operator+用Matrix的加法与Matrix(const MatrixRef &)。

Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
	size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size()));            // 未分解的原式中的矩阵长度
	auto ivec = *rhs.wptr.lock();
	for (size_t i = 0; i != hight; ++i) {
		for (size_t j = 0; j != width; ++j) {
			data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
		}
	}
}

其中max_size为wptr所指的vector<int>的size,进行根号得到。max_size就是MatrixRef对象未分解(即未分割的C)的矩阵边长。static_cast把sqrt返回的double转未size_t,因为是方矩阵,所以不会损失精度。

(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr)为MatrixRef对象引用范围内对应vector的下标。此必需乘以max_size。

下面为不是2的幂的情况。

template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
	size_t n = A.rows();
	double size = log(n) / log(2);
	size_t l_size = static_cast<size_t>(size);
	if (l_size != size) {
		size_t t_size = (l_size + 1)*(l_size + 1);
		Matrix a(t_size, t_size), b(t_size, t_size);
		a = A;
		b = B;
		Matrix C = Strassen_matrix_fit(a, b);
		Matrix c(n, n);
		c = C;
		return c;
	}
	else
		return Strassen_matrix_fit(A, B);
}

size与l_size比较可知是否为2的幂,如果是,执行else,不是,则执行if。

当不是2的幂时,我的思路是把它加0,拼成2的幂的形式。如下图。

1 2 3                1 2 3 0            5 6 7
2 3 2     --->       2 3 2 0   --->     4 5 2 
3 2 1                3 2 1 0            3 5 6
                     0 0 0 0

然后得出结果时再切去周围的零,其值是不变的。假如为n*n的矩阵,复杂度(n + k) ^ lg7。n + k 为最接近n的2的幂,其中0<k<n。

(n + k) ^ lg7 < (2n) ^ lg7 = 7 * n ^ lg7。

复杂度还是n ^ lg7。

加零还是切去零,我是通过赋值来实现的。

Matrix& Matrix::operator=(const Matrix &rhs) {
	if (hight == rhs.hight) {								//  rhs          this
		for (size_t i = 0; i != size(); ++i) {					        //	1 2 3		 1 2 3
			(*data)[i] = (*rhs.data)[i];						//	2 3 2   ->	 2 3 2
		}								        	//	3 2 1		 3 2 1
	}																		 
	else if (hight > rhs.hight) {							 	//	1 2 3		 1 2 3 0
		for (size_t i = 0;i != hight; ++i) {				        	//	2 3 2   ->	 2 3 2 0
			for (size_t j = 0, n = 1;j != width; ++j) {			        //	3 2 1		 3 2 1 0
				if (j >= rhs.width || i >= rhs.hight)			//				 0 0 0 0
					(*data)[i * width + j] = 0;
				else																
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];	
			}
		}
	}
	else {											     //	1 2 3 4	      1 2 3 
			for (size_t i = 0;i != hight; ++i) {					     // 2 3 4 3   ->  2 3 4 
				for (size_t j = 0;j != width; ++j) {				     // 3 4 3 2	      3 4 3 
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];     //	4 3 2 1		 
				}
			}
	}
	return *this;
}

有三种复制方式:当左矩阵边长与右矩阵边长相等,第一个,正常赋值。左>右时,左上角对其,剩余的赋0。左<右时,左上角对其,多余的切掉。

补充:

在我的电脑上,普通(n^3)的算法与Strassen算法在1500*1500左右的时候时间是差不多的,但是耗时达到30秒,之后Strassen算法会出现明显的优势。在小于100*100的矩阵乘法时普通算法耗时小于0.01秒,而Strassen可达到3秒,普通算法有绝对的优势。

在我的电脑上,把二维数组扩展到300*300以上时会有栈溢出,这时可以上网搜索找到相应的解决办法。

END

设置MatirxRef类不知道好不好,毕竟矩阵拷贝也不影响复杂度。

肯定有很多值得改进的地方,也有不对的地方,可以评论提醒一下。

附:

Matrix头文件。

#ifndef MATRIX_H
#define MATRIX_H
#include<iostream>
#include<memory>
#include<vector>
class MatrixRef;
class Matrix {
	friend Matrix operator+(const Matrix &, const Matrix &);
	friend Matrix operator-(const Matrix &, const Matrix &);
	friend std::ostream& operator<<(std::ostream&, const Matrix &);
	friend Matrix operator+(const MatrixRef &, const MatrixRef &);
	friend Matrix operator-(const MatrixRef &, const MatrixRef &);
	friend class MatrixRef;
public:
	Matrix();
	template<unsigned M, unsigned N>
	Matrix(int(&A)[M][N]) : hight(M), width(N), data(make_shared<vector<int>>()) {
		data->reserve(M*N);
		for (size_t i = 0;i != M;++i)
			for (size_t j = 0;j != N;++j)
				data->push_back(A[i][j]);
	}
	Matrix(size_t l, size_t r); // 创建一个行l、列r的零矩阵
	Matrix(const Matrix &rhs); // 深层次拷贝构造
	explicit Matrix(const MatrixRef &);  // 将MatrixRef转换为Matrix 
	int& get(size_t l, size_t r); // 取得行L、列R的值
	const int& get(size_t l, size_t r) const;
	int get() const; // 得到第一个值
	Matrix& operator=(const Matrix &rhs); // 深层次拷贝赋值
	Matrix& operator+=(const Matrix &rhs);
	Matrix& operator=(int i); // 把一个为i的值赋给行为1、列为1的矩阵
	Matrix& operator-(); // 对矩阵取负
	size_t rows() const {
		return width;
	}
	size_t size() const {
		return hight * width;
	}
private:
	void check_situation(size_t l, size_t r) const {
		if (l > hight || r > width)
			throw std::range_error("Invalid range");
	}
	size_t hight = 1;
	size_t width = 1;
	std::shared_ptr<std::vector<int>> data;
};

class MatrixRef {
	friend Matrix operator-(const MatrixRef &, const MatrixRef &);
	friend Matrix operator+(const MatrixRef &, const MatrixRef &);
	friend class Matrix;
public:
	MatrixRef(const Matrix &m, size_t line, size_t row);
	MatrixRef(const MatrixRef &mref, size_t line, size_t row);
	MatrixRef& operator=(const Matrix &rhs); // 对C_11、C_12、C_13、C_14进行赋值拼接的函数
	int& get() const; // 得到第一个值
	size_t rows() const {
		return length;
	}
private:
	std::weak_ptr<std::vector<int>> wptr;
	size_t hight_startptr;
	size_t width_startptr;
	size_t length;
};
Matrix operator+(const Matrix &, const Matrix &);
Matrix operator-(const Matrix &, const Matrix &);
Matrix operator+(const MatrixRef &, const MatrixRef &);
Matrix operator-(const MatrixRef &, const MatrixRef &);
std::ostream& operator<<(std::ostream&, const Matrix &);

template<typename T>
Matrix Square_max_matrix_multiply_recursive(const T &A, const T &B) {
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1)
		return C = A.get()*B.get();
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一个类MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 进行分割
		C_11 = Square_max_matrix_multiply_recursive(A_11, B_11) + Square_max_matrix_multiply_recursive(A_12, B_21);// Matrix::operator+;
		C_12 = Square_max_matrix_multiply_recursive(A_11, B_12) + Square_max_matrix_multiply_recursive(A_12, B_22);// MatrixRef::operator=;
		C_21 = Square_max_matrix_multiply_recursive(A_21, B_11) + Square_max_matrix_multiply_recursive(A_22, B_21);
		C_22 = Square_max_matrix_multiply_recursive(A_21, B_12) + Square_max_matrix_multiply_recursive(A_22, B_22);
	}
	return C;
}
template<typename T, typename N>
Matrix Strassen_matrix_fit(const T &A, const N &B) { // 为2的幂的情况下
	size_t n = A.rows();
	Matrix C(n, n);
	if (n == 1) {
		return C = A.get()*B.get();
	}
	else {
		MatrixRef A_11(A, 0, 0), A_12(A, 0, n / 2), A_21(A, n / 2, 0), A_22(A, n / 2, n / 2);// 使用一个类MatirxRef
		MatrixRef B_11(B, 0, 0), B_12(B, 0, n / 2), B_21(B, n / 2, 0), B_22(B, n / 2, n / 2);// 含有三个size_t类型。其中两个实现坐标,一个指明矩阵长度
		MatrixRef C_11(C, 0, 0), C_12(C, 0, n / 2), C_21(C, n / 2, 0), C_22(C, n / 2, n / 2);// 进行分割
		Matrix S1 = B_12 - B_22, S2 = A_11 + A_12, S3 = A_21 + A_22, S4 = B_21 - B_11, S5 = A_11 + A_22, //MatrixRef的加、减
			S6 = B_11 + B_22, S7 = A_12 - A_22, S8 = B_21 + B_22, S9 = A_11 - A_21, S10 = B_11 + B_12;
		Matrix P1 = Strassen_matrix_fit(A_11, S1), P2 = Strassen_matrix_fit(S2, B_22),
			P3 = Strassen_matrix_fit(S3, B_11), P4 = Strassen_matrix_fit(A_22, S4),
			P5 = Strassen_matrix_fit(S5, S6), P6 = Strassen_matrix_fit(S7, S8), P7 = Strassen_matrix_fit(S9, S10);
		C_11 = P5 + P4 - P2 + P6;
		C_12 = P1 + P2;
		C_21 = P3 + P4;
		C_22 = P5 + P1 - P3 - P7;
	}
	return C;
}
template<typename T, typename N>
Matrix Strassen_matrix(const T &A, const N &B) {
	size_t n = A.rows();
	double size = log(n) / log(2);
	size_t l_size = static_cast<size_t>(size);
	if (l_size != size) {
		size_t t_size = (l_size + 1)*(l_size + 1);
		Matrix a(t_size, t_size), b(t_size, t_size);
		a = A;
		b = B;
		Matrix C = Strassen_matrix_fit(a, b);
		Matrix c(n, n);
		c = C;
		return c;
	}
	else
		return Strassen_matrix_fit(A, B);
}


#endif

头文件实现。

#include"Matrix1.h"
#include<math.h>
using namespace std;

Matrix::Matrix() :data(make_shared<vector<int>>()) {
	data->push_back(0);
}
Matrix::Matrix(size_t l, size_t r) : hight(l), width(r), data(make_shared<vector<int>>()) {
	data->resize(l*r);
}
Matrix::Matrix(const Matrix &rhs) : hight(rhs.hight), width(rhs.width), data(make_shared<vector<int>>()) {
		for (size_t i = 0;i != size(); ++i)                        
			data->push_back((*rhs.data)[i]);														
}
Matrix::Matrix(const MatrixRef &rhs) : hight(rhs.length), width(rhs.length), data(make_shared<vector<int>>()) {
	size_t max_size = static_cast<size_t>(sqrt(rhs.wptr.lock()->size())); // 未分解的原式中的矩阵长度
	auto ivec = *rhs.wptr.lock();
	for (size_t i = 0; i != hight; ++i) {
		for (size_t j = 0; j != width; ++j) {
			data->push_back(ivec[(i + rhs.hight_startptr)*max_size + j + rhs.width_startptr]);
		}
	}
}
int& Matrix::get(size_t l, size_t r) {
	check_situation(l, r);
	return (*data)[--l * width + --r];
}
const int& Matrix::get(size_t l, size_t r) const {
	check_situation(l, r);
	return (*data)[--l * width + --r];
}
int Matrix::get() const {
	return (*data)[0];
}
Matrix& Matrix::operator=(const Matrix &rhs) {
	if (hight == rhs.hight) {												//  rhs          this
		for (size_t i = 0; i != size(); ++i) {								//	1 2 3		 1 2 3
			(*data)[i] = (*rhs.data)[i];									//	2 3 2   ->	 2 3 2
		}																	//	3 2 1		 3 2 1
	}																		 
	else if (hight > rhs.hight) {							 				//	1 2 3		 1 2 3 0
		for (size_t i = 0;i != hight; ++i) {								//	2 3 2   ->	 2 3 2 0
			for (size_t j = 0, n = 1;j != width; ++j) {						//	3 2 1		 3 2 1 0
				if (j >= rhs.width || i >= rhs.hight)						//				 0 0 0 0
					(*data)[i * width + j] = 0;
				else																
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j];	
			}
		}
	}
	else {																	 //	1 2 3 4		  1 2 3 
			for (size_t i = 0;i != hight; ++i) {							 //	2 3 4 3   ->  2 3 4 
				for (size_t j = 0;j != width; ++j) {				    	 //	3 4 3 2	      3 4 3 
					(*data)[i * width + j] = (*rhs.data)[i * rhs.width + j]; //	4 3	2 1		 
				}
			}
	}
	return *this;
}
Matrix& Matrix::operator+=(const Matrix &rhs) {
	if (hight == rhs.hight && width == rhs.width) {
		for (size_t i = 0;i != size();++i)
			(*data)[i] += (*rhs.data)[i];
	}
	else
		throw std::logic_error("Not Matched");
	return *this;
}
Matrix& Matrix::operator=(int i) {
	if (hight == width && hight == 1)
		(*data)[0] = i;
	return *this;
}
Matrix& Matrix::operator-() {
	for (auto &f : *data)
		f = -f;
	return *this;
}

Matrix operator+(const Matrix &lhs, const Matrix &rhs) {
	Matrix m(lhs);
	return m += rhs;
}
Matrix operator-(const Matrix &lhs,const Matrix &rhs) {
	Matrix m(rhs);
	return m = -m + lhs;
}

MatrixRef::MatrixRef(const Matrix &m, size_t line, size_t row) : wptr(m.data), hight_startptr(line), width_startptr(row), length(m.rows() / 2) { }
MatrixRef::MatrixRef(const MatrixRef &mref, size_t line, size_t row) : wptr(mref.wptr), hight_startptr(mref.hight_startptr + line),
width_startptr(mref.width_startptr + row), length(mref.rows() / 2) { }
MatrixRef& MatrixRef::operator=(const Matrix &rhs) {   
	for (size_t i = 0;i != length;++i) {
		for (size_t j = 0;j != length;++j) {
			(*wptr.lock())[(i + hight_startptr)*length * 2 + j + width_startptr] = rhs.get(i + 1, j + 1);  //注意:length*2  因为C也被分割了
		}
	}
	return *this;
}
int& MatrixRef::get() const {
	return (*wptr.lock())[static_cast<size_t>(hight_startptr*sqrt(wptr.lock()->size())) + width_startptr];
}
Matrix operator+(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml += mr;
}
Matrix operator-(const MatrixRef &lhs, const MatrixRef &rhs) {
	Matrix ml(lhs), mr(rhs);
	return ml = -mr + ml;
}
ostream& operator<<(ostream &os, const Matrix &m) {
	int i = 0;
	for (auto f : *m.data) {
		cout << f;
		if (++i == m.width) {
			std::cout << '\n';
			i = 0;
		}
		else
			cout << ' ';
	}
	return os;
}

END

猜你喜欢

转载自blog.csdn.net/qq_35959271/article/details/83511693