牛客多校round8 H-Playing games

题面:https://www.nowcoder.com/acm/contest/146/H

题目描述

Niuniu likes playing games. He has n piles of stones. The i-th pile has ai stones. He wants to play with his good friend, UinUin. Niuniu can choose some piles out of the n piles. They will play with the chosen piles of stones. UinUin takes the first move. They take turns removing at least one stone from one chosen pile. The player who removes the last stone from the chosen piles wins the game. Niuniu wants to choose the maximum number of piles so that he can make sure he wins the game. Can you help Niuniu choose the piles?

输入描述:

The first line contains one integer n (1 ≤ n ≤ 500000), which means the number of piles.
The second line describes the piles, containing n non-negative integers, a1 a2 … an, separated by a space. The integers are less than or equal to 500000.

输出描述:

Print a single line with one number, which is the maximum number of piles Niuniu can choose to make sure he wins. If Niuniu cannot always win whatever piles he chooses, print 0.

sol:

1.

根据nim游戏的性质,复合游戏的sg值等于所有单一简单nim游戏的sg值的xor和。假设所有石头xor 和为x,则题目等价于选最少的石头堆,使得这些石头堆xor 和为x。

2.

考虑朴素的dp方程:dp[i][j] 表示选取i个数,xor和为j的方案数。
则有

a n s = arg min i ( d p [ i ] [ x ] > 0 )

转移方程为

d p [ i ] [ j ] = a b = j d p [ i 1 ] [ a ] d p [ i 1 ] [ b ]

3.

看起来这玩意像是 O ( n 3 ) 的。考虑线性基的构造过程,可以知道最多选择 O ( l o g n ) 个数就可以达成目标。

4.

d p [ i ] d p [ i + 1 ] 转移时,可以通过fwt由 O ( n 2 ) 优化到 O ( n l o g n ) 所以本题最终可以优化到 O ( n l o g n 2 )

code:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxn = 2e6 + 10;
const int mod = 1e9 + 7;
const int inv2 = (mod + 1) >> 1;
const ll inf = 1.5e18;

#define fi first
#define se second
#define pll pair<ll, ll>

void Add(ll& x,ll y){
    x+=y;
    if(x>=mod) x-=mod;
}

void Mul(ll& x,ll y){
    x*=y;
    if(x>=mod) x%=mod;
}

void Mod(ll &x){
    if(x>=mod) x-=mod;
}

struct FWT
{
    int N;
    void init(int n)
    {
        N = 1;
        while (N < n)
            N <<= 1;
    }

    void FWT_or(ll *a, int opt)
    {
        for (int i = 1; i < N; i <<= 1)
            for (int p = i << 1, j = 0; j < N; j += p)
                for (int k = 0; k < i; ++k)
                    if (opt == 1)
                        a[i + j + k] = (a[j + k] + a[i + j + k]) % mod;
                    else
                        a[i + j + k] = (a[i + j + k] + mod - a[j + k]) % mod;
    }
    void FWT_and(ll *a, int opt)
    {
        for (int i = 1; i < N; i <<= 1)
            for (int p = i << 1, j = 0; j < N; j += p)
                for (int k = 0; k < i; ++k)
                    if (opt == 1)
                        a[j + k] = (a[j + k] + a[i + j + k]) % mod;
                    else
                        a[j + k] = (a[j + k] + mod - a[i + j + k]) % mod;
    }

    void fwt_xor(int *a, int op)
    {
        for (int i = 1; i < N; i <<= 1)
            for (int p = i << 1, j = 0; j < N; j += p)
                for (int k = 0; k < i; ++k)
                {
                    ll x = a[j + k], y = a[i + j + k];
                    a[j + k] = (x + y);
                    a[i + j + k] = (x - y);
                    if (op == -1)
                    {
                        a[j + k] = a[j + k] >> 1;
                        // Mul(a[j+k],inv2);
                        a[i + j + k] = a[i + j + k] >>1;
                        // Mul(a[i+j+k],inv2);
                    }
                }
    }
} fwt;

namespace IO{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    #define ll long long
    //fread->read

    bool IOerror=0;
    inline char nc(){
        static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
        if (p1==pend){
            p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
            if (pend==p1){IOerror=1;return -1;}
            //{printf("IO error!\n");system("pause");for (;;);exit(0);}
        }
        return *p1++;
    }
    inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
    inline void read(int &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(ll &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(double &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (ch=='.'){
            double tmp=1; ch=nc();
            for (;ch>='0'&&ch<='9';ch=nc())tmp/=10.0,x+=tmp*(ch-'0');
        }
        if (sign)x=-x;
    }
    inline void read(char *s){
        char ch=nc();
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
        *s=0;
    }
    inline void read(char &c){
        for (c=nc();blank(c);c=nc());
        if (IOerror){c=-1;return;}
    }
    //fwrite->write
    struct Ostream_fwrite{
        char *buf,*p1,*pend;
        Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
        void out(char ch){
            if (p1==pend){
                fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
            }
            *p1++=ch;
        }
        void print(int x){
            static char s[15],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1);
        }
        void println(int x){
            static char s[15],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1); out('\n');
        }
        void print(ll x){
            static char s[25],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1);
        }
        void println(ll x){
            static char s[25],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1); out('\n');
        }
        void print(double x,int y){
            static ll mul[]={1,10,100,1000,10000,100000,1000000,10000000,100000000,
                1000000000,10000000000LL,100000000000LL,1000000000000LL,10000000000000LL,
                100000000000000LL,1000000000000000LL,10000000000000000LL,100000000000000000LL};
            if (x<-1e-12)out('-'),x=-x;x*=mul[y];
            ll x1=(ll)floor(x); if (x-floor(x)>=0.5)++x1;
            ll x2=x1/mul[y],x3=x1-x2*mul[y]; print(x2);
            if (y>0){out('.'); for (size_t i=1;i<y&&x3*mul[i]<mul[y];out('0'),++i); print(x3);}
        }
        void println(double x,int y){print(x,y);out('\n');}
        void print(char *s){while (*s)out(*s++);}
        void println(char *s){while (*s)out(*s++);out('\n');}
        void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
        ~Ostream_fwrite(){flush();}
    }Ostream;
    inline void print(int x){Ostream.print(x);}
    inline void println(int x){Ostream.println(x);}
    inline void print(char x){Ostream.out(x);}
    inline void println(char x){Ostream.out(x);Ostream.out('\n');}
    inline void print(ll x){Ostream.print(x);}
    inline void println(ll x){Ostream.println(x);}
    inline void print(double x,int y){Ostream.print(x,y);}
    inline void println(double x,int y){Ostream.println(x,y);}
    inline void print(char *s){Ostream.print(s);}
    inline void println(char *s){Ostream.println(s);}
    inline void println(){Ostream.out('\n');}
    inline void flush(){Ostream.flush();}
    #undef ll
    #undef OUT_SIZE
    #undef BUF_SIZE
};

int c[maxn];
int dp[maxn];
int A[maxn];

int main()
{
    int n;
    IO::read(n);
    int x = 0;
    int mx = 0;
    for(int i=1;i<=n;i++) {
        IO::read(A[i]);
        x^=A[i];
        mx = max(mx,A[i]);
        c[A[i]] = 1;
    }
    memcpy(dp,c,sizeof(dp));
    // Print();
    fwt.init((mx+1));
    fwt.fwt_xor(c,1);
    // Print();
    if(x == 0){
        IO::print(n);
        IO::print('\n');
    }
    else {
        int ans = 1;
        while(true){
            // cout<<"ans = "<<ans<<"    ------------"<<endl;
            // Print();
            if(dp[x]>0){
                break;
            }
            // if(ans>32) break;
            // fwt.fwt_xor(c,1);
            fwt.fwt_xor(dp,1);
            for(int i=0;i<fwt.N;i++) dp[i] = dp[i]*c[i];
            fwt.fwt_xor(dp,-1);
            for(int i=0;i<fwt.N;i++) dp[i] = (dp[i]!=0);
            ++ans;
        }
        ans = n - ans;
        IO::print(ans);
        IO::print('\n');
        // printf("%d\n",ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/oWuHen12/article/details/81676998