Question link
LeetCode 236. Recent common ancestor of binary tree
Ideas
Reference article:
PAT Class A 1151 LCA in a Binary Tree (30 points) Recent Common Ancestor
Luogu P3379 [Template] Recent Public Ancestor (Doubling LCA)
Implementation code (Python)
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, x):
# self.val = x
# self.left = None
# self.right = None
# 普通lca
class Solution:
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
def lca(root, a, b):
if not root or root.val == a or root.val == b:
return root
left = lca(root.left, a, b)
right = lca(root.right, a, b)
if not left:
return right
if not right:
return left
return root
return(lca(root, p.val, q.val))
# 倍增lca
# Definition for a binary tree node.
class TreeNode:
def __init__(self, x, nId=0):
self.val = x
self.nId = nId
self.left = None
self.right = None
class Solution:
def lowestCommonAncestor(self, root: 'TreeNode', p: 'TreeNode', q: 'TreeNode') -> 'TreeNode':
depth = [0] * 200010
# depth[0] = -1
fa = [[0] * 20 for i in range(200010)] # 节点i的第2^j的祖先节点,最大不会超过2的20的次方
nodeNum = 0
def dfs(node, preNode): # node为当前节点,preNode为当前节点的父节点
nonlocal nodeNum
nodeNum += 1
node.nId = nodeNum
if preNode:
depth[node.nId] = depth[preNode.nId] + 1
fa[node.nId][0] = preNode.nId
else:
depth[node.nId] = 0
fa[node.nId][0] = 0
if node.left:
dfs(node.left, node)
if node.right:
dfs(node.right, node)
def lca(a, b):
if depth[a] > depth[b]: # 保证b的深度大于等于a的深度
a, b = b, a
t = depth[b] - depth[a] # t为两者的深度差
while t > 0: # 当它们不在同一深度,先让深度大的向上跳到同一深度
x = int(math.log(t, 2))
b = fa[b][x]
t = depth[b] - depth[a]
if a != b: # 若它们没有跳到同一个点
d = int(math.log(depth[a]))
for i in reversed(range(0, d + 1)): # 同时向上跳,但不能跳到同一个点,因为可能跳得太多了
if fa[a][i] != fa[b][i]:
a = fa[a][i]
b = fa[b][i]
a = fa[a][0] # 最后的结果等于a的父节点
return a
def preOrderTraversal(root, nId):
if root and root.nId == nId:
return root
resultNode = None
if root.left:
resultNode = preOrderTraversal(root.left, nId)
if resultNode:
return resultNode
if root.right:
resultNode = preOrderTraversal(root.right, nId)
if resultNode:
return resultNode
return resultNode
dfs(root, None)
j = 1
while 2 ** j <= nodeNum: # 记录节点i的第2^j的祖先节点
for i in range(1, nodeNum + 1):
fa[i][j] = fa[fa[i][j - 1]][j - 1]
# 意思是i的2^j祖先等于i的2^(j-1)祖先的2^(j-1)祖先
# 2^j = 2^(j-1) + 2^(j-1)
j += 1
return preOrderTraversal(root, lca(p.nId, q.nId))