我们还从一个非常经典的题目出发,最长公共子串问题。给定两个字符串 S 和 T,求 S 和 T 的最长公共子串的长度。比如 abcdefg 和 abacabca 的最长公共子串是 abc
这是一道经典的动态规划问题,大致思路就是用 fi 表示同时以 S [i] 和 T [j] 结尾的最长公共子串的长度。下面看伪代码:
- Ans = 0
- For i = 1...S.len
- For j = 1...T.len
- If S[i] == T[i]
- F[i][j] = F[i - 1][j - 1] + 1
- If F[i][j] > Ans
- Ans = F[i][j]
- Print Ans
DP 的时间复杂度是 O (S.len * T.len),但其实这道题利用后缀自动机,时间复杂度只到 O (S.len + T.len),下图就是字符串”aabbabd” 的后缀自动机:后缀自动机就是能接受并且只接受 S 的后缀字符串。也就是说,你可以发现对于 S 的后缀,我们都可以从 S 出发沿着字符标示的路径(蓝色实线)转移,最终到达终结状态,也就是红色的节点
例如 "bd" 对应的路径是 S59,"abd" 对应的路径是 S189,"abbabd" 对应的路径是 S184679。而对于不是 S 后缀的字符串,你会发现从 S 出发,最后会到达非终结状态或者 “无路可走”。特别的,对于 S 的子串,最终会到达一个合法状态。例如 "abba" 路径是 S1846,"bbab" 路径是 S5467。而对于其他不是 S 子串的字符串,最终会 “无路可走”。 例如 "aba" 对应 S18X,"aaba" 对应 S123X。(X 表示没有转移匹配该字符)
首先我们先介绍一个概念:子串的结束位置集合 endpos。对于 S 的一个子串 t,endpos (t) = t 在 S 中所有出现的结束位置集合。还是以 S="aabbabd" 为例,endpos ("ab") = {3, 6},因为 "ab" 一共出现了 2 次,结束位置分别是 3 和 6。同理 endpos ("a") = {1, 2, 5}, endpos ("abba") = {5}
我们把 S 的所有子串的 endpos 都求出来。如果两个子串的 endpos 相等,就把这两个子串归为一类。最终这些 endpos 的等价类就构成的 SAM 的状态集合。例如对于 S="aabbabd":
其中 maxlen 是一个状态包含的子串中,最长子串的长度。有了后缀自动机和每个状态的 maxlen,我们就能求解 S 和 T 的最长公共子串了。具体做法是先求出 S 的后缀自动机,然后用 T 的每一个字符在 S 的后缀自动机上跑一遍。这里跑一遍的意思就是从初始状态开始,根据每一个字符 T [i] 在自动机的不同状态之间转移
举个例子,假设 S=aabbabd,S 的后缀自动机就是一开始的那张图
T=abbbaabbab。我们要找 S 和 T 的最长公共子串,就从状态 S 开始匹配,用 u 表示当前的状态,l 表示当前匹配的长度。基本思路就是如果当前状态 u 没有标识 T [i] 的蓝色实线转移,我们就沿着绿色虚线向上找,不断令 u 等于 u 的绿色虚线指向的状态,同时令 l = u.maxlen,直到 u 存在标识 T [i] 的转移或者 u 走到头了
如果最后 u 走到头了,我们就令 u=S,l=0,从 T [i+1] 开始重新匹配。否则,我们就沿着蓝色实线转移,令 u 等于 u 的蓝色实线指向的状态,同时令 l=l+1。然后再从 T [i+1] 继续匹配
具体来说,字符串 T 的完整匹配过程如下:
一开始 u=S,l=0,我们要匹配 T [1]=a,刚好 u 有标识 a 的边,所以我们直接移动到 u=1, l=1
现在 u=1, l=1,我们要匹配 T [2]=b,刚好 u 有标识 b 的边,所以我们直接移动到 u=8, l=2
现在 u=8, l=2, 我们要匹配 T [3]=b,刚好 u 有标识 b 的边,所以我们直接移动到 u=4, l=3
现在 u=4, l=3, 我们要匹配 T [4]=b,u 没有标识 b 的边,所以沿绿色退到 u=5, l=maxlen [5]=1
现在 u=5, l=1, 我们要匹配 T [4]=b,u 有标识 a 的边,所以我们移动到 u=4, l=2
现在 u=4, l=2, 我们要匹配 T [5]=a,u 有标识 a 的边,所以我们移动到 u=6, l=3
现在 u=6, l=3, 我们要匹配 T [6]=a,u 没有标识 a 的边,所以沿绿色退到 u=1, l=maxlen [1]=1
现在 u=1, l=1, 我们要匹配 T [6]=a,u 有标识 a 的边,所以我们移动到 u=2, l=2
现在 u=2, l=2, 我们要匹配 T [7]=b,u 有标识 b 的边,所以我们移动到 u=3, l=3
现在 u=3, l=3, 我们要匹配 T [8]=b,u 有标识 b 的边,所以我们移动到 u=4, l=4
现在 u=4, l=4, 我们要匹配 T [9]=a,u 有标识 a 的边,所以我们移动到 u=6, l=5
现在 u=6, l=5, 我们要匹配 T [10]=b,u 有标识 b 的边,所以我们移动到 u=7, l=6
这样 T 就匹配完了,在匹配过程中,l 的最大值就是最长公共子串的长度,也就是 6。实际上匹配 T [10] 的时候 l 等于 6,也意味着最长公共子串是 T [5]~T [10] 即:aabbab;同时 u=7 也意味着最长公共子串是状态 7 中长度为 6 的子串,从之前的表格中我们知道也是 aabbab
对于字符串 S,构建 S 的后缀自动机的复杂度是 O (S.len) 的;之后对 T 的每个字母跑一遍匹配,复杂度是 O (T.len)。所以总的复杂度是 O (S.len+T.len)。最后我们给出用后缀自动机求最长公共子串的代码:
- #include<iostream>
- #include<string>
- using namespace std;
- const int MAXL = 1000000;
- string s, t;
- int n = 0, len, st;
- int maxlen[2 * MAXL + 10], minlen[2 * MAXL + 10], trans[2 * MAXL + 10][26], slink[2 * MAXL + 10];
- int new_state(int _maxlen, int _minlen, int* _trans, int _slink) {
- maxlen[n] = _maxlen;
- minlen[n] = _minlen;
- for(int i = 0; i < 26; i++) {
- if(_trans == NULL)
- trans[n][i] = -1;
- else
- trans[n][i] = _trans[i];
- }
- slink[n] = _slink;
- return n++;
- }
- int add_char(char ch, int u) {
- int c = ch - 'a';
- int z = new_state(maxlen[u] + 1, -1, NULL, -1);
- int v = u;
- while(v != -1 && trans[v][c] == -1) {
- trans[v][c] = z;
- v = slink[v];
- // cout << v << endl;
- }
- // cout << z << endl;
- if(v == -1) {
- minlen[z] = 1;
- slink[z] = 0;
- return z;
- }
- int x = trans[v][c];
- if(maxlen[v] + 1 == maxlen[x]) {
- minlen[z] = maxlen[x] + 1;
- slink[z] = x;
- return z;
- }
- int y = new_state(maxlen[v] + 1, -1, trans[x], slink[x]);
- slink[y] = slink[x];
- minlen[x] = maxlen[y] + 1;
- slink[x] = y;
- minlen[z] = maxlen[y] + 1;
- slink[z] = y;
- int w = v;
- while(w != -1 && trans[w][c] == x) {
- trans[w][c] = y;
- w = slink[w];
- }
- minlen[y] = maxlen[slink[y]] + 1;
- return z;
- }
- int main()
- {
- cin >> s >> t;
- len = s.size();
- st = new_state(0, 0, NULL, -1);
- // cout << trans[st][0] << endl;
- int u = st;
- for(int i = 0; i < len; i++) {
- int temp = add_char(s[i], u);
- u = temp;
- }
- u = st;
- int ans = 0, cur = 0;
- for(int i = 0; i < t.size(); i++) {
- int c = t[i] - 'a';
- while(u != -1 && trans[u][c] == -1) {
- u = slink[u];
- cur = maxlen[u];
- }
- if(u == -1) {
- u = 0;
- cur = 0;
- continue;
- }
- u = trans[u][c];
- cur++;
- if(cur > ans) ans = cur;
- }
- cout << ans << endl;
- return 0;
- }