bzoj 3872 [ Poi 2014 ] Ant colony —— 二分

题目:https://www.lydsy.com/JudgeOnline/problem.php?id=3872

从食蚁兽所在的边向叶节点推,会得到一个渐渐放大的取值区间,在叶子节点上二分有几群蚂蚁符合条件即可;

注意中途判断,如果已经超过范围就返回或者处理一下,据说会爆 long long 之类的;

而且食蚁兽所在的边的两个端点的初始值不一定是 k 和 k+1 !也要看度数!

注意统计答案的 num 也是 long long 。

代码如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define mid ((ll+rr)>>1)
using namespace std;
typedef long long ll;
int const xn=1e6+5;
int n,g,k,s,t,hd[xn],ct,to[xn<<1],nxt[xn<<1],deg[xn];
ll m[xn],l[xn],r[xn],num;//
bool vis[xn];
int rd()
{
    int ret=0,f=1; char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();}
    while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
    return f?ret:-ret;
}
void add(int x,int y){to[++ct]=y; nxt[ct]=hd[x]; hd[x]=ct;}
int findl(ll x)//第一个大于等于x的
{
    int ll=1,rr=g,ret=-1;
    while(ll<=rr)
    {
        if(m[mid]>=x)ret=mid,rr=mid-1;
        else ll=mid+1;
    }
    return ret;
}
int findr(ll x)
{
    int ll=1,rr=g,ret=-1;
    while(ll<=rr)
    {
        if(m[mid]<x)ret=mid,ll=mid+1;//<
        else rr=mid-1;
    }
    return ret;
}
void cal(int x)
{
    if(l[x]>m[g]||r[x]<=m[1])return;
    num+=findr(r[x])-findl(l[x])+1;
}
void dfs(int x)
{
    vis[x]=1;
    if(l[x]>m[g])return;//
    if(r[x]>m[g])r[x]=m[g]+1;//
    for(int i=hd[x],u;i;i=nxt[i])
    {
        if(vis[u=to[i]]||(x==s&&u==t)||(x==t&&u==s))continue;
        if(deg[u]>1)
        {
            l[u]=(deg[u]-1)*l[x]; r[u]=(deg[u]-1)*r[x];
            dfs(u);
        }
        else cal(x);
    }
}
int main()
{
    n=rd(); g=rd(); k=rd();
    for(int i=1;i<=g;i++)m[i]=rd();
    sort(m+1,m+g+1);
    for(int i=1,x,y;i<n;i++)
    {
        x=rd(); y=rd();
        if(i==1)s=x,t=y;
        add(x,y); add(y,x); deg[x]++; deg[y]++;
    }
//    l[s]=l[t]=k; r[s]=r[t]=k+1;
    if(deg[s]==1)l[s]=k,r[s]=k+1;
    else l[s]=(deg[s]-1)*k,r[s]=(deg[s]-1)*(k+1);
    if(deg[s]==1)cal(s);//!
    if(deg[t]==1)l[t]=k,r[t]=k+1;
    else l[t]=(deg[t]-1)*k,r[t]=(deg[t]-1)*(k+1);
    if(deg[t]==1)cal(t);//!
    dfs(s); dfs(t);
    printf("%lld\n",num*k);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Zinn/p/9707276.html