C++ 使用哈希表的思想完成稀疏矩阵乘法

题设

要求构建一种数据结构,完成稀疏矩阵乘法。

思路

采用了普遍使用的三元组思路,由于看网上实现的方法感觉略复杂,便想自己用易懂的方式自己写一遍。

构建的数据结构为,结构体sparse_node,表示矩阵中不为0的点。结构体sparse_mat,内部放置sparse_node以及行数列数

#include <iostream>
#include <vector>
#include <unordered_map>
#include <map>
using namespace std;

struct sparse_node
{
	int i;
	int j;
	int val;
};
struct sparse_mat
{
	map<pair<int, int>, int> data;
	int row_num;
	int col_num;
};

关于这里为何未使用STL里的哈希表unordered_map而是使用map,这是因为我在使用unordered_map时,后续将其与vector/pair结合使用(因为想使用非零元素坐标作为键值),报错“error C2280: 'std::hash<_Kty>::hash(const std::hash<_Kty> &)': attempting to reference a deleted function”翻译过来就是“错误 C2280 “std::hash<_Kty>::hash(const std::hash<_Kty> &)”: 尝试引用已删除的函数”

上网查找原因,Stack Overflow上的解释是“ using  std::unordered_map with a  vector as a key is usually a bad idea, because of the likely high cost of implementing any kind of effective hashing function.”,就是说使用vector作为键值会导致哈希函数效率损失很多,所以官方直接把这个函数模板从unordered_map里面去掉了,使用vector/pair作为键值,最合适的数据结构是map(STL中以红黑树为底层)

稀疏矩阵乘法

于是我们就将非零元素的坐标,作为map的key,非零元素的值作为map的value。

例如,将第一个矩阵的一个元素表示为{0, 3, 4},于是(0, 3)作为键值,4作为value。假设隔壁矩阵的一个元素为{3, 1, 2},那么他们位置正好对的上,能够在最后结果矩阵里产生一个{0, 1, 8}的元素,因为4x2=8。如果另外的元素相乘也在(0, 1)位置产生了值,要将这个值与过去的8加起来才符合矩阵乘法的思想。

下面是稀疏矩阵乘法的代码示例

sparse_mat sparse_mat_product(sparse_mat mat_1, sparse_mat mat_2)
{
        //输出的结果稀疏矩阵
	sparse_mat result;

        //如果两矩阵维度对不上,则抛出异常
	if (mat_1.col_num != mat_2.row_num)
		throw("matrix size not match");
	else
	{
		result.row_num = mat_1.row_num;
		result.col_num = mat_2.col_num;
        
                //使用迭代器遍历第一个矩阵的非零元素
		for (auto it1 = mat_1.data.begin(); it1 != mat_1.data.end(); it1++)
		{
                        //(*it1).first代表key,这个key是一个坐标pair,(*it1).first.first代表坐标的第一个元素,这里判断是否输入值越界
			if ((*it1).first.first > mat_1.row_num - 1 || (*it1).first.second > mat_1.col_num - 1)
				throw("index out of range");
                        //使用迭代器遍历第二个矩阵的非零元素
			for (auto it2  = mat_2.data.begin(); it2 != mat_2.data.end(); it2++)
			{
                                //检测是否输入越界
				if ((*it2).first.first > mat_2.row_num - 1 || (*it2).first.second > mat_2.col_num - 1)
					throw("index out of range");
                                //如果第一个元素的位置和第二个元素的位置对的上,我们就将其相乘
				if ((*it1).first.second == (*it2).first.first)
				{
                                        //确认结果矩阵中是否已经出现过该位置的值,如果没有就创建
					if (result.data.find(make_pair((*it1).first.first, (*it2).first.second)) == result.data.end())
					{
						result.data[make_pair((*it1).first.first, (*it2).first.second)] = (*it1).second * (*it2).second;
					}
                                        //如果已经出现过,则将其与历史值相加
					else
					{
						result.data[make_pair((*it1).first.first, (*it2).first.second)] += (*it1).second * (*it2).second;
					}
				}
			}
		}
	}
	return result;
}

 

检验

int main()
{
        //构建稀疏矩阵1
	sparse_mat mat1;
	vector<vector<int>> mat1_data = { {0, 3, 4}, {0, 2, 1}, {2, 2, 4 } };
	mat1.row_num = 3;
	mat1.col_num = 4;
	for (int i = 0; i <= mat1_data.size() - 1; i++)
	{
		vector<int> tmp = mat1_data[i];
		mat1.data[make_pair(tmp[0], tmp[1])] = tmp[2];
	}
	cout << "sparse matrix 1 :" << endl;
	cout << " size :" << mat1.row_num << " x " << mat1.col_num << endl;
	for (auto it = mat1.data.begin(); it != mat1.data.end(); it++)
	{
		cout << "{" <<(*it).first.first << " ," << (*it).first.second << "}" << ", " << (*it).second << endl;
	}
	cout << endl;
        //构建稀疏矩阵2
	sparse_mat mat2;
	vector<vector<int>> mat2_data = { {3, 1, 2}, {2, 1, 1}, {2, 2, 4 } };
	mat2.row_num = 4;
	mat2.col_num = 3;
	for (int i = 0; i <= mat2_data.size() - 1; i++)
	{
		vector<int> tmp = mat2_data[i];
		mat2.data[make_pair(tmp[0], tmp[1])] = tmp[2];
	}
	cout << "sparse matrix 2 :" << endl;
	cout << " size :" << mat2.row_num << " x " << mat2.col_num << endl;
	for (auto it = mat2.data.begin(); it != mat2.data.end(); it++)
	{
		cout << "{" << (*it).first.first << " ," << (*it).first.second << "}" << ", " << (*it).second << endl;
	}
	cout << endl;
        //构建结果稀疏矩阵
	sparse_mat result = sparse_mat_product(mat1, mat2);
	cout << "sparse matrix product : " << endl;
	cout << " size :" << result.row_num << " x " << result.col_num << endl;
	for (auto it = result.data.begin(); it != result.data.end(); it++)
	{
		cout << "{" << (*it).first.first << " ," << (*it).first.second << "}" << ", " << (*it).second << endl;
	}
	return 0;
}

发布了94 篇原创文章 · 获赞 137 · 访问量 19万+

猜你喜欢

转载自blog.csdn.net/yyhhlancelot/article/details/100085294