莫队算法总结

莫队算法 总结

最近两天学习了一下莫队,感觉莫队算法还是挺好用的(现在看到离线询问就想莫队...
就稍微写一下总结吧,加深一下对算法的理解。

  • 普通莫队

核心思想:莫队算法一般用来离线处理一系列无修改的区间询问问题,通过将所有的询问保存下来,并且将所有的询问区间进行适当地排序,从而达到降低时间复杂度的效果。

对于所有的询问区间\([l_i,r_i]\),如果暴力地进行区间端点移动,那么对于一次询问,区间端点可能移动\(n\)的长度。假设询问的规模与\(n\)同级,那么复杂度就为\(O(n^2)\)

但其实,我们可以巧妙地安排区间顺序以降低时间复杂度。
莫队算法的思想如下:
将区间分为\(\sqrt{n}\)块,每块的长度也为\(\sqrt{n}\),之后对所有的询问区间排序,如果区间左端点在同一块内,则按右端点排序;否则则按左端点所在块进行排序。
就这样排序过后,暴力计算就行了,可以证明,时间复杂度为\(O(n^{\frac{3}{2}})\)

下面给出简单的证明:
假设区间左端点在同一块内,那么一次询问左端点最多移动\(\sqrt{n}\),由于右端点是单增的,则右端点移动总的复杂度为\(O(n)\),此时端点移动的总复杂度为\(O(n^{\frac{3}{2}})\)。(注意这里是均摊意义上的复杂度)
如果区间左端点不在同一块,也就是左端点跨块移动,因为一共有\(\sqrt{n}\)块,每次右端点的移动最多\(O(n)\),此时总的时间复杂度也为\(O(n^{\frac{3}{2}})\)
所以经过分块过后,时间复杂度可以降为\(O(n^\frac{3}{2})\)

可以先通过几道例题感受一下:
 

\(cnt[i]\)为第\(i\)种颜色的袜子的个数,当前区间为\([l,r]\),那么容易知道所求的答案为\(\frac{\sum_{i=1}^{k}{C_{cnt[i]}^2}}{C_{r-l+1}^{2}}\)
因为分母是与区间长度有关,我们只用考虑区间端点变化时,分子的变化情况就行了。

先单独把分子拿出来:\(\sum_{i=1}^{k}{C_{cnt[i]}^2}\),当区间范围增加一时,会存在一个\(t\),有\(cnt[t]+1\),在这个求和式中,其余项不会改变,那么我们就只用看这一项对答案的影响。
影响即为:\(C_{cnt[t]+1}^2-C_{cnt[t]}^2\),那么在进行答案更新时算一下这个式子就好了。
对于区间范围减小的情况分析也同理。
代码如下:


Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005;
int n, m, block;
int a[N];
struct query{
    int l, r, id ;
}Q[N];
struct Ans{
    ll p, q;
}answer[N];
bool cmp(query x, query y) {
    if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
    return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
}
ll gcd(ll A, ll B) {
    return B == 0 ? A : gcd(B, A % B) ;
}
ll ans ;
ll cnt[N] ;
void update(int pos, int sign) {
    ans -= cnt[a[pos]] * cnt[a[pos]] ;
    cnt[a[pos]] += sign ;
    ans += cnt[a[pos]] * cnt[a[pos]] ;
}
int main() {
    scanf("%d%d",&n, &m) ;
    block = (int)sqrt(n) ;
    for(int i = 1; i <= n; i++) scanf("%d", &a[i]) ;
    for(int i = 1; i <= m; i++) {
        scanf("%d%d",&Q[i].l, &Q[i].r) ;
        Q[i].id = i ;
    }
    sort(Q + 1, Q + m + 1, cmp) ;
    int l = 1, r = 0;
    for(int i = 1; i <= m; i++) {
        for(; r < Q[i].r; r++) update(r + 1, 1) ;
        for(; r > Q[i].r; r--) update(r, -1) ;
        for(; l < Q[i].l; l++) update(l, -1) ;
        for(; l > Q[i].l; l--) update(l - 1, 1) ;
        answer[Q[i].id].p = ans - Q[i].r + Q[i].l - 1;
        answer[Q[i].id].q = 1ll * (Q[i].r - Q[i].l + 1) * (Q[i].r - Q[i].l) ;
        if(Q[i].l == Q[i].r) answer[Q[i].id].p = 0, answer[Q[i].id].q = 1;              
        ll g = gcd(answer[Q[i].id].p, answer[Q[i].id].q) ;
        answer[Q[i].id].p /= g; answer[Q[i].id].q /= g;
    }
    for(int i = 1; i <= m; i++) printf("%lld/%lld\n",answer[i].p, answer[i].q) ;
    return 0;
}

 

给出的数字串挺长的,但是质数\(p\)不是很大。
我们知道,如果一个数字\(t\)\(p\)的倍数,那么就有\(t\mod p=0\)。但是区间中的子串很多,我们直接时间复杂度等同于暴力。所以我们可以考虑将问题转化一下。

设串\(s\)所在区间为\([l,r]\),串的长度为\(n\),那么我们知道\(s*10^{r-l+1}=t[l,l+1,\cdots,n]-t[r+1,r+2,\cdots,n]*10^{r-l+1}\)
所以当质数\(p\)不为2和5时,\(s\mod p=0 => (t[l,l+1,\cdots,n]-t[r+1,r+2,\cdots,n]*10^{r-l+1}) \mod p=0 => t[l,l+1,\cdots,n]\mod p=t[r+1,r+2,\cdots,n]\mod p\)
所以我们就可以维护一个数组\(f[i]\),表示后缀\(i\)\(p\)取余的值为多少,那么我们就可以将一个区间为\([l,r]\)的询问转化为\([l,r+1]\)中有多少对\(f\)相等了。
之后就用莫队来搞,计算区间范围增加或者减小对答案的影响就好了。思路同上一题类似。

对于\(p\)为2或者5的情况,特判一波,维护前缀个数就好了。
代码如下:


Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
ll p, cnt;
int block, n;
char s[N] ;
ll f[N], d[N];
ll num[N] ;
ll sum[N][3] ;
struct Query{
    int l, r, id;
    ll ans ;
}q[N];
int Q;
bool cmp(Query x, Query y) {
    if((x.l - 1) / block + 1 == (y.l - 1) / block + 1) return x.r < y.r;
    return (x.l - 1) / block + 1 < (y.l - 1) / block + 1 ;
}
bool cmp2(Query x, Query y) {
    return x.id < y.id ;
}
void sol1() {
    for(int i = 1; i <= n; i++) {
        sum[i][0] = sum[i - 1][0] ;
        sum[i][1] = sum[i - 1][1] ;
        sum[i][2] = sum[i - 1][2] ;
        if(p == 2 && (s[i] - '0') % p == 0) sum[i][0] += i, sum[i][2]++;
        if(p == 5 && (s[i] - '0') % p == 0) sum[i][1] += i, sum[i][2]++;
    }
    for(int i = 1; i <= Q; i++) {
        int l = q[i].l, r = q[i].r;
        ll k;
        if(p == 2) k = sum[r][0] - sum[l - 1][0] ;
        else k = sum[r][1] - sum[l - 1][1] ;
        q[i].ans = k - (sum[r][2] - sum[l - 1][2]) * (l - 1) ;
    }
}
void update2(int pos, int sign) {
    cnt -= (num[f[pos]] - 1) * num[f[pos]] / 2;
    num[f[pos]] += sign;
    cnt += (num[f[pos]] - 1) * num[f[pos]] / 2;
}
void sol2() {
    int l = 1, r = 0 ;
    for(int i = 1; i <= Q; i++) {
        q[i].r += 1;
        for(; r < q[i].r; r++) update2(r + 1, 1) ;
        for(; r > q[i].r; r--) update2(r, -1) ;
        for(; l > q[i].l; l--) update2(l - 1, 1) ;
        for(; l < q[i].l; l++) update2(l, -1) ;
        q[i].ans = cnt ;        
    }
}
int main() {
    scanf("%lld%s%d", &p, s + 1, &Q) ;
    n = strlen(s + 1) ;
    block = (int)sqrt(n) ;
    for(int i = 1; i <= Q; i++) {
        scanf("%d%d",&q[i].l, &q[i].r) ;
        q[i].id = i;
    }
    sort(q + 1, q + Q + 1, cmp) ;
    ll x = 0, qp = 1;
    int flag = -1;
    for(int i = n; i >= 1; i--) {
        x = (x + (s[i] - '0') * qp % p) % p;
        d[i] = f[i] = x;
        if(f[i] == 0) flag = i;
        qp = qp * 10 % p;
    }
    sort(d + 1, d + n + 1) ;
    int D = unique(d + 1, d + n + 1) - d - 1;
    for(int i = 1; i <= n; i++) f[i] = lower_bound(d + 1, d + n + 1, f[i]) - d;
    if(flag > 0) f[n + 1] = f[flag] ;   
    if(p == 2 || p == 5) sol1() ;
    else sol2() ;
    sort(q + 1, q + Q + 1, cmp2) ;
    for(int i = 1; i <= Q; i++) printf("%lld\n", q[i].ans) ; 
    return 0;
}

 

感觉这个题挺好的。没想到还可以用莫队来搞。
对于区间\([l,r]\),假设我们要将\(r\)增加1,那么就会多出\(r-l+2\)个序列,我们就分析他们对答案的影响。
假设区间\([l,r]\)中最小值所在位置为\(p\),那么很显然,左端点在\([l,l+1,\cdots,p]\)时,区间最小值就为\(a[p]\)

对于\(r+1\)而言,如果我们找到左边第一个比他小的位置为\(k\),那么此时对答案的贡献就为\((r-k+2)*a[k]\);同理对\(k\)也可以执行同样的操作。最后必然会存在一个位置\(q\),其左边第一个比他小的位置为\(q\),那么操作在这里就终止了。

每次这么操作时间复杂度过高,发现可以维护一个类似于前缀和一样的东西,递推地来维护就行了。设该前缀和函数为\(f\),那么区间右端点增加一位对答案的贡献为:\(a[p]*(p-l+1)+f[r+1]-f[p]\)
这样就可以O(1)算出对答案的影响了。
左端点的情况也类似考虑。
代码如下:


Code

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, m, block;
int a[N];
struct Query{
    int l, r, id;
    ll ans;
}q[N];
bool cmp(Query A, Query B) {
    if((A.l - 1) / block  + 1 == (B.l - 1) / block + 1) return A.r < B.r;
    return (A.l - 1) / block + 1< (B.l - 1) / block + 1;
}
bool cmp_id(Query A, Query B) {
    return A.id < B.id ;
}
int l[N], r[N] ;
int sta[N], top;
ll f12[N], f21[N];
int f[N][22], pos[N][22], Log2[N];
ll ans ;
int Get_min(int L, int R) {
    ll k = Log2[R - L + 1];
    if(f[L][k] > f[R - (1LL << k) + 1][k]) return pos[R - (1LL << k) + 1][k] ;
    return pos[L][k] ;
}
void update1(int pos, int L, int R, int sign) {
    int p = Get_min(L, R) ;
    ll sum = f12[R] - f12[p] + 1ll * (p - L + 1) * a[p];
    ans += 1ll * sign * sum;
}
void update2(int pos, int L, int R, int sign) {
    int p = Get_min(L, R) ;
    ll sum = f21[L] - f21[p] + 1ll * (R - p + 1) * a[p] ;
    ans += 1ll * sign * sum ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m ;
    Log2[1] = 0;
    for(int i = 2; i <= n; i++) Log2[i] = Log2[i >> 1] + 1;
    block = sqrt(n) ;
    memset(f, INF, sizeof(f)) ;
    for(int i = 1; i <= n; i++) {
        cin >> a[i] ;
        f[i][0] = a[i] ;
        pos[i][0] = i ;
    }
    for(int j = 1; j <= 17; j++) {
        for(int i = 1; i + (1 << (j - 1)) <= n; i++) {
            if(f[i][j - 1] > f[i + (1 << (j - 1))][j - 1]) {
                f[i][j] = f[i + (1 << (j - 1))][j - 1] ;
                pos[i][j] = pos[i + (1 << (j - 1))][j - 1] ;
            } else {
                f[i][j] = f[i][j - 1];
                pos[i][j] = pos[i][j - 1] ;
            }
        }
    }
    for(int i = 1; i <= n + 1; i++) {
        while(top > 0 && a[sta[top]] >= a[i]) r[sta[top--]] = i ;
        sta[++top] = i;
    }
    top = 0;
    for(int i = n; i >= 0; i--) {
        while(top > 0 && a[sta[top]] >= a[i]) l[sta[top--]] = i;
        sta[++top] = i;
    }
    for(int i = 1; i <= n; i++)
        f12[i] = f12[l[i]] + 1ll * (i - l[i]) * a[i] ;
    for(int i = n; i >= 1; i--)
        f21[i] = f21[r[i]] + 1ll * (r[i] - i) * a[i] ;
    for(int i = 1; i <= m; i++) {
        int L, R;
        cin >> L >> R;
        q[i].l = L; q[i].r = R;
        q[i].id = i;
    }
    sort(q + 1, q + m + 1, cmp) ;
    int L = 1, R = 0;
    for(int i = 1; i <= m; i++) {
        for(; R < q[i].r; R++) update1(R + 1, L, R + 1, 1) ;
        for(; R > q[i].r; R--) update1(R, L, R, -1) ;
        for(; L < q[i].l; L++) update2(L, L, R, -1) ;
        for(; L > q[i].l; L--) update2(L - 1, L - 1, R, 1) ;
        q[i].ans = ans ;
    }
    sort(q + 1, q + m + 1, cmp_id) ;
    for(int i = 1; i <= m; i++)
        cout << q[i].ans << '\n' ;
    return 0;
}
  • 带修改莫队

之前说的莫队是不支持修改的,但其实也可以支持修改,只需要再加一维“时间状态”就行了,对于每个询问,新增一维,变为\([l,r,k]\),表示当前区间为\([l,r]\),之前经过\(k\)次修改操作的询问。
为什么这样是正确的呢?
因为我们如果知道了\([l,r,k]\)的答案,那么就很容易知道\([l+1,r,k],[l-1,r,k],[l,r-1,k],[l,r+1,k],[l,r,k-1],[l,r,k+1]\)对答案的影响。
具体来说,修改时间维度时,看看修改的位置是否在\([l,r]\)中,如果在则会对答案产生影响,否则直接修改就是了。之后区间端点左右移动时,遇到的位置也一定是完成\(k\)次修改过后的值了。

此时我们还是将区间进行分块,但现在要分为\(n^\frac{2}{3}\)块,每块长度为\(n^\frac{1}{3}\)。然后以左端点所在的块为第一关键字,右端点所在的块为第二关键字,修改次数为第三关键字进行排序。

可以证明这样的时间复杂度是\(O(n^\frac{5}{3})\)的。
证明方法就类似于上面的分析。

来看一道例题:
 

这就是个待修改莫队的模板题,多了一维对时间的修改,详细见代码吧:


Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50005, MAX = 1e6 + 5;
int n, m, block, num, M, l, r, t, ans;
char ss[5];
int c[N], cnt[MAX], last[N];
struct Upd{
    int pos, v;
}upd[N];
struct query{
    int l, r, ans, id, k;
}q[N];
bool cmp(query a, query b) {
    if((a.l - 1) / block == (b.l - 1) / block && (a.r - 1) / block == (b.r - 1) / block) return a.k < b.k;
    else if((a.l - 1) / block == (b.l - 1) / block) return (a.r - 1) / block < (b.r - 1) / block ;
    return (a.l - 1) / block < (b.l - 1) / block;
}
bool cmp_id(query a, query b) {
    return a.id < b.id;
}
void update_add(int T) {
    int pos = upd[T].pos, v = upd[T].v;
    last[T] = c[pos] ;
    if(l <= pos && pos <= r) {
        cnt[c[pos]]--;
        if(cnt[c[pos]] == 0) ans--;
        cnt[v]++;
        if(cnt[v] == 1) ans++;
    }
    c[pos] = v;
}
void update_del(int T) {
    int pos = upd[T].pos, v = upd[T].v;
    if(l <= pos && pos <= r) {
        cnt[v]--;
        if(cnt[v] == 0) ans--;
        c[pos] = last[T] ;
        cnt[c[pos]]++;
        if(cnt[c[pos]] == 1) ans++;
    } else c[pos] = last[T] ;
}
void update(int pos, int val) {
    cnt[c[pos]] += val;
    if(val == 1) {
        if(cnt[c[pos]] == 1) ans++;
    } else if(val == -1)
        if(cnt[c[pos]] == 0) ans--;
}
int main() {
    scanf("%d%d",&n, &m) ;
    block = pow(n, 0.666666) ;
    for(int i = 1; i <= n; i++) scanf("%d", &c[i]) ;
    for(int i = 1; i <= m; i++) {
        scanf("%s",ss) ;
        if(ss[0] == 'R') {
            int pos, v;
            scanf("%d%d",&pos, &v) ;
            upd[++num].pos = pos; upd[num].v = v ;
        } else {
            int l, r;
            scanf("%d%d",&l, &r) ;
            q[++M].l = l; q[M].r = r;
            q[M].id = M; q[M].k = num;
        }
    }
    sort(q + 1, q + M + 1, cmp) ;
    l = 1, r = 0, t = 0;
    for(int i = 1; i <= M; i++) {
        for(; t < q[i].k; t++) update_add(t + 1) ;
        for(; t > q[i].k; t--) update_del(t) ;
        for(; r < q[i].r; r++) update(r + 1, 1) ;
        for(; r > q[i].r; r--) update(r, -1) ;
        for(; l < q[i].l; l++) update(l, -1) ;
        for(; l > q[i].l; l--) update(l - 1, 1) ;
        q[i].ans = ans ;
    }
    sort(q + 1, q + M + 1, cmp_id) ;
    for(int i = 1; i <= M; i++) printf("%d\n", q[i].ans) ;
    return 0 ;
}
  • 树上莫队

如果可以对树进行分块的话,那么也可以对树上的询问用莫队来搞。刚好有一道树上分块的模板题
那么树上莫队的具体做法就为,首先将树进行分块,然后对所有的询问\([x,y]\),首先让\(x\)的时间戳小于\(y\)的时间戳,然后就按照\(x\)所在的块为第一关键字,以y的时间戳为第二关键字进行排序就好了。
之后考虑询问间的转移,方法为直接将\(x_i->x_{i+1}\)路径上面的所有点除开它们lca的状态取反,同理也将\(y_i->y_{i+1}\)路径上面的所有点除开它们lca的状态取反,计算答案就是了。
具体证明直接引用vfk的博客:

用S(v, u)代表 v到u的路径上的结点的集合。
用root来代表根结点,用lca(v, u)来代表v、u的最近公共祖先。
那么
S(v, u) = S(root, v) xor S(root, u) xor lca(v, u)
其中xor是集合的对称差。
简单来说就是节点出现两次消掉。
lca很讨厌,于是再定义
T(v, u) = S(root, v) xor S(root, u)
观察将curV移动到targetV前后T(curV, curU)变化:
T(curV, curU) = S(root, curV) xor S(root, curU)
T(targetV, curU) = S(root, targetV) xor S(root, curU)
取对称差:
T(curV, curU) xor T(targetV, curU)= (S(root, curV) xor S(root, curU)) xor (S(root, targetV) xor S(root, curU))
由于对称差的交换律、结合律:
T(curV, curU) xor T(targetV, curU)= S(root, curV) xorS(root, targetV)
两边同时xor T(curV, curU):
T(targetV, curU)= T(curV, curU) xor S(root, curV) xor S(root, targetV)
发现最后两项很爽……哇哈哈
T(targetV, curU)= T(curV, curU) xor T(curV, targetV)
(有公式恐惧症的不要走啊 T_T)
也就是说,更新的时候,xor T(curV, targetV)就行了。
即,对curV到targetV路径(除开lca(curV, targetV))上的结点,将它们的存在性取反即可。

因为lca我们不会算,所以最后单独考虑一下lca就行了。

这是树上莫队的第一种解法,另外还有一种就是直接将树转化为dfs序,压缩成线性的,同时每个结点维护两个时间戳,一个是进去的时间戳,一个是出来的时间戳。
那么对于树上的路径比如从\(x\)\(y\),若\(LCA(x,y)\)为其中之一,那么两个的路径在dfs序中的体现就为\(in[x]->in[y]\);否则就为\(out[x]->in[y]\)

这样写的话也需要一个数组来记录当前结点是否被算入答案中,每到一个位置也要将相应的状态取反。这里注意第二种情况lca也不会算上,所以也要单独考虑一下lca。

既然有了树上莫队,也有树上带修改莫队,好吧,其实原理都是差不多的。

看一个例题:
 

这基本上就是莫队算法的集大成者了。对答案的影响很好计算,维护一种颜色出现的次数就行了。
主要就是代码,我写了两种,一种是dfs序的,一种是树上分块的。


dfs序

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n, m, qq, block;
ll w[N], c[N], in[N], out[N], v[N];
vector <int> g[N] ;
struct Query{
    int l, r, id, k;
    ll ans ;
}q[N];
struct Upd{
    int x, y, last;
}upd[N];
bool cmp_id(Query A, Query B) {
    return A.id < B.id ;
}
bool cmp(Query A, Query B) {
    if((A.l - 1) / block == (B.l - 1) / block && (A.r - 1) / block == (B.r - 1) / block) return A.k < B.k;
    if((A.l - 1) / block == (B.l - 1) / block) return (A.r - 1) / block < (B.r - 1) / block;
    return (A.l - 1) / block < (B.l - 1) / block ;
}
int dfn;
ll a[2 * N], f[N][22], deep[N], pre[N];
void dfs(int u, int fa) {
    in[u] = ++dfn;
    a[dfn] = u ;
    deep[u] = deep[fa] + 1;
    for(auto v : g[u]) {
        if(v == fa) continue ;
        f[v][0] = u;
        for(int i = 1; i <= 17; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
        dfs(v, u);
    }
    out[u] = ++dfn;
    a[dfn] = u;
}
int LCA(int x, int y) {
    if(deep[x] < deep[y]) swap(x, y) ;
    for(int i = 17; i >= 0; i--)
        if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
    if(x == y) return x;
    for(int i = 17; i >= 0; i--)
        if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
    return f[x][0] ;
}
ll ans ;
int l, r, t, qnum, num;
bool vis[2 * N];
ll cnt[N] ;
void update(int u) {
    int col = c[u] ;
    if(vis[u]) ans -= 1ll * w[cnt[col]--] * v[col] ;
    else ans += 1ll * w[++cnt[col]] * v[col] ;
    vis[u] ^= 1;
}
void update_t(int T, int sign) {
    int u = upd[T].x, col = upd[T].y;
    if(sign == -1) col = upd[T].last;
    if(vis[u]) {
        update(u);
        c[u] = col;
        update(u);
    } else c[u] = col;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m >> qq;
    block = pow(n, 0.666666) ;
    for(int i = 1; i <= m; i++) cin >> v[i] ;
    for(int i = 1; i <= n; i++) cin >> w[i] ;
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v) ;
        g[v].push_back(u) ;
    }
    dfs(1, 0);
    for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i];
    for(int i = 1; i <= qq; i++) {
        int op, x, y;
        cin >> op >> x >> y;
        if(op == 1) {
            if(in[x] > in[y]) swap(x, y) ;
            int lca = LCA(x, y) ;
            q[++num].r = in[y];
            q[num].k = qnum;
            q[num].id = num;
            if(lca == x) q[num].l = in[x] ;
            else q[num].l = out[x];
        } else {
            upd[++qnum].x = x;
            upd[qnum].y = y;
            //pre[qnum] = (qnum == 1 ? c[x] : upd[qnum - 1].y) ;
            upd[qnum].last = pre[x];
            pre[x] = y;
            
        }
    }
    sort(q + 1, q + num + 1, cmp) ;
    l = 1, r = 0, t = 0;
    for(int i = 1; i <= num; i++) {
        for(; t < q[i].k; t++) update_t(t + 1, 1) ;
        for(; t > q[i].k; t--) update_t(t, -1) ;
        for(; r < q[i].r; r++) update(a[r + 1]) ;
        for(; r > q[i].r; r--) update(a[r]) ;
        for(; l < q[i].l; l++) update(a[l]) ;
        for(; l > q[i].l; l--) update(a[l - 1]) ;
        int lca = LCA(a[l], a[r]) ;
        if(lca != a[l] && lca != a[r]) {
            update(lca) ;
            q[i].ans = ans ;
            update(lca) ;
        } else q[i].ans = ans ;
    }
    sort(q + 1, q + num + 1, cmp_id) ;
    for(int i = 1; i <= num; i++)
        cout << q[i].ans << '\n' ;
    return 0;
}

 


树上分块

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n, m, qq, block;
int w[N], c[N], in[N], v[N];
int dfn;
int f[N][22], deep[N], pre[N];
int sta[N], bel[N];
int top, tot;
vector <int> g[N] ;
struct Query{
    int l, r, id, k;
    ll ans ;
}q[N];
struct Upd{
    int x, y, last;
}upd[N];
bool cmp_id(Query A, Query B) {
    return A.id < B.id ;
}
bool cmp(Query A, Query B) {
    if(bel[A.l] == bel[B.l] && bel[A.r] == bel[B.r]) return A.k < B.k;
    if(bel[A.l] == bel[B.l]) return bel[A.r] < bel[B.r] ;
    return bel[A.l] < bel[B.l] ;
}
void dfs(int u, int fa) {
    in[u] = ++dfn;
    deep[u] = deep[fa] + 1;
    int tmp = top ;
    for(auto v : g[u]) {
        if(v == fa) continue ;
        f[v][0] = u;
        for(int i = 1; i <= 16; i++) f[v][i] = f[f[v][i - 1]][i - 1] ;
        dfs(v, u);
        if(top - tmp >= block) {
            tot++;
            while(top > tmp) bel[sta[top--]] = tot;
        }
    }
    sta[++top] = u ;
}
int LCA(int x, int y) {
    if(deep[x] < deep[y]) swap(x, y) ;
    for(int i = 16; i >= 0; i--)
        if(deep[f[x][i]] >= deep[y]) x = f[x][i] ;
    if(x == y) return x;
    for(int i = 16; i >= 0; i--)
        if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i] ;
    return f[x][0] ;
}
ll ans ;
int l, r, t, qnum, num;
bool vis[N];
ll cnt[N] ;
void modify(int x) {
    int col = c[x] ;
    if(vis[x]) ans -= 1ll * w[cnt[col]--] * v[col] ;
    else ans += 1ll * w[++cnt[col]] * v[col] ;
    vis[x] ^= 1;
}
void update(int x, int y) {
    while(x != y) {
        if(deep[x] >= deep[y]) modify(x), x = f[x][0] ;
        else modify(y), y = f[y][0] ;
    }
}
void change(int x, int col) {
    if(vis[x]) {
        modify(x) ;
        c[x] = col ;
        modify(x) ;
    } else c[x] = col ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m >> qq;
    block = pow(n, 0.666666) ;
    for(int i = 1; i <= m; i++) cin >> v[i] ;
    for(int i = 1; i <= n; i++) cin >> w[i] ;
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v) ;
        g[v].push_back(u) ;
    }
    dfs(1, 0);
    for(int i = 1; i <= n; i++) cin >> c[i], pre[i] = c[i] ;
    for(int i = 1; i <= qq; i++) {
        int op, x, y;
        cin >> op >> x >> y;
        if(op == 1) {
            if(in[x] > in[y]) swap(x, y) ;
            q[++num].id = num;q[num].l = x;
            q[num].r = y;q[num].k = qnum;
        } else {
            upd[++qnum].x = x;upd[qnum].y = y;
            upd[qnum].last = pre[x] ;
            pre[x] = upd[qnum].y ;
        }
    }
    sort(q + 1, q + num + 1, cmp) ;
    l = q[1].l, r = q[1].r, t = 0;
    update(l, r);
    for(int i = 1; i <= num; i++) {
        for(;t < q[i].k; t++) change(upd[t + 1].x, upd[t + 1].y) ;
        for(;t > q[i].k; t--) change(upd[t].x, upd[t].last) ;
        update(l, q[i].l) ;
        update(r, q[i].r) ;
        int lca = LCA(q[i].l, q[i].r) ;
        modify(lca) ;
        q[q[i].id].ans = ans ;
        modify(lca) ;
        l = q[i].l, r = q[i].r ;
    }
    for(int i = 1; i <= num; i++) cout << q[i].ans << '\n' ;
    return 0;
}

 
再看看这个题:
 

这里询问的是出现次数大于等于k的颜色有多少种,看似比较棘手。实际上我们维护一个数组\(sum[i]\),表示大于等于\(i\)的颜色有多少种就行了。这个稍微想想还是比较清楚的。
代码如下:


Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 5;
int n, m, block;
int c[2 * N], a[2 * N], cnt[N];
int ans ;
vector <int> g[N];
struct Query{
    int l, r, k, id, ans;
}q[N];
bool cmp(Query A, Query B) {
    if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
    return (A.l - 1) / block < (B.l - 1) / block ;
}
bool cmp_id(Query A, Query B) {
    return A.id < B.id ;
}
int in[N], out[N] ;
int dfn, tot;
bool vis[2 * N], has[2 * N];
int sum[N] ;
void update(int pos, int val) {
    int col = c[a[pos]] ;
    if(val == 1) {
        if(vis[a[pos]]) return ;
        vis[a[pos]] = 1;
        sum[++cnt[col]]++;
    } else {
        if(!vis[a[pos]]) return ;
        vis[a[pos]] = 0;
        sum[cnt[col]--]--;
    }
}
void dfs(int u, int fa) {
    in[u] = ++dfn;
    a[dfn] = u ;
    int t = dfn;
    for(auto v : g[u]) {
        if(v == fa) continue ;
        dfs(v, u) ;
    }
    out[u] = ++dfn;
    a[dfn] = u ;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n >> m;
    block = sqrt(n) ;
    for(int i = 1; i <= n; i++) cin >> c[i] ;
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1, 0);
    for(int i = 1; i <= m; i++) {
        int v;
        cin >> v >> q[i].k;
        q[i].id = i;
        q[i].l = in[v] ; q[i].r = out[v] ;
    }
    sort(q + 1, q + m + 1, cmp);
    int l = 1, r = 0;
    for(int i = 1; i <= m; i++) {
        int k = q[i].k ;
        for(; r < q[i].r; r++) update(r + 1, 1) ;
        for(; r > q[i].r; r--) update(r, -1) ;
        for(; l < q[i].l; l++) update(l, -1) ;
        for(; l > q[i].l; l--) update(l - 1, 1) ;
        q[i].ans = sum[k] ;
    }
    sort(q + 1, q + m + 1, cmp_id) ;
    for(int i = 1; i <= m; i++)
        cout << q[i].ans << '\n' ;
    return 0;
}

 
最后再来看一道例题:
 

学过莫队之后是不是感觉很简单?
每次区间转移用树状数组维护信息即可。
代码如下:


Code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cmath>
using namespace std;
typedef long long ll;
const int N = 50005;
int c[N], a[N], b[N];
int l, r ;
int n, block;
struct Query{
    int l, r, id ;
    ll ans ;
}q[N];
bool cmp(Query A, Query B) {
    if((A.l - 1) / block == (B.l - 1) / block) return A.r < B.r;
    return (A.l - 1) / block < (B.l - 1) / block ;
}
int lowbit(int x) {
    return x & (-x) ;
}
void add(int x, int val) {
    for(int i = x; i < N; i += lowbit(i)) c[i] += val;
}
ll query(int x) {
    ll ans = 0;
    for(int i = x; i > 0; i -= lowbit(i)) ans += c[i];
    return ans ;
}
ll ans ;
void update(int x, int v, int sign) {
    if(sign == 1) {
        if(v == 1) {
            add(a[x], 1) ;
            int sum = query(a[x]) ;
            ans += r - l + 2 - sum ;
        } else {
            int sum = query(a[x]) ;
            ans -= (r - l + 1 - sum) ;
            add(a[x], -1) ;
        }
    } else {
        if(v == 1) {
            int sum = query(a[x] - 1) ;
            ans += sum;
            add(a[x], 1) ;
        } else {
            add(a[x], -1) ;
            int sum = query(a[x] - 1) ;
            ans -= sum ;
        }
    }
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    cin >> n;
    block = sqrt(n) ;
    for(int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i];
    sort(b + 1, b + n + 1);
    int D = unique(b + 1, b + n + 1) - b - 1;
    for(int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + D + 1, a[i]) - b;
    int Q;
    cin >> Q;
    for(int i = 1; i <= Q; i++) {
        int l, r;
        cin >> l >> r;
        q[i].l = l; q[i].r = r; q[i].id = i;
    }
    sort(q + 1, q + Q + 1, cmp) ;
    l = 1, r = 0;
    for(int i = 1; i <= Q; i++) {
        for(; r < q[i].r; r++) update(r + 1, 1, 1) ;
        for(; r > q[i].r; r--) update(r, -1, 1) ;
        for(; l < q[i].l; l++) update(l, -1, -1) ;
        for(; l > q[i].l; l--) update(l - 1, 1, -1) ;
        q[q[i].id].ans = ans ;
    }
    for(int i = 1; i <= Q; i++)
        cout << q[i].ans << '\n' ;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/heyuhhh/p/10827143.html
今日推荐