实现TRIE结构
第一种方法是用一个二维数组来存储:
int trie[MAX_NODE][CHARSET];
int k;
其中MAX_NODE是trie中最大能存储的节点数目,CHARSET是字符集的大小,k是当前trie中包含有多少个节点。Trie[i][j]的值是0表示trie树中i号节点,并没有一条连出去的边,满足边上的字符标识是字符集中第j个字符(从0开始);trie[i][j]的值是正整数x表示trie树中i号节点,有一条连出去的边,满足边上的字符标识是字符集中第j个字符,并且这条边的终点是x号节点
举个例子,下图中左边是trie树,右边是二维数组trie中非0的值
用二维数组实现trie的好处是用起来非常方便,因为trie的insert和search操作都要经常判断一个节点有没有标识某个字符的边,以及边的终点是几号节点。用二维数组的话,我们只要看相应的trie[i][j]的值即可。用二维数组的缺点是可能会浪费很多空间,因为我们对每一个节点都用了一个字符集大小的数组存储子节点号,但实际上每个点连出去的边很稀疏。比如上图中实际上每个点只有0-2个子节点,但是我们给每个节点都开辟了一个大小是26的数组去存储子节点
第二种方法是像邻接表一样用vector来存储每个节点的子节点:
vector<pair<char, int> > trie[MAX_NODE];
int k;
我们把i号节点的子节点和边的信息存在trie[i]里。每一个pair<char, int>其实是(边的标识,子节点号)这样的二元组:
用vector的好处是可以节省很多空间,起码理论上空间复杂度是O(N)的,其中N是节点总数。缺点是每次我们想找i号节点有没有标识是某个字符ch的边时,都需要遍历一遍trie[i]这个vector,而不能像数组一样直接查找
第三种方法是用unordered_map
unordered_map<char, int> trie[MAX_NODE];
int k;
使用unordered_map看上去是一个两全其美的做法。每次我们想找i号节点有没有标识是某个字符ch的边时,只要看trie[i][ch]的值即可。同时理论上也不需要每个节点都占用CHARSERT大小的空间去存储子节点,而是有几个子节点就用到几个子节点的空间。理论上时空复杂度都是O(N)
我们采用数组的方式来存储,下面分别看一下insert和search方法
#include <iostream>
#include <cstring>
using namespace std;
const int MAX_NODE = 1000000 + 10;
const int CHARSET = 26;
int trie[MAX_NODE][CHARSET] = {0};
int color[MAX_NODE0] = {0};
int k = 1;
void insert(char *w) {
int len = strlen(w);
int p = 0;
for(int i = 0;i < len;i++) {
int c = w[i] - 'a';
if(!trie[p][c]) {
trie[p][c] = k;
k++;
}
p = trie[p][c];
}
color[p] = 1;
}
代码的第6~8行,一开始trie[][]被初始化为0,保证每个节点被创建出来时,都没有子节点。K初始化为1表示一开始只有1个节点,也就是0号节点根节点。Color是用来标记一个节点是不是终结点。Color[i]=1标识i号节点是终结点
第9~21行是插入函数insert(w),w是字符指针,实际上可以看作是一个字符串
第11行是p从0号节点开始。第12~19行是依次插入w的每一个字符。第13行是计算w[i]是字符集第几个字符,这里我们假设字符集只包含26个小写字母。第14~17行是如果p没有连出标识是w[i]的边,那么就创建一个。这里新创建的节点一定就是k号节点。所谓创建新节点实际上也没什么可创建的,新节点就是个编号。所以我们直接令triei=k即可,然后将k累加1,整个创建过程就完成了。第18行是沿着标记着w[i]的边移动到下一个节点。最后第20行,是将最后到达的节点p标记为终结点
有了插入,我们再看查找代码如何实现:
int search(char *s) {
int len = strlen(s);
int p = 0;
for(int i = 0;i < len;i++) {
int c = s[i] - 'a';
if(!trie[p][c]) return 0;
p = trie[p][c];
}
return color[p] == 1;
}
查找的代码就比较简单直观,search返回0表示trie中没有s,返回1表示有s。第3行是从p=0也就是根节点开始。第4~8行是枚举s的每一个字符。第5行是计算当前字符s[i]在字符集的序号。第6行是判断p节点有没有连出标识s[i]字符的边,如果没有,说明现在无路可走,直接返回0;如果有的话,第7行就是移动到下一个节点。如果整个循环结束还没有return 0,那就说明成功沿着s的每一个字符到达了p节点。这时只要判断p节点是不是终结点即可,也就是第9行的代码
例1 题目链接:hihoCoder1014
这道题目的大意是给定一个包含N个字符串的集合,然后再给出M个询问。每次询问给出一个字符串s,要求你回答集合中有几个字符串的前缀是s
比如集合是{ babaab, babbbaaaa, abba, aaaaabaa, babaabab}询问前缀是bab,答案就是3。因为有babaab, babbbaaaa, babaabab 三个字符串前缀是bab
这道题是一道很经典的用Trie解决的题目。首先我们把集合中的N个字符串都插入到trie中。对于每一个查询s我们在trie中查找s,如果查找过程中无路可走,那么一定没有以s为前缀的字符串。如果最后停在一个节点p,那我们就要看看以p为根的子树里一共有多少终结点。终结点的数目就是答案
比如在上图中,询问是s=”in”。我们会停在2号节点上。这棵子树一共有3个终结点,对应着3个字符串”in”, “inn”, “int”,这三个字符串前缀都是”in”,所以答案是3
但是如果我们每次都遍历以P为根的子树,那时间复杂度就太高了。解决的办法是用空间换时间,我们增加一个数组int cnt[MAX_NODE],cnt[i]记录的是以i号节点为根的子树中,有几个终结点。然后我们每次insert一个字符串的时候,顺便就把沿途的节点的cnt值都+1。这样就不用每次遍历以P为根的子树,而是直接输出cnt[P]即可
#include <iostream>
#include <cstring>
using namespace std;
const int MAX_NODE = 1000000 + 10;
const int CHARSET = 26;
int trie[MAX_NODE][CHARSET] = {0};
int cnt[MAX_NODE] = {0};
int n,m,k = 1;
char s[20];
void insert(char *w) {
int len = strlen(w);
int p = 0;
for(int i = 0;i < len;i++) {
int c = w[i] - 'a';
if(!trie[p][c]) {
trie[p][c] = k;
k++;
}
p = trie[p][c];
cnt[p]++;
}
}
int search(char *s) {
int len = strlen(s);
int p = 0;
for(int i = 0;i < len;i++) {
int c = s[i] - 'a';
if(!trie[p][c]) return 0;
p = trie[p][c];
}
return cnt[p];
}
int main() {
memset(trie,0,sizeof(trie));
scanf("%d",&n);
for(int i = 0;i < n;i++) {
scanf("%s",s);
insert(s);
}
scanf("%d",&m);
while(m--) {
scanf("%s",s);
int ans = search(s);
printf("%d\n",ans);
}
return 0;
}