【Rust】ACL の遅延セグ木をベースに双対セグ木を作ってみた

ACLAtCoder Library)には通常のセグ木と遅延セグ木はありますが、双対セグ木はありません(2024年5月時点)。

遅延セグ木は区間作用・区間取得ができるデータ構造なのに対して、双対セグ木は区間作用・1点取得ができるデータ構造です。 双対セグ木は遅延セグ木と比べると区間取得ができないです。一方、双対セグ木には以下のメリットがあります。

  • 乗せるモノイドの定義が簡潔になる。
    • 遅延セグ木では作用とデータの両方に対してモノイドを定義する必要があったのに対して、双対セグ木は作用に対してのみモノイドを定義すればよいためです。
  • 遅延セグ木と比べて双対セグ木の方が定数倍で速い

つまり、1点更新で十分であれば双対セグ木は有用です。 そこで、ACL の遅延セグ木をベースに双対セグ木を作成しました。 基本的には、遅延セグ木から機能を削る形で作成しました。

ACL の遅延セグ木のコード」と「作成した双対セグ木のコード」との diff

(作成した双対セグ木のコードは本記事の後半にあります)

diff を見ると、双対セグ木のコードの作り方が大体わかります。diff を要約すると以下の通りです。

  • 作用を受けるデータに対してはモノイドが不要になった(作用に対してのみモノイドを定義すればよくなった)
  • データはセグ木の葉に対してのみ持つようにした(遅延セグ木では、セグ木の各ノードに対してデータを載せていた)

    • データを持つ配列の長さは、遅延セグ木では 2 * size だったのに対して、双対セグ木では size になっている
  • 区間取得の関数(prod など)を削除した
  • (など)

実際のコードの diff は以下の通りです。

--- acl_lazy_segtree.rs  2024-05-05 22:01:52.807678542 +0900
+++ dual_segtree.rs   2024-05-05 22:27:35.097668299 +0900
@@ -1,148 +1,78 @@
-use crate::internal_bit::ceil_pow2;
-use crate::Monoid;
+use std::ops::{Bound, RangeBounds};
+
+fn ceil_pow2(n: u32) -> u32 {
+    32 - n.saturating_sub(1).leading_zeros()
+}
 
 pub trait MapMonoid {
-    type M: Monoid;
     type F: Clone;
-    // type S = <Self::M as Monoid>::S;
-    fn identity_element() -> <Self::M as Monoid>::S {
-        Self::M::identity()
-    }
-    fn binary_operation(
-        a: &<Self::M as Monoid>::S,
-        b: &<Self::M as Monoid>::S,
-    ) -> <Self::M as Monoid>::S {
-        Self::M::binary_operation(a, b)
-    }
+    type S: Clone;
     fn identity_map() -> Self::F;
-    fn mapping(f: &Self::F, x: &<Self::M as Monoid>::S) -> <Self::M as Monoid>::S;
+    fn mapping(f: &Self::F, x: &Self::S) -> Self::S;
     fn composition(f: &Self::F, g: &Self::F) -> Self::F;
 }
 
-impl<F: MapMonoid> Default for LazySegtree<F> {
+impl<F: MapMonoid> Default for DualSegtree<F>
+where
+    F::S: Default,
+{
     fn default() -> Self {
         Self::new(0)
     }
 }
-impl<F: MapMonoid> LazySegtree<F> {
-    pub fn new(n: usize) -> Self {
-        vec![F::identity_element(); n].into()
+impl<F: MapMonoid> DualSegtree<F> {
+    pub fn new(n: usize) -> Self
+    where
+        F::S: Default,
+    {
+        vec![F::S::default(); n].into()
     }
 }
-impl<F: MapMonoid> From<Vec<<F::M as Monoid>::S>> for LazySegtree<F> {
-    fn from(v: Vec<<F::M as Monoid>::S>) -> Self {
+
+impl<F: MapMonoid> From<Vec<F::S>> for DualSegtree<F>
+where
+    F::S: Default,
+{
+    fn from(v: Vec<F::S>) -> Self {
         let n = v.len();
         let log = ceil_pow2(n as u32) as usize;
         let size = 1 << log;
-        let mut d = vec![F::identity_element(); 2 * size];
+        let mut d = vec![F::S::default(); size];
         let lz = vec![F::identity_map(); size];
-        d[size..(size + n)].clone_from_slice(&v);
-        let mut ret = LazySegtree {
+        d[..n].clone_from_slice(&v);
+        DualSegtree {
             n,
             size,
             log,
             d,
             lz,
-        };
-        for i in (1..size).rev() {
-            ret.update(i);
         }
-        ret
     }
 }
 
-impl<F: MapMonoid> LazySegtree<F> {
-    pub fn set(&mut self, mut p: usize, x: <F::M as Monoid>::S) {
+impl<F: MapMonoid> DualSegtree<F> {
+    pub fn set(&mut self, p: usize, x: F::S) {
         assert!(p < self.n);
-        p += self.size;
         for i in (1..=self.log).rev() {
-            self.push(p >> i);
+            self.push((p + self.size) >> i);
         }
         self.d[p] = x;
-        for i in 1..=self.log {
-            self.update(p >> i);
-        }
     }
 
-    pub fn get(&mut self, mut p: usize) -> <F::M as Monoid>::S {
+    pub fn get(&mut self, p: usize) -> F::S {
         assert!(p < self.n);
-        p += self.size;
         for i in (1..=self.log).rev() {
-            self.push(p >> i);
+            self.push((p + self.size) >> i);
         }
         self.d[p].clone()
     }
 
-    pub fn prod<R>(&mut self, range: R) -> <F::M as Monoid>::S
-    where
-        R: RangeBounds<usize>,
-    {
-        // Trivial optimization
-        if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
-            return self.all_prod();
-        }
-
-        let mut r = match range.end_bound() {
-            Bound::Included(r) => r + 1,
-            Bound::Excluded(r) => *r,
-            Bound::Unbounded => self.n,
-        };
-        let mut l = match range.start_bound() {
-            Bound::Included(l) => *l,
-            Bound::Excluded(l) => l + 1,
-            // TODO: There are another way of optimizing [0..r)
-            Bound::Unbounded => 0,
-        };
-
-        assert!(l <= r && r <= self.n);
-        if l == r {
-            return F::identity_element();
-        }
-
-        l += self.size;
-        r += self.size;
-
-        for i in (1..=self.log).rev() {
-            if ((l >> i) << i) != l {
-                self.push(l >> i);
-            }
-            if ((r >> i) << i) != r {
-                self.push(r >> i);
-            }
-        }
-
-        let mut sml = F::identity_element();
-        let mut smr = F::identity_element();
-        while l < r {
-            if l & 1 != 0 {
-                sml = F::binary_operation(&sml, &self.d[l]);
-                l += 1;
-            }
-            if r & 1 != 0 {
-                r -= 1;
-                smr = F::binary_operation(&self.d[r], &smr);
-            }
-            l >>= 1;
-            r >>= 1;
-        }
-
-        F::binary_operation(&sml, &smr)
-    }
-
-    pub fn all_prod(&self) -> <F::M as Monoid>::S {
-        self.d[1].clone()
-    }
-
-    pub fn apply(&mut self, mut p: usize, f: F::F) {
+    pub fn apply(&mut self, p: usize, f: F::F) {
         assert!(p < self.n);
-        p += self.size;
         for i in (1..=self.log).rev() {
-            self.push(p >> i);
+            self.push((p + self.size) >> i);
         }
         self.d[p] = F::mapping(&f, &self.d[p]);
-        for i in 1..=self.log {
-            self.update(p >> i);
-        }
     }
     pub fn apply_range<R>(&mut self, range: R, f: F::F)
     where
@@ -178,8 +108,6 @@
         }
 
         {
-            let l2 = l;
-            let r2 = r;
             while l < r {
                 if l & 1 != 0 {
                     self.all_apply(l, f.clone());
@@ -192,126 +120,29 @@
                 l >>= 1;
                 r >>= 1;
             }
-            l = l2;
-            r = r2;
-        }
-
-        for i in 1..=self.log {
-            if ((l >> i) << i) != l {
-                self.update(l >> i);
-            }
-            if ((r >> i) << i) != r {
-                self.update((r - 1) >> i);
-            }
         }
     }
-
-    pub fn max_right<G>(&mut self, mut l: usize, g: G) -> usize
-    where
-        G: Fn(<F::M as Monoid>::S) -> bool,
-    {
-        assert!(l <= self.n);
-        assert!(g(F::identity_element()));
-        if l == self.n {
-            return self.n;
-        }
-        l += self.size;
-        for i in (1..=self.log).rev() {
-            self.push(l >> i);
-        }
-        let mut sm = F::identity_element();
-        while {
-            // do
-            while l % 2 == 0 {
-                l >>= 1;
-            }
-            if !g(F::binary_operation(&sm, &self.d[l])) {
-                while l < self.size {
-                    self.push(l);
-                    l *= 2;
-                    let res = F::binary_operation(&sm, &self.d[l]);
-                    if g(res.clone()) {
-                        sm = res;
-                        l += 1;
-                    }
-                }
-                return l - self.size;
-            }
-            sm = F::binary_operation(&sm, &self.d[l]);
-            l += 1;
-            //while
-            {
-                let l = l as isize;
-                (l & -l) != l
-            }
-        } {}
-        self.n
-    }
-
-    pub fn min_left<G>(&mut self, mut r: usize, g: G) -> usize
-    where
-        G: Fn(<F::M as Monoid>::S) -> bool,
-    {
-        assert!(r <= self.n);
-        assert!(g(F::identity_element()));
-        if r == 0 {
-            return 0;
-        }
-        r += self.size;
-        for i in (1..=self.log).rev() {
-            self.push((r - 1) >> i);
-        }
-        let mut sm = F::identity_element();
-        while {
-            // do
-            r -= 1;
-            while r > 1 && r % 2 != 0 {
-                r >>= 1;
-            }
-            if !g(F::binary_operation(&self.d[r], &sm)) {
-                while r < self.size {
-                    self.push(r);
-                    r = 2 * r + 1;
-                    let res = F::binary_operation(&self.d[r], &sm);
-                    if g(res.clone()) {
-                        sm = res;
-                        r -= 1;
-                    }
-                }
-                return r + 1 - self.size;
-            }
-            sm = F::binary_operation(&self.d[r], &sm);
-            // while
-            {
-                let r = r as isize;
-                (r & -r) != r
-            }
-        } {}
-        0
-    }
 }
 
-pub struct LazySegtree<F>
+pub struct DualSegtree<F>
 where
     F: MapMonoid,
 {
     n: usize,
     size: usize,
     log: usize,
-    d: Vec<<F::M as Monoid>::S>,
+    d: Vec<F::S>,
     lz: Vec<F::F>,
 }
-impl<F> LazySegtree<F>
+impl<F> DualSegtree<F>
 where
     F: MapMonoid,
 {
-    fn update(&mut self, k: usize) {
-        self.d[k] = F::binary_operation(&self.d[2 * k], &self.d[2 * k + 1]);
-    }
     fn all_apply(&mut self, k: usize, f: F::F) {
-        self.d[k] = F::mapping(&f, &self.d[k]);
         if k < self.size {
             self.lz[k] = F::composition(&f, &self.lz[k]);
+        } else {
+            self.d[k - self.size] = F::mapping(&f, &self.d[k - self.size]);
         }
     }
     fn push(&mut self, k: usize) {
@@ -321,45 +152,16 @@
     }
 }
 
-// TODO is it useful?
-use std::{
-    fmt::{Debug, Error, Formatter, Write},
-    ops::{Bound, RangeBounds},
-};
-impl<F> Debug for LazySegtree<F>
-where
-    F: MapMonoid,
-    F::F: Debug,
-    <F::M as Monoid>::S: Debug,
-{
-    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
-        for i in 0..self.log {
-            for j in 0..1 << i {
-                f.write_fmt(format_args!(
-                    "{:?}[{:?}]\t",
-                    self.d[(1 << i) + j],
-                    self.lz[(1 << i) + j]
-                ))?;
-            }
-            f.write_char('\n')?;
-        }
-        for i in 0..self.size {
-            f.write_fmt(format_args!("{:?}\t", self.d[self.size + i]))?;
-        }
-        Ok(())
-    }
-}
-
 #[cfg(test)]
 mod tests {
-    use std::ops::{Bound::*, RangeBounds};
+    use std::convert::Infallible;
 
-    use crate::{LazySegtree, MapMonoid, Max};
+    use super::{DualSegtree, MapMonoid};
 
-    struct MaxAdd;
-    impl MapMonoid for MaxAdd {
-        type M = Max<i32>;
+    struct RangeAdd(Infallible);
+    impl MapMonoid for RangeAdd {
         type F = i32;
+        type S = i32;
 
         fn identity_map() -> Self::F {
             0
@@ -375,14 +177,15 @@
     }
 
     #[test]
-    fn test_max_add_lazy_segtree() {
+    fn test_range_add_dual_segtree() {
         let base = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
         let n = base.len();
-        let mut segtree: LazySegtree<MaxAdd> = base.clone().into();
+        let mut segtree: DualSegtree<RangeAdd> = base.clone().into();
         check_segtree(&base, &mut segtree);
 
-        let mut segtree = LazySegtree::<MaxAdd>::new(n);
         let mut internal = vec![i32::min_value(); n];
+        let mut segtree = DualSegtree::<RangeAdd>::from(internal.clone());
+
         for i in 0..n {
             segtree.set(i, base[i]);
             internal[i] = base[i];
@@ -411,69 +214,11 @@
     }
 
     //noinspection DuplicatedCode
-    fn check_segtree(base: &[i32], segtree: &mut LazySegtree<MaxAdd>) {
+    fn check_segtree(base: &[i32], segtree: &mut DualSegtree<RangeAdd>) {
         let n = base.len();
         #[allow(clippy::needless_range_loop)]
         for i in 0..n {
             assert_eq!(segtree.get(i), base[i]);
         }
-
-        check(base, segtree, ..);
-        for i in 0..=n {
-            check(base, segtree, ..i);
-            check(base, segtree, i..);
-            if i < n {
-                check(base, segtree, ..=i);
-            }
-            for j in i..=n {
-                check(base, segtree, i..j);
-                if j < n {
-                    check(base, segtree, i..=j);
-                    check(base, segtree, (Excluded(i), Included(j)));
-                }
-            }
-        }
-        assert_eq!(
-            segtree.all_prod(),
-            base.iter().max().copied().unwrap_or(i32::min_value())
-        );
-        for k in 0..=10 {
-            let f = |x| x < k;
-            for i in 0..=n {
-                assert_eq!(
-                    Some(segtree.max_right(i, f)),
-                    (i..=n)
-                        .filter(|&j| f(base[i..j]
-                            .iter()
-                            .max()
-                            .copied()
-                            .unwrap_or(i32::min_value())))
-                        .max()
-                );
-            }
-            for j in 0..=n {
-                assert_eq!(
-                    Some(segtree.min_left(j, f)),
-                    (0..=j)
-                        .filter(|&i| f(base[i..j]
-                            .iter()
-                            .max()
-                            .copied()
-                            .unwrap_or(i32::min_value())))
-                        .min()
-                );
-            }
-        }
-    }
-
-    fn check(base: &[i32], segtree: &mut LazySegtree<MaxAdd>, range: impl RangeBounds<usize>) {
-        let expected = base
-            .iter()
-            .enumerate()
-            .filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
-            .max()
-            .copied()
-            .unwrap_or(i32::min_value());
-        assert_eq!(segtree.prod(range), expected);
     }
 }

(304行削除、49行追加)

作成した双対セグ木のソースコード

作成した双対セグ木のソースコードは以下の通りです。 このコードは自由に使用していただいて構いませんが、バグがないことは保証しません。

use std::ops::{Bound, RangeBounds};

fn ceil_pow2(n: u32) -> u32 {
    32 - n.saturating_sub(1).leading_zeros()
}

pub trait MapMonoid {
    type F: Clone;
    type S: Clone;
    fn identity_map() -> Self::F;
    fn mapping(f: &Self::F, x: &Self::S) -> Self::S;
    fn composition(f: &Self::F, g: &Self::F) -> Self::F;
}

impl<F: MapMonoid> Default for DualSegtree<F>
where
    F::S: Default,
{
    fn default() -> Self {
        Self::new(0)
    }
}
impl<F: MapMonoid> DualSegtree<F> {
    pub fn new(n: usize) -> Self
    where
        F::S: Default,
    {
        vec![F::S::default(); n].into()
    }
}

impl<F: MapMonoid> From<Vec<F::S>> for DualSegtree<F>
where
    F::S: Default,
{
    fn from(v: Vec<F::S>) -> Self {
        let n = v.len();
        let log = ceil_pow2(n as u32) as usize;
        let size = 1 << log;
        let mut d = vec![F::S::default(); size];
        let lz = vec![F::identity_map(); size];
        d[..n].clone_from_slice(&v);
        DualSegtree {
            n,
            size,
            log,
            d,
            lz,
        }
    }
}

impl<F: MapMonoid> DualSegtree<F> {
    pub fn set(&mut self, p: usize, x: F::S) {
        assert!(p < self.n);
        for i in (1..=self.log).rev() {
            self.push((p + self.size) >> i);
        }
        self.d[p] = x;
    }

    pub fn get(&mut self, p: usize) -> F::S {
        assert!(p < self.n);
        for i in (1..=self.log).rev() {
            self.push((p + self.size) >> i);
        }
        self.d[p].clone()
    }

    pub fn apply(&mut self, p: usize, f: F::F) {
        assert!(p < self.n);
        for i in (1..=self.log).rev() {
            self.push((p + self.size) >> i);
        }
        self.d[p] = F::mapping(&f, &self.d[p]);
    }
    pub fn apply_range<R>(&mut self, range: R, f: F::F)
    where
        R: RangeBounds<usize>,
    {
        let mut r = match range.end_bound() {
            Bound::Included(r) => r + 1,
            Bound::Excluded(r) => *r,
            Bound::Unbounded => self.n,
        };
        let mut l = match range.start_bound() {
            Bound::Included(l) => *l,
            Bound::Excluded(l) => l + 1,
            // TODO: There are another way of optimizing [0..r)
            Bound::Unbounded => 0,
        };

        assert!(l <= r && r <= self.n);
        if l == r {
            return;
        }

        l += self.size;
        r += self.size;

        for i in (1..=self.log).rev() {
            if ((l >> i) << i) != l {
                self.push(l >> i);
            }
            if ((r >> i) << i) != r {
                self.push((r - 1) >> i);
            }
        }

        {
            while l < r {
                if l & 1 != 0 {
                    self.all_apply(l, f.clone());
                    l += 1;
                }
                if r & 1 != 0 {
                    r -= 1;
                    self.all_apply(r, f.clone());
                }
                l >>= 1;
                r >>= 1;
            }
        }
    }
}

pub struct DualSegtree<F>
where
    F: MapMonoid,
{
    n: usize,
    size: usize,
    log: usize,
    d: Vec<F::S>,
    lz: Vec<F::F>,
}
impl<F> DualSegtree<F>
where
    F: MapMonoid,
{
    fn all_apply(&mut self, k: usize, f: F::F) {
        if k < self.size {
            self.lz[k] = F::composition(&f, &self.lz[k]);
        } else {
            self.d[k - self.size] = F::mapping(&f, &self.d[k - self.size]);
        }
    }
    fn push(&mut self, k: usize) {
        self.all_apply(2 * k, self.lz[k].clone());
        self.all_apply(2 * k + 1, self.lz[k].clone());
        self.lz[k] = F::identity_map();
    }
}

#[cfg(test)]
mod tests {
    use std::convert::Infallible;

    use super::{DualSegtree, MapMonoid};

    struct RangeAdd(Infallible);
    impl MapMonoid for RangeAdd {
        type F = i32;
        type S = i32;

        fn identity_map() -> Self::F {
            0
        }

        fn mapping(&f: &i32, &x: &i32) -> i32 {
            f + x
        }

        fn composition(&f: &i32, &g: &i32) -> i32 {
            f + g
        }
    }

    #[test]
    fn test_range_add_dual_segtree() {
        let base = vec![3, 1, 4, 1, 5, 9, 2, 6, 5, 3];
        let n = base.len();
        let mut segtree: DualSegtree<RangeAdd> = base.clone().into();
        check_segtree(&base, &mut segtree);

        let mut internal = vec![i32::min_value(); n];
        let mut segtree = DualSegtree::<RangeAdd>::from(internal.clone());

        for i in 0..n {
            segtree.set(i, base[i]);
            internal[i] = base[i];
            check_segtree(&internal, &mut segtree);
        }

        segtree.set(6, 5);
        internal[6] = 5;
        check_segtree(&internal, &mut segtree);

        segtree.apply(5, 1);
        internal[5] += 1;
        check_segtree(&internal, &mut segtree);

        segtree.set(6, 0);
        internal[6] = 0;
        check_segtree(&internal, &mut segtree);

        segtree.apply_range(3..8, 2);
        internal[3..8].iter_mut().for_each(|e| *e += 2);
        check_segtree(&internal, &mut segtree);

        segtree.apply_range(2..=5, 7);
        internal[2..=5].iter_mut().for_each(|e| *e += 7);
        check_segtree(&internal, &mut segtree);
    }

    //noinspection DuplicatedCode
    fn check_segtree(base: &[i32], segtree: &mut DualSegtree<RangeAdd>) {
        let n = base.len();
        #[allow(clippy::needless_range_loop)]
        for i in 0..n {
            assert_eq!(segtree.get(i), base[i]);
        }
    }
}

テスト

双対セグ木のコードに付属しているテストが通ることは確認しました。また、作成した双対セグ木で ABC332 F - Random Update Query が通ることを確認しました(双対セグ木を使った ABC332 F の提出)。