CMU15-455プロジェクト#0-C ++ PRIMER(マトリックスクラスの詳細な説明を実装します)

質問:2次元配列を1次元配列に変換する必要があるのはなぜですか(1次元配列はCPUキャッシュを使用して継続的に格納されます)

 行列は連続メモリ空間(1次元配列)と同等です、初期化するのが最善です

行優先のストレージマトリックスでは、ポインタの配列を使用します。各ポインタは行の最初のアドレスを指し、元の線形メモリを操作します。残りは達成するのは難しくありません。

同時に、このタイプの実装では、スマートポインターUnique_Pointを使用してマトリックスを管理します。

#pragma once

#include <memory>

namespace bustub {

/*
 * The base class defining a Matrix
 */
template <typename T>
class Matrix {
 protected:
  Matrix(int r, int c):rows(r),cols(c) {
    linear = new T[rows*cols];
    memset(linear,0,sizeof(T)*rows*cols);
  }
  // # of rows in the matrix
  int rows;
  // # of Columns in the matrix
  int cols;
  // Flattened array containing the elements of the matrix
  // the array in the destructor.
  T *linear;

 public:
  // Return the # of rows in the matrix
  virtual int GetRows() = 0;

  // Return the # of columns in the matrix
  virtual int GetColumns() = 0;

  // Return the (i,j)th  matrix element
  virtual T GetElem(int i, int j) = 0;

  // Sets the (i,j)th  matrix element to val
  virtual void SetElem(int i, int j, T val) = 0;

  // Sets the matrix elements based on the array arr
  virtual void MatImport(T *arr) = 0;

  virtual ~Matrix(){
    delete [] linear;
  }
};

template <typename T>
class RowMatrix : public Matrix<T> {
 public:
  RowMatrix(int r, int c) : Matrix<T>(r, c) {
    data_ = new T*[r];
    for(int i=0;i<r;i++){
      data_[i] = this->linear + i*c;
    }
  }

  int GetRows() override { return this->rows; }

  int GetColumns() override { return this->cols; }

  T GetElem(int i, int j) override { return data_[i][j]; }

  void SetElem(int i, int j, T val) override {
    data_[i][j] = val;
  }

  void MatImport(T *arr) override {
    int row = GetRows();
    int col = GetColumns();
    for(int i=0;i<row*col;i++){
      this->linear[i] = arr[i];
    }
  }

  ~RowMatrix() override {
    delete [] data_;
  }

 private:
  // 2D array containing the elements of the matrix in row-major format
  // Allocate the array of row pointers in the constructor. Use these pointers
  // to point to corresponding elements of the 'linear' array.
  // Don't forget to free up the array in the destructor.
  T **data_;
};

template <typename T>
class RowMatrixOperations {
 public:
  // Compute (mat1 + mat2) and return the result.
  // Return nullptr if dimensions mismatch for input matrices.
  static std::unique_ptr<RowMatrix<T>> AddMatrices(std::unique_ptr<RowMatrix<T>> mat1,
                                                   std::unique_ptr<RowMatrix<T>> mat2) {
    if(mat1->GetRows()!=mat2->GetRows()||mat1->GetColumns()!=mat2->GetColumns()){
      return nullptr;
    }
    int row = mat1->GetRows();
    int col = mat1->GetColumns();
    std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
    for(int i=0;i<row;i++){
      for(int j=0;j<col;j++){
        mat3->SetElem(i,j,mat1->GetElem(i,j)+mat2->GetElem(i,j));
      }
    }
    return mat3;
  }

  // Compute matrix multiplication (mat1 * mat2) and return the result.
  // Return nullptr if dimensions mismatch for input matrices.
  static std::unique_ptr<RowMatrix<T>> MultiplyMatrices(std::unique_ptr<RowMatrix<T>> mat1,
                                                        std::unique_ptr<RowMatrix<T>> mat2) {
    if(mat1->GetColumns()!=mat2->GetRows()){
      return nullptr;
    }
    int row = mat1->GetRows();
    int col = mat2->GetColumns();
    int n = mat1->GetColumns();
    std::unique_ptr<RowMatrix<T>> mat3 = std::make_unique<RowMatrix<T>>(row,col);
    for(int i=0;i<row;i++){
      for(int j=0;j<col;j++){
        for(int k=0;k<n;k++){
          mat3->SetElem(i,j,mat3->GetElem(i,j)+mat1->GetElem(i,k)*mat2->GetElem(k,j));
        }
      }
    }
    return mat3;
  }

  // Simplified GEMM (general matrix multiply) operation
  // Compute (matA * matB + matC). Return nullptr if dimensions mismatch for input matrices
  static std::unique_ptr<RowMatrix<T>> GemmMatrices(std::unique_ptr<RowMatrix<T>> matA,
                                                    std::unique_ptr<RowMatrix<T>> matB,
                                                    std::unique_ptr<RowMatrix<T>> matC) {

    auto matTmp = MultiplyMatrices(matA,matB);
    if(matTmp== nullptr){
      return nullptr;
    }
    return AddMatrices(matTmp,matC);
  }
};
}  // namespace bustub

 

おすすめ

転載: blog.csdn.net/wwxy1995/article/details/113817834