最近看了不少关于写诗的博客,在前人的基础上做了一些小的改动,因比较喜欢一次输入很长的开头句,所以让机器人输出压缩为一个开头字生成两个诗句,写五言和七言诗,当然如果你想写更长的诗句是可以继续改动的。

在输入做了些改动,去除误输入的标点符号,例如输入下面词句:

怒发冲冠,凭栏处,潇潇雨歇。抬望眼,仰天长啸,壮怀激烈。

机器人写出如下:

怒漠多无度袍小,巡管山明恰见偷。
发杵共鸿莼散暮,家山曾住上阳台。
冲钩麻衣隐步障,楼舟复别赤轮楼。
冠盘一线倾中令,音信长思两足阴。
凭栏十字送月沈,莫待长筵韦与兵。
栏湿地闲鱼脚吠,匣中虚伴虎前羊。
处向梅香千万里,石城心中明碧簟。
潇湘夜塘独思不,几处深笼不尽迟。
潇湘十二海云天,五月今留报主人。
雨后青青沧练过,月明南渡雁来新。
歇鞭千里知贤哭,清镜草浮麋没繁。
抬吾谢安犹带减,醉来赤此不才齐。
望台已待隋家咏,楼上不能犹扫成。
眼看药炉香岭上,比惟无言旧青春。
仰归安得衰无事,三笑出师湘水春。
天初碧玉衣襟淡,国药满川霞彩寒。
长叹榆关家已远,从来父子屈襄郎。
啸花青石速望尽,宛逐汀洲随并年。
壮时天下还如旧,生计孤吟去杀无。
怀哉却寄终拘束,莫道人来有也才。
激眼剑旗喧并髻,新菰麦落破门骄。
烈灵不识槛西间,芳草青天有五禽。
怒搜温液切,若近太阳香。
发欲奔宾影,争得频人怜。
冲腾临缺曙,谢豹出红残。
冠盖若移在,想得绛皇皇。
凭高晋家雨,才不寄黄金。
栏落临巨浸,根孔恨仙桃。
处世愿越游,飘扬共行之。
潇湘南北洞,蜀国湘江湄。
潇湘弦管绝,上月洞庭时。
雨历道中朔,不起列仙风。
歇毂须江道,家歌住忽依。
抬山弄弓寞,装束岛霞裙。
望闻拜天子,幸有窦金狐。
眼看尽无些,香雨不兴天。
仰阮子不笔,不敢相思逸。
天与十二胆,彩笔取时七。
长安陇波归,银没乌方地。
啸作胡人船,大作康庄匠。
壮何时七两,隐括为匡庐。
怀君青纪肥,常带牺胆圣。
激逐鸱子子,人来儆此处。
烈太仓毛黄,家长受德吉。

代码如下:

main.py

import collections
import os
import sys
import re
import numpy as np
import tensorflow as tf
from model import rnn_model
from poems import process_poems, generate_batch os.environ['TF_CPP_MIN_LOG_LEVEL']='' tf.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
# set this to 'main.py' relative path
tf.flags.DEFINE_string('checkpoints_dir', './checkpoints/', 'checkpoints save path.')
tf.flags.DEFINE_string('file_path', './data/poetry.txt', 'file name of poems.') tf.flags.DEFINE_integer('epochs', 50, 'train how many epochs.') FLAGS = tf.flags.FLAGS start_token = 'G'
end_token = 'E' #开始训练
def run_training():
if not os.path.exists(os.path.dirname(FLAGS.checkpoints_dir)):
os.mkdir(os.path.dirname(FLAGS.checkpoints_dir))
if not os.path.exists(FLAGS.checkpoints_dir):
os.mkdir(FLAGS.checkpoints_dir)
# 单词转化的数字:向量,单词和数字一一对应的字典,单词
poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
# 真实值和目标值
batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
# 数据占位符
input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None]) end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
# 实例化保存模型
saver = tf.train.Saver(tf.global_variables())
# 全局变量进行初始化
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
# sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
# 先执行,全局变量初始化
sess.run(init_op) start_epoch = 0
# 把之前训练过的checkpoint拿出来
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
if checkpoint:
# 拿出训练保存模型
saver.restore(sess, checkpoint)
print("[INFO] restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('[INFO] start training...')
try:
for epoch in range(start_epoch, FLAGS.epochs):
n = 0
# 多少行唐诗//每次训练的个数
n_chunk = len(poems_vector) // FLAGS.batch_size
for batch in range(n_chunk):
loss, _, _ = sess.run([
end_points['total_loss'], # 损失
end_points['last_state'], # 最后一次输出
end_points['train_op'] # 训练优化损失
], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
n += 1
print('[INFO] Epoch: %d , batch: %d , training loss: %.6f' % (epoch, batch, loss))
if epoch % 6 == 0: # 每隔多少次保存
saver.save(sess, FLAGS.checkpoints_dir, global_step=epoch)
except KeyboardInterrupt:
print('[INFO] Interrupt manually, try saving checkpoint for now...')
saver.save(sess, FLAGS.checkpoints_dir, global_step=epoch)
print('[INFO] Last epoch were saved, next time will start from epoch {}.'.format(epoch)) def to_word(predict, vocabs):
t = np.cumsum(predict)
s = np.sum(predict)
# searchsorted 在前面查找后面的
sample = int(np.searchsorted(t, np.random.rand(1) * s))
# sample = np.argmax(predict)
if sample > len(vocabs):
sample = len(vocabs) - 1
return vocabs[sample] #调用模型生成诗句
def gen_poem(begin_words, num):
batch_size = 1
print('[INFO] loading corpus from %s' % FLAGS.file_path)
# 单词转化的数字:向量,单词和数字一一对应的字典,单词
poems_vector, word_int_map, vocabularies = process_poems(FLAGS.file_path)
# 此时输入为1个
input_data = tf.placeholder(tf.int32, [batch_size, None])
# 损失等
end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate) saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
sess.run(init_op)
# 保存模型的位置,拿回sess
checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoints_dir)
# checkpoint = tf.train.latest_checkpoint('./model/') saver.restore(sess, checkpoint)
# saver.restore(sess,'./model/-24')
# 从字典里面获取到的开始值
x = np.array([list(map(word_int_map.get, start_token))]) [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x})
poem = ''
for begin_word in begin_words: while True:
if begin_word:
word = begin_word
else:
word = to_word(predict, vocabularies)
sentence = ''
while word != end_token: sentence += word
x = np.zeros((1, 1))
x[0, 0] = word_int_map[word]
[predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
feed_dict={input_data: x, end_points['initial_state']: last_state})
word = to_word(predict, vocabularies)
# word = words[np.argmax(probs_)]
if len(sentence) == 2 + 2 * num and (',' or '?') not in sentence[:num] and (',' or '?') not in sentence[num+1:-1] and sentence[num] == ',' and '□' not in sentence:
poem += sentence
# sentence = ''
break
else:
print("我正在写诗呢") return poem #这里将生成的诗句,按照中文诗词的格式输出
#同时方便接入应用
def pretty_print_poem(poem):
poem_sentences = poem.split('。')
# print(poem_sentences)
for s in poem_sentences:
if s != '' and len(s) > 10:
# if s != '': print(s + '。') def main():
if len(sys.argv) == 2:
if sys.argv[1] == '':
print('[INFO] train tang poem...')
run_training()
elif sys.argv[1] == '':
num = int(input("请输入训练诗句(5:五言,7:七言):"))
if num == 5 or num == 7:
print('[INFO] write tang poem...')
begin_word = input('开始作诗,请输入起始字:')
if len(begin_word) == 0:
print("请输入词句")
return
r1 = '[a-zA-Z0-9’!"#$%&\'()*+,-./:;<=>?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'
begin_word = re.sub(r1, '', begin_word)
poem2 = gen_poem(begin_word, num)
pretty_print_poem(poem2)
else:
print('输入有误') else:
print('a',sys.argv[1])
print("请按照以下方式执行:")
print("python xxxx.py 1(1:训练,2:写诗)")
else:
print(len(sys.argv))
print("请按照以下方式执行:")
print("python xxxx.py 1(1:训练,2:写诗)") if __name__ == '__main__':
main()

将诗歌和单词转换为一一对应的数字:

def process_poems(file_name):
# 诗集
poems = []
with open(file_name, "r", encoding='utf-8', ) as f:
for line in f.readlines():
try:
title, content = line.strip().split(':')
content = content.replace(' ', '')
if '_' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:
continue
if len(content) < 5 or len(content) > 79:
continue
content = start_token + content + end_token
poems.append(content)
except ValueError as e:
pass
# 按诗的字数排序
poems = sorted(poems, key=lambda l: len(line)) # 统计每个字出现次数
all_words = []
for poem in poems:
all_words += [word for word in poem]
# 这里根据包含了每个字对应的频率 Counter({'我': 3, '你': 2, '他': 1})
counter = collections.Counter(all_words) # items转化为列表,里面为元组 [('他', 1), ('你', 2), ('我', 3)]
count_pairs = sorted(counter.items(), key=lambda x: x[-1])
# ('他', '你', '我'),(1, 2, 3)
words, _ = zip(*count_pairs) # 取前多少个常用字
words = words[:len(words)] + (' ',)
# words = words[:len(words)]
# 每个字映射为一个数字ID {'他': 0, '你': 1, '我': 2}
word_int_map = dict(zip(words, range(len(words))))
# 将诗歌中每个字转换成一一对应的数字,输入poem中的每个字,转化成对应数字返回
# 没有获得word对应数字Id,就返回len(words)
poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems] return poems_vector, word_int_map, words

定义真实值和目标值:

def generate_batch(batch_size, poems_vec, word_to_int):
# 每次取64首诗进行训练
# 计算有多少个batch_size
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size batches = poems_vec[start_index:end_index]
# 找到这个batch的所有poem中最长的poem的长度
length = max(map(len, batches))
# 填充一个这么大小的空batch,空的地方放空格进行填充
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
for row in range(batch_size):
# 每一行就是一首诗,在原本的长度上把诗还原上去
x_data[row, :len(batches[row])] = batches[row]
y_data = np.copy(x_data)
# y的话就是x向左边也就是前面移动一个
y_data[:, :-1] = x_data[:, 1:]
"""
x_data y_data
[6,2,4,6,9] [2,4,6,9,9]
[1,4,2,8,5] [4,2,8,5,5]
"""
x_batches.append(x_data)
y_batches.append(y_data)
return x_batches, y_batches

定义训练模型:

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL']='' #不显示一些提示警告信息 def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
learning_rate=0.01):
"""
construct rnn seq2seq model.
:param model: model class 模型种类
:param input_data: input data placeholder 输入
:param output_data: output data placeholder 输出
:param vocab_size: 词长度
:param rnn_size: 一个RNN单元的大小
:param num_layers: RNN层数,神经元
:param batch_size: 步长
:param learning_rate: 学习速率
:return:
"""
end_points = {} def rnn_cell():
if model == 'rnn':
cell_fun = tf.contrib.rnn.BasicRNNCell
elif model == 'gru':
cell_fun = tf.contrib.rnn.GRUCell
elif model == 'lstm':
# 基础模型
cell_fun = tf.contrib.rnn.BasicLSTMCell
# 指定rnn_size大小,H,C,控制值和输出值,是否当做元组返回,默认为True
cell = cell_fun(rnn_size, state_is_tuple=True)
return cell
# 基本单元
cell = tf.contrib.rnn.MultiRNNCell([rnn_cell() for _ in range(num_layers)], state_is_tuple=True)
# cell = tf.contrib.rnn.MultiRNNCell([rnn_cell()] * num_layers, state_is_tuple=True) if output_data is not None:
# 初始化
initial_state = cell.zero_state(batch_size, tf.float32)
else:
initial_state = cell.zero_state(1, tf.float32)
# 指定用cpu运行
with tf.device("/cpu:0"):
# 输入的向量转化为128维向量,所以先构建隐层,指定值为+1 到-1区间
embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
[vocab_size + 1, rnn_size], -1.0, 1.0))
# embedding = tf.Variable(tf.random_uniform([vocab_size + 1,rnn_size],-1.0,1.0))
# 输入的不是所有的词,是所有词的一部分,寻找属于哪个词
inputs = tf.nn.embedding_lookup(embedding, input_data) # [batch_size, ?, rnn_size] = [64, ?, 128]
# 输出,和最后一次的输出
outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
# 转换为128列的输出,output相当于中间一个128维隐层结果,向量
output = tf.reshape(outputs, [-1, rnn_size])
# 128维的权重,和词汇量
weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
# 偏置
bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
# bias加到前面没一行,预测值
logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
# [?, vocab_size+1] if output_data is not None:
# output_data must be one-hot encode 指定深度为,词汇量深度+1,真实值
labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
# should be [?, vocab_size+1]
# 计算损失,真实值和预测值
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
# loss shape should be [?, vocab_size+1]
# 对损失求平均
total_loss = tf.reduce_mean(loss)
# 训练,优化损失
train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss) end_points['initial_state'] = initial_state
end_points['output'] = output
end_points['train_op'] = train_op
end_points['total_loss'] = total_loss
end_points['loss'] = loss
end_points['last_state'] = last_state
else:
# 预测值
prediction = tf.nn.softmax(logits) end_points['initial_state'] = initial_state
end_points['last_state'] = last_state
end_points['prediction'] = prediction return end_points

训练的时候大概训练了半天吧,基于cpu,loss降到2到3之间就没训练了。古诗里面有些宋词,所以可能会写出一些宋词来。因此当写出宋词的时候,统统过滤了,让机器人从新写一次。

怒浩叛奴风入襄,相思犹记浇秋木。
发夜在关江伯均,屏帏倾酒醺醺时。
冲棱出与征轮钓,北阙朱门注胜囚。
冠子弄衣诗句苦,雁归羊祜得行戎。
凭仗眼巡分一已,拂眉下枕风吹鬓。
栏畔秋风听别情,无将日月悲卮花。
处飞更醉紫骝出,莫怕黄芽头孤酒。
潇湘似人风共恶,月影湿金横向波。
潇湘湘水涵红团,拖蜡桐花过酒楼。
雨丰寒鸟没疏角,雨过嵩门逐几层。
歇山地在秦栖狎,借问纯山色色微。
抬普高花量且白,也遭纱帽却离歌。
望华衔泥不复相,时时避缴如红兰。
眼眼南天南岳畔,十年章岸一何当。
仰头忽要天台绿,九没街边雪里临。
天下别容兵起力,怀王过写复何人。
长江九泽分明月,应逐征帆转汉阳。
啸骨立亡何所赠,别时身去坦云中。
壮时烟雨分雪发,战马翩翩一千尺。
怀乡生抛怀甲得,终竟颠应年子人。
激尔饷于无一事,岛层明月上高台。
烈士不能还续竹,应经此去虽迷定。

最新文章

  1. JS函数无响应
  2. 参考__MySql
  3. [Android] HttpURLConnection &amp; HttpClient &amp; Socket
  4. C# 路径
  5. PS网页设计教程XXVIII——如何在PS中创建一个干净的网页布局
  6. javap(反汇编命令)详解【转】
  7. bodybuilding
  8. javaSE第七天
  9. 探索VS中C++多态实现原理
  10. Linux 命令 - kill: 向进程发送信号
  11. 《JS权威指南学习总结--9.5 类和类型》
  12. DataTabel DataSet 对象 转换成json
  13. 团队作业8——Beta 阶段冲刺6th day
  14. 再谈RunLoop
  15. Log4Net 生成多个文件、文件名累加解决方法
  16. JS_高程6.面向对象的程序设计(2)创建对象_3 构造函数存在的问题
  17. CSDN不登录阅读全文(最新更新
  18. Jenkins 批量删除历史构建
  19. 《算法》第四章部分程序 part 12
  20. 怎么使用C++标准库来实现二维数组

热门文章

  1. too many open files linux服务器 golang java
  2. JS代码检查工具ESLint
  3. CAD中用户选择实体
  4. 「关于一种处理关于$p$成多项式的数论函数筛法」
  5. BZOJ_[JSOI2010]Group 部落划分 Group_kruskal
  6. 当需要向数据库插入空值时,sql语句的判断
  7. 死链接检查工具:Xenu 使用教程
  8. cannot be cast to java.lang.Comparable
  9. 【ODI】| 数据ETL:从零开始使用Oracle ODI完成数据集成(二)
  10. 把时间留给重要的事——Markdown 模板功能上线