[洛谷P5286] 一个简单的询问

问题描述

给你一个长度为 \(N\) 的序列 \(a_i\)\(1\leq i\leq N\),和 \(q\) 组询问,每组询问读入 \(l_1,r_1,l_2,r_2\),需输出

\(\sum\limits_{x=0}^\infty \text{get}(l_1,r_1,x)\times \text{get}(l_2,r_2,x)\)

\(\text{get}(l,r,x)\) 表示计算区间 \([l,r]\) 中,数字 \(x\) 出现了多少次。

输入格式

第一行,一个数字 \(N\),表示序列长度。
第二行,\(N\) 个数字,表示 \(a_1\sim a_N\)
第三行,一个数字 \(Q\),表示询问个数。
\(4\sim Q+3\) 行,每行四个数字 \(l_1,r_1,l_2,r_2\),表示询问。

输出格式

对于每组询问,输出一行一个数字,表示答案。

样例输入

5
1 1 1 1 1
2
1 2 3 4
1 1 4 4

样例输出

4
1

说明

对于 \(20\%\) 的数据,\(1\leq N,Q\leq 1000\)
对于另外 \(30\%\) 的数据,\(1\leq a_i\leq 50\)
对于 \(100\%\) 的数据,\(N,Q\leq 50000\)\(1\leq a_i\leq N\)\(1\leq l_1\leq r_1\leq N\)\(1\leq l_2\leq r_2\leq N\)

数据范围与原题相同,但测试数据由 LibreOJ 自制,并非原数据。

注意:答案有可能超过 int 的最大值。

解析

区间问题往往可以想到前缀和。如果想用前缀和的形式表示这个式子,那么可以按照如下过程化简:
\[ \begin{aligned} \sum_{i=0}^{\infty} get(l_1, r_1, x) \times get(l_2, r_2, x) = &\sum_{i=0}^{\infty} get(0, r_1, x) \times get(0, r_2, x)\\ - &\sum_{i=0}^{\infty} get(0, l_1-1, x) \times get(0, r_2, x)\\ - &\sum_{i=0}^{\infty} get(0, r_1, x) \times get(0, l_2-1, x) \\ + &\sum_{i=0}^{\infty} get(0, l_1-1, x) \times get(0, l_2-1, x) \end{aligned} \]
这样,我们就把一个询问拆成了4个可以用莫队维护的询问。对于每一个询问,维护 \(num[0][x]\) 表示在区间 \([1,l]\)\(x\) 出现了多少次,\(num[1][x]\) 表示区间 \([1,r]\)\(x\) 出现了多少次。那么答案就是 \(\sum\limits_{x=0}^\infty num[0][x]*num[1][x]\) 。至于修改操作,可以这么看:假设一个数原来是 \(a*b\) ,现在要把它变成 \((a+1)*b\) ,其实就相当于在原来的基础上加上一个 \(b\) 。这里也是同理。

代码

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define int long long
#define N 50002
using namespace std;
struct query{
    int l,r,id;
}q[4*N];
int n,m,i,a[N],b[N],gap,cnt,num[2][N],sum,ans[4*N],l,r;
int read()
{
    char c=getchar();
    int w=0;
    while(c<'0'||c>'9') c=getchar();
    while(c<='9'&&c>='0'){
        w=w*10+c-'0';
        c=getchar();
    }
    return w;
}
int my_comp(const query &x,const query &y)
{
    if(b[x.l]==b[y.l]) return x.r<y.r;
    return x.l<y.l;
}
void add(int op,int x)
{
    num[op][a[x]]++;
    sum+=num[op^1][a[x]];
}
void del(int op,int x)
{
    num[op][a[x]]--;
    sum-=num[op^1][a[x]];
}
signed main()
{
    n=read();
    gap=sqrt(1.0*n);
    for(i=1;i<=n;i++) a[i]=read();
    for(i=1;i<=n;i++) b[i]=(i-1)/gap+1;
    m=read();
    for(i=1;i<=m;i++){
        int l1=read(),r1=read(),l2=read(),r2=read();
        q[++cnt]=(query){min(r1,r2),max(r1,r2),cnt};
        q[++cnt]=(query){min(r1,l2-1),max(r1,l2-1),cnt};
        q[++cnt]=(query){min(l1-1,r2),max(l1-1,r2),cnt};
        q[++cnt]=(query){min(l1-1,l2-1),max(l1-1,l2-1),cnt};
    }
    sort(q+1,q+cnt+1,my_comp);
    for(i=1;i<=cnt;i++){
        while(l<q[i].l) add(0,++l);
        while(l>q[i].l) del(0,l--);
        while(r<q[i].r) add(1,++r);
        while(r>q[i].r) del(1,r--);
        ans[q[i].id]=sum;
    }
    for(i=1;i<=cnt;i+=4) printf("%lld\n",ans[i]-ans[i+1]-ans[i+2]+ans[i+3]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/LSlzf/p/12194881.html