@codeforces - 718E@ Matvey's Birthday


@description@

给定一个长度为 n 的字符串 s,保证只包含前 8 个小写字母 'a', 'b', ... 'h'。

根据该字符串建一个图。两个点 p, q 之间有连边要么 |p - q| = 1,要么 s[p] = s[q]。

求该图直径的长度(所有点对之间的最短距离的最大值),以及直径的数量。

Input
第一行一个整数 n,表示字符串长度。
第二行一个字符串 s。保证只包含前 8 个小写字母。

Output
输出直径长度与直径数量。

Examples
Input
3
abc
Output
2 1

Input
7
aaabaaa
Output
2 4

@solution@

虽然 cf 的难度系统不太准,但至少难度 > 3000 都是我不会做的题.jpg。

先考虑如何快速求两个点 i, j 之间的最短路。
首先注意到最短路上不能出现不相邻的相同字符(即类似于 'a' -> ... -> 'a'),否则我可以直接从第一个相同字符跳到最后一个相同字符。
这意味着最短路径一定 <= 2*8。

假如不经过相邻的相同字符(即不经过 s[p] = s[q] 类型的边),最短路径长度为 |i - j|。
否则,我们以某种字符为中转,向两边求到 i, j 的最短路,两者之和即 i->j 的最短路。
即如果记 d[c][i] 表示 i 到达某一个字符 c 的最短路,此时最短路径为 min(d[c][i] + d[c][j])。
那么 i, j 之间的最短路一定为 min(|i - j|, min(d[c][i] + d[c][j]))。

怎么求 d[c][i]?我们可以 bfs 搞定。
只是需要注意由于相同字符构成了一个完全图,假如这个字符的所有点已经全部进入队列,我们需要打上 tag 防止之后反复访问。不然时间就炸了。
这一部分的复杂度为 O(8*n)。

考虑通用的解法:枚举 i,求以 i 为起点的最长路径及路径数量。但是这样子还是 O(n^2) 的。
首先,只有 i 前面 16 个以及后面 16 个是可能取 |i - j| 为最小值的,这些直接暴算。
那么剩下的 j 只可能取 min(d[c][i] + d[c][j]),我们再研究怎么简化这一部分的复杂度。

注意到这个只跟 d[c][j] 有关。我们或许可以将 d[0][j], d[1][j], ... 相同的 j 放在一起处理。
具体怎么操作?注意到相同字符 x 对应的 j, k,总有 |d[c][j] - d[c][k]| <= 1(因为它们之间有边连接)。
记 mnd[c][x] = min(d[c][j]),那么对于字符 x 对应的任意 j,有 d[c][j] = mnd[c][x] 或 d[c][j] = mnd[c][x] + 1。
我们可以用一个 8 位的二进制状态将 d[c][j] 压缩,并存储该二进制状态对应的数量,就可以实现我们的目的了。

枚举 i 过后,再枚举 x 以及 8 位的二进制状态得到 d[0][j], d[1][j], ...,然后枚举中转字符 c,根据式子算。
注意我们需要把 i 前面 16 个以及后面 16 个的贡献先消掉。不然会计算重复。

@accepted code@

#include <queue>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 100000;
int clr[MAXN + 5], n;
vector<int>v[MAXN + 5];
int d[10][MAXN + 5];
bool tag[10];
void get_dist(int x) {
    queue<int>que;
    for(int i=0;i<8;i++) tag[i] = false;
    for(int i=1;i<=n;i++) d[x][i] = MAXN + 5;
    for(int i=0;i<v[x].size();i++)
        d[x][v[x][i]] = 0, que.push(v[x][i]);
    while( !que.empty() ) {
        int f = que.front(); que.pop();
        if( !tag[clr[f]] ) {
            tag[clr[f]] = true;
            for(int i=0;i<v[clr[f]].size();i++) {
                int u = v[clr[f]][i];
                if( d[x][u] > d[x][f] + 1 )
                    d[x][u] = d[x][f] + 1, que.push(u);
            }
        }
        if( f != 1 && d[x][f-1] > d[x][f] + 1 )
            d[x][f-1] = d[x][f] + 1, que.push(f-1);
        if( f != n && d[x][f+1] > d[x][f] + 1 )
            d[x][f+1] = d[x][f] + 1, que.push(f+1);
    }
}
int mnd[10][10], bts[MAXN + 5], cnt[10][1<<10];
void get_mask(int x) {
    for(int i=0;i<8;i++) {
        mnd[x][i] = MAXN + 5;
        for(int j=0;j<v[x].size();j++)
            mnd[x][i] = min(mnd[x][i], d[i][v[x][j]]);
        for(int j=0;j<v[x].size();j++)
            bts[v[x][j]] |= ((d[i][v[x][j]] - mnd[x][i])<<i);
    }
    for(int j=0;j<v[x].size();j++)
        cnt[x][bts[v[x][j]]]++;
}
int ans1; long long ans2;
void update(int x, int t) {
    if( x == ans1 ) ans2 += t;
    else if( x > ans1 ) ans1 = x, ans2 = t;
}
char s[MAXN + 5];
int abs(int x) {return x >= 0 ? x : -x;}
int main() {
    scanf("%d%s", &n, s + 1);
    for(int i=1;i<=n;i++)
        v[s[i]-'a'].push_back(i), clr[i] = s[i] - 'a';
    for(int i=0;i<8;i++) get_dist(i);
    for(int i=0;i<8;i++) get_mask(i);
    ans1 = 0, ans2 = 0;
    int t = (1<<8);
    for(int i=1;i<=n;i++) {
        for(int j=max(1,i-16);j<=min(i+16,n);j++) {
            int mn = abs(i - j);
            for(int k=0;k<8;k++)
                mn = min(mn, d[k][i] + 1 + d[k][j]);
            cnt[clr[j]][bts[j]]--, update(mn, 1);
        }
        for(int j=0;j<8;j++) {
            for(int s=0;s<t;s++) {
                int mn = MAXN + 5;
                for(int k=0;k<8;k++)
                    mn = min(mn, mnd[j][k] + ((s>>k) & 1) + d[k][i] + 1);
                if( cnt[j][s] )
                    update(mn, cnt[j][s]);
            }
        }
        for(int j=max(1,i-16);j<=min(i+16,n);j++)
            cnt[clr[j]][bts[j]]++;
    }
    printf("%d %lld\n", ans1, ans2/2);
}

@detail@

因为卡在第一步(求两个点之间的最短路)所以。。。也没什么好说的吧。。。

感觉自己代码能力还是有所提升的样子(好久没有那种一道题调一天的快感了)
这个 detail 的板块最近基本也没用来提示代码细节,反倒是在吐槽了 2333。

猜你喜欢

转载自www.cnblogs.com/Tiw-Air-OAO/p/11847166.html