洛谷比赛 waaadreamer的圣诞虐题赛题解

比赛链接
本题解同步于洛谷博客

获奖名单

前三:
rank1:@GZY_GZY
rank2:@ljc1301
rank3:@加藤惠
orz三位dalao!tql!

一血榜:
A: @ThinkofBlank(09:05:39)
B: @muller(09:08:57)
C: @GZY_GZY(11:34:05)
D: @东师附中大头(10:57:32)
E: @GZY_GZY(18:29:35获得最高分58pts)

吐槽

谁知道我元旦想放水题放出锅了啊……T2据说是和别人差不多的题……对此我深表歉意。(既然这是基本全场AC的签到题大家就放我一马吧,感激不尽2333)
其它题应该没啥问题。

还有某人特别喜欢抄别人代码,不是我针对你,抄一题就算了,你T3T4T5都是抄的,在这里做个警告,你拿了高分又如何,别以为改了变量名函数名我就看不出来,别以为开了完全隐私保护就NB了。

总体评估

T1,T2简单题,T3,T4稍有难度,T5较难(可能不太准?)部分分放的真的好多好多啊……
好了废话不多说,放题解。

T1 WD与矩阵

subtask1:

暴力即可。复杂度 O ( T 2 n m ) O(T2^{nm})

subtask2:

这个部分分是脑子一抽给的,肯定没人写。直接枚举列,把行的状态状压起来就行。复杂度 O ( T m 2 2 n ) O(Tm2^{2n})

subtask3:

打表应该都能发现,答案就是 2 ( n 1 ) ( m 1 ) 2^{(n-1)(m-1)} ,因为我们可以最后一行和最后一列完全可以根据前面填的值直接算出来,因此任意排列前面的东西即可。复杂度 O ( T l o g n ) O(Tlogn)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int mod = 998244353;
ll modpow(ll a, ll b) {
    ll res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
int main(){
    int n, m, T;
    scanf("%d", &T);
    while (T--) {
        scanf("%d%d", &n, &m);
        printf("%lld\n", modpow(2, (ll)(n - 1) * (m - 1)));
    }
    return 0;
}

T2 WD与循环

subtask1:

我们会显然的发现问题转化为求有多少种不同的值使 i = 1 n a i m \sum_{i=1}^na_i\le m 。因此这个直接dp即可。
f [ i ] [ j ] f[i][j] 表示当前到第 i i 个循环,前面值的和为 j j 的方案数,最后就是 i = 0 m f [ n ] [ i ] \sum_{i=0}^mf[n][i] ,复杂度 O ( n m + T ) O(nm+T) .

subtask2:

发现有不等号不太方便,考虑如果是等号怎么做。这是显然的插板法(如果不知道可以度娘……),也就是说 i = 1 n a i = m \sum_{i=1}^na_i=m 的非负解共有 ( n + m 1 m 1 ) \binom{n+m-1}{m-1} 组,因为我们可以视为在 n n 个数中加上 m 1 m-1 个可相邻的隔板,相邻两个隔板之间的长度就代表了对应的数字。
因此我们预处理阶乘,直接枚举 m m 计算对应的组合数,然后求和即可。复杂度 O ( T m + m o d ) O(Tm+mod)

subtask3:

有个结论,就是:
i = q n ( i q ) = ( n + 1 q + 1 ) \sum_{i=q}^n \binom{i}{q}=\binom{n+1}{q+1}
简略证明如下:
i = q n ( i q ) = ( n q ) + ( n q + 1 ) i = q n 1 ( i q ) = ( n q + 1 ) \sum_{i=q}^n\binom{i}{q}=\binom{n}{q}+\binom{n}{q+1}即\sum_{i=q}^{n-1}\binom{i}{q}=\binom{n}{q+1}
归纳即可。
因此原题的答案直接就是 ( n + m m ) \binom{n+m}{m} ,直接上卢卡斯定理即可。复杂度 O ( T + m o d ) O(T+mod) 。(假定卢卡斯求组合数为 O ( 1 ) O(1) )。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int mod = 19491001;
ll fac[mod], rev[mod], n, m;
ll modpow(ll a, int b) {
    ll res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
ll lucas(ll a, ll b) {
    if (a < b) return 0;
    if (a < mod) return fac[a] * rev[b] % mod * rev[a - b] % mod;
    return lucas(a / mod, b / mod) * lucas(a % mod, b % mod) % mod;
}
int main() {
    for (int i = fac[0] = 1; i < mod; i++) fac[i] = fac[i - 1] * i % mod;
    rev[mod - 1] = modpow(fac[mod - 1], mod - 2);
    for (int i = mod - 1; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
    int T; scanf("%d", &T);
    while (T--) {
        scanf("%lld%lld", &n, &m);
        printf("%lld\n", lucas(n + m, m));
    }
    return 0;
}

T3 WD与数列

subtask1:

似乎 O ( n 3 ) O(n^3) 暴力直接可以把subtask2艹过去?luogu还是太快了啊!这个应该不用多说吧,直接暴力枚举一对开头,然后暴力向后扩展直到相减的值改变了。

subtask2:

本来这个点是给hash的……又怕hash常数大被卡,于是就只有1000的数据量……不管不管送分大甩卖!
(我看你是送分大甩锅吧2333)
首先可以想到差分整个序列,然后问题转化为求不相交不相邻的相等子串对数。
就是枚举答案的长度,然后从左往右跑,对于某个串,把他左边和他不相交或相邻的串的hash加到hash_map里,直接看有多少串的hash和它相同即可。复杂度 O ( n 2 ) O(n^2)

subtask3:

嘿嘿嘿这个点没法水过去了吧……
做法1: 肯定还是差分整个数列,接下来考虑怎么做。如果去掉不相交不相邻的这个限制直接SA+单调栈即可(或者SAM,但是慢啊)。接下来考虑如何把多算的减掉。
我们会发现多算的一定靠的很近(废话),可以考虑使用NOI2016优秀的拆分那题的方法,也就是使用调和级数。
F0fokD.png
我们考虑枚举两个串的偏移位置(即图中的 k k ),然后每 k k 位打一个点,可以发现我们只需要让这两个串在它们第一次出现红点的位置被计算一次即可。
考虑计算两个红点为结尾的LCS和为开头的LCP,由于我们要让每对串只能在第一个红点处被计算,因此LCS要和 k k 取个min。然后就是细节功底了……考虑暴力怎么做,我们可以枚举两个串的开头位置 a , a + k a,a+k ,并且由于两个串要相邻或相交,则第一个串的结尾 b a + k 1 b\ge a+k-1 。于是会发现最后实际上是计算一个等差数列求和的东西,就可以 O ( 1 ) O(1) 算了。注意边界即可。
因此总复杂度为 O ( n l o g n ) O(nlogn) ,LCP和LCS可以用st表 O ( 1 ) O(1) 算。常数比较大2333……

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 500005, maxr = 10000000;
int fst[maxn], sec[maxn], cnt[maxn], srt[maxn], seq[maxn], lg[maxn], n;
struct SA {
    int lcp[maxn], sa[maxn], rnk[maxn], st[20][maxn];
    int cmp(int a, int b, int i) {
        if (a + i * 2 > n + 1 || b + i * 2 > n + 1) return 1;
        return sec[a] != sec[b] || sec[a + i] != sec[b + i];
    }
    void init_sa(int m) {
        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; i++) ++cnt[fst[i] = seq[i]];
        for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
        for (int i = n; i > 0; i--) sa[cnt[fst[i]]--] = i;
        for (int i = 1; i < n; i <<= 1) {
            int p = 0;
            for (int j = n - i + 1; j <= n; j++) sec[++p] = j;
            for (int j = 1; j <= n; j++) if (sa[j] > i) sec[++p] = sa[j] - i;
            memset(cnt, 0, sizeof(cnt));
            for (int j = 1; j <= n; j++) ++cnt[srt[j] = fst[sec[j]]];
            for (int j = 1; j <= m; j++) cnt[j] += cnt[j - 1];
            for (int j = n; j > 0; j--) sa[cnt[srt[j]]--] = sec[j];
            memcpy(sec, fst, sizeof(sec));
            p = fst[sa[1]] = 1;
            for (int j = 2; j <= n; j++) fst[sa[j]] = (p += cmp(sa[j - 1], sa[j], i));
            m = p;
        }
        for (int i = n + 1; i > 1; i--) sa[i] = sa[i - 1];
        sa[1] = n + 1;
    }
    void init(int m) {
        init_sa(m);
        for (int i = 1; i <= n + 1; i++) rnk[sa[i]] = i;
        int h = st[0][1] = lcp[1] = 0;
        for (int i = 1; i <= n; i++) {
            if (h > 0) --h;
            int j = sa[rnk[i] + 1];
            while (h + i <= n && h + j <= n && seq[h + i] == seq[h + j]) ++h;
            st[0][rnk[i]] = lcp[rnk[i]] = h;
        }
        for (int i = 1; i < 20; i++)
        for (int j = 1; j + (1 << i) - 1 <= n; j++)
            st[i][j] = min(st[i - 1][j], st[i - 1][j + (1 << i >> 1)]);
    }
    int get_lcp(int a, int b) {
        assert(a != b);
        a = rnk[a], b = rnk[b];
        if (a > b) swap(a, b);
        int l = lg[b - a];
        return min(st[l][a], st[l][b - (1 << l)]);
    }
} pre, suf;
int main(){
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%d", seq + i);
    for (int i = 1; i < n; i++) seq[i] = seq[i + 1] - seq[i];
    --n; for (int i = 1; i <= n; i++) srt[i] = seq[i];
    for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
    sort(srt + 1, srt + 1 + n);
    int m = unique(srt + 1, srt + 1 + n) - srt - 1;
    for (int i = 1; i <= n; i++)
        seq[i] = lower_bound(srt + 1, srt + 1 + m, seq[i]) - srt;
    pre.init(m);
    ll res = 0, sum = 0; int top = 0;
    for (int i = 1; i <= n + 1; i++) {
        res += sum;
        for (; top > 0 && pre.lcp[srt[top]] >= pre.lcp[i]; --top)
            sum -= (ll)(pre.lcp[srt[top]] - pre.lcp[i]) * (srt[top] - srt[top - 1]);
        sum += pre.lcp[srt[++top] = i];
    }
    reverse(seq + 1, seq + 1 + n);
    suf.init(m);
    for (int i = 1; i <= n; i++) {
        for (int j = i; j + i <= n; j += i) {
            int a = min(i, suf.get_lcp(n - j + 1, n - j - i + 1));
            int b = pre.get_lcp(j, j + i), l = min(a - 1, b + a - i);
            if (l >= 0) res -= (ll)(l + 1) * (b + a - i) - (ll)l * (l + 1) / 2;
        }
    }
    printf("%lld\n", res + (ll)n * (n + 1) / 2);
    return 0;
}

做法2: 我们考虑差分完数列后建出后缀自动机。我们让两个相等且不相交的串在它们的右端点处被统计,于是沿parent树向上线段树合并出每个节点的right集,同时利用启发式合并计算出一个节点对另一个节点的贡献,这个可以直接线段树查询位置和。总复杂度 O ( n l o g 2 n ) O(nlog^2n) ,但由于线段树/启发式合并常数小+不好卡,可以很快地通过此题。
这里放上@GZY_GZY 的代码。

// luogu-judger-enable-o2
#include<unordered_map>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<cassert>
#include<vector>
using namespace std;
#define ll long long
#define REP(i,a,b) for(int i=(a),_end_=(b);i<=_end_;i++)
#define DREP(i,a,b) for(int i=(a),_end_=(b);i>=_end_;i--)
#define EREP(i,a) for(int i=start[(a)];i;i=e[i].next)
template<class T>inline void chkmax(T &a,T b){ if(a<b)a=b;}
template<class T>inline void chkmin(T &a,T b){ if(a>b)a=b;}
#define fi first
#define se second
#define mkr(a,b) make_pair(a,b)
inline int read()
{
    int sum=0,p=1;char ch=getchar();
    while(!(('0'<=ch && ch<='9') || ch=='-'))ch=getchar();
    if(ch=='-')p=-1,ch=getchar();
    while('0'<=ch && ch<='9')sum=sum*10+ch-48,ch=getchar();
    return sum*p;
}

const int maxn=3e5+20;

struct node {
    int len,fa;
    unordered_map<int,int>ch;
};
node t[maxn<<1];
int tot,last[maxn];

int n,a[maxn];
ll ans;

inline void Add(int x,int pos)
{
    int np=last[pos]=++tot,p=last[pos-1];
    while(p && !t[p].ch.count(x))t[p].ch[x]=np,p=t[p].fa;
    if(!p)t[np].fa=1;
    else {
        int q=t[p].ch[x];
        if(t[q].len==t[p].len+1)t[np].fa=q;
        else {
            int nq=++tot; t[nq]=t[q];
            t[nq].len=t[p].len+1;
            t[q].fa=t[np].fa=nq;
            while(p && t[p].ch[x]==q)t[p].ch[x]=nq,p=t[p].fa;
        }
    }
}

struct Node {
    int v,next;
};
Node e[maxn<<1];
int cnt,start[maxn<<1];
inline void addedge(int u,int v){ e[++cnt]=(Node){v,start[u]};start[u]=cnt;}

struct NODE {
    int s;
    ll xs;
    int ls,rs;
};
NODE c[maxn<<6];
int ctot,rt[maxn<<1];

void update(int x,int &o,int l,int r)
{
    if(!o)o=++ctot;
    c[o].s++; c[o].xs+=x;
    if(l==r)return;
    int mid=l+r>>1;
    if(x<=mid)update(x,c[o].ls,l,mid);
    else update(x,c[o].rs,mid+1,r);
}

vector<int>tmp[maxn<<1];
vector<int>*f[maxn<<1];

int querys(int ql,int qr,int o,int l,int r)
{
    if(ql>qr || !o || qr<1)return 0;
    if(ql<=l && r<=qr)return c[o].s;
    int mid=l+r>>1,ans=0;
    if(ql<=mid)ans+=querys(ql,qr,c[o].ls,l,mid);
    if(qr>mid)ans+=querys(ql,qr,c[o].rs,mid+1,r);
    return ans;
}

ll queryxs(int ql,int qr,int o,int l,int r)
{
    if(ql>qr || !o || qr<1 || ql>n)return 0;
    if(ql<=l && r<=qr)return c[o].xs;
    int mid=l+r>>1; ll ans=0;
    if(ql<=mid)ans+=queryxs(ql,qr,c[o].ls,l,mid);
    if(qr>mid)ans+=queryxs(ql,qr,c[o].rs,mid+1,r);
    return ans;
}

int Merge(int x,int y)
{
    if(!x || !y)return x|y;
    c[x].s+=c[y].s;
    c[x].xs+=c[y].xs;
    c[x].ls=Merge(c[x].ls,c[y].ls);
    c[x].rs=Merge(c[x].rs,c[y].rs);
    return x;
}

void dfs(int u)
{
    int k=t[u].len;
    EREP(i,u)
    {
        int v=e[i].v;
        dfs(v);
        if(f[u]->size()<f[v]->size())swap(f[u],f[v]),swap(rt[u],rt[v]);
        for(int x:(*f[v]))
        {
            ll A=1ll*querys(x-k-1,x-2,rt[u],1,n)*(x-1)-queryxs(x-k-1,x-2,rt[u],1,n)+1ll*querys(1,x-k-2,rt[u],1,n)*k;
            ll B=queryxs(x+2,x+k+1,rt[u],1,n)-1ll*querys(x+2,x+k+1,rt[u],1,n)*(x+1) + 1ll*querys(x+k+2,n,rt[u],1,n)*k;
            ans+=A+B;
        }
        for(int x:(*f[v]))f[u]->push_back(x);
        rt[u]=Merge(rt[u],rt[v]);
    }
}

inline void init()
{
    n=read();
    REP(i,1,n)a[i]=read();
    REP(i,1,n-1)a[i]=a[i+1]-a[i];
    ans=(ll)n*(n-1)>>1;
    n--;
    last[0]=tot=1;
    REP(i,1,n)Add(a[i],i);
    REP(i,2,tot)addedge(t[i].fa,i);
    REP(i,1,tot)f[i]=&tmp[i];
    REP(i,1,n)f[last[i]]->push_back(i),update(i,rt[last[i]],1,n);
    dfs(1);
}

inline void doing()
{
    printf("%lld\n",ans);
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("C.in","r",stdin);
    freopen("C.out","w",stdout);
#endif
    init();
    doing();
    return 0;
}


T4 WD与积木

subtask1:

显然可以dp。由于等概率下期望=总和/方案数,我们令 f [ i ] f[i] 表示有 i i 个积木时产生层数的总和, g [ i ] g[i] 表示 i i 个积木产生不同堆法的数量,则答案就是 f [ n ] g [ n ] \frac{f[n]}{g[n]} 。显然,我们通过枚举第一层是哪些积木就可以dp了:
f [ n ] = g [ n ] + i = 1 n ( n i ) f [ n i ] , f [ 0 ] = 0 f[n]=g[n]+\sum_{i=1}^n\binom nif[n-i],f[0]=0
g [ n ] = i = 1 n ( n i ) g [ n i ] , g [ 0 ] = 1 g[n]=\sum_{i=1}^n\binom nig[n-i],g[0]=1
复杂度 O ( n 2 ) O(n^2)

subtask2:

我们考虑答案的组合意义。实际上是把这些积木划分成了一些集合,然后任意排列这些集合求方案数和总和。
也就是说我们实际上求的是这个东西:
s u m = i = 1 n S ( n , i ) i ! i , c n t = i = 1 n S ( n , i ) i ! = s u m c n t sum=\sum_{i=1}^nS(n,i)\cdot i!\cdot i,cnt=\sum_{i=1}^nS(n,i)\cdot i!,那么答案=\frac{sum}{cnt}
S ( n , i ) S(n,i) 是第二类斯特林数。因此直接NTT预处理第二类斯特林数即可,复杂度 O ( T n l o g n ) O(Tnlogn)

subtask3:

考虑如何优化subtask1中的dp做法。不妨设:
F ( x ) = i f [ i ] i ! x i , G ( x ) = i g [ i ] i ! x i , H ( x ) = i x i i ! F(x)=\sum_i\frac{f[i]}{i!}x^i,G(x)=\sum_i\frac{g[i]}{i!}x^i,H(x)=\sum_i\frac{x^i}{i!}
则很容易得到 2 F ( x ) = G ( x ) + F ( x ) H ( x ) 1 , 2 G ( x ) = G ( x ) H ( x ) + 1 2F(x)=G(x)+F(x)H(x)-1,2G(x)=G(x)H(x)+1 ,不难解得:
F ( x ) = G ( x ) ( G ( x ) 1 ) , G ( x ) = 1 2 H ( x ) F(x)=G(x)(G(x)-1),G(x)=\frac{1}{2-H(x)}
多项式求逆即可,复杂度 O ( n l o g n ) O(nlogn)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 1 << 18, mod = 998244353, G = 3, maxr = 10000000;
char str[maxr], prt[maxr];
int rpos, ppos, mmx;
char readc(){
    if(!rpos) mmx = fread(str, 1, maxr, stdin);
    char c = str[rpos++];
    if(rpos == maxr) rpos = 0;
    if(rpos > mmx) return 0;
    return c;
}
int read(){
    int x; char c;
    while((c = readc()) < '0' || c > '9');
    x = c - '0';
    while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
    return x;
}
void print(ll x){
    if(x){
        static char sta[20];
        int tp = 0;
        for(; x; x /= 10) sta[tp++] = x % 10 + '0';
        while(tp > 0) prt[ppos++] = sta[--tp];
    } else prt[ppos++] = '0';
    prt[ppos++] = '\n';
}
ll fac[maxn], rev[maxn], A[maxn], B[maxn], temp[maxn], F[maxn];
ll modpow(ll a, int b) {
    ll res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
    }
    return res;
}
void rader(ll *a, int n) {
    for (int i = 1, j = n >> 1; i < n - 1; i++) {
        if(i < j) swap(a[i], a[j]);
        int k = n >> 1;
        for (; j >= k; k >>= 1) j -= k;
        if(j < k) j += k;
    }
}
void ntt(ll *a, int n, int rev) {
    rader(a, n);
    for (int h = 2; h <= n; h <<= 1) {
        ll wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
        int hh = h >> 1;
        for (int i = 0; i < n; i += h)
        for (int j = i, w = 1; j < i + hh; j++, w = w * wn % mod) {
            int x = a[j], y = a[j + hh] * w % mod;
            a[j] = (x + y) % mod;
            a[j + hh] = (x - y + mod) % mod;
        }
    }
    if (rev) {
        int inv = modpow(n, mod - 2);
        for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
    }
}
void get_inv(ll *a, ll *b, int n) {
    if (n == 1) {b[0] = modpow(a[0], mod - 2); b[1] = 0; return;}
    get_inv(a, b, n >> 1);
    for (int i = 0; i < n; i++) temp[i] = a[i];
    int t = n << 1;
    for (int i = n; i < t; i++) temp[i] = 0;
    ntt(temp, t, 0), ntt(b, t, 0);
    for (int i = 0; i < t; i++)
        b[i] = (2 - b[i] * temp[i] % mod + mod) * b[i] % mod;
    ntt(b, t, 1);
    for(int i = n; i < t; i++) b[i] = 0;
}
int main() {
    int n = 100000, T = read();
    for (int i = fac[0] = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
    rev[n] = modpow(fac[n], mod - 2);
    for (int i = n; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
    for (int i = 0; i <= n; i++) A[i] = mod - rev[i];
    A[0] = (A[0] + 2) % mod;
    int len = 1;
    while (len <= n) len <<= 1;
    get_inv(A, B, len);
    for (int i = 0; i < len; i++) A[i] = F[i] = B[i];
    A[0] = (A[0] + mod - 1) % mod;
    ntt(A, len << 1, 0), ntt(F, len << 1, 0);
    for (int i = 0; i < len << 1; i++) F[i] = F[i] * A[i] % mod;
    ntt(F, len << 1, 1);
    while (T--) {
        n = read();
        print(F[n] * modpow(B[n], mod - 2) % mod);
    }
    fwrite(prt, 1, ppos, stdout);
    return 0;
}

T5 WD与地图

subtask1:

显然,一个地区是原图的一个强连通分量,那么我们就在每次删除操作之后暴力跑一遍tarjan,然后对于每个强连通分量建立一颗权值线段树,节点记录权值的出现次数和节点上所有数字的和。每次修改就是把一个位置的出现次数-1,另一个位置+1;每次询问在线段树上二分一下就好了。复杂度 O ( Q l o g n + O(Qlogn+ 删除次数 × m l o g n ) \times mlogn) .

subtask2:

把所有操作倒序,这样我们考虑如何维护加边操作后的强联通分量。比如连了一条a到b的边,可以暴力从b搜到a,沿途把所有点的权值线段树合并起来。这样做的复杂度是 O ( m × + n l o g n + Q l o g n ) O(m\times 删除次数+nlogn+Qlogn) 的,瓶颈在于暴力从b跑到a进行缩点。

subtask3:

做法2的瓶颈就在于无法快速求出每条边在什么时刻被缩起来的,这个东西的确不好直接用数据结构维护,我们可以考虑整体二分。
每次二分出一个 m i d mid ,把出现时间小于等于 m i d mid 的边全部加入到图里求一遍tarjan,就可以知道哪些边在 m i d \le mid 的时间内被缩掉了。具体的,我们还需要一个可撤销并查集来维护一下在当前有多少点已经被缩到一起,接下来就直接是整体二分的事情了。总复杂度 O ( Q l o g n α ( n ) ) O(Qlogn\cdot\alpha(n))

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> P;

const int maxn = 100005, maxm = 200005, maxr = 10000000, maxt = 10000005;
char str[maxr], prt[maxr];
int rpos, ppos, mmx;
char readc(){
    //return getchar();
    if(!rpos) mmx = fread(str, 1, maxr, stdin);
    char c = str[rpos++];
    if(rpos == maxr) rpos = 0;
    if(rpos > mmx) return 0;
    return c;
}
int read(){
    int x; char c;
    while((c = readc()) < '0' || c > '9');
    x = c - '0';
    while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
    return x;
}
void print(ll x){
    if(x){
        static char sta[20];
        int tp = 0;
        for(; x; x /= 10) sta[tp++] = x % 10 + '0';
        while(tp > 0) prt[ppos++] = sta[--tp];
    } else prt[ppos++] = '0';
    prt[ppos++] = '\n';
}
struct Edge { int to, next; } edge[maxm];
struct Graph { int u, v, t; } gra[maxm], cpy[maxm];
int head[maxn], par[maxn], val[maxn], n, m, Q, tot;
int type[maxm], arga[maxm], argb[maxm], arr[maxm * 2];
vector<Graph> gath[maxm];
int dfn[maxn], low[maxn], sta[maxn], ins[maxn], top;
int find(int x) { return x == par[x] ? x : par[x] = find(par[x]); }
void merge(int x, int y) {
    x = find(x), y = find(y);
    if (x == y) return;
    if (x > y) swap(x, y);
    par[x] = y;
}
void tarjan(int u) {
    dfn[u] = low[u] = ++tot;
    ins[sta[++top] = u] = 1;
    for (int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].to;
        if (!dfn[v]) tarjan(v), low[u] = min(low[u], low[v]);
        else if (ins[v]) low[u] = min(low[u], dfn[v]);
    }
    if (low[u] == dfn[u]) {
        for (int x = 0; x != u; x = sta[top--])
            merge(sta[top], u), ins[sta[top]] = 0;
    }
}
void divide(int l, int r, int ql, int qr) {
    if (l == r) {
        for (int i = ql; i <= qr; i++) gath[l].push_back(gra[i]);
        return;
    } else if (ql == qr) {
        int t = find(gra[ql].u) == find(gra[ql].v) ? gra[ql].t : Q + 1;
        gath[t].push_back(gra[ql]);
        return;
    }
    int mid = (l + r) >> 1, cnt = tot = 0;
    for (int i = ql; i <= qr; i++) if (gra[i].t <= mid) {
        int u = find(gra[i].u), v = find(gra[i].v);
        arr[++cnt] = u, arr[++cnt] = v;
        head[u] = head[v] = 0;
    }
    for (int i = 1; i < cnt; i += 2) {
        int u = arr[i], v = arr[i + 1];
        edge[++tot] = (Edge) { v, head[u] };
        head[u] = tot;
        dfn[u] = low[u] = ins[u] = 0;
        dfn[v] = low[v] = ins[v] = 0;
    }
    tot = top = 0;
    for (int i = 1; i <= cnt; i++) if (!dfn[arr[i]]) tarjan(arr[i]);
    int a = 0, b = ql; tot = 1;
    vector<P> cpar;
    for (int i = ql; i <= qr; i++) {
        int flag = 0;
        if (gra[i].t <= mid) {
            int u = arr[tot], v = arr[tot + 1]; tot += 2;
            if (find(u) == find(v)) flag = 1;
            cpar.push_back(P(u, par[u]));
            cpar.push_back(P(v, par[v]));
        }
        if (flag) gra[b++] = gra[i];
        else cpy[++a] = gra[i];
    }
    for (int i = b; i <= qr; i++) gra[i] = cpy[i - b + 1];
    for (int i = 1; i <= cnt; i++) par[arr[i]] = arr[i];
    if (ql < b) divide(l, mid, ql, b - 1);
    for (P p : cpar) par[p.first] = p.second;
    if (b <= qr) divide(mid + 1, r, b, qr);
}
int ls[maxt], rs[maxt], cnt[maxt], rt[maxn], tn; ll sum[maxt], res[maxm];
void update(int x, int coel, int k, int l = 1, int r = tn) {
    cnt[k] += coel; sum[k] += x * coel;
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (x <= mid) {
        if (!ls[k]) ls[k] = ++tot;
        update(x, coel, ls[k], l, mid);
    } else {
        if (!rs[k]) rs[k] = ++tot;
        update(x, coel, rs[k], mid + 1, r);
    }
}
ll query(int x, int k, int l = 1, int r = tn) {
    if (x >= cnt[k]) return sum[k];
    if (l == r) return sum[k] / cnt[k] * x;
    int mid = (l + r) >> 1;
    if (cnt[rs[k]] >= x) return query(x, rs[k], mid + 1, r);
    else return query(x - cnt[rs[k]], ls[k], l, mid) + sum[rs[k]];
}
int Tmerge(int x, int y) {
    if (!x || !y) return x + y;
    cnt[x] += cnt[y], sum[x] += sum[y];
    ls[x] = Tmerge(ls[x], ls[y]);
    rs[x] = Tmerge(rs[x], rs[y]);
    return x;
}
int main(){
    n = read(), m = read(), Q = read();
    for (int i = 1; i <= n; i++) val[par[i] = i] = read();
    set<P> ss;
    for (int i = 1; i <= m; i++) {
        int u = read(), v = read();
        ss.insert(P(u, v));
    }
    for (int i = Q; i > 0; i--) {
        int t = read(), a = read(), b = read();
        type[i] = t, arga[i] = a, argb[i] = b;
        if (t == 1) gra[++tot] = (Graph) { a, b, i }, ss.erase(P(a, b));
        else if (t == 2) val[a] += b;
    }
    for (P p : ss) gra[++tot] = (Graph) { p.first, p.second, 0 };
    assert(tot == m);
    divide(0, Q + 1, 1, m);
    /*for (int i = 0; i <= Q + 1; i++) {
        printf("%d ", i);
        for (Graph g : gath[i]) printf("(%d, %d) ", g.u, g.v);
        putchar('\n');
    }*/
    tot = top = 0; type[0] = 1;
    for (int i = 1; i <= n; i++) tn = max(tn, val[par[i] = i]);
    for (int i = 1; i <= n; i++) update(val[i], 1, rt[i] = ++tot);
    for (int i = 0; i <= Q; i++) {
        if (type[i] == 1) {
            for (Graph g : gath[i]) {
                g.u = find(g.u), g.v = find(g.v);
                if (g.u == g.v) continue;
                par[g.v] = g.u;
                rt[g.u] = Tmerge(rt[g.u], rt[g.v]);
            }
        } else if (type[i] == 2) {
            int &w = val[arga[i]], u = find(arga[i]);
            update(w, -1, rt[u]);
            update(w -= argb[i], 1, rt[u]);
        } else res[++top] = query(argb[i], rt[find(arga[i])]);
    }
    while (top > 0) print(res[top--]);
    fwrite(prt, 1, ppos, stdout);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/WAautomaton/article/details/85057257