稀疏矩阵 part 2

▶ 各种稀疏矩阵数据结构之间的转化

● MAT ←→ CSR

 1 CSR * MATToCSR(const MAT *in)                                       // MAT 转 CSR
 2 {
 3     checkNULL(in);
 4     CSR * out = initializeCSR(in->row, in->col, in->count);
 5     checkNULL(out);
 6     
 7     out->ptr[0] = 0;
 8     for (int i = 0, j = 0, k = 1; i < in->row * in->col; i++)       // i 遍历 in->data
 9     {
10         if (in->data[i] != 0)                                       // 找到非零元
11         {
12             if (j == in->count)                                     // 在 out->data 已经填满了的基础上又发现了非零元,错误
13                 return NULL;
14             out->data[j] = in->data[i];                             // 填充非零元素
15             out->index[j] = i % in->col;                            // 填充列号
16             j++;
17         }
18         if ((i + 1) % in->col == 0)                                 // 到了最后一列,写入行指针号
19             out->ptr[k++] = j;
20     }
21     return out;
22 }
23 
24 MAT * CSRToMAT(const CSR *in)                                       // CSR转MAT
25 {
26     checkNULL(in);
27     MAT *out = initializeMAT(in->row, in->col, in->ptr[in->row]);
28     checkNULL(out);
29 
30     memset(out->data, 0, sizeof(format) * in->row * in->col);
31     for (int i = 0; i < in->row; i++)                               // i 遍历行
32     {                                                
33         for (int j = in->ptr[i]; j < in->ptr[i + 1]; j++)           // j 遍历列 
34             out->data[i * in->col + in->index[j]] = in->data[j];
35     }
36     return out;
37 }

● MAT ←→ ELL

 1 ELL * MATToELL(const MAT *in)// MAT转ELL
 2 {
 3     checkNULL(in);
 4 
 5     int i, j, maxElement; 
 6     for (i = j = maxElement = 0; i < in->row * in->col; i++)                    // i 遍历 in->data,j 记录该行非零元素数,maxElement 记录一行非零元素最大值
 7     {     
 8         if (in->data[i] != 0)                                                   // 找到非零元       
 9             j++;                                                       
10         if ((i + 1) % in->col == 0)                                             // 行末,更新 maxElement                        
11         {
12             maxElement = MAX(j, maxElement);                    
13             j = 0;                                                              // 开始下一行之前清空 j
14         }
15     }
16     format* temp_data=(format *)malloc(sizeof(format) * in->row * maxElement);  // 临时数组,将列数压缩到 maxElement
17     checkNULL(temp_data);
18     int* temp_index = (int *)malloc(sizeof(int) * in->row * maxElement);
19     checkNULL(temp_index);
20     memset(temp_data, 0, sizeof(format) * in->row * maxElement);
21     memset(temp_index, 0, sizeof(int) * in->row * maxElement);
22     for (i = j = 0; i < in->row * in->col; i++)                                 // i 遍历 in->data,j 记录该行非零元素数,把 in 中每行的元素往左边推
23     {        
24         if (in->data[i] != 0)                                                   // 找到非零元
25         {
26             temp_data[i / in->col * maxElement + j] = in->data[i];              // 存放元素
27             temp_index[i / in->col * maxElement + j] = i % in->col;             // 记录所在的列号
28             j++;                                                    
29         }
30         if ((i + 1) % in->col == 0)                                             // 行末,将剩余位置的下标记作 -1,即无效元素
31         {            
32             for (j += i / in->col * in->col; j < maxElement * (i / in->col + 1); j++)   // 使得 j 指向本行最后一个非零元素之后的元素,再开始填充
33                 temp_index[j] = -1;                                 
34             j = 0;                                                              // 开始下一行之前清空 j
35         }
36     }    
37     ELL *out = initializeELL(maxElement, in->row, in->col);                     // 最终输出,如果不转置的话不要这部分
38     checkNULL(out);
39     for (i = 0; i < out->row * out->col; i++)                                   // 将 temp_data 和 temp_index 转置以提高缓存利用
40     {
41         out->data[i] = temp_data[i % out->col * out->row + i / out->col];
42         out->index[i] = temp_index[i % out->col * out->row + i / out->col];
43     }
44     free(temp_data);
45     free(temp_index);
46     return out;
47 }
48 
49 MAT * ELLToMAT(const ELL *in)                                                   // ELL转MAT
50 {
51     checkNULL(in);
52     MAT *out = initializeMAT(in->col, in->colOrigin);
53     checkNULL(out);
54 
55     for (int i = 0; i < in->row * in->col; i++)                                 // i 遍历 out->data 
56     {
57         if (in->index[i] < 0)                                                   // 注意跳过无效元素
58             continue;
59         out->data[i % in->col * in->colOrigin + in->index[i]] = in->data[i];
60     }
61     COUNT_MAT(out);
62     return out;
63 }

● MAT ←→ COO

 1 COO * MATToCOO(const MAT *in)                               // MAT转COO
 2 {
 3     checkNULL(in);
 4     COO *out = initializeCOO(in->row, in->col, in->count);
 5 
 6     for (int i=0, j = 0; i < in->row * in->col; i++)
 7     {
 8         if (in->data[i] != 0)
 9         {
10             out->data[j] = in->data[i];
11             out->rowIndex[j] = i / in->col;
12             out->colIndex[j] = i % in->col;
13             j++;
14         }
15     }
16     return out;
17 }
18 
19 MAT * COOToMAT(const COO *in)                               // COO转MAT
20 {
21     checkNULL(in);
22     MAT *out = initializeMAT(in->row, in->col, in->count);
23     checkNULL(out);
24 
25     for (int i = 0; i < in->row * in->col; i++)
26         out->data[i] = 0;
27     for (int i = 0; i < in->count; i++)
28         out->data[in->rowIndex[i] * in->col + in->colIndex[i]] = in->data[i];
29     return out;
30 }

● MAT ←→ DIA

 1 DIA * MATToDIA(const MAT *in)                                       // MAT转DIA
 2 {
 3     checkNULL(in);
 4 
 5     int *index = (int *)malloc(sizeof(int)*(in->row + in->col - 1));
 6     for (int diff = in->row - 1; diff > 0; diff--)                  // 左侧零对角线情况
 7     {        
 8         int flagNonZero = 0;
 9         for (int i = 0; i < in->col && i + diff < in->row; i++)     // i 沿着对角线方向遍历 in->data,flagNonZero 记录该对角线是否全部为零元
10         {            
11 #ifdef INT
12             if (in->data[(i + diff) * in->col + i] != 0)
13 #else
14             if (fabs(in->data[(i + diff) * in->col + i]) > EPSILON)
15 #endif            
16                 flagNonZero = 1;
17         }
18         index[in->row - 1 - diff] = flagNonZero;                    // 标记该对角线上有非零元
19     }
20     for (int diff = in->col - 1; diff >= 0; diff--)                 // 右侧零对角线情况
21     {                                                                                                                 
22         int flagNonZero = 0;
23         for (int j = 0; j < in->row && j + diff < in->col; j++)
24         {
25 #ifdef INT
26             if (in->data[j * in->col + j + diff] != 0)
27 #else
28             if (fabs(in->data[j * in->col + j + diff]) > EPSILON)
29 #endif            
30                 flagNonZero = 1;
31         }
32         index[in->row - 1 + diff] = flagNonZero;                    // 标记该对角线上有非零元
33     }       
34     int *prefixSumIndex = (int *)malloc(sizeof(int)*(in->row + in->col - 1));
35     prefixSumIndex[0] = index[0];
36     for (int i = 1; i < in->row + in->col - 1; i++)                 // 闭前缀和,prefixSumIndex[i] 表示原矩阵第 0 ~ i 条对角线中共有多少条非零对角线(含)
37         prefixSumIndex[i] = prefixSumIndex[i-1] + index[i];         // index[in->row + in->col -2] 表示原矩阵非零对角线条数,等于 DIA 矩阵列数
38     DIA *out = initializeDIA(in->row, prefixSumIndex[in->row + in->col - 2], in->col);
39     checkNULL(out);
40 
41     memset(out->data, 0, sizeof(int)*out->row * out->col);
42     for (int i = 0; i < in->row + in->col - 1; i++)             
43         out->index[i] = index[i];                                   // index 搬进 out
44     for (int i = 0; i < in->row; i++)                               // i,j 遍历原矩阵,将元素搬进 out
45     {
46         for (int j = 0; j < in->col; j++)
47         {
48             int temp = j - i + in->row - 1;
49             if (index[temp] == 0)
50                 continue;            
51             out->data[i * out->col + (temp > 0 ? prefixSumIndex[temp - 1] : 0)] = in->data[i * in->col + j];    // 第 row - 1 行第 0 列元素 temp == 0,单独处理
52         }
53     }
54     free(index);
55     free(prefixSumIndex);
56     return out;
57 }
58 
59 MAT * DIAToMAT(const DIA *in)                                       // DIA转MAT
60 {
61     checkNULL(in);
62     MAT *out = initializeMAT(in->row, in->colOrigin);
63     checkNULL(out);
64 
65     int * inverseIndex = (int *)malloc(sizeof(int) * in->col);
66     for (int i = 0, j = 0; i < in->row + in->col - 1; i++)          // 求一个 index 的逆,即 DIA 中第 i 列对应原矩阵第 inverseIndex[i] 对角线
67     {                                                               // 原矩阵对角线编号 (row-1, 0) 为第 0 条,(0, 0) 为第 row - 1 条,(col-1, 0) 为第 row + col - 2 条
68         if (in->index[i] == 1)
69         {
70             inverseIndex[j] = i;
71             j++;
72         }
73     }
74     for (int i = 0; i < in->row; i++)                               // i 遍历 in->data 行,j 遍历 in->data 列
75     {
76         for (int j = 0; j < in->col; j++)
77         {
78             if (i < in->row - 1 - inverseIndex[j] || i > inverseIndex[in->col - 1] - inverseIndex[j])   // 跳过两边呈三角形的无效元素
79                 continue;
80             out->data[i * in->col + inverseIndex[j] - in->row + 1] = in->data[i * in->col + j];         // 利用 inverseIndex 来找钙元素在原距震中的位置
81         }
82     }
83     free(inverseIndex);
84     return out;
85 }

猜你喜欢

转载自www.cnblogs.com/cuancuancuanhao/p/10428415.html