SGU421: K-th Product 题解

这题贼难写…但我觉得锻炼代码能力和思维都是非常好的
我们先考虑全是正数的情况,这种情况是比较简单的,考虑到 k 10000 ,这种题目的套路都是把一个初始状态扔进一个大根堆,每次取堆顶元素后把它能扩展到的元素扔进堆
刚开始的最大值肯定是取最大的前 m 个数,设初始序列是1.2.3…m
对于一个我当前取的序列,它能扩展到的序列一定是只有一位向后移了一格的一个序列,因为 m 13 ,所以状态数不会太多
注意数字可能会非常大,所以要写高精度(恶心)
下面考虑有负数的情况
我们发现可以把正数和负数分开来考虑;我们先考虑乘积是非负数,这时负数一定取偶数个,然后把正数和负数按照绝对值从大到小排序就和上面一样了,乘积是负数,则负数一定取奇数个,然后把正数和负数按绝对值从小到大排序就和上面一样了


貌似题目已经做完了,但噩梦才刚刚开始
假设你已经写出了一个不错的程序,这时你通常会得到tle on test 20
跑一跑极限数据,发现开了O2竟然要跑9s
于是就开始了漫漫卡常路(之后的时间都是开了O2测的)
首先考虑用哈希替换map,map是在向堆里加新元素用到的,因为你要保证加进去的元素不重复,而事实上一个序列是可能有多种方法被扩展出来的
写完哈希以后,成功从9s减到3s
然后考虑优化高精度的常数,这里有一个常见的小技巧
我们通常会开一个数组,每一个元素存一位数字,但事实上两个int在做乘法的时候,数字小并不会变快
所以我们改变一下策略,每个元素存7位数字,这样我的数字的最大位数就会大大减小,做乘法是 O ( n 2 ) 的,相当于优化了更多
这个地方有一个小细节要注意一下:每7位一存的话,除了最高位,其他的要补前导零
优化完高精度的常数,成功的从3s优化到2s(效果好像没有预期的明显啊)
考虑到每7位一存的话要开LL,运行速度会变慢,所以改成了每3位一存,这样int就能解决,优化到了1.4s(这是个玄学)


到了这里,卡常死活卡不动了,然而这题要求0.75s
我们尝试开始考虑优化算法(什么这个算法还能优化…)
这个哈希表其实是非常慢的,因为cmp一次的代价是 O ( m ) 的,尝试在这里做文章
我们为什么要哈希?因为同一个序列不能出现多次,而事实上同一个序列可能有很多种不同的扩展方式
那我们能不能通过限定扩展方式,使得序列不会重复出现?
是可以的,我们只要规定一个扩展的顺序,例如,对于正数和负数,我都规定如果这次你把第i个数向后移了一位得到了新的状态,那么下一次你不能再移动i+1~m
这样对于任意一个序列,他的扩展方式都是唯一的,只有先移最后一个元素到指定位置,然后移倒数第二个…以此类推
这样把哈希表省掉了,这是一个大优化
时间卡到了0.2s,不开O2 0.4s,OK可以过了


Code

#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
const double pi=3.14159265;

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

struct Bigint
{
    int b[30],len;
    inline void clear() {memset(b,0,sizeof(b));len=1;}
    inline void read(int x) {clear();for (len=0;x;x/=1000) b[++len]=x%1000;if (!len) len=1;}
    inline void print()
    {
        for (register int i=len;i;i--)
        {
            if (i==len) {printf("%d",b[i]);continue;}
            int cnt=0,tt=100;
            while (b[i]<tt) cnt++,tt/=10;
            while (cnt--) printf("0");
            if (b[i]) printf("%d",b[i]);
        }
    }
    inline bool operator < (const Bigint &x) const
    {
        if (len!=x.len) return len<x.len;
        for (register int i=len;i>=1;i--)
            if (b[i]!=x.b[i]) return b[i]<x.b[i];
        return true;
    }
    inline bool operator > (const Bigint &x) const
    {
        if (len!=x.len) return len>x.len;
        for (register int i=len;i>=1;i--)
            if (b[i]!=x.b[i]) return b[i]>x.b[i];
        return true;
    }
    inline Bigint operator + (const Bigint x)
    {
        Bigint res;res.clear();
        for (register int i=1;i<=max(len,x.len);i++)
        {
            res.b[i]+=b[i]+x.b[i];
            res.b[i+1]+=res.b[i]/1000;res.b[i]%=1000;
        }
        res.len=max(len,x.len);if (res.b[res.len+1]) res.len++;
        return res;
    }
    inline Bigint operator * (const Bigint x)
    {
        Bigint res;res.clear();int i,j;
        for (i=1;i<=len;i++)
            for (j=1;j<=x.len;j++)
                res.b[i+j-1]+=b[i]*x.b[j];
        for (i=1;i<=len+x.len-1;i++) res.b[i+1]+=res.b[i]/1000,res.b[i]%=1000;
        res.len=len+x.len-1;if (res.b[res.len+1]) res.len++;
        return res;
    }
};

int n,m,k;
int a[10048];
Bigint pos[10048],neg[10048];int ptot,ntot;
Bigint res[71048];int rtot;

inline int myabs(int x) {return x>=0?x:-x;}
inline bool cmp_big(Bigint x,Bigint y) {return x>y;}
inline bool cmp_small(Bigint x,Bigint y) {return x<y;}

struct node_pos
{
    Bigint val;
    short pnum,nnum,npos[20],ppos[20];
    bool cur;int cpos;
    inline bool operator < (const node_pos &x) const {return val<x.val;}
};

priority_queue<node_pos> qpos;
inline void Getpos(short pnum,short nnum)
{
    int i,j,cnt;node_pos ins,cur;
    ins.pnum=pnum;ins.nnum=nnum;ins.val.clear();ins.val.b[1]=1;ins.val.len=1;
    for (i=1;i<=pnum;i++) ins.ppos[i]=i,ins.val=ins.val*pos[i];
    for (i=1;i<=nnum;i++) ins.npos[i]=i,ins.val=ins.val*neg[i];
    ins.cur=false;ins.cpos=pnum;
    while (!qpos.empty()) qpos.pop();
    qpos.push(ins);cnt=0;
    while (!qpos.empty() && cnt<k)
    {
        cur=qpos.top();qpos.pop();
        cnt++;res[++rtot]=cur.val;
        if (!cur.cur)
        {
            for (i=1;i<=cur.cpos;i++)
                if (cur.ppos[i]<ptot && (i==pnum || cur.ppos[i+1]!=cur.ppos[i]+1))
                {
                    ins.cpos=i;ins.cur=false;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.ppos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*pos[ins.ppos[j]];
                    for (j=1;j<=nnum;j++) ins.val=ins.val*neg[ins.npos[j]];
                    qpos.push(ins);
                }
            for (i=nnum;i<=nnum && i;i++)
                if (cur.npos[i]<ntot && (i==nnum || cur.npos[i+1]!=cur.npos[i]+1))
                {
                    ins.cpos=i;ins.cur=true;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.npos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*(pos[ins.ppos[j]]);
                    for (j=1;j<=nnum;j++) ins.val=ins.val*(neg[ins.npos[j]]);
                    qpos.push(ins);
                }
        }
        else
        {
            for (i=1;i<=cur.cpos;i++)
                if (cur.npos[i]<ntot && (i==nnum || cur.npos[i+1]!=cur.npos[i]+1))
                {
                    ins.cpos=i;ins.cur=true;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.npos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*(pos[ins.ppos[j]]);
                    for (j=1;j<=nnum;j++) ins.val=ins.val*(neg[ins.npos[j]]);
                    qpos.push(ins);
                }
        }   
    }
}

struct node_neg
{
    Bigint val;
    short pnum,nnum,ppos[20],npos[20];
    bool cur;int cpos;
    inline bool operator < (const node_neg &x) const {return val>x.val;}
};

priority_queue<node_neg> qneg;
inline void Getneg(short pnum,short nnum)
{
    int i,j,cnt;node_neg ins,cur;
    ins.pnum=pnum;ins.nnum=nnum;ins.val.clear();ins.val.b[1]=1;ins.val.len=1;
    for (i=1;i<=pnum;i++) ins.ppos[i]=i,ins.val=ins.val*pos[i];
    for (i=1;i<=nnum;i++) ins.npos[i]=i,ins.val=ins.val*neg[i];
    ins.cur=false;ins.cpos=pnum;
    while (!qneg.empty()) qneg.pop();
    qneg.push(ins);cnt=0;
    while (!qneg.empty() && cnt<k)
    {
        cur=qneg.top();qneg.pop();
        cnt++;res[++rtot]=cur.val;
        if (!cur.cur)
        {
            for (i=1;i<=cur.cpos;i++)
                if (cur.ppos[i]<ptot && (i==pnum || cur.ppos[i+1]!=cur.ppos[i]+1))
                {
                    ins.cpos=i;ins.cur=false;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.ppos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*pos[ins.ppos[j]];
                    for (j=1;j<=nnum;j++) ins.val=ins.val*neg[ins.npos[j]];
                    qneg.push(ins);
                }
            for (i=nnum;i<=nnum && i;i++)
                if (cur.npos[i]<ntot && (i==nnum || cur.npos[i+1]!=cur.npos[i]+1))
                {
                    ins.cpos=i;ins.cur=true;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.npos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*(pos[ins.ppos[j]]);
                    for (j=1;j<=nnum;j++) ins.val=ins.val*(neg[ins.npos[j]]);
                    qneg.push(ins);
                }
        }
        else
        {
            for (i=1;i<=cur.cpos;i++)
                if (cur.npos[i]<ntot && (i==nnum || cur.npos[i+1]!=cur.npos[i]+1))
                {
                    ins.cpos=i;ins.cur=true;
                    for (j=1;j<=pnum;j++) ins.ppos[j]=cur.ppos[j];
                    for (j=1;j<=nnum;j++) ins.npos[j]=cur.npos[j];
                    ins.npos[i]++;ins.val.clear();ins.val.len=1;ins.val.b[1]=1;
                    for (j=1;j<=pnum;j++) ins.val=ins.val*(pos[ins.ppos[j]]);
                    for (j=1;j<=nnum;j++) ins.val=ins.val*(neg[ins.npos[j]]);
                    qneg.push(ins);
                }
        }
    }
}

Bigint tmp[400048];int fl;
inline void merge_sort(int left,int right)
{
    if (left>=right) return;
    int mid=(left+right)>>1,k1,k2,pt;
    merge_sort(left,mid);merge_sort(mid+1,right);
    for (k1=left,k2=mid+1,pt=left;k1<=mid && k2<=right;)
        if (fl==1)
        {
            if (res[k1]<res[k2]) tmp[pt++]=res[k1++]; else tmp[pt++]=res[k2++];
        }
        else
        {
            if (res[k1]>res[k2]) tmp[pt++]=res[k1++]; else tmp[pt++]=res[k2++];
        }
    for (;k1<=mid;) tmp[pt++]=res[k1++];
    for (;k2<=right;) tmp[pt++]=res[k2++];
    for (pt=left;pt<=right;pt++) res[pt]=tmp[pt];
}

int main ()
{
    int i,j;
    n=getint();m=getint();k=getint();
    for (i=1;i<=n;i++)
    {
        a[i]=getint();
        if (a[i]>=0) pos[++ptot].read(a[i]); else neg[++ntot].read(-a[i]);
    }
    for (i=1;i<=ptot;i++) res[i]=pos[i];
    fl=0;merge_sort(1,ptot);
    for (i=1;i<=ptot;i++) pos[i]=res[i];
    for (i=1;i<=ntot;i++) res[i]=neg[i];
    fl=0;merge_sort(1,ntot);
    for (i=1;i<=ntot;i++) neg[i]=res[i];
        rtot=0;
        for (i=0;i<=m;i+=2)
        {
            if (ptot<m-i || ntot<i) continue;
            Getpos(m-i,i);
        }
        if (rtot>=k)
        {
            fl=0;merge_sort(1,rtot);
            res[k].print();
            return 0;
        }
        k-=rtot;
        rtot=0;reverse(pos+1,pos+ptot+1);reverse(neg+1,neg+ntot+1);
        for (i=1;i<=m;i+=2)
        {
            if (ptot<m-i || ntot<i) continue;
            Getneg(m-i,i);
        }
        fl=1;merge_sort(1,rtot);
        printf("-");res[k].print();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/iceprincess_1968/article/details/80041026
今日推荐