[
  {
    "path": ".gitignore",
    "content": "application/cache\n*.pyc\n\n# general things to ignore\nbuild/\ndist/\n*.egg-info/\n*.egg\n*.py[cod]\n__pycache__/\n*~\n\n# due to using tox and pytest\n.tox\n.cache\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"third_party/megatron\"]\n\tpath = third_party/megatron\n\turl = https://github.com/NVIDIA/Megatron-LM.git\n[submodule \"third_party/deepspeed\"]\n\tpath = third_party/deepspeed\n\turl = https://github.com/microsoft/DeepSpeed.git\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# veGiantModel\nVeGiantModel is a torch based high efficient training library developed by the Applied Machine Learning team at Bytedance. This repository is for ongoing research to make giant model (such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf) and [T5](https://arxiv.org/abs/1910.10683)) training easy, efficient, and effective. VeGiantModel builds on top of [Megatron](https://github.com/NVIDIA/Megatron-LM) and [DeepSpeed](https://github.com/microsoft/DeepSpeed), improves communication efficiency by integrating high efficient communication library [BytePs](https://github.com/bytedance/byteps) and providing customized pipline partitioning.\n## initialization\n\n```python\nimport veGiantModel\npipeline_parallel_size = 1\nmodel_parallel_size = 2\nveGiantModel.initialize.init_distribute(pipeline_parallel_size, model_parallel_size, init_method=\"env://\")\nmp_size = veGiantModel.distributed.get_model_parallel_world_size()\ndp_size = veGiantModel.distributed.get_data_parallel_world_size()\n```\n\n## modules\n\n\n```python\nfrom veGiantModel.module import ColumnParallelLinear, RowParallelLinear\n\nclass PositionWiseFeedForward(nn.Module):\n    \"\"\" FeedForward Neural Networks for each position \"\"\"\n\n    def __init__(self, config: Config):\n        super().__init__()\n\n        if self.config.use_mp_linear_in_ffn:\n            assert ColumnParallelLinear is not None\n            assert RowParallelLinear is not None\n            self.fc1 = ColumnParallelLinear(config.dim, config.dim_ff, use_ft=False)\n            self.fc2 = RowParallelLinear(config.dim_ff, config.dim, use_ft=False)\n        else:\n            self.fc1 = nn.Linear(config.dim, config.dim_ff)\n            self.fc2 = nn.Linear(config.dim_ff, config.dim)\n        self.act = Activation(config.act)\n        self.dropout = nn.Dropout(config.p_drop_hidden)\n\n    def forward(self, x) -> torch.Tensor:\n        # (bsz, seq_len, dim) -> (bsz, seq_len, dim_ff / model_parallel_size) -> (bsz, seq_len, dim)\n        fc1_out = self.act(self.fc1(x))\n        if self.config.dropout_in_ffn:\n            fc1_out = self.dropout(fc1_out)\n        fc2_out = self.fc2(fc1_out)\n        if self.config.use_ffn_output_dropout:\n            fc2_out = self.dropout(fc2_out)\n        return fc2_out\n```\n\n\n## Examples\n### GPT Pretraining\nThe `examples/gpt/pretrain_gpt2_distributed.sh` scrips runs 345M parameter GPT pretraining on single 8 GPUs node. It follows largely the same as Megatron GPT script with a few notable differences. It shows good compatiblility with current megatron/Deepseed training job with little changes to adpot VeGiantModel.\n"
  },
  {
    "path": "docs/Dockerfile",
    "content": "FROM nvcr.io/nvidia/pytorch:21.05-py3 \n\nRUN pip3 install boto3 regex tensorboardX==1.8 wheel pybind11 ninja psutil pyprof\nRUN apt-get -yq autoremove --purge ibverbs-providers\nRUN apt-get update && \\\n    DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends --allow-downgrades \\\n     libibverbs-dev=28.0-1ubuntu1 libibverbs1=28.0-1ubuntu1\n\nRUN apt-get update && \\\n    DEBIAN_FRONTEND=noninteractive apt-get install -yq --no-install-recommends --allow-downgrades \\\n        cmake \\\n        libopenmpi-dev \\\n        openmpi-bin \\\n        openssh-client \\\n        openssh-server \\\n        ibverbs-providers \\\n        libibverbs-dev=28.0-1ubuntu1 \\\n        librdmacm-dev \\\n        vim \\\n        iputils-ping \\\n        llvm-10-dev \\\n        iproute2 \\\n        unzip\n\nRUN ln -s /usr/bin/aclocal-1.16 /usr/local/bin/aclocal-1.14\nRUN ln -s /usr/bin/automake /usr/local/bin/automake-1.14\n\nENV LD_LIBRARY_PATH \"/usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}\"\nENV BYTEPS_WITH_UCX 0\n\n#install byteps from package stored in tos at volcengine\n# RUN pip3 install https://giant-model-package.tos-cn-beijing.volces.com/byteps-0.7.2-cp38-cp38-linux_x86_64.whl\n\n#install byteps from source\nRUN git clone --recursive -b bccl-github https://github.com/bytedance/byteps.git && \\\n    cd byteps && python3 setup.py install\n\nWORKDIR /root"
  },
  {
    "path": "docs/step-by-step-tutorial.md",
    "content": "# A Step-by-Step Tutorial\nThe goal of this tutorial is to help you run the example quickly.\n\n## Pre-requisite\npytorch:\n```\npip3 install pytorch\n```\n\nApex:\n```\ngit clone https://github.com/NVIDIA/apex.git\ncd apex\npython3 setup.py -v --cpp_ext --cuda_ext bdist_wheel\nsudo pip3 install dist/*\n```\n\nBytePs:\n```\ngit clone --recursive -b bccl-github https://github.com/bytedance/byteps.git\ncd byteps\npython3 setup.py install\n```\n## Prepare data\n    [GPT data preprocess](https://github.com/NVIDIA/Megatron-LM#data-preprocessing)\n\n## Setup veGiantModel\n```\ngit clone https://github.com/volcengine/veGiantModel.git\ncd veGiantModel\ngit submodule update --init --recursive\n```\n\n## Modify script\nModify examples/gpt/pretrain_gpt2_distributed.sh before run\n```\nDATA_PATH           -- the preprocessed gpt data local folder path\nCHECKPOINT_PATH     -- local path to save/load check point\nMASTER_PORT         -- port number used by torch ddp\nWORKER_0_PORT       -- port number for veGiantModel use for communication\nWORKER_0_HOST       -- ip of the master node (single node training can use 'localhost')\nNUM_WORKER          -- number of workers in the training\nWORKER_RANK         -- rank of current node\nGPU_PER_WORKER      -- number of GPUs per node\n```\n\n## run script\n```\nbash examples/gpt/pretrain_gpt2_distributed.sh\n```\n\n"
  },
  {
    "path": "examples/gpt/gpt_piped.py",
    "content": "import torch\n\nfrom megatron import get_args, mpu\n\nfrom megatron.model.language_model import parallel_lm_logits, Embedding\nfrom megatron.model.transformer import ParallelTransformerLayer\nfrom megatron.model.transformer import LayerNorm\nfrom megatron.model.gpt2_model import gpt2_attention_mask_func\nfrom megatron.model.utils import init_method_normal\nfrom megatron.model.utils import scaled_init_method_normal\nfrom megatron.module import MegatronModule\nfrom megatron.utils import get_ltor_masks_and_position_ids\n\nfrom deepspeed.pipe import LayerSpec, TiedLayerSpec\nfrom megatron import get_tokenizer\nfrom veGiantModel.engine.module import VeGiantModule\nimport veGiantModel\n\nclass GPTModelPiped(VeGiantModule):\n    def __init__(self):\n        args = get_args()\n        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy\n        self.tokenizer = get_tokenizer()\n        self.parallel_output = True\n\n        self.num_layers = args.num_layers\n        self.hidden_size = args.hidden_size\n\n        self.init_method = init_method_normal(args.init_method_std)\n        self.scale_init_method = scaled_init_method_normal(args.init_method_std,\n                                                           args.num_layers)\n\n        self.num_tokentypes = 0\n\n        layers = []\n        layers.append(lambda x: self._get_batch(x))\n        layers.append(TiedLayerSpec(\"SharedEmbedding\",\n                                    EmbeddingPiped,\n                                    self.hidden_size,\n                                    args.padded_vocab_size,\n                                    args.max_position_embeddings,\n                                    args.hidden_dropout,\n                                    self.init_method,\n                                    self.num_tokentypes,\n                                    tied_weight_attr='embedding_weight'))\n\n        layers.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1]))\n\n        for i in range(self.num_layers):\n            layers.append(LayerSpec(ParallelTransformerLayerPiped,\n                                    gpt2_attention_mask_func,\n                                    self.init_method,\n                                    self.scale_init_method,\n                                    i+1))\n\n        layers.append(lambda x: (x[0].transpose(0, 1).contiguous()))\n\n        layers.append(LayerSpec(LayerNorm, args.hidden_size, eps=args.layernorm_epsilon))\n\n        layers.append(TiedLayerSpec(\"SharedEmbedding\",\n                                    LMLogitsPiped,\n                                    self.hidden_size,\n                                    args.padded_vocab_size,\n                                    self.init_method,\n                                    tied_weight_attr='embedding_weight'))\n\n        super().__init__(layers=layers,\n                         num_stages = args.num_stages, \n                         partition_method=args.partition_method,\n                         grid=veGiantModel.distributed.get_grid(),\n                         loss_fn=self.loss_fn)\n        \n\n    # Data Preprocessing, copied from pretrain_gpt2.py\n    def _get_batch(self, data):\n        \"\"\"Generate a batch\"\"\"\n        args = get_args()\n        # Unpack.\n        tokens = data\n\n        attention_mask, _, position_ids = get_ltor_masks_and_position_ids(\n            tokens,\n            self.tokenizer.eod,\n            args.reset_position_ids,\n            args.reset_attention_mask,\n            args.eod_mask_loss)\n\n        return (tokens.to(device=\"cuda\"),\n                position_ids.to(device=\"cuda\"),\n                attention_mask.to(device=\"cuda\"))\n\n    def loss_fn(self, inputs, data):\n        tokens = data[0]\n        target = data[1]\n        args = get_args()\n        _, loss_mask, _ = get_ltor_masks_and_position_ids(\n            tokens,\n            self.tokenizer.eod,\n            args.reset_position_ids,\n            args.reset_attention_mask,\n            args.eod_mask_loss)\n\n        if self.fp16_lm_cross_entropy:\n            assert inputs.dtype == torch.half\n            loss = mpu.vocab_parallel_cross_entropy(inputs, target)\n        else:\n            loss = mpu.vocab_parallel_cross_entropy(inputs.float(), target)\n        loss_mask = loss_mask.view(-1)\n        loss_avg = torch.sum(loss.view(-1) * loss_mask) / loss_mask.sum()\n        if loss.dtype == torch.half:\n            loss_avg = loss_avg.half()\n\n        return loss_avg\n\n    def batch_fn(self, batch, is_train:bool):\n        if batch is not None:\n            data = {'text': torch.tensor(batch['text'].numpy())}\n        else:\n            data = None\n\n        keys = ['text']\n        datatype = torch.int64\n\n        data_b = mpu.broadcast_data(keys, data, datatype)\n\n        tokens_ = data_b['text'].long()\n        tokens_write = tokens_\n        labels = tokens_[:, 1:].contiguous()\n        tokens_ = tokens_[:, :-1].contiguous()\n        tokens_2 = torch.unsqueeze(tokens_, 0)\n        data2 = torch.cat((tokens_2, labels[None, :, :]), dim=0)\n        data = []\n        data.append(tokens_)\n        data.append(data2)\n        return data\n\nclass LMLogitsPiped(MegatronModule):\n    def __init__(self, hidden_size, vocab_size, init_method):\n        super().__init__()\n        self.word_embeddings = mpu.VocabParallelEmbedding(\n            vocab_size, hidden_size, init_method=init_method)\n        self.embedding_weight = self.word_embeddings.weight\n\n    def forward(self, lm_output):\n        return parallel_lm_logits(lm_output, self.embedding_weight, True)\n\n\nclass EmbeddingPiped(Embedding):\n    def __init__(self,\n                hidden_size,\n                vocab_size,\n                max_sequence_length,\n                embedding_dropout_prob,\n                init_method,\n                num_tokentypes=0):\n        super().__init__(hidden_size,\n                        vocab_size,\n                        max_sequence_length,\n                        embedding_dropout_prob,\n                        init_method,\n                        num_tokentypes)\n        self.embedding_weight = self.word_embeddings.weight\n\n    def forward(self, inputs):\n        input_ids, position_ids, attention_mask = inputs\n        return super().forward(input_ids, position_ids, None), attention_mask\n\nclass ParallelTransformerLayerPiped(ParallelTransformerLayer):\n    def __init__(self,\n                attention_mask_func,\n                init_method,\n                output_layer_init_method,\n                layer_number):\n        super().__init__(attention_mask_func,\n                         init_method,\n                         output_layer_init_method,\n                         layer_number)\n\n    def forward(self, inputs):\n        hidden_states, attention_mask = inputs\n        return (super().forward(hidden_states, attention_mask),\n                attention_mask)"
  },
  {
    "path": "examples/gpt/initialize.py",
    "content": "import torch\nimport json\nimport veGiantModel\n\nfrom megatron import get_args, mpu\nfrom megatron.fp16 import FP16_Module\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\nfrom megatron.model import DistributedDataParallel as LocalDDP\nfrom megatron.model import get_params_for_weight_decay_optimization\nfrom apex.optimizers import FusedAdam as Adam\nfrom megatron.learning_rates import AnnealingLR\nfrom megatron import print_rank_0\n\n\ndef get_learning_rate_scheduler(optimizer, lr_scheduler_builder):\n    \"\"\"Build the learning rate scheduler.\"\"\"\n    args = get_args()\n\n\n    if lr_scheduler_builder is not None:\n        lr_scheduler = lr_scheduler_builder(optimizer)\n    else:\n        # Add linear learning rate scheduler.\n        if args.lr_decay_iters is not None:\n            num_iters = args.lr_decay_iters\n        else:\n            num_iters = args.train_iters\n        num_iters = max(1, num_iters)\n        init_step = 0\n        warmup_iter = args.warmup * num_iters\n        lr_scheduler = AnnealingLR(\n            optimizer,\n            start_lr=args.lr,\n            warmup_iter=warmup_iter,\n            total_iters=num_iters,\n            decay_style=args.lr_decay_style,\n            last_iter=init_step,\n            min_lr=args.min_lr,\n            use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,\n            override_lr_scheduler=args.override_lr_scheduler)\n\n    return lr_scheduler\n\n\ndef get_model(model_provider_func):\n    \"\"\"Build the model.\"\"\"\n    args = get_args()\n\n    # Build model on cpu.\n    model = model_provider_func()\n\n    # Print number of parameters.\n    if mpu.get_data_parallel_rank() == 0:\n        print(' > number of parameters on model parallel rank {}: {}'.format(\n            mpu.get_model_parallel_rank(),\n            sum([p.nelement() for p in model.parameters()])), flush=True)\n\n    # GPU allocation.\n    model.cuda(torch.cuda.current_device())\n\n    return model\n\ndef get_optimizer(model):\n    \"\"\"Set up the optimizer.\"\"\"\n    args = get_args()\n\n    # Build parameter groups (weight decay and non-decay).\n    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):\n        model = model.module\n    param_groups = get_params_for_weight_decay_optimization(model)\n\n    # Add model parallel attribute if it is not set.\n    for param_group in param_groups:\n        for param in param_group['params']:\n            if not hasattr(param, 'model_parallel'):\n                param.model_parallel = False\n\n    if args.cpu_optimizer:\n        if args.cpu_torch_adam:\n            cpu_adam_optimizer = torch.optim.Adam\n        else:\n            from deepspeed.ops.adam import DeepSpeedCPUAdam\n            cpu_adam_optimizer = DeepSpeedCPUAdam\n        optimizer = cpu_adam_optimizer(param_groups,\n                        lr=args.lr, weight_decay=args.weight_decay)\n    else:\n        # Use Adam.\n        optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)\n\n    if args.deepspeed:\n        # fp16 wrapper is not required for DeepSpeed.\n        return optimizer\n\n    # Wrap into fp16 optimizer.\n    if args.fp16:\n        optimizer = FP16_Optimizer(optimizer,\n                                   static_loss_scale=args.loss_scale,\n                                   dynamic_loss_scale=args.dynamic_loss_scale,\n                                   dynamic_loss_args={\n                                       'scale_window': args.loss_scale_window,\n                                       'min_scale': args.min_scale,\n                                       'delayed_shift': args.hysteresis},\n                                   fp16_optim=args.fp16_optim)\n\n    return optimizer\n\ndef setup_model_and_optimizer(model, optimizer, train_dataset_provider, lr_scheduler_builder):\n    \"\"\"Setup model and optimizer.\"\"\"\n    args = get_args()\n    if optimizer is None:\n        optimizer = get_optimizer(model)\n    lr_scheduler = get_learning_rate_scheduler(optimizer, lr_scheduler_builder)\n\n    print_rank_0(\"DeepSpeed is enabled.\")\n\n    # Print number of parameters.\n    if mpu.get_data_parallel_rank() == 0:\n        print(' > number of parameters on data parallel rank {}, model parallel rank {}, pipeline parallel rank {}: {}'.format(\n            mpu.get_data_parallel_rank(),\n            mpu.get_model_parallel_rank(),\n            mpu.get_pipe_parallel_rank(),\n            sum([p.nelement() for p in model.parameters()])), flush=True)\n\n    if args.deepspeed_pipeline:\n        print_rank_0(\"Pipeline Parallelism is enabled.\")\n        train_data = train_dataset_provider() if train_dataset_provider is not None else None\n        _param_dict = json.loads(args.config_param)\n        engine, optimizer, _, lr_scheduler = veGiantModel.initialize(\n            model=model,\n            optimizer=optimizer,\n            args=args,\n            lr_scheduler=lr_scheduler,\n            mpu=None,\n            dist_init_required=False,\n            config_params = _param_dict,\n            training_data=train_data\n        )\n        engine.set_batch_fn(model.batch_fn)\n    else:\n        engine, optimizer, _, lr_scheduler = veGiantModel.initialize(\n            model=model,\n            optimizer=optimizer,\n            args=args,\n            lr_scheduler=lr_scheduler,\n            mpu=mpu,\n            dist_init_required=False\n        )\n\n    print_rank_0(\"Model Preparation Done\")\n    args.iteration = 0\n\n    return engine, optimizer, lr_scheduler\n\n\ndef initialize_pipeline(model, optimizer, train_dataset_provider, lr_scheduler_builder=None):\n    return setup_model_and_optimizer(model, optimizer, train_dataset_provider, lr_scheduler_builder)\n\n\ndef initialize_distributed(num_stages, mp_size, distributed_backend='nccl'):\n    veGiantModel.init_distribute(num_stages=num_stages, mp_size=mp_size, distributed_backend=distributed_backend)\n\ndef initialize_megatron(extra_args_provider=None, args_defaults={}):\n    veGiantModel.initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults)\n"
  },
  {
    "path": "examples/gpt/pretrain_gpt2.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n\"\"\"Pretrain GPT2\"\"\"\nimport torch\nimport os\nimport numpy as np\nimport time\nimport sys\n\n_cwd = os.path.dirname(os.path.abspath(__file__))\n_giantModel_dir = os.path.join(_cwd, '../../src')\nsys.path.append(_giantModel_dir)\n\nfrom initialize import initialize_megatron, initialize_pipeline\nfrom gpt_piped import GPTModelPiped\n\nfrom megatron import get_args, mpu\nfrom megatron import get_timers\nfrom megatron import get_tensorboard_writer\nfrom megatron import print_rank_0\nfrom megatron.learning_rates import AnnealingLR\nfrom megatron.training import build_train_valid_test_data_iterators\nfrom megatron.data.gpt2_dataset import get_indexed_dataset_, get_train_valid_test_split_, _num_tokens, _num_epochs, _build_doc_idx, _build_shuffle_idx\nfrom deepspeed.utils import log_dist\n\ndef _build_index_mappings(name, data_prefix, documents, sizes,\n                        num_samples, seq_length, seed):\n    \"\"\"Build doc-idx, sample-idx, and shuffle-idx.\n    doc-idx: is an array (ordered) of documents to be used in training.\n    sample-idx: is the start document index and document offset for each\n    training sample.\n    shuffle-idx: maps the sample index into a random index into sample-idx.\n    \"\"\"\n    log_dist(f' >>>> Entering _build_index_mappings', ranks=[-1])\n    # Number of tokens in each epoch and number of required epochs.\n    args = get_args()\n    tokens_per_epoch = _num_tokens(documents, sizes)\n    num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)\n    # rng state\n    np_rng = np.random.RandomState(seed=seed)\n\n    # Filename of the index mappings.\n    _filename = data_prefix\n    _filename += '_{}_{}_indexmap'.format(args.rank, name)\n    _filename += '_{}ns'.format(num_samples)\n    _filename += '_{}sl'.format(seq_length)\n    _filename += '_{}s'.format(seed)\n    doc_idx_filename = _filename + '_doc_idx.npy'\n    sample_idx_filename = _filename + '_sample_idx.npy'\n    shuffle_idx_filename = _filename + '_shuffle_idx.npy'\n\n    # Build the indexed mapping if not exist.\n    device_count = torch.cuda.device_count()\n    if (not os.path.isfile(doc_idx_filename)) or \\\n    (not os.path.isfile(sample_idx_filename)) or \\\n    (not os.path.isfile(shuffle_idx_filename)):\n\n        log_dist(f' > WARNING: could not find index map files, building '\n                    'the indices ...', ranks=[-1])\n        # doc-idx.\n        start_time = time.time()\n        doc_idx = _build_doc_idx(documents, num_epochs, np_rng)\n        np.save(doc_idx_filename, doc_idx, allow_pickle=True)\n        log_dist(' > elasped time to build and save doc-idx mapping '\n                    '(seconds): {:4f}'.format(time.time() - start_time), ranks=[-1])\n        # sample-idx.\n        start_time = time.time()\n        # Use C++ implementation for speed.\n        # First compile and then import.\n        from megatron.data.dataset_utils import compile_helper\n        compile_helper()\n        from megatron.data import helpers\n        assert doc_idx.dtype == np.int32\n        assert sizes.dtype == np.int32\n        sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length,\n                                            num_epochs, tokens_per_epoch)\n        # sample_idx = _build_sample_idx(sizes, doc_idx, seq_length,\n        #                               num_epochs, tokens_per_epoch)\n        np.save(sample_idx_filename, sample_idx, allow_pickle=True)\n        log_dist(' > elasped time to build and save sample-idx mapping '\n                    '(seconds): {:4f}'.format(time.time() - start_time), ranks=[-1])\n        # shuffle-idx.\n        start_time = time.time()\n        # -1 is due to data structure used to retieve the index:\n        #    sample i --> [sample_idx[i], sample_idx[i+1])\n        shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng)\n        np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)\n        log_dist(' > elasped time to build and save shuffle-idx mapping'\n                    ' (seconds): {:4f}'.format(time.time() - start_time), ranks=[-1])\n\n    # This should be a barrier but nccl barrier assumes\n    # device_index=rank which is not the case for model\n    # parallel case\n    counts = torch.cuda.LongTensor([1])\n    torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())\n    assert counts[0].item() == torch.distributed.get_world_size(\n        group=mpu.get_data_parallel_group())\n\n    # Load mappings.\n    start_time = time.time()\n    log_dist(' > loading doc-idx mapping from {}'.format(\n        doc_idx_filename))\n\n    if not os.path.isfile(doc_idx_filename):\n        log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format(\n        doc_idx_filename), ranks=[-1])\n\n    doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r')\n    log_dist(' > loading sample-idx mapping from {}'.format(\n        sample_idx_filename), ranks=[-1])\n    if not os.path.isfile(sample_idx_filename):\n        log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format(\n        sample_idx_filename), ranks=[-1])\n    sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r')\n    log_dist(' > loading shuffle-idx mapping from {}'.format(\n        shuffle_idx_filename), ranks=[-1])\n    if not os.path.isfile(shuffle_idx_filename):\n        log_dist(' > loading doc-idx mapping from {} failed, file not exist'.format(\n        shuffle_idx_filename), ranks=[-1])\n    shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r')\n    log_dist('    loaded indexed file in {:3.3f} seconds'.format(\n        time.time() - start_time), ranks=[-1])\n    log_dist('    total number of samples: {}'.format(\n        sample_idx.shape[0]), ranks=[-1])\n    log_dist('    total number of epochs: {}'.format(num_epochs), ranks=[-1])\n\n    log_dist(f' >>>> exiting _build_index_mappings', ranks=[-1])\n    return doc_idx, sample_idx, shuffle_idx\n    \nclass GPT2DatasetFixed(torch.utils.data.Dataset):\n    def __init__(self, name, data_prefix, documents, indexed_dataset,\n                 num_samples, seq_length, seed):\n\n        self.name = name\n        self.indexed_dataset = indexed_dataset\n\n        # Checks\n        assert np.min(documents) >= 0\n        assert np.max(documents) < indexed_dataset.sizes.shape[0]\n\n        # Build index mappings.\n        self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(\n            self.name, data_prefix, documents, self.indexed_dataset.sizes,\n            num_samples, seq_length, seed)\n\n    def __len__(self):\n        # -1 is due to data structure used to retieve the index:\n        #    sample i --> [sample_idx[i], sample_idx[i+1])\n        return self.sample_idx.shape[0] - 1\n\n    def __getitem__(self, idx):\n        # Get the shuffled index.\n        idx = self.shuffle_idx[idx]\n        # Start and end documents and offsets.\n        doc_index_f = self.sample_idx[idx][0]\n        doc_index_l = self.sample_idx[idx + 1][0]\n        offset_f = self.sample_idx[idx][1]\n        offset_l = self.sample_idx[idx + 1][1]\n        # If we are within the same document, just extract the chunk.\n        if doc_index_f == doc_index_l:\n            sample = self.indexed_dataset.get(self.doc_idx[doc_index_f],\n                                              offset=offset_f,\n                                              length=offset_l - offset_f + 1)\n        else:\n            # Otherwise, get the rest of the initial document.\n            sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f],\n                                                    offset=offset_f)]\n            # Loop over all in between documents and add the entire document.\n            for i in range(doc_index_f + 1, doc_index_l):\n                sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))\n            # And finally add the relevant portion of last document.\n            sample_list.append(self.indexed_dataset.get(\n                self.doc_idx[doc_index_l],\n                length=offset_l + 1))\n            sample = np.concatenate(sample_list)\n\n        return {'text': np.array(sample, dtype=np.int64)}\n\n\n\ndef build_train_valid_test_datasets(data_prefix, data_impl, splits_string,\n                                    train_valid_test_num_samples,\n                                    seq_length, seed, skip_warmup):\n    \"\"\"Build train, valid, and test datasets.\"\"\"\n\n    # Indexed dataset.\n    indexed_dataset = get_indexed_dataset_(data_prefix,\n                                           data_impl,\n                                           skip_warmup)\n\n    total_num_of_documents = indexed_dataset.sizes.shape[0]\n    splits = get_train_valid_test_split_(splits_string, total_num_of_documents)\n\n    # Print stats about the splits.\n    print_rank_0(' > dataset split:')\n\n    def print_split_stats(name, index):\n        print_rank_0('    {}:'.format(name))\n        print_rank_0('     document indices in [{}, {}) total of {} '\n                     'documents'.format(splits[index], splits[index + 1],\n                                        splits[index + 1] - splits[index]))\n    print_split_stats('train', 0)\n    print_split_stats('validation', 1)\n    print_split_stats('test', 2)\n\n    def build_dataset(index, name):\n        dataset = None\n        if splits[index + 1] > splits[index]:\n            documents = np.arange(start=splits[index], stop=splits[index + 1],\n                                  step=1, dtype=np.int32)\n            dataset = GPT2DatasetFixed(name, data_prefix,\n                                  documents, indexed_dataset,\n                                  train_valid_test_num_samples[index],\n                                  seq_length, seed)\n        return dataset\n\n    train_dataset = build_dataset(0, 'train')\n    valid_dataset = build_dataset(1, 'valid')\n    test_dataset = build_dataset(2, 'test')\n\n    return (train_dataset, valid_dataset, test_dataset)\n\ndef model_provider():\n    \"\"\"Build the model.\"\"\"\n\n    print_rank_0('building GPT2 model ...')\n    model = GPTModelPiped()\n    return model\n\ndef lr_scheduler_builder(optimizer):\n    \"\"\"Build the learning rate scheduler.\"\"\"\n    args = get_args()\n\n    # Add linear learning rate scheduler.\n    if args.lr_decay_iters is not None:\n        num_iters = args.lr_decay_iters\n    else:\n        num_iters = args.train_iters\n    num_iters = max(1, num_iters)\n    init_step = 0\n    warmup_iter = args.warmup * num_iters\n  \n    lr_scheduler = AnnealingLR(\n        optimizer,\n        start_lr=args.lr,\n        warmup_iter=warmup_iter,\n        total_iters=num_iters,\n        decay_style=args.lr_decay_style,\n        last_iter=init_step,\n        min_lr=args.min_lr,\n        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,\n        override_lr_scheduler=args.override_lr_scheduler)\n    \n    return lr_scheduler\n\n\ndef pretrain(model_provider, args_defaults={}):\n    initialize_megatron(args_defaults=args_defaults)\n    timers = get_timers()\n\n    # Model, optimizer, and learning rate.\n    timers('model and optimizer').start()\n    model = model_provider()\n    engine, optimizer, lr_scheduler = initialize_pipeline(model, None, None, lr_scheduler_builder)\n    timers('model and optimizer').stop()\n\n    # Print setup timing.\n    print_rank_0('done with setups ...')\n    print_rank_0('training ...')\n\n    train(engine, optimizer, lr_scheduler)\n\ndef traing_log(loss_dict, iteration):\n    args = get_args()\n    timers = get_timers()\n    writer = get_tensorboard_writer()\n\n    # Logging.\n    timers_to_log = []\n\n    def add_to_logging(name):\n        if name in timers.timers:\n            timers_to_log.append(name)\n    add_to_logging('forward')\n    add_to_logging('backward')\n    add_to_logging('backward-backward')\n    add_to_logging('backward-allreduce')\n    add_to_logging('backward-master-grad')\n    add_to_logging('backward-clip-grad')\n    add_to_logging('optimizer')\n    add_to_logging('batch generator')\n\n    if writer and torch.distributed.get_rank() == 0:\n        writer.add_scalar('loss', loss_dict, iteration)\n        normalizer = iteration % args.log_interval\n        if normalizer == 0:\n            normalizer = args.log_interval\n        timers.write(timers_to_log, writer, iteration,\n                     normalizer=normalizer)\n\ndef train_valid_test_dataset_provider(train_val_test_num_samples):\n    \"\"\"Build train, valid, and test datasets.\"\"\"\n    args = get_args()\n\n    print_rank_0('> building train, validation, and test datasets '\n                 'for GPT ...')\n    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(\n        data_prefix=args.data_path,\n        data_impl=args.data_impl,\n        splits_string=args.split,\n        train_valid_test_num_samples=train_val_test_num_samples,\n        seq_length=args.seq_length,\n        seed=args.seed,\n        skip_warmup=(not args.mmap_warmup))\n    print_rank_0(\"> finished creating GPT datasets ...\")\n\n    return train_ds, valid_ds, test_ds\n\ndef train(engine, optimizer, lr_scheduler):\n    \"\"\"Train the model function.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Turn on training mode which enables dropout.\n    engine.train()\n\n    # Iterations.\n    iteration = args.iteration\n\n    timers('interval time').start()\n\n    train_data_iterator, valid_data_iterator, test_data_iterator \\\n        = build_train_valid_test_data_iterators(train_valid_test_dataset_provider)\n\n    log_dist(f' >>>> start training', ranks=[-1])\n    while iteration < args.train_iters:\n        engine.train_batch(train_data_iterator)\n\nif __name__ == \"__main__\":\n    pretrain(model_provider,\n             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})\n"
  },
  {
    "path": "examples/gpt/pretrain_gpt2_distributed.sh",
    "content": "#! /bin/bash\n# Runs the \"345M\" parameter model\n\nDATA_PATH=<Specify path where >\nCHECKPOINT_PATH=<Specify path>\n\nexport WORKER_0_HOST=127.0.0.1\nexport DMLC_NODE_HOST=127.0.0.1\nexport WORKER_0_PORT=6000\nexport NUM_WORKER=1\nexport WORKER_RANK=0\nexport GPU_PER_WORKER=8\n\nexport BYTEPS_WITH_UCX=0 \nexport DMLC_ENABLE_UCX=0\nexport DMLC_ENABLE_RDMA=0\n\nMASTER_PORT=6002\nMASTER_ADDR=$WORKER_0_HOST\n\nGPUS_PER_NODE=$GPU_PER_WORKER\n\nNNODES=$NUM_WORKER\nNODE_RANK=$WORKER_RANK\n\nWORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))\n\nbase_dir=$(cd `dirname $0`; pwd)\necho base_dir $base_dir\n\nDISTRIBUTED_ARGS=\"--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT\"\n\nds_config='{\n    \"train_micro_batch_size_per_gpu\":16,\n    \"train_batch_size\" : 16,\n    \"gradient_accumulation_steps\": 2,\n    \"steps_per_print\": 1,\n    \"gradient_clipping\": 1.0,\n    \"zero_optimization\": {\n      \"stage\": 0,\n      \"allgather_partitions\": true,\n      \"allgather_bucket_size\": 500000000,\n      \"overlap_comm\": true,\n      \"reduce_scatter\": true,\n      \"reduce_bucket_size\": 500000000,\n      \"contiguous_gradients\" : true,\n      \"cpu_offload\": false\n    },\n    \"fp16\": {\n      \"enabled\": true,\n      \"loss_scale\": 0,\n      \"loss_scale_window\": 1000,\n      \"hysteresis\": 2,\n      \"min_loss_scale\": 1\n    },\n    \"wall_clock_breakdown\": true\n}'\n\npython3 -m torch.distributed.launch $DISTRIBUTED_ARGS \\\n       --no_python --use_env python3 \\\n       ${base_dir}/pretrain_gpt2.py \\\n       --model-parallel-size 2 \\\n       --num-stages 2 \\\n       --num-layers 24 \\\n       --hidden-size 1024 \\\n       --train-batch-size 64 \\\n       --gradient_accumulation_steps 16 \\\n       --num-attention-heads 16 \\\n       --batch-size 4 \\\n       --seq-length 1024 \\\n       --max-position-embeddings 1024 \\\n       --train-iters 500000 \\\n       --lr-decay-iters 450000 \\\n       --save $CHECKPOINT_PATH \\\n       --load $CHECKPOINT_PATH \\\n       --data-path $DATA_PATH/openwebtext-gpt2_text_document \\\n       --vocab-file $DATA_PATH/gpt2-vocab.json \\\n       --merge-file $DATA_PATH/gpt2-merges.txt \\\n       --data-impl mmap \\\n       --split 949,50,1 \\\n       --distributed-backend nccl \\\n       --lr 0.00025 \\\n       --lr-decay-style cosine \\\n       --min-lr 1.0e-5 \\\n       --weight-decay 1e-2 \\\n       --clip-grad 1.0 \\\n       --warmup .02 \\\n       --log-interval 1 \\\n       --save-interval 100000 \\\n       --vocab-size 145608 \\\n       --DDP-impl torch \\\n       --eod-mask-loss \\\n       --deepspeed-pipeline \\\n       --deepspeed \\\n       --config_param \"$ds_config\" \\\n       --fp16 \\\n       --partition_method \"type:ParallelTransformerLayerPiped\" \\\n       $@\nset +x\n"
  },
  {
    "path": "src/veGiantModel/__init__.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\nimport sys\nimport os\n\ncwd = os.path.dirname(os.path.abspath(__file__))\n_deepspeed_dir = os.path.join(cwd, '../../third_party/deepspeed')\n_megatron_dir = os.path.join(cwd, '../../third_party/megatron')\nsys.path.append(cwd)\nsys.path.append(_deepspeed_dir)\nsys.path.append(_megatron_dir)\n\nfrom . import patcher\nfrom .engine.engine import VeGiantModelEngine\nfrom .initialize import initialize_megatron, init_distribute\nfrom .distributed import *\n\ndef initialize(args,\n               model,\n               optimizer=None,\n               model_parameters=None,\n               training_data=None,\n               lr_scheduler=None,\n               mpu=None,\n               dist_init_required=None,\n               collate_fn=None,\n               config_params=None):\n    engine = VeGiantModelEngine(args=args,\n                    model=model,\n                    optimizer=optimizer,\n                    model_parameters=model_parameters,\n                    training_data=training_data,\n                    lr_scheduler=lr_scheduler,\n                    mpu=model.mpu(),\n                    dist_init_required=dist_init_required,\n                    collate_fn=collate_fn,\n                    config_params=config_params)\n\n    return_items = [\n        engine,\n        engine.optimizer,\n        engine.training_dataloader,\n        engine.lr_scheduler\n    ]\n    return tuple(return_items)\n"
  },
  {
    "path": "src/veGiantModel/distributed/__init__.py",
    "content": "from .. import patcher as dist\nfrom megatron import mpu\n\ndef get_model_parallel_world_size():\n    return dist.get_model_parallel_world_size()\n\ndef get_model_parallel_rank():\n    return dist.get_model_parallel_rank()\n\ndef get_data_parallel_world_size():\n    return dist.get_data_parallel_world_size()\n\ndef get_model_parallel_group():\n    return dist.get_model_parallel_group()\n\ndef get_grid():\n    return dist.get_grid()\n\ndef copy_to_model_parallel_region(input_):\n    return mpu.copy_to_model_parallel_region(input_)\n\ndef reduce_from_model_parallel_region(input_):\n    return mpu.reduce_from_model_parallel_region(input_)\n\ndef gather_from_model_parallel_region(input_):\n    return mpu.gather_from_model_parallel_region(input_)\n"
  },
  {
    "path": "src/veGiantModel/engine/engine.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\r\n# Copyright 2019 The Microsoft DeepSpeed Team\r\nimport os\r\n\r\nfrom types import MethodType\r\n\r\nimport torch\r\n\r\nimport torch.distributed as dist\r\n\r\nfrom deepspeed.utils.logging import logger\r\nfrom deepspeed.utils.timer import ThroughputTimer\r\n\r\nfrom deepspeed.runtime.engine import MEMORY_OPT_ALLREDUCE_SIZE\r\nfrom deepspeed.runtime.dataloader import RepeatingLoader\r\n\r\nfrom deepspeed.runtime.pipe.module import PipelineModule, PipelineError\r\nfrom deepspeed.runtime.pipe.engine import PipelineEngine\r\nfrom . import p2p\r\nfrom . import schedule\r\ntry:\r\n    import byteps.torch as bps\r\nexcept ImportError:\r\n    print(\"byteps is not installed. Pipeline parallelism is disabled\")\r\n    bps = None\r\n\r\nfrom .module import VeGiantModule\r\nfrom deepspeed.utils import log_dist\r\nimport logging\r\nfrom torch._six import inf\r\n\r\n# from inspect import signature\r\n\r\nLOG_STAGE = -2\r\nDATA_PARALLEL_ID = -2\r\n\r\ntry:\r\n    from apex import amp\r\nexcept ImportError:\r\n    # Fail silently so we don't spam logs unnecessarily if user isn't using amp\r\n    pass\r\n\r\n\r\ndef is_even(number):\r\n    return number % 2 == 0\r\n\r\nENABLE_PYTORCH_BROADCAST = os.environ.get(\"ENABLE_PYTORCH_BROADCAST\", \"0\") != \"0\"\r\n\r\n\r\n\r\nDS_PIPE_VERBOSE = int(os.environ.get('DS_PIPE_VERBOSE', \"0\"))\r\nMEGATRON_DEBUG_DATA = os.environ.get('MEGATRON_DEBUG_DATA', \"0\") != \"0\"\r\nMEGATRON_DEBUG_GRAD = os.environ.get('MEGATRON_DEBUG_GRAD', \"0\") != \"0\"\r\nENABLE_BPS_PARTITION = os.environ.get(\"ENABLE_BPS_PARTITION\", \"0\") != \"0\"\r\n\r\n\r\ndef _tensor_bytes(tensor):\r\n    return tensor.numel() * tensor.element_size()\r\n\r\ndef _dtype_to_code(dtype):\r\n    if dtype == torch.half:\r\n        return 0\r\n    elif dtype == torch.float:\r\n        return 1\r\n    elif dtype == torch.int16:\r\n        return 2\r\n    elif dtype == torch.int32:\r\n        return 3\r\n    elif dtype == torch.int64:\r\n        return 4\r\n    elif dtype == torch.bool:\r\n        return 5\r\n    else:\r\n        raise AssertionError(\"not recognized tensor type for pipeline send\")\r\n\r\ndef _code_to_dtype(code):\r\n    if code == 0:\r\n        return torch.half\r\n    elif code == 1:\r\n        return torch.float\r\n    elif code == 2:\r\n        return torch.int16\r\n    elif code == 3:\r\n        return torch.int32\r\n    elif code == 4:\r\n        return torch.int64\r\n    elif code == 5:\r\n        return torch.bool\r\n    else:\r\n        raise AssertionError(\"not recognized tensor type code for pipeline recv\")\r\n\r\nclass VeGiantModelEngine(PipelineEngine):\r\n    \"\"\" A training engine hybrid pipeline, data, and model parallel training.\r\n\r\n    This engine is created by ``deepspeed.initialize()`` when a :class:`PipelineModule`\r\n    is provided.\r\n    \"\"\"\r\n    def overwrite(self, config_params, args):\r\n        if args.batch_size is not None:\r\n            log_dist(f'overwrite dsconfig train_micro_batch_size_per_gpu to {args.batch_size}', \\\r\n                ranks=[-1], level=logging.DEBUG)\r\n            config_params['train_micro_batch_size_per_gpu'] = args.batch_size\r\n        \r\n        if args.gradient_accumulation_steps is not None:\r\n            log_dist(f'overwrite dsconfig gradient_accumulation_steps to {args.gradient_accumulation_steps}', \\\r\n                ranks=[-1], level=logging.DEBUG)\r\n            config_params['gradient_accumulation_steps'] = args.gradient_accumulation_steps\r\n\r\n        if args.train_batch_size is not None:\r\n            log_dist(f'overwrite dsconfig train_batch_size to {args.train_batch_size}, ', \\\r\n                ranks=[-1], level=logging.DEBUG)\r\n            config_params['train_batch_size'] = args.train_batch_size\r\n\r\n        if args.log_interval is not None:\r\n            config_params['steps_per_print'] = args.log_interval\r\n\r\n    def __init__(self, args,\r\n                    model,\r\n                    optimizer,\r\n                    model_parameters,\r\n                    training_data,\r\n                    lr_scheduler,\r\n                    mpu,\r\n                    dist_init_required,\r\n                    collate_fn,\r\n                    config_params):\r\n        \r\n        self.overwrite(config_params, args)\r\n        super(PipelineEngine, self).__init__(args,\r\n                    model,\r\n                    optimizer,\r\n                    model_parameters,\r\n                    training_data,\r\n                    lr_scheduler,\r\n                    mpu,\r\n                    dist_init_required,\r\n                    collate_fn,\r\n                    config_params)\r\n        assert isinstance(self.module, PipelineModule), \"model must base PipelineModule\"\r\n\r\n        # pipeline step for logging\r\n        self.args = args\r\n        self.log_batch_step_id = -1\r\n        self.train_mode = True\r\n\r\n        self.enable_backward_allreduce = False\r\n        self.micro_batch_size = self.train_micro_batch_size_per_gpu()\r\n        self.micro_batches = self.gradient_accumulation_steps()\r\n        self.first_train = True\r\n        self.first_eval = True\r\n\r\n        # Set Grid and Communication Groups\r\n        self.grid = self.module._grid\r\n        if self.grid.get_global_rank() == 0:\r\n            logger.info(f'CONFIG: micro_batches={self.micro_batches} '\r\n                        f'micro_batch_size={self.micro_batch_size}')\r\n\r\n        self.global_rank = self.grid.get_global_rank()\r\n\r\n        assert self.dp_world_size == self.grid.data_parallel_size\r\n        assert self.train_batch_size() == \\\r\n            self.micro_batch_size * self.micro_batches * self.grid.data_parallel_size\r\n\r\n        #  Set Stage Inf\r\n        self.num_stages = self.grid.pipe_parallel_size\r\n        self.stage_id = self.grid.get_stage_id()\r\n        self.mp_id = self.grid.get_model_parallel_id()\r\n        self.prev_stage = self.stage_id - 1\r\n        self.next_stage = self.stage_id + 1\r\n\r\n        self.data_iterator = None\r\n        self.batch_fn = None\r\n        self.result_dict = {}\r\n\r\n        self._force_grad_boundary = False\r\n\r\n        self.batch_timer = ThroughputTimer(batch_size=self.micro_batch_size *\r\n                                           self.micro_batches,\r\n                                           num_workers=self.dp_world_size,\r\n                                           logging_fn=self.tput_log,\r\n                                           monitor_memory=False,\r\n                                           steps_per_output=self.steps_per_print())\r\n\r\n        # PipelineEngine needs to handle data loading specially due to only the first\r\n        # and last stages loading inputs/labels. We construct a sampler that uses\r\n        if self.training_data:\r\n            self._build_data_iter(self.training_data)\r\n\r\n        self.is_pipe_parallel = self.grid.pipe_parallel_size > 1\r\n        self.is_data_parallel = self.grid.data_parallel_size > 1\r\n        self.is_model_parallel = self.grid.model_parallel_size > 1\r\n\r\n        # Partition input/output buffers\r\n        self.is_pipe_partitioned = False if self.args.broadcast_activation else (self.is_model_parallel and ENABLE_PYTORCH_BROADCAST)\r\n        self.is_grad_partitioned = False\r\n\r\n        model_parameters = filter(lambda p: p.requires_grad, self.module.parameters())\r\n        num_params = sum([p.numel() for p in model_parameters])\r\n        unique_params = num_params\r\n        # Subtract tied parameters if we don't own them\r\n        if self.module.tied_comms:\r\n            tied_params = 0\r\n            for key, d in self.module.tied_comms.items():\r\n                if self.global_rank != min(d['ranks']):\r\n                    tied_params += sum(p.numel() for p in d['module'].parameters())\r\n            unique_params -= tied_params\r\n        params_tensor = torch.LongTensor(data=[num_params,\r\n                                               unique_params]).to(self.device)\r\n        print(f'Calculating param sizes ... ', flush=True)\r\n\r\n\r\n        dist.all_reduce(params_tensor, group=self.grid.get_model_parallel_group())\r\n        params_tensor = params_tensor.tolist()\r\n        total_params = params_tensor[0]\r\n        unique_params = params_tensor[1]\r\n        if self.grid.data_parallel_id == 0:\r\n            logger.info(f'RANK={self.global_rank} '\r\n                        f'STAGE={self.stage_id} '\r\n                        f'LAYERS={self.module._local_stop - self.module._local_start} '\r\n                        f'[{self.module._local_start}, {self.module._local_stop}) '\r\n                        f'STAGE_PARAMS={num_params} ({num_params/1e6:0.3f}M) '\r\n                        f'TOTAL_PARAMS={total_params} ({total_params/1e6:0.3f}M) '\r\n                        f'UNIQUE_PARAMS={unique_params} ({unique_params/1e6:0.3f}M)')\r\n\r\n        print(f'DONE calculating param sizes. Now init proc groups', flush=True)\r\n\r\n        #intialize peer-2-peer communication and allreduce groups\r\n        if self.is_pipe_parallel:\r\n            p2p.init_process_groups(self.grid)\r\n\r\n        # Pipeline buffers\r\n        self.num_pipe_buffers = 0\r\n        self.pipe_buffers = {\r\n            'inputs' : [],   # batch input and received activations\r\n            'labels' : [],   # labels from batch input\r\n            'outputs' : [],  # activations\r\n            'output_tensors' : [], # tensor object to preserve backward graph\r\n            'bps_act_recv' : [],  # activations recv\r\n            'bps_grad_recv' : [],  # activations recv\r\n        }\r\n        self.pipe_recv_buf = None\r\n        self.grad_layer = None\r\n\r\n        self.meta_buffer = None\r\n\r\n        self.first_output_send = True\r\n        self.first_gradient_send = True\r\n\r\n        #stores the loss for the current micro batch being processed\r\n        self.loss = torch.tensor(0.0).to(self.device)\r\n        self.metric = 0\r\n\r\n        #stores the loss for the entire batch\r\n        self.total_loss = None\r\n        self.agg_loss = torch.tensor(0.0, requires_grad=False).to(self.device)\r\n        self.dp_group_loss = torch.tensor(0.0, requires_grad=False).to(self.device)\r\n\r\n        if self._config.pipeline['activation_checkpoint_interval'] > 0:\r\n            self.module.activation_checkpoint_interval = self._config.pipeline[\r\n                'activation_checkpoint_interval']\r\n\r\n        if self.is_last_stage():\r\n            self.loss_model = self.module.loss_fn\r\n\r\n        log_dist(f'Initialize pipeline communicators', \\\r\n            ranks=[-1], level=logging.DEBUG)\r\n\r\n        # Initialize pipeline communicators. Just send a 0.\r\n        if is_even(self.stage_id):\r\n            if not self.is_last_stage():\r\n                p2p.send(self.loss, self.next_stage)\r\n            if not self.is_first_stage():\r\n                p2p.recv(self.loss, self.prev_stage)\r\n        else:\r\n            if not self.is_first_stage():\r\n                p2p.recv(self.loss, self.prev_stage)\r\n            if not self.is_last_stage():\r\n                p2p.send(self.loss, self.next_stage)\r\n        \r\n        log_dist(f'DONE Initialize pipeline communicators', \\\r\n            ranks=[-1], level=logging.DEBUG)\r\n\r\n        # XXX look into timer reporting timing\r\n        # Initialize some timers because of early weirdness.\r\n        if self.wall_clock_breakdown():\r\n            self.timers('forward_microstep').start()\r\n            self.timers('forward_microstep').stop()\r\n            self.timers('backward_microstep').start()\r\n            self.timers('backward_microstep').stop()\r\n            self.timers('backward_inner_microstep').start()\r\n            self.timers('backward_inner_microstep').stop()\r\n            self.timers('backward_allreduce_microstep').start()\r\n            self.timers('backward_allreduce_microstep').stop()\r\n            self.timers('backward_allreduce').start()\r\n            self.timers('backward_allreduce').stop()\r\n            self.timers('step_microstep').start()\r\n            self.timers('step_microstep').stop()\r\n\r\n        if self.local_rank == -1:\r\n            # or number of visiable device will be better\r\n            self.local_rank = self.global_rank % torch.cuda.device_count()\r\n\r\n        if not p2p.ENABLE_PYTORCH_BROADCAST:\r\n            gpu_per_node = int(os.environ['GPU_PER_WORKER'])\r\n            print(f'bps init worker: {gpu_per_node}, {self.local_rank}/{self.global_rank}', flush=True)\r\n            os.environ['BYTEPS_LOCAL_RANK'] = str(self.local_rank)\r\n            os.environ['BYTEPS_LOCAL_SIZE'] = str(gpu_per_node)\r\n            os.environ['BYTEPS_VISIBLE_DEVICE'] = str(self.local_rank)\r\n            os.environ['DMLC_ROLE'] = 'joint'\r\n            os.environ['DMLC_WORKER_ID'] = str(self.global_rank)\r\n            bps.init(lazy=False)\r\n            print(f'bps init DONE', flush=True)\r\n\r\n\r\n    def _profiling_func_exit(self):\r\n        torch.cuda.nvtx.range_pop()\r\n    \r\n    def _profiling_func_enter(self, func):\r\n        torch.cuda.nvtx.range_push(f'stage_id: {self.stage_id}, mp_id: {self.mp_id}, fun: {func}')\r\n\r\n    def _build_data_iter(self, dataset):\r\n        if not isinstance(dataset, torch.utils.data.Dataset):\r\n            self.set_dataloader(dataset)\r\n        else:\r\n            sampler = torch.utils.data.distributed.DistributedSampler(\r\n                dataset,\r\n                num_replicas=self.dp_world_size,\r\n                rank=self.mpu.get_data_parallel_rank(),\r\n                shuffle=False)\r\n            # Build a loader and make it repeating.\r\n            pipe_dataloader = self.deepspeed_io(dataset, data_sampler=sampler)\r\n            pipe_dataloader = RepeatingLoader(pipe_dataloader)\r\n            self.set_dataloader(pipe_dataloader)\r\n\r\n    def _exec_reduce_tied_grads(self):\r\n        self._profiling_func_enter('_exec_reduce_tied_grads')\r\n        self.module.allreduce_tied_weight_gradients()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_reduce_grads(self):\r\n        self._profiling_func_enter('_exec_reduce_grads')\r\n        self._force_grad_boundary = True\r\n        if self.is_data_parallel:\r\n            self.buffered_allreduce_fallback(\r\n                elements_per_buffer=MEMORY_OPT_ALLREDUCE_SIZE)\r\n        self._force_grad_boundary = False\r\n        self._profiling_func_exit()\r\n\r\n\r\n    def _reserve_pipe_buffers(self, num_buffers):\r\n        \"\"\"Ensure that each pipeline buffer has at least ``num_buffers`` slots.\r\n\r\n        This method only reserves slots and does not allocate tensors.\r\n\r\n        Args:\r\n            num_buffers (int): The number of buffers to reserve.\r\n        \"\"\"\r\n        if self.num_pipe_buffers >= num_buffers:\r\n            return\r\n\r\n        num_added = num_buffers - self.num_pipe_buffers\r\n        for key in self.pipe_buffers:\r\n            self.pipe_buffers[key].extend([None] * num_added)\r\n        self.num_pipe_buffers = num_buffers\r\n\r\n    def train_batch(self, data_iter=None):\r\n        \"\"\"Progress the pipeline to train the next batch of data. The engine will ingest\r\n        ``self.train_batch_size()`` total samples collectively across all workers.\r\n\r\n\r\n        An iterator that over training data should be provided as an argument\r\n        unless ``deepspeed.initialize()`` was provided a training set. In that event,\r\n        the training data will automatically be read.\r\n\r\n\r\n        .. warning::\r\n            A total of ``self.gradient_accumulation_steps()`` entries will be pulled\r\n            from ``data_iter`` by each pipeline. There must be sufficient\r\n            data left in ``data_iter`` or else a ``StopIteration`` will halt training.\r\n\r\n            DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`\r\n            that wraps data loaders to automatically restart upon a ``StopIteration``.\r\n\r\n        Args:\r\n            data_iter (Iterator, optional): Iterator of training data.\r\n\r\n        Returns:\r\n            The arithmetic mean of the losses computed this batch.\r\n        \"\"\"\r\n\r\n        if DS_PIPE_VERBOSE:\r\n            print(f'[{self.global_rank}] start train_batch()', flush=True)\r\n        if not torch._C.is_grad_enabled():\r\n            raise RuntimeError(\r\n                f'train_batch() requires gradients enabled. Use eval_batch() instead.')\r\n\r\n        if data_iter is not None:\r\n            self.set_dataiterator(data_iter)\r\n\r\n        self.module.train()\r\n        self.train()\r\n        self.total_loss = None\r\n\r\n        # Do the work\r\n        self.timers('train_batch').start()\r\n        # We only enable prefetching starting from the second batch\r\n        if not ENABLE_PYTORCH_BROADCAST:\r\n            sched = schedule.BytePSTrainSchedule(micro_batches=self.micro_batches,\r\n                                                stages=self.num_stages,\r\n                                                stage_id=self.stage_id, prefetch=not self.first_train)\r\n        else:\r\n            sched = schedule.TrainSchedule(micro_batches=self.micro_batches,\r\n                                       stages=self.num_stages,\r\n                                       stage_id=self.stage_id)\r\n        cmd = ','.join(str(x) for x in sched)\r\n        # log_dist(f'stage_id: {self.stage_id}, sched:{cmd}', ranks=[-1], level=logging.INFO)\r\n        self._exec_schedule(sched)\r\n        self.agg_train_loss = self._aggregate_total_loss()\r\n        self.timers('train_batch').stop()\r\n\r\n        if self.global_steps % self.steps_per_print() == 0:\r\n            if self.global_rank == 0:\r\n                elapsed = self.timers('train_batch').elapsed(reset=True)\r\n                iter_time = elapsed / self.steps_per_print()\r\n                tput = self.train_batch_size() / iter_time\r\n                print(f'steps: {self.global_steps} '\r\n                      f'loss: {self.agg_train_loss:0.4f} '\r\n                      f'iter time (s): {iter_time:0.3f} '\r\n                      f'samples/sec: {tput:0.3f}')\r\n\r\n        # Tensorboard\r\n        if self.tensorboard_enabled():\r\n            if self.global_rank == 0:\r\n                self.summary_events = [(f'Train/Samples/train_loss',\r\n                                        self.agg_train_loss.mean().item(),\r\n                                        self.global_samples)]\r\n                for event in self.summary_events:  # write_summary_events\r\n                    self.summary_writer.add_scalar(event[0], event[1], event[2])\r\n                if self.global_steps % self.steps_per_print() == 0:\r\n                    self.summary_writer.flush()\r\n\r\n        if self.wall_clock_breakdown(\r\n        ) and self.global_steps % self.steps_per_print() == 0:\r\n            self.timers.log([\r\n                'pipe_send_output',\r\n                'pipe_send_grad',\r\n                'pipe_recv_input',\r\n                'pipe_recv_grad'\r\n            ])\r\n\r\n        # TODO: should return precisely what loss returned and allow others to be queried?\r\n        self.first_train = False\r\n        if DS_PIPE_VERBOSE:\r\n            print(f'[{self.global_rank}] DONE train_batch()', flush=True)\r\n        \r\n        self.result_dict['loss'] = self.agg_train_loss\r\n        return self.result_dict\r\n\r\n    def eval_batch(self, data_iter):\r\n        \"\"\"Evaluate the pipeline on a batch of data from ``data_iter``. The\r\n        engine will evaluate ``self.train_batch_size()`` total samples\r\n        collectively across all workers.\r\n\r\n        This method is equivalent to:\r\n\r\n        .. code-block:: python\r\n\r\n            module.eval()\r\n            with torch.no_grad():\r\n                output = module(batch)\r\n\r\n        .. warning::\r\n            A total of ``self.gradient_accumulation_steps()`` entries will be pulled\r\n            from ``data_iter`` by each pipeline. There must be sufficient\r\n            data left in ``data_iter`` or else a ``StopIteration`` will halt training.\r\n\r\n            DeepSpeed provides a convenience class :class:`deepspeed.utils.RepeatingLoader`\r\n            that wraps data loaders to automatically restart upon a ``StopIteration``.\r\n\r\n        Args:\r\n            data_iter (Iterator): Iterator of data to evaluate.\r\n\r\n        Returns:\r\n            The arithmetic mean of the losses computed this batch.\r\n        \"\"\"\r\n\r\n        self.module.eval()\r\n        self.eval()\r\n        self.total_loss = None\r\n\r\n        # Use the provided data iterator\r\n        train_iterator = self.data_iterator\r\n        self.set_dataiterator(data_iter)\r\n\r\n        # Do the work\r\n        self.timers('eval_batch').start()\r\n        if not ENABLE_PYTORCH_BROADCAST:\r\n            sched = schedule.BytePSInferenceSchedule(micro_batches=1,\r\n                                           stages=self.num_stages,\r\n                                           stage_id=self.stage_id, prefetch=False)\r\n        else:\r\n            sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,\r\n                                           stages=self.num_stages,\r\n                                           stage_id=self.stage_id)\r\n        with torch.no_grad():\r\n            self._exec_schedule(sched)\r\n\r\n        self.agg_eval_loss = self._aggregate_total_loss()\r\n        self.timers('eval_batch').stop()\r\n        # # XXX hack model attribute\r\n        # if hasattr(self.module, '_get_metrics'):\r\n        #     self.module._ref_model[0].metric = {'pscc': self._aggregate_metric()}\r\n\r\n        # if self.global_rank == 0:\r\n        #     elapsed = self.timers('eval_batch').elapsed(reset=True)\r\n        #     iter_time = elapsed\r\n        #     print(f'loss: {self.agg_eval_loss:0.4f} '\r\n        #             f'iter time (s): {iter_time:0.3f} ')\r\n\r\n        if self.tensorboard_enabled():\r\n            if self.global_rank == 0:\r\n                self.summary_events = [(f'Train/Samples/eval_loss',\r\n                                        self.agg_eval_loss.mean().item(),\r\n                                        self.global_samples)]\r\n                for event in self.summary_events:  # write_summary_events\r\n                    self.summary_writer.add_scalar(event[0], event[1], event[2])\r\n                self.summary_writer.flush()\r\n\r\n        # Restore the training iterator\r\n        self.set_dataiterator(train_iterator)\r\n\r\n        # Reset any buffers that may have been populated during the forward passes.\r\n        #ds_checkpointing.reset()\r\n        self.first_eval = False\r\n        self.result_dict['loss'] = self.agg_eval_loss\r\n        return self.result_dict\r\n\r\n    def is_first_stage(self):\r\n        \"\"\"True if this process is in the first stage in the pipeline.\"\"\"\r\n        return self.stage_id == 0\r\n\r\n    def is_last_stage(self):\r\n        \"\"\"True if this process is in the last stage in the pipeline.\"\"\"\r\n        return self.stage_id == self.num_stages - 1\r\n\r\n    def _aggregate_metric(self):\r\n        # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group\r\n        if self.is_last_stage():\r\n            if DS_PIPE_VERBOSE:\r\n                print(f'[{self.global_rank}] bcast src={self.global_rank} group={self.grid.pp_group}', flush=True)\r\n            if self.is_data_parallel:\r\n                assert False\r\n\r\n            assert self.global_rank in self.grid.pp_group\r\n            metric = torch.Tensor([self.metric]).to(self.device)\r\n            dist.broadcast(tensor=metric,\r\n                           src=self.global_rank,\r\n                           group=self.mpu.get_pipe_parallel_group())\r\n\r\n        else:\r\n            # Get loss from last stage\r\n            src_rank = self.grid.stage_to_global(self.num_stages - 1)\r\n            if DS_PIPE_VERBOSE:\r\n                print(f'[{self.global_rank}] bcast src={src_rank} group={self.grid.pp_group}', flush=True)\r\n            assert src_rank in self.grid.pp_group\r\n            metric = torch.Tensor([0.]).to(self.device)\r\n            dist.broadcast(tensor=metric,\r\n                           src=src_rank,\r\n                           group=self.grid.get_pipe_parallel_group())\r\n            self.metric = metric.clone().detach().cpu().numpy()\r\n\r\n        return self.metric\r\n\r\n    def _aggregate_total_loss(self):\r\n        # Scale loss, average among DP ranks, and bcast loss to the rest of my DP group\r\n        if self.is_last_stage():\r\n            # XXX Hack: do not scale loss\r\n            loss = self._scale_loss(self.total_loss)\r\n\r\n            self.dp_group_loss = loss.clone().detach()\r\n\r\n            ## Average loss across all data-parallel groups\r\n            agg_loss = self.dp_group_loss.clone().detach()\r\n\r\n            if DS_PIPE_VERBOSE:\r\n                print(f'[{self.global_rank}] bcast SENDER src={self.global_rank} group={self.grid.pp_group}', flush=True)\r\n            if self.is_data_parallel:\r\n                dist.all_reduce(agg_loss, group=self.mpu.get_data_parallel_group())\r\n                agg_loss /= self.dp_world_size\r\n\r\n            assert self.global_rank in self.grid.pp_group\r\n            losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)\r\n            dist.broadcast(tensor=losses,\r\n                           src=self.global_rank,\r\n                           group=self.mpu.get_pipe_parallel_group())\r\n\r\n        else:\r\n            # Get loss from last stage\r\n            src_rank = self.grid.stage_to_global(self.num_stages - 1)\r\n            assert src_rank in self.grid.pp_group\r\n            losses = torch.Tensor([0., 0.]).to(self.device)\r\n            if DS_PIPE_VERBOSE:\r\n                print(f'[{self.global_rank}] bcast RECVER src={src_rank} group={self.grid.pp_group}', flush=True)\r\n            dist.broadcast(tensor=losses,\r\n                           src=src_rank,\r\n                           group=self.grid.get_pipe_parallel_group())\r\n            self.dp_group_loss = losses[0].clone().detach()\r\n            agg_loss = losses[1].clone().detach()\r\n        if DS_PIPE_VERBOSE:\r\n            print(f'DONE aggregate total loss', flush=True)\r\n        return agg_loss\r\n\r\n    def set_dataloader(self, loader):\r\n        \"\"\"\"\"\"\r\n        if self.is_first_stage() or self.is_last_stage():\r\n            self.training_dataloader = loader\r\n            self.data_iterator = iter(self.training_dataloader)\r\n\r\n    def set_dataiterator(self, iterator):\r\n        \"\"\" Store an iterator to sample for training data. \"\"\"\r\n        if self.is_first_stage() or self.is_last_stage():\r\n            self.training_dataloader = None\r\n            self.data_iterator = iterator\r\n\r\n    def set_batch_fn(self, fn):\r\n        self.batch_fn = fn\r\n        # sig = signature(fn)\r\n        # params = sig.parameters\r\n\r\n    def is_gradient_accumulation_boundary(self):\r\n        \"\"\"True if the engine is executing a gradient reduction or optimizer step instruction.\r\n\r\n        This is overridden from :class:`DeepSpeedEngine` to force reductions\r\n        and steps when the pipeline engine is instructed to do so.\r\n\r\n        Returns:\r\n            bool: whether reductions and optimizer steps should occur.\r\n        \"\"\"\r\n        return self._force_grad_boundary\r\n\r\n\r\n    def tput_log(self, *msg):\r\n        if self.global_rank == 0 and self.global_steps % self.steps_per_print() == 0:\r\n            print(*msg)\r\n\r\n    def _next_batch(self):\r\n        if self.is_model_parallel:\r\n            mp_rank = self.grid.get_slice_parallel_rank()\r\n        else:\r\n            mp_rank = 0\r\n\r\n        batch = None\r\n\r\n        # Only MP rank 0 loads the data.\r\n        if mp_rank == 0:\r\n            if self.data_iterator is None:\r\n                raise ValueError(f\"RANK={self.global_rank} no data iterator provided.\")\r\n            batch = next(self.data_iterator)\r\n\r\n        # All MP ranks participate in batch_fn, where they might broadcast the data.\r\n        if self.batch_fn:\r\n            batch = self.batch_fn(batch, self.train_mode)\r\n\r\n        # Sanity check dimensions.\r\n        # XXX: the last minibatch with size < micro_batch_size kills us\r\n        if torch.is_tensor(batch[0]):\r\n            if batch[0].size(0) != self.micro_batch_size:\r\n                print(f'size mismatch: {batch[0].size(0)} mb: {self.micro_batch_size}')\r\n                assert batch[0].size(0) == self.micro_batch_size\r\n                return self._next_batch()\r\n        else:\r\n            assert torch.is_tensor(batch[0][0])\r\n            if batch[0][0].size(0) != self.micro_batch_size:\r\n                print(f'HB next_batch: {batch[0][0].shape} vs {self.micro_batch_size}', flush=True)\r\n                return self._next_batch()\r\n        \r\n        return batch\r\n\r\n    def _exec_bps_forward_pass(self, buffer_id):\r\n        self.tput_timer.start()\r\n        self.mem_status('BEFORE FWD', reset_max=True)\r\n        self._profiling_func_enter('_exec_bps_forward_pass')\r\n\r\n        if isinstance(self.pipe_buffers['inputs'][buffer_id], tuple):\r\n            inputs = tuple(t.clone() for t in self.pipe_buffers['inputs'][buffer_id])\r\n        else:\r\n            inputs = self.pipe_buffers['inputs'][buffer_id].clone()\r\n\r\n        # collect the partitioned input from the previous stage\r\n        assert not self.is_pipe_partitioned\r\n\r\n        # Zero out the gradients each time we use the tensor because only the data in\r\n        # tensor changes across batches\r\n        self._zero_grads(inputs)\r\n\r\n        outputs = super(PipelineEngine, self).forward(inputs)\r\n\r\n        # Partition the outputs if we are not the last stage\r\n        assert not self.is_pipe_partitioned\r\n\r\n        self.pipe_buffers['outputs'][buffer_id] = outputs\r\n\r\n        # Optionally compute loss and metrics on the last device\r\n        if self.is_last_stage():\r\n            if self.loss_model is not None:\r\n                labels = self.pipe_buffers['labels'][buffer_id]\r\n                ret = self.loss_model(outputs, labels)\r\n                if isinstance(ret, dict):\r\n                    self.result_dict = ret\r\n                    self.loss = self.result_dict['loss']\r\n                else:\r\n                    self.loss = ret\r\n            else:\r\n                # Some models just return loss from forward()\r\n                self.loss = outputs\r\n            # get metric from self.module\r\n\r\n            if isinstance(self.loss, torch.Tensor):\r\n                if self.total_loss is None:\r\n                    self.total_loss = torch.zeros_like(self.loss)\r\n                self.total_loss += self.loss.detach()\r\n            else:\r\n                if self.total_loss is None:\r\n                    self.total_loss = [torch.zeros_like(l) for l in self.loss]\r\n                for idx, l in enumerate(self.loss):\r\n                    self.total_loss[idx] += l.detach()\r\n\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_backward_pass(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_backward_pass')\r\n        assert self.optimizer is not None, \"must provide optimizer during \" \\\r\n                                           \"init in order to use backward\"\r\n\r\n        self.mem_status('BEFORE BWD', reset_max=True)\r\n\r\n        # The last stage just runs backward on the loss using DeepSpeed's typical\r\n        # mechanisms.\r\n        if self.is_last_stage():\r\n            super(PipelineEngine, self).backward(self.loss)\r\n            self.mem_status('AFTER BWD')\r\n            self._profiling_func_exit()\r\n            return\r\n\r\n        outputs = self.pipe_buffers['outputs'][buffer_id]\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('backward_microstep').start()\r\n            self.timers('backward').start()\r\n            self.timers('backward_inner_microstep').start()\r\n            self.timers('backward_inner').start()\r\n\r\n        assert not self.is_pipe_partitioned\r\n        assert not self.is_grad_partitioned\r\n        # TODO: do we need to clone()?\r\n        grad_tensors = self.pipe_buffers['bps_grad_recv'][buffer_id]\r\n\r\n        if isinstance(outputs, tuple):\r\n            out_tensors = [t for t in outputs if t.is_floating_point()]\r\n            assert len(out_tensors) == len(grad_tensors)\r\n            new_out_tensors=[]\r\n            new_grad_tensors=[]\r\n            for t,g in zip(out_tensors, grad_tensors):\r\n                if t.requires_grad:\r\n                    new_out_tensors.append(t)\r\n                    new_grad_tensors.append(g)\r\n\r\n            assert len(new_out_tensors) == len(new_grad_tensors)\r\n            torch.autograd.backward(tensors=new_out_tensors, grad_tensors=new_grad_tensors)\r\n        else:\r\n            torch.autograd.backward(tensors=(outputs,), grad_tensors=(grad_tensors,))\r\n\r\n        # Free up the memory from the output of forward()\r\n        self.pipe_buffers['output_tensors'][buffer_id] = None\r\n        self.pipe_buffers['outputs'][buffer_id] = None\r\n        grad_tensors = None\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('backward_inner').stop()\r\n            self.timers('backward_inner_microstep').stop()\r\n            self.timers('backward').stop()\r\n            self.timers('backward_microstep').stop()\r\n\r\n        self.mem_status('AFTER BWD')\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_load_micro_batch(self, buffer_id):\r\n        self._profiling_func_enter('_exec_load_micro_batch')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('batch_input').start()\r\n\r\n        batch = self._next_batch()\r\n\r\n        if self.is_first_stage():\r\n            loaded = None\r\n            if torch.is_tensor(batch[0]):\r\n                loaded = batch[0].clone().to(self.device).detach()\r\n                loaded.requires_grad = loaded.is_floating_point()\r\n                if MEGATRON_DEBUG_DATA:\r\n                    print(f'batch = {loaded.sum().detach()}', flush=True)\r\n            else:\r\n                assert isinstance(batch[0], tuple)\r\n                # Assume list or tuple\r\n                loaded = []\r\n                for x in batch[0]:\r\n                    assert torch.is_tensor(x)\r\n                    mine = x.clone().detach().to(self.device)\r\n                    mine.requires_grad = mine.is_floating_point()\r\n                    loaded.append(mine)\r\n                loaded = tuple(loaded)\r\n                if MEGATRON_DEBUG_DATA:\r\n                    print(f'rank: {self.global_rank}, stage: {self.stage_id},  batch[0] = {[x.sum().detach() for x in loaded]}', flush=True)\r\n\r\n            self.pipe_buffers['inputs'][buffer_id] = loaded\r\n\r\n        if self.is_last_stage():\r\n            loaded = batch[1]\r\n            if torch.is_tensor(batch[1]):\r\n                loaded = batch[1].to(self.device)\r\n                if MEGATRON_DEBUG_DATA:\r\n                    print(f'rank: {self.global_rank}, stage: {self.stage_id},  batch[1] = {[x.sum().detach() for x in loaded]}', flush=True)\r\n            elif isinstance(batch[1], tuple):\r\n                loaded = []\r\n                for x in batch[1]:\r\n                    assert torch.is_tensor(x)\r\n                    x = x.to(self.device).detach()\r\n                    loaded.append(x)\r\n                loaded = tuple(loaded)\r\n                if MEGATRON_DEBUG_DATA:\r\n                    print(f'rank: {self.global_rank}, stage: {self.stage_id},  batch[1] = {[x.sum().detach() for x in loaded]}', flush=True)\r\n\r\n            self.pipe_buffers['labels'][buffer_id] = loaded\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('batch_input').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _send_tensor_meta(self, buffer, recv_stage):\r\n        self._profiling_func_enter('_send_tensor_meta')\r\n        \"\"\" Communicate metadata about upcoming p2p transfers.\r\n\r\n        Metadata is communicated in this order:\r\n            * type (0: tensor, 1: list)\r\n            * num_tensors if type=list\r\n            foreach tensor in buffer:\r\n                * ndims\r\n                * shape\r\n        \"\"\"\r\n        send_bytes = 0\r\n        if isinstance(buffer, torch.Tensor):\r\n            type_tensor = torch.LongTensor(data=[0]).to(self.device)\r\n            p2p.send(type_tensor, recv_stage)\r\n            send_shape = torch.LongTensor(data=buffer.size()).to(self.device)\r\n            send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)\r\n            send_dtype = torch.LongTensor(data=[_dtype_to_code(buffer.dtype)]).to(self.device)\r\n            p2p.send(send_ndims, recv_stage)\r\n            p2p.send(send_shape, recv_stage)\r\n            p2p.send(send_dtype, recv_stage)\r\n            send_bytes += _tensor_bytes(buffer)\r\n        elif isinstance(buffer, list):\r\n            assert (False)\r\n            type_tensor = torch.LongTensor(data=[1]).to(self.device)\r\n            p2p.send(type_tensor, recv_stage)\r\n            count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)\r\n            p2p.send(count_tensor, recv_stage)\r\n            for tensor in buffer:\r\n                assert isinstance(tensor, torch.Tensor)\r\n                send_shape = torch.LongTensor(data=tensor.size()).to(self.device)\r\n                send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)\r\n                send_dtype = torch.LongTensor(data=_dtype_to_code([tensor.dtype])).to(self.device)\r\n                p2p.send(send_ndims, recv_stage)\r\n                p2p.send(send_shape, recv_stage)\r\n                p2p.send(send_dtype, recv_stage)\r\n                send_bytes += _tensor_bytes(tensor)\r\n        elif isinstance(buffer, tuple):\r\n            type_tensor = torch.LongTensor(data=[2]).to(self.device)\r\n            p2p.send(type_tensor, recv_stage)\r\n            count_tensor = torch.LongTensor(data=[len(buffer)]).to(self.device)\r\n            p2p.send(count_tensor, recv_stage)\r\n            for idx, tensor in enumerate(buffer):\r\n                assert isinstance(tensor, torch.Tensor)\r\n                send_shape = torch.LongTensor(data=tensor.size()).to(self.device)\r\n                send_ndims = torch.LongTensor(data=[len(tensor.size())]).to(self.device)\r\n                send_dtype = torch.LongTensor(data=[_dtype_to_code(tensor.dtype)]).to(self.device)\r\n                p2p.send(send_ndims, recv_stage)\r\n                p2p.send(send_shape, recv_stage)\r\n                p2p.send(send_dtype, recv_stage)\r\n                # Useful for performance debugging.\r\n                '''\r\n                new_bytes = _tensor_bytes(tensor)\r\n                send_bytes += _tensor_bytes(tensor)\r\n                # Useful for performance debugging.\r\n                if self.grid.data_parallel_id == 0:\r\n                    print(\r\n                        f'STAGE={self.stage_id} pipe-send-volume[{idx}]: shape={send_shape} {new_bytes/1024**2:0.2f}MB'\r\n                    )\r\n                '''\r\n        else:\r\n            raise NotImplementedError(f'Could not send meta type {type(buffer)}')\r\n\r\n        self._profiling_func_exit()\r\n        # Useful for performance debugging.\r\n        '''\r\n        if self.grid.data_parallel_id == 0:\r\n            print(f'STAGE={self.stage_id} pipe-send-volume: {send_bytes/1024**2:0.2f}MB')\r\n        '''\r\n\r\n    def _recv_tensor_meta(self, send_stage):\r\n        self._profiling_func_enter('_recv_tensor_meta')\r\n        \"\"\"Receive metadata about upcoming p2p transfers and return allocated buffers.\r\n\r\n        Metadata is communicated in this order:\r\n            * type (0: tensor, 1: list)\r\n            * num_tensors if type=list\r\n            foreach tensor in buffer:\r\n                * ndims\r\n                * shape\r\n\r\n        Returns:\r\n            Allocated buffer for receiving from send_stage.\r\n        \"\"\"\r\n\r\n        type_tensor = torch.LongTensor(data=[0]).to(self.device)\r\n        p2p.recv(type_tensor, send_stage)\r\n        recv_type = type_tensor.item()\r\n\r\n        # A single tensor will be sent.\r\n        if recv_type == 0:\r\n            recv_ndims = torch.LongTensor(data=[0]).to(self.device)\r\n            p2p.recv(recv_ndims, send_stage)\r\n            recv_ndims = recv_ndims.item()\r\n            recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)\r\n            p2p.recv(recv_shape, send_stage)\r\n            recv_shape = recv_shape.tolist()\r\n            recv_dtype = torch.LongTensor(data=[0]).to(self.device)\r\n            p2p.recv(recv_dtype, send_stage)\r\n            recv_dtype_code = recv_dtype.item()\r\n            recv_dtype = _code_to_dtype(recv_dtype_code)\r\n            return self._allocate_buffer2(recv_shape, recv_dtype, num_buffers=1)[0]\r\n\r\n        # List or tuple of tensors\r\n        elif recv_type == 1 or recv_type == 2:\r\n            count_tensor = torch.LongTensor(data=[0]).to(self.device)\r\n            p2p.recv(count_tensor, send_stage)\r\n            num_tensors = count_tensor.item()\r\n            recv_shapes = []\r\n            recv_dtypes = []\r\n            for idx in range(num_tensors):\r\n                recv_ndims = torch.LongTensor(data=[0]).to(self.device)\r\n                p2p.recv(recv_ndims, send_stage)\r\n                recv_ndims = recv_ndims.item()\r\n                recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)\r\n                p2p.recv(recv_shape, send_stage)\r\n                recv_shapes.append(recv_shape.tolist())\r\n                recv_dtype = torch.LongTensor(data=[0]).to(self.device)\r\n                p2p.recv(recv_dtype, send_stage)\r\n                recv_dtype_code = recv_dtype.item()\r\n                recv_dtype = _code_to_dtype(recv_dtype_code)\r\n                recv_dtypes.append(recv_dtype)\r\n\r\n            buffers = self._allocate_buffers2(recv_shapes, recv_dtypes, num_buffers=1)[0]\r\n            # Convert to tuples if requested.\r\n            if recv_type == 2:\r\n                buffers = tuple(buffers)\r\n            return buffers\r\n\r\n        else:\r\n            raise NotImplementedError(f'Could not receive type {type(recv_type)}')\r\n        self._profiling_func_exit()\r\n\r\n    def _mp_slice(self, x):\r\n        mp_size = self.grid.get_model_parallel_world_size()\r\n        return x.reshape((mp_size, -1))[self.mp_id:self.mp_id+1, :].detach()\r\n\r\n    def _mp_view(self, x, rank):\r\n        mp_size = self.grid.get_model_parallel_world_size()\r\n        return x.view((mp_size, -1))[rank:rank+1, :]\r\n\r\n    def _exec_bps_send_partitioned_activations(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_send_activations')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_output').start()\r\n\r\n        outputs = self.pipe_buffers['outputs'][buffer_id]\r\n\r\n        if self.first_output_send:\r\n            self.first_output_send = False\r\n            self._send_tensor_meta(outputs, self.next_stage)\r\n\r\n        assert not self.args.broadcast_activation\r\n        assert ENABLE_BPS_PARTITION\r\n        name = f'act_{buffer_id}'\r\n        if isinstance(outputs, torch.Tensor):\r\n            p2p.bps_send(self._mp_slice(outputs.contiguous()),\r\n                         self.next_stage, name, index=0, async_op=True)\r\n        elif isinstance(outputs, (tuple, list)):\r\n            for idx, buffer in enumerate(outputs):\r\n                if DS_PIPE_VERBOSE >= 3:\r\n                    print(f'DS BPS_SEND tensors {idx}/{len(outputs)}', flush=True)\r\n                p2p.bps_send(self._mp_slice(buffer.contiguous()), self.next_stage,\r\n                             name, index=idx, async_op=True)\r\n        else:\r\n            raise NotImplementedError('Could not send output of type '\r\n                                      f'{type(outputs)}')\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_output').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_send_activations(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_send_activations')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_output').start()\r\n\r\n        outputs = self.pipe_buffers['outputs'][buffer_id]\r\n\r\n        if self.first_output_send:\r\n            self.first_output_send = False\r\n            self._send_tensor_meta(outputs, self.next_stage)\r\n\r\n        assert not self.args.broadcast_activation\r\n        assert not ENABLE_BPS_PARTITION\r\n        if self.mp_id == 0:\r\n            name = f'act_{buffer_id}'\r\n            if isinstance(outputs, torch.Tensor):\r\n                p2p.bps_send(outputs.contiguous(), self.next_stage, name, index=0, async_op=True)\r\n            elif isinstance(outputs, (tuple, list)):\r\n                for idx, buffer in enumerate(outputs):\r\n                    if DS_PIPE_VERBOSE >= 3:\r\n                        print(f'DS BPS_SEND tensors {idx}/{len(outputs)} start', flush=True)\r\n                    p2p.bps_send(buffer.contiguous(), self.next_stage, name, index=idx, async_op=True)\r\n                    if DS_PIPE_VERBOSE >= 3:\r\n                        print(f'DS BPS_SEND tensors {idx}/{len(outputs)} end', flush=True)\r\n            else:\r\n                raise NotImplementedError('Could not send output of type '\r\n                                          f'{type(outputs)}')\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_output').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_send_grads(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_send_grads')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_grad').start()\r\n\r\n        inputs = self.pipe_buffers['inputs'][buffer_id]\r\n\r\n        # Partition the gradient\r\n        assert not self.is_grad_partitioned\r\n        assert not self.args.broadcast_grads\r\n\r\n        name = f'grad_{buffer_id}'\r\n        # only MP rank 0 sends the gradient\r\n        if self.grid.get_model_parallel_rank() == 0:\r\n            if isinstance(inputs, torch.Tensor):\r\n                if inputs.grad is None:\r\n                    send_data = self._allocate_zeros(inputs.size())\r\n                else:\r\n                    send_data = inputs.grad\r\n                assert send_data.is_floating_point()\r\n                assert send_data is not None\r\n                p2p.bps_send(send_data, self.prev_stage, name, index=0, async_op=True)\r\n\r\n            else:\r\n                for idx, buffer in enumerate(inputs):\r\n                    if not buffer.is_floating_point():\r\n                        continue\r\n                    if buffer.grad is None:\r\n                        send_data = self._allocate_zeros(buffer.size())\r\n                    else:\r\n                        send_data = buffer.grad\r\n                    assert send_data.is_floating_point()\r\n                    assert send_data is not None\r\n                    p2p.bps_send(send_data, self.prev_stage, name, index=idx, async_op=True)\r\n\r\n        # We can free up the input buffer now\r\n        self.pipe_buffers['inputs'][buffer_id] = None\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_grad').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_send_partitioned_grads(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_send_grads')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_grad').start()\r\n\r\n        inputs = self.pipe_buffers['inputs'][buffer_id]\r\n\r\n        # Partition the gradient\r\n        assert not self.is_grad_partitioned\r\n        assert not self.args.broadcast_grads\r\n        assert ENABLE_BPS_PARTITION\r\n\r\n        name = f'grad_{buffer_id}'\r\n        if isinstance(inputs, torch.Tensor):\r\n            if inputs.grad is None:\r\n                send_data = self._allocate_zeros(inputs.size())\r\n            else:\r\n                send_data = inputs.grad\r\n            assert send_data.is_floating_point()\r\n            assert send_data is not None\r\n            p2p.bps_send(self._mp_slice(send_data), self.prev_stage, name,\r\n                         index=0, async_op=True)\r\n        else:\r\n            for idx, buffer in enumerate(inputs):\r\n                if not buffer.is_floating_point():\r\n                    continue\r\n                if buffer.grad is None:\r\n                    send_data = self._allocate_zeros(buffer.size())\r\n                else:\r\n                    send_data = buffer.grad\r\n                assert send_data.is_floating_point()\r\n                assert send_data is not None\r\n                p2p.bps_send(self._mp_slice(send_data), self.prev_stage,\r\n                             name, index=idx, async_op=True)\r\n\r\n        # We can free up the input buffer now\r\n        self.pipe_buffers['inputs'][buffer_id] = None\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_send_grad').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_sync_all(self):\r\n        p2p.bps_sync_all()\r\n\r\n    def _exec_bps_sync_partitioned_grads(self, buffer_id):\r\n        name = f'grad_{buffer_id}'\r\n        recv_buff = self.pipe_buffers['bps_grad_recv'][buffer_id]\r\n        if isinstance(recv_buff, torch.Tensor):\r\n            p2p.bps_sync(self.next_stage, name, index=0)\r\n        else:\r\n            for i in range(len(recv_buff)):\r\n                p2p.bps_sync(self.next_stage, name, index=i)\r\n\r\n        # all_gather the gradient from other ranks\r\n        mp_size = self.grid.model_parallel_size\r\n        if mp_size > 1:\r\n            src_rank = self.grid.slice_parallel_src_id\r\n            group = self.grid.slice_proc_group\r\n            if isinstance(recv_buff, torch.Tensor):\r\n                recv_buff_views = [self._mp_view(recv_buff, i) for i in range(mp_size)]\r\n                dist.all_gather(recv_buff_views, recv_buff_views[self.mp_id].clone(),\r\n                                group=group, async_op=False)\r\n            else:\r\n                for i in range(len(recv_buff)):\r\n                    if recv_buff[i].is_floating_point():\r\n                        recv_buff_views = [self._mp_view(recv_buff[i], j) for j in range(mp_size)]\r\n                        dist.all_gather(recv_buff_views, recv_buff_views[self.mp_id].clone(),\r\n                                        group=group, async_op=False)\r\n\r\n    def _exec_bps_sync_grads(self, buffer_id):\r\n        name = f'grad_{buffer_id}'\r\n        recv_buff = self.pipe_buffers['bps_grad_recv'][buffer_id]\r\n        if self.mp_id == 0:\r\n            if isinstance(recv_buff, torch.Tensor):\r\n                p2p.bps_sync(self.next_stage, name, index=0)\r\n            else:\r\n                for i in range(len(recv_buff)):\r\n                    p2p.bps_sync(self.next_stage, name, index=i)\r\n\r\n        # broadcast the activation at MP rank 0 to other ranks\r\n        if self.grid.model_parallel_size > 1:\r\n            src_rank = self.grid.slice_parallel_src_id\r\n            group = self.grid.slice_proc_group\r\n            if isinstance(recv_buff, torch.Tensor):        \r\n                dist.broadcast(recv_buff, src_rank, group=group, async_op=False)\r\n            else:\r\n                for i in range(len(recv_buff)):\r\n                    if recv_buff[i].is_floating_point():\r\n                        dist.broadcast(recv_buff[i], src_rank, group=group, async_op=False)\r\n\r\n    def _exec_bps_sync_partitioned_activations(self, buffer_id):\r\n        recv_buff = self.pipe_buffers['bps_act_recv'][buffer_id]\r\n        recvd = None\r\n        src_rank = self.grid.slice_parallel_src_id\r\n        mp_size = self.grid.model_parallel_size\r\n        group = self.grid.slice_proc_group\r\n        name = f'act_{buffer_id}'\r\n\r\n        if isinstance(recv_buff, torch.Tensor):\r\n            p2p.bps_sync(self.prev_stage, name, index=0)\r\n            # broadcast the activation at MP rank 0 to other ranks\r\n            if mp_size > 1:\r\n                recv_buff_views = [self._mp_view(recv_buff, i) for i in range(mp_size)]\r\n                dist.all_gather(recv_buff_views, recv_buff_views[self.mp_id].clone(),\r\n                                group=group, async_op=False)\r\n            recvd = recv_buff.clone().detach()\r\n            recvd.requires_grad = recv_buff.is_floating_point()\r\n        else:\r\n            recvd = [None] * len(recv_buff)\r\n            for i in range(len(recv_buff)):\r\n                p2p.bps_sync(self.prev_stage, name, index=i)\r\n                # broadcast the activation at MP rank 0 to other ranks\r\n                if mp_size > 1:\r\n                    recv_buff_views = [self._mp_view(recv_buff[i], j) for j in range(mp_size)]\r\n                    dist.all_gather(recv_buff_views, recv_buff_views[self.mp_id].clone(),\r\n                                    group=group, async_op=False)\r\n                recvd[i] = recv_buff[i].clone().detach()\r\n            recvd = tuple(recvd)\r\n            for buffer in recvd:\r\n                buffer.requires_grad = buffer.is_floating_point()\r\n\r\n        self.pipe_buffers['inputs'][buffer_id] = recvd\r\n\r\n    def _exec_bps_sync_activations(self, buffer_id):\r\n        recv_buff = self.pipe_buffers['bps_act_recv'][buffer_id]\r\n        recvd = None\r\n        src_rank = self.grid.slice_parallel_src_id\r\n        group = self.grid.slice_proc_group\r\n        name = f'act_{buffer_id}'\r\n\r\n        if isinstance(recv_buff, torch.Tensor):\r\n            if self.mp_id == 0:        \r\n                p2p.bps_sync(self.prev_stage, name, index=0)\r\n            # broadcast the activation at MP rank 0 to other ranks\r\n            if self.grid.model_parallel_size > 1:\r\n                dist.broadcast(recv_buff, src_rank, group=group, async_op=False)\r\n            recvd = recv_buff.clone().detach()\r\n            recvd.requires_grad = recv_buff.is_floating_point()\r\n        else:\r\n            recvd = [None] * len(recv_buff)\r\n            for i in range(len(recv_buff)):\r\n                if self.mp_id == 0:\r\n                    p2p.bps_sync(self.prev_stage, name, index=i)\r\n                # broadcast the activation at MP rank 0 to other ranks\r\n                if self.grid.model_parallel_size > 1:\r\n                    dist.broadcast(recv_buff[i], src_rank, group=group, async_op=False)\r\n                recvd[i] = recv_buff[i].clone().detach()\r\n            recvd = tuple(recvd)\r\n            for buffer in recvd:\r\n                buffer.requires_grad = buffer.is_floating_point()\r\n\r\n        self.pipe_buffers['inputs'][buffer_id] = recvd\r\n\r\n    def _exec_bps_recv_partitioned_activations(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_recv_activations')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_input').start()\r\n\r\n        recv_buffs = self.pipe_buffers['bps_act_recv']\r\n\r\n        # Allocate the buffer if necessary\r\n        if recv_buffs[buffer_id] is None:\r\n            if recv_buffs[0] is None:\r\n                recv_buffs[buffer_id] = self._recv_tensor_meta(self.prev_stage)\r\n            else:\r\n                if torch.is_tensor(recv_buffs[0]):\r\n                    recv_buffs[buffer_id] = recv_buffs[0].clone().detach()\r\n                else:\r\n                    recv_buffs[buffer_id] = tuple([x.clone().detach() for x in recv_buffs[0]])\r\n\r\n        assert not self.args.broadcast_activation\r\n        assert not self.is_pipe_partitioned\r\n        recv_buff = recv_buffs[buffer_id]\r\n        name = f'act_{buffer_id}'\r\n        if isinstance(recv_buff, torch.Tensor):\r\n            p2p.bps_recv(self._mp_view(recv_buff, self.mp_id), self.prev_stage,\r\n                         name, index=0, async_op=True)\r\n        else:\r\n            assert isinstance(recv_buff, (tuple, list))\r\n            for idx, buffer in enumerate(recv_buff):\r\n                assert torch.is_tensor(buffer)\r\n                p2p.bps_recv(self._mp_view(buffer, self.mp_id), self.prev_stage,\r\n                             name, index=idx, async_op=True)\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_input').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_recv_activations(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_recv_activations')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_input').start()\r\n\r\n        recv_buffs = self.pipe_buffers['bps_act_recv']\r\n\r\n        # Allocate the buffer if necessary\r\n        if recv_buffs[buffer_id] is None:\r\n            if recv_buffs[0] is None:\r\n                recv_buffs[buffer_id] = self._recv_tensor_meta(self.prev_stage)\r\n            else:\r\n                if torch.is_tensor(recv_buffs[0]):\r\n                    recv_buffs[buffer_id] = recv_buffs[0].clone().detach()\r\n                else:\r\n                    recv_buffs[buffer_id] = tuple([x.clone().detach() for x in recv_buffs[0]])\r\n\r\n        assert not self.args.broadcast_activation\r\n        assert not self.is_pipe_partitioned\r\n        recv_buff = recv_buffs[buffer_id]\r\n        if self.mp_id == 0:\r\n            name = f'act_{buffer_id}'\r\n            if isinstance(recv_buff, torch.Tensor):\r\n                p2p.bps_recv(recv_buff, self.prev_stage, name, index=0, async_op=True)\r\n            else:\r\n                assert isinstance(recv_buff, (tuple, list))\r\n                for idx, buffer in enumerate(recv_buff):\r\n                    assert torch.is_tensor(buffer)\r\n                    p2p.bps_recv(buffer, self.prev_stage, name, index=idx, async_op=True)\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_input').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_recv_partitioned_grads(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_recv_grads')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_grad').start()\r\n\r\n        outputs = self.pipe_buffers['outputs'][buffer_id]\r\n        grad_buffs = self.pipe_buffers['bps_grad_recv']\r\n        # Restore partitioned output if it was partitioned and we are sending full gradients\r\n        assert not self.is_pipe_partitioned\r\n        assert not self.is_grad_partitioned\r\n        assert not self.args.broadcast_grads\r\n        assert ENABLE_BPS_PARTITION\r\n        # Allocate gradient if necessary\r\n        if grad_buffs[buffer_id] is None:\r\n            if isinstance(outputs, torch.Tensor):\r\n                s = list(outputs.size())\r\n                grad_buffs[buffer_id] = self._allocate_buffer(s, num_buffers=1)[0]\r\n            else:\r\n                sizes = [list(t.size()) for t in outputs if t.is_floating_point()]\r\n                grad_buffs[buffer_id] = self._allocate_buffers(sizes, num_buffers=1)[0]\r\n        grad_buff = grad_buffs[buffer_id]\r\n        name = f'grad_{buffer_id}'\r\n        if isinstance(grad_buff, torch.Tensor):\r\n            p2p.bps_recv(self._mp_view(grad_buff, self.mp_id), self.next_stage,\r\n                         name, index=0, async_op=True)\r\n        else:\r\n            assert isinstance(outputs, tuple)\r\n            recv_idx = 0\r\n            for idx, buffer in enumerate(grad_buff):\r\n                p2p.bps_recv(self._mp_view(buffer, self.mp_id), self.next_stage,\r\n                             name, index=recv_idx, async_op=True)\r\n                recv_idx += 1\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_grad').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_bps_recv_grads(self, buffer_id):\r\n        self._profiling_func_enter('_exec_bps_recv_grads')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_grad').start()\r\n\r\n        outputs = self.pipe_buffers['outputs'][buffer_id]\r\n        grad_buffs = self.pipe_buffers['bps_grad_recv']\r\n        # Restore partitioned output if it was partitioned and we are sending full gradients\r\n        assert not self.is_pipe_partitioned\r\n        assert not self.is_grad_partitioned\r\n        assert not self.args.broadcast_grads\r\n        # Allocate gradient if necessary\r\n        if grad_buffs[buffer_id] is None:\r\n            if isinstance(outputs, torch.Tensor):\r\n                s = list(outputs.size())\r\n                grad_buffs[buffer_id] = self._allocate_buffer(s, num_buffers=1)[0]\r\n            else:\r\n                sizes = [list(t.size()) for t in outputs if t.is_floating_point()]\r\n                grad_buffs[buffer_id] = self._allocate_buffers(sizes, num_buffers=1)[0]\r\n        grad_buff = grad_buffs[buffer_id]\r\n        name = f'grad_{buffer_id}'\r\n        if isinstance(grad_buff, torch.Tensor):\r\n            if self.mp_id == 0:\r\n                p2p.bps_recv(grad_buff, self.next_stage, name, index=0, async_op=True)\r\n        else:\r\n            assert isinstance(outputs, tuple)\r\n            recv_idx = 0\r\n            if self.mp_id == 0:\r\n                for idx, buffer in enumerate(grad_buff):\r\n                    p2p.bps_recv(buffer, self.next_stage, name, index=recv_idx, async_op=True)\r\n                    recv_idx += 1\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('pipe_recv_grad').stop()\r\n        self._profiling_func_exit()\r\n\r\n    def _exec_optimizer_step(self, lr_kwargs=None):\r\n        self._profiling_func_enter('_exec_optimizer_step')\r\n        if self.wall_clock_breakdown():\r\n            self.timers('step_microstep').start()\r\n            self.timers('step').start()\r\n        self.mem_status('BEFORE STEP', reset_max=True)\r\n\r\n        if self.global_rank == 0 and MEGATRON_DEBUG_GRAD:\r\n             params = list(self.module.named_parameters())\r\n             for i in (0, 1, -2, -1):\r\n                 p = params[i]\r\n                 if p[1] is None:\r\n                     print(f'name={p[0]} | None', flush=True)\r\n                 elif p[1].grad is None:\r\n                     print(f'name={p[0]} | weight={p[1].mean()}', flush=True)\r\n                 else:\r\n                     print(f'name={p[0]} | weight={p[1].norm()} | grad={p[1].grad.norm()}', flush=True)\r\n             params_w_grad = []\r\n             params_wo_grad = []\r\n             for p in params:\r\n                 if p[1].grad is not None:\r\n                     params_w_grad.append(p[0])\r\n                 else:\r\n                     params_wo_grad.append(p[0])\r\n\r\n        self._force_grad_boundary = True\r\n        self._take_model_step(lr_kwargs)\r\n        self._force_grad_boundary = False\r\n\r\n        self.mem_status('AFTER STEP')\r\n\r\n        if self.tensorboard_enabled():\r\n            if self.global_rank == 0:\r\n                self.summary_events = [(f'Train/Samples/lr',\r\n                                        self.get_lr()[0],\r\n                                        self.global_samples)]\r\n                if self.fp16_enabled() and hasattr(self.optimizer, 'cur_scale'):\r\n                    self.summary_events.append((f'Train/Samples/loss_scale',\r\n                                                self.optimizer.cur_scale,\r\n                                                self.global_samples))\r\n                for event in self.summary_events:  # write_summary_events\r\n                    self.summary_writer.add_scalar(event[0], event[1], event[2])\r\n\r\n        if self.wall_clock_breakdown():\r\n            self.timers('step_microstep').stop()\r\n            self.timers('step').stop()\r\n            if self.global_steps % self.steps_per_print() == 0:\r\n                self.timers.log([\r\n                    'batch_input',\r\n                    'forward_microstep',\r\n                    'backward_microstep',\r\n                    'backward_inner_microstep',\r\n                    'backward_allreduce_microstep',\r\n                    'backward_tied_allreduce_microstep',\r\n                    'step_microstep'\r\n                ])\r\n            if self.global_steps % self.steps_per_print() == 0:\r\n                self.timers.log([\r\n                    'forward',\r\n                    'backward',\r\n                    'backward_inner',\r\n                    'backward_allreduce',\r\n                    'step'\r\n                ])\r\n        self._profiling_func_exit()\r\n\r\n    def _zero_grads(self, inputs):\r\n        if isinstance(inputs, torch.Tensor):\r\n            if inputs.grad is not None:\r\n                inputs.grad.data.zero_()\r\n        else:\r\n            for t in inputs:\r\n                if t.grad is not None:\r\n                    t.grad.data.zero_()\r\n\r\n    def _allocate_zeros(self, shape, fp16=None, **kwargs):\r\n        \"\"\" Allocate a tensor of zeros on the engine's device.\r\n\r\n        Arguments:\r\n            shape: the shape of the tensor to allocate\r\n            fp16 (bool): whether to use FP16. default: defer to self.fp16_enabled()\r\n            kwargs: passed to torch.zeros()\r\n\r\n        Returns:\r\n            A tensor from torch.zeros() allocated on self.device.\r\n        \"\"\"\r\n\r\n        if fp16 is None:\r\n            fp16 = self.fp16_enabled()\r\n\r\n        if fp16:\r\n            return torch.zeros(shape, dtype=torch.half, device=self.device, **kwargs)\r\n        else:\r\n            return torch.zeros(shape, device=self.device, **kwargs)\r\n\r\n    def _allocate_zeros2(self, shape, dtype, **kwargs):\r\n        return torch.zeros(shape, dtype=dtype, device=self.device, **kwargs)\r\n\r\n    def _allocate_buffer(self, shape, num_buffers=-1, **kwargs):\r\n        buffers = []\r\n        if num_buffers == -1:\r\n            num_buffers = self.num_pipe_buffers\r\n        for count in range(num_buffers):\r\n            buffers.append(self._allocate_zeros(shape, **kwargs))\r\n        return buffers\r\n\r\n    def _allocate_buffer2(self, shape, dtype, num_buffers=-1, **kwargs):\r\n        buffers = []\r\n        if num_buffers == -1:\r\n            num_buffers = self.num_pipe_buffers\r\n        for count in range(num_buffers):\r\n            buffers.append(self._allocate_zeros2(shape, dtype, **kwargs))\r\n        return buffers\r\n\r\n    def _allocate_buffers(self, shapes, requires_grad=False, num_buffers=-1):\r\n        buffers = []\r\n        if num_buffers == -1:\r\n            num_buffers = self.num_pipe_buffers\r\n        for count in range(num_buffers):\r\n            buffer = []\r\n            for shape in shapes:\r\n                buffer.append(self._allocate_zeros(shape, requires_grad=requires_grad))\r\n            buffers.append(buffer)\r\n        return buffers\r\n\r\n    def _allocate_buffers2(self, shapes, dtypes, requires_grad=False, num_buffers=-1):\r\n        buffers = []\r\n        if num_buffers == -1:\r\n            num_buffers = self.num_pipe_buffers\r\n        for count in range(num_buffers):\r\n            buffer = []\r\n            for i in range(len(shapes)):\r\n                buffer.append(self._allocate_zeros2(shapes[i], dtypes[i], requires_grad=requires_grad))\r\n            buffers.append(buffer)\r\n        return buffers\r\n\r\n    def forward(self, *args, **kwargs):\r\n        \"\"\"Disabled for pipeline parallel training. See ``train_batch()``. \"\"\"\r\n        raise PipelineError(\"Only train_batch() is accessible in pipeline mode.\")\r\n\r\n    def backward(self, *args, **kwargs):\r\n        \"\"\"Disabled for pipeline parallel training. See ``train_batch()``. \"\"\"\r\n        raise PipelineError(\"Only train_batch() is accessible in pipeline mode.\")\r\n\r\n    def step(self, *args, **kwargs):\r\n        \"\"\"Disabled for pipeline parallel training. See ``train_batch()``. \"\"\"\r\n        raise PipelineError(\"Only train_batch() is accessible in pipeline mode.\")\r\n\r\n    # A map of PipeInstruction types to methods. Each method will be executed with the\r\n    # kwargs provided to the PipeInstruction from the scheduler.\r\n    _INSTRUCTION_MAP = {\r\n        schedule.OptimizerStep: _exec_optimizer_step,\r\n        schedule.ReduceGrads: _exec_reduce_grads,\r\n        schedule.ReduceTiedGrads: _exec_reduce_tied_grads,\r\n        schedule.LoadMicroBatch: _exec_load_micro_batch,\r\n        schedule.BytePSForwardPass: _exec_bps_forward_pass,\r\n        schedule.BytePSBackwardPass: _exec_bps_backward_pass,\r\n        schedule.BytePSSendActivation: _exec_bps_send_partitioned_activations if ENABLE_BPS_PARTITION else _exec_bps_send_activations,\r\n        schedule.BytePSRecvActivation: _exec_bps_recv_partitioned_activations if ENABLE_BPS_PARTITION else _exec_bps_recv_activations,\r\n        schedule.BytePSSyncActivation: _exec_bps_sync_partitioned_activations if ENABLE_BPS_PARTITION else _exec_bps_sync_activations,\r\n        schedule.BytePSSyncGrad: _exec_bps_sync_partitioned_grads if ENABLE_BPS_PARTITION else _exec_bps_sync_grads,\r\n        schedule.BytePSSendGrad: _exec_bps_send_partitioned_grads if ENABLE_BPS_PARTITION else _exec_bps_send_grads,\r\n        schedule.BytePSRecvGrad: _exec_bps_recv_partitioned_grads if ENABLE_BPS_PARTITION else _exec_bps_recv_grads,\r\n        schedule.BytePSSyncAll: _exec_bps_sync_all\r\n    }\r\n\r\n    def _exec_schedule(self, pipe_schedule):\r\n        self._reserve_pipe_buffers(pipe_schedule.num_pipe_buffers())\r\n        # For each step in the schedule\r\n        has_optim_step = False\r\n        for step_cmds in pipe_schedule:\r\n            # For each instruction in the step\r\n            for cmd in step_cmds:\r\n                if isinstance(cmd, schedule.OptimizerStep):\r\n                    has_optim_step = True\r\n                if DS_PIPE_VERBOSE:\r\n                    if \"buffer_id\" in cmd.kwargs:\r\n                        print(f'[{self.grid.get_global_rank()}] | cmd={cmd.__class__.__name__} | {cmd.kwargs[\"buffer_id\"]}', flush=True)\r\n                    else:\r\n                        print(f'[{self.grid.get_global_rank()}] | cmd={cmd.__class__.__name__}', flush=True)\r\n                if type(cmd) not in self._INSTRUCTION_MAP:\r\n                    raise RuntimeError(\r\n                        f'{self.__class__.__name__} does not understand instruction {repr(cmd)}'\r\n                    )\r\n\r\n                self._exec_instr = MethodType(self._INSTRUCTION_MAP[type(cmd)], self)\r\n                self._exec_instr(**cmd.kwargs)\r\n        # check for anormalies\r\n        if isinstance(pipe_schedule, (schedule.BytePSTrainSchedule, schedule.TrainSchedule)):\r\n            assert has_optim_step\r\n"
  },
  {
    "path": "src/veGiantModel/engine/module.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n# Copyright 2019 The Microsoft DeepSpeed Team\nimport os\n\nimport re as regex\n\nfrom functools import partial\n\nimport torch\nimport torch.nn as nn\nimport torch.distributed as dist\n\nfrom math import floor\n\nfrom deepspeed.utils import logger\nfrom deepspeed.runtime import utils as ds_utils\nfrom deepspeed.runtime.activation_checkpointing import checkpointing\nfrom deepspeed.pipe import PipelineModule,LayerSpec, TiedLayerSpec\nfrom .topology import PipeDataParallelTopology, PipelineParallelGrid\n\nclass VeGiantModule(PipelineModule):\n    def __init__(self,\n                 layers,\n                 num_stages=None,\n                 loss_fn=None,\n                 seed_layers=False,\n                 seed_fn=None,\n                 base_seed=1234,\n                 grid=None,\n                 partition_method='parameters',\n                 activation_checkpoint_interval=0,\n                 activation_checkpoint_func=checkpointing.checkpoint):\n        \"\"\"Modules to be parallelized with pipeline parallelism.\n\n        The key constraint that enables pipeline parallelism is the\n        representation of the forward pass as a sequence of layers\n        and the enforcement of a simple interface between them. The\n        forward pass is implicitly defined by the module ``layers``. The key\n        assumption is that the output of each layer can be directly fed as\n        input to the next, like a ``torch.nn.Sequence``. The forward pass is\n        implicitly:\n\n        .. code-block:: python\n\n            def forward(self, inputs):\n                x = inputs\n                for layer in self.layers:\n                    x = layer(x)\n                return x\n\n        Args:\n            layers (Iterable): A sequence of layers defining pipeline structure. Can be a ``torch.nn.Sequential`` module.\n            num_stages (int, optional): The degree of pipeline parallelism. If not specified, ``topology`` must be provided.\n            topology (``deepseed.pipe.ProcessTopology``, optional): Defines the axes of parallelism axes for training. Must be provided if ``num_stages`` is ``None``.\n            loss_fn (callable, optional): Loss is computed ``loss = loss_fn(outputs, label)``\n            base_seed (int, optional): [description]. Defaults to 1234.\n            partition_method (str, optional): [description]. Defaults to 'parameters'.\n            activation_checkpoint_interval (int, optional): The granularity activation checkpointing in terms of number of layers. 0 disables activation checkpointing.\n            activation_checkpoint_func (callable, optional): The function to use for activation checkpointing. Defaults to ``deepspeed.checkpointing.checkpoint``.\n        \"\"\"\n\n        super(PipelineModule, self).__init__()\n\n        topology = grid.topology() if grid is not None else None\n\n        if num_stages is None and topology is None:\n            raise RuntimeError('must provide num_stages or topology')\n\n        self.micro_offset = 0\n\n        self.loss_fn = loss_fn\n\n        self.seed_layers = seed_layers\n        self.seed_fn = seed_fn\n        self.base_seed = base_seed\n        if dist.get_rank() == 0:\n            try:\n                seed_str = self.seed_fn.__name__\n            except AttributeError:\n                seed_str = None\n            print(\n                f'SEED_LAYERS={self.seed_layers} BASE_SEED={self.base_seed} SEED_FN={seed_str}'\n            )\n\n        # Setup world info\n        self.world_group = dist.new_group(ranks=range(dist.get_world_size()))\n        self.global_rank = dist.get_rank(group=self.world_group)\n        self.world_size = dist.get_world_size(group=self.world_group)\n\n        if topology:\n            self._topo = topology\n            self.num_stages = self._topo.get_dim('pipe')\n        else:\n            self.num_stages = num_stages\n            if topology is None:\n                if self.world_size % self.num_stages != 0:\n                    raise RuntimeError(\n                        f'num_stages ({self.num_stages}) must divide distributed world size ({self.world_size})'\n                    )\n                dp = self.world_size // num_stages\n                topology = PipeDataParallelTopology(num_pp=num_stages, num_dp=dp)\n                self._topo = topology\n\n        # Contruct communicators for pipeline topology\n        self._grid = grid if grid is not None else PipelineParallelGrid(process_group=self.world_group, topology=self._topo)\n\n        self.stage_id = self._topo.get_coord(self.global_rank).pipe\n\n        # Initialize partition information\n        self._layer_specs = list(layers)\n        self._num_layers = len(self._layer_specs)\n        self._local_start = 0\n        self._local_stop = None\n        self._partition_layers(method=partition_method)\n\n        self.forward_funcs = []\n        self.tied_modules = nn.ModuleDict()\n        self.tied_weight_attrs = {}\n\n        # Offset the random seed by the stage ID.\n        #newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()\n        #ds_utils.set_random_seed(newseed)\n\n        #with torch.random.fork_rng(devices=[torch.cuda.current_device()]):\n        self._build()\n        self.to('cuda')\n\n        self.tied_comms = self._index_tied_modules()\n        self._synchronize_tied_weights()\n\n        self.activation_checkpoint_interval = activation_checkpoint_interval\n        self.activation_checkpoint_func = activation_checkpoint_func\n\n    def _build(self):\n        specs = self._layer_specs\n\n        for local_idx, layer in enumerate(specs[self._local_start:self._local_stop]):\n            layer_idx = local_idx + self._local_start\n            if self.seed_layers:\n                if self.seed_fn:\n                    self.seed_fn(self.base_seed + layer_idx)\n                else:\n                    ds_utils.set_random_seed(self.base_seed + layer_idx)\n\n            # Recursively build PipelineModule objects\n            if isinstance(layer, PipelineModule):\n                raise NotImplementedError('RECURSIVE BUILD NOT YET IMPLEMENTED')\n\n            # LayerSpec objects contain an nn.Module that should be allocated now.\n            elif isinstance(layer, nn.Module):\n                name = str(layer_idx)\n                self.forward_funcs.append(layer)\n                self.add_module(name, layer)\n\n            # TiedLayerSpec objects contain an nn.Module that should be allocated now.\n            elif isinstance(layer, TiedLayerSpec):\n                # Build and register the module if we haven't seen it before.\n                if layer.key not in self.tied_modules:\n                    self.tied_modules[layer.key] = layer.build()\n                    self.tied_weight_attrs[layer.key] = layer.tied_weight_attr\n\n                if layer.forward_fn is None:\n                    # Just use forward()\n                    self.forward_funcs.append(self.tied_modules[layer.key])\n                else:\n                    # User specified fn with args (module, input)\n                    self.forward_funcs.append(\n                        partial(layer.forward_fn,\n                                self.tied_modules[layer.key]))\n\n            # LayerSpec objects contain an nn.Module that should be allocated now.\n            elif isinstance(layer, LayerSpec):\n                module = layer.build()\n                name = str(layer_idx)\n                self.forward_funcs.append(module)\n                self.add_module(name, module)\n\n            # Last option: layer may be a functional (e.g., lambda). We do nothing in\n            # that case and just use it in forward()\n            else:\n                self.forward_funcs.append(layer)\n\n        # All pipeline parameters should be considered as model parallel in the context\n        # of our FP16 optimizer\n        for p in self.parameters():\n            p.model_parallel = True\n\n    def _count_layer_params(self):\n        \"\"\"Count the trainable parameters in individual layers.\n\n        This routine will only build one layer at a time.\n\n        Returns:\n            A list of the number of parameters in each layer.\n        \"\"\"\n        param_counts = [0] * len(self._layer_specs)\n        for idx, layer in enumerate(self._layer_specs):\n            if isinstance(layer, LayerSpec):\n                l = layer.build()\n                params = filter(lambda p: p.requires_grad, l.parameters())\n                param_counts[idx] = sum(p.numel() for p in params)\n            elif isinstance(layer, nn.Module):\n                params = filter(lambda p: p.requires_grad, layer.parameters())\n                param_counts[idx] = sum(p.numel() for p in params)\n        return param_counts\n\n    def _find_layer_type(self, layername):\n        idxs = []\n        typeregex = regex.compile(layername, regex.IGNORECASE)\n        for idx, layer in enumerate(self._layer_specs):\n            name = None\n            if isinstance(layer, LayerSpec):\n                name = layer.typename.__name__\n            elif isinstance(layer, nn.Module):\n                name = layer.__class__.__name__\n            else:\n                try:\n                    name = layer.__name__\n                except AttributeError:\n                    continue\n            if typeregex.search(name):\n                idxs.append(idx)\n\n        if len(idxs) == 0:\n            raise RuntimeError(\n                f\"Partitioning '{layername}' found no valid layers to partition.\")\n        return idxs\n\n    def forward(self, forward_input):\n        # We need to offset the seed by the microbatch ID. Save it in a local var to\n        # ensure it is preserved in the closure. Otherwise checkpointed forward funcs\n        # will see a different offset.\n        self.micro_offset += 1\n\n        def exec_range_func(start, end):\n            ''' Helper function to be used with checkpoint()\n            Adapted from torch.utils.checkpoint:checkpoint_sequential()\n            '''\n            local_micro_offset = self.micro_offset + 1\n\n            def exec_func(*inputs):\n                # Single tensor inputs need to be unwrapped\n                if len(inputs) == 1:\n                    inputs = inputs[0]\n                for idx, layer in enumerate(self.forward_funcs[start:end]):\n                    self.curr_layer = idx + self._local_start\n                    if self.seed_layers:\n                        new_seed = (self.base_seed *\n                                    local_micro_offset) + self.curr_layer\n                        if self.seed_fn:\n                            self.seed_fn(new_seed)\n                        else:\n                            ds_utils.set_random_seed(new_seed)\n\n                    inputs = layer(inputs)\n                return inputs\n\n            return exec_func\n\n        if self.activation_checkpoint_interval == 0:\n            func = exec_range_func(0, len(self.forward_funcs))\n            x = func(forward_input)\n        else:\n            num_layers = len(self.forward_funcs)\n            x = forward_input\n            for start_idx in range(0, num_layers, self.activation_checkpoint_interval):\n                end_idx = min(start_idx + self.activation_checkpoint_interval,\n                              num_layers)\n\n                funcs = self.forward_funcs[start_idx:end_idx]\n                # Since we either pass tensors or tuples of tensors without unpacking, we\n                # need to be careful not to double-wrap tensors with tuple.\n                if not isinstance(x, tuple):\n                    x = (x, )\n\n                if self._is_checkpointable(funcs):\n                    x = self.activation_checkpoint_func(\n                        exec_range_func(start_idx,\n                                        end_idx),\n                        *x)\n                else:\n                    x = exec_range_func(start_idx, end_idx)(*x)\n        return x\n\n    def _partition_uniform(self, num_items, num_parts):\n        # print(f'enter _partition_uniform', flush=True)\n        parts = [0] * (num_parts + 1)\n        if num_items <= num_parts:\n            for p in range(num_parts + 1):\n                parts[p] = min(p, num_items)\n            return parts\n        expected_chunksize = num_items / num_parts\n        for p in range(num_parts):\n            parts[p] = min(floor(expected_chunksize * p), num_items)\n        parts[num_parts] = num_items\n        return parts\n\n    def _partition_balanced(self, weights, num_parts, eps=1e-3):\n        num_items = len(weights)\n        # First check for the trivial edge case\n        if num_items <= num_parts:\n            return self._partition_uniform(num_items, num_parts)\n\n        weights_ = ds_utils.prefix_sum_inc(weights)\n\n        # Find the smallest bottleneck (weight of heaviest partition)\n        bottleneck = ds_utils._rb_partition_balanced(weights_, num_parts, eps=eps)\n\n        # Now compute that partitioning\n        parts, success = ds_utils._lprobe(weights_, num_parts, bottleneck)\n        assert success\n\n        return parts\n\n    def _partition_layers(self, method='uniform'):\n        num_stages = self._topo.get_dim('pipe')\n        stage_id = self._topo.get_coord(self.global_rank).pipe\n\n        if self.global_rank == 0:\n            logger.info(f'Partitioning pipeline stages with method {method}')\n\n        method = method.lower()\n\n        # Each stage gets a simple uniform number of layers.\n        if method == 'uniform':\n            num_layers = len(self._layer_specs)\n            self.parts = self._partition_uniform(num_items=num_layers,\n                                            num_parts=num_stages)\n        elif method == 'parameters':\n            param_counts = self._count_layer_params()\n            self.parts = self._partition_balanced(weights=param_counts,\n                                                     num_parts=num_stages)\n        elif method.startswith('type:'):\n            layertype = method.split(':')[1]\n            binary_weights = [0] * len(self._layer_specs)\n            for idx in self._find_layer_type(layertype):\n                binary_weights[idx] = 1\n            else:\n                self.parts = self._partition_balanced(weights=binary_weights,\n                                                         num_parts=num_stages)\n        elif method.startswith('manual:'):\n            msplit = method.split(':')\n            layernum = int(msplit[1])\n            layerparts = msplit[2].split(',')\n            assert len(self._layer_specs) == layernum # failsafe check for layer num\n            assert num_stages == len(layerparts)-1 # failsafe check for num stages\n            self.parts = list(map(int, layerparts))\n        elif method == 'profile':\n            raise NotImplementedError(f'Partitioning method {method} not implemented.')\n        else:\n            raise NotImplementedError(f'Partitioning method {method} not implemented.')\n\n        # Print some information on the partitioning.\n        if self.global_rank == 0:\n            for stage in range(num_stages):\n                start = self.parts[stage]\n                stop = self.parts[stage + 1]\n                print(f'stage={stage} layers={stop - start}')\n                for idx, layer in enumerate(self._layer_specs[start:stop]):\n                    name = str(layer)\n                    if isinstance(layer, LayerSpec):\n                        name = layer.typename.__name__\n                    if isinstance(layer, nn.Module):\n                        name = layer.__class__.__name__\n                    else:\n                        try:\n                            name = layer.__name__\n                        except AttributeError:\n                            pass\n                    print(f'    {idx+start:2d}: {name}')\n            if self.loss_fn:\n                try:\n                    print(f'  loss: {self.loss_fn.__name__}')\n                except AttributeError:\n                    print(f'  loss: {self.loss_fn.__class__.__name__}')\n\n        self._set_bounds(start=self.parts[stage_id], stop=self.parts[stage_id + 1])\n\n    def allreduce_tied_weight_gradients(self):\n        '''All reduce the gradients of the tied weights between tied stages'''\n        for key, comm in self.tied_comms.items():\n            weight = getattr(self.tied_modules[key], comm['weight_attr'])\n            dist.all_reduce(weight.grad, group=comm['group'])\n\n    def _synchronize_tied_weights(self):\n        for key, comm in self.tied_comms.items():\n            dist.broadcast(\n                getattr(comm['module'],\n                        comm['weight_attr']),\n                src=min(comm['ranks']),\n                group=comm['group'],\n            )\n\n    def _index_tied_modules(self):\n        ''' Build communication structures for tied modules. '''\n        tied_comms = {}\n        if self._topo.get_dim('pipe') == 1:\n            return tied_comms\n\n        specs = self._layer_specs\n        tie_keys = set(s.key for s in specs if isinstance(s, TiedLayerSpec))\n        for key in tie_keys:\n            # Find the layers that the tied module appears in\n            tied_layers = []\n            for idx, layer in enumerate(specs):\n                if isinstance(layer, TiedLayerSpec) and layer.key == key:\n                    tied_layers.append(idx)\n            # Find all stages with this tied module\n            # TODO: Would be nice to remove the nested data/model parallelism loops and\n            # TODO: instead generalize in some way, since we really just care about the\n            # TODO: stage that owns the tied layer. Then loop over each (dp, mp, ...)\n            # TODO: fiber to generate process groups.\n            tied_stages = set(self.stage_owner(idx) for idx in tied_layers)\n            for dp in range(self._grid.data_parallel_size):\n                for mp in range(self._grid.model_parallel_size):\n                    tied_ranks = []\n                    for s in sorted(tied_stages):\n                        if self._grid.model_parallel_size > 1:\n                            tied_ranks.append(\n                                self._grid.stage_to_global(stage_id=s,\n                                                           data=dp,\n                                                           model=mp))\n                        else:\n                            tied_ranks.append(\n                                self._grid.stage_to_global(stage_id=s,\n                                                           data=dp))\n                    group = dist.new_group(ranks=tied_ranks)\n\n                    # Record this tied module if we own a local copy of it.\n                    if self.global_rank in tied_ranks:\n                        assert key in self.tied_modules\n                        if key in self.tied_modules:\n                            tied_comms[key] = {\n                                'ranks': tied_ranks,\n                                'group': group,\n                                'weight_attr': self.tied_weight_attrs[key],\n                                'module': self.tied_modules[key],\n                            }\n                            # Only count the tied module once in the eyes of the FP16 optimizer\n                            if self.global_rank != tied_ranks[0]:\n                                for p in self.tied_modules[key].parameters():\n                                    p.model_parallel = False\n        '''\n        if len(tied_comms) > 0:\n            print(f'RANK={self.global_rank} tied_comms={tied_comms}')\n        '''\n\n        return tied_comms\n\n    def partitions(self):\n        return self.parts\n\n    def stage_owner(self, layer_idx):\n        assert 0 <= layer_idx < self._num_layers\n        for stage in range(self._topo.get_dim('pipe')):\n            if self.parts[stage] <= layer_idx < self.parts[stage + 1]:\n                return stage\n        raise RuntimeError(f'Layer {layer_idx} not owned? parts={self.parts}')\n\n    def _set_bounds(self, start=None, stop=None):\n        \"\"\"Manually define the range of layers that will be built on this process.\n\n        These boundaries are treated as list slices and so start is inclusive and stop is\n        exclusive. The default of None for both results in all layers being built\n        locally.\n        \"\"\"\n        self._local_start = start\n        self._local_stop = stop\n\n    def set_checkpoint_interval(self, interval):\n        assert interval >= 0\n        self.checkpoint_interval = interval\n\n    def topology(self):\n        \"\"\" ProcessTopology object to query process mappings. \"\"\"\n        return self._topo\n\n    def mpu(self):\n        return self._grid\n\n    def num_pipeline_stages(self):\n        return self._topo.get_dim('pipe')\n\n    def ckpt_prefix(self, checkpoints_path, tag):\n        \"\"\"Build a prefix for all checkpoint files written by this module. \"\"\"\n        # All checkpoint files start with this\n        rank_name = 'module'\n\n        # Data parallelism is omitted from the naming convention because we are agnostic\n        # to this in the checkpoint.\n        omit_dims = frozenset(['data'])\n        axes = [a for a in self._grid._topo.get_axis_names() if a not in omit_dims]\n        for dim in axes:\n            rank = getattr(self._grid._topo.get_coord(rank=self.global_rank), dim)\n            rank_name += f'-{dim}_{rank:02d}'\n\n        ckpt_name = os.path.join(checkpoints_path, str(tag), rank_name)\n        return ckpt_name\n\n    def ckpt_layer_path(self, ckpt_dir, local_layer_idx):\n        \"\"\"Customize a prefix for a specific pipeline module layer. \"\"\"\n        idx = local_layer_idx + self._local_start\n        layer_ckpt_path = os.path.join(ckpt_dir, f'layer_{idx:02d}')\n        rank_repr = self._grid._topo.get_rank_repr(rank=self.global_rank)\n        if rank_repr is not '':\n            layer_ckpt_path += f'-{rank_repr}'\n        layer_ckpt_path += '-model_states.pt'\n        return layer_ckpt_path\n\n    def save_state_dict(self, save_dir):\n        if self._grid.data_parallel_id != 0:\n            return\n\n        os.makedirs(save_dir, exist_ok=True)\n        layer_offset = self._local_start\n        for idx, layer in enumerate(self.forward_funcs):\n            model_ckpt_path = self.ckpt_layer_path(save_dir, idx)\n            if not hasattr(layer, 'state_dict'):\n                continue\n            torch.save(layer.state_dict(), model_ckpt_path)\n\n    def load_state_dir(self, load_dir, strict=True):\n        rank = dist.get_rank()\n\n        layer_offset = self._local_start\n        for idx, layer in enumerate(self.forward_funcs):\n            # Functions, etc. will not have state_dicts\n            if not hasattr(layer, 'load_state_dict'):\n                continue\n\n            model_ckpt_path = self.ckpt_layer_path(load_dir, idx)\n            layer.load_state_dict(torch.load(model_ckpt_path,\n                                             map_location=lambda storage,\n                                             loc: storage),\n                                  strict=strict)\n            if self._grid.data_parallel_id == 0:\n                logger.info(\n                    f'RANK={self.global_rank} Loaded layer={idx+layer_offset} file={model_ckpt_path}'\n                )\n\n        self._synchronize_tied_weights()\n\n    def _is_checkpointable(self, funcs):\n        if self.__class__.__name__ == 'GPT2ModelPipe':\n            return all('ParallelTransformerLayerPipe' in f.__class__.__name__\n                       for f in funcs)\n\n        params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]\n        return any(len(list(p)) > 0 for p in params)\n"
  },
  {
    "path": "src/veGiantModel/engine/p2p.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\r\n# Copyright 2019 The Microsoft DeepSpeed Team\r\n'''\r\nCopyright 2019 The Microsoft DeepSpeed Team\r\n'''\r\n\r\nimport os\r\nimport torch\r\nimport torch.distributed as dist\r\nfrom deepspeed.utils import logger, log_dist\r\n\r\nENABLE_PYTORCH_BROADCAST = os.environ.get(\"ENABLE_PYTORCH_BROADCAST\", \"0\") != \"0\"\r\n\r\ntry:\r\n    if not ENABLE_PYTORCH_BROADCAST:\r\n        import byteps.torch as bps\r\n    else:\r\n        print(\"BytePS import is disabled\", flush=True)\r\n        bps = None\r\nexcept ImportError:\r\n    print(\"BytePS is not installed\")\r\n    bps = None\r\n\r\n_groups = None\r\n_grid = None\r\n\r\nDS_PIPE_VERBOSE = os.environ.get('DS_PIPE_VERBOSE', \"0\") != \"0\"\r\n\r\ndid_recv = False\r\nsend_stream = None\r\nrecv_stream = None \r\n\r\n\r\nbps_send_handles = {}\r\nbps_recv_handles = {}\r\n\r\n\r\n#initializes adjacent process groups\r\n#run this only after torch.distributed.init_process_group() has been called\r\ndef init_process_groups(grid):\r\n    global _groups, _grid\r\n    _grid = grid\r\n\r\n    assert _grid.pipe_parallel_size > 1, \"There is no model parallelism\"\r\n\r\n    _groups = [dist.new_group(ranks=group) for group in _grid.p2p_groups]\r\n\r\n\r\ndef _is_valid_send_recv(src_stage, dest_stage):\r\n    first_stage = 0\r\n    last_stage = _grid.pipe_parallel_size - 1\r\n    assert abs(src_stage-dest_stage) == 1 or \\\r\n        (src_stage == first_stage and dest_stage == last_stage) or \\\r\n        (src_stage == last_stage and dest_stage == first_stage), \\\r\n    \"Functionality currently limited to send and receive between adjacent ranks only\"\r\n\r\n\r\ndef send(tensor, dest_stage, async_op=False):\r\n    global _groups\r\n\r\n    async_op = False\r\n    src_stage = _grid.get_stage_id()\r\n    _is_valid_send_recv(src_stage, dest_stage)\r\n\r\n    group = _get_send_recv_group(src_stage, dest_stage)\r\n    src_rank = _grid.stage_to_global(stage_id=src_stage)\r\n\r\n    import torch\r\n    if tensor.dtype != torch.float32 and DS_PIPE_VERBOSE:\r\n        print('warning: p2p send', tensor.dtype, tensor.shape, flush=True)\r\n    return _send(tensor, src_rank, group, async_op)\r\n\r\ndef _bps_get_name(src, dest, name, suffix):\r\n    return \"_\".join([str(src), str(dest), str(name), str(suffix)])\r\n\r\ndef bps_send(tensor, dest_stage, name, index, async_op=True):\r\n    global bps_send_handles\r\n\r\n    src_stage = _grid.get_stage_id()\r\n    _is_valid_send_recv(src_stage, dest_stage)\r\n    src_rank = _grid.stage_to_global(stage_id=src_stage)\r\n    dest_rank = _grid.stage_to_global(stage_id=dest_stage)\r\n    name = _bps_get_name(src_rank, dest_rank, name, index)\r\n    if name not in bps_send_handles:\r\n        # XXX hard-code max number of tensors for this name\r\n        bps_send_handles[name] = [None] * 10\r\n    else:\r\n        handle = bps_send_handles[name][index]\r\n        if handle is not None:\r\n            bps.synchronize(handle)\r\n    handle = bps.send_async(tensor, dest_rank, name=name)\r\n    # XXX\r\n    if not async_op:\r\n        bps.synchronize(handle)\r\n    else:\r\n        bps_send_handles[name][index] = handle\r\n    return tensor\r\n\r\ndef bps_sync(src_stage, name, index=0):\r\n    dest_stage = _grid.get_stage_id()\r\n    _is_valid_send_recv(src_stage, dest_stage)\r\n    src_rank = _grid.stage_to_global(stage_id=src_stage)\r\n    dest_rank = _grid.stage_to_global(stage_id=dest_stage)\r\n    name = _bps_get_name(src_rank, dest_rank, name, index)\r\n    if name in bps_recv_handles:\r\n        handle = bps_recv_handles[name][index]\r\n        if handle is not None:\r\n            bps.synchronize(handle)\r\n\r\ndef bps_sync_all():\r\n    for name, handles in bps_send_handles.items():\r\n        for handle in handles:\r\n            if handle is not None:\r\n                bps.synchronize(handle)\r\n\r\n    for name, handles in bps_recv_handles.items():\r\n        for handle in handles:\r\n            if handle is not None:\r\n                bps.synchronize(handle)\r\n\r\ndef bps_recv(tensor, src_stage, name, index=0, async_op=True):\r\n    global bps_recv_handles\r\n\r\n    dest_stage = _grid.get_stage_id()\r\n    _is_valid_send_recv(src_stage, dest_stage)\r\n    src_rank = _grid.stage_to_global(stage_id=src_stage)\r\n    dest_rank = _grid.stage_to_global(stage_id=dest_stage)\r\n    name = _bps_get_name(src_rank, dest_rank, name, index)\r\n    if name not in bps_recv_handles:\r\n        # XXX hard-code max number of tensors for this name\r\n        bps_recv_handles[name] = [None] * 10\r\n    else:\r\n        handle = bps_recv_handles[name][index]\r\n        if handle is not None:\r\n            bps.synchronize(handle)\r\n    handle = bps.recv_async(tensor, src_rank, name=name)\r\n    if not async_op:\r\n        bps.synchronize(handle)\r\n    else:\r\n        bps_recv_handles[name][index] = handle\r\n    return tensor\r\n\r\n\r\ndef _send(tensor, src_rank, group, async_op):\r\n    global did_recv\r\n    return dist.broadcast(tensor, src_rank, group=group, async_op=async_op)\r\n\r\ndef send_grads(tensor, grid, async_op=False):\r\n    async_op = False\r\n    if  grid.send_grads_src_rank == grid.global_rank:\r\n        # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}, send_grad_groups: {grid.send_grads_proc_group}', flush=True)\r\n        _send(tensor, grid.send_grads_src_rank, grid.send_grads_proc_group, async_op)\r\n        # print(f'finis rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True)\r\n    else:\r\n        # print(f'finish fast rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True)\r\n        pass\r\n\r\ndef _recv(tensor, src_rank, group, async_op):\r\n    global did_recv\r\n    tensor = dist.broadcast(tensor, src_rank, group=group, async_op=async_op)\r\n    did_recv = True\r\n    return tensor\r\n\r\ndef recv_grads(tensor, grid, async_op=False):\r\n    async_op = False\r\n    # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}, recv_grad_groups: {grid.recv_grads_proc_group}', flush=True)\r\n    _recv(tensor, grid.recv_grads_src_rank, grid.recv_grads_proc_group, async_op)\r\n    # print(f'finish rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}', flush=True)\r\n\r\n\r\ndef send_activations(tensor, grid, async_op=False):\r\n    async_op = False\r\n    if  grid.send_activation_src_rank == grid.global_rank:\r\n        # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}, send_grad_groups: {grid.send_grads_proc_group}', flush=True)\r\n        _send(tensor, grid.send_activation_src_rank, grid.send_activation_proc_group, async_op)\r\n        # print(f'finis rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True)\r\n    else:\r\n        # print(f'finish fast rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _send_grad_src_rank: {grid.send_grads_src_rank}, send group: {grid.send_grads_group}', flush=True)\r\n        pass \r\n\r\ndef recv_activations(tensor, grid, async_op=False):\r\n    async_op = False\r\n    # print(f'start rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}, recv_grad_groups: {grid.recv_grads_proc_group}', flush=True)\r\n    _recv(tensor, grid.recv_activation_src_rank, grid.recv_activation_proc_group, async_op)\r\n    # print(f'finish rank: {grid.global_rank}, stage_id: {grid.stage_id}, mp_id: {grid.model_parallel_id}, _recv_grad_src_rank: {grid.recv_grads_src_rank}, recv group: {grid.recv_grads_group}', flush=True)\r\n\r\ndef recv(tensor, src_stage, async_op=False):\r\n    global _groups\r\n    global did_recv\r\n\r\n    async_op = False\r\n    dest_stage = _grid.get_stage_id()\r\n    _is_valid_send_recv(src_stage, dest_stage)\r\n\r\n    group = _get_send_recv_group(src_stage, dest_stage)\r\n    src_rank = _grid.stage_to_global(stage_id=src_stage)\r\n    return _recv(tensor, src_rank, group, async_op)\r\n\r\n\r\ndef barrier(stage_id):\r\n    global _groups, _grid\r\n    group_id = _grid.stage_to_global(stage_id=stage_id)\r\n    if (dist.get_rank() >= 0):\r\n        print(\"Barrier Group ID\", group_id)\r\n        print(\"Barrier Group\", _grid.p2p_groups[group_id])\r\n    dist.barrier(group=_groups[group_id])\r\n    if (dist.get_rank() >= 0):\r\n        print(\"Exiting Barrier \", group_id)\r\n\r\n\r\ndef _get_send_recv_group(src_stage, dest_stage):\r\n    '''the group id is always the smaller rank unless its a wrap around'''\r\n\r\n    stage_id = None\r\n\r\n    first_stage = 0\r\n    last_stage = _grid.pipe_parallel_size - 1\r\n\r\n    if (src_stage == first_stage and dest_stage == last_stage\r\n            or dest_stage == first_stage and src_stage == last_stage):\r\n        stage_id = last_stage\r\n    elif src_stage > dest_stage:\r\n        stage_id = dest_stage\r\n    else:\r\n        stage_id = src_stage\r\n    '''group_id corresponds to group of [group_id, group_id+1]\r\n     unless group_id is the rank of the last stage\r\n     in which case group_id correspods to group[group_id-num_stages+1, group_id]\r\n     '''\r\n    group_id = _grid.stage_to_global(stage_id=stage_id)\r\n\r\n    return _groups[group_id]\r\n"
  },
  {
    "path": "src/veGiantModel/engine/schedule.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\r\nfrom deepspeed.runtime.pipe.schedule import (\r\n    BufferOpInstruction,PipeInstruction,\r\n    ReduceTiedGrads,ReduceGrads,OptimizerStep,\r\n    LoadMicroBatch,PipeSchedule,TrainSchedule,\r\n)\r\n\r\nimport os\r\n\r\nBYTEPS_REDUCED_MEM = os.environ.get('BYTEPS_REDUCED_MEM', '1') != '0'\r\n\r\nclass BytePSInferenceSchedule(PipeSchedule):\r\n    \"\"\"A schedule for inferencing batches using pipeline parallelism.\r\n    \"\"\"\r\n    def __init__(self, micro_batches, stages, stage_id, prefetch=True):\r\n        super().__init__(micro_batches, stages, stage_id)\r\n        self.prefetch = prefetch\r\n\r\n    def steps(self):\r\n        \"\"\"\"\"\"\r\n        total_steps = self.micro_batches + self.stages - 1\r\n        for step_id in range(total_steps):\r\n            cmds = []\r\n            micro_batch_id = step_id - self.stage_id\r\n\r\n            buffer_id = micro_batch_id % self.num_pipe_buffers()\r\n            batch_is_valid = self._valid_micro_batch(micro_batch_id)\r\n\r\n            if not self.prefetch:    \r\n                if batch_is_valid:\r\n                    if self.is_first_stage or self.is_last_stage:\r\n                        cmds.append(LoadMicroBatch(buffer_id))\r\n                    if self._valid_stage(self.prev_stage):\r\n                        cmds.append(BytePSRecvActivation(buffer_id))\r\n                        cmds.append(BytePSSyncActivation(buffer_id))\r\n                    cmds.append(BytePSForwardPass(buffer_id))\r\n                    if self._valid_stage(self.next_stage):\r\n                        cmds.append(BytePSSendActivation(buffer_id))\r\n            else:\r\n                next_buffer_id = (micro_batch_id + 1) % self.num_pipe_buffers()\r\n                next_batch_is_valid = self._valid_micro_batch(micro_batch_id + 1)\r\n                # micro_batch starts at 0. Get the current batch, and start prefetching\r\n                if micro_batch_id == 0:\r\n                    if self.is_first_stage or self.is_last_stage:\r\n                        cmds.append(LoadMicroBatch(buffer_id))\r\n                    if self._valid_stage(self.prev_stage):\r\n                        cmds.append(BytePSRecvActivation(buffer_id))\r\n                        if next_batch_is_valid:\r\n                            cmds.append(BytePSRecvActivation(next_buffer_id))\r\n                        cmds.append(BytePSSyncActivation(buffer_id))\r\n                    cmds.append(BytePSForwardPass(buffer_id))\r\n                    if self._valid_stage(self.next_stage):\r\n                        cmds.append(BytePSSendActivation(buffer_id))\r\n                elif batch_is_valid:\r\n                    # After micro_batch 0, we prefetch the next one,\r\n                    # and wait for the current one\r\n                    if self._valid_stage(self.prev_stage) and next_batch_is_valid:\r\n                        cmds.append(BytePSRecvActivation(next_buffer_id))\r\n                    if self.is_first_stage or self.is_last_stage:\r\n                        cmds.append(LoadMicroBatch(buffer_id))\r\n                    if self._valid_stage(self.prev_stage):\r\n                        cmds.append(BytePSSyncActivation(buffer_id))\r\n                    cmds.append(BytePSForwardPass(buffer_id))\r\n                    if self._valid_stage(self.next_stage):\r\n                        cmds.append(BytePSSendActivation(buffer_id))\r\n\r\n            yield cmds\r\n\r\n    def num_pipe_buffers(self):\r\n        \"\"\"Only `self.micro_batches` pipeline buffers are required for inferencing.\r\n\r\n        Returns:\r\n            ``self.micro_batches``\r\n        \"\"\"\r\n        buffers = min(self.micro_batches, self.stages * 2)\r\n        if BYTEPS_REDUCED_MEM:\r\n            buffers = min(self.stages + 1, self.micro_batches)\r\n        return max(2, buffers)\r\n\r\n\r\nclass BytePSTrainSchedule(TrainSchedule):\r\n    \"\"\"A schedule for training a batch using hybrid parallelism.\r\n\r\n    Pipeline parallelism is extracted through gradient accumulation and thus\r\n    convergence follows that of a data parallel approach with the same batch\r\n    size.\r\n    \"\"\"\r\n    def __init__(self, micro_batches, stages, stage_id, prefetch=True):\r\n        super().__init__(micro_batches, stages, stage_id)\r\n        self.prefetch = prefetch and micro_batches > 1\r\n        if not self.prefetch:\r\n            print('BYTEPS NO PREFETCH STEPS', flush=True)\r\n\r\n    def steps(self):\r\n        if self.prefetch:\r\n            return self._steps()\r\n        else:\r\n            return self._steps_no_prefetch()\r\n\r\n    def _steps(self):\r\n        \"\"\"\"\"\"\r\n        total_steps = 2 * (self.micro_batches + self.stages - 1)\r\n        for step_id in range(total_steps):\r\n            # Map the step of the pipeline to the micro-batch id and also whether it is a\r\n            # forward or backward pass step.\r\n            cmds = []\r\n            micro_batch_id, is_forward = self._step_to_micro_batch(step_id)\r\n            batch_is_valid = self._valid_micro_batch(micro_batch_id)\r\n            if not batch_is_valid:\r\n                if step_id == total_steps - 1:\r\n                    cmds.append(BytePSSyncAll())\r\n                    cmds.append(ReduceTiedGrads())\r\n                    cmds.append(ReduceGrads())\r\n                    cmds.append(OptimizerStep())\r\n                    yield cmds\r\n                continue\r\n            curr_buffer = self._buffer_idx(micro_batch_id)\r\n\r\n            # try to find the next valid batch\r\n            next_step_id = step_id + 1\r\n            next_micro_batch_id, next_is_forward, next_batch_is_valid = None, None, None\r\n            while next_step_id < total_steps:\r\n                next_micro_batch_id, next_is_forward = self._step_to_micro_batch(next_step_id)\r\n                next_batch_is_valid = self._valid_micro_batch(next_micro_batch_id)\r\n                if next_batch_is_valid:\r\n                    break\r\n                next_step_id += 1\r\n\r\n            next_buffer = None\r\n            if next_batch_is_valid:\r\n                next_buffer = self._buffer_idx(next_micro_batch_id)\r\n\r\n            if micro_batch_id == 0 and is_forward:\r\n                # first/last stage loads\r\n                if self.stage_id == 0 or self.stage_id == self.stages - 1:\r\n                    cmds.append(LoadMicroBatch(curr_buffer))\r\n                # fetch\r\n                if self._valid_stage(self.prev_stage):\r\n                    cmds.append(BytePSRecvActivation(curr_buffer))\r\n                # pre-fetch\r\n                if next_batch_is_valid:\r\n                    if self._valid_stage(self.prev_stage) and next_is_forward:\r\n                        cmds.append(BytePSRecvActivation(next_buffer))\r\n                    if self._valid_stage(self.next_stage) and not next_is_forward:\r\n                        cmds.append(BytePSRecvGrad(next_buffer))\r\n                # sync and compute\r\n                if self._valid_stage(self.prev_stage):\r\n                    cmds.append(BytePSSyncActivation(curr_buffer))\r\n                cmds.append(BytePSForwardPass(curr_buffer))\r\n                if self._valid_stage(self.next_stage):\r\n                    cmds.append(BytePSSendActivation(curr_buffer))\r\n            else:\r\n                # prefetch\r\n                if next_batch_is_valid:\r\n                    if self._valid_stage(self.prev_stage) and next_is_forward:\r\n                        cmds.append(BytePSRecvActivation(next_buffer))\r\n                    if self._valid_stage(self.next_stage) and not next_is_forward:\r\n                        cmds.append(BytePSRecvGrad(next_buffer))\r\n                if is_forward:\r\n                    if self.stage_id == 0 or self.stage_id == self.stages - 1:\r\n                        # First/last stage loads\r\n                        cmds.append(LoadMicroBatch(curr_buffer))\r\n                    if self._valid_stage(self.prev_stage):\r\n                        cmds.append(BytePSSyncActivation(curr_buffer))\r\n                    cmds.append(BytePSForwardPass(curr_buffer))\r\n                    if self._valid_stage(self.next_stage):\r\n                        cmds.append(BytePSSendActivation(curr_buffer))\r\n                else:\r\n                    if self._valid_stage(self.next_stage):\r\n                        cmds.append(BytePSSyncGrad(curr_buffer))\r\n                    cmds.append(BytePSBackwardPass(curr_buffer))\r\n                    if self._valid_stage(self.prev_stage):\r\n                        cmds.append(BytePSSendGrad(curr_buffer))\r\n\r\n            # Model step at the end of the batch\r\n            if step_id == total_steps - 1:\r\n                cmds.append(BytePSSyncAll())\r\n                cmds.append(ReduceTiedGrads())\r\n                cmds.append(ReduceGrads())\r\n                cmds.append(OptimizerStep())\r\n\r\n            yield cmds\r\n\r\n    def _steps_no_prefetch(self):\r\n        \"\"\"\"\"\"\r\n        total_steps = 2 * (self.micro_batches + self.stages - 1)\r\n        for step_id in range(total_steps):\r\n            # Map the step of the pipeline to the micro-batch id and also whether it is a\r\n            # forward or backward pass step.\r\n            cmds = []\r\n            micro_batch_id, is_forward = self._step_to_micro_batch(step_id)\r\n            batch_is_valid = self._valid_micro_batch(micro_batch_id)\r\n            if not batch_is_valid:\r\n                if step_id == total_steps - 1:\r\n                    cmds.append(BytePSSyncAll())\r\n                    cmds.append(ReduceTiedGrads())\r\n                    cmds.append(ReduceGrads())\r\n                    cmds.append(OptimizerStep())\r\n                    yield cmds\r\n                continue\r\n\r\n            curr_buffer = self._buffer_idx(micro_batch_id)\r\n\r\n            if is_forward:\r\n                if self._valid_stage(self.prev_stage):\r\n                    cmds.append(BytePSRecvActivation(curr_buffer))\r\n                    cmds.append(BytePSSyncActivation(curr_buffer))\r\n                if self.stage_id == 0 or self.stage_id == self.stages - 1:\r\n                    # First/last stage loads\r\n                    cmds.append(LoadMicroBatch(curr_buffer))\r\n                cmds.append(BytePSForwardPass(curr_buffer))\r\n                if self._valid_stage(self.next_stage):\r\n                    cmds.append(BytePSSendActivation(curr_buffer))\r\n            else:\r\n                if self._valid_stage(self.next_stage):\r\n                    cmds.append(BytePSRecvGrad(curr_buffer))\r\n                    cmds.append(BytePSSyncGrad(curr_buffer))\r\n                cmds.append(BytePSBackwardPass(curr_buffer))\r\n                if self._valid_stage(self.prev_stage):\r\n                    cmds.append(BytePSSendGrad(curr_buffer))\r\n\r\n            # Model step at the end of the batch\r\n            if step_id == total_steps - 1:\r\n                cmds.append(BytePSSyncAll())\r\n                cmds.append(ReduceTiedGrads())\r\n                cmds.append(ReduceGrads())\r\n                cmds.append(OptimizerStep())\r\n\r\n            yield cmds\r\n\r\n    def num_pipe_buffers(self):\r\n        \"\"\"As many buffers as the distance from this stage to the last stage.\r\n        \"\"\"\r\n        buffers = min(self.micro_batches, self.stages * 2)\r\n        if BYTEPS_REDUCED_MEM:\r\n            buffers = min(self.stages + 1, self.micro_batches)\r\n        return max(2, buffers)\r\n\r\n\r\nclass BytePSSendActivation(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSRecvActivation(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSSyncActivation(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSSyncGrad(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSSendGrad(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSRecvGrad(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSForwardPass(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSBackwardPass(BufferOpInstruction):\r\n    pass\r\n\r\nclass BytePSSyncAll(PipeInstruction):\r\n    pass\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n"
  },
  {
    "path": "src/veGiantModel/engine/topology.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\r\n# Copyright 2019 The Microsoft DeepSpeed Team\r\n\r\nfrom deepspeed.utils import log_dist\r\n\r\nimport torch.distributed as dist\r\n\r\nfrom collections import namedtuple\r\nfrom itertools import product as cartesian_product\r\nimport logging, os\r\n\r\nimport torch\r\n\r\nclass ProcessTopology:\r\n    \"\"\" Manages the mapping of n-dimensional Cartesian coordinates to linear\r\n    indices. This mapping is used to map the rank of processes to the grid\r\n    for various forms of parallelism.\r\n\r\n    Each axis of the tensor is accessed by its name. The provided ordering\r\n    of the axes defines the layout of the topology. ProcessTopology uses a \"row-major\"\r\n    layout of the tensor axes, and so axes=['x', 'y'] would map coordinates (x,y) and\r\n    (x,y+1) to adjacent linear indices. If instead axes=['y', 'x'] was used, coordinates\r\n    (x,y) and (x+1,y) would be adjacent.\r\n\r\n    Some methods return ProcessCoord namedtuples.\r\n    \"\"\"\r\n    def __init__(self, axes, dims):\r\n        \"\"\"Create a mapping of n-dimensional tensor coordinates to linear indices.\r\n\r\n        Arguments:\r\n            axes (list): the names of the tensor axes\r\n            dims (list): the dimension (length) of each axis of the topology tensor\r\n        \"\"\"\r\n\r\n        self.axes = axes  # names of each topology axis\r\n        self.dims = dims  # length of each topology axis\r\n\r\n        # This is actually a class that lets us hash {'row':3, 'col':2} mappings\r\n        self.ProcessCoord = namedtuple('ProcessCoord', axes)\r\n\r\n        self.mapping = {}\r\n        ranges = [range(d) for d in dims]\r\n        # example: 1, (0,0,1)\r\n        for global_rank, coord in enumerate(cartesian_product(*ranges)):\r\n            key = {axis: coord[self.axes.index(axis)] for axis in self.axes}\r\n            key = self.ProcessCoord(**key)\r\n            # for example, {ProcessCoord(row=0, col=1) : 1}\r\n            self.mapping[key] = global_rank\r\n\r\n    def get_rank(self, **coord_kwargs):\r\n        \"\"\"Return the global rank of a process via its coordinates.\r\n\r\n        Coordinates are specified as kwargs. For example:\r\n\r\n            >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])\r\n            >>> X.get_rank(x=0, y=1)\r\n            1\r\n        \"\"\"\r\n        if len(coord_kwargs) != len(self.axes):\r\n            raise ValueError('get_rank() does not support slices. Use filter_match())')\r\n\r\n        key = self.ProcessCoord(**coord_kwargs)\r\n        assert key in self.mapping, f'key {kwargs} invalid'\r\n        return self.mapping[key]\r\n\r\n    def get_axis_names(self):\r\n        \"\"\"Return a list of the axis names in the ordering of the topology. \"\"\"\r\n        return self.axes\r\n\r\n    def get_rank_repr(self,\r\n                      rank,\r\n                      omit_axes=['data',\r\n                                 'pipe'],\r\n                      inner_sep='_',\r\n                      outer_sep='-'):\r\n        \"\"\"Return a string representation of a rank.\r\n\r\n        This method is primarily used for checkpointing model data.\r\n\r\n        For example:\r\n            >>> topo = Topo(axes=['a', 'b'], dims=[2, 2])\r\n            >>> topo.get_rank_repr(rank=3)\r\n            'a_01-b_01'\r\n            >>> topo.get_rank_repr(rank=3, omit_axes=['a'])\r\n            'b_01'\r\n\r\n        Args:\r\n            rank (int): A rank in the topology.\r\n            omit_axes (list, optional): Axes that should not be in the representation. Defaults to ['data', 'pipe'].\r\n            inner_sep (str, optional): [description]. Defaults to '_'.\r\n            outer_sep (str, optional): [description]. Defaults to '-'.\r\n\r\n        Returns:\r\n            str: A string representation of the coordinate owned by ``rank``.\r\n        \"\"\"\r\n        omit_axes = frozenset(omit_axes)\r\n        axes = [a for a in self.get_axis_names() if a not in omit_axes]\r\n        names = []\r\n        for ax in axes:\r\n            ax_rank = getattr(self.get_coord(rank=rank), ax)\r\n            names.append(f'{ax}{inner_sep}{ax_rank:02d}')\r\n        return outer_sep.join(names)\r\n\r\n    def get_dim(self, axis):\r\n        \"\"\"Return the number of processes along the given axis.\r\n\r\n        For example:\r\n            >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])\r\n            >>> X.get_dim('y')\r\n            3\r\n        \"\"\"\r\n        if axis not in self.axes:\r\n            return 0\r\n        return self.dims[self.axes.index(axis)]\r\n\r\n    def get_coord(self, rank):\r\n        \"\"\"Return the coordinate owned by a process rank.\r\n\r\n        The axes of the returned namedtuple can be directly accessed as members. For\r\n        example:\r\n            >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])\r\n            >>> coord = X.get_coord(rank=1)\r\n            >>> coord.x\r\n            0\r\n            >>> coord.y\r\n            1\r\n        \"\"\"\r\n        for coord, idx in self.mapping.items():\r\n            if idx == rank:\r\n                return coord\r\n        raise ValueError(f'rank {rank} not found in topology.')\r\n\r\n    def get_axis_comm_lists(self, axis):\r\n        \"\"\" Construct lists suitable for a communicator group along axis ``axis``.\r\n\r\n        Example:\r\n            >>> topo = Topo(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])\r\n            >>> topo.get_axis_comm_lists('pipe')\r\n            [\r\n                [0, 4], # data=0, model=0\r\n                [1, 5], # data=0, model=1\r\n                [2, 6], # data=1, model=0\r\n                [3, 7], # data=1, model=1\r\n            ]\r\n\r\n        Returns:\r\n            A list of lists whose coordinates match in all axes *except* ``axis``.\r\n        \"\"\"\r\n\r\n        # We don't want to RuntimeError because it allows us to write more generalized\r\n        # code for hybrid parallelisms.\r\n        if axis not in self.axes:\r\n            return []\r\n\r\n        # Grab all axes but `axis`\r\n        other_axes = [a for a in self.axes if a != axis]\r\n\r\n        lists = []\r\n\r\n        # Construct all combinations of coords with other_axes\r\n        ranges = [range(self.get_dim(a)) for a in other_axes]\r\n        for coord in cartesian_product(*ranges):\r\n            other_keys = {a: coord[other_axes.index(a)] for a in other_axes}\r\n            # now go over all ranks in `axis`.\r\n            sub_list = []\r\n            for axis_key in range(self.get_dim(axis)):\r\n                key = self.ProcessCoord(**other_keys, **{axis: axis_key})\r\n                sub_list.append(self.mapping[key])\r\n            lists.append(sub_list)\r\n\r\n        return lists\r\n\r\n    def filter_match(self, **filter_kwargs):\r\n        \"\"\"Return the list of ranks whose coordinates match the provided criteria.\r\n\r\n        Example:\r\n            >>> X = ProcessTopology(axes=['pipe', 'data', 'model'], dims=[2, 2, 2])\r\n            >>> X.filter_match(pipe=0, data=1)\r\n            [2, 3]\r\n            >>> [X.get_coord(rank) for rank in X.filter_match(pipe=0, data=1)]\r\n            [ProcessCoord(pipe=0, data=1, model=0), ProcessCoord(pipe=0, data=1, model=1)]\r\n\r\n        Arguments:\r\n            **filter_kwargs (dict): criteria used to select coordinates.\r\n\r\n        Returns:\r\n            The list of ranks whose coordinates match filter_kwargs.\r\n        \"\"\"\r\n        def _filter_helper(x):\r\n            for key, val in filter_kwargs.items():\r\n                if getattr(x, key) != val:\r\n                    return False\r\n            return True\r\n\r\n        coords = filter(_filter_helper, self.mapping.keys())\r\n        return [self.mapping[coo] for coo in coords]\r\n\r\n    def get_axis_list(self, axis, idx):\r\n        \"\"\"Returns the list of global ranks whose coordinate in an axis is idx.\r\n\r\n        For example:\r\n            >>> X = ProcessTopology(axes=['x', 'y'], dims=[2,3])\r\n            >>> X.get_axis_list(axis='x', idx=0)\r\n            [0, 1, 2]\r\n            >>> X.get_axis_list(axis='y', idx=0)\r\n            [0, 3]\r\n        \"\"\"\r\n\r\n        # This could be faster by generating the desired keys directly instead of\r\n        # filtering.\r\n        axis_num = self.axes.index(axis)\r\n        ranks = [self.mapping[k] for k in self.mapping.keys() if k[axis_num] == idx]\r\n        return ranks\r\n\r\n    def world_size(self):\r\n        return len(self.mapping)\r\n\r\n    def __str__(self):\r\n        return str(self.mapping)\r\n\r\n\r\ndef _prime_factors(N):\r\n    \"\"\" Returns the prime factorization of positive integer N. \"\"\"\r\n    if N <= 0:\r\n        raise ValueError(\"Values must be strictly positive.\")\r\n\r\n    primes = []\r\n    while N != 1:\r\n        for candidate in range(2, N + 1):\r\n            if N % candidate == 0:\r\n                primes.append(candidate)\r\n                N //= candidate\r\n                break\r\n    return primes\r\n\r\n\r\nclass PipeDataParallelTopology(ProcessTopology):\r\n    \"\"\" A topology specialiation for hybrid data and pipeline parallelism.\r\n\r\n        Uses data parallelism on the last dimension to encourage gradient\r\n        reductions to use high-bandwidth intra-node links and lower-volume\r\n        pipeline communications to use low-bandwidth inter-node links.\r\n    \"\"\"\r\n    def __init__(self, num_pp, num_dp):\r\n        super().__init__(axes=['pipe', 'data'], dims=[num_pp, num_dp])\r\n\r\n\r\nclass PipeModelDataParallelTopology(ProcessTopology):\r\n    \"\"\" A topology for hybrid pipeline, model, and data parallelism. \"\"\"\r\n    def __init__(self, num_dp, num_pp, num_mp):\r\n        # super().__init__(axes=['model', 'data', 'pipe'], dims=[num_mp, num_dp, num_pp])\r\n        super().__init__(axes=['pipe', 'data', 'model'], dims=[num_pp, num_dp, num_mp])\r\n\r\n\r\nclass PipelineParallelGrid:\r\n    \"\"\"Implements a grid object that stores the data parallel ranks\r\n    corresponding to each o the model parallel stages\r\n\r\n    The grid object organizes the processes in a distributed pytorch job\r\n    into a 2D grid, of stage_id and data_parallel_id.\r\n\r\n    self.stage_id and self.data_parallel_id stores the stage id\r\n    and the data parallel id of current process.\r\n\r\n    self.dp_group groups the processes by stage_id.\r\n    self.dp_group[i], is a list containing all process ranks whose\r\n    stage_id is i.\r\n\r\n    self.p2p_groups stores a list of tuple, where each tuple\r\n    stores process ranks of adjacent stages for a given data_parallel_id.\r\n    For example if num_stage is 5 then a tuple [7,8] represents stages [3, 4],\r\n    with data_parallel id = 1. A stage wrap around will appear as non-adjacent ranks,\r\n    for example tuple [4,0] with representing wrap-around stage 4 and 0, for\r\n    data_parallel_id = 0, or similarly [9,5] represents wrapped around stages [4,0]\r\n    for data_parallel_id = 1.\r\n    \"\"\"\r\n    def __init__(self, topology=None, process_group=None):\r\n        # TODO use process_group if provided\r\n        self.global_rank = dist.get_rank()\r\n        self.world_size = dist.get_world_size()\r\n        if topology is not None:\r\n            log_dist(f'building PipelineParallelGrid with topology: {topology}', ranks=[-1], level=logging.DEBUG)\r\n            self._topo = topology\r\n        else:\r\n            num_pp = 1\r\n            num_dp = 1\r\n            for idx, prime in enumerate(_prime_factors(self.world_size)):\r\n                if idx % 2 == 0:\r\n                    num_pp *= prime\r\n                else:\r\n                    num_dp *= prime\r\n            self._topo = PipeDataParallelTopology(num_dp=num_dp, num_pp=num_pp)\r\n        self.data_parallel_size = max(self._topo.get_dim('data'), 1)\r\n        self.pipe_parallel_size = max(self._topo.get_dim('pipe'), 1)\r\n        self.model_parallel_size = max(self._topo.get_dim('model'), 1)\r\n        assert self._is_grid_valid(), \"Invalid Grid\"\r\n\r\n        self.stage_id = self.get_stage_id()\r\n        self.data_parallel_id = self.get_data_parallel_id()\r\n        self.model_parallel_id = self.get_model_parallel_id()\r\n        self.slice_parallel_src_id = self.get_src_parallel_src_id()\r\n        log_dist(f'stage_id: {self.stage_id}, slice_parallel_src_id: {self.slice_parallel_src_id}', ranks=[-1], level=logging.DEBUG)\r\n        # Create new ProcessGroups for all model parallelism. DeepSpeedLight uses these\r\n        # to detect overflow, etc.\r\n\r\n\r\n        self.ds_model_proc_group = None\r\n        self.ds_model_rank = -1\r\n        for dp in range(self.data_parallel_size):\r\n            ranks = sorted(self._topo.get_axis_list(axis='data', idx=dp))\r\n            if self.global_rank == 0:\r\n                #print(f'RANK={self.global_rank} building DeepSpeed model group: {ranks}')\r\n                pass\r\n            proc_group = dist.new_group(ranks=ranks)\r\n\r\n            if self.global_rank in ranks:\r\n                log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \\\r\n                    stage_id: {self.stage_id}, building ds model group: {ranks}', ranks=[-1], level=logging.DEBUG)\r\n                self.ds_model_proc_group = proc_group\r\n                self.ds_model_world_size = len(ranks)\r\n                self.ds_model_rank = ranks.index(self.global_rank)\r\n        assert self.ds_model_rank > -1\r\n        assert self.ds_model_proc_group is not None\r\n\r\n        # Create new ProcessGroup for gradient all-reduces - these are the data parallel groups\r\n        self.dp_group = []\r\n        self.dp_groups = self._topo.get_axis_comm_lists('data')\r\n        for g in self.dp_groups:\r\n            proc_group = dist.new_group(ranks=g)\r\n            if self.global_rank in g:\r\n                log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \\\r\n                    stage_id: {self.stage_id}, building dp group: {g}', ranks=[-1], level=logging.DEBUG)\r\n                self.dp_group = g\r\n                self.dp_proc_group = proc_group\r\n\r\n        self.is_first_stage = (self.stage_id == 0)\r\n        self.is_last_stage = (self.stage_id == (self.pipe_parallel_size - 1))\r\n\r\n        self.p2p_groups = self._build_p2p_groups()\r\n        self._build_grads_groups()\r\n        self._build_activation_groups()\r\n\r\n        self._build_grads_groups()\r\n\r\n        self._build_activation_groups()\r\n\r\n        # Create new ProcessGroup for pipeline collectives - these are pipe parallel groups\r\n        self.pp_group = []\r\n        self.pp_proc_group = None\r\n        self.pipe_groups = self._topo.get_axis_comm_lists('pipe')\r\n        for ranks in self.pipe_groups:\r\n            # if self.global_rank == 0:\r\n            #     #print(f'RANK={self.global_rank} building pipeline group: {ranks}')\r\n            #     pass\r\n            proc_group = dist.new_group(ranks=ranks)\r\n            if self.global_rank in ranks:\r\n                log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id},\\\r\n                    stage_id: {self.stage_id}, building pipeline group: {ranks}', \\\r\n                    ranks=[-1], level=logging.DEBUG)\r\n                self.pp_group = ranks\r\n                self.pp_proc_group = proc_group\r\n        assert self.pp_proc_group is not None\r\n        \r\n        # Create new ProcessGroup for model (tensor-slicing) collectives\r\n\r\n        # Short circuit case without model parallelism.\r\n        # TODO: it would be nice if topology had bcast semantics to avoid this branching\r\n        # case?\r\n        if self.model_parallel_size == 1:\r\n            for group_rank in range(self.world_size):\r\n                group_rank = [group_rank]\r\n                group = dist.new_group(ranks=group_rank)\r\n                if group_rank[0] == self.global_rank:\r\n                    self.slice_group = group_rank\r\n                    self.slice_proc_group = group\r\n            return\r\n        else:\r\n            self.mp_group = []\r\n            self.model_groups = self._topo.get_axis_comm_lists('model')\r\n            for g in self.model_groups:\r\n                proc_group = dist.new_group(ranks=g)\r\n                if self.global_rank in g:\r\n                    log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: {self.model_parallel_id}, \\\r\n                        stage_id: {self.stage_id}, building slice group: {g}', ranks=[-1], level=logging.DEBUG)\r\n                    self.slice_group = g\r\n                    self.slice_proc_group = proc_group\r\n\r\n    def get_stage_id(self):\r\n        return self._topo.get_coord(rank=self.global_rank).pipe\r\n\r\n    def get_data_parallel_id(self):\r\n        return self._topo.get_coord(rank=self.global_rank).data\r\n    \r\n    def get_model_parallel_id(self):\r\n        if 'model' in self._topo.get_axis_names():\r\n            return self._topo.get_coord(rank=self.global_rank).model\r\n        return 0\r\n\r\n    def get_src_parallel_src_id(self):\r\n        if 'model' not in self._topo.get_axis_names():\r\n            return 0\r\n        return self.stage_to_global(stage_id=self.stage_id,\r\n                                    data=self.data_parallel_id,\r\n                                    model=0)\r\n\r\n    def _build_p2p_groups(self):\r\n        \"\"\"Groups for sending and receiving activations and gradients across model\r\n        parallel stages.\r\n        \"\"\"\r\n        comm_lists = self._topo.get_axis_comm_lists('pipe')\r\n        log_dist(f'_build_p2p_groups data_parallel_id: {self.data_parallel_id}, \\\r\n            model_parallel_id: {self.model_parallel_id}, stage_id: {self.stage_id}, \\\r\n            comm_lists: {comm_lists}', ranks=[-1], level=logging.DEBUG)\r\n\r\n        p2p_lists = []\r\n        for rank in range(self.world_size):\r\n            for l in comm_lists:\r\n                assert len(l) == self.pipe_parallel_size\r\n                if rank in l:\r\n                    idx = l.index(rank)\r\n                    buddy_rank = l[(idx + 1) % self.pipe_parallel_size]\r\n                    p2p_lists.append([rank, buddy_rank])\r\n                    break  # next global rank\r\n        assert len(p2p_lists) == self.world_size\r\n        log_dist(f'data_parallel_id: {self.data_parallel_id}, model_parallel_id: \\\r\n            {self.model_parallel_id}, stage_id: {self.stage_id}, \\\r\n            p2p_lists: {p2p_lists}', ranks=[-1], level=logging.DEBUG)\r\n        return p2p_lists\r\n    \r\n    def _build_grads_groups(self):\r\n        self.send_grads_src_rank = -1\r\n        self.recv_grads_src_rank = -1\r\n\r\n        self.send_grads_group = []\r\n        self.recv_grads_group = []\r\n\r\n        self.send_grads_proc_group = None\r\n        self.recv_grads_proc_group = None\r\n        self.grads_proc_groups = []\r\n\r\n        for dp_id in range(self.data_parallel_size):\r\n            for stage in range(self.pipe_parallel_size):\r\n                next_stage = stage + 1\r\n                prev_stage = stage - 1\r\n\r\n                grads_group = []\r\n                grads_proc_group = None\r\n            \r\n                if prev_stage > -1:\r\n                    grads_src_rank = self._topo.filter_match(data=dp_id, pipe=stage, model=0)[0]\r\n                    prev_mp_group = self._topo.filter_match(data=dp_id, pipe=prev_stage)\r\n                    grads_group.append(grads_src_rank)\r\n                    grads_group.extend(prev_mp_group)\r\n                    grads_group.sort()\r\n                    # log_dist(f'_build_grads_groups stage: {stage}, grads_group: {grads_group}', ranks=[-1])\r\n                    grads_proc_group = dist.new_group(ranks=grads_group)\r\n                    self.grads_proc_groups.append(grads_proc_group)\r\n                    if stage == self.stage_id and self.data_parallel_id == dp_id:\r\n                        self.send_grads_src_rank = grads_src_rank\r\n                        self.send_grads_group = grads_group\r\n                        self.send_grads_proc_group = grads_proc_group\r\n                    \r\n                    elif stage == self.stage_id + 1 and self.data_parallel_id == dp_id:\r\n                        self.recv_grads_src_rank = grads_src_rank\r\n                        self.recv_grads_group = grads_group\r\n                        self.recv_grads_proc_group = grads_proc_group\r\n        log_dist(f'_build_grads_groups stage: {self.stage_id}, send_grads_src_rank : {self.send_grads_src_rank}, '\r\n                f'send_grads_group: {self.send_grads_group}, recv_grads_group: {self.recv_grads_group}', \\\r\n                ranks=[-1], level=logging.DEBUG)\r\n\r\n    def _build_activation_groups(self):\r\n        self.send_activation_src_rank = -1\r\n        self.recv_activation_src_rank = -1\r\n\r\n        self.send_activation_group = []\r\n        self.recv_activation_group = []\r\n\r\n        self.send_activation_proc_group = None\r\n        self.recv_activation_proc_group = None\r\n        self.activation_proc_groups = []\r\n\r\n        for dp_id in range(self.data_parallel_size):\r\n            for stage in range(self.pipe_parallel_size):\r\n                next_stage = stage + 1\r\n                prev_stage = stage - 1\r\n\r\n                activation_group = []\r\n                activation_proc_group = None\r\n            \r\n                if next_stage < self.pipe_parallel_size:\r\n                    activation_src_rank = self._topo.filter_match(data=dp_id, pipe=stage, model=0)[0]\r\n                    next_mp_group = self._topo.filter_match(data=dp_id, pipe=next_stage)\r\n                    activation_group.append(activation_src_rank)\r\n                    activation_group.extend(next_mp_group)\r\n                    activation_group.sort()\r\n                    activation_proc_group = dist.new_group(ranks=activation_group)\r\n                    self.activation_proc_groups.append(activation_proc_group)\r\n                    if stage == self.stage_id and self.data_parallel_id == dp_id:\r\n                        self.send_activation_src_rank = activation_src_rank\r\n                        self.send_activation_group = activation_group\r\n                        self.send_activation_proc_group = activation_proc_group\r\n                    elif stage == self.stage_id - 1 and self.data_parallel_id == dp_id:\r\n                        self.recv_activation_src_rank = activation_src_rank\r\n                        self.recv_activation_group = activation_group\r\n                        self.recv_activation_proc_group = activation_proc_group\r\n        log_dist(f'_build_activation_groups stage: {self.stage_id}, send_activation_src_rank : '\\\r\n            f'{self.send_activation_src_rank}, send_activation_group: {self.send_activation_group}, '\\\r\n            f'recv_grads_group: {self.recv_grads_group}', ranks=[-1], level=logging.DEBUG)\r\n\r\n    def _is_grid_valid(self):\r\n        ranks = 1\r\n        for ax in self._topo.get_axis_names():\r\n            ranks *= self._topo.get_dim(ax)\r\n        return ranks == dist.get_world_size()\r\n\r\n    #returns the global rank of the process with the provided stage id\r\n    #which has the same data_parallel_id as caller process\r\n    def stage_to_global(self, stage_id, **kwargs):\r\n        me = self._topo.get_coord(self.global_rank)\r\n        transform = me._replace(pipe=stage_id, **kwargs)._asdict()\r\n        return self._topo.get_rank(**transform)\r\n\r\n    #returns the byteps rank of the process with the provided stage id\r\n    def stage_to_byteps(self, stage_id):\r\n        return self.pipe_parallel_size * self.data_parallel_id + stage_id\r\n\r\n    def topology(self):\r\n        return self._topo\r\n\r\n    # MPU functions for DeepSpeed integration\r\n    def get_global_rank(self):\r\n        return self.global_rank\r\n\r\n    def get_pipe_parallel_rank(self):\r\n        \"\"\" The stage of the pipeline this rank resides in. \"\"\"\r\n        return self.stage_id\r\n\r\n    def get_pipe_parallel_world_size(self):\r\n        \"\"\" The number of stages in the pipeline. \"\"\"\r\n        return self.pipe_parallel_size\r\n\r\n    def get_pipe_parallel_group(self):\r\n        \"\"\" The group of ranks within the same pipeline. \"\"\"\r\n        return self.pp_proc_group\r\n\r\n    def get_data_parallel_rank(self):\r\n        \"\"\" Which pipeline this rank resides in. \"\"\"\r\n        return self.data_parallel_id\r\n\r\n    def get_data_parallel_world_size(self):\r\n        \"\"\" The number of pipelines. \"\"\"\r\n        return self.data_parallel_size\r\n\r\n    def get_data_parallel_group(self):\r\n        \"\"\" The group of ranks within the same stage of all pipelines. \"\"\"\r\n        return self.dp_proc_group\r\n\r\n    # These are model parallel groups across all types of model parallelism.\r\n    # Deepspeed uses them to detect overflow, etc.\r\n    def get_model_parallel_rank(self):\r\n        return self.model_parallel_id\r\n\r\n    def get_model_parallel_world_size(self):\r\n        return self.model_parallel_size\r\n\r\n    def get_model_parallel_group(self):\r\n        return self.slice_proc_group\r\n\r\n    # For Megatron-style tensor slicing\r\n    def get_slice_parallel_rank(self):\r\n        return self.model_parallel_id\r\n\r\n    def get_slice_parallel_world_size(self):\r\n        return self.model_parallel_size\r\n\r\n    def get_slice_parallel_group(self):\r\n        return self.slice_proc_group\r\n\r\n    def get_slice_parallel_src_rank(self):\r\n        return self.slice_parallel_src_id\r\n"
  },
  {
    "path": "src/veGiantModel/initialize.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\nimport torch\nimport os\nimport random\nimport numpy as np\n\nfrom megatron.global_vars import set_global_variables\nfrom megatron import get_args, mpu, print_rank_0\nfrom .engine.topology import PipeModelDataParallelTopology, PipelineParallelGrid\nfrom .launcher.launch import launch_bps\nfrom deepspeed.utils import log_dist\nimport logging\n\n\n\ndef add_byte_giant_model_customize_args(parser):\n    import deepspeed\n    parser = deepspeed.add_config_arguments(parser)\n    group = parser.add_argument_group(title='bytedance')\n    group.add_argument('--cpu-optimizer', action='store_true',\n                       help='Run optimizer on CPU')\n    group.add_argument('--cpu_torch_adam', action='store_true',\n                       help='Use Torch Adam as optimizer on CPU.')\n    group.add_argument('--vocab-size', type=int, default=1000,\n                       help='vocab size.')\n    group.add_argument('--train-batch-size', type=int, default=0,\n                       help='global batch size')\n    group.add_argument('--train_micro_batch_size_per_gpu', type=int, default=0,\n                       help='Batch size per model instance (for deepspeed). '\n                       'Global batch size is local batch size times data '\n                       'parallel size.')\n    group.add_argument('--deepspeed-activation-checkpointing', action='store_true',\n                       help='deepspeed_activation_checkpointing.')\n    group.add_argument('--deepspeed-pipeline', action='store_true',\n                       help='enable pipeline parallelism via deepspeed.')\n    group.add_argument('--ci', action='store_true', help=\"run in CI environment\")\n    group.add_argument('--gradient_accumulation_steps', type=int, default=1,\n                        help=\"set gradient_accumulation_steps for deepspeed config\")\n    group.add_argument('--train_batch_size', type=int, default=0,\n                        help=\"train_batch_size\")\n    group.add_argument('--broadcast_activation', action='store_true', help=\"use broadcast to send/recv activation\")\n    group.add_argument('--broadcast_grads', action='store_true', help=\"use broadcast to send/recv grads\")\n    group.add_argument('--partition_method', type=str, default='uniform',\n                       help='the method to partition layers in pipeline parallelism.')\n    group.add_argument('--config_param', type=str, default='',\n                       help='json dict for deepspeed config')\n\n    group.add_argument('--num-stages', type=int, default=1,\n                       help='number of stages')\n    return parser\n\ndef initialize_megatron(extra_args_provider=None, args_defaults={}):\n    set_global_variables(extra_args_provider=add_byte_giant_model_customize_args, args_defaults=args_defaults)\n    args = get_args()\n    init_distribute(args.num_stages, args.model_parallel_size)\n    _set_random_seed(args.seed)\n\ndef _init_topology(num_stages, mp_size):\n    num_pp = num_stages\n    num_mp = mp_size\n    num_dp = (torch.distributed.get_world_size() // num_pp) // num_mp\n    log_dist('rank: {args.rank}, init topology with num_pp:{num_pp}, num_mp:{num_mp}, \\\n        num_dp: {num_dp}', ranks=[-1], level=logging.DEBUG)\n    topology = PipeModelDataParallelTopology(num_pp=num_pp, num_mp=num_mp, num_dp=num_dp)\n    log_dist(f'finish building topology, topology.mapping: {topology.mapping}', \\\n        ranks=[-1], level=logging.DEBUG)\n    return PipelineParallelGrid(topology)\n\ndef _set_random_seed(seed):\n    \"\"\"Set random seed for reproducability.\"\"\"\n    if seed is not None and seed > 0:\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        if torch.cuda.device_count() > 0:\n            mpu.model_parallel_cuda_manual_seed(seed)\n    else:\n        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))\n\ndef init_distribute(num_stages, mp_size,\n                    distributed_backend='nccl', init_method='tcp://'):\n    rank = int(os.getenv('RANK', '0'))\n    world_size = int(os.getenv(\"WORLD_SIZE\", '1'))\n    device_count = torch.cuda.device_count()\n    local_rank = rank % device_count\n\n    if torch.distributed.is_initialized():\n        print_rank_0('torch distributed is already initialized, '\n                'skipping initialization ...')\n    else:\n        print_rank_0('> initializing torch distributed ...')\n       \n        torch.cuda.set_device(local_rank)\n        # Call the init process\n        master_ip = os.getenv('MASTER_ADDR', 'localhost')\n        master_port = os.getenv('MASTER_PORT', '6000')\n        init_method += master_ip + ':' + master_port\n        torch.distributed.init_process_group(\n            backend=distributed_backend,\n            world_size=world_size, rank=rank,\n            init_method=init_method)\n\n    # Set the model-parallel / data-parallel communicators.\n    grid = _init_topology(num_stages, mp_size)\n    mpu.initialize_model_parallel(grid)\n    if num_stages > 1:\n        import byteps.torch as bps\n        assert bps is not None\n        launch_bps(local_rank)\n"
  },
  {
    "path": "src/veGiantModel/launcher/launch.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n#!/usr/bin/python\n\nfrom __future__ import print_function\nimport os\nimport subprocess\nimport threading\nimport sys\nfrom megatron import mpu\nfrom deepspeed.utils import log_dist\nimport logging\n\nclass PropagatingThread(threading.Thread):\n    \"\"\" propagate exceptions to the parent's thread\n    refer to https://stackoverflow.com/a/31614591/9601110\n    \"\"\"\n\n    def run(self):\n        self.exc = None\n        try:\n            if hasattr(self, '_Thread__target'):\n                #  python 2.x\n                self.ret = self._Thread__target(\n                    *self._Thread__args, **self._Thread__kwargs)\n            else:\n                # python 3.x\n                self.ret = self._target(*self._args, **self._kwargs)\n        except BaseException as e:\n            self.exc = e\n\n    def join(self):\n        super(PropagatingThread, self).join()\n        if self.exc:\n            raise self.exc\n        return self.exc\n\ndef launch_scheduler(local_rank):\n    if os.environ['WORKER_RANK'] != '0':\n        return\n\n    if local_rank != 0:\n        return\n\n\n    def scheduler_runner():\n        my_env = os.environ.copy()\n        my_env['DMLC_ROLE'] = 'scheduler'\n        my_env['PS_VERBOSE'] = os.environ.get('PS_VERBOSE', '1')\n        nvidia_smi = f'nvidia-smi -L'\n        devices = os.popen(nvidia_smi).read().strip()\n        if 'A100' in devices:\n            ip_cmd = f'ip addr show eth2'\n            ip = os.popen(ip_cmd + ' | grep \"\\<inet\\>\" | awk \\'{ print $2 }\\' | awk -F \"/\" \\'{ print $1 }\\'').read().strip()\n            my_env['DMLC_NODE_HOST'] = ip\n            my_env['UCX_RDMA_CM_SOURCE_ADDRESS'] = ip\n            os.environ['UCX_NET_DEVICES'] = 'mlx5_2:1,eth0,eth1,eth2,eth3'\n\n        command = \"python3 -c 'import byteps.server'\"\n        subprocess.check_call(command, env=my_env,\n                          stdout=sys.stdout, stderr=sys.stderr, shell=True)\n    t = PropagatingThread(target=scheduler_runner)\n    t.setDaemon(True)\n    t.start()\n\ndef get_worker0_host():\n    host = os.environ['WORKER_0_HOST']\n    return host\n\ndef get_worker0_port():\n    port = os.environ['WORKER_0_PORT']\n    return port\n\ndef setup_env(local_rank):\n    mp_size = mpu.get_model_parallel_world_size()\n\n    num_nodes = int(os.environ['NUM_WORKER'])\n    gpu_per_node = int(os.environ['GPU_PER_WORKER'])\n    assert gpu_per_node >= mp_size\n    assert gpu_per_node % mp_size == 0\n\n    os.environ['BYTEPS_RDMA_START_DEPTH'] = str(32)\n    os.environ['BYTEPS_RDMA_RX_DEPTH'] = str(512)\n\n    os.environ['DMLC_NUM_WORKER'] = str(gpu_per_node * num_nodes)\n    os.environ['DMLC_NUM_SERVER'] = str(gpu_per_node * num_nodes)\n\n    os.environ['BYTEPS_LOCAL_SIZE'] = str(gpu_per_node)\n    os.environ['BYTEPS_FORCE_DISTRIBUTED'] = '1'\n    os.environ['BYTEPS_ENABLE_IPC'] = '0'\n    os.environ['DMLC_PS_ROOT_PORT'] = get_worker0_port()\n    os.environ['DMLC_PS_ROOT_URI'] = get_worker0_host()\n\n    if 'DMLC_ENABLE_RDMA' not in os.environ:\n        os.environ['DMLC_ENABLE_RDMA'] = '1'\n    os.environ['DMLC_ENABLE_UCX'] = os.environ.get('DMLC_ENABLE_UCX', '1')\n    os.environ['UCX_IB_TRAFFIC_CLASS'] = '236'\n    os.environ['UCX_TLS'] = os.environ.get('UCX_TLS', 'rc_x,tcp,sm')\n    nvidia_smi = f'nvidia-smi -L'\n    devices = os.popen(nvidia_smi).read().strip()\n    if 'A100' in devices:\n        nic = 2 # TODO: use multiple NICs with `int(local_rank / 2)`\n        ip_cmd = f'ip addr show eth{nic}'\n        ip = os.popen(ip_cmd + ' | grep \"\\<inet\\>\" | awk \\'{ print $2 }\\' | awk -F \"/\" \\'{ print $1 }\\'').read().strip()\n        os.environ['UCX_RDMA_CM_SOURCE_ADDRESS'] = os.environ.get('UCX_RDMA_CM_SOURCE_ADDRESS', ip)\n        devs = os.environ.get('UCX_NET_DEVICES', f'mlx5_{nic}:1,eth0,eth1,eth2,eth3')\n        os.environ['UCX_NET_DEVICES'] = devs\n        os.environ['DMLC_NODE_HOST'] = os.environ['UCX_RDMA_CM_SOURCE_ADDRESS']\n    elif 'V100' in devices or 'T4' in devices:\n        devs = os.environ.get('UCX_NET_DEVICES', 'mlx5_2:1,eth0,eth2')\n        os.environ['UCX_NET_DEVICES'] = devs\n    else:\n        raise RuntimeError(f\"Unknown devices: {devices}\")\n\ndef launch_bps(local_rank):\n    log_dist(f'launch_bps({local_rank})', ranks=[-1], level=logging.DEBUG)\n    setup_env(local_rank)\n    launch_scheduler(local_rank)"
  },
  {
    "path": "src/veGiantModel/module/__init__.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\nfrom .dense import ColumnSerialLinear, ColumnParallelLinear\nfrom .dense import RowSerialLinear, RowParallelLinear, MockModule\nfrom .dense import ColumnParallelLinearTranspose, ColumnSerialLinearTranspose\n\n__all__ = ['ColumnSerialLinear',\n           'ColumnParallelLinear',\n           'ColumnParallelLinearTranspose',\n           'ColumnSerialLinearTranspose',\n           'RowSerialLinear',\n           'RowParallelLinear',\n           'MockModule']\n"
  },
  {
    "path": "src/veGiantModel/module/dense.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nimport torch.nn as nn\nimport torch.autograd as autograd\n\n# try:\n#     import veGiantModel\n# except ImportError:\n#     byteGiantModel = None\n\nclass MockModule(nn.Module):\n    \"\"\"Module for testing model parallelism\"\"\"\n    pass\n\ntry:\n    from th_fastertransformer import Linear\n\n    class LinearFunction(autograd.Function):\n\n        @staticmethod\n        def forward(ctx, input_tensor, weight, bias, act_gelu=False, dropout_rate=0.0):\n            bias_out = torch.Tensor(0)\n            dropout_mask = torch.Tensor(0)\n            if act_gelu == True or dropout_rate > 0.0:\n                output, bias_out, dropout_mask = Linear.forward_gelu_dropout(input_tensor, weight, bias, act_gelu, dropout_rate)\n            else:\n                output = Linear.forward(input_tensor, weight, bias)\n            ctx.save_for_backward(input_tensor, weight, bias_out, dropout_mask)\n            ctx.act_gelu = act_gelu\n            ctx.dropout_rate = dropout_rate\n            return output\n\n        @staticmethod\n        def backward(ctx, grad_out):\n            act_gelu = ctx.act_gelu\n            dropout_rate = ctx.dropout_rate\n            input_tensor, weight, bias_out, dropout_mask = ctx.saved_tensors\n            if act_gelu == True or dropout_rate > 0.0:\n                grad_in, grad_weight, grad_bias = Linear.backward_gelu_dropout(\n                    grad_out, input_tensor, weight, act_gelu, dropout_rate, bias_out, dropout_mask)\n            else:\n                grad_in, grad_weight, grad_bias = Linear.backward(\n                    grad_out, input_tensor, weight)\n            return grad_in, grad_weight, grad_bias, None, None\n\n    class FTLinear(nn.Module):\n        def __init__(self, in_features, out_features, initializer_range=0.02, act_gelu=False, dropout_rate=0.0):\n            super().__init__()\n\n            self.in_features = in_features\n            self.out_features = out_features\n            self.weight = nn.Parameter(torch.Tensor(out_features, in_features))\n            self.bias = nn.Parameter(torch.Tensor(out_features))\n            self.act_gelu = act_gelu\n            self.dropout_rate = dropout_rate\n\n            self.weight.data.normal_(mean=0.0, std=initializer_range)\n            self.bias.data.zero_()\n\n        def forward(self, input_tensor):\n            return LinearFunction.apply(input_tensor, self.weight, self.bias, self.act_gelu, self.dropout_rate if self.training else 0.)\n\n        def extra_repr(self):\n            return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)\n\nexcept Exception as e:\n    FTLinear = None\n\ntry:\n    from th_fastertransformer import LinearTranspose\n\n    class LinearTransposeFunction(autograd.Function):\n        @staticmethod\n        def forward(ctx, input_tensor, weight, bias, head_num, transpose_type):\n            output = LinearTranspose.forward(input_tensor, weight, bias, head_num, transpose_type)\n            ctx.head_num = head_num\n            ctx.transpose_type = transpose_type\n            ctx.save_for_backward(input_tensor, weight)\n            return output\n\n        @staticmethod\n        def backward(ctx, grad_out):\n            input_tensor, weight = ctx.saved_tensors\n            grad_in, grad_weight, grad_bias = LinearTranspose.backward(grad_out, input_tensor, weight, ctx.head_num, ctx.transpose_type)\n            return grad_in, grad_weight, grad_bias, None, None\n\n    class FTLinearTranspose(nn.Module):\n        def __init__(self, in_features, out_features, head_num, transpose_type=\"0213\", initializer_range=0.02):\n            super().__init__()\n\n            self.in_features = in_features\n            self.out_features = out_features\n            self.head_num = head_num\n            self.transpose_type = transpose_type\n            self.weight = nn.Parameter(torch.Tensor(out_features, in_features))\n            self.bias = nn.Parameter(torch.Tensor(out_features))\n\n            self.weight.data.normal_(mean=0.0, std=initializer_range)\n            self.bias.data.zero_()\n\n        def forward(self, input_tensor):\n            return LinearTransposeFunction.apply(input_tensor, self.weight, self.bias, self.head_num, self.transpose_type)\n\n        def extra_repr(self):\n            return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num)\n\nexcept Exception as e:\n    FTLinearTranspose = None\n    FTDAGather = None\n\ndef column_parallel_load_hook(module, log_fn):\n    \"\"\"hook for column parallel linear's load_state_dict function.\n    It is a helper function to load a the checkpoint from a\n    non-model-parallel module. It returns a hook function that\n    pre-processes the checkpoint to parallel slices such that\n    each model parallel rank could load the corresponding slice.\n\n    Arguments:\n        module: ColumnParallelLinear or ColumnParallelLinearTranspose\n\n        log_fn: function for logging\n\n    Returns:\n        A hook function to help load model parallel modules from non-\n        model-parallel checkpoints.\n    \"\"\"\n    assert module.mp_rank is not None\n    assert module.out_features is not None\n    def hook(state_dict, prefix, local_metadata, strict, missing_keys,\n             unexpected_keys, error_msgs):\n        weight_name = prefix + 'weight'\n        bias_name = prefix + 'bias'\n        if weight_name in state_dict:\n            v = state_dict[weight_name]\n            assert len(v.shape) == 2, v.shape\n            idx_begin = module.mp_rank * module.out_features\n            idx_end = (module.mp_rank + 1) * module.out_features\n            shard = v[idx_begin:idx_end, :]\n            state_dict[weight_name] = shard\n            log_fn(f\"slice param {weight_name}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n        if bias_name in state_dict:\n            v = state_dict[bias_name]\n            assert len(v.shape) == 1, v.shape\n            idx_begin = module.mp_rank * module.out_features\n            idx_end = (module.mp_rank + 1) * module.out_features\n            shard = v[idx_begin:idx_end]\n            state_dict[bias_name] = shard\n            log_fn(f\"slice param {bias_name}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n    return hook\n\ndef column_serial_load_hook(module, log_fn):\n    \"\"\"hook for column serial linear's load_state_dict function.\n    It is a helper function to load a the checkpoint from a\n    non-model-parallel module. It returns a hook function that\n    pre-processes the checkpoint to parallel slices such that\n    each model parallel rank could load the corresponding slice.\n\n    Arguments:\n        module: ColumnSerialLinear or ColumnSerialLinearTranspose\n\n        log_fn: function for logging\n\n    Returns:\n        A hook function to help load model serial modules from non-\n        model-parallel checkpoints.\n    \"\"\"\n    assert module.model_parallel_size is not None\n    assert module.out_features is not None\n    def hook(state_dict, prefix, local_metadata, strict, missing_keys,\n             unexpected_keys, error_msgs):\n        weight_name = prefix + 'weight'\n        bias_name = prefix + 'bias'\n        if weight_name in state_dict:\n            v = state_dict[weight_name]\n            assert len(v.shape) == 2, v.shape\n            for i in range(module.model_parallel_size):\n                weight_name_i = weight_name + \".\" + str(i)\n                idx_begin = i * module.out_features\n                idx_end = (i + 1) * module.out_features\n                shard = v[idx_begin:idx_end, :]\n                state_dict[weight_name_i] = shard\n                log_fn(f\"slice param {weight_name_i}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n            del state_dict[weight_name]\n        if bias_name in state_dict:\n            v = state_dict[bias_name]\n            assert len(v.shape) == 1, v.shape\n            for i in range(module.model_parallel_size):\n                bias_name_i = bias_name + \".\" + str(i)\n                idx_begin = i * module.out_features\n                idx_end = (i + 1) * module.out_features\n                shard = v[idx_begin:idx_end]\n                state_dict[bias_name_i] = shard\n                log_fn(f\"slice param {bias_name_i}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n            del state_dict[bias_name]\n    return hook\n\nclass ColumnSerialLinear(MockModule):\n    def __init__(self, in_features, out_features, initializer_range=0.02,\n                 act_gelu=False, dropout_rate=0.0, load_from_shards=False, use_ft=False):\n        \"\"\"\n        A serial module that mocks the ColumnParallelLinear module. It mocks the parallel\n        logic by applying the series of work on the same rank, and reduce the result if needed.\n        \"\"\"\n        super().__init__()\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.model_parallel_size = model_parallel_size\n        self.in_features = in_features\n        self.out_features = out_features // model_parallel_size\n        assert out_features % model_parallel_size == 0, (out_features, model_parallel_size)\n        weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)]\n        self.weight = nn.ParameterList(weight_params)\n        bias_params = [nn.Parameter(torch.Tensor(self.out_features)) for _ in range(model_parallel_size)]\n        self.bias = nn.ParameterList(bias_params)\n        self.act_gelu = act_gelu\n        self.dropout_rate = dropout_rate\n        for weight in self.weight:\n            weight.data.normal_(mean=0.0, std=initializer_range)\n        for bias in self.bias:\n            bias.data.zero_()\n        self.use_ft = use_ft\n        if not use_ft:\n            assert not act_gelu\n            assert not dropout_rate, dropout_rate\n        if not load_from_shards:\n            load_hook = column_serial_load_hook(self, print)\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        outputs = []\n        for i in range(self.model_parallel_size):\n            if self.use_ft:\n                output_i = LinearFunction.apply(input_tensor, self.weight[i], self.bias[i], self.act_gelu,\n                                                self.dropout_rate if self.training else 0.)\n            else:\n                output_i = nn.functional.linear(input_tensor, self.weight[i], self.bias[i])\n            outputs.append(output_i)\n        output = torch.cat(outputs, dim=-1)\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)\n\nclass ColumnParallelLinear(nn.Module):\n    def __init__(self, in_features, out_features, initializer_range=0.02,\n                 act_gelu=False, dropout_rate=0.0, load_from_shards=False, use_ft=False,\n                 bias=True, gather_output=False):\n        \"\"\"Linear layer with column parallelism.\n\n        The linear layer is defined as Y = dropout(gelu(XA + b)). A is parallelized along\n        its second dimension as A = [A_1, ..., A_p].\n\n        Arguments:\n            in_features: first dimension of matrix A.\n            out_features: second dimension of matrix A.\n            initializer_range: range for weight initialization. Note that bias is always set\n                        to zero.\n            act_gelu: If true, apply gelu activation to (XA+b)\n            dropout_rate: If greater than zero, apply dropout to gelu(XA+b)\n            load_from_shards: If true, load the states from sharded checkpoints. Otherwise,\n                        the module automatically slice the checkpoint tensor based on its\n                        model parallel rank.\n            use_ft: use faster transformer for acceleration.\n            bias: If true, add bias\n            gather_output: If true, call all-gether on output and make Y avaiable\n                        to all GPUs, otherwise, every GPU will have its output\n                        which is Y_i = XA_i\n        \"\"\"\n        super().__init__()\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.in_features = in_features\n        self.out_features = out_features // model_parallel_size\n        assert out_features % model_parallel_size == 0, (out_features, model_parallel_size)\n        self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features))\n        self.weight.data.normal_(mean=0.0, std=initializer_range)\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(self.out_features))\n            self.bias.data.zero_()\n        else:\n            self.bias = None\n            assert not use_ft\n        self.gather_output = gather_output\n        self.act_gelu = act_gelu\n        self.dropout_rate = dropout_rate\n        self.use_ft = use_ft\n        self.mp_rank = veGiantModel.distributed.get_model_parallel_rank()\n        if not use_ft:\n            assert not act_gelu\n            assert not dropout_rate, dropout_rate\n        if not load_from_shards:\n            load_hook = column_parallel_load_hook(self, print)\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        import veGiantModel\n        input_tensor = veGiantModel.distributed.copy_to_model_parallel_region(input_tensor)\n        if self.use_ft:\n            output = LinearFunction.apply(input_tensor, self.weight, self.bias, self.act_gelu,\n                                            self.dropout_rate if self.training else 0.)\n        else:\n            output = nn.functional.linear(input_tensor, self.weight, self.bias)\n        if self.gather_output:\n            output = veGiantModel.distributed.gather_from_model_parallel_region(output)\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)\n\nclass RowSerialLinear(MockModule):\n    def __init__(self, in_features, out_features, initializer_range=0.02, dropout_rate=0.0,\n                 load_from_shards=False, use_ft=False):\n        \"\"\"\n        A serial module that mocks the RowParallelLinear module. It mocks the parallel\n        logic by applying the series of work on the same rank.\n        \"\"\"\n        super().__init__()\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.model_parallel_size = model_parallel_size\n        self.in_features = in_features // model_parallel_size\n        self.out_features = out_features\n        assert in_features % model_parallel_size == 0, (in_features, model_parallel_size)\n        weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)]\n        self.weight = nn.ParameterList(weight_params)\n        self.bias = nn.Parameter(torch.Tensor(self.out_features))\n        self.dropout_rate = dropout_rate\n        for weight in self.weight:\n            weight.data.normal_(mean=0.0, std=initializer_range)\n        self.bias.data.zero_()\n        self.dropout = nn.Dropout(dropout_rate)\n        self.use_ft = use_ft\n        self.mp_rank = veGiantModel.distributed.get_model_parallel_rank()\n        if not load_from_shards:\n            def load_hook(state_dict, prefix, local_metadata, strict, missing_keys,\n                          unexpected_keys, error_msgs):\n                weight_name = prefix + 'weight'\n                if weight_name in state_dict:\n                    v = state_dict[weight_name]\n                    assert len(v.shape) == 2, v.shape\n                    for i in range(model_parallel_size):\n                        weight_name_i = weight_name + '.' + str(i)\n                        idx_begin = i * self.in_features\n                        idx_end = (i + 1) * self.in_features\n                        shard = v[:, idx_begin:idx_end]\n                        state_dict[weight_name_i] = shard\n                        print(f\"slice param {weight_name_i}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n                    del state_dict[weight_name]\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        input_tensors = torch.split(input_tensor, self.in_features, dim=-1)\n        outputs = []\n        for i in range(self.model_parallel_size):\n            if self.use_ft:\n                output_i = LinearFunction.apply(input_tensors[i].contiguous(), self.weight[i], self.bias, False, 0.)\n            else:\n                output_i = nn.functional.linear(input_tensors[i].contiguous(), self.weight[i], self.bias)\n            outputs.append(output_i)\n        output = outputs[0]\n        for i in range(self.model_parallel_size - 1):\n            output = output + outputs[i + 1]\n        if self.dropout_rate:\n            output = self.dropout(output)\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)\n\nclass RowParallelLinear(nn.Module):\n    def __init__(self, in_features, out_features, initializer_range=0.02, dropout_rate=0.0,\n                 load_from_shards=False, use_ft=False):\n        \"\"\"Linear layer with row parallelism.\n\n        The linear layer is defined as Y = XA + b. A is parallelized along\n        its first dimension and X along its second dimension as:\n                -   -\n                | A_1 |\n                | .   |\n            A = | .   |        X = [X_1, ..., X_p]\n                | .   |\n                | A_p |\n                -   -\n\n        Arguments:\n            in_features: first dimension of matrix A.\n            out_features: second dimension of matrix A.\n            initializer_range: range for weight initialization. Note that bias is always set\n                        to zero.\n            dropout_rate: If greater than zero, apply dropout XA+b\n            load_from_shards: If true, load the states from sharded checkpoints. Otherwise,\n                        the module automatically slice the checkpoint tensor based on its\n                        model parallel rank.\n            use_ft: use faster transformer for acceleration.\n        \"\"\"\n        super().__init__()\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.in_features = in_features // model_parallel_size\n        self.out_features = out_features\n        assert in_features % model_parallel_size == 0, (in_features, model_parallel_size)\n        self.weight = nn.Parameter(torch.Tensor(self.out_features, self.in_features))\n        self.bias = nn.Parameter(torch.Tensor(self.out_features))\n        self.dropout_rate = dropout_rate\n        self.weight.data.normal_(mean=0.0, std=initializer_range)\n        self.bias.data.zero_()\n        self.dropout = nn.Dropout(dropout_rate)\n        self.use_ft = use_ft\n        self.mp_rank = veGiantModel.distributed.get_model_parallel_rank()\n        if not load_from_shards:\n            def load_hook(state_dict, prefix, local_metadata, strict, missing_keys,\n                            unexpected_keys, error_msgs):\n                weight_name = prefix + 'weight'\n                if weight_name in state_dict:\n                    v = state_dict[weight_name]\n                    assert len(v.shape) == 2, v.shape\n                    idx_begin = self.mp_rank * self.in_features\n                    idx_end = (self.mp_rank + 1) * self.in_features\n                    shard = v[:, idx_begin:idx_end]\n                    state_dict[weight_name] = shard\n                    print(f\"slice param {weight_name}\\tfor model parallelism: {v.shape} -> {shard.shape}\")\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        if self.use_ft:\n            output = LinearFunction.apply(input_tensor, self.weight, self.bias, False, 0.)\n        else:\n            output = nn.functional.linear(input_tensor, self.weight, self.bias)\n        import veGiantModel\n        output = veGiantModel.distributed.reduce_from_model_parallel_region(output)\n\n        if self.dropout_rate:\n            output = self.dropout(output)\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}'.format(self.in_features, self.out_features)\n\n\nclass ColumnParallelLinearTranspose(nn.Module):\n    def __init__(self, in_features, out_features, head_num, transpose_type=\"0213\", initializer_range=0.02,\n                 use_ft=False, load_from_shards=False):\n        \"\"\"Linear layer with column parallelism. The output is then reshaped to 4D with\n        (dim0, dim1, head_num, out_features / head_num), then permuted with axies provided by transpose_type.\n        For equivalent computation, check the implementation of `ColumnSerialLinearTranspose`.\n\n        The linear layer is defined as Y = XA + b. A is parallelized along\n        its second dimension as A = [A_1, ..., A_p].\n\n        Arguments:\n            in_features: first dimension of matrix A.\n            out_features: second dimension of matrix A.\n            head_num: number of \"heads\" for the out_feature dimension.\n            transpose_type: the axies for permutation on the output.\n            initializer_range: range for weight initialization. Note that bias is always set\n                        to zero.\n            use_ft: use faster transformer for acceleration.\n            load_from_shards: If true, load the states from sharded checkpoints. Otherwise,\n                        the module automatically slice the checkpoint tensor based on its\n                        model parallel rank.\n        \"\"\"\n        super().__init__()\n        self.use_ft = use_ft\n        self.in_features = in_features\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.mp_rank = veGiantModel.distributed.get_model_parallel_rank()\n\n        assert out_features % model_parallel_size == 0, (out_features, model_parallel_size)\n        self.out_features = out_features // model_parallel_size\n        assert head_num % model_parallel_size == 0, (head_num, model_parallel_size)\n        self.head_num = head_num // model_parallel_size\n        self.head_dim = self.out_features // self.head_num\n        self.transpose_type = transpose_type\n        self.weight = nn.Parameter(torch.Tensor(self.out_features, in_features))\n        self.bias = nn.Parameter(torch.Tensor(self.out_features))\n        self.weight.data.normal_(mean=0.0, std=initializer_range)\n        self.bias.data.zero_()\n        if not load_from_shards:\n            load_hook = column_parallel_load_hook(self, print)\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        import veGiantModel\n        input_tensor = veGiantModel.distributed.copy_to_model_parallel_region(input_tensor)\n        if self.use_ft:\n            output = LinearTransposeFunction.apply(input_tensor, self.weight, self.bias,\n                                                    self.head_num, self.transpose_type)\n        else:\n            assert self.transpose_type == \"0213\", self.transpose_type\n            linear_out = nn.functional.linear(input_tensor, self.weight, self.bias)\n            new_shape = linear_out.size()[:-1] + (self.head_num, self.head_dim)\n            linear_out = linear_out.view(*new_shape)\n            output = linear_out.permute(0, 2, 1, 3).contiguous()\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num)\n\nclass ColumnSerialLinearTranspose(MockModule):\n    def __init__(self, in_features, out_features, head_num, transpose_type=\"0213\", initializer_range=0.02,\n                    use_ft=False, load_from_shards=False):\n        \"\"\"\n        A serial module that mocks the ColumnParallelLinearTranspose module. It mocks the parallel\n        logic by applying the series of work on the same rank.\n        \"\"\"\n        super().__init__()\n        self.use_ft = use_ft\n        self.in_features = in_features\n        import veGiantModel\n        model_parallel_size = veGiantModel.distributed.get_model_parallel_world_size()\n        self.model_parallel_size = model_parallel_size\n        self.mp_rank = veGiantModel.distributed.get_model_parallel_rank()\n        assert out_features % model_parallel_size == 0, (out_features, model_parallel_size)\n        self.out_features = out_features // model_parallel_size\n        assert head_num % model_parallel_size == 0, (head_num, model_parallel_size)\n        self.head_num = head_num // model_parallel_size\n        self.head_dim = self.out_features // self.head_num\n        self.transpose_type = transpose_type\n        weight_params = [nn.Parameter(torch.Tensor(self.out_features, self.in_features)) for _ in range(model_parallel_size)]\n        self.weight = nn.ParameterList(weight_params)\n        bias_params = [nn.Parameter(torch.Tensor(self.out_features)) for _ in range(model_parallel_size)]\n        self.bias = nn.ParameterList(bias_params)\n        for weight in self.weight:\n            weight.data.normal_(mean=0.0, std=initializer_range)\n        for bias in self.bias:\n            bias.data.zero_()\n\n        if not load_from_shards:\n            load_hook = column_serial_load_hook(self, print)\n            self._register_load_state_dict_pre_hook(load_hook)\n\n    def forward(self, input_tensor):\n        outputs = []\n        for i in range(self.model_parallel_size):\n            if self.use_ft:\n                output_i = LinearTransposeFunction.apply(input_tensor, self.weight[i], self.bias[i], self.head_num, self.transpose_type)\n            else:\n                assert self.transpose_type == \"0213\", self.transpose_type\n                linear_out = nn.functional.linear(input_tensor, self.weight[i], self.bias[i])\n                new_shape = linear_out.size()[:-1] + (self.head_num, self.head_dim)\n                linear_out = linear_out.view(*new_shape)\n                output_i = linear_out.permute(0, 2, 1, 3).contiguous()\n            outputs.append(output_i)\n        output = torch.cat(outputs, dim=1)\n        return output\n\n    def extra_repr(self):\n        return 'in_features={}, out_features={}, head_num={}'.format(self.in_features, self.out_features, self.head_num)"
  },
  {
    "path": "src/veGiantModel/patcher.py",
    "content": "# Copyright (c) 2021, ByteDance Inc.  All rights reserved.\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n\nimport torch\nprint(\"Loading veGiantModel submodules ...\")\n\n_TOPOLOGY = None\n\ndef is_unitialized():\n    \"\"\"Useful for code segments that may be accessed with or without mpu initialization\"\"\"\n    return _TOPOLOGY is None\n\n\ndef initialize_model_parallel(grid):\n    # Get world size and rank. Ensure some consistencies.\n    assert torch.distributed.is_initialized()\n    global _TOPOLOGY\n    _TOPOLOGY = grid\n\n\ndef model_parallel_is_initialized():\n    \"\"\"Check if model and data parallel groups are initialized.\"\"\"\n    if _TOPOLOGY is None:\n        return False\n    return True\n\n\ndef get_model_parallel_group():\n    \"\"\"Get the parallel group the caller rank belongs to.\"\"\"\n    assert _TOPOLOGY is not None, \\\n        ' parallel group is not initialized'\n    return _TOPOLOGY.get_slice_parallel_group()\n\n\ndef get_data_parallel_group():\n    \"\"\"Get the data parallel group the caller rank belongs to.\"\"\"\n    assert _TOPOLOGY is not None, \\\n        'data parallel group is not initialized'\n    return _TOPOLOGY.get_data_parallel_group()\n\n\ndef set_model_parallel_world_size(world_size):\n    pass\n\n\ndef get_model_parallel_world_size():\n    \"\"\"Return world size for the model parallel group.\"\"\"\n    return _TOPOLOGY.get_slice_parallel_world_size()\n\n\ndef set_model_parallel_rank(rank):\n    pass\n\n\ndef get_model_parallel_rank():\n    \"\"\"Return my rank for the model parallel group.\"\"\"\n    return _TOPOLOGY.get_slice_parallel_rank()\n\n\ndef get_model_parallel_src_rank():\n    return _TOPOLOGY.get_slice_parallel_src_rank()\n\n\ndef get_data_parallel_world_size():\n    \"\"\"Return world size for the data parallel group.\"\"\"\n    return _TOPOLOGY.get_data_parallel_world_size()\n\n\ndef get_data_parallel_rank():\n    \"\"\"Return my rank for the data parallel group.\"\"\"\n    return _TOPOLOGY.get_data_parallel_rank()\n\ndef get_pipe_parallel_rank():\n    return _TOPOLOGY.get_pipe_parallel_rank()\n\ndef destroy_model_parallel():\n    \"\"\"Set the groups to none.\"\"\"\n    global _TOPOLOGY\n    _TOPOLOGY = None\n\ndef get_grid():\n    return _TOPOLOGY\n\ndef get_topo():\n    return _TOPOLOGY.topology()\n\nimport megatron.mpu.initialize as initialize\ninitialize.is_unitialized = is_unitialized\ninitialize.initialize_model_parallel = initialize_model_parallel\ninitialize.model_parallel_is_initialized = model_parallel_is_initialized\ninitialize.get_model_parallel_group = get_model_parallel_group\ninitialize.get_data_parallel_group = get_data_parallel_group\ninitialize.set_model_parallel_world_size = set_model_parallel_world_size\ninitialize.get_model_parallel_world_size = get_model_parallel_world_size\ninitialize.set_model_parallel_rank = set_model_parallel_rank\ninitialize.get_model_parallel_rank = get_model_parallel_rank\ninitialize.get_model_parallel_src_rank = get_model_parallel_src_rank\ninitialize.get_data_parallel_world_size = get_data_parallel_world_size\ninitialize.get_data_parallel_rank = get_data_parallel_rank\ninitialize.get_pipe_parallel_rank = get_pipe_parallel_rank\ninitialize.destroy_model_parallel = destroy_model_parallel\n\nfrom megatron import mpu\nfrom importlib import reload  \nreload(mpu.data)\nreload(mpu.mappings)\nreload(mpu.cross_entropy)\nmpu.get_pipe_parallel_rank = get_pipe_parallel_rank\nreload(mpu)\n\nfrom megatron.mpu import mappings\n\ndef _gather(input_):\n    \"\"\"Gather tensors and concatinate along the last dimension.\"\"\"\n\n    world_size = get_model_parallel_world_size()\n    # Bypass the function if we are using only 1 GPU.\n    if world_size==1:\n        return input_\n\n    # Size and dimension.\n    last_dim = input_.dim() - 1\n    rank = get_model_parallel_rank()\n\n    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]\n    tensor_list[rank] = input_\n    group = get_model_parallel_group()\n    torch.distributed.all_gather(tensor_list, input_, group=group)\n\n    # Note: torch.cat already creates a contiguous tensor.\n    output = torch.cat(tensor_list, dim=last_dim).contiguous()\n\n    return output\n\nmappings._gather = _gather\n\nfrom megatron.tokenizer import tokenizer as token\nfrom megatron.tokenizer.tokenizer import _BertWordPieceTokenizer, _vocab_size_with_padding, _GPT2BPETokenizer\n\ndef build_tokenizer(args):\n    if args.vocab_file is None:\n        args.padded_vocab_size = _vocab_size_with_padding(args.vocab_size,\n                                                    args)\n        return None\n    \"\"\"Initialize tokenizer.\"\"\"\n    if args.rank == 0:\n        print('> building {} tokenizer ...'.format(args.tokenizer_type),\n              flush=True)\n\n    # Select and instantiate the tokenizer.\n    assert args.vocab_file is not None\n    if args.tokenizer_type == 'BertWordPieceLowerCase':\n        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,\n                                            lower_case=True)\n    elif args.tokenizer_type == 'BertWordPieceCase':\n        tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,\n                                            lower_case=False)\n    elif args.tokenizer_type == 'GPT2BPETokenizer':\n        assert args.merge_file is not None\n        tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)\n    else:\n        raise NotImplementedError('{} tokenizer is not '\n                                  'implemented.'.format(args.tokenizer_type))\n\n    # Add vocab size.\n    args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size,\n                                                      args)\n\n    return tokenizer\n\ntoken.build_tokenizer = build_tokenizer\nimport megatron\nreload(megatron.tokenizer)\nreload(megatron.global_vars)\nreload(megatron.global_vars)\nprint(\"veGiantModel loaded.\")\n"
  }
]