这道最直接的思路是把二叉树转化成图,然后用BFS就能找出距离为K的节点。这个实际上在原来树的基础上,重新建了一个图。
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
vector<int> distanceK(TreeNode* root, TreeNode* target, int K) {
buildGraph(root);
queue<TreeNode*> q;
vector<int> res;
unordered_set<TreeNode*> visited;
q.push(target);
visited.insert(target);
int count = 0;
while(!q.empty()&&count<=K){
int n = q.size();
while(n--){
auto tmp = q.front();
q.pop();
if(count==K) res.emplace_back(tmp->val);
for(auto neighbor:graph[tmp]){
if(!visited.count(neighbor)){
q.push(neighbor);
visited.insert(neighbor);
}
}
}
count++;
}
return res;
}
void buildGraph(TreeNode* root){
if(root==NULL){
return;
}
if(root->left){
graph[root].emplace_back(root->left);
graph[root->left].emplace_back(root);
}
if(root->right){
graph[root].emplace_back(root->right);
graph[root->right].emplace_back(root);
}
buildGraph(root->left);
buildGraph(root->right);
}
unordered_map<TreeNode*,vector<TreeNode*>> graph;
};
使用树的特点用递归函数比较难想, 思路如图所示,假设我们知道root A的左子树B到Target的距离是l,那么右子树C到T的距离就是l+2,此时到target的距离为K的节点可以分为两类,一类是以Target为根节点下方子树中的,另一部分是上方的,上方的可以通过找到距离当前根节点距离为K-L-2的节点。
/**
* Definition for a binary tree node.
* struct TreeNode {
* int val;
* TreeNode *left;
* TreeNode *right;
* TreeNode(int x) : val(x), left(NULL), right(NULL) {}
* };
*/
class Solution {
public:
vector<int> distanceK(TreeNode* root, TreeNode* target, int K) {
dfs(root,target,K);
return res;
}
vector<int> res;
// 这个函数表示从根节点到target的距离
int dfs(TreeNode* root, TreeNode* target, int K){
if(root==nullptr){
return -1;
}
if(root==target){
collect(target,K); // 从当前节点出发,将距离为K的节点加入答案
return 0;
}
int left = dfs(root->left,target,K);
int right = dfs(root->right,target,K);
if(left>=0){
if(left==K-1) res.push_back(root->val);
collect(root->right,K-left-2);
return left+1;
}
if(right>=0){
if(right==K-1) res.push_back(root->val);
collect(root->left,K-right-2);
return right+1;
}
return -1;
}
void collect(TreeNode* root, int d){
if(root==nullptr||d<0) return;
if(d==0) res.push_back(root->val);
collect(root->left,d-1);
collect(root->right,d-1);
}
};