前段时间,由于自己偶遇了一个关于pytorch dataset的bug,便顺便把pytorch dataloader的源码看了下,以了解pytorch dataloader并行读数据的原理。(本来不想做什么笔记来着,恰好今天闲来无事,就记一下吧)

首先我们先来看,pytorch dataloader接收的一些参数:

1
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

dataloader对于数据的读取延迟主要取决于num_workerspin_memory这两个参数。首先,我先介绍一下比较简单的pin_memory参数。

什么是pin_memory

所谓的pin_memory就是锁页内存的意思。

计算机为了运行程序会先将程序和数据读到内存里。一般来说,计算机的内存都是比较小的,很难存的下太多的数据。但是,某个程序在某个时间段所需的程序和数据往往是比较少的,也就是说在某个时间点我们不需要将一个程序所需要的所有资源都放在内存里。我们可以将这些暂时用不到的数据或程序存放在硬盘一个被称为虚拟内存的地方。在程序运行的时候,我们可以不断交换内存和虚拟内存的数据以减少内存所需存储的数据。而且这些交换往往是通过某些规律预测下个时刻程序会用到的数据和代码并提前交换至内存的,这些规律的使用以及预测的准确性将会影响到程序的速度。

所谓的锁页内存就是说,我们不允许系统将某些内存里的数据交换至虚拟内存,毋庸置疑这将会提升程序的运行速度。但是也会是内存的存储占用消耗很多。

(我自己没有测试pin_memory为true的时候速度的提升会有多大,只是原理性的理解了一下,我自己感觉影响应该不是特别大,有兴趣的可以试一下。)

Dataloader的多进程读数据细节

Dataloader多进程读取数据的参数是通过num_workers指定的,num_workers为0的话就用主进程去读取数据,num_workers为N的话就会多开N个进程去读取数据。这里的多进程是通过python的multiprocessing module实现的(其实pytorch在multiprocessing又加了一个wraper以实现shared memory)。

先借用源码中的流程图解释一下Dataloader读数据的整个流程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14

# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
  1. 首先每个worker的进程会拥有一个index_queue,dataloader初始化的时候,每个worker的index_queue会放入两个batch的index。

    1
    2
    3
    4
    # dataloader初始化,共放入2 * self.num_workers个batch的index
    # 每个worker的index_queue获得两个batch的index
    for _ in range(2 * self.num_workers):
    self._put_indices()

    index的放入是根据worker的id顺序放入

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    def _put_indices(self):
    assert self.batches_outstanding < 2 * self.num_workers
    indices = next(self.sample_iter, None)
    if indices is None:
    return
    self.index_queues[self.worker_queue_idx].put((self.send_idx, indices))
    #index是根据worker的id顺序放入的,不会出现先放某个worker的情况
    self.worker_queue_idx = (self.worker_queue_idx + 1) % self.num_workers
    self.batches_outstanding += 1
    self.send_idx += 1
  1. 每个worker的进程会不断检查自己的index_queue里有没有值,没有的话就继续检查。

    1
    2
    3
    4
    try:
    r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
    except queue.Empty:
    continue

    有的话,就去读一个batch(这个读的过程是通过调用dataset的get_item()实现的,并通过函数将数据合并为一个batch)。放入所有worker共享的data_queue(如果指定了pin_memory,这个新加的batch是会被放入pin_memory的)

    1
    data_queue.put((idx, samples))
  2. Dataloader会返回一个迭代器,每迭代一次

    首先程序会检查这次要load的idx数据是不是之前已经load过了(已经从共享的data_queue里取出来了),事先放在一个字典里存起来了(为什么会load过,下面会解释),如果是的话,就直接拿来用

    1
    2
    3
    4
    # self.rcvd_idx是第几个batch的计数
    if self.rcvd_idx in self.reorder_dict:
    batch = self.reorder_dict.pop(self.rcvd_idx)
    return self._process_next_batch(batch)

    如果没有load过,就从data_queue获取下一个batch和相应的idx,但是这里从data_queue获得的batch可能不是按顺序的,因为有的worker可能比较快提前将它的数据读好放到data_queue里了。这时候我们将这个提前来的batch先保存到self.reorder_dict这个字典里面,这就解释了上面为什么会出现load过的问题。如果一直等不到我们就会一直将提前来的batch放入self.reorder_dict暂存,直至我们等到那个按顺序来的batch。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    while True:
    assert (not self.shutdown and self.batches_outstanding > 0)
    idx, batch = self._get_batch()
    self.batches_outstanding -= 1
    if idx != self.rcvd_idx:
    # store out-of-order samples
    self.reorder_dict[idx] = batch
    continue
    return self._process_next_batch(batch)

    在每次迭代成功的时候,dataloader会放入一个新的batch_index到特定worker的index_queue里面:

    1
    2
    3
    4
    5
    6
    def _process_next_batch(self, batch):
    self.rcvd_idx += 1 #在返回数据的同时为当前的batch计数加一
    self._put_indices() #在返回数据的同时放入新的index
    if isinstance(batch, ExceptionWrapper):
    raise batch.exc_type(batch.exc_msg)
    return batch

    可以看出,dataloader只会在每次迭代成功的时候才会放入新的index到index_queue里面。因为上面写了在初始化dataloader的时候,我们一共放了2 x self.num_workers个batch的index到index_queue。读了一个batch才会放新的batch,所以这所有的worker进程最多缓存的batch数量就是2 x self.num_workers个。

那应该设定多少个worker呢?

怎么说呢?自己试吧,太多太少都不好吧。

太少,程序想去从data_queue里拿 batch的时候这个batch可能还没被读进来,会影响整个程序的运行时间。

太多的话,每个进程都是要占cpu核的,核不够,并行就只能变成串行呗,那就相当于指定那么多worker就没什么用了呗,如果拥有较靠前index的那个worker一直被阻塞没将那个batch读进来,dataloader也会一直等待这个batch,反而会更慢。(这些都是我理论分析的,哈哈哈,不信的话就去实验验证吧,毕竟实践是检验真理的唯一标准(ಡωಡ))