Codeforces 868F 分治优化Dp

原题链接:http://codeforces.com/problemset/problem/868/F

大致题意:给出有n(n<=10^5)个元素的序列,元素值ai<=n,需要将其分为m(m<=min( 20, n))段,每段的费用是∑( calc[i]-1)*calc[i]/2,其中calc[i]为元素值为i的个数。

我的理解,首先这题似乎和合并类DP很相似,如果定下f[ i ][ j ]为将前i个元素分割为j段的最小代价,有个显然地转移方程就是f[ i ] [ j ] = MAX{ f[ i1 ] [ j-1 ]+cost( i1+1,i)  }(i1<i)

然后似乎有个单调性可以试着证明一下:如果i1 是满足f[ i ][ j ]=f[ i1 ][ j-1 ]+cost( i1+1,i)的最小值,则对于使得f[ k ][ j ]=f[ i2 ][ j-1 ]+cost( i2+1 , k ) (k>i)的i2必然有i2>=i1。

因为对于所有的i0 < i1 , f[ i0 ][ j-1 ]+cost( i0+1 , i ) >f[ i1 ][ j-1 ]+cost( i1+1 , i ),然后显然有cost( i0+1 , k )  > cost( i1+1 , k )

于是就得到了 f[ i0 ][ j-1 ]+cost( i0+1 , k ) >f[ i1 ][ j-1 ]+cost( i1+1 , k ),即对于任意的j,f[ i ][ j ]的转移点对于i具有单调性

换种简洁点的表述就是,令f[ i ][ j ]=f[  from[ i ][ j ]  ][ j-1 ]+cost( from[ i ][ j ]+1 , i ),则from[ i ][ j ]>=from[ i0 ][ j ]( i0< i )

这题的精妙之处就在于巧妙的利用了这个单调性,如果要求得f[ i ][ j ](L<= i <= R )的值,则可以先求from[ (L+R)/2 ][ j ]的值,然后对于i<(L+R)/2,就有from[ (L+R)/2 ][ j ]>from[ i0 ][ j ] ,对于i>(L+R)/2 ,就有from[ (L+R)/2 ][ j ],然后再递归处理i在[ L , (L+R)/2 -1 ]和[ (L+R)/2+1 , R ]区间时的情况。显然这样递归的层数是logn层,每一层都有总长度为n的扫描[L, R ]区间求from[ (L+R)/2 ][ j ]的花销,这部分总耗时为O(m*n*log n )。

但是怎么高效地求cost( l , r )的值,考虑到既然都是暴力扫描求from[ (L+R)/2 ][ j ],如果可以相同复杂度的消耗来完成就好了。观察分治后的各个子段内部,因为是从左到右依次暴力枚举,而cost的值的增量只和已有的数值数量有关,只要维护当前枚举段内各值的出现次数同时把增量就可以了,于是对于每个字段,求cost也是O(子段的长度)的复杂度了,也就是说,求cost的复杂度和扫描求from[ (L+R)/2 ][ j ]的总复杂度一样都是O(m*n*log n )。

整体算法就完成了

代码:

#include <bits/stdc++.h>
using namespace std;
inline void read(int &x){
    char ch;
    bool flag=false;
    for (ch=getchar();!isdigit(ch);ch=getchar())if (ch=='-') flag=true;
    for (x=0;isdigit(ch);x=x*10+ch-'0',ch=getchar());
    x=flag?-x:x;
}

inline void read(long long &x){
    char ch;
    bool flag=false;
    for (ch=getchar();!isdigit(ch);ch=getchar())if (ch=='-') flag=true;
    for (x=0;isdigit(ch);x=x*10+ch-'0',ch=getchar());
    x=flag?-x:x;
}
inline void write(int x){
    static const int maxlen=100;
    static char s[maxlen];
        if (x<0) {   putchar('-'); x=-x;}
    if(!x){ putchar('0'); return; }
    int len=0; for(;x;x/=10) s[len++]=x % 10+'0';
    for(int i=len-1;i>=0;--i) putchar(s[i]);
}

const int MAXN = 120000;
const int MAXM = 22 ;
typedef long long ll;

int n , m;
int num[ MAXN ];
int calc[ MAXN ];
ll f[ MAXN ];
ll pre[ MAXN ];
ll st,ed,sum;

void solve(int ans_l,int ans_r,int aim_l,int aim_r){
if ( aim_l > aim_r )
    return ;
int mid=(aim_l+aim_r)/2;
int ans_m=ans_l;
ll tmp=1ll<<60;
for (int i=ans_l;i<=min(mid-1, ans_r);i++)
    {
        while ( st< i+1 ) calc[ num[st] ]-- , sum-=calc[ num[st] ] ,st++ ;
        while ( ed> mid ) calc[ num[ed] ]-- , sum-=calc[ num[ed] ] ,ed-- ;
        while ( st> i+1 ) st-- ,sum+=calc[ num[st] ] , calc[ num[st] ]++ ;
        while ( ed< mid )
        {
            ed++ ,sum+=calc[ num[ed] ] , calc[ num[ed] ]++ ;
            //printf("----%d  %d  %d %d\n",ed, mid ,num[ed],calc[ num[ed] ]);
        }
        //printf("%d   %d  %d  %d %d\n",st,ed,i,mid,sum);
        //printf("%d  %d  %d  %d\n",pre[i],sum,pre[ans_m],tmp);
        if ( pre[i] + sum < pre[ ans_m ] + tmp )
            ans_m=i,tmp=sum;
    }
f[mid]=pre[ans_m]+tmp;
//printf("%d --> %d\n",ans_m,mid);
solve( ans_l, ans_m , aim_l , mid-1 );
solve( ans_m, ans_r , mid+1 , aim_r );
}


int main(){
    read(n); read(m);
    for (int i=1;i<=n;i++)
        read(num[i]);
    pre[0]=0;
    for (int i=1;i<=n;i++)
        pre[i]=1ll<<60;
    sum=0;
    st=1;ed=1;sum=0;calc[ num[1] ]=1;
    while ( m--)
        {
            solve(0,n-1,1,n);
            for (int i=1;i<=n;i++)
                pre[i]=f[i];
            /*
            for (int i=1;i<=n;i++)
                printf("%d ",pre[i]);
            puts("");
            */
            memset(f,0,sizeof(f));
        }
    printf("%I64d\n",pre[n]);
    return 0;
}


猜你喜欢

转载自blog.csdn.net/u012602144/article/details/78307801