【LOJ】#3046. 「ZJOI2019」语言

LOJ#3046. 「ZJOI2019」语言

先orz zsy吧

有一个\(n\log^3n\)的做法是把树链剖分后,形成logn个区间,这些区间两两搭配可以获得一个矩形,求矩形面积并

然后就是对于一个点把树链的log个区间加进去然后线段树合并,这是\(n \log^2 n\)

链并会形成一棵树,如果我们把经过某个点的链的端点按dfn序排序的话,相邻两项算一下距离,首尾两项再算一下,我们就可以获得链并的这棵树的边权和×2,由此可以求树上的点的个数

我们要求的就是经过每个点的链并-1的和,然后再除2

对于每个点,开一个线段树,如果有一条链\(s,t\)\(s\)的线段树上的\(s\)位置+ 1,在\(t\)的线段树上\(t\)的位置+1,在\(t\)的线段树上进行相同的操作

每次看看这个位置累加的是不是正数,是正数证明这个点的链并有这个点

\(lca\)的上方把这两个点的贡献删除即可

复杂度\(n \log n\)

#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define space putchar(' ')
#define enter putchar('\n')
#define eps 1e-10
#define MAXN 100005
#define ba 47
//#define ivorysi
using namespace std;
typedef long long int64;
typedef unsigned int u32;
typedef double db;
template<class T>
void read(T &res) {
    res = 0;T f = 1;char c = getchar();
    while(c < '0' || c > '9') {
    if(c == '-') f = -1;
    c = getchar();
    }
    while(c >= '0' && c <= '9') {
    res = res * 10 +c - '0';
    c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x) {
    if(x < 0) {x = -x;putchar('-');}
    if(x >= 10) {
    out(x / 10);
    }
    putchar('0' + x % 10);
}
int N,M;
struct node {
    int to,next;
}E[MAXN * 2];

int head[MAXN],sumE;
int fa[MAXN],dep[MAXN],top[MAXN],siz[MAXN],dfn[MAXN],idx,line[MAXN];
int len[MAXN * 2],st[MAXN * 2][20],tot,pos[MAXN];
vector<int> del[MAXN];
int64 ans = 0;
int mindex(int a,int b) {return dep[a] < dep[b] ? a : b;}
int Query(int a,int b) {
    int l = len[b - a + 1];
    return mindex(st[a][l],st[b - (1 << l) + 1][l]);
}
int lca(int a,int b) {
    int u = pos[a],v = pos[b];
    if(u > v) swap(u,v);
    return Query(u,v);
}
int dist(int a,int b) {
    return dep[a] + dep[b] - 2 * dep[lca(a,b)];
}
void add(int u,int v) {
    E[++sumE].to = v;
    E[sumE].next = head[u];
    head[u] = sumE;
}
void dfs(int u) {
    dep[u] = dep[fa[u]] + 1;siz[u] = 1;dfn[u] = ++idx;line[idx] = u;
    st[++tot][0] = u;pos[u] = tot;
    for(int i = head[u] ; i ; i = E[i].next) {
    int v = E[i].to;
    if(v != fa[u]) {
        fa[v] = u;
        dfs(v);
        st[++tot][0] = u;
        siz[u] += siz[v];
    }
    }
}

struct tr_node {
    int ls,rs,lp,rp,cnt;
    int64 sum;
}tr[MAXN * 100];
int rt[MAXN],Ncnt = 0;
void update(int u) {
    int lson = tr[u].ls,rson = tr[u].rs;
    tr[u].rp = tr[rson].rp ? tr[rson].rp : tr[lson].rp;
    tr[u].lp = tr[lson].lp ? tr[lson].lp : tr[rson].lp;
    tr[u].sum = tr[lson].sum + tr[rson].sum;
    if(tr[lson].rp && tr[rson].lp) tr[u].sum += dist(tr[lson].rp,tr[rson].lp);
}
void Add(int &u,int l,int r,int pos,int v) {
    if(!u) u = ++Ncnt;
    if(l == r) {
    tr[u].cnt += v;
    if(tr[u].cnt) {tr[u].lp = tr[u].rp = line[pos];tr[u].sum = 0;}
    else {tr[u].lp = tr[u].rp = tr[u].sum = 0;}
    return;
    }
    int mid = (l + r) >> 1;
    if(pos <= mid) Add(tr[u].ls,l,mid,pos,v);
    else if(pos > mid) Add(tr[u].rs,mid + 1,r,pos,v);
    update(u);
}
int Merge(int u,int v,int l,int r) {
    if(!u) return v;
    if(!v) return u;
    if(l == r) {
    tr[u].cnt = tr[u].cnt + tr[v].cnt;
    if(tr[u].cnt) {tr[u].lp = tr[u].rp = line[l];tr[u].sum = 0;}
    else {tr[u].lp = tr[u].rp = tr[u].sum = 0;}
    return u;
    }
    int mid = (l + r) >> 1;
    tr[u].ls = Merge(tr[u].ls,tr[v].ls,l,mid);
    tr[u].rs = Merge(tr[u].rs,tr[v].rs,mid + 1,r);
    update(u);
    return u;
}
void Calc(int u) {
    for(int i = head[u] ; i ; i = E[i].next) {
    int v = E[i].to;
    if(v != fa[u]) {
        Calc(v);
        rt[u] = Merge(rt[u],rt[v],1,N);
    }
    }
    for(auto t : del[u]) Add(rt[u],1,N,dfn[t],-2);
    if(tr[rt[u]].lp && tr[rt[u]].rp) {
    int64 res = tr[rt[u]].sum + dist(tr[rt[u]].lp,tr[rt[u]].rp);res /= 2;
    ans += res;
    }
}
void Solve() {
    read(N);read(M);
    int a,b;
    for(int i = 1 ; i < N ; ++i) {
    read(a);read(b);
    add(a,b);add(b,a);
    }
    dfs(1);
    for(int i = 2 ; i <= tot ; ++i) len[i] = len[i / 2] + 1;
    for(int j = 1 ; j <= 19 ; ++j) {
    for(int i = 1 ; i <= tot ; ++i) {
        if(i + (1 << j) - 1 > tot) break;
        st[i][j] = mindex(st[i][j - 1],st[i + (1 << j - 1)][j - 1]);
    }
    }
    for(int i = 1 ; i <= M ; ++i) {
    read(a);read(b);
    Add(rt[a],1,N,dfn[a],1);Add(rt[a],1,N,dfn[b],1);
    Add(rt[b],1,N,dfn[b],1);Add(rt[b],1,N,dfn[a],1);
    int f = lca(a,b);
    if(fa[f]) {del[fa[f]].pb(a);del[fa[f]].pb(b);}
    }
    Calc(1);
    ans /= 2;
    out(ans);enter;
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Solve();
}

猜你喜欢

转载自www.cnblogs.com/ivorysi/p/10972493.html