Keras入门——(7)长短期记忆网络LSTM(四)
2024-09-05 07:27:06
数据准备:http://www.manythings.org/anki/cmn-eng.zip
源代码:https://github.com/pjgao/seq2seq_keras
参考:https://blog.csdn.net/PIPIXIU/article/details/81016974
导入库
执行代码:
from keras.layers import Input,LSTM,Dense from keras.models import Model,load_model from keras.utils import plot_model import pandas as pd import numpy as np
返回信息:
Using TensorFlow backend.
执行代码:
def create_model(n_input,n_output,n_units): #训练阶段 #encoder encoder_input = Input(shape = (None, n_input)) #encoder输入维度n_input为每个时间步的输入xt的维度,这里是用来one-hot的英文字符数 encoder = LSTM(n_units, return_state=True) #n_units为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,c _,encoder_h,encoder_c = encoder(encoder_input) encoder_state = [encoder_h,encoder_c] #保留下来encoder的末状态作为decoder的初始状态 #decoder decoder_input = Input(shape = (None, n_output)) #decoder的输入维度为中文字符数 decoder = LSTM(n_units,return_sequences=True, return_state=True) #训练模型时需要decoder的输出序列来与结果对比优化,故return_sequences也要设为True decoder_output, _, _ = decoder(decoder_input,initial_state=encoder_state) #在训练阶段只需要用到decoder的输出序列,不需要用最终状态h.c decoder_dense = Dense(n_output,activation='softmax') decoder_output = decoder_dense(decoder_output) #输出序列经过全连接层得到结果 #生成的训练模型 model = Model([encoder_input,decoder_input],decoder_output) #第一个参数为训练模型的输入,包含了encoder和decoder的输入,第二个参数为模型的输出,包含了decoder的输出 #推理阶段,用于预测过程 #推断模型—encoder encoder_infer = Model(encoder_input,encoder_state) #推断模型-decoder decoder_state_input_h = Input(shape=(n_units,)) decoder_state_input_c = Input(shape=(n_units,)) decoder_state_input = [decoder_state_input_h, decoder_state_input_c]#上个时刻的状态h,c decoder_infer_output, decoder_infer_state_h, decoder_infer_state_c = decoder(decoder_input,initial_state=decoder_state_input) decoder_infer_state = [decoder_infer_state_h, decoder_infer_state_c]#当前时刻得到的状态 decoder_infer_output = decoder_dense(decoder_infer_output)#当前时刻的输出 decoder_infer = Model([decoder_input]+decoder_state_input,[decoder_infer_output]+decoder_infer_state) return model, encoder_infer, decoder_infer
N_UNITS = 256 BATCH_SIZE = 64 EPOCH = 50 NUM_SAMPLES = 10000
数据处理
data_path = 'data/cmn.txt'
读取数据
执行代码:
df = pd.read_table(data_path,header=None).iloc[:NUM_SAMPLES,:,] df.columns=['inputs','targets'] df['targets'] = df['targets'].apply(lambda x: '\t'+x+'\n') input_texts = df.inputs.values.tolist() target_texts = df.targets.values.tolist() input_characters = sorted(list(set(df.inputs.unique().sum()))) target_characters = sorted(list(set(df.targets.unique().sum())))
返回信息:
C:\3rd\Anaconda2\lib\site-packages\ipykernel_launcher.py:1: FutureWarning: read_table is deprecated, use read_csv instead, passing sep='\t'. """Entry point for launching an IPython kernel.
执行代码:
INUPT_LENGTH = max([len(i) for i in input_texts]) OUTPUT_LENGTH = max([len(i) for i in target_texts]) INPUT_FEATURE_LENGTH = len(input_characters) OUTPUT_FEATURE_LENGTH = len(target_characters)
向量化
执行代码:
encoder_input = np.zeros((NUM_SAMPLES,INUPT_LENGTH,INPUT_FEATURE_LENGTH)) decoder_input = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH)) decoder_output = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH))
input_dict = {char:index for index,char in enumerate(input_characters)} input_dict_reverse = {index:char for index,char in enumerate(input_characters)} target_dict = {char:index for index,char in enumerate(target_characters)} target_dict_reverse = {index:char for index,char in enumerate(target_characters)}
for seq_index,seq in enumerate(input_texts): for char_index, char in enumerate(seq): encoder_input[seq_index,char_index,input_dict[char]] = 1
for seq_index,seq in enumerate(target_texts): for char_index,char in enumerate(seq): decoder_input[seq_index,char_index,target_dict[char]] = 1.0 if char_index > 0: decoder_output[seq_index,char_index-1,target_dict[char]] = 1.0
观察向量化的数据
执行代码:
''.join([input_dict_reverse[np.argmax(i)] for i in encoder_input[0] if max(i) !=0])
返回信息:
'Hi.'
执行代码:
''.join([target_dict_reverse[np.argmax(i)] for i in decoder_output[0] if max(i) !=0])
返回信息:
'嗨。\n'
执行代码:
''.join([target_dict_reverse[np.argmax(i)] for i in decoder_input[0] if max(i) !=0])
返回信息:
'\t嗨。\n'
创建模型
执行代码:
model_train, encoder_infer, decoder_infer = create_model(INPUT_FEATURE_LENGTH, OUTPUT_FEATURE_LENGTH, N_UNITS)
返回信息:
WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer.
执行代码:
#查看模型结构 plot_model(to_file='model.png',model=model_train,show_shapes=True) plot_model(to_file='encoder.png',model=encoder_infer,show_shapes=True) plot_model(to_file='decoder.png',model=decoder_infer,show_shapes=True)
model_train.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model_train.summary()
返回信息:
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, None, 73) 0 __________________________________________________________________________________________________ input_2 (InputLayer) (None, None, 2623) 0 __________________________________________________________________________________________________ lstm_1 (LSTM) [(None, 256), (None, 337920 input_1[0][0] __________________________________________________________________________________________________ lstm_2 (LSTM) [(None, None, 256), 2949120 input_2[0][0] lstm_1[0][1] lstm_1[0][2] __________________________________________________________________________________________________ dense_1 (Dense) (None, None, 2623) 674111 lstm_2[0][0] ================================================================================================== Total params: 3,961,151 Trainable params: 3,961,151 Non-trainable params: 0 __________________________________________________________________________________________________
执行代码:
encoder_infer.summary()
返回信息:
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, None, 73) 0 _________________________________________________________________ lstm_1 (LSTM) [(None, 256), (None, 256) 337920 ================================================================= Total params: 337,920 Trainable params: 337,920 Non-trainable params: 0 _________________________________________________________________
执行代码:
decoder_infer.summary()
返回信息:
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_2 (InputLayer) (None, None, 2623) 0 __________________________________________________________________________________________________ input_3 (InputLayer) (None, 256) 0 __________________________________________________________________________________________________ input_4 (InputLayer) (None, 256) 0 __________________________________________________________________________________________________ lstm_2 (LSTM) [(None, None, 256), 2949120 input_2[0][0] input_3[0][0] input_4[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, None, 2623) 674111 lstm_2[1][0] ================================================================================================== Total params: 3,623,231 Trainable params: 3,623,231 Non-trainable params: 0 __________________________________________________________________________________________________
模型训练
执行代码:
model_train.fit([encoder_input,decoder_input],decoder_output,batch_size=BATCH_SIZE,epochs=EPOCH,validation_split=0.2)
返回信息:
WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\ops\math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.cast instead. WARNING:tensorflow:From C:\3rd\Anaconda2\lib\site-packages\tensorflow\python\ops\math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. Instructions for updating: Deprecated in favor of operator or tf.math.divide. Train on 8000 samples, validate on 2000 samples Epoch 1/50 8000/8000 [==============================] - 165s 21ms/step - loss: 2.0312 - val_loss: 2.5347 Epoch 2/50 8000/8000 [==============================] - 135s 17ms/step - loss: 1.9115 - val_loss: 2.4281 Epoch 3/50 8000/8000 [==============================] - 137s 17ms/step - loss: 1.8016 - val_loss: 2.3045 Epoch 4/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.7122 - val_loss: 2.2427 Epoch 5/50 8000/8000 [==============================] - 138s 17ms/step - loss: 1.6271 - val_loss: 2.1663 Epoch 6/50 8000/8000 [==============================] - 135s 17ms/step - loss: 1.5521 - val_loss: 2.0765 Epoch 7/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.4851 - val_loss: 2.0489 Epoch 8/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.4294 - val_loss: 2.0093 Epoch 9/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.3731 - val_loss: 1.9413 Epoch 10/50 8000/8000 [==============================] - 137s 17ms/step - loss: 1.3267 - val_loss: 1.9104 Epoch 11/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.2840 - val_loss: 1.8889 Epoch 12/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.2452 - val_loss: 1.8658 Epoch 13/50 8000/8000 [==============================] - 135s 17ms/step - loss: 1.2077 - val_loss: 1.8468 Epoch 14/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.1747 - val_loss: 1.8372 Epoch 15/50 8000/8000 [==============================] - 135s 17ms/step - loss: 1.1425 - val_loss: 1.8197 Epoch 16/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.1125 - val_loss: 1.8159 Epoch 17/50 8000/8000 [==============================] - 135s 17ms/step - loss: 1.0834 - val_loss: 1.8012 Epoch 18/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.0553 - val_loss: 1.7965 Epoch 19/50 8000/8000 [==============================] - 136s 17ms/step - loss: 1.0287 - val_loss: 1.7954 Epoch 20/50 8000/8000 [==============================] - 137s 17ms/step - loss: 1.0026 - val_loss: 1.7882 Epoch 21/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.9766 - val_loss: 1.7861 Epoch 22/50 8000/8000 [==============================] - 138s 17ms/step - loss: 0.9517 - val_loss: 1.7907 Epoch 23/50 8000/8000 [==============================] - 137s 17ms/step - loss: 0.9274 - val_loss: 1.7936 Epoch 24/50 8000/8000 [==============================] - 135s 17ms/step - loss: 0.9044 - val_loss: 1.7815 Epoch 25/50 8000/8000 [==============================] - 135s 17ms/step - loss: 0.8811 - val_loss: 1.7831 Epoch 26/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.8592 - val_loss: 1.7894 Epoch 27/50 8000/8000 [==============================] - 137s 17ms/step - loss: 0.8376 - val_loss: 1.7932 Epoch 28/50 8000/8000 [==============================] - 139s 17ms/step - loss: 0.8161 - val_loss: 1.7874 Epoch 29/50 8000/8000 [==============================] - 137s 17ms/step - loss: 0.7947 - val_loss: 1.7913 Epoch 30/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.7746 - val_loss: 1.7912 Epoch 31/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.7545 - val_loss: 1.8008 Epoch 32/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.7353 - val_loss: 1.7989 Epoch 33/50 8000/8000 [==============================] - 135s 17ms/step - loss: 0.7164 - val_loss: 1.8023 Epoch 34/50 8000/8000 [==============================] - 136s 17ms/step - loss: 0.6984 - val_loss: 1.8090 Epoch 35/50 8000/8000 [==============================] - 135s 17ms/step - loss: 0.6803 - val_loss: 1.8092 Epoch 36/50 8000/8000 [==============================] - 142s 18ms/step - loss: 0.6627 - val_loss: 1.8177 Epoch 37/50 8000/8000 [==============================] - 140s 18ms/step - loss: 0.6467 - val_loss: 1.8249 Epoch 38/50 8000/8000 [==============================] - 144s 18ms/step - loss: 0.6288 - val_loss: 1.8290 Epoch 39/50 8000/8000 [==============================] - 143s 18ms/step - loss: 0.6127 - val_loss: 1.8333 Epoch 40/50 8000/8000 [==============================] - 144s 18ms/step - loss: 0.5973 - val_loss: 1.8454 Epoch 41/50 8000/8000 [==============================] - 143s 18ms/step - loss: 0.5824 - val_loss: 1.8488 Epoch 42/50 8000/8000 [==============================] - 160s 20ms/step - loss: 0.5669 - val_loss: 1.8490 Epoch 43/50 8000/8000 [==============================] - 143s 18ms/step - loss: 0.5529 - val_loss: 1.8600 Epoch 44/50 8000/8000 [==============================] - 144s 18ms/step - loss: 0.5383 - val_loss: 1.8636 Epoch 45/50 8000/8000 [==============================] - 144s 18ms/step - loss: 0.5244 - val_loss: 1.8806 Epoch 46/50 8000/8000 [==============================] - 827s 103ms/step - loss: 0.5120 - val_loss: 1.8877 Epoch 47/50 8000/8000 [==============================] - 147s 18ms/step - loss: 0.4992 - val_loss: 1.8902 Epoch 48/50 8000/8000 [==============================] - 150s 19ms/step - loss: 0.4859 - val_loss: 1.8981 Epoch 49/50 8000/8000 [==============================] - 149s 19ms/step - loss: 0.4737 - val_loss: 1.9022 Epoch 50/50 8000/8000 [==============================] - 149s 19ms/step - loss: 0.4619 - val_loss: 1.9096 <keras.callbacks.History at 0x2f3f3936908>
预测序列
执行代码:
def predict_chinese(source,encoder_inference, decoder_inference, n_steps, features): #先通过推理encoder获得预测输入序列的隐状态 state = encoder_inference.predict(source) #第一个字符'\t',为起始标志 predict_seq = np.zeros((1,1,features)) predict_seq[0,0,target_dict['\t']] = 1 output = '' #开始对encoder获得的隐状态进行推理 #每次循环用上次预测的字符作为输入来预测下一次的字符,直到预测出了终止符 for i in range(n_steps):#n_steps为句子最大长度 #给decoder输入上一个时刻的h,c隐状态,以及上一次的预测字符predict_seq yhat,h,c = decoder_inference.predict([predict_seq]+state) #注意,这里的yhat为Dense之后输出的结果,因此与h不同 char_index = np.argmax(yhat[0,-1,:]) char = target_dict_reverse[char_index] output += char state = [h,c]#本次状态做为下一次的初始状态继续传递 predict_seq = np.zeros((1,1,features)) predict_seq[0,0,char_index] = 1 if char == '\n':#预测到了终止符则停下来 break return output
for i in range(1000,1100): test = encoder_input[i:i+1,:,:]#i:i+1保持数组是三维 out = predict_chinese(test,encoder_infer,decoder_infer,OUTPUT_LENGTH,OUTPUT_FEATURE_LENGTH) #print(input_texts[i],'\n---\n',target_texts[i],'\n---\n',out) print(input_texts[i]) print(out)
返回信息:
I have brothers. 我有一個意思。 I have ten pens. 我有一個意見。 I have to hurry! 我有一个好。 I have two cats. 我有一個意見。 I have two sons. 我有一個意見。 I just threw up. 我只是不能看看。 I lent him a CD. 我喜欢你的。 I like Tom, too. 我喜欢你的房间。 I like football. 我喜欢跑步。 I like potatoes. 我喜欢跑步。 I like the cold. 我喜欢跑步。 I like this dog. 我喜欢跑步。 I like your car. 我喜欢你的车。 I lived in Rome. 我愛這台車。 I love this car. 我愛這台車。 I might say yes. 我非常喜欢我的。 I must help her. 我在這裡吃飯了。 I need a friend. 我需要一張郵票。 I need evidence. 我需要一張郵票。 I need you here. 我需要你的幫助。 I paid the bill. 我同意他。 I played tennis. 我喜歡運動。 I run every day. 我非常喜欢。 I speak Swedish. 我同意他。 I talked to her. 我打算去那裡。 I teach Chinese. 我相信你。 I think it's OK. 我想感觉。 I took a shower. 我想要一張。 I want a guitar. 我想要一個。 I want that bag. 我想要一點。 I want to drive. 我想要一個。 I was surprised. 我真是不能运。 I wish you'd go. 我希望你去。 I woke up early. 我希望你去。 I work too much. 我在这个工作。 I'll bring wine. 我會開車。 I'll never stop. 我會打網球。 I'm a foreigner. 我是個老師。 I'm a night owl. 我是个男人。 I'm about ready. 我是個老師。 I'm always here. 我很高興。 I'm daydreaming. 我是個老師。 I'm feeling fit. 我是一個好男孩。 I'm left-handed. 我是一個好人。 I'm not serious. 我不是你的。 I'm out of time. 我是个男人。 I'm really busy. 我很快樂。 I'm really cold. 我很高興。 I'm still angry. 我是個老師。 I'm very hungry. 我很快樂。 I'm very lonely. 我很好奇。 I've had enough. 我有一切。 I've had enough. 我有一切。 Is Tom Canadian? 汤姆是好人。 Is he breathing? 他是日本人吗? Is it all there? 那是一只是嗎? Is it too salty? 那是你的車嗎? Is she Japanese? 他的是日本人嗎? Is this a river? 这是一只是用的铅笔吗? Isn't that mine? 那是我的吗? It is up to you. 那是你的想象。 It snowed a lot. 它不是真的。 It was terrible. 它是我們的錯誤。 It was very far. 它是我們的秘密。 It'll be cloudy. 它很大。 It's a dead end. 它是我們的錯。 It's a new book. 它是我們的秘密。 It's a nice day. 它是我們的秘密。 It's a surprise. 它是我們的! It's almost six. 它是一個很好的問題。 It's already 11. 它是我們的秘密。 It's fine today. 它是一個好主意。 It's impossible. 它是一個很好的問題。 It's lunch time. 它是我們的錯。 It's okay to go. 它是我們的秘密。 It's over there. 它是一個好。 It's time to go. 它是我最好的。 It's time to go. 它是我最好的。 Jesus loves you. 把它给我的! Keep on smiling. 保持微笑。 Keep on working. 让不要保持。 Keep the change! 繼續工作! Large, isn't it? 很少, 不是嗎? Lemons are sour. 讓我看一下。 Let me go alone. 讓我看看。 Let me see that. 讓我看看。 Let them decide. 讓我看看。 Let's eat sushi. 讓我們開始吧。 Let's go by bus. 讓我們開始吧。 Let's not argue. 讓我們開始吧。 Let's turn back. 讓我們開始吧。 Look at the sky. 看看那個好。 Look behind you. 看看那個人。 Make it smaller. 把它弄小一點。 May I leave now? 我可以看看你的行嗎? May I try it on? 我可以看看你的行嗎? Maybe next time. 大部電話是我的工作。 Men should work. 大下周日。 Merry Christmas! 瑪麗很高。 Mom, I'm hungry. 媽,我可以去游泳嗎?
参考:
https://zybuluo.com/hanbingtao/note/541458
https://zybuluo.com/hanbingtao/note/581764
最新文章
- .NET WebAPI 实现图片上传(包括附带参数上传图片)
- JAVA Day9
- final 评论 I
- 关于Memcached一致性hash的探究
- java 21 - 10 文本文件和集合之间互相存储数据
- input子系统详解
- C++11 Concurrency Features
- JavaScript与Flash的通信
- 关于ax+by=c的解x,y的min(|x|+|y|)值问题
- <;转>;GC其他:引用标记-清除、复制、标记-整理的说明
- eclipse重构详解(转)
- 【Sort】QuickSort
- Hihocoder 2月29日
- Linux中grep命令学习
- 7_linux下PHP、Apache、Mysql服务的安装
- String s=new String(";abc";)创建了几个对象?
- How to SetUp The Receiving Transaction Manager
- win10 开机自启指定软件
- python学习笔记(四)、条件、循环及其他语句
- 【java多线程】队列系统之说说队列Queue
热门文章
- EAC3 enhanced channel coupling
- DOM的方法和属性
- NPOI 导出Excel表报错
- 【visio】数据可视化 - 数据展示
- 清理rancher、k8s环境
- 快递查询API
- CSS学习(5)更多的选择器
- Mac配置内网穿透
- Validation failed for one or more entities. See ‘EntityValidationErrors
- java爬虫出现java.lang.IllegalArgumentException: Illegal character in path at index 31