Codeforces #510-E Vasya and Magic Matrix

在一个n*m的矩阵中给定一个点,可以到达任意数值比它小的点,获得的分数是两点间欧几里得距离的平方,求期望分数。

- 期望的意思是所有可能发生的情况分别乘上它们发生的概率。

- 解法是从小往大求出每个数可以获得的值。可以推一下式子,用一用因式分解就出来了。记录一些奇怪的前缀和即可。

- 注意需要求逆元,每一次加法都要记得取模。

- 没有想到的原因是不熟悉dp,相当于可以直接用以前的结果了。(自己总是会想复杂)

//#include<bits/stdc++.h>
#include<cmath>
#include<cstdio>
#include<assert.h>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
int n,m,R,C;
int p = 998244353;
int dp[1005][1005];
int mul(int a,int b)
{
    return int(a * 1LL * b % p);
}
void upd(int &a,int b){
    a += b;
    while(a >= p) a -= p;
    while(a < 0) a += p;
}
int bp(int a,int b)
{
    int res = 1;
    while(b > 0)
    {
        if(b & 1) res = mul(res,a);
        a = mul(a,a);
        b >>= 1;
    }
    return res;
}
int inv(int a){
    int ia = bp(a,p-2);
    assert(mul(a,ia) == 1);
    return ia;
}
struct data{
    int v,r,c;
}d[1005*1005];


bool cmp(data a, data b)
{
    return a.v < b.v;
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i = 1;i <= n; i++)
    {
        for (int j = 1; j <= m; j++)
        {
            int xx;
            scanf("%d",&xx);
            d[(i-1)*m+j].v = xx;
            d[(i-1)*m+j].r = i;
            d[(i-1)*m+j].c = j;
        }
    }
    scanf("%d%d",&R,&C);
    sort(d+1,d+n*m+1,cmp);
    memset(dp,0,sizeof dp);
    int SumDP = 0, SumR = 0, SumC = 0, SumR2 = 0, SumC2 = 0;
  
    int l = 1;
 //   cout<<mul(2,2)<<" "<<bp(2,2)<<endl;
     
    while(l <= n*m)
    {
        int r = l;
        while(d[r].v == d[l].v && r <= n*m) r++;
      //  cout<<l<<" "<<r<<endl;
       
        int il = -1;
        if(l != 1) il = inv(l-1);
        
        for (int i = l; i < r; i++)
        {
            int rr = d[i].r,cc = d[i].c;
            if(il == -1) 
            {
                dp[rr][cc] = 0;
                continue;
            }
            upd(dp[rr][cc],mul(SumDP, il));
            upd(dp[rr][cc],mul(rr,rr));
            upd(dp[rr][cc],mul(cc,cc));
            upd(dp[rr][cc],mul(SumR2,il));
            upd(dp[rr][cc],mul(SumC2,il));
            upd(dp[rr][cc],mul(mul(-rr-rr,SumR),il));
            upd(dp[rr][cc],mul(mul(-cc-cc,SumC),il));
        }
        for (int i = l; i < r; i++)
        {
            int rr = d[i].r, cc = d[i].c;
          //  cout<<rr<<" "<<cc<<" "<<dp[rr][cc]<<endl;
            upd(SumDP,dp[rr][cc]);
            upd(SumR2,mul(rr,rr));
            upd(SumC2,mul(cc,cc));
            upd(SumR,rr);
            upd(SumC,cc);
          //  cout<<SumDP<<" "<<SumR2<<" "<<SumC2<<" "<<SumR<<" "<<SumC<<endl;
        }
        
        l = r;
    }
    printf("%d\n",dp[R][C]);
    return 0;
}
代码

猜你喜欢

转载自www.cnblogs.com/waxwing/p/9672179.html