使用C++模板实现矩阵的运算

//matrix.h
#ifndef MATRIX_H_INCLUDED
#define MATRIX_H_INCLUDED
#include <iostream>
using namespace std;

template <typename T>
class Matrix
{
private:
    int cols, rows;
    T *pdata;
public:
    void AddVal(T a);
    bool AddMatrix(Matrix<T> m2);
    Matrix<T> MultiMatrix(Matrix<T> m2);
    Matrix<T> operator * (Matrix<T> &m2);
    Matrix()
    {
        cols=1, rows=1;
        pdata=new T [1];
        *pdata=0;
    }
    Matrix(T v)
    {
        cols=1, rows=1;
        pdata=new T [1];
        *pdata=v;
    }
    Matrix(int rs, int cols, T Arr[], int arrsize)
    {
        this->cols=cols;
        rows=rs;
        pdata=new T [rows*cols];
        int i, n=rows*cols;
        for(i=0;i<n;i++)
        {
            *(pdata+i)=0;
        }
        for(i=0;i<arrsize;i++)
        {
            *(pdata+i)=Arr[i];
        }
    }
    ~Matrix(){delete []pdata;}
    template <typename T1>
    friend ostream & operator << (ostream & o, Matrix <T1> & m);
};

template <typename T>
ostream & operator << (ostream & o, Matrix <T> & m)
{
    int i, j;
    for(i=0;i<m.rows;i++)
    {
        for(j=0;j<m.cols;j++)
        {
            o<<m.pdata[i*m.cols+j]<<"  ";
        }
        o<<endl;
    }
    return o;
}

template <typename T>
void Matrix<T>::AddVal(T a)
{
    int i, n=rows*cols;
    for(i=0;i<n;i++)
    {
        *(pdata+i)+=a;
    }
}

template <typename T>
bool Matrix<T>::AddMatrix(Matrix<T> m2)
{
    int n=rows*cols, i;
    if(rows==m2.rows&&cols==m2.cols)
    {
        for(i=0;i<n;i++)
        {
            *(pdata+i)+=m2.pdata[i];
        }
        return true;
    }
    else return false;
}

template <typename T>
Matrix<T> Matrix<T>::MultiMatrix(Matrix<T> m2)
{
    int i,j, q, res=0;
    int arr[1]={0};
    Matrix<T> m3(rows, m2.cols, arr, 1);
    //cout<<m3<<endl;
    for(i=0;i<rows;i++)
        {
            for(j=0;j<cols;j++)
            {
                for(q=0;q<cols;q++)
                {
                    res+=*(pdata+i*rows+q)*m2.pdata[j+q*m2.cols];
                    //cout<<res<<"  ";
                }
                //cout<<endl;
                m3.pdata[i*m2.cols+j]+=res;
                res=0;
            }
        }
    return m3;
}

template <typename T>
Matrix<T> Matrix<T>::operator * (Matrix<T> &m2)
{
    int i,j, q, res=0;
    int arr[1]={0};
    Matrix<T> m3(rows, m2.cols, arr, 1);
    //cout<<m3<<endl;
    for(i=0;i<rows;i++)
        {
            for(j=0;j<cols;j++)
            {
                for(q=0;q<cols;q++)
                {
                    res+=*(pdata+i*rows+q)*m2.pdata[j+q*m2.cols];
                    //cout<<res<<"  ";
                }
                //cout<<endl;
                m3.pdata[i*m2.cols+j]+=res;
                res=0;
            }
        }
    return m3;
}
#endif // MATRIX_H_INCLUDED

//main.cpp
#include <iostream>
#include "matrix.h"
using namespace std;

int main()
{
    int arr[5]={1,2,3,4,5};
    Matrix <int> m1(2,3, arr, 5);
    cout<<m1<<endl;
    //m1.AddVal(5);
    //cout<<m1<<endl;
    //int arr1[]={3,4,5,3,1};
    //Matrix <int> m2(2,3, arr1, 5);
    //cout<<m2<<endl;
    //if(m1.AddMatrix(m2)) cout<<m1<<endl;
    //else cout<<"error"<<endl;
    int arr2[]={1,2,3,4,5,6,7,8,9};
    Matrix <int> m3(3,3, arr2, 9);
    //Matrix <int> m4(m1.MultiMatrix(m3));
    Matrix<int> m4=m1*m3;
    cout<<m3<<endl;
    cout<<m4<<endl;
    return 0;
}
发布了5 篇原创文章 · 获赞 1 · 访问量 730

猜你喜欢

转载自blog.csdn.net/davidhan427/article/details/87208202