可持久化权值线段树

主席树可以解决不适用结合律的区间问题(如区间第 $K$ 大,区间种类数),这些问题原本是需要繁琐的树套树,而有了主席树就简单很多了。

主席树的中心思想是保留历史版本,最暴力的做法是没插入一个节点就新建一棵线段树,但这样会各种爆,其实我们可以只新建有更改的节点,然后直接连边到原来的节点即可。

类比普通的线段树,主席树的插入顺序相当于普通线段树的位置,而主席树中的位置是维护的权值。

例题

[POJ2104] K-th Number

给定一个长度为 $N$ 的序列 $a$ ,有 $M$ 次查询,每次查询区间 $[l,r]$ 中第 $K$ 大的数值。

$n\leq10^5,m\leq5000,a_i\leq10^9$

离散化之后用主席树维护

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>

const int MAXN=1e5+5;

int n,m;

struct TN{int lc,rc,sz;} t[MAXN*20];int tcnt;//nlogn
int rt[MAXN];

void insert(int l,int r,int x,int &y,int pos)
{
y=++tcnt;
t[y]=t[x];t[y].sz++;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) insert(l,mid,t[x].lc,t[y].lc,pos);
else insert(mid+1,r,t[x].rc,t[y].rc,pos);
}

int r2i(int l,int r,int x,int y,int k)
{
if(l==r) return l;
int mid=(l+r)>>1;
int pos=t[t[y].lc].sz-t[t[x].lc].sz;
if(k<=pos) return r2i(l,mid,t[x].lc,t[y].lc,k);
else return r2i(mid+1,r,t[x].rc,t[y].rc,k-pos);
}

int ncnt,a[MAXN],b[MAXN];
int main()
{
int i;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++) scanf("%d",&a[i]),b[i]=a[i];
std::sort(b+1,b+n+1);
ncnt=std::unique(b+1,b+n+1)-b-1;
for(i=1;i<=n;i++)
insert(1,ncnt,rt[i-1],rt[i],std::lower_bound(b+1,b+ncnt+1,a[i])-b);
for(i=1;i<=m;i++)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
printf("%d\n",b[r2i(1,ncnt,rt[x-1],rt[y],z)]);
}
}

[BZOJ1878] HH 的项链

给定一个长度为 $N$ 的序列 $a$,共有 $M$ 个询问,对每个询问 $[l,r]$ ,需要回答 $[l,r]$ 之间的种类数。

$N\leq5\times10^4,M\leq2\times10^5,0\leq a_i\leq10^6$

我们对于每个位置记录一下加一个与它相同的数的位置 $nextPos[i]$, 这样,我们用主席树维护 $nextPos$ 数组,对于每个查询只需要统计区间内有多少个 $nextPos[i]>r$ 即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/**************************************************************
Problem: 1878
User: zhangche0526
Language: C++
Result: Accepted
Time:2656 ms
Memory:35860 kb
****************************************************************/

#include<iostream>
#include<cstdio>

const int MAXN=5e4+5;

struct TN{int lc,rc,sz;} t[MAXN*25];int tcnt;
int rt[MAXN];
void insert(int l,int r,int x,int &y,int k)
{
y=++tcnt;
t[y]=t[x];
t[y].sz++;
if(l==r) return;
int mid=l+r>>1;
if(k<=mid) insert(l,mid,t[x].lc,t[y].lc,k);
else insert(mid+1,r,t[x].rc,t[y].rc,k);
}
int calc(int l,int r,int x,int y,int v)
{
if(r<v) return 0;
if(l>=v) return t[y].sz-t[x].sz;
int mid=l+r>>1;
return calc(l,mid,t[x].lc,t[y].lc,v)+calc(mid+1,r,t[x].rc,t[y].rc,v);
}
int n,m;

int la[MAXN*100],nextPos[MAXN];
int main()
{
int i,x;
scanf("%d",&n);
for(i=1;i<=n;i++)
{
scanf("%d",&x);
if(la[x]) nextPos[la[x]]=i;
la[x]=i;
}
for(i=1;i<=n;i++) if(!nextPos[i]) nextPos[i]=n+1;
for(i=1;i<=n;i++) insert(1,n+1,rt[i-1],rt[i],nextPos[i]);
scanf("%d",&m);
for(i=1;i<=m;i++)
{
int l,r;scanf("%d%d",&l,&r);
printf("%d\n",calc(1,n+1,rt[l-1],rt[r],r+1));
}
}

[BZOJ1901]动态排名系统

给定一长度为 $N$ 的序列 $a$, 有 $M$ 次操作,每次查询区间 $[l,r]$ 中第 $K$ 大的数值,或单点修改。

$N\leq5\times10^4,M\leq10^4$

本题就是在例一的基础上增加了单点修改操作,如果直接修改每次需要 $O(n\log_2n)$ 的时间复杂度,这是不能接受的。注意到原本的主席树每个节点维护的是前缀信息,可以利用树状数组优化,即将原先的对于每一个新元素的前缀建树,改为对于树状数组中的前缀建树,这样可以将每次修改的时间复杂度降为 $O(\log_2^2n)$

需要注意的是,在修改时不用新建节点,因此插入函数稍有改动。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>

const int MAXN=5e6+5;

int n,m;

struct T{int lc,rc,sum;} t[MAXN];int tcnt,rt[MAXN];

int na[MAXN],srtdNa[MAXN],ncnt;
int X[MAXN],xcnt,Y[MAXN],ycnt;

inline int lowbit(int x){return x&(-x);}

void insert(int l,int r,int x,int &y,int k,int v)
{
if(!y) y=++tcnt,t[y]=t[x];
t[y].sum+=v;
if(l==r) return;
int mid=l+r>>1;
if(k<=mid) insert(l,mid,t[x].lc,t[y].lc,k,v);
else insert(mid+1,r,t[x].rc,t[y].rc,k,v);
}

void add(int x,int v)
{
int k=std::lower_bound(srtdNa+1,srtdNa+ncnt+1,na[x])-srtdNa;
for(;x<=n;x+=lowbit(x)) insert(1,ncnt,rt[x],rt[x],k,v);
}

int query(int l,int r,int k)
{
int i,sum=0,mid=l+r>>1;
if(l==r) return l;
for(i=1;i<=xcnt;i++) sum-=t[t[X[i]].lc].sum;
for(i=1;i<=ycnt;i++) sum+=t[t[Y[i]].lc].sum;
if(k<=sum)
{
for(i=1;i<=xcnt;i++) X[i]=t[X[i]].lc;
for(i=1;i<=ycnt;i++) Y[i]=t[Y[i]].lc;
return query(l,mid,k);
}else
{
for(i=1;i<=xcnt;i++) X[i]=t[X[i]].rc;
for(i=1;i<=ycnt;i++) Y[i]=t[Y[i]].rc;
return query(mid+1,r,k-sum);
}
}

int A[MAXN],B[MAXN],C[MAXN];
int main()
{
freopen("dynrank.in","r",stdin);
freopen("dynrank.out","w",stdout);
int T;scanf("%d",&T);
while(T--)
{
memset(rt,0,sizeof(rt));
memset(C,0,sizeof(C));
ncnt=tcnt=0;
int i,j;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)
scanf("%d",na+i),srtdNa[++ncnt]=na[i];
for(i=1;i<=m;i++)
{
char s[10];scanf("%s%d%d",s,A+i,B+i);
if(s[0]=='Q') scanf("%d",C+i);
else srtdNa[++ncnt]=B[i];
}
std::sort(srtdNa+1,srtdNa+ncnt+1);
ncnt=std::unique(srtdNa+1,srtdNa+ncnt+1)-(srtdNa+1);
for(i=1;i<=n;i++) add(i,1);
for(i=1;i<=m;i++)
if(C[i])
{
xcnt=ycnt=0;
for(j=A[i]-1;j;j-=lowbit(j)) X[++xcnt]=rt[j];
for(j=B[i];j;j-=lowbit(j)) Y[++ycnt]=rt[j];
printf("%d\n",srtdNa[query(1,ncnt,C[i])]);
}
else
{
add(A[i],-1);
na[A[i]]=B[i];
add(A[i],1);
}
}
}

[BZOJ2120]数颜色

给定一个长度为 $N$ 的序列 $a$,共有 $M$ 个操作,查询需要回答 $[l,r]$ 之间的种类数,或单点修改。

$N\leq10^4,M\leq10^4$, 修改操作不多于 $10^3$ 次,所有的输入数据中出现的所有整数 $1\leq x\leq10^6$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
typedef long long ll;
const int MAXN=1e4+5;
struct Tn{int lc,rc;ll sum;} t[MAXN*600];int tcnt;
int n,m;
int na[MAXN],srtdNa[MAXN],X[55],Y[55],A[MAXN],B[MAXN],C[MAXN],nextPos[MAXN],la[MAXN*100],rt[MAXN];
int ncnt,xcnt,ycnt;

int lowbit(int x){return x&(-x);}

void insert(int l,int r,int x,int &y,int k,int v)
{
if(!y) y=++tcnt,t[y]=t[x];
t[y].sum+=(ll)v;
if(l==r) return;
int mid=l+r>>1;
if(k<=mid) insert(l,mid,t[x].lc,t[y].lc,k,v);
else insert(mid+1,r,t[x].rc,t[y].rc,k,v);
}

void add(int x,int v)
{
for(int i=x;i<=n+1;i+=lowbit(i))
insert(1,n+1,rt[i],rt[i],nextPos[x],v);
}

ll query(int l,int r,int v)
{
int i;ll sum=0;
if(r<v) return 0;
if(l>=v)
{
for(i=1;i<=xcnt;i++) sum-=t[X[i]].sum;
for(i=1;i<=ycnt;i++) sum+=t[Y[i]].sum;
return sum;
}
int laX[55],laY[55];
for(i=1;i<=xcnt;i++) laX[i]=X[i];
for(i=1;i<=ycnt;i++) laY[i]=Y[i];
int mid=l+r>>1;
for(i=1;i<=xcnt;i++) X[i]=t[laX[i]].lc;
for(i=1;i<=ycnt;i++) Y[i]=t[laY[i]].lc;
sum+=query(l,mid,v);
for(i=1;i<=xcnt;i++) X[i]=t[laX[i]].rc;
for(i=1;i<=ycnt;i++) Y[i]=t[laY[i]].rc;
sum+=query(mid+1,r,v);
return sum;
}
int main()
{
int i,j,k;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)
{
scanf("%d",na+i);
if(la[na[i]]) nextPos[la[na[i]]]=i;
la[na[i]]=i;
}
for(i=1;i<=n;i++) if(!nextPos[i]) nextPos[i]=n+1;
for(i=1;i<=m;i++)
{
char s[10];scanf("%s%d%d",s,A+i,B+i);
if(s[0]=='Q') C[i]=1;
}
for(i=1;i<=n;i++) add(i,1);
for(i=1;i<=m;i++)
{
xcnt=ycnt=0;
if(C[i])
{
memset(X,0,sizeof(X));memset(Y,0,sizeof(Y));
for(j=A[i]-1;j;j-=lowbit(j)) X[++xcnt]=rt[j];
for(j=B[i];j;j-=lowbit(j)) Y[++ycnt]=rt[j];
printf("%d\n",query(1,n+1,B[i]+1));
}
else
{
add(A[i],-1);
na[A[i]]=B[i];
nextPos[A[i]]=n+1;
for(k=A[i]+1;k<=n;k++)
if(na[k]==B[i])
{
nextPos[A[i]]=k;
break;
}
add(A[i],1);
for(j=A[i]-1;j;j--)
{
if(nextPos[j]==A[i])
{
add(j,-1);
nextPos[j]=n+1;
for(k=j+1;k<=n;k++)
if(na[k]==na[j])
{
nextPos[j]=k;
break;
}
add(j,1);
}
if(na[j]==B[i]&&nextPos[j]>A[i])
{
add(j,-1);
nextPos[j]=A[i];
add(j,1);
}
}
}
}
return 0;
}