[HNOI2017]礼物

Description:

给定两个有n个数的序列,你可以将其中一个进行旋转(想象是在一个环上),或者对序列的每个数加上一个非负整数C
求操作后 \(\sum{(a_i-b_i)^2}\)的最小值

Description:

\(n<=5*10^4,m<=100,a_i<=m\)

Solution:

Description:

给定两个有n个数的序列,你可以将其中一个进行旋转(想象是在一个环上),或者对序列的每个数加上一个非负整数C
求操作后 \(\sum{(a_i-b_i)^2}\)的最小值

Description:

\(n<=5*10^4,m<=100,a_i<=m\)

Solution:

一眼看去,十分不可做,于是开始拆式子

\(\sum(a_i-b_i+C)^2\)
\(=\sum a_i^2 +\sum b_i^2+2*\sum (a_i-b_i)*C +n*C^2-2*\sum a_ib_i\)

由于 \(m\) 很小,我们考虑枚举 C

然后只要求出 \(2*\sum a_ib_i​\) 的最大值就行了

将 b 数组翻转

即求 $ \sum a_ib_{n-i+1}$ 最大值

如何求 ?

将 a 数组倍长

由卷积的定义,FFT后对于 n+1 到 2*n 得到的数就分别对应所有的旋转

checkmax 即可

#include<bits/stdc++.h>
using namespace std;
const int mxn=1e6+5;
const double PI=acos(-1);
int n,m,l,s1,s2,s3,lim=1,r[mxn],tp[mxn];
int ans,res=100000000;

struct cp {
    double x,y;
    cp(double xx=0,double yy=0) {x=xx,y=yy;}
    friend cp operator + (cp a,cp b) {
        return cp(a.x+b.x,a.y+b.y);
    }
    friend cp operator - (cp a,cp b) {
        return cp(a.x-b.x,a.y-b.y);
    }
    friend cp operator * (cp a,cp b) {
        return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
    }
}a[mxn],b[mxn];

void FFT(cp *p,int opt) 
{
    for(int i=0;i<=lim;++i) 
        if(i<r[i]) swap(p[i],p[r[i]]);
    for(int mid=1;mid<lim;mid<<=1) {
        cp wn(cos(PI/mid),opt*sin(PI/mid));
        for(int len=mid<<1,j=0;j<lim;j+=len) {
            cp w(1,0);
            for(int k=0;k<mid;++k,w=w*wn) {
                cp x=p[j+k],y=w*p[j+mid+k];
                p[j+k]=x+y,p[j+mid+k]=x-y;
            }
        }
    }   
}

int main()
{
    scanf("%d%d",&n,&m); 
    for(int i=1;i<=n;++i) {
        scanf("%lf",&a[i].x);
        a[i+n].x=a[i].x;
        ans+=a[i].x*a[i].x;
        s1+=a[i].x;
    }
    for(int i=1;i<=n;++i) {
        scanf("%d",&tp[i]);
        ans+=tp[i]*tp[i];
        s2+=tp[i];
    }
    for(int i=1;i<=n;++i) {
        b[i].x=tp[n-i+1];
    }
    while(lim<=3*n) lim<<=1,++l;
    for(int i=0;i<lim;++i) 
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    FFT(a,1); FFT(b,1);
    for(int i=0;i<=lim;++i) a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=n+1;i<=n*2;++i) s3=max(s3,(int )(a[i].x/lim+0.5));
    ans-=2*s3;
    for(int i=-m;i<=m;++i) res=min(res,n*i*i+2*(s1-s2)*i);
    printf("%d",res+ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/list1/p/10436489.html
今日推荐