CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks中提到的,论文地址: http://www.cs.toronto.edu/~graves/icml_2006.pdf

论文中CTC的定义是这样的:把对未分割的序列数据label的任务叫做Temporal Classification,把使用RNNs对未分割的序列数据label叫做Connectionist Temporal Classification(CTC) 。与之相对的是,把对数据序列的每一个time-step或者frame独立label 叫做framewise classification

tensorflow中的相关实现在 /tensorflow/python/ops/ctc_ops.py

1. ctc_loss, 计算ctc loss

def ctc_loss(labels, inputs, sequence_length,
preprocess_collapse_repeated=False,
ctc_merge_repeated=True, time_major=True):

这个类执行softmax操作,所以输入应该是LSTM输出的线性映射

inputs, 最内部维度大小是num_classes,代表“num_labels +1” 个类别,其中num_labels是真实的balebs的数目,最大值“num_labels-1”是为blank label保留的

例如,如果一个单词包含3个labels ‘[a, b, c]’,则num_classes =4, 且labels的索引号是 ‘{a:0, b:1, c:2, blank:3}’

至于参数 preprocess_collapse_repeated 和 ctc_merge_repeated:

如果 preprocess_collapse_repeated = True ,在计算ctc之前,重复的labels会被合并为一个labels。这种预处理对下面这种情况是有用的:如果训练数据是强制对齐得到的,会包含不必要的重复。

如果 ctc_merge_repeated = False,那么伴随ctc计算的深入,重复的非blank将不会被合并,会被解释为独立的labels。这是ctc的简化的非标准的版本

具体见下表

  • preprocess_collapse_repeated = False,ctc_merge_repeated = True:经典CTC,输出的真实的重复的中间带有blanks类别,也可以通过解码器解码,输出不带有blanks的重复类别
  • preprocess_collapse_repeated = True,ctc_merge_repeated = False:因为在training之前,input 的labels已经合并重复项了,所以不会输出重复的类
  • preprocess_collapse_repeated = False,ctc_merge_repeated = False:输出重复的中间带有blank的类别,但是通常不需要解码器合并重复项
  • preprocess_collapse_repeated = True,ctc_merge_repeated = True: 未测试,非常可能不会学会输出重复类

参数:

labels: int32 SparseTensor, 标准的输出,稀疏矩阵

inputs: 3-D float tensor . 计算得到的logits。 如果time_major = False, shape:batch_size x max_time x num_classes. 如果 time_major = True, shape:max_time x batch_size x num_classes

sequence_length: 1-D int32 向量, batch_size

输出:

1-D float tensor,size:[batch], 概率的负对数

2. ctc_beam_search_decoder: 对输入的logits执行beam search 解码

def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100,
top_paths=1, merge_repeated=True):

如果 merge_repeated = True, 在输出的beam中合并重复类。这意味着如果一个beam中的连续项( consecutive entries) 相同,只有第一个提交。即,如果top path 是‘A B B B ’,返回值是‘A B’(当merge_repeated = True),‘A B B B ’ (当merge_repeated = False)

参数:

inputs: 3-D float tensor , shape:max_time x batch_size x num_classes

sequence_length: 1-D int32 向量, batch_size

beam_width: int scalar>=0

top_paths: int scalar>=0, <= beam_width, 输出解码后的数目

输出:

元组:(decoded, log_prob)

其中:

decoded : a list of length top_paths, 每一个是一个稀疏矩阵

log_prob : matrix , shape (batch_size x top_paths)

最新文章

  1. 生产环境中,数据库升级维护的最佳解决方案flyway
  2. .NET轻量级MVC框架:Nancy入门教程(一)——初识Nancy
  3. 使用markdown编辑evernote(印象笔记)的常用方法汇总
  4. 【STL】-Map/Multimap的用法
  5. MySQL语法
  6. STM32 使用 printf 发送数据配置方法 -- 串口 UART, JTAG SWO, JLINK RTT
  7. iOS 进阶 第十二天(0413)
  8. Java的jLinqer包介绍
  9. Git命令详解
  10. Perfect Squares——Leetcode
  11. Fileupload控件导致500错误
  12. eclipse 和myEclipse 项目导入
  13. 关于WM_ERASEBKGND和WM_PAINT的深刻理解
  14. 8天入门docker系列 —— 第五天 使用aspnetcore小案例熟悉容器互联和docker-compose一键部署
  15. BELLMEN-FORD普通
  16. 微信小程序 WXS实现json数据需要做过滤转义(filter)
  17. Hbase记录-hbase部署
  18. 基于MATLAB System Generator 搭建Display Enhancement模型
  19. [C++ Primer Plus] 第2章、开始学习c++
  20. LeetCode – All Nodes Distance K in Binary Tree

热门文章

  1. JavaEE--EL表达式
  2. 在Windows与Ubuntu上使用tensorboard的不同点
  3. Linux--磁盘管理--04
  4. laravel-admin后台框架基本使用
  5. TCP/IP的网络客户端和服务器端程序
  6. python snippets
  7. laravel 中数据库查询结果自动转数组
  8. Python面向对象的三大特性之继承和组合
  9. 使用python开发WebService
  10. springMVC项目访问URL链接时遇到某一段然后忽略后面的部分