0%

浅谈Splay

前言

想了想还是写吧。

前置芝士

旋转和 BST。

正文

现发表一下个人的看法,splay 其实就是一种优雅的暴力的 BST,通过把一个节点不断旋转到根来维护其 BST 的性质。

旋转

因为你更改或插入时要更改从本节点到根节点路径上所有点的大小,所以就需要用到旋转了,也就是把本节点旋转到根。

具体的,对于一个节点,如果其性质和他的父亲的性质相同,就旋转他的父亲,反之旋转他自己,最后旋转他自己。

重复以上步骤直到根节点即可。

对于一个节点的性质为其属于其父亲的左子树或右子树。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int get(int x) { return x==ch[fa[x]][1]; }

void rotate(int x) {
int y=fa[x];
int z=fa[y],tmp=get(x);
ch[y][tmp]=ch[x][tmp^1];
if (ch[x][tmp^1]) fa[ch[x][tmp^1]]=y;
ch[x][tmp^1]=y;
fa[y]=x;
fa[x]=z;
if (z) ch[z][y==ch[z][1]]=x;
upd(y);
upd(x);
}

void splay(int x) {
for (int i=fa[x];i=fa[x],i;rotate(x))
if (fa[i]) rotate(get(x)==get(i)?i:x);
rt=x;
}

插入

没什么好说的,对于当前节点如果插入值比它大,往右子树搜,反之往左子树搜,再分几种情况:

  1. 根节点为空,建一个新节点作为根节点。

  2. 树中存在这个数,搜到此节点,更新,然后旋转到根。

  3. 树中不存在这个数,一直搜到叶子节点,建立一个属于他自己的新节点,更新,旋转到根。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void ins(int x) {
if (!rt) {
rt=++cnt;
val[cnt]=x;
num[cnt]++;
upd(rt);
return;
}
int y=rt,f=0;
while (1) {
if (x==val[y]) {
num[y]++;
upd(y);
upd(f);
splay(y);
break;
}
f=y;
y=ch[y][x>val[y]];
if (!y) {
val[++cnt]=x;
num[cnt]++;
fa[cnt]=f;
ch[f][val[cnt]>val[f]]=cnt;
upd(cnt);
upd(f);
splay(cnt);
break;
}
}
}

查询排名

首先不保证它在树中,所以要先插入,最后删掉即可。

在树中搜索,同样的,对于当前节点如果插入值比它小,往左子树搜,反之让答案加上左子树的大小,如果当前节点的值是我们所求的,就返回答案+1并旋转到根,否则再加上当前节点的个数并往右子树搜。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
int rk(int x) {
int res=0,y=rt;
while (1) {
if (x<val[y]) {
y=ch[y][0];
} else {
res+=siz[ch[y][0]];
if (x==val[y]) {
splay(y);
return res+1;
}
res+=num[y];
y=ch[y][1];
}

}
}

查询给定名次的值

在树中搜索,同样的,如果当前节点的左子树大小大于等于排名,往左子树搜,反之让排名减去左子树的大小和当前节点的个数,如果排名此时小于等于 $0$,就返回当前节点并旋转到根,否则往右子树搜。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int kth(int x) {
int y=rt;
while (1) {
if (ch[y][0] and x<=siz[ch[y][0]]) {
y=ch[y][0];
} else {
x-=(siz[ch[y][0]]+num[y]);
if (x<=0) {
splay(y);
return val[y];
}
y=ch[y][1];
}
}
}

前驱和后继

搜前驱就去找其左子树中的最大值,后继就是找其右子树中的最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int pre(int x) {
int y=ch[x][0];
if (!y) return y;
while (ch[y][1]) y=ch[y][1];
splay(y);
return y;
}

int nxt(int x) {
int y=ch[x][1];
if (!y) return y;
while (ch[y][0]) y=ch[y][0];
splay(y);
return y;
}

删除

分五种情况:

  1. 当前节点个数大于 $1$,个数减一,更新,旋转到根。

  2. 左右子树皆为空(即为根节点),删除本节点,rt=0

  3. 左子树有,右子树没有,删除本节点,rt=t[rt].ch[0]

  4. 左子树没有,右子树有,删除本节点,rt=t[rt].ch[1]

  5. 左右子树皆有,找到本节点的前驱,把前驱的右子树改为本节点的右子树,本节点的右子树的父亲改为前驱,删除本节点并更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
void clear(int x) {siz[x]=ch[x][0]=ch[x][1]=fa[x]=num[x]=val[x]=0;}

void del(int x) {
rk(x);
if (num[rt]>1) {
num[rt]--;
upd(rt);
} else if (!ch[rt][1] and !ch[rt][0]){
clear(rt);
rt=0;
} else if (!ch[rt][1] and ch[rt][0]) {
int tmp=rt;
rt=ch[rt][0];
fa[rt]=0;
clear(tmp);
} else if (ch[rt][1] and !ch[rt][0]) {
int tmp=rt;
rt=ch[rt][1];
fa[rt]=0;
clear(tmp);
} else {
int tmp=rt,x=pre(rt);
fa[ch[tmp][1]]=x;
ch[x][1]=ch[tmp][1];
clear(tmp);
upd(rt);
}
}

更新

没啥好说的。

1
void upd(int x) { siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+num[x]; }

习题

P3369 【模板】普通平衡树

把板子贴上去即可

P3391 【模板】文艺平衡树

就是把 splay 当区间树用,再加上一个 $lazytag$,对于每一个区间 $[l,r]$,把 $l-1$ 旋转到根,再把 $r+1$ 旋转到根的右子树,再对 $r+1$ 的左子树打上 $lazytag$ 即可。

旋转部分代码:

1
2
3
4
5
void splay(int x,int k) {
for (int f=t[x].fa;f=t[x].fa,f!=k;rotate(x))
if (t[f].fa!=k) rotate(get(x)==get(f)?f:x);
if (!k) rt=x;
}

P2596 [ZJOI2006]书架

把初始时每本书上面书的本数作为初始权值插入,并定义 $loc_i$ 为第 $i$ 本书在树中的编号。

  1. 对于操作 Top ,把本节点旋转到根,找到前驱,将右子树接到前驱的右子树出,更新旋转到根。

  2. 对于操作 Bottom ,同上。

  3. 对于操作 Insert,因为 $t \in {-1,0,1}$,所以当 $t$ 为 $0$ 时,直接 continue,那么当 $t=-1$ 时,找到前驱,交换权值和 $loc$,否则反之。

  4. 对于操作 Ask,将 $s$ 旋转到根,输出左子树大小。

  5. 对于操作 Query,就是求排名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#include <iostream>
#include <string>
#include <cstdio>

#define INF 100000

using namespace std;
int n,m;

namespace Splay {
struct Node {
int val,num,siz,fa,ch[2];
}t[1000001];

int rt,cnt,loc[1000001];

void upd(int x) { t[x].siz=t[t[x].ch[0]].siz+t[t[x].ch[1]].siz+t[x].num; }

int get(int x) { return x==t[t[x].fa].ch[1]; }

void rotate(int x) {
int y=t[x].fa,z=t[y].fa,tmp=get(x);
t[y].ch[tmp]=t[x].ch[tmp^1];
if (t[x].ch[tmp^1]) t[t[x].ch[tmp^1]].fa=y;
t[x].ch[tmp^1]=y;
t[y].fa=x;
t[x].fa=z;
if (z) t[z].ch[y==t[z].ch[1]]=x;
upd(y);
upd(x);
}

void splay(int x,int goal) {
for (int f=t[x].fa;f=t[x].fa,f!=goal;rotate(x))
if (t[f].fa!=goal) rotate(get(f)==get(x)?f:x);
loc[t[x].val]=x;
if (!goal) rt=x;
}

int find_max() {
int y=t[rt].ch[1];
if (!y) return rt;
while (t[y].ch[1]) y=t[y].ch[1];
splay(y,0);
return y;
}

int find_min() {
int y=t[rt].ch[0];
if (!y) return rt;
while (t[y].ch[0]) y=t[y].ch[0];
splay(y,0);
return y;
}

int rk(int x) {
splay(x,0);
return t[t[x].ch[0]].siz;
}

int kth(int x) {
int y=rt;
while (1) {
if (x<=t[t[y].ch[0]].siz) y=t[y].ch[0];
else {
x-=t[y].num+t[t[y].ch[0]].siz;
if (x<=0) {
splay(y,0);
return y;
}
y=t[y].ch[1];
}
}
}

int pre(int x,int k) {
int y=t[x].ch[k];
if (!y) return y;
while (t[y].ch[k^1]) y=t[y].ch[k^1];
splay(y,0);
return y;
}

void ins(int x) {
if (!rt) {
t[rt=++cnt].val=x;
t[cnt].num++;
loc[x]=cnt;
upd(cnt);
return;
}
find_max();
t[++cnt].val=x;
t[cnt].num++;
t[cnt].fa=rt;
loc[x]=cnt;
t[rt].ch[1]=cnt;
upd(cnt);
upd(rt);
splay(cnt,0);
}
}
using namespace Splay;

void work(int x) {
if (!x) return;
work(t[x].ch[0]);
printf("%d ",t[x].val);
work(t[x].ch[1]);
}

void work1(int x) {
if (!x) return;
work1(t[x].ch[0]);
printf("%d %d ",loc[t[x].val],x);
work1(t[x].ch[1]);
}

int main() {
scanf("%d%d",&n,&m);
for (int i=1,x;i<=n;i++) {
scanf("%d",&x);
ins(x);
}
string s;
for (int i=1,x,tt;i<=m;i++) {
cin>>s;
scanf("%d",&x);
if (s=="Top") {
splay(loc[x],0);
if (!t[rt].ch[0]) continue;
if (!t[rt].ch[1]) {
t[rt].ch[1]=t[rt].ch[0];
t[rt].ch[0]=0;
continue;
}
int y=t[rt].ch[1];
while (t[y].ch[0]) y=t[y].ch[0];
t[y].ch[0]=t[rt].ch[0];
if (t[y].ch[0]) t[t[y].ch[0]].fa=y;
t[rt].ch[0]=0;
upd(y);
splay(t[y].ch[0],0);
} else if (s=="Bottom") {
splay(loc[x],0);
if (!t[rt].ch[1]) continue;
if (!t[rt].ch[0]) {
t[rt].ch[0]=t[rt].ch[1];
t[rt].ch[1]=0;
continue;
}
int y=t[rt].ch[0];
while (t[y].ch[1]) y=t[y].ch[1];
t[y].ch[1]=t[rt].ch[1];
if (t[y].ch[1]) t[t[y].ch[1]].fa=y;
t[rt].ch[1]=0;
upd(y);
splay(t[y].ch[1],0);
} else if (s=="Insert") {
splay(loc[x],0);
scanf("%d",&tt);
if (!tt) continue;
int t1=pre(loc[x],(tt==-1?0:1));
if (!t1) continue;
swap(t[loc[x]].val,t[rt].val);
swap(loc[x],loc[t[loc[x]].val]);
} else if (s=="Ask") {
printf("%d\n",rk(loc[x]));
} else if (s=="Query") {
printf("%d\n",t[kth(x)].val);
}
// work(rt);
// printf("\n");
}
return 0;
}