LOJ 2542 「PKUWC2018」随机游走 ——树上高斯消元(期望DP)+最值反演+fmt

题目:https://loj.ac/problem/2542

可以最值反演。注意 min 不是独立地算从根走到每个点的最小值,在点集里取 min ,而是整体来看,“从根开始走到点集中的任意一个点就停下”的期望步数。

设 f[ i ] 表示从根走到 i ,再走期望几步就能走到点集中的某个点。有 \( f[i]=\frac{1}{d[i]}\sum\limits_{j}(f[j]+1) \) ( j 是和 i 有边的点)

于是要“树上高斯消元”。其实就是尝试写成 \( f[i]=a[i]*f[st]+b[i] \) (st 是根)之类的形式,从而让系数的转移有方向,然后根据 \( a[st] \) 和 \( b[st] \) 算 \( f[st] \) 。

为了有方向,这里设 \( f[i]=a[i]*f[st]+b[i]*f[fa]+c[i] \) (有 \( a[i]*f[st] \) 是为了算 \( f[st] \) )

\( f[i]=\frac{1}{d[i]}f[fa]+\frac{1}{d[i]}+\frac{1}{d[i]}\sum\limits_{j \in child}f[j]+\frac{d[i]-1}{d[i]} \)

\( d[i]*f[i]=f[fa]+1+\sum\limits_{j \in child}a[j]f[st]+\sum\limits_{j \in child}b[j]f[i]+\sum\limits_{j \in child}c[j] \)

\( (d[i]-\sum\limits_{j \in child}b[j])f[i]=\sum\limits_{j \in child}a[j]f[1]+f[fa]+d[i]+\sum\limits_{j \in child}c[j] \)

然后对于每个点集可以树形DP地算。如果走到了点集中的点,那么 a[cr] 、b[cr] 、c[cr] 都是 0 ,并且直接 return 即可。

最值反演的时候求子集的和可以用 fmt 算。那个 -1 的系数只要在初值的时候体现一下就行了。

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
int rdn()
{
  int ret=0;bool fx=1;char ch=getchar();
  while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return fx?ret:-ret;
}
const int N=20,M=(1<<18)+5,mod=998244353;
int upt(int x){if(x<0)x+=mod;if(x>=mod)x-=mod;return x;}
int pw(int x,int k)
{int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;}

int n,st,hd[N],xnt,to[N<<1],nxt[N<<1],dg[N],a[N],b[N],c[N];
int bin[N],f[M],ct[M];
void add(int x,int y){to[++xnt]=y;nxt[xnt]=hd[x];hd[x]=xnt;dg[x]++;}
void dfs(int cr,int fa,int s)
{
  a[cr]=b[cr]=c[cr]=0;int tp=0;
  if(s&bin[cr-1])return;//////return!
  for(int i=hd[cr],v;i;i=nxt[i])
    if((v=to[i])!=fa)
      {
    dfs(v,cr,s);
    a[cr]=upt(a[cr]+a[v]); c[cr]=upt(c[cr]+c[v]); tp=upt(tp+b[v]);
      }
  tp=pw(upt(dg[cr]-tp),mod-2);
  a[cr]=(ll)a[cr]*tp%mod; b[cr]=tp; c[cr]=(ll)(c[cr]+dg[cr])*tp%mod;
}
void fmt()
{
  for(int i=1;i<bin[n];i<<=1)
    for(int s=0;s<bin[n];s++)
      if(s&i)f[s]=upt(f[s]+f[s^i]);
}
int main()
{
  n=rdn();int Q=rdn();st=rdn();
  for(int i=1,u,v;i<n;i++)
    u=rdn(),v=rdn(),add(u,v),add(v,u);
  bin[0]=1;for(int i=1;i<=n;i++)bin[i]=bin[i-1]<<1;
  for(int s=1;s<bin[n];s++)ct[s]=ct[s-(s&-s)]+1;
  for(int s=1;s<bin[n];s++)
    {
      dfs(st,0,s);f[s]=(ll)c[st]*pw(upt(1-a[st]),mod-2)%mod;
      if((ct[s]&1)==0)f[s]=upt(-f[s]);
    }
  fmt();
  while(Q--)
    {
      n=rdn();int s=0;
      for(int i=1;i<=n;i++)s|=bin[rdn()-1];
      printf("%d\n",f[s]);
    }
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/Narh/p/10279703.html
今日推荐