给两个字符串S和T,长度分别为n和m。且从0开始编号。

定义next数组:next[i]表示T[i…m-1]和T的最长公共前缀。

定义extend数组:extend[i]表示S[i…n-1]和T的最长公共前缀。

例子:

i 0 1 2 3 4 5 6 7
S a a a a a b b b
T a a a a a c
next[i] 6 4 3 2 1 0
extend[i] 5 4 3 2 1 0 0 0

若某个extend[i]=m,则说明在S中找到了一个T的匹配,且可以在S中找到所有T的匹配,所以说是扩展KMP算法。

流程

1

假设当前S遍历到i,即extend[0…i-1]都已经计算得到。两个变量a和p,p是从a开始S和T匹配成功的最远处。即S[a…p-1]==T[0…p-a-1]

2

如果i+next[i-a]<p,如上图,三个蓝色的椭圆形长度相同,为next[i-a],此时extend[i] = next[i-a]

3

如果i+next[i-a]==p,S[p]!=T[p-a]且T[p-a]!=T[p-i],但是S[p]有可能等于T[p-i],所以将S[p]和T[p-i]开始朴素的往后匹配。

4

如果i+next[i-a]>p呢?因为s[i…p)==T[i-a…p-a)相同,S[p]!=T[p-a],但是T[p-a]==T[p-i],所以S[p]!=T[p-i],所以没必要匹配了,这时候extend[i]=p-i

求解extend数组是S和T在匹配,求解next数组是T和T在匹配,两个算法类似。

模板

int nxt[maxn], extend[maxn];
char s[maxn], t[maxn];

void getNext()
{
	int n = strlen(t);
	nxt[0] = n;
	int a=0, p=0;
	for(int i=1;i<n;i++)
	{
	    // i>=p是S和T没有匹配字母的情况
		if(i>=p || i+nxt[i-a]>=p)
		{
			if(i>=p)p=i;
			// 朴素往后匹配
			while(p<n && t[p]==t[p-i])p++;
			nxt[i] = p-i;
			a=i;
		}
		else
			nxt[i] = nxt[i-a];
	}
}

// 注释参考上面
void getExtend()
{
	int n = strlen(s), m = strlen(t);
	int a=0, p=0;
	getNext();
	for(int i=0;i<n;i++)
	{
		if(i>=p || i+nxt[i-a]>=p)
		{
			if(i>=p)p=i;
			while(p<n && p-i<m && s[p]==t[p-i])p++;
			extend[i] = p-i;
			a=i;
		}
		else
			extend[i] = nxt[i-a];
	}
}

例题

洛谷P5410

5

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define IO ios::sync_with_stdio(0)
typedef pair<int, int> P;
const int maxn = 2e5+5;
const ll mod = 1e9+7;
const double eps = 1e-9;
using namespace std;


int nxt[maxn], extend[maxn];
char s[maxn], t[maxn];

void getNext()
{
	int n = strlen(t);
	nxt[0] = n;
	int a=0, p=0;
	for(int i=1;i<n;i++)
	{
		if(i>=p || i+nxt[i-a]>=p)
		{
			if(i>=p)p=i;
			while(p<n && t[p]==t[p-i])p++;
			nxt[i] = p-i;
			a=i;
		}
		else
			nxt[i] = nxt[i-a];
	}
}

void getExtend()
{
	int n = strlen(s), m = strlen(t);
	int a=0, p=0;
	getNext();
	for(int i=0;i<n;i++)
	{
		if(i>=p || i+nxt[i-a]>=p)
		{
			if(i>=p)p=i;
			while(p<n && p-i<m && s[p]==t[p-i])p++;
			extend[i] = p-i;
			a=i;
		}
		else
			extend[i] = nxt[i-a];
	}
}

int main()
{
	scanf("%s", s);
	scanf("%s", t);
	getExtend();
	int lens = strlen(s), lent = strlen(t);
	for(int i=0;i<lent;i++)
		printf("%d ", nxt[i]);
	printf("\n");
	for(int i=0;i<lens;i++)
		printf("%d ", extend[i]);
	printf("\n");
	return 0;
}