例2 题目链接:HDOJ1711
这个题目的大意就是给你一个整数数组A=[A1, A2, … AN]和一个整数数组B=[B1, B2, … BM],让你找到一个位置K使得从A[K]开始的M个整数与[B1, B2, … BM]相等。如果没有满足条件的位置K就输出-1
这其实就是一个整数版的KMP。我们之前讲KMP的背景都是字符串匹配,其实也可以拿来做数组序列匹配。只要把整数相等对应成字符相等就可以了
#include <cstring>
#include <cstdio>
using namespace std;
int nxt[1000001];
int ans,t;
int p[1000001],s[1000001];
int n,m;
int main()
{
scanf("%d",&t);
while(t--)
{
ans = -1;
scanf("%d %d",&n,&m);
for(int i = 1;i <= n;i++)
scanf("%d",&s[i]);
for(int i = 1;i <= m;i++)
scanf("%d",&p[i]);
nxt[0] = -1;
int j = -1;
for(int i = 1;i <= m;i++)
{
while(j >= 0 && p[j + 1] != p[i])
j = nxt[j];
nxt[i] = ++j;
}
j = 0;
for(int i = 1;i <= n;i++)
{
while(j >= 0 && p[j + 1] != s[i])
j = nxt[j];
if(++j == m)
{
ans = i - m + 1;
break;
}
}
printf("%d\n",ans);
}
return 0;
}
例3 题目链接:HDOJ2087
这道题目的大意是给定两个字符串P和S。问P能从S中剪出来多少次。这道题看上去与上一节hihoCoder #1015题类似。但是其实是有一个关键条件不一样,这道题不允许剪出来P在S中重叠。比如在之前的题目中ADA在ADADADA中出现了3次,本题中就只能剪出来2次。这道题的样例也告诉你,从aaaaaa中最多剪出来3个aa
这道题有两种思路。第一种思路是先不管三七二十一,用hihoCoder #1015中的匹配方法,把所有出现P的位置都求出来,存在一个vector<int> start_pos里:
for(int i = 1;i <= n;i++)
{
while(j >= 0 && p[j + 1] != s[i])
j = nxt[j];
if(++j == m)
{
start_pos.push_back(i - m + 1);
j = nxt[j];
}
}
假如输入是aaaaaa和aa,那么start_pos中保存的就是[1, 2, 3, 4, 5]。现在我们的问题就变成要从start_pos中选出尽量多的整数,并且保证相邻的两个整数差不小于P.len。比如上面的例子中就应该选出[1, 3, 5]
这个问题是可以贪心选的。大致思路就是一定选第一个最小的整数,然后从小到大检查每一个整数,如果能选(与上一个差不小于P.len)就一定选,不能选再看下一个。伪代码如下:
Last_Pick = -P.len
Cnt = 0
For i:start_Pos
If i - Last_Pos >= P.len
Last_Pick = i
Cnt++
Print Cnt
第二种思路。这种思路是建立在贪心算法基础上的。假设S[L..L+P.len-1]是P第一次出现的位置。根据贪心的思路,这个S[L..L+P.len-1]是一定要被剪出来的。在之前的题目中,找到一个完整匹配j==m的时候是令j = next[j] (也就是j=next[m]),让S[i]与P[next[m]+1]继续匹配。这里S[i]就是S[L+P.len]
;本题要避免剪出重叠的P,要把S[L..L+P.len-1]”剪掉“,显然剪掉之后S[L+P.len]只能从P[1]开始匹配起了,而不能从P[next[m]+1]开始了。所以我们直接把j = next[j]改成j = 0即可让KMP按我们的要求避免重叠
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
int nxt[1000001];
int ans,n,m;
char p[1000001],s[1000001];
int main()
{
scanf("%s",s + 1);
while(strcmp(s + 1,"#") != 0)
{
scanf("%s",p + 1);
n = strlen(s + 1);
m = strlen(p + 1);
ans = 0;
nxt[0] = -1;
int j = -1;
for(int i = 1;i <= m;i++)
{
while(j >= 0 && p[j + 1] != p[i])
j = nxt[j];
nxt[i] = ++j;
}
j = 0;
for(int i = 1;i <= n;i++)
{
while(j >= 0 && p[j + 1] != s[i])
j = nxt[j];
if(++j == m)
{
ans++;
j = 0;
}
}
cout << ans << endl;
scanf("%s",s + 1);
}
return 0;
}
上面的代码同上一节hihoCoder #1015题除了处理输入部分之外,就只有27行不一样。这道题的代码27是直接j = 0,从而强制下一个字符S[i]从P[1]开始匹配。第二个思路需要对KMP的理解更加清楚,才能更加灵活的运用
例4 题目链接:hihoCoder1625
题目大意是给定两个字符串A和B,请你求出字符串A最少重复几次才能使得B是A的子串。例如A="hiho",B="hohihohi"。则A重复3次之后变为"hihohihohiho",这时B是A的子串。如果没有解输出-1
这题降低时间复杂度的关键是一次性得到足够长的S,进行KMP;而不能从一个比较短的S开始尝试KMP,匹配不成就在S后面接一个A,再匹配不成再接一个A……
于是问题来了,S多长才是足够长?这个问题的答案是A.len+B.len。证明比较简单,留给大家自己思考。于是我们有一个伪代码:
S = 空串
While S.len < A.len + B.len
S = S + A//+这里是字符串连接
L = KMP(B,S)
If L > 0
Print L + B.len - 2 / A.len + 1//长度L + B.len - 1需要几个A拼起来
Else
Print -1
不过这道题目还有一种更简单的实现方法。我们可以并不真的拼出一个字符串S。而是“假装”有一个字符串S。当我们KMP的过程中需要用的S[i]的时候,就用A[(i-1) % A.len+1]代替。代码如下:
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
int nxt[1000001];
int ans,n,m,t;
char p[1000001],s[1000001];
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%s%s",s + 1,p + 1);
n = strlen(s + 1);
m = strlen(p + 1);
ans = -1;
nxt[0] = -1;
int j = -1;
for(int i = 1;i <= m;i++)
{
while(j >= 0 && p[j + 1] != p[i])
j = nxt[j];
nxt[i] = ++j;
}
j = 0;
for(int i = 1;i <= n + m;i++)
{
char c = s[(i - 1) % n + 1];
while(j >= 0 && p[j + 1] != c)
j = nxt[j];
if(++j == m)
{
ans = (i - 1) / n + 1;
break;
}
}
cout << ans << endl;
}
return 0;
}
注意第28行这里的s实际上是题目中的A。我们实际上并没有S[i],而是用A中的对应字符代替
例5 题目链接:HDOJ2594
这道题目的大意是给定两个字符串P和S,让你找一个最长的字符串T满足T是P的前缀,也是S的后缀
如果你对KMP的过程非常清楚的话,你会发现KMP用P去匹配S的过程中,如果S[i]匹配上了P[j]那就说明P[1..j]是S[1..i]的最长后缀,同时P[1..j]显然是P的前缀。所以我们只要找到第一个匹配上S[n]的P[j]的即可,这时j就是答案
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
int nxt[1000001];
int ans,n,m,t;
char p[1000001],s[1000001];
int main()
{
while(scanf("%s%s",p + 1,s + 1) != EOF)
{
n = strlen(s + 1);
m = strlen(p + 1);
nxt[0] = -1;
int j = -1;
for(int i = 1;i <= m;i++)
{
while(j >= 0 && p[j + 1] != p[i])
j = nxt[j];
nxt[i] = ++j;
}
j = 0;
for(int i = 1;i <= n;i++)
{
char c = s[(i - 1) % n + 1];
while(j >= 0 && p[j + 1] != c)
j = nxt[j];
if(++j == m)
{
if(i == n)
break;
j = nxt[j];
}
}
if(j == 0)
cout << 0 << endl;
else
{
p[j + 1] = 0;
cout << p + 1 << ' ' << j << endl;
}
}
return 0;
}