牛客练习赛81 D 小 Q 与树 (权值线段树+dsu on tree)

tp链接

在这里插入图片描述
min(a[u],a[v])*dis(u,v)这个式子带min函数,dis函数,都比较麻烦,肯定需要化简的。

trick:

  • dis(u,v)可以引入LCA,转化成dis(1,u) + dis(1,v) - 2*dis(1,LCA)
  • 对于min函数的处理,我们可以分类讨论,在rt子树中,把点权大于等于a[rt]的节点分为一类,小于a[rt]的节点分一类
  • 于是可以把式子写成:
    对于
    a x 1 , a x 2 , . . . a x c n t 1 > = a u a_{x1},a_{x2},...a_{xcnt1} >= a_{u} ax1,ax2,...axcnt1>=au
    贡献为:
    a u ∗ ( d e p [ a u ] + d e p [ a x 1 ] − 2 ∗ d e p [ L C A ] ) a_{u}*(dep[a_{u}]+dep[a_{x1}]-2*dep[LCA]) au(dep[au]+dep[ax1]2dep[LCA])
    + a u ∗ ( d e p [ a u ] + d e p [ a x 2 ] − 2 ∗ d e p [ L C A ] ) +a_{u}*(dep[a_{u}]+dep[a_{x2}]-2*dep[LCA]) +au(dep[au]+dep[ax2]2dep[LCA])

    + a u ∗ ( d e p [ a u ] + d e p [ a x c n t 1 ] − 2 ∗ d e p [ L C A ] ) +a_{u}*(dep[a_{u}]+dep[a_{xcnt1}]-2*dep[LCA]) +au(dep[au]+dep[axcnt1]2dep[LCA])
    整理得到:
    在这里插入图片描述

对于:
a y 1 , a y 2 , . . . a y c n t 2 < a u a_{y1},a_{y2},...a_{ycnt2} < a_{u} ay1,ay2,...aycnt2<au
贡献为:
在这里插入图片描述
整理得到:
在这里插入图片描述

所以我们可以对点权建权值线段树,维护4个变量:
①Σa[i] (点权为a[i]的点的点权和)
②Σcnt (点权为a[i]的点的个数)
③Σdep[i] (点权为a[i]的点的深度和)
④Σa[i]*dep[i] (点权为a[i]的点的点权 * 深度之和)

然后枚举LCA作为子树根节点,计算每个LCA的贡献即可。

用树状数组1000ms就行。
卑微线段树常数大,离散化+动态开点跑了1600ms…

#include<bits/stdc++.h>
using namespace std;
//#pragma GCC optimize(2)
#define ull unsigned long long
#define ll long long
#define pii pair<int, int>
#define pdd pair<double, double>
#define re register
const int maxn = 2e5 + 10;
const ll mod = 998244353;
const ll inf = (ll)4e17+5;
const int INF = 1e9 + 7;
const double pi = acos(-1.0);
ll inv(ll b){
    
    if(b==1)return 1;return(mod-mod/b)*inv(mod%b)%mod;}
inline ll read()
{
    
    
    ll x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){
    
    if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){
    
    x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
//给定带点权的有根树 求 Σmin(a[i],b[j])*dis(i,j)
//分类讨论 枚举根节点作为LCA 求贡献 
vector<int> g[maxn];
int n,n0;
ll a[maxn],b[maxn],idx[maxn];
int in[maxn],pos[maxn],clk,son[maxn],siz[maxn],dep[maxn];
ll ret;
//权值线段树模板 
struct node 
{
    
    
	int cnt;
	ll sum_d,sum_a,sum_ad;
	node operator +(const node &f)const 
	{
    
    
		node t;
		t.cnt=cnt+f.cnt;
		t.sum_a=(sum_a+f.sum_a)%mod;
		t.sum_d=(sum_d+f.sum_d)%mod;
		t.sum_ad=(sum_ad+f.sum_ad)%mod;
		return t;
	}
}tree[maxn*40];
int rt_node=0,cnt=0,lc[maxn*40],rc[maxn*40];//动态开点 rt_node用于传引用
inline void pushup(int rt)
{
    
    
	tree[rt]=tree[lc[rt]]+tree[rc[rt]];
}
inline void upd(int &rt,int l,int r,int pos,int v,int f) //加入顶点v f为1或-1  表示加入或删除
{
    
    
	if(!rt) rt=++cnt;
	if(l==r) 
	{
    
    
		tree[rt].cnt+=f;
		tree[rt].sum_a=((tree[rt].sum_a+f*a[v])%mod + mod)%mod;
		tree[rt].sum_d=((tree[rt].sum_d+f*dep[v])%mod + mod)%mod;
		tree[rt].sum_ad=((tree[rt].sum_ad+f*a[v]*dep[v]%mod)%mod + mod)%mod;
		return ;
	}
	int mid=l+r>>1;
	if(pos<=mid) upd(lc[rt],l,mid,pos,v,f);
	else upd(rc[rt],mid+1,r,pos,v,f);
	pushup(rt);
}
inline node qry(int rt,int l,int r,int vl,int vr)
{
    
    
	if(!rt || l>r)
	{
    
    
		node t={
    
    0,0,0,0};
		return t;
	}
	if(vl<=l && r<=vr) return tree[rt];
	int mid=l+r>>1;
	if(vr<=mid) return qry(lc[rt],l,mid,vl,vr);
	else if(vl>mid) return qry(rc[rt],mid+1,r,vl,vr);
	return qry(lc[rt],l,mid,vl,vr)+qry(rc[rt],mid+1,r,vl,vr);
}
//求重儿子+dfs序
void dfs1(int rt,int fa)
{
    
    
	dep[rt]=dep[fa]+1;
	siz[rt]=1;
	in[rt]=++clk;
	pos[clk]=rt;
	for(int i:g[rt])
	{
    
    
		if(i==fa) continue;
		dfs1(i,rt);
		siz[rt]+=siz[i];
		if(siz[i] > siz[son[rt]]) son[rt]=i;
	}
}
int LCA;
inline ll cal(int u) //分别计算4部分 注意取模可能出现负数
{
    
    
	ll ret=0;
	node t1=qry(1,1,n0,idx[u],n0),t2=qry(1,1,n0,1,idx[u]-1);
	ret=(ret + (1ll*dep[u]-2*dep[LCA]+mod) % mod * t1.cnt % mod * a[u] % mod) % mod;
	ret=(ret + a[u] * t1.sum_d % mod)%mod;

	ret=(ret + (1ll*dep[u]-2*dep[LCA]+mod) % mod * t2.sum_a % mod)%mod;
	ret=(ret+t2.sum_ad)%mod;
	return ret;
}
inline void add(int rt) 
{
    
    
	for(int i=in[rt];i<in[rt]+siz[rt];i++)
	{
    
    
		int u=pos[i];
		ret=(ret+cal(u))%mod;
	}
}
inline void up(int rt,int v)//子树每个点都加入
{
    
    
	for(int i=in[rt];i<in[rt]+siz[rt];i++)
	{
    
    
		int u=pos[i];
		upd(rt_node,1,n0,idx[u],u,v);
	}
}
void dfs2(int rt,int fa,bool ok) 
{
    
    
	for(int i:g[rt]) 
	{
    
    
		if(i==son[rt] || i==fa) continue;
		dfs2(i,rt,0);
	}
	if(son[rt]) dfs2(son[rt],rt,1);
	LCA=rt;
	upd(rt_node,1,n0,idx[rt],rt,1);
	ret=(ret+cal(rt))%mod;//根节点也会产生贡献 需要在upd之前加
	
	for(int i:g[rt]) 
	{
    
    
		if(i==son[rt] || i==fa) continue;
		add(i);
		up(i,1);
	}
	if(!ok) up(rt,-1);
}
int main()
{
    
    	
	scanf("%d",&n);
	for(int i=1;i<=n;i++) 
	{
    
    
		a[i]=read();
		b[i]=a[i];
	}
	sort(b+1,b+n+1);
	n0=unique(b+1,b+n+1)-b-1;//离散化
	for(int i=1;i<=n;i++)
	{
    
    
		idx[i]=lower_bound(b+1,b+n0+1,a[i])-b;
	}
	for(int i=1,u,v;i<n;i++)
	{
    
    
		u=read();
		v=read();
		g[u].push_back(v);
		g[v].push_back(u);
	}
	dfs1(1,0);
	dfs2(1,0,1);
	cout<<ret*2%mod<<'\n';
	return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_46030630/article/details/120673457