关于点分治的常见写法及时间复杂度-爱代码爱编程
关于点分治的常见写法及时间复杂度
0x01 概述
点分治是一种基于树的重心,统计树上路径的优秀算法。将树上的路径分为经过树的重心和不经过树的重心两种,同时利用树的重心性质,使得递归深度不超过 l o g n logn logn次。
总时间复杂取决于每次递归统计答案的时间复杂度。若每次统计是 O ( n ) O(n) O(n)的,那么总时间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)。若统计的时间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)的,那么总时间复杂度为 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。这两种时间复杂度均为可也接受的。
0x02 写法
点分治有2种写法:
-
对于某个重心 u u u,统计以 u u u为根的所有路径,然后计算出所有组合情况。递归子树时,首先删除全部在一颗子树的路径,然后再进入子树递归求解。这样可以保证路径全部合法且不重不漏。
-
对于某个重心 u u u,先进入子树 v 1 v_1 v1,求解出 u u u到子树 v 1 v_1 v1所有节点的路径,然后进入子树 v 2 v_2 v2,进入时先统计答案,然后再统计相关值……每次进入新的子树时,先统计答案,这样每次计算的路径一定是和之前统计过子树的相连而成的,没有不合法的答案,所以不用删除。最后,依次递归进子树,找出重心递归求解。
图1表示直接统计以重心为根的子树上的路径,递归进入子树之前,首先删除子树中的不合法路径。
图2表示按子树依次递归进入,进入后先统计答案,然后再统计相关值。统计答案是利用到了先前子树的信息。
一般地,写法2常数更小。
0x03 分析
下面来探讨两种写法的时间复杂度。
两种写法首先在“分治”过程中,每次都是中心划分,树的层数(递归深度)是严格 l o g n logn logn的。也就是说总体框架上都是带一个log。
对当前层答案统计、每次问题规模减半,可看作 T ( n ) = T ( n / 2 ) + O ( n ) T(n)=T(n/2)+O(n) T(n)=T(n/2)+O(n)。
问题就出在两种写法在”统计计算”这块的时间。
对于写法1,一般会将所有子树的dis全部排个序或者二分,这样就躲不掉的一个log的时间,总时间复杂度是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。
对于写法2,采用的桶来记录前面子树的dis信息,或占用一定量的空间(当然在范围允许的情况下),或用map增加一个log的时间,总时间复杂度是 O ( n l o g n ) O(nlogn) O(nlogn)或 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。
下给出两份题目的代码,供读者阅读体会:
//写法1实现
#include <bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0; bool f = 0; char ch = getchar();
while (!isdigit(ch)) f |= ch == '-', ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
return f ? -x : x;
}
const int MAXN = 4e4 + 6;
int n, k, root, cnte, cntd, ans;
int mxp[MAXN], vis[MAXN], sz[MAXN], hd[MAXN], dis[MAXN];
struct tEdge { int to, nxt, wei; } e[MAXN << 1];
inline void link(int u, int v, int w) {
e[++cnte] = {v, hd[u], w};
hd[u] = cnte;
}
void getroot(int u, int p, int n) {
sz[u] = 1; mxp[u] = 0;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (v == p || vis[v]) continue;
getroot(v, u, n);
sz[u] += sz[v];
mxp[u] = max(mxp[u], sz[v]);
}
mxp[u] = max(mxp[u], n - sz[u]);
if (mxp[u] < mxp[root]) root = u;
// mxp[u] < mxp[root] ...YES
// mxp[u] <= n / 2 ...NO
}
void getdis(int u, int d, int p) {
dis[++cntd] = d;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to, w = e[i].wei;
if (v == p || vis[v]) continue;
getdis(v, d + w, u);
}
}
int calc(int u, int d) {
cntd = 0; getdis(u, d, 0);
sort(dis + 1, dis + cntd + 1);
int sum = 0;
for (int l = 1, r = cntd; ; ++l) {
while (r && dis[l] + dis[r] > k) --r;
if (r < l) break;
sum += r - l;
}
return sum;
}
void solve(int u) {
ans += calc(u, 0); vis[u] = 1;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to, w = e[i].wei;
if (vis[v]) continue;
ans -= calc(v, w);
mxp[root = 0] = INT_MAX;
getroot(v, 0, sz[v]); solve(root);
}
}
int main() {
n = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read(), w = read();
link(u, v, w); link(v, u, w);
}
k = read();
mxp[root = 0] = INT_MAX;
getroot(1, 0, n); solve(root);
printf("%d\n", ans);
return 0;
}
//写法2实现
#include <bits/stdc++.h>
using namespace std;
inline int read() {
int x = 0; bool f = 0; char ch = getchar();
while (!isdigit(ch)) f |= ch == '-', ch = getchar();
while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
return f ? -x : x;
}
const int MAXN = 1e4 + 6, INF = 1e8 + 6;
int n, m, cnte, cntd, root;
int hd[MAXN], sz[MAXN], mxp[MAXN], vis[MAXN];
int q[MAXN], f[MAXN], dic[INF], dis[MAXN], tf[MAXN];
struct tEdge { int to, nxt, wei; } e[MAXN << 1];
inline void link(int u, int v, int w) {
e[++cnte] = {v, hd[u], w};
hd[u] = cnte;
}
void getroot(int u, int p, int n) {
sz[u] = 1, mxp[u] = 0;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (vis[v] || v == p) continue;
getroot(v, u, n);
sz[u] += sz[v];
mxp[u] = max(mxp[u], sz[v]);
}
mxp[u] = max(mxp[u], n - sz[u]);
if (mxp[u] < mxp[root]) root = u;
}
void getdis(int u, int d, int p) {
dis[++cntd] = d;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to, w = e[i].wei;
if (v == p || vis[v]) continue;
getdis(v, d + w, u);
}
}
void calc(int u) {
int tot = 0;
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to, w = e[i].wei;
if (vis[v]) continue;
cntd = 0; getdis(v, w, u);
for (int j = cntd; j; --j)
for (int k = 1; k <= m; ++k) {
if (q[k] < dis[j]) continue;
f[k] |= dic[q[k] - dis[j]];
}
for (int j = cntd; j; --j)
tf[++tot] = dis[j], dic[dis[j]] = 1;
}
for (int i = 1; i <= tot; ++i)
dic[tf[i]] = 0;
}
void solve(int u) {
vis[u] = dic[0] = 1; calc(u);
for (int i = hd[u]; i; i = e[i].nxt) {
int v = e[i].to;
if (vis[v]) continue;
mxp[root = 0] = INT_MAX;
getroot(v, 0, sz[v]); solve(root);
}
}
int main() {
n = read(), m = read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read(), d = read();
link(u, v, d); link(v, u, d);
}
for (int i = 1; i <= m; ++i)
q[i] = read();
mxp[root = 0] = INT_MAX;
getroot(1, 0, n); solve(root);
for (int i = 1; i <= m; ++i) {
if (f[i]) puts("AYE");
else puts("NAY");
}
return 0;
}
0x04 相关
点分治常见题型有:
- 路径和等于或小于等于 k k k的点对(路径条数)。
- 路径和为某个数的倍数。
- 路径和为 k k k且路径的边数最少。
- 路径和 m o d M mod M modM后为某个值。
- 路径上经过不允许点的个数不超过某个值,且路径和最大。
- ……
若使用写法1,一般开一个栈,保存路径上的距离等相关信息,排序后利用单调性或者二分找答案。
若使用写法2,则一般处理这些问题时,都是开一个桶 m p i mp_i mpi表示距离为i的相关信息。
0x05 总结
个人感觉写法2更好,效率更高、适用面广、调试与思维难度低。