1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| package main
import "fmt"
type SegmentTree struct { data []int tree []int mergeFunc func(lv, rv int) int }
func NewSegmentTree(data []int, mergeFunc func(lv, rv int) int) *SegmentTree { st := &SegmentTree{} st.tree = make([]int, 4*len(data)) st.data = make([]int, len(data)) st.mergeFunc = mergeFunc copy(st.data, data) st.buildSegmentTree(0, 0, len(data)-1) return st }
func (st *SegmentTree) Get(i int) int { return st.data[i] }
func (st *SegmentTree) Size() int { return len(st.data) }
func (st *SegmentTree) Search(lIdx, rIdx int) int { return st.search(0, 0, len(st.data)-1, lIdx, rIdx) }
func (st *SegmentTree) search(root, l, r, lIdx, rIdx int) (res int) { if l == lIdx && r == rIdx { return st.tree[root] } lTreeIdx := st.leftChild(root) rTreeIdx := st.rightChild(root) mid := l + (r-l)/2 if rIdx <= mid { return st.search(lTreeIdx, l, mid, lIdx, rIdx) } if lIdx > mid { return st.search(rTreeIdx, mid+1, r, lIdx, rIdx) } return st.mergeFunc( st.search(lTreeIdx, l, mid, lIdx, mid), st.search(rTreeIdx, mid+1, r, mid+1, rIdx)) }
func (st *SegmentTree) Update(idx int, val int) { st.update(0, 0, len(st.data)-1, idx, val) }
func (st *SegmentTree) update(root, l, r int, idx int, val int) { if l == idx && l == r { st.tree[root] = val st.data[idx] = val return } lTreeIdx := st.leftChild(root) rTreeIdx := st.rightChild(root) mid := l + (r-l)/2 if idx <= mid { st.update(lTreeIdx, l, mid, idx, val) } else { st.update(rTreeIdx, mid+1, r, idx, val) } st.tree[root] = st.mergeFunc(st.tree[lTreeIdx], st.tree[rTreeIdx]) }
func (st *SegmentTree) buildSegmentTree(root, l, r int) { if l == r { st.tree[root] = st.data[l] return } lTreeIdx := st.leftChild(root) rTreeIdx := st.rightChild(root) midIdx := l + (r-l)/2 st.buildSegmentTree(lTreeIdx, l, midIdx) st.buildSegmentTree(rTreeIdx, midIdx+1, r) st.tree[root] = st.mergeFunc(st.tree[lTreeIdx], st.tree[rTreeIdx]) }
func (st *SegmentTree) leftChild(idx int) int { return 2*idx + 1 }
func (st *SegmentTree) rightChild(idx int) int { return 2*idx + 2 }
func main() { sg := NewSegmentTree([]int{-2, 0, 3, -5, 2, -1, 10, 23}, func(lv, rv int) int { return lv + rv }) fmt.Println(sg.tree) fmt.Println(sg.data) fmt.Println(sg.Search(1, 6)) fmt.Println(sg.Search(3, 4)) sg.Update(3, 5) fmt.Println(sg.tree) fmt.Println(sg.data) fmt.Println(sg.Search(1, 6)) fmt.Println(sg.Search(3, 4)) }
|