JZOJ 5752 斐波那契树

传送门
题面

思路

唉,我太弱了,什么都不会,好不容易 Rush 一波 T2,原因竟然是做不来 T1。看到大家都把 T1 A 了,看到自己又垫底了,不禁感叹:唉,我太弱啦!

事实上,解决问题还是要从两个方向出发:一是 Limited Constraint,而是 Special Instance。前者对这道题来说似乎不太凑效:直接打暴力可以 O ( n 2 ) 维护所有信息,也可以离线对后面的询问进行操作,没有太大研究价值。所以我们来看看 Special Instance。

这道题给了两个 Special Instance。第一个是 m = 0 ,也就是说只给自己加权值……算了不研究第一个,随便乱搞就有 40 分了,我们研究第二个。第二个是 u = 1 ,也就是只对 1 结点进行操作 1

显然我们得把操作想办法合并在一起,然后一并算出某个查询的答案,否则我们不可能降低时间复杂度。发现这样一个性质:我们对结点 1 做两次操作 1 ,距离都为无穷,首项分别为 a 1 b 1 a 2 b 2 。那么我们事实上可以将它看成一个操作,首项为 a 1 + a 2 b 1 + b 2 。现在加上距离限制,我们发现,我们可以只保留首项。由于某个结点到 1 结点的深度是确定的,因此只要我们求出了要算在内的 a b ,我们就能在 O ( log n ) 的时间内计算出某个结点的答案。

可以用一个树状数组,每次修改时,我们让深度范围在 m 内的结点的首项都加上 a b ,查询时我们只需要查询该结点在对应深度的 a b 分别的和就可以了。这相当于是一个区间加,单点查的操作。时间复杂度为 O ( n log n ) 。顺利得到 70 分。

上面的讨论默认了 1 结点为根结点。现在我们还是令 1 为根,但是只对 2 结点进行操作 1 (假设 2 结点是 1 的儿子),这会发生什么呢?发现,我们抛弃 2 结点所在子树,只看 1 的话,相当于递推一项,距离减一,其它的都一样。而对于 2 结点所在子树,我们就不能看 1 结点了——这难道不会让你想到点分治吗?显然原问题可以用点分治做,思路就是这一段的内容:递推 d 项,距离减 d 。所以这个题就这么做完啦!

时间复杂度 O ( n log 2 n )

扫描二维码关注公众号,回复: 1214514 查看本文章

你不会点分治?不会就算了。我相信会点分治的人到这一步一定都能做了。

参考代码
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <cassert>
#include <cctype>
#include <climits>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <deque>
#include <map>
#include <set>
#include <bitset>
#include <list>
#include <functional>
typedef long long LL;
typedef unsigned long long ULL;
using std::cin;
using std::cout;
using std::endl;
typedef int INT_PUT;
INT_PUT readIn()
{
    INT_PUT a = 0; bool positive = true;
    char ch = getchar();
    while (!(ch == '-' || std::isdigit(ch))) ch = getchar();
    if (ch == '-') { positive = false; ch = getchar(); }
    while (std::isdigit(ch)) { a = a * 10 - (ch - '0'); ch = getchar(); }
    return positive ? -a : a;
}
void printOut(INT_PUT x)
{
    char buffer[20]; int length = 0;
    if (x < 0) putchar('-'); else x = -x;
    do buffer[length++] = -(x % 10) + '0'; while (x /= 10);
    do putchar(buffer[--length]); while (length);
    putchar('\n');
}

const int mod = int(1e9) + 7;
const int maxn = int(2e5) + 5;
int n, q;

struct Graph
{
    struct Edge
    {
        int to;
        int next;
    } edges[maxn * 2];
    int i;
    int head[maxn];
    Graph() : i() { memset(head, -1, sizeof(head)); }
    void addEdge(int from, int to)
    {
        edges[i].to = to;
        edges[i].next = head[from];
        head[from] = i;
        i++;
    }
#define idx(G) idx_##G
#define wander(G, node) for (int idx(G) = G.head[node];  ~idx(G); idx(G) = G.edges[idx(G)].next)
#define DEF(G) const Graph::Edge& e = G.edges[idx(G)]; int to = e.to
} G;

struct Ins
{
    int type;
    int u, m, a, b;
    void read()
    {
        type = readIn();
        u = readIn();
        if (type == 1)
        {
            m = readIn();
            a = readIn();
            b = readIn();
        }
    }
} ins[maxn];

std::vector<int> offlineq[maxn];
std::vector<int> offlinem[maxn];

int fib[maxn];
std::pair<int, int> GetFib(LL a, LL b, int t)
{
    if (t == 0) return std::make_pair(a, b);
    if (t == 1) return std::make_pair(b, (a + b) % mod);
    static int rect[20][2][2];
    static bool inited;
    if (!inited)
    {
        inited = true;
        rect[0][1][0] = 1;
        rect[0][0][1] = 1;
        rect[0][1][1] = 1;
        for (int v = 1; v < 20; v++)
        {
            for (int i = 0; i < 2; i++)
                for (int k = 0; k < 2; k++) if (rect[v - 1][i][k])
                    for (int j = 0; j < 2; j++)
                        rect[v][i][j] = (rect[v][i][j] +
                        (LL)rect[v - 1][i][k] * rect[v - 1][k][j]) % mod;
        }
    }
    int x = 0;
    while (t)
    {
        if (t & 1)
        {
            LL ta = a * rect[x][0][0] + b * rect[x][1][0];
            LL tb = a * rect[x][0][1] + b * rect[x][1][1];
            a = ta % mod;
            b = tb % mod;
        }
        x++;
        t >>= 1;
    }
    return std::make_pair(a, b);
}
struct BIT
{
    int bit[maxn];
    inline static int lowbit(int x) { return x & (-x); }
    void add(int pos, int val)
    {
        register int t;
        while (pos <= n)
        {
            bit[pos] = (t = bit[pos] + val) >= mod ? t - mod : t;
            pos += lowbit(pos);
        }
    }
    int query(int pos)
    {
        register int ret = 0;
        register int t;
        while (pos)
        {
            ret = (t = ret + bit[pos]) >= mod ? t - mod : t;
            pos ^= lowbit(pos);
        }
        return ret;
    }
    void clear(int pos)
    {
        while (pos)
        {
            if (bit[pos]) bit[pos] = 0;
            else break;
            pos += lowbit(pos);
        }
    }
} bit[2];

LL ans[maxn];

bool vis[maxn];
int size[maxn];
void DFS1(int node, int parent)
{
    size[node] = 1;
    wander(G, node)
    {
        DEF(G);
        if (to == parent || vis[to]) continue;
        DFS1(to, node);
        size[node] += size[to];
    }
}
int findRoot(int node, int parent, int s)
{
    wander(G, node)
    {
        DEF(G);
        if (to == parent || vis[to]) continue;
        if (size[to] > (s >> 1))
            return findRoot(to, node, s);
    }
    return node;
}
int depth[maxn];
void DFS2(int node, int parent, std::vector<int>& qs, std::vector<int>& ms)
{
    for (const int& t : offlineq[node])
        qs.push_back(t);
    for (const int& t : offlinem[node])
        ms.push_back(t);
    wander(G, node)
    {
        DEF(G);
        if (to == parent || vis[to]) continue;
        depth[to] = depth[node] + 1;
        DFS2(to, node, qs, ms);
    }
}
void DFS3(int node, int parent, std::vector<int>& qs, std::vector<int>& ms)
{
    for (const int& t : offlineq[node])
        qs.push_back(t);
    for (const int& t : offlinem[node])
        ms.push_back(t);
    wander(G, node)
    {
        DEF(G);
        if (to == parent || vis[to]) continue;
        DFS3(to, node, qs, ms);
    }
}
void solve(int node, int parent)
{
    if (parent)
    {
        std::vector<int> qs;
        std::vector<int> ms;
        DFS3(node, 0, qs, ms);

        std::sort(qs.begin(), qs.end());
        std::sort(ms.begin(), ms.end());
        int j = 0;
        for (int i = 0; i < qs.size(); i++)
        {
            for (; j < ms.size() && ms[j] < qs[i]; j++)
            {
                int nodet = ins[ms[j]].u;
                int dis = ins[ms[j]].m;
                if (dis < depth[nodet])
                    continue;
                auto val = GetFib(ins[ms[j]].a, ins[ms[j]].b, depth[nodet]);
                bit[0].add(1, val.first);
                bit[1].add(1, val.second);
                bit[0].add(dis - depth[nodet] + 2, (-val.first + mod) % mod);
                bit[1].add(dis - depth[nodet] + 2, (-val.second + mod) % mod);
            }
            int nodet = ins[qs[i]].u;
            LL a = bit[0].query(depth[nodet] + 1);
            LL b = bit[1].query(depth[nodet] + 1);
            auto val = GetFib(a, b, depth[nodet]);
            ans[qs[i]] = ((ans[qs[i]] - val.first) % mod + mod) % mod;
        }
        bit[0].clear(1);
        bit[1].clear(1);
        for (const int& t : ms)
        {
            int nodet = ins[t].u;
            int dis = ins[t].m;
            if (dis < depth[nodet])
                continue;
            bit[0].clear(dis - depth[nodet] + 2);
            bit[1].clear(dis - depth[nodet] + 2);
        }
    }
    DFS1(node, 0);
    node = findRoot(node, 0, size[node]);
    vis[node] = true;

    {
        std::vector<int> qs;
        std::vector<int> ms;
        depth[node] = 0;
        DFS2(node, 0, qs, ms);

        std::sort(qs.begin(), qs.end());
        std::sort(ms.begin(), ms.end());
        int j = 0;
        for (int i = 0; i < qs.size(); i++)
        {
            for (; j < ms.size() && ms[j] < qs[i]; j++)
            {
                int nodet = ins[ms[j]].u;
                int dis = ins[ms[j]].m;
                if (dis < depth[nodet])
                    continue;
                auto val = GetFib(ins[ms[j]].a, ins[ms[j]].b, depth[nodet]);
                bit[0].add(1, val.first);
                bit[1].add(1, val.second);
                bit[0].add(dis - depth[nodet] + 2, (-val.first + mod) % mod);
                bit[1].add(dis - depth[nodet] + 2, (-val.second + mod) % mod);
            }
            int nodet = ins[qs[i]].u;
            LL a = bit[0].query(depth[nodet] + 1);
            LL b = bit[1].query(depth[nodet] + 1);
            auto val = GetFib(a, b, depth[nodet]);
            ans[qs[i]] = (ans[qs[i]] + val.first) % mod;
        }
        bit[0].clear(1);
        bit[1].clear(1);
        for (const int& t : ms)
        {
            int nodet = ins[t].u;
            int dis = ins[t].m;
            if (dis < depth[nodet])
                continue;
            bit[0].clear(dis - depth[nodet] + 2);
            bit[1].clear(dis - depth[nodet] + 2);
        }
    }

    wander(G, node)
    {
        DEF(G);
        if (!vis[to])
            solve(to, node);
    }
}

void run()
{
    n = readIn();
    q = readIn();
    for (int i = 2; i <= n; i++)
    {
        int from = readIn();
        int to = readIn();
        G.addEdge(from, to);
        G.addEdge(to, from);
    }
    for (int i = 1; i <= q; i++)
    {
        ins[i].read();
        if (ins[i].type == 1)
            offlinem[ins[i].u].push_back(i);
        else
            offlineq[ins[i].u].push_back(i);
    }

    fib[0] = 1;
    fib[1] = 0;
    for (int i = 2; i <= n; i++)
        fib[i] = (fib[i - 2] + fib[i - 1]) % mod;

    solve(1, 0);
    for (int i = 1; i <= q; i++)
        if (ins[i].type == 2)
            printOut(ans[i]);
}

int main()
{
#ifndef LOCAL
    freopen("fibtree.in", "r", stdin);
    freopen("fibtree.out", "w", stdout);
#endif
    run();
    return 0;
}

猜你喜欢

转载自blog.csdn.net/lycheng1215/article/details/80511335
今日推荐