AC 自动机
AC 自动机,有的地方也叫 Trie 图,可以用来解决多串匹配的问题
多串匹配是这样一个问题:给定 N 个敏感词 W1, W2, W3, … WN,然后对于一个字符串 S,判断 S 中存在不存在任意敏感词
多串匹配比较常见的算法时间复杂度都比较高。比如我们可以用多次 KMP,伪代码如下:
- Match(W[],S):
- For i = 1...N
- If KMP(W[i],S)
- Return Found
- Retrun NotFound
这个算法复杂度是 $O (NS.len + \sum W [i].len)$。再比如我们可以用 Trie,先把 W [] 都插入到 Trie 中,再看 S [1..n], S [2..n], S [3..n], … S [n..n](其中 n=S.len)在 Trie 上匹配,能不能到达终结点,如果能到终结点,说明 S 中有敏感词,否则说明没有。伪代码如下:
- Match(W[],S):
- For i = 1...N
- Trie.insert(W[i])
- For i = 1...S.len
- P = Trie.Root
- For j = i...S.len
- If !P.thru(S[j])
- Break
- P = P.thru(S[j])
- If P是终结点
- Return Found
上面这个算法的复杂度是 $O (S.len^2 + \sum W [i].len)$ 的。其实多串匹配也有更优秀的算法,需要利用 AC 自动机或者叫 Trie 图这个工具。首先我们看一下 Trie 图是什么样子。假设 W []=["aa", "abb", "bb", "bba"],那么用 W [] 构造的 Trie 是这个样子:而用 W [] 构造的 Trie 图是这个样子:
上面这个图看着有点复杂。不过你仔细看发现图中的边有直的也有弯的。如果只看直的边,那么就是和上一张图的 Trie 树是完全一样的。所以 Trie 图实际上可以看成是在 Trie 树上增加了一些边。那增加的这些边究竟怎么来的呢?其实增加的这些边有点像 next 数组,它告诉我们如果当前状态 u 不存在标识为字符 c 的边,那么我们要移动到哪个节点去。比如之前在 trie 中,4 号节点没有标识是’a’的边,于是 Trie 图就给 4 号节点增加了标识是’a’的边,并且指向 1 号节点
Trie 图的构造方法这里就不仔细讲了,大家有兴趣的话,可以看 hihoCoder#1036 题里的讲解。此外在网上也有很多讲义,大家可以通过搜 AC 自动机或者 Trie 图、多串匹配等关键字搜到,然后对照着看一看
我们简单介绍一下如何用 Trie 图进行多串匹配。实际上我们只要从根节点开始,拿字符串 S 在 Trie 图上跑就可以了。如果经过了任意一个终结点,就说明 S 中包含敏感词。伪代码如下:
- Search(S)
- P = Trie.Root
- For i = 1...S.len
- P = P.thru(S[i])//对于Trie图来说,边一定存在
- If p是终结点
- Return Found
- Return NotFound
比如 S="abab",经过的 Trie 图路径是:01414,没有终结点,所以 S 不包含敏感词;再比如 S="babb",经过的 Trie 图路径是:02146,有终结点 6,所以包含对应的敏感词 abb
例 1 hihoCoder1440
这道题的大意就是敏感词过滤,给你 N 个敏感词和一个字符串 S。让你把 S 中只要是出现敏感词的地方都替换成 *。比如敏感词是 abc 和 cd,那么 abcxyzabcd 经过过滤之后就是 ***xyz****
- #include <iostream>
- #include <string>
- #include <algorithm>
- #include <iomanip>
- #include <vector>
- #include <map>
- #include <stdio.h>
- #include <math.h>
- #include <string.h>
- #include <queue>
- #include <set>
- using namespace std;
-
- /* --------------------------------- */
-
- #define ios ios_base::sync_with_stdio(false)
-
- const int MAXINT = 2147483647;
- int cnt = 0;
- class TrNode {
- public:
- int no;
- int value;
- vector<TrNode *> next;
- TrNode *pre;
- TrNode(int s) {
- value = 0;
- next = vector<TrNode *>(s, NULL);
- no = ++cnt;
- }
- void addStr(string str) {
- TrNode *rt = this;
- for (int i = 0; i < str.length(); i++) {
- if (rt -> next[str[i] - 'a'] == NULL) {
- rt -> next[str[i] - 'a'] = new TrNode(26);
- }
- rt = rt -> next[str[i] - 'a'];
- }
- rt -> value = max(rt -> value, (int)str.length());
- }
- void buildGraph() {
- queue<TrNode *> Q;
- pre = this;
- for (int i = 0; i < next.size(); i++) {
- if (next[i] != NULL) {
- Q.push(next[i]);
- next[i] -> pre = this;
- }
- else next[i] = this;
- }
- while (!Q.empty()) {
- TrNode *rt = Q.front(); Q.pop();
- for (int i = 0; i < rt -> next.size(); i ++) {
- if (rt -> next[i] != NULL) {
- Q.push(rt -> next[i]);
- rt -> next[i] -> pre = rt -> pre -> next[i];
- rt -> next[i] -> value = max(rt -> next[i] -> value, rt -> next[i] -> pre -> value);
- }
- else rt -> next[i] = rt -> pre -> next[i];
- }
- }
- }
- void print() {
- set<int> s; s.insert(no);
- queue<TrNode *> Q; Q.push(this);
- while (!Q.empty()) {
- TrNode *rt = Q.front(); Q.pop();
- for (int i = 0; i < rt -> next.size(); i ++) {
- if (rt -> next[i] != NULL) {
- cout << rt -> no << ' ' << (char)('a' + i) << ' ' << rt -> next[i] -> no << endl;
- if (s.find(rt -> next[i] -> no) == s.end()) {
- Q.push(rt -> next[i]);
- s.insert(rt -> next[i] -> no);
- }
- }
- }
- }
- }
- };
-
- /* --------------------------------- */
-
- int cmp(pair<int, int> a, pair<int, int> b) {
- return a.first < b.first;
- }
-
- int main() {
- ios;
-
- int n;
- cin >> n;
- vector<string> keys(n);
- for (int i = 0; i < n; i++) cin >> keys[i];
- TrNode *root = new TrNode(26);
- for (int i = 0; i < n; i++) root -> addStr(keys[i]);
- //root -> print();
- root -> buildGraph();
- //root -> print();
- vector<pair<int, int> > c;
- string str;
- cin >> str;
- for (int i = 0; i < str.length(); i++) {
- root = root -> next[str[i] - 'a'];
- if (root -> value > 0) {
- c.push_back(make_pair(i - root -> value + 1, i));
- }
- }
- sort(c.begin(), c.end(), cmp);
- int j = -1, k = 0;
- for (int i = 0; i < str.length(); i++) {
- while (k < c.size() && c[k].first <= i) {
- j = max(j, c[k].second);
- k++;
- }
- if (i > j) cout << str[i];
- else cout << '*';
- }
-
- return 0;
- }
回文自动机
最后我们再介绍一个叫做回文自动机或者叫回文树的东西。比如对于 S=”abbaabba”,构建的回文树或者回文自动机是这个样子:回文自动机有 2 个初始节点,0 和 1,分别代表长度是偶数的回文串起点和长度是奇数的回文串起点。每个节点代表了 S 的一个回文子串。比如 2 号节点代表 a,3 号节点代表 b,4 号节点代表 bb,6 号节点代表 aa。节点 u 连出一条标记字符 c 的边,指向 v,表示 v 代表的回文串可以通过在 u 代表的回文串左右各加上字符 c 得到。比如 7 号代表 baab,就是在 6 号 aa 的左右各加上一个 'b' 得到的
回文自动机可以解决 (1) 最长回文子串问题。这个问题一般用 Manacher 算法解决,并且 Manacher 足够优秀。但其实也可以用回文自动机解决。回文自动机中最深的节点就代表最长的回文子串。比如上图中 9 号节点,代表 abbaabba
此外回文自动机还可以解决 (2) S 中本质不同的回文子串数目。比如对于 S="abbaabba",本质不同的回文子串有 a, b, aa, bb, abba, baab, bbaabb, abbaabba 8 个。实际上就是回文自动机中除去 01 之外的剩余节点数目
对于字符串 S,构造回文自动机有 $O (S.len \times log (字符集大小))$ 的算法。大家有兴趣的话可以在网上找到资料