全部代码如下:(红色部分为与笔记二不同之处)

#1.Import the neccessary libraries needed
import numpy as np
import tensorflow as tf
import matplotlib
from matplotlib import pyplot as plt ######################################################################## #2.Set default parameters for plots
matplotlib.rcParams['font.size'] = 20
matplotlib.rcParams['figure.titlesize'] = 20
matplotlib.rcParams['figure.figsize'] = [9, 7]
matplotlib.rcParams['font.family'] = ['STKaiTi']
matplotlib.rcParams['axes.unicode_minus']=False ######################################################################## #3.Initialize Parameters #Initialize learning rate
lr = 1e-2 #----------------------changed
#Initialize batch size
batchsz = 512
#Initialize loss and accurate array
losses = []
accs = [] #----------------------changed
#Initialize the weights layers and the bias layers
w1=tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1=tf.Variable(tf.zeros([256]))
w2=tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2=tf.Variable(tf.zeros([128]))
w3=tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3=tf.Variable(tf.zeros([10])) ########################################################################
#4.Define preprocess function #----------------------changed
def preprocess(x,y):
x=tf.cast(x,dtype=tf.float32)/255.
x=tf.reshape(x,[-1,28*28])
y=tf.cast(y,dtype=tf.int32)
#one_hot接受的输入为int32,输出为float32
y=tf.one_hot(y,depth=10)
return x,y ######################################################################## #5.Import the minist dataset offline
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.mnist.load_data(path=r'F:\learning\machineLearning\TensorFlow2_deeplearning\forward_progression\mnist.npz')
train_db=tf.data.Dataset.from_tensor_slices((x_train,y_train))
train_db=train_db.shuffle(10000) #-----------------------changed
train_db=train_db.batch(batchsz)
train_db=train_db.map(preprocess)
#Control the epoch times
train_db=train_db.repeat(20) test_db=tf.data.Dataset.from_tensor_slices((x_test,y_test))
test_db=test_db.shuffle(1000).batch(batchsz).map(preprocess) ######################################################################## #The main function
def main():
for step,(x,y) in enumerate(train_db):#Or for x,y in train_db:
with tf.GradientTape() as tape: # tf.Variable
# layer1
h1 = x@w1 + b1
h1 = tf.nn.relu(h1)
# layer2
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
# output
out = h2@w3 + b3
# compute loss
loss = tf.square(y-out)
# mean: scalar
loss = tf.reduce_mean(loss)
# compute gradients
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
#Update the weights and the bias #-----------------------changed
for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
p.assign_sub(lr * g) if step % 80 == 0:
print(step, 'loss:', float(loss))
losses.append(float(loss)) if step % 80 == 0: #-----------------------changed
total, total_correct = 0., 0
for x,y in test_db:
# layer1
h1 = x@w1 + b1
h1 = tf.nn.relu(h1)
# layer2
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
# output
out = h2@w3 + b3
pred=tf.argmax(out,axis=1)
y=tf.argmax(y,axis=1)
correct=tf.equal(pred,y)
total_correct+=tf.reduce_sum(tf.cast(correct,dtype=tf.int32)).numpy()
total+=x.shape[0]
print(step,'Evaluate ACC:',total_correct/total)
accs.append(total_correct/total)
plt.figure()
x = [i*80 for i in range(len(losses))]
plt.plot(x, losses, color='C0', marker='s', label='训练')
plt.ylabel('MSE')
plt.xlabel('Step')
plt.legend() plt.figure()
plt.plot(x, accs, color='C1', marker='s', label='测试')
plt.ylabel('准确率')
plt.xlabel('Step')
plt.legend() plt.show()
if __name__ == '__main__':
main()

其中learning rate在此处改为了1e-2,经测试若为1e-3则accurate rate会增长较慢,在20epoch下最终会达到30~40%,而1e-2则会接近80%

并且通过.map(preprocess)方法预处理了train_db,包括将图片数据标准化到(0-1),reshape到[-1,28*28],将标签数据做one-hot处理,深度为10;通过train_db=train_db.repeat(20)代替了for epoch in range(20);用

for p, g in zip([w1, b1, w2, b2, w3, b3], grads):
  p.assign_sub(lr * g)
代替了
w1.assign_sub(lr * grads[0])
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5])

最新文章

  1. [Bug]2016-02
  2. Android 多媒体视频播放一( 多媒体理解与经验分享)
  3. [转] java中int,char,string三种类型的相互转换
  4. 读取缓存模拟----FIFO
  5. MySQL数据库优化技术概述
  6. dom div重合提示
  7. 解决lucene 重复索引的问题
  8. 微博输入相关js 代码
  9. ZABBIX安装官方指南
  10. Swift - 点击输入框外部屏幕关闭虚拟键盘
  11. kendo ui 单击取消编辑数据grid减少的原因和治疗方法的数据
  12. 使用Cookie来统计浏览次数,当天重复刷新不增加
  13. HandlerMapping 和 HandlerAdapter
  14. kafka全部数据清空与某一topic数据清空
  15. 胜利大逃亡,bfs,广度优先搜索
  16. JWT的相关讲解
  17. php获取本月、上月、上上月、今日、昨日、上周的起始时间
  18. Python+Selenium笔记(四):unittest的Test Suite(测试套件)
  19. k8s+Jenkins+GitLab-自动化部署asp.net core项目
  20. (c#) 销毁资源和释放内存

热门文章

  1. Bootstrap table插件 被选中的行颜色改变
  2. nginx/apache静态资源跨域访问问题详解
  3. ssh连接的原理
  4. .NET Core开源Quartz.Net作业调度框架实战演练
  5. 【并行计算-CUDA开发】【视频开发】ffmpeg Nvidia硬件加速总结
  6. hive 集群搭建
  7. readiness与liveness
  8. 终端下更改printk打印级别
  9. 更改ubuntu桌面环境
  10. C++工程师养成 每日一题(string使用)