计蒜客 Prefix Free Code 字典树 + 树状数组 + 求逆元

1.题意:给你n个字符串,给你一个k,意思是你能任选k个字符串组成一个长字符串。再给你一个长字符串问你这个字符串在所有任选k个字符串组合中字典序排第几。(所有字符串长度之和不大于1e6 , 要求结果对1e9 + 7取模)

2.思路:有康拓展开的思路联想到:我们如果把所有字符串由小到大排序后,标为1~n,再把要求的字符串映射为对应数列,比如:

5 3
a
b
c
d
e
cad

这个样例,cad映射为 3 1 4,接下来找所有比314字典序小的数列,由组合数学知:第 i 位上比当前字典序小的数列共有

getNum( ans[i] ) *C(n-i,m-i) * (m-i) !(其中getNum( ans[i] )表示比第i位数字小的没有用过的数字个数)

进一步推得:

于是第i位上的比他小的数字有: getNum(ans[i])*A(n-i,n-m);

这样最后的累和ans就是所有比要求字符串小的字符串数目,他的位数当然是ans+1,分析完成。

3.实现:

(1)这题肯定会卡超时,所以我们在字符串上的处理要使用字典树,因为所有字符串都不是另一个的前缀,我们把所有输入的字符串插入字典树里,排序后分别给他们标上下标。(分配下标完成)

(2)然后把长字符串(也就是我们要求的)转化为数列,怎么转化呢?还是用字典树不断搜索,当遇到该处judge = true;说明该处构成单词了,把该处下标保存,然后cnt = 0(字典树下标归零),重新找下一个单词。(数列映射完成)

(3)计算getNum(ans[i])时,表示比当前数字小的还未使用的数字数目,为了加快速度,这里使用树状数组处理(逆序数)

(4)计算(n-i)! / (n - m)! %mod 因为取模运算分配率不满足除法,所以这里必须来求逆元(不然会出现小数或者溢出),所以这里采用费马小定理来求逆元:(n-i)!/(n-m)! = (n - i)! * inv((n-m)!)%mod

注:如何求逆元参考我的下一篇文章

4.代码:

方法一:string + sort排序(132ms)

#include <iostream>
#include<cstdio>
#include<string>
#include<algorithm>
#include<cstring>
#include<map>
#include<vector>
#define lowbit(x) x&(-x)
#define mod 1000000007
using namespace std;
typedef long long LL;
const int maxn = 1000002;
int n,k,tot,number,l;
LL f[maxn];
LL com[maxn],ans[maxn];
string str[maxn];
struct Node
{
    int next[26];
    bool judge;
    int sign;
};
Node node[maxn];
int CreatTree()
{
    memset(node[tot].next,0,sizeof(node[tot].next));
    node[tot].sign = 0;
    node[tot].judge = false;
    return tot++;
}
LL power(LL a,LL b)//快速幂
{
    a%=mod;
    LL aans = 1;
    while(b){
        if(b&1)aans = (aans*a)%mod;
        b>>=1;
        a = (a*a)%mod;
    }
    return aans;
}
void insertTree(string s)
{
    int len = s.size();
    int cnt = 0;
    for(int i = 0;i<len;i++){
        int k = s[i] - 'a';
        if(node[cnt].next[k]==0){
            node[cnt].next[k] = CreatTree();
        }
        cnt = node[cnt].next[k];
    }
    node[cnt].sign = ++number;//分配数字
    node[cnt].judge = true;
}
void findTree(string s)
{
    int len = s.size();
    int cnt = 0;
    for(int i = 0;i<len;i++){
        int k = s[i] - 'a';
        cnt = node[cnt].next[k];
        if(node[cnt].judge==true){//找到一个字符串,保存数字
            ans[++l] = node[cnt].sign;
            cnt = 0;//归零找下一个字符串
        }
    }
}
void Update(int x,int c)
{
    for(int i = x;i<=n;i+=lowbit(i)){
        com[i]+=c;
    }
}
LL A(LL numA,LL numB)
{
    if(numB<0)return 0;//费马小定理
    return ((f[numA]%mod)*(power(f[numB],mod-2)%mod))%mod;
}
LL get_Num(int x)
{
    LL p = 0;
    for(int i = x;i>0;i-=lowbit(i)){
        p+=com[i];
    }
    return p;
}
int main()
{
    tot = number = l = 0;
    memset(com,0,sizeof(com));
    scanf("%d%d",&n,&k);
    CreatTree();//一开始忘了建立树根无限RE QAQ
    for(int i = 0;i<n;i++){
        cin>>str[i];
    }
    sort(str,str+n);//先排序再插入
    for(int i = 0;i<n;i++)insertTree(str[i]);
    string name;
    cin>>name;
    findTree(name);
    f[0] = 1;
    for(int i = 1;i<maxn;i++){f[i] =((f[i-1]%mod)*(i%mod))%mod;}//计算乘阶
    for(int i = 1;i<=n;i++)Update(i,1);//刚开始所有数字都能用
    LL sum = 0;
    for(int i = 1;i<=l;i++){
        sum = (sum + (A((LL)n-i,(LL)n-k)*(LL)(get_Num(ans[i])-1))%mod)%mod;
        Update(ans[i],-1);//更新
    }
    printf("%lld\n",(sum+1)%mod);
    return 0;
}

方法二:字符数组 + dfs排序(87ms)

#include <iostream>
#include<cstdio>
#include<string>
#include<algorithm>
#include<cstring>
#include<map>
#include<vector>
#define lowbit(x) x&(-x)
#define mod 1000000007
using namespace std;
typedef long long LL;
const int maxn = 1000007;
char str[maxn];
LL f[maxn];
int ans[maxn],sum[maxn],n,k,tot,lenth,number;
struct Node
{
    int next[26];
    bool judge;
    int index;
};
Node node[maxn];
int CreatTree(){
    node[tot].judge = false;
    memset(node[tot].next,0,sizeof(node[tot].next));
    node[tot].index = 0;
    return tot++;
}
LL power(LL x,LL y){
    x%=mod;
    LL cnt = 1;
    while(y){
        if(y&1)cnt = (cnt*x)%mod;
        y>>=1;
        x = (x*x)%mod;
    }
    return cnt;
}
void insertTree(char *s){
    int len = strlen(s);
    int cnt = 0;
    for(int i = 0;i<len;i++){
        int k = s[i] - 'a';
        if(node[cnt].next[k]==0){
            node[cnt].next[k] = CreatTree();
        }
        cnt = node[cnt].next[k];
    }
    node[cnt].judge = true;
}
void findTree(char *s){
    int len = strlen(s);
    int cnt = 0;
    for(int i = 0;i<len;i++){
        int k = s[i] - 'a';
        cnt = node[cnt].next[k];
        if(node[cnt].judge){
            ans[++lenth] = node[cnt].index;
            cnt = 0;
        }
    }
}
void DfsOrder(int a){//排序分配数字
    if(node[a].judge){node[a].index = ++number;return;}
    for(int i = 0;i<26;i++){
        if(node[a].next[i])DfsOrder(node[a].next[i]);
    }
}
void Update(int x,int c){
    for(int i = x;i<=n;i+=lowbit(i)){
        sum[i]+=c;
    }
}
int getNum(int x){
    int cnt = 0;
    for(int i = x;i>0;i-=lowbit(i)){
        cnt+=sum[i];
    }
    return cnt;
}
LL A(int x,int y){
    if(y<0)return 0;
    return (f[x]*(power(f[y],mod-2)%mod))%mod;
}
int main(){
    tot = lenth = number = 0;
    memset(sum,0,sizeof(sum));
    f[0] = 1;
    for(int i = 1;i<maxn;i++){f[i] = (f[i-1]*i)%mod;}
    scanf("%d%d",&n,&k);
    CreatTree();
    for(int i = 1;i<=n;i++)Update(i,1);
    for(int i = 0;i<n;i++){
        scanf("%s",str);
        insertTree(str);//先插入
    }
    DfsOrder(0);//递归排序,因为我们循环肯定先从小的开始,字典序由有小到大的
    scanf("%s",str);
    findTree(str);
    LL answer = 0;
    for(int i = 1;i<=lenth;i++){
        answer = (answer + (A(n-i,n-k)*(getNum(ans[i])-1))%mod)%mod;
        Update(ans[i],-1);
    }
    printf("%lld\n",(answer+1)%mod);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_40772692/article/details/81738156