[SDOI2017]切树游戏

description

洛谷
给一棵树,要求支持单点权值修改,以及询问树上有多少个连通块的权值异或和恰好为\(k\)
答案对\(1e4+7\)取模。

data range

\(2^m,q1,q2\)分别为最大权值,修改次数和询问次数。
\[n,q1+q2\le 3\times 10^4,2^m < 128 ,q2\le 10^4\]

solution

我们显然可以想到一个朴素的\(DP\)式:
\(f[u][s]\)表示异或和为\(s\),且以\(u\)为深度最浅的点的连通块个数,那么枚举树边\((u,v)(v\not=fa)\)
\[f'[u][s]=\sum_{t=0}^{2^m-1}f[u][t]\times f[v][s\oplus t]\]
其中\(\oplus\)表示二进制异或。
我们对于每一次修改都这样\(DP\)一遍,那么复杂度为\(O(qn2^{2m})\)

可以发现异或卷积可以使用\(FWT\)优化(考过哦),复杂度降为\(O(qnm2^m)\)

我们发现\(FWT\)后的数组更新是直接按位相乘/相加,于是我们只要对于初始化数组\(FWT\)一遍,之后

对于答案数组再\(FWT\)回去即可,复杂度将为\(O(nm2^m+q(n+m)2^m)\)

考虑优化我们的\(DP\)式。

考虑\(f[u]\)的生成函数\(f_u(x)=\sum_{i=0}^{2^m-1}a_ix^i\),那么
\[f_u(x)=\prod_{v\in son_u}x^{val_i}(f_v(x)+1)\]
其中多项式的积定义为异或卷积。

我们知道每次修改只有一条链的\(DP\)值会被修改,于是考虑用树链剖分维护。

考虑新开一个\(g_i(x)\)表示\(i\)的子树的\(f(x)\)之和,那么答案即为\(g_1(x)\)

维护轻儿子\(f(x)\)的乘积\(LF_i\)\(g(x)\)的和\(LG_i\),考虑一条从根到底的链\(c_1,c_2,...,c_k\),

那么我们有
\[ \begin{aligned} f_{c_i}(x)&=(x^{v_{c_i}}LF_{c_i})f_{c_{i+1}}(x)+(x^{v_{c_i}}LF_{c_i}) \\ g_{c_i}(x)&=LG_{c_i}+f_{c_i}(x)+g_{c_{i+1}}(x)=LG_{c_i}(x)+f_{c_i}(x)+g_{c_{i+1}}(x) \\ &=LG_{c_i}+(x^{v_{c_i}}LF_{c_i})+(x^{v_{c_i}}LF_{c_i})f_{c_{i+1}}(x)+g_{c_{i+1}}(x) \end{aligned} \]

这是什么?递推式啊!

对于递推式,我们可以使用矩阵快速幂进行优化。

\begin{equation}
(f_{c_{i+1}},g_{c_{i+1}},1){
\left( \begin{array}{ccc}
x^{v_{c_i}}LF_{c_i} & x^{v_{c_i}}LF_{c_i} & 0\
0 & 1 & 0\
x^{v_{c_i}}LF_{c_i} & x^{v_{c_i}}LF_{c_i}+LG_{c_i} & 1
\end{array}
\right )}=(f_{c_i},g_{c_i},1)
\end{equation}

然后树链剖分+线段树维护矩阵就可以了。

时间复杂度为\(O(n(m+logn)2^m+q(m+log^2n)2^m)\)

树剖\(log^2n\)跑不满,可以过。

最后注意维护\(LF_i\)的时候答案可能会除\(0\),维护一下\(0\)的个数就可以了。

Code

#include<algorithm>
#include<iostream>
#include<cstdlib>
#include<iomanip>
#include<cstring>
#include<complex>
#include<vector>
#include<cstdio>
#include<string>
#include<bitset>
#include<ctime>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define Cpy(x,y) memcpy(x,y,sizeof(x))
#define Set(x,y) memset(x,y,sizeof(x))
#define FILE "4911"
#define mp make_pair
#define pb push_back
#define RG register
#define il inline
using namespace std;
typedef unsigned long long ull;
typedef vector<int>VI;
typedef long long ll;
typedef double dd;
const int N=30010;
const int M=1e7+10;
const int base=26;
const dd eps=1e-6;
const int inf=1e9;
const ll INF=1ll<<60;
const ll P=100000;
#define mod (10007)
il ll read(){
  RG ll data=0,w=1;RG char ch=getchar();
  while(ch!='-'&&(ch<'0'||ch>'9'))ch=getchar();
  if(ch=='-')w=-1,ch=getchar();
  while(ch<='9'&&ch>='0')data=data*10+ch-48,ch=getchar();
  return data*w;
}
il void file(){
  srand(time(NULL)+rand());
  freopen(FILE".in","r",stdin);
  freopen(FILE".out","w",stdout);
}
int n,m,q,inv[mod],a[N],tmp[128],I[128];
int head[N],nxt[N<<1],to[N<<1],cnt;
int fa[N],sz[N],son[N],dep[N],top[N],bot[N],w[N],fw[N],cntw;
il void add(int u,int v){to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;}
il void upd(int &a,int b){a+=b;if(a>=mod)a-=mod;}
il void dec(int &a,int b){if(b)upd(a,mod-b);}
il void fwt(int *a,int n,int opt){
  for(RG int i=1;i<n;i<<=1)
    for(RG int j=0,p=i<<1;j<n;j+=p)
      for(RG int k=0;k<i;k++){
    RG int x=a[j+k],y=a[i+j+k];a[i+j+k]=x;
    upd(a[j+k],y);dec(a[i+j+k],y);
    if(opt==-1)
      a[j+k]=1ll*a[j+k]*inv[2]%mod,a[i+j+k]=1ll*a[i+j+k]*inv[2]%mod;
      }
}
struct int0{int v,z;il void init(int x){v=x?x:1;z=x?0:1;}};
int0 operator *(int0 a,int b){b?a.v=1ll*a.v*b%mod:a.z++;return a;}
int0 operator /(int0 a,int b){b?a.v=1ll*a.v*inv[b]%mod:a.z--;return a;}
int0 operator *(int0 a,int0 b){a.v=1ll*a.v*b.v%mod;a.z+=b.z;return a;}
int0 operator /(int0 a,int0 b){a.v=1ll*a.v*inv[b.v]%mod;a.z-=b.z;return a;}
il int ret(int0 a){return a.z?0:a.v;}
int0 LF[N][128];int LH[N][128];
struct matrix{int s[4][128];int* operator [](int x){return s[x];}};
matrix operator *(matrix x,matrix y){
  matrix z;
  for(RG int i=0;i<m;i++){
    z.s[0][i]=1ll*x.s[0][i]*y.s[0][i]%mod;
    z.s[1][i]=1ll*x.s[0][i]*y.s[1][i]%mod;upd(z.s[1][i],x.s[1][i]);
    z.s[2][i]=1ll*x.s[2][i]*y.s[0][i]%mod;upd(z.s[2][i],y.s[2][i]);
    z.s[3][i]=1ll*x.s[2][i]*y.s[1][i]%mod;
    upd(z.s[3][i],x.s[3][i]);upd(z.s[3][i],y.s[3][i]);
  }
  return z;
}
#define ls (i<<1)
#define rs (i<<1|1)
#define mid ((l+r)>>1)
matrix sum[N<<2];
il void update(int i){sum[i]=sum[rs]*sum[ls];}
void insert(int i,int l,int r,int p){
  if(l==r){
    memset(tmp,0,sizeof(tmp));tmp[a[fw[l]]]=1;fwt(tmp,m,1);
    for(RG int j=0,x;j<m;j++){
      x=1ll*ret(LF[l][j])*tmp[j]%mod;
      sum[i][0][j]=sum[i][1][j]=sum[i][2][j]=sum[i][3][j]=x;
      upd(sum[i][3][j],LH[l][j]);
    }
    return;
  }
  if(p<=mid)insert(ls,l,mid,p);else insert(rs,mid+1,r,p);update(i);
}
matrix query(int i,int l,int r,int x,int y){
  if(x<=l&&r<=y)return sum[i];
  if(y<=mid)return query(ls,l,mid,x,y);if(mid<x)return query(rs,mid+1,r,x,y);
  return query(rs,mid+1,r,x,y)*query(ls,l,mid,x,y);
}
void dfs1(int u,int ff){
  fa[u]=ff;sz[u]=1;son[u]=0;dep[u]=dep[ff]+1;
  for(RG int i=head[u];i;i=nxt[i]){
    RG int v=to[i];if(v==ff)continue;
    dfs1(v,u);sz[u]+=sz[v];if(sz[son[u]]<sz[v])son[u]=v;
  }
}
void dfs2(int u,int tp){
  top[u]=tp;w[u]=++cntw;fw[cntw]=u;bot[u]=u;
  if(son[u]){dfs2(son[u],tp);bot[u]=bot[son[u]];}  
  for(RG int j=0;j<m;j++)LF[w[u]][j].init(I[j]);  
  RG matrix r;
  for(RG int i=head[u];i;i=nxt[i]){
    RG int v=to[i];if(v==fa[u]||v==son[u])continue;
    dfs2(v,v);r=query(1,1,n,w[v],w[bot[v]]);    
    for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
    for(RG int j=0;j<m;j++)
      LF[w[u]][j]=LF[w[u]][j]*r.s[2][j],upd(LH[w[u]][j],r.s[3][j]);
  }
  insert(1,1,n,w[u]);
}
il void change(int x,int y){
  RG matrix r;
  for(RG int u=top[x],ff;fa[u];u=top[fa[u]]){
    ff=w[fa[u]];r=query(1,1,n,w[u],w[bot[u]]);
    for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
    for(RG int j=0;j<m;j++)
      LF[ff][j]=LF[ff][j]/r.s[2][j],dec(LH[ff][j],r.s[3][j]);
  }
  a[x]=y;insert(1,1,n,w[x]);
  for(RG int u=top[x],ff;fa[u];u=top[fa[u]]){
    ff=w[fa[u]];r=query(1,1,n,w[u],w[bot[u]]);
    for(RG int j=0;j<m;j++)upd(r.s[2][j],I[j]);
    for(RG int j=0;j<m;j++)
      LF[ff][j]=LF[ff][j]*r.s[2][j],upd(LH[ff][j],r.s[3][j]);
    insert(1,1,n,ff);
  }
}
int main()
{
  n=read();m=read();inv[1]=1;I[0]=1;fwt(I,m,1);
  for(RG int i=2;i<mod;i++)inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
  for(RG int i=1;i<=n;i++)a[i]=read();
  for(RG int i=1,u,v;i<n;i++){u=read();v=read();add(u,v);add(v,u);}
  dfs1(1,0);dfs2(1,1);RG matrix r;q=read();
  for(RG int i=1,c,x,y;i<=q;i++){
    c=0;while(c!='Q'&&c!='C')c=getchar();
    if(c=='Q'){
      x=read();r=query(1,1,n,w[1],w[bot[1]]);
      memcpy(tmp,r.s[3],sizeof(tmp));fwt(tmp,m,-1);
      printf("%d\n",tmp[x]);
    }
    else{x=read();y=read();change(x,y);}
  }
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/cjfdf/p/9750958.html