异或序列 [set优化DP]

也许更好的阅读体验

\(\mathcal{Description}\)
有一个长度为 \(n\)的自然数序列 \(a\),要求将这个序列分成至少 \(m\) 个连续子段
每个子段的价值为该子段的所有数的按位异或
要使所有子段的价值按位与的结果最大,输出这个最大值

\(T\)组询问
\(T\leq 10,n,m\leq 1000,a_i\leq 2^{30}\)
\(\mathcal{Solution}\)
实际上数据范围可开大很多

我们贪心的一位一位的确定最终答案,即看当前考虑的位能否为\(1\)
\(s_i\)表示前\(i\)个数的异或和,\(\bigoplus\)表示异或
设当前考虑到了第\(b\)
\(res=ans|(1<<b)\)
一段区间\([j+1,i]\)如果是一个合法的区间,可以得到
\(\left(s_i\bigoplus s_j\right)\&res=res\)
于是我们得到了一个\(n^2log\)\(DP\)方程
\(f_i=max{f_i,f_j+1}\)其中\(\left(s_i\bigoplus s_j\right)=res\)
枚举位是\(log\)的,这样就可以\(AC\)此题了

实际这个\(DP\)可以进一步优化
\(\left(s_i\bigoplus s_j\right)\&res=res\)可以推出
\(\left(s_i \& res\right)\bigoplus \left(s_j\& res\right)=res\)
\(\Rightarrow s_i \& res=\left(s_j\& res\right)\bigoplus res\)
即要将\(s_i\)\(s_j\)这段作为一个子段必须满足上面的条件
因为题目是至少\(m\)段,所以分的越多越好
则我们可以考虑完\(s_i\)的最优答案后将\(s_i\bigoplus res\)作为第一关键字存进\(set\)
\(f_i=find(s_i\bigoplus res)\)
这样一次转移就是\(log\)
复杂度为\(nlog^2\)

\(\mathcal{Code}\)

/*******************************
Author:Morning_Glory
LANG:C++
Created Time:2019年10月26日 星期六 09时18分19秒
*******************************/
#include <cstdio>
#include <fstream>
#include <cstring>
#include <set>
#define mp make_pair
using namespace std;
const int maxn = 2003;
//{{{cin
struct IO{
    template<typename T>
    IO & operator>>(T&res){
        res=0;
        bool flag=false;
        char ch;
        while((ch=getchar())>'9'||ch<'0')   flag|=ch=='-';
        while(ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch^'0'),ch=getchar();
        if (flag)   res=~res+1;
        return *this;
    }
}cin;
//}}}
int n,m,T,ans;
int a[maxn],s[maxn];
set < pair<int,int> > v;
set < pair<int,int> > :: iterator it,nx;
//{{{solve
void solve (int x)
{
    int res=ans|(1<<x);
    bool flag;
    v.clear();
    for (int i=1;i<=n;++i){
        int val=s[i]&res;
        v.insert(mp(val,0));
        nx=it=v.lower_bound(mp(val,0));
        ++nx;
        while (nx!=v.end()&&nx->first==val){
            v.erase(it);
            it=nx,++nx;
        }
        if (it->second==0){
            if (val==res){
                v.insert(mp(val^res,1));
                if (i==n)   flag=it->second+1>=m;
            }
        }
        else{
            v.insert(mp(val^res,(it->second)+1));
            if (i==n)   flag=it->second+1>=m;
        }
    }
    if (flag)   ans=res;
}
//}}}
int main()
{
    cin>>T;
    while (T--){
        cin>>n>>m;
        ans=0;
        for (int i=1;i<=n;++i){
            cin>>a[i];
            s[i]=s[i-1]^a[i];
        }

        for (int i=29;~i;--i)   solve(i);
        printf("%d\n",ans);
    }
    return 0;
}

如有哪里讲得不是很明白或是有错误,欢迎指正
如您喜欢的话不妨点个赞收藏一下吧

猜你喜欢

转载自www.cnblogs.com/Morning-Glory/p/11743945.html