[
  {
    "path": "README.md",
    "content": "# Balanced-DataParallel\n这里是改进了pytorch的DataParallel, 用来平衡第一个GPU的显存使用量\n\n本代码来自transformer-XL:https://github.com/kimiyoung/transformer-xl\n\n代码不是本人写的, 但是感觉很好用, 就分享一下.\n\n# 怎么使用:\n\n&emsp;&emsp;这个 `BalancedDataParallel` 类使用起来和 `DataParallel` 类似, 下面是一个示例代码:\n\n```\nmy_net = MyNet()\nmy_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()\n```\n\n&emsp;&emsp;这里包含三个参数, 第一个参数是第一个GPU要分配多大的batch_size, 但是要注意, 如果你使用了梯度累积, 那么这里传入的是每次进行运算的实际batch_size大小. 举个例子, 比如你在3个GPU上面跑代码, 但是一个GPU最大只能跑3条数据, 但是因为0号GPU还要做一些数据的整合操作, 于是0号GPU只能跑2条数据, 这样一算, 你可以跑的大小是2+3+3=8, 于是你可以设置下面的这样的参数:\n\n```\nbatch_szie = 8\ngpu0_bsz = 2\nacc_grad = 1\nmy_net = MyNet()\nmy_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()\n```\n\n&emsp;&emsp;这个时候突然想跑个batch size是16的怎么办呢, 那就是4+6+6=16了, 这样设置累积梯度为2就行了:\n\n\n```\nbatch_szie = 16\ngpu0_bsz = 4\nacc_grad = 2\nmy_net = MyNet()\nmy_net = BalancedDataParallel(gpu0_bsz // acc_grad, my_net, dim=0).cuda()\n\n```\n\n### 各个版本的data_parallel\n\n- data_parallel.py: 原作者的代码, 但是使用的时候发现, 如果batch size设置的小于GPU的数量, 会导致最后一个批次的数据分配的不足以所有的GPU分配, 然后报错.\n\n- data_parallel_my.py: 我稍微改了一点, 然后稍微测试了一下, 应该是解决了上面的问题.\n\n- data_parallel_my_v2.py：上面第一个版本的修改，导致无法设置gpu0_bsz=0，这个版本应该是修复这个问题了\n"
  },
  {
    "path": "data_parallel.py",
    "content": "\nfrom torch.nn.parallel import DataParallel\nimport torch\nfrom torch.nn.parallel._functions import Scatter\nfrom torch.nn.parallel.parallel_apply import parallel_apply\n\ndef scatter(inputs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"\n    Slices tensors into approximately equal chunks and\n    distributes them across given GPUs. Duplicates\n    references to objects that are not tensors.\n    \"\"\"\n    def scatter_map(obj):\n        if isinstance(obj, torch.Tensor):\n            try:\n                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)\n            except:\n                print('obj', obj.size())\n                print('dim', dim)\n                print('chunk_sizes', chunk_sizes)\n                quit()\n        if isinstance(obj, tuple) and len(obj) > 0:\n            return list(zip(*map(scatter_map, obj)))\n        if isinstance(obj, list) and len(obj) > 0:\n            return list(map(list, zip(*map(scatter_map, obj))))\n        if isinstance(obj, dict) and len(obj) > 0:\n            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))\n        return [obj for targets in target_gpus]\n\n    # After scatter_map is called, a scatter_map cell will exist. This cell\n    # has a reference to the actual function scatter_map, which has references\n    # to a closure that has a reference to the scatter_map cell (because the\n    # fn is recursive). To avoid this reference cycle, we set the function to\n    # None, clearing the cell\n    try:\n        return scatter_map(inputs)\n    finally:\n        scatter_map = None\n\ndef scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"Scatter with support for kwargs dictionary\"\"\"\n    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []\n    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []\n    if len(inputs) < len(kwargs):\n        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])\n    elif len(kwargs) < len(inputs):\n        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])\n    inputs = tuple(inputs)\n    kwargs = tuple(kwargs)\n    return inputs, kwargs\n\nclass BalancedDataParallel(DataParallel):\n    def __init__(self, gpu0_bsz, *args, **kwargs):\n        self.gpu0_bsz = gpu0_bsz\n        super().__init__(*args, **kwargs)\n\n    def forward(self, *inputs, **kwargs):\n        if not self.device_ids:\n            return self.module(*inputs, **kwargs)\n        if self.gpu0_bsz == 0:\n            device_ids = self.device_ids[1:]\n        else:\n            device_ids = self.device_ids\n        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)\n        if len(self.device_ids) == 1:\n            return self.module(*inputs[0], **kwargs[0])\n        replicas = self.replicate(self.module, self.device_ids)\n        if self.gpu0_bsz == 0:\n            replicas = replicas[1:]\n        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)\n        return self.gather(outputs, self.output_device)\n\n    def parallel_apply(self, replicas, device_ids, inputs, kwargs):\n        return parallel_apply(replicas, inputs, kwargs, device_ids)\n\n    def scatter(self, inputs, kwargs, device_ids):\n        bsz = inputs[0].size(self.dim)\n        num_dev = len(self.device_ids)\n        gpu0_bsz = self.gpu0_bsz\n        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)\n        if gpu0_bsz < bsz_unit:\n            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)\n            delta = bsz - sum(chunk_sizes)\n            for i in range(delta):\n                chunk_sizes[i + 1] += 1\n            if gpu0_bsz == 0:\n                chunk_sizes = chunk_sizes[1:]\n        else:\n            return super().scatter(inputs, kwargs, device_ids)\n\n        # print('bsz: ', bsz)\n        # print('num_dev: ', num_dev)\n        # print('gpu0_bsz: ', gpu0_bsz)\n        # print('bsz_unit: ', bsz_unit)\n        # print('chunk_sizes: ', chunk_sizes)\n        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)\n\n"
  },
  {
    "path": "data_parallel_my.py",
    "content": "\nfrom torch.nn.parallel import DataParallel\nimport torch\nfrom torch.nn.parallel._functions import Scatter\nfrom torch.nn.parallel.parallel_apply import parallel_apply\n\ndef scatter(inputs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"\n    Slices tensors into approximately equal chunks and\n    distributes them across given GPUs. Duplicates\n    references to objects that are not tensors.\n    \"\"\"\n    def scatter_map(obj):\n        if isinstance(obj, torch.Tensor):\n            try:\n                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)\n            except:\n                print('obj', obj.size())\n                print('dim', dim)\n                print('chunk_sizes', chunk_sizes)\n                quit()\n        if isinstance(obj, tuple) and len(obj) > 0:\n            return list(zip(*map(scatter_map, obj)))\n        if isinstance(obj, list) and len(obj) > 0:\n            return list(map(list, zip(*map(scatter_map, obj))))\n        if isinstance(obj, dict) and len(obj) > 0:\n            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))\n        return [obj for targets in target_gpus]\n\n    # After scatter_map is called, a scatter_map cell will exist. This cell\n    # has a reference to the actual function scatter_map, which has references\n    # to a closure that has a reference to the scatter_map cell (because the\n    # fn is recursive). To avoid this reference cycle, we set the function to\n    # None, clearing the cell\n    try:\n        return scatter_map(inputs)\n    finally:\n        scatter_map = None\n\ndef scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"Scatter with support for kwargs dictionary\"\"\"\n    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []\n    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []\n    if len(inputs) < len(kwargs):\n        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])\n    elif len(kwargs) < len(inputs):\n        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])\n    inputs = tuple(inputs)\n    kwargs = tuple(kwargs)\n    return inputs, kwargs\n\nclass BalancedDataParallel(DataParallel):\n    def __init__(self, gpu0_bsz, *args, **kwargs):\n        self.gpu0_bsz = gpu0_bsz\n        super().__init__(*args, **kwargs)\n\n    def forward(self, *inputs, **kwargs):\n        if not self.device_ids:\n            return self.module(*inputs, **kwargs)\n        if self.gpu0_bsz == 0:\n            device_ids = self.device_ids[1:]\n        else:\n            device_ids = self.device_ids\n        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)\n        # print('len(inputs)1: ', str(len(inputs)))\n        # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))\n        if len(self.device_ids) == 1:\n            return self.module(*inputs[0], **kwargs[0])\n        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])\n        if self.gpu0_bsz == 0:\n            replicas = replicas[1:]\n        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)\n        return self.gather(outputs, self.output_device)\n\n    def parallel_apply(self, replicas, device_ids, inputs, kwargs):\n        return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])\n\n    def scatter(self, inputs, kwargs, device_ids):\n        bsz = inputs[0].size(self.dim)\n        num_dev = len(self.device_ids)\n        gpu0_bsz = self.gpu0_bsz\n        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)\n        if gpu0_bsz < bsz_unit:\n            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)\n            delta = bsz - sum(chunk_sizes)\n            for i in range(delta):\n                chunk_sizes[i + 1] += 1\n            if gpu0_bsz == 0:\n                chunk_sizes = chunk_sizes[1:]\n        else:\n            return super().scatter(inputs, kwargs, device_ids)\n\n        # print('bsz: ', bsz)\n        # print('num_dev: ', num_dev)\n        # print('gpu0_bsz: ', gpu0_bsz)\n        # print('bsz_unit: ', bsz_unit)\n        # print('chunk_sizes: ', chunk_sizes)\n        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)\n\n"
  },
  {
    "path": "data_parallel_my_v2.py",
    "content": "\nfrom torch.nn.parallel import DataParallel\nimport torch\nfrom torch.nn.parallel._functions import Scatter\nfrom torch.nn.parallel.parallel_apply import parallel_apply\n\ndef scatter(inputs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"\n    Slices tensors into approximately equal chunks and\n    distributes them across given GPUs. Duplicates\n    references to objects that are not tensors.\n    \"\"\"\n    def scatter_map(obj):\n        if isinstance(obj, torch.Tensor):\n            try:\n                return Scatter.apply(target_gpus, chunk_sizes, dim, obj)\n            except:\n                print('obj', obj.size())\n                print('dim', dim)\n                print('chunk_sizes', chunk_sizes)\n                quit()\n        if isinstance(obj, tuple) and len(obj) > 0:\n            return list(zip(*map(scatter_map, obj)))\n        if isinstance(obj, list) and len(obj) > 0:\n            return list(map(list, zip(*map(scatter_map, obj))))\n        if isinstance(obj, dict) and len(obj) > 0:\n            return list(map(type(obj), zip(*map(scatter_map, obj.items()))))\n        return [obj for targets in target_gpus]\n\n    # After scatter_map is called, a scatter_map cell will exist. This cell\n    # has a reference to the actual function scatter_map, which has references\n    # to a closure that has a reference to the scatter_map cell (because the\n    # fn is recursive). To avoid this reference cycle, we set the function to\n    # None, clearing the cell\n    try:\n        return scatter_map(inputs)\n    finally:\n        scatter_map = None\n\ndef scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):\n    r\"\"\"Scatter with support for kwargs dictionary\"\"\"\n    inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []\n    kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []\n    if len(inputs) < len(kwargs):\n        inputs.extend([() for _ in range(len(kwargs) - len(inputs))])\n    elif len(kwargs) < len(inputs):\n        kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])\n    inputs = tuple(inputs)\n    kwargs = tuple(kwargs)\n    return inputs, kwargs\n\nclass BalancedDataParallel(DataParallel):\n    def __init__(self, gpu0_bsz, *args, **kwargs):\n        self.gpu0_bsz = gpu0_bsz\n        super().__init__(*args, **kwargs)\n\n    def forward(self, *inputs, **kwargs):\n        if not self.device_ids:\n            return self.module(*inputs, **kwargs)\n        if self.gpu0_bsz == 0:\n            device_ids = self.device_ids[1:]\n        else:\n            device_ids = self.device_ids\n        inputs, kwargs = self.scatter(inputs, kwargs, device_ids)\n\n        print('len(inputs): ', str(len(inputs)))\n        print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)]))\n\n        if len(self.device_ids) == 1:\n            return self.module(*inputs[0], **kwargs[0])\n        if self.gpu0_bsz == 0:\n            replicas = self.replicate(self.module, self.device_ids)\n        else:\n            replicas = self.replicate(self.module, self.device_ids[:len(inputs)])\n\n        # replicas = self.replicate(self.module, device_ids[:len(inputs)])\n        if self.gpu0_bsz == 0:\n            replicas = replicas[1:]\n\n        print('replicas:', str(len(replicas)))\n\n        outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)\n        return self.gather(outputs, self.output_device)\n\n    def parallel_apply(self, replicas, device_ids, inputs, kwargs):\n        return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)])\n\n    def scatter(self, inputs, kwargs, device_ids):\n        bsz = inputs[0].size(self.dim)\n        num_dev = len(self.device_ids)\n        gpu0_bsz = self.gpu0_bsz\n        bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)\n        if gpu0_bsz < bsz_unit:\n            chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)\n            delta = bsz - sum(chunk_sizes)\n            for i in range(delta):\n                chunk_sizes[i + 1] += 1\n            if gpu0_bsz == 0:\n                chunk_sizes = chunk_sizes[1:]\n        else:\n            return super().scatter(inputs, kwargs, device_ids)\n\n        print('bsz: ', bsz)\n        print('num_dev: ', num_dev)\n        print('gpu0_bsz: ', gpu0_bsz)\n        print('bsz_unit: ', bsz_unit)\n        print('chunk_sizes: ', chunk_sizes)\n        return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)\n\n"
  }
]