题解:
枚举相同位置的长度是多少,然后可以设计一个DP,
表示第
位,相同状态为
的方案数(注意这里要带个-1的系数方便容斥),然后发现这个状态数很少,就可以过了。
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long ULL;
const int RLEN=1<<18|1;
inline char nc() {
static char ibuf[RLEN],*ib,*ob;
(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob) ? -1 : *ib++;
}
inline int rd() {
char ch=nc(); int i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
return i*f;
}
const int N=2e2+5, M=1e5+5, mod=1e9+7;
const ULL base=131;
inline int add(int x,int y) {return (x+y>=mod) ? (x+y-mod) : (x+y);}
inline int dec(int x,int y) {return (x-y<0) ? (x-y+mod) : (x-y);}
inline int mul(int x,int y) {return (long long)x*y%mod;}
int n,m;
map <ULL,int> mp;
map <ULL,int> dis;
struct sta {
int anc[N];
inline int ga(int x) {
return (anc[x]==x) ? x : (anc[x]=ga(anc[x]));
}
inline ULL Hash() {
ULL sum=0, hv=1;
for(int i=1;i<=m;i++)
sum+=hv*ga(i), hv=hv*base;
return sum;
}
inline void merge(int x,int y) {
x=ga(x), y=ga(y);
if(x>y) swap(x,y);
anc[y]=x;
}
inline void opt(int p) {
for(int i=1,j=m-p+1;i<=p;i++,j++) merge(i,j);
}
inline int calc() {
int rs=0;
for(int i=1;i<=m;i++)
if(ga(i)==i) ++rs;
return rs;
}
} ori;
int tr[M][N],cnt[M],tot;
int f[N][M],g[N],pw[N];
vector <sta> cont[N];
inline void bfs(sta &now) {
dis.clear();
for(int i=1;i<=n-m+1;i++) cont[i].clear();
cont[1].push_back(now);
mp[now.Hash()]=++tot;
dis[now.Hash()]=1;
cnt[tot]=now.calc();
for(int i=1;i<=n-m+1;i++) {
for(int j=0;j<cont[i].size();j++) {
sta u=cont[i][j];
ULL v=u.Hash();
if(dis[v]!=i) continue;
int id=mp[v];
for(int z=1;z<m;z++) if(i+m-z<=n-m+1) {
sta nxt=u;
nxt.opt(z);
v=nxt.Hash();
if(!mp.count(v)) mp[v]=++tot, cnt[tot]=nxt.calc();
tr[id][z]=mp[v];
if(!dis.count(v) || dis[v]>i+m-z) {
dis[v]=i+m-z;
cont[dis[v]].push_back(nxt);
}
}
}
}
}
inline int calc(int nn) {
mp.clear(); m=nn; tot=0;
for(int i=1;i<=m;i++) ori.anc[i]=i;
bfs(ori);
for(int i=1;i<=n-m+1;i++)
for(int j=1;j<=tot;j++) f[i][j]=0;
f[1][1]=mod-1;
for(int i=1;i<=n-m+1;i++)
for(int s=1;s<=tot;s++) if(f[i][s])
for(int j=i+1;j<=n-m+1;j++) {
int nxt=(j<=i+m-1) ? tr[s][i+m-j] : s;
int v=(j<=i+m-1) ? f[i][s] : mul(f[i][s],pw[j-i-m]);
f[j][nxt]=dec(f[j][nxt],v);
}
int ans=0;
for(int i=2;i<=n-m+1;i++)
for(int s=1;s<=tot;s++) if(f[i][s])
ans=add(ans,mul(mul(f[i][s],pw[cnt[s]]),pw[n-i-m+1]));
return ans;
}
int main() {
n=rd(); pw[0]=1; pw[1]=rd();
for(int i=2;i<=n;i++) pw[i]=mul(pw[i-1],pw[1]);
for(int i=1;i<n;i++) g[i]=calc(i);
int ans=0;
for(int i=1;i<n;i++) ans=add(ans,mul(i,dec(g[i],g[i+1])));
cout<<ans<<'\n';
}