最近公共祖先

求最近公共祖先的方法

  1. 向上标记法O(n)

  2. 倍增

    fa[i, j]表示从i开始,向上走2^j步所能走到的节点。0 <= j <= logn

    depth[i] 表示深度

    哨兵:如果从i开始向上跳2^j步跳过根节点,fa[i, j] = 0, depth[0] = 0;

    步骤:

    • 先让两个点跳到同一层
    • 让两个点同时往上跳,一直跳到他们的最近公共祖先的下一层。

    预处理O(nlogn)

    查询O(logn)

  3. Tarjan——离线求LCA O(n + m)

    深度优先遍历的时候将所有点分成三大类:

    • 已经遍历过,且回溯的点
    • 正在搜索分支的点
    • 还未搜索到的点

祖孙询问(基础)

题目描述

给定一棵包含 n 个节点的有根无向树,节点编号互不相同,但不一定是 1∼n。

有 m 个询问,每个询问给出了一对节点的编号 x 和 y,询问 x 与 y 的祖孙关系。

输入格式

输入第一行包括一个整数 表示节点个数;

接下来 n 行每行一对整数 a 和 b,表示 a 和 b 之间有一条无向边。如果 b 是 −1,那么 a 就是树的根;

第 n+2 行是一个整数 m 表示询问个数;

接下来 m 行,每行两个不同的正整数 x 和 y,表示一个询问。

输出格式

对于每一个询问,若 x 是 y 的祖先则输出 1,若 y 是 x 的祖先则输出2,否则输出 0。

数据范围

1≤n,m≤4×10^4
1≤每个节点的编号≤4×10^4

输入样例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
10
234 -1
12 234
13 234
14 234
15 234
16 234
17 234
18 234
19 234
233 19
5
234 233
233 12
233 13
233 15
233 19

输出样例:

1
2
3
4
5
1
0
0
0
2

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>

using namespace std;
const int N = 40010, M = N * 2;
int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][16];

void add(int a, int b){
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void bfs(int root){
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
queue<int> q;
q.push(root);
while(q.size()){
auto t = q.front(); q.pop();
for(int i = h[t]; ~i; i = ne[i]){
int j = e[i];
if(depth[j] > depth[t] + 1){
depth[j] = depth[t] + 1;
q.push(j);
fa[j][0] = t; // j跳2^0步
for(int k = 1; k <= 15; k ++) {
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
}

int lca(int a, int b){
if(depth[a] < depth[b]) swap(a, b);
// 将两个点跳到同一深度
for(int k = 15; k >= 0; k--){
if(depth[fa[a][k]] >= depth[b])
a = fa[a][k];
}
if(a == b) return a;
//两个点同时往上跳, 找到最近祖先的下面的深度
for(int k = 15; k >= 0; k--){
if(fa[a][k] != fa[b][k]){
a = fa[a][k];
b = fa[b][k];
}
}
return fa[a][0];
}

int main ()
{
cin >> n;
int root = 0;
memset(h, -1, sizeof h);
for(int i = 0; i < n; i++) {
int a, b;
cin >> a >> b;
if(b == -1) root = a;
else add(a, b), add(b, a);
}
bfs(root);
// m次询问
cin >> m;
while(m--){
int a, b;
cin >> a >> b;
int p = lca(a, b);
if(p == a) cout << "1" << endl;
else if(p == b) cout << "2" << endl;
else cout << "0" << endl;
}
}

距离(提高)

题目描述

给出 n 个点的一棵树,多次询问两点之间的最短距离。

注意:

  • 边是无向的。
  • 所有节点的编号是 1,2,…,n。

输入格式

第一行为两个整数 n 和 m。n 表示点数,m 表示询问次数;

下来 n−1 行,每行三个整数 x,y,k,表示点 x 和点 y 之间存在一条边长度为 k;

再接下来 m 行,每行两个整数 x,y,表示询问点 x 到点 y 的最短距离。

树中结点编号从 1 到 n。

输出格式

共 m 行,对于每次询问,输出一行询问结果。

数据范围

2≤n≤104,
1≤m≤2×104,
0<k≤100,
1≤x,y≤n

输入样例1:

1
2
3
4
2 2 
1 2 100
1 2
2 1

输出样例1:

1
2
100
100

输入样例2:

1
2
3
4
5
3 2
1 2 10
3 1 15
1 2
3 2

输出样例2:

1
2
10
25

思路

深度优先遍历的时候将所有点分成三大类:

  • 已经遍历过,且回溯的点
  • 正在搜索分支的点
  • 还未搜索到的点

深搜计算每个节点到根节点的路径和。回溯之后标记为st[i] = 2, 那么正在路劲上的点就是已经搜索过的点的祖先。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include <bits/stdc++.h>
#define pii pair<int, int>

using namespace std;

const int N = 100010, M = N * 2;

int n, m, h[N], e[M], w[M], ne[M], idx, dist[N], res[M], st[N], p[N];
vector<pii> query[N];


void add(int a, int b, int c){
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}

//计算从当前节点到根节点的路径和
void dfs(int u, int fa){
for(int i = h[u]; ~i; i = ne[i]){
int j = e[i];
if(j == fa) continue; // 当键节点等于父节点就不用管,保证从上到下执行
dist[j] = dist[u] + w[i]; // 当前路径和等于父节点加边权
dfs(j, u);
}
}

int find(int x){
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}

void tarjin(int u){
st[u] = 1;
for(int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if(!st[j]) {
tarjin(j);
p[j] = u;
}
}

for(auto item : query[u]){
auto [y, id] = item;

//已经遍历过且回溯的点
if(st[y] == 2){
int anc = find(y); // 父节点是anc
res[id] = dist[u] + dist[y] - dist[anc] * 2; //将当前结果存到res数组里面
}
}

st[u] = 2;
}

int main (){
cin >> n >> m;

memset(h, -1, sizeof h);

for(int i = 0; i < n - 1; i++){
int a, b, c;
cin >> a >> b >> c;
add(a, b, c);
add(b, a, c);
}

for(int i = 0; i < m; i++) {
int a, b;
cin >> a >> b;
if(a != b) {
query[a].push_back({b, i});
query[b].push_back({a, i});
}
}

for(int i = 1; i <= N; i++) p[i] = i;

dfs(1, -1);
tarjin(1);

for(int i = 0; i < m; i++) cout << res[i] << endl;

return 0;
}