mod tree {
use std::mem::swap;
use num::Num;
#[derive(Clone, Copy)]
struct Edge<T: Ord + Copy + Num> {
from: usize,
to: usize,
cost: T,
}
pub struct LowestCommonAncestor<T: Ord + Copy + Num> {
n: usize,
logn: usize,
edges: Vec<Vec<Edge<T>>>,
depth: Vec<usize>,
tab: Vec<Vec<Option<usize>>>,
order: Vec<usize>,
}
#[allow(dead_code)]
impl<T: Ord + Copy + Num> LowestCommonAncestor<T> {
pub fn new(n: usize) -> Self {
let edges = vec![vec![]; n];
let depth = vec![0; n];
let logn = bit_width(n) - 1;
let tab = vec![vec![None; logn + 1]; n];
let order = vec![0; n];
Self {
n,
logn,
edges,
depth,
tab,
order,
}
}
pub fn add_edge(&mut self, from: usize, to: usize, cost: T) {
self.add_directed_edge(from, to, cost);
self.add_directed_edge(to, from, cost);
}
fn add_directed_edge(&mut self, from: usize, to: usize, cost: T) {
let edge = Edge { from, to, cost };
self.edges[from].push(edge);
}
pub fn build(&mut self) {
self.build_depth();
self.build_tab();
}
fn build_depth(&mut self) {
self.dfs(0, 0, 0, 0);
}
fn dfs(&mut self, v: usize, par: usize, depth: usize, mut order: usize) -> usize {
self.depth[v] = depth;
self.order[v] = order;
let m = self.edges[v].len();
for i in 0..m {
let edge = self.edges[v][i];
if edge.to == par {
continue;
}
order = self.dfs(edge.to, v, depth + 1, order + 1);
}
order
}
fn build_tab(&mut self) {
for i in 0..self.n {
for edge in self.edges[i].iter() {
if self.depth[edge.from] < self.depth[edge.to] {
continue;
}
self.tab[edge.from][0] = Some(edge.to);
}
}
for k in 0..self.logn {
for i in 0..self.n {
if let Some(to) = self.tab[i][k] {
self.tab[i][k + 1] = self.tab[to][k];
}
}
}
}
pub fn get_order(&self, v: usize) -> usize {
self.order[v]
}
pub fn lca(&self, mut u: usize, mut v: usize) -> usize {
if self.depth[u] < self.depth[v] {
swap(&mut u, &mut v);
}
let diff = self.depth[u] - self.depth[v];
for i in 0..=self.logn {
if ((diff >> i) & 1) == 1 {
u = self.tab[u][i].unwrap();
}
}
if u == v {
return u;
}
for i in (0..=self.logn).rev() {
if self.tab[u][i].is_none() || self.tab[v][i].is_none() {
continue;
}
if self.tab[u][i] != self.tab[v][i] {
u = self.tab[u][i].unwrap();
v = self.tab[v][i].unwrap();
}
}
self.tab[u][0].unwrap()
}
pub fn build_cost_tab<F: Fn(T, T) -> T>(&self, merge: F, e: T) -> Vec<Vec<T>> {
let mut tab = vec![vec![e; self.logn + 1]; self.n];
for i in 0..self.n {
for edge in self.edges[i].iter() {
if self.depth[edge.from] < self.depth[edge.to] {
continue;
}
tab[edge.from][0] = edge.cost;
}
}
for k in 0..self.logn {
for i in 0..self.n {
if let Some(to) = self.tab[i][k] {
tab[i][k + 1] = merge(tab[i][k], tab[to][k]);
}
}
}
tab
}
pub fn lca_cost_tab<F: Fn(T, T) -> T>(
&self,
mut u: usize,
mut v: usize,
merge: F,
e: T,
tab: &Vec<Vec<T>>,
) -> T {
let mut ret = e;
if self.depth[u] < self.depth[v] {
swap(&mut u, &mut v);
}
let diff = self.depth[u] - self.depth[v];
for i in 0..=self.logn {
if ((diff >> i) & 1) == 1 {
ret = merge(ret, tab[u][i]);
u = self.tab[u][i].unwrap();
}
}
if u == v {
return ret;
}
for i in (0..=self.logn).rev() {
if self.tab[u][i].is_none() || self.tab[v][i].is_none() {
continue;
}
if self.tab[u][i] != self.tab[v][i] {
ret = merge(ret, tab[u][i]);
ret = merge(ret, tab[v][i]);
u = self.tab[u][i].unwrap();
v = self.tab[v][i].unwrap();
}
}
ret = merge(ret, tab[u][0]);
ret = merge(ret, tab[v][0]);
ret
}
}
fn bit_width(x: usize) -> usize {
x.ilog2() as usize + 1
}
}