実際の質問と他の質問の解決策を確認したい学生は、CCF-CSP の実際の質問と解決策にアクセスしてください。
質問番号: | 202305-2 |
質問名: | 行列演算 |
制限時間: | 5.0秒 |
メモリ制限: | 512.0MB |
問題の説明: | トピックの背景Softmax(Q×KTd)×V は、Transformer の注目モジュールの中心となる式です。ここで、Q、K、V はすべて n 行 d 列の行列で、KT は行列 K の転置を表し、× は行列の乗算を表します。 問題の説明計算の便宜上、Dun Dun は Softmax をサイズ n の 1 次元ベクトル W のドット乗算に単純化します: ここで、行列 Q、K、V とベクトル W が与えられたとして、簡略化された公式に従って計算された結果を計算してみます。 入力フォーマット標準入力からデータを読み取ります。 入力の最初の行には、行列のサイズを示すスペースで区切られた 2 つの正の整数 n と d が含まれています。 次に、行列 Q、K、V が順番に入力されます。各行列に n 行を入力します。各行にはスペースで区切られた d 個の整数が含まれます。i 行目の j 番目の数値は行列の i 行、j 列に対応します。 最後の行は、ベクトル W を表す n 個の整数を入力します。 出力フォーマット標準出力に出力します。 合計 n 行を出力します。各行には、計算結果を表すスペースで区切られた d 個の整数が含まれます。 サンプル入力
サンプル出力
サブタスク70 のテスト データは、n ≤ 100 および d ≤ 10 を満たします。入力行列とベクトルの要素はすべて整数であり、絶対値は 30 を超えません。 すべてのテスト データは次の条件を満たします: n≤104 および d≤20; 入力行列とベクトルの要素はすべて整数であり、絶対値は 1000 を超えません。 ヒント行列の乗算後の値の範囲を慎重に評価し、適切なデータ型を使用して行列に整数を格納します。 |
本当の質問のソース:行列演算
興味のある学生は、この方法でコードを作成して提出練習を行うことができます。
アイデアの説明:
この問題はそれほど難しいものではなく、法則を紙に押し込めば循環計算の法則を見つけることができます。この質問の焦点は時間計算量にあります。最初に QK 行列を乗算すると、n*n 行列が得られ、タイムアウトが表示されるため、最初に後者の 2 つの行列を計算する必要があります。時間計算量は次のようになります。合格した。
C++ フルスコア ソリューション:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 10010, D = 30;
LL tmp[D][D], ans[N][N];
int n, d;
int Q[N][D], K[N][D], V[N][D], W[N];
int main()
{
cin >> n >> d;
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= d; j ++)
cin >> Q[i][j];
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= d; j ++)
cin >> K[i][j];
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= d; j ++)
cin >> V[i][j];
for (int i = 1; i <= n; i ++) cin >> W[i];
// 计算 Q * V = tmp
for (int i = 1; i <= d; i ++)
for (int j = 1; j <= d; j ++)
for (int k = 1; k <= n; k ++)
tmp[i][j] += K[k][i] * V[k][j];
// 计算 K * tmp = ans
for (int i = 1; i <= n; i ++)
for (int j = 1; j <= d; j ++)
{
for (int k = 1; k <= d; k ++)
ans[i][j] += Q[i][k] * tmp[k][j];
ans[i][j] *= (LL) W[i];
}
for (int i = 1; i <= n; i ++)
{
for (int j = 1; j <= d; j ++)
cout << ans[i][j] << " ";
cout << endl;
}
return 0;
}
操作結果: