Mo’s algorithmの練習。方針はすぐだったが細部がつらかった。

problem

数列$a_i$が与えられる。 クエリとして与えられる区間$[l_i, r_i)$ごとに、その中に含まれるxor総和が$k$となるような部分区間の数を答えよ。

solution

Mo’s algorithm。$O((N+M)\sqrt{N})$で空間$O(\max \{ k, \max a_i \})$。

事前に累積和$b_i$を取っておくと、区間$[l_i, r_i)$に対する答えは$l_i \le i \lt j \le r_i$で$b_i \oplus b_j \oplus k = 0$なものの数になる。 ある区間$[l, r)$に対する答えと区間中の$b_i$の出現頻度の表があれば、区間$[l \pm 1, r)$や$[l, r \pm 1)$に対してのそれを$O(1)$で計算できる。 $b_i$の出現頻度の表は配列で持ち、答えの増減は丁寧にやる。 これはMo’s algorithmを適用でき$O((N+M)\sqrt{N})$。

詳細について。 累積和$b_i = \sum_{j \lt i}^{\oplus} a_j$。 区間$[l,r)$ごとに、答え$\mathrm{ans}_{l,r} = |\{ [i,j) \subseteq [l,r) \mid i \ne j \land b_j \oplus b_i = k \}|$、頻度表$c_{l,r}(i) = |\{ j \in [l,r) \mid b_j = i \}|$を持つ。 例えば右端を伸ばすとして、$\mathrm{ans}_{l,r+1} = |\{ [i,j) \subseteq [l,r+1) \mid i \ne j \land b_j \oplus b_i = k \}| = \mathrm{ans}_{l,r} + |\{ i \in [l,r+1) \mid b_{r+1} \oplus b_i = k \}| = \mathrm{ans}_{l,r} + c_{l,r+1}(b_{r+1} \oplus k)$。

implementation

#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <cmath>
#define repeat(i,n) for (int i = 0; (i) < int(n); ++(i))
#define whole(f,x,...) ([&](decltype((x)) whole) { return (f)(begin(whole), end(whole), ## __VA_ARGS__); })(x)
using ll = long long;
using namespace std;
int main() {
    // input
    int n, m, k; cin >> n >> m >> k;
    vector<int> a(n); repeat (i,n) cin >> a[i];
    vector<int> l(m), r(m); repeat (i,m) { cin >> l[i] >> r[i]; -- l[i]; }
    // solve
    //   Mo's algorithm
    int sqrt_n = sqrt(n);
    vector<int> ixs(m);
    whole(iota, ixs, 0);
    whole(sort, ixs, [&](int i, int j) {
        return make_pair(l[i] / sqrt_n, r[i]) < make_pair(l[j] / sqrt_n, r[j]);
    });
    vector<int> acc { 0 }; whole(partial_sum, a, back_inserter(acc), bit_xor<int>());
    vector<int> accs(max(k, *whole(max_element, a)) * 2 + 3);
    ll cnt = 0;
    int l_cur = 0, r_cur = 0;
    vector<ll> ans(m);
    for (int i : ixs) {
        while (l[i] < l_cur) {
            int i = -- l_cur;
            if ((acc[i] ^ acc[r_cur]) == k) cnt += 1;
            cnt += accs[acc[i] ^ k];
            accs[acc[i]] += 1;
        }
        while (r_cur < r[i]) {
            int i = r_cur ++;
            accs[acc[i]] += 1;
            cnt += accs[acc[i+1] ^ k];
        }
        while (l_cur < l[i]) {
            int i = l_cur ++;
            accs[acc[i]] -= 1;
            cnt -= accs[acc[i] ^ k];
            if ((acc[i] ^ acc[r_cur]) == k) cnt -= 1;
        }
        while (r[i] < r_cur) {
            int i = -- r_cur;
            cnt -= accs[acc[i+1] ^ k];
            accs[acc[i]] -= 1;
        }
        ans[i] = cnt;
    }
    // output
    repeat (i,m) {
        cout << ans[i] << endl;
    }
    return 0;
}