这里用的是google开源的nmt项目来简单实现的一个chatbot。很直觉的,把对话的语聊喂到nmt的模型进行训练,这样最终训练得到的模型就是一个简单的聊天机器人。
Google开源的tensorflow-nmt(seq2seq)模型,可以在下面这篇博客里详细了解下:tensorflow-nmt(seq2seq)模型
使用seq2seq框架完成一个聊天机器人构建的任务,我给大家准备了一些对话语料,我们使用这份数据来构建聊天机器人的AI应用。在此之前,我们先了解一下原有的翻译系统需要准备的语料格式,我们把中文数据处理成格式一致的形态。
我们先拉取一份样例数据。执行项目里的脚本文件进行下载。1
!bash nmt/scripts/download_iwslt15.sh /tmp/nmt_data
查看一下包含的文件:1
2train.en tst2012.en tst2013.en vocab.en
train.vi tst2012.vi tst2013.vi vocab.vi
看一下源语言与目标语言的格式,以及对应的数据量,可以看到都是做过tokenization之后的数据:1
!head -10 /tmp/nmt_data/train.en
1 | Rachel Pike : The science behind a climate headline |
还需要准备好vocabulary词表:1
!head -10 /tmp/nmt_data/vocab.en
1 | <unk> |
处理数据
下面使用小黄鸡进行训练。
首先下载小黄鸡语料,并对它做一个处理,使得它符合seq2seq模型的输入格式。1
2!wget https://github.com/candlewill/Dialog_Corpus/raw/master/xiaohuangji50w_nofenci.conv.zip
!unzip xiaohuangji50w_nofenci.conv.zip
预处理数据,将问题与回答标识出来:1
2
3
4!perl -pi.bak -e 's/(E\s)/\1Q /g' xiaohuangji50w_nofenci.conv
!perl -pi.bak -e 's/(Q M)/Q/g' xiaohuangji50w_nofenci.conv
!perl -pi.bak -e 's/(M )/A /g' xiaohuangji50w_nofenci.conv
!head -30 xiaohuangji50w_nofenci.conv
用jieba工具分词:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21import jieba
def split_conv(in_f, out_q, out_a):
out_question = open(out_q, 'w')
out_answer = open(out_a, 'w')
text = open(in_f).read().split("E\n")
for pair in text:
# 句子长度太短的问题对话,跳过
if len(pair)<=4:
continue
# 切分问题和回答
contents = pair.split("\n")
out_question.write(" ".join(jieba.lcut(contents[0].strip("Q ")))+"\n")
out_answer.write(" ".join(jieba.lcut(contents[1].strip("A ")))+"\n")
out_question.close()
out_answer.close()
in_f = "xiaohuangji50w_nofenci.conv"
out_q = 'question.file'
out_a = 'answer.file'
split_conv(in_f, out_q, out_a)
构建词表:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20import re
def get_vocab(in_f, out_f):
vocab_dic = {}
for line in open(in_f, encoding='utf-8'):
words = line.strip().split(" ")
for word in words:
# 保留汉字内容
if not re.match(r"[\u4e00-\u9fa5]+", word):
continue
try:
vocab_dic[word] += 1
except:
vocab_dic[word] = 1
out = open(out_f, 'w', encoding='utf-8')
out.write("<unk>\n<s>\n</s>\n")
vocab = sorted(vocab_dic.items(),key = lambda x:x[1],reverse = True)
for word in [x[0] for x in vocab[:80000]]:
out.write(word)
out.write("\n")
out.close()
1 | in_file = "question.file" |
1 | in_file = "answer.file" |
切分训练,验证,测试集:1
2
3
4
5
6
7!mkdir data
!head -300000 question.file > data/train.input
!head -300000 answer.file > data/train.output
!head -380000 question.file | tail -80000 > data/val.input
!head -380000 answer.file | tail -80000 > data/val.output
!tail -75000 question.file > data/test.input
!tail -75000 answer.file > data/test.output
训练摘要生成模型:1
2
3
4
5
6
7
8
9
10
11
12
13
14!python3 -m nmt.nmt \
--attention=scaled_luong \
--src=input --tgt=output \
--vocab_prefix=./data/vocab \
--train_prefix=./data/train \
--dev_prefix=./data/val \
--test_prefix=./data/test \
--out_dir=/tmp/nmt_attention_model \
--num_train_steps=12000 \
--steps_per_stats=1 \
--num_layers=2 \
--num_units=128 \
--dropout=0.2 \
--metrics=bleu