稀疏矩阵c++实现

稀疏矩阵c++实现

楼主最近在学习数据结构,想在网上看看相关的代码并实现一下,找了很多都不尽人意,因此自己写了一个,就当自己的学习笔记,以便时常翻阅查看

下面展示一些 `tirtuple.h文件。

// trituple.h
#pragma once
#include<iostream>
#include<assert.h>
using namespace std;
template<class T>
struct Trituple
{
    
    
	int row, col;
	T value;
	Trituple<T>& operator=(Trituple<T>& x) 
	{
    
    
		row = x.row;
		col = x.col;
		value = x.value;
		return *this;
	}
	
};

下面展示一些 "sparematrix.h

// An highlighted block
#pragma once
#include"trituple.h"
using std::ostream;
template<class T>
class sparsematrix 
{
    
    
public:
	sparsematrix<T>& transpose();
	sparsematrix<T>& add(sparsematrix<T>& );
	sparsematrix(int,int);
	sparsematrix(sparsematrix<T>& );
	~sparsematrix();
	T& getnum(int, int)const ;
	void insert(Trituple<T>& tmp);
	sparsematrix<T>& operator=(sparsematrix<T>& SM);
	sparsematrix<T>& multiply(sparsematrix<T>& b);
	friend ostream& operator<<(ostream& ostr, sparsematrix<T>& SM);
	friend istream& operator >> (istream& istr,sparsematrix<T>& SM);
private:
	int Rows;				//行数
	int Cols;				//列数
	int Terms;				//非0元素的个数
	Trituple<T> *smArry;	//存放非零元素的三元数组
	int maxTerms;			//能容纳的最大元素个数
};

下面展示一些 "sparematrix.cpp

// An highlighted block
#include"sparematrix.h"
template<class T>
inline sparsematrix<T>::sparsematrix(int maxcol,int maxrow) :maxTerms(maxcol*maxrow),Rows(maxrow),Cols(maxcol),Terms(0)
{
    
    
	if (maxcol<1||maxrow<1)
	{
    
    
		cerr << "init ERROR" << endl;
		return;
	}
	assert(smArry != nullptr);
}
template<class T>
inline sparsematrix<T>::sparsematrix(sparsematrix<T>& SM) 
{
    
    
	if (SM.Terms==0) 
	{
    
    
		smArry = nullptr;
		Cols = Rows = Terms = 0;
		return;
	}
	if (SM == *this) 
	{
    
    
		return;
	}
	Rows = SM.Rows;
	Cols = SM.Cols;
	Terms = 0;
	maxTerms = SM.maxTerms;
	for (int i = 0; i < Terms; i++) 
	{
    
    
		insert(SM.smArry[i]);
	}
}
template<class T>
sparsematrix<T>& sparsematrix<T>::operator=(sparsematrix<T>& SM) 
{
    
    
	if (SM.Terms == 0)
	{
    
    
		smArry = nullptr;
		Cols = Rows = Terms = 0;
		return;
	}
	if (SM == *this)
	{
    
    
		return;
	}
	Rows = SM.Rows;
	Cols = SM.Cols;
	Terms = SM.Terms;
	maxTerms = SM.maxTerms;
	smArry = new Trituple<T>[maxTerms];
	for (int i = 0; i < Terms; i++)
	{
    
    
		smArry[i] = SM.smArry[i];
	}
}
template<class T>
inline sparsematrix<T>::~sparsematrix() 
{
    
    
	if (maxTerms!=0) 
	{
    
    
		delete[] smArry;
		Rows = Cols = Terms = 0;
	}
}
template<class T>
inline T & sparsematrix<T>::getnum(int row1, int col1)const 
{
    
    
	for (int i = 0; i < Terms; i++) 
	{
    
    
		if (smArry[i].col == col1&&smArry[i].row == row1) 
		{
    
    
			return smArry[i].value;
		}
	}
	return 0;
}
template<class T>
void sparsematrix<T>::insert(Trituple<T>& tmp)
{
    
    
	for (int i = 0; i < Terms; i++) 
	{
    
    
		if (smArry[i].row = tmp.row&&smArry[i].col == tmp.col) 
		{
    
    
			crr << "already exist" << endl;
			return;
		}
	}
	Terms++;
	Trituple<T>* smarry = new Trituple<T>[Terms];
	if(Terms>1)//方便后续加法操作,按照索引排序,如果之前就存在元素,那么:
	{
    
    
		for (int i = 0; i < Terms-2; i++)//遍历到倒数第二个就行,不然指针越界
		{
    
    
			//如果比最小的还要小 
			if (tmp.row*Cols*tmp.col < smArry[0].row*Cols + smArry[0].col)
			{
    
    
				smarry = tmp;
				for (int j = 0; j < Terms - 1; j++) 
				{
    
    
					smarry[j + 1] = smArry[j];
				}
				break;
			}
			//如果在中间
			else if (smArry[i].row*Cols + smArry[i].col <= tmp.row*Cols*tmp.col
				&&tmp.row*Cols*tmp.col <= smArry[i+1].row*Cols + smArry[i+1].col)
			{
    
    
				for (int j = 0; j < i; j++)
				{
    
    
					smarry[j] = smArry[j];
				}
				smarry[i] = tmp;
				for (int j = i; j < Terms-1; j++)
				{
    
    
					smarry[j+1] = smArry[j];
				}
				break;
			}
			//如果在尾巴上
			else if (smArry[Terms - 2].row*Cols + smArry[Terms - 2].col < tmp.row*Cols*tmp.col)
			{
    
    
				smarry[Terms - 1] = tmp;
				break;
			}
		}
		delete[] smArry;
	}
	smArry = smarry;//头指针赋值
}
template<class T>
ostream& operator<<(ostream& ostr, sparsematrix<T>& SM) 
{
    
    
	ostr << "rows=" << SM.Rows << endl;
	ostr << "Cols=" << SM.Cols << endl;
	ostr << "terms=" << SM.Terms << endl;
	for (int i = 0; i < SM.Terms; i++) 
	{
    
    
		ostr << i + 1 << ":<" << SM.smArry[i].row << "," << SM.smArry[i].col << ">="
			<< SM.smArry[i].value << endl;
	}
	return ostr;
}
template<class T>
istream & operator>>(istream & istr, sparsematrix<T>& SM)
{
    
    
	istr >> SM.Rows >> SM.Cols >> SM.Terms;
	if (SM.Terms > SM.maxTerms) 
	{
    
    
		cerr << "index overflowed" << endl;
		exit(1);
	}
	for (int i=0; i < SM.Terms; i++) 
	{
    
    
		cin >> SM.smArry[i].row >> SM.smArry[i].col >> SM.smArry[i].value;
	}

	return istr;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::add(sparsematrix<T>& b)
{
    
    
	sparsematrix<T> result(cols,rows);
	if (Rows != b.Rows || Cols != b.Cols)
	{
    
    
		cerr << "incompatable!" << endl;;
		return result;
	}
	if (b.Terms == 0)
	{
    
    
		return *this;
	}
	if (Terms == 0)
	{
    
    
		return b;
	}
	result.Rows = Rows;
	result.Cols = Cols;
	result.Terms = 0;
	result.maxTerms = Rows*Cols;
	int i = 0, j = 0, index_a, index_b;
	while (i < Terms&&j<b.Terms) //当两个矩形的非零元素都没有到极限的时候
	{
    
    //前提是smarry是按照左上到右下,从左到右排列的,也就是每个非零元素在矩阵中的索引递增
		index_a = smArry[i].row*Cols + smArry[i].col;//计算每一个非零元素在矩阵中的位置
		index_b = b.smArry[j].row*b.Cols + b.smArry[i].col;
		if (index_a < index_b)//如果本矩阵的元素排在前面,就把本矩阵的元素插入目标矩阵 
		{
    
    
			result.insert(smArry[i]);
			i++;//下次计算下一个位置
		}
		else if (index_a > index_b) //如果被加矩阵的元素排在前面
		{
    
    
			result.insert(smArry[j]);
			j++;//计算被加矩阵下一个元素的位置
		}
		else //如果两元素在同一个索引
		{
    
    
			if (smArry[i].value + b.smArry[j].value) 
			{
    
    
				Trituple<T> tmp;
				tmp = smArry[j];
				tmp.value = smArray[i].value + b.smArray[j].value;
				result.insert(tmp);
				i++;
				j++;
			}
		}
	}
	if (Terms > b.Terms) //当被加矩阵算完了,还剩下原来矩阵的数,因为原来矩阵中数据不重复,因此直接插入就行
	{
    
    
		for (; i < Terms; i++)
		{
    
    
			result.insert(smArry[i]);
			i++;
		}
	}
	else 
	{
    
    
		for (; j<b.Terms; j++)
		{
    
    
			result.insert(b.smArry[j]);
			j++;
		}
	}
	return result;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::transpose()
{
    
    
	int *colSize = new int[Cols+1];//每列非零元素啊的个数
	int *rowstart = new int[Cols+1];//b中每一行的首个非零元素在smArry中的索引,因为是转置,因此列为行
	sparsematrix<T> b(maxTerms);//转置矩阵对应的三元组
	b.Rows = Rows;				//b的性质
	b.Cols = Cols;
	b.Terms = Terms;
	b.maxTerms = maxTerms;
	if (Terms > 0) //如果存在非零元素
	{
    
    
		int i;
		for (i = 0; i < Cols; i++) 
		{
    
    
			colSize[i] = 0;//初始化
		}
		for (i = 0; i < Terms; i++) 
		{
    
    
			colSize[smArry[i].col]++;//根据第i个非零数据的列属性,计算出每列
		}//new必须加1,否则这里会越界
		rowstart[0] = 0;
		for (i = 1; i < Cols+1; i++)//
		{
    
    
			rowstart[i] = rowstart[i - 1] + colSize[i - 1];//后一行的第一个首个非零元素的索引是前行首个非零元素索引+非零元素个数
		}
		//以上的所有步骤都是为了找到a中每一个smarry在转置后的矩阵中的脚标
		for (i = 0; i < Terms; i++)//遍历三元组 
		{
    
    
			int j = rowstart[smArry[i].col]++;//对应第i给元素所在列的首个不为零的数据的索引,按照行从左到右顺序存入
											//获取值之后自加1,因为一列可能有多个元素,下次再到这一列的时候,就是这一列的下一个元素,
											//同样转置后的矩阵的非零元素的下一个元素的脚标也要加1
			b.smArry[j] = smArry[i];
		}
	}
	delete[] rowstart;
	delete[] colSize;
	return b;
}
template<class T>
sparsematrix<T>& sparsematrix<T>::multiply(sparsematrix<T>& b)
{
    
    
	sparsematrix<T> result = {
    
    };
	if (Cols != b.Rows) 
	{
    
    
		cerr << "unable to multiply" << endl;
		return result;
	}
	int *colSize = new int[b.Cols+1];//b矩阵每列非零元素的个数
	int *colStart = new int[b.Cols + 1];//b矩阵每列第一个非零元素在b中的下标
	int i;
	for (i = 0; i < b.Cols; i++) 
	{
    
    
		colSize[i] = 0;
	}
	for (i = 0; i <Terms; i++) 
	{
    
    
		colSize[b.smArry[i].col]++;
	}
	colStart[0] = 0;
	for (i = 1; i < b.Cols+1; i++) 
	{
    
    
		colStart[i] = colStart[i - 1] + colSize[i - 1];
	}
	int index = 0;//非0元素的脚标
	int temp[b.Cols] = {
    
    };//暂存每个元素的运算结果
	//获取第一个元素所在行
	while (index < Terms)
		//只要不结束,结束上一行的所有计算后,进行下一行的计算
	{
    
    
		int  row_a = smArry[index].row;
		while (index < Terms&&smArry[index].row == row_a) //数组是从左到右,从上到下排列,因此可以这么写,获得本身在rowa_a的元素
		{
    
    //第一个判断条件是怕index越界,产生乱码
			for (int j = 1; j < b.Cols + 1; j++) //针对每列元素
			{
    
    
				for (i = colStart[j]; i < colStart[j + 1]; i++) //针对b在第col_a列的非零元素,进行乘积操作
				{
    
    
					int row_b = b.smArray[i].row;
					if (smArry[index].col == row_b) //如果b矩阵中非零元素的行等于本身非零元素的列,也就是说两个位置元素都为非零,才能相乘
					{
    
    
						temp[j-1] += smArry[index].value*b.smArry[i].value;
					}
				}
			}
			index++;

		}
		for (i = 0; i < b.Cols; i++) 
		{
    
    
			if (temp[i] != 0) 
			{
    
    
				Trituple<T>& Result;
				Result.row = smArray[index-1].row;//一定要减一,不然会row会全部+1
				Result.col = i;
				Result.value = temp[i];
				result.insert(Result);
			}
		}
	}
	result.Rows = Rows;
	result.Cols = b.Cols;
	delete[] colSize;
	delete[] colStart;
	return result;
}

猜你喜欢

转载自blog.csdn.net/weixin_42034081/article/details/114239513
今日推荐