Trim Dead Nodes from Binary Tree
You are given the reference to the root of a binary tree and are asked to trim the tree of "dead" nodes. A dead node is a node whose value is listed in the provided dead array. Once the tree has been trimmed of all dead nodes, return a list containing references to the roots of all the remaining segments of the tree.
Note: When a dead node is removed, its children become roots of new segments (if they are not dead themselves).
Example(s)
Example 1:
Input:
3
/ \
1 7
/ \ / \
2 8 4 6
dead = [7, 8]
Output: [3, 4, 6]
Explanation:
- Node 7 is dead, so it's removed. Its children (4 and 6) become new roots.
- Node 8 is dead, so it's removed. It has no children.
- The original root 3 remains (it's not dead).
- Result: [3, 4, 6]
Example 2:
Input:
5
/ \
3 2
/ \ / \
1 4 6 7
dead = [3, 2]
Output: [5, 1, 4, 6, 7]
Explanation:
- Node 3 is dead, so its children (1 and 4) become new roots.
- Node 2 is dead, so its children (6 and 7) become new roots.
- The original root 5 remains.
- Result: [5, 1, 4, 6, 7]
Example 3:
Input:
1
/ \
2 3
dead = [1]
Output: [2, 3]
Explanation:
- Root 1 is dead, so it's removed.
- Its children (2 and 3) become new roots.
- Result: [2, 3]
Solution
The solution uses post-order traversal with a result collector:
- Post-order traversal: Process children before parent (left → right → root)
- Check if node is dead: If dead, add its non-dead children to result and return null
- Recursively process children: Update children by removing dead nodes
- Collect roots: When a dead node is found, its children become new roots
- Handle original root: If original root is not dead, add it to result
- JavaScript Solution
- Python Solution
JavaScript Solution
/**
* Definition for a binary tree node.
* function TreeNode(val, left, right) {
* this.val = (val===undefined ? 0 : val)
* this.left = (left===undefined ? null : left)
* this.right = (right===undefined ? null : right)
* }
*/
/**
* Trim dead nodes from binary tree and return roots of remaining segments
* @param {TreeNode} root - Root of the binary tree
* @param {number[]} dead - Array of dead node values
* @return {TreeNode[]} - Array of root references for remaining segments
*/
function trimDeadNodes(root, dead) {
const deadSet = new Set(dead); // For O(1) lookup
const result = [];
/**
* Recursively process the tree and collect roots
* @param {TreeNode} node - Current node
* @return {TreeNode|null} - Modified node or null if dead
*/
function dfs(node) {
if (!node) {
return null;
}
// Check if current node is dead
if (deadSet.has(node.val)) {
// Node is dead, so its children become potential new roots
// Process children first (they might also be dead)
const leftChild = dfs(node.left);
const rightChild = dfs(node.right);
// Add non-dead children to result as new roots
if (leftChild && !deadSet.has(leftChild.val)) {
result.push(leftChild);
}
if (rightChild && !deadSet.has(rightChild.val)) {
result.push(rightChild);
}
// Return null to remove this dead node
return null;
}
// Node is not dead, recursively process children
node.left = dfs(node.left);
node.right = dfs(node.right);
return node;
}
// Process the tree
const processedRoot = dfs(root);
// If original root is not dead, add it to result
if (processedRoot && !deadSet.has(processedRoot.val)) {
result.push(processedRoot);
}
return result;
}
// Helper function to create a tree node
function TreeNode(val, left, right) {
this.val = (val===undefined ? 0 : val);
this.left = (left===undefined ? null : left);
this.right = (right===undefined ? null : right);
}
// Helper function to print tree values
function getValues(nodes) {
return nodes.map(node => node.val);
}
// Test case 1: [3, 1, 7, 2, 8, 4, 6], dead = [7, 8]
const tree1 = new TreeNode(3);
tree1.left = new TreeNode(1);
tree1.right = new TreeNode(7);
tree1.left.left = new TreeNode(2);
tree1.left.right = new TreeNode(8);
tree1.right.left = new TreeNode(4);
tree1.right.right = new TreeNode(6);
const result1 = trimDeadNodes(tree1, [7, 8]);
console.log('Example 1:', getValues(result1)); // [3, 4, 6]
// Test case 2: [5, 3, 2, 1, 4, 6, 7], dead = [3, 2]
const tree2 = new TreeNode(5);
tree2.left = new TreeNode(3);
tree2.right = new TreeNode(2);
tree2.left.left = new TreeNode(1);
tree2.left.right = new TreeNode(4);
tree2.right.left = new TreeNode(6);
tree2.right.right = new TreeNode(7);
const result2 = trimDeadNodes(tree2, [3, 2]);
console.log('Example 2:', getValues(result2)); // [5, 1, 4, 6, 7]
// Test case 3: [1, 2, 3], dead = [1]
const tree3 = new TreeNode(1);
tree3.left = new TreeNode(2);
tree3.right = new TreeNode(3);
const result3 = trimDeadNodes(tree3, [1]);
console.log('Example 3:', getValues(result3)); // [2, 3]Output:
Click "Run Code" to execute the code and see the results.
Python Solution
from typing import List, Optional, Set
# Definition for a binary tree node.
class TreeNode:
def __init__(self, val=0, left=None, right=None):
self.val = val
self.left = left
self.right = right
def trim_dead_nodes(root: Optional[TreeNode], dead: List[int]) -> List[TreeNode]:
"""
Trim dead nodes from binary tree and return roots of remaining segments
Args:
root: Root of the binary tree
dead: List of dead node values
Returns:
List[TreeNode]: List of root references for remaining segments
"""
dead_set = set(dead) # For O(1) lookup
result = []
def dfs(node: Optional[TreeNode]) -> Optional[TreeNode]:
"""
Recursively process the tree and collect roots
Returns:
TreeNode or None: Modified node or None if dead
"""
if not node:
return None
# Check if current node is dead
if node.val in dead_set:
# Node is dead, so its children become potential new roots
# Process children first (they might also be dead)
left_child = dfs(node.left)
right_child = dfs(node.right)
# Add non-dead children to result as new roots
if left_child and left_child.val not in dead_set:
result.append(left_child)
if right_child and right_child.val not in dead_set:
result.append(right_child)
# Return None to remove this dead node
return None
# Node is not dead, recursively process children
node.left = dfs(node.left)
node.right = dfs(node.right)
return node
# Process the tree
processed_root = dfs(root)
# If original root is not dead, add it to result
if processed_root and processed_root.val not in dead_set:
result.append(processed_root)
return result
# Helper function to get values from nodes
def get_values(nodes: List[TreeNode]) -> List[int]:
return [node.val for node in nodes]
# Test case 1: [3, 1, 7, 2, 8, 4, 6], dead = [7, 8]
tree1 = TreeNode(3)
tree1.left = TreeNode(1)
tree1.right = TreeNode(7)
tree1.left.left = TreeNode(2)
tree1.left.right = TreeNode(8)
tree1.right.left = TreeNode(4)
tree1.right.right = TreeNode(6)
result1 = trim_dead_nodes(tree1, [7, 8])
print('Example 1:', get_values(result1)) # [3, 4, 6]
# Test case 2: [5, 3, 2, 1, 4, 6, 7], dead = [3, 2]
tree2 = TreeNode(5)
tree2.left = TreeNode(3)
tree2.right = TreeNode(2)
tree2.left.left = TreeNode(1)
tree2.left.right = TreeNode(4)
tree2.right.left = TreeNode(6)
tree2.right.right = TreeNode(7)
result2 = trim_dead_nodes(tree2, [3, 2])
print('Example 2:', get_values(result2)) # [5, 1, 4, 6, 7]
# Test case 3: [1, 2, 3], dead = [1]
tree3 = TreeNode(1)
tree3.left = TreeNode(2)
tree3.right = TreeNode(3)
result3 = trim_dead_nodes(tree3, [1])
print('Example 3:', get_values(result3)) # [2, 3]Loading Python runtime...
Output:
Click "Run Code" to execute the code and see the results.
Alternative Solution (Two-Pass)
Here's an alternative approach that first collects all roots, then trims:
- JavaScript Alternative
- Python Alternative
/**
* Alternative: Two-pass approach
* First pass: identify and collect roots
* Second pass: trim the tree
*/
function trimDeadNodesTwoPass(root, dead) {
const deadSet = new Set(dead);
const roots = [];
// First pass: collect all roots (original + children of dead nodes)
function collectRoots(node, isRoot) {
if (!node) return;
if (isRoot && !deadSet.has(node.val)) {
roots.push(node);
}
const isDead = deadSet.has(node.val);
collectRoots(node.left, isDead);
collectRoots(node.right, isDead);
}
collectRoots(root, true);
// Second pass: trim dead nodes
function trim(node) {
if (!node) return null;
if (deadSet.has(node.val)) return null;
node.left = trim(node.left);
node.right = trim(node.right);
return node;
}
// Trim each root
return roots.map(trim).filter(Boolean);
}
def trim_dead_nodes_two_pass(root: Optional[TreeNode], dead: List[int]) -> List[TreeNode]:
"""
Alternative: Two-pass approach
"""
dead_set = set(dead)
roots = []
# First pass: collect all roots
def collect_roots(node: Optional[TreeNode], is_root: bool):
if not node:
return
if is_root and node.val not in dead_set:
roots.append(node)
is_dead = node.val in dead_set
collect_roots(node.left, is_dead)
collect_roots(node.right, is_dead)
collect_roots(root, True)
# Second pass: trim dead nodes
def trim(node: Optional[TreeNode]) -> Optional[TreeNode]:
if not node:
return None
if node.val in dead_set:
return None
node.left = trim(node.left)
node.right = trim(node.right)
return node
# Trim each root
return [trim(r) for r in roots if trim(r)]
Complexity
- Time Complexity: O(n) - Where n is the number of nodes. We visit each node exactly once.
- Space Complexity: O(h + m) - Where h is the height of the tree (recursion stack) and m is the number of dead nodes (for the set). In the worst case, O(n) for skewed tree.
Approach
The solution uses post-order traversal with root collection:
- Convert dead array to set: For O(1) lookup time
- Post-order traversal: Process children before parent
- Dead node handling:
- When a dead node is found, recursively process its children
- Add non-dead children to the result list (they become new roots)
- Return null to remove the dead node
- Non-dead node handling:
- Recursively process and update children
- Return the node (possibly with modified children)
- Original root handling: After processing, if the original root is not dead, add it to result
Key Insights
- Post-order traversal ensures children are processed before parent
- Children of dead nodes become roots: This is the key insight - when a node is removed, its children become independent trees
- Set for dead values: O(1) lookup instead of O(m) array search
- In-place modification: We modify the tree structure in place
- Root collection: We collect roots as we find dead nodes, not at the end
Step-by-Step Example
Let's trace through Example 1: Tree with nodes [3, 1, 7, 2, 8, 4, 6], dead = [7, 8]
Initial tree:
3
/ \
1 7
/ \ / \
2 8 4 6
Post-order traversal: 2 → 8 → 1 → 4 → 6 → 7 → 3
Step 1: Process node 2 (leaf, not dead)
- Return node 2
Step 2: Process node 8 (leaf, dead)
- Add children to result: none (leaf node)
- Return null
Step 3: Process node 1 (not dead)
- left = 2, right = null (8 was removed)
- Return node 1
Step 4: Process node 4 (leaf, not dead)
- Return node 4
Step 5: Process node 6 (leaf, not dead)
- Return node 6
Step 6: Process node 7 (dead)
- left = 4, right = 6 (both not dead)
- Add 4 and 6 to result
- Return null
Step 7: Process node 3 (root, not dead)
- left = 1, right = null (7 was removed)
- Return node 3
Final: Add root 3 to result
Result: [3, 4, 6]
Edge Cases
- Root is dead: Children become new roots
- All nodes are dead: Return empty array
- No dead nodes: Return array with original root
- Dead node with dead children: Only non-dead grandchildren become roots
- Single node tree: If dead, return empty; if not, return [root]
- Empty tree: Return empty array
Important Notes
- Node references: The problem asks for references to nodes, not values
- Tree modification: The original tree structure is modified (dead nodes removed)
- Multiple roots: Result can contain multiple root references
- Children become roots: When a dead node is removed, its children become roots of new segments
Takeaways
- Post-order traversal is useful when children need to be processed before parent
- Root collection during traversal is more efficient than collecting at the end
- Set for lookups improves time complexity from O(nm) to O(n)
- In-place modification is space-efficient
- Understanding the problem - children of dead nodes become roots - is crucial