「AH2017/HNOI2017」礼物

题目链接

戳我

\(Solution\)

应为我们可以将任意一个数列加上一个非负整数,即可以变为将一个数列加上一个整数(可以为负),我们将这个整数设为\(z\).所以要求的式子的变为:
\[\sum_{i=1}^{n}(x_i-y_i+z)^2\]
首先来化简一下式子
\[\sum_{i=1}^{n}(x_i-y_i+z)^2\]
\[\sum_{i=1}^{n}x_i^2+\sum_{i=1}^{n}y_i^2+\sum_{i=1}^{n}z_i^2+2z\sum_{i=1}^{n}(x_i-y_i)-2\sum_{i=1}^{n}x_i*y_i\]
我们可以发现不管如何变化\(\sum_{i=1}^{n}x_i^2+\sum_{i=1}^{n}y_i^2\)的值都不会变.
然后再看看\(\sum_{i=1}^{n}z_i^2+2z\sum_{i=1}^{n}(x_i-y_i)\)显然这是一个关于\(z\)的二次函数.最小值可以\(O(1)\)算出但是注意一下\(z\)必须是整数而且可以为负,所以需要将\(-\frac{x_i-y_i}{n}\)向上和向下取整并带入式子取最小值.
所以我们现在只需要算出\(2\sum_{i=1}^{n}x_i*y_i\)的最大值即可.
这个如何去算,将\(y\)变成链,在进行一次\(fft\)在取一个最大值就好了

\(Code\)

#include<bits/stdc++.h>
#define int long long
#define rg register
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
int read(){
    int x=0,f=1;
    char c=getchar();
    while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
    while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
    return f*x;
}
const int N=3000001;
const double pi=3.14159265358979323846;
struct node {
    double x,y;
    node operator + (node z)const {
        return (node){x+z.x,y+z.y};
    }
    node operator - (node z)const {
        return (node){x-z.x,y-z.y};
    }
    node operator * (node z)const {
        return (node){x*z.x-y*z.y,y*z.x+x*z.y};
    }
}f[N],g[N];
int r[N],limit=1;
void fft(node *a,int opt){
    for(int i=0;i<=limit;i++)
        if(i<r[i])
            swap(a[i],a[r[i]]);
    for(int i=1;i<limit;i<<=1){
        node w=(node){cos(pi/i),opt*sin(pi/i)};
        for(int j=0;j<limit;j+=i<<1){
            node l=(node){1,0};
            for(int k=j;k<j+i;k++){
                node p=l*a[k+i];
                a[k+i]=a[k]-p;
                a[k]=a[k]+p;
                l=l*w;
            }
        }
    }
}
int a[N],b[N];
main(){
    int n=read(),k=read(),l=0,m=n,js=0,ans=0;
    for(int i=1;i<=n;i++)
        a[i]=f[i].x=read(),f[i+n].x=f[i].x,js+=a[i],ans+=a[i]*a[i];
    for(int i=1;i<=n;i++)
        b[i]=read(),js-=b[i],ans+=b[i]*b[i];
    reverse(b+1,b+n+1);
    for(int i=1;i<=n;i++)
        g[i].x=b[i];
    int A=-floor(js*1.0/n),B=-ceil(js*1.0/n);
    ans+=min(A*A*n+2*A*js,B*B*n+2*B*js);
    n*=2;
    while(limit<=n+m)
        limit<<=1,l++;
    for(int i=0;i<limit;i++)
        r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
    fft(f,1),fft(g,1);
    for(int i=0;i<=limit;i++)
        f[i]=f[i]*g[i];
    fft(f,-1);
    int maxx=0;
    for(int i=m;i<=n+1;i++)
        maxx=max((int)(f[i].x/limit+0.5),maxx);
    printf("%lld",ans-2*maxx);
}

猜你喜欢

转载自www.cnblogs.com/hbxblog/p/10328890.html