源码地址:https://github.com/mrzhu-cool/pix2pix-pytorch

相比于朱俊彦的版本,这一版更加简单易读

训练的代码在train.py,开头依然是很多代码的共同三板斧,加载参数,加载数据,加载模型

命令行参数

# Training settings
parser = argparse.ArgumentParser(description='pix2pix-pytorch-implementation')
parser.add_argument('--dataset', required=True, help='facades')
parser.add_argument('--batch_size', type=int, default=1, help='training batch size')
parser.add_argument('--test_batch_size', type=int, default=1, help='testing batch size')
parser.add_argument('--direction', type=str, default='b2a', help='a2b or b2a')
parser.add_argument('--input_nc', type=int, default=3, help='input image channels')
parser.add_argument('--output_nc', type=int, default=3, help='output image channels')
parser.add_argument('--ngf', type=int, default=64, help='generator filters in first conv layer')
parser.add_argument('--ndf', type=int, default=64, help='discriminator filters in first conv layer')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count')
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='use cuda?')
parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--lamb', type=int, default=10, help='weight on L1 term in objective')
opt = parser.parse_args()

数据

print('===> Loading datasets')
root_path = "dataset/"
train_set = get_training_set(root_path + opt.dataset, opt.direction)
test_set = get_test_set(root_path + opt.dataset, opt.direction)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batch_size, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.test_batch_size, shuffle=False)

模型

print('===> Building models')
net_g = define_G(opt.input_nc, opt.output_nc, opt.ngf, 'batch', False, 'normal', 0.02, gpu_id=device)
net_d = define_D(opt.input_nc + opt.output_nc, opt.ndf, 'basic', gpu_id=device)

优化器,损失函数

criterionGAN = GANLoss().to(device)
criterionL1 = nn.L1Loss().to(device)
criterionMSE = nn.MSELoss().to(device) # setup optimizer
optimizer_g = optim.Adam(net_g.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
optimizer_d = optim.Adam(net_d.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
net_g_scheduler = get_scheduler(optimizer_g, opt)
net_d_scheduler = get_scheduler(optimizer_d, opt)

接着按批次读取数据,首先更新判别器,判别器的输入是图像对(真,真)(真,假)

######################
# (1) Update D network
###################### optimizer_d.zero_grad() # train with fake
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False) # train with real
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True) # Combined D loss
loss_d = (loss_d_fake + loss_d_real) * 0.5 loss_d.backward() optimizer_d.step()

然后更新生成器,生成器的损失由判别器产生的损失函数和真假图像之间的L1约束组成

######################
# (2) Update G network
###################### optimizer_g.zero_grad() # First, G(A) should fake the discriminator
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab)
loss_g_gan = criterionGAN(pred_fake, True) # Second, G(A) = B
loss_g_l1 = criterionL1(fake_b, real_b) * opt.lamb loss_g = loss_g_gan + loss_g_l1 loss_g.backward() optimizer_g.step()

最后更新学习率

update_learning_rate(net_g_scheduler, optimizer_g)
update_learning_rate(net_d_scheduler, optimizer_d)

比较核心的代码是网络构造,以及一些工具函数,放在后面写

最新文章

  1. jQuery中异步操作对象Deferred
  2. 安卓手机已保存WiFi密码查看助手(开源)
  3. 【心得】怪异的JS的Date函数
  4. MongoDB直接执行js脚本
  5. 关于INTRAWEB ISAPI DLL发布
  6. php新特性--持续更新
  7. linux 备份系统
  8. Xcode7.3.1中通过最新的CocoaPod安装pop动画引擎
  9. CXF安装和配置时出现Exception in thread "main" java.lang.UnsupportedClassVersionError:异常?
  10. KMP之计算Next数组
  11. python反爬虫解决方法——模拟浏览器上网
  12. MySQL数据库分区操作【RANGE】
  13. owa2013配置HTTPS
  14. opencv3.2.0形态学滤波之形态学梯度、顶帽、黑帽
  15. linux 用户管理 groupadd、groupmod、groupdel、gpasswd
  16. 读Bayes' Theorem
  17. WIN7下恼人的AppData——删除没用的缓存文件
  18. ffmpe安装
  19. for(j=0,i=0;j
  20. 北邮校赛 H. Black-white Tree (猜的)

热门文章

  1. mysql AUTO INCREMENT字段 语法
  2. Devexpress MVC GridView / CardView (持续更新)
  3. 论文阅读:Stateless Network Functions: Breaking the Tight Coupling of State and Processing
  4. Who is better?
  5. shell命令别名
  6. .net framework4.6项目的dll升级后,未找到方法“System.String.GetPathsOfAllDirectoriesAbove”解决
  7. GET,POST传值总结
  8. Docker安装CentOS7
  9. Python Module_openpyxl_处理Excel表格
  10. Object 的 property descriptor