[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

0x00 摘要

为了更好的介绍参数服务器Paracel的数据加载,我们临时插入两篇PyTorch的数据加载,主要是从分布式的角度进行切入。本文只算是开胃甜点,后续会有专门系列分析PyTorch分布式。

参数服务器系列其他文章如下:

[源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice

[源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van

[源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer

[源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

[源码解析] 机器学习参数服务器 Paracel (1)-----总体架构

[源码解析] 机器学习参数服务器 Paracel (2)--------SSP控制协议实现

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

0x01 前情回顾

关于数据加载,上回书我们说到了 DistributedSampler,本文接下来就进行 DataLoader的分析。

为了更好说明,我们首先给出上文的流水线图,本文会对这个图进行细化。

                    +------------+
+--------+ | |
| | | Process 1 |
| Data 1 +--------> | +------+
| | | Load Data | |
+--------+ | | |
+------------+ |
|
|
|
+------------+ | +-----------------------------------+
+--------+ | | | | |
| | | Process 2 | +------> | Pin-memory process |
| Data 2 +--------> | | | |
| | | Load Data +-------------> | |
+--------+ | | | Transfer to Pinned Memory |
+------------+ +-----> | |
| | |
| +-----------------------------------+
|
+--------+ +------------+ |
| | | | |
| Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+

其次,我们再看看数据加载总体逻辑,具体如下图,简要说就是:

  1. DataSet 把数据集数目发给DistributedSampler。
  2. Sampler 按照某种规则生成数据indices并发送给DataLoader。
  3. DataLoader 依据indices来从DataSet之中加载数据(其内部的DataLoaderIter对象负责协调单进程/多进程加载Dataset)。
  4. DataLoader 把数据发给模型,进行训练。
+------------------------+                     +-----------+
|DistributedSampler | |DataLoader |
| | 2 indices | |
| Some strategy +-------------------> | |
| | | |
|-------------+----------| | |
^ | | 4 data +-------+
| | -------------->+ train |
1 | length | | +-------+
| | |
+-------------+----------+ | |
|DataSet | | |
| +---------+ | 3 Load | |
| | Data +-------------------------> | |
| +---------+ | | |
| | | |
+------------------------+ +-----------+

接下来,我们就正式进入 DataLoader。

0x02 DataLoader

DataLoader的作用是:结合Dataset和Sampler之后,在数据集上提供了一个迭代器

可以这么理解:

DataSet 是原始数据,Sampler 提供了如何切分数据的策略(或者说是提供了切分数据的维度),DataLoader就是依据策略来具体打工干活的,其中单进程加载就是一个人干活,多进程加载就是多拉几个人一起干活

2.1 初始化

初始化的主要参数如下:

  • dataset (Dataset) :所加载的数据集。
  • batch_size (int, optional) :每个批次加载多少个样本。
  • shuffle (bool, optional) :如果为 True,则每个epoch 都会再打乱数据。
  • sampler (Sampler or Iterable, optional) :定义了如何从样本采样的策略。可以是任何实现了 __len__的迭代器。
  • batch_sampler (Sampler or Iterable, optional) :与sampler类似,但是每次返回一个批次的数据索引。
  • num_workers (int, optional) :数据加载的子进程数目。如果是 0,表示从主进程加载数据。
  • collate_fn (callable, optional):从一个小批次( mini-batch)张量中合并出一个样本列表。当从 map-style 数据集做批量加载时候使用。
  • pin_memory (bool, optional) : 如果为true,则在返回张量之前把张量拷贝到CUDA固定内存之中。
  • drop_last (bool, optional) :当数据集不能被均匀分割时,如果为true,丢掉最后一个不完整的批次。如果为False,那么最后一个批次的数据较小。
  • timeout (numeric, optional): 如果是整数,则是worker收集批次数据的超时值。
  • worker_init_fn (callable, optional):如果非空,则会在seeding和数据加载之前被每个子进程调用,以Iworker id ([0, num_workers - 1])作为输入参数。
  • generator (torch.Generator, optional):如果非空,则被RandomSampler 用来产生随机索引,也被多进程用来产生 base_seed
  • prefetch_factor (int, optional, keyword-only arg):每个 worker 提前加载 的 sample 数量。
  • persistent_workers (bool, optional):如果为 True, 则在消费一次之后,data loader也 不会关掉worker进程。这允许workerDataset实例维持活动状态。

具体初始化代码如下,主要就是各种设置,为了更好的说明,去除了异常处理代码:

class DataLoader(Generic[T_co]):

    dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Sampler
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader") self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# 省略异常处理
else:
self._dataset_kind = _DatasetKind.Map if batch_sampler is not None:
# auto_collation with custom batch_sampler
# 省略异常处理
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with drop_last') if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset) if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last) self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._IterableDataset_len_called = None
self._iterator = None
self.check_worker_number_rationality()

2.2 关键函数

这里关键函数之一就是_index_sampler,用来让迭代器调用sampler,我们接下来就会讲到

    @property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler

2.3 单进程加载

单进程模式下,Data Loader会在计算进程内加载数据,所以加载过程中可能会阻塞计算。

for 语句会调用enumerate 会返回一个迭代器,以此来遍历数据集。在eumerate之中,dataloader 的 __next__(self) 方法会被调用,逐一获取下一个对象,从而遍历数据集。

    cuda0 = torch.device('cuda:0')  # CUDA GPU 0
for i, x in enumerate(train_loader):
x = x.to(cuda0)

2.3.1 区分生成

当多进程加载时候,在DataLoader声明周期之中,迭代器只被建立一次,这样worker可以重用迭代器。

在单进程加载时候,应该每次生成,以避免重置状态。

    def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0: # 如果是多进程或者设置了持久化
if self._iterator is None: # 如果没有,才会新生成
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else: # 单进程
return self._get_iterator() # 每次都直接生成新的

具体会依据是否是多进程来区别生成。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)

2.3.2 迭代器基类

_BaseDataLoaderIter 是迭代器基类,我们挑选关键函数看看。

这里关键成员变量就是:

  • _index_sampler:这里设置了loader 的 sampler,所以迭代器可以据此获取采样策略。
  • _sampler_iter:得到 sampler 的迭代器。
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化参数
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler # 得到采样策略
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__) def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 获取数据
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略错误提示处理
warnings.warn(warn_msg)
return data

2.3.3 单进程迭代器

_SingleProcessDataLoaderIter 继承了 _BaseDataLoaderIter,可以看到,其增加了 _dataset_fetcher,在构造时候传入了 _collate_fn 等各种参数。

回忆下,__next__会调用 self._next_data() 获取数据,而在这里,_next_data 就会:

  • 使用 self._next_index(),其又会使用 _sampler_iter(采样器的迭代器)来获取indices 。
  • 使用 self._dataset_fetcher.fetch(index)来依据indices获取数据。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0 # 获取样本方法
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last) def _next_data(self):
index = self._next_index() # may raise StopIteration
# 获取样本
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data def _next_index(self): # 得到indices
return next(self._sampler_iter) # may raise StopIteration

2.3.4 获取样本

我们接下来看看如何获取样本。就是通过索引传入 fetcher,从而获取想要的样本。

fetcher生成如下,这是在_SingleProcessDataLoaderIter初始化时候生成的:

class _DatasetKind(object):
Map = 0
Iterable = 1 @staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

对于Map-style,就使用 _MapDatasetFetcher 处理,就是使用 possibly_batched_index 从数据集之中提取数据,possibly_batched_index 是key。

如果有batch sampler,就使用 batch sampler。

如果需要从一个小批次( mini-batch)张量中合并出一个样本列表。就使用 collate_fn后处理。

class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last) def fetch(self, possibly_batched_index):
if self.auto_collation:
# 如果配置了batch_sampler,_auto_collation就为True,
# 那么就优先使用batch_sampler,此时fetcher中传入的就是一个batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)

对于 Iterable-style,因为 __init__ 方法内设置了 dataset 初始的迭代器,所以在fetch 方法内获取元素的时候,如果是常规 sampler,index 其实已经不起作用,直接从dataset迭代器获取。如果是batch sampler,则index有效果。

class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset) def fetch(self, possibly_batched_index):
if self.auto_collation:
# 即auto_collation为True,表示使用batch_sampler。
# 则使用possibly_batched_index,获取1个batch大小的样本
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# sampler则直接往后遍历,提取1个样本
data = next(self.dataset_iter)
return self.collate_fn(data)

此时总逻辑如下:

     +--------------------------+            +-------------------------------+
| DataLoader | | _SingleProcessDataLoaderIter |
| | | |
| | | __next__ |
+---------------+ Sampler | | |
| | | | _next_data +-----------+
| | Dataset | | | |
| | | | _next_index | |
| | __iter__ | | | |
| | | | _index_sampler | |
| | _get_iterator +--------------> | + | |
| | | | | | |
| +--------------------------+ +-------------------------------+ |
| | |
| | |
| | |
| | |
| | |
| +----------------------------+ | |
| |Sampler | | |
+------------------------> | | <------+ |
| | |
| | |
| | |
+----------------------------+ |
|
|
+----------------------------+ |
|_BaseDatasetFetcher | |
| | |
| | |
| dataset | |
| | <----------------------+
| collate_fn |
| |
+----------------------------+

动态流程如下:

  User              DataLoader    _SingleProcessDataLoaderIter _DatasetKind   Sampler

    +                   +                    +                        +           +
| | | | |
| 1 | | | |
enumerate--------> __iter__ | | |
| + | | |
| | | | |
| | | | |
| | 2 v 3 v |
| _get_iterator--------> __init__ +----------> create_fetcher |
| 4 | + + |
| <-----------------+ | | |
| iterator | | | |
| | 5 | | |
for loop +------------------------------> __next__ | |
| | | | |
| | | | |
| | | | |
| | _next_data | |
| | | | |
| | | | |
| | | 6 next | |
| | _next_index +-------------------------> |
| | | | |
| | | <---------------------------------+
| | | 7 index | |
| | | | |
| | | | |
| | | 8 fetch(index) | |
| | | +--------------------> | |
| | | | |
| | | <---------------------+ |
| | | 9 data | |
| <-------------------------------------+ | |
| 10 data | | | |
| | | | |
v v v v v

2.4 多进程加载

为了加速,PyTorch提供了多进程下载,只要把将参数 num_workers 设置为正整数,系统就会相应生成多进程处理,在这种模式下,每个worker都是一个独立进程。

由上节我们可以知道,_SingleProcessDataLoaderIter 是单进程加载数据的核心,loader通过它来与sampler,dataset交互。在多进程中,这个核心对应的就是 _MultiProcessingDataLoaderIter。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)

我们接下来就从 _MultiProcessingDataLoaderIter 开始分析。

2.4.1 总体逻辑

_MultiProcessingDataLoaderIter 中的注释十分详尽,值得大家深读,而且给出了逻辑流程图如下,其基本流程是围绕着三个queue进行的:

  • 主进程把需要获取的数据 index 放入index_queue,这是指定子进程需要获取哪些数据的队列。同时也给子进程传入结果队列,关于结果队列,有两个分支:

    • 如果设置了pin memory,则传入的是 worker_result_queue。
    • 否则传入 data_queue。
  • 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue,这是向主进程返回结果的队列
  • 主进程进行处理,这里有两个分支:
    • 如果设置了pin memory,则主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue,这是处理结果的队列
    • 如果不需要pin memory,则结果已经存在 data_queue 之中,不做新操作。

可以看到,每个进程的输入是一个队列index_queue ,输出也是一个队列worker_result_queue。主进程和子进程通过这2~3个 queue 联系了起来,从而达到解耦合和加速的作用

    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.

具体如下图所示,如果不需要 pin memory,则为:

                                               +-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | +------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | Data Queue |
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | +------------+
| |
| indices -------------+ indices +-----------+ |
+--------->+index queue +-------->+ Worker | Data |
| | | Process +------+
-------------+ | |
+-----------+

当有pin memory时候,则是先进入 result queue,然后 pin_memory_thread 处理之后会转入到 data queue:

                                               +-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | --------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | result_queue|
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | ---------+----+
| | |
| indices -------------+ indices +-----------+ | |
+--------->+index queue +-------->+ Worker | Data | +---------+--------+
| | | Process +------+ | pin_memory_thread|
-------------+ | | | | |
+-----------+ | | |
| | |
+------------------+
|
|
|
v
+-----+------+
| Data Queue |
| |
+------------+

2.4.2 初始化

初始化函数如下,主要是:

  • 配置,生成各种成员变量,配置各种queue。
  • 启动各个子进程。
  • 启动主进程中的pin_memory的线程。

主要成员变量为:

  • _index_queues: 这是一个queue 列表,列表的每一个元素是一个 queue,就是每个子进程的队列需要处理的数据index,每个子进程对应一个 queue。
  • _worker_result_queue: 子进程处理完的 (idx, data)。
  • data_queue: 经过主进程 pin_memory 线程处理之后的数据队列,如果不需要pin,则直接会使用 _worker_result_queue
  • _worker_queue_idx_cycle 用以找出下一个工作的worker。

具体代码如下:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler""" def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader) assert self._num_workers > 0
assert self._prefetch_factor > 0 if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # 子进程输出,读取完数据的index
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event() self._index_queues = [] # 子进程输入,需读取数据的index
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # worker进程主函数,把各种queue和函数传进去
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue) # 把这个worker对应的index_queue放到主进程这里存起来,以后就可以交互了
self._workers.append(w) if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event() # Queue is not type-annotated
self._data_queue = queue.Queue() # pin 处理之后的数据结果
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # 如果不需要pin,则直接使用_worker_result_queue # .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True self._reset(loader, first_iter=True) # 继续完善业务

2.4.3 业务重置

__init__ 函数最后会调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。

上小节函数中,已经启动了worker子进程,但是没有分配任务,所以_reset函数会进行任务分配,预取。

_MultiProcessingDataLoaderIter有如下 flag 参数来协调各个 worker (包括各种queue)之间的工作:

  • _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx

  • _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx

  • _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整型索引),value 为 (worker_id,)(worker_id, data),分别对应数据 未取 和 已取 的情况

  • _tasks_outstanding: 整型,代表已经准备好的 task/batch 的数量(可能有些正在准备中)

  • _send_idx: 发送索引,记录下一次要放 index_queue 中 task batch 的 idx。

  • _rcvd_idx: 接受索引,记录下一次要从 data_queue 中取出的 task batch 的 idx。_send_idx_rcvd_idx 主要用来进行流量控制和确保接受索引有意义。

  • _task_info: 存储将要产生的 data 信息的 dict,key为 task batch idx(由 0 开始的整型索引),value 为 (worker_id,)(worker_id, data),分别对应数据 未取 和 已取 的情况。_task_info的作用是依据 task batch idx 获取对应的 worker id 和暂存乱序数据。

  • _tasks_outstanding: 整型,正在准备的 task/batch 的数量,实际上就是进行一些确认工作,没有太实际的意义。

对于加载数据,每个 worker 一次产生一个 batch 的数据,返回 batch 数据前,会放入下一个批次要处理的数据下标,所以 reset 函数会把 _send_idx_rcvd_idx 都恢复成0,这样下次迭代就可以重新处理。

在 reset 方法最后,有一个预取数据操作。我们会在后面结合乱序处理进行讲解

    def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
# 每个worker的状态
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
return_idx, return_data = self._get_data()
if isinstance(return_idx, _utils.worker._ResumeIteration):
assert return_data is None
resume_iteration_cnt -= 1
# prime the prefetch loop # 预取若干index,目的是为了配合后续的乱序处理。
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()

2.4.4 获取 index

_try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。

  • 从sampler获取下一批次的index。
  • 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
  • 并且调整主进程的信息。
    def _next_index(self): # 定义在基类 _BaseDataLoaderIter 之中,就是获取下一批index
return next(self._sampler_iter) # may raise StopIteration def _try_put_index(self): assert self._tasks_outstanding < self._prefetch_factor * self._num_workers try:
index = self._next_index() # 获取下一批index
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]: # 如果已经工作,就继续找
break
else:
# not found (i.e., didn't break)
return # 以下是主进程进行相关记录
# 给下一个工作worker放入 (任务index, 数据index), 就是给queue放入数据,所以worker loop之中就立刻会从queue中得到index,从而开始获取数据。
self._index_queues[worker_queue_idx].put((self._send_idx, index))
# 记录 将要产生的 data 信息
self._task_info[self._send_idx] = (worker_queue_idx,)
# 正在处理的batch个数+1
self._tasks_outstanding += 1
# send_idx 记录从sample_iter中发送索引到index_queue的次数
self._send_idx += 1 # 递增下一批发送的task index

2.4.5 worker主函数

_worker_loop 是 worker进程的主函数,主要逻辑如其注释所示:

    # [ worker processes ]
# While loader process is alive:
# Get from `index_queue`.
# If get anything else,
# Check `workers_done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data:
# If is fetching from an `IterableDataset` and the iterator
# is exhausted, send an `_IterableDatasetStopIteration`
# object to signal iteration end. The main process, upon
# receiving such an object, will send `None` to this
# worker and not use the corresponding `index_queue`
# anymore.
# If timed out,
# No matter `workers_done_event` is set (still need to see `None`)
# or not, must continue to next iteration.
# (outside loop)
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
# `data_queue.cancel_join_thread()`. (Everything is ending here:
# main process won't read from it;
# other workers will also call
# `cancel_join_thread`.)

就是通过index_queue, data_queue与主进程交互。

  • 从 index_queue 获取新的数据index;
  • 如果没有设置本worker结束,就使用 fetcher获取数据
  • 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue, 否则传入 data_queue
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function. try:
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers() torch.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
np.random.seed(np_seed) global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset) from torch.utils.data import _DatasetKind init_exception = None try:
if init_fn is not None:
init_fn(worker_id) fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id)) iteration_end = False
watchdog = ManagerWatchdog() while watchdog.is_alive(): # 等待在这里
try:
# _try_put_index 如果放入了数据index,这里就被激活,开始工作
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
# 省略处理代码 data_queue.put((idx, data)) # 放入数据,通知主进程
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()

2.4.6 Pin memory thread

在主进程之中,如果设置了需要pin memory,主进程的 pin_memory_thread 会从 worker_result_queue 读取数据,进行处理(加速CPU和GPU的数据拷贝),把结果放入 data_queue。

    # [ pin_memory_thread ]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While `pin_memory_thread_done_event` is not set:
# Get from `index_queue`.
# If timed out, continue to get in the next iteration.
# Otherwise, process data.
# While `pin_memory_thread_done_event` is not set:
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
# If timed out, continue to put in the next iteration.
# Otherwise, break, i.e., continuing to the out loop.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.

具体代码如下:

def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1) torch.cuda.set_device(device_id) # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
data = pin_memory(data)
# 省略异常处理代码
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
del r # save memory def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, collections.abc.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, collections.abc.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data

2.4.7 用户获取data

现在数据已经加载完毕,我们接下来看用户如何从DataLoader之中获取数据。

这里有一个很关键的地方:如何保持在不同实验之中数据读取顺序的一致性。为了让多次实验之间可以比对,就需要尽量保证在这些实验中,每次读取数据的顺序都是一致的,这样才不会因为数据原因造成结果的误差。

打破顺序一致性的最大可能就是乱序数据。而造成乱序问题的原因就是:多进程读取,可能某个进程快,某个进程慢。比如,用户这次需要读取6-19,16-26,37-46。但是某一个worker慢,6-19不能即时返回,另一个worker 的 16-26 先返回了,于是就会造成乱序。

如何处理乱序数据?PyTorch的具体做法就是:DataLoader严格按照Sampler的顺序返回数据。如果某一个数据是乱序的,则会把它暂存起来,转而去获取下一个数据,见下面代码中 "store out-of-order samples" 注释处。等到应该返回时候(这个数据顺序到了)才返回。

但是其风险就是数据返回会比当前请求慢,比如应该获取 6,但是Data queue里面没有这个数据,只有 16,27,于是用户只能等待 6 加载完成。

解决慢的方法是:预取(prefetch)。就是在reset方法最后,提前提取若干index,让DataLoader提前去取,这虽然不能保证任意两次训练的数据返回顺序完全一致,但是可以最大限度保证。

具体代码如下,首先,回忆基类的 __next__ 函数 ,可以看到其调用了 _next_data 获取数据。

class _BaseDataLoaderIter(object):
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 获取数据
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略错误提示处理
warnings.warn(warn_msg)
return data

所以,我们要看 _MultiProcessingDataLoaderIter_next_data

  • 因为之前有预取了index,worker进程已经开始获取数据,所以主进程此时可以得到数据,如果没有数据,就继续while True等待。
  • 如果获取成功,则使用 _process_data 设定下一次的indx,准备下一次迭代。
  • 通过 _task_info 来记录乱序数据,如果暂时无法处理,就在这里保存。
    def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead. # 找到待取idx
while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break # 有数据或者正在工作,就跳出内部这个while
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration # Now `self._rcvd_idx` is the batch index we want to fetch # Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data) # 设定下一次的indx,进行下一次迭代 assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # 从 self._data_queue 中取数据
self._tasks_outstanding -= 1 # 正在准备的batch个数需要减1 if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue if idx != self._rcvd_idx: # 乱序数据
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx] # 正常数据
return self._process_data(data) # 设定下一次的indx,进行下一次迭代

其次,我们看看 _get_data 如何从 self._data_queue 中取数据。具体是使用 _try_get_data 来提取。

  • 如果有超时配置,就按照超时读取。
  • 如果设置了pin memory,则从pin 线程处理之后的数据读取。
  • 否则循环读取worker处理的数据,直至获取到数据为止。
    def _get_data(self):
# Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
# in a loop. This is the only mechanism to detect worker failures for
# Windows. For other platforms, a SIGCHLD handler is also used for
# worker failure detection.
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
if self._timeout > 0: # 如果有超时配置,就按照超时读取
success, data = self._try_get_data(self._timeout)
if success:
return data
else:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
elif self._pin_memory: # 从pin 线程处理之后的数据读取
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
success, data = self._try_get_data() # 读取worker处理的数据
if success:
return data

_try_get_data 就是从 _data_queue 读取。主进程和worker进程通过queue上的put, get进行通讯交互。

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._mark_worker_as_unavailable(worker_id)
# 省略异常处理代码
import tempfile
import errno
try:
# Raise an exception if we are this close to the FDs limit.
# Apparently, trying to open only one file is not a sufficient
# test.
# See NOTE [ DataLoader on Linux and open files limit ]
fds_limit_margin = 10
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
except OSError as e:
# 省略异常处理代码
raise

设置下一次迭代是使用_process_data

    def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # 设定下一次的indx,进行下一次迭代
if isinstance(data, ExceptionWrapper):
data.reraise()
return data # 返回数据

2.4.8 小结

我们小结一下多进程逻辑。

总体逻辑如下:

  • 主进程把需要获取的数据 index 放入index_queue。
  • 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue。
  • 主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue。

具体流程如下图:

  1. 在 _MultiProcessingDataLoaderIter 的初始化函数 __init__ 之中会进行初始化:

    • 配置,生成各种成员变量,配置各种queue。
    • 启动各个子进程。
    • 启动主进程中的pin_memory的线程。
    • 调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。上面已经启动了worker子进程,但是没有分配任务,所以reset函数会进行任务分配,预取
  2. 接下来是一个预取操作(在看下图中一定要留意)。
    • _try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。

      • 使用 _next_index 从sampler获取下一批次的index。
      • 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
      • 并且调整主进程的信息。
    • 拿到index之后,回到主线程。这里会进行数据提取。就是通过index_queue, data_queue与主进程交互。
      • 从 index_queue 获取新的数据index;
      • 如果没有设置本worker结束,就使用 fetcher获取数据。
      • 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue,否则传入 data_queue。
  3. 当用户迭代时,调用了Loader基类的 __next__ 函数 ,其调用 _next_data 从 DataLoader 之中获取数据。
    • 使用 _get_data 如何从 self._data_queue 中取数据。
    • 使用_process_data 设置下一次迭代的 index,即使用 _try_put_index_next_index 来进行下一轮设置。

具体如下图:

user        _MultiProcessingDataLoaderIter   Sampler        Queue(index_queue)    Queue(data_queue)    _worker_loop     Fetcher
+ + + + + + +
| | | | | | |
| | | | | | |
| v | | | | |
| __init__ | | | | |
| 1 _reset | | | | |
| + | | | | |
| | | | | | |
| | | | | | |
| v | | | | |
| 2 _try_put_index next | | | | |
| _next_index +------------> | | | | |
| + | | | | |
| | <-----------------+ | | | | |
| | index | | | | |
| | | | | | |
| | +------------------------------------> | | | |
| | put | | | get | |
| | | +--------------------------------------> | |
| | | | | | index |
| | | | | +------------> |
| next | | | | | <----------+ |
+---------------------> | | | | <----------------+ data |
| | | | | data | |
| + | | | | |
| _next_data | | | | |
| 3 _get_data get | | | | |
| _try_get_data +--------------------------------------------------> | | |
| + | | | | |
| | <----------------------------------------------------------+ | | |
| | data | | | | |
| + | | | | |
| _process_data | | | | |
| _try_put_index next | | | | |
| _next_index +-------------> | | | | |
| + <--------------------+ | | | |
| | index | | | | |
| +---------------------------------------> | | get | |
| <-------------------+ | put | +-------------------------------------> | index |
| data | | | | | +----------> |
| | | | | +<-----------+ |
v v v v v v data v

手机上如下:

2.5 Pipleline

至此,我们把之前的pipeline图进一步细化,具体如下:

                                                  +------------+
+--------+ | |
| | | Process 1 |
+-----> | Data 1 +--------> | +------+
| | | | Load Data | |
| +--------+ | | |
| +------------+ |
| |
| |
| |
+----------------+ | +------------+ | +-------------------------+
|Main process | | +--------+ | | | | pin_memory_thread |
| | | | | | Process 2 | +------> +------------------------+ | | +------------+
| index_queue +----------> | Data 2 +--------> | | | | | | | |
| | | | | | Load Data +-------------> | _worker_result_queue +-----> | Write to pinned memory +--------> | data_queue |
| | | +--------+ | | | | | | | |
+----------------+ | +------------+ +-----> | | | | +------------+
| | +------------------------+ | |
| | +-------------------------+
| |
| +--------+ +------------+ |
| | | | | |
+-----> | Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+

手机如下:

至此,PyTorch 分布式的数据加载部分分析完毕,下一篇我们回归到 Paracel 如何处理数据加载。

0xFF 参考

卷积神经网络的并行化模型--One weird trick for parallelizing convolutional neural networks

AI框架中数据处理的挑战与解决思路

PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

谈谈你对大规模机器学习这个领域的理解和认识?

Nvidia-DALI 从放弃到入门

pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel

Pytorch数据Pipeline设计总结

深度学习框架数据Pipeline设计

最新文章

  1. [转]SQL 常用函数及示例
  2. Android Activity的生命周期简单总结
  3. 关闭CENTOS不必要的默认服务
  4. JSon_零基础_007_将JSon格式的&quot;数组&quot;字符串转换为Java对象&quot;数组&quot;
  5. Stm32_调试出现 Error:Flash Download Failed-&quot;Cortex-M3&quot;
  6. Python爬取百度贴吧图片
  7. \\ip 映射 指定的网络名不再可用
  8. 如何在Win10中启用和关闭管理员账户?
  9. Redis同步(主从复制)
  10. mysql 获取当前时间戳
  11. ElasticSearch 集群监控
  12. python3全栈开发-并发编程的多进程理论
  13. java并发包分析之———AQS框架
  14. 2018-2019-2 20175217 实验三《敏捷开发与XP实践》实验报告
  15. Excel 2010如何打开多个独立窗口?
  16. [Vue warn]: Duplicate keys detected: &#39;1&#39;. This may cause an update error
  17. ucore-lab1-练习6report
  18. node基础知识
  19. newcoder 筱玛的迷阵探险(搜索 + 01字典树)题解
  20. Iphone控件大全

热门文章

  1. ROS笔记一
  2. Java | 方法的定义 &amp; 重载 &amp; 递归
  3. 从 Vue 中 parseHTML 方法来看前端 html 词法分析
  4. navicate for mysql命令中输入中文报错
  5. 00JAVA语法基础_四则运算 01
  6. Kubernetes实战:高可用集群的搭建和部署
  7. 【剑指offer】42.和为S的两个数字
  8. 两万字Vue.js基础学习笔记
  9. 什么是jstl表达式,怎么应用
  10. ts 学习笔记 - 类