#!/usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf

class TRNNConfig(object):
"""RNN配置参数"""

# 模型参数
embedding_dim = 64 # 词向量维度
seq_length = 600 # 序列长度
num_classes = 10 # 类别数
vocab_size = 5000 # 词汇表达小

num_layers= 2 # 隐藏层层数
hidden_dim = 128 # 隐藏层神经元
rnn = 'gru' # lstm 或 gru

dropout_keep_prob = 0.8 # dropout保留比例
learning_rate = 1e-3 # 学习率

batch_size = 128 # 每批训练大小
num_epochs = 10 # 总迭代轮次

print_per_batch = 100 # 每多少轮输出一次结果
save_per_batch = 10 # 每多少轮存入tensorboard

class TextRNN(object):
"""文本分类,RNN模型"""
def __init__(self, config):
self.config = config

# 三个待输入的数据
self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

self.rnn()

def rnn(self):
"""rnn模型"""

def lstm_cell(): # lstm核
return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)

def gru_cell(): # gru核
return tf.contrib.rnn.GRUCell(self.config.hidden_dim)

def dropout(): # 为每一个rnn核后面加一个dropout层
if (self.config.rnn == 'lstm'):
cell = lstm_cell()
else:
cell = gru_cell()
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

# 动作映射
with tf.device('/cpu:0'):
embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

with tf.name_scope("rnn"):
# 多层rnn网络
cells = [dropout() for _ in range(self.config.num_layers)]
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)

_outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
last = _outputs[:, -1, :] # 取最后一个时序输出作为结果

with tf.name_scope("score"):
# 全连接层,后面接dropout以及relu激活
fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
fc = tf.nn.relu(fc)

# 分类器
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
# 预测类别
self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)

with tf.name_scope("optimize"):
# 损失函数,交叉熵
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
#求输入的所有行的预测值的均值
self.loss = tf.reduce_mean(cross_entropy)
# 优化器
self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

with tf.name_scope("accuracy"):
# 准确率 其中 self.y_pred_cls为预测的类别
correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
#
self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

最新文章

  1. 迅为-iMX6开发板 飞思卡尔iMX6Q开发板 工业级开发板
  2. eclipse配置jdk的src.zip源代码步骤
  3. linux | 管道符、输出重定向
  4. ASP.NET MVC路由配置(转载自http://www.cnblogs.com/zeusro/p/RouteConfig.html )
  5. hdu 2545(并查集求节点到根节点的距离)
  6. Android 2D游戏引擎AndEngine配置环境
  7. [Everyday Mathematics]20150123
  8. Spring MVC 3.0.5+Spring 3.0.5+MyBatis3.0.4全注解实例详解(一)
  9. smarty 变量调节器
  10. 【最大流】【HDU3572】Task Schedule
  11. ubuntu字符界面怎么设置中文显示和中文输入
  12. Can't find msguniq. Make sure you have GNU gettext tools 0.15 or newer installed
  13. 多个git使用的 ssh key共存
  14. linux I/O状态实时监控iostat
  15. 在linux下,去除^M,将windows格式文件(dos文件)改为unix格式文件
  16. 手动搭建一个webpack+react笔记
  17. OCR 识别原理
  18. .net 使用com组件操作word遇到的一些问题
  19. Python3 Tkinter-Scale
  20. autofac.webapi2

热门文章

  1. Linux网络管理——nslookup
  2. python链接sql server 乱码问题
  3. powerlink的Windows-DEMO生成笔记
  4. (14)占位符%和format
  5. ubuntu---解决pip安装tf很慢的问题
  6. 使用python批量造测试数据
  7. FFmpeg常用命令学习笔记(七)直播相关命令
  8. 将 Python 程序打包成 .exe 文件
  9. ZrOJ #882. 画画
  10. Math.cbrt() Math.sqrt() Math.pow()