Lock (fft + 状压dp)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/cys460714380/article/details/79996924

Description
Yplusplus has a rotary password lock. The lock has n(n <= 50000) positions and each position corresponds to a number from 0 to 3. This is a rotating lock, so digits of each position are circular which means that 0 can become 1 or 3 with one rotation. In the same way, 3 can become 0 or 2 with one rotation. In addition, this lock is a bit rusty, so some positions can not be rotated.The condition to open the lock is to rotate some positions so that it contains k(1<=k<=8) non-overlapping, consecutive specified password strings.Now, Yplusplus wants to open the lock with as few rotations as possible. So he turns to you. If the lock can not be open, output -1.

Input
The first line contains a integer T — the number of cases.For each case, the first line contains two integers n and k — the length of lock and the number of specified password strings.The following line represents the current state of lock.Next line represents the rusty state of lock, 0 means the coresponding position is not rusty and 1 meas it is rusty.The next k lines represent each password string.See samples for more details.

Output
Print a single integer for each case — the fewest rotations to open the lock. If it is impossible, print -1.

Sample Input 1
2
6 2
012321
000000
10
12
6 2
022332
110000
10
13
Sample Output 1
2
5
Hint
for the first sample, change 012321 to 012310.for the second sample, change 022332 to 021013.

Source
yplusplus

先计算每个密码串在每个位置匹配的代价
因为数字只有0-3 把每个数字分开来算代价加到一起
算代价对密码串的顺序反过来构造一个多项式
那么一个位置匹配的代价就是对应后面第len个位置系数
得到代价之后做一个简单的状压dp

顺便学会了一下fft的使用
大概就是先取一个为>=n*2的长度2^k
然后对两个复数集进行构造
调用fft(x,2^k,1) fft(y,2^k,1)
然后将两个复数集对应相乘
再调用fft(x,2^k,-1)
对应位置就是多项式对应的结果了
因为是double所以需要注意处理误差

#include <iostream>
#include <algorithm>
#include <sstream>
#include <string>
#include <queue>
#include <cstdio>
#include <map>
#include <set>
#include <utility>
#include <stack>
#include <cstring>
#include <cmath>
#include <vector>
//#include <ctime>
#include <bitset>
#include <assert.h>
using namespace std;
#define pb push_back
#define sd(n) scanf("%d",&n)
#define sdd(n,m) scanf("%d%d",&n,&m)
#define sddd(n,m,k) scanf("%d%d%d",&n,&m,&k)
#define sld(n) scanf("%lld",&n)
#define sldd(n,m) scanf("%lld%lld",&n,&m)
#define slddd(n,m,k) scanf("%lld%lld%lld",&n,&m,&k)
#define sf(n) scanf("%lf",&n)
#define sff(n,m) scanf("%lf%lf",&n,&m)
#define sfff(n,m,k) scanf("%lf%lf%lf",&n,&m,&k)
#define ss(str) scanf("%s",str)
#define ansn() printf("%d\n",ans)
#define lansn() printf("%lld\n",ans)
#define r0(i,n) for(int i=0;i<(n);++i)
#define r1(i,e) for(int i=1;i<=e;++i)
#define rn(i,e) for(int i=e;i>=1;--i)
#define mst(abc,bca) memset(abc,bca,sizeof abc)
#define lowbit(a) (a&(-a))
#define all(a) a.begin(),a.end()
#define pii pair<int,int>
#define pll pair<long long,long long>
#define mp(aa,bb) make_pair(aa,bb)
#define lrt rt<<1
#define rrt rt<<1|1
#define X first
#define Y second
#define PI (acos(-1.0))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
const ll mod = 1000000009 ;
const double eps=1e-9;
const int inf=0x3f3f3f3f;
//const ll infl = 100000000000000000;//1e17
const int maxn=  5e4+20;
const int maxm = 4e2+20;
//muv[i]=(p-(p/i))*muv[p%i]%p;
int in(int &ret) {
    char c;
    int sgn ;
    if(c=getchar(),c==EOF)return -1;
    while(c!='-'&&(c<'0'||c>'9'))c=getchar();
    sgn = (c=='-')?-1:1;
    ret = (c=='-')?0:(c-'0');
    while(c=getchar(),c>='0'&&c<='9')ret = ret*10+(c-'0');
    ret *=sgn;
    return 1;
}
struct cpx {
    double x,y;
    cpx() {}
    cpx(double _x ,double _y) {
        x = _x;
        y = _y;
    }
    cpx operator - (const cpx &b)const {
        return cpx(x-b.x,y-b.y);
    }
    cpx operator + (const cpx &b)const {
        return cpx(x+b.x,y+b.y);
    }
    cpx operator * (const cpx &b)const {
        return cpx(x*b.x - y*b.y,x*b.y+y*b.x);
    }

};

cpx d[1111111],f[1111111];
void change(cpx y[],int len) {
    int i,j,k;
    for(i = 1,j=len/2; i<len-1; ++i) {
        if(i<j)swap(y[i],y[j]);
        k = len/2;
        while(j>=k) {
            j -= k;
            k /= 2;
        }
        if(j<k)j+=k;
    }
}
void fft(cpx y[],int len,int on) {
    change(y,len);
    for(int h = 2; h<=len; h<<=1) {
        cpx wn(cos(-on*2*PI/h),sin(-on*2*PI/h));
        for(int j = 0; j<len; j+=h) {
            cpx w(1,0);
            for(int k=j; k<j+h/2; ++k) {
                cpx u = y[k];
                cpx t = w*y[k+h/2];
                y[k] = u + t;
                y[k+h/2] = u - t;
                w = w*wn;
            }
        }
    }
    if(on==-1) {
        for(int i=0; i<len; ++i)y[i].x /=len;
    }
}
string lock;
string rust;
string pass[10];
int len[10];
int dp[maxn][maxm];
int w[maxn][10];
int n,k;
void calw() {
    mst(w,0);
    int len1 = 1;//fft length
    while(len1<2*n)len1<<=1;
    for(int p=0; p<k; ++p) {
        for(int i=0; i<=3; ++i) {
            int len2 = len[p];//oassword length
            for(int j=0; j<n; ++j) {
                if(lock[j]-'0'==i)d[j] = cpx(0,0);
                else {
                    if(rust[j]-'0')d[j] = cpx(inf,0);
                    else {
                        int x = lock[j] - '0';
                        int c = min((x - i + 4)%4,(i - x +4 )%4 );
                        d[j] = cpx(c,0);
                    }
                }
            }
            for(int j=n;j<len1;++j)d[j] = cpx(0,0);
            for(int j=0;j<len2;++j)f[len2-j-1] = cpx(((pass[p][j]-'0')==i),0);
            for(int j = len2;j<len1;++j)f[j] = cpx(0,0);
            fft(d,len1,1);
            fft(f,len1,1);
            for(int j =0;j<len1;++j)d[j] = d[j]*f[j];
            fft(d,len1,-1);
            for(int j=0;j<n;++j)
            {
                w[j][p] += (int)(d[j+len2-1].x + 0.5);
            }
        }
    }
//    for(int i=0;i<k;++i)
//    {
//        for(int j=0;j<n;++j)printf("%d%c",w[j][i]," \n"[j==n-1]);
//    }
}
int caldp()
{
    mst(dp,0x3f);
    dp[0][0] = 0;
    int sz = 1<<k;
    for(int i=0;i<n;++i)
    {
        for(int j=0;j<sz;++j)
        {
            if(dp[i][j]==inf)continue;
            dp[i+1][j] = min(dp[i+1][j],dp[i][j]);
            for(int l=0;l<k;++l)
            {
                if((j>>l)&1)continue;
                if(w[i][l]>=inf)continue;
                int len2 = len[l];
                if(len2+i>n)continue;
                int nt = j|(1<<l);
                dp[i+len2][nt] = min(dp[i+len2][nt],dp[i][j] + w[i][l]);
//                printf("dp %d %d  = %d to dp %d %d = %d\n",i,j,dp[i][j],i+len2,nt,dp[i][j] + w[i][l]);

            }
        }
    }
    int ans = inf;
    for(int i=0;i<=n;++i)
    {
        ans = min(ans,dp[i][sz-1]);
    }

    return ans;
}
int main() {
#ifdef LOCAL
    freopen("input.txt","r",stdin);
//    freopen("output.txt","w",stdout);
#endif // LOCAL

    int t;
    sd(t);
    while(t--) {
//        int n,k;
        sdd(n,k);
        cin>>lock>>rust;
        int tot = 0;
        r0(i,k)cin>>pass[i],len[i] = pass[i].length(),tot += len[i];
        if(tot>n) {
            puts("-1");
            continue;
        }
        calw();
        int ans = caldp();
        if(ans==inf)ans = -1;
        ansn();

    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/cys460714380/article/details/79996924