CMU15-455 PROJECT #0 - C++ PRIMER (实现一个矩阵类详解)

问题:为什么二维数组要转化为一维数组(一维数组连续存储,利用CPU Cache)

 Matrix就是相当于连续内存空间(一维数组), 最好初始化

对于行优先存储矩阵,我们用指针数组,每个指针指向该行的首地址,然后在原来线性的内存上操作。剩下的就不难实现。

同时该类实现使用到智能指针Unique_Point来管理矩阵。

#pragma once

#include <memory>

namespace bustub {

/*
 * The base class defining a Matrix
 */
template <typename T>
class Matrix {
 protected:
  Matrix(int r, int c):rows(r),cols(c) {
    linear = new T[rows*cols];
    memset(linear,0,sizeof(T)*rows*cols);
  }
  // # of rows in the matrix
  int rows;
  // # of Columns in the matrix
  int cols;
  // Flattened array containing the elements of the matrix
  // the array in the destructor.
  T *linear;

 public:
  // Return the # of rows in the matrix
  virtual int GetRows() = 0;

  // Return the # of columns in the matrix
  virtual int GetColumns() = 0;

  // Return the (i,j)th  matrix element
  virtual T GetElem(int i, int j) = 0;

  // Sets the (i,j)th  matrix element to val
  virtual void SetElem(int i, int j, T val) = 0;

  // Sets the matrix elements based on the array arr
  virtual void MatImport(T *arr) = 0;

  virtual ~Matrix(){
    delete [] linear;
  }
};

template <typename T>
class RowMatrix : public Matrix<T> {
 public:
  RowMatrix(int r, int c) : Matrix<T>(r, c) {
    data_ = new T*[r];
    for(int i=0;i<r;i++){
      data_[i] = this->linear + i*c;
    }
  }

  int GetRows() override { return this->rows; }

  int GetColumns() override { return this->cols; }

  T GetElem(int i, int j) override { return data_[i][j]; }

  void SetElem(int i, int j, T val) override {
    data_[i][j] = val;
  }

  void MatImport(T *arr) override {
    int row = GetRows();
    int col = GetColumns();
    for(int i=0;i<row*col;i++){
      this->linear[i] = arr[i];
    }
  }

  ~RowMatrix() override {
    delete [] data_;
  }

 private:
  // 2D array containing the elements of the matrix in row-major format
  // Allocate the array of row pointers in the constructor. Use these pointers
  // to point to corresponding elements of the 'linear' array.
  // Don't forget to free up the array in the destructor.
  T **data_;
};

template <typename T>
class RowMatrixOperations {
 public:
  // Compute (mat1 + mat2) and return the result.
  // Return nullptr if dimensions mismatch for input matrices.
  static std::unique_ptr<RowMatrix<T>> AddMatrices(std::unique_ptr<RowMatrix<T>> mat1,
                                                   std::unique_ptr<RowMatrix<T>> mat2) {
    if(mat1->GetRows()!=mat2->GetRows()||mat1->GetColumns()!=mat2->GetColumns()){
      return nullptr;
    }
    int row = mat1->GetRows();
    int col = mat1->GetColumns();
    std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
    for(int i=0;i<row;i++){
      for(int j=0;j<col;j++){
        mat3->SetElem(i,j,mat1->GetElem(i,j)+mat2->GetElem(i,j));
      }
    }
    return mat3;
  }

  // Compute matrix multiplication (mat1 * mat2) and return the result.
  // Return nullptr if dimensions mismatch for input matrices.
  static std::unique_ptr<RowMatrix<T>> MultiplyMatrices(std::unique_ptr<RowMatrix<T>> mat1,
                                                        std::unique_ptr<RowMatrix<T>> mat2) {
    if(mat1->GetColumns()!=mat2->GetRows()){
      return nullptr;
    }
    int row = mat1->GetRows();
    int col = mat2->GetColumns();
    int n = mat1->GetColumns();
    std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
    for(int i=0;i<row;i++){
      for(int j=0;j<col;j++){
        for(int k=0;k<n;k++){
          mat3->SetElem(i,j,mat3->GetElem(i,j)+mat1->GetElem(i,k)*mat2->GetElem(k,j));
        }
      }
    }
    return mat3;
  }

  // Simplified GEMM (general matrix multiply) operation
  // Compute (matA * matB + matC). Return nullptr if dimensions mismatch for input matrices
  static std::unique_ptr<RowMatrix<T>> GemmMatrices(std::unique_ptr<RowMatrix<T>> matA,
                                                    std::unique_ptr<RowMatrix<T>> matB,
                                                    std::unique_ptr<RowMatrix<T>> matC) {

    auto matTmp = MultiplyMatrices(matA,matB);
    if(matTmp== nullptr){
      return nullptr;
    }
    return AddMatrices(matTmp,matC);
  }
};
}  // namespace bustub

猜你喜欢

转载自blog.csdn.net/wwxy1995/article/details/113817834
今日推荐