CCF-CSP Zhenti「202305-2 マトリックス演算」アイデア + Python、C++ フルスコア ソリューション

実際の質問と他の質問の解決策を確認したい学生は、CCF-CSP の実際の質問と解決策にアクセスしてください。


質問番号: 202305-2
質問名: 行列演算
制限時間: 5.0秒
メモリ制限: 512.0MB
問題の説明:

トピックの背景

Softmax(Q×KTd)×V は、Transformer の注目モジュールの中心となる式です。ここで、Q、K、V はすべて n 行 d 列の行列で、KT は行列 K の転置を表し、× は行列の乗算を表します。

問題の説明

計算の便宜上、Dun Dun は Softmax をサイズ n の 1 次元ベクトル W のドット乗算に単純化します:
(W⋅(Q×KT))×V
ドット乗算は対応するビットの乗算であり、W(i) は次のようになります。ベクトル W の i 番目の要素、つまり (Q×KT) の i 行目の各要素に W(i) が掛けられます。

ここで、行列 Q、K、V とベクトル W が与えられたとして、簡略化された公式に従って計算された結果を計算してみます。

入力フォーマット

標準入力からデータを読み取ります。

入力の最初の行には、行列のサイズを示すスペースで区切られた 2 つの正の整数 n と d が含まれています。

次に、行列 Q、K、V が順番に入力されます。各行列に n 行を入力します。各行にはスペースで区切られた d 個の整数が含まれます。i 行目の j 番目の数値は行列の i 行、j 列に対応します。

最後の行は、ベクトル W を表す n 個の整数を入力します。

出力フォーマット

標準出力に出力します。

合計 n 行を出力します。各行には、計算結果を表すスペースで区切られた d 個の整数が含まれます。

サンプル入力

3 2
1 2 3
4
5 6
10 10
-20 -20
30 30
6
5 4
3 2 1
4 0 -5

サンプル出力

480 240
0 0
-2200 -1100

サブタスク

70 のテスト データは、n ≤ 100 および d ≤ 10 を満たします。入力行列とベクトルの要素はすべて整数であり、絶対値は 30 を超えません。

すべてのテスト データは次の条件を満たします: n≤104 および d≤20; 入力行列とベクトルの要素はすべて整数であり、絶対値は 1000 を超えません。

ヒント

行列の乗算後の値の範囲を慎重に評価し、適切なデータ型を使用して行列に整数を格納します。

本当の質問のソース:行列演算

 興味のある学生は、この方法でコードを作成して提出練習を行うことができます。

アイデアの説明:

この問題はそれほど難しいものではなく、法則を紙に押し込めば循環計算の法則を見つけることができます。この質問の焦点は時間計算量にあります。最初に QK 行列を乗算すると、n*n 行列が得られ、タイムアウトが表示されるため、最初に後者の 2 つの行列を計算する必要があります。時間計算量は次のようになります。合格した。

C++ フルスコア ソリューション:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010, D = 30;
LL tmp[D][D], ans[N][N];
int n, d;
int Q[N][D], K[N][D], V[N][D], W[N];
int main()
{
    cin >> n >> d;
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> Q[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> K[i][j];
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
            cin >> V[i][j];
    for (int i = 1; i <= n; i ++) cin >> W[i];
    
	// 计算 Q * V = tmp
    for (int i = 1; i <= d; i ++)
        for (int j = 1; j <= d; j ++)
            for (int k = 1; k <= n; k ++)
                tmp[i][j] += K[k][i] * V[k][j];
                
    // 计算 K * tmp = ans
    for (int i = 1; i <= n; i ++)
        for (int j = 1; j <= d; j ++)
        {
            for (int k = 1; k <= d; k ++)
                ans[i][j] += Q[i][k] * tmp[k][j];
            ans[i][j] *= (LL) W[i];
        }
        
    for (int i = 1; i <= n; i ++)
    {
        for (int j = 1; j <= d; j ++)
            cout << ans[i][j] << " ";
        cout << endl;
    }
    return 0;
}

 操作結果:

おすすめ

転載: blog.csdn.net/weixin_53919192/article/details/131490291