问题:为什么二维数组要转化为一维数组(一维数组连续存储,利用CPU Cache)
Matrix就是相当于连续内存空间(一维数组), 最好初始化
对于行优先存储矩阵,我们用指针数组,每个指针指向该行的首地址,然后在原来线性的内存上操作。剩下的就不难实现。
同时该类实现使用到智能指针Unique_Point来管理矩阵。
#pragma once
#include <memory>
namespace bustub {
/*
* The base class defining a Matrix
*/
template <typename T>
class Matrix {
protected:
Matrix(int r, int c):rows(r),cols(c) {
linear = new T[rows*cols];
memset(linear,0,sizeof(T)*rows*cols);
}
// # of rows in the matrix
int rows;
// # of Columns in the matrix
int cols;
// Flattened array containing the elements of the matrix
// the array in the destructor.
T *linear;
public:
// Return the # of rows in the matrix
virtual int GetRows() = 0;
// Return the # of columns in the matrix
virtual int GetColumns() = 0;
// Return the (i,j)th matrix element
virtual T GetElem(int i, int j) = 0;
// Sets the (i,j)th matrix element to val
virtual void SetElem(int i, int j, T val) = 0;
// Sets the matrix elements based on the array arr
virtual void MatImport(T *arr) = 0;
virtual ~Matrix(){
delete [] linear;
}
};
template <typename T>
class RowMatrix : public Matrix<T> {
public:
RowMatrix(int r, int c) : Matrix<T>(r, c) {
data_ = new T*[r];
for(int i=0;i<r;i++){
data_[i] = this->linear + i*c;
}
}
int GetRows() override { return this->rows; }
int GetColumns() override { return this->cols; }
T GetElem(int i, int j) override { return data_[i][j]; }
void SetElem(int i, int j, T val) override {
data_[i][j] = val;
}
void MatImport(T *arr) override {
int row = GetRows();
int col = GetColumns();
for(int i=0;i<row*col;i++){
this->linear[i] = arr[i];
}
}
~RowMatrix() override {
delete [] data_;
}
private:
// 2D array containing the elements of the matrix in row-major format
// Allocate the array of row pointers in the constructor. Use these pointers
// to point to corresponding elements of the 'linear' array.
// Don't forget to free up the array in the destructor.
T **data_;
};
template <typename T>
class RowMatrixOperations {
public:
// Compute (mat1 + mat2) and return the result.
// Return nullptr if dimensions mismatch for input matrices.
static std::unique_ptr<RowMatrix<T>> AddMatrices(std::unique_ptr<RowMatrix<T>> mat1,
std::unique_ptr<RowMatrix<T>> mat2) {
if(mat1->GetRows()!=mat2->GetRows()||mat1->GetColumns()!=mat2->GetColumns()){
return nullptr;
}
int row = mat1->GetRows();
int col = mat1->GetColumns();
std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
mat3->SetElem(i,j,mat1->GetElem(i,j)+mat2->GetElem(i,j));
}
}
return mat3;
}
// Compute matrix multiplication (mat1 * mat2) and return the result.
// Return nullptr if dimensions mismatch for input matrices.
static std::unique_ptr<RowMatrix<T>> MultiplyMatrices(std::unique_ptr<RowMatrix<T>> mat1,
std::unique_ptr<RowMatrix<T>> mat2) {
if(mat1->GetColumns()!=mat2->GetRows()){
return nullptr;
}
int row = mat1->GetRows();
int col = mat2->GetColumns();
int n = mat1->GetColumns();
std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
for(int i=0;i<row;i++){
for(int j=0;j<col;j++){
for(int k=0;k<n;k++){
mat3->SetElem(i,j,mat3->GetElem(i,j)+mat1->GetElem(i,k)*mat2->GetElem(k,j));
}
}
}
return mat3;
}
// Simplified GEMM (general matrix multiply) operation
// Compute (matA * matB + matC). Return nullptr if dimensions mismatch for input matrices
static std::unique_ptr<RowMatrix<T>> GemmMatrices(std::unique_ptr<RowMatrix<T>> matA,
std::unique_ptr<RowMatrix<T>> matB,
std::unique_ptr<RowMatrix<T>> matC) {
auto matTmp = MultiplyMatrices(matA,matB);
if(matTmp== nullptr){
return nullptr;
}
return AddMatrices(matTmp,matC);
}
};
} // namespace bustub