对于生活中的熟悉的动物,我们人脑经过一次扫描,便可以得到该动物的物种!那么机器是如何识别这个图片上的动物是属于哪一物种呢?

本次实验借生活中最常见的猫和狗来探究其原理!

环境准备:

tensorflow ,python,一些data

实验预期:

  当模型训练完成后,我们可以用该模型去预测一张图片属于哪一个类别,很显然,本次项目属于一个二分类问题,

  网上有很多此类的项目,但是都不能很好的落地,那么这次实验所完成的最终结果是,我们上传一张图片,控制台

  便会返回该图片的类别:猫/狗

模型搭建:

  对于图片识别来说,最强大的工具莫过于卷积神经网络,对于CNN的原理也不是很难,只要知道其主要的计算过程即可,

  熟悉CNN的人都知道,并不是层数越多越好,因为层数过多,会造正过拟合,导致实验结果不会很理想,所以经过我多次的实验,

  最终模型的设置如下:

  

model = tf.keras.models.Sequential([

    tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid')
])

  每一层卷积跟一层最大池化,Conv2D()中参数:16表示卷积核个数,(3,3)表示卷积核大小,很多论文中给出的代码中设定的也是(3,3),input_shape表示输入数据形状,后面是通道数;

  经过最大池化留下来的神经元对输出才会有贡献!环节卷积层对位置的敏感性!

然后再模型之前,我们也需要对数据进行一些操作:读取数据,将数据分为验证数据集和训练数据集

base_dir = 'D:/cats and dogs'

train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation') train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs') validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

接下来的操作就是一些固定的步骤,对数据进行归一化,生成带标签的数据,绘制损失曲线等,直接上代码:

train_datagen = ImageDataGenerator(rescale=1.0 / 255.)
test_datagen = ImageDataGenerator(rescale=1.0 / 255.) train_generator = train_datagen.flow_from_directory(train_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150)) validation_generator = test_datagen.flow_from_directory(validation_dir,
batch_size=20,
class_mode='binary',
target_size=(150, 150)) history = model.fit_generator(train_generator,
validation_data=validation_generator,
steps_per_epoch=100,
epochs=15,
validation_steps=50,
verbose=2) model.save('model.h5') acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc)
plt.plot(epochs, val_acc)
plt.title('Training and validation accuracy')
plt.legend(('Training accuracy', 'validation accuracy'))
plt.figure() plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.legend(('Training loss', 'validation loss'))
plt.title('Training and validation loss')
plt.show()

预测部分

from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.preprocessing import image path = 'D:/cats and dogs/cat.123.jpg'
model = load_model('model.h5')
img = image.load_img(path, target_size=(150, 150))
x = image.img_to_array(img) / 255.0 x = np.expand_dims(x, axis=0)
# np.vstack:按垂直方向(行顺序)堆叠数组构成一个新的数组
images = np.vstack([x]) classes = model.predict(images, batch_size=1) if classes[0] > 0.5:
print("图片识别为狗")
else:
print("图片识别为猫")

结果说明还可以!!!!!!!

最新文章

  1. ie6兼容问题汇总
  2. mvc Razor 视图中找不到 ViewBag的定义
  3. excel日期格式转换为文本格式
  4. 中国移动测试大会 PPT 和视频
  5. CentOS版本选择说明
  6. 移动端调试工具-Weinre
  7. webservice basics
  8. 无线路由器wds桥接技术+丢包率
  9. Delphi 两个应用程序(进程)之间的通信
  10. lintcode : 二叉树的层次遍历
  11. windows 下FFMPEG的编译方法 附2012-9-19发布的FFMPEG编译好的SDK下载
  12. 搭建Windows Azure开发环境-环境搭建
  13. Spring读书笔记-----Spring的Bean之设置Bean值
  14. 关于多本小说站的SEO—从”易读中文网”获得的心得体会
  15. [读书笔记]黑客与画家[Hackers.and.Painters]
  16. Arrays.asList的那点事
  17. ASP.NET MVC5 怒跨 Linux 平台
  18. Python中的文件路径的分隔符
  19. perl学习笔记---标量
  20. Git以一个远程分支为基础新建一个远程分支(转载)

热门文章

  1. windows版本rabbitmq安装及日志level设置
  2. 零基础学Java(13)方法参数
  3. 筛 sigma_k
  4. 我和Apache DolphinScheduler的这一年
  5. 图片系列(6)不同版本上 Bitmap 内存分配与回收原理对比
  6. 如何定义 Java 的回调函数,与 JavaScript 回调函数的区别
  7. 从0搭建Vue3组件库:button组件
  8. PyTorch中的CUDA操作
  9. 利用Hugging Face中的模型进行句子相似性实践
  10. Go 语言入门 3-动态数组(slice)的特性及实现原理