[XSY] 线图(树形DP、分类讨论)

线图

  • 如图,每个L(L(T))上的点对应T上的一条三点链
    在这里插入图片描述
  • 在连接L(L(T))上两点,当且仅当两点代表的三点链在T上有共边,且边权为 共边边权*2+非共边1边权+非共边2边权
    在这里插入图片描述
  • 在L(L(T))上从点u走到点v,等价于u代表的三点链在T上删掉自己的一条边,然后在剩下来的两个点上选一个点接一条边,转化为v代表的三点链,代价为 不动边边权*2+删边边权+接边边权
    在这里插入图片描述
  • 先考虑两个三点链在树上的最短路。此处不赘述,大体上的分类讨论如图:
    在这里插入图片描述
  • 拓展到求任意两三点链的最短路径总和,可以用树形DP实现,考虑如何做到不重不漏:
    1.首先每对不相交三点链的贡献可以拆成两部分:树上最短路径的贡献+三点链的贡献。三点链的贡献只与树上最短路径连接的是三点链的中点还是端点有关,与具体选择什么样的最短路径无关:
    在这里插入图片描述
    一条边作为树上最短路径的一部分时,贡献永远是自身边权*4
    所以我们可以对每条边分别讨论它 作为三点链的一部分的贡献 和 作为树上最短路的一部分的贡献,再把这两部分的贡献加起来。
    2.再考虑相交的三点链对。
    对于X型,我们对每条边讨论它为四边中第1小、第2小、第3小、第4小边时自身的贡献,再把这些贡献加起来。
    对于Y型,边(u,v)(u=fa[v],u的度数为d)作为Y型的共边出现的情况有 ( d − 1 ) ( d − 2 ) / 2 (d-1)(d-2)/2 (d1)(d2)/2种,作为非共边出现的情况有 ( d − 1 ) ( d − 2 ) (d-1)(d-2) (d1)(d2)种,我们在扫到这条边时直接给答案加上 ( d − 1 ) ( d − 2 ) / 2 × 4 × 边 权 (d-1)(d-2)/2\times4\times边权 (d1)(d2)/2×4×即可。
    对于Z型,我们在共边处统计出整个Z型的贡献。

ps:为保证不重不漏地考虑到所有的三点链,我们在DP到树节点u时,就只考虑以u为中点的三点链

#include<iostream>
#include<cstdio>
#include<algorithm>
#define int long long
using namespace std;
const int N=5e5+5;
const int mod=998244353;
int read(){
    
    
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
    
    if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){
    
    x=x*10+ch-'0';ch=getchar();}
    return x*f;
}
int add(int a,int b){
    
    return (a+b)%mod;} 
int dec(int a,int b){
    
    return ((a-b)%mod+mod)%mod;}
void Add(int &a,int b){
    
    a=add(a,b);}
struct Edge{
    
    
	int v,w,nxt;
}e[N<<1];
int n,d[N],head[N],cnt=0,ans;
int f[N],g[N];
//f[i]:i子树内三点链的个数
//g[i]:i子树外三点链的个数
int addedge(int u,int v,int w){
    
    
	e[++cnt].v=v;e[cnt].w=w;e[cnt].nxt=head[u];head[u]=cnt;
}
void dfs(int u,int ff){
    
    
    f[u]=1ll*d[u]*(d[u]-1)/2%mod;
    for(int i=head[u];i;i=e[i].nxt){
    
    
    	if(e[i].v!=ff){
    
    
    		dfs(e[i].v,u);
            Add(f[u],f[e[i].v]);
		}
	}
}
int to[N],val[N],Sw[N];
int su[N],sv[N];
int pre[N],suf[N];
bool cmp(int a,int b){
    
    
	return val[a]<val[b];
}
void work(int u,int ff){
    
    
    for(int i=head[u];i;i=e[i].nxt){
    
    
    	int v=e[i].v;
    	if(v==ff) continue;
    	work(v,u);
	}
    int tot=0;
    for(int i=head[u];i;i=e[i].nxt){
    
    
    	int v=e[i].v;
    	to[++tot]=v;
    	val[v]=e[i].w; 
    	Add(Sw[u],e[i].w);
    	su[v]=v==ff?f[u]:g[v];
    	sv[v]=v==ff?g[u]:f[v];
	}
    sort(to+1,to+tot+1,cmp);
    pre[0]=suf[tot+1]=0; 
    for(int i=1;i<=tot;i++) pre[i]=add(pre[i-1],sv[to[i]]);
    for(int i=tot;i;i--) suf[i]=add(suf[i+1],sv[to[i]]);
    for(int i=1;i<=tot;i++){
    
    
        int v=to[i],du=d[u],dv=d[v],w=val[v];
        //处理这条边为相交的三点链做的贡献 
        Add(ans,1ll*(tot-i)*(tot-i-1)*(tot-i-2)/6*9%mod*w%mod);//以u为中心的X型,当前边为第1小边 
        Add(ans,1ll*(i-1)*(tot-i)*(tot-i-1)/2*7%mod*w%mod);//以u为中心的X型,当前边为第2小边 
        Add(ans,1ll*(i-1)*(i-2)/2*(tot-i)*5%mod*w%mod);//以u为中心的X型,当前边为第3小边 
        Add(ans,1ll*(i-1)*(i-2)*(i-3)/6*3%mod*w%mod);//以u为中心的X型,当前边为第4小边 
        Add(ans,1ll*(tot-1)*(tot-2)/2*4%mod*w%mod);//以u为中心的Y型
        if(v!=ff) Add(ans,1ll*(du-1)*(dv-1)*2%mod*w%mod+add(1ll*(Sw[u]-w)*(dv-1)%mod,1ll*(Sw[v]-w)*(du-1)%mod));//以u为中心的Z型 
        //处理这条边为不相交的三点链做的贡献 
        if(v!=ff) Add(ans,1ll*(sv[v]-(dv-1))*(su[v]-(du-1))*4%mod*w%mod);//这条边在树上最短路径中,注意if 
        Add(ans,1ll*dec(su[v],1ll*du*(du-1)/2%mod)*(tot-2)%mod*w%mod); 
        Add(ans,1ll*dec(su[v],pre[i-1]+1ll*du*(du-1)/2%mod)*(tot-i-1)%mod*2*w%mod);
        Add(ans,1ll*dec(su[v],suf[i+1]+1ll*du*(du-1)/2%mod)*(tot-i)%mod*2*w%mod);//树上最短路径连中点u
		Add(ans,(1ll*(tot-1)*3*w%mod+1ll*(Sw[u]-w))*(sv[v]-(dv-1))%mod);//树上最短路径连端点v 
    }
}
signed main(){
    
    
    n=read();
    for(int i=1;i<n;i++){
    
    
    	int u=read(),v=read(),w=read();
    	addedge(u,v,w);
    	addedge(v,u,w);
    	d[u]++;d[v]++;
	}
    dfs(1,0);
    for(int i=1;i<=n;i++) g[i]=f[1]-f[i];
    work(1,0);
    cout<<dec(ans,0)<<endl;
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Emma2oo6/article/details/114396147