一直在研究怎样用caffe做行人检測问题。然而參考那些经典结构比方faster-rcnn等,都是自己定义的caffe层来完毕的检測任务。

这些都要求对caffe框架有一定程度的了解。近期看到了怎样用caffe完毕回归的任务,就想把检測问题当成回归问题来解决。

我们把行人检測问题当成回归来看待,就须要限制检出目标的个数,由于我们的输出个数是固定的。所以,这里我假定每张图片最多检出的目标个数为2。即每一个目标用4个值来表示其位置信息(中心位置坐标x,y。

BBox的宽和高)。则网络的最后输出是8个值。

制作HDF5数据

这里我们使用HDF5格式的数据来完毕我们的回归任务,那么首先我们须要的是制作h5格式的数据。

这里以VOC数据集为例。以下是制作HDF5格式数据的python代码。

import h5py
import caffe
import os
import xml.etree.ElementTree as ET
import cv2
import time
import math
from os.path import join, exists
import numpy as np def convert(size, box):
dw = 1./size[0]
dh = 1./size[1]
x = (box[0] + box[1])/2.0
y = (box[2] + box[3])/2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h) def shuffle_in_unison_scary(a, b):
rng_state = np.random.get_state()
np.random.shuffle(a)
np.random.set_state(rng_state)
np.random.shuffle(b) def processImage(imgs):
imgs = imgs.astype(np.float32)
for i, img in enumerate(imgs):
m = img.mean()
s = img.std()
imgs[i] = (img - m) / s
return imgs TrainImgDir = 'F:/GenerateHDF5/trainImage'
TrainLabelDir = 'F:/GenerateHDF5/trainLabels'
TestImgDir = 'F:/GenerateHDF5/testImg'
TestLabelDir = 'F:/GenerateHDF5/testLabels' InImg = []
InBBox = [] for rootDir,dirs,files in os.walk(TestLabelDir): #####
for file in files:
file_name = file.split('.')[0]
full_file_name = '%s%s'%(file_name,'.jpg')
full_file_dir = '%s/%s'%(TestImgDir,full_file_name) #####
Img = cv2.imread(full_file_dir,cv2.CV_LOAD_IMAGE_GRAYSCALE)
xml_file = open("%s/%s"%(rootDir,file))
tree = ET.parse(xml_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text) landmark = np.zeros(8)
count = 0
for obj in root.iter('object'):
count = count + 1
if count == 3:
break
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
landmark[(count-1)*4+0]=bb[0]
landmark[(count-1)*4+1]=bb[1]
landmark[(count-1)*4+2]=bb[2]
landmark[(count-1)*4+3]=bb[3] InBBox.append(landmark.reshape(8))
Img = cv2.resize(Img,(h,w))
InImg.append(Img.reshape((1,h,w))) InImg, InBBox = np.asarray(InImg), np.asarray(InBBox)
InImg = processImage(InImg)
shuffle_in_unison_scary(InImg, InBBox) outputDir = 'hdf5/'
HDF5_file_name = 'hdf5_test.h5' #####
if not os.path.exists(outputDir):
os.makedirs(outputDir) output = join(outputDir,HDF5_file_name)
with h5py.File(output, 'w') as h5:
h5['data'] = InImg.astype(np.float32)
h5['labels'] = InBBox.astype(np.float32)
h5.close()

这里注意一点,全部的BBox数据都要做归一化操作,即全部坐标要除以图片相应的宽高。据说,这样做能使最后得到的结果更好。

制作好了HDF5数据后。注意每一个H5文件大小不能超过2G(这是caffe的规定,假设一个文件超过2G。请分开制作多个)。

然后建立一个TXT文件,文件中写上全部H5文件的绝对路径。比方我这里建立的文件是list_train.txt。

然后我仅仅有一个H5文件,即hdf5_train.h5。所以我的list_train.txt文件中的内容就是/home/XXX/caffe/model/hdf5/hdf5_train.h5

配置solver文件

接下来是caffe的solver文件。这个文件没有什么差别,

test_iter: 20
test_interval: 70
base_lr: 0.0000000005
display: 9
max_iter: 210000
lr_policy: "step"
gamma: 0.1
momentum: 0.9
weight_decay: 0.0001
stepsize: 700
snapshot: 500
snapshot_prefix: "snapshot"
solver_mode: GPU
net: "train_val.prototxt"
solver_type: SGD

配置train_val.prototxt文件

接下来是网络的train_val.prototxt文件。这是caffe的网络结构文件,我们这里以LeNet网络为例。我这里是这种:

name: "LeNet"
layer {
name: "data"
type: "HDF5Data"
top: "data"
top: "labels"
include {
phase: TRAIN
}
hdf5_data_param {
source: "list_train.txt"
batch_size: 50
}
}
layer {
name: "data"
type: "HDF5Data"
top: "data"
top: "labels"
include {
phase: TEST
}
hdf5_data_param {
source: "list_test.txt"
batch_size: 50
}
}
layer {
name: "conv1"
type: "Convolution"
bottom: "scaled"
top: "conv1"
param {
lr_mult: 1.0
}
param {
lr_mult: 2.0
}
convolution_param {
num_output: 20
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool1"
type: "Pooling"
bottom: "conv1"
top: "pool1"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "conv2"
type: "Convolution"
bottom: "pool1"
top: "conv2"
param {
lr_mult: 1.0
}
param {
lr_mult: 2.0
}
convolution_param {
num_output: 50
kernel_size: 5
stride: 1
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "pool2"
type: "Pooling"
bottom: "conv2"
top: "pool2"
pooling_param {
pool: MAX
kernel_size: 2
stride: 2
}
}
layer {
name: "ip1"
type: "InnerProduct"
bottom: "pool2"
top: "ip1"
param {
lr_mult: 1.0
}
param {
lr_mult: 2.0
}
inner_product_param {
num_output: 500
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "relu1"
type: "ReLU"
bottom: "ip1"
top: "ip1"
}
layer {
name: "ip2"
type: "InnerProduct"
bottom: "ip1"
top: "ip2"
param {
lr_mult: 1.0
}
param {
lr_mult: 2.0
}
inner_product_param {
num_output: 8
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
}
}
}
layer {
name: "error"
type: "EuclideanLoss"
bottom: "ip2"
bottom: "labels"
top: "error"
include {
phase: TEST
}
}
layer {
name: "loss"
type: "EuclideanLoss"
bottom: "ip2"
bottom: "labels"
top: "loss"
include {
phase: TRAIN
}
}

这里注意的是。最后的一层全连接层,输出的num_output应该是你label的维度,我这里是8。然后最后的loss计算,我使用的是欧氏距离的loss,也能够试着用其它类型的loss。

開始训练

依照以上步骤配置好了,最后就是训练了。

在控制台中输入下面指令来训练我们的数据:

./cafferoot/caffe/tools/caffe train --solver=solver.prototxt

可能是我数据源的问题,我的loss一開始很大。然后一直降不下来。也有可能是LeNet本身网络性能就不好。

关于网络的性能还须要另外再想办法提升。

最新文章

  1. hibernate关联映射学习
  2. 昨天一日和彭讨论post请求数据的问题
  3. Kl 证明 凸函数
  4. MongoDB的C#封装类
  5. mysql-主从复制(二)
  6. Objective-C 【构造方法(重写、场景、自定义)、super】
  7. DTcms会员中心添加新页面-会员投稿,获得所有文章并分页
  8. bugku login2 writeup 不使用vps的方法
  9. PowerDesigner 提示 Existence of index、key、reference错误
  10. 为Sublime Text 设置全局启动快捷键
  11. 因缺失log4j.properties 配置文件导致flume无法正常启动。
  12. 一个单js文件也可以运行vue
  13. [ci]jenkins-slave-ssh docker容器化-自动注入key
  14. webpack学习笔记-2-file-loader 和 url-loader
  15. 关于安装php时 --with-mysql命令参数问题
  16. Django(完整的登录示例、render字符串替换和redirect跳转)
  17. nginx 错误日志分析
  18. Python爬虫进阶四之PySpider的用法
  19. SQLSERVER带端口号的链接方式
  20. 使用 jQuery 避免鼠标双击

热门文章

  1. rxjava 视频
  2. 数据库SQL归纳(三)
  3. 【二分】Codeforces Round #435 (Div. 2) D. Mahmoud and Ehab and the binary string
  4. SSL 认证之后,request.getScheme()获取不到https的问题记录
  5. Educational Codeforces Round 8 A. Tennis Tournament 暴力
  6. 安装gcc-linaro-6.1.1-2016.08-x86_64_arm-linux-gnueabi交叉编译器
  7. js阻止浏览器、元素的默认事件与js阻止事件冒泡、阻止事件流
  8. 【个人专用&入门级】LAMP一键安装包
  9. List 中的最大最小值
  10. [Java基础] java的守护线程与非守护线程