MENU

回文自动机、AC自动机和后缀自动机介绍(2)

July 9, 2018 • Read: 2827 • 算法阅读设置

AC自动机

AC自动机,有的地方也叫Trie图,可以用来解决多串匹配的问题

多串匹配是这样一个问题:给定N个敏感词W1, W2, W3, … WN,然后对于一个字符串S,判断S中存在不存在任意敏感词

多串匹配比较常见的算法时间复杂度都比较高。比如我们可以用多次KMP,伪代码如下:

Match(W[],S):
    For i = 1...N
        If KMP(W[i],S)
            Return Found
Retrun NotFound

这个算法复杂度是$O(NS.len + \sum W[i].len)$。再比如我们可以用Trie,先把W[]都插入到Trie中,再看S[1..n], S[2..n], S[3..n], … S[n..n](其中n=S.len)在Trie上匹配,能不能到达终结点,如果能到终结点,说明S中有敏感词,否则说明没有。伪代码如下:

Match(W[],S):
    For i = 1...N
        Trie.insert(W[i])
    For i = 1...S.len
        P = Trie.Root
        For j = i...S.len
            If !P.thru(S[j])
                Break
            P = P.thru(S[j])
            If P是终结点
                Return Found

上面这个算法的复杂度是$O(S.len^2 + \sum W[i].len)$的。其实多串匹配也有更优秀的算法,需要利用AC自动机或者叫Trie图这个工具。首先我们看一下Trie图是什么样子。假设W[]=["aa", "abb", "bb", "bba"],那么用W[]构造的Trie是这个样子:

而用W[]构造的Trie图是这个样子:

上面这个图看着有点复杂。不过你仔细看发现图中的边有直的也有弯的。如果只看直的边,那么就是和上一张图的Trie树是完全一样的。所以Trie图实际上可以看成是在Trie树上增加了一些边。那增加的这些边究竟怎么来的呢?其实增加的这些边有点像next数组,它告诉我们如果当前状态u不存在标识为字符c的边,那么我们要移动到哪个节点去。比如之前在trie中,4号节点没有标识是’a’的边,于是Trie图就给4号节点增加了标识是’a’的边,并且指向1号节点

Trie图的构造方法这里就不仔细讲了,大家有兴趣的话,可以看hihoCoder#1036题里的讲解。此外在网上也有很多讲义,大家可以通过搜AC自动机或者Trie图、多串匹配等关键字搜到,然后对照着看一看

我们简单介绍一下如何用Trie图进行多串匹配。实际上我们只要从根节点开始,拿字符串S在Trie图上跑就可以了。如果经过了任意一个终结点,就说明S中包含敏感词。伪代码如下:

Search(S)
    P = Trie.Root
    For i = 1...S.len
        P = P.thru(S[i])//对于Trie图来说,边一定存在
        If p是终结点
            Return Found
Return NotFound

比如S="abab",经过的Trie图路径是:01414,没有终结点,所以S不包含敏感词;再比如S="babb",经过的Trie图路径是:02146,有终结点6,所以包含对应的敏感词abb

例1 hihoCoder1440


这道题的大意就是敏感词过滤,给你N个敏感词和一个字符串S。让你把S中只要是出现敏感词的地方都替换成*。比如敏感词是abc和cd,那么abcxyzabcd经过过滤之后就是***xyz****

#include <iostream>
#include <string>
#include <algorithm>
#include <iomanip>
#include <vector>
#include <map>
#include <stdio.h>
#include <math.h>
#include <string.h>
#include <queue>
#include <set>
using namespace std;

/* --------------------------------- */
 
#define ios ios_base::sync_with_stdio(false)

const int MAXINT = 2147483647;
int cnt = 0;
class TrNode {
public:
    int no;
    int value;
    vector<TrNode *> next;
    TrNode *pre;
    TrNode(int s) {
        value = 0;
        next = vector<TrNode *>(s, NULL);
        no = ++cnt;
    }
    void addStr(string str) {
        TrNode *rt = this;
        for (int i = 0; i < str.length(); i++) {
            if (rt -> next[str[i] - 'a'] == NULL) {
                rt -> next[str[i] - 'a'] = new TrNode(26);    
            }
            rt = rt -> next[str[i] - 'a']; 
        }
        rt -> value = max(rt -> value, (int)str.length());
    }
    void buildGraph() {
        queue<TrNode *> Q;
        pre = this;
        for (int i = 0; i < next.size(); i++) {
            if (next[i] != NULL) {
                Q.push(next[i]);
                next[i] -> pre = this;
            }
            else next[i] = this;
        }
        while (!Q.empty()) {
            TrNode *rt = Q.front(); Q.pop();
            for (int i = 0; i < rt -> next.size(); i ++) {
                if (rt -> next[i] != NULL) {
                    Q.push(rt -> next[i]);
                    rt -> next[i] -> pre = rt -> pre -> next[i];
                    rt -> next[i] -> value = max(rt -> next[i] -> value, rt -> next[i] -> pre -> value);
                }
                else rt -> next[i] = rt -> pre -> next[i];
            }
        }
    }
    void print() {
        set<int> s; s.insert(no);
        queue<TrNode *> Q; Q.push(this);
        while (!Q.empty()) {
            TrNode *rt = Q.front(); Q.pop();
            for (int i = 0; i < rt -> next.size(); i ++) {
                if (rt -> next[i] != NULL) {
                    cout << rt -> no << ' ' << (char)('a' + i) << ' ' << rt -> next[i] -> no << endl;
                    if (s.find(rt -> next[i] -> no) == s.end()) {
                        Q.push(rt -> next[i]);
                        s.insert(rt -> next[i] -> no);
                    }
                }
            }
        }
    }
};

/* --------------------------------- */

int cmp(pair<int, int> a, pair<int, int> b) {
    return a.first < b.first;
}

int main() {
    ios;
    
    int n;
    cin >> n;
    vector<string> keys(n);
    for (int i = 0; i < n; i++) cin >> keys[i];
    TrNode *root = new TrNode(26);
    for (int i = 0; i < n; i++) root -> addStr(keys[i]);
    //root -> print();
    root -> buildGraph();
      //root -> print();
      vector<pair<int, int> > c;
      string str;
      cin >> str;
    for (int i = 0; i < str.length(); i++) {
        root = root -> next[str[i] - 'a'];
        if (root -> value > 0) {
            c.push_back(make_pair(i - root -> value + 1, i));
        }
    }
    sort(c.begin(), c.end(), cmp);
    int j = -1, k = 0;
    for (int i = 0; i < str.length(); i++) {
        while (k < c.size() && c[k].first <= i) {
            j = max(j, c[k].second);
            k++;
        }
        if (i > j) cout << str[i];
        else cout << '*';
    }
    
    return 0;
}

回文自动机

最后我们再介绍一个叫做回文自动机或者叫回文树的东西。比如对于S=”abbaabba”,构建的回文树或者回文自动机是这个样子:

回文自动机有2个初始节点,0和1,分别代表长度是偶数的回文串起点和长度是奇数的回文串起点。每个节点代表了S的一个回文子串。比如2号节点代表a,3号节点代表b,4号节点代表bb,6号节点代表aa。节点u连出一条标记字符c的边,指向v,表示v代表的回文串可以通过在u代表的回文串左右各加上字符c得到。比如7号代表baab,就是在6号aa的左右各加上一个'b'得到的

回文自动机可以解决(1)最长回文子串问题。这个问题一般用Manacher算法解决,并且Manacher足够优秀。但其实也可以用回文自动机解决。回文自动机中最深的节点就代表最长的回文子串。比如上图中9号节点,代表abbaabba

此外回文自动机还可以解决(2)S中本质不同的回文子串数目。比如对于S="abbaabba",本质不同的回文子串有a, b, aa, bb, abba, baab, bbaabb, abbaabba 8个。实际上就是回文自动机中除去01之外的剩余节点数目

对于字符串S,构造回文自动机有$O(S.len \times log(字符集大小))$的算法。大家有兴趣的话可以在网上找到资料

Last Modified: May 12, 2021
Archives Tip
QR Code for this page
Tipping QR Code