[
  {
    "path": ".gitignore",
    "content": ".DS_Store\n*.py[cod]\n*.log\n\n# C extensions\n*.so\n\n# Packages\n*.egg\n*.egg-info\ndist\nbuild\neggs\nparts\nvar\nsdist\ndevelop-eggs\n.installed.cfg\nlib\nlib64\n__pycache__\n\n# Installer logs\npip-log.txt\nfiles.txt\n\n# Unit test / coverage reports\n.coverage\n.tox\nnosetests.xml\n\n# Translations\n*.mo\n\n# Mr Developer\n.mr.developer.cfg\n.project\n.pydevproject\n"
  },
  {
    "path": "LICENSE",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [2022] [Michael Isaev, Nic McDonald]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".SUFFIXES:\n.PHONY: help install clean lint test count\n\nhelp:\n\t@echo \"options are: install clean lint test count\"\n\ninstall:\n\tpython3 setup.py install --user --record files.txt\n\nuninstall:\n\tcat files.txt | xargs rm -rf\n\nclean:\n\trm -rf build dist calculon.egg-info calculon/*.pyc calculon/__pycache__ calculon/*/__pycache__ test/*.pyc test/__pycache__\n\nlint:\n\tpylint -r n calculon\n\ntest:\n\tpython3 -m unittest -v -f --buffer\n\t@echo -e \"Unit testing successful!\\n\\n\"\n\t./test/test.sh\n\ncount:\n\t@wc calculon/*.py test/*.py | sort -n -k1\n\t@echo \"files : \"$(shell echo calculon/*.py test/*.py | wc -w)\n\t@echo \"commits : \"$(shell git rev-list HEAD --count) \n"
  },
  {
    "path": "NOTICE",
    "content": "Calculon - Co-design for large scale parallel applications\nCopyright 2022 Michael Isaev, Nic McDonald\nAll rights reserved."
  },
  {
    "path": "README.md",
    "content": "[![DOI](https://zenodo.org/badge/660734586.svg)](https://zenodo.org/badge/latestdoi/660734586)\n# Calculon - Co-design for large scale parallel applications\n\n## Running\n\nRun Calculon like this:\n``` sh\n$> PYTHONPATH=. ./bin/ <args>\n```\n\nCalculon is a hierarchical command line. To see the commands it accepts, use `--help` or `-h`:\n``` sh\n$> PYTHONPATH=. ./bin/ -h\n```\n\nYou can also see how to use any command specifically by using `--help` or `-h` on the command:\n``` sh\n$> PYTHONPATH=. ./bin/ llm -h\n```\n\n## LLM Example\n\nRun a single calculation for LLM (~1 sec):\n``` sh\n$> PYTHONPATH=. ./bin/ llm models/megatron-1T.json examples/3072_t4_p64_d12_mbs4_full.json systems/a100_80g.json -\n```\n\nRun a system execution optimizer for LLM (~1 min):\n``` sh\n$> PYTHONPATH=. ./bin/ llm-optimal-execution models/turing-530B.json 5128 2520 float16 systems/a100_80g.json output.json -m\n```\n`opt_exe.json` will contain the optimal way to run Turing-530B across 5128 A100 GPUs.\n\nTo store results from all successful runs from the same experiment, run a special system optimizer (~1 min):\n``` sh\n$> PYTHONPATH=. ./bin/ llm-all-executions models/turing-530B.json 5128 2520 float16 systems/a100_80g.json all_output.csv\n```\n\n## Testing and validation (optional)\nTo make sure that the current build is working, use\n\n``` sh\n$> make test\n```\nTo validate Calculon performance modeling against Megatron run on NVIDIA's Selene A100-based supercomputer with results published in [\"Sequence parallelism\" paper](https://arxiv.org/abs/2205.05198), use\n\n``` sh\n$> PYTHONPATH=. ./bin/calculon llm-validation\n```\n\n## Publications\n\n* Calculon: A Methodology and Tool for High-Level Co-Design of Systems and Large Language Models\\\nMikhail Isaev, Nic McDonald, Larry Dennison, Richard Vuduc\\\n[Paper](https://dl.acm.org/doi/pdf/10.1145/3581784.3607102)\n\n* Scaling Infrastructure to Support Multi-Trillion Parameter LLM Training\\\nMikhail Isaev, Nic McDonald, Richard Vuduc\\\n[Paper](https://openreview.net/pdf?id=rqn2v1Ltgn0)\n"
  },
  {
    "path": "bin/calculon",
    "content": "#!/usr/bin/env python3\n\n\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport argparse\nimport calculon\nimport logging\nimport sys\n\n\n\n\nif __name__ == '__main__':\n  # CLI inspired from: https://github.com/ssnetsim/ssplot/\n\n  # Creates an argparser and subparsers.\n  desc = 'Calculon: Co-design for large scale parallel applications'\n  ap = argparse.ArgumentParser(description=desc)\n  ap.add_argument('-l', '--log', default='-',\n                  help='Sets the log file, or - for stdout (default)')\n  ap.add_argument('-v', '--verbosity', default='INFO',\n                  help='Sets the logging level (see logging docs)')\n  sp = ap.add_subparsers(title='commands', dest='command',\n                         description='commands available in Calculon',\n                         help='the command')\n  sp.required = True\n\n  # Registers each command line interface.\n  for cls in calculon.CommandLine.command_lines():\n    cls.create_parser(sp)\n\n  # Parses the args and creates the logger\n  args = ap.parse_args()\n  logger = logging.getLogger()\n  if args.log == '-':\n    logger.addHandler(logging.StreamHandler(stream=sys.stdout))\n  else:\n    fd = open(args.log, 'w')\n    logger.addHandler(logging.StreamHandler(stream=fd))\n  logger.setLevel(args.verbosity)\n\n  # Calls the corresponding command function\n  sys.exit(args.func(logger, args))\n"
  },
  {
    "path": "calculon/__init__.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\n__version__ = '0.1.0'\n\n# Imports of this module\nfrom .command_line import CommandLine\nfrom .io import *\nfrom .system import System\nfrom .util import *\nfrom .version import Version\n\n# Imports submodules\nfrom .llm import *\n"
  },
  {
    "path": "calculon/command_line.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport copy\n\nclass CommandLine:\n  \"\"\"Defines the abstract interface definition for a command line interface.\n  Inspired from: https://github.com/ssnetsim/ssplot/\n  \"\"\"\n\n  @staticmethod\n  def create_parser(subparser):\n    \"\"\"\n    This function adds a parser to the subparser object according to the\n    specific command line interface implementation.\n    \"\"\"\n    raise NotImplementedError('subclasses must override this')\n\n  @staticmethod\n  def run_command(logger, args):\n    \"\"\"\n    This function is used to run the command if it is chosen at the command\n    line. This function should be registered to the parser in create_parser().\n    \"\"\"\n    raise NotImplementedError('subclasses must override this')\n\n  # this is a mapping of all names (class->names)\n  _names = {}\n\n  @staticmethod\n  def register(cls):\n    # gather names\n    primary_name = cls.NAME\n    aliases = cls.ALIASES\n\n    # create a set to hold all\n    all_names = [primary_name] + aliases\n\n    # check current names against all new names\n    for new_name in all_names:\n      for pname in CommandLine._names:\n        assert new_name is not pname, f'{new_name} already exists'\n        for alias in CommandLine._names[pname]:\n          assert new_name is not alias, f'{new_name} already exists'\n\n    # add to map\n    CommandLine._names[cls] = all_names\n\n  @staticmethod\n  def command_lines():\n    return set(CommandLine._names.keys())\n\n  @staticmethod\n  def all_names():\n    return copy.copy(CommandLine._names)\n"
  },
  {
    "path": "calculon/io.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\nimport gzip\nimport json\nimport numpy as np\n\n\nclass NpEncoder(json.JSONEncoder):\n  def default(self, obj):\n    if isinstance(obj, np.integer):\n      return int(obj)\n    if isinstance(obj, np.floating):\n      return float(obj)\n    if isinstance(obj, np.ndarray):\n      return obj.tolist()\n    if isinstance(obj, np.bool_):\n      return bool(obj)\n    return super(NpEncoder, self).default(obj)\n\ndef is_json_extension(filename):\n  return filename.endswith('.json') or filename.endswith('.json.gz')\n\n\ndef write_json_file(jdata, filename):\n  assert is_json_extension(filename)\n  opener = gzip.open if filename.endswith('.gz') else open\n  indent = None if filename.endswith('.gz') else 2\n  with opener(filename, 'wb') as fd:\n    fd.write(bytes(json.dumps(jdata, indent=indent, cls=NpEncoder), 'utf-8'))\n\n\ndef read_json_file(filename):\n  assert is_json_extension(filename)\n  opener = gzip.open if filename.endswith('.gz') else open\n  with opener(filename, 'rb') as fd:\n    return json.loads(fd.read().decode('utf-8'))\n"
  },
  {
    "path": "calculon/llm/__init__.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nfrom .layers import *\nfrom .llm import *\n\n# Command lines\nfrom .all_executions import AllExecutions\nfrom .optimal_execution import OptimalExecution\nfrom .parameter_calculator import ParameterCalculator\nfrom .validation import Validation\nfrom .runner import Runner\n"
  },
  {
    "path": "calculon/llm/all_executions.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport datetime\nimport gzip\nimport itertools\nimport logging\nimport math\nimport multiprocessing as mp\nimport os\nimport pandas\nimport psutil\nimport random\n\nimport calculon\nfrom calculon.util import pick, arg_true_false_all\nfrom calculon.llm import *\n\n\nclass AllExecutions(calculon.CommandLine):\n  NAME = 'llm-all-executions'\n  ALIASES = ['lae']\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(\n      AllExecutions.NAME, aliases=AllExecutions.ALIASES,\n      help='run a search to find the optimal llm execution')\n    sp.set_defaults(func=AllExecutions.run_command)\n    sp.add_argument('-d', '--debug', action='store_true',\n                    help='Loop over executions, don\\'t run them')\n    sp.add_argument('application', type=str,\n                    help='File path to application configuration')\n    sp.add_argument('num_procs', type=int,\n                    help='Number of processors in execution')\n    sp.add_argument('max_batch_size', type=int,\n                    help='Maximum batch size, will be largest multiple of DP')\n    sp.add_argument('datatype', type=str, choices=System.supported_datatypes(),\n                    help='The datatype to use')\n    sp.add_argument('system', type=str,\n                    help='File path to system configuration')\n    sp.add_argument('output', type=str,\n                    help='File path to the output file'\n                    \" ('*.csv', '*.csv.gz')\")\n    sp.add_argument('-c', '--cpus', type=int, default=psutil.cpu_count(logical=False),\n                    help='CPUs to use for parallelization')\n    sp.add_argument('-n', '--noneok', action='store_true',\n                    help='Don\\'t give failure status when no good execution exists')\n    sp.add_argument('-f', '--fused_activation', type=arg_true_false_all,\n                    default='true', help='Mode of fused activation')\n\n  @staticmethod\n  def execution_fields():\n    return (\n      'num_procs', 'tensor_par', 'pipeline_par', 'data_par', 'tensor_par_net',\n      'pipeline_par_net', 'data_par_net', 'batch_size', 'microbatch_size',\n      'datatype', 'fused_activation', 'attention_type', 'activation_recompute',\n      'pipeline_interleaving', 'optimizer_sharding', 'tensor_par_comm_type',\n      'tensor_par_overlap', 'seq_par_ag_redo', 'data_par_overlap',\n      'weight_offload', 'activations_offload', 'optimizer_offload', 'training')\n\n  @staticmethod\n  def get_batch_size(data_par, max_batch_size):\n    if data_par > max_batch_size:\n      return None\n    last = data_par\n    while True:\n      if last + data_par > max_batch_size:\n        return last\n      else:\n        last += data_par\n\n  @staticmethod\n  def all_executions(app, syst, num_procs, max_batch_size, datatype, fused_activation):\n    has_mem2 = syst.mem2.capacity > 0\n    num_nets = syst.num_networks\n    count = 0\n    for tp in Llm.get_all_tensor_parallelisms(\n        num_procs, app.hidden, app.attn_heads):\n      for pp in Llm.get_all_pipeline_parallelisms(\n          num_procs, tp, app.num_blocks):\n        dp = Llm.get_data_parallelism(num_procs, tp, pp)\n        for ppint in Llm.get_valid_pipeline_interleavings(app.num_blocks, pp):\n          batch_size = AllExecutions.get_batch_size(dp, max_batch_size)\n          if batch_size is None:\n            continue\n          for activation_recompute in ['full', 'attn_only', 'none']:\n            for optimizer_sharding in pick(dp>1, [True, False], [False]):\n              for tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']:\n                can_redo = Llm.can_redo_ag(tensor_par_comm_type,\n                                           activation_recompute)\n                for seq_par_ag_redo in pick(can_redo, [True, False], [False]):\n                  for data_par_overlap in pick(dp>1, [True, False], [False]):\n                    for tensor_par_overlap in pick(tp>1, ['none', 'ring', 'pipe'], ['none']):\n                      for weight_offload in pick(has_mem2, [True, False], [False]):\n                        if activation_recompute == 'full' or not has_mem2:\n                          activations_offloads = [False]\n                        else:\n                          activations_offloads = [True, False]\n                        for activations_offload in activations_offloads:\n                          for optimizer_offload in pick(has_mem2, [True, False],\n                                                        [False]):\n                            for fused_act in fused_activation:\n                              for microbatch_size in Llm.get_valid_microbatch_sizes(\n                                  app.seq_size, tp, dp, batch_size, pp):\n                                for tn in pick(tp>1, range(num_nets), [0]):\n                                  for pn in pick(pp>1, range(num_nets), [0]):\n                                    for dn in pick(dp>1, range(num_nets), [0]):\n                                      yield (num_procs, tp, pp, dp, tn, pn, dn,\n                                             batch_size, microbatch_size, datatype,\n                                             fused_act, 'multihead', activation_recompute,\n                                             ppint, optimizer_sharding, tensor_par_comm_type,\n                                             tensor_par_overlap, seq_par_ag_redo,\n                                             data_par_overlap, weight_offload,\n                                             activations_offload, optimizer_offload,\n                                             True)\n                                      count += 1\n\n  @staticmethod\n  def run_command(logger, args):\n    assert args.output.endswith('.csv') or args.output.endswith('.csv.gz')\n\n    app = Llm.Application(calculon.io.read_json_file(args.application))\n    syst = System(calculon.io.read_json_file(args.system))\n\n    executions = list(AllExecutions.all_executions(\n      app, syst, args.num_procs, args.max_batch_size, args.datatype,\n      args.fused_activation))\n    random.shuffle(executions)\n    exe_count = len(executions)\n    logger.info(f'Total executions: {exe_count}')\n\n    step = math.ceil(len(executions) / args.cpus)\n    worker_args = []\n    for index in range(0, len(executions), step):\n      worker_args.append((app, syst, executions[index : index + step]))\n    del executions\n\n    # Runs parallel searches\n    start_time = datetime.datetime.now()\n    with mp.Pool(args.cpus) as pool:\n      goods = pool.starmap(AllExecutions.search, worker_args)\n    end_time = datetime.datetime.now()\n    good_count = sum(len(good) for good in goods)\n\n    # Console statistics\n    logger.info(f'Good executions: {good_count}')\n    logger.info(f'Bad executions: {exe_count-good_count}')\n    calc_rate = exe_count / (end_time - start_time).total_seconds()\n    logger.info(f'Calculation rate: {calc_rate:.2f} calcs/sec')\n\n    # Check if OK\n    if good_count == 0:\n      if not args.noneok:\n        logger.fatal('No acceptable configurations found :(')\n        return -1\n      else:\n        logger.info('No acceptable configurations found :(')\n\n    if args.debug:\n      return 0\n\n    # Writes to CSV\n    fields = Llm.Execution.fields() + Llm.get_stats_fields()\n    assert len(fields) == len(goods[0][0])\n    logger.info(f'Output: {args.output}')\n    opener = gzip.open if args.output.endswith('.gz') else open\n    with opener(args.output, 'wb') as fd:\n      fd.write(bytes(','.join(fields) + '\\n', 'utf-8'))\n      for vals in itertools.chain(*goods):\n        fd.write(bytes(','.join(str(v) for v in vals) + '\\n', 'utf-8'))\n\n    return 0\n\n  @staticmethod\n  def search(app, syst, executions):\n    good = []\n    for execution in executions:\n      try:\n        model = Llm(app, logging.Logger('sub'))\n        model.compile(syst, Llm.Execution(*execution))\n        model.run(syst)\n        statistics = model.get_stats_values()\n        good.append(execution + statistics)\n      except Llm.Error as ex:\n        logger = logging.getLogger()\n        logger.debug(f'ERROR:{ex}\\n')\n    return good\n\n  @staticmethod\n  def update_list(current, candidate, quantity):\n    if not isinstance(candidate, list):\n      current.append(candidate)\n    else:\n      current.extend(candidate)\n    if quantity <= 0:\n      return current  # don't sort and chop\n    else:\n      current.sort(reverse=True, key=lambda x: x[0])\n      return current[:quantity]\n\n\ncalculon.CommandLine.register(AllExecutions)\n"
  },
  {
    "path": "calculon/llm/layers.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nfrom calculon import *\n\n\nclass Layer:\n  \"\"\"\n  A single layer of a neural network. Has weights, activation space,\n  gradients, and optimizer state associated with it. May invoke compute,\n  memory access, or network operation.\n  \"\"\"\n\n  def __init__(self, name, sys, fw_flops=0, agrad_flops=0, wgrad_flops=0,\n               inputs_size=0, output_size=0, activation_space=0,\n               activation_grads=0, weight_space=0, weight_grads=0,\n               optim_space=0, needs_recompute=False, needs_recomm=False,\n               activation_reused=False, activation_stored=True,\n               output_stored=True):\n    self.name = name\n    self.sys = sys\n    self.fw_flops = fw_flops\n    self.agrad_flops = agrad_flops\n    self.wgrad_flops = wgrad_flops\n    self.inputs_size = inputs_size\n    self.output_size = output_size\n    # activations equal input size, we store them to compute Wgrad during BW\n    self.activation_space = activation_space\n    # activation grads equal output size and correspond grads w.r.t. the output\n    self.activation_grads = activation_grads\n    self.weight_space = weight_space\n    self.weight_grads = weight_grads\n    self.optim_space = optim_space\n    self.optim_sharding_num_proc = 1\n\n    # Add optimizations and parallelization split\n    self.needs_recompute = needs_recompute\n    self.needs_recomm = needs_recomm\n    self.activation_reused=activation_reused\n    self.activation_stored = activation_stored\n    self.output_stored = output_stored\n    # Before bytes_per_element set by SW config, we operate with just\n    # parameter count, setting bytes_per_element to 1\n    self.bytes_per_element = 1\n    self.processing_time = None\n    self.net_exposed_time = None\n\n  def get_stats_json(self):\n    return {\n      'name': self.name,\n      'inputs_size': self.inputs_size,\n      'outputs_size': self.output_size,\n      'fw_flops': self.get_fw_flops(),\n      'fw_mem_accessed': self.get_fw_mem_accessed(),\n      'fw_arithmetic_intensity': self.get_fw_arithmetic_intensity(),\n      'fw_processing_time': self.compute_processing_time('fw'),\n      'baseblock_fw_tp_comm_tile': self.get_comm_tile('fw', baseblock=True),\n      'edgeblock_fw_tp_comm_tile': self.get_comm_tile('fw', baseblock=False),\n      'baseblock_fw_tp_comm_size': self.get_comm_bytes('fw', baseblock=True),\n      'edgeblock_fw_tp_comm_size': self.get_comm_bytes('fw', baseblock=False),\n      'baseblock_fw_tp_comm_time': self.compute_net_time('fw', baseblock=True),\n      'edgeblock_fw_tp_comm_time': self.compute_net_time('fw',baseblock=False),\n      'baseblock_fw_tp_comm_time_exposed': self.get_exposed_net_time(\n        'fw', baseblock=True),\n      'edgeblock_fw_tp_comm_time_exposed': self.get_exposed_net_time(\n        'fw', baseblock=False),\n      'agrad_flops': self.get_agrad_flops(),\n      'agrad_mem_accessed': self.get_agrad_mem_accessed(),\n      'agrad_arithmetic_intensity': self.get_agrad_arithmetic_intensity(),\n      'agrad_processing_time': self.compute_processing_time('agrad'),\n      'baseblock_bw_tp_comm_tile': self.get_comm_tile('agrad', baseblock=True),\n      'edgeblock_bw_tp_comm_tile': self.get_comm_tile('agrad', baseblock=False),\n      'baseblock_bw_tp_comm_size': self.get_comm_bytes('agrad', baseblock=True),\n      'edgeblock_bw_tp_comm_size': self.get_comm_bytes('agrad', baseblock=False),\n      'baseblock_bw_tp_comm_time': self.compute_net_time('agrad', baseblock=True),\n      'edgeblock_bw_tp_comm_time': self.compute_net_time('agrad', baseblock=False),\n      'baseblock_bw_tp_comm_time_exposed': self.get_exposed_net_time(\n        'agrad', baseblock=True),\n      'edgeblock_bw_tp_comm_time_exposed': self.get_exposed_net_time(\n        'agrad', baseblock=False),\n      'wgrad_flops': self.get_wgrad_flops(),\n      'wgrad_mem_accessed': self.get_wgrad_mem_accessed(),\n      'wgrad_arithmetic_intensity': self.get_wgrad_arithmetic_intensity(),\n      'wgrad_processing_time': self.compute_processing_time('wgrad'),\n      'baseblock_recomm_tile': self.get_comm_tile('wgrad', baseblock=True),\n      'edgeblock_recomm_tile': self.get_comm_tile('wgrad', baseblock=False),\n      'baseblock_recomm_size': self.get_comm_bytes('wgrad', baseblock=True),\n      'edgeblock_recomm_size': self.get_comm_bytes('wgrad', baseblock=False),\n      'baseblock_recomm_time': self.compute_net_time('wgrad', baseblock=True),\n      'edgeblock_recomm_time': self.compute_net_time('wgrad', baseblock=False),\n      'baseblock_recomm_time_exposed': self.get_exposed_net_time(\n        'wgrad', baseblock=True),\n      'edgeblock_recomm_time_exposed': self.get_exposed_net_time(\n        'wgrad', baseblock=False),\n      'optim_flops': self.get_optim_step_flops(),\n      'optim_mem_accessed': self.get_optim_step_mem_accessed(),\n      'optim_arithmetic_intensity': self.get_optim_step_arithmetic_intensity(),\n      'optim_processing_time': self.compute_processing_time('optim'),\n      'weight': self.get_weight(),\n      'activation': self.get_activation(),\n      'weight_grad': self.get_weight_grad(),\n      'activation_grad': self.get_activation_grad(),\n      'optimizer': self.get_optimizer()\n    }\n\n  def get_stats_str(self):\n    stats = \"Operation {0}:\\n{1} FW flops, {2} FW bytes accessed,\".format(\n      self.name,\n      human_format(self.get_fw_flops(), 'flops'),\n      human_format(self.get_fw_mem_accessed(), 'bytes'))\n    stats += \" FW AI: {0:.3f}\\n\".format(self.get_fw_arithmetic_intensity())\n    stats += \"{0} BW Adrad flops, {1} BW Agrad bytes accessed,\".format(\n      human_format(self.get_agrard_flops(), 'flops'),\n      human_format(self.get_agrad_mem_accessed(), 'bytes'))\n    stats += \" BW Agrad AI: {0:.3f}\\n\".format(\n      self.get_agrad_arithmetic_intensity())\n    stats += \"{0} BW Wdrad flops, {1} BW Wgrad bytes accessed,\".format(\n      human_format(self.get_wgrard_flops(), 'flops'),\n      human_format(self.get_wgrad_mem_accessed(), 'bytes'))\n    stats += \" BW Wgrad AI: {0:.3f}\\n\".format(\n      self.get_wgrad_arithmetic_intensity())\n    stats += \"{0} Optim flops, {1} Optim bytes accessed,\".format(\n      human_format(self.get_optim_step_flops(), 'flops'),\n      human_format(self.get_optim_step_mem_accessed(), 'bytes'))\n    stats += \" Optim AI: {0:.3f}\\n\".format(\n      self.get_optim_step_arithmetic_intensity())\n    stats += \"W: {0}, Act: {1}, WGrad: {2}, AGrad: {3}, Optim: {4}\".format(\n      human_format(self.get_weight(), 'bytes'),\n      human_format(self.get_activation(), 'bytes'),\n      human_format(self.get_weight_grad(), 'bytes'),\n      human_format(self.get_activation_grad(), 'bytes'),\n      human_format(self.get_optimizer(), 'bytes'))\n    return stats\n\n  def set_bytes_per_element(self, bytes_per_element):\n    self.bytes_per_element = bytes_per_element\n\n  # Shard (distribute) optimizer and weight grads between data parallel nodes\n  def shard_optimizer(self, num_procs):\n    self.optim_sharding_num_proc = num_procs\n\n  # getters that will be called from Llm model class, can be rewritten\n  def get_fw_flops(self):\n    return self.fw_flops\n\n  def get_fw_mem_accessed(self):\n    mem_accessed = self.inputs_size + self.output_size + self.weight_space\n    mem_accessed *= self.bytes_per_element\n    return mem_accessed\n\n  def get_fw_arithmetic_intensity(self):\n    if self.fw_flops == 0:\n      return 0\n    if self.get_fw_mem_accessed() == 0:\n      return float('inf')\n    return self.fw_flops / self.get_fw_mem_accessed()\n\n  def get_recompute_flag(self):\n    return self.needs_recompute\n\n  def get_recomm_flag(self):\n    return self.needs_recomm\n\n  def reuses_activation(self):\n    return self.activation_reused\n\n  def stores_activation(self):\n    return self.activation_stored\n\n  def stores_output(self):\n    return self.output_stored\n\n  def get_agrad_flops(self):\n    return self.agrad_flops\n\n  def get_agrad_mem_accessed(self):\n    # activation grads equal output size and correspond grads w.r.t.\n    # layer output; activations are equal to input size\n    grad_mem = self.weight_space + (\n      self.activation_space + self.activation_grads)\n    grad_mem *= self.bytes_per_element\n    return grad_mem\n\n  def get_agrad_arithmetic_intensity(self):\n    if self.agrad_flops == 0:\n      return 0\n    if self.get_agrad_mem_accessed() == 0:\n      return float('inf')\n    return self.agrad_flops / self.get_agrad_mem_accessed()\n\n  def get_wgrad_flops(self):\n    return self.wgrad_flops\n\n  def get_wgrad_mem_accessed(self):\n    if self.weight_space == 0:\n      assert self.wgrad_flops == 0, \\\n        f\"Haven't expected to see wgrad flops in layer {self.name}\"\n      return 0\n    # activation grads equal output size and correspond grads w.r.t.\n    # layer output; activations are equal to input size\n    grad_mem = self.weight_grads + (\n      self.activation_space + self.activation_grads)\n    grad_mem *= self.bytes_per_element\n    return grad_mem\n\n  def get_wgrad_arithmetic_intensity(self):\n    if self.wgrad_flops == 0:\n      return 0\n    if self.get_wgrad_mem_accessed() == 0:\n      return float('inf')\n    return self.wgrad_flops / self.get_wgrad_mem_accessed()\n\n  # We use Adam optimizer. The amount of flops is based on the number of\n  # weight grads to accommodate for possible weight_grad sharding\n  # among data parallel nodes\n  def get_optim_step_flops(self):\n    optim_flops = self.weight_grads / self.optim_sharding_num_proc * 11\n    return optim_flops\n\n  def get_optim_step_mem_accessed(self):\n    return self.get_optimizer()\n\n  def get_optim_step_arithmetic_intensity(self):\n    if self.get_optim_step_flops() == 0:\n      return 0\n    if self.get_optim_step_mem_accessed() == 0:\n      return float('inf')\n    return self.get_optim_step_flops() / self.get_optim_step_mem_accessed()\n\n  def get_weight(self):\n    return self.weight_space * self.bytes_per_element\n\n  def get_activation(self):\n    return self.activation_space * self.bytes_per_element\n\n  def get_output(self):\n    return self.output_size * self.bytes_per_element\n\n  def get_weight_grad(self, sharded=True):\n    # Keep lower precision copy of grads for mem and net transfers\n    grads = self.weight_grads\n    if sharded:\n      # We keep grads in lower precision for communication\n      grads *= self.bytes_per_element\n      grads /= self.optim_sharding_num_proc\n    else:\n      # otherwise keep grads in 32 bit for accumulation\n      grads *= 4\n    return grads\n\n  def get_activation_grad(self):\n    return self.activation_grads * self.bytes_per_element\n\n  def get_optimizer(self):\n    # Keep 32-bits master copy of weights, plus both moments (m,v)\n    # master copy for grads is accounted for in get_weight_grad()\n    moments_size = self.optim_space * 4\n    if self.bytes_per_element < 4:\n      master_copy_size = self.weight_space * 4\n    else:\n      master_copy_size = 0\n    return (master_copy_size + moments_size) / self.optim_sharding_num_proc\n\n  def set_processing_time(self, processing_time):\n    self.processing_time = processing_time\n\n  def get_processing_time(self):\n    return self.processing_time\n\n  def use_matrix_engine(self):\n    return False\n\n  def get_comm_bytes(self, stage, baseblock=True):\n    return 0\n\n  def get_comm_tile(self, stage, baseblock=True):\n    return self.get_comm_bytes(stage, baseblock)\n\n  def compute_flops_time(self, stage):\n    if stage == \"fw\":\n      flops = self.get_fw_flops()\n    elif stage == \"agrad\":\n      flops = self.get_agrad_flops()\n    elif stage == \"wgrad\":\n      flops = self.get_wgrad_flops()\n    elif stage == \"optim\":\n      flops = self.get_optim_step_flops()\n    else:\n      raise Exception(f'Bad compute stage : {stage}')\n    if self.use_matrix_engine() and stage != \"optim\":\n      throughput = self.sys.get_matrix_throughput(flops)\n    else:\n      throughput = self.sys.get_vector_throughput(flops)\n    return flops / throughput\n\n  def compute_mem_time(self, stage):\n    if stage == \"fw\":\n      mem = self.get_fw_mem_accessed()\n    elif stage == \"agrad\":\n      mem = self.get_agrad_mem_accessed()\n    elif stage == \"wgrad\":\n      mem = self.get_wgrad_mem_accessed()\n    elif stage == \"optim\":\n      mem = self.get_optim_step_mem_accessed()\n    else:\n      raise Exception(f'Bad compute stage : {stage}')\n    return mem / self.sys.get_mem1_throughput(mem)\n\n  def compute_net_time(self, stage, baseblock=True):\n    return 0\n\n  def get_exposed_net_time(self, stage, baseblock=True):\n    return 0\n\n  def get_required_bandwidth(self, stage, baseblock=True):\n    return 0\n\n  def compute_processing_time(self, stage):\n    self.processing_time =  self.sys.get_processing_time(\n      self.compute_flops_time(stage),\n      self.compute_mem_time(stage)\n    )\n    return self.processing_time\n\n# We can factor all layers peculiarities and layer-wise optimizations by\n# rewriting parent class member functions when needed\nclass Linear(Layer):\n  def __init__(self, name, sys, batch_seq, c_in, c_out,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    m, n, k = batch_seq, c_in, c_out\n    super().__init__(name,\n                     sys,\n                     fw_flops=2*m*n*k,\n                     agrad_flops=2*m*n*k,\n                     wgrad_flops=2*m*n*k,\n                     inputs_size=m*n,\n                     output_size=m*k,\n                     weight_space=n*k,\n                     weight_grads=n*k,\n                     activation_space=m*n,\n                     activation_grads=m*k,\n                     optim_space=2*n*k,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def use_matrix_engine(self):\n    return True\n\nclass LinearOverlapped(Layer):\n  def __init__(self, name, sys, batch_seq, c_in, c_out, tensor_par_comm_type,\n               num_tiles, net_id, num_peers, conjugate=False,\n               in_network_reduction=False, tp_overlap='pipe',\n               needs_recompute=False, needs_recomm=False,\n               activation_reused=False, activation_stored=True,\n               output_stored=True):\n    m, n, k = batch_seq, c_in, c_out\n    self.tensor_par_comm_type = tensor_par_comm_type\n    self.num_tiles = num_tiles\n    self.net = sys.get_network(net_id)\n    self.num_peers = num_peers\n    self.conjugate = conjugate\n    self.in_network_reduction = in_network_reduction\n    self.tp_overlap = tp_overlap\n    self._processed_flag = False\n    if self.tensor_par_comm_type == 'rs_ag':\n      if not conjugate:\n        #AllGather case\n        assert k % self.num_peers == 0\n        # assert m % self.num_peers == 0         # this should be true for seq_par\n        k = k // self.num_peers\n        act_space = m * n // num_tiles\n        act_grad_space = m * k\n        act_net_buffer = m * n // num_tiles\n        act_grad_net_buffer = 0\n      else:\n        # ReduceScatter case\n        assert n % self.num_peers == 0\n        # assert m % self.num_peers == 0         # this should be true for seq_par\n        n = n // self.num_peers\n        act_space = m * n\n        act_grad_space = m * k // num_tiles\n        act_net_buffer = 0\n        act_grad_net_buffer = m * k // num_tiles\n        #act_net_buffer = m * k // num_tiles\n    else:\n      if not conjugate:\n        # AllReduce case\n        assert k % self.num_peers == 0\n        k = k // self.num_peers\n        act_space = m * n\n        act_grad_space = 0\n        act_net_buffer = m * n // num_tiles\n        act_grad_net_buffer = 0\n      else:\n        # Identityy case\n        assert n % self.num_peers == 0\n        n = n // self.num_peers\n        act_space = 0\n        act_grad_space = m * k\n        act_net_buffer = 0\n        act_grad_net_buffer = m * k\n\n    super().__init__(name,\n                     sys,\n                     fw_flops=2*m*n*k,\n                     agrad_flops=2*m*n*k,\n                     wgrad_flops=2*m*n*k,\n                     inputs_size=m*n,\n                     output_size=m*k,\n                     weight_space=n*k,\n                     weight_grads=n*k,\n                     activation_space=act_space, # + act_net_buffer,\n                     activation_grads=act_grad_space + act_grad_net_buffer,\n                     optim_space=2*n*k,\n                     needs_recompute=needs_recompute,\n                     needs_recomm=needs_recomm,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def use_matrix_engine(self):\n    return True\n\n  def get_comm_bytes(self, stage, baseblock=True):\n    if self.num_peers == 1:\n      return 0\n    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (\n      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)\n    ag_comm_size = self.inputs_size * self.bytes_per_element\n    ar_rs_comm_size = self.output_size * self.bytes_per_element\n    if stage == 'fw':\n      if self.conjugate:\n        # ReduceScatter or AllReduce on FW\n        return ar_rs_comm_size\n      else:\n        if split_comm:\n          # AllGather on FW\n          return ag_comm_size\n        else:\n          # Identity on FW\n          return 0\n    if stage == 'agrad':\n      # Comm sizes during FW and BW pass are the same\n      if not self.conjugate:\n        # ReduceScatter or AllReduce on BW\n        return ag_comm_size\n      else:\n        if split_comm:\n          # AllGather on BW\n          return ar_rs_comm_size\n        else:\n          # Identity on BW\n          return 0\n    if stage == 'wgrad':\n      if self.needs_recomm:\n        return self.get_comm_bytes('fw', baseblock)\n      else:\n        return 0\n    if stage == 'optim':\n      return 0\n\n  def get_comm_flops(self, stage, baseblock=True):\n    return self.get_comm_bytes(stage, baseblock) / self.bytes_per_element\n\n  def get_num_tiles(self):\n    return self.num_tiles\n\n  def get_comm_tile(self, stage, baseblock=True):\n    return self.get_comm_bytes(stage, baseblock) / self.get_num_tiles()\n\n  def compute_net_time(self, stage, baseblock=True):\n    if self.num_peers == 1:\n      return 0\n    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (\n      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)\n    if self.conjugate:\n      if split_comm:\n        # ReduceScatter case\n        fw_comm_type = 'reduce_scatter'\n        bw_comm_type = 'all_gather'\n      else:\n        #AllReduce case\n        fw_comm_type = 'all_reduce'\n        bw_comm_type = None\n      if not self.in_network_reduction:\n        fw_flops = self.get_comm_flops(stage, baseblock) * (\n          self.num_peers - 1) / self.num_peers\n        fw_flop_time = fw_flops / self.sys.get_vector_throughput(fw_flops)\n      else:\n        fw_flop_time = 0\n      bw_flop_time = 0\n    else:\n      if split_comm:\n        #AllGather case\n        fw_comm_type = 'all_gather'\n        bw_comm_type = 'reduce_scatter'\n      else:\n        # Identity case\n        fw_comm_type = None\n        bw_comm_type = 'all_reduce'\n      fw_flop_time = 0\n      if not self.in_network_reduction:\n        bw_flops = self.get_comm_flops(stage, baseblock) * (\n          self.num_peers - 1) / self.num_peers\n        bw_flop_time = bw_flops / self.sys.get_vector_throughput(bw_flops)\n      else:\n        bw_flop_time = 0\n    if stage == 'fw':\n      if fw_comm_type == None:\n        return 0\n      else:\n        fw_net_time = self.net.time(\n          fw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)\n        return fw_net_time + fw_flop_time\n    if stage == 'agrad':\n      if bw_comm_type == None:\n        return 0\n      else:\n        bw_net_time = self.net.time(\n          bw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)\n        return bw_net_time + bw_flop_time\n    if stage == 'wgrad':\n      if self.needs_recomm and fw_comm_type:\n        # AllGather Redo (RS_AG only) or full recompute\n        return self.net.time(\n          fw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)\n      else:\n        return 0\n    if stage == 'optim':\n      return 0\n\n  def compute_processing_time(self, stage):\n    flop_time = self.compute_flops_time(stage)\n    flop_time_slowed = flop_time / (1 - self.net.processor_usage)\n    mem_time = self.compute_mem_time(stage)\n    net_time = self.compute_net_time(stage)\n    compute_time = self.sys.get_processing_time(flop_time, mem_time)\n    if net_time == 0:\n      time = compute_time\n      net_exposed_time = 0\n    else:\n      compute_time_slowed = self.sys.get_processing_time(\n        flop_time_slowed, mem_time)\n      # Tiled time computed as fraction of full time, to model high effective\n      # throughput when processing many consequitive tiles\n      flop_tile = flop_time / self.num_tiles\n      flop_tile_slowed = flop_time_slowed / self.num_tiles\n      net_tile = net_time / self.num_tiles\n      compute_tile = compute_time / self.num_tiles\n      compute_tile_slowed = compute_time_slowed / self.num_tiles\n      overlap_inflection = net_tile - flop_tile_slowed\n      # we have one exposed comm tile if tp_comm is not ring,\n      # one exposed compute tile, and\n      # (Proc - 1) overlapped tiles, where either compute or comm is exposed\n      if overlap_inflection > 0:\n        # Tcomm is larger than compute, excess is exposed\n        # compute time itself is the compute + mem\n        time = compute_tile + (self.num_tiles - 1) * compute_tile_slowed\n        net_exposed_time = (self.num_tiles - 1) * overlap_inflection\n      else:\n        # Tcomm is smaller than compute and hidden, but it contributes to\n        # compute slowdown due part of compute resources orchestrating comm\n        time = compute_tile + (self.num_tiles - 1) * compute_tile + (\n          self.num_tiles - 1) * net_tile * self.net.processor_usage\n        net_exposed_time = 0\n      if self.tp_overlap == 'pipe':\n        # If overlap type is pipe, we need to add an exposed comm tile\n        # with ring-based overlap, we have a special schedule for comm and avoid\n        # sending an extra tile we have in the beginning\n        net_exposed_time += net_tile\n        time += net_tile\n    self.processing_time = time\n    self.net_exposed_time = net_exposed_time\n    self._processed_flag = True\n    return self.processing_time\n\n  def get_exposed_net_time(self, stage, baseblock=True):\n    # only use after calling compute_processing_time(), otherwise it's set with None\n    assert self._processed_flag\n    return self.net_exposed_time\n\n  def get_required_bandwidth(self, stage, baseblock=True):\n    assert self._processed_flag\n    net_tile_size = self.get_comm_tile(stage, baseblock)\n    flop_time = self.compute_flops_time(stage)\n    flop_time_slowed = flop_time / (1 - self.net.processor_usage)\n    flop_tile_slowed = flop_time_slowed / self.num_tiles\n    return net_tile_size / flop_tile_slowed\n\nclass BatchMatMul(Layer):\n  def __init__(self, name, sys, batch, size_a, contraction_size, size_b,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    m, n, k = size_a, contraction_size, size_b\n    super().__init__(name,\n                     sys,\n                     fw_flops=batch*2*m*n*k,\n                     agrad_flops=batch*2*2*m*n*k,\n                     inputs_size=batch*(m*n+n*k),\n                     output_size=batch*m*k,\n                     activation_space=batch*(m*n+n*k),\n                     activation_grads=batch*m*k,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def use_matrix_engine(self):\n    return True\n\n# https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html\n# https://cthorey.github.io./blog/2016/backpropagation/\nclass LayerNorm(Layer):\n  def __init__(self, name, sys, act_size, hidden,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    super().__init__(name,\n                     sys,\n                     fw_flops=9*act_size,\n                     agrad_flops=14*act_size,\n                     wgrad_flops=7*act_size,\n                     inputs_size=act_size,\n                     output_size=act_size,\n                     activation_space=act_size,\n                     activation_grads=act_size,\n                     weight_space=2*hidden,\n                     weight_grads=2*hidden,\n                     optim_space=2*2*hidden,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n\nclass DropOut(Layer):\n  def __init__(self, name, sys, act_size,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    super().__init__(name,\n                     sys,\n                     fw_flops=act_size,\n                     agrad_flops=act_size,\n                     inputs_size=act_size,\n                     output_size=act_size,\n                     activation_space=act_size,\n                     activation_grads=act_size,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n\n  # need to account for DropOut mask of bool type that takes 1 B per element\n  # mask is the only DropOut activation\n  def get_activation(self):\n    return self.activation_space\n\n  def get_activation_grad(self):\n    return self.activation_grads\n\n  def get_fw_mem_accessed(self):\n    mask_size = self.activation_space\n    mem_accessed = self.inputs_size + self.output_size\n    mem_accessed *= self.bytes_per_element\n    mem_accessed += mask_size\n    return mem_accessed\n\n  def get_agrad_mem_accessed(self):\n    return self.get_fw_mem_accessed()\n\n\n# https://mlfromscratch.com/activation-functions-explained/#/\nclass GeLU(Layer):\n  def __init__(self, name, sys, act_size,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True,\n               fused=False):\n    # Fused GeLU runs right after previous Linear layer and does not store\n    # activations or gradients\n    self._fused = fused\n    if fused:\n      eff_act_space = 0\n      eff_act_grads = 0\n    else:\n      eff_act_space = act_size\n      eff_act_grads = act_size\n    super().__init__(name, sys, fw_flops=8*act_size, agrad_flops=13*act_size,\n                     inputs_size=act_size, output_size=act_size,\n                     activation_space=eff_act_space,\n                     activation_grads=eff_act_grads,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def get_agrad_mem_accessed(self):\n    return self.get_fw_mem_accessed()\n\n\n# https://automata88.medium.com/how-to-implement-the-softmax-derivative-independently-from-any-loss-function-ae6d44363a9d\nclass SoftMax(Layer):\n  def __init__(self, name, sys, act_size,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    super().__init__(name,\n                     sys,\n                     fw_flops=5*act_size,\n                     agrad_flops=8*act_size,\n                     inputs_size=act_size,\n                     output_size=act_size,\n                     activation_space=act_size,\n                     activation_grads=act_size,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def get_agrad_mem_accessed(self):\n    return self.get_fw_mem_accessed()\n\n\n# https://explained.ai/matrix-calculus/#sec:1.4.2\nclass ElementWise(Layer):\n  def __init__(self, name, sys, operand1, operand2,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    act_size = max(operand1, operand2)\n    super().__init__(name,\n                     sys,\n                     fw_flops=act_size,\n                     agrad_flops=(operand1+operand2),\n                     inputs_size=(operand1+operand2),\n                     output_size=act_size,\n                     activation_space=(operand1+operand2),\n                     activation_grads=act_size,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n\n# Splits activation on the forward pass, sums gradients on the backward\nclass Fork(Layer):\n  def __init__(self, name, sys, act_size, num_users,\n               needs_recompute=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    self.num_users = num_users\n    super().__init__(name,\n                     sys,\n                     inputs_size=act_size,\n                     agrad_flops=num_users*act_size,\n                     activation_space=act_size,\n                     # Gradients from num_users accumulated in a single storage\n                     # that's accounted in the other layers\n                     # use 0 here to avoid double accounting\n                     activation_grads=0,\n                     needs_recompute=needs_recompute,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def get_fw_mem_accessed(self):\n    return 0\n\n  def get_agrad_mem_accessed(self):\n    return self.activation_space * self.bytes_per_element * (\n      self.num_users + 1)\n\n\nclass TPComm(Layer):\n\n  def __init__(self, name, sys, act_size, net_id, num_peers, tensor_par_comm_type,\n               conjugate=False, in_network_reduction=False,\n               needs_recomm=False, activation_reused=False,\n               activation_stored=True, output_stored=True):\n    self.net = sys.get_network(net_id)\n    self.num_peers = num_peers\n    self.tensor_par_comm_type = tensor_par_comm_type\n    self.comm_size = act_size\n    self.conjugate = conjugate\n    if self.num_peers == 1:\n      fw_flops = 0\n      bw_flops = 0\n      in_size = 0\n      out_size = 0\n    else:\n      if not self.conjugate:\n        # FW pass Identity/AllGather, BW pass AllReduce/ReduceScatter\n        fw_flops = 0\n        if not in_network_reduction:\n          bw_flops = act_size * (self.num_peers - 1) / self.num_peers\n        else:\n          bw_flops = 0\n        in_size = act_size\n        out_size = act_size\n      else:\n        # Conjugate function is opposite\n        if not in_network_reduction:\n          fw_flops = act_size * (self.num_peers - 1) / self.num_peers\n        else:\n          fw_flops = 0\n        bw_flops = 0\n        in_size = act_size\n        out_size = act_size\n    super().__init__(name,\n                     sys,\n                     fw_flops=fw_flops,\n                     agrad_flops=bw_flops,\n                     inputs_size=in_size,\n                     output_size=out_size,\n                     activation_space=in_size,\n                     activation_grads=out_size,\n                     needs_recomm=needs_recomm,\n                     activation_reused=activation_reused,\n                     activation_stored=activation_stored,\n                     output_stored=output_stored)\n\n  def get_activation(self):\n    if self.tensor_par_comm_type == 'rs_ag':\n      return self.activation_space * self.bytes_per_element / self.num_peers\n    else:\n      if self.conjugate:\n        return self.activation_space * self.bytes_per_element\n      else:\n        # Identity\n        return 0\n\n  def get_fw_mem_accessed(self):\n    if not self.tensor_par_comm_type == 'rs_ag' and not self.conjugate:\n      # Identity\n      return 0\n    else:\n      return super().get_fw_mem_accessed()\n\n  def get_activation_grad(self):\n    if self.tensor_par_comm_type == 'rs_ag':\n      return self.activation_space * self.bytes_per_element / self.num_peers\n    else:\n      if not self.conjugate:\n        return self.activation_grads * self.bytes_per_element\n      else:\n        # Identity\n        return 0\n\n  def get_agrad_mem_accessed(self):\n    if not self.tensor_par_comm_type == 'rs_ag' and self.conjugate:\n      # Identity\n      return 0\n    else:\n      return super().get_agrad_mem_accessed()\n\n  def get_comm_bytes(self, stage, baseblock=True):\n    if self.num_peers == 1:\n      return 0\n    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (\n      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)\n    if (not split_comm and (self.conjugate and stage == 'agrad' or\n        not self.conjugate and stage == 'fw')):\n      # Identity FW or AllReduce BW\n      return 0\n    else:\n      if stage == 'fw' or stage == 'agrad':\n        return self.comm_size * self.bytes_per_element\n      if stage == 'wgrad' and self.needs_recomm and (\n          split_comm or self.conjugate):\n        # with AG Redo, we need recomm both on FW pass (not self.conjugate)\n        # and BW pass (self.conjugate)\n        return self.comm_size * self.bytes_per_element\n      else:\n        # optim and wgrad stage has no comm if no ag_redo flag for RS_AG\n        return 0\n\n  def compute_net_time(self, stage, baseblock=True):\n    if self.num_peers == 1:\n      return 0\n    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (\n      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)\n    net_compute_time = super().compute_processing_time(stage)\n    if split_comm:\n      if self.conjugate:\n        # ReduceScatter case\n        fw_net_time = self.net.time('reduce_scatter',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n        bw_net_time = self.net.time('all_gather',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n      else:\n        #AllGather case\n        fw_net_time = self.net.time('all_gather',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n        bw_net_time = self.net.time('reduce_scatter',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n    else:\n      if self.conjugate:\n        fw_net_time = self.net.time('all_reduce',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n        bw_net_time = 0\n      else:\n        fw_net_time = 0\n        bw_net_time = self.net.time('all_reduce',\n          self.get_comm_bytes(stage, baseblock), self.num_peers)\n    if stage == 'fw':\n      return fw_net_time + net_compute_time\n    elif stage == 'agrad':\n      return bw_net_time + net_compute_time\n    elif stage == 'wgrad':\n      # with AG Redo, we need recomm both on FW pass (not self.conjugate)\n      # and BW pass (self.conjugate)\n      if self.needs_recomm:\n        return fw_net_time + net_compute_time\n      else:\n        return 0\n    elif stage == 'optim':\n      return 0\n    else:\n      raise Exception(f'Bad compute stage : {stage}')\n    return 0\n\n  def get_exposed_net_time(self, stage, baseblock=True):\n    # only use after calling compute_processing_time(), otherwise it's set witth None\n    return self.compute_net_time(stage, baseblock)\n\n  def compute_processing_time(self, stage):\n    return 0\n"
  },
  {
    "path": "calculon/llm/llm.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nfrom calculon import *\nfrom .layers import *\n\n\nclass Llm:\n  \"\"\"\n  This implements the transformer with tensor, pipeline, and data parallelism.\n  Using it follows this pattern:\n  1. Initialize the model with certain model parameters\n  2. Compile it with certain optimizations and parallelization strategies\n  3. Run on particular hardware system\n  \"\"\"\n\n  class Application:\n    \"\"\"Specifies the application configuration.\"\"\"\n    def __init__(self, cfg):\n      self.cfg = cfg\n      self.hidden = cfg['hidden']\n      self.feedforward = cfg['feedforward']\n      self.seq_size = cfg['seq_size']\n      self.attn_heads = cfg['attn_heads']\n      self.attn_size = cfg['attn_size']\n      self.num_blocks = cfg['num_blocks']\n\n    def num_parameters(self):\n      # https://cs.stanford.edu/~matei/papers/2021/sc_megatron_lm.pdf\n      # Equation 2\n      p = 2 * self.hidden * self.feedforward                   # MLP weights\n      p += 4 * self.hidden * self.attn_heads * self.attn_size  # Attn weights\n      p += self.hidden + self.feedforward                      # biases MLP\n      p += 3 * self.attn_heads * self.attn_size + self.hidden  # biases Attn\n      p += 2 * 2 * self.hidden                                 # layer norm\n      p *= self.num_blocks                                     # per each block\n      p += (51200 + self.seq_size) * self.hidden               # embeddings\n      return p\n\n  class Execution:\n    \"\"\"Specifies the execution configuration.\"\"\"\n\n    @staticmethod\n    def fields():\n      return (\n        'num_procs', 'tensor_par', 'pipeline_par', 'data_par', 'tensor_par_net',\n        'pipeline_par_net', 'data_par_net', 'batch_size', 'microbatch_size',\n        'datatype', 'fused_activation', 'attention_type', 'activation_recompute',\n        'pipeline_interleaving', 'optimizer_sharding', 'tensor_par_comm_type',\n        'tensor_par_overlap', 'seq_par_ag_redo', 'data_par_overlap',\n        'weight_offload', 'activations_offload', 'optimizer_offload', 'training')\n\n    @staticmethod\n    def from_json(cfg):\n      assert set(cfg.keys()) == set(Llm.Execution.fields())\n      values = [cfg[field] for field in Llm.Execution.fields()]\n      return Llm.Execution(*values)\n\n    def __init__(self, num_procs, tensor_par, pipeline_par, data_par,\n                 tensor_par_net, pipeline_par_net, data_par_net,\n                 batch_size, microbatch_size, datatype,\n                 fused_activation, attention_type, activation_recompute,\n                 pipeline_interleaving, optimizer_sharding,\n                 tensor_par_comm_type, tensor_par_overlap,\n                 seq_par_ag_redo, data_par_overlap, weight_offload,\n                 activations_offload, optimizer_offload, training):\n      self.training = training\n      self.num_procs = num_procs\n      assert self.num_procs > 0\n      self.tensor_par = tensor_par\n      assert self.tensor_par > 0\n      self.pipeline_par = pipeline_par\n      assert self.pipeline_par > 0\n      self.data_par = data_par\n      assert self.data_par > 0\n      assert self.num_procs == self.tensor_par * self.pipeline_par * \\\n        self.data_par, 'tensor * pipeline * data parallelism != num_procs'\n      self.tensor_par_net = tensor_par_net\n      self.pipeline_par_net = pipeline_par_net\n      self.data_par_net = data_par_net\n      self.global_batch_size = batch_size\n      assert self.global_batch_size > 0\n      self.microbatch_size = microbatch_size\n      assert self.microbatch_size > 0\n      assert self.global_batch_size % self.data_par == 0\n      self._local_batch_size = self.global_batch_size // self.data_par\n      assert self._local_batch_size % self.microbatch_size == 0\n      self._num_microbatches = self._local_batch_size // self.microbatch_size\n      self.datatype = datatype\n      self.fused_activation = fused_activation\n      self.attention_type = attention_type\n      assert self.attention_type in ['multihead', 'multiquery']\n      self.activation_recompute = activation_recompute\n      assert self.activation_recompute in ['full', 'attn_only', 'none']\n      if self.activation_recompute in ['full', 'attn_only']:\n        assert self.training, \"We only perform recompute during training\"\n      self.pipeline_interleaving = pipeline_interleaving\n      assert self.pipeline_interleaving > 0, \\\n        f'Bad pipeline interleaving of {self.pipeline_interleaving}'\n      if self.pipeline_par == 1:\n        assert self.pipeline_interleaving == 1, \\\n        f'Bad pipeline interleaving of {self.pipeline_interleaving} with PP=1'\n      self.optimizer_sharding = optimizer_sharding\n      if self.optimizer_sharding:\n        assert self.data_par > 1, \"We perform optimizer sharding with DP > 1\"\n      self.tensor_par_comm_type = tensor_par_comm_type\n      self.in_network_reduction = False\n      assert self.tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']\n      self.tensor_par_overlap = tensor_par_overlap\n      assert self.tensor_par_overlap in ['none', 'ring', 'pipe']\n      if self.tensor_par_overlap != 'none':\n        assert self.tensor_par > 1, \"We perform TP comm overlap with TP > 1\"\n      self._sequence_par = self.tensor_par_comm_type == 'rs_ag'\n      self.seq_par_ag_redo = seq_par_ag_redo\n      if self.seq_par_ag_redo:\n        assert self.tensor_par_comm_type == 'rs_ag', \"We only redo AG comm\"\n        assert self._sequence_par, \"We only redo AG with sequence parallelism\"\n        assert self.activation_recompute != 'full', \\\n          \"We assume no extra AG with full recompute\"\n      self._pipeline_par_rs_ag = \\\n        self.tensor_par_comm_type in ['p2p_rs_ag', 'rs_ag']\n      self.data_par_overlap = data_par_overlap\n      if self.data_par_overlap:\n        assert self.training, \"We only perform DP comm overlap during training\"\n        assert self.data_par > 1, \"We perform DP comm overlap with DP > 1\"\n      self.weight_offload = weight_offload\n      self.activations_offload = activations_offload\n      self.optimizer_offload = optimizer_offload\n      if self.optimizer_offload:\n        assert self.training, \\\n          \"We only perform optimizer offloading during training\"\n\n    def get_json(self):\n      keys = Llm.Execution.fields()\n      values = [\n        self.num_procs, self.tensor_par, self.pipeline_par, self.data_par, self.tensor_par_net,\n        self.pipeline_par_net, self.data_par_net, self.global_batch_size, self.microbatch_size,\n        self.datatype, self.fused_activation, self.attention_type, self.activation_recompute,\n        self.pipeline_interleaving, self.optimizer_sharding, self.tensor_par_comm_type,\n        self.tensor_par_overlap, self.seq_par_ag_redo, self.data_par_overlap,\n        self.weight_offload, self.activations_offload, self.optimizer_offload, self.training\n      ]\n      assert len(keys) == len(values)\n      return dict(zip(keys, values))\n\n    def get_peers_json(self):\n      peers = {}\n      for di in range(self.data_par):\n        for pi in range(self.pipeline_par):\n          for ti in range(self.tensor_par):\n            nid = (di * self.tensor_par * self.pipeline_par +\n                   pi * self.tensor_par +\n                   ti)\n            peers[nid] = {}\n\n            # tensor parallelism peers\n            if self.tensor_par > 1:\n              peers[nid]['tensor'] = []\n              for ti2 in range(self.tensor_par):\n                pid = (di * self.tensor_par * self.pipeline_par +\n                       pi * self.tensor_par +\n                       ti2)\n                peers[nid]['tensor'].append(pid)\n\n            # pipeline parallelism peer\n            if self.pipeline_par > 1:\n              peers[nid]['pipeline'] = None\n              pi2 = (pi + 1) % self.pipeline_par\n              pid = (di * self.tensor_par * self.pipeline_par +\n                     pi2 * self.tensor_par +\n                     ti)\n              peers[nid]['pipeline'] = pid\n\n            # data parallelism peers\n            if self.data_par > 1:\n              peers[nid]['data'] = []\n              for di2 in range(self.data_par):\n                pid = (di2 * self.tensor_par * self.pipeline_par +\n                       pi * self.tensor_par +\n                       ti)\n                peers[nid]['data'].append(pid)\n      return peers\n\n\n  # This is used for errors where the user may not be fully aware of\n  # limitations. Use it like this:\n  #   raise self.Error(f'Foo bar {num1} is not {num2}')\n  class Error(Exception):\n    pass\n\n  @staticmethod\n  def _factors(x):\n    for cand in range(1, x + 1):\n      if x % cand == 0:\n        yield cand\n\n  @staticmethod\n  def get_all_tensor_parallelisms(num_procs, hidden, attn_heads):\n    for cand in Llm._factors(num_procs):\n      if hidden % cand == 0 and attn_heads % cand == 0:\n        yield cand\n\n  @staticmethod\n  def get_all_pipeline_parallelisms(num_procs, tensor_par, num_blocks):\n    assert num_procs % tensor_par == 0\n    max_pp = min(num_procs // tensor_par, num_blocks)\n    for cand in Llm._factors(max_pp):\n      if (num_procs % (tensor_par * cand) == 0 and\n          num_blocks % cand == 0):\n        yield cand\n\n  @staticmethod\n  def get_data_parallelism(num_procs, tensor_par, pipeline_par):\n    assert num_procs % (tensor_par * pipeline_par) == 0, \\\n      f'np={num_procs} tp={tensor_par} pp={pipeline_par}'\n    return num_procs // (tensor_par * pipeline_par)\n\n  @staticmethod\n  def get_valid_pipeline_interleavings(num_blocks, pipeline_par):\n    assert num_blocks % pipeline_par == 0\n    if pipeline_par == 1:\n      yield 1\n    else:\n      max_ppint = num_blocks // pipeline_par\n      yield from Llm._factors(max_ppint)\n\n  @staticmethod\n  def get_valid_microbatch_sizes(\n      seq_size, tensor_par, data_par, global_batch_size, pipeline_par):\n    assert global_batch_size % data_par == 0\n    local_batch_size = global_batch_size // data_par\n    for cand in Llm._factors(local_batch_size):\n      batch_seq = cand * seq_size\n      if batch_seq % tensor_par == 0:\n        yield cand\n\n  @staticmethod\n  def can_redo_ag(tensor_par_comm_type, activation_recompute):\n    return tensor_par_comm_type == 'rs_ag' and activation_recompute != 'full'\n\n  def __init__(self, app, log):\n    assert isinstance(app, self.Application)\n    self.app = app\n    self.log = log\n\n    # Set during compile\n    self.exe = None\n\n    # Set during run\n    self.sys = None\n\n    # State of calling compile() and run()\n    self._compiled = False\n    self._executed = False\n\n    # Holds the layers in a single block\n    self._llm_block = []\n\n    # A chunk is a set of blocks for microbatch before passing to the next\n    # processor in the pipeline. Each chunk is modeled as a base\n    # block that is repeated N-1 times and followed by 1 edge block.\n    # Recommunication time is the same in both base and edge blocks.\n    self._blocks_per_proc = None\n    self._bubble_reduction_blocks = None\n    self._blocks_per_chunk = None\n    self._chunks_per_proc = None\n    self._baseblocks_per_chunk = None\n    self._edgeblocks_per_chunk = None\n\n    # Misc compilation values\n    self._bytes_per_element = None\n    self._batch_seq = None\n    self._batch_seq_par = None\n    self._activation_size = None\n    self._seq_par_activation_size = None\n\n    # Assignments to specific networks\n    self._tp_net = None\n    self._pp_net = None\n    self._dp_net = None\n\n    # metrics collected after run for each microbatch\n    self._block_fw_flops = None\n    self._block_fw_flops_time = None\n    self._block_fw_mem_accessed = None\n    self._block_fw_mem_time = None\n    self._block_fw_time = None\n    self._block_re_flops = None\n    self._block_re_flops_time = None\n    self._block_re_mem_accessed = None\n    self._block_re_mem_time = None\n    self._block_re_time = None\n    self._block_agrad_flops = None\n    self._block_agrad_flops_time = None\n    self._block_agrad_mem_accessed = None\n    self._block_agrad_mem_time = None\n    self._block_agrad_time = None\n    self._block_wgrad_flops = None\n    self._block_wgrad_flops_time = None\n    self._block_wgrad_mem_accessed = None\n    self._block_wgrad_mem_time = None\n    self._block_wgrad_time = None\n    self._block_optim_flops = None\n    self._block_optim_flops_time = None\n    self._block_optim_mem_accessed = None\n    self._block_optim_mem_time = None\n    self._block_optim_time = None\n\n    self._baseblock_fw_tp_size = None\n    self._edgeblock_fw_tp_size = None\n    self._baseblock_agrad_tp_size = None\n    self._edgeblock_agrad_tp_size = None\n    self._baseblock_recomm_size = None\n    self._edgeblock_recomm_size = None\n    self._block_fw_pp_size = None\n    self._block_bw_pp_size = None\n    self._block_dp_size = None\n    self._baseblock_fw_time_no_offload = None\n    self._edgeblock_fw_time_no_offload = None\n    self._baseblock_bw_time_no_offload = None\n    self._edgeblock_bw_time_no_offload = None\n    self._baseblock_fw_offload_overhead = None\n    self._edgeblock_fw_offload_overhead = None\n    self._baseblock_bw_offload_overhead = None\n    self._edgeblock_bw_offload_overhead = None\n    self._baseblock_fw_time = None\n    self._edgeblock_fw_time = None\n    self._baseblock_bw_time = None\n    self._edgeblock_bw_time = None\n    self._block_dp_time = None\n    self._tp_bw_overlap_req = None\n    self._dp_bw_overlap_req_chunk = None\n    self._dp_bw_overlap_req_tail = None\n\n    self._block_weight_space = None\n    self._block_act_working_space = None\n    self._block_act_storage_space = None\n    self._block_act_checkpoint_size = None\n    self._block_weight_grad_space = None\n    self._block_weight_grad_space_no_sharding = None\n    self._block_act_grad_space = None\n    self._block_optimizer_space = None\n\n    # Top level memory usage stats\n    self._weight_space = None\n    self._act_space = None\n    self._act_checkpoint_size = None\n    self._weight_grad_space = None\n    self._act_grad_space = None\n    self._optimizer_space = None\n\n    # Top level throughput stats\n    self._fw_flops = None\n    self._fw_flops_time = None\n    self._fw_mem_accessed = None\n    self._fw_mem_time = None\n    self._fw_time = None\n    self._baseblock_fw_tp_time = None\n    self._edgeblock_fw_tp_time = None\n    self._baseblock_fw_tp_time_exposed = None\n    self._edgeblock_fw_tp_time_exposed = None\n    self._re_flops = None\n    self._re_flops_time = None\n    self._re_mem_accessed = None\n    self._re_mem_time = None\n    self._re_time = None\n    self._baseblock_recomm_time = None\n    self._edgeblock_recomm_time = None\n    self._baseblock_recomm_time_exposed = None\n    self._edgeblock_recomm_time_exposed = None\n    self._agrad_flops = None\n    self._agrad_flops_time = None\n    self._agrad_mem_accessed = None\n    self._agrad_mem_time = None\n    self._baseblock_agrad_tp_time = None\n    self._edgeblock_agrad_tp_time = None\n    self._baseblock_agrad_tp_time_exposed = None\n    self._edgeblock_agrad_tp_time_exposed = None\n    self._agrad_time = None\n    self._wgrad_flops = None\n    self._wgrad_flops_time = None\n    self._wgrad_mem_accessed = None\n    self._wgrad_mem_time = None\n    self._wgrad_time = None\n    self._optim_flops = None\n    self._optim_flops_time = None\n    self._optim_mem_accessed = None\n    self._optim_mem_time = None\n    self._optim_time = None\n\n    # Top level network stats\n    self._tp_comm_time_exposed = None\n    self._tp_comm_time_link = None\n    self._recomm_time_exposed = None\n    self._recomm_time_link = None\n    self._pp_comm_time_exposed = None\n    self._pp_comm_time_link = None\n    self._dp_comm_time_exposed = None\n    self._dp_comm_time_link = None\n    self._bubble_time = None\n\n  @staticmethod\n  def get_stats_fields():\n    return (\n      'block_fw_flops',\n      'block_fw_flops_time',\n      'block_fw_mem_accessed',\n      'block_fw_mem_time',\n      'block_fw_time',\n      'baseblock_fw_tp_time',\n      'edgeblock_fw_tp_time',\n      'baseblock_fw_tp_time_exposed',\n      'edgeblock_fw_tp_time_exposed',\n      'block_re_flops',\n      'block_re_flops_time',\n      'block_re_mem_accessed',\n      'block_re_mem_time',\n      'block_re_time',\n      'baseblock_recomm_time',\n      'edgeblock_recomm_time',\n      'baseblock_recomm_time_exposed',\n      'edgeblock_recomm_time_exposed',\n      'block_agrad_flops',\n      'block_agrad_flops_time',\n      'block_agrad_mem_accessed',\n      'block_agrad_mem_time',\n      'block_agrad_time',\n      'baseblock_agrad_tp_time',\n      'edgeblock_agrad_tp_time',\n      'baseblock_agrad_tp_time_exposed',\n      'edgeblock_agrad_tp_time_exposed',\n      'block_wgrad_flops',\n      'block_wgrad_flops_time',\n      'block_wgrad_mem_accessed',\n      'block_wgrad_mem_time',\n      'block_wgrad_time',\n      'block_optim_flops',\n      'block_optim_flops_time',\n      'block_optim_mem_accessed',\n      'block_optim_mem_time',\n      'block_optim_time',\n\n      'baseblock_fw_tp_size',\n      'edgeblock_fw_tp_size',\n      'baseblock_bw_tp_size',\n      'edgeblock_bw_tp_size',\n      'baseblock_recomm_size',\n      'edgeblock_recomm_size',\n      'block_fw_pp_size',\n      'block_bw_pp_size',\n      'block_dp_size',\n      'tp_bw_overlap_req',\n      'dp_bw_overlap_req_chunk',\n      'dp_bw_overlap_req_tail',\n\n      'block_weight_space',\n      'block_act_working_space',\n      'block_act_storage_space',\n      'block_act_checkpoint_size',\n      'block_weight_grad_space',\n      'block_weight_grad_space_no_sharding',\n      'block_act_grad_space',\n      'block_optimizer_space',\n\n      'weight_space_with_offload',\n      'act_space_with_offload',\n      'act_checkpoint_size_with_offload',\n      'act_grad_space_with_offload',\n      'weight_grad_space_with_offload',\n      'optimizer_space_with_offload',\n\n      'weight_space',\n      'act_space',\n      'act_checkpoint_size',\n      'act_grad_space',\n      'weight_grad_space',\n      'optimizer_space',\n\n      'fw_time',\n      'bw_time',\n      'optim_step_time',\n      'recompute_time',\n      'recomm_link_time',\n      'recomm_exposed_time',\n      'bubble_time',\n      'tp_comm_link_time',\n      'pp_comm_link_time',\n      'dp_comm_link_time',\n      'tp_comm_exposed_time',\n      'pp_comm_exposed_time',\n      'dp_comm_exposed_time',\n      'fw_offload_exposed_time',\n      'bw_offload_exposed_time',\n      'total_time',\n      'act_offload_bw_req',\n      'weight_offload_bw_req',\n      'optim_offload_bw_req',\n      'offload_mem_bw_req',\n      'proc_mem_tier1_cap_req',\n      'proc_mem_tier2_cap_req',\n      'useful_flops',\n      'compute_efficiency',\n      'system_efficiency',\n      'total_efficiency',\n      'sample_rate')\n\n  def get_stats_values(self):\n    assert self._executed\n    return (\n      self._block_fw_flops,\n      self._block_fw_flops_time,\n      self._block_fw_mem_accessed,\n      self._block_fw_mem_time,\n      self._block_fw_time,\n      self._baseblock_fw_tp_time,\n      self._edgeblock_fw_tp_time,\n      self._baseblock_fw_tp_time_exposed,\n      self._edgeblock_fw_tp_time_exposed,\n      self._block_re_flops,\n      self._block_re_flops_time,\n      self._block_re_mem_accessed,\n      self._block_re_mem_time,\n      self._block_re_time,\n      self._baseblock_recomm_time,\n      self._edgeblock_recomm_time,\n      self._baseblock_recomm_time_exposed,\n      self._edgeblock_recomm_time_exposed,\n      self._block_agrad_flops,\n      self._block_agrad_flops_time,\n      self._block_agrad_mem_accessed,\n      self._block_agrad_mem_time,\n      self._block_agrad_time,\n      self._baseblock_agrad_tp_time,\n      self._edgeblock_agrad_tp_time,\n      self._baseblock_agrad_tp_time_exposed,\n      self._edgeblock_agrad_tp_time_exposed,\n      self._block_wgrad_flops,\n      self._block_wgrad_flops_time,\n      self._block_wgrad_mem_accessed,\n      self._block_wgrad_mem_time,\n      self._block_wgrad_time,\n      self._block_optim_flops,\n      self._block_optim_flops_time,\n      self._block_optim_mem_accessed,\n      self._block_optim_mem_time,\n      self._block_optim_time,\n\n      self._baseblock_fw_tp_size,\n      self._edgeblock_fw_tp_size,\n      self._baseblock_agrad_tp_size,\n      self._edgeblock_agrad_tp_size,\n      self._baseblock_recomm_size,\n      self._edgeblock_recomm_size,\n      self._block_fw_pp_size,\n      self._block_bw_pp_size,\n      self._block_dp_size,\n      self._tp_bw_overlap_req,\n      self._dp_bw_overlap_req_chunk,\n      self._dp_bw_overlap_req_tail,\n\n      self._block_weight_space,\n      self._block_act_working_space,\n      self._block_act_storage_space,\n      self._block_act_checkpoint_size,\n      self._block_weight_grad_space,\n      self._block_weight_grad_space_no_sharding,\n      self._block_act_grad_space,\n      self._block_optimizer_space,\n\n      self.get_weight_space_min(),\n      self.get_act_space_min(),\n      self.get_act_checkpoint_size_min(),\n      self.get_act_grad_space_min(),\n      self.get_weight_grad_space_min(),\n      self.get_optimizer_space_min(),\n\n      self.get_weight_space(),\n      self.get_act_space(),\n      self.get_act_checkpoint_size(),\n      self.get_act_grad_space(),\n      self.get_weight_grad_space(),\n      self.get_optimizer_space(),\n\n      self.get_fw_time(),\n      self.get_bw_time(),\n      self.get_optim_step_time(),\n      self.get_recompute_time(),\n      self.get_recomm_link_time(),\n      self.get_recomm_exposed_time(),\n      self.get_bubble_time(),\n      self.get_tp_comm_link_time(),\n      self.get_pp_comm_link_time(),\n      self.get_dp_comm_link_time(),\n      self.get_tp_comm_exposed_time(),\n      self.get_pp_comm_exposed_time(),\n      self.get_dp_comm_exposed_time(),\n      self.get_fw_offload_overhead(),\n      self.get_bw_offload_overhead(),\n      self.get_total_time(),\n      self.get_act_offload_bw_req(),\n      self.get_weight_offload_bw_req(),\n      self.get_optim_offload_bw_req(),\n      self.get_offload_mem_bw_req(),\n      self.get_mem_tier1_cap_req(),\n      self.get_mem_tier2_cap_req(),\n      self.get_useful_flops(),\n      self.get_compute_efficiency(),\n      self.get_system_efficiency(),\n      self.get_total_efficiency(),\n      self.get_sample_rate())\n\n  def get_stats_json(self, include_layers):\n    assert self._executed\n    keys = Llm.get_stats_fields()\n    values = self.get_stats_values()\n    assert len(keys) == len(values), f'{len(keys)} {len(values)}'\n    j = dict(zip(keys, values))\n    if include_layers:\n      j['layers'] = []\n      for layer in self._llm_block:\n        j['layers'].append(layer.get_stats_json())\n    return j\n\n  def _build_attn_block(self):\n    recompute_flag = self.exe.activation_recompute == \"full\"\n    recompute_attn_flag = self.exe.activation_recompute in \\\n      [\"full\", \"attn_only\"]\n    recompute_ag_flag = recompute_attn_flag or self.exe.seq_par_ag_redo\n\n    assert self.app.hidden % self.exe.tensor_par == 0, (\n      f\"We should split hidden={self.app.hidden} between\"\n      f\" {self.exe.tensor_par} TP partitions evenly\")\n    assert self.app.feedforward % self.exe.tensor_par == 0, (\n      f\"We should split feedforward={self.app.feedforward} between\"\n      f\" {self.exe.tensor_par} TP partitions evenly\")\n    assert self.app.attn_heads % self.exe.tensor_par == 0, (\n      f\"We should split {self.app.attn_heads} attn_heads between\"\n      f\" {self.exe.tensor_par} TP partitions evenly\")\n\n    self._llm_block.append(Fork(\n      \"AttnBlock_Fork\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      2,\n      needs_recompute=recompute_flag,\n      # We account this activation when consider Residual and LayerNorm\n      activation_stored=True))\n    self._llm_block.append(LayerNorm(\n      \"AttnBlock_LayerNorm\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      self.app.hidden,\n      needs_recompute=recompute_flag,\n      # Activation is stored in Fork instead\n      activation_stored=False,\n      activation_reused=True))\n    if self.exe.tensor_par_overlap == 'none':\n      self._llm_block.append(TPComm(\n        \"AttnBlock_F\",\n        self.sys,\n        self._activation_size,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        # We only compute flops/mem analyzing this layers, comm analyzed later\n        # This is conservative estimate that does not consider p2p_rs_ag\n        # because we don't differentiate between edge and middle blocks here\n        tensor_par_comm_type=self.exe.tensor_par_comm_type,\n        conjugate=False,\n        in_network_reduction=self.exe.in_network_reduction,\n        needs_recomm=recompute_ag_flag))\n      self._llm_block.append(Fork(\n        \"AttnBlock_Multihead_Fork\",\n        self.sys,\n        self._activation_size,\n        3,\n        needs_recompute=recompute_ag_flag,\n        # With seq_par, we use activations from Comm layers to reflect that\n        # they're split, otherwise we keep full size activations\n        activation_stored=(not recompute_ag_flag)))\n      self._llm_block.append(Linear(\n        \"AttnBlock_Query\",\n        self.sys,\n        self._batch_seq,\n        self.app.hidden,\n        self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,\n        needs_recompute=recompute_flag,\n        # Activation is stored in Fork instead,\n        activation_stored=False,\n        activation_reused=True))\n      if self.exe.attention_type == 'multihead':\n        self._llm_block.append(Linear(\n          \"AttnBlock_Key\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n        self._llm_block.append(Linear(\n          \"AttnBlock_Value\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n      elif self.exe.attention_type == 'multiquery':\n        # Multiqueri attention uses the same K, V for all \"heads\" resulting in\n        # smaller Wk and Wv, less matmul, faster inference\n        self._llm_block.append(Linear(\n          \"AttnBlock_Key\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_size,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n        self._llm_block.append(Linear(\n          \"AttnBlock_Value\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_size,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n      else:\n        raise self.Error('Wrong attention type', self.exe.attention_type)\n    else:\n      if self.exe.attention_type == 'multihead':\n        self._llm_block.append(LinearOverlapped(\n          \"AttnBlock_QKV_AG\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_heads * self.app.attn_size *3,          # Q, K, V\n          self.exe.tensor_par_comm_type,\n          self.exe.tensor_par,\n          self.exe.tensor_par_net,\n          self.exe.tensor_par,\n          conjugate=False,\n          tp_overlap=self.exe.tensor_par_overlap,\n          needs_recompute=recompute_flag,\n          needs_recomm=recompute_ag_flag))\n      elif self.exe.attention_type == 'multiquery':\n        self._llm_block.append(LinearOverlapped(\n          \"AttnBlock_Query_AG\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_heads * self.app.attn_size,\n          self.exe.tensor_par_comm_type,\n          self.exe.tensor_par,\n          self.exe.tensor_par_net,\n          self.exe.tensor_par,\n          conjugate=False,\n          tp_overlap=self.exe.tensor_par_overlap,\n          needs_recompute=recompute_flag,\n          needs_recomm=recompute_ag_flag))\n        self._llm_block.append(Fork(\n          \"AttnBlock_KV_Fork\",\n          self.sys,\n          self._activation_size,\n          2,\n          needs_recompute=recompute_ag_flag,\n          # With seq_par, we use activations from Comm layers to reflect that\n          # they're split, otherwise we keep full size activations\n          activation_stored=(not recompute_ag_flag)))\n        self._llm_block.append(Linear(\n          \"AttnBlock_Key\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_size,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n        self._llm_block.append(Linear(\n          \"AttnBlock_Value\",\n          self.sys,\n          self._batch_seq,\n          self.app.hidden,\n          self.app.attn_size,\n          needs_recompute=recompute_flag,\n          # Activation is stored in Fork instead,\n          activation_stored=False,\n          activation_reused=True))\n      else:\n        raise self.Error('Wrong attention type', self.exe.attention_type)\n    self._llm_block.append(BatchMatMul(\n      \"AttnBlock_Multihead_Key_Query\",\n      self.sys,\n      self.exe.microbatch_size * self.app.attn_heads // self.exe.tensor_par,\n      self.app.seq_size,\n      self.app.attn_size,\n      self.app.seq_size,\n      needs_recompute=recompute_attn_flag,\n      output_stored=(not recompute_attn_flag)))\n    self._llm_block.append(SoftMax(\n      \"AttnBlock_Multihead_SoftMax\",\n      self.sys,\n      self.app.attn_heads // self.exe.tensor_par * \\\n        self.app.seq_size**2 * self.exe.microbatch_size,\n      needs_recompute=recompute_attn_flag,\n      output_stored=(not recompute_attn_flag)))\n    self._llm_block.append(DropOut(\n      \"AttnBlock_Multihead_DropOut\",\n      self.sys,\n      self.app.attn_heads // self.exe.tensor_par * \\\n        self.app.seq_size**2 * self.exe.microbatch_size,\n      needs_recompute=recompute_attn_flag,\n      activation_stored=(not recompute_attn_flag)))\n    self._llm_block.append(BatchMatMul(\n      \"AttnBlock_Multihead_Attn\",\n      self.sys,\n      self.exe.microbatch_size * self.app.attn_heads // self.exe.tensor_par,\n      self.app.seq_size,\n      self.app.seq_size,\n      self.app.attn_heads * self.app.attn_size // self.app.attn_heads,\n      needs_recompute=recompute_flag))\n    if self.exe.tensor_par_overlap == 'none':\n      self._llm_block.append(Linear(\n        \"AttnBlock_MLP\",\n        self.sys,\n        self._batch_seq,\n        self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,\n        self.app.hidden,\n        needs_recompute=recompute_flag))\n      self._llm_block.append(TPComm(\n        \"AttnBlock_G\",\n        self.sys,\n        self._activation_size,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        # We only compute flops/mem analyzing this layers, comm analyzed later\n        # This is conservative estimate that does not consider p2p_rs_ag\n        # because we don't differentiate between edge and middle blocks here\n        tensor_par_comm_type=self.exe.tensor_par_comm_type,\n        conjugate=True,\n        in_network_reduction=self.exe.in_network_reduction,\n        needs_recomm=recompute_flag,\n        # We don't store input to RS/AR\n        activation_stored=False))\n    else:\n      self._llm_block.append(LinearOverlapped(\n        \"AttnBlock_MLP_RS\",\n        self.sys,\n        self._batch_seq,\n        self.app.attn_heads * self.app.attn_size,\n        self.app.hidden,\n        self.exe.tensor_par_comm_type,\n        self.exe.tensor_par,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        conjugate=True,\n        tp_overlap=self.exe.tensor_par_overlap,\n        needs_recompute=recompute_flag,\n        needs_recomm=recompute_flag))\n    self._llm_block.append(DropOut(\n      \"AttnBlock_DropOut\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      needs_recompute=recompute_flag))\n    self._llm_block.append(ElementWise(\n      \"AttnBlock_Residual\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      needs_recompute=recompute_flag,\n      # Activation is stored in Fork instead\n      activation_stored=False,\n      activation_reused=True))\n\n  def _build_mlp_block(self):\n    recompute_flag = self.exe.activation_recompute == \"full\"\n    recompute_ag_flag = recompute_flag or self.exe.seq_par_ag_redo\n\n    self._llm_block.append(Fork(\n      \"MlpBlock_Fork\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      2,\n      needs_recompute=recompute_flag,\n      # We account this activation when consider Residual and LayerNorm\n      activation_stored=True))\n    self._llm_block.append(LayerNorm(\n      \"MlpBlock_LayerNorm\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      self.app.hidden,\n      needs_recompute=recompute_flag,\n      # Activation is stored in Fork instead\n      activation_stored=False,\n      activation_reused=True))\n    if self.exe.tensor_par_overlap == 'none':\n      self._llm_block.append(TPComm(\n        \"MlpBlock_F\",\n        self.sys,\n        # We only do compute/mem analyzing this layers, comm analyzed later\n        # We keep extra mem buffer for comm, consider full tensor mem access\n        # to be consistent with how much data comm moves/touches\n        # This is conservative estimate that does not consider p2p_rs_ag\n        # because we don't differentiate between edge and middle blocks here\n        self._activation_size,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        tensor_par_comm_type=self.exe.tensor_par_comm_type,\n        conjugate=False,\n        in_network_reduction=self.exe.in_network_reduction,\n        needs_recomm=recompute_ag_flag))\n      self._llm_block.append(Linear(\n        \"MlpBlock_Mlp1\",\n        self.sys,\n        self._batch_seq,\n        self.app.hidden,\n        self.app.feedforward // self.exe.tensor_par,\n        needs_recompute=recompute_flag,\n        # With seq_par, we use activations from Comm layers to reflect that\n        # they're split, otherwise we keep full size activations\n        activation_stored=(not recompute_ag_flag)))\n    else:\n      self._llm_block.append(LinearOverlapped(\n        \"MlpBlock_Mlp1_AG\",\n        self.sys,\n        self._batch_seq,\n        self.app.hidden,\n        self.app.feedforward,\n        self.exe.tensor_par_comm_type,\n        self.exe.tensor_par,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        conjugate=False,\n        tp_overlap=self.exe.tensor_par_overlap,\n        needs_recompute=recompute_flag,\n        needs_recomm=recompute_ag_flag))\n    self._llm_block.append(GeLU(\n      \"MlpBlock_GeLU\",\n      self.sys,\n      self.app.feedforward * self._batch_seq // self.exe.tensor_par,\n      needs_recompute=recompute_flag,\n      fused=self.exe.fused_activation))\n    if self.exe.tensor_par_overlap == 'none':\n      self._llm_block.append(Linear(\n        \"MlpBlock_Mlp2\",\n        self.sys,\n        self._batch_seq,\n        self.app.feedforward // self.exe.tensor_par,\n        self.app.hidden,\n        needs_recompute=recompute_flag))\n      self._llm_block.append(TPComm(\n        \"MlpBlock_G\",\n        self.sys,\n        self._activation_size,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        # We only compute flops/mem analyzing this layers, comm analyzed later\n        # This is conservative estimate that does not consider p2p_rs_ag\n        # because we don't differentiate between edge and middle blocks here\n        tensor_par_comm_type=self.exe.tensor_par_comm_type,\n        conjugate=True,\n        in_network_reduction=self.exe.in_network_reduction,\n        needs_recomm=recompute_flag,\n        # We don't store input to RS/AR\n        activation_stored=False))\n    else:\n      self._llm_block.append(LinearOverlapped(\n        \"MlpBlock_Mlp2_RS\",\n        self.sys,\n        self._batch_seq,\n        self.app.feedforward,\n        self.app.hidden,\n        self.exe.tensor_par_comm_type,\n        self.exe.tensor_par,\n        self.exe.tensor_par_net,\n        self.exe.tensor_par,\n        conjugate=True,\n        tp_overlap=self.exe.tensor_par_overlap,\n        needs_recompute=recompute_flag,\n        needs_recomm=recompute_flag))\n    self._llm_block.append(DropOut(\n      \"MlpBlock_DropOut\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      needs_recompute=recompute_flag))\n    self._llm_block.append(ElementWise(\n      \"MlpBlock_Residual\",\n      self.sys,\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      pick(self.exe._sequence_par, self._seq_par_activation_size,\n           self._activation_size),\n      needs_recompute=recompute_flag,\n      # Activation is stored in Fork instead\n      activation_stored=False,\n      activation_reused=True))\n\n  def compile(self, sys, exe):\n    assert not self._compiled\n    assert isinstance(exe, self.Execution)\n    self.exe = exe\n    assert isinstance(sys, System)\n    self.sys = sys\n    self._check_network_assignments()\n\n    self.sys.set_datatype(self.exe.datatype)\n\n    # If we have number of blocks not divisible by PP, we can allocate the\n    # reminder of the blocks on the first num_block % PP Procs and block\n    # \"bubbles\" on the last PP - (num_block % PP) Procs. To reflect that,\n    # we round up blocks_per_proc. We report time for Proc0. In that case\n    # its bubble time is `PP - (num_block % PP)` blocks shorter\n    self._blocks_per_proc = self.app.num_blocks // self.exe.pipeline_par\n    if self.app.num_blocks % self.exe.pipeline_par != 0:\n      self._blocks_per_proc += 1\n      self._bubble_reduction_blocks = self.exe.pipeline_par - (\n        self.app.num_blocks % self.exe.pipeline_par)\n    else:\n      self._bubble_reduction_blocks = 0\n    if self.exe.pipeline_interleaving > self._blocks_per_proc:\n      raise self.Error('Pipeline interleaving must be less than or equal to '\n                       'the number of blocks per processor')\n    if self._blocks_per_proc % self.exe.pipeline_interleaving != 0:\n      raise self.Error('Pipeline interleaving must be a factor value of the '\n                       'number of blocks per processor')\n    self._bytes_per_element = System.TypeSizes[self.exe.datatype]\n\n    # Checks that enough blocks per processor exist if offloading is being\n    # performed\n    if (self.exe.weight_offload or self.exe.activations_offload or\n        self.exe.optimizer_offload) and (self._blocks_per_proc <= 2):\n      raise self.Error('Offloading requires each processor to handle at least'\n                       ' 3 blocks')\n\n    # A chunk is a set of blocks for microbatch before passing to the next\n    # processor in the pipeline. Each chunk is modeled as a base\n    # block that is repeated N-1 times and followed by 1 edge block.\n    # Recommunication time is the same in both base and edge blocks.\n    self._blocks_per_chunk = \\\n      self._blocks_per_proc // self.exe.pipeline_interleaving\n    assert self._blocks_per_proc % self._blocks_per_chunk == 0, \\\n      \"PP interleaving should evenly devide {self._blocks_per_proc} blocks\"\n    self._chunks_per_proc = self._blocks_per_proc // self._blocks_per_chunk\n    assert self._chunks_per_proc == self.exe.pipeline_interleaving, \\\n      \"Number of chunks should be equal to pipeline_interleaving\"\n    self._baseblocks_per_chunk = self._blocks_per_chunk - 1\n    self._edgeblocks_per_chunk = 1\n\n    # Build model during the compilation step\n    self._batch_seq = self.exe.microbatch_size * self.app.seq_size\n    self._activation_size = self._batch_seq * self.app.hidden\n    self._batch_seq_par = self._batch_seq // self.exe.tensor_par\n    if self.exe._sequence_par or self.exe._pipeline_par_rs_ag:\n      assert self._batch_seq % self.exe.tensor_par == 0, (\n        f\"We should split batch_seq={self._batch_seq} between\"\n        f\" {self.exe.tensor_par} TP partitions evenly\")\n    self._seq_par_activation_size = self._batch_seq_par * self.app.hidden\n    self._build_attn_block()\n    self._build_mlp_block()\n    for layer in self._llm_block:\n      layer.set_bytes_per_element(self._bytes_per_element)\n      if self.exe.optimizer_sharding:\n        layer.shard_optimizer(self.exe.data_par)\n    self._compiled = True\n\n  def _check_network_assignments(self):\n    used = [False] * self.sys.num_networks\n    size = [1] * self.sys.num_networks\n\n    assert self.exe.tensor_par_net < self.sys.num_networks\n    assert self.exe.pipeline_par_net < self.sys.num_networks\n    assert self.exe.data_par_net < self.sys.num_networks\n\n    if self.exe.tensor_par > 1:\n      used[self.exe.tensor_par_net] = True\n      size[self.exe.tensor_par_net] *= self.exe.tensor_par\n    self._tp_net = self.sys.get_network(self.exe.tensor_par_net)\n\n    if self.exe.pipeline_par > 1:\n      used[self.exe.pipeline_par_net] = True\n      size[self.exe.pipeline_par_net] *= self.exe.pipeline_par\n    self._pp_net = self.sys.get_network(self.exe.pipeline_par_net)\n\n    if self.exe.data_par > 1:\n      used[self.exe.data_par_net] = True\n      size[self.exe.data_par_net] *= self.exe.data_par\n    self._dp_net = self.sys.get_network(self.exe.data_par_net)\n\n    for tier_used, tier_size, tier in zip(\n        used, size, range(self.sys.num_networks)):\n      if tier_used:\n        if tier_size > self.sys.get_network(tier).size:\n          raise self.Error(f'Network tier{tier} isn\\'t big enough')\n        if (self.sys.get_network(tier).must_be_filled and\n            self.sys.get_network(tier).size % tier_size != 0):\n          raise self.Error(f'Network tier{tier} isn\\'t fully used')\n\n  def _compute_block_stats(self):\n    \"\"\"\n    This function computes the statistics for one microbatch on a single block.\n    This only computes flops, flop time, and communication sizes. Since\n    tensor and pipeline parallelism cause different communication operations to\n    occur at the full batch level, the communication times are computed later.\n    \"\"\"\n    if self.exe.training and self.exe.activation_recompute == \"full\":\n      self._block_act_checkpoint_size = \\\n        self._activation_size * self._bytes_per_element\n    else:\n      self._block_act_checkpoint_size = 0\n\n    # Initializes values to zero for accumulation in layer loop\n    self._block_fw_flops = 0\n    self._block_fw_flops_time = 0\n    self._block_fw_mem_accessed = 0\n    self._block_fw_mem_time = 0\n    self._block_fw_time = 0\n    self._baseblock_fw_tp_size = 0\n    self._edgeblock_fw_tp_size = 0\n    self._baseblock_fw_tp_time = 0\n    self._edgeblock_fw_tp_time = 0\n    self._baseblock_fw_tp_time_exposed = 0\n    self._edgeblock_fw_tp_time_exposed = 0\n    self._block_weight_space = 0\n    self._block_act_working_space = 0\n    self._block_act_storage_space = 0\n    # We use this block for self.exe.training, but initialize anyway\n    self._block_re_flops = 0\n    self._block_re_flops_time = 0\n    self._block_re_mem_accessed = 0\n    self._block_re_mem_time = 0\n    self._block_re_time = 0\n    self._baseblock_recomm_size = 0\n    self._edgeblock_recomm_size = 0\n    self._baseblock_recomm_time = 0\n    self._edgeblock_recomm_time = 0\n    self._baseblock_recomm_time_exposed = 0\n    self._edgeblock_recomm_time_exposed = 0\n    self._block_agrad_flops = 0\n    self._block_agrad_flops_time = 0\n    self._block_agrad_mem_accessed = 0\n    self._block_agrad_mem_time = 0\n    self._block_agrad_time = 0\n    self._baseblock_agrad_tp_size = 0\n    self._edgeblock_agrad_tp_size = 0\n    self._baseblock_agrad_tp_time = 0\n    self._edgeblock_agrad_tp_time = 0\n    self._baseblock_agrad_tp_time_exposed = 0\n    self._edgeblock_agrad_tp_time_exposed = 0\n    self._block_wgrad_flops = 0\n    self._block_wgrad_flops_time = 0\n    self._block_wgrad_mem_accessed = 0\n    self._block_wgrad_mem_time = 0\n    self._block_wgrad_time = 0\n    self._block_optim_flops = 0\n    self._block_optim_flops_time = 0\n    self._block_optim_mem_accessed = 0\n    self._block_optim_mem_time = 0\n    self._block_optim_time = 0\n    self._block_weight_grad_space = 0\n    self._block_weight_grad_space_no_sharding = 0\n    self._block_act_grad_space = 0\n    self._block_optimizer_space = 0\n    self._tp_bw_overlap_req = 0\n\n    prev_layer_recompute = False\n    for layer in self._llm_block:\n      # Add flops/bytes/times per layer\n      self._block_fw_flops += layer.get_fw_flops()\n      self._block_fw_flops_time += layer.compute_flops_time(\"fw\")\n      self._block_fw_mem_accessed += layer.get_fw_mem_accessed()\n      self._block_fw_mem_time += layer.compute_mem_time(\"fw\")\n      self._block_fw_time += layer.compute_processing_time(\"fw\")\n      self._baseblock_fw_tp_size += layer.get_comm_bytes(\"fw\",\n        baseblock=True)\n      self._edgeblock_fw_tp_size += layer.get_comm_bytes(\"fw\",\n        baseblock=False)\n      self._baseblock_fw_tp_time += layer.compute_net_time(\"fw\",\n        baseblock=True)\n      self._edgeblock_fw_tp_time += layer.compute_net_time(\"fw\",\n        baseblock=False)\n      self._baseblock_fw_tp_time_exposed += layer.get_exposed_net_time(\"fw\",\n        baseblock=True)\n      self._edgeblock_fw_tp_time_exposed += layer.get_exposed_net_time(\"fw\",\n        baseblock=False)\n      self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,\n        layer.get_required_bandwidth(\"fw\", baseblock=True))\n      self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,\n        layer.get_required_bandwidth(\"fw\", baseblock=False))\n      if self.exe.training:\n        if layer.get_recompute_flag():\n          self._block_re_flops += self._block_fw_flops\n          self._block_re_flops_time += self._block_fw_flops_time\n          self._block_re_mem_accessed += self._block_fw_mem_accessed\n          self._block_re_mem_time += self._block_fw_mem_time\n          self._block_re_time += layer.compute_processing_time(\"fw\")\n        if layer.get_recomm_flag():\n          self._baseblock_recomm_size += layer.get_comm_bytes(\"wgrad\",\n            baseblock=True)\n          self._edgeblock_recomm_size += layer.get_comm_bytes(\"wgrad\",\n            baseblock=False)\n          self._baseblock_recomm_time += layer.compute_net_time(\"wgrad\",\n            baseblock=True)\n          self._edgeblock_recomm_time += layer.compute_net_time(\"wgrad\",\n            baseblock=False)\n          self._baseblock_recomm_time_exposed += layer.get_exposed_net_time(\n            \"wgrad\", baseblock=True)\n          self._edgeblock_recomm_time_exposed += layer.get_exposed_net_time(\n            \"wgrad\", baseblock=False)\n        self._block_agrad_flops += layer.get_agrad_flops()\n        self._block_agrad_flops_time += layer.compute_flops_time(\"agrad\")\n        self._block_agrad_mem_accessed += layer.get_agrad_mem_accessed()\n        self._block_agrad_mem_time += layer.compute_mem_time(\"agrad\")\n        self._block_agrad_time += layer.compute_processing_time(\"agrad\")\n        self._baseblock_agrad_tp_size += layer.get_comm_bytes(\"agrad\",\n          baseblock=True)\n        self._edgeblock_agrad_tp_size += layer.get_comm_bytes(\"agrad\",\n          baseblock=False)\n        self._baseblock_agrad_tp_time += layer.compute_net_time(\"agrad\",\n          baseblock=True)\n        self._edgeblock_agrad_tp_time += layer.compute_net_time(\"agrad\",\n          baseblock=False)\n        self._baseblock_agrad_tp_time_exposed += layer.get_exposed_net_time(\n          \"agrad\", baseblock=True)\n        self._edgeblock_agrad_tp_time_exposed += layer.get_exposed_net_time(\n          \"agrad\", baseblock=False)\n        self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,\n          layer.get_required_bandwidth(\"agrad\", baseblock=True))\n        self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,\n          layer.get_required_bandwidth(\"agrad\", baseblock=False))\n        self._block_wgrad_flops += layer.get_wgrad_flops()\n        self._block_wgrad_flops_time += layer.compute_flops_time(\"wgrad\")\n        self._block_wgrad_mem_accessed += layer.get_wgrad_mem_accessed()\n        self._block_wgrad_mem_time += layer.compute_mem_time(\"wgrad\")\n        self._block_wgrad_time += layer.compute_processing_time(\"wgrad\")\n        self._block_optim_flops += layer.get_optim_step_flops()\n        self._block_optim_flops_time += layer.compute_flops_time(\"optim\")\n        self._block_optim_mem_accessed += layer.get_optim_step_mem_accessed()\n        self._block_optim_mem_time += layer.compute_mem_time(\"optim\")\n        self._block_optim_time += layer.compute_processing_time(\"optim\")\n\n      # Accumulate space requirements per block\n      self._block_weight_space += layer.get_weight()\n      if not layer.reuses_activation():\n        self._block_act_working_space += layer.get_activation()\n      self._block_act_storage_space += layer.get_activation()\n      if self.exe.training:\n        if not layer.stores_output():\n          self._block_act_storage_space -= layer.get_output()\n        if not layer.stores_activation():\n          self._block_act_storage_space -= layer.get_activation()\n        self._block_weight_grad_space += layer.get_weight_grad()\n        self._block_weight_grad_space_no_sharding += layer.get_weight_grad(\n          sharded=False)\n        self._block_act_grad_space += layer.get_activation_grad()\n        self._block_optimizer_space += layer.get_optimizer()\n\n      self.log.debug(\"%s %s %s\", layer.name, 'Recompute flag:',\n                     str(layer.get_recompute_flag()))\n      self.log.debug(\"%s %s %s\", layer.name, 'Recomm flag:',\n                     str(layer.get_recomm_flag()))\n      self.log.debug(\"%s %s %s\", layer.name, 'Stores activation:',\n                     str(layer.stores_activation()))\n      self.log.debug(\"%s %s %s\", layer.name, 'Reuses activation:',\n                     str(layer.reuses_activation()))\n      self.log.debug(\"%s %s %s\", layer.name, 'Stores output:',\n                     str(layer.stores_output()))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW flops:',\n                     human_format(layer.get_fw_flops(), 'flops'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW num inputs:',\n                     human_format(layer.inputs_size, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW num output:',\n                     human_format(layer.output_size, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW num weights:',\n                     human_format(layer.weight_space, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW mem:',\n                     human_format(layer.get_fw_mem_accessed(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW baseblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"fw\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW edgeblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"fw\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW baseblock comm size:',\n                     human_format(layer.get_comm_bytes(\"fw\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'FW edgeblock comm size:',\n                     human_format(layer.get_comm_bytes(\"fw\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'FW net link time:',\n                     layer.compute_net_time(\"fw\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'FW net exposed time:',\n                     layer.get_exposed_net_time(\"fw\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'FW time:',\n                     layer.compute_processing_time(\"fw\"))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW flops:',\n                     human_format(\n                      layer.get_agrad_flops() + layer.get_wgrad_flops(),\n                      'flops'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW num Wgrads:',\n                     human_format(layer.weight_grads, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW num Agrads:',\n                     human_format(layer.activation_grads, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW num Igrads:',\n                     human_format(layer.inputs_size, 'base2'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW mem:',\n                     human_format(\n                      layer.get_agrad_mem_accessed() +\n                      layer.get_wgrad_mem_accessed(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW baseblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"agrad\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW edgeblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"agrad\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW baseblock comm size:',\n                     human_format(layer.get_comm_bytes(\"agrad\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW edgeblock comm size:',\n                     human_format(layer.get_comm_bytes(\"agrad\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'BW net link time:',\n                     layer.compute_net_time(\"agrad\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'BW net exposed time:',\n                     layer.get_exposed_net_time(\"agrad\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'BW time:',\n                     layer.compute_processing_time(\"agrad\") +\n                     layer.compute_processing_time(\"wgrad\"))\n      self.log.debug(\"%s %s %s\", layer.name, 'Recomm baseblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"wgrad\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Recomm edgeblock comm tile size:',\n                     human_format(layer.get_comm_tile(\"wgrad\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Recomm baseblock comm size:',\n                     human_format(layer.get_comm_bytes(\"wgrad\", baseblock=True),\n                     'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Recomm edgeblock comm size:',\n                     human_format(layer.get_comm_bytes(\"wgrad\", baseblock=False),\n                     'bytes'))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'Recomm net link time:',\n                     layer.compute_net_time(\"wgrad\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'Recomm net exposed time:',\n                     layer.get_exposed_net_time(\"wgrad\"))\n      self.log.debug(\"%s %s %s\", layer.name, 'Optim flops:',\n                     human_format(layer.get_optim_step_flops(), 'flops'))\n      self.log.debug(\"%s %s %s\", layer.name, 'BW Optimizer size:',\n                     human_format(layer.get_optimizer(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Optim mem:',\n                     human_format(layer.get_optim_step_mem_accessed(), 'bytes'))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'Optim time:',\n                     layer.compute_processing_time(\"optim\"))\n      self.log.debug(\"%s %s %.3e\", layer.name, 'Recompute:',\n                     layer.get_recompute_flag())\n      self.log.debug(\"%s %s %s\", layer.name, 'Recompute mem saving:',\n                     human_format(layer.stores_output() * \\\n                       layer.get_output(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Weight:',\n                     human_format(layer.get_weight(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Act:',\n                     human_format(layer.get_activation(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Weight grad:',\n                     human_format(layer.get_weight_grad(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Act grad:',\n                     human_format(layer.get_activation_grad(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Optim:',\n                     human_format(layer.get_optimizer(), 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Weight:',\n                     human_format(self._block_weight_space, 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Act Working space:',\n                     human_format(self._block_act_working_space, 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Act Storage space:',\n                     human_format(self._block_act_storage_space, 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Weight grad:',\n                     human_format(self._block_weight_grad_space, 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Act grad:',\n                     human_format(self._block_act_grad_space, 'bytes'))\n      self.log.debug(\"%s %s %s\", layer.name, 'Incremental Optim:',\n                     human_format(self._block_optimizer_space, 'bytes'))\n      prev_layer_recompute = layer.get_recompute_flag()\n    if self.exe.activation_recompute == 'full':\n      self._block_act_storage_space = 0\n\n    # Sets the PP communication operation size\n    if self.exe.pipeline_par > 1:\n      if self.exe._pipeline_par_rs_ag:\n        self._block_fw_pp_size = self._seq_par_activation_size * \\\n          self._bytes_per_element\n      else:\n        self._block_fw_pp_size = self._activation_size * \\\n          self._bytes_per_element\n    else:\n      self._block_fw_pp_size = 0\n\n    # When training, BW sizes for TP and PP are same as FW\n    if self.exe.training:\n      self._block_bw_pp_size = self._block_fw_pp_size\n    else:\n      self._block_bw_pp_size = 0\n\n    self.log.debug(\"%s %s\", 'TP comm FW baseblock size:',\n                   human_format(self._baseblock_fw_tp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP comm FW edgeblock size:',\n                   human_format(self._edgeblock_fw_tp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'PP comm FW size:',\n                   human_format(self._block_fw_pp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP comm BW baseblock size:',\n                   human_format(self._baseblock_agrad_tp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP comm BW edgeblock size:',\n                   human_format(self._edgeblock_agrad_tp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'PP comm BW size:',\n                   human_format(self._block_bw_pp_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP recomm baseblock size:',\n                   human_format(self._baseblock_recomm_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP recomm edgeblock size:',\n                   human_format(self._edgeblock_recomm_size, 'bytes'))\n    self.log.debug(\"%s %s\", 'TP comm required bandwidth for tiled overlap:',\n                   human_format(self._tp_bw_overlap_req, 'bandwidth'))\n\n  def _compute_batch_stats(self):\n    \"\"\"\n    This function computes the statistics for a full batch. This uses the per\n    microbatch per block statistics from the prior function (see above).\n    \"\"\"\n    # Total stats for compute and memory\n    mult = self._blocks_per_proc * self.exe._num_microbatches\n    self._fw_flops = mult * self._block_fw_flops\n    self._fw_flops_time = mult * self._block_fw_flops_time\n    self._fw_mem_accessed = mult * self._block_fw_mem_accessed\n    self._fw_mem_time = mult * self._block_fw_mem_time\n    self._fw_time = mult * self._block_fw_time\n    self._re_flops = mult * self._block_re_flops\n    self._re_flops_time = mult * self._block_re_flops_time\n    self._re_mem_accessed = mult * self._block_re_mem_accessed\n    self._re_mem_time = mult * self._block_re_mem_time\n    self._re_time = mult * self._block_re_time\n    self._agrad_flops = mult * self._block_agrad_flops\n    self._agrad_flops_time = mult * self._block_agrad_flops_time\n    self._agrad_mem_accessed = mult * self._block_agrad_mem_accessed\n    self._agrad_mem_time = mult * self._block_agrad_mem_time\n    self._agrad_time = mult * self._block_agrad_time\n    self._wgrad_flops = mult * self._block_wgrad_flops\n    self._wgrad_flops_time = mult * self._block_wgrad_flops_time\n    self._wgrad_mem_accessed = mult * self._block_wgrad_mem_accessed\n    self._wgrad_mem_time = mult * self._block_wgrad_mem_time\n    self._wgrad_time = mult * self._block_wgrad_time\n    self._optim_flops = self._blocks_per_proc * self._block_optim_flops\n    self._optim_flops_time = self._blocks_per_proc * self._block_optim_flops_time\n    self._optim_mem_accessed = self._blocks_per_proc * self._block_optim_mem_accessed\n    self._optim_mem_time = self._blocks_per_proc * self._block_optim_mem_time\n    self._optim_time = self._blocks_per_proc * self._block_optim_time\n\n    # These TP numbers are for total times for all blocks in all chunks\n    tp_fw_comm_time = self.exe._num_microbatches * self._chunks_per_proc * (\n      (self._baseblocks_per_chunk * self._baseblock_fw_tp_time) +\n      (self._edgeblocks_per_chunk * self._edgeblock_fw_tp_time))\n    tp_fw_comm_time_exposed = \\\n      self.exe._num_microbatches * self._chunks_per_proc * (\n        (self._baseblocks_per_chunk * self._baseblock_fw_tp_time_exposed) +\n        (self._edgeblocks_per_chunk * self._edgeblock_fw_tp_time_exposed))\n    tp_bw_comm_time = self.exe._num_microbatches * self._chunks_per_proc * (\n      self._baseblocks_per_chunk * self._baseblock_agrad_tp_time +\n      self._edgeblocks_per_chunk * self._edgeblock_agrad_tp_time)\n    tp_bw_comm_time_exposed = \\\n      self.exe._num_microbatches * self._chunks_per_proc * (\n        self._baseblocks_per_chunk * self._baseblock_agrad_tp_time_exposed +\n        self._edgeblocks_per_chunk * self._edgeblock_agrad_tp_time_exposed)\n    tp_recomm_time = self.exe._num_microbatches * self._chunks_per_proc * (\n      (self._baseblocks_per_chunk * self._baseblock_recomm_time) +\n      (self._edgeblocks_per_chunk * self._edgeblock_recomm_time))\n    tp_recomm_time_exposed = \\\n      self.exe._num_microbatches * self._chunks_per_proc * (\n        (self._baseblocks_per_chunk * self._baseblock_recomm_time_exposed) +\n        (self._edgeblocks_per_chunk * self._edgeblock_recomm_time_exposed))\n\n    # Per chunk PP comm time\n    chunk_fw_pp_time = self._pp_net.time('p2p', self._block_fw_pp_size, 2)\n    chunk_bw_pp_time = self._pp_net.time('p2p', self._block_bw_pp_size, 2)\n\n    # Determines number of times PP causes pipeline p2p communications per\n    # chunk during the forward and backward pass (equal to chunks per proc)\n    if self.exe.pipeline_par > 1:\n      num_fw_pp_p2ps = self._chunks_per_proc\n      if self.exe.training:\n        num_bw_pp_p2ps = self._chunks_per_proc\n      else:\n        num_bw_pp_p2ps = 0\n    else:\n      num_fw_pp_p2ps = 0\n      num_bw_pp_p2ps = 0\n\n    # These PP numbers are for total times for all blocks and all microbatches\n    pp_fw_comm_time = self.exe._num_microbatches * num_fw_pp_p2ps * \\\n      chunk_fw_pp_time\n    pp_bw_comm_time = self.exe._num_microbatches * num_bw_pp_p2ps * \\\n      chunk_bw_pp_time\n\n    # Aggregrates metrics\n    self._tp_comm_time_link = tp_fw_comm_time + tp_bw_comm_time\n    self._tp_comm_time_exposed = (tp_fw_comm_time_exposed +\n      tp_bw_comm_time_exposed)\n    self._recomm_time_link = tp_recomm_time\n    self._recomm_time_exposed = tp_recomm_time_exposed\n    self._pp_comm_time_link = pp_fw_comm_time + pp_bw_comm_time\n    self._pp_comm_time_exposed = self._pp_comm_time_link\n\n    self.log.debug(\"%s %s\", 'TP comm baseblock FW time:',\n      self._baseblock_fw_tp_time)\n    self.log.debug(\"%s %s\", 'TP comm edgeblock FW time:',\n      self._edgeblock_fw_tp_time)\n    self.log.debug(\"%s %s\", 'TP comm FW time:', tp_fw_comm_time)\n    self.log.debug(\"%s %s\", 'TP comm baseblock FW exposed time:',\n      self._baseblock_fw_tp_time_exposed)\n    self.log.debug(\"%s %s\", 'TP comm edgeblock FW exposed time:',\n      self._edgeblock_fw_tp_time_exposed)\n    self.log.debug(\"%s %s\", 'TP comm FW exposed time:', tp_fw_comm_time_exposed)\n    self.log.debug(\"%s %s\", 'TP comm baseblock BW time:',\n      self._baseblock_agrad_tp_time)\n    self.log.debug(\"%s %s\", 'TP comm edgeblock BW time:',\n      self._edgeblock_agrad_tp_time)\n    self.log.debug(\"%s %s\", 'TP comm BW time:', tp_bw_comm_time)\n    self.log.debug(\"%s %s\", 'TP comm baseblock BW exposed time:',\n      self._baseblock_agrad_tp_time_exposed)\n    self.log.debug(\"%s %s\", 'TP comm edgeblock BW exposed time:',\n      self._edgeblock_agrad_tp_time_exposed)\n    self.log.debug(\"%s %s\", 'TP comm BW exposed time:',\n      tp_bw_comm_time_exposed)\n    self.log.debug(\"%s %s\", 'PP comm chunk FW time:', chunk_fw_pp_time)\n    self.log.debug(\"%s %s\", 'PP comm chunk BW time:', chunk_bw_pp_time)\n    self.log.debug(\"%s %s\", 'PP comm FW time:', pp_fw_comm_time)\n    self.log.debug(\"%s %s\", 'PP comm BW time:', pp_bw_comm_time)\n\n    # Bubble forms between i-th microbatch FW and BW passes on the 1st GPU.\n    # With no interleaving between blocks, it includes\n    # L/gpu x microbatch_time x (p-1) x Tcycle, where cycle includes both\n    # FW and BW passes, TP and PP communication for FW and BW passes\n    # With full interleaving, we only need microbatch_time x (p-1) x Tcycle time\n    self._baseblock_fw_time_no_offload = (\n      self._block_fw_time + self._baseblock_fw_tp_time_exposed)\n    self._edgeblock_fw_time_no_offload = (\n      self._block_fw_time + self._edgeblock_fw_tp_time_exposed +\n      chunk_fw_pp_time)\n    self._baseblock_fw_offload_overhead = max(\n      0, self.get_fw_offload_time() + self._block_fw_mem_time -\n      self._baseblock_fw_time_no_offload)\n    self._edgeblock_fw_offload_overhead = max(\n      0, self.get_fw_offload_time() + self._block_fw_mem_time -\n      self._edgeblock_fw_time_no_offload)\n    self._baseblock_fw_time = (\n      self._baseblock_fw_time_no_offload + self._baseblock_fw_offload_overhead)\n    self._edgeblock_fw_time = (\n      self._edgeblock_fw_time_no_offload + self._edgeblock_fw_offload_overhead)\n    # When we consider block BW time, we do not add optimizer step to it\n    # because we have optimizer only for last microbatches, while offloading\n    # works during the whole backward pass.\n    # Optimizer step is overall memory bound streaming task, itt is reasonable\n    # to not overlap offloading with optimizer step\n    self._baseblock_bw_time_no_offload = (\n      self._block_re_time + self._baseblock_recomm_time_exposed +\n      self._block_agrad_time + self._block_wgrad_time +\n      self._baseblock_agrad_tp_time_exposed)\n    self._edgeblock_bw_time_no_offload = (\n      self._block_re_time + self._edgeblock_recomm_time_exposed +\n      self._block_agrad_time + self._block_wgrad_time +\n      self._edgeblock_agrad_tp_time_exposed + chunk_bw_pp_time)\n    self._baseblock_bw_offload_overhead = max(\n      0, self.get_bw_offload_time() + self._block_agrad_mem_time +\n      self._block_wgrad_mem_time -\n      self._baseblock_bw_time_no_offload)\n    self._edgeblock_bw_offload_overhead = max(\n      0, self.get_bw_offload_time() + self._block_agrad_mem_time +\n      self._block_wgrad_mem_time -\n      self._edgeblock_bw_time_no_offload)\n    self._baseblock_bw_time = (\n      self._baseblock_bw_time_no_offload + self._baseblock_bw_offload_overhead)\n    self._edgeblock_bw_time = (\n      self._edgeblock_bw_time_no_offload + self._edgeblock_bw_offload_overhead)\n    chunk_fw_time = (\n      (self._baseblocks_per_chunk * self._baseblock_fw_time) +\n      (self._edgeblocks_per_chunk * self._edgeblock_fw_time))\n    chunk_bw_time = (\n      (self._baseblocks_per_chunk * self._baseblock_bw_time) +\n      (self._edgeblocks_per_chunk * self._edgeblock_bw_time))\n    # Can't overlap DP comm with mem accesses, but can overlap with offload\n    baseblock_dp_overlap_time = self._baseblock_bw_time - (\n      self._block_agrad_mem_time + self._block_wgrad_mem_time +\n      self._block_re_mem_time)\n    edgeblock_dp_overlap_time = self._edgeblock_bw_time - (\n      self._block_agrad_mem_time + self._block_wgrad_mem_time +\n      self._block_re_mem_time)\n    block_dp_compute_time = (\n      self._block_agrad_flops_time + self._block_wgrad_flops_time +\n      self._block_re_flops_time)\n    if not self.exe.optimizer_sharding:\n      # If optimizer is not sharded, we can overlap optimizer step with\n      # communication, except for memory access time\n      baseblock_dp_overlap_time += (\n        self._block_optim_time - self._block_optim_mem_time)\n      edgeblock_dp_overlap_time += (\n        self._block_optim_time - self._block_optim_mem_time)\n      block_dp_compute_time += self._block_optim_flops_time\n    if self._dp_net == self._tp_net:\n      # Can't overlap DP with TP if in the same network\n      baseblock_dp_overlap_time -= (\n        self._baseblock_recomm_time + self._baseblock_agrad_tp_time)\n      edgeblock_dp_overlap_time -= (\n        self._edgeblock_recomm_time + self._edgeblock_agrad_tp_time)\n    chunk_dp_overlap_time = (\n      self._baseblocks_per_chunk * baseblock_dp_overlap_time +\n      self._edgeblocks_per_chunk * edgeblock_dp_overlap_time)\n    chunk_dp_compute_time = self._blocks_per_chunk * block_dp_compute_time\n    chunk_time = chunk_fw_time + chunk_bw_time\n    # Block bubbles appear due to uneven division of blocks by pipeline stages\n    # and result in the schedule bubble shorten by the missing edge blocks on\n    # the later pipeline stages (missing block case)\n    if self._baseblocks_per_chunk > 0:\n      # We cut last block of chunk, which is half-edge (has PP comm in the end)\n      bubble_reduction_time = self._bubble_reduction_blocks * (\n        self._baseblock_fw_time + self._edgeblock_fw_time +\n        self._baseblock_bw_time + self._edgeblock_bw_time) / 2\n    else:\n      # If chunk doesn't have base blocks, we cut edge block\n      bubble_reduction_time = self._bubble_reduction_blocks * (\n        self._edgeblock_fw_time + self._edgeblock_bw_time)\n    # With PP interleaving we assume that we move through every chunk at least\n    # PP mini batches. If num_microbatches < PP, then we have extra bubbles\n    # (missing microbatches case). We have the bubbles in the last microbatches\n    # of every overlappable chunk (all but last chunks). Size of bubbles is\n    # equal to microbatch_shortage, same number of microbatches will be missing\n    # in the last chunk\n    chunks_in_bubble = self.exe.pipeline_par - 1\n    num_overlappable_chunks = self.exe.pipeline_interleaving - 1\n    microbatch_shortage = self.exe.pipeline_par - (\n      self.exe._num_microbatches % self.exe.pipeline_par)\n    if self.exe._num_microbatches % self.exe.pipeline_par != 0:\n      extra_interleaving_bubbles = num_overlappable_chunks * \\\n        microbatch_shortage\n    else:\n      extra_interleaving_bubbles = 0\n    self._bubble_time = chunks_in_bubble * chunk_time + (\n      extra_interleaving_bubbles * chunk_time - bubble_reduction_time)\n\n    self.log.debug(\"%s %s\", 'Block FW time:', self._block_fw_time)\n    self.log.debug(\"%s %s\", 'Baseblock FW time:', self._baseblock_fw_time)\n    self.log.debug(\"%s %s\", 'With FW offload overhead time:',\n      self._baseblock_fw_offload_overhead)\n    self.log.debug(\"%s %s\", 'Edgeblock FW time:', self._edgeblock_fw_time)\n    self.log.debug(\"%s %s\", 'With FW offload overhead time:',\n      self._edgeblock_fw_offload_overhead)\n    self.log.debug(\"%s %s\", 'Baseblock REcomm exposed time:',\n      self._baseblock_recomm_time_exposed)\n    self.log.debug(\"%s %s\", 'Edgeblock REcomm exposed time:',\n      self._edgeblock_recomm_time_exposed)\n    self.log.debug(\"%s %s\", 'Block RE time:', self._block_re_time)\n    self.log.debug(\"%s %s\", 'Block BW Agrad time:', self._block_agrad_time)\n    self.log.debug(\"%s %s\", 'Block BW Wgrad time:', self._block_wgrad_time)\n    self.log.debug(\"%s %s\", 'Block optim time:', self._block_optim_time)\n    self.log.debug(\"%s %s\", 'Baseblock BW time:', self._baseblock_bw_time)\n    self.log.debug(\"%s %s\", 'With BW offload overhead time:',\n      self._baseblock_bw_offload_overhead)\n    self.log.debug(\"%s %s\", 'Edgeblock BW time:', self._edgeblock_bw_time)\n    self.log.debug(\"%s %s\", 'With BW offload overhead time:',\n      self._edgeblock_bw_offload_overhead)\n\n    # Determines how long it takes to perform the DP per block\n    # This assumes no DP communication overlap (will be adjusted later).\n    if self.exe.data_par > 1 and self.exe.training:\n      self._block_dp_size = self._block_weight_space\n      if self.exe.optimizer_sharding:\n        # When performing optimizer sharding, the communication time is a\n        # reduce-scatter plus an all-gather.\n        self._block_dp_time = (\n          self._dp_net.time(\n            'reduce_scatter', self._block_dp_size, self.exe.data_par) +\n          self._dp_net.time(\n            'all_gather', self._block_dp_size, self.exe.data_par))\n      else:\n        # When not performing optimizer sharding, the communication time is a\n        # single all-reduce.\n        self._block_dp_time = self._dp_net.time(\n          'all_reduce', self._block_dp_size, self.exe.data_par)\n    else:\n      self._block_dp_size = 0\n      self._block_dp_time = 0\n    self.log.debug('DP block comm size: %s',\n                   human_format(self._block_dp_size, 'bytes'))\n    self.log.debug('DP block comm time (no overlap): %.3e',\n                   self._block_dp_time)\n\n    # DP overlap happens if DP time for a previous block(s) is lower than\n    # microbatch BW pass time for next pack of consecutive blocks\n    # If no interleaving, we move a single microbatch through each block\n    # and need to overlap DP during a single block single microbatch time\n    # In case of full interleaving, we propagate p microbatches through each\n    # block and need to overlap DP comm with p-1 microbatches over a block\n    # In a mixed case, we can overlap DP communication of several chunks, e.g.\n    # non-interleaved blocks (L/gpu / interleaving_factor) over BW pass of\n    # p-1 microbatches through the same amount of blocks if memory capacity is\n    # enough, or perform offload/prefetch after each block-microbatch\n    # For simplicity we count only bandwidth-optimal case\n    # Note that uneven extra PP bubbles won't affect overlapping\n    if self.exe.data_par > 1 and self.exe.training:\n      if self.exe.data_par_overlap:\n        # we can evenly overlap all the chunks except for the last one\n        # in the last chunk we can overlap only all blocks except for the last\n        num_overlappable_chunks = self.exe.pipeline_interleaving - 1\n        last_chunk_overlap_size = self._blocks_per_chunk - 1\n        # We can overlap DP with BW pass, overlap[ing AR for previous layer\n        # with BW for current, except when optimizer sharded. We can't overlap\n        # during optimizer step as we RS grads before step and AG weights after\n        # Overlappable chunks have overlap size equal to\n        # blocks_per_chunk * num_microbatches\n        # In case of 1F1B schedule, num_microbatches == pipeline_par\n        overlap_window = self.exe.pipeline_par * chunk_dp_overlap_time\n        overlap_compute = self.exe.pipeline_par * chunk_dp_compute_time\n        chunk_dp_time = self._blocks_per_chunk * self._block_dp_time\n        # We may have PP and DP comm colliding if DP comm takes longer than\n        # a single chunk BW time. We can't collide more PP than microbatches\n        if self._dp_net == self._pp_net:\n          if self.exe._num_microbatches % self.exe.pipeline_par != 0:\n            num_overlapped_pp = min(\n              chunk_dp_time // chunk_bw_time,\n              self.exe._num_microbatches % self.exe.pipeline_par)\n          else:\n            num_overlapped_pp = min(\n              chunk_dp_time // chunk_bw_time,\n              self.exe.pipeline_par)\n        else:\n          # if PP and DP on different networks, overlapping is fine\n          num_overlapped_pp = 0\n        # we add DP/PP collision time and compute slowdown due to overlap\n        overlap_inflection = chunk_dp_time - (overlap_window -\n          num_overlapped_pp * chunk_bw_pp_time) + overlap_compute * \\\n          self._dp_net.processor_usage\n        if overlap_inflection > 0:\n          # Tcomm is larger than compute, excess is exposed\n          overlappable_chunks_exposed_time = num_overlappable_chunks * \\\n            overlap_inflection\n        else:\n          # Tcomm is smaller than compute and hidden, but it contributes to\n          # compute slowdown due part of compute resources orchestrating comm\n          overlappable_chunks_exposed_time = num_overlappable_chunks * \\\n            chunk_dp_time * self._dp_net.processor_usage\n        # Compute minimal bandwidth required for DP comm overlap of all chunks\n        # but the last one.\n        chunk_overlap_time = overlap_window + overlap_compute * \\\n          self._dp_net.processor_usage\n        if self._dp_net == self._pp_net:\n          chunk_overlap_time -= chunk_bw_pp_time\n        chunk_overlap_time *= num_overlappable_chunks\n        if chunk_overlap_time > 0:\n          self._dp_bw_overlap_req_chunk = self._blocks_per_chunk * \\\n            self._block_dp_size / chunk_overlap_time\n          if self.exe.optimizer_sharding:\n            self._dp_bw_overlap_req_chunk *= (\n              self._dp_net._ops[\"reduce_scatter\"].scalar +\n              self._dp_net._ops[\"all_gather\"].scalar)\n          else:\n            self._dp_bw_overlap_req_chunk *= self._dp_net._ops[\"all_reduce\"].scalar\n        else:\n          self._dp_bw_overlap_req_chunk = 0\n        # in the last chunk, we overlap DP comm over last edge block and all\n        # middle blocks, so we substract the time of the first edge block\n        if self._baseblocks_per_chunk > 0:\n          last_chunk_window = chunk_dp_overlap_time - chunk_bw_pp_time - (\n            self._baseblock_bw_time + self._edgeblock_bw_time) / 2\n          if not self.exe.optimizer_sharding:\n            # If optimizer is not sharded, we can overlap optimizer step with\n            # communication, except for memory access time\n            last_chunk_window += (\n              self._block_optim_time - self._block_optim_mem_time)\n        else:\n          # if there is no base blocks, we only have a single edge block\n          # and last chunk is completely not overlappable\n          last_chunk_window = 0\n        last_chunk_inflection = (\n          last_chunk_overlap_size * self._block_dp_time) + (\n            block_dp_compute_time * self._dp_net.processor_usage -\n            last_chunk_window)\n        if last_chunk_inflection > 0:\n          # Tcomm is larger than compute, excess is exposed\n          last_chunk_exposed_time = last_chunk_inflection\n        else:\n          # Tcomm is smaller than compute and hidden, but it contributes to\n          # compute slowdown due part of compute resources orchestrating comm\n          last_chunk_exposed_time = last_chunk_overlap_size * \\\n            self._block_dp_time * self._dp_net.processor_usage\n        exposed_time = \\\n          overlappable_chunks_exposed_time + last_chunk_exposed_time\n        # Compute minimal bandwidth required for DP comm overlap of last chunk\n        tail_overlap_time = last_chunk_window + last_chunk_overlap_size * \\\n          self._block_dp_time * self._dp_net.processor_usage\n        if tail_overlap_time > 0:\n          self._dp_bw_overlap_req_tail = self._blocks_per_chunk * \\\n          self._block_dp_size / tail_overlap_time\n          if self.exe.optimizer_sharding:\n            self._dp_bw_overlap_req_tail *= (\n              self._dp_net._ops[\"reduce_scatter\"].scalar +\n              self._dp_net._ops[\"all_gather\"].scalar)\n          else:\n            self._dp_bw_overlap_req_tail *= self._dp_net._ops[\"all_reduce\"].scalar\n        else:\n          self._dp_bw_overlap_req_tail = 0\n        self._dp_comm_time_exposed = self._block_dp_time + exposed_time\n        self._dp_comm_time_link = self._blocks_per_proc * self._block_dp_time\n        self.log.debug('Blocks per chunk: %d', self._blocks_per_chunk)\n        self.log.debug('Num overlappable chunks: %d', num_overlappable_chunks)\n        self.log.debug('Last chunk size: %d', last_chunk_overlap_size)\n        self.log.debug('Chunk exposed time: %.3e', max(0, \\\n          chunk_dp_time + num_overlapped_pp * chunk_bw_pp_time - \\\n          overlap_window))\n        self.log.debug('Last chunk exposed time: %.3e', last_chunk_exposed_time)\n      else:\n        self._dp_comm_time_exposed = self._blocks_per_proc * self._block_dp_time\n        self._dp_comm_time_link = self._dp_comm_time_exposed\n        self._dp_bw_overlap_req_chunk = 0\n        self._dp_bw_overlap_req_tail = 0\n    else:\n      self._dp_comm_time_exposed = 0\n      self._dp_comm_time_link = 0\n      self._dp_bw_overlap_req_chunk = 0\n      self._dp_bw_overlap_req_tail = 0\n    self.log.debug('Chunk FW time: %.3e', chunk_fw_time)\n    self.log.debug('Chunk BW time: %.3e', chunk_bw_time)\n    self.log.debug('Chunk BW time for DP overlap: %.3e', chunk_dp_overlap_time)\n    self.log.debug('DP comm time exposed: %.3e', self._dp_comm_time_exposed)\n    self.log.debug('DP comm time on the link: %.3e',\n                   self._dp_comm_time_link)\n    self.log.debug('DP comm required bandwidth for overlapped chunks: %s',\n                   human_format(self._dp_bw_overlap_req_chunk, \"bandwidth\"))\n    self.log.debug('DP comm required bandwidth for the last chunk: %s',\n                   human_format(self._dp_bw_overlap_req_tail, \"bandwidth\"))\n\n    # memory capacity stats\n    self._weight_space = self._block_weight_space * self._blocks_per_proc\n    # account for activation recomputation\n    # for full recompute we keep single block's activations\n    # (no scaling by L/gpu)\n    if self.exe.training:\n      # With 1F1B schedule we only keep `pipeline_par` microbatches\n      # If num_microbatches < PP, we keep num_microbatches for all PP stages\n      if self.exe._num_microbatches < self.exe.pipeline_par:\n        mem_microbatches = self.exe._num_microbatches\n      else:\n        mem_microbatches = self.exe.pipeline_par\n      if self.exe.activation_recompute == \"full\":\n        assert self._block_act_storage_space == 0, \\\n          \"We expect with full act recomputation we recompute ALL activations\"\n        self._act_space = self._block_act_working_space\n        # We would need to store checkpoints for all microbatches before we\n        # compute BW pass with regular schedule, but we ONLY use 1F1B schedule\n        self._act_checkpoint_size = self._blocks_per_proc * \\\n          self._block_act_checkpoint_size\n        # Keep activation checkpoints for all pipeline stages for PP\n        if self.exe.pipeline_interleaving > 1:\n          self._act_checkpoint_size *= mem_microbatches * (\n            1 + (self.exe.pipeline_par - 1) / (self.exe.pipeline_interleaving *\n                                               self.exe.pipeline_par))\n        else:\n          assert self.exe.pipeline_interleaving == 1\n          self._act_checkpoint_size *= mem_microbatches\n      else:\n        # Without full recompute, we don't need checkpoints\n        self._act_checkpoint_size = 0\n        # Without full recompute, we keep activations for all blocks on the GPU,\n        # one activation for working block, and activation for other blocks for\n        # all pipeline stages w.r.t. interleaved 1F1B schedule\n        if self.exe.pipeline_interleaving > 1:\n          pp_microbatch_factor = mem_microbatches * (\n            1 + (self.exe.pipeline_par - 1) / (self.exe.pipeline_interleaving *\n                                               self.exe.pipeline_par))\n        else:\n          assert self.exe.pipeline_interleaving == 1\n          pp_microbatch_factor = mem_microbatches\n        self._act_space = self._block_act_working_space + \\\n          self._block_act_storage_space * (\n            self._blocks_per_proc * pp_microbatch_factor - 1)\n      # Only need activation grads for a single block\n      self._act_grad_space = self._block_act_grad_space\n    else:\n      self._act_space = self._block_act_working_space\n      self._act_checkpoint_size = 0\n      self._act_grad_space = 0\n\n    # Optimizer split  already accounted for during block compilation\n    # We should keep non-sharded weight grad for a current block for AllReduce\n    # and one that we currently compute, so 2x total\n    # We only need a single no sharded weight grad copy for before reduction\n    if self.exe.training:\n      if self._blocks_per_proc == 1:\n        self._weight_grad_space = self._block_weight_grad_space_no_sharding\n      else:\n        self._weight_grad_space = \\\n          self._block_weight_grad_space_no_sharding + \\\n          self._block_weight_grad_space * (self._blocks_per_proc - 1)\n      self._optimizer_space = \\\n        self._block_optimizer_space * self._blocks_per_proc\n    else:\n      self._weight_grad_space = 0\n      self._optimizer_space = 0\n\n  def _check_mem_caps(self):\n    if self.get_mem_tier1_cap_req() > self.sys.mem1.capacity:\n      raise self.Error(f'Mem tier1 needs '\n                       f'{human_format(self.get_mem_tier1_cap_req(), \"bytes\")} '\n                       f'but only has '\n                       f'{human_format(self.sys.mem1.capacity, \"bytes\")}')\n    if self.get_mem_tier2_cap_req() > self.sys.mem2.capacity:\n      raise self.Error(f'Mem tier2 needs '\n                       f'{human_format(self.get_mem_tier2_cap_req(), \"bytes\")} '\n                       f'but only has '\n                       f'{human_format(self.sys.mem2.capacity, \"bytes\")}')\n\n  def _misc_sanity_checks(self):\n    if self.exe.tensor_par == 1:\n      assert self.get_tp_comm_exposed_time() == 0\n      assert self.get_tp_comm_link_time() == 0\n    if self.exe.pipeline_par == 1:\n      assert self.get_pp_comm_exposed_time() == 0\n      assert self.get_pp_comm_link_time() == 0\n    if self.exe.data_par == 1:\n      assert self.get_dp_comm_exposed_time() == 0\n      assert self.get_dp_comm_link_time() == 0\n\n    assert self._fw_flops >= self._block_fw_flops\n    assert self._fw_flops_time >= self._block_fw_flops_time\n    assert self._fw_mem_accessed >= self._block_fw_mem_accessed\n    assert self._fw_mem_time >= self._block_fw_mem_time\n    assert self._fw_time >= self._block_fw_time\n    assert self._re_flops >= self._block_re_flops\n    assert self._re_flops_time >= self._block_re_flops_time\n    assert self._re_mem_accessed >= self._block_re_mem_accessed\n    assert self._re_mem_time >= self._block_re_mem_time\n    assert self._re_time >= self._block_re_time\n    assert self._agrad_flops >= self._block_agrad_flops\n    assert self._agrad_flops_time >= self._block_agrad_flops_time\n    assert self._agrad_mem_accessed >= self._block_agrad_mem_accessed\n    assert self._agrad_mem_time >= self._block_agrad_mem_time\n    assert self._agrad_time >= self._block_agrad_time\n    assert self._wgrad_flops >= self._block_wgrad_flops\n    assert self._wgrad_flops_time >= self._block_wgrad_flops_time\n    assert self._wgrad_mem_accessed >= self._block_wgrad_mem_accessed\n    assert self._wgrad_mem_time >= self._block_wgrad_mem_time\n    assert self._wgrad_time >= self._block_wgrad_time\n    assert self._optim_flops >= self._block_optim_flops\n    assert self._optim_flops_time >= self._block_optim_flops_time\n    assert self._optim_mem_accessed >= self._block_optim_mem_accessed\n    assert self._optim_mem_time >= self._block_optim_mem_time\n    assert self._optim_time >= self._block_optim_time\n    assert self._weight_space >= self._block_weight_space\n    assert self._act_space >= self._block_act_working_space\n    assert self._act_checkpoint_size >= self._block_act_checkpoint_size\n    assert self._weight_grad_space >= self._block_weight_grad_space_no_sharding\n    assert self._act_grad_space == self._block_act_grad_space\n    assert self._optimizer_space >= self._block_optimizer_space\n\n    if not self.exe.training:\n      # when not training (inference), backward is not performed and DP has no\n      # communication overhead\n      assert self.get_bw_time() == 0\n      assert self.get_optim_step_time() == 0\n      assert self.get_bw_offload_time() == 0\n      assert self.get_recompute_time() == 0\n      assert self.get_act_checkpoint_size() == 0\n      assert self.get_dp_comm_exposed_time() == 0\n      assert self.get_dp_comm_link_time() == 0\n    else:\n      # when training, backward is performed\n      assert self.get_bw_time() > 0\n      assert self.get_optim_step_time() > 0\n      if self.exe.activation_recompute == 'full':\n        assert self.get_recompute_time() > 0\n        assert self.get_act_checkpoint_size() > 0\n      elif self.exe.activation_recompute == 'attn_only':\n        assert self.get_recompute_time() > 0\n        assert self.get_act_checkpoint_size() == 0\n      else:\n        if not self.exe.seq_par_ag_redo:\n          assert self.get_recompute_time() == 0\n        assert self.get_act_checkpoint_size() == 0\n\n\n  def run(self, sys):\n    assert self._compiled, \"You must first call self.compile()\"\n    assert not self._executed\n    assert isinstance(sys, System)\n    self._compute_block_stats()\n    self._compute_batch_stats()\n    self._check_mem_caps()\n    self._misc_sanity_checks()\n    self._executed = True\n\n  def _get_fw_offload_size(self):\n    if self.exe.weight_offload:\n      weight_offload_size = self._block_weight_space\n    else:\n      weight_offload_size = 0\n    if self.exe.activations_offload:\n      if self.exe.activation_recompute != 'full':\n        act_offload_size = self._block_act_storage_space\n      else:\n        act_offload_size = self._block_act_checkpoint_size\n    else:\n      act_offload_size = 0\n    return max(weight_offload_size, act_offload_size)\n\n  def _get_bw_offload_size(self):\n    bw_offload_size = 0\n    if self.exe.training:\n      if self.exe.weight_offload:\n        bw_offload_size += self._block_weight_space\n      if self.exe.activations_offload:\n        if self.exe.activation_recompute != 'full':\n          bw_offload_size += self._block_act_storage_space\n        else:\n          bw_offload_size += self._block_act_checkpoint_size\n      if self.exe.optimizer_offload:\n        bw_offload_size += self._block_optimizer_space\n    return bw_offload_size\n\n  def get_fw_time(self):\n    return self._fw_time\n\n  def get_fw_offload_time(self):\n    return self.sys.compute_offload_time(self._get_fw_offload_size())\n\n  def get_fw_offload_overhead(self):\n    full_overhead = self.exe._num_microbatches * self._chunks_per_proc * (\n      (self._baseblocks_per_chunk * self._baseblock_fw_offload_overhead) +\n      (self._edgeblocks_per_chunk * self._edgeblock_fw_offload_overhead))\n    return full_overhead\n\n  def get_bw_time(self):\n    return self._agrad_time + self._wgrad_time\n\n  def get_optim_step_time(self):\n    return self._optim_time\n\n  def get_bw_offload_time(self):\n    if self.exe.training:\n      return self.sys.compute_offload_time(self._get_bw_offload_size())\n    else:\n      return 0\n\n  def get_bw_offload_overhead(self):\n    if self.exe.training:\n      full_overhead = self.exe._num_microbatches * self._chunks_per_proc * (\n        (self._baseblocks_per_chunk * self._baseblock_bw_offload_overhead) +\n        (self._edgeblocks_per_chunk * self._edgeblock_bw_offload_overhead))\n      return full_overhead\n    else:\n      return 0\n\n  def get_recompute_time(self):\n    return self._re_time\n\n  def get_recomm_exposed_time(self):\n    if self.exe.training:\n      return self._recomm_time_exposed\n    else:\n      return 0\n\n  def get_recomm_link_time(self):\n    if self.exe.training:\n      return self._recomm_time_link\n    else:\n      return 0\n\n  def get_bubble_time(self):\n    return self._bubble_time\n\n  def get_tp_comm_exposed_time(self):\n    return self._tp_comm_time_exposed\n\n  def get_pp_comm_exposed_time(self):\n    return self._pp_comm_time_exposed\n\n  def get_dp_comm_exposed_time(self):\n    if self.exe.training:\n      return self._dp_comm_time_exposed\n    else:\n      return 0\n\n  def get_tp_comm_link_time(self):\n    return self._tp_comm_time_link\n\n  def get_pp_comm_link_time(self):\n    return self._pp_comm_time_link\n\n  def get_dp_comm_link_time(self):\n    if self.exe.training:\n      return self._dp_comm_time_link\n    else:\n      return 0\n\n  def get_dp_comm_net_time(self):\n    if self.exe.training:\n      return self._blocks_per_proc * self._block_dp_time\n    else:\n      return 0\n\n  def get_total_time(self):\n    time = self.get_fw_time()\n    time += self.get_bw_time()\n    time += self.get_optim_step_time()\n    time += self.get_fw_offload_overhead()\n    time += self.get_bw_offload_overhead()\n    time += self.get_recompute_time()\n    time += self.get_recomm_exposed_time()\n    time += self.get_bubble_time()\n    time += self.get_tp_comm_exposed_time()\n    time += self.get_pp_comm_exposed_time()\n    time += self.get_dp_comm_exposed_time()\n    return time\n\n  def get_useful_flops(self):\n    total_flops = sum(\n      [block.get_fw_flops() for block in self._llm_block])\n    if self.exe.training:\n      total_flops += sum(\n        [block.get_agrad_flops() + block.get_wgrad_flops() + \\\n          block.get_optim_step_flops() for block in self._llm_block])\n    return total_flops\n\n  def get_compute_efficiency(self):\n    total_flops = self.get_useful_flops()\n    compute_time = self.get_fw_time() + self.get_bw_time() + \\\n      self.get_optim_step_time()\n    perfect_time = self._blocks_per_proc * self.exe._num_microbatches * \\\n      total_flops / self.sys.matrix.flops(self.exe.datatype)\n    return perfect_time / compute_time\n\n  def get_system_efficiency(self):\n    compute_time = self.get_fw_time() + self.get_bw_time() + \\\n      self.get_optim_step_time()\n    return compute_time / self.get_total_time()\n\n  def get_total_efficiency(self):\n    total_flops = self.get_useful_flops()\n    perfect_time = self._blocks_per_proc * self.exe._num_microbatches * \\\n      total_flops / self.sys.matrix.flops(self.exe.datatype)\n    return perfect_time / self.get_total_time()\n\n  def get_weight_space_min(self):\n    return self._block_weight_space * 2\n\n  def get_weight_space(self):\n    return self._weight_space\n\n  def get_act_space_min(self):\n    if self.exe.activation_recompute != 'full':\n      return self._block_act_working_space + self._block_act_storage_space\n    else:\n      return self._block_act_working_space\n\n  def get_act_space(self):\n    return self._act_space\n\n  def get_act_checkpoint_size_min(self):\n    if self.exe.training:\n      if self.exe.activation_recompute != 'full':\n        return 0\n      else:\n        return self._block_act_checkpoint_size * 2\n\n  def get_act_checkpoint_size(self):\n    if self.exe.training:\n      if self.exe.activation_recompute != 'full':\n        return 0\n      else:\n        return self._act_checkpoint_size\n    else:\n      return 0\n\n  def get_weight_grad_space_min(self):\n    if self.exe.training:\n      # We keep one set of non-sharded weight grads after compute before\n      # reduction, and one sharded set for offloading\n      return self._block_weight_grad_space_no_sharding + \\\n        self._block_weight_grad_space\n    else:\n      return 0\n\n  def get_weight_grad_space(self):\n    if self.exe.training:\n      return self._weight_grad_space\n    else:\n      return 0\n\n  def get_act_grad_space_min(self):\n    return self.get_act_grad_space()\n\n  def get_act_grad_space(self):\n    if self.exe.training:\n      return self._act_grad_space\n    else:\n      return 0\n\n    return self._block_optimizer_space * 2\n\n  def get_optimizer_space_min(self):\n    if self.exe.training:\n      return self._block_optimizer_space * 2\n    else:\n      return 0\n\n  def get_optimizer_space(self):\n    if self.exe.training:\n      return self._optimizer_space\n    else:\n      return 0\n\n  def _get_mem_cap_reqs(self):\n    tier1 = 0\n    tier2 = 0\n    if self.exe.weight_offload:\n      tier1 += self.get_weight_space_min()\n      tier2 += self.get_weight_space()\n    else:\n      tier1 += self.get_weight_space()\n    if self.exe.activations_offload:\n      if self.exe.activation_recompute != 'full':\n        tier1 += self.get_act_space_min()\n        tier2 += self.get_act_space()\n      else:\n        tier1 += self.get_act_space_min()\n        tier1 += self.get_act_checkpoint_size_min()\n        tier2 += self.get_act_checkpoint_size()\n    else:\n      tier1 += self.get_act_space()\n      tier1 += self.get_act_checkpoint_size()\n    if self.exe.optimizer_offload:\n      # We keep one set of non-sharded weight grads after compute before\n      # reduction, and one sharded set for offloading\n      tier1 += self.get_weight_grad_space_min()\n      tier1 += self.get_optimizer_space_min()\n      tier2 += self._block_weight_grad_space * self._blocks_per_proc\n      tier2 += self.get_optimizer_space()\n    else:\n      tier1 += self.get_weight_grad_space() + \\\n        self.get_optimizer_space()\n    tier1 += self.get_act_grad_space()\n    return tier1, tier2\n\n  def get_mem_tier1_cap_req(self):\n    return self._get_mem_cap_reqs()[0]\n\n  def get_mem_tier2_cap_req(self):\n    return self._get_mem_cap_reqs()[1]\n\n  def get_act_offload_bw_req(self):\n    # We should be able to offload (write) activation during FW pass and\n    # prefetch it (read) during BW pass for block (i-1)\n    # After BW pass activations are discarded\n    if self.exe.activation_recompute != 'full':\n      act_offload_size = self._block_act_storage_space\n    else:\n      act_offload_size = self._block_act_checkpoint_size\n    offload_time = min(\n      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,\n      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)\n    return act_offload_size / offload_time\n\n  def get_weight_offload_bw_req(self):\n    # We should be able to offload (write) and prefetch (read) weights both\n    # during FW and BW passes for blocks (i-1) / (i+1).\n    # We always keep weights, they cannot be discarded\n    offload_time = min(\n      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,\n      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)\n    return self._block_weight_space / offload_time\n\n  def get_optim_offload_bw_req(self):\n    # We should be able to offload (write) weight grads and optimizer state\n    # and prefetch (read) optimizer state during BW passes for blocks\n    # (i-1) / (i+1).\n    if self.exe.training:\n      offload_time = min(\n        self._baseblock_bw_time_no_offload - (self._block_agrad_mem_time +\n          self._block_wgrad_mem_time),\n        self._edgeblock_bw_time_no_offload - (self._block_agrad_mem_time +\n          self._block_wgrad_mem_time))\n      return (self._block_weight_grad_space + self._block_optimizer_space) / \\\n        offload_time\n    else:\n      return 0\n\n  def get_offload_mem_bw_req(self):\n    fw_offload_time = min(\n      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,\n      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)\n    if self.exe.training:\n      bw_offload_time = min(\n        self._baseblock_bw_time_no_offload - (self._block_agrad_mem_time +\n          self._block_wgrad_mem_time),\n        self._edgeblock_bw_time_no_offload - (self._block_agrad_mem_time +\n          self._block_wgrad_mem_time))\n      req_bw = max(self._get_fw_offload_size() / fw_offload_time,\n                   self._get_bw_offload_size() / bw_offload_time)\n      return req_bw\n    else:\n      return self._get_fw_offload_size() / fw_offload_time\n\n  def get_sample_rate(self):\n    return self.exe.global_batch_size / self.get_total_time()\n\n  def display_stats(self):\n    stats = \"=\" * 80 + \"\\n\"\n    stats += \"\" \\\n      f\"blocks={self.app.num_blocks}, \" \\\n      f\"hidden={self.app.hidden}, feedforward={self.app.feedforward}\\n\" \\\n      f\"num attn heads: {self.app.attn_heads}, \" \\\n      f\"attn_size={self.app.attn_size}\\n\" \\\n      f\"Run on {self.exe.num_procs} processors with:\\n\" \\\n      f\"TP={self.exe.tensor_par}\\n\" \\\n      f\"PP={self.exe.pipeline_par}\\n\" \\\n      f\"DP={self.exe.data_par}\\n\" \\\n      f\"Blocks per processor: {self._blocks_per_proc}\\n\" \\\n      f\"Execution: {self.exe.get_json()};\\n\" \\\n      f\"System: {self.sys.cfg};\\n\" \\\n      f\"Weights: {human_format(self.get_weight_space(), 'bytes')};\\n\" \\\n      f\"Act: {human_format(self.get_act_space(), 'bytes')};\\n\" \\\n      f\"Act CP: {human_format(self.get_act_checkpoint_size(), 'bytes')};\\n\" \\\n      f\"Act grad: {human_format(self.get_act_grad_space(), 'bytes')};\\n\" \\\n      f\"Weight grad: {human_format(self.get_weight_grad_space(), 'bytes')};\\n\" \\\n      f\"Optim space: {human_format(self.get_optimizer_space(), 'bytes')};\\n\" \\\n      f\"Batch FW time: {self.get_fw_time():.4f};\\n\" \\\n      f\"Batch BW time: {self.get_bw_time():.4f};\\n\" \\\n      f\"Batch optim time: {self.get_optim_step_time():.4f};\\n\" \\\n      f\"Batch FW offload overhead: {self.get_fw_offload_overhead():.4f};\\n\" \\\n      f\"Batch BW offload overhead: {self.get_bw_offload_overhead():.4f};\\n\" \\\n      f\"Batch recompute overhead: {self.get_recompute_time():.4f};\\n\" \\\n      f\"Batch recomm overhead: {self.get_recomm_exposed_time():.4f};\\n\" \\\n      f\"Batch bubble overhead: {self.get_bubble_time():.4f};\\n\" \\\n      f\"Batch TP comm overhead: {self.get_tp_comm_exposed_time():.4f};\\n\" \\\n      f\"Batch PP comm overhead: {self.get_pp_comm_exposed_time():.4f};\\n\" \\\n      f\"Batch DP comm overhead: {self.get_dp_comm_exposed_time():.4f};\\n\" \\\n      f\"Batch TP comm time on link: {self.get_tp_comm_link_time():.4f};\\n\" \\\n      f\"Batch PP comm time on link: {self.get_pp_comm_link_time():.4f};\\n\" \\\n      f\"Batch DP comm time on link: {self.get_dp_comm_link_time():.4f};\\n\" \\\n      f\"Batch total time: {self.get_total_time():.4f};\\n\" \\\n      f\"Activation offload required BW: \" \\\n      f\"{human_format(self.get_act_offload_bw_req(), 'bandwidth')};\\n\" \\\n      f\"Weight offload required BW: \" \\\n      f\"{human_format(self.get_weight_offload_bw_req(), 'bandwidth')};\\n\" \\\n      f\"Optimizer offload required BW: \" \\\n      f\"{human_format(self.get_optim_offload_bw_req(), 'bandwidth')};\\n\" \\\n      f\"Total offload required BW: \" \\\n      f\"{human_format(self.get_offload_mem_bw_req(), 'bandwidth')};\\n\" \\\n      f\"Mem tier1 capacity requirement: \" \\\n      f\"{human_format(self.get_mem_tier1_cap_req(), 'bytes')};\\n\" \\\n      f\"Mem tier2 capacity requirement: \" \\\n      f\"{human_format(self.get_mem_tier2_cap_req(), 'bytes')};\\n\" \\\n      f\"Mem tier2 BW for offload: \" \\\n      f\"{human_format(self.get_offload_mem_bw_req(), 'bandwidth')};\\n\" \\\n      f\"Compute efficiency: {self.get_compute_efficiency()*100:.2f}%;\\n\" \\\n      f\"System efficiency: {self.get_system_efficiency()*100:.2f}%;\\n\" \\\n      f\"Total efficiency: {self.get_total_efficiency()*100:.2f}%;\\n\" \\\n      f\"Sample rate: {self.get_sample_rate():.2f};\\n\"\n    self.log.info(stats)\n"
  },
  {
    "path": "calculon/llm/optimal_execution.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport datetime\nimport gzip\nimport logging\nimport multiprocessing as mp\nimport psutil\nimport os\n\nimport calculon\nfrom calculon.util import pick, arg_true_false_all\nfrom calculon.llm import *\n\n\nclass OptimalExecution(calculon.CommandLine):\n  NAME = 'llm-optimal-execution'\n  ALIASES = ['loe']\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(\n      OptimalExecution.NAME, aliases=OptimalExecution.ALIASES,\n      help='run a search to find the optimal llm execution')\n    sp.set_defaults(func=OptimalExecution.run_command)\n    sp.add_argument('-d', '--debug', action='store_true',\n                    help='Loop over executions, don\\'t run them')\n    sp.add_argument('application', type=str,\n                    help='File path to application configuration')\n    sp.add_argument('num_procs', type=int,\n                    help='Number of processors in execution')\n    sp.add_argument('max_batch_size', type=int,\n                    help='Maximum batch size, will be largest multiple of DP')\n    sp.add_argument('datatype', type=str, choices=System.supported_datatypes(),\n                    help='The datatype to use')\n    sp.add_argument('system', type=str,\n                    help='File path to system configuration')\n    sp.add_argument('output', type=str,\n                    help='File path to the output file'\n                    \" ('*.csv', '*.csv.gz', '*.json', '*.json.gz')\")\n    sp.add_argument('-c', '--cpus', type=int, default=psutil.cpu_count(logical=False),\n                    help='CPUs to use for parallelization')\n    sp.add_argument('-n', '--noneok', action='store_true',\n                    help='Don\\'t give failure status when no good execution exists')\n    sp.add_argument('-m', '--mbs-break', action='store_true',\n                    help='Search across MBS and break earlier when possible')\n    sp.add_argument('-t', '--top-n', type=int, default=1,\n                    help='Number of best outputs')\n    sp.add_argument('-l', '--layers', action='store_true',\n                    help='Include layers information in output stats file')\n    sp.add_argument('-f', '--fused_activation', type=arg_true_false_all,\n                    default='true', help='Mode of fused activation')\n    sp.add_argument('--no-tp-overlap', action='store_true',\n                    help='Don\\'t allow TP overlap')\n    sp.add_argument('--no-dp-overlap', action='store_true',\n                    help='Don\\'t allow DP overlap')\n\n  @staticmethod\n  def run_command(logger, args):\n    assert args.top_n > 0, 'top-n must be > 0'\n\n    app = Llm.Application(calculon.io.read_json_file(args.application))\n    syst = System(calculon.io.read_json_file(args.system))\n\n    params = []\n    for tp in Llm.get_all_tensor_parallelisms(\n        args.num_procs, app.hidden, app.attn_heads):\n      for pp in Llm.get_all_pipeline_parallelisms(\n          args.num_procs, tp, app.num_blocks):\n        dp = Llm.get_data_parallelism(args.num_procs, tp, pp)\n        for ppint in Llm.get_valid_pipeline_interleavings(app.num_blocks, pp):\n          batch_size = OptimalExecution.get_batch_size(dp, args.max_batch_size)\n          if batch_size is None:\n            continue\n          for activation_recompute in ['full', 'attn_only', 'none']:\n            for optimizer_sharding in pick(dp>1, [True, False], [False]):\n              for tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']:\n                params.append(\n                  (args.debug, args.top_n, args.layers, args.num_procs,\n                   args.max_batch_size, args.datatype, app, syst, tp, pp, dp,\n                   ppint, batch_size, activation_recompute, optimizer_sharding,\n                   tensor_par_comm_type, args.fused_activation, args.mbs_break,\n                   not args.no_tp_overlap, not args.no_dp_overlap))\n\n    # Runs parallel searches\n    start_time = datetime.datetime.now()\n    with mp.Pool(args.cpus) as pool:\n      searches = pool.starmap(OptimalExecution.search, params)\n    end_time = datetime.datetime.now()\n\n    # Combines parallel search result into one data structure\n    best = []\n    exe_count = 0\n    good_exe_count = 0\n    bad_exe_count = 0\n    for cbest, ec, gec, bec, tp, pp in searches:\n      best = OptimalExecution.update_list(best, cbest, args.top_n)\n      exe_count += ec\n      good_exe_count += gec\n      bad_exe_count += bec\n\n    logger.info(f'Total executions: {exe_count}')\n    logger.info(f'Good executions: {good_exe_count}')\n    logger.info(f'Bad executions: {bad_exe_count}')\n    calc_rate = exe_count / (end_time - start_time).total_seconds()\n    logger.info(f'Calculation rate: {calc_rate:.2f} calcs/sec')\n    if args.debug:\n      return 0\n\n    if len(best) == 0:\n      if not args.noneok:\n        logger.fatal('No acceptable configurations found :(')\n        return -1\n      else:\n        logger.info('No acceptable configurations found :(')\n    else:\n      logger.info(f'Best sample rate: {best[0][0]}')\n\n    output = {}\n    for index, run in enumerate(best):\n      _, execution, stats = run\n      output[index] = {\n        'execution': execution,\n        'stats': stats\n      }\n\n    if calculon.io.is_json_extension(args.output):\n      logger.info(f'Output: {args.output}')\n      calculon.io.write_json_file(output, args.output)\n    elif args.output.endswith('.csv') or args.output.endswith('.csv.gz'):\n      logger.info(f'Output: {args.output}')\n      exe_keys = list(output[0]['execution'].keys())\n      stats_keys = list(output[0]['stats'].keys())\n      opener = gzip.open if args.output.endswith('.gz') else open\n      with opener(args.output, 'wb') as fd:\n        fd.write(bytes(f',{\",\".join(exe_keys)},{\",\".join(stats_keys)}\\n',\n                       'utf-8'))\n        for index in sorted(output.keys()):\n          fd.write(bytes(f'{index}', 'utf-8'))\n          for exe_key in exe_keys:\n            fd.write(bytes(f',{output[index][\"execution\"][exe_key]}', 'utf-8'))\n          for stats_key in stats_keys:\n            fd.write(bytes(f',{output[index][\"stats\"][stats_key]}', 'utf-8'))\n          fd.write(bytes('\\n', 'utf-8'))\n    else:\n      assert False, f'Unknown file type: {args.output}'\n\n    return 0\n\n  @staticmethod\n  def get_batch_size(data_par, max_batch_size):\n    if data_par > max_batch_size:\n      return None\n    last = data_par\n    while True:\n      if last + data_par > max_batch_size:\n        return last\n      else:\n        last += data_par\n\n  @staticmethod\n  def search(debug, top_n, layers, num_procs, max_batch_size, datatype,\n             app, syst, tp, pp, dp, ppint, batch_size, activation_recompute,\n             optimizer_sharding, tensor_par_comm_type, fused_acts, mbs_break,\n             allow_tp_overlap, allow_dp_overlap):\n    num_nets = syst.num_networks\n\n    best = []\n    exe_count = 0\n    good_exe_count = 0\n    bad_exe_count = 0\n\n    has_mem2 = syst.mem2.capacity > 0\n\n    can_redo = Llm.can_redo_ag(tensor_par_comm_type,\n                               activation_recompute)\n    for seq_par_ag_redo in pick(can_redo, [True, False], [False]):\n      for data_par_overlap in pick(dp>1 and allow_dp_overlap, [True, False],\n                                   [False]):\n        for tensor_par_overlap in pick(tp>1 and allow_tp_overlap,\n                                       ['none', 'ring', 'pipe'], ['none']):\n          for weight_offload in pick(has_mem2, [True, False], [False]):\n            if activation_recompute == 'full' or not has_mem2:\n              activations_offloads = [False]\n            else:\n              activations_offloads = [True, False]\n            for activations_offload in activations_offloads:\n              for optimizer_offload in pick(has_mem2, [True, False],\n                                            [False]):\n                for fused_act in fused_acts:\n                  for microbatch_size in Llm.get_valid_microbatch_sizes(\n                      app.seq_size, tp, dp, batch_size, pp):\n                    mbs_break_good = good_exe_count\n                    for tn in pick(tp>1, range(num_nets), [0]):\n                      for pn in pick(pp>1, range(num_nets), [0]):\n                        for dn in pick(dp>1, range(num_nets), [0]):\n                          exe_count += 1\n                          exe_json = {\n                            'num_procs': num_procs,\n                            'tensor_par': tp,\n                            'pipeline_par': pp,\n                            'data_par': dp,\n                            'tensor_par_net': tn,\n                            'pipeline_par_net': pn,\n                            'data_par_net': dn,\n                            'batch_size': batch_size,\n                            'microbatch_size': microbatch_size,\n                            'datatype': datatype,\n                            'fused_activation': fused_act,\n                            'attention_type': 'multihead',\n                            'activation_recompute': activation_recompute,\n                            'pipeline_interleaving': ppint,\n                            'optimizer_sharding': optimizer_sharding,\n                            'tensor_par_comm_type': tensor_par_comm_type,\n                            'tensor_par_overlap': tensor_par_overlap,\n                            'seq_par_ag_redo': seq_par_ag_redo,\n                            'data_par_overlap': data_par_overlap,\n                            'weight_offload': weight_offload,\n                            'activations_offload': activations_offload,\n                            'optimizer_offload': optimizer_offload,\n                            'training': True\n                          }\n\n                          if not debug:\n                            try:\n                              logger = logging.Logger('sub')\n                              model = Llm(app, logger)\n                              model.compile(\n                                syst,\n                                Llm.Execution.from_json(exe_json))\n                              model.run(syst)\n                              stats = model.get_stats_json(layers)\n                              good_exe_count += 1\n                              curr = (stats['sample_rate'], exe_json, stats)\n                              best = OptimalExecution.update_list(best, curr,\n                                                                  top_n)\n                            except Llm.Error as ex:\n                              logger = logging.getLogger()\n                              logger.debug(f'JSON:{exe_json}\\nERROR:{ex}\\n')\n                              bad_exe_count += 1\n                    if mbs_break and good_exe_count == mbs_break_good:\n                      break\n    return (best, exe_count, good_exe_count, bad_exe_count, tp, pp)\n\n  @staticmethod\n  def update_list(current, candidate, quantity):\n    if not isinstance(candidate, list):\n      current.append(candidate)\n    else:\n      current.extend(candidate)\n    current.sort(reverse=True, key=lambda x: x[0])\n    return current[:quantity]\n\n\ncalculon.CommandLine.register(OptimalExecution)\n"
  },
  {
    "path": "calculon/llm/parameter_calculator.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport os\n\nimport calculon\nfrom calculon.llm import *\n\nclass ParameterCalculator(calculon.CommandLine):\n  NAME = 'llm-parameter-calculator'\n  ALIASES = ['lpc']\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(ParameterCalculator.NAME,\n                              aliases=ParameterCalculator.ALIASES,\n                              help='run a single llm calculation')\n    sp.set_defaults(func=ParameterCalculator.run_command)\n    sp.add_argument('application', type=str,\n                    help='File path to application configuration')\n    sp.add_argument('-a', '--alignment', type=int, default=13,\n                    help='Alignment spaces')\n\n  @staticmethod\n  def run_command(logger, args):\n    app_json = calculon.io.read_json_file(args.application)\n\n    try:\n      app = Llm.Application(app_json)\n    except Llm.Error as error:\n      print(f'ERROR: {error}')\n      return -1\n\n    app_name, _ = os.path.splitext(os.path.basename(args.application))\n\n    logger.info(f'{app_name}'\n                f'{\" \" * (args.alignment - len(app_name))}'\n                ' -> '\n                f'{human_format(app.num_parameters())}')\n\n\ncalculon.CommandLine.register(ParameterCalculator)\n"
  },
  {
    "path": "calculon/llm/runner.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport calculon\nfrom calculon.llm import *\n\nclass Runner(calculon.CommandLine):\n  NAME = 'llm'\n  ALIASES = []\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(Runner.NAME, aliases=Runner.ALIASES,\n                              help='run a single llm calculation')\n    sp.set_defaults(func=Runner.run_command)\n    sp.add_argument('application', type=str,\n                    help='File path to application configuration')\n    sp.add_argument('execution', type=str,\n                    help='File path to execution configuration')\n    sp.add_argument('system', type=str,\n                    help='File path to system configuration')\n    sp.add_argument('stats', type=str,\n                    help='File path to stats output (\"-\" for stdout\")')\n    sp.add_argument('-p', '--peers', type=str, default=None,\n                    help='File path to write out peers file')\n    sp.add_argument('-l', '--layers', action='store_true',\n                    help='Include layers information in output stats file')\n\n  @staticmethod\n  def run_command(logger, args):\n    app_json = calculon.io.read_json_file(args.application)\n    exe_json = calculon.io.read_json_file(args.execution)\n    sys_json = calculon.io.read_json_file(args.system)\n\n    app = Llm.Application(app_json)\n    exe = Llm.Execution.from_json(exe_json)\n    syst = System(sys_json)\n\n    try:\n      model = Llm(app, logger)\n      model.compile(syst, exe)\n      model.run(syst)\n    except Llm.Error as error:\n      print(f'ERROR: {error}')\n      return -1\n\n    if args.stats == '-':\n      model.display_stats()\n    elif calculon.is_json_extension(args.stats):\n      calculon.write_json_file(model.get_stats_json(args.layers), args.stats)\n    else:\n      assert False, f'unknown stats extension: {args.stats}'\n\n    if args.peers:\n      calculon.write_json_file(exe.get_peers_json(), args.peers)\n\n    return 0\n\n\ncalculon.CommandLine.register(Runner)\n"
  },
  {
    "path": "calculon/llm/validation.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport logging\nimport math\nimport os\n\nimport calculon\nfrom calculon.util import pick\nfrom calculon.llm import *\n\n\nclass Validation(calculon.CommandLine):\n  NAME = 'llm-validation'\n  ALIASES = ['lv']\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(\n      Validation.NAME, aliases=Validation.ALIASES,\n      help='run a validation of llm execution')\n    sp.set_defaults(func=Validation.run_command)\n    sp.add_argument('-b', '--base_dir', default='.',\n                    help='Base directory')\n    sp.add_argument('-v', '--verbose', action='store_true',\n                    help='Show verbose output while running')\n\n  @staticmethod\n  def run_command(logger, args):\n    funcs = [\n      Validation.seqsel_fig1,\n      Validation.seqsel_fig7,\n      Validation.seqsel_tab5\n    ]\n    for func in funcs:\n      if args.verbose:\n        print(f'\\n\\nNow running test: {func.__name__}')\n      if func(logger, args) is not None:\n        return -1\n\n  @staticmethod\n  def seqsel_fig1(logger, args):\n    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']\n    kModes = ['none', 'seqsel']\n    # These profiled values are reported here:\n    # https://arxiv.org/pdf/2205.05198.pdf\n    # Figure 1\n    kProfile = {\n      'megatron-22B': {\n        'none': {\n          'par_opt': 45.5625,\n          'act': 59.25\n        },\n        'seqsel': {\n          'par_opt': 45.5625,\n          'act': 9.5625\n        }\n      },\n      'gpt3-175B': {\n        'none': {\n          'par_opt': 45.5625,\n          'act': 66.84375\n        },\n        'seqsel': {\n          'par_opt': 45.5625,\n          'act': 12.3515625\n        }\n      },\n      'turing-530B': {\n        'none': {\n          'par_opt': 31.640625,\n          'act': 114.0234375\n        },\n        'seqsel': {\n          'par_opt': 31.640625,\n          'act': 23.076171875\n        }\n      },\n      'megatron-1T': {\n        'none': {\n          'par_opt': 32.958984375,\n          'act': 131.25\n        },\n        'seqsel': {\n          'par_opt': 32.958984375,\n          'act': 26.5625\n        }\n      }\n    }\n\n    def get_files(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      app = os.path.join(args.base_dir, 'models', f'{model}.json')\n      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'fig1',\n                         f'{model}_{mode}.json')\n      return app, exe\n\n    def get_profile(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      return kProfile[model][mode]\n\n    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80e.json')\n    syst = System(calculon.io.read_json_file(syst_file))\n    data = {}\n    for model in kModels:\n      data[model] = {}\n      for mode in kModes:\n        if args.verbose:\n          print(f'Analyzing {model} {mode}')\n        data[model][mode] = {}\n        app_file, exe_file = get_files(model, mode)\n        app = Llm.Application(calculon.read_json_file(app_file))\n        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))\n        mt = Llm(app, logger)\n        mt.compile(syst, exe)\n        mt.run(syst)\n        stats = mt.get_stats_json(False)\n        data[model][mode]['profile_gib'] = get_profile(model, mode)\n        act_par_opt = (stats['weight_space'] + stats['weight_grad_space'] +\n                       stats['optimizer_space']) / (1024**3)\n        act_act = stats['act_space'] / (1024**3)\n        data[model][mode]['actual_gib'] = {\n          'par_opt': act_par_opt,\n          'act': act_act\n        }\n\n    print('*Params & Opt,|,none,,,|,seqsel,,,')\n    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')\n    max_error = 0\n    abs_error = 0\n    for model in kModels:\n      print(f'{model},', end='')\n      for mode in kModes:\n        p = data[model][mode]['profile_gib']['par_opt']\n        a = data[model][mode]['actual_gib']['par_opt']\n        d = 100*(1-a/p)\n        if math.fabs(d) > max_error:\n          max_error = math.fabs(d)\n        abs_error += math.fabs(d)\n        print(f'|,{p},{a:.2f},{d:.2f}%,', end='')\n      print()\n    ave_error = abs_error / (len(kModels) * len(kModes))\n    print(f'Ave,,{ave_error:.2f}%')\n    print(f'Max,,{max_error:.2f}%')\n    print(',')\n\n    print('*Activations,|,none,,,|,seqsel,,,')\n    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')\n    max_error = 0\n    abs_error = 0\n    for model in kModels:\n      print(f'{model},', end='')\n      for mode in kModes:\n        p = data[model][mode]['profile_gib']['act']\n        a = data[model][mode]['actual_gib']['act']\n        d = 100*(1-a/p)\n        if math.fabs(d) > max_error:\n          max_error = math.fabs(d)\n        abs_error += math.fabs(d)\n        print(f'|,{p},{a:.2f},{d:.2f}%,', end='')\n      print()\n    ave_error = abs_error / (len(kModels) * len(kModes))\n    print(f'Ave,,{ave_error:.2f}%')\n    print(f'Max,,{max_error:.2f}%')\n    print(',')\n\n  @staticmethod\n  def seqsel_fig7(logger, args):\n    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']\n    kModes = ['none', 'seq', 'sel', 'seqsel', 'full']\n    # These profiled values are reported here:\n    # https://arxiv.org/pdf/2205.05198.pdf\n    # Figure 7\n    kProfile = {\n      'megatron-22B': {\n        'none': 100.00,\n        'seq': 66.84,\n        'sel': 49.42,\n        'seqsel': 16.18,\n        'full': 7.64\n      },\n      'gpt3-175B': {\n        'none': 100.00,\n        'seq': 62.04,\n        'sel': 56.53,\n        'seqsel': 18.49,\n        'full': 8.71\n      },\n      'turing-530B': {\n        'none': 100.00,\n        'seq': 58.31,\n        'sel': 62.04,\n        'seqsel': 20.27,\n        'full': 9.42\n      },\n      'megatron-1T': {\n        'none': 100.00,\n        'seq': 58.31,\n        'sel': 62.04,\n        'seqsel': 20.27,\n        'full': 9.42\n      }\n    }\n\n    def get_files(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      app = os.path.join(args.base_dir, 'models', f'{model}.json')\n      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'fig7',\n                         f'{model}_{mode}.json')\n      return app, exe\n\n    def get_profile(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      return kProfile[model][mode]\n\n    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80e.json')\n    syst = System(calculon.io.read_json_file(syst_file))\n    raw = {}\n    for model in kModels:\n      raw[model] = {}\n      for mode in kModes:\n        if args.verbose:\n          print(f'Analyzing {model} {mode}')\n        raw[model][mode] = {}\n        app_file, exe_file = get_files(model, mode)\n        app = Llm.Application(calculon.read_json_file(app_file))\n        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))\n        mt = Llm(app, logger)\n        mt.compile(syst, exe)\n        mt.run(syst)\n        stats = mt.get_stats_json(False)\n        raw[model][mode] = stats['act_space'] + stats['act_checkpoint_size']\n\n    rel = {}\n    for model in kModels:\n      rel[model] = {}\n      for mode in kModes:\n        rel[model][mode] = {}\n        rel[model][mode] = raw[model][mode] / raw[model]['none'] * 100\n\n    print('Activations,|,none,,,|,seq,,,|,sel,,,|,seqsel,,,|,full,,,')\n    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,|'\n          ',Profile,Calc,Delta,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')\n    max_error = 0\n    abs_error = 0\n    for model in kModels:\n      print(f'{model},', end='')\n      for mode in kModes:\n        p = get_profile(model, mode)\n        a = rel[model][mode]\n        d = 100*(1-a/p)\n        if math.fabs(d) > max_error:\n          max_error = math.fabs(d)\n        abs_error += math.fabs(d)\n        print(f'|,{p}%,{a:.2f}%,{d:.2f}%,', end='')\n      print()\n    ave_error = abs_error / (len(kModels) * len(kModes))\n    print(f'Ave,,{ave_error:.2f}%')\n    print(f'Max,,{max_error:.2f}%')\n    print(',')\n\n  @staticmethod\n  def seqsel_tab5(logger, args):\n    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']\n    kModes = ['full', 'seqsel']\n    # These profiled values are reported here:\n    # https://arxiv.org/pdf/2205.05198.pdf\n    # Table 5\n    kProfile = {\n      'megatron-22B': {\n        'full': 1.42,\n        'seqsel': 1.10\n      },\n      'gpt3-175B': {\n        'full': 18.13,\n        'seqsel': 13.75\n      },\n      'turing-530B': {\n        'full': 49.05,\n        'seqsel': 37.83\n      },\n      'megatron-1T': {\n        'full': 94.42,\n        'seqsel': 71.49\n      }\n    }\n\n    def get_files(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      app = os.path.join(args.base_dir, 'models', f'{model}.json')\n      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'tab5',\n                         f'{model}_{mode}.json')\n      return app, exe\n\n    def get_profile(model, mode):\n      assert model in kModels\n      assert mode in kModes\n      return kProfile[model][mode]\n\n    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80g.json')\n    syst = System(calculon.io.read_json_file(syst_file))\n    data = {}\n    for model in kModels:\n      data[model] = {}\n      for mode in kModes:\n        if args.verbose:\n          print(f'Analyzing {model} {mode}')\n        data[model][mode] = {}\n        app_file, exe_file = get_files(model, mode)\n        app = Llm.Application(calculon.read_json_file(app_file))\n        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))\n        mt = Llm(app, logger)\n        mt.compile(syst, exe)\n        mt.run(syst)\n        stats = mt.get_stats_json(False)\n        data[model][mode]['profile_time'] = get_profile(model, mode)\n        data[model][mode]['actual_time'] = stats[\"total_time\"]\n        data[model][mode]['memory_req'] = stats[\"proc_mem_tier1_cap_req\"]\n\n    print('End-to-end,|,full,,,,|,seqsel,,,,')\n    print('Model,|,Profile,Calc,Delta,GiB,|,Profile,Calc,Delta,GiB,')\n    max_error = 0\n    abs_error = 0\n    for model in kModels:\n      print(f'{model},', end='')\n      for mode in kModes:\n        p = data[model][mode]['profile_time']\n        a = data[model][mode]['actual_time']\n        d = 100*(1-a/p)\n        if math.fabs(d) > max_error:\n          max_error = math.fabs(d)\n        abs_error += math.fabs(d)\n        m = data[model][mode]['memory_req'] / (1024**3)\n        print(f'|,{p},{a:.2f},{d:.2f}%,{m:.2f},', end='')\n      print()\n    ave_error = abs_error / (len(kModels) * len(kModes))\n    print(f'Ave,,{ave_error:.2f}%')\n    print(f'Max,,{max_error:.2f}%')\n    print(',')\n\ncalculon.CommandLine.register(Validation)\n"
  },
  {
    "path": "calculon/memory.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nclass Memory:\n  \"\"\"Configuration for a memory.\"\"\"\n\n  def __init__(self, cfg):\n    self._capacity = cfg['GiB'] * 1024**3\n    self._bandwidth = cfg['GBps'] * 1e9\n    self._efficiency = []\n    for mbytes, eff in cfg['MB_efficiency']:\n      bytes = mbytes * 1e6\n      assert 0 < eff <= 1.0\n      self._efficiency.append((bytes, eff))\n\n  @property\n  def capacity(self):\n    return self._capacity\n\n  @property\n  def bandwidth(self):\n    return self._bandwidth\n\n  def efficiency(self, op_bytes):\n    for bytes, eff in self._efficiency:\n      if op_bytes >= bytes:\n        return eff\n    assert False, f'OP bytes {op_bytes} wasn\\'t covered'\n\n  def throughput(self, op_bytes):\n    return self._bandwidth * self.efficiency(op_bytes)\n"
  },
  {
    "path": "calculon/network.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\n\nclass Network:\n  \"\"\"Configuration for a network.\"\"\"\n\n  kKeys = set(['bandwidth', 'efficiency', 'size', 'latency', 'ops',\n               'must_be_filled', 'processor_usage'])\n  kNetOps = set(['p2p', 'reduce_scatter', 'all_gather', 'all_reduce'])\n  kCollectives = set(['reduce_scatter', 'all_gather', 'all_reduce'])\n\n  class Op:\n    def __init__(self, scalar, offset):\n      self.scalar = scalar\n      self.offset = offset\n\n  @staticmethod\n  def _parse_op(op, scalar, offset):\n    assert op in Network.kNetOps, f'Invalid network op: {op}'\n    assert scalar > 0.0, f'Invalid network scalar for {op}: {scalar}'\n    if op in Network.kCollectives:\n      assert offset is not None, f'Must give offset for {op}'\n      return Network.Op(scalar, offset)\n    else:\n      assert offset is None, f'Can\\'t give offset for {op}'\n      return Network.Op(scalar, 0)\n\n  def __init__(self, cfg):\n    assert Network.kKeys == set(cfg.keys())\n    self._bw = cfg['bandwidth'] * 1e9  # Specified in GB/s\n    assert self._bw > 0\n    self._eff = cfg['efficiency']\n    assert 0 < self._eff <= 1.0\n    self._size = cfg['size']\n    assert self._size >= 0\n    self._latency = cfg['latency']\n    self._ops = {}\n    for op in cfg['ops']:\n      self._ops[op] = Network._parse_op(\n        op, cfg['ops'][op][0], cfg['ops'][op][1])\n    assert set(self._ops.keys()) == Network.kNetOps\n    self._must_be_filled = cfg['must_be_filled']\n    self._proc_usage = cfg['processor_usage']\n    assert self._proc_usage >= 0.0 and self._proc_usage < 1.0\n\n  @property\n  def size(self):\n    return self._size\n\n  @property\n  def must_be_filled(self):\n    return self._must_be_filled\n\n  @property\n  def processor_usage(self):\n    return self._proc_usage\n\n  def time(self, op, op_size, comm_size):\n    \"\"\" Computes the time taken for a network operation.\n\n    Args:\n      op (str)        : operation name\n      op_size (int)   : operation size in bytes\n      comm_size (int) : number of participants in operation\n\n    Returns:\n      time (float)    : time needed for operation\n    \"\"\"\n    if op not in Network.kCollectives:\n      assert comm_size == 2\n    else:\n      assert comm_size >= 2\n    assert op in Network.kNetOps\n    assert op_size >= 0\n\n    # Scales the op_size by the scalar\n    op_size *= self._ops[op].scalar\n\n    # Scales the op_size by the op offset\n    chunk_size = 1 / comm_size * op_size\n    op_size += chunk_size * self._ops[op].offset\n\n    # Calculates time based on raw bandwidth,  bandwidth efficiency, and latency\n    return self._latency + op_size / (self._bw * self._eff)\n"
  },
  {
    "path": "calculon/processor.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nclass Processor:\n  \"\"\"Configuration for a processing engine.\"\"\"\n\n  def __init__(self, cfg):\n    self._datatypes = {}\n    for datatype in cfg.keys():\n      self._datatypes[datatype] = {\n        'flops': cfg[datatype]['tflops'] * 1e12,\n        'efficiency': []\n      }\n      last = None\n      for gflops, eff in cfg[datatype]['gflops_efficiency']:\n        flops = gflops * 1e9\n        assert 0 < eff <= 1.0\n        if last:\n          assert flops < last\n        last = flops\n        self._datatypes[datatype]['efficiency'].append((flops, eff))\n\n  def flops(self, datatype):\n    return self._datatypes[datatype]['flops']\n\n  def efficiency(self, datatype, op_flops):\n    for flops, eff in self._datatypes[datatype]['efficiency']:\n      if op_flops >= flops:\n        return eff\n    assert False, f'{op_flops} wasn\\'t covered in {datatype} efficiency curve'\n\n  def throughput(self, datatype, op_flops):\n    assert datatype in self._datatypes, f'Unsupported type: {datatype}'\n    return self.flops(datatype) * self.efficiency(datatype, op_flops)\n"
  },
  {
    "path": "calculon/system.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nfrom .memory import *\nfrom .network import *\nfrom .processor import *\n\nclass System:\n  \"\"\"Configuration for a system.\"\"\"\n\n  TypeSizes = {\n    'float8'   : 1,\n    'float16'  : 2,\n    'float32'  : 4,\n    'bfloat16' : 2\n  }\n\n  @staticmethod\n  def supported_datatypes():\n    return list(System.TypeSizes.keys())\n\n  def __init__(self, cfg):\n    self.cfg = cfg\n    self.matrix = Processor(cfg['matrix'])\n    self.vector = Processor(cfg['vector'])\n    self.datatype = None\n\n    self.mem1 = Memory(cfg['mem1'])\n    self.mem2 = Memory(cfg['mem2'])\n\n    self.proc_mode = cfg['processing_mode']\n    assert self.proc_mode in ['roofline', 'no_overlap']\n\n    self.networks = [Network(n) for n in cfg['networks']]\n\n  @property\n  def num_networks(self):\n    return len(self.networks)\n\n  def get_network(self, tier):\n    assert tier < len(self.networks), f'Bad network tier ID: {tier}'\n    return self.networks[tier]\n\n  def set_datatype(self, datatype):\n    assert datatype in System.TypeSizes, f'Unsupported data type: {datatype}'\n    self.datatype = datatype\n\n  def get_matrix_throughput(self, flops):\n    return self.matrix.throughput(self.datatype, flops)\n\n  def get_vector_throughput(self, flops):\n    return self.vector.throughput(self.datatype, flops)\n\n  def get_mem1_throughput(self, size):\n    return self.mem1.throughput(size)\n\n  def get_mem2_throughput(self, size):\n    return self.mem2.throughput(size)\n\n  def compute_offload_time(self, size):\n    return size / self.mem2.throughput(size)\n\n  def get_processing_time(self, flops_time, mem_time):\n    if self.proc_mode == 'roofline':\n      return max(flops_time, mem_time)\n    elif self.proc_mode == 'no_overlap':\n      return flops_time + mem_time\n"
  },
  {
    "path": "calculon/util.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport argparse\n\n\ndef human_format(value, v_type='base10', precision=3):\n  step = 1\n  suffix = ''\n  if v_type == 'base10':\n    step = 1000\n    suffix = ''\n  elif v_type == 'base2':\n    step = 1024\n    suffix = ''\n  elif v_type == 'bytes':\n    step = 1024\n    suffix = 'iB'\n  elif v_type == 'bandwidth':\n    step = 1000\n    suffix = 'B/s'\n  elif v_type == 'flops':\n    step = 1000\n    suffix = 'Ops'\n  elif v_type == 'throughput':\n    step = 1000\n    suffix = 'Op/s'\n  else:\n    raise ValueError(\n      f\"Type value should be 'base10', 'base2', 'bytes', 'flops', \"\n      f\"'bandwidth', or 'throughput'. You gave {v_type}\")\n  labels = ['', 'k', 'M', 'G', 'T', 'P', 'E']\n  index = 0\n  if value != None:\n    abs_value = abs(value)\n    if value >= 0:\n      sign = 1\n    else:\n      sign = -1\n    for l in labels:\n      if abs_value >= step:\n        abs_value /= step\n        index += 1\n      else:\n        break\n    value = sign * abs_value\n    return \"{0:.{1}f} {2}{3}\".format(value, precision, labels[index], suffix)\n  else:\n    return \"n/a {1}{2}\".format(value, labels[0], suffix)\n\n\ndef pick(en, a, b):\n  if en:\n    return a\n  return b\n\n\ndef arg_true_false_all(arg):\n  trues = ['t', 'true', 'T', 'True', '1']\n  falses = ['f', 'false', 'F', 'False', '0']\n  alls = ['both', 'all', '*']\n  if arg in trues:\n    return [True]\n  elif arg in falses:\n    return [False]\n  elif arg in alls:\n    return [False, True]\n  else:\n    raise argparse.ArgumentTypeError(f'Invalid true/false/all: {arg}')\n"
  },
  {
    "path": "calculon/version.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport calculon\n\nclass Version(calculon.CommandLine):\n  NAME = 'version'\n  ALIASES = ['v']\n\n  @staticmethod\n  def create_parser(subparser):\n    sp = subparser.add_parser(Version.NAME, aliases=Version.ALIASES,\n                              help='show the version and exit')\n    sp.set_defaults(func=Version.run_command)\n\n  @staticmethod\n  def run_command(logger, args):\n    # version is specified in __init__.py\n    logger.info(calculon.__version__)\n\n\ncalculon.CommandLine.register(Version)\n"
  },
  {
    "path": "examples/3072_t4_p64_d12_mbs4_full.json",
    "content": "{\n  \"num_procs\": 3072,\n  \"tensor_par\": 4,\n  \"pipeline_par\": 64,\n  \"data_par\": 12,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 3072,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": true,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": true,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "models/anthropic-52B.json",
    "content": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 8192,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": 64\n}\n"
  },
  {
    "path": "models/chinchilla.json",
    "content": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": 80\n}\n"
  },
  {
    "path": "models/gopher-280B.json",
    "content": "{\n  \"hidden\": 16384,\n  \"feedforward\": 65536,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 128,\n  \"num_blocks\": 80\n}\n"
  },
  {
    "path": "models/gpt3-13B.json",
    "content": "{\n  \"hidden\": 5140,\n  \"feedforward\": 20560,\n  \"seq_size\": 2048,\n  \"attn_heads\": 40,\n  \"attn_size\": 128,\n  \"num_blocks\": 40\n}\n"
  },
  {
    "path": "models/gpt3-175B.json",
    "content": "{\n  \"hidden\": 12288,\n  \"feedforward\": 49152,\n  \"seq_size\": 2048,\n  \"attn_heads\": 96,\n  \"attn_size\": 128,\n  \"num_blocks\": 96\n}\n"
  },
  {
    "path": "models/lamda.json",
    "content": "{\n  \"hidden\": 8192,\n  \"feedforward\": 65536,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 128,\n  \"num_blocks\": 64\n}\n"
  },
  {
    "path": "models/megatron-126M.json",
    "content": "{\n  \"hidden\": 768,\n  \"feedforward\": 3072,\n  \"seq_size\": 2048,\n  \"attn_heads\": 16,\n  \"attn_size\": 48,\n  \"num_blocks\": 12\n}\n"
  },
  {
    "path": "models/megatron-1T.json",
    "content": "{\n  \"hidden\": 25600,\n  \"feedforward\": 102400,\n  \"seq_size\": 2048,\n  \"attn_heads\": 160,\n  \"attn_size\": 160,\n  \"num_blocks\": 128\n}\n"
  },
  {
    "path": "models/megatron-22B.json",
    "content": "{\n  \"hidden\": 6144,\n  \"feedforward\": 24576,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 96,\n  \"num_blocks\": 48\n}\n"
  },
  {
    "path": "models/megatron-40B.json",
    "content": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": 48\n}\n"
  },
  {
    "path": "models/megatron-5B.json",
    "content": "{\n  \"hidden\": 4096,\n  \"feedforward\": 16384,\n  \"seq_size\": 2048,\n  \"attn_heads\": 32,\n  \"attn_size\": 128,\n  \"num_blocks\": 24\n}\n"
  },
  {
    "path": "models/palm-540B.json",
    "content": "{\n  \"hidden\": 18432,\n  \"feedforward\": 73728,\n  \"seq_size\": 2048,\n  \"attn_heads\": 48,\n  \"attn_size\": 256,\n  \"num_blocks\": 118\n}\n"
  },
  {
    "path": "models/turing-530B.json",
    "content": "{\n  \"hidden\": 20480,\n  \"feedforward\": 81920,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 160,\n  \"num_blocks\": 105\n}\n"
  },
  {
    "path": "pylintrc",
    "content": "[MESSAGES CONTROL]\ndisable=locally-disabled,\n\ttoo-many-branches,\n\ttoo-many-instance-attributes,\n\ttoo-many-return-statements,\n\tduplicate-code,\n\ttoo-many-arguments,\n\tno-method-argument\n\n[FORMAT]\nindent-string='  '\nindent-after-paren=2\n\n[DESIGN]\nmin-public-methods=0\nmax-public-methods=9999"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\n    \"setuptools>=42\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "scripts/3dplot.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport matplotlib\nmatplotlib.use('TkAgg')\nimport matplotlib.pyplot as plt\nimport matplotlib.ticker as tkr\nimport numpy as np\n\n\ndef main(args):\n  data = calculon.io.read_json_file(args.stats)\n\n  # Turns the keys back into integers\n  ndata = {}\n  for tp in data.keys():\n    tpi = int(tp)\n    ndata[tpi] = {}\n    for pp in data[tp].keys():\n      ppi = int(pp)\n      ndata[tpi][ppi] = data[tp][pp]\n  data = ndata\n  tps = sorted(list(data.keys()))\n  pps = set()\n  for tp in data.keys():\n    for pp in data[tp].keys():\n      pps.add(pp)\n  pps = sorted(list(pps))\n  assert len(tps) > 1, f'len(tps)={len(tps)} can\\'t plot'\n  assert len(pps) > 1, f'len(pps)={len(pps)} can\\'t plot'\n\n  # Gathers data\n  fdata = np.full((len(pps), len(tps)), float('NaN'))\n  for tp in data.keys():\n    for pp in data[tp].keys():\n      if 'stats' in data[tp][pp]:\n        v = data[tp][pp]['stats']['sample_rate']\n        fdata[pps.index(pp)][tps.index(tp)] = v\n        print(f'{tp},{pp} is {v}')\n      else:\n        print(f'{tp},{pp} has none')\n\n  fig = plt.figure()\n  ax = fig.add_subplot(111, projection='3d')\n  X, Y = np.meshgrid(list(range(len(tps))), list(range(len(pps))))\n  ax.plot_surface(X, Y, fdata, rstride=1, cstride=1,\n                  cmap='rainbow',\n                  edgecolor='none')\n  ax.set_xlabel('Tensor Parallelism')\n  ax.set_ylabel('Pipeline Parallelism')\n  ax.set_zlabel('Sample Rate (s/sec)')\n  if args.title:\n    ax.set_title(args.title)\n  ax.view_init(20, 180+25)\n  @tkr.FuncFormatter\n  def formatter(x, pos):\n    d = 2**x\n    if d < 1:\n      return 'duh'\n    else:\n      return str(int(d))\n  ax.xaxis.set_major_formatter(formatter)\n  ax.yaxis.set_major_formatter(formatter)\n  fig.tight_layout()\n  plt.show()\n\n\n\nif __name__ == '__main__':\n  ap = argparse.ArgumentParser()\n  ap.add_argument('stats', type=str,\n                  help='File path to stats input')\n  ap.add_argument('-t', '--title', type=str, default=None,\n                  help='Title of plot')\n  args = ap.parse_args()\n  main(args)\n"
  },
  {
    "path": "scripts/find_huge.py",
    "content": "#!/usr/bin/env python3\n\nimport numpy as np\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport tol_colors as tc\n\n########## Utils ##########\ndef transformer_attn_size(hidden, layers, attn_size_step=32):\n  return step_rounder(hidden / layers, attn_size_step)\n\ndef transformer_num_parameters(hidden, layers, attn_size_step=32):\n  attn_heads = layers\n  attn_size = transformer_attn_size(hidden, layers, attn_size_step)\n  mlp_params = 8 * layers * hidden **2\n  attn_params = 4 * layers * hidden * attn_heads * attn_size\n  return mlp_params + attn_params\n  #return 12 * layers * hidden **2\n\ndef transformer_t_params(hidden, layers):\n  return transformer_num_parameters(hidden, layers) / 10**12\n\ndef step_rounder(layer, step=1):\n  return np.round(layer/step) * step\n\ndef model_ratio(hidden, layers):\n  return hidden / layers\n\ndef human_format(value, v_type='base10', precision=3):\n  step = 1\n  suffix = ''\n  if v_type == 'base10':\n    step = 1000\n    suffix = ''\n  elif v_type == 'base2':\n    step = 1024\n    suffix = ''\n  elif v_type == 'bytes':\n    step = 1024\n    suffix = 'iB'\n  elif v_type == 'bandwidth':\n    step = 1000\n    suffix = 'B/s'\n  elif v_type == 'flops':\n    step = 1000\n    suffix = 'Ops'\n  elif v_type == 'throughput':\n    step = 1000\n    suffix = 'Op/s'\n  else:\n    raise ValueError(\n      f\"Type value should be 'base10', 'base2', 'bytes', 'flops', \"\n      f\"'bandwidth', or 'throughput'. You gave {v_type}\")\n  labels = ['', 'k', 'M', 'G', 'T', 'P', 'E']\n  index = 0\n  if value != None:\n    abs_value = abs(value)\n    if value >= 0:\n      sign = 1\n    else:\n      sign = -1\n    for l in labels:\n      if abs_value >= step:\n        abs_value /= step\n        index += 1\n      else:\n        break\n    value = sign * abs_value\n    return \"{0:.{1}f}{2}{3}\".format(value, precision, labels[index], suffix)\n  else:\n    return \"n/a {1}{2}\".format(value, labels[0], suffix)\n\n########## Scale rules with ratio ##########\ndef ratio_layer_scale(hidden, ratio=128, step=4):\n  return step_rounder(hidden/ratio, step=step)\ndef ratio_hidden_scale(layers, ratio=128, step=4096):\n  return step_rounder(layers * ratio, step=step)\ndef ratio_param_layer_scale(layers, ratio=128, step=4096):\n  return transformer_num_parameters(\n    ratio_hidden_scale(layers, ratio=ratio, step=step), layers)\ndef ratio_param_hidden_scale(hidden, ratio=128, step=4):\n  return transformer_num_parameters(\n    hidden, ratio_layer_scale(hidden, ratio=ratio, step=step))\n\n\n\nhidden_step = 1024\nlayer_step = 32\nhiddens = [x for x in range(24*1024, 8192*24 + 1, hidden_step)]\nlayers = [x for x in range(128, 576 + 1, layer_step)]\nslope = (320-192) / ((512-128)/layer_step)\ny_intercept = 192\ntargets = [slope * x + y_intercept for x in range(len(layers))]\n#targets = [200 for x in range(len(layers))]\nhiddens = np.asarray(hiddens)\nlayers = np.asarray(layers)\nparams_grid = np.zeros((hiddens.shape[0], layers.shape[0]), dtype=\"float\")\nratio_grid = np.zeros((hiddens.shape[0], layers.shape[0]), dtype=\"float\")\ntarget_ratio_grid = np.zeros((hiddens.shape[0], layers.shape[0]), dtype=\"float\")\nfor row, h in enumerate(hiddens):\n  for col, l in enumerate(layers):\n    params_grid[row][col] = transformer_num_parameters(h, l)\n    ratio = model_ratio(h, l)\n    ratio_grid[row][col] = ratio\n    target_ratio_grid[row][col] = ratio / targets[col]\n\n\nfig = plt.figure(figsize=(16, 16), dpi=200)\nax = fig.add_subplot(1, 1, 1)\n\nim = ax.imshow(target_ratio_grid, cmap=tc.tol_cmap('BuRd'),\n               vmin=.5, vmax=1.5, origin='lower')#, aspect=0.8)\nax.set_xlabel('# of blocks')\nax.set_ylabel('Hidden size')\n\n# Show all ticks and label them with the respective list entries\nax.set_yticks(np.arange(hiddens.shape[0]))\nax.set_xticks(np.arange(layers.shape[0]))\nax.set_yticklabels(hiddens)\nax.set_xticklabels(layers)\n\n# Rotate the tick labels and set their alignment.\nplt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n         rotation_mode=\"anchor\")\n\n# Loop over data dimensions and create text annotations.\nprint('name,hidden,feedforward,seq_size,attn_heads,attn_size,num_blocks,gbs,ratio')\nfor col, l in enumerate(layers):\n  best_val = 9999\n  best_row = None\n  for row, h in enumerate(hiddens):\n    val = abs(target_ratio_grid[row][col] - 1)\n    if val < best_val:\n      best_val = val\n      best_row = row\n  for row, h in enumerate(hiddens):\n    result = human_format(params_grid[row][col], precision=0)\n    result += \"\\n\"\n    result += human_format(ratio_grid[row][col], precision=0)\n    weight = 'bold' if row == best_row else None\n    text = ax.text(col, row, result, ha=\"center\", va=\"center\", color=\"k\", size=8, weight=weight)\n    if row == best_row:\n      attn_size = int(step_rounder(hiddens[row] / layers[col]))\n      params = human_format(transformer_num_parameters(hiddens[row], layers[col]), precision=0)\n      ratio = hiddens[row] / layers[col]\n      print(f'{params},{hiddens[row]},{hiddens[row]*4},8192,{layers[col]},{attn_size},{layers[col]},3072,{ratio}')\n\nexit(0)\nax.spines[:].set_visible(False)\nax.set_xticks(np.arange(params_grid.shape[1]+1)-.5, minor=True)\nax.set_yticks(np.arange(params_grid.shape[0]+1)-.5, minor=True)\nax.grid(which=\"minor\", color=\"w\", linestyle='-', linewidth=2)\nax.tick_params(which=\"minor\", bottom=False, left=False)\nax.set_title(\"Number of parameters in trillions, and model ratio, colored by ratio\")\n\nfig.tight_layout()\nfig.savefig('huge.png')\nplt.close(fig)\n"
  },
  {
    "path": "scripts/heatmap.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport matplotlib.ticker as tkr\nimport numpy as np\nimport tol_colors as tc\n\n\ndef main(args):\n  data = calculon.io.read_json_file(args.stats)\n\n  # Turns the keys back into integers\n  ndata = {}\n  for tp in data.keys():\n    tpi = int(tp)\n    ndata[tpi] = {}\n    for pp in data[tp].keys():\n      ppi = int(pp)\n      ndata[tpi][ppi] = data[tp][pp]\n  data = ndata\n  tps = sorted(list(data.keys()))\n  pps = set()\n  for tp in data.keys():\n    for pp in data[tp].keys():\n      pps.add(pp)\n  pps = sorted(list(pps))\n  assert len(tps) > 1, f'len(tps)={len(tps)} can\\'t plot'\n  assert len(pps) > 1, f'len(pps)={len(pps)} can\\'t plot'\n\n  # Gathers data\n  fdata = np.full((len(tps), len(pps)), float('NaN'))\n  for tp in data.keys():\n    for pp in data[tp].keys():\n      if 'stats' in data[tp][pp]:\n        v = data[tp][pp]['stats']['sample_rate']\n        fdata[tps.index(tp)][pps.index(pp)] = v\n        print(f'{tp},{pp} is {v}')\n      else:\n        print(f'{tp},{pp} has none')\n\n  # Determines range\n  minf = min(map(min, fdata))\n  maxf = max(map(max, fdata))\n  black_threshold = minf + (maxf - minf) * 0.30\n  print(f'min={minf} max={maxf} thres={black_threshold}')\n\n  # Creates the plot\n  fig = plt.figure()\n  ax = fig.add_subplot(1, 1, 1)\n  ax.imshow(fdata, origin='lower', cmap='hot')#, linewidth=0.5)\n  ax.set_xticks(np.arange(len(pps)), labels=pps)\n  ax.set_xlabel('Pipeline Parallelism')\n  ax.set_yticks(np.arange(len(tps)), labels=tps)\n  ax.set_ylabel('Tensor Parallelism')\n  for tp in tps:\n    for pp in pps:\n      perf = fdata[tps.index(tp), pps.index(pp)]\n      color = 'black' if perf > black_threshold else 'white'\n      perf = f'{perf:.1f}'\n      text = f'{perf}'\n      ax.text(pps.index(pp), tps.index(tp), text, ha='center', va='center',\n              color=color)\n  if args.title:\n    ax.set_title(args.title)\n  print(f'writing {args.output}')\n  fig.tight_layout()\n  fig.savefig(args.output)\n  plt.close(fig)\n\n\nif __name__ == '__main__':\n  ap = argparse.ArgumentParser()\n  ap.add_argument('stats', type=str,\n                  help='File path to stats input')\n  ap.add_argument('output', type=str,\n                  help='Output plot file')\n  ap.add_argument('-t', '--title', type=str, default=None,\n                  help='Title of plot')\n  args = ap.parse_args()\n  main(args)\n"
  },
  {
    "path": "scripts/install_hooks.sh",
    "content": "#!/bin/bash\n\nset -e\n\n# Pre-commit hook\ncat > .git/hooks/pre-commit <<-EOF\n#!/bin/bash\necho -n \"Testing...\"\nif ! make test &> /dev/null; then\n    echo \" failed :(\"\n    exit -1\nfi\nEOF\n\nchmod a+x .git/hooks/pre-commit\n"
  },
  {
    "path": "scripts/json_to_csv.py",
    "content": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport gzip\nimport json\nimport sys\n\n\ndef main(args):\n  j = calculon.read_json_file(args.json_file)\n\n  header_entries = []\n  for category in j['0']:\n    for key in j['0'][category]:\n      header_entries.append((category, key))\n\n  opener = gzip.open if args.csv_file.endswith('.gz') else open\n  with opener(args.csv_file, 'wb') as fd:\n    # Header\n    fd.write(bytes(',', 'utf-8'))\n    for _, key in header_entries:\n      fd.write(bytes(f'{key},', 'utf-8'))\n    fd.write(bytes(',\\n', 'utf-8'))\n\n    # Rows\n    for entry in j.keys():\n      fd.write(bytes(f'{entry},', 'utf-8'))\n      for category, key in header_entries:\n        v = j[entry][category][key]\n        fd.write(bytes(f'{v},', 'utf-8'))\n      fd.write(bytes(',\\n', 'utf-8'))\n\nif __name__ == '__main__':\n  ap = argparse.ArgumentParser()\n  ap.add_argument('json_file', help='input JSON file')\n  ap.add_argument('csv_file', help='output CSV file')\n  sys.exit(main(ap.parse_args()))\n"
  },
  {
    "path": "setup.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n\nimport codecs\nimport re\nimport os\nimport sys\n\ntry:\n  from setuptools import setup\nexcept:\n  print('please install setuptools via pip:')\n  print('  pip3 install setuptools')\n  sys.exit(-1)\n\ndef find_version(*file_paths):\n  version_file = codecs.open(os.path.join(os.path.abspath(\n    os.path.dirname(__file__)), *file_paths), 'r').read()\n  version_match = re.search(r\"^__version__ = ['\\\"]([^'\\\"]*)['\\\"]\",\n                            version_file, re.M)\n  if version_match:\n    return version_match.group(1)\n  raise RuntimeError(\"Unable to find version string.\")\n\n\nsetup(\n  name='calculon',\n  version=find_version('calculon', '__init__.py'),\n  description='Co-design for large scale parallel applications',\n  author='Michael Isaev',\n  author_email='michael.v.isaev@gmail.com',\n  license='Apache 2',\n  url='http://github.com/calculon-ai/calculon',\n  packages=['calculon', 'calculon.llm'],\n  scripts=['bin/calculon'],\n  install_requires=[],\n)\n"
  },
  {
    "path": "systems/a100_80e.json",
    "content": "{\n  \"matrix\" : {\n    \"float16\": {\n      \"tflops\": 312,\n      \"gflops_efficiency\": [\n        [128, 0.99],\n        [16, 0.9],\n        [1, 0.6],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"vector\": {\n    \"float16\": {\n      \"tflops\": 78,\n      \"gflops_efficiency\": [\n        [16, 0.95],\n        [1, 0.5],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"mem1\": {\n    \"GiB\": 80000000000,\n    \"GBps\": 2048,\n    \"MB_efficiency\": [\n      [100, 0.95],\n      [10, 0.90],\n      [1, 0.7],\n      [0, 0.3]\n    ]\n  },\n  \"mem2\": {\n    \"GiB\": 512,\n    \"GBps\": 32,\n    \"MB_efficiency\": [\n      [100, 0.95],\n      [10, 0.9],\n      [1, 0.7],\n      [0, 0.3]\n    ]\n  },\n  \"processing_mode\": \"no_overlap\",\n  \"networks\": [\n    {\n      \"bandwidth\": 300,\n      \"efficiency\": 0.65,\n      \"size\": 8,\n      \"latency\": 0.00001,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.5, -1],\n        \"all_gather\": [1.5, -1],\n        \"all_reduce\": [2.0, -1]\n      },\n      \"must_be_filled\": true,\n      \"processor_usage\": 0.15\n    },{\n      \"bandwidth\": 25,\n      \"efficiency\": 0.9,\n      \"size\": 65536,\n      \"latency\": 0.00002,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.0, 0],\n        \"all_gather\": [1.0, 0],\n        \"all_reduce\": [1.0, 0]\n      },\n      \"must_be_filled\": false,\n      \"processor_usage\": 0.02\n    }\n  ]\n}\n"
  },
  {
    "path": "systems/a100_80g.json",
    "content": "{\n  \"matrix\" : {\n    \"float16\": {\n      \"tflops\": 312,\n      \"gflops_efficiency\": [\n        [128, 0.95],\n        [16, 0.9],\n        [1, 0.6],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"vector\": {\n    \"float16\": {\n      \"tflops\": 78,\n      \"gflops_efficiency\": [\n        [16, 0.95],\n        [1, 0.5],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"mem1\": {\n    \"GiB\": 80,\n    \"GBps\": 2048,\n    \"MB_efficiency\": [\n      [100, 0.90],\n      [10, 0.75],\n      [1, 0.6],\n      [0, 0.3]\n    ]\n  },\n  \"mem2\": {\n    \"GiB\": 512,\n    \"GBps\": 32,\n    \"MB_efficiency\": [\n      [100, 0.95],\n      [10, 0.9],\n      [1, 0.7],\n      [0, 0.3]\n    ]\n  },\n  \"processing_mode\": \"no_overlap\",\n  \"networks\": [\n    {\n      \"bandwidth\": 300,\n      \"efficiency\": 0.65,\n      \"size\": 8,\n      \"latency\": 0.00001,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.5, -1],\n        \"all_gather\": [1.5, -1],\n        \"all_reduce\": [2.0, -1]\n      },\n      \"must_be_filled\": true,\n      \"processor_usage\": 0.15\n    },{\n      \"bandwidth\": 25,\n      \"efficiency\": 0.9,\n      \"size\": 65536,\n      \"latency\": 0.00002,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.0, 0],\n        \"all_gather\": [1.0, 0],\n        \"all_reduce\": [1.0, 0]\n      },\n      \"must_be_filled\": false,\n      \"processor_usage\": 0.02\n    }\n  ]\n}\n"
  },
  {
    "path": "systems/h100_80g_nvl8.json",
    "content": "{\n  \"matrix\": {\n    \"float8\": {\n      \"tflops\": 2000,\n      \"gflops_efficiency\": [\n        [128, 0.95],\n        [16, 0.9],\n        [1, 0.6],\n        [0, 0.1]\n      ]\n    },\n    \"float16\": {\n      \"tflops\": 1000,\n      \"gflops_efficiency\": [\n        [128, 0.95],\n        [16, 0.9],\n        [1, 0.6],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"vector\": {\n    \"float8\": {\n      \"tflops\": 120,\n      \"gflops_efficiency\": [\n        [16, 0.95],\n        [1, 0.5],\n        [0, 0.1]\n      ]\n    },\n    \"float16\": {\n      \"tflops\": 120,\n      \"gflops_efficiency\": [\n        [16, 0.95],\n        [1, 0.5],\n        [0, 0.1]\n      ]\n    }\n  },\n  \"mem1\": {\n    \"GiB\": 80,\n    \"GBps\": 3072,\n    \"MB_efficiency\": [\n      [100, 0.90],\n      [10, 0.75],\n      [1, 0.6],\n      [0, 0.3]\n    ]\n  },\n  \"mem2\": {\n    \"GiB\": 512,\n    \"GBps\": 450,\n    \"MB_efficiency\": [\n      [100, 0.95],\n      [10, 0.9],\n      [1, 0.7],\n      [0, 0.3]\n    ]\n  },\n  \"processing_mode\": \"no_overlap\",\n  \"networks\": [\n    {\n      \"bandwidth\": 450,\n      \"efficiency\": 0.65,\n      \"size\": 8,\n      \"latency\": 0.00001,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.0, 0],\n        \"all_gather\": [1.0, 0],\n        \"all_reduce\": [1.0, 1]\n      },\n      \"must_be_filled\": true,\n      \"processor_usage\": 0.15\n    },{\n      \"bandwidth\": 50,\n      \"efficiency\": 0.9,\n      \"size\": 65536,\n      \"latency\": 0.00002,\n      \"ops\": {\n        \"p2p\": [1.0, null],\n        \"reduce_scatter\": [1.0, 0],\n        \"all_gather\": [1.0, 0],\n        \"all_reduce\": [1.0, 0]\n      },\n      \"must_be_filled\": false,\n      \"processor_usage\": 0.02\n    }\n  ]\n}\n"
  },
  {
    "path": "test/__init__.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\n"
  },
  {
    "path": "test/test.sh",
    "content": "#!/bin/bash\n\nset -e\n\nexport PYTHONPATH=.\n\n# CLI interface infrastructure\necho -e \"### Testing top level --help\"\n./bin/calculon --help > /dev/null\ncommands=$(./bin/calculon --help | head -n 2 | tail -n 1 | tr '{' ' ' | tr '}' ' ' | tr ',' ' ')\nfor command in $commands; do\n    if [ $command == 'v' ] || [ $command == 'version' ]; then\n\techo -e \"### Testing \\\"$command\\\"\"\n\t./bin/calculon $command\n    else\n\techo -e \"### Testing \\\"$command\\\" --help\"\n\t./bin/calculon $command --help > /dev/null\n    fi\ndone\necho -e \"\\n\\n\"\n\n# Model size calculations\necho -e \"### Testing llm-parameter-calculator\"\nfor model in models/*json; do\n    ./bin/calculon llm-parameter-calculator -a 15 $model\ndone\necho -e \"\\n\\n\"\n\n# Model tests\necho -e \"### Testing llm\"\nfor model in models/*json; do\n    echo $model\n    ./bin/calculon llm $model examples/3072_t4_p64_d12_mbs4_full.json systems/a100_80e.json - > /dev/null\n    ./bin/calculon llm $model examples/3072_t4_p64_d12_mbs4_full.json systems/a100_80e.json /tmp/calculon_stats.json -p /tmp/calculon_peers.json\ndone\necho -e \"\\n\\n\"\n\n# Llm validation\necho -e \"### Testing llm-validation\"\n./bin/calculon lv -v\necho -e \"\\n\\n\"\n\n# Llm optimal execution\necho -e \"### Testing llm-optimal-execution (float16) (using -f)\"\n./bin/calculon loe models/turing-530B.json 5128 2520 float16 systems/h100_80g_nvl8.json /tmp/calculon_530B_fp16.json -t 3 -f False --no-tp-overlap\necho -e \"\\n\"\n\necho -e \"### Testing llm-optimal-execution (float8) (using -m)\"\n./bin/calculon loe models/turing-530B.json 5128 2520 float8 systems/h100_80g_nvl8.json /tmp/calculon_530B_fp8.csv.gz -t 10 -m\necho -e \"\\n\\n\"\n\n# Llm all executions\necho -e \"### Testing llm-all-executions (float8)\"\n./bin/calculon lae models/turing-530B.json 5128 2520 float8 systems/h100_80g_nvl8.json /tmp/calculon_530B_fp8_all.csv.gz\necho -e \"\\n\\n\"\n\n"
  },
  {
    "path": "test/test_json_write_read.py",
    "content": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n *  https://www.apache.org/licenses/LICENSE-2.0\n *\n * See the NOTICE file distributed with this work for additional information\n * regarding copyright ownership.\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n\"\"\"\nimport calculon\nimport os\nimport tempfile\nimport unittest\n\n\nclass JsonWriteReadTestCase(unittest.TestCase):\n  def test_json_read_write(self):\n    jd = {\n      'a': 1239,\n      'hi': {\n        '34': 'world',\n        'ugh': 77,\n        '1': 'hello world world world world world world world world world world'\n      }\n    }\n\n    _, reg_file = tempfile.mkstemp(suffix='.json')\n    _, gz_file = tempfile.mkstemp(suffix='.json.gz')\n    _, foo_file = tempfile.mkstemp(suffix='.json.foo')\n    _, bar_file = tempfile.mkstemp(suffix='.bar.gz')\n    os.remove(reg_file)\n    os.remove(gz_file)\n    os.remove(foo_file)\n    os.remove(bar_file)\n\n    self.assertTrue(calculon.is_json_extension(reg_file))\n    self.assertTrue(calculon.is_json_extension(gz_file))\n    self.assertFalse(calculon.is_json_extension(foo_file))\n    self.assertFalse(calculon.is_json_extension(bar_file))\n\n    self.assertFalse(os.path.exists(reg_file))\n    self.assertFalse(os.path.exists(gz_file))\n\n    calculon.io.write_json_file(jd, reg_file)\n    calculon.io.write_json_file(jd, gz_file)\n\n    self.assertTrue(os.path.exists(reg_file))\n    self.assertTrue(os.path.exists(gz_file))\n\n    reg_size = os.path.getsize(reg_file)\n    gz_size = os.path.getsize(gz_file)\n    self.assertTrue(reg_size > 0)\n    self.assertTrue(reg_size > gz_size)\n    self.assertTrue(gz_size > 0)\n\n    reg_jd = calculon.io.read_json_file(reg_file)\n    gz_jd = calculon.io.read_json_file(gz_file)\n\n    self.assertEqual(reg_jd, jd)\n    self.assertEqual(gz_jd, jd)\n\n    os.remove(reg_file)\n    os.remove(gz_file)\n"
  },
  {
    "path": "validation/seqsel/fig1/gpt3-175B_none.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/gpt3-175B_seqsel.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-1T_none.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-1T_seqsel.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\":  false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-22B_none.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-22B_seqsel.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/turing-530B_none.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig1/turing-530B_seqsel.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_full.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_none.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_sel.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_seq.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_seqsel.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_full.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_none.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_sel.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_seq.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_seqsel.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_full.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_none.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_sel.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_seq.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_seqsel.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_full.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_none.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_sel.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_seq.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"none\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_seqsel.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/gpt3-175B_full.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/gpt3-175B_seqsel.json",
    "content": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 64,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-1T_full.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-1T_seqsel.json",
    "content": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 512,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-22B_full.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-22B_seqsel.json",
    "content": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 4,\n  \"microbatch_size\": 4,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 1,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/turing-530B_full.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"full\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"ar\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": false,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  },
  {
    "path": "validation/seqsel/tab5/turing-530B_seqsel.json",
    "content": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\": 1,\n  \"data_par_net\": 1,\n  \"batch_size\": 280,\n  \"microbatch_size\": 1,\n  \"datatype\": \"float16\",\n  \"fused_activation\": false,\n  \"attention_type\": \"multihead\",\n  \"activation_recompute\": \"attn_only\",\n  \"pipeline_interleaving\": 3,\n  \"optimizer_sharding\": false,\n  \"tensor_par_comm_type\": \"rs_ag\",\n  \"tensor_par_overlap\": \"none\",\n  \"seq_par_ag_redo\": true,\n  \"data_par_overlap\": false,\n  \"weight_offload\": false,\n  \"activations_offload\": false,\n  \"optimizer_offload\": false,\n  \"training\": true\n}\n"
  }
]