Adversarial Learning for Semi-Supervised Semantic Segmentation

论文原文

摘要

创新点:我们提出了一种使用对抗网络进行半监督语义分割的方法。

在传统的GAN网络中,discriminator大多是用来进行输入图像的真伪分类(Datasets里面sample的图片打高分,generator产生的图片打低分),而本文设计了一种全卷积的discriminator,用于区分输入标签图中各个像素(pixel-wise)的分类结果是ground truth或是segmentation network给出的。本文证明了所提出的discriminator可以通过耦合模型的对抗损失和标准交叉熵损失来提高语义分割的准确性。此外,全卷积鉴别器通过发现未标记图像预测结果中的可信区域,实现半监督学习,从而提供额外的监督信号。

网络模型

对于labeled images:

  1. image \(x_n\)输入segmentation network,得到分割结果 \(S(x_n)\)

  2. 分割结果\(S(x_n)\)和该图片对应真实标签\(Y_n\)比较,计算交叉熵损失\(L_{ce}\)

  3. 分割结果\(S(x_n)\)送入discriminator中求 \(L_{adv}\)

  4. 使用\(S(x_n)\)和真实标签\(Y_n\)训练discriminator:分别将\(S(x_n)\)和真实标签\(y\)输入discriminator,让discriminator分辨输入标签的每个像素是来自是ground truth还是segmentation network(即输入的每个像素为来自于\(S(x_n)\)还是真实标签\(Y_n\))

    discriminator的输入为\(S(x_n)\)或真实标签\(Y_n\),尺寸为\(HxWxC\),其中\(C\)为语义分割的类别数;输出尺寸为\(HxWx1\),像素值代表这个pixel来自于真实标签\(Y_n\)的概率(如果discriminator认为该像素100%是来自真实标签\(Y_n\),则该位置像素值为1)

    损失函数为

  • 注:上式中,当输入为\(S(x_n)\)时,\(y_n = 0\),当输入为\(Y_n\)时,\(y_n = 1\)

对于unlabeled image:

  1. 将image \(x_n\)输入segmentation network,得到输出\(S(x_n)\),尺寸为\(HxWxC\),每个维度上的值代表该像素取这个类别的概率值。对输入进行 one-hot encode,得到 \(\hat{Y_n}\)

    编码过程:

  2. 用\(\hat{Y_n}\)和\(S(x_n)\)进行交叉熵损失计算

  3. 将\(S(x_n)\)通过训练后的discriminator,得到\(D(S(x_n))\),尺寸为\(HxWx1\),并设置闸值,通过指示函数对输出进行二值化(对于输出中像素值大于闸值的像素,认为是可信的,以突出正确的区域)

    无标签部分的损失函数为:



    实际中\(T_{semi}\)的取值为0.1~0.3

训练总损失:





Tips:

  1. 在训练过程中首先用labeled image进行5000iteration的训练(segmentation network和discriminator交替update)
  2. 此后随机sample,每个batch里面都可能有labeled image和unlabeled image,各自按照自己的步骤训练
  3. discriminator只用每个batch里面的labeled image进行训练

具体网络结构

Segmentation network:

首先采用DeepLab-v2 中的ResNet-101作为backbone进行预训练,并去掉最后一个分类层,将最后两个卷积层的步幅从2修改为1,从而使输出特征图的分辨率有效地达到输入图像大小的1/8。为了扩大感受野,我们将扩展后的卷积分别应用于步幅为2和4的conv4和conv5层。此外,我们在最后一层使用了Atrous Spatial Pyramid Pooling (ASPP)。最后,我们应用一个上采样层和softmax输出来匹配输入图像的大小。

Discriminator:

实验

Table4: 训练数据集为pascal VOC标准的1464张图片,SBD中的图片作为无标签数据进行训练



最新文章

  1. css设置背景图片
  2. AngularJs--angular-pagination可复用的分页指令
  3. Java 项目JDBC 链接数据库中会出现的错误
  4. NSOJ A fairy tale of the two(最小费用最大流、SPFA版本、ZKW版本)
  5. SpringMvc异常处理
  6. MVC4学习笔记(一)
  7. Maven详解之仓库------本地仓库、远程仓库
  8. 对git认识
  9. Python之路,Day17 - 分分钟做个BBS论坛
  10. CAEmitterLayer实现粒子效果
  11. 用Ultraiso刻录U盘装系统
  12. 【Tomcat】Invalid character found in the request target
  13. 201521123049 《JAVA程序设计》 第1周学习总结
  14. Extjs6(七)——增删查改之删除
  15. 【NOI2015】程序自动分析
  16. Django rest framework(5)----解析器
  17. ubuntu环境下实现 多线程的socket(tcp) 通信
  18. Centos7X部署Zabbix监控
  19. 用react + redux + router写一个todo
  20. CSS3系列教程:HSL 和HSL

热门文章

  1. TensorFlow.NET机器学习入门【7】采用卷积神经网络(CNN)处理Fashion-MNIST
  2. 解决opencv:AttributeError: 'NoneType' object has no attribute 'copy'
  3. Java网络编程Demo,使用TCP 实现简单群聊功能GroupchatSimple,多个客户端输入消息,显示在服务端的控制台
  4. ActiveMQ基础教程(二):安装与配置(单机与集群)
  5. .net core集成使用EasyNetQ来使用rabbitmq
  6. CentOS 7安装Etherpad(在线协作编辑)
  7. 简单的树莓派4b装64位系统+docker和docker-compose
  8. Leetcode算法系列(链表)之删除链表倒数第N个节点
  9. vue3.0获取地址栏参数
  10. 10个JS技巧