UOJ #348 州区划分 —— 状压DP+子集卷积

题目:http://uoj.ac/problem/348

一开始可以 3^n 子集DP,枚举一种状态的最后一个集合是什么来转移;

设 \( f[s] \) 表示 \( s \) 集合内的点都划分好了,\( g[s] = \sum\limits_{i \in s} w[i] \)

那么 \( f[s] = \sum\limits_{d \subseteq s} \frac{f[s-d] * g[d]}{g[s]} \)

注意判断一个集合是否合法,不仅要判断每个点的度数,还要判断整个集合是否连通;

这样就可以过 n <= 15 的点了,UOJ上有30分;

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(1<<21)+5,xxn=25,xm=505,mod=998244353;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int n,m,p,f[xn],g[xn],w[xxn],g2[xn];
int hd[xxn],ct,to[xm],nxt[xm],bin[xxn];
void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
ll pw(ll a,int b){ll ret=1; for(;b;b>>=1,a=a*a%mod)if(b&1)ret=ret*a%mod; return ret;}
bool vis[xxn];
int dfs(int x,int s)
{
  vis[x]=1; int ret=1;
  for(int i=hd[x],u;i;i=nxt[i])
    if(!vis[u=to[i]]&&(s&bin[u-1]))ret+=dfs(u,s);
  return ret;
}
bool ck(int s)//
{
  int cnt=0;
  for(int x=1;x<=n;x++)
    {
      if(!(s&bin[x-1]))continue;
      int deg=0; cnt++;
      for(int i=hd[x];i;i=nxt[i])
    {
      if(s&bin[to[i]-1])deg++;
    }
      if(deg&1)return 1;
    }
  for(int i=1;i<=n;i++)vis[i]=0;
  for(int i=1;i<=n;i++)
    if(s&bin[i-1])return dfs(i,s)!=cnt;
}
int main()
{
  n=rd(); m=rd(); p=rd();
  bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]*2;
  for(int i=1,x,y;i<=m;i++)x=rd(),y=rd(),add(x,y),add(y,x);
  for(int i=1;i<=n;i++)g[bin[i-1]]=rd();
  for(int s=0;s<bin[n];s++)g[s]=upt(g[s&(-s)]+g[s^(s&(-s))]);
  for(int s=0;s<bin[n];s++)g[s]=pw(g[s],p),g2[s]=pw(g[s],mod-2);
  for(int s=0;s<bin[n];s++)if(!ck(s))g[s]=0;
  int num=0;
  f[0]=1;
  for(int s=1;s<bin[n];s++)
    {
      for(int d=s;d;d=(s&(d-1)))//d=s
    f[s]=(f[s]+(ll)f[s^d]*g[d])%mod;
      f[s]=(ll)f[s]*g2[s]%mod;
    }
  printf("%d\n",f[bin[n]-1]);
  return 0;
}
3^n

关于FMT(其实和高维前缀和差不多)和子集卷积:https://www.cnblogs.com/Dance-Of-Faith/p/8818211.html

于是可以做子集卷积加速DP的过程。

代码如下:

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
int const xn=(1<<21)+5,xxn=25,xm=505,mod=998244353;
int rd()
{
  int ret=0,f=1; char ch=getchar();
  while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
  while(ch>='0'&&ch<='9')ret=ret*10+ch-'0',ch=getchar();
  return f?ret:-ret;
}
int n,m,p,f[xxn][xn],g[xxn][xn],w[xxn],g2[xn];
int hd[xxn],ct,to[xm],nxt[xm],bin[xxn],cnt[xn];
void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;}
int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;}
ll pw(ll a,int b){ll ret=1; for(;b;b>>=1,a=a*a%mod)if(b&1)ret=ret*a%mod; return ret;}
bool vis[xxn];
int dfs(int x,int s)
{
  vis[x]=1; int ret=1;
  for(int i=hd[x],u;i;i=nxt[i])
    if(!vis[u=to[i]]&&(s&bin[u-1]))ret+=dfs(u,s);
  return ret;
}
bool ck(int s)//
{
  int cnt=0;
  for(int x=1;x<=n;x++)
    {
      if(!(s&bin[x-1]))continue;
      int deg=0; cnt++;
      for(int i=hd[x];i;i=nxt[i])
    {
      if(s&bin[to[i]-1])deg++;
    }
      if(deg&1)return 1;
    }
  for(int i=1;i<=n;i++)vis[i]=0;
  for(int i=1;i<=n;i++)
    if(s&bin[i-1])return dfs(i,s)!=cnt;
}
int cal(int s){int ret=0; while(s)ret+=(s&1),s>>=1; return ret;}
void fmt(int *a,int tp)
{
  for(int d=1;d<bin[n];d<<=1)
    for(int s=0;s<bin[n];s++)
      if(s&d)a[s]=upt(a[s]+a[s^d]*tp);
}
int main()
{
  n=rd(); m=rd(); p=rd();
  bin[0]=1; for(int i=1;i<=n;i++)bin[i]=bin[i-1]*2;
  for(int i=1,x,y;i<=m;i++)x=rd(),y=rd(),add(x,y),add(y,x);
  for(int s=0;s<bin[n];s++)cnt[s]=cal(s);
  for(int i=1;i<=n;i++)g2[bin[i-1]]=rd();
  for(int s=0;s<bin[n];s++)g2[s]=upt(g2[s&(-s)]+g2[s^(s&(-s))]);
  for(int s=0;s<bin[n];s++)g[cnt[s]][s]=pw(g2[s],p),g2[s]=pw(g[cnt[s]][s],mod-2);
  for(int s=0;s<bin[n];s++)if(!ck(s))g[cnt[s]][s]=0;
  for(int i=1;i<=n;i++)fmt(g[i],1);
  f[0][0]=1; fmt(f[0],1);
  for(int i=1;i<=n;i++)
    {
      for(int j=0;j<=i;j++)
    for(int s=0;s<bin[n];s++)
      f[i][s]=(f[i][s]+(ll)f[j][s]*g[i-j][s])%mod;
      fmt(f[i],-1);
      for(int s=0;s<bin[n];s++)
    if(cnt[s]==i)f[i][s]=(ll)f[i][s]*g2[s]%mod;
    else f[i][s]=0;
      fmt(f[i],1);
    }
  fmt(f[n],-1);
  printf("%d\n",f[n][bin[n]-1]);
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/Zinn/p/10258606.html