from tensorflow.python.keras.preprocessing.image import load_img,img_to_array
from tensorflow.python.keras.models import Sequential,Model
from tensorflow.python.keras.layers import Dense,Flatten,Input
import tensorflow as tf
from tensorflow.python.keras.losses import sparse_categorical_crossentropy
from tensorflow.python import keras
import os
import numpy as np class SingleNN(object): #建立神经网络模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28,28)),
keras.layers.Dense(128,activation=tf.nn.relu),
keras.layers.Dense(10,activation=tf.nn.softmax)
]) def __init__(self):
(self.x_train,self.y_train),(self.x_test,self.y_test) = keras.datasets.fashion_mnist.load_data()
#归一化
self.x_train = self.x_train/255.0
self.x_test = self.x_test/255.0 def singlenn_compile(self):
'''
编译模型优化器、损失、准确率
:return:
'''
SingleNN.model.compile(
optimizer=keras.optimizers.SGD(lr=0.01),
loss=keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
) def singlenn_fit(self):
"""
进行fit训练
:return:
"""
SingleNN.model.fit(self.x_train,self.y_train,epochs=5) def single_evalute(self):
'''
模型评估
:return:
'''
test_loss,test_acc = SingleNN.model.evaluate(self.x_test,self.y_test)
print(test_loss,test_acc) def single_predict(self):
'''
预测结果
:return:
'''
# if os.path.exists("./ckpt/checkpoink"):
# SingleNN.model.load_weights("./ckpt/SingleNN") if os.path.exists("./ckpt/SingleNN.h5"):
SingleNN.model.load_weights("./ckpt/SingleNN.h5") predictions = SingleNN.model.predict(self.x_test) return predictions if __name__ == '__main__':
snn = SingleNN()
# snn.singlenn_compile()
# snn.singlenn_fit()
# snn.single_evalute()
# # SingleNN.model.save_weights("./ckpt/SingleNN")
# SingleNN.model.save_weights("./ckpt/SingleNN.h5")
predictions = snn.single_predict()
print(predictions)
result = np.argmax(predictions,axis=1)
print(result)

  

最新文章

  1. Duilib源码分析(六)整体流程
  2. STM32之待机唤醒
  3. 浅析Java中CountDownLatch用法
  4. 网络数据包收发流程(二):不配置NAPI的情况
  5. asp.net mvc 配合前端js的CMD模块化部署思想,小思路
  6. PHP学习笔记:用mysqli连接数据库
  7. [Educational Codeforces Round 16]A. King Moves
  8. 看项目得到info_freeCsdn-01闪屏页面
  9. 【转】Android Studio Essential Training
  10. Qt 学习之路:线程和 QObject
  11. 2015.9.11模拟赛 codevs 4159【hzwer的迷の数列】
  12. nodejs安装express遇到的坑
  13. vue-router 快速入门
  14. MapReduce程序依赖的jar包
  15. js基础之冒号
  16. PHP学习笔记10-图片加水印
  17. ZZCMS8.1|代码审计
  18. springboot 学习之路 6(集成durid连接池)
  19. MySql Scaffolding an Existing Database in EF Core
  20. XML基础入门

热门文章

  1. Mysql 随笔记录
  2. mybatis 入门基础
  3. MySQL进阶篇(01):基于多个维度,分析服务器性能
  4. CentOS7通过wget下载文件到指定目录
  5. Django-User
  6. ServletConfig&ServletContext对比
  7. 微信小程序template富文本插件image宽度被js强制设置
  8. C/C++ 数据精确度的设置
  9. Flask 入门(五)
  10. json格式的文件操作2