重軽分解をしたかった。最近熱い気がするしcode festival本戦までに書いておきたかった。 予想以上に実装が重く苦労した。達成感はあった。 重軽分解 + LCA + segment木 + 遅延伝播 + 累積和。 良い問題だと思う。

No.235 めぐるはめぐる (5)

問題

頂点数$N$の木($N \le 2\cdot 10^5$)が与えられる。各頂点には重み$S_i$、係数$C_i$が定められている。以下のクエリが$Q$個($Q \le 2\cdot 10^5$)与えられるので処理せよ。

  • $X,Y,Z$が与えられる。頂点$X,Y$間の道上の各頂点$i$について、その重み$S_i$に$C_i \cdot Z$を加えよ。
  • $X,Y$が与えられる。頂点$X,Y$間の道上の各頂点$i$について、その重み$S_i$の総和を求めよ。

解法

道上の各頂点への一様でない加算と総和を処理しなければならない。 愚直に行なえば$O(NQ)$で間に合わない。

heavy-light decompositionをする。 heavy-light decompositionに関してはHeavy-Light Decomposition - math314のブログが詳しい。 これにより木上の道のクエリを、単なる列の上のクエリとして処理できる。

列のクエリになればsegment treeで処理できる。ただし遅延伝播などいくつか工夫が必要。 初期の重み$S_i$は累積和を取っておくなり適当にすればよい。 頂点は係数$C_i$を持つので加算が一様でなく、単純に処理することはできない。 加算された区間を包含する大きな区間に対する総和のクエリであれば問題ないが、そうでない場合は各頂点に別々の値を足す必要がある。 そこで、加算のクエリの時点では区間に対して足すところで処理を止めておいて、総和のクエリで必要になった場合に必要な分だけ続きを分割し加算していけばよい。

実装

#include <iostream>
#include <vector>
#include <map>
#include <cmath>
#include <cassert>
#define repeat_from(i,m,n) for (int i = (m); (i) < (n); ++(i))
#define repeat(i,n) repeat_from(i,0,n)
#define repeat_from_reverse(i,m,n) for (int i = (n)-1; (i) >= (m); --(i))
#define repeat_reverse(i,n) repeat_from_reverse(i,0,n)
typedef long long ll;
using namespace std;
struct heavy_light_decomposition_t {
    int n; // |V'|
    vector<int> a; // V ->> V' epic
    vector<vector<int> > path; // V' -> V*, bottom to top order, disjoint union of codomain matchs V
    vector<map<int,int> > pfind; // V' * V -> int, find in path
    vector<int> parent; // V' -> V
    heavy_light_decomposition_t(int v, vector<vector<int> > const & g) {
        n = 0;
        a.resize(g.size());
        dfs(v, -1, g);
    }
    int dfs(int v, int p, vector<vector<int> > const & g) {
        int heavy_node = -1;
        int heavy_size = 0;
        int desc_size = 1;
        for (int w : g[v]) if (w != p) {
            int size = dfs(w, v, g);
            desc_size += size;
            if (heavy_size < size) {
                heavy_node = w;
                heavy_size = size;
            }
        }
        if (heavy_node == -1) {
            a[v] = n;
            n += 1;
            path.emplace_back();
            path.back().push_back(v);
            pfind.emplace_back();
            pfind.back()[v] = 0;
            parent.push_back(p);
        } else {
            int i = a[heavy_node];
            a[v] = i;
            pfind[i][v] = path[i].size();
            path[i].push_back(v);
            parent[i] = p;
        }
        return desc_size;
    }
};
struct lowest_common_ancestor_t {
    vector<vector<int> > a;
    vector<int> depth;
    lowest_common_ancestor_t(int v, vector<vector<int> > const & g) {
        int n = g.size();
        int l = 1 + floor(log2(n));
        a.resize(l);
        repeat (k,l) a[k].resize(n, -1);
        depth.resize(n);
        dfs(v, -1, 0, g, a[0], depth);
        repeat (k,l-1) {
            repeat (i,n) {
                if (a[k][i] != -1) {
                    a[k+1][i] = a[k][a[k][i]];
                }
            }
        }
    }
    static void dfs(int v, int p, int current_depth, vector<vector<int> > const & g, vector<int> & parent, vector<int> & depth) {
        parent[v] = p;
        depth[v] = current_depth;
        for (int w : g[v]) if (w != p) {
            dfs(w, v, current_depth + 1, g, parent, depth);
        }
    }
    int operator () (int x, int y) const {
        int l = a.size();
        if (depth[x] < depth[y]) swap(x,y);
        repeat_reverse (k,l) {
            if (a[k][x] != -1 and depth[a[k][x]] >= depth[y]) {
                x = a[k][x];
            }
        }
        assert (depth[x] == depth[y]);
        if (x == y) return x;
        repeat_reverse (k,l) {
            if (a[k][x] != a[k][y]) {
                x = a[k][x];
                y = a[k][y];
            }
        }
        assert (x != y);
        assert (a[0][x] == a[0][y]);
        return a[0][x];
    }
};
constexpr ll mod = 1000000007;
struct segment_tree_t {
    int n;
    vector<ll> a;
    vector<ll> b; // lazy propagation
    vector<ll> s; // |s| = n+1, accumulated
    vector<ll> c; // |c| = n+1, accumulated
    segment_tree_t(vector<int> const & x, vector<ll> const & a_s, vector<ll> const & a_c) {
        n = pow(2,ceil(log2(x.size())));
        a.resize(2*n-1);
        b.resize(2*n-1);
        s.resize(n+1); repeat (i, x.size()) s[i+1] = (s[i] + a_s[x[i]]) % mod;
        c.resize(n+1); repeat (i, x.size()) c[i+1] = (c[i] + a_c[x[i]]) % mod;
    }
    void range_add(int l, int r, ll z) {
        assert (l < r);
        range_add(1, 0, n, l, r, z);
    }
    void range_add(int i, int il, int ir, int l, int r, ll z) {
        if (l <= il and ir <= r) {
            a[i-1] = (a[i-1] + z * ((c[ir] - c[il] + mod) % mod) % mod + mod) % mod;
            if (2*i < b.size()) {
                b[2*i-1] = (b[2*i-1] + z) % mod;
                b[2*i]   = (b[2*i]   + z) % mod;
            }
        } else if (r <= il or ir <= l) {
            // nop
        } else {
            range_add(2*i,   il, (il+ir)/2, l, r, z);
            range_add(2*i+1, (il+ir)/2, ir, l, r, z);
            a[i-1] =
                (range_sum(2*i,   il, (il+ir)/2, 0, n) +
                 range_sum(2*i+1, (il+ir)/2, ir, 0, n)) % mod;
        }
    }
    ll range_sum(int l, int r) {
        assert (l < r);
        return (range_sum(1, 0, n, l, r) + s[r] - s[l] + mod) % mod;
    }
    ll range_sum(int i, int il, int ir, int l, int r) {
        if (b[i-1]) {
            a[i-1] = (a[i-1] + b[i-1] * ((c[ir] - c[il] + mod) % mod) % mod + mod) % mod;
            if (2*i < b.size()) {
                b[2*i-1] = (b[2*i-1] + b[i-1]) % mod;
                b[2*i]   = (b[2*i]   + b[i-1]) % mod;
            }
            b[i-1] = 0;
        }
        if (l <= il and ir <= r) {
            return a[i-1];
        } else if (r <= il or ir <= l) {
            return 0;
        } else {
            return
                (range_sum(2*i,   il, (il+ir)/2, l, r) +
                 range_sum(2*i+1, (il+ir)/2, ir, l, r)) % mod;
        }
    }
};
ll query_sum_path(heavy_light_decomposition_t & hl, lowest_common_ancestor_t & lca, vector<segment_tree_t> & st, int v, int w) {
    ll a = 0;
    int i = hl.a[v];
    if (hl.a[w] == i) {
        assert (hl.pfind[i][v] <= hl.pfind[i][w]);
        a = (a + st[i].range_sum(hl.pfind[i][v], hl.pfind[i][w]+1)) % mod;
    } else {
        a = (a + st[i].range_sum(hl.pfind[i][v], hl.path[i].size())) % mod;
        a = (a + query_sum_path(hl, lca, st, hl.parent[i], w)) % mod;
    }
    return a;
}
ll query_sum(heavy_light_decomposition_t & hl, lowest_common_ancestor_t & lca, vector<segment_tree_t> & st, int x, int y) {
    int w = lca(x, y);
    ll a = 0;
    a = (a + query_sum_path(hl, lca, st, x, w)) % mod;
    a = (a + query_sum_path(hl, lca, st, y, w)) % mod;
    a = (a - query_sum_path(hl, lca, st, w, w) + mod) % mod;
    return a;
}
void query_add_path(heavy_light_decomposition_t & hl, lowest_common_ancestor_t & lca, vector<segment_tree_t> & st, int v, int w, ll z) {
    int i = hl.a[v];
    if (hl.a[w] == i) {
        assert (hl.pfind[i][v] <= hl.pfind[i][w]);
        st[i].range_add(hl.pfind[i][v], hl.pfind[i][w]+1, z);
    } else {
        st[i].range_add(hl.pfind[i][v], hl.path[i].size(), z);
        query_add_path(hl, lca, st, hl.parent[i], w, z);
    }
}
void query_add(heavy_light_decomposition_t & hl, lowest_common_ancestor_t & lca, vector<segment_tree_t> & st, int x, int y, ll z) {
    int w = lca(x, y);
    query_add_path(hl, lca, st, x, w,   z);
    query_add_path(hl, lca, st, y, w,   z);
    query_add_path(hl, lca, st, w, w, - z);
}
int main() {
    int n; cin >> n;
    vector<ll> s(n); repeat (i,n) cin >> s[i];
    vector<ll> c(n); repeat (i,n) cin >> c[i];
    vector<vector<int> > g(n);
    repeat (i,n-1) {
        int a, b; cin >> a >> b;
        -- a; -- b;
        g[a].push_back(b);
        g[b].push_back(a);
    }
    heavy_light_decomposition_t hl(0, g);
    lowest_common_ancestor_t lca(0, g);
    vector<segment_tree_t> st;
    repeat (i,hl.n) st.emplace_back(hl.path[i], s, c);
    int query; cin >> query;
    while (query --) {
        int query_type, x, y; cin >> query_type >> x >> y;
        -- x; -- y;
        if (query_type) {
            cout << query_sum(hl, lca, st, x, y) << endl;
        } else {
            ll z; cin >> z;
            query_add(hl, lca, st, x, y, z);
        }
    }
    return 0;
}

疲れた。