【题解】CF1042E Vasya and Magic Matrix【期望dp】

题目大意

一个 n n m m 列的矩阵,每个位置有权值 a i , j a_{i,j}

给定一个出发点,每次可以等概率的移动到一个权值小于当前点权值的点,同时得分加上两个点之间欧几里得距离的平方(欧几里得距离: ( x 1 x 2 ) 2 + ( y 1 y 2 ) 2 \sqrt{(x_1-x_2)^2+(y_1-y_2)^2} ),问得分的期望

solution

考虑期望 d p dp

很容易写出转移方程: f [ i ] = 1 m v j < v i f [ j ] + d i s ( i , j ) 2 f[i]=\frac{1}{m} \cdot \sum_{v_j<v_i}f[j]+dis(i,j)^2
m m v j < v i v_j<v_i 的元素个数

直接求解复杂度为 O ( n 2 ) O(n^2) ,考虑优化

d i s dis 展开,得到 x i 2 2 x i x j + x j 2 + y i 2 2 y i y j + y j 2 x_i^2-2x_ix_j+x_j^2+y_i^2-2y_iy_j+y_j^2

很明显在根据 v v 排序后含 j j 每一项都可以利用前缀和优化

1 m \frac{1}{m} 可以用费马小定理 O ( l o g n ) O(logn) 求逆元。

复杂度 O ( n l o g n ) O(nlogn)

注意细节,由于只能向小于的转移,所以转移时要分层转移,在转移完这一层后更新前缀和。

m m 表示的是比 v [ i ] v[i] 小的个数,不是在 i i 前有多少个

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <bitset>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <cmath>

using namespace std;

#ifdef WIN32
#define LLIO "%I64d"
#else
#define LLIO "%lld"
#endif

struct FastIO {
    inline FastIO& operator >> (int& x) {
        x = 0; char f = 0, ch = getchar();
        while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
        while(ch <= '9' && ch >= '0') x = x * 10 + ch - 48, ch = getchar();
        return x = (f ? -x : x), *this;
    }
    inline FastIO& operator >> (long long& x) {
        x = 0; char f = 0, ch = getchar();
        while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
        while(ch <= '9' && ch >= '0') x = x * 10 + ch - 48, ch = getchar();
        return x = (f ? -x : x), *this;
    }
    inline FastIO& operator >> (double& x) {
        x = 0; char f = 0, ch = getchar();
        double d = 0.1;
        while(ch > '9' || ch < '0') f |= (ch == '-'), ch = getchar();
        while(ch <= '9' && ch >= '0') x=x * 10 + ch - 48, ch = getchar();
        if(ch == '.') {
            ch = getchar();
            while(ch <= '9' && ch >= '0') x += d * (ch ^ 48), d *= 0.1, ch = getchar();
        }
        return x = (f ? -x : x), *this;
    }
}rin;
const int N = 1e6 + 50, mod = 998244353;
struct node {
	long long x, y, v;
}m[N];
bool operator < (const node &a, const node &b) {
	return a.v < b.v;
}
int n, k, r, c, tot, ans, st, smal, nsmal;
long long sum, sumx, sumy, sumxp, sumyp;
long long nsum, nsumx, nsumy, nsumxp, nsumyp;
long long f[N];
long long ksm(long long x, int p) {
	long long ans = 1;
	while(p) {
		if(p & 1) ans = ans * x % mod;
		x = x * x % mod;
		p >>= 1;
	}
	return ans % mod;
}
signed main() {
	rin >> n >> k; 
	for(int i = 1 ; i <= n ; ++ i) {
		for(int j = 1 ; j <= k ; ++ j) {
			rin >> m[++ tot].v;
			m[tot].x = i;
			m[tot].y = j;
		}
	}
	rin >> r >> c;
	sort(m + 1, m + 1 + tot);
	for(int i = 1 ; i <= tot ; ++ i) {
		if(m[i].v != m[i - 1].v && i != 1) break;
		f[i] = 0;
		nsumx  = (nsumx + m[i].x) % mod;
		nsumy  = (nsumy + m[i].y) % mod;
		nsumxp = (nsumxp + m[i].x * m[i].x % mod) % mod;
		nsumyp = (nsumyp + m[i].y * m[i].y % mod) % mod;
		++ st;
		++ nsmal;
		if(m[i].x == r && m[i].y == c) ans = i;
	}
	for(int i = st + 1 ; i <= tot ; ++ i) {
		if(m[i].v != m[i - 1].v) {
			sum   = (sum   + nsum)   % mod; nsum   = 0;
			sumx  = (sumx  + nsumx)  % mod; nsumx  = 0;
			sumy  = (sumy  + nsumy)  % mod; nsumy  = 0;
			sumxp = (sumxp + nsumxp) % mod; nsumxp = 0;
			sumyp = (sumyp + nsumyp) % mod; nsumyp = 0;
			smal  = nsmal + smal;           nsmal  = 0;
		}
		f[i] = (((((sum \
				+ m[i].x * m[i].x % mod * (smal) % mod \
				- 2 * m[i].x * sumx % mod \
				+ sumxp \
				+ sumyp \
				+ m[i].y * m[i].y % mod * (smal) % mod \
				- 2 * m[i].y * sumy % mod) + mod) % mod) \
				* ksm(smal, mod - 2) % mod) % mod + mod) % mod;
		if(m[i].x == r && m[i].y == c) ans = i;
		nsum   = (nsum   + f[i]) % mod;
		nsumx  = (nsumx  + m[i].x) % mod;
		nsumy  = (nsumy  + m[i].y) % mod;
		nsumxp = (nsumxp + m[i].x * m[i].x % mod) % mod;
		nsumyp = (nsumyp + m[i].y * m[i].y % mod) % mod;
		++ nsmal;
	}
	printf(LLIO "\n", f[ans] % mod);
	return 0;
}	

猜你喜欢

转载自blog.csdn.net/weixin_43933089/article/details/85011387