1、同文章中建议的使用ubuntu-python隔离环境,真的很好用

参照:http://blog.topspeedsnail.com/archives/5618
启动虚拟环境:
source env/bin/activate
退出虚拟环境:
deactivate
注意:下面的操作全部都要在隔离环境中完成
2、搭建虚拟环境
pip install -r(requests)应该是安装request中所有的包
pip install Cython == 0.26
sudo apt-get install python3-dev
editdistance == 0.3.13、

3、

参照,编译百度warpctc
http://blog.csdn.net/amds123/article/details/73433926
git clone
https://github.com/baidu-research/warp-ctc.git

cd warp-ctc
mkdir build
cd build
cmake ..
make
sudo make install

执行文章中snt-orc
mxnet/metrics/ctc` and run `python setup.py build_ext --inplace`

4、
编译MXNET:
git clonr --recursive mxnet
cd mxnet
git tag
git checkout v0.9.3
按照论文中的方法编译失败,只能下载新版本编译
新版本编译步骤参考:https://www.bbsmax.com/A/A7zgqGk54n/
安装依赖:
$ sudo apt-get install -y build-essential git

$ sudo apt-get install -y libopenblas-dev

$ sudo apt-get install -y libopencv-dev

git clone --recursive https://github.com/dmlc/mxnet.git
cd mxnet
cp make/*.ck ./(编译选项文件)
vim *(按需修改编译文件)文章要求加入warpctc
https://mxnet.incubator.apache.org/tutorials/speech_recognition/baidu_warp_ctc.html
make -j4

5、
编译python接口参照
http://blog.csdn.net/zziahgf/article/details/72729883
编译 MXNet的Python API:
安装所需包
sudo apt-get install -y python-dev python-setuptools python-numpy
cd python
sudo python setup.py install

6、
下载stn-orc网络:https://github.com/Bartzi/stn-ocr
这个网络感觉跟FCN使用差不多,应该不需要什么格外操作

7、
下载model
https://bartzi.de/research/stn-ocr
中的文本识别:会有model文件夹,测试数据集
model文件夹中有两个文件
*.params是模型文件,*.json应该是网络描述文件
测试数据集中有图片文件夹,gt文件,还有一个不知道是什么用
还需要一个文件stn-orc网络中data文件对应‘文本’中应有个char_map文件,后面需要
模型预测代码就是stn-orc文件下的eva的py代码,看名字就知道,不过由于之前下载的是新版本,跟文中不同,所以使用这里的py文件没有运行成功,仿照文件自己写了一个简单的测试文件:

import matplotlib.pyplot as plt

import argparse
import csv
import json
import os
from collections import namedtuple from PIL import Image import editdistance
import mxnet as mx
import numpy as np from callbacks.save_bboxes import BBOXPlotter
from metrics.ctc_metrics import strip_prediction
from networks.text_rec import SVHNMultiLineCTCNetwork
from operations.disable_shearing import *
from utils.datatypes import Size Batch = namedtuple('Batch', ['data']) #后缀都不能加的,程序自己添加,似乎同时加载两个文件
sym,arg_params,aux_params = mx.model.load_checkpoint('./testxt/model/model',2)
#这里面应该是训练的参数
#print(arg_params)
net, loc, transformed_output, size_params = SVHNMultiLineCTCNetwork.get_network((1,1,64,200),Size(50,50),46,2,23)
output = mx.sym.Group([loc, transformed_output, net]) #靠 在这里预定义的话,TMD,soft 层怎么办?
mod = mx.mod.Module(output,context=mx.cpu(),data_names=['data',
'softmax_label',
'l0_forward_init_h_state',
'l0_forward_init_c_state_cell',
'l1_forward_init_h_state',
'l1_forward_init_c_state_cell' ],label_names=[])
mod.bind(for_training=False,grad_req='null',data_shapes=[
('data',(1,1,64,200)),
('softmax_label', (1,23)),
('l0_forward_init_h_state', (1, 1, 256)),
('l0_forward_init_c_state_cell', (1, 1, 256)),
('l1_forward_init_h_state', (1, 1, 256)),
('l1_forward_init_c_state_cell', (1, 1, 256))
])
arg_params['l0_forward_init_h_state'] = mx.nd.zeros((1, 1, 256))
arg_params['l0_forward_init_c_state_cell'] = mx.nd.zeros((1, 1, 256))
arg_params['l1_forward_init_h_state'] = mx.nd.zeros((1, 1, 256))
arg_params['l1_forward_init_c_state_cell'] = mx.nd.zeros((1, 1, 256))
mod.set_params(arg_params, aux_params) #看看怎么加载label
#一个映射文件,类似caffe中的label,在下面循环中用到
with open('/home/lbk/python-env/stn-ocr/mxnet/testxt/ctc_char_map.json') as char_map_file:
char_map = json.load(char_map_file)
reverse_char_map = {v: k for k, v in char_map.items()}
print(len(reverse_char_map)) with open('/home/lbk/python-env/stn-ocr/mxnet/testxt/icdar2013_eval/one_gt.txt') as eval_gt:
reader = csv.reader(eval_gt,delimiter=';')
for idx,line in enumerate(reader):
file_name = line[0]
label = line[1].strip()
gt_word = label.lower()
print(gt_word)
#这一步又是干什么的
#dict.get(key,default)查找,不存在返回default
label = [reverse_char_map.get(ord(char.lower()),reverse_char_map[9250]) for char in gt_word]
label+=[reverse_char_map[9250]]*(23-len(label))
#print(label)
the_image = Image.open(file_name)
the_image = the_image.convert('L')
the_image = the_image.resize((200,64), Image.ANTIALIAS)
image = np.asarray(the_image, dtype=np.float32)[np.newaxis, np.newaxis, ...]
image/=255
temp = mx.nd.zeros((1,1,256))
label = mx.nd.array([label])
image = mx.nd.array(image)
print(type(temp),type(label))
input_batch = Batch(data=[image,label,temp,temp,temp,temp]) mod.forward(input_batch,is_train=False)
print(len(mod.get_outputs()))
print('0000',mod.get_outputs()[2])
predictions = mod.get_outputs()[2].asnumpy()
predicted_classes = np.argmax(predictions,axis=1)
print(len(predicted_classes))
print(predicted_classes) predicted_classes = strip_prediction(predicted_classes, int(reverse_char_map[9250]))
predicted_word = ''.join([chr(char_map[str(p)]) for p in predicted_classes]).replace(' ', '')
print(predicted_word) distance = editdistance.eval(gt_word, predicted_word)
print("{} - {}\t\t{}: {}".format(idx, gt_word, predicted_word, distance)) results = [prediction == label for prediction, label in zip(predicted_word, gt_word)]
print(results)

  

补充:
学习MXNET:
http://www.infoq.com/cn/articles/an-introduction-to-the-mxnet-api-part04
http://blog.csdn.net/yiweibian/article/details/72678020
http://ysfalo.github.io/2016/04/01/mxnet%E4%B9%8Bfine-tune/
http://shuokay.com/2016/01/01/mxnet-memo/

最新文章

  1. oc 单例
  2. swift学习笔记之-访问控制
  3. yii2 实现多表联查
  4. unity3d中dllimport方法的使用,以接入腾讯平台为例!!!
  5. zju 1037 Gridland(找规律,水题)
  6. Mysql 的变量
  7. 如何让input之间无空隙
  8. poj 1018 Communication System_贪心
  9. 本地php 连接 MySQL
  10. mysql、mysqli、pdo使用
  11. ‘true’==true返回false详解
  12. TestNG监听器实现用例运行失败自动截图、重运行功能
  13. Spring中配置使用slf4j + log4j
  14. HDU 6300
  15. 测试 Java 类的非公有成员变量和方法
  16. 打包python为可执行文件时报错R6034解决方案
  17. ats Linux Bridge内联
  18. Linux内存使用方法详细解析
  19. [CF1111C]Creative Snap
  20. AtCoder Grand Contest

热门文章

  1. js 值类型和引用类型
  2. Http头 Range、Content-Range
  3. php报错配置问题
  4. Ubuntu免安装配置MySQL
  5. linux调整缓存写入磁盘的时间,减少磁盘爆掉的可能性
  6. 转: maven打可执行的jar包以及classpath设置
  7. [java面试]关于多态性的理解
  8. MongoDB入门学习(二):MongoDB的基本概念和数据类型
  9. Linux在中国正在走向没落
  10. HTTP状态码中301与302的区别