题目描述

Given a binary tree, return the tilt of the whole tree.

The tilt of a tree node is defined as the absolute difference between the sum of all left subtree node values and the sum of all right subtree node values. Null node has tilt 0.

The tilt of the whole tree is defined as the sum of all nodes' tilt.

Example:

Input: 
         1
       /   \
      2     3
Output: 1
Explanation: 
Tilt of node 2 : 0
Tilt of node 3 : 0
Tilt of node 1 : |2-3| = 1
Tilt of binary tree : 0 + 0 + 1 = 1

Note:

  1. The sum of node values in any subtree won't exceed the range of 32-bit integer.
  2. All the tilt values won't exceed the range of 32-bit integer.

解法一

思路

既然是二叉树的遍历问题,那想必是递归了。再看对于某个结点,需要先得到左子树和右子树的和,才能计算tilt,那就是后序遍历了,模板就是:

def sum_and_tilt(root):
    ...
    sum_left, tilt_left = sum_and_tilt(root.left)
    sum_right, tilt_right = sum_and_tilt(root.right)
    # TODO
    return sum,tilt

下面来考虑sumtilt该如何计算。sum即以当前结点为根节点的所有结点的和,那就是sum_left + sum_right + root.value了。当前结点的tilt是左子树结点和与右子树结点和的差,那就是abs(sum_left - sum_right);注意这里返回的tilt是包括其左右子树的tilt的,所以最终就是abs(sum_left - sum_right) + tilt_left + tilt_right

再仔细考虑tilt,实际上tilt不必要放在递归结构中返回,直接让tilt作为成员变量,保存累加和就好了,这样也避免和很多语言不支持return多个变量的问题。

Python

# Definition for a binary tree node.
# class TreeNode(object):
#     def __init__(self, x):
#         self.val = x
#         self.left = None
#         self.right = None


class Solution(object):
    def findTilt(self, root):
        """
        :type root: TreeNode
        :rtype: int
        """

        def sum_and_tilt(root):
            if not root:
                return 0, 0
            sum_left, tilt_left = sum_and_tilt(root.left)
            sum_right, tilt_right = sum_and_tilt(root.right)
            return sum_left + sum_right + root.val, abs(sum_left - sum_right) + tilt_left + tilt_right

        sum_tree, tilt_tree = sum_and_tilt(root)
        return tilt_tree

Java

/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int findTilt(TreeNode root) {
        class Util {
            int tilt;

            int sum(TreeNode root) {
                if (root == null) {
                    return 0;
                }
                int sumLeft = sum(root.left);
                int sumRight = sum(root.right);
                tilt += Math.abs(sumLeft - sumRight);
                return sumLeft + sumRight + root.val;
            }
        }
        Util util = new Util();
        util.sum(root);
        return util.tilt;
    }
}

C++

/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int findTilt(TreeNode *root) {
        class Util {
        public:
            int tilt = 0;

            int sum(TreeNode *root) {
                if (root == nullptr) {
                    return 0;
                }
                int sumLeft = sum(root->left);
                int sumRight = sum(root->right);
                tilt += abs(sumLeft - sumRight);
                return sumLeft + sumRight + root->val;
            }
        };
        Util util;
        util.sum(root);
        return util.tilt;
    }
};