CF1039D 题解

题 目 链 接

题目大意

给出一棵\(n\)个节点的树,对于\(1\)~\(n\)间的每一个数\(k\),你需要求出:最多能选出多少条互不相交的路径,每条路径的长度都为\(k\)

\(Solution:\)

\(1\)~\(n\)不好算,先考虑单个的\(k\)
\(ans_i\)表示最多能选出多少条互不相交的路径,每条路径的长度都为\(i\)\(f_u\)表示回溯到\(u\)时最多能选出多少条互不相交的路径,每条路径的长度都为\(i\)\(i\)就是前面的\(i\));

考虑贪心地\(dp\):对于一个点\(u\),我们考虑最大化以\(u\)为根的子树中完整的长度为\(k\)的路径条数,其次最大化未完成的链的长度
正确性显然:由于树的特性,\(u\)节点只会在一个路径中被计数,那么我们用类似点分治的看法,设\(u\)是该路径上\(dep\)最浅的点,\(max1\)\(u\)节点向下挂出的最长链的长度,\(max2\)\(u\)向下挂出的次长链的长度,那么:
\((1)\)\(max1+max2+1 \geq k\),那么我们显然要将\(max1,max2,u\)组成的路径计入答案,否则会用到\(u\)的祖先节点,这是不优的;
\((2)Otherwise\),我们就用\(max1,u\)组成的链继承上去

代码大概是这个样子的(看了\(lzy\)大佬的\(blog\),发现其实可以先\(dfs\)一遍跑出\(dfs\)序,然后直接在\(dfs\)序上操作,新技能\(get\)

void dfs(int u,int fa)
{
	fat[u]=fa;
	go(u)
	{
		int v=e[i].to;
		if(v!=fa) dfs(v,u);
	}
	dfn[++idx]=u;
} 
inline int solve(int k)
{
	int ans=0;
	fr(i,1,n) f[i]=1;
	fr(i,1,n)
	{
		int u=dfn[i];
		if(fat[u]&&f[fat[u]]&&f[u])
		{
			if(f[u]+f[fat[u]]>=k)
			{
				++ans,f[fat[u]]=0;
			}
			else f[fat[u]]=max(f[fat[u]],f[u]+1);
		}
	}
	return ans;
}

然鹅这是\(O(n^2)\)的,考虑优化。可以发现有:\(ans_i \leq \frac{n}{i}\)(原因是路径长度为\(i\),那么最好情况即将所有点都用上,就会有\(\frac{n}{i}\)条路径)。

那么结合数据范围\(n \leq 10^5\),我们可以用上根号分治:对于一个确定的\(k\),我们设一个阀值\(B\)
\((1)\)\(k \leq B\),直接暴力\(dp\)(也就是上面的代码),复杂度\(O(nB)\)
\((2)\)\(k > B\),此时必然有:\(ans_k \in [0,\frac{n}{B}]\),也就是只有\(\frac{n}{B}\)这么多个取值,然后\(f\)(即\(dp\)数组)显然具有单调不增的性质,也就是说中间有一段一段的dp值是一样的,那么我们考虑二分出这些段的边界,每次二分用solve()\(check\),复杂度\(O(\frac{n}{B} n \log_2 n)\)

那么我们现在来分析阀值\(B\)的取值,由上面的分析可知,总复杂度\(O(nB+\frac{n}{B} n \log_2 n)=O(n(B+\frac{n}{B} \log n))\),由均值不等式:\(min=n \sqrt{n \log n}\),当且仅\(B=\frac{n}{B} \log n\)\(B=\sqrt{n \log n}\)时取得。

上代码:

\(Code:\)

#include<bits/stdc++.h>
using namespace std;
namespace my_std
{
	typedef long long ll;
	typedef double db;
	#define pf printf
	#define pc putchar
	#define fr(i,x,y) for(register int i=(x);i<=(y);++i)
	#define pfr(i,x,y) for(register int i=(x);i>=(y);--i)
	#define go(x) for(int i=head[u];i;i=e[i].nxt)
	#define enter pc('\n')
	#define space pc(' ')
	#define fir first
	#define sec second
	#define MP make_pair
	const int inf=0x3f3f3f3f;
	const ll inff=1e15;
	inline int read()
	{
		int sum=0,f=1;
		char ch=0;
		while(!isdigit(ch))
		{
			if(ch=='-') f=-1;
			ch=getchar();
		}
		while(isdigit(ch))
		{
			sum=sum*10+(ch^48);
			ch=getchar();
		}
		return sum*f;
	}
	inline void write(int x)
	{
		if(x<0)
		{
			x=-x;
			pc('-');
		}
		if(x>9) write(x/10);
		pc(x%10+'0');
	}
	inline void writeln(int x)
	{
		write(x);
		enter;
	}
	inline void writesp(int x)
	{
		write(x);
		space;
	}
}
using namespace my_std;
const int N=1e5+50;
int n,B,idx,f[N],fat[N],dfn[N],head[N],cnt,ans[N];
struct edge
{
	int to,nxt;
}e[N<<1];
inline void add(int u,int v)
{
	e[++cnt].to=v;
	e[cnt].nxt=head[u];
	head[u]=cnt;
}
void dfs(int u,int fa)
{
	fat[u]=fa;
	go(u)
	{
		int v=e[i].to;
		if(v!=fa) dfs(v,u);
	}
	dfn[++idx]=u;
} 
inline int solve(int k)
{
	int ans=0;
	fr(i,1,n) f[i]=1;
	fr(i,1,n)
	{
		int u=dfn[i];
		if(fat[u]&&f[fat[u]]&&f[u])
		{
			if(f[u]+f[fat[u]]>=k)
			{
				++ans,f[fat[u]]=0;
			}
			else f[fat[u]]=max(f[fat[u]],f[u]+1);
		}
	}
	return ans;
}
int main(void)
{
	n=read();
	B=sqrt(n*log(n)/log(2));
	fr(i,1,n-1)
	{
		int u=read(),v=read();
		add(u,v),add(v,u);
	}
	dfs(1,0);
	//fr(i,1,n) writesp(dfn[i]);
	ans[1]=n;
	fr(i,2,B) ans[i]=solve(i);
	for(int i=B+1,l,r;i<=n;i=l+1)
	{
		l=i,r=n;
		int tmp=solve(i);
		while(r-l>1)
		{
			int mid=(l+r)>>1;
			if(solve(mid)==tmp) l=mid;
			else r=mid;
		}
		fr(j,i,l) ans[j]=tmp;
	}
	fr(i,1,n) writeln(ans[i]);
	return 0;
}

猜你喜欢

转载自www.cnblogs.com/lgj-lgj/p/12714031.html