【bzoj5339】[TJOI2018]教科书般的亵渎(拉格朗日插值/第二类斯特林数)

传送门

题意:
一开始有很多怪兽,每个怪兽的血量在\(1\)\(n\)之间且各不相同,\(n\leq 10^{13}\)
然后有\(m\)种没有出现的血量,\(m\leq 50\)
现在有个人可以使用魔法卡片,使用一张会使得所有的怪兽掉一点血,如果有怪兽死亡,则继续施展魔法。
这个人能够获得一定的分数,分数计算如下,每一次使用卡片前,假设一个怪兽血量为\(x\),那么获得\(x^k\)的分数。\(k\)为杀死所有怪兽需要的卡片数量。
求最后总的分数。

思路:
因为\(m\)很小,那么我们可以对每次施展卡片前获得的分数单独计算,最后加起来即可。
那么这个问题的本质就是要算:
\[ \sum_{i=0}^ni^k-\sum_{j=1}^ma_j^k \]
后面一部分显然可以直接计算,那么主要问题就在于计算前面的部分。
而幂级数的形式可以直接用第二类斯特林数展开,最后问题就变为了预处理第二类斯特林数,计算可以直接\(O(k)\)计算。
展开过程详见:传送门

当然,这显然为一个与\(n\)有关的\(k+1\)次多项式,拉格朗日插值搞一搞就行。
当然,还有许多其它的方法,太菜了还不会...
斯特林数:

/*
 * Author:  heyuhhh
 * Created Time:  2019/12/14 11:00:17
 */
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <cmath>
#include <set>
#include <map>
#include <queue>
#include <iomanip>
#define MP make_pair
#define fi first
#define se second
#define sz(x) (int)(x).size()
#define all(x) (x).begin(), (x).end()
#define INF 0x3f3f3f3f
#define Local
#ifdef Local
  #define dbg(args...) do { cout << #args << " -> "; err(args); } while (0)
  void err() { std::cout << '\n'; }
  template<typename T, typename...Args>
  void err(T a, Args...args) { std::cout << a << ' '; err(args...); }
#else
  #define dbg(...)
#endif
void pt() {std::cout << '\n'; }
template<typename T, typename...Args>
void pt(T a, Args...args) {std::cout << a << ' '; pt(args...); }
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
//head
const int N = 55, MOD = 1e9 + 7;

ll n;
int m;
int s[N][N], fac[N], c[N];
ll a[N];

ll qpow(ll a, ll b) {
    a %= MOD;
    ll ans = 1;
    while(b) {
        if(b & 1) ans = ans * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return ans;
}

void init() {
    s[0][0] = 1;
    for(int i = 1; i < N; i++) 
        for(int j = 1; j <= i; j++) 
            s[i][j] = (1ll * s[i - 1][j] * j % MOD + s[i - 1][j - 1]) % MOD;
    fac[0] = 1;
    for(int i = 1; i < N; i++) fac[i] = 1ll * fac[i - 1] * i % MOD;
    c[0] = 1;
}

int calc(ll n, int k) {
    int res = 0;
    for(int i = 1; i <= k + 1; i++) c[i] = 1ll * c[i - 1] * ((n + 2 - i) % MOD) % MOD * qpow(i, MOD - 2) % MOD;
    for(int i = 1; i <= k; i++) {
        res = (res + 1ll * fac[i] * s[k][i] % MOD * c[i + 1] % MOD) % MOD;
    }   
    return res;
}

void run(){
    cin >> n >> m;
    for(int i = 1; i <= m; i++) cin >> a[i];
    sort(a + 1, a + m + 1);
    int ans = 0;
    for(int k = 0; k <= m; k++) {
        int res = calc(n - a[k], m + 1), tmp = 0;
        for(int i = k + 1; i <= m; i++) {
            tmp = (tmp + qpow(a[i] - a[k], m + 1)) % MOD;   
        }
        res = (res + MOD - tmp) % MOD;
        ans = (ans + res) % MOD;
    }
    cout << ans << '\n';
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cout << fixed << setprecision(20);
    init();
    int T; cin >> T;
    while(T--) run();
    return 0;
}

拉格朗日插值:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 55, MOD = 1e9 + 7;
int T;
ll a[N], fac[N];
ll qp(ll A, ll B) {
    ll ans = 1;
    while(B) {
        if(B & 1) ans = ans * A % MOD;
        A = A * A % MOD;
        B >>= 1;
    }
    return ans ;
}
void add(ll &x, ll y, ll z) {
    x += z * y % MOD;
    x %= MOD;
    if(x < 0) x += MOD;
}
void mul(ll &x, ll y) {
    x *= y;
    x %= MOD;
    if(x < 0) x += MOD;
}
ll calc(ll n, ll m) {
    ll ans = 0;
    if(n <= m + 2) {
        for(int i = 1; i <= n; i++) add(ans, qp(i, m), 1) ;
        return ans ;
    }
    ll g = 1, y = 0;
    for(int i = 1; i <= m + 2; i++) mul(g, n - i);
    for(int i = 1; i <= m + 2; i++) {
        ll t = qp(fac[i - 1] * fac[m + 2 - i] % MOD, MOD - 2) ;
        if((m + 2 - i) & 1) t = -t;
        add(y, qp(i, m), 1);
        ll tmp = qp(n - i, MOD - 2);
        mul(tmp, t * y % MOD * g % MOD) ;
        add(ans, tmp, 1);
    }
    return ans;
}
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    fac[0] = 1;
    for(int i = 1; i < N; i++) fac[i] = fac[i - 1] * i % MOD ;
    cin >> T;
    while(T--) {
        int n, m;
        cin >> n >> m;
        for(int i = 1; i <= m; i++) cin >> a[i];
        sort(a + 1, a + m + 1) ;
        ll ans = 0;
        for(int i = 0; i <= m; i++) {
            add(ans, calc(n - a[i], m + 1), 1);
            for(int j = i + 1; j <= m; j++)
                add(ans, qp(a[j] - a[i], m + 1), -1) ;
        }
        cout << ans << '\n';
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/heyuhhh/p/12053771.html