import tensorflow as tf
import tensorflow.contrib.slim as slim import rawpy
import numpy as np
import tensorflow as tf
import struct
import glob
import os
from PIL import Image
import time __sony__ = 0
__huawei__ = 1
__blackberry__ = 2 __stage_raw2raw__ = 0
__stage_raw2rgb__ = 1
__stage_overall__ = 2 train_prefix = ''
valid_prefix = ''
test_prefix = '' # ============ CONFIGURATION ============
USE_GPU = False
if USE_GPU:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
# change this to switch between datasets
source_id = __sony__ # switch between training stages
training_stage = __stage_raw2rgb__ # patch size should be set on running
patch_size = (512, 512)
#patch_size = (2840, 4248) # switch between training and validation
current_prefix = train_prefix # model saving settings
max_epoch = 2000
save_epoch_delay = 1
model_dir = './result_raw2raw/'
out_dir = './output_raw2raw/'
log_dir = './log_raw2raw/'
learn_rate = 1e-2
# ============ CONFIGURATION ============ if source_id == __blackberry__:
WHITE_LEVEL = 1023
BLACK_LEVEL = 64
HEIGHT = 3024
WIDTH = 4032
elif source_id == __sony__:
WHITE_LEVEL = 16383
BLACK_LEVEL = 512
HEIGHT = 2848
WIDTH = 4256
elif source_id == __huawei__:
WHITE_LEVEL = 1023
BLACK_LEVEL = 64
HEIGHT = 2976
WIDTH = 3968 if USE_GPU:
data_dir = '../see_in_the_dark/dataset/Sony_small/'
else:
data_dir = 'D:/data/Sony_small/' # !!!!!! DO NOT TOUCH THIS SETTING !!!!!!
fixed_size = (128, 128)
num_of_denoise_filter = 3
standard_brightness = 0.1
# !!!!!! DO NOT TOUCH THIS SETTING !!!!!! def has_nan_in_tensor(x):
return np.sum(x != x) > 0 def raw_from_file(path):
if source_id == __sony__:
data = rawpy.imread(path)
raw = data.raw_image_visible.astype(np.float32)
raw = raw.reshape(2848, 4256)
# convert from RGBG into standard GRGB format:
# cut the strips of left and right borders
h, w = raw.shape[0], raw.shape[1]
return np.reshape(raw[:, 1:w-1], [h, w-2, 1])
elif source_id == __huawei__:
data = rawpy.imread(path)
raw = data.raw_image_visible.astype(np.float32)
raw = raw.reshape(2976, 3968)
# convert from BGRG into standard GRGB format:
# cut the strips of top and bottom borders
h, w = raw.shape[0], raw.shape[1]
return np.reshape(raw[1:h-1, :], [h-2, w, 1])
elif source_id == __blackberry__:
data = open(path, 'rb').read()
data = struct.unpack('H'*int(len(data)/2), data)
raw = np.float32(data)
raw = raw.reshape(3024, 4032)
h, w = raw.shape[0], raw.shape[1]
return np.reshape(raw, [h, w, 1])
else:
assert False def rgb_from_file(path):
if source_id == __sony__:
raw = rawpy.imread(path)
rgb = np.float32(
raw.postprocess(
use_camera_wb=True,
half_size=False,
no_auto_bright=True,
output_bps=16
)
) / 65535.0
return rgb[:, 1:-1, :]
elif source_id == __huawei__:
raw = rawpy.imread(path)
rgb = np.float32(
raw.postprocess(
use_camera_wb=True,
half_size=False,
no_auto_bright=True,
output_bps=16
)
) / 65535.0
return rgb[1:-1, :, :]
else:
raise NameError('file type [%d] does not support rawpy!' % source_id) def black_level_correction(bayer):
with tf.name_scope('black_level_corr'):
r = 1.0/(WHITE_LEVEL-BLACK_LEVEL)
return tf.nn.relu((bayer - BLACK_LEVEL)*r) def bound(bayer):
return tf.minimum(tf.maximum(bayer, 0), 1) def bayer_to_rgb(bayer):
with tf.name_scope('bayer2rgb'):
filters = np.array([
[0.0, 1.0, 0.0, 0.0], # R
[0.5, 0.0, 0.0, 0.5], # (G1+G2)/2
[0.0, 0.0, 1.0, 0.0], # B
]).reshape([1, 3, 2, 2]).transpose([2, 3, 0, 1])
return tf.nn.conv2d(
bayer,
filters,
strides=(1, 2, 2, 1),
padding='VALID',
name='bayer_converter'
) def demosaic(rgb):
with tf.name_scope('demosaic'):
return tf.image.resize_bilinear(rgb, patch_size) def color_correction(rgb, color_matrix):
with tf.name_scope('color_corr'):
filters = tf.reshape(color_matrix, [1, 1, 3, 3])
return tf.nn.conv2d(rgb, filters, (1, 1, 1, 1), 'SAME', name='output') def min_max_normalize(rgb):
_min = tf.reduce_min(rgb)
_max = tf.reduce_max(rgb)
return (rgb - _min + 1e-8)/(_max - _min + 1e-8) def gaussian_norm(rgb):
_mean = tf.reduce_mean(rgb)
_vari = tf.sqrt(tf.reduce_mean(tf.square(rgb-_mean)))
return (rgb-_mean)/_vari # not supported on SNPE, so do it on cpu of mobile phone
# in case of negative value, normalize it before power operation
def gamma_correction(rgb, gamma):
with tf.name_scope('gamma_corr'):
return tf.pow(min_max_normalize(rgb), gamma) def lrelu(x):
return tf.maximum(x*0.2, x) def network_raw2raw(inputs):
with tf.name_scope('raw2raw'):
net = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
scope='g_conv1')
net = slim.conv2d(net, 32, [3, 3], rate=2, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
scope='g_conv2')
net = slim.conv2d(net, 32, [3, 3], rate=4, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
scope='g_conv3')
net = slim.conv2d(net, 32, [3, 3], rate=8, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
scope='g_conv4')
net = slim.conv2d(net, 32, [3, 3], rate=16, activation_fn=lrelu, weights_initializer=tf.initializers.constant,
scope='g_conv5')
net = slim.conv2d(net, 1, [1, 1], rate=1, activation_fn=None, scope='g_conv_last')
return net def show(rgb, title):
im = Image.fromarray(np.uint8(rgb * 255))
im.show(title) def save(rgb, path):
im = Image.fromarray(np.uint8(rgb * 255))
im.save(path) def concat(ims):
return np.concatenate(ims, axis=1) def get_color_matrix_and_gamma(bayer):
with tf.name_scope('isp_param_gen'):
with tf.name_scope('common_extractor'):
channels = tf.layers.conv2d(bayer, 3, kernel_size=3, strides=2, padding='valid')
activations = tf.nn.tanh(channels)
channels = tf.layers.conv2d(activations, 5, kernel_size=3, strides=2, padding='valid')
activations = tf.nn.relu(channels)
with tf.name_scope('color_matrix'):
channels_cm = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid')
activations_cm = tf.nn.tanh(channels_cm)
channels_cm = tf.layers.conv2d(activations_cm, 5, kernel_size=3, strides=2, padding='valid')
channels_flat_cm = tf.reshape(
channels_cm,
[-1, channels_cm.shape[1]*channels_cm.shape[2]*channels_cm.shape[3]])
color_matrix = tf.reshape(tf.layers.dense(channels_flat_cm, 9), [3, 3])
with tf.name_scope('gamma'):
channels_gamma = tf.layers.conv2d(activations, 7, kernel_size=3, strides=2, padding='valid')
activations_gama = tf.nn.tanh(channels_gamma)
channels_gamma = tf.layers.conv2d(activations_gama, 5, kernel_size=3, strides=2, padding='valid')
channels_flat_gamma = tf.reshape(
channels_gamma,
[-1, channels_gamma.shape[1] * channels_gamma.shape[2] * channels_gamma.shape[3]])
gamma = tf.reshape(tf.maximum(tf.layers.dense(channels_flat_gamma, 1), 1e-3), [1])
return color_matrix, gamma def build_isp_process_flow(bayer, color_matrix, gamma):
with tf.name_scope('isp_flow'):
return gamma_correction(
color_correction(
demosaic(
bayer
), color_matrix
), gamma
) # in form of NHWC
def color_normalize(rgb):
return rgb/tf.expand_dims(tf.maximum(tf.reduce_sum(rgb, axis=3), 1e-7), axis=-1) def color_loss(rgb_out, rgb_gt):
return tf.reduce_mean(tf.abs(color_normalize(rgb_out) - color_normalize(rgb_gt))) # load images from files
gt_files = glob.glob(data_dir + '/long/' + current_prefix + '*.ARW')
in_files = [None]*len(gt_files) train_ids = [None] * len(gt_files)
gt_raws = [None] * len(train_ids)
gt_rgbs = [None] * len(train_ids)
in_raws = [None] * len(train_ids) # Reorganize the raw files according to their training id
for i in range(len(gt_files)):
if USE_GPU:
train_ids[i] = gt_files[i].split('/')[-1][1:5]
else:
train_ids[i] = gt_files[i].split('\\')[-1][1:5]
# for input files, multiple ones may relate to single ground truth file
in_files[i] = glob.glob(data_dir + '/short/' + current_prefix + train_ids[i] + '*.ARW')
in_raws[i] = [None]*len(in_files[i]) def get_gt_file_by_train_id(tid):
return gt_files[tid] def get_in_file_by_train_id_file_id(tid, fid):
return in_files[tid][fid] def get_patch_pair_raw_raw(raw_in, raw_gt):
h, w = raw_in.shape[0], raw_in.shape[1]
y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1])
return (
np.expand_dims(raw_in[y:y + patch_size[0], x:x + patch_size[1], :], axis=0),
np.expand_dims(raw_gt[y:y + patch_size[0], x:x + patch_size[1], :], axis=0)
) def get_patch_pair_raw_rgb(raw, rgb):
h, w = raw.shape[0], raw.shape[1]
y, x = np.random.randint(0, h - patch_size[0]), np.random.randint(0, w - patch_size[1])
return (
np.expand_dims(raw[y:y + patch_size[0], x:x + patch_size[1], :], axis=0),
np.expand_dims(rgb[y:y + patch_size[0], x:x + patch_size[1], :], axis=0)
) def get_rand_patch_from_file_raw2rgb():
while True:
seq = np.random.permutation(len(train_ids))
for ind in seq:
if gt_rgbs[ind] is None:
# resource not found in cache, load it from disk
gt_file = get_gt_file_by_train_id(ind)
gt_rgb = rgb_from_file(gt_file)
fid = np.random.randint(0, len(in_files[ind]))
if in_raws[ind][fid] is None:
in_file = get_in_file_by_train_id_file_id(ind, fid)
in_raw = raw_from_file(in_file)
# cache them when using GPU on linux server since memory is sufficient
if USE_GPU:
gt_rgbs[ind] = gt_rgb
in_raws[ind][fid] = in_raw
yield get_patch_pair_raw_rgb(in_raw, gt_rgb) def get_rand_patch_from_file_raw2raw():
while True:
seq = np.random.permutation(len(train_ids))
for ind in seq:
if gt_raws[ind] is None:
# resource not found in cache, load it from disk
gt_file = get_gt_file_by_train_id(ind)
gt_raw = raw_from_file(gt_file)
fid = np.random.randint(0, len(in_files[ind]))
if in_raws[ind][fid] is None:
in_file = get_in_file_by_train_id_file_id(ind, fid)
in_raw = raw_from_file(in_file)
# cache them when using GPU on linux server since memory is sufficient
if USE_GPU:
in_raws[ind][fid] = in_raw
gt_raws[ind] = gt_raw
yield get_patch_pair_raw_rgb(in_raw, gt_raw) # basic nodes
t_bayer_in = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1], name='input')
t_bayer_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 1])
t_bayer_std = black_level_correction(t_bayer_in)
t_bayer_gt_std = black_level_correction(t_bayer_gt)
t_bayer_boosted = network_raw2raw(tf.minimum(300*t_bayer_std, 1.0)) t_half_rgb = bayer_to_rgb(t_bayer_std)
t_half_rgb_boosted = bayer_to_rgb(bound(t_bayer_boosted))
t_half_rgb_gt = bayer_to_rgb(t_bayer_gt_std)
t_half_rgb_resized = tf.image.resize_bilinear(t_half_rgb, fixed_size) t_rgb_gt = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3]) # ISP nodes
t_color_matrix, t_gamma = get_color_matrix_and_gamma(t_half_rgb_resized) # training raw2raw alone
# t_err_raw = tf.reduce_mean(tf.abs(t_half_rgb_gt - t_half_rgb_boosted))
t_err_raw = tf.reduce_mean(tf.abs(gaussian_norm(t_half_rgb_boosted) - gaussian_norm(t_half_rgb_gt))) # training raw2rgb alone
t_half_rgb_freeze = tf.stop_gradient(t_half_rgb_boosted)
t_rgb_freeze = build_isp_process_flow(t_half_rgb_freeze, t_color_matrix, t_gamma)
# t_err_rgb = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_freeze))
t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) + tf.abs(t_gamma[0] - 1.0/2.5)
# t_err_rgb = color_loss(t_rgb_freeze, t_rgb_gt) # training overall model
t_rgb_final = build_isp_process_flow(t_half_rgb_boosted, t_color_matrix, t_gamma)
# t_err_overall = tf.reduce_mean(tf.abs(t_rgb_gt - t_rgb_final))
t_err_overall = color_loss(t_rgb_final, t_rgb_gt) def clean_no_grad_vars(vs, gs):
vs_clear = []
gs_clear = []
for i in range(len(gs)):
if gs[i] is not None:
vs_clear.append(vs[i])
gs_clear.append(gs[i])
return vs_clear, gs_clear def make_var_grad_pairs(vs, gs):
return [(gs[i], vs[i]) for i in range(len(vs))] def train():
print('Staged training begins...')
t_opt = tf.train.GradientDescentOptimizer(learning_rate=learn_rate)
sess = tf.Session() t_minimizer_raw2raw = t_opt.minimize(t_err_raw)
t_minimizer_raw2rgb = t_opt.minimize(t_err_rgb)
t_minimizer_overall = t_opt.minimize(t_err_overall) # include = ['g_conv1', 'g_conv2', 'g_conv3', 'g_conv4', 'g_conv5', 'g_conv_last']
# variables_to_restore = slim.get_variables_to_restore(include=include) # saver = tf.train.Saver(variables_to_restore)
saver = tf.train.Saver(tf.global_variables())
sess.run(tf.global_variables_initializer()) # logger
if not os.path.exists(log_dir):
os.mkdir(log_dir)
logger = tf.summary.FileWriter(log_dir, graph=sess.graph)
t_sum_raw = tf.summary.scalar('raw2raw_loss', t_err_raw)
t_sum_rgb = tf.summary.scalar('raw2rgb_loss', t_err_rgb)
t_sum_all = tf.summary.scalar('overall_loss', t_err_overall) if not os.path.exists(os.path.join(model_dir, 'checkpoint')):
if not os.path.exists(model_dir):
os.mkdir(model_dir)
else:
print('Restoring model...')
model_name_prefix = 'model_checkpoint_path: "'
with open(os.path.join(model_dir + 'checkpoint')) as ckpt:
latest_id = ckpt.readline()[len(model_name_prefix):-2]
saver.restore(sess, os.path.join(model_dir, latest_id)) # bind saver to the full graph instead of a sub-graph
saver = tf.train.Saver(tf.global_variables()) # first stage: raw to raw training
if training_stage == __stage_raw2raw__:
print('Stage I: train to map input raw into ground truth raw')
patches = get_rand_patch_from_file_raw2raw()
counter = 0
t_start = time.clock()
for raw_in, raw_gt in patches:
_, err_raw2raw, sum_raw = sess.run(
[t_minimizer_raw2raw, t_err_raw, t_sum_raw],
feed_dict={
t_bayer_in: raw_in,
t_bayer_gt: raw_gt
}
) logger.add_summary(sum_raw, counter)
epoch = int(counter / len(train_ids))
print('Epoch# %d Counter# %d Loss= %.7f' % (epoch, counter, err_raw2raw))
counter += 1 if counter % 100 is 0:
t_stop = time.clock()
print('Speed: %.6f' % ((t_stop - t_start) / 100))
t_start = t_stop if counter > max_epoch * len(train_ids):
saver.save(sess, model_dir + '/' + str(epoch))
print('Training done.')
break
elif counter % (len(train_ids) * save_epoch_delay) is 0:
saver.save(sess, model_dir + '/' + str(epoch))
print('Model saved.')
# second stage: raw to rgb training
if training_stage == __stage_raw2rgb__:
print('Stage II: train to map generated raw into ground truth rgb') # gradient clip
# t_vs = tf.trainable_variables()
# t_gs = tf.gradients(t_err_rgb, t_vs)
# t_vs, t_gs = clean_no_grad_vars(t_vs, t_gs)
# t_var_grad_pairs = make_var_grad_pairs(t_vs, t_gs)
# t_minimizer_raw2rgb = t_opt.apply_gradients(t_var_grad_pairs) patches = get_rand_patch_from_file_raw2rgb()
counter = 0
t_start = time.clock()
for raw_in, rgb_gt in patches:
_, err_raw2rgb, sum_rgb, gamma = sess.run(
[t_minimizer_raw2rgb, t_err_rgb, t_sum_rgb, t_gamma],
feed_dict={
t_bayer_in: raw_in,
t_rgb_gt: rgb_gt
}
) # _, err_raw2rgb, grads, sum_rgb, gamma = sess.run(
# [t_minimizer_raw2rgb, t_err_rgb, t_gs, t_sum_rgb, t_gamma],
# feed_dict={
# t_bayer_in: raw_in,
# t_rgb_gt: rgb_gt
# }
# ) logger.add_summary(sum_rgb, counter)
epoch = int(counter / len(train_ids))
print('Epoch# %d Counter# %d Loss= %.7f Gamma=%.6f' % (epoch, counter, err_raw2rgb, 1.0 / gamma)) # Gradient check
# for i in range(len(grads)):
# if has_nan_in_tensor(grads[i]):
# print('Nan value found in gradient: %s!' % t_gs[i].name) counter += 1
if counter % 100 is 0:
t_stop = time.clock()
print('Speed: %.6f' % ((t_stop - t_start) / 100))
t_start = t_stop if counter > max_epoch * len(train_ids):
saver.save(sess, model_dir + '/' + str(epoch))
print('Training done.')
elif counter % (len(train_ids) * save_epoch_delay) is 0:
saver.save(sess, model_dir + '/' + str(epoch))
print('Model saved.')
# second stage: overall training
if training_stage == __stage_overall__:
print('Stage III: train to map input raw into ground truth rgb')
patches = get_rand_patch_from_file_raw2rgb()
counter = 0
t_start = time.clock()
for raw_in, rgb_gt in patches:
_, err_overall, sum_all = sess.run(
[t_minimizer_overall, t_err_overall, t_sum_all],
feed_dict={
t_bayer_in: raw_in,
t_rgb_gt: rgb_gt
}
) logger.add_summary(sum_all, counter)
epoch = int(counter / len(train_ids))
print('Epoch# %d Counter# %d Loss= %.7f' % (epoch, counter, err_overall))
counter += 1
if counter % 100 is 0:
t_stop = time.clock()
print('Speed: %.6f' % ((t_stop - t_start) / 100))
t_start = t_stop if counter > max_epoch * len(train_ids):
saver.save(sess, model_dir + '/' + str(epoch))
print('Training done.')
elif counter % (len(train_ids) * save_epoch_delay) is 0:
saver.save(sess, model_dir + '/' + str(epoch))
print('Model saved.')
# finalization
logger.close()
sess.close() def test_half_rgb():
print('Testing Half RGB reconstruction...')
sess = tf.Session() t_vars = tf.global_variables() # var_names = []
# for v in t_vars:
# var_names.append(v.name)
# print(v.name) saver = tf.train.Saver(t_vars) if not os.path.exists(model_dir):
assert 'path not found!'
model_name_prefix = 'model_checkpoint_path: "'
with open(os.path.join(model_dir, 'checkpoint')) as ckpt:
latest_id = ckpt.readline()[len(model_name_prefix):-2]
saver.restore(sess, os.path.join(model_dir, latest_id))
print('Model loaded.') if not os.path.exists(out_dir):
os.mkdir(out_dir) patches = get_rand_patch_from_file_raw2raw()
counter = 0 for raw_in, raw_gt in patches:
half_rgb_boosted, half_rgb_gt = sess.run(
[t_half_rgb_boosted, t_half_rgb_gt],
feed_dict={
t_bayer_in: raw_in,
t_bayer_gt: raw_gt
}
)
im_cmp = concat((half_rgb_boosted[0], half_rgb_gt[0]))
# show(im_cmp, str(counter))
save(im_cmp, (out_dir + '/HALF_%04d.jpg') % counter)
counter += 1
if counter >= 20:
break if __name__ == '__main__':
# test_half_rgb()
train()

1.先说tf.train.Saver()的坑,这个比较严重,其损失是不可挽回的!!!

由于经常需要迁移学习,需要执行图融合的操作,于是,需要先加载一部分子图然后创建另一部分子图,训练完后保存整个模型。

问题是:直接采用tf.train.Saver()的话,等效于saver = tf.train.Saver(tf.global_variables())

在加载子图的时候会报错:因为在子图的checkpoint文件中找不到新创建的子图中的算子,因此需要特别指定要回复的算子,而不是采用tf.global_variables()。

于是将tf.global_variables()这个替换掉,方案有两种:

1.直接利用name的prefix进行变量过滤,即对tf.global_variables()得到的变量列表中的部分变量根据其v.name进行剔除,剩下的就是需要加载的变量。

2.采用tf.contrib.slim直接获取要加载的变量列表,然而这里出现了一个坑:

slim.get_variables_to_restore(include=include) 中 include 是一个name list,采用正则进行名字匹配,原理是:if v.name.startswith('VAR_NAME_PREFIX'): ADD_TO_LIST(ret)

于是当你的include list中有conv2d这个变量名称前缀时,所有的conv2d_xxx都会被自动添加到列表中,而且,SLIM很傻逼的不进行查重检查!!!于是你得到的var_list中将会出现重复的

变量,导致加载模型时报错:at least two of variables have the same name : conv2d_1/bias !!!

填坑完毕!

创建saver一定要指定要加载的变量列表,不然不知不觉的可能导致辛辛苦苦训练好的变量(参数)最终没有保存,永远的在结束训练时的内存中消亡了~~~~~

最新文章

  1. Python之路Day17-jQuery
  2. 如何启动app时全屏显示Default.png(图片)?
  3. 套题 codeforces 361
  4. 【C语言】C语言常量和变量
  5. 。【自学总结 2】------3ds Max 菜单
  6. hdu4081 次小生成树
  7. 第四周 课堂Scrum站立会议
  8. 【socket】高级用法-异步
  9. 【转贴】gdb中的信号(signal)相关调试技巧
  10. js函数、表单验证
  11. 复习hiernate
  12. Centos7下安装php7
  13. 微信小程序与AspNetCore SignalR聊天实例
  14. Echarts扩展地图文字位置错乱的问题
  15. HDU 1046(最短路径 **)
  16. Vmware ESXi 的虚拟机的开机自启动
  17. KMeans算法分析以及实现
  18. 【Android端 adb相关】adb相关总结
  19. idea创建第一个maven web项目
  20. spring cloud 学习(4) - hystrix 服务熔断处理

热门文章

  1. ok6410 nandflash 启动uboot 超过256k怎么办
  2. Appium(一)---环境搭建的一些问题
  3. 刷Python核心编程第三版的习题时遇到一个findall的坑
  4. ceph添加osd(ceph-deploy)
  5. 《视觉SLAM十四讲课后作业》第二讲
  6. netperf
  7. 旧版本firefox添加扩展addons的地址
  8. 用python计算圆周率
  9. 手写JavaScript常用的函数
  10. 类中为什么要定义__init__()方法