题解 bzoj 4398福慧双修(二进制分组)

二进制分组,算个小技巧

bzoj 4398福慧双修

给一张图,同一条边不同方向权值不同,一条边只能走一次,求从1号点出发再回到1号点的最短路
一开始没注意一条边只能走一次这个限制,打了个从一号点相邻节点为原点的dij,样例就挂了
其实就是要从这个错误思路上改进
对于不与1号点相接的边,权值为正,肯定不会重复走,所以这个条件可以忽略
考虑1号点相邻的点,走出第一步后所在的点,和走回1号点前的那个点不能相同
设这两个点编号为\(i\),\(j\),则\(i\),\(j\)的二进制至少有一位不同
所以用二进制分组其实这个方法也是看来题解之后才知道的,自己想真很难想出来
枚举每个二进制位,虚拟一个源点和汇点,这位为0/1的点分成两组,分别与源/汇点相连,跑两次dij,取\(dis[1]\)的最小值
复杂度高于那种构造新图的方式,bzoj上测的总时间2s+,肯定是倒数,但想起来也能简单一些
code.

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<iomanip>
#include<cstring>
#define R register
#define EN std::puts("")
#define LL long long
inline int read(){
    int x=0,y=1;
    char c=std::getchar();
    while(c<'0'||c>'9'){if(c=='-') y=0;c=std::getchar();}
    while(c>='0'&&c<='9'){x=x*10+(c^48);c=std::getchar();}
    return y?x:-x;
}
int n,m;
int fir[40006],nex[200006],to[200006],w[200006],tot;
int dui[40006],size;
int dis[40006],in[40006];
inline void push(int x){
    dui[size++]=x;
    R int i=size-1,fa;
    while(i){
        fa=i>>1;
        if(dis[dui[fa]]<=dis[dui[i]]) return;
        std::swap(dui[fa],dui[i]);i=fa;
    }
}
inline int pop(){
    int ret=dui[0];dui[0]=dui[--size];
    R int i=0,ls,rs;
    while((i<<1)<size){
        ls=i<<1;rs=ls|1;
        if(rs<size&&dis[dui[rs]]<dis[dui[ls]]) ls=rs;
        if(dis[dui[ls]]>=dis[dui[i]]) break;
        std::swap(dui[ls],dui[i]);i=ls;
    }
    return ret;
}
inline int dij(int bit,int panduan){
    std::memset(dis,0x3f,sizeof dis);
    for(R int i=fir[1];i;i=nex[i])if((to[i]&bit)==panduan){
        push(to[i]);dis[to[i]]=w[i];in[to[i]]=1;
    }
    while(size){
        R int u=pop();in[u]=0;
        for(R int i=fir[u];i;i=nex[i]){
            R int v=to[i];
            if(v==1&&(u&bit)==panduan) continue;
            if(dis[v]>dis[u]+w[i]){
                dis[v]=dis[u]+w[i];
                if(!in[v]) push(v),in[v]=1;
            }
        }
    }
    return dis[1];
}
inline void add(int x,int y,int z){
    to[++tot]=y;w[tot]=z;
    nex[tot]=fir[x];fir[x]=tot;
}
int work(){
    R int ret=0x3f3f3f3f;
    R int tmp=0;int nn=n;
    while(nn) tmp++,nn>>=1;
    for(R int i=0;i<tmp;i++)
        {ret=std::min(ret,dij(1<<i,1<<i)),ret=std::min(ret,dij(1<<i,0));}
    return ret;
}
int main(){
    n=read();m=read();
    for(R int i=1;i<=m;i++){
        int x=read(),y=read(),ww=read(),www=read();
        add(x,y,ww);add(y,x,www);
    }
    R int ret=work();
    std::printf("%d",ret==0x3f3f3f3f?-1:ret);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suxxsfe/p/12527529.html