Question
Given a binary tree, count the number of uni-value subtrees.
A Uni-value subtree means all nodes of the subtree have the same value.
Example :
Input: root = [5,1,5,5,5,null,5]
5
/ \
1 5
/ \ \
5 5 5
Output: 4
Solution
# Definition for a binary tree node.
# class TreeNode(object):
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
class Solution(object):
def countUnivalSubtrees(self, root):
"""
:type root: TreeNode
:rtype: int
"""
if root is None:
return 0
else:
return self.helper(root)[0]
def helper(self, root):
if root.left is None and root.right is None:
return (1, True)
elif root.left is not None and root.right is None:
subTreeCount, isUnvalSubTree = self.helper(root.left)
isMatch = root.val == root.left.val and isUnvalSubTree
return (subTreeCount + (1 if isMatch else 0), isMatch)
elif root.left is None and root.right is not None:
subTreeCount, isUnvalSubTree = self.helper(root.right)
isMatch = root.val == root.right.val and isUnvalSubTree
return (subTreeCount + (1 if isMatch else 0), isMatch)
else:
subTreeCountLeft, isUnvalSubTreeLeft = self.helper(root.left)
subTreeCountRight, isUnvalSubTreeRight = self.helper(root.right)
isMatch = root.val == root.left.val == root.right.val and isUnvalSubTreeLeft and isUnvalSubTreeRight
return (subTreeCountLeft + subTreeCountRight + (1 if isMatch else 0), isMatch)