[LOJ # 3046. "ZJOI2019" language

LOJ # 3046. "ZJOI2019" language

First orz zsy

A \ (n \ log \ ^ 3n ) approach is to split the tree chain formed logn intervals, these intervals can be obtained with twenty-two a rectangle, a rectangular area and seek

Then the point is for a tree log intervals plus strand into segments and then merge tree, which is \ (n \ log ^ 2 n \) of

Chain and will form a tree, if we go through the chain of endpoints of a point in order to sort by dfn words, two adjacent count from end to end and then count two, we can get this tree and chain and a number of edge weight × 2, thereby seeking the point of the tree

Desired is through the chain and -1 for each point and then divided by 2

For each point, a segment tree to open, if there is a chain \ (s, t \) in \ (S \) segment tree \ (S \) position + 1, in \ (T \) of the segment tree the \ (T \) position +1, the \ (T \) the same operation as in the segment tree

Every time look at the cumulative position is not a positive number, is positive proof point of this chain and have this point

In \ (lca \) above the contribution of these two points can be deleted

Complexity \ (n \ log n \)

#include <bits/stdc++.h>
#define fi first
#define se second
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define space putchar(' ')
#define enter putchar('\n')
#define eps 1e-10
#define MAXN 100005
#define ba 47
//#define ivorysi
using namespace std;
typedef long long int64;
typedef unsigned int u32;
typedef double db;
template<class T>
void read(T &res) {
    res = 0;T f = 1;char c = getchar();
    while(c < '0' || c > '9') {
    if(c == '-') f = -1;
    c = getchar();
    }
    while(c >= '0' && c <= '9') {
    res = res * 10 +c - '0';
    c = getchar();
    }
    res *= f;
}
template<class T>
void out(T x) {
    if(x < 0) {x = -x;putchar('-');}
    if(x >= 10) {
    out(x / 10);
    }
    putchar('0' + x % 10);
}
int N,M;
struct node {
    int to,next;
}E[MAXN * 2];

int head[MAXN],sumE;
int fa[MAXN],dep[MAXN],top[MAXN],siz[MAXN],dfn[MAXN],idx,line[MAXN];
int len[MAXN * 2],st[MAXN * 2][20],tot,pos[MAXN];
vector<int> del[MAXN];
int64 ans = 0;
int mindex(int a,int b) {return dep[a] < dep[b] ? a : b;}
int Query(int a,int b) {
    int l = len[b - a + 1];
    return mindex(st[a][l],st[b - (1 << l) + 1][l]);
}
int lca(int a,int b) {
    int u = pos[a],v = pos[b];
    if(u > v) swap(u,v);
    return Query(u,v);
}
int dist(int a,int b) {
    return dep[a] + dep[b] - 2 * dep[lca(a,b)];
}
void add(int u,int v) {
    E[++sumE].to = v;
    E[sumE].next = head[u];
    head[u] = sumE;
}
void dfs(int u) {
    dep[u] = dep[fa[u]] + 1;siz[u] = 1;dfn[u] = ++idx;line[idx] = u;
    st[++tot][0] = u;pos[u] = tot;
    for(int i = head[u] ; i ; i = E[i].next) {
    int v = E[i].to;
    if(v != fa[u]) {
        fa[v] = u;
        dfs(v);
        st[++tot][0] = u;
        siz[u] += siz[v];
    }
    }
}

struct tr_node {
    int ls,rs,lp,rp,cnt;
    int64 sum;
}tr[MAXN * 100];
int rt[MAXN],Ncnt = 0;
void update(int u) {
    int lson = tr[u].ls,rson = tr[u].rs;
    tr[u].rp = tr[rson].rp ? tr[rson].rp : tr[lson].rp;
    tr[u].lp = tr[lson].lp ? tr[lson].lp : tr[rson].lp;
    tr[u].sum = tr[lson].sum + tr[rson].sum;
    if(tr[lson].rp && tr[rson].lp) tr[u].sum += dist(tr[lson].rp,tr[rson].lp);
}
void Add(int &u,int l,int r,int pos,int v) {
    if(!u) u = ++Ncnt;
    if(l == r) {
    tr[u].cnt += v;
    if(tr[u].cnt) {tr[u].lp = tr[u].rp = line[pos];tr[u].sum = 0;}
    else {tr[u].lp = tr[u].rp = tr[u].sum = 0;}
    return;
    }
    int mid = (l + r) >> 1;
    if(pos <= mid) Add(tr[u].ls,l,mid,pos,v);
    else if(pos > mid) Add(tr[u].rs,mid + 1,r,pos,v);
    update(u);
}
int Merge(int u,int v,int l,int r) {
    if(!u) return v;
    if(!v) return u;
    if(l == r) {
    tr[u].cnt = tr[u].cnt + tr[v].cnt;
    if(tr[u].cnt) {tr[u].lp = tr[u].rp = line[l];tr[u].sum = 0;}
    else {tr[u].lp = tr[u].rp = tr[u].sum = 0;}
    return u;
    }
    int mid = (l + r) >> 1;
    tr[u].ls = Merge(tr[u].ls,tr[v].ls,l,mid);
    tr[u].rs = Merge(tr[u].rs,tr[v].rs,mid + 1,r);
    update(u);
    return u;
}
void Calc(int u) {
    for(int i = head[u] ; i ; i = E[i].next) {
    int v = E[i].to;
    if(v != fa[u]) {
        Calc(v);
        rt[u] = Merge(rt[u],rt[v],1,N);
    }
    }
    for(auto t : del[u]) Add(rt[u],1,N,dfn[t],-2);
    if(tr[rt[u]].lp && tr[rt[u]].rp) {
    int64 res = tr[rt[u]].sum + dist(tr[rt[u]].lp,tr[rt[u]].rp);res /= 2;
    ans += res;
    }
}
void Solve() {
    read(N);read(M);
    int a,b;
    for(int i = 1 ; i < N ; ++i) {
    read(a);read(b);
    add(a,b);add(b,a);
    }
    dfs(1);
    for(int i = 2 ; i <= tot ; ++i) len[i] = len[i / 2] + 1;
    for(int j = 1 ; j <= 19 ; ++j) {
    for(int i = 1 ; i <= tot ; ++i) {
        if(i + (1 << j) - 1 > tot) break;
        st[i][j] = mindex(st[i][j - 1],st[i + (1 << j - 1)][j - 1]);
    }
    }
    for(int i = 1 ; i <= M ; ++i) {
    read(a);read(b);
    Add(rt[a],1,N,dfn[a],1);Add(rt[a],1,N,dfn[b],1);
    Add(rt[b],1,N,dfn[b],1);Add(rt[b],1,N,dfn[a],1);
    int f = lca(a,b);
    if(fa[f]) {del[fa[f]].pb(a);del[fa[f]].pb(b);}
    }
    Calc(1);
    ans /= 2;
    out(ans);enter;
}
int main() {
#ifdef ivorysi
    freopen("f1.in","r",stdin);
#endif
    Solve();
}

Guess you like

Origin www.cnblogs.com/ivorysi/p/10972493.html