CCF-CSP Zhenti "202305-2 Matrix Operation" ideas + python, c++ full score solution

Students who want to check the real questions and solutions of other questions can go to: CCF-CSP real questions with solutions


Question No.: 202305-2
Question name: Matrix Operations
time limit: 5.0s
Memory limit: 512.0MB
Problem Description:

topic background

Softmax(Q×KTd)×V is the core formula of the attention module in Transformer, where Q, K and V are all matrices with n rows and d columns, KT represents the transpose of matrix K, and × represents matrix multiplication.

Problem Description

For the convenience of calculation, Dun Dun simplifies Softmax to dot multiplication of a one-dimensional vector W of size n:
(W⋅(Q×KT))×V
dot multiplication is the multiplication of corresponding bits, and W(i) is the vector W The i-th element of , that is, each element in the i-th row of (Q×KT) is multiplied by W(i).

Now given the matrices Q, K and V and the vector W, try to calculate the result calculated according to the simplified formula.

input format

Read data from standard input.

The first line of input contains two positive integers n and d separated by spaces, denoting the size of the matrix.

Next the matrices Q, K and V are entered in sequence. Input n rows for each matrix, and each row contains d integers separated by spaces, where the j-th number in the i-th row corresponds to the i-th row and j-th column of the matrix.

The last line inputs n integers representing the vector W.

output format

output to standard output.

Output a total of n lines, each line contains d integers separated by spaces, representing the result of the calculation.

sample input

3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5

sample output

480 240
0 0
-2200 -1100

Subtasks

The test data of 70 satisfies: n≤100 and d≤10; the elements in the input matrix and vector are all integers, and the absolute value does not exceed 30.

All test data satisfy: n≤104 and d≤20; the elements in the input matrix and vector are all integers, and the absolute value does not exceed 1000.

hint

Carefully evaluate the range of values ​​after matrix multiplication and use an appropriate data type to store integers in the matrix.

Source of real question: matrix operation

 Interested students can code in this way for practice submission

Ideas explained:

This question is not difficult, and you can find the law of circular calculation by pushing the law on the paper. The focus of this question lies in the time complexity. If you multiply the QK matrices first, you will get an n*n matrix, which will display a timeout, so you must first calculate the latter two matrices, and the time complexity can be passed.

C++ full score solution:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010, D = 30;
LL tmp[D][D], ans[N][N];
int n, d;
int Q[N][D], K[N][D], V[N][D], W[N];
int main()
{
    cin >> n >> d;
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> Q[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> K[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> V[i][j];
    for (int i = 1; i <= n; i ++) cin >> W[i];
    
	// 计算 Q * V = tmp
    for (int i = 1; i <= d; i ++)
        for (int j = 1; j <= d; j ++)
            for (int k = 1; k <= n; k ++)
                tmp[i][j] += K[k][i] * V[k][j];
                
    // 计算 K * tmp = ans
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
        {
            for (int k = 1; k <= d; k ++)
                ans[i][j] += Q[i][k] * tmp[k][j];
            ans[i][j] *= (LL) W[i];
        }
        
    for (int i = 1; i <= n; i ++)
    {
        for (int j = 1; j <= d; j ++)
            cout << ans[i][j] << " ";
        cout << endl;
    }
    return 0;
}

 operation result:

Guess you like

Origin blog.csdn.net/weixin_53919192/article/details/131490291