1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
| #include <bits/stdc++.h> #define mp make_pair #define pb push_back #define fl first #define fr second using namespace std; typedef long long LL; typedef pair<int, int> pii;
const int N = 200000 + 5;
int n, m, rt[N], cnt;
struct Node { int lc, rc, fa, depth; } t[N * 20];
int Build(int l, int r) { int u = ++cnt; if (l == r) { t[u].fa = l; return u; } int mid = (l + r) >> 1; t[u].lc = Build(l, mid); t[u].rc = Build(mid + 1, r); return u; }
int Make(int u) { cnt++; t[cnt] = t[u]; return cnt; }
int Modify(int u, int l, int r, int pos, int x) { u = Make(u); if (l == r) { t[u].fa = x; return u; } int mid = (l + r) >> 1; if (pos <= mid) t[u].lc = Modify(t[u].lc, l, mid, pos, x); else t[u].rc = Modify(t[u].rc, mid + 1, r, pos, x); return u; }
int Query(int u, int l, int r, int pos) { if (l == r) return u; int mid = (l + r) >> 1; if (pos <= mid) return Query(t[u].lc, l, mid, pos); else return Query(t[u].rc, mid + 1, r, pos); }
int Find(int k, int x) { int f = Query(rt[k], 1, n, x); if (x == t[f].fa) return f; return Find(k, t[f].fa); }
void Add(int u, int l, int r, int pos) { if (l == r) t[u].depth++; else { int mid = (l + r) >> 1; if (pos <= mid) Add(t[u].lc, l, mid, pos); else Add(t[u].rc, mid + 1, r, pos); } }
int main() { scanf("%d%d", &n, &m); rt[0] = Build(1, n); for (int i = 1; i <= m; i++) { int opt, a, b; scanf("%d", &opt); if (opt == 1) { rt[i] = rt[i - 1]; scanf("%d%d", &a, &b); int fa = Find(i, a), fb = Find(i, b); if (t[fa].fa == t[fb].fa) continue; if (t[fa].depth > t[fb].depth) swap(fa, fb); rt[i] = Modify(rt[i], 1, n, t[fa].fa, t[fb].fa); if (t[fa].depth == t[fb].depth) Add(rt[i], 1, n, t[fb].fa); } if (opt == 2) { scanf("%d", &a); rt[i] = rt[a]; } if (opt == 3) { rt[i] = rt[i - 1]; scanf("%d%d", &a, &b); int fa = Find(i, a), fb = Find(i, b); if (t[fa].fa == t[fb].fa) printf("1\n"); else printf("0\n"); } } }
|