[CF915F]Imbalance Value of a Tree

[CF915F]Imbalance Value of a Tree

题目大意:

一棵\(n(n\le10^6)\)个结点的树,每个结点有一个权值\(w_i\)。定义\(I(i,j)\)\(i\)\(j\)之间简单路径上最大权值与最小权值之差,求\(\displaystyle\sum_{i=1}^n\sum_{j=1}^nI(i,j)\)

思路:

分别计算路径最大权值之和与最小权值之和。以最大权值之和为例,在图中按权值从大到小枚举每一个点,则对于该连通块中每一个经过该点的路径,该点为路径上权值最大的点,可以计算该点对答案的贡献,并将计算完贡献的点从图中删去。

由于删点是一个难以实现的操作,因此可以将操作改为加点操作,用并查集维护连通性即可。最小权值和同理。时间复杂度\(\mathcal O(n\log n)\)

源代码:

#include<cstdio>
#include<cctype>
#include<numeric>
#include<algorithm>
#include<forward_list>
using int64=long long;
inline int getint() {
    register char ch;
    while(!isdigit(ch=getchar()));
    register int x=ch^'0';
    while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
    return x;
}
constexpr int N=1e6+1;
int seq[N],pos[N],w[N];
std::forward_list<int> e[N];
inline void add_edge(const int &u,const int &v) {
    e[u].emplace_front(v);
    e[v].emplace_front(u);
}
struct DisjointSet {
    int anc[N],size[N];
    void reset(const int &n) {
        std::fill(&size[1],&size[n+1],1);
        std::iota(&anc[1],&anc[n+1],1);
    }
    int find(const int &x) {
        return x==anc[x]?x:anc[x]=find(anc[x]);
    }
    void merge(const int &x,const int &y) {
        size[find(y)]+=size[find(x)];
        anc[find(x)]=find(y);
    }
};
DisjointSet s;
int main() {
    const int n=getint();
    for(register int i=1;i<=n;i++) w[i]=getint();
    for(register int i=1;i<n;i++) {
        add_edge(getint(),getint());
    }
    int64 max=0,min=0;
    s.reset(n);
    std::iota(&seq[1],&seq[n+1],1);
    std::sort(&seq[1],&seq[n+1],[](const int &a,const int &b){return w[a]<w[b];});
    for(register int i=1;i<=n;i++) pos[seq[i]]=i;
    for(register int i=1;i<=n;i++) {
        const int &x=seq[i];
        int64 last=1,tmp=0;
        for(register auto &y:e[x]) {
            if(pos[y]>pos[x]) continue;
            tmp+=last*s.size[s.find(y)];
            last+=s.size[s.find(y)];
            s.merge(x,y);
        }
        max+=(int64)w[x]*tmp;
    }
    s.reset(n);
    std::iota(&seq[1],&seq[n+1],1);
    std::sort(&seq[1],&seq[n+1],[](const int &a,const int &b){return w[a]>w[b];});
    for(register int i=1;i<=n;i++) pos[seq[i]]=i;
    for(register int i=1;i<=n;i++) {
        const int &x=seq[i];
        int64 last=1,tmp=0;
        for(register auto &y:e[x]) {
            if(pos[y]>pos[x]) continue;
            tmp+=last*s.size[s.find(y)];
            last+=s.size[s.find(y)];
            s.merge(x,y);
        }
        min+=(int64)w[x]*tmp;
    }
    printf("%lld\n",max-min);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/skylee03/p/9107833.html