CF1042E Vasya and Magic Matrix

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/qcwlmqy/article/details/102478015

CF1042E Vasya and Magic Matrix


题意

一个nm列的矩阵,每个位置有权值 a i , j a_{i,j}

给定一个出发点,每次可以等概率的移动到一个权值小于当前点权值的点,同时得分加上两个点之间欧几里得距离的平方(欧几里得距离: ( x 1 x 2 ) 2 + ( y 1 y 2 ) 2 \sqrt{(x_1-x_2)^2+(y_1-y_2)^2} ,问得分的期望


思路

按照计算期望的一般思路

我们可以考虑先计算小的点的期望,在用小的点的期望计算大的点的期望

我们先从小到大排序

转移方程

d p [ i ] = j a j < a i d p [ j ] + ( x i x j ) 2 + ( y i y j ) 2 dp[i]= \sum_{j}^{a_{j}<a_{i}}{dp[j]+(x_{i}-x_{j})^2+(y_{i}-y_{j})^2}

前缀和优化

  • ( x i x j ) 2 = x i 2 + 2 x i x j + x j 2 (x_{i}-x_{j})^2 =x_{i}^2 +2*x_{i}*x_{j} +x_{j}^2

    所以我们只要知道 x i \sum{x_{i}} x i 2 \sum {x_{i}^2} 就可以O(1)转移

  • 我们维护前缀和 s u m = d p [ j ] sum=\sum{dp[j]} s x = x [ j ] sx=\sum{x[j]} x 2 = x [ j ] 2 x2=\sum{x[j]^2} s y = y [ j ] sy=\sum{y[j]} y 2 = y [ j ] 2 y2=\sum{y[j]^2}

  • d p [ i ] = s u m + s x + s y 2 x [ i ] s x 2 y [ i ] s y + ( x [ i ] 2 + y [ i ] 2 ) ( i ) dp[i]=sum+sx+sy-2*x[i]*sx-2*y[i]*sy+(x[i]^2+y[i]^2)*(小于i的点数)

ans = (sum + x2 + y2) % mod;
LL inv = quickpow(p - 1, mod - 2);
dp[j] = ((ans - sx * 2 * node[j].x - sy * 2 * node[j].y) % mod + mod) % mod;
dp[j] = (dp[j] + LL(p - 1) * node[j].x * node[j].x % mod + LL(p - 1) * node[j].y * node[j].y % mod) % mod;
dp[j] = dp[j] * inv % mod;

代码

#include <bits/stdc++.h>
using namespace std;
struct Node {
    int data;
    int x;
    int y;
    friend bool operator<(const Node& a, const Node& b)
    {
        return a.data < b.data;
    }
};
const int maxn = 1001 * 1001 * 5;
Node node[maxn];
typedef long long LL;
const LL mod = 998244353;
LL quickpow(LL m, LL p)
{
    LL res = 1;
    while (p) {
        if (p & 1)
            res = res * m % mod;
        m = m * m % mod;
        p >>= 1;
    }
    return res;
}
LL dp[maxn];
int main()
{
    int n, m, k = 0;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            scanf("%d", &node[++k].data);
            node[k].x = i;
            node[k].y = j;
        }
    }
    int ex, ey;
    scanf("%d%d", &ex, &ey);
    sort(node + 1, node + 1 + k);
    bool flag = false;
    LL sx, x2, sy, y2, sum, ans;
    sum = sx = sy = x2 = y2 = 0;
    for (int i = 1; i <= k;) {
        int p = i;
        while (++i <= k && node[i].data == node[p].data)
            ;
        ans = (sum + x2 + y2) % mod;
        LL inv = quickpow(p - 1, mod - 2);
        for (int j = p; j < i; j++) {
            dp[j] = ((ans - sx * 2 * node[j].x - sy * 2 * node[j].y) % mod + mod) % mod;
            dp[j] = (dp[j] + LL(p - 1) * node[j].x * node[j].x % mod + LL(p - 1) * node[j].y * node[j].y % mod) % mod;
            dp[j] = dp[j] * inv % mod;
            if (node[j].x == ex && node[j].y == ey) {
                flag = true;
                ans = dp[j];
                break;
            }
        }
        if (flag)
            break;
        for (int j = p; j < i; j++)
            sum = (sum + dp[j]) % mod,
            sx += node[j].x,
            x2 += LL(node[j].x) * node[j].x,
            sy += node[j].y,
            y2 += LL(node[j].y) * node[j].y;
    }
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qcwlmqy/article/details/102478015