先训练G:

先不计算D的梯度:                                           判别器输入类型为(源域,0)或者(目标域,1),输出图片为真实图片(源域)的概率值
for param in model_D.parameters():    # model_D = nn.ModuleList([FCDiscriminator...]) 判别器是一个全卷积网络,其实就是一个二分类,输出一个条件概率,即输入样本属于源域或者目标域的概率
param.requires_grad = False 判别损失 Ld 是一个二分类交叉熵损失,判断输入属于源域还是目标域
                                                      怎么才算训练好判别器:判别器能对真图打高分,对假图打低分
                                                       
输入图片:
images.size: torch.Size([1, 3, 512, 1024])
labels.size: torch.Size([1, 512, 1024])
源域图片S 的输出分割特征图:
feat_source: ([1, 2048, 65, 129])
pred_source: ([1, 19, 65, 129])
输出特征图接一个上采样后 pred_source 大小变成: ([1, 19, 512, 1024])
计算交叉熵损失:
loss_seg = seg_loss(pred_source, labels)
计算梯度值,并反传梯度值: (只是计算,不更新)
loss_seg.backward()

目标域图片T的大小、特征图大小 和上面的源域S一样,不同的是,经过分割网络时,得到一个加权的特征图(注:加权后的特征图大小不变)
和S一样,得到特征图后,接一个上采样:
pred_target = interp_target(pred_target)
先损失清零
loss_adv = 0
然后计算判别损失值,即对倒数第二层的T域特征图打分

D_out = model_D[0](feat_target) (判别器D[0]输入通道为2048,输出通道为1)
再用上面的判别损失值来计算对抗损失,即用bce_loss(均方差MSELoss())来计算D_out和source_label的分布差
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))
# source_label=0

先对最后一层的T域特征图打分:特征图先变成概率图(用softmax()),然后对概率图打分
D_out = model_D[1](F.softmax(pred_target, dim=1)) (判别器D[1]输入19,输出1)
然后计算对抗损失:
loss_adv += bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).to(device))

loss_adv = loss_adv * 0.01
计算梯度值,并将梯度反传:
loss_adv.backward()
更新模型参数:
optimizer.step()
再训练D:


 

最新文章

  1. Oracle通过一个Value值查询数据库
  2. Javascript快速入门(上篇)
  3. Percona XtraBackup User Manual 阅读笔记
  4. 06章 Struts2国际化
  5. C#特性学习笔记二
  6. 使用JQuery Mobile实现手机新闻浏览器
  7. [转] Android SDK manager 无法获取更新版本列表
  8. handlebar helper帮助方法
  9. KEIL、uVision、RealView、MDK、KEIL C51区别比较
  10. vbox要手动mount才能挂载windows的共享文件夹(好用,不用安装samba了)
  11. Java考查“==”和equals
  12. 算法01 C语言设计
  13. sface
  14. JavaWeb架构发展
  15. hdu4746莫比乌斯反演+分块
  16. MapReduce Demo
  17. TZOJ 3209 后序遍历(已知中序前序求后序)
  18. [leetcode]179. Largest Number最大数
  19. eclipse的.properties文件中文显示问题
  20. WPF XAML 特殊字符(小于号、大于号、引号、&符号)

热门文章

  1. JAVA集合框架特征介绍
  2. 浅析sleep()方法与wait()方法
  3. [已解决]Android studio连接远程MySQL问题解决
  4. go-使用 vscore 调试 go 语言
  5. 使用Git GUI Here进行推送时产生报错
  6. 基于Axi4_lite的UART串口Verilog代码实现
  7. 修改 npm 全局模块及模块缓存存放位置
  8. 动态规划-3-RNA的二级结构
  9. DoTween结束后删除对象
  10. Linux下查找并杀死 zombile 和 stopped 进程