洛谷 P3676 小清新数据结构题

https://www.luogu.org/problemnew/show/P3676

这题被我当成动态dp去做了,码了4k,搞了一个换根的动态dp

  1 #include<cstdio>
  2 #include<algorithm>
  3 #include<cstring>
  4 using namespace std;
  5 typedef long long ll;
  6 struct E
  7 {
  8     int to,nxt;
  9 }e[400011];
 10 int f1[200011],ne;
 11 struct P1
 12 {
 13     int len;ll a,b,c,d,e,f;
 14     //长度,(点权)和,后缀和之和,后缀和的平方之和,(答案)和
 15     //前缀和之和,前缀和的平方之和
 16 };
 17 struct P2
 18 {
 19     ll a,b;
 20     //点权和,答案和
 21 };
 22 ll a[200101];
 23 int sz[200101],hson[200101],ff[200101];
 24 int b[200101],pl[200101];
 25 int n,m;
 26 inline void merge(P1 &c,const P1 &a,const P1 &b)
 27 {
 28     c.len=a.len+b.len;
 29     c.a=a.a+b.a;
 30     c.b=b.b+a.b+b.a*a.len;
 31     c.c=b.c+b.a*b.a*a.len+2*a.b*b.a+a.c;
 32     c.d=a.d+b.d;
 33     c.e=a.e+b.e+a.a*b.len;
 34     c.f=a.f+a.a*a.a*b.len+2*b.e*a.a+b.f;
 35 }
 36 inline void initnode(P1 &c,const P2 &a)
 37 {
 38     c.len=1;c.a=c.b=c.e=a.a;c.c=c.f=a.a*a.a;c.d=a.b;
 39 }
 40 namespace S
 41 {
 42 #define lc (num<<1)
 43 #define rc (num<<1|1)
 44     P1 d[800101];
 45     inline void upd(int num){merge(d[num],d[lc],d[rc]);}
 46     P1 x;int L;
 47     void _setx(int l,int r,int num)
 48     {
 49         if(l==r)
 50         {
 51             d[num]=x;
 52             return;
 53         }
 54         int mid=(l+r)>>1;
 55         if(L<=mid)    _setx(l,mid,lc);
 56         else    _setx(mid+1,r,rc);
 57         upd(num);
 58     }
 59     P1 getx(int L,int R,int l,int r,int num)
 60     {
 61         if(L<=l&&r<=R)    return d[num];
 62         int mid=(l+r)>>1;
 63         if(L<=mid&&mid<R)
 64         {
 65             P1 x;
 66             merge(x,getx(L,R,l,mid,lc),getx(L,R,mid+1,r,rc));
 67             return x;
 68         }
 69         else if(L<=mid)
 70             return getx(L,R,l,mid,lc);
 71         else if(mid<R)
 72             return getx(L,R,mid+1,r,rc);
 73         else
 74             exit(-1);
 75     }
 76 }
 77 void dfs1(int u,int fa)
 78 {
 79     sz[u]=1;
 80     for(int v,k=f1[u];k;k=e[k].nxt)
 81         if(e[k].to!=fa)
 82         {
 83             v=e[k].to;
 84             ff[v]=u;
 85             dfs1(v,u);
 86             sz[u]+=sz[v];
 87             if(sz[v]>sz[hson[u]])    hson[u]=v;
 88         }
 89 }
 90 P2 d1[200101];//d1[i]维护i节点及其轻儿子的贡献
 91 P2 d2[200101];//d2[i]维护i节点(是重链顶)所在重链的值
 92 int tp[200101],dwn[200101];//链顶,链底
 93 inline void updd1(int x)
 94 {
 95     initnode(S::x,d1[x]);S::L=pl[x];S::_setx(1,n,1);
 96 }
 97 void dfs2(int u,int fa)
 98 {
 99     d1[u].a=a[u];
100     b[++b[0]]=u;pl[u]=b[0];
101     tp[u]=(u==hson[fa])?tp[fa]:u;
102     if(hson[u])    dfs2(hson[u],u);
103     dwn[u]=hson[u]?dwn[hson[u]]:u;
104     int v,k;
105     for(k=f1[u];k;k=e[k].nxt)
106         if(e[k].to!=fa&&e[k].to!=hson[u])
107         {
108             v=e[k].to;
109             dfs2(v,u);
110             d1[u].b+=d2[v].b;
111             d1[u].a+=d2[v].a;
112         }
113     updd1(u);
114     if(u==tp[u])
115     {
116         P1 t=S::getx(pl[u],pl[dwn[u]],1,n,1);
117         d2[u].a=t.a;d2[u].b=t.d+t.c;
118     }
119 }
120 inline ll getsize(int x)
121 {
122     return S::getx(pl[x],pl[dwn[x]],1,n,1).a;
123 }
124 int main()
125 {
126     int i,x,y,idx;ll z,ans,szall;P1 t;
127     scanf("%d%d",&n,&m);
128     for(i=1;i<n;++i)
129     {
130         scanf("%d%d",&x,&y);
131         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
132         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
133     }
134     for(i=1;i<=n;++i)    scanf("%lld",a+i);
135     dfs1(1,0);
136     dfs2(1,0);
137     while(m--)
138     {
139         scanf("%d",&idx);
140         if(idx==1)
141         {
142             scanf("%d%lld",&x,&z);
143             d1[x].a-=a[x];a[x]=z;d1[x].a+=z;
144             while(x)
145             {
146                 updd1(x);
147                 x=tp[x];y=ff[x];
148                 t=S::getx(pl[x],pl[dwn[x]],1,n,1);
149                 d1[y].a-=d2[x].a;d1[y].b-=d2[x].b;
150                 d2[x].a=t.a;d2[x].b=t.d+t.c;
151                 d1[y].a+=d2[x].a;d1[y].b+=d2[x].b;
152                 x=y;
153             }
154             //printf("3t%d\n",d2[1].b);
155         }
156         else
157         {
158             scanf("%d",&x);
159             ans=d2[1].b;
160             szall=getsize(1);
161             if(x!=tp[x])
162             {
163                 y=tp[x];
164                 z=d1[y].a;
165                 d1[y].a+=szall-getsize(y);
166                 updd1(y);
167                 if(y!=dwn[y])
168                 {
169                     t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1);
170                     ans-=t.c;
171                 }
172                 if(x!=dwn[y])
173                 {
174                     t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1);
175                     ans+=t.c;
176                 }
177                 t=S::getx(pl[y],pl[x]-1,1,n,1);
178                 ans+=t.f;
179                 d1[y].a=z;
180                 updd1(y);
181                 x=y;
182             }
183             while(x!=1)
184             {
185                 y=ff[x];
186                 z=getsize(x);
187                 ans-=z*z;
188                 z=szall-z;
189                 ans+=z*z;
190                 x=y;
191                 if(x!=tp[x])
192                 {
193                     y=tp[x];
194                     z=d1[y].a;
195                     d1[y].a+=szall-getsize(y);
196                     updd1(y);
197                     if(y!=dwn[y])
198                     {
199                         t=S::getx(pl[y]+1,pl[dwn[y]],1,n,1);
200                         ans-=t.c;
201                     }
202                     if(x!=dwn[y])
203                     {
204                         t=S::getx(pl[x]+1,pl[dwn[y]],1,n,1);
205                         ans+=t.c;
206                     }
207                     t=S::getx(pl[y],pl[x]-1,1,n,1);
208                     ans+=t.f;
209                     d1[y].a=z;
210                     updd1(y);
211                     x=y;
212                 }
213             }
214             printf("%lld\n",ans);
215         }
216     }
217     return 0;
218 }
View Code

码完一看题解,???好像画风不太对??

所以还是无视上面那个代码吧...

正常得多的做法:

  1 #include<cstdio>
  2 #include<algorithm>
  3 using namespace std;
  4 typedef long long ll;
  5 struct E
  6 {
  7     int to,nxt;
  8 }e[400011];
  9 int f1[200011],ne;
 10 int n,m;
 11 struct S
 12 {
 13 #define lowbit(x) ((x)&(-x))
 14     ll d1[200011],d2[200011];
 15     void _add(int p,ll x,ll *d)
 16     {
 17         for(;p<=n;p+=lowbit(p))
 18             d[p]+=x;
 19     }
 20     ll _sum(int p,ll *d)
 21     {
 22         ll ans=0;
 23         for(;p>0;p-=lowbit(p))
 24             ans+=d[p];
 25         return ans;
 26     }
 27     void add(int l,int r,ll x)
 28     {
 29         _add(l,x,d1);
 30         _add(r+1,-x,d1);
 31         _add(l,x*l,d2);
 32         _add(r+1,-x*(r+1),d2);
 33     }
 34     ll sum(int l,int r)
 35     {
 36         return (r+1)*_sum(r,d1)-_sum(r,d2)
 37             -l*_sum(l-1,d1)+_sum(l-1,d2);
 38     }
 39 }s1;
 40 int b[200011],pl[200011];
 41 ll a[200011],a2[200011];
 42 int sz[200011],hson[200011],tp[200011];
 43 ll dep[200011];
 44 int ff[200011];
 45 void dfs1(int u,int fa)
 46 {
 47     sz[u]=1;
 48     for(int k=f1[u];k;k=e[k].nxt)
 49         if(e[k].to!=fa)
 50         {
 51             ff[e[k].to]=u;
 52             dep[e[k].to]=dep[u]+1;
 53             dfs1(e[k].to,u);
 54             sz[u]+=sz[e[k].to];
 55             if(sz[e[k].to]>sz[hson[u]])    hson[u]=e[k].to;
 56         }
 57 }
 58 void dfs2(int u,int fa)
 59 {
 60     b[++b[0]]=u;pl[u]=b[0];
 61     tp[u]=u==hson[fa]?tp[fa]:u;
 62     a2[u]=a[u];
 63     if(hson[u])
 64     {
 65         dfs2(hson[u],u);
 66         a2[u]+=a2[hson[u]];
 67     }
 68     for(int k=f1[u];k;k=e[k].nxt)
 69         if(e[k].to!=fa&&e[k].to!=hson[u])
 70         {
 71             dfs2(e[k].to,u);
 72             a2[u]+=a2[e[k].to];
 73         }
 74 }
 75 inline ll gsum1(int x)//x到1的路径和
 76 {
 77     int y;ll an=0;
 78     for(;x;x=ff[y])
 79     {
 80         y=tp[x];
 81         an+=s1.sum(pl[y],pl[x]);
 82     }
 83     return an;
 84 }
 85 inline void add1(int x,ll z)//x到1加上z
 86 {
 87     int y;
 88     for(;x;x=ff[y])
 89     {
 90         y=tp[x];
 91         s1.add(pl[y],pl[x],z);
 92     }
 93 }
 94 ll anss;
 95 int main()
 96 {
 97     ll ans,z,t;
 98     int i,x,y,idx;
 99     scanf("%d%d",&n,&m);
100     for(i=1;i<n;++i)
101     {
102         scanf("%d%d",&x,&y);
103         e[++ne].to=y;e[ne].nxt=f1[x];f1[x]=ne;
104         e[++ne].to=x;e[ne].nxt=f1[y];f1[y]=ne;
105     }
106     for(i=1;i<=n;++i)
107         scanf("%lld",a+i);
108     dfs1(1,0);
109     dfs2(1,0);
110     for(i=1;i<=n;++i)
111     {
112         s1.add(pl[i],pl[i],a2[i]);
113         anss+=a2[i]*a2[i];
114     }
115     while(m--)
116     {
117         scanf("%d",&idx);
118         if(idx==1)
119         {
120             scanf("%d%lld",&x,&z);
121             z=z-a[x];a[x]+=z;
122             anss+=z*z*(dep[x]+1);
123             anss+=2*gsum1(x)*z;
124             add1(x,z);
125         }
126         else
127         {
128             scanf("%d",&x);
129             ans=anss;
130             t=gsum1(1);
131             ans+=dep[x]*t*t;
132             ans-=2*t*(gsum1(x)-t);
133             printf("%lld\n",ans);
134         }
135     }
136     return 0;
137 }
View Code

猜你喜欢

转载自www.cnblogs.com/hehe54321/p/10198047.html