Description
There is a tree of n nodes, each point has a color, the same color does not exceed t (t<=20) times on the tree, find the number of simple paths without the same color, (u, v) and ( v, u) is regarded as the same path, and (u, u) is also regarded as a path.
Solution
Consider what kind of point pairs (paths) cannot be selected.
has a
We find that for a restricted point pair, it dyes the lattice
Use scan line statistics, and the line segment tree mark can be permanent.
Analyze the time complexity, each color appears t times, and each pair of enumeration is
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
using namespace std;
typedef long long ll;
const int N=2e5+10,M=2e5+10;
int to[M],nx[M],ls[N],num=0;
void link(int u,int v){
to[++num]=v,nx[num]=ls[u],ls[u]=num;
}
struct node{
int x,l,r,v;
}a[N*40];
bool cmp(node x,node y){
return x.x<y.x;
}
vector<int> c[N];
int L[N],R[N],dep[N],tot=0;
int fa[N][20];
int n;
void pre(int x,int fr){
L[x]=++tot,dep[x]=dep[fr]+1;
rep(i,x){
int v=to[i];
if(v==fr) continue;
fa[v][0]=x;
pre(v,x);
}
R[x]=tot;
}
void putin(int x1,int x2,int y1,int y2){
if(x1>x2 || y1>y2) return;
a[++tot].x=x1,a[tot].l=y1,a[tot].r=y2,a[tot].v=1;
a[++tot].x=x2+1,a[tot].l=y1,a[tot].r=y2,a[tot].v=-1;
}
int get(int x,int t){
fd(i,17,0) if(dep[fa[x][i]]>dep[t]) x=fa[x][i];
return x;
}
void put(int x,int y){
if(L[x]<=L[y] && R[x]>=R[y]){
int p=get(y,x);
putin(1,L[p]-1,L[y],R[y]);
putin(R[p]+1,n,L[y],R[y]);
}
else if(L[y]<=L[x] && R[y]>=R[x]){
int p=get(x,y);
putin(L[x],R[x],1,L[p]-1);
putin(L[x],R[x],R[p]+1,n);
}
else putin(L[x],R[x],L[y],R[y]);
}
int tr[N<<2],lz[N<<2];
int max(int x,int y){
return x>y?x:y;
}
void add(int v,int l,int r,int x,int y,int t){
if(l==x && r==y) {
lz[v]+=t,tr[v]=lz[v]>0?r-l+1:tr[v<<1]+tr[(v<<1)+1];
return;
}
int mid=(l+r)>>1;
if(y<=mid) add(v<<1,l,mid,x,y,t);
else if(x>mid) add((v<<1)+1,mid+1,r,x,y,t);
else add(v<<1,l,mid,x,mid,t),add((v<<1)+1,mid+1,r,mid+1,y,t);
tr[v]=lz[v]?r-l+1:tr[v<<1]+tr[(v<<1)+1];
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d",&n);
fo(i,1,n){
int x;
scanf("%d",&x);
c[x].push_back(i);
}
fo(i,2,n){
int u,v;
scanf("%d %d",&u,&v);
link(u,v),link(v,u);
}
pre(1,0);
fo(j,1,17)
fo(i,1,n) fa[i][j]=fa[fa[i][j-1]][j-1];
tot=0;
fo(i,1,n){
int o=c[i].size();
if(o<2) continue;
fo(j,1,o-1)
fo(k,0,j-1) put(c[i][k],c[i][j]),put(c[i][j],c[i][k]);
}
sort(a+1,a+tot+1,cmp);
ll ans=0;
int p=0;
fo(i,1,n+1){
while(p<tot && a[p+1].x==i){
p++;
add(1,1,n,a[p].l,a[p].r,a[p].v);
}
ans+=tr[1];
}
ans=(ll)n*n-ans;
printf("%lld",(ans-n)/2+n);
}