codeforces 895D - String Mark 组合数学

很像数位DP的过程。。。

一般这种上界+下届我们可以考虑小于上界的数-小于下届的数。

即a的所有排列中 字典序小于b的个数-字典序小于a的个数-1  就是最终答案。

很明显 按位处理。a排列的字典序要小于s.

处理到第i位时:dfs(id,s)

1.这一位填的字符小于s[i]时,后面的字符任意填  都满足字典序小于s.

2.这一位填的字符等于s[i]时,ans+=dfs(id+1,s);//相当于一个递归的过程。

但是这题dfs会爆栈,1e6.所以我们递推来写。

我们发现上述1,2的过程就相当于执行n次,每次填好前i个字符 使得前i个字符等于s。

所以我们这样做:

枚举前i个字符等于s。求出剩下位置填的方案  求个和即可。

这一定包含所有情况。因为i==1时 相当于所有情况都考虑了,只有第一位等于s的时候没考虑,但i+1时会把这种情况算上而且一定不重不漏,因为第i位都没有算i位相同的情况。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
//typedef __int128 LL;
//typedef unsigned long long ull;
//#define F first
//#define S second
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef pair<ld,ld> pdd;
const ld PI=acos(-1);
const ld eps=1e-9;
//unordered_map<int,int>mp;
#define ls (o<<1)
#define rs (o<<1|1)
#define pb push_back
//#define a(i,j) a[(i)*(m+2)+(j)]  //m是矩阵的列数
//pop_back()
const int seed=131;
const int M = 1e6+7;
const ll mod =1e9+7;
/*
int head[M],cnt;
void init(){cnt=0,memset(head,0,sizeof(head));}
struct EDGE{int to,nxt,val;}ee[M*2];
void add(int x,int y,int z){ee[++cnt].nxt=head[x],ee[cnt].to=y,ee[cnt].val=z,head[x]=cnt;}
*/
char a[M],b[M];
int l,nm[27];;
ll ji[M],inv[M];
ll qpow(ll a,ll b)
{
	ll ans=1;
	while(b)
	{
		if(b&1)ans=ans*a%mod;
		a=a*a%mod;
		b/=2;
	}
	return ans;
}
void pre()
{
	ji[0]=inv[1]=1;
	for(int i=1;i<=1e6;i++)ji[i]=ji[i-1]*i%mod;
    for(int i=1;i<=1e6;i++)inv[i]=qpow(ji[i],mod-2);
}
ll cal(char *s)//a的全排列 中字典序小于s的个数 
{
	ll ans=0;
	memset(nm,0,sizeof(nm));
	for(int i=0;i<l;i++)nm[a[i]-'a']++;//a-z  字符还剩多少个没用 
	for(int i=0;i<l;i++)//枚举a,b i 之前的字符都相同的 情况 
	{
		int ch=s[i]-'a';
		ll tp=1;
		for(int j=0;j<26;j++)if(nm[j])tp=tp*ji[nm[j]]%mod; 
		tp=qpow(tp,mod-2);
		for(int j=0;j<ch;j++)
		{
			if(nm[j])
			{
		//		cout<<j<<"  "<<l-i-1<<"  "<<ji[l-i-1]<<" "<<tp<<"  "<<nm[j]<<" "<<nm[j-1]<<endl;
				if(nm[j]>1)ans=(ans+ji[l-i-1]*tp%mod*ji[nm[j]]%mod*inv[nm[j]-1]%mod)%mod;
				else ans=(ans+ji[l-i-1]*tp%mod)%mod;
			}
		}
		if(nm[ch]<=0)break;//这一位不能凑相同 后面就不用考虑了
		nm[ch]--; 
	}
	return ans;
}
int main()
{
  	scanf("%s%s",a,b);
  	l=strlen(a);pre();
  	cout<<(cal(b)-cal(a)-1+mod)%mod<<endl;
	return 0;
}
发布了284 篇原创文章 · 获赞 13 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/bjfu170203101/article/details/104270326