c++实现线段树结构

c++实现线段树结构,第1张

c++实现线段树结构

leetcode原题:307.区域和检索

给你一个数组 nums ,请你完成两类查询,其中一类查询要求更新数组下标对应的值,另一类查询要求返回数组中某个范围内元素的总和。

实现 NumArray 类:

NumArray(int[] nums) 用整数数组 nums 初始化对象
void update(int index, int val) 将 nums[index] 的值更新为 val
int sumRange(int left, int right) 返回子数组 nums[left, right] 的总和(即,nums[left] + nums[left + 1], ..., nums[right])

示例:

输入: [“NumArray”, “sumRange”, “update”, “sumRange”] [[[1, 3, 5]], [0,
2], [1, 2], [0, 2]]
输出: [null, 9, null, 8]

解释: NumArray numArray = new NumArray([1, 3, 5]);
numArray.sumRange(0,2); // 返回 9 ,sum([1,3,5]) = 9
numArray.update(1, 2); // nums =[1,2,5]
numArray.sumRange(0, 2); // 返回 8 ,sum([1,2,5]) = 8

提示:

1 <= nums.length <= 3 * 104
-100 <= nums[i] <= 100
0 <= index < nums.length
-100 <= val <= 100
0 <= left <= right < nums.length
最多调用 3 * 104 次 update 和 sumRange 方法

不看题解根本不知道还有一个线段树结构,从没见过这个概念。线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。


  1. 线段树构造过程

    如上图中有10个节点的线段树,根节点为全部区间,然后左节点为1-5,右节点为6-10;之后继续从中间裂变。

    所以,定义线段树结构:

    struct Node {
        int beg; // 起始区间
        int end; // 结束区间
        int val; // 区间之和
        Node *left; // 左子树
        Node *right; // 右子树
    
        Node() {
            beg = end = val = 0;
            left = right = nullptr;
        }
    };
    

    在依据中间裂变方式定义构造线段树方法:

        Node *buildTree(vector &nums, int beg, int end) {
            if (beg > end) {
                return nullptr;
            }
            Node *p = new Node;
            p->beg = beg;
            p->end = end;
            if (beg == end) {
                p->val = nums[beg];
            } else {
                p->left = buildTree(nums, beg, (beg + end) / 2);
                p->right = buildTree(nums, (beg + end) / 2 + 1, end);
            }
            return p;
        }
    

    再构造线段树过程中,叶子节点之和即为该叶子节点对应值,而需要更新非叶子节点中和值。

        int calcTree(Node *p) {
            if (p == nullptr) {
                return 0;
            }
            p->val += calcTree(p->left) + calcTree(p->right);
            return p->val;
        }
    
  2. 更新区间中某个值
    更新区间内某个位置值,则需要遍历线段树,判断更新位置在节点所属区间内,再更新差值。

        void updateTree(Node *p, int index, int diff) {
            if (p == nullptr) {
                return;
            }
            if (p->beg <= index && p->end >= index) {
                p->val += diff;
            }
            if (p->beg > index || p->end < index) {
                return;
            }
            updateTree(p->left, index, diff);
            updateTree(p->right, index, diff);
        }
    
  3. 查找计算区间值
    为了计算区间内之和,如区间[left, right],则可分别求出[1,right] - [1, left]。同样遍历线段树,并判断节点区间。如果节点都在位置之前,则加上整个节点之和。

        int getValue(Node *p, int index) {
            if (index < 0) {
                return 0;
            }
            if (index < p->beg) {
                return getValue(p->left, index);
            } else if (index > p->end) {
                return p->val + getValue(p->right, index);
            } else if (index == p->end) {
                return p->val;
            } else {
                int mid = (p->beg + p->end) / 2;
                if (index <= mid) {
                    return getValue(p->left, index);
                } else {
                    return p->left->val + getValue(p->right, index);
                }
            }
        }
    
  4. 完整实现如下所示

    struct Node {
        int beg;
        int end;
        int val;
        Node *left;
        Node *right;
    
        Node() {
            beg = end = val = 0;
            left = right = nullptr;
        }
    };
    
    class NumArray {
    public:
        NumArray(vector& nums) {
            this->nums = nums;
            root = buildTree(nums, 0, nums.size() - 1);
            calcTree(root);
        }
    
    	  ~NumArray() {
    	  	freeTree(root);
    	  }
    
        void update(int index, int val) {
            if (nums[index] == val) {
                return;
            }
            int diff = val - nums[index];
            nums[index] = val;
            updateTree(root, index, diff);
        }
    
        int sumRange(int left, int right) {
            return getValue(root, right) - getValue(root, left - 1);
        }
    
    private:
        Node *buildTree(vector &nums, int beg, int end) {
            if (beg > end) {
                return nullptr;
            }
            Node *p = new Node;
            p->beg = beg;
            p->end = end;
            if (beg == end) {
                p->val = nums[beg];
            } else {
                p->left = buildTree(nums, beg, (beg + end) / 2);
                p->right = buildTree(nums, (beg + end) / 2 + 1, end);
            }
            return p;
        }
    
        int calcTree(Node *p) {
            if (p == nullptr) {
                return 0;
            }
            p->val += calcTree(p->left) + calcTree(p->right);
            return p->val;
        }
    
        void updateTree(Node *p, int index, int diff) {
            if (p == nullptr) {
                return;
            }
            if (p->beg <= index && p->end >= index) {
                p->val += diff;
            }
            if (p->beg > index || p->end < index) {
                return;
            }
            updateTree(p->left, index, diff);
            updateTree(p->right, index, diff);
        }
    
        int getValue(Node *p, int index) {
            if (index < 0) {
                return 0;
            }
            if (index < p->beg) {
                return getValue(p->left, index);
            } else if (index > p->end) {
                return p->val + getValue(p->right, index);
            } else if (index == p->end) {
                return p->val;
            } else {
                int mid = (p->beg + p->end) / 2;
                if (index <= mid) {
                    return getValue(p->left, index);
                } else {
                    return p->left->val + getValue(p->right, index);
                }
            }
        }
    
        void freeTree(Node *p) {
            if (p == nullptr) {
                return;
            }
            freeTree(p->left);
            freeTree(p->right);
            delete p;
        }
    
    private:
        Node *root;
        vector nums;
    };
    
    
    
    
  5. 看下leetcode中推荐的题解,小丑还是我自己啊。不是说好的树结构,怎么就用个数组就搞定了。

欢迎分享,转载请注明来源:内存溢出

原文地址:https://www.54852.com/zaji/5698566.html

(0)
打赏 微信扫一扫微信扫一扫 支付宝扫一扫支付宝扫一扫
上一篇 2022-12-17
下一篇2022-12-17

发表评论

登录后才能评论

评论列表(0条)

    保存