本文目的

在介绍estimator分布式的时候,官方文档由于版本更新导致与接口不一致。具体是:在estimator分布式当中,使用dataset作为数据输入,在1.12版本中,数据训练只是dataset的数据,就是所有设备加起来,跑一遍数据。

而在2.0版本中,训练数据是dataset的数据乘以分

布式的设备数。也就是说,在每个设备当中都会完整地跑一遍dataset的所有数据。

1.12版本读取

1. 在主线程当中创建图

下面这段代码中,在client中调用了input function,得到迭代器。这是属于estimator distribute train调用的代码

with ops.Graph().as_default() as g:
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
if is_tpu_strategy:
steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
iterator, input_hooks = self._get_iterator_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
  • _get_iterator_from_input_fn * 这个函数会生成迭代器供后续训练读取数据。
  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode) iterator = result.make_initializable_iterator()
input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
return iterator, input_hooks

这里会调用distribute_dataset生成dataset。

再点进去看以后可看到会创建这样一个PerDeviceDataset

class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices # Default to using prefetching in graph mode, unless specified.
# TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently") if self._prefetch_on_device:
self._dataset = dataset.apply(
prefetching_ops_v2.prefetch_to_devices(self._devices))
else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
self._dataset = dataset.batch(len(devices), drop_remainder=True)

最后一行代码可以看到,在原dataset上又封装了一层batch。将数据根据设备数切分。

后面创建迭代器也是封装为PerDeviceDataIterator,形成一个字典映射,不同设备不同数据,根据batch 的index切分。

分布式训练

在1.12版本中的训练比较简单。对于MirroredStrategy来说,会给每个一个device创建一个线程,

有一个缺点就是,每一次run都会创建线程,在todo里看到,后续会优化掉应该。

下面是在client中从迭代器获取数据,传递给每个device去运算的代码,

self._train_distribution.call_for_each_tower

features, labels = estimator_util.parse_iterator_result(
iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
loss = self._train_distribution.unwrap(
self._train_distribution.reduce(
distribute_lib.get_loss_reduction(),
grouped_estimator_spec.loss,
destinations='/device:CPU:0'))[0]
distributed_train_op = grouped_estimator_spec.train_op

call_for_each_tower是每个设备训练的接口

def _call_for_each_tower(distribution, fn, *args, **kwargs):
"""Run `fn` in separate threads, once per tower/worker device.
run_concurrently = kwargs.pop("run_concurrently", True)
if not context.executing_eagerly():
# Lots of TF library code isn't thread-safe in graph mode, and
# there is little to be gained by turning on multithreading when
# constructing a graph.
run_concurrently = False
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
elif run_concurrently is None:
run_concurrently = True coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,)) shared_variable_store = {} # TODO(isaprykin): Create these threads once instead of during every run()
# call.
threads = []
for index, d in enumerate(distribution.worker_devices):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
distribution, coord, d, variable_creator_fn, fn,
*values.select_device(d, args), **values.select_device(d, kwargs))
threads.append(t) for t in threads:
t.start()

其中,select_device就是取对应设备key对应的值。完成整个分布式训练。

最新文章

  1. Android中的数据保存
  2. my_strcpy()
  3. -webkit-text-size-adjust:none;
  4. 《构建之法》第8、9、10章读后感和Sprint总结
  5. Sql中的union和union all的讲解
  6. 分享:Android中利用机器码注册机制防止破解(转)
  7. Java操作MongoDB
  8. [Hive - Tutorial] Querying and Inserting Data 查询和插入数据
  9. Codeforces Round #206 (Div. 2)
  10. hibernate01ORM的引入
  11. CI分页器pagination的原理及实现
  12. uboot移植前奏
  13. 使用Java POI来选择提取Word文档中的表格信息
  14. kaldi通用底层矩阵运算库——CBLAS
  15. SQL Server的实例恢复解析
  16. SQL UPDATE with INNER JOIN
  17. Windows IIS 使用批处理脚本自动安装与卸载
  18. cf1144E 假高精度平均数
  19. windows添加永久静态路由
  20. 继承ViewGroup学习onMeasure()和onLayout()方法

热门文章

  1. [学习笔记] pymysql入门
  2. 全文检索方案Elasticsearch【Python-Django 服务端开发】
  3. Wpf窗口设置可拖动
  4. java线程池,阿里为什么不允许使用Executors?
  5. 后端开发之chrome开发者模式
  6. RecyclerView实现混合布局
  7. Python模块之snmp-cmds,easysnmp
  8. 整合-flowable-modeler,第一篇
  9. Springboot源码分析之番外篇
  10. Python --深入浅出Apriori关联分析算法(二) Apriori关联规则实战