import tensorflow as tf

def initialize_uninitialized(sess):
global_vars = tf.global_variables()
is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f] print [str(i.name) for i in not_initialized_vars] # only for testing
if len(not_initialized_vars):
sess.run(tf.variables_initializer(not_initialized_vars))

上述代码是用于初始化剩余未被初始化的变量的函数

需要注意的是,我们一般采用tf.global_variables_initializer()作为初始化op会覆盖原来通过saver.restore()方式加载的变量状态,因此,不可采用此方法。

另外,如果采用sess.run(tf.global_variables_initializer())在 saver.restore()之前,是不起作用的,原因未知,restore函数似乎能屏蔽掉global_variables_initializer()

的初始化效果。

选择性加载变量时可以采用scope进行隔离,提取出name:var这样的键值对的字典作为saver的加载根据。如下代码:

# stage_merged.py
# transform from single frame into multi-frame enhanced single raw
from __future__ import division
import os, time, scipy.io
import tensorflow as tf
import numpy as np
import rawpy
import glob
from model_sid_latest import network_stages_merged, network_my_unet, network_enhance_raw
import platform
from PIL import Image if platform.system() == 'Windows':
data_dir = 'D:/data/Sony/dataset/bbf-raw-selected/'
elif platform.system() == 'Linux':
data_dir = './dataset/bbf-raw-selected/'
else:
print('platform not supported!')
assert False os.environ["CUDA_VISIBLE_DEVICES"] = ""
checkpoint_dir = './model_stage_merged/'
result_dir = './out_stage_merged/'
log_dir = './log_stage_merged/'
learning_rate = 1e-4
epoch_bound = 20000
save_model_every_n_epoch = 10 if platform.system() == 'Windows':
output_every_n_steps = 1
else:
output_every_n_steps = 100 if platform.system() == 'Windows':
ckpt_enhance_raw = 'D:/model/enhance_raw/'
ckpt_raw2rgb = 'D:/model/raw2rgb-c1/'
else:
ckpt_enhance_raw = './model/enhance_raw/'
ckpt_raw2rgb = './model/raw2rgb-c1/' # BBF100-2
bbf_w = 4032
bbf_h = 3024 patch_w = 512
patch_h = 512 max_level = 1023
black_level = 64 patch_w = 512
patch_h = 512 # set up dataset
input_files = glob.glob(data_dir + '/*.dng')
input_files.sort() def preprocess(raw, bl, wl):
im = raw.raw_image_visible.astype(np.float32)
im = np.maximum(im - bl, 0)
return im / (wl - bl) def pack_raw_bbf(path):
raw = rawpy.imread(path)
bl = 64
wl = 1023
im = preprocess(raw, bl, wl)
im = np.expand_dims(im, axis=2)
H = im.shape[0]
W = im.shape[1]
if raw.raw_pattern[0, 0] == 0: # CFA=RGGB
out = np.concatenate((im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 2: # BGGR
out = np.concatenate((im[1:H:2, 1:W:2, :],
im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 0: # GRBG
out = np.concatenate((im[0:H:2, 1:W:2, :],
im[0:H:2, 0:W:2, :],
im[1:H:2, 0:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
elif raw.raw_pattern[0,0] == 1 and raw.raw_pattern[0,1] == 2: # GBRG
out = np.concatenate((im[1:H:2, 0:W:2, :],
im[0:H:2, 0:W:2, :],
im[0:H:2, 1:W:2, :],
im[1:H:2, 1:W:2, :]), axis=2)
else:
assert False
wb = np.array(raw.camera_whitebalance)
wb[3] = wb[1]
wb = wb / wb[1]
out = np.minimum(out * wb, 1.0) h_, w_ = im.shape[0]//2, im.shape[1]//2
out_16bit_ = np.zeros([h_, w_, 4], dtype=np.uint16)
out_16bit_[:, :, :] = np.uint16(out[:, :, :] * (wl - bl))
del out
return out_16bit_ tf.reset_default_graph()
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
in_im = tf.placeholder(tf.float32, [1, patch_h, patch_w, 4], name='input') with tf.variable_scope('enhance_raw', reuse=tf.AUTO_REUSE):
enhanced_raw = network_enhance_raw(in_im, patch_h, patch_w)
with tf.variable_scope('raw2rgb', reuse=tf.AUTO_REUSE):
gt_im = network_my_unet(enhanced_raw, patch_h, patch_w)
with tf.variable_scope('stage_merged', reuse=tf.AUTO_REUSE):
out_im = network_stages_merged(in_im, patch_h, patch_w) gt_im_cut = tf.minimum(tf.maximum(gt_im, 0.0), 1.0)
out_im_cut = tf.minimum(tf.maximum(out_im, 0.0), 1.0)
ssim_loss = 1 - tf.image.ssim_multiscale(gt_im_cut[0], out_im_cut[0], 1.0)
l1_loss = tf.reduce_mean(tf.reduce_sum(tf.abs(gt_im_cut - out_im_cut), axis=-1))
l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(gt_im_cut - out_im_cut), axis=-1))
G_loss = ssim_loss
# G_loss = l1_loss + l2_loss tf.summary.scalar('G_loss', G_loss)
tf.summary.scalar('L1 Loss', l1_loss)
tf.summary.scalar('L2 Loss', l2_loss) ########## LOADING MODELS #############
scope_ = 'enhance_raw'
enhance_raw_var_list = tf.global_variables(scope_)
enhance_raw_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in enhance_raw_var_list]
enhance_raw_map = dict()
for i in range(len(enhance_raw_var_names)):
enhance_raw_map[enhance_raw_var_names[i]] = enhance_raw_var_list[i] saver_enhance_raw = tf.train.Saver(var_list=enhance_raw_map)
ckpt = tf.train.get_checkpoint_state(ckpt_enhance_raw)
if ckpt:
saver_enhance_raw.restore(sess, ckpt.model_checkpoint_path)
print('loaded enhance_raw model: ' + ckpt.model_checkpoint_path)
else:
print('Error: failed to load enhance_raw model!')
#----------------------------------------
scope_ = 'raw2rgb'
raw2rgb_var_list = tf.global_variables(scope_)
raw2rgb_var_names = [v.name.replace(scope_+'/', '').replace(':0', '') for v in raw2rgb_var_list]
raw2rgb_map = dict()
for i in range(len(raw2rgb_var_names)):
raw2rgb_map[raw2rgb_var_names[i]] = raw2rgb_var_list[i] saver_raw2rgb = tf.train.Saver(var_list=raw2rgb_map)
ckpt = tf.train.get_checkpoint_state(ckpt_raw2rgb)
if ckpt:
saver_raw2rgb.restore(sess, ckpt.model_checkpoint_path)
print('loaded raw2rgb model: ' + ckpt.model_checkpoint_path)
else:
print('Error: failed to load raw2rgb model!')
assert False
#---------------------------------------- def initialize_uninitialized(sess):
global_vars = tf.global_variables()
bool_inits = sess.run([tf.is_variable_initialized(var) for var in global_vars])
uninit_vars = [v for (v, b) in zip(global_vars, bool_inits) if not b]
for v in uninit_vars:
print(str(v.name))
if len(uninit_vars):
sess.run(tf.variables_initializer(uninit_vars)) t_vars = tf.trainable_variables(scope='stage_merged')
lr = tf.placeholder(tf.float32)
G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=t_vars) saver = tf.train.Saver(var_list=t_vars)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt:
saver.restore(sess, ckpt.model_checkpoint_path)
print('loaded ' + ckpt.model_checkpoint_path)
else:
sess.run(tf.variables_initializer(var_list=t_vars))
initialize_uninitialized(sess)
#######################################
if not os.path.isdir(result_dir):
os.mkdir(result_dir) input_images = [None] * len(input_files)
g_loss = np.zeros([500, 1]) merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(log_dir, sess.graph) steps = 0
st = time.time() for epoch in range(0, epoch_bound):
for ind in np.random.permutation(len(input_images)):
steps += 1
if input_images[ind] is None:
input_images[ind] = np.expand_dims(pack_raw_bbf(input_files[ind]), axis=0) # random cropping
xx = np.random.randint(0, bbf_w // 2 - patch_w)
yy = np.random.randint(0, bbf_h // 2 - patch_h)
input_patch = np.float32(input_images[ind][:, yy:yy + patch_h, xx:xx + patch_w, :]) / (
max_level - black_level) # random flipping
if np.random.randint(2, size=1)[0] == 1: # random flip
input_patch = np.flip(input_patch, axis=1)
if np.random.randint(2, size=1)[0] == 1:
input_patch = np.flip(input_patch, axis=0)
if np.random.randint(2, size=1)[0] == 1: # random transpose
input_patch = np.transpose(input_patch, (0, 2, 1, 3)) summary, _, G_current, output, gt_im_ = sess.run(
[merged, G_opt, G_loss, out_im_cut, gt_im_cut],
feed_dict={
in_im: input_patch,
lr: learning_rate})
g_loss[steps % len(g_loss)] = G_current if steps % output_every_n_steps == 0:
loss_ = np.mean(g_loss[np.where(g_loss)])
cost_ = (time.time() - st) / output_every_n_steps
st = time.time()
print("%d %d Loss=%.6f Speed=%.6f" % (epoch, steps, loss_, cost_))
writer.add_summary(summary, global_step=steps)
temp = np.concatenate(
(input_patch[0, :, :, :3],
gt_im_[0, 0:patch_h*2:2, 0:patch_w*2:2, :3],
output[0, 0:patch_h*2:2, 0:patch_w*2:2, :3]), axis=1)
scipy.misc.toimage(temp * 255, high=255, low=0, cmin=0, cmax=255) \
.save(result_dir + '/%d_%d.jpg' % (epoch, steps)) # clean up the memory if necessary
if platform.system() == 'Windows':
input_images[ind] = None if epoch % save_model_every_n_epoch == 0:
saver.save(sess, checkpoint_dir + '%d.ckpt' % epoch)
print('model saved.')

最新文章

  1. ES6之变量常量字符串数值
  2. 你需要知道的包管理器(Package Manager)
  3. FUND
  4. Linux设备驱动之中断支持及中断分层
  5. 深入理解拉格朗日乘子法(Lagrange Multiplier) 和KKT条件
  6. AngularJs ngCloak、ngController、ngInit、ngModel
  7. Spring.Net的AOP的通知
  8. springboot
  9. Spark Streaming源码解读之Job动态生成和深度思考
  10. PHP 实现多服务器共享 SESSION 数据
  11. Leetcode#99 Recover Binary Search Tree
  12. 优雅的python 写排序算法
  13. linux下安装
  14. 轻量级的数据交换语言(JSON)
  15. 【BZOJ 1367】 1367: [Baltic2004]sequence (可并堆-左偏树)
  16. linux 关于Apache默认编码错误 导致网站乱码的解决方案
  17. 【XSY1905】【XSY2761】新访问计划 二分 树型DP
  18. JS中for in 与 for of
  19. eclipse签名使用的key文件(android生成keystore)
  20. 4.update更新和delete删除用法

热门文章

  1. node获取当前路径的三种方法
  2. CCF CSP 201712-1 最小差值
  3. Oarcle之序列
  4. jq复制
  5. 4.产生10个1-100的随机数,并放到一个数组中 (1)把数组中大于等于10的数字放到一个list集合中,并打印到控制台。 (2)把数组中的数字放到当前文件夹的numArr.txt文件中
  6. 转载-《Python学习手册》读书笔记
  7. postman的几个问题
  8. oracle SQL性能分析之10053事件
  9. MySQL高性能优化规范建议,速度收藏
  10. JDBC——Java语言连接数据库的标准