Repository: Link-Li/Balanced-DataParallel Branch: master Commit: 347b6cc4ac81 Files: 4 Total size: 13.3 KB Directory structure: gitextract_7w_mbpd8/ ├── README.md ├── data_parallel.py ├── data_parallel_my.py └── data_parallel_my_v2.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Balanced-DataParallel 这里是改进了pytorch的DataParallel, 用来平衡第一个GPU的显存使用量 本代码来自transformer-XL:https://github.com/kimiyoung/transformer-xl 代码不是本人写的, 但是感觉很好用, 就分享一下. # 怎么使用:   这个 `BalancedDataParallel` 类使用起来和 `DataParallel` 类似, 下面是一个示例代码: ``` my_net = MyNet() my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() ```   这里包含三个参数, 第一个参数是第一个GPU要分配多大的batch_size, 但是要注意, 如果你使用了梯度累积, 那么这里传入的是每次进行运算的实际batch_size大小. 举个例子, 比如你在3个GPU上面跑代码, 但是一个GPU最大只能跑3条数据, 但是因为0号GPU还要做一些数据的整合操作, 于是0号GPU只能跑2条数据, 这样一算, 你可以跑的大小是2+3+3=8, 于是你可以设置下面的这样的参数: ``` batch_szie = 8 gpu0_bsz = 2 acc_grad = 1 my_net = MyNet() my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() ```   这个时候突然想跑个batch size是16的怎么办呢, 那就是4+6+6=16了, 这样设置累积梯度为2就行了: ``` batch_szie = 16 gpu0_bsz = 4 acc_grad = 2 my_net = MyNet() my_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda() ``` ### 各个版本的data_parallel - data_parallel.py: 原作者的代码, 但是使用的时候发现, 如果batch size设置的小于GPU的数量, 会导致最后一个批次的数据分配的不足以所有的GPU分配, 然后报错. - data_parallel_my.py: 我稍微改了一点, 然后稍微测试了一下, 应该是解决了上面的问题. - data_parallel_my_v2.py:上面第一个版本的修改,导致无法设置gpu0_bsz=0,这个版本应该是修复这个问题了 ================================================ FILE: data_parallel.py ================================================ from torch.nn.parallel import DataParallel import torch from torch.nn.parallel._functions import Scatter from torch.nn.parallel.parallel_apply import parallel_apply def scatter(inputs, target_gpus, chunk_sizes, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): try: return Scatter.apply(target_gpus, chunk_sizes, dim, obj) except: print('obj', obj.size()) print('dim', dim) print('chunk_sizes', chunk_sizes) quit() if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): r"""Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs class BalancedDataParallel(DataParallel): def __init__(self, gpu0_bsz, *args, **kwargs): self.gpu0_bsz = gpu0_bsz super().__init__(*args, **kwargs) def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) if self.gpu0_bsz == 0: device_ids = self.device_ids[1:] else: device_ids = self.device_ids inputs, kwargs = self.scatter(inputs, kwargs, device_ids) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids) if self.gpu0_bsz == 0: replicas = replicas[1:] outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) return self.gather(outputs, self.output_device) def parallel_apply(self, replicas, device_ids, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, device_ids) def scatter(self, inputs, kwargs, device_ids): bsz = inputs[0].size(self.dim) num_dev = len(self.device_ids) gpu0_bsz = self.gpu0_bsz bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) if gpu0_bsz < bsz_unit: chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) delta = bsz - sum(chunk_sizes) for i in range(delta): chunk_sizes[i + 1] += 1 if gpu0_bsz == 0: chunk_sizes = chunk_sizes[1:] else: return super().scatter(inputs, kwargs, device_ids) # print('bsz: ', bsz) # print('num_dev: ', num_dev) # print('gpu0_bsz: ', gpu0_bsz) # print('bsz_unit: ', bsz_unit) # print('chunk_sizes: ', chunk_sizes) return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) ================================================ FILE: data_parallel_my.py ================================================ from torch.nn.parallel import DataParallel import torch from torch.nn.parallel._functions import Scatter from torch.nn.parallel.parallel_apply import parallel_apply def scatter(inputs, target_gpus, chunk_sizes, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): try: return Scatter.apply(target_gpus, chunk_sizes, dim, obj) except: print('obj', obj.size()) print('dim', dim) print('chunk_sizes', chunk_sizes) quit() if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): r"""Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs class BalancedDataParallel(DataParallel): def __init__(self, gpu0_bsz, *args, **kwargs): self.gpu0_bsz = gpu0_bsz super().__init__(*args, **kwargs) def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) if self.gpu0_bsz == 0: device_ids = self.device_ids[1:] else: device_ids = self.device_ids inputs, kwargs = self.scatter(inputs, kwargs, device_ids) # print('len(inputs)1: ', str(len(inputs))) # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) if self.gpu0_bsz == 0: replicas = replicas[1:] outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) return self.gather(outputs, self.output_device) def parallel_apply(self, replicas, device_ids, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) def scatter(self, inputs, kwargs, device_ids): bsz = inputs[0].size(self.dim) num_dev = len(self.device_ids) gpu0_bsz = self.gpu0_bsz bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) if gpu0_bsz < bsz_unit: chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) delta = bsz - sum(chunk_sizes) for i in range(delta): chunk_sizes[i + 1] += 1 if gpu0_bsz == 0: chunk_sizes = chunk_sizes[1:] else: return super().scatter(inputs, kwargs, device_ids) # print('bsz: ', bsz) # print('num_dev: ', num_dev) # print('gpu0_bsz: ', gpu0_bsz) # print('bsz_unit: ', bsz_unit) # print('chunk_sizes: ', chunk_sizes) return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) ================================================ FILE: data_parallel_my_v2.py ================================================ from torch.nn.parallel import DataParallel import torch from torch.nn.parallel._functions import Scatter from torch.nn.parallel.parallel_apply import parallel_apply def scatter(inputs, target_gpus, chunk_sizes, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ def scatter_map(obj): if isinstance(obj, torch.Tensor): try: return Scatter.apply(target_gpus, chunk_sizes, dim, obj) except: print('obj', obj.size()) print('dim', dim) print('chunk_sizes', chunk_sizes) quit() if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): r"""Scatter with support for kwargs dictionary""" inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] if len(inputs) < len(kwargs): inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) elif len(kwargs) < len(inputs): kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) inputs = tuple(inputs) kwargs = tuple(kwargs) return inputs, kwargs class BalancedDataParallel(DataParallel): def __init__(self, gpu0_bsz, *args, **kwargs): self.gpu0_bsz = gpu0_bsz super().__init__(*args, **kwargs) def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) if self.gpu0_bsz == 0: device_ids = self.device_ids[1:] else: device_ids = self.device_ids inputs, kwargs = self.scatter(inputs, kwargs, device_ids) print('len(inputs): ', str(len(inputs))) print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) if len(self.device_ids) == 1: return self.module(*inputs[0], **kwargs[0]) if self.gpu0_bsz == 0: replicas = self.replicate(self.module, self.device_ids) else: replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) # replicas = self.replicate(self.module, device_ids[:len(inputs)]) if self.gpu0_bsz == 0: replicas = replicas[1:] print('replicas:', str(len(replicas))) outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) return self.gather(outputs, self.output_device) def parallel_apply(self, replicas, device_ids, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) def scatter(self, inputs, kwargs, device_ids): bsz = inputs[0].size(self.dim) num_dev = len(self.device_ids) gpu0_bsz = self.gpu0_bsz bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) if gpu0_bsz < bsz_unit: chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) delta = bsz - sum(chunk_sizes) for i in range(delta): chunk_sizes[i + 1] += 1 if gpu0_bsz == 0: chunk_sizes = chunk_sizes[1:] else: return super().scatter(inputs, kwargs, device_ids) print('bsz: ', bsz) print('num_dev: ', num_dev) print('gpu0_bsz: ', gpu0_bsz) print('bsz_unit: ', bsz_unit) print('chunk_sizes: ', chunk_sizes) return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)