数据准备:http://www.manythings.org/anki/cmn-eng.zip

源代码:https://github.com/pjgao/seq2seq_keras

参考:https://blog.csdn.net/PIPIXIU/article/details/81016974

导入库

执行代码:

from keras.layers import Input,LSTM,Dense
from keras.models import Model,load_model
from keras.utils import plot_model

import pandas as pd
import numpy as np

返回信息:

Using TensorFlow backend.

执行代码:

def create_model(n_input,n_output,n_units):
    #训练阶段
    #encoder
    encoder_input = Input(shape = (None, n_input))
    #encoder输入维度n_input为每个时间步的输入xt的维度,这里是用来one-hot的英文字符数
    encoder = LSTM(n_units, return_state=True)
    #n_units为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,c
    _,encoder_h,encoder_c = encoder(encoder_input)
    encoder_state = [encoder_h,encoder_c]
    #保留下来encoder的末状态作为decoder的初始状态

    #decoder
    decoder_input = Input(shape = (None, n_output))
    #decoder的输入维度为中文字符数
    decoder = LSTM(n_units,return_sequences=True, return_state=True)
    #训练模型时需要decoder的输出序列来与结果对比优化,故return_sequences也要设为True
    decoder_output, _, _ = decoder(decoder_input,initial_state=encoder_state)
    #在训练阶段只需要用到decoder的输出序列,不需要用最终状态h.c
    decoder_dense = Dense(n_output,activation='softmax')
    decoder_output = decoder_dense(decoder_output)
    #输出序列经过全连接层得到结果

    #生成的训练模型
    model = Model([encoder_input,decoder_input],decoder_output)
    #第一个参数为训练模型的输入,包含了encoder和decoder的输入,第二个参数为模型的输出,包含了decoder的输出

    #推理阶段,用于预测过程
    #推断模型—encoder
    encoder_infer = Model(encoder_input,encoder_state)

    #推断模型-decoder
    decoder_state_input_h = Input(shape=(n_units,))
    decoder_state_input_c = Input(shape=(n_units,))
    decoder_state_input = [decoder_state_input_h, decoder_state_input_c]#上个时刻的状态h,c   

    decoder_infer_output, decoder_infer_state_h, decoder_infer_state_c = decoder(decoder_input,initial_state=decoder_state_input)
    decoder_infer_state = [decoder_infer_state_h, decoder_infer_state_c]#当前时刻得到的状态
    decoder_infer_output = decoder_dense(decoder_infer_output)#当前时刻的输出
    decoder_infer = Model([decoder_input]+decoder_state_input,[decoder_infer_output]+decoder_infer_state)

    return model, encoder_infer, decoder_infer
N_UNITS = 256
BATCH_SIZE = 64
EPOCH = 50
NUM_SAMPLES = 10000

数据处理

data_path = 'data/cmn.txt'

读取数据

执行代码:

df = pd.read_table(data_path,header=None).iloc[:NUM_SAMPLES,:,]
df.columns=['inputs','targets']

df['targets'] = df['targets'].apply(lambda x: '\t'+x+'\n')

input_texts = df.inputs.values.tolist()
target_texts = df.targets.values.tolist()

input_characters = sorted(list(set(df.inputs.unique().sum())))
target_characters = sorted(list(set(df.targets.unique().sum())))

返回信息:

C:\3rd\Anaconda2\lib\site-packages\ipykernel_launcher.py:1: FutureWarning: read_table is deprecated, use read_csv instead, passing sep='\t'.
  """Entry point for launching an IPython kernel.

执行代码:

INUPT_LENGTH = max([len(i) for i in input_texts])
OUTPUT_LENGTH = max([len(i) for i in target_texts])
INPUT_FEATURE_LENGTH = len(input_characters)
OUTPUT_FEATURE_LENGTH = len(target_characters)

向量化

执行代码:

encoder_input = np.zeros((NUM_SAMPLES,INUPT_LENGTH,INPUT_FEATURE_LENGTH))
decoder_input = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH))
decoder_output = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH))
input_dict = {char:index for index,char in enumerate(input_characters)}
input_dict_reverse = {index:char for index,char in enumerate(input_characters)}
target_dict = {char:index for index,char in enumerate(target_characters)}
target_dict_reverse = {index:char for index,char in enumerate(target_characters)}
for seq_index,seq in enumerate(input_texts):
    for char_index, char in enumerate(seq):
        encoder_input[seq_index,char_index,input_dict[char]] = 1
for seq_index,seq in enumerate(target_texts):
    for char_index,char in enumerate(seq):
        decoder_input[seq_index,char_index,target_dict[char]] = 1.0
        if char_index > 0:
            decoder_output[seq_index,char_index-1,target_dict[char]] = 1.0

观察向量化的数据

执行代码:

''.join([input_dict_reverse[np.argmax(i)] for i in encoder_input[0] if max(i) !=0])

返回信息:

'Hi.'

执行代码:

''.join([target_dict_reverse[np.argmax(i)] for i in decoder_output[0] if max(i) !=0])

返回信息:

'嗨。\n'

执行代码:

''.join([target_dict_reverse[np.argmax(i)] for i in decoder_input[0] if max(i) !=0])

返回信息:

'\t嗨。\n'

创建模型

执行代码:

model_train, encoder_infer, decoder_infer = create_model(INPUT_FEATURE_LENGTH, OUTPUT_FEATURE_LENGTH, N_UNITS)

返回信息:

WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.

执行代码:

#查看模型结构
plot_model(to_file='model.png',model=model_train,show_shapes=True)
plot_model(to_file='encoder.png',model=encoder_infer,show_shapes=True)
plot_model(to_file='decoder.png',model=decoder_infer,show_shapes=True)
model_train.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model_train.summary()

返回信息:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None, 73)     0
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 2623)   0
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, 256), (None, 337920      input_1[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   [(None, None, 256),  2949120     input_2[0][0]
                                                                 lstm_1[0][1]
                                                                 lstm_1[0][2]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 2623)   674111      lstm_2[0][0]
==================================================================================================
Total params: 3,961,151
Trainable params: 3,961,151
Non-trainable params: 0
__________________________________________________________________________________________________

执行代码:

encoder_infer.summary()

返回信息:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, None, 73)          0
_________________________________________________________________
lstm_1 (LSTM)                [(None, 256), (None, 256) 337920
=================================================================
Total params: 337,920
Trainable params: 337,920
Non-trainable params: 0
_________________________________________________________________

执行代码:

decoder_infer.summary()

返回信息:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_2 (InputLayer)            (None, None, 2623)   0
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 256)          0
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, 256)          0
__________________________________________________________________________________________________
lstm_2 (LSTM)                   [(None, None, 256),  2949120     input_2[0][0]
                                                                 input_3[0][0]
                                                                 input_4[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 2623)   674111      lstm_2[1][0]
==================================================================================================
Total params: 3,623,231
Trainable params: 3,623,231
Non-trainable params: 0
__________________________________________________________________________________________________

模型训练

执行代码:

model_train.fit([encoder_input,decoder_input],decoder_output,batch_size=BATCH_SIZE,epochs=EPOCH,validation_split=0.2)

返回信息:

WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\ops\math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Train on 8000 samples, validate on 2000 samples
Epoch 1/50
8000/8000 [==============================] - 165s 21ms/step - loss: 2.0312 - val_loss: 2.5347
Epoch 2/50
8000/8000 [==============================] - 135s 17ms/step - loss: 1.9115 - val_loss: 2.4281
Epoch 3/50
8000/8000 [==============================] - 137s 17ms/step - loss: 1.8016 - val_loss: 2.3045
Epoch 4/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.7122 - val_loss: 2.2427
Epoch 5/50
8000/8000 [==============================] - 138s 17ms/step - loss: 1.6271 - val_loss: 2.1663
Epoch 6/50
8000/8000 [==============================] - 135s 17ms/step - loss: 1.5521 - val_loss: 2.0765
Epoch 7/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.4851 - val_loss: 2.0489
Epoch 8/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.4294 - val_loss: 2.0093
Epoch 9/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.3731 - val_loss: 1.9413
Epoch 10/50
8000/8000 [==============================] - 137s 17ms/step - loss: 1.3267 - val_loss: 1.9104
Epoch 11/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.2840 - val_loss: 1.8889
Epoch 12/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.2452 - val_loss: 1.8658
Epoch 13/50
8000/8000 [==============================] - 135s 17ms/step - loss: 1.2077 - val_loss: 1.8468
Epoch 14/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.1747 - val_loss: 1.8372
Epoch 15/50
8000/8000 [==============================] - 135s 17ms/step - loss: 1.1425 - val_loss: 1.8197
Epoch 16/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.1125 - val_loss: 1.8159
Epoch 17/50
8000/8000 [==============================] - 135s 17ms/step - loss: 1.0834 - val_loss: 1.8012
Epoch 18/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.0553 - val_loss: 1.7965
Epoch 19/50
8000/8000 [==============================] - 136s 17ms/step - loss: 1.0287 - val_loss: 1.7954
Epoch 20/50
8000/8000 [==============================] - 137s 17ms/step - loss: 1.0026 - val_loss: 1.7882
Epoch 21/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.9766 - val_loss: 1.7861
Epoch 22/50
8000/8000 [==============================] - 138s 17ms/step - loss: 0.9517 - val_loss: 1.7907
Epoch 23/50
8000/8000 [==============================] - 137s 17ms/step - loss: 0.9274 - val_loss: 1.7936
Epoch 24/50
8000/8000 [==============================] - 135s 17ms/step - loss: 0.9044 - val_loss: 1.7815
Epoch 25/50
8000/8000 [==============================] - 135s 17ms/step - loss: 0.8811 - val_loss: 1.7831
Epoch 26/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.8592 - val_loss: 1.7894
Epoch 27/50
8000/8000 [==============================] - 137s 17ms/step - loss: 0.8376 - val_loss: 1.7932
Epoch 28/50
8000/8000 [==============================] - 139s 17ms/step - loss: 0.8161 - val_loss: 1.7874
Epoch 29/50
8000/8000 [==============================] - 137s 17ms/step - loss: 0.7947 - val_loss: 1.7913
Epoch 30/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.7746 - val_loss: 1.7912
Epoch 31/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.7545 - val_loss: 1.8008
Epoch 32/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.7353 - val_loss: 1.7989
Epoch 33/50
8000/8000 [==============================] - 135s 17ms/step - loss: 0.7164 - val_loss: 1.8023
Epoch 34/50
8000/8000 [==============================] - 136s 17ms/step - loss: 0.6984 - val_loss: 1.8090
Epoch 35/50
8000/8000 [==============================] - 135s 17ms/step - loss: 0.6803 - val_loss: 1.8092
Epoch 36/50
8000/8000 [==============================] - 142s 18ms/step - loss: 0.6627 - val_loss: 1.8177
Epoch 37/50
8000/8000 [==============================] - 140s 18ms/step - loss: 0.6467 - val_loss: 1.8249
Epoch 38/50
8000/8000 [==============================] - 144s 18ms/step - loss: 0.6288 - val_loss: 1.8290
Epoch 39/50
8000/8000 [==============================] - 143s 18ms/step - loss: 0.6127 - val_loss: 1.8333
Epoch 40/50
8000/8000 [==============================] - 144s 18ms/step - loss: 0.5973 - val_loss: 1.8454
Epoch 41/50
8000/8000 [==============================] - 143s 18ms/step - loss: 0.5824 - val_loss: 1.8488
Epoch 42/50
8000/8000 [==============================] - 160s 20ms/step - loss: 0.5669 - val_loss: 1.8490
Epoch 43/50
8000/8000 [==============================] - 143s 18ms/step - loss: 0.5529 - val_loss: 1.8600
Epoch 44/50
8000/8000 [==============================] - 144s 18ms/step - loss: 0.5383 - val_loss: 1.8636
Epoch 45/50
8000/8000 [==============================] - 144s 18ms/step - loss: 0.5244 - val_loss: 1.8806
Epoch 46/50
8000/8000 [==============================] - 827s 103ms/step - loss: 0.5120 - val_loss: 1.8877
Epoch 47/50
8000/8000 [==============================] - 147s 18ms/step - loss: 0.4992 - val_loss: 1.8902
Epoch 48/50
8000/8000 [==============================] - 150s 19ms/step - loss: 0.4859 - val_loss: 1.8981
Epoch 49/50
8000/8000 [==============================] - 149s 19ms/step - loss: 0.4737 - val_loss: 1.9022
Epoch 50/50
8000/8000 [==============================] - 149s 19ms/step - loss: 0.4619 - val_loss: 1.9096
<keras.callbacks.History at 0x2f3f3936908>

预测序列

执行代码:

def predict_chinese(source,encoder_inference, decoder_inference, n_steps, features):
    #先通过推理encoder获得预测输入序列的隐状态
    state = encoder_inference.predict(source)
    #第一个字符'\t',为起始标志
    predict_seq = np.zeros((1,1,features))
    predict_seq[0,0,target_dict['\t']] = 1

    output = ''
    #开始对encoder获得的隐状态进行推理
    #每次循环用上次预测的字符作为输入来预测下一次的字符,直到预测出了终止符
    for i in range(n_steps):#n_steps为句子最大长度
        #给decoder输入上一个时刻的h,c隐状态,以及上一次的预测字符predict_seq
        yhat,h,c = decoder_inference.predict([predict_seq]+state)
        #注意,这里的yhat为Dense之后输出的结果,因此与h不同
        char_index = np.argmax(yhat[0,-1,:])
        char = target_dict_reverse[char_index]
        output += char
        state = [h,c]#本次状态做为下一次的初始状态继续传递
        predict_seq = np.zeros((1,1,features))
        predict_seq[0,0,char_index] = 1
        if char == '\n':#预测到了终止符则停下来
            break
    return output
for i in range(1000,1100):
    test = encoder_input[i:i+1,:,:]#i:i+1保持数组是三维
    out = predict_chinese(test,encoder_infer,decoder_infer,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH)
    #print(input_texts[i],'\n---\n',target_texts[i],'\n---\n',out)
    print(input_texts[i])
    print(out)

返回信息:

I have brothers.
我有一個意思。

I have ten pens.
我有一個意見。

I have to hurry!
我有一个好。

I have two cats.
我有一個意見。

I have two sons.
我有一個意見。

I just threw up.
我只是不能看看。

I lent him a CD.
我喜欢你的。

I like Tom, too.
我喜欢你的房间。

I like football.
我喜欢跑步。

I like potatoes.
我喜欢跑步。

I like the cold.
我喜欢跑步。

I like this dog.
我喜欢跑步。

I like your car.
我喜欢你的车。

I lived in Rome.
我愛這台車。

I love this car.
我愛這台車。

I might say yes.
我非常喜欢我的。

I must help her.
我在這裡吃飯了。

I need a friend.
我需要一張郵票。

I need evidence.
我需要一張郵票。

I need you here.
我需要你的幫助。

I paid the bill.
我同意他。

I played tennis.
我喜歡運動。

I run every day.
我非常喜欢。

I speak Swedish.
我同意他。

I talked to her.
我打算去那裡。

I teach Chinese.
我相信你。

I think it's OK.
我想感觉。

I took a shower.
我想要一張。

I want a guitar.
我想要一個。

I want that bag.
我想要一點。

I want to drive.
我想要一個。

I was surprised.
我真是不能运。

I wish you'd go.
我希望你去。

I woke up early.
我希望你去。

I work too much.
我在这个工作。

I'll bring wine.
我會開車。

I'll never stop.
我會打網球。

I'm a foreigner.
我是個老師。

I'm a night owl.
我是个男人。

I'm about ready.
我是個老師。

I'm always here.
我很高興。

I'm daydreaming.
我是個老師。

I'm feeling fit.
我是一個好男孩。

I'm left-handed.
我是一個好人。

I'm not serious.
我不是你的。

I'm out of time.
我是个男人。

I'm really busy.
我很快樂。

I'm really cold.
我很高興。

I'm still angry.
我是個老師。

I'm very hungry.
我很快樂。

I'm very lonely.
我很好奇。

I've had enough.
我有一切。

I've had enough.
我有一切。

Is Tom Canadian?
汤姆是好人。

Is he breathing?
他是日本人吗?

Is it all there?
那是一只是嗎?

Is it too salty?
那是你的車嗎?

Is she Japanese?
他的是日本人嗎?

Is this a river?
这是一只是用的铅笔吗?

Isn't that mine?
那是我的吗?

It is up to you.
那是你的想象。

It snowed a lot.
它不是真的。

It was terrible.
它是我們的錯誤。

It was very far.
它是我們的秘密。

It'll be cloudy.
它很大。

It's a dead end.
它是我們的錯。

It's a new book.
它是我們的秘密。

It's a nice day.
它是我們的秘密。

It's a surprise.
它是我們的!

It's almost six.
它是一個很好的問題。

It's already 11.
它是我們的秘密。

It's fine today.
它是一個好主意。

It's impossible.
它是一個很好的問題。

It's lunch time.
它是我們的錯。

It's okay to go.
它是我們的秘密。

It's over there.
它是一個好。

It's time to go.
它是我最好的。

It's time to go.
它是我最好的。

Jesus loves you.
把它给我的!

Keep on smiling.
保持微笑。

Keep on working.
让不要保持。

Keep the change!
繼續工作!

Large, isn't it?
很少, 不是嗎?

Lemons are sour.
讓我看一下。

Let me go alone.
讓我看看。

Let me see that.
讓我看看。

Let them decide.
讓我看看。

Let's eat sushi.
讓我們開始吧。

Let's go by bus.
讓我們開始吧。

Let's not argue.
讓我們開始吧。

Let's turn back.
讓我們開始吧。

Look at the sky.
看看那個好。

Look behind you.
看看那個人。

Make it smaller.
把它弄小一點。

May I leave now?
我可以看看你的行嗎?

May I try it on?
我可以看看你的行嗎?

Maybe next time.
大部電話是我的工作。

Men should work.
大下周日。

Merry Christmas!
瑪麗很高。

Mom, I'm hungry.
媽,我可以去游泳嗎?

参考:

https://zybuluo.com/hanbingtao/note/541458

https://zybuluo.com/hanbingtao/note/581764

最新文章

  1. .NET WebAPI 实现图片上传(包括附带参数上传图片)
  2. JAVA Day9
  3. final 评论 I
  4. 关于Memcached一致性hash的探究
  5. java 21 - 10 文本文件和集合之间互相存储数据
  6. input子系统详解
  7. C++11 Concurrency Features
  8. JavaScript与Flash的通信
  9. 关于ax+by=c的解x,y的min(|x|+|y|)值问题
  10. &lt;转&gt;GC其他:引用标记-清除、复制、标记-整理的说明
  11. eclipse重构详解(转)
  12. 【Sort】QuickSort
  13. Hihocoder 2月29日
  14. Linux中grep命令学习
  15. 7_linux下PHP、Apache、Mysql服务的安装
  16. String s=new String(&quot;abc&quot;)创建了几个对象?
  17. How to SetUp The Receiving Transaction Manager
  18. win10 开机自启指定软件
  19. python学习笔记(四)、条件、循环及其他语句
  20. 【java多线程】队列系统之说说队列Queue

热门文章

  1. EAC3 enhanced channel coupling
  2. DOM的方法和属性
  3. NPOI 导出Excel表报错
  4. 【visio】数据可视化 - 数据展示
  5. 清理rancher、k8s环境
  6. 快递查询API
  7. CSS学习(5)更多的选择器
  8. Mac配置内网穿透
  9. Validation failed for one or more entities. See ‘EntityValidationErrors
  10. java爬虫出现java.lang.IllegalArgumentException: Illegal character in path at index 31