trie详解

Trie树略解

简介

Trie树,是一种树形借工,是一种哈希树的变种。典型应用是用于统计,排序和保存大量的字符串(但不仅限于字符串)。

它利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较,查询效率比哈希树高。

建图

假设我们有\(aa,aba,ba,caaa,cab,cba,cc,cb\)这么多字符串

那么我们可以构建一棵这样的Trie树(其中红色节点表示每一个trie树的节点)

可以发现,这棵字典树用边来代表字母

而从根节点到树上红色节点的路径就代表了一个读入的字符串

\(Trie\)的图非常好建,假设当前节点编号为u,下一个字符为c

则只需新申请一个节点,并将其作为\(Trie[u][c]\)的编号即可

void insert(char a[]){
    int root=0,len=strlen(a);
    for(int i=0;i<len;i++){
        int k=a[i]-'a';
        if(tr[root][k]==0)tr[root][k]=++tot;//申请新的节点
        root=tr[root][k];
    }
    vis[root]=1;//标记字符串的结尾
}

至于查询,也相当简单。按照下一条边的编号递归下去以及\(Trie\)数组中蕴含的信息,递归下去即可。

bool search(char a[]){
    int root=0,len=strlen(a);
    for(int i=0;i<len;i++){
        int k=a[i]-'a';
        if(tr[root][k]==0)return false;
        root=tr[root][k];
    }
    return vis[root];//判断a数组是否在给定的字符串集中出现过
}

动态开点Trie树

上述\(Trie\)树实现的过程中,我们使用了数组。

不过这样子可能会由于对空间的控制不当导致数组越界或者超空间。

此时,我们就可以运用指针来实现动态开点\(Trie\)

代码如下

struct Trie{
    Trie* Next[26];
    Trie(){
        for(int i=0;i<26;i++)Next[i]=NULL;
    }
}root;
void insert(char a[]){
    int len=strlen(a);
    Trie* rt=&root;
    for(int i=0;i<len;i++){
        if(rt->Next[a[i]-'a']==NULL)rt->Next[a[i]-'a']=new Trie;
        rt=rt->Next[a[i]-'a'];
    }
}
bool search(char a[]){
    int len=strlen(a);
    Trie* rt=&root;
    for(int i=0;i<len;i++){
        if(rt->Next[a[i]-'a']==NULL)return false;
        rt=rt->Next[a[i]-'a'];
    }
    return 1;
}

应用

假定这题的空间复杂度是\(512MB\)

在一次\(NOI\ Online\)的测试中,我遇到了这么一道题目

数据范围:\(n\leq3000\)

解题思路

\(Alice\)的初始串为\(s\)\(Bob\)的初始串为\(t\)

由于\(Bob\)的条件的特殊性,我们可以枚举未被删除的部分的左端点\(l\)

然后由于对于\(s\)串子序列与\(t\)串的匹配,我们可以运用贪心的思想。

如果存在\(s_i\)=\(t_j\),若此时不匹配,到\(s_k=t_j(k>j)\)时在匹配肯定不是最优的(感性理解一下)

然后匹配后,就把答案+1

但是这样可能会出现重复的情况,我们用\(trie\)树来处理重复情况。

#include<bits/stdc++.h>
using namespace std;
#define N 3005
int n,ans;
char s[N],t[N];
struct Trie{
    Trie* Next[26];
    Trie(){
        for(int i=0;i<26;i++)Next[i]=NULL;
    }
}root;
int main(){
    scanf("%d%s%s",&n,s+1,t+1);
    for(int i=1,j,k;i<=n;i++){
        Trie* rt=&root;
        for(j=i,k=1;j<=n&&k<=n;k++)
            if(t[j]==s[k]){
                if(rt->Next[t[j]-'a']==NULL)rt->Next[t[j]-'a']=new Trie;
                else ans--;
                rt=rt->Next[t[j]-'a'];
                j++;
            }
        ans+=j-i;
    }
    printf("%d",ans);
}

时间复杂度:\(O(n^2)\)

在空间限制为\(512MB\)时可以拿满分

01-Trie

将数的二进制表示看做一个字符串,就可以建出字符集为{\(0,1\)}的\(Trie\)树。

\(01-trie\)树有两种建树方式:

  1. 从低位到高位
  2. 从高位到低位

我们要分情况的使用

例题:最长异或路径

解题思路

首先用\(T(u,v)\)来表示 \((u,v)\)之间路径异或和

\(T(u,v)=T(root,u)\ xor\ T(root,v)\)(显然)

然后对于每个\(T(root,u)\),我们都可以运用贪心的思想求出最大的答案

\(Trie\)的根开始,如果能向和\(T(root,u)\)的当前位不同的子树走,就向那边走,否则向另一边走。(显然)

代码

#include<bits/stdc++.h>
using namespace std;
#define N 100005
int n,trie[N<<5][2],head[N],ans,edgenum,tot,dis[N];
struct edge{
    int v,next,w;
}e[N<<1];
void add(int u,int v,int w){
    edgenum++;
    e[edgenum]=(edge){v,head[u],w};
    head[u]=edgenum;
}
void insert(int w){
    int root=0;
    for(int i=30;i>=0;i--){
        int x=(w&(1<<i));
        if(x)x=1;
        if(!trie[root][x])trie[root][x]=++tot;
        root=trie[root][x];
    }
    //这种情况下,从高位到低位建树可以使我们更好的贪心
}
void get(int w){
    int rec=0,root=0;
    for(int i=30;i>=0;i--){
        int x=(w&(1<<i));
        if(x)x=1;
        if(trie[root][x^1]){
            root=trie[root][x^1];
            rec|=(1<<i);
        }
        else root=trie[root][x];
    }
    ans=max(ans,rec);
}
void dfs(int u,int fa){
    insert(dis[u]);
    get(dis[u]);
    for(int i=head[u],v;i;i=e[i].next)
        if((v=e[i].v)^fa){
            dis[v]=dis[u]^e[i].w;
            dfs(v,u);
        }
}
int main(){
    scanf("%d",&n);
    for(int i=1,u,v,w;i<n;i++){
        scanf("%d%d%d",&u,&v,&w);
        add(u,v,w);add(v,u,w);
    }
    dfs(1,0);
    printf("%d",ans);
}

维护异或和

\(01-trie\)可以用来维护一些数字的异或和,支持修改(删除+重新插入)和全局加1

插入&删除

如果要维护异或和,我们只需要知道某一位上0和1个数的奇偶性

也就是只有这一位上1的个数为奇数时,这一位上的数字才是1

对于每个节点,我们要记录一下三个量:

\(trie[u][0/1]\)表示\(u\)的两个子节点,\(trie[u][0]\)指下一位是\(0\)的子节点,\(trie[u][1]\)指下一位是\(1\)的子节点

\(w[u]\)是指从\(trie\)树根经过\(u\)向上的这条边的数字的数目,用于维护\(0\)\(1\)个数的奇偶性

\(sum[u]\)是指以\(u\)为根的子树的异或和

代码(维护当前节点)

void maintain(int u){
    w[u]=sum[u]=0;
    if(trie[u][0]){
        w[u]+=w[trie[u][0]];
        sum[u]^=sum[trie[u][0]]<<1;
        //由于我们是从低位到高位构建trie树的,所以在子树中得到的答案到当前节点要*2
    }
    if(trie[u][1]){
        w[u]+=w[trie[u][0]];
        sum[u]^=(sum[trie[u][1]]<<1)|(w[trie[u][1]]&1);
        //只有当为1时且w[u]为奇数时才会对答案做出贡献
    }
}

插入、删除的代码非常相似

为处理方便,我们强制定义了个\(Trie\)树的最大深度\(height\),也就是每个数字,即便是它的高位为0,也要到\(height+1\)才退出

至于插入、删除中答案的更新,只要不断调用\(maintain\)函数即可

总代吗

void maintain(int u){
    w[u]=sum[u]=0;
    if(trie[u][0]){
        w[u]+=w[trie[u][0]];
        sum[u]^=sum[trie[u][0]]<<1;
        //由于我们是从低位到高位构建trie树的,所以在子树中得到的答案到当前节点要*2
    }
    if(trie[u][1]){
        w[u]+=w[trie[u][1]];
        sum[u]^=(sum[trie[u][1]]<<1)|(w[trie[u][1]]&1);
        //只有当为1时且w[trie[u][1]]为奇数时才会对答案做出贡献
    }
}
void insert(int &u,int x,int dep){
    if(!u)u=++tot;
    if(dep>height){
        w[u]++;
        return;
    }
    insert(trie[u][x&1],x>>1,dep+1);
    maintain(u);
}
void erase(int u,int x,int dep){
    if(dep>height){
        w[u]--;
        return;
    }
    erase(trie[u][x&1],x>>1,dep+1);
    maintain(u);
}

全局加一

全局加一就是指让这棵\(trie\)中所有的数值+1

\(trie\)中维护的数值有\(V_1,V_2,V_3,\cdots,V_n\)

全局加一后其中维护的值应该变成\(V_1+1,V_2+1,\cdots,V_n+1\)

接下来我们思考一下二进制意义+1是如何操作的

我们从低位到高位找到第一个出现的\(0\),把它变成\(1\),然后把这个位置后面的\(1\)都变成0即可

1000(10)  + 1 = 1001(11)  ;
10011(19) + 1 = 10100(20) ;
11111(31) + 1 = 100000(32);
10101(21) + 1 = 10110(22) ;

对应\(trie\)树上的操作,其实就是交换其左右儿子,顺着交换后的通往\(0\)的边往下递归操作。

此时,我们之前从低位到高位的建树可以使我们轻松的处理该操作

void addall(int u){
    swap(trie[u][0],trie[u][1]);
    if(trie[u][0])addall(trie[u][0]);
    maintain(u);
}

01-trie合并

这指的是将两个\(01-Trie\)进行合并,同时合并维护的信息

首先考虑我们有一个int merge(int a,int b)的函数

这个函数传入两个\(Trie\)树位于同一相对位置的节点编号,合并完成后返回合并后的节点编号

我们分三种情况

  • 如果\(a\)没有这个位置上的节点,新合并的节点就是\(b\)
  • 如果\(b\)没有这个位置上的节点,新合并的节点就是\(a\)
  • 如果\(a,b\)都存在,那就把\(b\)的信息合并到\(a\)上,新合并的节点就是\(a\),然后递归处理\(a\)的左右儿子
int merge(int a,int b){
    if(!a)return b;
    if(!b)return a;
    w[a]=w[a]+w[b];
    sum[a]^=sum[b];
    trie[a][0]=merge(trie[a][0],trie[b][0]);
    trie[a][1]=merge(trie[a][1],trie[b][1]);
    return a;
}

例题HDU6191 Query on A Tree

题目大意

一棵树上,每次输入\(u,x\),询问以\(u\)为根节点的子树上的某个点与\(x\)异或最大可以是多少

\(2\leq n,q\leq10^5,0\leq v_i\leq10^9\)

解题思路

这是一道简单的模拟题,只需按照题目说的做即可

对于每个节点,我们都建一棵\(01trie\)

显而易见,我们可以先从叶子结点开始建树,随后运用\(01trie\)树合并的方法,得到每个节点的\(01trie\)树。

然后是贪心的求最大异或和了

代码

#include<algorithm>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
#define N 100005
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<3)+(x<<1)+(ch^48);
        ch=getchar();
    }
    return x;
}
int n,q,ans[N],head[N],edgenum,a[N];
struct nade{ 
    int v,next;
}e[N];
vector<pair<int,int> >E[N];
struct trie{
    trie* Next[2];
    trie(){
        Next[0]=Next[1]=NULL;
    }
};
void add(int u,int v){
    edgenum++;
    e[edgenum]=(nade{v,head[u]});
    head[u]=edgenum;
}
trie* merge(trie* a,trie* b){
    if(!b)return a;
    if(!a)return b;
    a->Next[0]=merge(a->Next[0],b->Next[0]);
    a->Next[1]=merge(a->Next[1],b->Next[1]);
    free(b);
    return a;
}
void insert(int x,trie* a){
    for(int i=30;i>=0;i--){
        int k=(x&(1<<i));
        if(k)k=1;
        if(!a->Next[k])a->Next[k]=new trie;
        a=a->Next[k];
    }
}
int ask(trie* a,int x){
    int rec=0;
    for(int i=30;i>=0;i--){
        int k=(x&(1<<i));
        if(k)k=1;
        if(a->Next[k^1]){
            a=a->Next[k^1];
            rec|=(1<<i);
        }
        else a=a->Next[k];
    }
    return rec;
}
trie* dfs(int u){
    trie* root=new trie;
    for(int i=head[u],v;i;i=e[i].next)root=merge(root,dfs(e[i].v));
    insert(a[u],root);
    for(int i=0;i<E[u].size();i++)ans[E[u][i].first]=ask(root,E[u][i].second);
    return root;
}
void del(trie* rt){
    if(rt->Next[0])del(rt->Next[0]);
    if(rt->Next[1])del(rt->Next[1]);
    free(rt);
}
int main(){
    while(scanf("%d%d",&n,&q)!=EOF){
        edgenum=0;
        for(int i=1;i<=n;i++){
            head[i]=0;
            E[i].clear();
        }
        for(int i=1;i<=n;i++)a[i]=read();
        for(int i=2;i<=n;i++)add(read(),i);
        for(int i=1;i<=q;i++){
            int u=read(),x=read();
            E[u].push_back(make_pair(i,x));
        }
        del(dfs(1));
        for(int i=1;i<=q;i++)printf("%d\n",ans[i]);
    }
}

例题Fusion tree

题目大意

解题思路

这道题其实也是道很裸的\(01-trie\)

我们先令\(1\)为该树的根

然后对于每个节点,建立\(01-trie\)树来维护与它相邻的子节点的异或和

  • 对于操作\(1\),首先运用”全局+1″来搞定与它相邻的子节点的异或和。然后对于它的父亲节点,先将其从它爷爷的\(01-trie\)树中删除它,然后在插入更新后的权值
  • 对于操作\(2\),是操作\(1\)的简化版
  • 对于操作\(3\),合并以它为根的\(trie\)树的答案,再与其父亲的权值异或下即可

代码

#include<bits/stdc++.h>
using namespace std;
#define N 1000005
inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}
int tag[N<<5],sum[N<<5],trie[N<<5][2],rt[N],w[N<<5],a[N],tot,n,m,fa[N];
int head[N],edgenum;
struct edge{
    int v,next;
}e[N<<1];
void add(int u,int v){
    edgenum++;
    e[edgenum]={v,head[u]};
    head[u]=edgenum;
} 
void dfs(int u){
    for(int i=head[u];i;i=e[i].next){
        if(e[i].v==fa[u])continue;
        fa[e[i].v]=u;
        dfs(e[i].v);
    }
}
void weihu(int u){
    w[u]=sum[u]=0;
    if(trie[u][0]){
        w[u]+=w[trie[u][0]];
        sum[u]^=sum[trie[u][0]]<<1;
    }
    if(trie[u][1]){
        w[u]+=w[trie[u][1]];
        sum[u]^=(sum[trie[u][1]]<<1)|(w[trie[u][1]]&1);
    }
}
void insert(int &u,int x,int dep){
    if(!u)u=++tot;
    if(dep>30){
        w[u]++;
        return;
    }
    insert(trie[u][x&1],x>>1,dep+1);
    weihu(u);
}
void erase(int u,int x,int dep){
    if(dep>30){
        w[u]--;
        return;
    }
    erase(trie[u][x&1],x>>1,dep+1);
    weihu(u);
}
void addall(int u){
    swap(trie[u][0],trie[u][1]);
    if(trie[u][0])addall(trie[u][0]);
    weihu(u); 
}
int main(){
    n=read();m=read();
    for(int i=1,u,v;i<n;i++){
        u=read();v=read();
        add(u,v);add(v,u);
    }
    dfs(1);
    for(int i=1;i<=n;i++)rt[i]=++tot;
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        if(i>1)insert(rt[fa[i]],a[i],0);
    }
    for(int op,x,v;m--;){
        op=read();x=read();
        if(op==1){
            tag[x]++;
            addall(rt[x]);
            if(x>1){
                if(fa[x]>1)erase(rt[fa[fa[x]]],a[fa[x]]+tag[fa[fa[x]]],0);
                a[fa[x]]++;
                if(fa[x]>1)insert(rt[fa[fa[x]]],a[fa[x]]+tag[fa[fa[x]]],0);
            }
        }
        else if(op==2){
            v=read();
            if(x>1)erase(rt[fa[x]],a[x]+tag[fa[x]],0);
            a[x]-=v;
            if(x>1)insert(rt[fa[x]],a[x]+tag[fa[x]],0);
        }
        else{
            if(x>1){
                if(fa[x]>1)printf("%d\n",sum[rt[x]]^(a[fa[x]]+tag[fa[fa[x]]]));
                else printf("%d\n",sum[rt[x]]^a[fa[x]]);
            }
            else printf("%d\n",sum[1]);
        }
    }
}

例题[省选联考 2020 A 卷] 树

题目大意

\(1\leq n,v_i\leq525010,1\leq p_i\leq n\)

解题思路

这道题依旧是\(01-trie\)的裸题。

由于每个节点的答案只与它子树内的节点有关,我们可以先求出子树的答案,再运用\(01-trie\)合并更新其父节点的答案

\(d(fa[x],y)=d(x,y)+1(y是x子树内的节点)\)

这就和全局加1很像,于是便可直接套模板。

同时,应当先合并子节点的\(01-trie\),再全局加1,最后再插入当前节点的权值

代码

#include<bits/stdc++.h>
using namespace std;
#define N 1000005
#define int long long
int w[N<<5],sum[N<<5],trie[N<<5][2],n,a[N],head[N],edgenum,tot,rt[N],ans;
struct nade{
    int v,next;
}e[N<<1];
void add(int u,int v){
    edgenum++;
    e[edgenum]=nade{v,head[u]};
    head[u]=edgenum;
}
void weihu(int u){
    w[u]=sum[u]=0;
    if(trie[u][0]){
        w[u]+=w[trie[u][0]];
        sum[u]^=sum[trie[u][0]]<<1;
    }
    if(trie[u][1]){
        w[u]+=w[trie[u][1]];
        sum[u]^=(sum[trie[u][1]]<<1)|(w[trie[u][1]]&1);
    }
}
void insert(int &u,int x,int dep){
    if(!u)u=++tot;
    if(dep>30){
        w[u]++;
        return;
    }
    insert(trie[u][x&1],x>>1,dep+1);
    weihu(u);
}
int merge(int a,int b){
    if(!a)return b;
    if(!b)return a;
    w[a]+=w[b];
    sum[a]^=sum[b];
    trie[a][0]=merge(trie[a][0],trie[b][0]);
    trie[a][1]=merge(trie[a][1],trie[b][1]);
    return a;
}
void addall(int u){
    swap(trie[u][0],trie[u][1]);
    if(trie[u][0])addall(trie[u][0]);
    weihu(u);
}
int dfs(int u){
    for(int i=head[u];i;i=e[i].next)rt[u]=merge(rt[u],dfs(e[i].v));
    addall(rt[u]);
    insert(rt[u],a[u],0);
    ans+=sum[rt[u]];
    return rt[u];
}
signed main(){
    scanf("%lld",&n);
    for(int i=1;i<=n;i++){
        rt[i]=++tot;
        scanf("%lld",&a[i]);
    }
    for(int i=2,fa;i<=n;i++){
        scanf("%lld",&fa);
        add(fa,i);
    }
    dfs(1);
    printf("%lld",ans);
}