1660. Correct a Binary Tree

Problem Description

You are given a binary tree that has exactly one defective node. This defective node has its right child pointer incorrectly pointing to another node that:

  • Is at the same depth (level) in the tree
  • Is positioned to the right of the defective node

Your task is to identify and remove this defective node along with all nodes in its subtree (except for the node it incorrectly points to), then return the root of the corrected binary tree.

The key insight is that in a valid binary tree, when traversing from root to leaves, you should never encounter the same node twice. If a node's right child has already been visited during the traversal, it means we've found the defective node with the invalid pointer.

The solution uses a depth-first search (DFS) approach with a visited set to track nodes. When processing each node:

  • If the current node is null or its right child has already been visited (exists in the visited set), this indicates we've found the defective node, so we return null to remove it
  • Otherwise, we add the current node to the visited set and recursively process its right and left children
  • The order of processing right child before left child is important because the invalid pointer points to a node on the right side

For example, if we have a tree where node A's right child incorrectly points to node B (which is at the same level but to the right), when we traverse and reach node B normally, it will already be in our visited set. When we later reach node A and check its right child (which incorrectly points to B), we'll detect that B is already visited and know that A is the defective node to remove.

Flowchart Walkthrough

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: A binary tree is a special type of graph where each node has at most two children and there are no cycles.

Is it a tree?

  • Yes: The problem explicitly states we're working with a binary tree structure with nodes connected in a hierarchical manner.

DFS

  • Yes: We arrive at DFS as the solution approach.

Conclusion: The flowchart suggests using DFS (Depth-First Search) for this problem.

Why DFS is Perfect for This Problem

The DFS pattern is particularly well-suited for this problem because:

  1. Tree Traversal: We need to traverse the entire tree to find the defective node, and DFS naturally explores tree structures by going deep into each branch before backtracking.

  2. Detection Through Visitation: The key to identifying the defective node is detecting when a node's right child has already been visited. DFS with a visited set allows us to track which nodes we've already seen during our traversal.

  3. Subtree Removal: When we find the defective node, we need to remove it and its entire subtree. DFS makes this easy - we simply return null when we detect the invalid node, which effectively removes that entire branch from the tree.

  4. Order Matters: The solution processes the right child before the left child in the recursive calls. This ensures that when we encounter an invalid pointer (pointing to a node on the right), that target node has already been visited through the normal traversal path.

The DFS approach elegantly solves the problem by combining traversal with cycle detection (a node pointing to an already-visited node indicates the invalid connection) and tree modification in a single recursive pass.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

Intuition

The core insight comes from understanding what makes the defective node "invalid" - it creates an improper connection in the tree structure. In a valid binary tree, each node should only be reachable through one unique path from the root. However, the defective node's right child points to a node that already exists elsewhere in the tree, creating two paths to reach the same node.

Think about traversing the tree normally. If we visit nodes in a systematic way and keep track of which nodes we've seen, we should never encounter the same node twice. But with the defective pointer, we will! The defective node's right child will point to a node we can also reach through the normal tree structure.

Here's the key realization: if we traverse the tree and find that a node's right child has already been visited, then we've found our defective node. Why? Because the only way a child pointer can lead to an already-visited node is if it's pointing somewhere it shouldn't - specifically to that node on the same level to its right.

The elegant part is that we don't need to explicitly find which node is defective and then remove it in a separate step. Instead, we can detect and remove in one pass. As we traverse with DFS:

  • We mark each node as visited when we first encounter it
  • Before processing any node, we check if its right child has been visited
  • If yes, this must be the defective node, so we return null to remove it
  • If no, we continue processing normally

The order of traversal matters here. We process the right subtree before the left subtree at each node. This ensures that when we reach the defective node through normal traversal, the node it incorrectly points to (which is to its right) has likely already been visited through the correct path. This ordering increases our chances of detecting the invalid pointer when we encounter it.

This approach turns a potentially complex tree manipulation problem into a simple traversal with cycle detection - if we ever see a "cycle" (a pointer to an already-visited node), we've found our culprit.

Learn more about Tree, Depth-First Search, Breadth-First Search and Binary Tree patterns.

Solution Implementation

1# Definition for a binary tree node. 2# class TreeNode: 3# def __init__(self, val=0, left=None, right=None): 4# self.val = val 5# self.left = left 6# self.right = right 7 8class Solution: 9 def correctBinaryTree(self, root: TreeNode) -> TreeNode: 10 """ 11 Corrects a binary tree by removing the node that has an incorrect right pointer. 12 The incorrect node's right child points to a node at the same level that appears earlier. 13 14 Args: 15 root: The root of the binary tree 16 17 Returns: 18 The root of the corrected binary tree 19 """ 20 21 def dfs(node: TreeNode) -> TreeNode: 22 """ 23 Performs depth-first search to identify and remove the incorrect node. 24 Processes nodes from right to left to detect incorrect right pointers. 25 26 Args: 27 node: Current node being processed 28 29 Returns: 30 The node itself if valid, None if it should be removed 31 """ 32 # Base case: if node is None or its right child was already visited 33 # (meaning this node has an incorrect right pointer) 34 if node is None or node.right in visited_nodes: 35 return None 36 37 # Mark current node as visited 38 visited_nodes.add(node) 39 40 # Process right subtree first (right-to-left traversal) 41 node.right = dfs(node.right) 42 43 # Then process left subtree 44 node.left = dfs(node.left) 45 46 return node 47 48 # Set to track visited nodes during traversal 49 visited_nodes = set() 50 51 # Start DFS from root and return corrected tree 52 return dfs(root) 53
1/** 2 * Definition for a binary tree node. 3 * public class TreeNode { 4 * int val; 5 * TreeNode left; 6 * TreeNode right; 7 * TreeNode() {} 8 * TreeNode(int val) { this.val = val; } 9 * TreeNode(int val, TreeNode left, TreeNode right) { 10 * this.val = val; 11 * this.left = left; 12 * this.right = right; 13 * } 14 * } 15 */ 16class Solution { 17 // Set to keep track of visited nodes during traversal 18 private Set<TreeNode> visitedNodes = new HashSet<>(); 19 20 /** 21 * Corrects a binary tree by removing the invalid node. 22 * The invalid node is one whose right child points to a node 23 * that has already been visited at the same level. 24 * 25 * @param root The root of the binary tree 26 * @return The root of the corrected binary tree 27 */ 28 public TreeNode correctBinaryTree(TreeNode root) { 29 return dfs(root); 30 } 31 32 /** 33 * Performs depth-first search to identify and remove the invalid node. 34 * Traverses right subtree first to detect invalid references early. 35 * 36 * @param currentNode The current node being processed 37 * @return The corrected subtree rooted at currentNode, or null if invalid 38 */ 39 private TreeNode dfs(TreeNode currentNode) { 40 // Base case: null node or invalid node detected 41 // If right child points to an already visited node, this node is invalid 42 if (currentNode == null || visitedNodes.contains(currentNode.right)) { 43 return null; 44 } 45 46 // Mark current node as visited 47 visitedNodes.add(currentNode); 48 49 // Process right subtree first (reverse of typical DFS) 50 // This ensures we detect invalid references as early as possible 51 currentNode.right = dfs(currentNode.right); 52 53 // Process left subtree 54 currentNode.left = dfs(currentNode.left); 55 56 // Return the processed node 57 return currentNode; 58 } 59} 60
1/** 2 * Definition for a binary tree node. 3 * struct TreeNode { 4 * int val; 5 * TreeNode *left; 6 * TreeNode *right; 7 * TreeNode() : val(0), left(nullptr), right(nullptr) {} 8 * TreeNode(int x) : val(x), left(nullptr), right(nullptr) {} 9 * TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {} 10 * }; 11 */ 12class Solution { 13public: 14 TreeNode* correctBinaryTree(TreeNode* root) { 15 // Set to track visited nodes during traversal 16 unordered_set<TreeNode*> visitedNodes; 17 18 // Lambda function for depth-first search traversal 19 // Returns the corrected subtree or nullptr if the node should be removed 20 function<TreeNode*(TreeNode*)> dfs = [&](TreeNode* currentNode) -> TreeNode* { 21 // Base case: if node is null or its right child points to an already visited node 22 // (indicating an incorrect reference), return nullptr to remove this node 23 if (!currentNode || visitedNodes.count(currentNode->right)) { 24 return nullptr; 25 } 26 27 // Mark current node as visited 28 visitedNodes.insert(currentNode); 29 30 // Recursively process right subtree first (reverse inorder traversal) 31 // This ensures we visit nodes from right to left at each level 32 currentNode->right = dfs(currentNode->right); 33 34 // Then process left subtree 35 currentNode->left = dfs(currentNode->left); 36 37 // Return the current node after processing its children 38 return currentNode; 39 }; 40 41 // Start DFS traversal from root and return the corrected tree 42 return dfs(root); 43 } 44}; 45
1/** 2 * Definition for a binary tree node. 3 */ 4interface TreeNode { 5 val: number; 6 left: TreeNode | null; 7 right: TreeNode | null; 8} 9 10/** 11 * Corrects a binary tree by removing the invalid node. 12 * An invalid node is one whose right child points to a node that has already been visited. 13 * 14 * @param root - The root of the binary tree 15 * @returns The root of the corrected binary tree 16 */ 17function correctBinaryTree(root: TreeNode | null): TreeNode | null { 18 // Set to track visited nodes during traversal 19 const visitedNodes = new Set<TreeNode>(); 20 21 /** 22 * Performs depth-first search to identify and remove the invalid node. 23 * Traverses the tree from right to left to detect invalid references. 24 * 25 * @param node - Current node being processed 26 * @returns The node if valid, null if it should be removed 27 */ 28 const dfs = (node: TreeNode | null): TreeNode | null => { 29 // Base case: node is null or node's right child has already been visited (invalid) 30 if (!node || visitedNodes.has(node.right!)) { 31 return null; 32 } 33 34 // Mark current node as visited 35 visitedNodes.add(node); 36 37 // Process right subtree first (right-to-left traversal) 38 node.right = dfs(node.right); 39 40 // Process left subtree 41 node.left = dfs(node.left); 42 43 return node; 44 }; 45 46 return dfs(root); 47} 48

Solution Approach

The implementation uses a recursive DFS function with a visited set to track nodes we've already seen during traversal. Let's walk through the key components:

Data Structure: We use a set called vis to store references to nodes we've visited. This allows O(1) lookup time to check if a node has been seen before.

The DFS Function: The recursive dfs(root) function handles each subtree:

  1. Base Cases:

    • If root is None: We've reached a leaf's child, return None
    • If root.right in vis: We've found the defective node! Its right child points to an already-visited node, so we return None to remove this entire subtree
  2. Mark as Visited: Add the current node to the vis set. This must happen before processing children to ensure we can detect invalid pointers.

  3. Recursive Processing:

    root.right = dfs(root.right) root.left = dfs(root.left)

    Notice the order - we process the right child first, then the left. This ordering is crucial because:

    • The invalid pointer points to a node on the same level to the right
    • By processing right subtrees first, we're more likely to visit the target node through the correct path before encountering the invalid pointer
    • When we later reach the defective node, its right child will already be in vis
  4. Return the Node: If the node passes all checks, we return it unchanged, preserving the tree structure.

Why This Works: The algorithm exploits the fact that in a proper tree, each node should be reachable through exactly one path from the root. The defective node breaks this rule by creating a second path to some node. By tracking visited nodes, we can detect when a child pointer leads to an already-discovered node - this can only happen at the defective node.

The beauty of this approach is that it combines detection and removal in a single traversal. When we detect the invalid node (by finding its right child in vis), we immediately return None, which effectively removes that node and its entire subtree from the tree structure. The parent of the defective node will receive None and update its child pointer accordingly, seamlessly removing the problematic branch.

Ready to land your dream job?

Unlock your dream job with a 5-minute evaluator for a personalized learning plan!

Start Evaluator

Example Walkthrough

Let's walk through a small example to illustrate how the solution identifies and removes the defective node.

Consider this binary tree where node 3 is defective (its right child incorrectly points to node 5):

 1  / \  2 4  / / \  3 5 6  \_____|  (invalid pointer)

Initial State:

  • vis = {} (empty set)
  • Start DFS from root (node 1)

Step 1 - Process Node 1:

  • Check: Is node 1 null? No
  • Check: Is node 1's right child (4) in vis? No
  • Add node 1 to vis: vis = {1}
  • Recursively process right child (node 4) first

Step 2 - Process Node 4:

  • Check: Is node 4 null? No
  • Check: Is node 4's right child (6) in vis? No
  • Add node 4 to vis: vis = {1, 4}
  • Recursively process right child (node 6)

Step 3 - Process Node 6:

  • Check: Is node 6 null? No
  • Check: Is node 6's right child (null) in vis? No
  • Add node 6 to vis: vis = {1, 4, 6}
  • Process children (both null), return node 6

Step 4 - Back to Node 4, Process Left Child (5):

  • Check: Is node 5 null? No
  • Check: Is node 5's right child (null) in vis? No
  • Add node 5 to vis: vis = {1, 4, 5, 6}
  • Process children (both null), return node 5
  • Node 4 processing complete, return node 4

Step 5 - Back to Node 1, Process Left Child (2):

  • Check: Is node 2 null? No
  • Check: Is node 2's right child (null) in vis? No
  • Add node 2 to vis: vis = {1, 2, 4, 5, 6}
  • Process right child (null), then left child (node 3)

Step 6 - Process Node 3 (The Defective Node!):

  • Check: Is node 3 null? No
  • Check: Is node 3's right child (5) in vis? YES! Node 5 is already in vis
  • This is the defective node! Return null to remove it

Final Result: Node 2's left child becomes null, effectively removing node 3 from the tree:

 1  / \  2 4  / \  5 6

The key moment was in Step 6: when we checked node 3's right child, we found it was already visited. This could only happen if the pointer was invalid, confirming node 3 as defective. By returning null, we removed node 3 and its entire subtree from the final tree structure.

Time and Space Complexity

Time Complexity: O(n), where n is the number of nodes in the binary tree.

The algorithm performs a depth-first search (DFS) traversal of the tree. In the worst case, it visits every node exactly once before finding and removing the incorrect node. Each node is processed in constant time O(1) - checking if root.right is in the visited set, adding the current node to the visited set, and making recursive calls. Therefore, the total time complexity is O(n).

Space Complexity: O(n), where n is the number of nodes in the binary tree.

The space complexity consists of two components:

  1. Visited Set (vis): In the worst case, the algorithm stores all nodes in the visited set before finding the incorrect node, requiring O(n) space.
  2. Recursion Call Stack: The DFS traversal uses the call stack for recursion. In the worst case (a skewed tree), the maximum depth of recursion can be O(n), requiring O(n) space for the call stack.

Since both components can require O(n) space in the worst case, the overall space complexity is O(n).

Learn more about how to find time and space complexity quickly.

Common Pitfalls

1. Processing Children in Wrong Order

One of the most critical pitfalls is processing the left child before the right child. This would cause the algorithm to fail in many cases.

Why it's a problem: The defective node's incorrect pointer points to a node on the same level to its right. If we traverse left-to-right (processing left child first), we might visit the defective node before visiting its incorrectly-pointed-to target, meaning the target won't be in our visited set yet, and we'll miss detecting the defective node.

Incorrect approach:

# WRONG: Processing left before right node.left = dfs(node.left) # Processing left first node.right = dfs(node.right) # Processing right second

Correct approach:

# CORRECT: Process right before left node.right = dfs(node.right) # Process right first node.left = dfs(node.left) # Process left second

2. Checking Visited Status After Adding to Set

Another common mistake is adding the current node to the visited set before checking if its children are already visited.

Why it's a problem: If you add the node to visited set first, then check if node.right in visited_nodes, you might incorrectly flag valid parent-child relationships as problematic.

Incorrect approach:

def dfs(node):  if node is None:  return None   visited_nodes.add(node) # Adding to visited first   # WRONG: Now checking children after marking current as visited  if node.right in visited_nodes:  return None

Correct approach:

def dfs(node):  # Check the right child BEFORE adding current node to visited  if node is None or node.right in visited_nodes:  return None   visited_nodes.add(node) # Add to visited AFTER the check

3. Using Node Values Instead of Node References

A subtle but important pitfall is using node values to track visited nodes instead of node references.

Why it's a problem: Multiple nodes can have the same value in a binary tree. Using values would cause false positives when different nodes have identical values.

Incorrect approach:

visited_nodes = set()  def dfs(node):  if node is None or (node.right and node.right.val in visited_nodes):  return None   visited_nodes.add(node.val) # Storing values, not references

Correct approach:

visited_nodes = set()  def dfs(node):  if node is None or node.right in visited_nodes:  return None   visited_nodes.add(node) # Store the actual node reference

4. Forgetting to Update Parent's Child Pointers

While the provided solution handles this correctly, a common implementation mistake is detecting the defective node but forgetting to actually remove it from the tree structure.

Why it's a problem: Simply identifying the defective node isn't enough; we need to update its parent's pointer to exclude the defective subtree.

Incorrect approach:

def dfs(node):  if node is None:  return None   if node.right in visited_nodes:  # Just returning without updating parent's pointer  return   visited_nodes.add(node)  dfs(node.right) # Not capturing return value  dfs(node.left) # Not capturing return value  return node

Correct approach:

def dfs(node):  if node is None or node.right in visited_nodes:  return None # Return None to signal removal   visited_nodes.add(node)  node.right = dfs(node.right) # Update child pointers  node.left = dfs(node.left) # Update child pointers  return node

These pitfalls highlight the importance of understanding both the tree traversal order and the mechanism of tree modification through recursive pointer updates.

Loading...
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Tailor Your Study Plan:

Which of the following shows the order of node visit in a Breadth-first Search?


Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More