题目大意
一个 行 列的矩阵,每个位置有权值
给定一个出发点,每次可以等概率的移动到一个权值小于当前点权值的点,同时得分加上两个点之间欧几里得距离的平方(欧几里得距离: ),问得分的期望
solution
考虑期望
很容易写出转移方程:
为
的元素个数
直接求解复杂度为 ,考虑优化
把 展开,得到
很明显在根据 排序后含 每一项都可以利用前缀和优化
可以用费马小定理 求逆元。
复杂度
注意细节,由于只能向小于的转移,所以转移时要分层转移,在转移完这一层后更新前缀和。
表示的是比 小的个数,不是在 前有多少个
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <bitset>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <cmath>
using namespace std;
#ifdef WIN32
#define LLIO "%I64d"
#else
#define LLIO "%lld"
#endif
struct FastIO {
inline FastIO& operator >> (int& x) {
x = 0; char f = 0, ch = getchar();
while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x = x * 10 + ch - 48, ch = getchar();
return x = (f ? -x : x), *this;
}
inline FastIO& operator >> (long long& x) {
x = 0; char f = 0, ch = getchar();
while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x = x * 10 + ch - 48, ch = getchar();
return x = (f ? -x : x), *this;
}
inline FastIO& operator >> (double& x) {
x = 0; char f = 0, ch = getchar();
double d = 0.1;
while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x=x * 10 + ch - 48, ch = getchar();
if(ch == '.') {
ch = getchar();
while(ch <= '9' && ch >= '0') x += d * (ch ^ 48), d *= 0.1, ch = getchar();
}
return x = (f ? -x : x), *this;
}
}rin;
const int N = 1e6 + 50, mod = 998244353;
struct node {
long long x, y, v;
}m[N];
bool operator < (const node &a, const node &b) {
return a.v < b.v;
}
int n, k, r, c, tot, ans, st, smal, nsmal;
long long sum, sumx, sumy, sumxp, sumyp;
long long nsum, nsumx, nsumy, nsumxp, nsumyp;
long long f[N];
long long ksm(long long x, int p) {
long long ans = 1;
while(p) {
if(p & 1) ans = ans * x % mod;
x = x * x % mod;
p >>= 1;
}
return ans % mod;
}
signed main() {
rin >> n >> k;
for(int i = 1 ; i <= n ; ++ i) {
for(int j = 1 ; j <= k ; ++ j) {
rin >> m[++ tot].v;
m[tot].x = i;
m[tot].y = j;
}
}
rin >> r >> c;
sort(m + 1, m + 1 + tot);
for(int i = 1 ; i <= tot ; ++ i) {
if(m[i].v != m[i - 1].v && i != 1) break;
f[i] = 0;
nsumx = (nsumx + m[i].x) % mod;
nsumy = (nsumy + m[i].y) % mod;
nsumxp = (nsumxp + m[i].x * m[i].x % mod) % mod;
nsumyp = (nsumyp + m[i].y * m[i].y % mod) % mod;
++ st;
++ nsmal;
if(m[i].x == r && m[i].y == c) ans = i;
}
for(int i = st + 1 ; i <= tot ; ++ i) {
if(m[i].v != m[i - 1].v) {
sum = (sum + nsum) % mod; nsum = 0;
sumx = (sumx + nsumx) % mod; nsumx = 0;
sumy = (sumy + nsumy) % mod; nsumy = 0;
sumxp = (sumxp + nsumxp) % mod; nsumxp = 0;
sumyp = (sumyp + nsumyp) % mod; nsumyp = 0;
smal = nsmal + smal; nsmal = 0;
}
f[i] = (((((sum \
+ m[i].x * m[i].x % mod * (smal) % mod \
- 2 * m[i].x * sumx % mod \
+ sumxp \
+ sumyp \
+ m[i].y * m[i].y % mod * (smal) % mod \
- 2 * m[i].y * sumy % mod) + mod) % mod) \
* ksm(smal, mod - 2) % mod) % mod + mod) % mod;
if(m[i].x == r && m[i].y == c) ans = i;
nsum = (nsum + f[i]) % mod;
nsumx = (nsumx + m[i].x) % mod;
nsumy = (nsumy + m[i].y) % mod;
nsumxp = (nsumxp + m[i].x * m[i].x % mod) % mod;
nsumyp = (nsumyp + m[i].y * m[i].y % mod) % mod;
++ nsmal;
}
printf(LLIO "\n", f[ans] % mod);
return 0;
}