The North American Invitational Programming Contest 2018 E. Prefix Free Code(Trie+树状数组+排列)

题意:给出n个字符串,可以从这n个串中取出m个进行排列,给出一个排列好的字符串,求这个串是所有排列中的第几个串

分析:场上用的map映射和康拓展开,不停WA,TLE,猝。正确解法是使用Tire将串映射成数字,然后对给出的串的每个字符,求他之前还有多少种排列,然后相加起来,这里可以用树状数组加速统计,同时再求排列数的时候用到逆元;

如下面一组样例

5 3
a
b
c
d
e
cad

通过映射之后 cad成为了数字314的排列,下面对其每个数字进行分析,设其是第ans个

在3前面有1和2 , 那么 ans = ans + 2*A(4,2);

在1前面没有数字,那么 ans = ans+0;

在4前面有虽然有1,2,3 但是 1,3已经被用过了,只剩下了 2 那么 ans = ans + 1*A(2,1);

在求逆元的时候使用费马小定理

AC代码:

#include <map>
#include <set>
#include <cmath>
#include <queue>
#include <stack>
#include <cstdio>
#include <vector>
#include <iomanip>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
#define mod 1000000007
#define lowbit(x) (x&(-x))
#define mem(a) memset(a,0,sizeof(a))
#define FRER() freopen("in.txt","r",stdin);
#define FREW() freopen("out.txt","w",stdout);

using namespace std;

typedef pair<int,int> pii;
const int maxn = 1000000 + 7 , inf = 0x3f3f3f3f ;
int n,m,len,tot,dfn;
int c[maxn],son[maxn][26],vis[maxn],a[maxn];
ll f[maxn];
char s[maxn];
int get_sum(int x){
    int ans = 0;
    while(x>=1){
        ans+=c[x];
        x-=lowbit(x);
    }
    return ans;
}

void add(int x,int d){
    while(x<=n){
        c[x]+=d;
        x += lowbit(x);
    }
}

void dfs(int u){
    if(vis[u]) vis[u]=++dfn;
    for(int i=0;i<26;i++)
        if(son[u][i]) dfs(son[u][i]);
}

ll my_pow(ll x,ll y){
    ll res = 1;
    while(y){
        if(y&1) res = res*x%mod;
        x = x*x%mod;
        y>>=1;
    }
    return res;
}

ll A(int n,int m){
    if(n<m) return 0;
    return f[n]*my_pow(f[n-m],mod-2)%mod;
}

int main(){
    //FRER();
    tot = dfn = 0;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++){
        scanf("%s",s);
        len = strlen(s);
        int x = 0;
        for(int j=0;j<len;j++){
            if(!son[x][s[j]-'a']) son[x][s[j]-'a'] = ++tot;
            x = son[x][s[j]-'a'];
        }
        vis[x] = 1;
    }
    dfs(0);
    f[0] = 1;
    for(ll i=1;i<maxn;i++) f[i]=(f[i-1]*i)%mod;
    scanf("%s",s);
    int x = 0 , cnt = 0;
    for(int i=0;s[i];i++){
        x = son[x][s[i]-'a'];
        if(vis[x]) a[++cnt] = vis[x],x=0;
    }
    ll ans = 1;
    for(int i=1;i<=n;i++) add(i,1);
    for(int i=1;i<=cnt;i++){
        add(a[i],-1);
        ans = (ans + (A((ll)n-i,(ll)m-i)*(ll)get_sum(a[i]))%mod)%mod;
    }
    printf("%lld\n",ans);

}

猜你喜欢

转载自blog.csdn.net/Insist_77/article/details/81592154