题目描述

Given a binary search tree and the lowest and highest boundaries as L and R, trim the tree so that all its elements lies in [L, R] (R >= L). You might need to change the root of the tree, so the result should return the new root of the trimmed binary search tree.

Example 1:

Input: 
    1
   / \
  0   2

  L = 1
  R = 2

Output: 
    1
      \
       2

Example 2:

Input: 
    3
   / \
  0   4
   \
    2
   /
  1

  L = 1
  R = 3

Output: 
      3
     / 
   2   
  /
 1

解法一

思路

仔细想想这题其实就是“二叉搜索树中删除一个结点”的变形,所以可以直接套用那个模板了:

def trimBST(self, root, L, R):
    if not root:
        return root
    
    # TODO

    root.left = self.trimBST(root.left, L, R)
    root.right = self.trimBST(root.right, L, R)
    return root

那么下面的问题就是TODO该怎么写了。再来看看题目,如果当前结点的值不在[L, R]之间,那么这个结点应该删除掉,那么以哪个结点来替换呢?这个倒不难:如果小于L,那么一定是右子树结点;如果大于R,那么一定是左子树结点。注意这里只是说“子树结点”而不是“子节点”,是因为当前结点的“子节点”也可能不满足要求。那么怎么寻找这个“子树结点”呢?我们只用交给trimBST()来处理啦,这里可能有点难理解,还需自行体会。

所以TODO部分就是:

if root.val < L:
    return self.trimBST(root.right, L, R)
if root.val > R:
    return self.trimBST(root.left, L, R)

Python

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None


class Solution:
    def trimBST(self, root, L, R):
        """
        :type root: TreeNode
        :type L: int
        :type R: int
        :rtype: TreeNode
        """
        if not root:
            return root
        if root.val < L:
            return self.trimBST(root.right, L, R)
        if root.val > R:
            return self.trimBST(root.left, L, R)
        root.left = self.trimBST(root.left, L, R)
        root.right = self.trimBST(root.right, L, R)
        return root

Java

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public TreeNode trimBST(TreeNode root, int L, int R) {
        if (root == null) {
            return root;
        }
        if (root.val < L) {
            return trimBST(root.right, L, R);
        }
        if (root.val > R) {
            return trimBST(root.left, L, R);
        }
        root.left = trimBST(root.left, L, R);
        root.right = trimBST(root.right, L, R);
        return root;
    }
}

C++

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    TreeNode *trimBST(TreeNode *root, int L, int R) {
        if (root == nullptr) {
            return root;
        }
        if (root->val < L) {
            return trimBST(root->right, L, R);
        }
        if (root->val > R) {
            return trimBST(root->left, L, R);
        }
        root->left = trimBST(root->left, L, R);
        root->right = trimBST(root->right, L, R);
        return root;
    }
};