【学习笔记】Berlekamp-Massey

orz zhenzhendong

之前贺过一边周指导博客但是弃疗了, 今天又贺了一次.

问题描述

给定一个长度为 \(n\) 的数列 \(\{a_i\}\), 求一个最短的齐次线性递推数列\(\{b_i\}\)(设长度为 \(m\)),使得对于所有 \(m \leq k \leq n\), 有 \(a_k = \sum_{i = 0} ^ m a_{k - i} b_i\)

复杂度要求: \(O(n ^ 2)\)

一看就很适合配合常系数齐次线性递推食用.

算法流程

增量构造.

假设我们当前已经求出了 \(a_{0...i - 1}\) 的线性递推数列. 计算过程中, 我们曾经得出过 \(c\) 个递推式, 第 \(i\) 个递推式在 \(fail_i\) 的位置第一次失效了.

一开始 \(c = 0\), 我们有一个空的递推式.

现在我们加入数 \(a_i\).

\(R_c\) 的长度为 \(m\), \(delta_i = a_i - \sum_{k = 1} ^ {m} a_{i - k} R_c(k)\)

如果 \(delta_i = 0\), \(R_c\) 仍是一个合法的递推式.

否则我们要对 \(R_c\) 做出调整, 来得到一个新的符合条件的递推式.

\(c = 0\), 那么前 \(i - 1\) 个数都是 \(0\). 我们只需要构造一个包含 \(i\)\(0\) 的递推式即可.

\(c \not= 0\), 只需要构造一个递推式 \(R'\), 当 \(|R'| + 1 \leq k < n\) 时, \(\sum_{i = 1} ^ {|R'|} a_{k - i} R'_i = 0\), \(\sum_{i = 1} ^ {|R'| } a_{n - i} R'_i = delta_n\), 那么 \(R_{c + 1} = R_c + R'\) 就符合条件.

我们随便找一个 $ 0 \leq id < c$, 它的前 \(fail_{id} - 1\) 个数都是 \(0\). 如果我们对它作一个位移, 即前面补上 \(i - fail_{id} - 1\)\(0\), 后面跟个 \(1\), 然后接上 \(-R_{id}\), 我们就可以得到一个只有位置 \(i\)\(delta_{fail_{id}}\),其余位置都是 \(0\) 的数组. 然后我们把它整个乘上 \(tmp = \frac{delta_{i}}{delta_{fail_{id}}}\), 就构造出了 \(R'\)

也就是 \(R'\)
\[ \{0,0,...0,tmp,-tmp R_{id}(1),-tmp R_{id}(2),...\} \]

然后我们还要保证 \(R_c + R'\) 是最短的, 我们找到 \(i - fail_{id} + |R_{id}|\) 最小的即可. (不会严格证明)

模板

数据可以去周指导的博客上看.

#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int mod = 1e9 + 7;

template <typename T> T read(T &x) {
    int f = 0;
    register char c = getchar();
    while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
    for (x = 0; c >= '0' && c <= '9'; c = getchar())
        x = (x << 3) + (x << 1) + (c ^ 48);
    if (f) x = -x;
    return x;
}

inline void upd(int &x, int y) {
    (x += y) >= mod ? x -= mod : 0;
}

inline int add(int x, int y) {
    return (x += y) >= mod ? x - mod : x;
}

inline int dec(int x, int y) {
    return (x -= y) < 0 ? x + mod : x;
}

inline int Qpow(int x, int p) {
    int ans = 1;
    for (; p; p >>= 1) {
        if (p & 1) ans = 1LL * ans * x % mod;
        x = 1LL * x * x % mod;
    }
    return ans;
}

inline int Inv(int x) {
    return Qpow(x, mod - 2);
}

namespace BM {

    const int Maxn = 5005;
    
    int n, c;
    int a[Maxn], del[Maxn], fail[Maxn];
    vector<int> R[Maxn];

    vector<int> solve() {
        c = 0;
        for (int i = 1; i <= n; ++i) {
            if (c == 0) {
                if (a[i]) {
                    fail[0] = i;
                    ++c;
                    del[i] = a[i];
                    fail[c] = i;
                    R[c].resize(i);
                }
                continue;
            }
            del[i] = a[i];
            for (int j = 0; j < R[c].size(); ++j) {
                del[i] = dec(del[i], 1LL * R[c][j] * a[i - j - 1] % mod);
            }
            if (del[i] == 0) continue;
            fail[c] = i;
            int id = c - 1, v = i - fail[id] + R[id].size();
            for (int j = c - 1; j >= 0; --j) {
                if (v > i - fail[j] + R[j].size()) {
                    v = i - fail[j] + R[j].size();
                    id = j;
                }
            }
            int p = i - fail[id];
            int tmp = 1LL * del[i] * Inv(del[fail[id]]) % mod;
            R[c + 1] = R[c];
            if (R[c + 1].size() < v) R[c + 1].resize(v);
            upd(R[c + 1][p - 1], tmp);
            for (int j = 0; j < R[id].size(); ++j) {
                upd(R[c + 1][p + j], -1LL * tmp * R[id][j] % mod + mod);
            }
            ++c;
        }
        if (c == 0) return vector<int>(0);
        return R[c];
    }
}
using namespace BM;

int main() {
    read(n);
    for (int i = 1; i <= n; ++i) read(a[i]);
    vector<int> ans = BM::solve();
    cout << ans.size() << endl;
    for (int i = 0; i < ans.size(); ++i)
        cout << ans[i] << ' ';
    puts("");
}

猜你喜欢

转载自www.cnblogs.com/Vexoben/p/11845379.html