Thanks to visit codestin.com
Credit goes to hitonanode.github.io

cplib-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hitonanode/cplib-cpp

:heavy_check_mark: data_structure/test/lazy_rbst.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_sequence_range_affine_range_sum"
#include "../lazy_rbst.hpp"
#include "../../modint.hpp"
#include <algorithm>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;

using mint = ModInt<998244353>;

struct S {
    mint sum;
    int sz;
};
using F = pair<bool, pair<mint, mint>>;
S op(S l, S r) { return S{l.sum + r.sum, l.sz + r.sz}; }
S mapping(F f, S x) {
    if (!f.first) return x;
    mint a = f.second.first, b = f.second.second;
    return {x.sum * a + b * x.sz, x.sz};
}
S reversal(S x) { return x; }
F composition(F fnew, F gold) {
    if (!fnew.first) return gold;
    if (!gold.first) return fnew;
    auto anew = fnew.second.first, bnew = fnew.second.second;
    auto aold = gold.second.first, bold = gold.second.second;
    return {true, {anew * aold, anew * bold + bnew}};
}
F id() { return {false, {1, 0}}; }

int main() {
    cin.tie(nullptr), ios::sync_with_stdio(false);
    int N, Q;
    cin >> N >> Q;
    vector<S> A(N);
    for (auto &x : A) cin >> x.sum, x.sz = 1;
    lazy_rbst<1000001, S, op, F, reversal, mapping, composition, id> rbst;

    auto root = rbst.new_tree();
    rbst.assign(root, A);
    while (Q--) {
        int tp;
        cin >> tp;
        if (tp == 0) {
            int i, x;
            cin >> i >> x;
            rbst.insert(root, i, S{x, 1});
            N++;
        } else if (tp == 1) {
            int i;
            cin >> i;
            rbst.erase(root, i);
            N--;
        } else if (tp == 2) {
            int l, r;
            cin >> l >> r;
            rbst.reverse(root, l, r);
        } else if (tp == 3) {
            int l, r, b, c;
            cin >> l >> r >> b >> c;
            rbst.apply(root, l, r, {true, {b, c}});
        } else if (tp == 4) {
            int l, r;
            cin >> l >> r;
            cout << rbst.prod(root, l, r).sum << '\n';
        }
    }
}
#line 1 "data_structure/test/lazy_rbst.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_sequence_range_affine_range_sum"
#line 2 "data_structure/lazy_rbst.hpp"
#include <array>
#include <cassert>
#include <chrono>
#include <utility>
#include <vector>

// Lazy randomized binary search tree
template <int LEN, class S, S (*op)(S, S), class F, S (*reversal)(S), S (*mapping)(F, S),
          F (*composition)(F, F), F (*id)()>
struct lazy_rbst {
    // Do your RuBeSTy! ⌒°( ・ω・)°⌒
    inline uint32_t _rand() { // XorShift
        static uint32_t x = 123456789, y = 362436069, z = 521288629, w = 88675123;
        uint32_t t = x ^ (x << 11);
        x = y;
        y = z;
        z = w;
        return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
    }

    struct Node {
        Node *l, *r;
        S val, sum;
        F lz;
        bool is_reversed;
        int sz;
        Node(const S &v)
            : l(nullptr), r(nullptr), val(v), sum(v), lz(id()), is_reversed(false), sz(1) {}
        Node() : l(nullptr), r(nullptr), lz(id()), is_reversed(false), sz(0) {}
        template <class OStream> friend OStream &operator<<(OStream &os, const Node &n) {
            os << '[';
            if (n.l) os << *(n.l) << ',';
            os << n.val << ',';
            if (n.r) os << *(n.r);
            return os << ']';
        }
    };
    using Nptr = Node *;
    std::array<Node, LEN> data;
    int d_ptr;

    int size(Nptr t) const { return t != nullptr ? t->sz : 0; }

    lazy_rbst() : d_ptr(0) {}

protected:
    Nptr update(Nptr t) {
        t->sz = 1;
        t->sum = t->val;
        if (t->l) {
            t->sz += t->l->sz;
            t->sum = op(t->l->sum, t->sum);
        }
        if (t->r) {
            t->sz += t->r->sz;
            t->sum = op(t->sum, t->r->sum);
        }
        return t;
    }

    void all_apply(Nptr t, F f) {
        t->val = mapping(f, t->val);
        t->sum = mapping(f, t->sum);
        t->lz = composition(f, t->lz);
    }
    void _toggle(Nptr t) {
        auto tmp = t->l;
        t->l = t->r, t->r = tmp;
        t->sum = reversal(t->sum);
        t->is_reversed ^= true;
    }

    void push(Nptr &t) {
        _duplicate_node(t);
        if (t->lz != id()) {
            if (t->l) {
                _duplicate_node(t->l);
                all_apply(t->l, t->lz);
            }
            if (t->r) {
                _duplicate_node(t->r);
                all_apply(t->r, t->lz);
            }
            t->lz = id();
        }
        if (t->is_reversed) {
            if (t->l) _toggle(t->l);
            if (t->r) _toggle(t->r);
            t->is_reversed = false;
        }
    }

    virtual void _duplicate_node(Nptr &) {}

    Nptr _make_node(const S &val) {
        if (d_ptr >= LEN) throw;
        return &(data[d_ptr++] = Node(val));
    }

public:
    Nptr new_tree() { return nullptr; } // 新たな木を作成

    int mem_used() const { return d_ptr; }
    bool empty(Nptr t) const { return t == nullptr; }

    // lとrをrootとする木同士を結合して,新たなrootを返す
    Nptr merge(Nptr l, Nptr r) {
        if (l == nullptr or r == nullptr) return l != nullptr ? l : r;
        if (_rand() % uint32_t(l->sz + r->sz) < uint32_t(l->sz)) {
            push(l);
            l->r = merge(l->r, r);
            return update(l);
        } else {
            push(r);
            r->l = merge(l, r->l);
            return update(r);
        }
    }

    // [0, k)の木と[k, root->size())の木に分けて各root
    // (部分木の要素数が0ならnullptr)を返す
    std::pair<Nptr, Nptr> split(Nptr &root, int k) { // rootの子孫からあとk個欲しい
        if (root == nullptr) return std::make_pair(nullptr, nullptr);
        push(root);
        if (k <= size(root->l)) { // leftからk個拾える
            auto p = split(root->l, k);
            root->l = p.second;
            return std::make_pair(p.first, update(root));
        } else {
            auto p = split(root->r, k - size(root->l) - 1);
            root->r = p.first;
            return std::make_pair(update(root), p.second);
        }
    }

    // 0-indexedでarray[pos]の手前に新たな要素 x を挿入する
    void insert(Nptr &root, int pos, const S &x) {
        auto p = split(root, pos);
        root = merge(p.first, merge(_make_node(x), p.second));
    }

    // 0-indexedでarray[pos]を削除する(先頭からpos+1個目の要素)
    void erase(Nptr &root, int pos) {
        auto p = split(root, pos);
        auto p2 = split(p.second, 1);
        root = merge(p.first, p2.second);
    }

    // 1点更新 array[pos].valにupdvalを入れる
    void set(Nptr &root, int pos, const S &x) {
        auto p = split(root, pos);
        auto p2 = split(p.second, 1);
        _duplicate_node(p2.first);
        *p2.first = Node(x);
        root = merge(p.first, merge(p2.first, p2.second));
    }

    // 遅延評価を利用した範囲更新 [l, r)
    void apply(Nptr &root, int l, int r, const F &f) {
        if (l == r) return;
        auto p = split(root, l);
        auto p2 = split(p.second, r - l);
        all_apply(p2.first, f);
        root = merge(p.first, merge(p2.first, p2.second));
    }

    S prod(Nptr &root, int l, int r) {
        assert(l < r);
        auto p = split(root, l);
        auto p2 = split(p.second, r - l);
        if (p2.first != nullptr) push(p2.first);
        S res = p2.first->sum;
        root = merge(p.first, merge(p2.first, p2.second));
        return res;
    }

    // array[pos].valを取得する
    S get(Nptr &root, int pos) { return prod(root, pos, pos + 1); }

    template <bool (*g)(S)> int max_right(Nptr root, const S &e) {
        return max_right(root, e, [](S x) { return g(x); });
    }
    template <class G> int max_right(Nptr root, const S &e, G g) {
        assert(g(e));
        if (root == nullptr) return 0;
        push(root);
        Nptr now = root;
        S prod_now = e;
        int sz = 0;
        while (true) {
            if (now->l != nullptr) {
                push(now->l);
                auto pl = op(prod_now, now->l->sum);
                if (g(pl)) {
                    prod_now = pl;
                    sz += now->l->sz;
                } else {
                    now = now->l;
                    continue;
                }
            }
            auto pl = op(prod_now, now->val);
            if (!g(pl)) return sz;
            prod_now = pl, sz++;
            if (now->r == nullptr) return sz;
            push(now->r);
            now = now->r;
        }
    }

    template <bool (*g)(S)> int min_left(Nptr root, const S &e) {
        return min_left(root, e, [](S x) { return g(x); });
    }
    template <class G> int min_left(Nptr root, const S &e, G g) {
        assert(g(e));
        if (root == nullptr) return 0;
        push(root);
        Nptr now = root;
        S prod_now = e;
        int sz = size(root);
        while (true) {
            if (now->r != nullptr) {
                push(now->r);
                auto pr = op(now->r->sum, prod_now);
                if (g(pr)) {
                    prod_now = pr;
                    sz -= now->r->sz;
                } else {
                    now = now->r;
                    continue;
                }
            }
            auto pr = op(now->val, prod_now);
            if (!g(pr)) return sz;
            prod_now = pr, sz--;
            if (now->l == nullptr) return sz;
            push(now->l);
            now = now->l;
        }
    }

    void reverse(Nptr &root) { _duplicate_node(root), _toggle(root); }
    void reverse(Nptr &root, int l, int r) {
        auto p2 = split(root, r);
        auto p1 = split(p2.first, l);
        reverse(p1.second);
        root = merge(merge(p1.first, p1.second), p2.second);
    }

    // データを壊して新規にinitの内容を詰める
    void assign(Nptr &root, const std::vector<S> &init) {
        int N = init.size();
        root = N ? _assign_range(0, N, init) : new_tree();
    }
    Nptr _assign_range(int l, int r, const std::vector<S> &init) {
        if (r - l == 1) {
            Nptr t = _make_node(init[l]);
            return update(t);
        }
        return merge(_assign_range(l, (l + r) / 2, init), _assign_range((l + r) / 2, r, init));
    }

    // データをvecへ書き出し
    void dump(Nptr &t, std::vector<S> &vec) {
        if (t == nullptr) return;
        push(t);
        dump(t->l, vec);
        vec.push_back(t->val);
        dump(t->r, vec);
    }

    // gc
    void re_alloc(Nptr &root) {
        std::vector<S> mem;
        dump(root, mem);
        d_ptr = 0;
        assign(root, mem);
    }
};

// Persistent lazy randomized binary search tree
// Verified: https://atcoder.jp/contests/arc030/tasks/arc030_4
// CAUTION: https://yosupo.hatenablog.com/entry/2015/10/29/222536
template <int LEN, class S, S (*op)(S, S), class F, S (*reversal)(S), S (*mapping)(F, S),
          F (*composition)(F, F), F (*id)()>
struct persistent_lazy_rbst : lazy_rbst<LEN, S, op, F, reversal, mapping, composition, id> {
    using RBST = lazy_rbst<LEN, S, op, F, reversal, mapping, composition, id>;
    using Node = typename RBST::Node;
    using Nptr = typename RBST::Nptr;
    persistent_lazy_rbst() : RBST() {}

protected:
    void _duplicate_node(Nptr &t) override {
        if (t == nullptr) return;
        if (RBST::d_ptr >= LEN) throw;
        t = &(RBST::data[RBST::d_ptr++] = *t);
    }

public:
    void copy(Nptr &root, int l, int d, int target_l) { // [target_l, )に[l, l+d)の値を入れる
        auto p1 = RBST::split(root, l);
        auto p2 = RBST::split(p1.second, d);
        root = RBST::merge(p1.first, RBST::merge(p2.first, p2.second));
        auto p3 = RBST::split(root, target_l);
        auto p4 = RBST::split(p3.second, d);
        root = RBST::merge(p3.first, RBST::merge(p2.first, p4.second));
    }
};
#line 3 "modint.hpp"
#include <iostream>
#include <set>
#line 6 "modint.hpp"

template <int md> struct ModInt {
    static_assert(md > 1);
    using lint = long long;
    constexpr static int mod() { return md; }
    static int get_primitive_root() {
        static int primitive_root = 0;
        if (!primitive_root) {
            primitive_root = [&]() {
                std::set<int> fac;
                int v = md - 1;
                for (lint i = 2; i * i <= v; i++)
                    while (v % i == 0) fac.insert(i), v /= i;
                if (v > 1) fac.insert(v);
                for (int g = 1; g < md; g++) {
                    bool ok = true;
                    for (auto i : fac)
                        if (ModInt(g).pow((md - 1) / i) == 1) {
                            ok = false;
                            break;
                        }
                    if (ok) return g;
                }
                return -1;
            }();
        }
        return primitive_root;
    }
    int val_;
    int val() const noexcept { return val_; }
    constexpr ModInt() : val_(0) {}
    constexpr ModInt &_setval(lint v) { return val_ = (v >= md ? v - md : v), *this; }
    constexpr ModInt(lint v) { _setval(v % md + md); }
    constexpr explicit operator bool() const { return val_ != 0; }
    constexpr ModInt operator+(const ModInt &x) const {
        return ModInt()._setval((lint)val_ + x.val_);
    }
    constexpr ModInt operator-(const ModInt &x) const {
        return ModInt()._setval((lint)val_ - x.val_ + md);
    }
    constexpr ModInt operator*(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.val_ % md);
    }
    constexpr ModInt operator/(const ModInt &x) const {
        return ModInt()._setval((lint)val_ * x.inv().val() % md);
    }
    constexpr ModInt operator-() const { return ModInt()._setval(md - val_); }
    constexpr ModInt &operator+=(const ModInt &x) { return *this = *this + x; }
    constexpr ModInt &operator-=(const ModInt &x) { return *this = *this - x; }
    constexpr ModInt &operator*=(const ModInt &x) { return *this = *this * x; }
    constexpr ModInt &operator/=(const ModInt &x) { return *this = *this / x; }
    friend constexpr ModInt operator+(lint a, const ModInt &x) { return ModInt(a) + x; }
    friend constexpr ModInt operator-(lint a, const ModInt &x) { return ModInt(a) - x; }
    friend constexpr ModInt operator*(lint a, const ModInt &x) { return ModInt(a) * x; }
    friend constexpr ModInt operator/(lint a, const ModInt &x) { return ModInt(a) / x; }
    constexpr bool operator==(const ModInt &x) const { return val_ == x.val_; }
    constexpr bool operator!=(const ModInt &x) const { return val_ != x.val_; }
    constexpr bool operator<(const ModInt &x) const {
        return val_ < x.val_;
    } // To use std::map<ModInt, T>
    friend std::istream &operator>>(std::istream &is, ModInt &x) {
        lint t;
        return is >> t, x = ModInt(t), is;
    }
    constexpr friend std::ostream &operator<<(std::ostream &os, const ModInt &x) {
        return os << x.val_;
    }

    constexpr ModInt pow(lint n) const {
        ModInt ans = 1, tmp = *this;
        while (n) {
            if (n & 1) ans *= tmp;
            tmp *= tmp, n >>= 1;
        }
        return ans;
    }

    static constexpr int cache_limit = std::min(md, 1 << 21);
    static std::vector<ModInt> facs, facinvs, invs;

    constexpr static void _precalculation(int N) {
        const int l0 = facs.size();
        if (N > md) N = md;
        if (N <= l0) return;
        facs.resize(N), facinvs.resize(N), invs.resize(N);
        for (int i = l0; i < N; i++) facs[i] = facs[i - 1] * i;
        facinvs[N - 1] = facs.back().pow(md - 2);
        for (int i = N - 2; i >= l0; i--) facinvs[i] = facinvs[i + 1] * (i + 1);
        for (int i = N - 1; i >= l0; i--) invs[i] = facinvs[i] * facs[i - 1];
    }

    constexpr ModInt inv() const {
        if (this->val_ < cache_limit) {
            if (facs.empty()) facs = {1}, facinvs = {1}, invs = {0};
            while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
            return invs[this->val_];
        } else {
            return this->pow(md - 2);
        }
    }

    constexpr static ModInt fac(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        while (n >= int(facs.size())) _precalculation(facs.size() * 2);
        return facs[n];
    }

    constexpr static ModInt facinv(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        while (n >= int(facs.size())) _precalculation(facs.size() * 2);
        return facinvs[n];
    }

    constexpr static ModInt doublefac(int n) {
        assert(n >= 0);
        if (n >= md) return ModInt(0);
        long long k = (n + 1) / 2;
        return (n & 1) ? ModInt::fac(k * 2) / (ModInt(2).pow(k) * ModInt::fac(k))
                       : ModInt::fac(k) * ModInt(2).pow(k);
    }

    constexpr static ModInt nCr(int n, int r) {
        assert(n >= 0);
        if (r < 0 or n < r) return ModInt(0);
        return ModInt::fac(n) * ModInt::facinv(r) * ModInt::facinv(n - r);
    }

    constexpr static ModInt nPr(int n, int r) {
        assert(n >= 0);
        if (r < 0 or n < r) return ModInt(0);
        return ModInt::fac(n) * ModInt::facinv(n - r);
    }

    static ModInt binom(int n, int r) {
        static long long bruteforce_times = 0;

        if (r < 0 or n < r) return ModInt(0);
        if (n <= bruteforce_times or n < (int)facs.size()) return ModInt::nCr(n, r);

        r = std::min(r, n - r);

        ModInt ret = ModInt::facinv(r);
        for (int i = 0; i < r; ++i) ret *= n - i;
        bruteforce_times += r;

        return ret;
    }

    // Multinomial coefficient, (k_1 + k_2 + ... + k_m)! / (k_1! k_2! ... k_m!)
    // Complexity: O(sum(ks))
    template <class Vec> static ModInt multinomial(const Vec &ks) {
        ModInt ret{1};
        int sum = 0;
        for (int k : ks) {
            assert(k >= 0);
            ret *= ModInt::facinv(k), sum += k;
        }
        return ret * ModInt::fac(sum);
    }
    template <class... Args> static ModInt multinomial(Args... args) {
        int sum = (0 + ... + args);
        ModInt result = (1 * ... * ModInt::facinv(args));
        return ModInt::fac(sum) * result;
    }

    // Catalan number, C_n = binom(2n, n) / (n + 1) = # of Dyck words of length 2n
    // C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
    // https://oeis.org/A000108
    // Complexity: O(n)
    static ModInt catalan(int n) {
        if (n < 0) return ModInt(0);
        return ModInt::fac(n * 2) * ModInt::facinv(n + 1) * ModInt::facinv(n);
    }

    ModInt sqrt() const {
        if (val_ == 0) return 0;
        if (md == 2) return val_;
        if (pow((md - 1) / 2) != 1) return 0;
        ModInt b = 1;
        while (b.pow((md - 1) / 2) == 1) b += 1;
        int e = 0, m = md - 1;
        while (m % 2 == 0) m >>= 1, e++;
        ModInt x = pow((m - 1) / 2), y = (*this) * x * x;
        x *= (*this);
        ModInt z = b.pow(m);
        while (y != 1) {
            int j = 0;
            ModInt t = y;
            while (t != 1) j++, t *= t;
            z = z.pow(1LL << (e - j - 1));
            x *= z, z *= z, y *= z;
            e = j;
        }
        return ModInt(std::min(x.val_, md - x.val_));
    }
};
template <int md> std::vector<ModInt<md>> ModInt<md>::facs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::facinvs = {1};
template <int md> std::vector<ModInt<md>> ModInt<md>::invs = {0};

using ModInt998244353 = ModInt<998244353>;
// using mint = ModInt<998244353>;
// using mint = ModInt<1000000007>;
#line 4 "data_structure/test/lazy_rbst.test.cpp"
#include <algorithm>
#line 8 "data_structure/test/lazy_rbst.test.cpp"
using namespace std;

using mint = ModInt<998244353>;

struct S {
    mint sum;
    int sz;
};
using F = pair<bool, pair<mint, mint>>;
S op(S l, S r) { return S{l.sum + r.sum, l.sz + r.sz}; }
S mapping(F f, S x) {
    if (!f.first) return x;
    mint a = f.second.first, b = f.second.second;
    return {x.sum * a + b * x.sz, x.sz};
}
S reversal(S x) { return x; }
F composition(F fnew, F gold) {
    if (!fnew.first) return gold;
    if (!gold.first) return fnew;
    auto anew = fnew.second.first, bnew = fnew.second.second;
    auto aold = gold.second.first, bold = gold.second.second;
    return {true, {anew * aold, anew * bold + bnew}};
}
F id() { return {false, {1, 0}}; }

int main() {
    cin.tie(nullptr), ios::sync_with_stdio(false);
    int N, Q;
    cin >> N >> Q;
    vector<S> A(N);
    for (auto &x : A) cin >> x.sum, x.sz = 1;
    lazy_rbst<1000001, S, op, F, reversal, mapping, composition, id> rbst;

    auto root = rbst.new_tree();
    rbst.assign(root, A);
    while (Q--) {
        int tp;
        cin >> tp;
        if (tp == 0) {
            int i, x;
            cin >> i >> x;
            rbst.insert(root, i, S{x, 1});
            N++;
        } else if (tp == 1) {
            int i;
            cin >> i;
            rbst.erase(root, i);
            N--;
        } else if (tp == 2) {
            int l, r;
            cin >> l >> r;
            rbst.reverse(root, l, r);
        } else if (tp == 3) {
            int l, r, b, c;
            cin >> l >> r >> b >> c;
            rbst.apply(root, l, r, {true, {b, c}});
        } else if (tp == 4) {
            int l, r;
            cin >> l >> r;
            cout << rbst.prod(root, l, r).sum << '\n';
        }
    }
}
Back to top page