这篇博客主要是使用Keras框架微调Inception V3模型对卫星图片进行分类,并测试;

1. 流程概述

  微调Inception V3对卫星图片进行分类;整个流程可以大致分成四个步骤,如下:

  • (1)Satellite数据集准备;
  • (2)搭建Inception V3网络;
  • (3)进行训练;
  • (4)测试;

2. 准备数据集

2.1 Satellite数据集介绍

  用于实验训练与测试的数据集来自于《21个项目玩转深度学习:基于Tensorflow的实践详解》第三章中提供的实验卫星图片数据集;

  Satellite数据集目录结构如下:

# 其中共6类卫星图片,训练集总共4800张,每类800张;验证集共1200张,每类200张;
Satellite/
train/
glacier/
rock/
urban/
water/
wetland/
wood/
validation/
glacier/
rock/
urban/
water/
wetland/
wood/

3. Inception V3网络

  待补充;

4. 训练

4.1 基于Keras微调Inception V3网络

from keras.application.incepiton_v3 import InceptionV3, preprocess_input
from keras.layers import GlobalAveragePooling2D, Dense # 基础Inception_V3模型,不包含全连接层
base_model = InceptionV3(weights='imagenet', include_top=False)
# 增加新的输出层
x = base_model.output
x = GlobalAveragePooling2D()(x) # 添加Global average pooling层
x = Dense(1024, activation='relu')(x)
predictions = Dense(6, activation='softmax')(x)

4.2 Keras实时生成批量增强数据

# keras实时生成批量增强数据
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input, # 将每一张图片归一化到[-1,1];数据增强后执行;
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
)
val_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
) # 指定数据集路径并批量生成增强数据
train_generator = train_datagen.flow_from_directory(directory='satellite/data/train',
target_size=(299, 299),#Inception V3规定大小
batch_size=64)
val_generator = val_datagen.flow_from_directory(directory='satellite/data/validation',
target_size=(299,299),
batch_size=64)

4.3 配置transfer learning & finetune

from keras.optimizers import Adagrad

# transfer learning
def setup_to_transfer_learning(model,base_model):#base_model
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # 配置模型,为下一步训练 # finetune
def setup_to_fine_tune(model,base_model):
GAP_LAYER = 17 # max_pooling_2d_2
for layer in base_model.layers[:GAP_LAYER+1]:
layer.trainable = False
for layer in base_model.layers[GAP_LAYER+1:]:
layer.trainable = True
model.compile(optimizer=Adagrad(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

4.4 执行训练

# Step 1: transfer learning
setup_to_transfer_learning(model,base_model)
history_tl = model.fit_generator(generator=train_generator,
steps_per_epoch=75, # 800
epochs=10,
validation_data=val_generator,
validation_steps=64, # 12
class_weight='auto'
)
model.save('satellite/train_dir/satellite_iv3_tl.h5') # Step 2: finetune
setup_to_fine_tune(model,base_model)
history_ft = model.fit_generator(generator=train_generator,
steps_per_epoch=75,
epochs=10,
validation_data=val_generator,
validation_steps=64,
class_weight='auto')
model.save('satellite/train_dir/satellite_iv3_ft.h5')

5. 测试

5.1 对单张图片进行测试

# *-coding: utf-8 -*

"""
使用h5模型文件对satellite进行测试
"""
# ================================================================
import tensorflow as tf
import numpy as np
from skimage import io
from keras.models import load_model def normalize(array):
"""对给定数组进行归一化 Argument:
array: array
给定数组
Return:
array_norm: array
归一化后的数组
"""
array_flatten = array.flatten()
array_mean = np.mean(array_flatten)
mx = np.max(array_flatten)
mn = np.min(array_flatten)
array_norm = [(float(i) - array_mean) / (mx - mn) for i in array_flatten] return np.reshape(array_norm, array.shape) def img_preprocess(image_path):
"""根据图片路径,对图片进行相应预处理 Argument:
image_path: str
输入图片路径
Return:
image_data: array
预处理好的图像数组
"""
img_array = io.imread(image_path)
img_norm = normalize(img_array)
size = img_norm.shape
image_data = np.reshape(img_norm, (1, size[0], size[1], 3)) return image_data def index_to_label(index):
"""将标签索引转换成可读的标签 Argument:
index: int
标签索引位置
Return:
human_label: str
人可读的标签
"""
labels = ["glacier", "rock", "urban", "water", "wetland", "wood"]
human_label = labels[index] return human_label def classifier_satellite_byh5(image_path, model_file_path):
"""对给定单张图片使用训练好的模型进行分类 Argument:
image_path: str
输入图片路径
model_file_path: str
训练好的h5模型文件名称
Return:
human_label: str
人可读的图片标签
"""
image_data = img_preprocess(image_path)
# 加载模型文件
model = load_model(model_file_path)
predictions = model.predict(image_data) human_label = index_to_label(np.argmax(predictions)) return human_label def classifier_satellite_byh5_hci(image_path):
"""用于对从交互界面传来的图片进行分类 Argument:
image_path: str
Return:
human_label: str
人可读的图片标签
"""
# 模型文件,如果有新的模型需要修改
model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5" image_data = img_preprocess(image_path)
# 加载模型文件
model = load_model(model_file_path)
predictions = model.predict(image_data) human_label = index_to_label(np.argmax(predictions)) return human_label # 测试单张图片
if __name__ == "__main__":
image_path = "satellite/data/train/glacier/40965_91335_18.jpg"
model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5" human_label = classifier_satellite_byh5(image_path, model_file_path)
print(human_label)

6. 可视化分类界面

6.1 交互界面设计

# encoding: utf-8
"""
交互界面:使用训练好的模型对卫星图片进行分类;
""" from tkinter import *
import tkinter
import tkinter.filedialog
import os
import tkinter.messagebox
from PIL import Image, ImageTk
import test_satellite_bypb # 窗口属性
root = tkinter.Tk()
root.title('Satellite图像分类')
root.geometry('800x600') formatImg = ['jpg'] def resize(w, h, w_box, h_box, pil_image):
# 对一个pil_image对象进行缩放,让它在一个矩形框内,还能保持比例 f1 = 1.0*w_box/w # 1.0 forces float division in Python2
f2 = 1.0*h_box/h
factor = min([f1, f2])
width = int(w*factor)
height = int(h*factor)
return pil_image.resize((width, height), Image.ANTIALIAS) def showImg():
img1 = entry_imgPath.get() # 获取图片路径地址
pil_image = Image.open(img1) # 打开图片
# 期望显示大小
w_box = 400
h_box = 400
# 获取原始图像的大小
w, h = pil_image.size
pil_image_resized = resize(w, h, w_box, h_box, pil_image) # 把PIL图像对象转变为Tkinter的PhotoImage对象
tk_image = ImageTk.PhotoImage(pil_image_resized) img = tkinter.Label(image=tk_image, width=w_box, height=h_box)
img.image = tk_image
img.place(x=50, y=150) def choose_file():
text_showClass.delete(0.0, END) # 清空输出结果文本框,在再次选择图片文件之前清空上次结果;
selectFileName = tkinter.filedialog.askopenfilename(title='选择文件') # 选择文件
if selectFileName[-3:] not in formatImg:
tkinter.messagebox.askokcancel(title='出错', message='未选择图片或图片格式不正确') # 弹出错误窗口
return
else:
e.set(selectFileName) # 设置变量
showImg() # 显示图片 def ouputOfModel():
# 完成识别,显示类别
# 图片文件路径
text_showClass.delete(0.0, END) # 清空上次结果文本框
img_path = entry_imgPath.get() # 获取所选择的图片路径地址 # 判断是否存在改图片
if not os.path.exists(img_path):
tkinter.messagebox.askokcancel(title='出错', message='未选择图片文件或图片格式不正确')
else: # 得到输出结果,以及相应概率
human_label = test_satellite_bypb.classifier_satellite_img(img_path)
# 通过训练的模型,计算得到相对应输出类别 # 清空文本框中的内容,写入识别出来的类别
text_showClass.config(state=NORMAL)
text_showClass.insert('insert', '%s\n' % (human_label)) ##################
# 窗口部件
################## e = tkinter.StringVar() # 字符串变量 # label : 选择文件
label_selectImg = tkinter.Label(root, text='选择图片:')
label_selectImg.grid(row=0, column=0) # Entry: 显示图片文件路径地址
entry_imgPath = tkinter.Entry(root, width=80, textvariable=e)
entry_imgPath.grid(row=0, column=1) # Button: 选择图片文件
button_selectImg = tkinter.Button(root, text="选择", command=choose_file)
button_selectImg.grid(row=0, column=2) # Button: 执行识别程序按钮
button_recogImg = tkinter.Button(root, text="开始识别", command=ouputOfModel)
button_recogImg.grid(row=0, column=3) # Text: 显示结果类别文本框
text_showClass = tkinter.Text(root, width=20, height=1, font='18',)
text_showClass.grid(row=1, column=1)
text_showClass.config(state=DISABLED) root.mainloop()

6.2 后台核心代码:模型加载并分类

# *-coding: utf-8 -*

"""
使用h5模型文件对satellite进行测试
"""
# ================================================================
import tensorflow as tf
import numpy as np
from skimage import io
from keras.models import load_model def normalize(array):
"""对给定数组进行归一化 Argument:
array: array
给定数组
Return:
array_norm: array
归一化后的数组
"""
array_flatten = array.flatten()
array_mean = np.mean(array_flatten)
mx = np.max(array_flatten)
mn = np.min(array_flatten)
array_norm = [(float(i) - array_mean) / (mx - mn) for i in array_flatten] return np.reshape(array_norm, array.shape) def img_preprocess(image_path):
"""根据图片路径,对图片进行相应预处理 Argument:
image_path: str
输入图片路径
Return:
image_data: array
预处理好的图像数组
"""
img_array = io.imread(image_path)
img_norm = normalize(img_array)
size = img_norm.shape
image_data = np.reshape(img_norm, (1, size[0], size[1], 3)) return image_data def index_to_label(index):
"""将标签索引转换成可读的标签 Argument:
index: int
标签索引位置
Return:
human_label: str
人可读的标签
"""
labels = ["glacier", "rock", "urban", "water", "wetland", "wood"]
human_label = labels[index] return human_label def classifier_satellite_byh5(image_path, model_file_path):
"""对给定单张图片使用训练好的模型进行分类 Argument:
image_path: str
输入图片路径
model_file_path: str
训练好的h5模型文件名称
Return:
human_label: str
人可读的图片标签
"""
image_data = img_preprocess(image_path)
# 加载模型文件
model = load_model(model_file_path)
predictions = model.predict(image_data) human_label = index_to_label(np.argmax(predictions)) return human_label def classifier_satellite_byh5_hci(image_path):
"""用于对从交互界面传来的图片进行分类 Argument:
image_path: str
Return:
human_label: str
人可读的图片标签
"""
# 模型文件,如果有新的模型需要修改
model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5" image_data = img_preprocess(image_path)
# 加载模型文件
model = load_model(model_file_path)
predictions = model.predict(image_data) human_label = index_to_label(np.argmax(predictions)) return human_label # 测试单张图片
if __name__ == "__main__":
image_path = "satellite/data/train/glacier/40965_91335_18.jpg"
model_file_path = "satellite/train_dir/models/satellite_iv3_ft.h5" human_label = classifier_satellite_byh5(image_path, model_file_path)
print(human_label)

6.3 交互界面效果

最新文章

  1. web api9
  2. ArrayList实现删除重复元素(元素不是对象类型的情况)
  3. Responsive设计——不同设备的分辨率设置
  4. 设置DIV块元素在浏览器页面中垂直居中
  5. JavaScript函数编程-Ramdajs
  6. 是否要学SpringMVC
  7. css定位之绝对定位
  8. Intellij Idea使用技巧、快捷键
  9. codeforces 439 E. Devu and Birthday Celebration 组合数学 容斥定理
  10. NET垃圾回收机制【Copy By Internet】
  11. CodeForces 705A Hulk (水题)
  12. 在Mac OS X中使用VIM开发STM32(3)
  13. WCF消息
  14. Berserk Rook
  15. 移动web前端的一些硬技能(二)动手前必须掌握的基本常识
  16. Spring+SpringMVC+MyBatis+easyUI整合基础篇(七)JDBC url的连接参数
  17. 关于导入excel问题
  18. C# 通过KD树进行距离最近点的查找.
  19. 分布式Session共享解决方案
  20. 络谷AT941(水提高+)题解

热门文章

  1. Redis之java增删改查
  2. 如何查看apache配置文件路径
  3. ubuntu安装opencv(自己编译)
  4. h大数据
  5. Android4.4 GPS框架分析【转】
  6. jumpserver v3.0
  7. sql中使用timestamp增量抽取数据
  8. Git_学习_02_ 分支
  9. ACM学习历程——POJ 1700 Crossing River(贪心)
  10. 洛谷 2585 [ZJOI2006]三色二叉树——树形dp