Semi-Supervised Semantic Segmentation with High- and Low-level Consistency

TPAMI 2019

论文原文

code

创新点:

利用两个分支结构分别处理low-level和high-level的特征,进行半监督语义分割

网络结构



上分支:Semi-Supervised Semantic Segmentation GAN (s4GAN)

下分支:Multi-Label Mean Teacher (MLMT)

s4GAN

训练segmentation network \(S\)

segmentation network \(S\)的损失函数由以下三部分组成:

  1. Cross-entropy loss

    输入原图到segmentation network \(S\)中,对于labeled images,输出的分割结果\(S(x^l)\)和标签\(y^l\)对比,计算交叉熵损失\(L_{ce}\)

  2. Feature matching loss

    为了使得分割结果\(S(x^l)\)和标签\(y^l\)的特征分布尽可能一致,本文计算分割结果\(S(x^l)\)和标签\(y^l\)的特征分布差异mean discrepancy,并设计Feature matching loss



    上式中\(D_k\)表示discriminator的第\(k\)层

    注:此Feature matching loss适用于有标签和无标签的数据
  3. Self-training loss

    本文认为,在训练过程中generator和discriminator需要达到某种平衡,如果discriminator过于strong,则无法给generator任何有用的学习信号。因此,对于unlabeled image,本文每次将generator产生的,可以成功欺骗discriminator的分割图当作真实标签,用于监督学习。由此可以促使segmentation network(即generator)变强,且一定程度上阻碍discriminator的进步,不希望discriminator过于强大,破坏平衡。

    具体而言,discriminator在s4GAN中用于在image-level判断一张分割图是真实标签(real label),还是segmentation network的输出(fake label),根据为真实标签的可能性输出一个0~1之间的概率值(若为真实标签,则输出1)

    文章设置闸值,对于输出大于闸值的分割图,作为高质量的预测图,当作真实标签,用于监督学习,并计算交叉熵损失

s4GAN总损失:

训练discriminator

discriminator的输入包含原图image和对应标签,训练discriminator,希望discriminator能给真实标签打高分,给fake label打低分。具体损失函数和传统的GAN相同。



(channel wise)

MLMT

该分支包含两个网络,分别为学生网络和老师网络,训练时,一张image经过微小的,不同的扰动之后分别输入学生网络和老师网络,学生网络和老师网络使用online ensemble的weight(老师网络是学生网络学习的目标,老师网络的权重在学生网络的基础上根据指数平均移动线移动,详见论文)。本文希望学生网络的输出和老师网络的输出尽可能一致,则对于所有image,使用均方误差来衡量两个网络输出的差异,对于labeled image,同时使用类交叉熵函数计算损失

Network Fusion

简单的通过deactivate segmentation networks的输出中没有出现在input image中的图片来融合两个网络的结果。

对于一张image分割图的一个类别c的mask,尺寸为\(HxWx1\),(对于每一个像素?)如果学生网络的输出(soft label)小于设定的某个闸值,则令segmentation network的输出为0,否则segmentation network的输出不变。

实验

数据集:

PASCAL VOC 2012 segmentation benchmark, the PASCAL-Context dataset, and the Cityscapes dataset.

网络具体结构:

segmentation network:

deeplab v2

discriminator:

4层卷积层,通道数分别为\({64,128,256,512}\),卷积核大小为4x4,每个卷积层后面都有一个negative slope of 0.2的Leaky-ReLU层和一个dropout概率为0.5的dropout层(该高概率的dropout layer对于GAN的稳定训练非常关键)。最后一个卷积层后面是一个全局平均池化层和全连接层,全局平均池化的输出用于Feature matching loss的计算

学生网络和老师网络:

ResNet101(在imagenet上预训练)

实验结果:

疑问:

  1. 网络融合的目的?
  2. self-train loss的设定(为阻止discriminator变强)?

最新文章

  1. nginx+lua
  2. json转js对象
  3. 文本 To 音频
  4. 【leetcode】Wildcard Matching
  5. ORACLE object_id和data_object_id
  6. 利用PPT的WebBroswer控件助力系统汇报演示
  7. Java输入、输入、IO流 类层次关系梳理
  8. latin1
  9. HDOJ(HDU) 1860 统计字符
  10. [原创]安卓使用Termux做渗透测试(演示sqlmap安装,并附上一个神器)
  11. zk mysql 主从自动切换
  12. SOM网络聚类完整示例(利用python和java)
  13. AEAI HR开源人力资源管理v1.6.0发版公告
  14. WPF InkCanvas 书写毛笔效果
  15. A bean with that name has already been defined in DataSourceConfiguration$Hikari.class
  16. Vue 局部组件和全局组件的使用
  17. 最新的 Vue 相关开源项目库汇总
  18. 转:Exploiting Windows 10 in a Local Network with WPAD/PAC and JScript
  19. string 类简介和例程
  20. webpack分离第三方库(CommonsChunkPlugin并不是分离第三方库的好办法DllPlugin科学利用浏览器缓存)

热门文章

  1. Codeforces 888D: Almost Identity Permutations(错排公式,组合数)
  2. 【C\C++笔记】指针输出字符串
  3. 大数据分布式存储之Cassandra
  4. gojs 实用高级用法
  5. 美和易思 · 「云农职互联网技术学院」HTML+CSS 做西普尼金表官网
  6. mysql总结笔记 -- 索引篇
  7. Parallel.ForEach 之 MaxDegreeOfParallelism
  8. sqlsugar freesql hisql 三个ORM框架性能测试对比
  9. centos7 安装locate
  10. spring security 动态 修改当前登录用户的 权限