AC自动机

AC自动机

  AC自动机很类似于kmp算法,kmp算法适合于一个串与另一个串匹配,而AC自动机用于一坨串匹配一个串。某种意义上来说AC自动机=kmp+trie。

  学习博客:
  1.https://www.cnblogs.com/cmmdc/p/7337611.html
  2.https://www.luogu.com.cn/blog/juruohyfhaha/ac-zi-dong-ji

前置知识:

1.trie树(把那一坨模式串插到trie树上)
2.kmp算法

这个算法的一个过程大致分为三个过程:
1.insert_() 把模式串插到trie树上
2.get_fail() 相当于kmp算法的get_next()
3.ac_query() 拿那一个长串去匹配

Fail指针:

Fail指针的实质含义是什么呢?

如果一个点i的Fail指针指向j。那么root到j的字符串是root到i的字符串的一个后缀。

如何求fail

1.每个点的fail一定比这个点所在的深度小
2.第一层所有点的fail都是root
3.设点i的父节点是fa,fa的fail指针指向fafail,则点i的fail指针应该指向fafail节点的子节点中与点i值相同的点
4.处理点i,必须先处理完fa,所以用bfs

一些细节

1.第一层所有点的fail都是root
2.如果点i不存在,则将它设为fafail节点的子节点中与点i值相同的点
3.无论fafail存不存在与i相同的节点j,都将i的fail指针指向j,因为在2中已经处理了。

模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
void get_fail()
{
//细节1
for(int i = 0; i < 26; ++i)
{
if(ac[0].vis[i])
{
ac[ac[0].vis[i]].fail = 0;
q.push(ac[0].vis[i]);
}
}

while(!q.empty())
{
u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(ac[u].vis[i])//细节3
{
ac[ac[u].vis[i]].fail = ac[ac[u].fail].vis[i];
q.push(ac[u].vis[i]);
}
else ac[u].vis[i] = ac[ac[u].fail].vis[i];//细节2
}
}
}

查询

1.查询过程类似与trie树
2.用一个标记记录一下该节点是不是某个串的结尾,是否记录过了
3.如果该点匹配成功,那么他的fail也一定成功,fail的fail也一定。。。

模板(以最基础的n个模式串中有几个出现这类题为例)

模板1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
int ac_query()
{
len = s.size();
int now = 0, ans = 0, p;
for(int i = 0; i < len; ++i)
{
id = s[i] - 'a';
while(ac[now].vis[id] == 0 && now != 0) now = ac[now].fail;
now = ac[now].vis[id];

now = (now == 0) ? 0 : now;
p = now;
while(p != 0 && ac[p].en != -1)
{
ans += ac[p].en;
ac[p].en = -1;
p = ac[p].fail;
}
}
return ans;
}

模板2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int ac_query()
{
int len = s.size();
int now = 0, ans = 0;
for(int i = 0; i < len; ++i)
{
now = ac[now].vis[s[i] - 'a'];
for(int j = now; j && ac[j].en != -1; j = ac[j].fail)
{
ans += ac[j].en;
ac[j].en = -1;
}
}
return ans;
}

几道入门模板题

1.P3808 【模板】AC自动机(简单版)

img

AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#include <bits/stdc++.h>

using namespace std;

struct tree
{
int fail;
int vis[30];
int en;
}ac[1000050];
int tot, len, root, id;
char ch[1000050];
string s;

inline void insert_(char *str)
{
len = strlen(str);
root = 0;
for(int i = 0; i < len; ++i)
{
id = str[i] - 'a';
if(!ac[root].vis[id]) ac[root].vis[id] = ++tot;
root = ac[root].vis[id];
}
ac[root].en += 1;
}

void get_fail()
{
queue<int> q;
for(int i = 0; i < 26; ++i)
{
if(ac[0].vis[i])
{
ac[ac[0].vis[i]].fail = 0;
q.push(ac[0].vis[i]);
}
}

while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(ac[u].vis[i])
{
ac[ac[u].vis[i]].fail = ac[ac[u].fail].vis[i];
q.push(ac[u].vis[i]);
}
else ac[u].vis[i] = ac[ac[u].fail].vis[i];
}
}
}

int ac_query()
{
int len = s.size();
int now = 0, ans = 0;
for(int i = 0; i < len; ++i)
{
now = ac[now].vis[s[i] - 'a'];
for(int j = now; j && ac[j].en != -1; j = ac[j].fail)
{
ans += ac[j].en;
ac[j].en = -1;
}
}
return ans;
}

int main()
{
int t;
scanf("%d", &t);
while(t--)
{
scanf("%s", ch);
insert_(ch);
}
cin >> s;
get_fail();
cout << ac_query() << '\n';
return 0;
}

2.【模板】AC自动机(加强版)

AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#include <bits/stdc++.h>

using namespace std;

struct tree
{
int fail;
int vis[26];
int en;
}ac[10600];
int cnt = 0;
struct node
{
int num, pos;
}ans[10600];

bool cmp(node a, node b)
{
if(a.num != b.num) return a.num > b.num;
else return a.pos < b.pos;
}

char ch[160][80];
char s[1000050];

//注意一下清数组的技巧
inline void init(int x)
{
memset(ac[x].vis, 0, sizeof(ac[x].vis));
ac[x].fail = ac[x].en = 0;
}

inline void build(char *s, int k)
{
int l = strlen(s);
int now = 0;
for(int i = 0; i < l; ++i)
{
if(ac[now].vis[s[i] - 'a'] == 0)
{
ac[now].vis[s[i] - 'a'] = ++cnt;
init(cnt);
}
now = ac[now].vis[s[i] - 'a'];
}
ac[now].en = k;
}

void get_fail()
{
queue<int> q;
for(int i = 0; i < 26; ++i)
{
if(ac[0].vis[i])
{
ac[ac[0].vis[i]].fail = 0;
q.push(ac[0].vis[i]);
}
}
while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(ac[u].vis[i])
{
ac[ac[u].vis[i]].fail = ac[ac[u].fail].vis[i];
q.push(ac[u].vis[i]);
}
//不存在这个子节点
else ac[u].vis[i] = ac[ac[u].fail].vis[i];
//当前节点的这个子节点指向当
//前节点fail指针的这个子节点
}
}
}

//区别主要在这里
void ac_query()
{
int l = strlen(s);
int now = 0;
for(int i = 0; i < l; ++i)
{
now = ac[now].vis[s[i] - 'a'];
for(int j = now; j ; j = ac[j].fail) ans[ac[j].en].num++;
}
}

int main()
{
int n;
while(~scanf("%d", &n) && n)
{
cnt = 0;
init(0);
for(int i = 1; i <= n; ++i)
{
scanf("%s", ch[i]);
ans[i].num = 0;
ans[i].pos = i;
build(ch[i], i);
}
ac[0].fail = 0;
get_fail();

scanf("%s", s);
ac_query();
sort(ans + 1, ans + n + 1, cmp);
printf("%d\n", ans[1].num);
printf("%s\n", ch[ans[1].pos]);
for(int i = 2; i <= n; ++i)
{
if(ans[i].num == ans[i - 1].num) printf("%s\n", ch[ans[i].pos]);
else break;
}
}
return 0;
}

3.P5357 【模板】AC自动机(二次加强版)

分析:
这个题如果还用普通的ac自动机会超时,就是匹配成功之后再去找fail指针指向的点然后一直找到根节点,这样效率很低,有一个方法优化。
对于匹配成功的点,我们进行一些标记,而不是沿着他的fail指针一直走到根。
我们把fail指针看成有向边,那这颗trie树实际上也是一张拓扑图,然后我们使用拓扑排序,把拓扑路径上的点的标记累计即可。
详解参考dalao の blog

AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <bits/stdc++.h>

using namespace std;

int n;
char s[2000050];
char t[2000050];
struct node
{
int vis[26];
int flag, ans, fail;
}ac[200050];
queue<int> q;
int in[200050];
int ans[200050];
int mp[200050];
int tot, id, len, u, v;

void insert_(char *str, int k)
{
u = 0;
len = strlen(str);
for(int i = 0; i < len; ++i)
{
id = str[i] - 'a';
if(!ac[u].vis[id]) ac[u].vis[id] = ++tot;
u = ac[u].vis[id];
}
if(!ac[u].flag) ac[u].flag = k;
mp[k] = ac[u].flag;
}

void get_fail()
{
for(int i = 0; i < 26; ++i)
{
if(ac[0].vis[i])
{
ac[ac[0].vis[i]].fail = 0;
q.push(ac[0].vis[i]);
}
}

while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(ac[u].vis[i])
{
ac[ac[u].vis[i]].fail = ac[ac[u].fail].vis[i];
++in[ac[ac[u].vis[i]].fail];
q.push(ac[u].vis[i]);
}
else ac[u].vis[i] = ac[ac[u].fail].vis[i];
}
}
}

void ac_query(char *str)
{
len = strlen(str);
u = 0;

for(int i = 0; i < len; ++i)
{
id = str[i] - 'a';
u = ac[u].vis[id];
ac[u].ans++;
}
}

void topo()
{
for(int i = 1; i <= tot; ++i) if(!in[i]) q.push(i);

while(!q.empty())
{
u = q.front();
q.pop();
ans[ac[u].flag] = ac[u].ans;
v = ac[u].fail;
in[v]--;
ac[v].ans += ac[u].ans;
if(!in[v]) q.push(v);
}
}

int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; ++i)
{
scanf("%s", t);
insert_(t, i);
}
get_fail();
scanf("%s", s);
ac_query(s);
topo();
for(int i = 1; i <= n; ++i) printf("%d\n", ans[mp[i]]);
return 0;
}

4.P3966 [TJOI2013]单词

分析:
同第三题,只是那一个长串没直接告诉你,实际长串就是那一坨模式串每个之间加上一个奇怪的字符‘#’间隔开。
然后我们ac_query()的时候,遇到‘#’,直接跳回根节点并且continue即可。
AC代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#include <bits/stdc++.h>

using namespace std;

int n;
char s[2000050];
char t[2000050];
struct node
{
int vis[26];
int flag, ans, fail;
}ac[200050];
queue<int> q;
int in[200050];
int ans[200050];
int mp[200050];
int tot, id, len, u, v;

void insert_(char *str, int k)
{
u = 0;
len = strlen(str);
for(int i = 0; i < len; ++i)
{
id = str[i] - 'a';
if(!ac[u].vis[id]) ac[u].vis[id] = ++tot;
u = ac[u].vis[id];
}
if(!ac[u].flag) ac[u].flag = k;
mp[k] = ac[u].flag;
}

void get_fail()
{
for(int i = 0; i < 26; ++i)
{
if(ac[0].vis[i])
{
ac[ac[0].vis[i]].fail = 0;
q.push(ac[0].vis[i]);
}
}

while(!q.empty())
{
int u = q.front();
q.pop();
for(int i = 0; i < 26; ++i)
{
if(ac[u].vis[i])
{
ac[ac[u].vis[i]].fail = ac[ac[u].fail].vis[i];
++in[ac[ac[u].vis[i]].fail];
q.push(ac[u].vis[i]);
}
else ac[u].vis[i] = ac[ac[u].fail].vis[i];
}
}
}

void ac_query(char *str)
{
len = strlen(str);
u = 0;

for(int i = 0; i < len; ++i)
{
id = str[i] - 'a';
u = ac[u].vis[id];
ac[u].ans++;
}
}

void topo()
{
for(int i = 1; i <= tot; ++i) if(!in[i]) q.push(i);

while(!q.empty())
{
u = q.front();
q.pop();
ans[ac[u].flag] = ac[u].ans;
v = ac[u].fail;
in[v]--;
ac[v].ans += ac[u].ans;
if(!in[v]) q.push(v);
}
}

int main()
{
scanf("%d", &n);
for(int i = 1; i <= n; ++i)
{
scanf("%s", t);
insert_(t, i);
}
get_fail();
scanf("%s", s);
ac_query(s);
topo();
for(int i = 1; i <= n; ++i) printf("%d\n", ans[mp[i]]);
return 0;
}