链接:https://www.nowcoder.com/acm/contest/123/D
来源:牛客网
题意:
在根节点为0的树上(编号:0,1,2...n)。一条边上有两个值 poweri ,numi,相当于网络流:poweri相当于路上的cost,numi相当于容量,相当于汇点为0,源点为所有叶子节点,求最大费用。
样例:
7
0 100 0
1 2 3
2 2 5
1 5 1
2 1 3
3 2 4
4 3 2
输出:
33分析:想到可以贪心优先分配路上总power最大的叶子节点,然后求路上最小值mi,ans加上power*mi,然后把路径上的num都减去最小值mi,重复下去。然后用树链剖分维护就行。(看别人代码不用树链刨分直接修改也能过)。
#include<iostream> #include<stdio.h> #include<algorithm> #include<string.h> #include<vector> using namespace std; typedef long long int ll; const int inf = 1e9; const int maxn = 100000; int n,son[maxn+5],sz[maxn+5],fa[maxn+5],top[maxn+5],tid[maxn+5],Rank[maxn+5],tim,dep[maxn+5]; int num[maxn+5],power[maxn+5],cost[maxn+5],st[maxn+5]; int mi[maxn*4],add[maxn*4]; vector<int>g[maxn+5]; void dfs1(int x,int f) { sz[x] = 1, fa[x] = f, son[x]=-1; int len = g[x].size(); for(int i=0; i<len; i++) { int y = g[x][i]; if(y==f) continue; dep[y] = dep[x] + 1; cost[y]+=cost[x]+power[y]; dfs1(y,x); sz[x]+=sz[y]; if(son[x]==-1||sz[son[x]]<sz[y]) son[x] = y; } } void dfs2(int x,int tp) { top[x] = tp; tid[x] = ++tim; Rank[tim] = x; if(son[x]!=-1) dfs2(son[x],tp); for(int i=0,len=g[x].size(); i<len; i++) { int y = g[x][i]; if(y!=son[x]&&y!=fa[x]) dfs2(y,y); } } void pushup(int o) { mi[o] = min(mi[o*2],mi[o*2+1]); } void build(int o,int l,int r) { add[o]=0; if(l==r) { mi[o] = num[Rank[l]]; return; } int mid = (l+r)>>1; build(o*2,l,mid); build(o*2+1,mid+1,r); pushup(o); } void pushdown(int o) { if(add[o]!=0) { int ls = o*2, rs = o*2+1; mi[ls]+=add[o], mi[rs]+=add[o]; add[ls]+=add[o], add[rs]+=add[o]; add[o] = 0; } } int query(int o,int l,int r,int L,int R) { if(L<=l&&r<=R) return mi[o]; int mid = (l+r)>>1; pushdown(o); if(R<=mid) return query(o*2,l,mid,L,R); else { if(mid<L) return query(o*2+1,mid+1,r,L,R); return min(query(o*2,l,mid,L,R),query(o*2+1,mid+1,r,L,R)); } } void updata(int o,int l,int r,int L,int R,int ad) { if(L<=l&&r<=R) { mi[o]+=ad; add[o]+=ad; } else { pushdown(o); int mid = (l+r)>>1; if(R<=mid) updata(o*2,l,mid,L,R,ad); else { if(mid<L) updata(o*2+1,mid+1,r,L,R,ad); else { updata(o*2,l,mid,L,R,ad); updata(o*2+1,mid+1,r,L,R,ad); } } pushup(o); } } int Find(int u,int v) { int f1 = top[u], f2 = top[v], tmp = inf; while(f1!=f2) { if(dep[f1]<dep[f2]) swap(f1,f2), swap(u,v); tmp = min(tmp,query(1,1,tim,tid[f1],tid[u])); u = fa[f1]; f1 = top[u]; } if(u==v) return tmp; if(dep[u]>dep[v]) swap(u,v); return min(tmp,query(1,1,tim,tid[son[u]],tid[v])); } void Updata(int va,int vb,int ad) { int f1 = top[va], f2 = top[vb]; while (f1 != f2) { if (dep[f1] < dep[f2]) { swap(f1, f2); swap(va, vb); } updata(1,1,tim,tid[f1],tid[va],ad); va = fa[f1]; f1 = top[va]; } if (va == vb) return; if (dep[va] > dep[vb]) swap(va, vb); updata(1, 1, tim, tid[son[va]], tid[vb],ad); } bool cmp(int x,int y) { return cost[x] > cost[y]; } int main() { scanf("%d",&n); for(int i=0; i<=n; i++) g[i].clear(); for(int i=1,f; i<=n; i++) { scanf("%d %d %d",&f,&num[i],&power[i]); g[f].push_back(i); } dep[0] = 0; dfs1(0,-1); tim = -1; dfs2(0,0); build(1,1,tim); int cnt = 0; for(int i=1; i<=n; i++) if(g[i].size()==0) st[++cnt] = i; sort(st+1,st+cnt+1,cmp); ll ans = 0; for(int i=1; i<=cnt; i++) { int x = st[i]; int w = Find(x,0); if(w<=0) continue; ans += cost[x]*1ll*w; Updata(x,0,-w); } printf("%lld\n",ans); return 0; }