DEV Community

Abhishek Chaudhary
Abhishek Chaudhary

Posted on

Lowest Common Ancestor of a Binary Tree

Given a binary tree, find the lowest common ancestor (LCA) of two given nodes in the tree.

According to the definition of LCA on Wikipedia: “The lowest common ancestor is defined between two nodes p and q as the lowest node in T that has both p and q as descendants (where we allow a node to be a descendant of itself).”

Example 1:

Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 1
Output: 3
Explanation: The LCA of nodes 5 and 1 is 3.

Example 2:

Input: root = [3,5,1,6,2,0,8,null,null,7,4], p = 5, q = 4
Output: 5
Explanation: The LCA of nodes 5 and 4 is 5, since a node can be a descendant of itself according to the LCA definition.

Example 3:

Input: root = [1,2], p = 1, q = 2
Output: 1

Constraints:

  • The number of nodes in the tree is in the range [2, 105].
  • -109 <= Node.val <= 109
  • All Node.val are unique.
  • p != q
  • p and q will exist in the tree.

SOLUTION:

# Definition for a binary tree node. # class TreeNode: # def __init__(self, x): # self.val = x # self.left = None # self.right = None # class Solution: # def isInTree(self, root, node): # if root: # if root.val == node.val: # return True # if self.isInTree(root.left, node) or self.isInTree(root.right, node): # return True # return False # def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode': # if root: # if root.val == p.val or root.val == q.val: # return root # pInLeft = self.isInTree(root.left, p) # qInLeft = self.isInTree(root.left, q) # if pInLeft and qInLeft: # return self.lowestCommonAncestor(root.left, p, q) # if not pInLeft and not qInLeft: # return self.lowestCommonAncestor(root.right, p, q) # if pInLeft ^ qInLeft: # return root class Solution: def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode': nodepaths = [None, None] paths = [[root]] while len(paths) > 0: curr = paths.pop() if curr[-1].val == p.val: nodepaths[0] = curr if curr[-1].val == q.val: nodepaths[1] = curr if nodepaths[0] and nodepaths[1]: break if curr[-1].left: paths.append(curr + [curr[-1].left]) if curr[-1].right: paths.append(curr + [curr[-1].right]) i = 0 k = min(len(nodepaths[0]), len(nodepaths[1])) while nodepaths[0][i].val == nodepaths[1][i].val: i += 1 if i >= k: break return nodepaths[0][i - 1] 
Enter fullscreen mode Exit fullscreen mode

Top comments (0)