牛客网 牛客国庆集训派对Day2 矩阵乘法 (分块)

链接:https://www.nowcoder.com/acm/contest/202/A
来源:牛客网

题目描

深度学习算法很大程度上基于矩阵运算。例如神经网络中的全连接,本质上是一个矩阵乘法;而卷积运算也通常是用矩阵乘法来完成的。有一些科研工作者为了让神经网络的计算更快捷,提出了二值化网络的方法,就是将网络权重压缩成只用两种值表示的形式,这样就可以用一些 trick 加速计算了。例如两个二进制向量点乘,可以用计算机中的与运算代替,然后统计结果中 1 的个数即可。
然而有时候为了降低压缩带来的误差,只允许其中一个矩阵被压缩成二进制。这样的情况下矩阵乘法运算还能否做进一步优化呢?给定两个整数矩阵A, B,计算矩阵乘法 C = A x B。为了减少输出,你只需要计算 C 中所有元素的的异或和即可。

输入描述:

第一行有三个整数 N, P, M, 表示矩阵 A, B 的大小分别是 N x P, P x M 。
接下来 N 行是矩阵 A 的值,每一行有 P 个数字。第 i+1 行第 j 列的数字为 A
i,j
, A
i,j
 用大写的16进制表示(即只包含 0~9, A~F),每个数字后面都有一个空格。
接下来 M 行是矩阵 B 的值,每一行是一个长为 P 的 01字符串。第 i + N + 1 行第 j 个字符表示 B
j,i
 的值。

输出描述:

一个整数,矩阵 C 中所有元素的异或和。
示例1

输入

复制
4 2 3
3 4
8 A
F 5
6 7
01
11
10

输出

复制
2

说明

矩阵 C 的值为:
4 7 3
10 18 8
5 20 15
7 13 6
 
赛后题解给的是分块的方法,阐述一下自己的理解,该题目中,B矩阵是一个二进制矩阵,所以它每个位置上的数只有0或者1,两种可能,所以我们可以考虑枚举这个矩阵的每一列的部分情况进行分块。然后将每一个块当成一个整体的数,来缩小矩阵乘法的规模
A矩阵的极限情况是4096*64  如果进行一个块大小为8的分块,那么就能将它转化成4096*8的矩阵,B矩阵转化成8*4096的矩阵,缩小矩阵的规模,让其达到能够使用矩阵乘法的时间复杂度。
比如样例中 第一行3和4就在同一个块里,那么我们先预处理出3 4,遇到B矩阵的所有会出现的情况.AQ[1][1][0]代表第一行第一个块遇到的B矩阵二进制串为0,0的情况,AQ[1][1][2],代表遇到的B矩阵二进制串为0,1的情况,AQ[1][1][3]代表遇到B矩阵为1,1的情况
我们预处理出AQ数组,那么就能对B矩阵出现的所有情况进行对应的处理,并进行矩阵乘法
 
 
代码如下:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
int A[4100][70];
int B[70][4100];
int AQ[4100][10][260]; //表示A矩阵的第i行,就j个列块,在与B矩阵的对应块k相乘得到的结果;
int BQ[10][4100]; //表示B矩阵第j列,第i行块的二进制数
int n,p,m,size1;
int ans,tmp;
int main()
{
    scanf("%d%d%d",&n,&p,&m);
    for(int i=0;i<n;i++)
      for(int j=0;j<p;j++)
      scanf("%X",&A[i][j]);


    for(int i=0;i<m;i++)
        for(int j=0;j<p;j++)
         scanf("%1d",&B[j][i]);

    size1=(p-1)/8+1;    //分成大小为8的块,块的数量为size1
    for(int i=0;i<n;i++)
      for(int j=1;j<=size1;j++) //枚举A矩阵每一个块遇到对应的B矩阵对应块的的情况
      {
          int r=(j-1)*8;
          for(int k=0;k<=255;k++)//枚举二进制块的所有情况
            for(int z=0;z<8;z++)
          {
              if(k&(1<<z))
               AQ[i][j][k]+=A[i][r+z];
          }
      }

    for(int i=0;i<m;i++)   //将B矩阵缩小后的矩阵的二进制块的数值读入
        for(int j=1;j<=size1;j++)
    {
         int r=(j-1)*8;
         for(int k=0;k<8;k++)
           if(B[r+k][i]==1)
         BQ[j][i]+=(1<<k);
    }

    ans=0;
    for(int i=0;i<n;i++)
        for(int k=0;k<m;k++)
        {
          tmp=0;
          for(int j=1;j<=size1;j++)
           tmp+=AQ[i][j][BQ[j][k]];
            ans^=tmp;
        }
    printf("%d\n",ans);
    return 0;
}
 
 
 
 
 

猜你喜欢

转载自www.cnblogs.com/a249189046/p/9749042.html