MENU

回文自动机、AC 自动机和后缀自动机介绍(2)

July 9, 2018 • Read: 3334 • 算法阅读设置

AC 自动机

AC 自动机,有的地方也叫 Trie 图,可以用来解决多串匹配的问题

多串匹配是这样一个问题:给定 N 个敏感词 W1, W2, W3, … WN,然后对于一个字符串 S,判断 S 中存在不存在任意敏感词

多串匹配比较常见的算法时间复杂度都比较高。比如我们可以用多次 KMP,伪代码如下:

  • Match(W[],S):
  • For i = 1...N
  • If KMP(W[i],S)
  • Return Found
  • Retrun NotFound

这个算法复杂度是 $O (NS.len + \sum W [i].len)$。再比如我们可以用 Trie,先把 W [] 都插入到 Trie 中,再看 S [1..n], S [2..n], S [3..n], … S [n..n](其中 n=S.len)在 Trie 上匹配,能不能到达终结点,如果能到终结点,说明 S 中有敏感词,否则说明没有。伪代码如下:

  • Match(W[],S):
  • For i = 1...N
  • Trie.insert(W[i])
  • For i = 1...S.len
  • P = Trie.Root
  • For j = i...S.len
  • If !P.thru(S[j])
  • Break
  • P = P.thru(S[j])
  • If P是终结点
  • Return Found

上面这个算法的复杂度是 $O (S.len^2 + \sum W [i].len)$ 的。其实多串匹配也有更优秀的算法,需要利用 AC 自动机或者叫 Trie 图这个工具。首先我们看一下 Trie 图是什么样子。假设 W []=["aa", "abb", "bb", "bba"],那么用 W [] 构造的 Trie 是这个样子:
而用 W [] 构造的 Trie 图是这个样子:
上面这个图看着有点复杂。不过你仔细看发现图中的边有直的也有弯的。如果只看直的边,那么就是和上一张图的 Trie 树是完全一样的。所以 Trie 图实际上可以看成是在 Trie 树上增加了一些边。那增加的这些边究竟怎么来的呢?其实增加的这些边有点像 next 数组,它告诉我们如果当前状态 u 不存在标识为字符 c 的边,那么我们要移动到哪个节点去。比如之前在 trie 中,4 号节点没有标识是’a’的边,于是 Trie 图就给 4 号节点增加了标识是’a’的边,并且指向 1 号节点

Trie 图的构造方法这里就不仔细讲了,大家有兴趣的话,可以看 hihoCoder#1036 题里的讲解。此外在网上也有很多讲义,大家可以通过搜 AC 自动机或者 Trie 图、多串匹配等关键字搜到,然后对照着看一看

我们简单介绍一下如何用 Trie 图进行多串匹配。实际上我们只要从根节点开始,拿字符串 S 在 Trie 图上跑就可以了。如果经过了任意一个终结点,就说明 S 中包含敏感词。伪代码如下:

  • Search(S)
  • P = Trie.Root
  • For i = 1...S.len
  • P = P.thru(S[i])//对于Trie图来说,边一定存在
  • If p是终结点
  • Return Found
  • Return NotFound

比如 S="abab",经过的 Trie 图路径是:01414,没有终结点,所以 S 不包含敏感词;再比如 S="babb",经过的 Trie 图路径是:02146,有终结点 6,所以包含对应的敏感词 abb

例 1 hihoCoder1440

这道题的大意就是敏感词过滤,给你 N 个敏感词和一个字符串 S。让你把 S 中只要是出现敏感词的地方都替换成 *。比如敏感词是 abc 和 cd,那么 abcxyzabcd 经过过滤之后就是 ***xyz****

  • #include <iostream>
  • #include <string>
  • #include <algorithm>
  • #include <iomanip>
  • #include <vector>
  • #include <map>
  • #include <stdio.h>
  • #include <math.h>
  • #include <string.h>
  • #include <queue>
  • #include <set>
  • using namespace std;
  • /* --------------------------------- */
  • #define ios ios_base::sync_with_stdio(false)
  • const int MAXINT = 2147483647;
  • int cnt = 0;
  • class TrNode {
  • public:
  • int no;
  • int value;
  • vector<TrNode *> next;
  • TrNode *pre;
  • TrNode(int s) {
  • value = 0;
  • next = vector<TrNode *>(s, NULL);
  • no = ++cnt;
  • }
  • void addStr(string str) {
  • TrNode *rt = this;
  • for (int i = 0; i < str.length(); i++) {
  • if (rt -> next[str[i] - 'a'] == NULL) {
  • rt -> next[str[i] - 'a'] = new TrNode(26);
  • }
  • rt = rt -> next[str[i] - 'a'];
  • }
  • rt -> value = max(rt -> value, (int)str.length());
  • }
  • void buildGraph() {
  • queue<TrNode *> Q;
  • pre = this;
  • for (int i = 0; i < next.size(); i++) {
  • if (next[i] != NULL) {
  • Q.push(next[i]);
  • next[i] -> pre = this;
  • }
  • else next[i] = this;
  • }
  • while (!Q.empty()) {
  • TrNode *rt = Q.front(); Q.pop();
  • for (int i = 0; i < rt -> next.size(); i ++) {
  • if (rt -> next[i] != NULL) {
  • Q.push(rt -> next[i]);
  • rt -> next[i] -> pre = rt -> pre -> next[i];
  • rt -> next[i] -> value = max(rt -> next[i] -> value, rt -> next[i] -> pre -> value);
  • }
  • else rt -> next[i] = rt -> pre -> next[i];
  • }
  • }
  • }
  • void print() {
  • set<int> s; s.insert(no);
  • queue<TrNode *> Q; Q.push(this);
  • while (!Q.empty()) {
  • TrNode *rt = Q.front(); Q.pop();
  • for (int i = 0; i < rt -> next.size(); i ++) {
  • if (rt -> next[i] != NULL) {
  • cout << rt -> no << ' ' << (char)('a' + i) << ' ' << rt -> next[i] -> no << endl;
  • if (s.find(rt -> next[i] -> no) == s.end()) {
  • Q.push(rt -> next[i]);
  • s.insert(rt -> next[i] -> no);
  • }
  • }
  • }
  • }
  • }
  • };
  • /* --------------------------------- */
  • int cmp(pair<int, int> a, pair<int, int> b) {
  • return a.first < b.first;
  • }
  • int main() {
  • ios;
  • int n;
  • cin >> n;
  • vector<string> keys(n);
  • for (int i = 0; i < n; i++) cin >> keys[i];
  • TrNode *root = new TrNode(26);
  • for (int i = 0; i < n; i++) root -> addStr(keys[i]);
  • //root -> print();
  • root -> buildGraph();
  • //root -> print();
  • vector<pair<int, int> > c;
  • string str;
  • cin >> str;
  • for (int i = 0; i < str.length(); i++) {
  • root = root -> next[str[i] - 'a'];
  • if (root -> value > 0) {
  • c.push_back(make_pair(i - root -> value + 1, i));
  • }
  • }
  • sort(c.begin(), c.end(), cmp);
  • int j = -1, k = 0;
  • for (int i = 0; i < str.length(); i++) {
  • while (k < c.size() && c[k].first <= i) {
  • j = max(j, c[k].second);
  • k++;
  • }
  • if (i > j) cout << str[i];
  • else cout << '*';
  • }
  • return 0;
  • }

回文自动机

最后我们再介绍一个叫做回文自动机或者叫回文树的东西。比如对于 S=”abbaabba”,构建的回文树或者回文自动机是这个样子:
回文自动机有 2 个初始节点,0 和 1,分别代表长度是偶数的回文串起点和长度是奇数的回文串起点。每个节点代表了 S 的一个回文子串。比如 2 号节点代表 a,3 号节点代表 b,4 号节点代表 bb,6 号节点代表 aa。节点 u 连出一条标记字符 c 的边,指向 v,表示 v 代表的回文串可以通过在 u 代表的回文串左右各加上字符 c 得到。比如 7 号代表 baab,就是在 6 号 aa 的左右各加上一个 'b' 得到的

回文自动机可以解决 (1) 最长回文子串问题。这个问题一般用 Manacher 算法解决,并且 Manacher 足够优秀。但其实也可以用回文自动机解决。回文自动机中最深的节点就代表最长的回文子串。比如上图中 9 号节点,代表 abbaabba

此外回文自动机还可以解决 (2) S 中本质不同的回文子串数目。比如对于 S="abbaabba",本质不同的回文子串有 a, b, aa, bb, abba, baab, bbaabb, abbaabba 8 个。实际上就是回文自动机中除去 01 之外的剩余节点数目

对于字符串 S,构造回文自动机有 $O (S.len \times log (字符集大小))$ 的算法。大家有兴趣的话可以在网上找到资料

Last Modified: May 12, 2021
Archives Tip
QR Code for this page
Tipping QR Code