seq2seq构建生成式的聊天机器人

这里用的是google开源的nmt项目来简单实现的一个chatbot。很直觉的,把对话的语聊喂到nmt的模型进行训练,这样最终训练得到的模型就是一个简单的聊天机器人。

Google开源的tensorflow-nmt(seq2seq)模型,可以在下面这篇博客里详细了解下:tensorflow-nmt(seq2seq)模型

使用seq2seq框架完成一个聊天机器人构建的任务,我给大家准备了一些对话语料,我们使用这份数据来构建聊天机器人的AI应用。在此之前,我们先了解一下原有的翻译系统需要准备的语料格式,我们把中文数据处理成格式一致的形态。

我们先拉取一份样例数据。执行项目里的脚本文件进行下载。

1
!bash nmt/scripts/download_iwslt15.sh /tmp/nmt_data

查看一下包含的文件:

1
2
train.en  tst2012.en  tst2013.en  vocab.en
train.vi tst2012.vi tst2013.vi vocab.vi

看一下源语言与目标语言的格式,以及对应的数据量,可以看到都是做过tokenization之后的数据:

1
!head -10 /tmp/nmt_data/train.en

1
2
3
4
5
6
7
8
9
10
Rachel Pike : The science behind a climate headline
In 4 minutes , atmospheric chemist Rachel Pike provides a glimpse of the massive scientific effort behind the bold headlines on climate change , with her team -- one of thousands who contributed -- taking a risky flight over the rainforest in pursuit of data on a key molecule .
I 'd like to talk to you today about the scale of the scientific effort that goes into making the headlines you see in the paper .
Headlines that look like this when they have to do with climate change , and headlines that look like this when they have to do with air quality or smog .
They are both two branches of the same field of atmospheric science .
Recently the headlines looked like this when the Intergovernmental Panel on Climate Change , or IPCC , put out their report on the state of understanding of the atmospheric system .
That report was written by 620 scientists from 40 countries .
They wrote almost a thousand pages on the topic .
And all of those pages were reviewed by another 400-plus scientists and reviewers , from 113 countries .
It 's a big community . It 's such a big community , in fact , that our annual gathering is the largest scientific meeting in the world .

还需要准备好vocabulary词表:

1
!head -10 /tmp/nmt_data/vocab.en

1
2
3
4
5
6
7
8
9
10
<unk>
<s>
</s>
Rachel
:
The
science
behind
a
climate

处理数据

下面使用小黄鸡进行训练。

首先下载小黄鸡语料,并对它做一个处理,使得它符合seq2seq模型的输入格式。

1
2
!wget https://github.com/candlewill/Dialog_Corpus/raw/master/xiaohuangji50w_nofenci.conv.zip
!unzip xiaohuangji50w_nofenci.conv.zip

预处理数据,将问题与回答标识出来:

1
2
3
4
!perl -pi.bak -e 's/(E\s)/\1Q /g' xiaohuangji50w_nofenci.conv
!perl -pi.bak -e 's/(Q M)/Q/g' xiaohuangji50w_nofenci.conv
!perl -pi.bak -e 's/(M )/A /g' xiaohuangji50w_nofenci.conv
!head -30 xiaohuangji50w_nofenci.conv

用jieba工具分词:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import jieba

def split_conv(in_f, out_q, out_a):
out_question = open(out_q, 'w')
out_answer = open(out_a, 'w')
text = open(in_f).read().split("E\n")
for pair in text:
# 句子长度太短的问题对话,跳过
if len(pair)<=4:
continue
# 切分问题和回答
contents = pair.split("\n")
out_question.write(" ".join(jieba.lcut(contents[0].strip("Q ")))+"\n")
out_answer.write(" ".join(jieba.lcut(contents[1].strip("A ")))+"\n")
out_question.close()
out_answer.close()

in_f = "xiaohuangji50w_nofenci.conv"
out_q = 'question.file'
out_a = 'answer.file'
split_conv(in_f, out_q, out_a)

构建词表:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import re
def get_vocab(in_f, out_f):
vocab_dic = {}
for line in open(in_f, encoding='utf-8'):
words = line.strip().split(" ")
for word in words:
# 保留汉字内容
if not re.match(r"[\u4e00-\u9fa5]+", word):
continue
try:
vocab_dic[word] += 1
except:
vocab_dic[word] = 1
out = open(out_f, 'w', encoding='utf-8')
out.write("<unk>\n<s>\n</s>\n")
vocab = sorted(vocab_dic.items(),key = lambda x:x[1],reverse = True)
for word in [x[0] for x in vocab[:80000]]:
out.write(word)
out.write("\n")
out.close()

1
2
3
in_file = "question.file"
out_file = "./data/vocab.input"
get_vocab(in_file, out_file)
1
2
3
in_file = "answer.file"
out_file = "./data/vocab.output"
get_vocab(in_file, out_file)

切分训练,验证,测试集:

1
2
3
4
5
6
7
!mkdir data
!head -300000 question.file > data/train.input
!head -300000 answer.file > data/train.output
!head -380000 question.file | tail -80000 > data/val.input
!head -380000 answer.file | tail -80000 > data/val.output
!tail -75000 question.file > data/test.input
!tail -75000 answer.file > data/test.output

训练摘要生成模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
!python3 -m nmt.nmt \
--attention=scaled_luong \
--src=input --tgt=output \
--vocab_prefix=./data/vocab \
--train_prefix=./data/train \
--dev_prefix=./data/val \
--test_prefix=./data/test \
--out_dir=/tmp/nmt_attention_model \
--num_train_steps=12000 \
--steps_per_stats=1 \
--num_layers=2 \
--num_units=128 \
--dropout=0.2 \
--metrics=bleu

github