首先需要安装gym模块,提供游戏的。

1,所需模块

import tensorflow as tf
import numpy as np
import gym
import random
from collections import deque
from keras.utils.np_utils import to_categorical

2,自定义一个简单的3层Dense Model

# 自定义Model
class QNetwork(tf.keras.Model):
def __init__(self):
super().__init__()
# 简单的3个Dense
self.dense1=tf.keras.layers.Dense(24,activation='relu')
self.dense2=tf.keras.layers.Dense(24,activation='relu')
self.dense3=tf.keras.layers.Dense(2)
def call(self,inputs):
x=self.dense1(inputs)
x=self.dense2(x)
x=self.dense3(x)
return x
def predict(self,inputs):
q_values=self(inputs)#调用call
return tf.argmax(q_values,axis=-1)

3,定义相关参数

# 游戏环境,实例化一个游戏
env=gym.make('CartPole-v1')
model=QNetwork() # 循环轮数设置小一点,50就可以了
num_episodes=500
num_exploration=100
max_len=1000
batch_size=32
lr=1e-3
gamma=1.
initial_epsilon=1.
final_epsilon=0.01
replay_buffer=deque(maxlen=10000) epsilon=initial_epsilon
# tensorflow2.0
optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=lr)

4,训练,测试

for i in range(num_episodes):
# 初始化环境
state=env.reset()
# 逐渐衰减,至final_epsilon
epsilon=max(initial_epsilon*(num_exploration-i)/num_exploration,final_epsilon)
for t in range(max_len):
# 当前帧绘制到屏幕
env.render()
# 以epsilon的概率随机行动,epsilon是衰减的,说明游戏动作会越来越稳定
if random.random()<epsilon:
action=env.action_space.sample()
else:
# 从当前状态预测一个动作
action=model.predict(tf.constant(np.expand_dims(state,axis=0),dtype=tf.float32)).numpy()
action=action[0]
# 执行一步动作
next_state,reward,done,info=env.step(action)
# 奖励
reward=-10.if done else reward
# 缓存
replay_buffer.append((state,action,reward,next_state,done))
state=next_state
if done:
print('episode %d,epsilon %f,score %d'%(i,epsilon,t))
break
# 预测batch_size步后执行
if len(replay_buffer)>=batch_size:
# 随机获取一个batch的数据
batch_state,batch_action,batch_reward,batch_next_state,batch_done=\
[np.array(a,dtype=np.float32) for a in zip(*random.sample(replay_buffer,batch_size))]
# 下一个状态,由此得到的y为真实值
# 预测值与真实值的计算看不太懂
q_value=model(tf.constant(batch_next_state,dtype=tf.float32))
y=batch_reward+(gamma*tf.reduce_max(q_value,axis=1))*(1-batch_done)
with tf.GradientTape() as tape:
# loss=tf.losses.mean_squared_error(labels=y,predictions=tf.reduce_sum(
# model(tf.constant(batch_state))*tf.one_hot(batch_action,depth=2),axis=1))
loss=tf.losses.mean_squared_error(y,tf.reduce_sum(
model(tf.constant(batch_state))*to_categorical(batch_action,num_classes=2),axis=1))
grads=tape.gradient(loss,model.variables)
optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))

最终会出现一个窗口,平衡游戏不断进行。。。

上面注释部分因为tf.one_hot方法会报错。

最新文章

  1. 佳能6d 魔灯
  2. App Transport Security has blocked a cleartext HTTP (http://)
  3. Div添加阴影效果
  4. elasticsearch 之IK分词器安装
  5. Moqui学习Day3
  6. 解决Windows照片查看器中图片显示发黄的问题
  7. Cg Programming/Vertex Transformations
  8. 修改webftp,在线文件管理
  9. eMarketer:DMP帮广告主搞定大数据处理问题
  10. Python异步IO --- 轻松管理10k+并发连接
  11. 高性能WEB开发之Web性能测试工具推荐
  12. AliCTF 2016
  13. SVD神秘值分解
  14. Vulkan Tutorial 19 Vertex input description
  15. kvm克隆
  16. 安卓 新版本 获取wifi状态网络是否可用等
  17. [十二省联考2019]字符串问题——后缀自动机+parent树优化建图+拓扑序DP+倍增
  18. docker安装,err:exit status 255,提示找不到虚拟机IP
  19. 文本超过控件长度自动显示省略号的css
  20. 《python语言程序设计》_第二章笔记

热门文章

  1. TO B是什么?TO C呢?
  2. 构建的Web应用界面还不够好看?DevExtreme v19.1全新主题来袭
  3. grunt-contrib-compass 编译sass
  4. qt5---滑动条QSlider
  5. Adboost几个要点分析
  6. vue项目搭建步骤以及一些安装依赖包
  7. Array数组对象方法
  8. maven项目创建4
  9. TTTTTTTTTTTT POJ 2112 奶牛与机器 多重二分匹配 跑最大流 建图很经典!!
  10. 家谱(gen)x