洛谷 P3953 [ NOIP 2017 ] 逛公园 —— 最短路DP

题目:https://www.luogu.org/problemnew/show/P3953

主要是看题解...还是觉得好难想啊...

dfs DP,剩余容量的损耗是边权减去两点最短路差值...表示对于最短路来说多走了这么多...

还要注意该点能否到达 n 号点,不能就不走了(剪枝);

%p 那个地方会爆 int 吗?反正 %=p RE了一个点...(然而改成 ll 还是RE)

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
typedef long long ll;
int const maxn=2e5+5,maxm=4e5+5,inf=0x3f3f3f3f;
int T,n,m,K,p,hd[maxn],ct,head[maxn],xt;
int f[maxn][55],dis[maxn],s[maxn][55];
bool v[maxn][55],in[maxn],vis[maxn];
queue<int>q;
queue<int>ff;
struct N{
    int to,nxt,w;
    N(int t=0,int n=0,int w=0):to(t),nxt(n),w(w) {}
}ed[maxm],edge[maxm];
void add(int x,int y,int z){ed[++ct]=N(y,hd[x],z); hd[x]=ct;}
void add2(int x,int y,int z){edge[++xt]=N(y,head[x],z); head[x]=xt;}
void spfa()
{
    while(q.size())q.pop();
    memset(vis,0,sizeof vis);
    memset(dis,0x3f,sizeof dis);
    dis[1]=0; vis[1]=1; q.push(1);
    while(q.size())
    {
        int x=q.front(); q.pop(); vis[x]=0;
        for(int i=hd[x],u;i;i=ed[i].nxt)
            if(dis[u=ed[i].to]>dis[x]+ed[i].w)
            {
                dis[u]=dis[x]+ed[i].w;
                if(!vis[u])vis[u]=1,q.push(u);
            }
    }
}
void bfs()
{
    while(q.size())q.pop();
    in[n]=1; q.push(n);
    while(q.size())
    {
        int x=q.front(); q.pop();
        for(int i=head[x],u;i;i=edge[i].nxt)
        {
            if(in[u=edge[i].to])continue;
            in[u]=1; q.push(u);
        }
    }
}
int dfs(int x,int w)
{
    if(w<0)return 0;
    else if(v[x][w])return -inf;
    else if(s[x][w]!=-1)return s[x][w];
    else
    {
        int ret=0; v[x][w]=1;
        if(x==n)ret++;//
        for(int i=hd[x],u;i;i=ed[i].nxt)
        {
            if(!in[u=ed[i].to])continue;//!
            int tmp=dfs(u,w-(ed[i].w-(dis[u]-dis[x])));
            if(tmp==-inf)return -inf;
            else ret=(ret+tmp)%p;//(ret+=tmp)%=p RE?
        }
        v[x][w]=0;
        s[x][w]=ret%p;
        return ret;
    }
}
int main()
{
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d%d%d",&n,&m,&K,&p);
        ct=0; xt=0;
        memset(hd,0,sizeof hd);
        memset(head,0,sizeof head);
        for(int i=1,x,y,z;i<=m;i++)
        {
            scanf("%d%d%d",&x,&y,&z);
            add(x,y,z); add2(y,x,z);
        }
        memset(in,0,sizeof in);
        spfa(); bfs();
        memset(v,0,sizeof v);
        memset(s,-1,sizeof s);
        int ans=dfs(1,K);
        if(ans==-inf)printf("-1\n");
        else printf("%d\n",ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Zinn/p/9379254.html