Full Code of calculon-ai/calculon for AI

main caa4b11f8fe1 cached
87 files
249.7 KB
68.3k tokens
255 symbols
1 requests
Download .txt
Showing preview only (271K chars total). Download the full file or copy to clipboard to get everything.
Repository: calculon-ai/calculon
Branch: main
Commit: caa4b11f8fe1
Files: 87
Total size: 249.7 KB

Directory structure:
gitextract_qq3udcpt/

├── .gitignore
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── bin/
│   └── calculon
├── calculon/
│   ├── __init__.py
│   ├── command_line.py
│   ├── io.py
│   ├── llm/
│   │   ├── __init__.py
│   │   ├── all_executions.py
│   │   ├── layers.py
│   │   ├── llm.py
│   │   ├── optimal_execution.py
│   │   ├── parameter_calculator.py
│   │   ├── runner.py
│   │   └── validation.py
│   ├── memory.py
│   ├── network.py
│   ├── processor.py
│   ├── system.py
│   ├── util.py
│   └── version.py
├── examples/
│   └── 3072_t4_p64_d12_mbs4_full.json
├── models/
│   ├── anthropic-52B.json
│   ├── chinchilla.json
│   ├── gopher-280B.json
│   ├── gpt3-13B.json
│   ├── gpt3-175B.json
│   ├── lamda.json
│   ├── megatron-126M.json
│   ├── megatron-1T.json
│   ├── megatron-22B.json
│   ├── megatron-40B.json
│   ├── megatron-5B.json
│   ├── palm-540B.json
│   └── turing-530B.json
├── pylintrc
├── pyproject.toml
├── scripts/
│   ├── 3dplot.py
│   ├── find_huge.py
│   ├── heatmap.py
│   ├── install_hooks.sh
│   └── json_to_csv.py
├── setup.py
├── systems/
│   ├── a100_80e.json
│   ├── a100_80g.json
│   └── h100_80g_nvl8.json
├── test/
│   ├── __init__.py
│   ├── test.sh
│   └── test_json_write_read.py
└── validation/
    └── seqsel/
        ├── fig1/
        │   ├── gpt3-175B_none.json
        │   ├── gpt3-175B_seqsel.json
        │   ├── megatron-1T_none.json
        │   ├── megatron-1T_seqsel.json
        │   ├── megatron-22B_none.json
        │   ├── megatron-22B_seqsel.json
        │   ├── turing-530B_none.json
        │   └── turing-530B_seqsel.json
        ├── fig7/
        │   ├── gpt3-175B_full.json
        │   ├── gpt3-175B_none.json
        │   ├── gpt3-175B_sel.json
        │   ├── gpt3-175B_seq.json
        │   ├── gpt3-175B_seqsel.json
        │   ├── megatron-1T_full.json
        │   ├── megatron-1T_none.json
        │   ├── megatron-1T_sel.json
        │   ├── megatron-1T_seq.json
        │   ├── megatron-1T_seqsel.json
        │   ├── megatron-22B_full.json
        │   ├── megatron-22B_none.json
        │   ├── megatron-22B_sel.json
        │   ├── megatron-22B_seq.json
        │   ├── megatron-22B_seqsel.json
        │   ├── turing-530B_full.json
        │   ├── turing-530B_none.json
        │   ├── turing-530B_sel.json
        │   ├── turing-530B_seq.json
        │   └── turing-530B_seqsel.json
        └── tab5/
            ├── gpt3-175B_full.json
            ├── gpt3-175B_seqsel.json
            ├── megatron-1T_full.json
            ├── megatron-1T_seqsel.json
            ├── megatron-22B_full.json
            ├── megatron-22B_seqsel.json
            ├── turing-530B_full.json
            └── turing-530B_seqsel.json

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.DS_Store
*.py[cod]
*.log

# C extensions
*.so

# Packages
*.egg
*.egg-info
dist
build
eggs
parts
var
sdist
develop-eggs
.installed.cfg
lib
lib64
__pycache__

# Installer logs
pip-log.txt
files.txt

# Unit test / coverage reports
.coverage
.tox
nosetests.xml

# Translations
*.mo

# Mr Developer
.mr.developer.cfg
.project
.pydevproject


================================================
FILE: LICENSE
================================================

                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [2022] [Michael Isaev, Nic McDonald]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: Makefile
================================================
.SUFFIXES:
.PHONY: help install clean lint test count

help:
	@echo "options are: install clean lint test count"

install:
	python3 setup.py install --user --record files.txt

uninstall:
	cat files.txt | xargs rm -rf

clean:
	rm -rf build dist calculon.egg-info calculon/*.pyc calculon/__pycache__ calculon/*/__pycache__ test/*.pyc test/__pycache__

lint:
	pylint -r n calculon

test:
	python3 -m unittest -v -f --buffer
	@echo -e "Unit testing successful!\n\n"
	./test/test.sh

count:
	@wc calculon/*.py test/*.py | sort -n -k1
	@echo "files : "$(shell echo calculon/*.py test/*.py | wc -w)
	@echo "commits : "$(shell git rev-list HEAD --count) 


================================================
FILE: NOTICE
================================================
Calculon - Co-design for large scale parallel applications
Copyright 2022 Michael Isaev, Nic McDonald
All rights reserved.

================================================
FILE: README.md
================================================
[![DOI](https://zenodo.org/badge/660734586.svg)](https://zenodo.org/badge/latestdoi/660734586)
# Calculon - Co-design for large scale parallel applications

## Running

Run Calculon like this:
``` sh
$> PYTHONPATH=. ./bin/ <args>
```

Calculon is a hierarchical command line. To see the commands it accepts, use `--help` or `-h`:
``` sh
$> PYTHONPATH=. ./bin/ -h
```

You can also see how to use any command specifically by using `--help` or `-h` on the command:
``` sh
$> PYTHONPATH=. ./bin/ llm -h
```

## LLM Example

Run a single calculation for LLM (~1 sec):
``` sh
$> PYTHONPATH=. ./bin/ llm models/megatron-1T.json examples/3072_t4_p64_d12_mbs4_full.json systems/a100_80g.json -
```

Run a system execution optimizer for LLM (~1 min):
``` sh
$> PYTHONPATH=. ./bin/ llm-optimal-execution models/turing-530B.json 5128 2520 float16 systems/a100_80g.json output.json -m
```
`opt_exe.json` will contain the optimal way to run Turing-530B across 5128 A100 GPUs.

To store results from all successful runs from the same experiment, run a special system optimizer (~1 min):
``` sh
$> PYTHONPATH=. ./bin/ llm-all-executions models/turing-530B.json 5128 2520 float16 systems/a100_80g.json all_output.csv
```

## Testing and validation (optional)
To make sure that the current build is working, use

``` sh
$> make test
```
To validate Calculon performance modeling against Megatron run on NVIDIA's Selene A100-based supercomputer with results published in ["Sequence parallelism" paper](https://arxiv.org/abs/2205.05198), use

``` sh
$> PYTHONPATH=. ./bin/calculon llm-validation
```

## Publications

* Calculon: A Methodology and Tool for High-Level Co-Design of Systems and Large Language Models\
Mikhail Isaev, Nic McDonald, Larry Dennison, Richard Vuduc\
[Paper](https://dl.acm.org/doi/pdf/10.1145/3581784.3607102)

* Scaling Infrastructure to Support Multi-Trillion Parameter LLM Training\
Mikhail Isaev, Nic McDonald, Richard Vuduc\
[Paper](https://openreview.net/pdf?id=rqn2v1Ltgn0)


================================================
FILE: bin/calculon
================================================
#!/usr/bin/env python3

"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import argparse
import calculon
import logging
import sys




if __name__ == '__main__':
  # CLI inspired from: https://github.com/ssnetsim/ssplot/

  # Creates an argparser and subparsers.
  desc = 'Calculon: Co-design for large scale parallel applications'
  ap = argparse.ArgumentParser(description=desc)
  ap.add_argument('-l', '--log', default='-',
                  help='Sets the log file, or - for stdout (default)')
  ap.add_argument('-v', '--verbosity', default='INFO',
                  help='Sets the logging level (see logging docs)')
  sp = ap.add_subparsers(title='commands', dest='command',
                         description='commands available in Calculon',
                         help='the command')
  sp.required = True

  # Registers each command line interface.
  for cls in calculon.CommandLine.command_lines():
    cls.create_parser(sp)

  # Parses the args and creates the logger
  args = ap.parse_args()
  logger = logging.getLogger()
  if args.log == '-':
    logger.addHandler(logging.StreamHandler(stream=sys.stdout))
  else:
    fd = open(args.log, 'w')
    logger.addHandler(logging.StreamHandler(stream=fd))
  logger.setLevel(args.verbosity)

  # Calls the corresponding command function
  sys.exit(args.func(logger, args))


================================================
FILE: calculon/__init__.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

__version__ = '0.1.0'

# Imports of this module
from .command_line import CommandLine
from .io import *
from .system import System
from .util import *
from .version import Version

# Imports submodules
from .llm import *


================================================
FILE: calculon/command_line.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import copy

class CommandLine:
  """Defines the abstract interface definition for a command line interface.
  Inspired from: https://github.com/ssnetsim/ssplot/
  """

  @staticmethod
  def create_parser(subparser):
    """
    This function adds a parser to the subparser object according to the
    specific command line interface implementation.
    """
    raise NotImplementedError('subclasses must override this')

  @staticmethod
  def run_command(logger, args):
    """
    This function is used to run the command if it is chosen at the command
    line. This function should be registered to the parser in create_parser().
    """
    raise NotImplementedError('subclasses must override this')

  # this is a mapping of all names (class->names)
  _names = {}

  @staticmethod
  def register(cls):
    # gather names
    primary_name = cls.NAME
    aliases = cls.ALIASES

    # create a set to hold all
    all_names = [primary_name] + aliases

    # check current names against all new names
    for new_name in all_names:
      for pname in CommandLine._names:
        assert new_name is not pname, f'{new_name} already exists'
        for alias in CommandLine._names[pname]:
          assert new_name is not alias, f'{new_name} already exists'

    # add to map
    CommandLine._names[cls] = all_names

  @staticmethod
  def command_lines():
    return set(CommandLine._names.keys())

  @staticmethod
  def all_names():
    return copy.copy(CommandLine._names)


================================================
FILE: calculon/io.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""
import gzip
import json
import numpy as np


class NpEncoder(json.JSONEncoder):
  def default(self, obj):
    if isinstance(obj, np.integer):
      return int(obj)
    if isinstance(obj, np.floating):
      return float(obj)
    if isinstance(obj, np.ndarray):
      return obj.tolist()
    if isinstance(obj, np.bool_):
      return bool(obj)
    return super(NpEncoder, self).default(obj)

def is_json_extension(filename):
  return filename.endswith('.json') or filename.endswith('.json.gz')


def write_json_file(jdata, filename):
  assert is_json_extension(filename)
  opener = gzip.open if filename.endswith('.gz') else open
  indent = None if filename.endswith('.gz') else 2
  with opener(filename, 'wb') as fd:
    fd.write(bytes(json.dumps(jdata, indent=indent, cls=NpEncoder), 'utf-8'))


def read_json_file(filename):
  assert is_json_extension(filename)
  opener = gzip.open if filename.endswith('.gz') else open
  with opener(filename, 'rb') as fd:
    return json.loads(fd.read().decode('utf-8'))


================================================
FILE: calculon/llm/__init__.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

from .layers import *
from .llm import *

# Command lines
from .all_executions import AllExecutions
from .optimal_execution import OptimalExecution
from .parameter_calculator import ParameterCalculator
from .validation import Validation
from .runner import Runner


================================================
FILE: calculon/llm/all_executions.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import datetime
import gzip
import itertools
import logging
import math
import multiprocessing as mp
import os
import pandas
import psutil
import random

import calculon
from calculon.util import pick, arg_true_false_all
from calculon.llm import *


class AllExecutions(calculon.CommandLine):
  NAME = 'llm-all-executions'
  ALIASES = ['lae']

  @staticmethod
  def create_parser(subparser):
    sp = subparser.add_parser(
      AllExecutions.NAME, aliases=AllExecutions.ALIASES,
      help='run a search to find the optimal llm execution')
    sp.set_defaults(func=AllExecutions.run_command)
    sp.add_argument('-d', '--debug', action='store_true',
                    help='Loop over executions, don\'t run them')
    sp.add_argument('application', type=str,
                    help='File path to application configuration')
    sp.add_argument('num_procs', type=int,
                    help='Number of processors in execution')
    sp.add_argument('max_batch_size', type=int,
                    help='Maximum batch size, will be largest multiple of DP')
    sp.add_argument('datatype', type=str, choices=System.supported_datatypes(),
                    help='The datatype to use')
    sp.add_argument('system', type=str,
                    help='File path to system configuration')
    sp.add_argument('output', type=str,
                    help='File path to the output file'
                    " ('*.csv', '*.csv.gz')")
    sp.add_argument('-c', '--cpus', type=int, default=psutil.cpu_count(logical=False),
                    help='CPUs to use for parallelization')
    sp.add_argument('-n', '--noneok', action='store_true',
                    help='Don\'t give failure status when no good execution exists')
    sp.add_argument('-f', '--fused_activation', type=arg_true_false_all,
                    default='true', help='Mode of fused activation')

  @staticmethod
  def execution_fields():
    return (
      'num_procs', 'tensor_par', 'pipeline_par', 'data_par', 'tensor_par_net',
      'pipeline_par_net', 'data_par_net', 'batch_size', 'microbatch_size',
      'datatype', 'fused_activation', 'attention_type', 'activation_recompute',
      'pipeline_interleaving', 'optimizer_sharding', 'tensor_par_comm_type',
      'tensor_par_overlap', 'seq_par_ag_redo', 'data_par_overlap',
      'weight_offload', 'activations_offload', 'optimizer_offload', 'training')

  @staticmethod
  def get_batch_size(data_par, max_batch_size):
    if data_par > max_batch_size:
      return None
    last = data_par
    while True:
      if last + data_par > max_batch_size:
        return last
      else:
        last += data_par

  @staticmethod
  def all_executions(app, syst, num_procs, max_batch_size, datatype, fused_activation):
    has_mem2 = syst.mem2.capacity > 0
    num_nets = syst.num_networks
    count = 0
    for tp in Llm.get_all_tensor_parallelisms(
        num_procs, app.hidden, app.attn_heads):
      for pp in Llm.get_all_pipeline_parallelisms(
          num_procs, tp, app.num_blocks):
        dp = Llm.get_data_parallelism(num_procs, tp, pp)
        for ppint in Llm.get_valid_pipeline_interleavings(app.num_blocks, pp):
          batch_size = AllExecutions.get_batch_size(dp, max_batch_size)
          if batch_size is None:
            continue
          for activation_recompute in ['full', 'attn_only', 'none']:
            for optimizer_sharding in pick(dp>1, [True, False], [False]):
              for tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']:
                can_redo = Llm.can_redo_ag(tensor_par_comm_type,
                                           activation_recompute)
                for seq_par_ag_redo in pick(can_redo, [True, False], [False]):
                  for data_par_overlap in pick(dp>1, [True, False], [False]):
                    for tensor_par_overlap in pick(tp>1, ['none', 'ring', 'pipe'], ['none']):
                      for weight_offload in pick(has_mem2, [True, False], [False]):
                        if activation_recompute == 'full' or not has_mem2:
                          activations_offloads = [False]
                        else:
                          activations_offloads = [True, False]
                        for activations_offload in activations_offloads:
                          for optimizer_offload in pick(has_mem2, [True, False],
                                                        [False]):
                            for fused_act in fused_activation:
                              for microbatch_size in Llm.get_valid_microbatch_sizes(
                                  app.seq_size, tp, dp, batch_size, pp):
                                for tn in pick(tp>1, range(num_nets), [0]):
                                  for pn in pick(pp>1, range(num_nets), [0]):
                                    for dn in pick(dp>1, range(num_nets), [0]):
                                      yield (num_procs, tp, pp, dp, tn, pn, dn,
                                             batch_size, microbatch_size, datatype,
                                             fused_act, 'multihead', activation_recompute,
                                             ppint, optimizer_sharding, tensor_par_comm_type,
                                             tensor_par_overlap, seq_par_ag_redo,
                                             data_par_overlap, weight_offload,
                                             activations_offload, optimizer_offload,
                                             True)
                                      count += 1

  @staticmethod
  def run_command(logger, args):
    assert args.output.endswith('.csv') or args.output.endswith('.csv.gz')

    app = Llm.Application(calculon.io.read_json_file(args.application))
    syst = System(calculon.io.read_json_file(args.system))

    executions = list(AllExecutions.all_executions(
      app, syst, args.num_procs, args.max_batch_size, args.datatype,
      args.fused_activation))
    random.shuffle(executions)
    exe_count = len(executions)
    logger.info(f'Total executions: {exe_count}')

    step = math.ceil(len(executions) / args.cpus)
    worker_args = []
    for index in range(0, len(executions), step):
      worker_args.append((app, syst, executions[index : index + step]))
    del executions

    # Runs parallel searches
    start_time = datetime.datetime.now()
    with mp.Pool(args.cpus) as pool:
      goods = pool.starmap(AllExecutions.search, worker_args)
    end_time = datetime.datetime.now()
    good_count = sum(len(good) for good in goods)

    # Console statistics
    logger.info(f'Good executions: {good_count}')
    logger.info(f'Bad executions: {exe_count-good_count}')
    calc_rate = exe_count / (end_time - start_time).total_seconds()
    logger.info(f'Calculation rate: {calc_rate:.2f} calcs/sec')

    # Check if OK
    if good_count == 0:
      if not args.noneok:
        logger.fatal('No acceptable configurations found :(')
        return -1
      else:
        logger.info('No acceptable configurations found :(')

    if args.debug:
      return 0

    # Writes to CSV
    fields = Llm.Execution.fields() + Llm.get_stats_fields()
    assert len(fields) == len(goods[0][0])
    logger.info(f'Output: {args.output}')
    opener = gzip.open if args.output.endswith('.gz') else open
    with opener(args.output, 'wb') as fd:
      fd.write(bytes(','.join(fields) + '\n', 'utf-8'))
      for vals in itertools.chain(*goods):
        fd.write(bytes(','.join(str(v) for v in vals) + '\n', 'utf-8'))

    return 0

  @staticmethod
  def search(app, syst, executions):
    good = []
    for execution in executions:
      try:
        model = Llm(app, logging.Logger('sub'))
        model.compile(syst, Llm.Execution(*execution))
        model.run(syst)
        statistics = model.get_stats_values()
        good.append(execution + statistics)
      except Llm.Error as ex:
        logger = logging.getLogger()
        logger.debug(f'ERROR:{ex}\n')
    return good

  @staticmethod
  def update_list(current, candidate, quantity):
    if not isinstance(candidate, list):
      current.append(candidate)
    else:
      current.extend(candidate)
    if quantity <= 0:
      return current  # don't sort and chop
    else:
      current.sort(reverse=True, key=lambda x: x[0])
      return current[:quantity]


calculon.CommandLine.register(AllExecutions)


================================================
FILE: calculon/llm/layers.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

from calculon import *


class Layer:
  """
  A single layer of a neural network. Has weights, activation space,
  gradients, and optimizer state associated with it. May invoke compute,
  memory access, or network operation.
  """

  def __init__(self, name, sys, fw_flops=0, agrad_flops=0, wgrad_flops=0,
               inputs_size=0, output_size=0, activation_space=0,
               activation_grads=0, weight_space=0, weight_grads=0,
               optim_space=0, needs_recompute=False, needs_recomm=False,
               activation_reused=False, activation_stored=True,
               output_stored=True):
    self.name = name
    self.sys = sys
    self.fw_flops = fw_flops
    self.agrad_flops = agrad_flops
    self.wgrad_flops = wgrad_flops
    self.inputs_size = inputs_size
    self.output_size = output_size
    # activations equal input size, we store them to compute Wgrad during BW
    self.activation_space = activation_space
    # activation grads equal output size and correspond grads w.r.t. the output
    self.activation_grads = activation_grads
    self.weight_space = weight_space
    self.weight_grads = weight_grads
    self.optim_space = optim_space
    self.optim_sharding_num_proc = 1

    # Add optimizations and parallelization split
    self.needs_recompute = needs_recompute
    self.needs_recomm = needs_recomm
    self.activation_reused=activation_reused
    self.activation_stored = activation_stored
    self.output_stored = output_stored
    # Before bytes_per_element set by SW config, we operate with just
    # parameter count, setting bytes_per_element to 1
    self.bytes_per_element = 1
    self.processing_time = None
    self.net_exposed_time = None

  def get_stats_json(self):
    return {
      'name': self.name,
      'inputs_size': self.inputs_size,
      'outputs_size': self.output_size,
      'fw_flops': self.get_fw_flops(),
      'fw_mem_accessed': self.get_fw_mem_accessed(),
      'fw_arithmetic_intensity': self.get_fw_arithmetic_intensity(),
      'fw_processing_time': self.compute_processing_time('fw'),
      'baseblock_fw_tp_comm_tile': self.get_comm_tile('fw', baseblock=True),
      'edgeblock_fw_tp_comm_tile': self.get_comm_tile('fw', baseblock=False),
      'baseblock_fw_tp_comm_size': self.get_comm_bytes('fw', baseblock=True),
      'edgeblock_fw_tp_comm_size': self.get_comm_bytes('fw', baseblock=False),
      'baseblock_fw_tp_comm_time': self.compute_net_time('fw', baseblock=True),
      'edgeblock_fw_tp_comm_time': self.compute_net_time('fw',baseblock=False),
      'baseblock_fw_tp_comm_time_exposed': self.get_exposed_net_time(
        'fw', baseblock=True),
      'edgeblock_fw_tp_comm_time_exposed': self.get_exposed_net_time(
        'fw', baseblock=False),
      'agrad_flops': self.get_agrad_flops(),
      'agrad_mem_accessed': self.get_agrad_mem_accessed(),
      'agrad_arithmetic_intensity': self.get_agrad_arithmetic_intensity(),
      'agrad_processing_time': self.compute_processing_time('agrad'),
      'baseblock_bw_tp_comm_tile': self.get_comm_tile('agrad', baseblock=True),
      'edgeblock_bw_tp_comm_tile': self.get_comm_tile('agrad', baseblock=False),
      'baseblock_bw_tp_comm_size': self.get_comm_bytes('agrad', baseblock=True),
      'edgeblock_bw_tp_comm_size': self.get_comm_bytes('agrad', baseblock=False),
      'baseblock_bw_tp_comm_time': self.compute_net_time('agrad', baseblock=True),
      'edgeblock_bw_tp_comm_time': self.compute_net_time('agrad', baseblock=False),
      'baseblock_bw_tp_comm_time_exposed': self.get_exposed_net_time(
        'agrad', baseblock=True),
      'edgeblock_bw_tp_comm_time_exposed': self.get_exposed_net_time(
        'agrad', baseblock=False),
      'wgrad_flops': self.get_wgrad_flops(),
      'wgrad_mem_accessed': self.get_wgrad_mem_accessed(),
      'wgrad_arithmetic_intensity': self.get_wgrad_arithmetic_intensity(),
      'wgrad_processing_time': self.compute_processing_time('wgrad'),
      'baseblock_recomm_tile': self.get_comm_tile('wgrad', baseblock=True),
      'edgeblock_recomm_tile': self.get_comm_tile('wgrad', baseblock=False),
      'baseblock_recomm_size': self.get_comm_bytes('wgrad', baseblock=True),
      'edgeblock_recomm_size': self.get_comm_bytes('wgrad', baseblock=False),
      'baseblock_recomm_time': self.compute_net_time('wgrad', baseblock=True),
      'edgeblock_recomm_time': self.compute_net_time('wgrad', baseblock=False),
      'baseblock_recomm_time_exposed': self.get_exposed_net_time(
        'wgrad', baseblock=True),
      'edgeblock_recomm_time_exposed': self.get_exposed_net_time(
        'wgrad', baseblock=False),
      'optim_flops': self.get_optim_step_flops(),
      'optim_mem_accessed': self.get_optim_step_mem_accessed(),
      'optim_arithmetic_intensity': self.get_optim_step_arithmetic_intensity(),
      'optim_processing_time': self.compute_processing_time('optim'),
      'weight': self.get_weight(),
      'activation': self.get_activation(),
      'weight_grad': self.get_weight_grad(),
      'activation_grad': self.get_activation_grad(),
      'optimizer': self.get_optimizer()
    }

  def get_stats_str(self):
    stats = "Operation {0}:\n{1} FW flops, {2} FW bytes accessed,".format(
      self.name,
      human_format(self.get_fw_flops(), 'flops'),
      human_format(self.get_fw_mem_accessed(), 'bytes'))
    stats += " FW AI: {0:.3f}\n".format(self.get_fw_arithmetic_intensity())
    stats += "{0} BW Adrad flops, {1} BW Agrad bytes accessed,".format(
      human_format(self.get_agrard_flops(), 'flops'),
      human_format(self.get_agrad_mem_accessed(), 'bytes'))
    stats += " BW Agrad AI: {0:.3f}\n".format(
      self.get_agrad_arithmetic_intensity())
    stats += "{0} BW Wdrad flops, {1} BW Wgrad bytes accessed,".format(
      human_format(self.get_wgrard_flops(), 'flops'),
      human_format(self.get_wgrad_mem_accessed(), 'bytes'))
    stats += " BW Wgrad AI: {0:.3f}\n".format(
      self.get_wgrad_arithmetic_intensity())
    stats += "{0} Optim flops, {1} Optim bytes accessed,".format(
      human_format(self.get_optim_step_flops(), 'flops'),
      human_format(self.get_optim_step_mem_accessed(), 'bytes'))
    stats += " Optim AI: {0:.3f}\n".format(
      self.get_optim_step_arithmetic_intensity())
    stats += "W: {0}, Act: {1}, WGrad: {2}, AGrad: {3}, Optim: {4}".format(
      human_format(self.get_weight(), 'bytes'),
      human_format(self.get_activation(), 'bytes'),
      human_format(self.get_weight_grad(), 'bytes'),
      human_format(self.get_activation_grad(), 'bytes'),
      human_format(self.get_optimizer(), 'bytes'))
    return stats

  def set_bytes_per_element(self, bytes_per_element):
    self.bytes_per_element = bytes_per_element

  # Shard (distribute) optimizer and weight grads between data parallel nodes
  def shard_optimizer(self, num_procs):
    self.optim_sharding_num_proc = num_procs

  # getters that will be called from Llm model class, can be rewritten
  def get_fw_flops(self):
    return self.fw_flops

  def get_fw_mem_accessed(self):
    mem_accessed = self.inputs_size + self.output_size + self.weight_space
    mem_accessed *= self.bytes_per_element
    return mem_accessed

  def get_fw_arithmetic_intensity(self):
    if self.fw_flops == 0:
      return 0
    if self.get_fw_mem_accessed() == 0:
      return float('inf')
    return self.fw_flops / self.get_fw_mem_accessed()

  def get_recompute_flag(self):
    return self.needs_recompute

  def get_recomm_flag(self):
    return self.needs_recomm

  def reuses_activation(self):
    return self.activation_reused

  def stores_activation(self):
    return self.activation_stored

  def stores_output(self):
    return self.output_stored

  def get_agrad_flops(self):
    return self.agrad_flops

  def get_agrad_mem_accessed(self):
    # activation grads equal output size and correspond grads w.r.t.
    # layer output; activations are equal to input size
    grad_mem = self.weight_space + (
      self.activation_space + self.activation_grads)
    grad_mem *= self.bytes_per_element
    return grad_mem

  def get_agrad_arithmetic_intensity(self):
    if self.agrad_flops == 0:
      return 0
    if self.get_agrad_mem_accessed() == 0:
      return float('inf')
    return self.agrad_flops / self.get_agrad_mem_accessed()

  def get_wgrad_flops(self):
    return self.wgrad_flops

  def get_wgrad_mem_accessed(self):
    if self.weight_space == 0:
      assert self.wgrad_flops == 0, \
        f"Haven't expected to see wgrad flops in layer {self.name}"
      return 0
    # activation grads equal output size and correspond grads w.r.t.
    # layer output; activations are equal to input size
    grad_mem = self.weight_grads + (
      self.activation_space + self.activation_grads)
    grad_mem *= self.bytes_per_element
    return grad_mem

  def get_wgrad_arithmetic_intensity(self):
    if self.wgrad_flops == 0:
      return 0
    if self.get_wgrad_mem_accessed() == 0:
      return float('inf')
    return self.wgrad_flops / self.get_wgrad_mem_accessed()

  # We use Adam optimizer. The amount of flops is based on the number of
  # weight grads to accommodate for possible weight_grad sharding
  # among data parallel nodes
  def get_optim_step_flops(self):
    optim_flops = self.weight_grads / self.optim_sharding_num_proc * 11
    return optim_flops

  def get_optim_step_mem_accessed(self):
    return self.get_optimizer()

  def get_optim_step_arithmetic_intensity(self):
    if self.get_optim_step_flops() == 0:
      return 0
    if self.get_optim_step_mem_accessed() == 0:
      return float('inf')
    return self.get_optim_step_flops() / self.get_optim_step_mem_accessed()

  def get_weight(self):
    return self.weight_space * self.bytes_per_element

  def get_activation(self):
    return self.activation_space * self.bytes_per_element

  def get_output(self):
    return self.output_size * self.bytes_per_element

  def get_weight_grad(self, sharded=True):
    # Keep lower precision copy of grads for mem and net transfers
    grads = self.weight_grads
    if sharded:
      # We keep grads in lower precision for communication
      grads *= self.bytes_per_element
      grads /= self.optim_sharding_num_proc
    else:
      # otherwise keep grads in 32 bit for accumulation
      grads *= 4
    return grads

  def get_activation_grad(self):
    return self.activation_grads * self.bytes_per_element

  def get_optimizer(self):
    # Keep 32-bits master copy of weights, plus both moments (m,v)
    # master copy for grads is accounted for in get_weight_grad()
    moments_size = self.optim_space * 4
    if self.bytes_per_element < 4:
      master_copy_size = self.weight_space * 4
    else:
      master_copy_size = 0
    return (master_copy_size + moments_size) / self.optim_sharding_num_proc

  def set_processing_time(self, processing_time):
    self.processing_time = processing_time

  def get_processing_time(self):
    return self.processing_time

  def use_matrix_engine(self):
    return False

  def get_comm_bytes(self, stage, baseblock=True):
    return 0

  def get_comm_tile(self, stage, baseblock=True):
    return self.get_comm_bytes(stage, baseblock)

  def compute_flops_time(self, stage):
    if stage == "fw":
      flops = self.get_fw_flops()
    elif stage == "agrad":
      flops = self.get_agrad_flops()
    elif stage == "wgrad":
      flops = self.get_wgrad_flops()
    elif stage == "optim":
      flops = self.get_optim_step_flops()
    else:
      raise Exception(f'Bad compute stage : {stage}')
    if self.use_matrix_engine() and stage != "optim":
      throughput = self.sys.get_matrix_throughput(flops)
    else:
      throughput = self.sys.get_vector_throughput(flops)
    return flops / throughput

  def compute_mem_time(self, stage):
    if stage == "fw":
      mem = self.get_fw_mem_accessed()
    elif stage == "agrad":
      mem = self.get_agrad_mem_accessed()
    elif stage == "wgrad":
      mem = self.get_wgrad_mem_accessed()
    elif stage == "optim":
      mem = self.get_optim_step_mem_accessed()
    else:
      raise Exception(f'Bad compute stage : {stage}')
    return mem / self.sys.get_mem1_throughput(mem)

  def compute_net_time(self, stage, baseblock=True):
    return 0

  def get_exposed_net_time(self, stage, baseblock=True):
    return 0

  def get_required_bandwidth(self, stage, baseblock=True):
    return 0

  def compute_processing_time(self, stage):
    self.processing_time =  self.sys.get_processing_time(
      self.compute_flops_time(stage),
      self.compute_mem_time(stage)
    )
    return self.processing_time

# We can factor all layers peculiarities and layer-wise optimizations by
# rewriting parent class member functions when needed
class Linear(Layer):
  def __init__(self, name, sys, batch_seq, c_in, c_out,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    m, n, k = batch_seq, c_in, c_out
    super().__init__(name,
                     sys,
                     fw_flops=2*m*n*k,
                     agrad_flops=2*m*n*k,
                     wgrad_flops=2*m*n*k,
                     inputs_size=m*n,
                     output_size=m*k,
                     weight_space=n*k,
                     weight_grads=n*k,
                     activation_space=m*n,
                     activation_grads=m*k,
                     optim_space=2*n*k,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def use_matrix_engine(self):
    return True

class LinearOverlapped(Layer):
  def __init__(self, name, sys, batch_seq, c_in, c_out, tensor_par_comm_type,
               num_tiles, net_id, num_peers, conjugate=False,
               in_network_reduction=False, tp_overlap='pipe',
               needs_recompute=False, needs_recomm=False,
               activation_reused=False, activation_stored=True,
               output_stored=True):
    m, n, k = batch_seq, c_in, c_out
    self.tensor_par_comm_type = tensor_par_comm_type
    self.num_tiles = num_tiles
    self.net = sys.get_network(net_id)
    self.num_peers = num_peers
    self.conjugate = conjugate
    self.in_network_reduction = in_network_reduction
    self.tp_overlap = tp_overlap
    self._processed_flag = False
    if self.tensor_par_comm_type == 'rs_ag':
      if not conjugate:
        #AllGather case
        assert k % self.num_peers == 0
        # assert m % self.num_peers == 0         # this should be true for seq_par
        k = k // self.num_peers
        act_space = m * n // num_tiles
        act_grad_space = m * k
        act_net_buffer = m * n // num_tiles
        act_grad_net_buffer = 0
      else:
        # ReduceScatter case
        assert n % self.num_peers == 0
        # assert m % self.num_peers == 0         # this should be true for seq_par
        n = n // self.num_peers
        act_space = m * n
        act_grad_space = m * k // num_tiles
        act_net_buffer = 0
        act_grad_net_buffer = m * k // num_tiles
        #act_net_buffer = m * k // num_tiles
    else:
      if not conjugate:
        # AllReduce case
        assert k % self.num_peers == 0
        k = k // self.num_peers
        act_space = m * n
        act_grad_space = 0
        act_net_buffer = m * n // num_tiles
        act_grad_net_buffer = 0
      else:
        # Identityy case
        assert n % self.num_peers == 0
        n = n // self.num_peers
        act_space = 0
        act_grad_space = m * k
        act_net_buffer = 0
        act_grad_net_buffer = m * k

    super().__init__(name,
                     sys,
                     fw_flops=2*m*n*k,
                     agrad_flops=2*m*n*k,
                     wgrad_flops=2*m*n*k,
                     inputs_size=m*n,
                     output_size=m*k,
                     weight_space=n*k,
                     weight_grads=n*k,
                     activation_space=act_space, # + act_net_buffer,
                     activation_grads=act_grad_space + act_grad_net_buffer,
                     optim_space=2*n*k,
                     needs_recompute=needs_recompute,
                     needs_recomm=needs_recomm,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def use_matrix_engine(self):
    return True

  def get_comm_bytes(self, stage, baseblock=True):
    if self.num_peers == 1:
      return 0
    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (
      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)
    ag_comm_size = self.inputs_size * self.bytes_per_element
    ar_rs_comm_size = self.output_size * self.bytes_per_element
    if stage == 'fw':
      if self.conjugate:
        # ReduceScatter or AllReduce on FW
        return ar_rs_comm_size
      else:
        if split_comm:
          # AllGather on FW
          return ag_comm_size
        else:
          # Identity on FW
          return 0
    if stage == 'agrad':
      # Comm sizes during FW and BW pass are the same
      if not self.conjugate:
        # ReduceScatter or AllReduce on BW
        return ag_comm_size
      else:
        if split_comm:
          # AllGather on BW
          return ar_rs_comm_size
        else:
          # Identity on BW
          return 0
    if stage == 'wgrad':
      if self.needs_recomm:
        return self.get_comm_bytes('fw', baseblock)
      else:
        return 0
    if stage == 'optim':
      return 0

  def get_comm_flops(self, stage, baseblock=True):
    return self.get_comm_bytes(stage, baseblock) / self.bytes_per_element

  def get_num_tiles(self):
    return self.num_tiles

  def get_comm_tile(self, stage, baseblock=True):
    return self.get_comm_bytes(stage, baseblock) / self.get_num_tiles()

  def compute_net_time(self, stage, baseblock=True):
    if self.num_peers == 1:
      return 0
    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (
      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)
    if self.conjugate:
      if split_comm:
        # ReduceScatter case
        fw_comm_type = 'reduce_scatter'
        bw_comm_type = 'all_gather'
      else:
        #AllReduce case
        fw_comm_type = 'all_reduce'
        bw_comm_type = None
      if not self.in_network_reduction:
        fw_flops = self.get_comm_flops(stage, baseblock) * (
          self.num_peers - 1) / self.num_peers
        fw_flop_time = fw_flops / self.sys.get_vector_throughput(fw_flops)
      else:
        fw_flop_time = 0
      bw_flop_time = 0
    else:
      if split_comm:
        #AllGather case
        fw_comm_type = 'all_gather'
        bw_comm_type = 'reduce_scatter'
      else:
        # Identity case
        fw_comm_type = None
        bw_comm_type = 'all_reduce'
      fw_flop_time = 0
      if not self.in_network_reduction:
        bw_flops = self.get_comm_flops(stage, baseblock) * (
          self.num_peers - 1) / self.num_peers
        bw_flop_time = bw_flops / self.sys.get_vector_throughput(bw_flops)
      else:
        bw_flop_time = 0
    if stage == 'fw':
      if fw_comm_type == None:
        return 0
      else:
        fw_net_time = self.net.time(
          fw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)
        return fw_net_time + fw_flop_time
    if stage == 'agrad':
      if bw_comm_type == None:
        return 0
      else:
        bw_net_time = self.net.time(
          bw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)
        return bw_net_time + bw_flop_time
    if stage == 'wgrad':
      if self.needs_recomm and fw_comm_type:
        # AllGather Redo (RS_AG only) or full recompute
        return self.net.time(
          fw_comm_type, self.get_comm_bytes(stage, baseblock), self.num_peers)
      else:
        return 0
    if stage == 'optim':
      return 0

  def compute_processing_time(self, stage):
    flop_time = self.compute_flops_time(stage)
    flop_time_slowed = flop_time / (1 - self.net.processor_usage)
    mem_time = self.compute_mem_time(stage)
    net_time = self.compute_net_time(stage)
    compute_time = self.sys.get_processing_time(flop_time, mem_time)
    if net_time == 0:
      time = compute_time
      net_exposed_time = 0
    else:
      compute_time_slowed = self.sys.get_processing_time(
        flop_time_slowed, mem_time)
      # Tiled time computed as fraction of full time, to model high effective
      # throughput when processing many consequitive tiles
      flop_tile = flop_time / self.num_tiles
      flop_tile_slowed = flop_time_slowed / self.num_tiles
      net_tile = net_time / self.num_tiles
      compute_tile = compute_time / self.num_tiles
      compute_tile_slowed = compute_time_slowed / self.num_tiles
      overlap_inflection = net_tile - flop_tile_slowed
      # we have one exposed comm tile if tp_comm is not ring,
      # one exposed compute tile, and
      # (Proc - 1) overlapped tiles, where either compute or comm is exposed
      if overlap_inflection > 0:
        # Tcomm is larger than compute, excess is exposed
        # compute time itself is the compute + mem
        time = compute_tile + (self.num_tiles - 1) * compute_tile_slowed
        net_exposed_time = (self.num_tiles - 1) * overlap_inflection
      else:
        # Tcomm is smaller than compute and hidden, but it contributes to
        # compute slowdown due part of compute resources orchestrating comm
        time = compute_tile + (self.num_tiles - 1) * compute_tile + (
          self.num_tiles - 1) * net_tile * self.net.processor_usage
        net_exposed_time = 0
      if self.tp_overlap == 'pipe':
        # If overlap type is pipe, we need to add an exposed comm tile
        # with ring-based overlap, we have a special schedule for comm and avoid
        # sending an extra tile we have in the beginning
        net_exposed_time += net_tile
        time += net_tile
    self.processing_time = time
    self.net_exposed_time = net_exposed_time
    self._processed_flag = True
    return self.processing_time

  def get_exposed_net_time(self, stage, baseblock=True):
    # only use after calling compute_processing_time(), otherwise it's set with None
    assert self._processed_flag
    return self.net_exposed_time

  def get_required_bandwidth(self, stage, baseblock=True):
    assert self._processed_flag
    net_tile_size = self.get_comm_tile(stage, baseblock)
    flop_time = self.compute_flops_time(stage)
    flop_time_slowed = flop_time / (1 - self.net.processor_usage)
    flop_tile_slowed = flop_time_slowed / self.num_tiles
    return net_tile_size / flop_tile_slowed

class BatchMatMul(Layer):
  def __init__(self, name, sys, batch, size_a, contraction_size, size_b,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    m, n, k = size_a, contraction_size, size_b
    super().__init__(name,
                     sys,
                     fw_flops=batch*2*m*n*k,
                     agrad_flops=batch*2*2*m*n*k,
                     inputs_size=batch*(m*n+n*k),
                     output_size=batch*m*k,
                     activation_space=batch*(m*n+n*k),
                     activation_grads=batch*m*k,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def use_matrix_engine(self):
    return True

# https://kratzert.github.io/2016/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer.html
# https://cthorey.github.io./blog/2016/backpropagation/
class LayerNorm(Layer):
  def __init__(self, name, sys, act_size, hidden,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    super().__init__(name,
                     sys,
                     fw_flops=9*act_size,
                     agrad_flops=14*act_size,
                     wgrad_flops=7*act_size,
                     inputs_size=act_size,
                     output_size=act_size,
                     activation_space=act_size,
                     activation_grads=act_size,
                     weight_space=2*hidden,
                     weight_grads=2*hidden,
                     optim_space=2*2*hidden,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)


class DropOut(Layer):
  def __init__(self, name, sys, act_size,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    super().__init__(name,
                     sys,
                     fw_flops=act_size,
                     agrad_flops=act_size,
                     inputs_size=act_size,
                     output_size=act_size,
                     activation_space=act_size,
                     activation_grads=act_size,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)


  # need to account for DropOut mask of bool type that takes 1 B per element
  # mask is the only DropOut activation
  def get_activation(self):
    return self.activation_space

  def get_activation_grad(self):
    return self.activation_grads

  def get_fw_mem_accessed(self):
    mask_size = self.activation_space
    mem_accessed = self.inputs_size + self.output_size
    mem_accessed *= self.bytes_per_element
    mem_accessed += mask_size
    return mem_accessed

  def get_agrad_mem_accessed(self):
    return self.get_fw_mem_accessed()


# https://mlfromscratch.com/activation-functions-explained/#/
class GeLU(Layer):
  def __init__(self, name, sys, act_size,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True,
               fused=False):
    # Fused GeLU runs right after previous Linear layer and does not store
    # activations or gradients
    self._fused = fused
    if fused:
      eff_act_space = 0
      eff_act_grads = 0
    else:
      eff_act_space = act_size
      eff_act_grads = act_size
    super().__init__(name, sys, fw_flops=8*act_size, agrad_flops=13*act_size,
                     inputs_size=act_size, output_size=act_size,
                     activation_space=eff_act_space,
                     activation_grads=eff_act_grads,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def get_agrad_mem_accessed(self):
    return self.get_fw_mem_accessed()


# https://automata88.medium.com/how-to-implement-the-softmax-derivative-independently-from-any-loss-function-ae6d44363a9d
class SoftMax(Layer):
  def __init__(self, name, sys, act_size,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    super().__init__(name,
                     sys,
                     fw_flops=5*act_size,
                     agrad_flops=8*act_size,
                     inputs_size=act_size,
                     output_size=act_size,
                     activation_space=act_size,
                     activation_grads=act_size,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def get_agrad_mem_accessed(self):
    return self.get_fw_mem_accessed()


# https://explained.ai/matrix-calculus/#sec:1.4.2
class ElementWise(Layer):
  def __init__(self, name, sys, operand1, operand2,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    act_size = max(operand1, operand2)
    super().__init__(name,
                     sys,
                     fw_flops=act_size,
                     agrad_flops=(operand1+operand2),
                     inputs_size=(operand1+operand2),
                     output_size=act_size,
                     activation_space=(operand1+operand2),
                     activation_grads=act_size,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)


# Splits activation on the forward pass, sums gradients on the backward
class Fork(Layer):
  def __init__(self, name, sys, act_size, num_users,
               needs_recompute=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    self.num_users = num_users
    super().__init__(name,
                     sys,
                     inputs_size=act_size,
                     agrad_flops=num_users*act_size,
                     activation_space=act_size,
                     # Gradients from num_users accumulated in a single storage
                     # that's accounted in the other layers
                     # use 0 here to avoid double accounting
                     activation_grads=0,
                     needs_recompute=needs_recompute,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def get_fw_mem_accessed(self):
    return 0

  def get_agrad_mem_accessed(self):
    return self.activation_space * self.bytes_per_element * (
      self.num_users + 1)


class TPComm(Layer):

  def __init__(self, name, sys, act_size, net_id, num_peers, tensor_par_comm_type,
               conjugate=False, in_network_reduction=False,
               needs_recomm=False, activation_reused=False,
               activation_stored=True, output_stored=True):
    self.net = sys.get_network(net_id)
    self.num_peers = num_peers
    self.tensor_par_comm_type = tensor_par_comm_type
    self.comm_size = act_size
    self.conjugate = conjugate
    if self.num_peers == 1:
      fw_flops = 0
      bw_flops = 0
      in_size = 0
      out_size = 0
    else:
      if not self.conjugate:
        # FW pass Identity/AllGather, BW pass AllReduce/ReduceScatter
        fw_flops = 0
        if not in_network_reduction:
          bw_flops = act_size * (self.num_peers - 1) / self.num_peers
        else:
          bw_flops = 0
        in_size = act_size
        out_size = act_size
      else:
        # Conjugate function is opposite
        if not in_network_reduction:
          fw_flops = act_size * (self.num_peers - 1) / self.num_peers
        else:
          fw_flops = 0
        bw_flops = 0
        in_size = act_size
        out_size = act_size
    super().__init__(name,
                     sys,
                     fw_flops=fw_flops,
                     agrad_flops=bw_flops,
                     inputs_size=in_size,
                     output_size=out_size,
                     activation_space=in_size,
                     activation_grads=out_size,
                     needs_recomm=needs_recomm,
                     activation_reused=activation_reused,
                     activation_stored=activation_stored,
                     output_stored=output_stored)

  def get_activation(self):
    if self.tensor_par_comm_type == 'rs_ag':
      return self.activation_space * self.bytes_per_element / self.num_peers
    else:
      if self.conjugate:
        return self.activation_space * self.bytes_per_element
      else:
        # Identity
        return 0

  def get_fw_mem_accessed(self):
    if not self.tensor_par_comm_type == 'rs_ag' and not self.conjugate:
      # Identity
      return 0
    else:
      return super().get_fw_mem_accessed()

  def get_activation_grad(self):
    if self.tensor_par_comm_type == 'rs_ag':
      return self.activation_space * self.bytes_per_element / self.num_peers
    else:
      if not self.conjugate:
        return self.activation_grads * self.bytes_per_element
      else:
        # Identity
        return 0

  def get_agrad_mem_accessed(self):
    if not self.tensor_par_comm_type == 'rs_ag' and self.conjugate:
      # Identity
      return 0
    else:
      return super().get_agrad_mem_accessed()

  def get_comm_bytes(self, stage, baseblock=True):
    if self.num_peers == 1:
      return 0
    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (
      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)
    if (not split_comm and (self.conjugate and stage == 'agrad' or
        not self.conjugate and stage == 'fw')):
      # Identity FW or AllReduce BW
      return 0
    else:
      if stage == 'fw' or stage == 'agrad':
        return self.comm_size * self.bytes_per_element
      if stage == 'wgrad' and self.needs_recomm and (
          split_comm or self.conjugate):
        # with AG Redo, we need recomm both on FW pass (not self.conjugate)
        # and BW pass (self.conjugate)
        return self.comm_size * self.bytes_per_element
      else:
        # optim and wgrad stage has no comm if no ag_redo flag for RS_AG
        return 0

  def compute_net_time(self, stage, baseblock=True):
    if self.num_peers == 1:
      return 0
    split_comm = (self.tensor_par_comm_type == 'rs_ag') or (
      (self.tensor_par_comm_type == 'p2p_rs_ag') and not baseblock)
    net_compute_time = super().compute_processing_time(stage)
    if split_comm:
      if self.conjugate:
        # ReduceScatter case
        fw_net_time = self.net.time('reduce_scatter',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
        bw_net_time = self.net.time('all_gather',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
      else:
        #AllGather case
        fw_net_time = self.net.time('all_gather',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
        bw_net_time = self.net.time('reduce_scatter',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
    else:
      if self.conjugate:
        fw_net_time = self.net.time('all_reduce',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
        bw_net_time = 0
      else:
        fw_net_time = 0
        bw_net_time = self.net.time('all_reduce',
          self.get_comm_bytes(stage, baseblock), self.num_peers)
    if stage == 'fw':
      return fw_net_time + net_compute_time
    elif stage == 'agrad':
      return bw_net_time + net_compute_time
    elif stage == 'wgrad':
      # with AG Redo, we need recomm both on FW pass (not self.conjugate)
      # and BW pass (self.conjugate)
      if self.needs_recomm:
        return fw_net_time + net_compute_time
      else:
        return 0
    elif stage == 'optim':
      return 0
    else:
      raise Exception(f'Bad compute stage : {stage}')
    return 0

  def get_exposed_net_time(self, stage, baseblock=True):
    # only use after calling compute_processing_time(), otherwise it's set witth None
    return self.compute_net_time(stage, baseblock)

  def compute_processing_time(self, stage):
    return 0


================================================
FILE: calculon/llm/llm.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

from calculon import *
from .layers import *


class Llm:
  """
  This implements the transformer with tensor, pipeline, and data parallelism.
  Using it follows this pattern:
  1. Initialize the model with certain model parameters
  2. Compile it with certain optimizations and parallelization strategies
  3. Run on particular hardware system
  """

  class Application:
    """Specifies the application configuration."""
    def __init__(self, cfg):
      self.cfg = cfg
      self.hidden = cfg['hidden']
      self.feedforward = cfg['feedforward']
      self.seq_size = cfg['seq_size']
      self.attn_heads = cfg['attn_heads']
      self.attn_size = cfg['attn_size']
      self.num_blocks = cfg['num_blocks']

    def num_parameters(self):
      # https://cs.stanford.edu/~matei/papers/2021/sc_megatron_lm.pdf
      # Equation 2
      p = 2 * self.hidden * self.feedforward                   # MLP weights
      p += 4 * self.hidden * self.attn_heads * self.attn_size  # Attn weights
      p += self.hidden + self.feedforward                      # biases MLP
      p += 3 * self.attn_heads * self.attn_size + self.hidden  # biases Attn
      p += 2 * 2 * self.hidden                                 # layer norm
      p *= self.num_blocks                                     # per each block
      p += (51200 + self.seq_size) * self.hidden               # embeddings
      return p

  class Execution:
    """Specifies the execution configuration."""

    @staticmethod
    def fields():
      return (
        'num_procs', 'tensor_par', 'pipeline_par', 'data_par', 'tensor_par_net',
        'pipeline_par_net', 'data_par_net', 'batch_size', 'microbatch_size',
        'datatype', 'fused_activation', 'attention_type', 'activation_recompute',
        'pipeline_interleaving', 'optimizer_sharding', 'tensor_par_comm_type',
        'tensor_par_overlap', 'seq_par_ag_redo', 'data_par_overlap',
        'weight_offload', 'activations_offload', 'optimizer_offload', 'training')

    @staticmethod
    def from_json(cfg):
      assert set(cfg.keys()) == set(Llm.Execution.fields())
      values = [cfg[field] for field in Llm.Execution.fields()]
      return Llm.Execution(*values)

    def __init__(self, num_procs, tensor_par, pipeline_par, data_par,
                 tensor_par_net, pipeline_par_net, data_par_net,
                 batch_size, microbatch_size, datatype,
                 fused_activation, attention_type, activation_recompute,
                 pipeline_interleaving, optimizer_sharding,
                 tensor_par_comm_type, tensor_par_overlap,
                 seq_par_ag_redo, data_par_overlap, weight_offload,
                 activations_offload, optimizer_offload, training):
      self.training = training
      self.num_procs = num_procs
      assert self.num_procs > 0
      self.tensor_par = tensor_par
      assert self.tensor_par > 0
      self.pipeline_par = pipeline_par
      assert self.pipeline_par > 0
      self.data_par = data_par
      assert self.data_par > 0
      assert self.num_procs == self.tensor_par * self.pipeline_par * \
        self.data_par, 'tensor * pipeline * data parallelism != num_procs'
      self.tensor_par_net = tensor_par_net
      self.pipeline_par_net = pipeline_par_net
      self.data_par_net = data_par_net
      self.global_batch_size = batch_size
      assert self.global_batch_size > 0
      self.microbatch_size = microbatch_size
      assert self.microbatch_size > 0
      assert self.global_batch_size % self.data_par == 0
      self._local_batch_size = self.global_batch_size // self.data_par
      assert self._local_batch_size % self.microbatch_size == 0
      self._num_microbatches = self._local_batch_size // self.microbatch_size
      self.datatype = datatype
      self.fused_activation = fused_activation
      self.attention_type = attention_type
      assert self.attention_type in ['multihead', 'multiquery']
      self.activation_recompute = activation_recompute
      assert self.activation_recompute in ['full', 'attn_only', 'none']
      if self.activation_recompute in ['full', 'attn_only']:
        assert self.training, "We only perform recompute during training"
      self.pipeline_interleaving = pipeline_interleaving
      assert self.pipeline_interleaving > 0, \
        f'Bad pipeline interleaving of {self.pipeline_interleaving}'
      if self.pipeline_par == 1:
        assert self.pipeline_interleaving == 1, \
        f'Bad pipeline interleaving of {self.pipeline_interleaving} with PP=1'
      self.optimizer_sharding = optimizer_sharding
      if self.optimizer_sharding:
        assert self.data_par > 1, "We perform optimizer sharding with DP > 1"
      self.tensor_par_comm_type = tensor_par_comm_type
      self.in_network_reduction = False
      assert self.tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']
      self.tensor_par_overlap = tensor_par_overlap
      assert self.tensor_par_overlap in ['none', 'ring', 'pipe']
      if self.tensor_par_overlap != 'none':
        assert self.tensor_par > 1, "We perform TP comm overlap with TP > 1"
      self._sequence_par = self.tensor_par_comm_type == 'rs_ag'
      self.seq_par_ag_redo = seq_par_ag_redo
      if self.seq_par_ag_redo:
        assert self.tensor_par_comm_type == 'rs_ag', "We only redo AG comm"
        assert self._sequence_par, "We only redo AG with sequence parallelism"
        assert self.activation_recompute != 'full', \
          "We assume no extra AG with full recompute"
      self._pipeline_par_rs_ag = \
        self.tensor_par_comm_type in ['p2p_rs_ag', 'rs_ag']
      self.data_par_overlap = data_par_overlap
      if self.data_par_overlap:
        assert self.training, "We only perform DP comm overlap during training"
        assert self.data_par > 1, "We perform DP comm overlap with DP > 1"
      self.weight_offload = weight_offload
      self.activations_offload = activations_offload
      self.optimizer_offload = optimizer_offload
      if self.optimizer_offload:
        assert self.training, \
          "We only perform optimizer offloading during training"

    def get_json(self):
      keys = Llm.Execution.fields()
      values = [
        self.num_procs, self.tensor_par, self.pipeline_par, self.data_par, self.tensor_par_net,
        self.pipeline_par_net, self.data_par_net, self.global_batch_size, self.microbatch_size,
        self.datatype, self.fused_activation, self.attention_type, self.activation_recompute,
        self.pipeline_interleaving, self.optimizer_sharding, self.tensor_par_comm_type,
        self.tensor_par_overlap, self.seq_par_ag_redo, self.data_par_overlap,
        self.weight_offload, self.activations_offload, self.optimizer_offload, self.training
      ]
      assert len(keys) == len(values)
      return dict(zip(keys, values))

    def get_peers_json(self):
      peers = {}
      for di in range(self.data_par):
        for pi in range(self.pipeline_par):
          for ti in range(self.tensor_par):
            nid = (di * self.tensor_par * self.pipeline_par +
                   pi * self.tensor_par +
                   ti)
            peers[nid] = {}

            # tensor parallelism peers
            if self.tensor_par > 1:
              peers[nid]['tensor'] = []
              for ti2 in range(self.tensor_par):
                pid = (di * self.tensor_par * self.pipeline_par +
                       pi * self.tensor_par +
                       ti2)
                peers[nid]['tensor'].append(pid)

            # pipeline parallelism peer
            if self.pipeline_par > 1:
              peers[nid]['pipeline'] = None
              pi2 = (pi + 1) % self.pipeline_par
              pid = (di * self.tensor_par * self.pipeline_par +
                     pi2 * self.tensor_par +
                     ti)
              peers[nid]['pipeline'] = pid

            # data parallelism peers
            if self.data_par > 1:
              peers[nid]['data'] = []
              for di2 in range(self.data_par):
                pid = (di2 * self.tensor_par * self.pipeline_par +
                       pi * self.tensor_par +
                       ti)
                peers[nid]['data'].append(pid)
      return peers


  # This is used for errors where the user may not be fully aware of
  # limitations. Use it like this:
  #   raise self.Error(f'Foo bar {num1} is not {num2}')
  class Error(Exception):
    pass

  @staticmethod
  def _factors(x):
    for cand in range(1, x + 1):
      if x % cand == 0:
        yield cand

  @staticmethod
  def get_all_tensor_parallelisms(num_procs, hidden, attn_heads):
    for cand in Llm._factors(num_procs):
      if hidden % cand == 0 and attn_heads % cand == 0:
        yield cand

  @staticmethod
  def get_all_pipeline_parallelisms(num_procs, tensor_par, num_blocks):
    assert num_procs % tensor_par == 0
    max_pp = min(num_procs // tensor_par, num_blocks)
    for cand in Llm._factors(max_pp):
      if (num_procs % (tensor_par * cand) == 0 and
          num_blocks % cand == 0):
        yield cand

  @staticmethod
  def get_data_parallelism(num_procs, tensor_par, pipeline_par):
    assert num_procs % (tensor_par * pipeline_par) == 0, \
      f'np={num_procs} tp={tensor_par} pp={pipeline_par}'
    return num_procs // (tensor_par * pipeline_par)

  @staticmethod
  def get_valid_pipeline_interleavings(num_blocks, pipeline_par):
    assert num_blocks % pipeline_par == 0
    if pipeline_par == 1:
      yield 1
    else:
      max_ppint = num_blocks // pipeline_par
      yield from Llm._factors(max_ppint)

  @staticmethod
  def get_valid_microbatch_sizes(
      seq_size, tensor_par, data_par, global_batch_size, pipeline_par):
    assert global_batch_size % data_par == 0
    local_batch_size = global_batch_size // data_par
    for cand in Llm._factors(local_batch_size):
      batch_seq = cand * seq_size
      if batch_seq % tensor_par == 0:
        yield cand

  @staticmethod
  def can_redo_ag(tensor_par_comm_type, activation_recompute):
    return tensor_par_comm_type == 'rs_ag' and activation_recompute != 'full'

  def __init__(self, app, log):
    assert isinstance(app, self.Application)
    self.app = app
    self.log = log

    # Set during compile
    self.exe = None

    # Set during run
    self.sys = None

    # State of calling compile() and run()
    self._compiled = False
    self._executed = False

    # Holds the layers in a single block
    self._llm_block = []

    # A chunk is a set of blocks for microbatch before passing to the next
    # processor in the pipeline. Each chunk is modeled as a base
    # block that is repeated N-1 times and followed by 1 edge block.
    # Recommunication time is the same in both base and edge blocks.
    self._blocks_per_proc = None
    self._bubble_reduction_blocks = None
    self._blocks_per_chunk = None
    self._chunks_per_proc = None
    self._baseblocks_per_chunk = None
    self._edgeblocks_per_chunk = None

    # Misc compilation values
    self._bytes_per_element = None
    self._batch_seq = None
    self._batch_seq_par = None
    self._activation_size = None
    self._seq_par_activation_size = None

    # Assignments to specific networks
    self._tp_net = None
    self._pp_net = None
    self._dp_net = None

    # metrics collected after run for each microbatch
    self._block_fw_flops = None
    self._block_fw_flops_time = None
    self._block_fw_mem_accessed = None
    self._block_fw_mem_time = None
    self._block_fw_time = None
    self._block_re_flops = None
    self._block_re_flops_time = None
    self._block_re_mem_accessed = None
    self._block_re_mem_time = None
    self._block_re_time = None
    self._block_agrad_flops = None
    self._block_agrad_flops_time = None
    self._block_agrad_mem_accessed = None
    self._block_agrad_mem_time = None
    self._block_agrad_time = None
    self._block_wgrad_flops = None
    self._block_wgrad_flops_time = None
    self._block_wgrad_mem_accessed = None
    self._block_wgrad_mem_time = None
    self._block_wgrad_time = None
    self._block_optim_flops = None
    self._block_optim_flops_time = None
    self._block_optim_mem_accessed = None
    self._block_optim_mem_time = None
    self._block_optim_time = None

    self._baseblock_fw_tp_size = None
    self._edgeblock_fw_tp_size = None
    self._baseblock_agrad_tp_size = None
    self._edgeblock_agrad_tp_size = None
    self._baseblock_recomm_size = None
    self._edgeblock_recomm_size = None
    self._block_fw_pp_size = None
    self._block_bw_pp_size = None
    self._block_dp_size = None
    self._baseblock_fw_time_no_offload = None
    self._edgeblock_fw_time_no_offload = None
    self._baseblock_bw_time_no_offload = None
    self._edgeblock_bw_time_no_offload = None
    self._baseblock_fw_offload_overhead = None
    self._edgeblock_fw_offload_overhead = None
    self._baseblock_bw_offload_overhead = None
    self._edgeblock_bw_offload_overhead = None
    self._baseblock_fw_time = None
    self._edgeblock_fw_time = None
    self._baseblock_bw_time = None
    self._edgeblock_bw_time = None
    self._block_dp_time = None
    self._tp_bw_overlap_req = None
    self._dp_bw_overlap_req_chunk = None
    self._dp_bw_overlap_req_tail = None

    self._block_weight_space = None
    self._block_act_working_space = None
    self._block_act_storage_space = None
    self._block_act_checkpoint_size = None
    self._block_weight_grad_space = None
    self._block_weight_grad_space_no_sharding = None
    self._block_act_grad_space = None
    self._block_optimizer_space = None

    # Top level memory usage stats
    self._weight_space = None
    self._act_space = None
    self._act_checkpoint_size = None
    self._weight_grad_space = None
    self._act_grad_space = None
    self._optimizer_space = None

    # Top level throughput stats
    self._fw_flops = None
    self._fw_flops_time = None
    self._fw_mem_accessed = None
    self._fw_mem_time = None
    self._fw_time = None
    self._baseblock_fw_tp_time = None
    self._edgeblock_fw_tp_time = None
    self._baseblock_fw_tp_time_exposed = None
    self._edgeblock_fw_tp_time_exposed = None
    self._re_flops = None
    self._re_flops_time = None
    self._re_mem_accessed = None
    self._re_mem_time = None
    self._re_time = None
    self._baseblock_recomm_time = None
    self._edgeblock_recomm_time = None
    self._baseblock_recomm_time_exposed = None
    self._edgeblock_recomm_time_exposed = None
    self._agrad_flops = None
    self._agrad_flops_time = None
    self._agrad_mem_accessed = None
    self._agrad_mem_time = None
    self._baseblock_agrad_tp_time = None
    self._edgeblock_agrad_tp_time = None
    self._baseblock_agrad_tp_time_exposed = None
    self._edgeblock_agrad_tp_time_exposed = None
    self._agrad_time = None
    self._wgrad_flops = None
    self._wgrad_flops_time = None
    self._wgrad_mem_accessed = None
    self._wgrad_mem_time = None
    self._wgrad_time = None
    self._optim_flops = None
    self._optim_flops_time = None
    self._optim_mem_accessed = None
    self._optim_mem_time = None
    self._optim_time = None

    # Top level network stats
    self._tp_comm_time_exposed = None
    self._tp_comm_time_link = None
    self._recomm_time_exposed = None
    self._recomm_time_link = None
    self._pp_comm_time_exposed = None
    self._pp_comm_time_link = None
    self._dp_comm_time_exposed = None
    self._dp_comm_time_link = None
    self._bubble_time = None

  @staticmethod
  def get_stats_fields():
    return (
      'block_fw_flops',
      'block_fw_flops_time',
      'block_fw_mem_accessed',
      'block_fw_mem_time',
      'block_fw_time',
      'baseblock_fw_tp_time',
      'edgeblock_fw_tp_time',
      'baseblock_fw_tp_time_exposed',
      'edgeblock_fw_tp_time_exposed',
      'block_re_flops',
      'block_re_flops_time',
      'block_re_mem_accessed',
      'block_re_mem_time',
      'block_re_time',
      'baseblock_recomm_time',
      'edgeblock_recomm_time',
      'baseblock_recomm_time_exposed',
      'edgeblock_recomm_time_exposed',
      'block_agrad_flops',
      'block_agrad_flops_time',
      'block_agrad_mem_accessed',
      'block_agrad_mem_time',
      'block_agrad_time',
      'baseblock_agrad_tp_time',
      'edgeblock_agrad_tp_time',
      'baseblock_agrad_tp_time_exposed',
      'edgeblock_agrad_tp_time_exposed',
      'block_wgrad_flops',
      'block_wgrad_flops_time',
      'block_wgrad_mem_accessed',
      'block_wgrad_mem_time',
      'block_wgrad_time',
      'block_optim_flops',
      'block_optim_flops_time',
      'block_optim_mem_accessed',
      'block_optim_mem_time',
      'block_optim_time',

      'baseblock_fw_tp_size',
      'edgeblock_fw_tp_size',
      'baseblock_bw_tp_size',
      'edgeblock_bw_tp_size',
      'baseblock_recomm_size',
      'edgeblock_recomm_size',
      'block_fw_pp_size',
      'block_bw_pp_size',
      'block_dp_size',
      'tp_bw_overlap_req',
      'dp_bw_overlap_req_chunk',
      'dp_bw_overlap_req_tail',

      'block_weight_space',
      'block_act_working_space',
      'block_act_storage_space',
      'block_act_checkpoint_size',
      'block_weight_grad_space',
      'block_weight_grad_space_no_sharding',
      'block_act_grad_space',
      'block_optimizer_space',

      'weight_space_with_offload',
      'act_space_with_offload',
      'act_checkpoint_size_with_offload',
      'act_grad_space_with_offload',
      'weight_grad_space_with_offload',
      'optimizer_space_with_offload',

      'weight_space',
      'act_space',
      'act_checkpoint_size',
      'act_grad_space',
      'weight_grad_space',
      'optimizer_space',

      'fw_time',
      'bw_time',
      'optim_step_time',
      'recompute_time',
      'recomm_link_time',
      'recomm_exposed_time',
      'bubble_time',
      'tp_comm_link_time',
      'pp_comm_link_time',
      'dp_comm_link_time',
      'tp_comm_exposed_time',
      'pp_comm_exposed_time',
      'dp_comm_exposed_time',
      'fw_offload_exposed_time',
      'bw_offload_exposed_time',
      'total_time',
      'act_offload_bw_req',
      'weight_offload_bw_req',
      'optim_offload_bw_req',
      'offload_mem_bw_req',
      'proc_mem_tier1_cap_req',
      'proc_mem_tier2_cap_req',
      'useful_flops',
      'compute_efficiency',
      'system_efficiency',
      'total_efficiency',
      'sample_rate')

  def get_stats_values(self):
    assert self._executed
    return (
      self._block_fw_flops,
      self._block_fw_flops_time,
      self._block_fw_mem_accessed,
      self._block_fw_mem_time,
      self._block_fw_time,
      self._baseblock_fw_tp_time,
      self._edgeblock_fw_tp_time,
      self._baseblock_fw_tp_time_exposed,
      self._edgeblock_fw_tp_time_exposed,
      self._block_re_flops,
      self._block_re_flops_time,
      self._block_re_mem_accessed,
      self._block_re_mem_time,
      self._block_re_time,
      self._baseblock_recomm_time,
      self._edgeblock_recomm_time,
      self._baseblock_recomm_time_exposed,
      self._edgeblock_recomm_time_exposed,
      self._block_agrad_flops,
      self._block_agrad_flops_time,
      self._block_agrad_mem_accessed,
      self._block_agrad_mem_time,
      self._block_agrad_time,
      self._baseblock_agrad_tp_time,
      self._edgeblock_agrad_tp_time,
      self._baseblock_agrad_tp_time_exposed,
      self._edgeblock_agrad_tp_time_exposed,
      self._block_wgrad_flops,
      self._block_wgrad_flops_time,
      self._block_wgrad_mem_accessed,
      self._block_wgrad_mem_time,
      self._block_wgrad_time,
      self._block_optim_flops,
      self._block_optim_flops_time,
      self._block_optim_mem_accessed,
      self._block_optim_mem_time,
      self._block_optim_time,

      self._baseblock_fw_tp_size,
      self._edgeblock_fw_tp_size,
      self._baseblock_agrad_tp_size,
      self._edgeblock_agrad_tp_size,
      self._baseblock_recomm_size,
      self._edgeblock_recomm_size,
      self._block_fw_pp_size,
      self._block_bw_pp_size,
      self._block_dp_size,
      self._tp_bw_overlap_req,
      self._dp_bw_overlap_req_chunk,
      self._dp_bw_overlap_req_tail,

      self._block_weight_space,
      self._block_act_working_space,
      self._block_act_storage_space,
      self._block_act_checkpoint_size,
      self._block_weight_grad_space,
      self._block_weight_grad_space_no_sharding,
      self._block_act_grad_space,
      self._block_optimizer_space,

      self.get_weight_space_min(),
      self.get_act_space_min(),
      self.get_act_checkpoint_size_min(),
      self.get_act_grad_space_min(),
      self.get_weight_grad_space_min(),
      self.get_optimizer_space_min(),

      self.get_weight_space(),
      self.get_act_space(),
      self.get_act_checkpoint_size(),
      self.get_act_grad_space(),
      self.get_weight_grad_space(),
      self.get_optimizer_space(),

      self.get_fw_time(),
      self.get_bw_time(),
      self.get_optim_step_time(),
      self.get_recompute_time(),
      self.get_recomm_link_time(),
      self.get_recomm_exposed_time(),
      self.get_bubble_time(),
      self.get_tp_comm_link_time(),
      self.get_pp_comm_link_time(),
      self.get_dp_comm_link_time(),
      self.get_tp_comm_exposed_time(),
      self.get_pp_comm_exposed_time(),
      self.get_dp_comm_exposed_time(),
      self.get_fw_offload_overhead(),
      self.get_bw_offload_overhead(),
      self.get_total_time(),
      self.get_act_offload_bw_req(),
      self.get_weight_offload_bw_req(),
      self.get_optim_offload_bw_req(),
      self.get_offload_mem_bw_req(),
      self.get_mem_tier1_cap_req(),
      self.get_mem_tier2_cap_req(),
      self.get_useful_flops(),
      self.get_compute_efficiency(),
      self.get_system_efficiency(),
      self.get_total_efficiency(),
      self.get_sample_rate())

  def get_stats_json(self, include_layers):
    assert self._executed
    keys = Llm.get_stats_fields()
    values = self.get_stats_values()
    assert len(keys) == len(values), f'{len(keys)} {len(values)}'
    j = dict(zip(keys, values))
    if include_layers:
      j['layers'] = []
      for layer in self._llm_block:
        j['layers'].append(layer.get_stats_json())
    return j

  def _build_attn_block(self):
    recompute_flag = self.exe.activation_recompute == "full"
    recompute_attn_flag = self.exe.activation_recompute in \
      ["full", "attn_only"]
    recompute_ag_flag = recompute_attn_flag or self.exe.seq_par_ag_redo

    assert self.app.hidden % self.exe.tensor_par == 0, (
      f"We should split hidden={self.app.hidden} between"
      f" {self.exe.tensor_par} TP partitions evenly")
    assert self.app.feedforward % self.exe.tensor_par == 0, (
      f"We should split feedforward={self.app.feedforward} between"
      f" {self.exe.tensor_par} TP partitions evenly")
    assert self.app.attn_heads % self.exe.tensor_par == 0, (
      f"We should split {self.app.attn_heads} attn_heads between"
      f" {self.exe.tensor_par} TP partitions evenly")

    self._llm_block.append(Fork(
      "AttnBlock_Fork",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      2,
      needs_recompute=recompute_flag,
      # We account this activation when consider Residual and LayerNorm
      activation_stored=True))
    self._llm_block.append(LayerNorm(
      "AttnBlock_LayerNorm",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      self.app.hidden,
      needs_recompute=recompute_flag,
      # Activation is stored in Fork instead
      activation_stored=False,
      activation_reused=True))
    if self.exe.tensor_par_overlap == 'none':
      self._llm_block.append(TPComm(
        "AttnBlock_F",
        self.sys,
        self._activation_size,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        # We only compute flops/mem analyzing this layers, comm analyzed later
        # This is conservative estimate that does not consider p2p_rs_ag
        # because we don't differentiate between edge and middle blocks here
        tensor_par_comm_type=self.exe.tensor_par_comm_type,
        conjugate=False,
        in_network_reduction=self.exe.in_network_reduction,
        needs_recomm=recompute_ag_flag))
      self._llm_block.append(Fork(
        "AttnBlock_Multihead_Fork",
        self.sys,
        self._activation_size,
        3,
        needs_recompute=recompute_ag_flag,
        # With seq_par, we use activations from Comm layers to reflect that
        # they're split, otherwise we keep full size activations
        activation_stored=(not recompute_ag_flag)))
      self._llm_block.append(Linear(
        "AttnBlock_Query",
        self.sys,
        self._batch_seq,
        self.app.hidden,
        self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,
        needs_recompute=recompute_flag,
        # Activation is stored in Fork instead,
        activation_stored=False,
        activation_reused=True))
      if self.exe.attention_type == 'multihead':
        self._llm_block.append(Linear(
          "AttnBlock_Key",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
        self._llm_block.append(Linear(
          "AttnBlock_Value",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
      elif self.exe.attention_type == 'multiquery':
        # Multiqueri attention uses the same K, V for all "heads" resulting in
        # smaller Wk and Wv, less matmul, faster inference
        self._llm_block.append(Linear(
          "AttnBlock_Key",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_size,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
        self._llm_block.append(Linear(
          "AttnBlock_Value",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_size,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
      else:
        raise self.Error('Wrong attention type', self.exe.attention_type)
    else:
      if self.exe.attention_type == 'multihead':
        self._llm_block.append(LinearOverlapped(
          "AttnBlock_QKV_AG",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_heads * self.app.attn_size *3,          # Q, K, V
          self.exe.tensor_par_comm_type,
          self.exe.tensor_par,
          self.exe.tensor_par_net,
          self.exe.tensor_par,
          conjugate=False,
          tp_overlap=self.exe.tensor_par_overlap,
          needs_recompute=recompute_flag,
          needs_recomm=recompute_ag_flag))
      elif self.exe.attention_type == 'multiquery':
        self._llm_block.append(LinearOverlapped(
          "AttnBlock_Query_AG",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_heads * self.app.attn_size,
          self.exe.tensor_par_comm_type,
          self.exe.tensor_par,
          self.exe.tensor_par_net,
          self.exe.tensor_par,
          conjugate=False,
          tp_overlap=self.exe.tensor_par_overlap,
          needs_recompute=recompute_flag,
          needs_recomm=recompute_ag_flag))
        self._llm_block.append(Fork(
          "AttnBlock_KV_Fork",
          self.sys,
          self._activation_size,
          2,
          needs_recompute=recompute_ag_flag,
          # With seq_par, we use activations from Comm layers to reflect that
          # they're split, otherwise we keep full size activations
          activation_stored=(not recompute_ag_flag)))
        self._llm_block.append(Linear(
          "AttnBlock_Key",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_size,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
        self._llm_block.append(Linear(
          "AttnBlock_Value",
          self.sys,
          self._batch_seq,
          self.app.hidden,
          self.app.attn_size,
          needs_recompute=recompute_flag,
          # Activation is stored in Fork instead,
          activation_stored=False,
          activation_reused=True))
      else:
        raise self.Error('Wrong attention type', self.exe.attention_type)
    self._llm_block.append(BatchMatMul(
      "AttnBlock_Multihead_Key_Query",
      self.sys,
      self.exe.microbatch_size * self.app.attn_heads // self.exe.tensor_par,
      self.app.seq_size,
      self.app.attn_size,
      self.app.seq_size,
      needs_recompute=recompute_attn_flag,
      output_stored=(not recompute_attn_flag)))
    self._llm_block.append(SoftMax(
      "AttnBlock_Multihead_SoftMax",
      self.sys,
      self.app.attn_heads // self.exe.tensor_par * \
        self.app.seq_size**2 * self.exe.microbatch_size,
      needs_recompute=recompute_attn_flag,
      output_stored=(not recompute_attn_flag)))
    self._llm_block.append(DropOut(
      "AttnBlock_Multihead_DropOut",
      self.sys,
      self.app.attn_heads // self.exe.tensor_par * \
        self.app.seq_size**2 * self.exe.microbatch_size,
      needs_recompute=recompute_attn_flag,
      activation_stored=(not recompute_attn_flag)))
    self._llm_block.append(BatchMatMul(
      "AttnBlock_Multihead_Attn",
      self.sys,
      self.exe.microbatch_size * self.app.attn_heads // self.exe.tensor_par,
      self.app.seq_size,
      self.app.seq_size,
      self.app.attn_heads * self.app.attn_size // self.app.attn_heads,
      needs_recompute=recompute_flag))
    if self.exe.tensor_par_overlap == 'none':
      self._llm_block.append(Linear(
        "AttnBlock_MLP",
        self.sys,
        self._batch_seq,
        self.app.attn_heads * self.app.attn_size // self.exe.tensor_par,
        self.app.hidden,
        needs_recompute=recompute_flag))
      self._llm_block.append(TPComm(
        "AttnBlock_G",
        self.sys,
        self._activation_size,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        # We only compute flops/mem analyzing this layers, comm analyzed later
        # This is conservative estimate that does not consider p2p_rs_ag
        # because we don't differentiate between edge and middle blocks here
        tensor_par_comm_type=self.exe.tensor_par_comm_type,
        conjugate=True,
        in_network_reduction=self.exe.in_network_reduction,
        needs_recomm=recompute_flag,
        # We don't store input to RS/AR
        activation_stored=False))
    else:
      self._llm_block.append(LinearOverlapped(
        "AttnBlock_MLP_RS",
        self.sys,
        self._batch_seq,
        self.app.attn_heads * self.app.attn_size,
        self.app.hidden,
        self.exe.tensor_par_comm_type,
        self.exe.tensor_par,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        conjugate=True,
        tp_overlap=self.exe.tensor_par_overlap,
        needs_recompute=recompute_flag,
        needs_recomm=recompute_flag))
    self._llm_block.append(DropOut(
      "AttnBlock_DropOut",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      needs_recompute=recompute_flag))
    self._llm_block.append(ElementWise(
      "AttnBlock_Residual",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      needs_recompute=recompute_flag,
      # Activation is stored in Fork instead
      activation_stored=False,
      activation_reused=True))

  def _build_mlp_block(self):
    recompute_flag = self.exe.activation_recompute == "full"
    recompute_ag_flag = recompute_flag or self.exe.seq_par_ag_redo

    self._llm_block.append(Fork(
      "MlpBlock_Fork",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      2,
      needs_recompute=recompute_flag,
      # We account this activation when consider Residual and LayerNorm
      activation_stored=True))
    self._llm_block.append(LayerNorm(
      "MlpBlock_LayerNorm",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      self.app.hidden,
      needs_recompute=recompute_flag,
      # Activation is stored in Fork instead
      activation_stored=False,
      activation_reused=True))
    if self.exe.tensor_par_overlap == 'none':
      self._llm_block.append(TPComm(
        "MlpBlock_F",
        self.sys,
        # We only do compute/mem analyzing this layers, comm analyzed later
        # We keep extra mem buffer for comm, consider full tensor mem access
        # to be consistent with how much data comm moves/touches
        # This is conservative estimate that does not consider p2p_rs_ag
        # because we don't differentiate between edge and middle blocks here
        self._activation_size,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        tensor_par_comm_type=self.exe.tensor_par_comm_type,
        conjugate=False,
        in_network_reduction=self.exe.in_network_reduction,
        needs_recomm=recompute_ag_flag))
      self._llm_block.append(Linear(
        "MlpBlock_Mlp1",
        self.sys,
        self._batch_seq,
        self.app.hidden,
        self.app.feedforward // self.exe.tensor_par,
        needs_recompute=recompute_flag,
        # With seq_par, we use activations from Comm layers to reflect that
        # they're split, otherwise we keep full size activations
        activation_stored=(not recompute_ag_flag)))
    else:
      self._llm_block.append(LinearOverlapped(
        "MlpBlock_Mlp1_AG",
        self.sys,
        self._batch_seq,
        self.app.hidden,
        self.app.feedforward,
        self.exe.tensor_par_comm_type,
        self.exe.tensor_par,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        conjugate=False,
        tp_overlap=self.exe.tensor_par_overlap,
        needs_recompute=recompute_flag,
        needs_recomm=recompute_ag_flag))
    self._llm_block.append(GeLU(
      "MlpBlock_GeLU",
      self.sys,
      self.app.feedforward * self._batch_seq // self.exe.tensor_par,
      needs_recompute=recompute_flag,
      fused=self.exe.fused_activation))
    if self.exe.tensor_par_overlap == 'none':
      self._llm_block.append(Linear(
        "MlpBlock_Mlp2",
        self.sys,
        self._batch_seq,
        self.app.feedforward // self.exe.tensor_par,
        self.app.hidden,
        needs_recompute=recompute_flag))
      self._llm_block.append(TPComm(
        "MlpBlock_G",
        self.sys,
        self._activation_size,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        # We only compute flops/mem analyzing this layers, comm analyzed later
        # This is conservative estimate that does not consider p2p_rs_ag
        # because we don't differentiate between edge and middle blocks here
        tensor_par_comm_type=self.exe.tensor_par_comm_type,
        conjugate=True,
        in_network_reduction=self.exe.in_network_reduction,
        needs_recomm=recompute_flag,
        # We don't store input to RS/AR
        activation_stored=False))
    else:
      self._llm_block.append(LinearOverlapped(
        "MlpBlock_Mlp2_RS",
        self.sys,
        self._batch_seq,
        self.app.feedforward,
        self.app.hidden,
        self.exe.tensor_par_comm_type,
        self.exe.tensor_par,
        self.exe.tensor_par_net,
        self.exe.tensor_par,
        conjugate=True,
        tp_overlap=self.exe.tensor_par_overlap,
        needs_recompute=recompute_flag,
        needs_recomm=recompute_flag))
    self._llm_block.append(DropOut(
      "MlpBlock_DropOut",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      needs_recompute=recompute_flag))
    self._llm_block.append(ElementWise(
      "MlpBlock_Residual",
      self.sys,
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      pick(self.exe._sequence_par, self._seq_par_activation_size,
           self._activation_size),
      needs_recompute=recompute_flag,
      # Activation is stored in Fork instead
      activation_stored=False,
      activation_reused=True))

  def compile(self, sys, exe):
    assert not self._compiled
    assert isinstance(exe, self.Execution)
    self.exe = exe
    assert isinstance(sys, System)
    self.sys = sys
    self._check_network_assignments()

    self.sys.set_datatype(self.exe.datatype)

    # If we have number of blocks not divisible by PP, we can allocate the
    # reminder of the blocks on the first num_block % PP Procs and block
    # "bubbles" on the last PP - (num_block % PP) Procs. To reflect that,
    # we round up blocks_per_proc. We report time for Proc0. In that case
    # its bubble time is `PP - (num_block % PP)` blocks shorter
    self._blocks_per_proc = self.app.num_blocks // self.exe.pipeline_par
    if self.app.num_blocks % self.exe.pipeline_par != 0:
      self._blocks_per_proc += 1
      self._bubble_reduction_blocks = self.exe.pipeline_par - (
        self.app.num_blocks % self.exe.pipeline_par)
    else:
      self._bubble_reduction_blocks = 0
    if self.exe.pipeline_interleaving > self._blocks_per_proc:
      raise self.Error('Pipeline interleaving must be less than or equal to '
                       'the number of blocks per processor')
    if self._blocks_per_proc % self.exe.pipeline_interleaving != 0:
      raise self.Error('Pipeline interleaving must be a factor value of the '
                       'number of blocks per processor')
    self._bytes_per_element = System.TypeSizes[self.exe.datatype]

    # Checks that enough blocks per processor exist if offloading is being
    # performed
    if (self.exe.weight_offload or self.exe.activations_offload or
        self.exe.optimizer_offload) and (self._blocks_per_proc <= 2):
      raise self.Error('Offloading requires each processor to handle at least'
                       ' 3 blocks')

    # A chunk is a set of blocks for microbatch before passing to the next
    # processor in the pipeline. Each chunk is modeled as a base
    # block that is repeated N-1 times and followed by 1 edge block.
    # Recommunication time is the same in both base and edge blocks.
    self._blocks_per_chunk = \
      self._blocks_per_proc // self.exe.pipeline_interleaving
    assert self._blocks_per_proc % self._blocks_per_chunk == 0, \
      "PP interleaving should evenly devide {self._blocks_per_proc} blocks"
    self._chunks_per_proc = self._blocks_per_proc // self._blocks_per_chunk
    assert self._chunks_per_proc == self.exe.pipeline_interleaving, \
      "Number of chunks should be equal to pipeline_interleaving"
    self._baseblocks_per_chunk = self._blocks_per_chunk - 1
    self._edgeblocks_per_chunk = 1

    # Build model during the compilation step
    self._batch_seq = self.exe.microbatch_size * self.app.seq_size
    self._activation_size = self._batch_seq * self.app.hidden
    self._batch_seq_par = self._batch_seq // self.exe.tensor_par
    if self.exe._sequence_par or self.exe._pipeline_par_rs_ag:
      assert self._batch_seq % self.exe.tensor_par == 0, (
        f"We should split batch_seq={self._batch_seq} between"
        f" {self.exe.tensor_par} TP partitions evenly")
    self._seq_par_activation_size = self._batch_seq_par * self.app.hidden
    self._build_attn_block()
    self._build_mlp_block()
    for layer in self._llm_block:
      layer.set_bytes_per_element(self._bytes_per_element)
      if self.exe.optimizer_sharding:
        layer.shard_optimizer(self.exe.data_par)
    self._compiled = True

  def _check_network_assignments(self):
    used = [False] * self.sys.num_networks
    size = [1] * self.sys.num_networks

    assert self.exe.tensor_par_net < self.sys.num_networks
    assert self.exe.pipeline_par_net < self.sys.num_networks
    assert self.exe.data_par_net < self.sys.num_networks

    if self.exe.tensor_par > 1:
      used[self.exe.tensor_par_net] = True
      size[self.exe.tensor_par_net] *= self.exe.tensor_par
    self._tp_net = self.sys.get_network(self.exe.tensor_par_net)

    if self.exe.pipeline_par > 1:
      used[self.exe.pipeline_par_net] = True
      size[self.exe.pipeline_par_net] *= self.exe.pipeline_par
    self._pp_net = self.sys.get_network(self.exe.pipeline_par_net)

    if self.exe.data_par > 1:
      used[self.exe.data_par_net] = True
      size[self.exe.data_par_net] *= self.exe.data_par
    self._dp_net = self.sys.get_network(self.exe.data_par_net)

    for tier_used, tier_size, tier in zip(
        used, size, range(self.sys.num_networks)):
      if tier_used:
        if tier_size > self.sys.get_network(tier).size:
          raise self.Error(f'Network tier{tier} isn\'t big enough')
        if (self.sys.get_network(tier).must_be_filled and
            self.sys.get_network(tier).size % tier_size != 0):
          raise self.Error(f'Network tier{tier} isn\'t fully used')

  def _compute_block_stats(self):
    """
    This function computes the statistics for one microbatch on a single block.
    This only computes flops, flop time, and communication sizes. Since
    tensor and pipeline parallelism cause different communication operations to
    occur at the full batch level, the communication times are computed later.
    """
    if self.exe.training and self.exe.activation_recompute == "full":
      self._block_act_checkpoint_size = \
        self._activation_size * self._bytes_per_element
    else:
      self._block_act_checkpoint_size = 0

    # Initializes values to zero for accumulation in layer loop
    self._block_fw_flops = 0
    self._block_fw_flops_time = 0
    self._block_fw_mem_accessed = 0
    self._block_fw_mem_time = 0
    self._block_fw_time = 0
    self._baseblock_fw_tp_size = 0
    self._edgeblock_fw_tp_size = 0
    self._baseblock_fw_tp_time = 0
    self._edgeblock_fw_tp_time = 0
    self._baseblock_fw_tp_time_exposed = 0
    self._edgeblock_fw_tp_time_exposed = 0
    self._block_weight_space = 0
    self._block_act_working_space = 0
    self._block_act_storage_space = 0
    # We use this block for self.exe.training, but initialize anyway
    self._block_re_flops = 0
    self._block_re_flops_time = 0
    self._block_re_mem_accessed = 0
    self._block_re_mem_time = 0
    self._block_re_time = 0
    self._baseblock_recomm_size = 0
    self._edgeblock_recomm_size = 0
    self._baseblock_recomm_time = 0
    self._edgeblock_recomm_time = 0
    self._baseblock_recomm_time_exposed = 0
    self._edgeblock_recomm_time_exposed = 0
    self._block_agrad_flops = 0
    self._block_agrad_flops_time = 0
    self._block_agrad_mem_accessed = 0
    self._block_agrad_mem_time = 0
    self._block_agrad_time = 0
    self._baseblock_agrad_tp_size = 0
    self._edgeblock_agrad_tp_size = 0
    self._baseblock_agrad_tp_time = 0
    self._edgeblock_agrad_tp_time = 0
    self._baseblock_agrad_tp_time_exposed = 0
    self._edgeblock_agrad_tp_time_exposed = 0
    self._block_wgrad_flops = 0
    self._block_wgrad_flops_time = 0
    self._block_wgrad_mem_accessed = 0
    self._block_wgrad_mem_time = 0
    self._block_wgrad_time = 0
    self._block_optim_flops = 0
    self._block_optim_flops_time = 0
    self._block_optim_mem_accessed = 0
    self._block_optim_mem_time = 0
    self._block_optim_time = 0
    self._block_weight_grad_space = 0
    self._block_weight_grad_space_no_sharding = 0
    self._block_act_grad_space = 0
    self._block_optimizer_space = 0
    self._tp_bw_overlap_req = 0

    prev_layer_recompute = False
    for layer in self._llm_block:
      # Add flops/bytes/times per layer
      self._block_fw_flops += layer.get_fw_flops()
      self._block_fw_flops_time += layer.compute_flops_time("fw")
      self._block_fw_mem_accessed += layer.get_fw_mem_accessed()
      self._block_fw_mem_time += layer.compute_mem_time("fw")
      self._block_fw_time += layer.compute_processing_time("fw")
      self._baseblock_fw_tp_size += layer.get_comm_bytes("fw",
        baseblock=True)
      self._edgeblock_fw_tp_size += layer.get_comm_bytes("fw",
        baseblock=False)
      self._baseblock_fw_tp_time += layer.compute_net_time("fw",
        baseblock=True)
      self._edgeblock_fw_tp_time += layer.compute_net_time("fw",
        baseblock=False)
      self._baseblock_fw_tp_time_exposed += layer.get_exposed_net_time("fw",
        baseblock=True)
      self._edgeblock_fw_tp_time_exposed += layer.get_exposed_net_time("fw",
        baseblock=False)
      self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,
        layer.get_required_bandwidth("fw", baseblock=True))
      self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,
        layer.get_required_bandwidth("fw", baseblock=False))
      if self.exe.training:
        if layer.get_recompute_flag():
          self._block_re_flops += self._block_fw_flops
          self._block_re_flops_time += self._block_fw_flops_time
          self._block_re_mem_accessed += self._block_fw_mem_accessed
          self._block_re_mem_time += self._block_fw_mem_time
          self._block_re_time += layer.compute_processing_time("fw")
        if layer.get_recomm_flag():
          self._baseblock_recomm_size += layer.get_comm_bytes("wgrad",
            baseblock=True)
          self._edgeblock_recomm_size += layer.get_comm_bytes("wgrad",
            baseblock=False)
          self._baseblock_recomm_time += layer.compute_net_time("wgrad",
            baseblock=True)
          self._edgeblock_recomm_time += layer.compute_net_time("wgrad",
            baseblock=False)
          self._baseblock_recomm_time_exposed += layer.get_exposed_net_time(
            "wgrad", baseblock=True)
          self._edgeblock_recomm_time_exposed += layer.get_exposed_net_time(
            "wgrad", baseblock=False)
        self._block_agrad_flops += layer.get_agrad_flops()
        self._block_agrad_flops_time += layer.compute_flops_time("agrad")
        self._block_agrad_mem_accessed += layer.get_agrad_mem_accessed()
        self._block_agrad_mem_time += layer.compute_mem_time("agrad")
        self._block_agrad_time += layer.compute_processing_time("agrad")
        self._baseblock_agrad_tp_size += layer.get_comm_bytes("agrad",
          baseblock=True)
        self._edgeblock_agrad_tp_size += layer.get_comm_bytes("agrad",
          baseblock=False)
        self._baseblock_agrad_tp_time += layer.compute_net_time("agrad",
          baseblock=True)
        self._edgeblock_agrad_tp_time += layer.compute_net_time("agrad",
          baseblock=False)
        self._baseblock_agrad_tp_time_exposed += layer.get_exposed_net_time(
          "agrad", baseblock=True)
        self._edgeblock_agrad_tp_time_exposed += layer.get_exposed_net_time(
          "agrad", baseblock=False)
        self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,
          layer.get_required_bandwidth("agrad", baseblock=True))
        self._tp_bw_overlap_req = max(self._tp_bw_overlap_req,
          layer.get_required_bandwidth("agrad", baseblock=False))
        self._block_wgrad_flops += layer.get_wgrad_flops()
        self._block_wgrad_flops_time += layer.compute_flops_time("wgrad")
        self._block_wgrad_mem_accessed += layer.get_wgrad_mem_accessed()
        self._block_wgrad_mem_time += layer.compute_mem_time("wgrad")
        self._block_wgrad_time += layer.compute_processing_time("wgrad")
        self._block_optim_flops += layer.get_optim_step_flops()
        self._block_optim_flops_time += layer.compute_flops_time("optim")
        self._block_optim_mem_accessed += layer.get_optim_step_mem_accessed()
        self._block_optim_mem_time += layer.compute_mem_time("optim")
        self._block_optim_time += layer.compute_processing_time("optim")

      # Accumulate space requirements per block
      self._block_weight_space += layer.get_weight()
      if not layer.reuses_activation():
        self._block_act_working_space += layer.get_activation()
      self._block_act_storage_space += layer.get_activation()
      if self.exe.training:
        if not layer.stores_output():
          self._block_act_storage_space -= layer.get_output()
        if not layer.stores_activation():
          self._block_act_storage_space -= layer.get_activation()
        self._block_weight_grad_space += layer.get_weight_grad()
        self._block_weight_grad_space_no_sharding += layer.get_weight_grad(
          sharded=False)
        self._block_act_grad_space += layer.get_activation_grad()
        self._block_optimizer_space += layer.get_optimizer()

      self.log.debug("%s %s %s", layer.name, 'Recompute flag:',
                     str(layer.get_recompute_flag()))
      self.log.debug("%s %s %s", layer.name, 'Recomm flag:',
                     str(layer.get_recomm_flag()))
      self.log.debug("%s %s %s", layer.name, 'Stores activation:',
                     str(layer.stores_activation()))
      self.log.debug("%s %s %s", layer.name, 'Reuses activation:',
                     str(layer.reuses_activation()))
      self.log.debug("%s %s %s", layer.name, 'Stores output:',
                     str(layer.stores_output()))
      self.log.debug("%s %s %s", layer.name, 'FW flops:',
                     human_format(layer.get_fw_flops(), 'flops'))
      self.log.debug("%s %s %s", layer.name, 'FW num inputs:',
                     human_format(layer.inputs_size, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'FW num output:',
                     human_format(layer.output_size, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'FW num weights:',
                     human_format(layer.weight_space, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'FW mem:',
                     human_format(layer.get_fw_mem_accessed(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'FW baseblock comm tile size:',
                     human_format(layer.get_comm_tile("fw", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'FW edgeblock comm tile size:',
                     human_format(layer.get_comm_tile("fw", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'FW baseblock comm size:',
                     human_format(layer.get_comm_bytes("fw", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'FW edgeblock comm size:',
                     human_format(layer.get_comm_bytes("fw", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %.3e", layer.name, 'FW net link time:',
                     layer.compute_net_time("fw"))
      self.log.debug("%s %s %.3e", layer.name, 'FW net exposed time:',
                     layer.get_exposed_net_time("fw"))
      self.log.debug("%s %s %.3e", layer.name, 'FW time:',
                     layer.compute_processing_time("fw"))
      self.log.debug("%s %s %s", layer.name, 'BW flops:',
                     human_format(
                      layer.get_agrad_flops() + layer.get_wgrad_flops(),
                      'flops'))
      self.log.debug("%s %s %s", layer.name, 'BW num Wgrads:',
                     human_format(layer.weight_grads, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'BW num Agrads:',
                     human_format(layer.activation_grads, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'BW num Igrads:',
                     human_format(layer.inputs_size, 'base2'))
      self.log.debug("%s %s %s", layer.name, 'BW mem:',
                     human_format(
                      layer.get_agrad_mem_accessed() +
                      layer.get_wgrad_mem_accessed(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'BW baseblock comm tile size:',
                     human_format(layer.get_comm_tile("agrad", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'BW edgeblock comm tile size:',
                     human_format(layer.get_comm_tile("agrad", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'BW baseblock comm size:',
                     human_format(layer.get_comm_bytes("agrad", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'BW edgeblock comm size:',
                     human_format(layer.get_comm_bytes("agrad", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %.3e", layer.name, 'BW net link time:',
                     layer.compute_net_time("agrad"))
      self.log.debug("%s %s %.3e", layer.name, 'BW net exposed time:',
                     layer.get_exposed_net_time("agrad"))
      self.log.debug("%s %s %.3e", layer.name, 'BW time:',
                     layer.compute_processing_time("agrad") +
                     layer.compute_processing_time("wgrad"))
      self.log.debug("%s %s %s", layer.name, 'Recomm baseblock comm tile size:',
                     human_format(layer.get_comm_tile("wgrad", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Recomm edgeblock comm tile size:',
                     human_format(layer.get_comm_tile("wgrad", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Recomm baseblock comm size:',
                     human_format(layer.get_comm_bytes("wgrad", baseblock=True),
                     'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Recomm edgeblock comm size:',
                     human_format(layer.get_comm_bytes("wgrad", baseblock=False),
                     'bytes'))
      self.log.debug("%s %s %.3e", layer.name, 'Recomm net link time:',
                     layer.compute_net_time("wgrad"))
      self.log.debug("%s %s %.3e", layer.name, 'Recomm net exposed time:',
                     layer.get_exposed_net_time("wgrad"))
      self.log.debug("%s %s %s", layer.name, 'Optim flops:',
                     human_format(layer.get_optim_step_flops(), 'flops'))
      self.log.debug("%s %s %s", layer.name, 'BW Optimizer size:',
                     human_format(layer.get_optimizer(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Optim mem:',
                     human_format(layer.get_optim_step_mem_accessed(), 'bytes'))
      self.log.debug("%s %s %.3e", layer.name, 'Optim time:',
                     layer.compute_processing_time("optim"))
      self.log.debug("%s %s %.3e", layer.name, 'Recompute:',
                     layer.get_recompute_flag())
      self.log.debug("%s %s %s", layer.name, 'Recompute mem saving:',
                     human_format(layer.stores_output() * \
                       layer.get_output(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Weight:',
                     human_format(layer.get_weight(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Act:',
                     human_format(layer.get_activation(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Weight grad:',
                     human_format(layer.get_weight_grad(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Act grad:',
                     human_format(layer.get_activation_grad(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Optim:',
                     human_format(layer.get_optimizer(), 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Weight:',
                     human_format(self._block_weight_space, 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Act Working space:',
                     human_format(self._block_act_working_space, 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Act Storage space:',
                     human_format(self._block_act_storage_space, 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Weight grad:',
                     human_format(self._block_weight_grad_space, 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Act grad:',
                     human_format(self._block_act_grad_space, 'bytes'))
      self.log.debug("%s %s %s", layer.name, 'Incremental Optim:',
                     human_format(self._block_optimizer_space, 'bytes'))
      prev_layer_recompute = layer.get_recompute_flag()
    if self.exe.activation_recompute == 'full':
      self._block_act_storage_space = 0

    # Sets the PP communication operation size
    if self.exe.pipeline_par > 1:
      if self.exe._pipeline_par_rs_ag:
        self._block_fw_pp_size = self._seq_par_activation_size * \
          self._bytes_per_element
      else:
        self._block_fw_pp_size = self._activation_size * \
          self._bytes_per_element
    else:
      self._block_fw_pp_size = 0

    # When training, BW sizes for TP and PP are same as FW
    if self.exe.training:
      self._block_bw_pp_size = self._block_fw_pp_size
    else:
      self._block_bw_pp_size = 0

    self.log.debug("%s %s", 'TP comm FW baseblock size:',
                   human_format(self._baseblock_fw_tp_size, 'bytes'))
    self.log.debug("%s %s", 'TP comm FW edgeblock size:',
                   human_format(self._edgeblock_fw_tp_size, 'bytes'))
    self.log.debug("%s %s", 'PP comm FW size:',
                   human_format(self._block_fw_pp_size, 'bytes'))
    self.log.debug("%s %s", 'TP comm BW baseblock size:',
                   human_format(self._baseblock_agrad_tp_size, 'bytes'))
    self.log.debug("%s %s", 'TP comm BW edgeblock size:',
                   human_format(self._edgeblock_agrad_tp_size, 'bytes'))
    self.log.debug("%s %s", 'PP comm BW size:',
                   human_format(self._block_bw_pp_size, 'bytes'))
    self.log.debug("%s %s", 'TP recomm baseblock size:',
                   human_format(self._baseblock_recomm_size, 'bytes'))
    self.log.debug("%s %s", 'TP recomm edgeblock size:',
                   human_format(self._edgeblock_recomm_size, 'bytes'))
    self.log.debug("%s %s", 'TP comm required bandwidth for tiled overlap:',
                   human_format(self._tp_bw_overlap_req, 'bandwidth'))

  def _compute_batch_stats(self):
    """
    This function computes the statistics for a full batch. This uses the per
    microbatch per block statistics from the prior function (see above).
    """
    # Total stats for compute and memory
    mult = self._blocks_per_proc * self.exe._num_microbatches
    self._fw_flops = mult * self._block_fw_flops
    self._fw_flops_time = mult * self._block_fw_flops_time
    self._fw_mem_accessed = mult * self._block_fw_mem_accessed
    self._fw_mem_time = mult * self._block_fw_mem_time
    self._fw_time = mult * self._block_fw_time
    self._re_flops = mult * self._block_re_flops
    self._re_flops_time = mult * self._block_re_flops_time
    self._re_mem_accessed = mult * self._block_re_mem_accessed
    self._re_mem_time = mult * self._block_re_mem_time
    self._re_time = mult * self._block_re_time
    self._agrad_flops = mult * self._block_agrad_flops
    self._agrad_flops_time = mult * self._block_agrad_flops_time
    self._agrad_mem_accessed = mult * self._block_agrad_mem_accessed
    self._agrad_mem_time = mult * self._block_agrad_mem_time
    self._agrad_time = mult * self._block_agrad_time
    self._wgrad_flops = mult * self._block_wgrad_flops
    self._wgrad_flops_time = mult * self._block_wgrad_flops_time
    self._wgrad_mem_accessed = mult * self._block_wgrad_mem_accessed
    self._wgrad_mem_time = mult * self._block_wgrad_mem_time
    self._wgrad_time = mult * self._block_wgrad_time
    self._optim_flops = self._blocks_per_proc * self._block_optim_flops
    self._optim_flops_time = self._blocks_per_proc * self._block_optim_flops_time
    self._optim_mem_accessed = self._blocks_per_proc * self._block_optim_mem_accessed
    self._optim_mem_time = self._blocks_per_proc * self._block_optim_mem_time
    self._optim_time = self._blocks_per_proc * self._block_optim_time

    # These TP numbers are for total times for all blocks in all chunks
    tp_fw_comm_time = self.exe._num_microbatches * self._chunks_per_proc * (
      (self._baseblocks_per_chunk * self._baseblock_fw_tp_time) +
      (self._edgeblocks_per_chunk * self._edgeblock_fw_tp_time))
    tp_fw_comm_time_exposed = \
      self.exe._num_microbatches * self._chunks_per_proc * (
        (self._baseblocks_per_chunk * self._baseblock_fw_tp_time_exposed) +
        (self._edgeblocks_per_chunk * self._edgeblock_fw_tp_time_exposed))
    tp_bw_comm_time = self.exe._num_microbatches * self._chunks_per_proc * (
      self._baseblocks_per_chunk * self._baseblock_agrad_tp_time +
      self._edgeblocks_per_chunk * self._edgeblock_agrad_tp_time)
    tp_bw_comm_time_exposed = \
      self.exe._num_microbatches * self._chunks_per_proc * (
        self._baseblocks_per_chunk * self._baseblock_agrad_tp_time_exposed +
        self._edgeblocks_per_chunk * self._edgeblock_agrad_tp_time_exposed)
    tp_recomm_time = self.exe._num_microbatches * self._chunks_per_proc * (
      (self._baseblocks_per_chunk * self._baseblock_recomm_time) +
      (self._edgeblocks_per_chunk * self._edgeblock_recomm_time))
    tp_recomm_time_exposed = \
      self.exe._num_microbatches * self._chunks_per_proc * (
        (self._baseblocks_per_chunk * self._baseblock_recomm_time_exposed) +
        (self._edgeblocks_per_chunk * self._edgeblock_recomm_time_exposed))

    # Per chunk PP comm time
    chunk_fw_pp_time = self._pp_net.time('p2p', self._block_fw_pp_size, 2)
    chunk_bw_pp_time = self._pp_net.time('p2p', self._block_bw_pp_size, 2)

    # Determines number of times PP causes pipeline p2p communications per
    # chunk during the forward and backward pass (equal to chunks per proc)
    if self.exe.pipeline_par > 1:
      num_fw_pp_p2ps = self._chunks_per_proc
      if self.exe.training:
        num_bw_pp_p2ps = self._chunks_per_proc
      else:
        num_bw_pp_p2ps = 0
    else:
      num_fw_pp_p2ps = 0
      num_bw_pp_p2ps = 0

    # These PP numbers are for total times for all blocks and all microbatches
    pp_fw_comm_time = self.exe._num_microbatches * num_fw_pp_p2ps * \
      chunk_fw_pp_time
    pp_bw_comm_time = self.exe._num_microbatches * num_bw_pp_p2ps * \
      chunk_bw_pp_time

    # Aggregrates metrics
    self._tp_comm_time_link = tp_fw_comm_time + tp_bw_comm_time
    self._tp_comm_time_exposed = (tp_fw_comm_time_exposed +
      tp_bw_comm_time_exposed)
    self._recomm_time_link = tp_recomm_time
    self._recomm_time_exposed = tp_recomm_time_exposed
    self._pp_comm_time_link = pp_fw_comm_time + pp_bw_comm_time
    self._pp_comm_time_exposed = self._pp_comm_time_link

    self.log.debug("%s %s", 'TP comm baseblock FW time:',
      self._baseblock_fw_tp_time)
    self.log.debug("%s %s", 'TP comm edgeblock FW time:',
      self._edgeblock_fw_tp_time)
    self.log.debug("%s %s", 'TP comm FW time:', tp_fw_comm_time)
    self.log.debug("%s %s", 'TP comm baseblock FW exposed time:',
      self._baseblock_fw_tp_time_exposed)
    self.log.debug("%s %s", 'TP comm edgeblock FW exposed time:',
      self._edgeblock_fw_tp_time_exposed)
    self.log.debug("%s %s", 'TP comm FW exposed time:', tp_fw_comm_time_exposed)
    self.log.debug("%s %s", 'TP comm baseblock BW time:',
      self._baseblock_agrad_tp_time)
    self.log.debug("%s %s", 'TP comm edgeblock BW time:',
      self._edgeblock_agrad_tp_time)
    self.log.debug("%s %s", 'TP comm BW time:', tp_bw_comm_time)
    self.log.debug("%s %s", 'TP comm baseblock BW exposed time:',
      self._baseblock_agrad_tp_time_exposed)
    self.log.debug("%s %s", 'TP comm edgeblock BW exposed time:',
      self._edgeblock_agrad_tp_time_exposed)
    self.log.debug("%s %s", 'TP comm BW exposed time:',
      tp_bw_comm_time_exposed)
    self.log.debug("%s %s", 'PP comm chunk FW time:', chunk_fw_pp_time)
    self.log.debug("%s %s", 'PP comm chunk BW time:', chunk_bw_pp_time)
    self.log.debug("%s %s", 'PP comm FW time:', pp_fw_comm_time)
    self.log.debug("%s %s", 'PP comm BW time:', pp_bw_comm_time)

    # Bubble forms between i-th microbatch FW and BW passes on the 1st GPU.
    # With no interleaving between blocks, it includes
    # L/gpu x microbatch_time x (p-1) x Tcycle, where cycle includes both
    # FW and BW passes, TP and PP communication for FW and BW passes
    # With full interleaving, we only need microbatch_time x (p-1) x Tcycle time
    self._baseblock_fw_time_no_offload = (
      self._block_fw_time + self._baseblock_fw_tp_time_exposed)
    self._edgeblock_fw_time_no_offload = (
      self._block_fw_time + self._edgeblock_fw_tp_time_exposed +
      chunk_fw_pp_time)
    self._baseblock_fw_offload_overhead = max(
      0, self.get_fw_offload_time() + self._block_fw_mem_time -
      self._baseblock_fw_time_no_offload)
    self._edgeblock_fw_offload_overhead = max(
      0, self.get_fw_offload_time() + self._block_fw_mem_time -
      self._edgeblock_fw_time_no_offload)
    self._baseblock_fw_time = (
      self._baseblock_fw_time_no_offload + self._baseblock_fw_offload_overhead)
    self._edgeblock_fw_time = (
      self._edgeblock_fw_time_no_offload + self._edgeblock_fw_offload_overhead)
    # When we consider block BW time, we do not add optimizer step to it
    # because we have optimizer only for last microbatches, while offloading
    # works during the whole backward pass.
    # Optimizer step is overall memory bound streaming task, itt is reasonable
    # to not overlap offloading with optimizer step
    self._baseblock_bw_time_no_offload = (
      self._block_re_time + self._baseblock_recomm_time_exposed +
      self._block_agrad_time + self._block_wgrad_time +
      self._baseblock_agrad_tp_time_exposed)
    self._edgeblock_bw_time_no_offload = (
      self._block_re_time + self._edgeblock_recomm_time_exposed +
      self._block_agrad_time + self._block_wgrad_time +
      self._edgeblock_agrad_tp_time_exposed + chunk_bw_pp_time)
    self._baseblock_bw_offload_overhead = max(
      0, self.get_bw_offload_time() + self._block_agrad_mem_time +
      self._block_wgrad_mem_time -
      self._baseblock_bw_time_no_offload)
    self._edgeblock_bw_offload_overhead = max(
      0, self.get_bw_offload_time() + self._block_agrad_mem_time +
      self._block_wgrad_mem_time -
      self._edgeblock_bw_time_no_offload)
    self._baseblock_bw_time = (
      self._baseblock_bw_time_no_offload + self._baseblock_bw_offload_overhead)
    self._edgeblock_bw_time = (
      self._edgeblock_bw_time_no_offload + self._edgeblock_bw_offload_overhead)
    chunk_fw_time = (
      (self._baseblocks_per_chunk * self._baseblock_fw_time) +
      (self._edgeblocks_per_chunk * self._edgeblock_fw_time))
    chunk_bw_time = (
      (self._baseblocks_per_chunk * self._baseblock_bw_time) +
      (self._edgeblocks_per_chunk * self._edgeblock_bw_time))
    # Can't overlap DP comm with mem accesses, but can overlap with offload
    baseblock_dp_overlap_time = self._baseblock_bw_time - (
      self._block_agrad_mem_time + self._block_wgrad_mem_time +
      self._block_re_mem_time)
    edgeblock_dp_overlap_time = self._edgeblock_bw_time - (
      self._block_agrad_mem_time + self._block_wgrad_mem_time +
      self._block_re_mem_time)
    block_dp_compute_time = (
      self._block_agrad_flops_time + self._block_wgrad_flops_time +
      self._block_re_flops_time)
    if not self.exe.optimizer_sharding:
      # If optimizer is not sharded, we can overlap optimizer step with
      # communication, except for memory access time
      baseblock_dp_overlap_time += (
        self._block_optim_time - self._block_optim_mem_time)
      edgeblock_dp_overlap_time += (
        self._block_optim_time - self._block_optim_mem_time)
      block_dp_compute_time += self._block_optim_flops_time
    if self._dp_net == self._tp_net:
      # Can't overlap DP with TP if in the same network
      baseblock_dp_overlap_time -= (
        self._baseblock_recomm_time + self._baseblock_agrad_tp_time)
      edgeblock_dp_overlap_time -= (
        self._edgeblock_recomm_time + self._edgeblock_agrad_tp_time)
    chunk_dp_overlap_time = (
      self._baseblocks_per_chunk * baseblock_dp_overlap_time +
      self._edgeblocks_per_chunk * edgeblock_dp_overlap_time)
    chunk_dp_compute_time = self._blocks_per_chunk * block_dp_compute_time
    chunk_time = chunk_fw_time + chunk_bw_time
    # Block bubbles appear due to uneven division of blocks by pipeline stages
    # and result in the schedule bubble shorten by the missing edge blocks on
    # the later pipeline stages (missing block case)
    if self._baseblocks_per_chunk > 0:
      # We cut last block of chunk, which is half-edge (has PP comm in the end)
      bubble_reduction_time = self._bubble_reduction_blocks * (
        self._baseblock_fw_time + self._edgeblock_fw_time +
        self._baseblock_bw_time + self._edgeblock_bw_time) / 2
    else:
      # If chunk doesn't have base blocks, we cut edge block
      bubble_reduction_time = self._bubble_reduction_blocks * (
        self._edgeblock_fw_time + self._edgeblock_bw_time)
    # With PP interleaving we assume that we move through every chunk at least
    # PP mini batches. If num_microbatches < PP, then we have extra bubbles
    # (missing microbatches case). We have the bubbles in the last microbatches
    # of every overlappable chunk (all but last chunks). Size of bubbles is
    # equal to microbatch_shortage, same number of microbatches will be missing
    # in the last chunk
    chunks_in_bubble = self.exe.pipeline_par - 1
    num_overlappable_chunks = self.exe.pipeline_interleaving - 1
    microbatch_shortage = self.exe.pipeline_par - (
      self.exe._num_microbatches % self.exe.pipeline_par)
    if self.exe._num_microbatches % self.exe.pipeline_par != 0:
      extra_interleaving_bubbles = num_overlappable_chunks * \
        microbatch_shortage
    else:
      extra_interleaving_bubbles = 0
    self._bubble_time = chunks_in_bubble * chunk_time + (
      extra_interleaving_bubbles * chunk_time - bubble_reduction_time)

    self.log.debug("%s %s", 'Block FW time:', self._block_fw_time)
    self.log.debug("%s %s", 'Baseblock FW time:', self._baseblock_fw_time)
    self.log.debug("%s %s", 'With FW offload overhead time:',
      self._baseblock_fw_offload_overhead)
    self.log.debug("%s %s", 'Edgeblock FW time:', self._edgeblock_fw_time)
    self.log.debug("%s %s", 'With FW offload overhead time:',
      self._edgeblock_fw_offload_overhead)
    self.log.debug("%s %s", 'Baseblock REcomm exposed time:',
      self._baseblock_recomm_time_exposed)
    self.log.debug("%s %s", 'Edgeblock REcomm exposed time:',
      self._edgeblock_recomm_time_exposed)
    self.log.debug("%s %s", 'Block RE time:', self._block_re_time)
    self.log.debug("%s %s", 'Block BW Agrad time:', self._block_agrad_time)
    self.log.debug("%s %s", 'Block BW Wgrad time:', self._block_wgrad_time)
    self.log.debug("%s %s", 'Block optim time:', self._block_optim_time)
    self.log.debug("%s %s", 'Baseblock BW time:', self._baseblock_bw_time)
    self.log.debug("%s %s", 'With BW offload overhead time:',
      self._baseblock_bw_offload_overhead)
    self.log.debug("%s %s", 'Edgeblock BW time:', self._edgeblock_bw_time)
    self.log.debug("%s %s", 'With BW offload overhead time:',
      self._edgeblock_bw_offload_overhead)

    # Determines how long it takes to perform the DP per block
    # This assumes no DP communication overlap (will be adjusted later).
    if self.exe.data_par > 1 and self.exe.training:
      self._block_dp_size = self._block_weight_space
      if self.exe.optimizer_sharding:
        # When performing optimizer sharding, the communication time is a
        # reduce-scatter plus an all-gather.
        self._block_dp_time = (
          self._dp_net.time(
            'reduce_scatter', self._block_dp_size, self.exe.data_par) +
          self._dp_net.time(
            'all_gather', self._block_dp_size, self.exe.data_par))
      else:
        # When not performing optimizer sharding, the communication time is a
        # single all-reduce.
        self._block_dp_time = self._dp_net.time(
          'all_reduce', self._block_dp_size, self.exe.data_par)
    else:
      self._block_dp_size = 0
      self._block_dp_time = 0
    self.log.debug('DP block comm size: %s',
                   human_format(self._block_dp_size, 'bytes'))
    self.log.debug('DP block comm time (no overlap): %.3e',
                   self._block_dp_time)

    # DP overlap happens if DP time for a previous block(s) is lower than
    # microbatch BW pass time for next pack of consecutive blocks
    # If no interleaving, we move a single microbatch through each block
    # and need to overlap DP during a single block single microbatch time
    # In case of full interleaving, we propagate p microbatches through each
    # block and need to overlap DP comm with p-1 microbatches over a block
    # In a mixed case, we can overlap DP communication of several chunks, e.g.
    # non-interleaved blocks (L/gpu / interleaving_factor) over BW pass of
    # p-1 microbatches through the same amount of blocks if memory capacity is
    # enough, or perform offload/prefetch after each block-microbatch
    # For simplicity we count only bandwidth-optimal case
    # Note that uneven extra PP bubbles won't affect overlapping
    if self.exe.data_par > 1 and self.exe.training:
      if self.exe.data_par_overlap:
        # we can evenly overlap all the chunks except for the last one
        # in the last chunk we can overlap only all blocks except for the last
        num_overlappable_chunks = self.exe.pipeline_interleaving - 1
        last_chunk_overlap_size = self._blocks_per_chunk - 1
        # We can overlap DP with BW pass, overlap[ing AR for previous layer
        # with BW for current, except when optimizer sharded. We can't overlap
        # during optimizer step as we RS grads before step and AG weights after
        # Overlappable chunks have overlap size equal to
        # blocks_per_chunk * num_microbatches
        # In case of 1F1B schedule, num_microbatches == pipeline_par
        overlap_window = self.exe.pipeline_par * chunk_dp_overlap_time
        overlap_compute = self.exe.pipeline_par * chunk_dp_compute_time
        chunk_dp_time = self._blocks_per_chunk * self._block_dp_time
        # We may have PP and DP comm colliding if DP comm takes longer than
        # a single chunk BW time. We can't collide more PP than microbatches
        if self._dp_net == self._pp_net:
          if self.exe._num_microbatches % self.exe.pipeline_par != 0:
            num_overlapped_pp = min(
              chunk_dp_time // chunk_bw_time,
              self.exe._num_microbatches % self.exe.pipeline_par)
          else:
            num_overlapped_pp = min(
              chunk_dp_time // chunk_bw_time,
              self.exe.pipeline_par)
        else:
          # if PP and DP on different networks, overlapping is fine
          num_overlapped_pp = 0
        # we add DP/PP collision time and compute slowdown due to overlap
        overlap_inflection = chunk_dp_time - (overlap_window -
          num_overlapped_pp * chunk_bw_pp_time) + overlap_compute * \
          self._dp_net.processor_usage
        if overlap_inflection > 0:
          # Tcomm is larger than compute, excess is exposed
          overlappable_chunks_exposed_time = num_overlappable_chunks * \
            overlap_inflection
        else:
          # Tcomm is smaller than compute and hidden, but it contributes to
          # compute slowdown due part of compute resources orchestrating comm
          overlappable_chunks_exposed_time = num_overlappable_chunks * \
            chunk_dp_time * self._dp_net.processor_usage
        # Compute minimal bandwidth required for DP comm overlap of all chunks
        # but the last one.
        chunk_overlap_time = overlap_window + overlap_compute * \
          self._dp_net.processor_usage
        if self._dp_net == self._pp_net:
          chunk_overlap_time -= chunk_bw_pp_time
        chunk_overlap_time *= num_overlappable_chunks
        if chunk_overlap_time > 0:
          self._dp_bw_overlap_req_chunk = self._blocks_per_chunk * \
            self._block_dp_size / chunk_overlap_time
          if self.exe.optimizer_sharding:
            self._dp_bw_overlap_req_chunk *= (
              self._dp_net._ops["reduce_scatter"].scalar +
              self._dp_net._ops["all_gather"].scalar)
          else:
            self._dp_bw_overlap_req_chunk *= self._dp_net._ops["all_reduce"].scalar
        else:
          self._dp_bw_overlap_req_chunk = 0
        # in the last chunk, we overlap DP comm over last edge block and all
        # middle blocks, so we substract the time of the first edge block
        if self._baseblocks_per_chunk > 0:
          last_chunk_window = chunk_dp_overlap_time - chunk_bw_pp_time - (
            self._baseblock_bw_time + self._edgeblock_bw_time) / 2
          if not self.exe.optimizer_sharding:
            # If optimizer is not sharded, we can overlap optimizer step with
            # communication, except for memory access time
            last_chunk_window += (
              self._block_optim_time - self._block_optim_mem_time)
        else:
          # if there is no base blocks, we only have a single edge block
          # and last chunk is completely not overlappable
          last_chunk_window = 0
        last_chunk_inflection = (
          last_chunk_overlap_size * self._block_dp_time) + (
            block_dp_compute_time * self._dp_net.processor_usage -
            last_chunk_window)
        if last_chunk_inflection > 0:
          # Tcomm is larger than compute, excess is exposed
          last_chunk_exposed_time = last_chunk_inflection
        else:
          # Tcomm is smaller than compute and hidden, but it contributes to
          # compute slowdown due part of compute resources orchestrating comm
          last_chunk_exposed_time = last_chunk_overlap_size * \
            self._block_dp_time * self._dp_net.processor_usage
        exposed_time = \
          overlappable_chunks_exposed_time + last_chunk_exposed_time
        # Compute minimal bandwidth required for DP comm overlap of last chunk
        tail_overlap_time = last_chunk_window + last_chunk_overlap_size * \
          self._block_dp_time * self._dp_net.processor_usage
        if tail_overlap_time > 0:
          self._dp_bw_overlap_req_tail = self._blocks_per_chunk * \
          self._block_dp_size / tail_overlap_time
          if self.exe.optimizer_sharding:
            self._dp_bw_overlap_req_tail *= (
              self._dp_net._ops["reduce_scatter"].scalar +
              self._dp_net._ops["all_gather"].scalar)
          else:
            self._dp_bw_overlap_req_tail *= self._dp_net._ops["all_reduce"].scalar
        else:
          self._dp_bw_overlap_req_tail = 0
        self._dp_comm_time_exposed = self._block_dp_time + exposed_time
        self._dp_comm_time_link = self._blocks_per_proc * self._block_dp_time
        self.log.debug('Blocks per chunk: %d', self._blocks_per_chunk)
        self.log.debug('Num overlappable chunks: %d', num_overlappable_chunks)
        self.log.debug('Last chunk size: %d', last_chunk_overlap_size)
        self.log.debug('Chunk exposed time: %.3e', max(0, \
          chunk_dp_time + num_overlapped_pp * chunk_bw_pp_time - \
          overlap_window))
        self.log.debug('Last chunk exposed time: %.3e', last_chunk_exposed_time)
      else:
        self._dp_comm_time_exposed = self._blocks_per_proc * self._block_dp_time
        self._dp_comm_time_link = self._dp_comm_time_exposed
        self._dp_bw_overlap_req_chunk = 0
        self._dp_bw_overlap_req_tail = 0
    else:
      self._dp_comm_time_exposed = 0
      self._dp_comm_time_link = 0
      self._dp_bw_overlap_req_chunk = 0
      self._dp_bw_overlap_req_tail = 0
    self.log.debug('Chunk FW time: %.3e', chunk_fw_time)
    self.log.debug('Chunk BW time: %.3e', chunk_bw_time)
    self.log.debug('Chunk BW time for DP overlap: %.3e', chunk_dp_overlap_time)
    self.log.debug('DP comm time exposed: %.3e', self._dp_comm_time_exposed)
    self.log.debug('DP comm time on the link: %.3e',
                   self._dp_comm_time_link)
    self.log.debug('DP comm required bandwidth for overlapped chunks: %s',
                   human_format(self._dp_bw_overlap_req_chunk, "bandwidth"))
    self.log.debug('DP comm required bandwidth for the last chunk: %s',
                   human_format(self._dp_bw_overlap_req_tail, "bandwidth"))

    # memory capacity stats
    self._weight_space = self._block_weight_space * self._blocks_per_proc
    # account for activation recomputation
    # for full recompute we keep single block's activations
    # (no scaling by L/gpu)
    if self.exe.training:
      # With 1F1B schedule we only keep `pipeline_par` microbatches
      # If num_microbatches < PP, we keep num_microbatches for all PP stages
      if self.exe._num_microbatches < self.exe.pipeline_par:
        mem_microbatches = self.exe._num_microbatches
      else:
        mem_microbatches = self.exe.pipeline_par
      if self.exe.activation_recompute == "full":
        assert self._block_act_storage_space == 0, \
          "We expect with full act recomputation we recompute ALL activations"
        self._act_space = self._block_act_working_space
        # We would need to store checkpoints for all microbatches before we
        # compute BW pass with regular schedule, but we ONLY use 1F1B schedule
        self._act_checkpoint_size = self._blocks_per_proc * \
          self._block_act_checkpoint_size
        # Keep activation checkpoints for all pipeline stages for PP
        if self.exe.pipeline_interleaving > 1:
          self._act_checkpoint_size *= mem_microbatches * (
            1 + (self.exe.pipeline_par - 1) / (self.exe.pipeline_interleaving *
                                               self.exe.pipeline_par))
        else:
          assert self.exe.pipeline_interleaving == 1
          self._act_checkpoint_size *= mem_microbatches
      else:
        # Without full recompute, we don't need checkpoints
        self._act_checkpoint_size = 0
        # Without full recompute, we keep activations for all blocks on the GPU,
        # one activation for working block, and activation for other blocks for
        # all pipeline stages w.r.t. interleaved 1F1B schedule
        if self.exe.pipeline_interleaving > 1:
          pp_microbatch_factor = mem_microbatches * (
            1 + (self.exe.pipeline_par - 1) / (self.exe.pipeline_interleaving *
                                               self.exe.pipeline_par))
        else:
          assert self.exe.pipeline_interleaving == 1
          pp_microbatch_factor = mem_microbatches
        self._act_space = self._block_act_working_space + \
          self._block_act_storage_space * (
            self._blocks_per_proc * pp_microbatch_factor - 1)
      # Only need activation grads for a single block
      self._act_grad_space = self._block_act_grad_space
    else:
      self._act_space = self._block_act_working_space
      self._act_checkpoint_size = 0
      self._act_grad_space = 0

    # Optimizer split  already accounted for during block compilation
    # We should keep non-sharded weight grad for a current block for AllReduce
    # and one that we currently compute, so 2x total
    # We only need a single no sharded weight grad copy for before reduction
    if self.exe.training:
      if self._blocks_per_proc == 1:
        self._weight_grad_space = self._block_weight_grad_space_no_sharding
      else:
        self._weight_grad_space = \
          self._block_weight_grad_space_no_sharding + \
          self._block_weight_grad_space * (self._blocks_per_proc - 1)
      self._optimizer_space = \
        self._block_optimizer_space * self._blocks_per_proc
    else:
      self._weight_grad_space = 0
      self._optimizer_space = 0

  def _check_mem_caps(self):
    if self.get_mem_tier1_cap_req() > self.sys.mem1.capacity:
      raise self.Error(f'Mem tier1 needs '
                       f'{human_format(self.get_mem_tier1_cap_req(), "bytes")} '
                       f'but only has '
                       f'{human_format(self.sys.mem1.capacity, "bytes")}')
    if self.get_mem_tier2_cap_req() > self.sys.mem2.capacity:
      raise self.Error(f'Mem tier2 needs '
                       f'{human_format(self.get_mem_tier2_cap_req(), "bytes")} '
                       f'but only has '
                       f'{human_format(self.sys.mem2.capacity, "bytes")}')

  def _misc_sanity_checks(self):
    if self.exe.tensor_par == 1:
      assert self.get_tp_comm_exposed_time() == 0
      assert self.get_tp_comm_link_time() == 0
    if self.exe.pipeline_par == 1:
      assert self.get_pp_comm_exposed_time() == 0
      assert self.get_pp_comm_link_time() == 0
    if self.exe.data_par == 1:
      assert self.get_dp_comm_exposed_time() == 0
      assert self.get_dp_comm_link_time() == 0

    assert self._fw_flops >= self._block_fw_flops
    assert self._fw_flops_time >= self._block_fw_flops_time
    assert self._fw_mem_accessed >= self._block_fw_mem_accessed
    assert self._fw_mem_time >= self._block_fw_mem_time
    assert self._fw_time >= self._block_fw_time
    assert self._re_flops >= self._block_re_flops
    assert self._re_flops_time >= self._block_re_flops_time
    assert self._re_mem_accessed >= self._block_re_mem_accessed
    assert self._re_mem_time >= self._block_re_mem_time
    assert self._re_time >= self._block_re_time
    assert self._agrad_flops >= self._block_agrad_flops
    assert self._agrad_flops_time >= self._block_agrad_flops_time
    assert self._agrad_mem_accessed >= self._block_agrad_mem_accessed
    assert self._agrad_mem_time >= self._block_agrad_mem_time
    assert self._agrad_time >= self._block_agrad_time
    assert self._wgrad_flops >= self._block_wgrad_flops
    assert self._wgrad_flops_time >= self._block_wgrad_flops_time
    assert self._wgrad_mem_accessed >= self._block_wgrad_mem_accessed
    assert self._wgrad_mem_time >= self._block_wgrad_mem_time
    assert self._wgrad_time >= self._block_wgrad_time
    assert self._optim_flops >= self._block_optim_flops
    assert self._optim_flops_time >= self._block_optim_flops_time
    assert self._optim_mem_accessed >= self._block_optim_mem_accessed
    assert self._optim_mem_time >= self._block_optim_mem_time
    assert self._optim_time >= self._block_optim_time
    assert self._weight_space >= self._block_weight_space
    assert self._act_space >= self._block_act_working_space
    assert self._act_checkpoint_size >= self._block_act_checkpoint_size
    assert self._weight_grad_space >= self._block_weight_grad_space_no_sharding
    assert self._act_grad_space == self._block_act_grad_space
    assert self._optimizer_space >= self._block_optimizer_space

    if not self.exe.training:
      # when not training (inference), backward is not performed and DP has no
      # communication overhead
      assert self.get_bw_time() == 0
      assert self.get_optim_step_time() == 0
      assert self.get_bw_offload_time() == 0
      assert self.get_recompute_time() == 0
      assert self.get_act_checkpoint_size() == 0
      assert self.get_dp_comm_exposed_time() == 0
      assert self.get_dp_comm_link_time() == 0
    else:
      # when training, backward is performed
      assert self.get_bw_time() > 0
      assert self.get_optim_step_time() > 0
      if self.exe.activation_recompute == 'full':
        assert self.get_recompute_time() > 0
        assert self.get_act_checkpoint_size() > 0
      elif self.exe.activation_recompute == 'attn_only':
        assert self.get_recompute_time() > 0
        assert self.get_act_checkpoint_size() == 0
      else:
        if not self.exe.seq_par_ag_redo:
          assert self.get_recompute_time() == 0
        assert self.get_act_checkpoint_size() == 0


  def run(self, sys):
    assert self._compiled, "You must first call self.compile()"
    assert not self._executed
    assert isinstance(sys, System)
    self._compute_block_stats()
    self._compute_batch_stats()
    self._check_mem_caps()
    self._misc_sanity_checks()
    self._executed = True

  def _get_fw_offload_size(self):
    if self.exe.weight_offload:
      weight_offload_size = self._block_weight_space
    else:
      weight_offload_size = 0
    if self.exe.activations_offload:
      if self.exe.activation_recompute != 'full':
        act_offload_size = self._block_act_storage_space
      else:
        act_offload_size = self._block_act_checkpoint_size
    else:
      act_offload_size = 0
    return max(weight_offload_size, act_offload_size)

  def _get_bw_offload_size(self):
    bw_offload_size = 0
    if self.exe.training:
      if self.exe.weight_offload:
        bw_offload_size += self._block_weight_space
      if self.exe.activations_offload:
        if self.exe.activation_recompute != 'full':
          bw_offload_size += self._block_act_storage_space
        else:
          bw_offload_size += self._block_act_checkpoint_size
      if self.exe.optimizer_offload:
        bw_offload_size += self._block_optimizer_space
    return bw_offload_size

  def get_fw_time(self):
    return self._fw_time

  def get_fw_offload_time(self):
    return self.sys.compute_offload_time(self._get_fw_offload_size())

  def get_fw_offload_overhead(self):
    full_overhead = self.exe._num_microbatches * self._chunks_per_proc * (
      (self._baseblocks_per_chunk * self._baseblock_fw_offload_overhead) +
      (self._edgeblocks_per_chunk * self._edgeblock_fw_offload_overhead))
    return full_overhead

  def get_bw_time(self):
    return self._agrad_time + self._wgrad_time

  def get_optim_step_time(self):
    return self._optim_time

  def get_bw_offload_time(self):
    if self.exe.training:
      return self.sys.compute_offload_time(self._get_bw_offload_size())
    else:
      return 0

  def get_bw_offload_overhead(self):
    if self.exe.training:
      full_overhead = self.exe._num_microbatches * self._chunks_per_proc * (
        (self._baseblocks_per_chunk * self._baseblock_bw_offload_overhead) +
        (self._edgeblocks_per_chunk * self._edgeblock_bw_offload_overhead))
      return full_overhead
    else:
      return 0

  def get_recompute_time(self):
    return self._re_time

  def get_recomm_exposed_time(self):
    if self.exe.training:
      return self._recomm_time_exposed
    else:
      return 0

  def get_recomm_link_time(self):
    if self.exe.training:
      return self._recomm_time_link
    else:
      return 0

  def get_bubble_time(self):
    return self._bubble_time

  def get_tp_comm_exposed_time(self):
    return self._tp_comm_time_exposed

  def get_pp_comm_exposed_time(self):
    return self._pp_comm_time_exposed

  def get_dp_comm_exposed_time(self):
    if self.exe.training:
      return self._dp_comm_time_exposed
    else:
      return 0

  def get_tp_comm_link_time(self):
    return self._tp_comm_time_link

  def get_pp_comm_link_time(self):
    return self._pp_comm_time_link

  def get_dp_comm_link_time(self):
    if self.exe.training:
      return self._dp_comm_time_link
    else:
      return 0

  def get_dp_comm_net_time(self):
    if self.exe.training:
      return self._blocks_per_proc * self._block_dp_time
    else:
      return 0

  def get_total_time(self):
    time = self.get_fw_time()
    time += self.get_bw_time()
    time += self.get_optim_step_time()
    time += self.get_fw_offload_overhead()
    time += self.get_bw_offload_overhead()
    time += self.get_recompute_time()
    time += self.get_recomm_exposed_time()
    time += self.get_bubble_time()
    time += self.get_tp_comm_exposed_time()
    time += self.get_pp_comm_exposed_time()
    time += self.get_dp_comm_exposed_time()
    return time

  def get_useful_flops(self):
    total_flops = sum(
      [block.get_fw_flops() for block in self._llm_block])
    if self.exe.training:
      total_flops += sum(
        [block.get_agrad_flops() + block.get_wgrad_flops() + \
          block.get_optim_step_flops() for block in self._llm_block])
    return total_flops

  def get_compute_efficiency(self):
    total_flops = self.get_useful_flops()
    compute_time = self.get_fw_time() + self.get_bw_time() + \
      self.get_optim_step_time()
    perfect_time = self._blocks_per_proc * self.exe._num_microbatches * \
      total_flops / self.sys.matrix.flops(self.exe.datatype)
    return perfect_time / compute_time

  def get_system_efficiency(self):
    compute_time = self.get_fw_time() + self.get_bw_time() + \
      self.get_optim_step_time()
    return compute_time / self.get_total_time()

  def get_total_efficiency(self):
    total_flops = self.get_useful_flops()
    perfect_time = self._blocks_per_proc * self.exe._num_microbatches * \
      total_flops / self.sys.matrix.flops(self.exe.datatype)
    return perfect_time / self.get_total_time()

  def get_weight_space_min(self):
    return self._block_weight_space * 2

  def get_weight_space(self):
    return self._weight_space

  def get_act_space_min(self):
    if self.exe.activation_recompute != 'full':
      return self._block_act_working_space + self._block_act_storage_space
    else:
      return self._block_act_working_space

  def get_act_space(self):
    return self._act_space

  def get_act_checkpoint_size_min(self):
    if self.exe.training:
      if self.exe.activation_recompute != 'full':
        return 0
      else:
        return self._block_act_checkpoint_size * 2

  def get_act_checkpoint_size(self):
    if self.exe.training:
      if self.exe.activation_recompute != 'full':
        return 0
      else:
        return self._act_checkpoint_size
    else:
      return 0

  def get_weight_grad_space_min(self):
    if self.exe.training:
      # We keep one set of non-sharded weight grads after compute before
      # reduction, and one sharded set for offloading
      return self._block_weight_grad_space_no_sharding + \
        self._block_weight_grad_space
    else:
      return 0

  def get_weight_grad_space(self):
    if self.exe.training:
      return self._weight_grad_space
    else:
      return 0

  def get_act_grad_space_min(self):
    return self.get_act_grad_space()

  def get_act_grad_space(self):
    if self.exe.training:
      return self._act_grad_space
    else:
      return 0

    return self._block_optimizer_space * 2

  def get_optimizer_space_min(self):
    if self.exe.training:
      return self._block_optimizer_space * 2
    else:
      return 0

  def get_optimizer_space(self):
    if self.exe.training:
      return self._optimizer_space
    else:
      return 0

  def _get_mem_cap_reqs(self):
    tier1 = 0
    tier2 = 0
    if self.exe.weight_offload:
      tier1 += self.get_weight_space_min()
      tier2 += self.get_weight_space()
    else:
      tier1 += self.get_weight_space()
    if self.exe.activations_offload:
      if self.exe.activation_recompute != 'full':
        tier1 += self.get_act_space_min()
        tier2 += self.get_act_space()
      else:
        tier1 += self.get_act_space_min()
        tier1 += self.get_act_checkpoint_size_min()
        tier2 += self.get_act_checkpoint_size()
    else:
      tier1 += self.get_act_space()
      tier1 += self.get_act_checkpoint_size()
    if self.exe.optimizer_offload:
      # We keep one set of non-sharded weight grads after compute before
      # reduction, and one sharded set for offloading
      tier1 += self.get_weight_grad_space_min()
      tier1 += self.get_optimizer_space_min()
      tier2 += self._block_weight_grad_space * self._blocks_per_proc
      tier2 += self.get_optimizer_space()
    else:
      tier1 += self.get_weight_grad_space() + \
        self.get_optimizer_space()
    tier1 += self.get_act_grad_space()
    return tier1, tier2

  def get_mem_tier1_cap_req(self):
    return self._get_mem_cap_reqs()[0]

  def get_mem_tier2_cap_req(self):
    return self._get_mem_cap_reqs()[1]

  def get_act_offload_bw_req(self):
    # We should be able to offload (write) activation during FW pass and
    # prefetch it (read) during BW pass for block (i-1)
    # After BW pass activations are discarded
    if self.exe.activation_recompute != 'full':
      act_offload_size = self._block_act_storage_space
    else:
      act_offload_size = self._block_act_checkpoint_size
    offload_time = min(
      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,
      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)
    return act_offload_size / offload_time

  def get_weight_offload_bw_req(self):
    # We should be able to offload (write) and prefetch (read) weights both
    # during FW and BW passes for blocks (i-1) / (i+1).
    # We always keep weights, they cannot be discarded
    offload_time = min(
      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,
      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)
    return self._block_weight_space / offload_time

  def get_optim_offload_bw_req(self):
    # We should be able to offload (write) weight grads and optimizer state
    # and prefetch (read) optimizer state during BW passes for blocks
    # (i-1) / (i+1).
    if self.exe.training:
      offload_time = min(
        self._baseblock_bw_time_no_offload - (self._block_agrad_mem_time +
          self._block_wgrad_mem_time),
        self._edgeblock_bw_time_no_offload - (self._block_agrad_mem_time +
          self._block_wgrad_mem_time))
      return (self._block_weight_grad_space + self._block_optimizer_space) / \
        offload_time
    else:
      return 0

  def get_offload_mem_bw_req(self):
    fw_offload_time = min(
      self._baseblock_fw_time_no_offload - self._block_fw_mem_time,
      self._edgeblock_fw_time_no_offload - self._block_fw_mem_time)
    if self.exe.training:
      bw_offload_time = min(
        self._baseblock_bw_time_no_offload - (self._block_agrad_mem_time +
          self._block_wgrad_mem_time),
        self._edgeblock_bw_time_no_offload - (self._block_agrad_mem_time +
          self._block_wgrad_mem_time))
      req_bw = max(self._get_fw_offload_size() / fw_offload_time,
                   self._get_bw_offload_size() / bw_offload_time)
      return req_bw
    else:
      return self._get_fw_offload_size() / fw_offload_time

  def get_sample_rate(self):
    return self.exe.global_batch_size / self.get_total_time()

  def display_stats(self):
    stats = "=" * 80 + "\n"
    stats += "" \
      f"blocks={self.app.num_blocks}, " \
      f"hidden={self.app.hidden}, feedforward={self.app.feedforward}\n" \
      f"num attn heads: {self.app.attn_heads}, " \
      f"attn_size={self.app.attn_size}\n" \
      f"Run on {self.exe.num_procs} processors with:\n" \
      f"TP={self.exe.tensor_par}\n" \
      f"PP={self.exe.pipeline_par}\n" \
      f"DP={self.exe.data_par}\n" \
      f"Blocks per processor: {self._blocks_per_proc}\n" \
      f"Execution: {self.exe.get_json()};\n" \
      f"System: {self.sys.cfg};\n" \
      f"Weights: {human_format(self.get_weight_space(), 'bytes')};\n" \
      f"Act: {human_format(self.get_act_space(), 'bytes')};\n" \
      f"Act CP: {human_format(self.get_act_checkpoint_size(), 'bytes')};\n" \
      f"Act grad: {human_format(self.get_act_grad_space(), 'bytes')};\n" \
      f"Weight grad: {human_format(self.get_weight_grad_space(), 'bytes')};\n" \
      f"Optim space: {human_format(self.get_optimizer_space(), 'bytes')};\n" \
      f"Batch FW time: {self.get_fw_time():.4f};\n" \
      f"Batch BW time: {self.get_bw_time():.4f};\n" \
      f"Batch optim time: {self.get_optim_step_time():.4f};\n" \
      f"Batch FW offload overhead: {self.get_fw_offload_overhead():.4f};\n" \
      f"Batch BW offload overhead: {self.get_bw_offload_overhead():.4f};\n" \
      f"Batch recompute overhead: {self.get_recompute_time():.4f};\n" \
      f"Batch recomm overhead: {self.get_recomm_exposed_time():.4f};\n" \
      f"Batch bubble overhead: {self.get_bubble_time():.4f};\n" \
      f"Batch TP comm overhead: {self.get_tp_comm_exposed_time():.4f};\n" \
      f"Batch PP comm overhead: {self.get_pp_comm_exposed_time():.4f};\n" \
      f"Batch DP comm overhead: {self.get_dp_comm_exposed_time():.4f};\n" \
      f"Batch TP comm time on link: {self.get_tp_comm_link_time():.4f};\n" \
      f"Batch PP comm time on link: {self.get_pp_comm_link_time():.4f};\n" \
      f"Batch DP comm time on link: {self.get_dp_comm_link_time():.4f};\n" \
      f"Batch total time: {self.get_total_time():.4f};\n" \
      f"Activation offload required BW: " \
      f"{human_format(self.get_act_offload_bw_req(), 'bandwidth')};\n" \
      f"Weight offload required BW: " \
      f"{human_format(self.get_weight_offload_bw_req(), 'bandwidth')};\n" \
      f"Optimizer offload required BW: " \
      f"{human_format(self.get_optim_offload_bw_req(), 'bandwidth')};\n" \
      f"Total offload required BW: " \
      f"{human_format(self.get_offload_mem_bw_req(), 'bandwidth')};\n" \
      f"Mem tier1 capacity requirement: " \
      f"{human_format(self.get_mem_tier1_cap_req(), 'bytes')};\n" \
      f"Mem tier2 capacity requirement: " \
      f"{human_format(self.get_mem_tier2_cap_req(), 'bytes')};\n" \
      f"Mem tier2 BW for offload: " \
      f"{human_format(self.get_offload_mem_bw_req(), 'bandwidth')};\n" \
      f"Compute efficiency: {self.get_compute_efficiency()*100:.2f}%;\n" \
      f"System efficiency: {self.get_system_efficiency()*100:.2f}%;\n" \
      f"Total efficiency: {self.get_total_efficiency()*100:.2f}%;\n" \
      f"Sample rate: {self.get_sample_rate():.2f};\n"
    self.log.info(stats)


================================================
FILE: calculon/llm/optimal_execution.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import datetime
import gzip
import logging
import multiprocessing as mp
import psutil
import os

import calculon
from calculon.util import pick, arg_true_false_all
from calculon.llm import *


class OptimalExecution(calculon.CommandLine):
  NAME = 'llm-optimal-execution'
  ALIASES = ['loe']

  @staticmethod
  def create_parser(subparser):
    sp = subparser.add_parser(
      OptimalExecution.NAME, aliases=OptimalExecution.ALIASES,
      help='run a search to find the optimal llm execution')
    sp.set_defaults(func=OptimalExecution.run_command)
    sp.add_argument('-d', '--debug', action='store_true',
                    help='Loop over executions, don\'t run them')
    sp.add_argument('application', type=str,
                    help='File path to application configuration')
    sp.add_argument('num_procs', type=int,
                    help='Number of processors in execution')
    sp.add_argument('max_batch_size', type=int,
                    help='Maximum batch size, will be largest multiple of DP')
    sp.add_argument('datatype', type=str, choices=System.supported_datatypes(),
                    help='The datatype to use')
    sp.add_argument('system', type=str,
                    help='File path to system configuration')
    sp.add_argument('output', type=str,
                    help='File path to the output file'
                    " ('*.csv', '*.csv.gz', '*.json', '*.json.gz')")
    sp.add_argument('-c', '--cpus', type=int, default=psutil.cpu_count(logical=False),
                    help='CPUs to use for parallelization')
    sp.add_argument('-n', '--noneok', action='store_true',
                    help='Don\'t give failure status when no good execution exists')
    sp.add_argument('-m', '--mbs-break', action='store_true',
                    help='Search across MBS and break earlier when possible')
    sp.add_argument('-t', '--top-n', type=int, default=1,
                    help='Number of best outputs')
    sp.add_argument('-l', '--layers', action='store_true',
                    help='Include layers information in output stats file')
    sp.add_argument('-f', '--fused_activation', type=arg_true_false_all,
                    default='true', help='Mode of fused activation')
    sp.add_argument('--no-tp-overlap', action='store_true',
                    help='Don\'t allow TP overlap')
    sp.add_argument('--no-dp-overlap', action='store_true',
                    help='Don\'t allow DP overlap')

  @staticmethod
  def run_command(logger, args):
    assert args.top_n > 0, 'top-n must be > 0'

    app = Llm.Application(calculon.io.read_json_file(args.application))
    syst = System(calculon.io.read_json_file(args.system))

    params = []
    for tp in Llm.get_all_tensor_parallelisms(
        args.num_procs, app.hidden, app.attn_heads):
      for pp in Llm.get_all_pipeline_parallelisms(
          args.num_procs, tp, app.num_blocks):
        dp = Llm.get_data_parallelism(args.num_procs, tp, pp)
        for ppint in Llm.get_valid_pipeline_interleavings(app.num_blocks, pp):
          batch_size = OptimalExecution.get_batch_size(dp, args.max_batch_size)
          if batch_size is None:
            continue
          for activation_recompute in ['full', 'attn_only', 'none']:
            for optimizer_sharding in pick(dp>1, [True, False], [False]):
              for tensor_par_comm_type in ['ar', 'p2p_rs_ag', 'rs_ag']:
                params.append(
                  (args.debug, args.top_n, args.layers, args.num_procs,
                   args.max_batch_size, args.datatype, app, syst, tp, pp, dp,
                   ppint, batch_size, activation_recompute, optimizer_sharding,
                   tensor_par_comm_type, args.fused_activation, args.mbs_break,
                   not args.no_tp_overlap, not args.no_dp_overlap))

    # Runs parallel searches
    start_time = datetime.datetime.now()
    with mp.Pool(args.cpus) as pool:
      searches = pool.starmap(OptimalExecution.search, params)
    end_time = datetime.datetime.now()

    # Combines parallel search result into one data structure
    best = []
    exe_count = 0
    good_exe_count = 0
    bad_exe_count = 0
    for cbest, ec, gec, bec, tp, pp in searches:
      best = OptimalExecution.update_list(best, cbest, args.top_n)
      exe_count += ec
      good_exe_count += gec
      bad_exe_count += bec

    logger.info(f'Total executions: {exe_count}')
    logger.info(f'Good executions: {good_exe_count}')
    logger.info(f'Bad executions: {bad_exe_count}')
    calc_rate = exe_count / (end_time - start_time).total_seconds()
    logger.info(f'Calculation rate: {calc_rate:.2f} calcs/sec')
    if args.debug:
      return 0

    if len(best) == 0:
      if not args.noneok:
        logger.fatal('No acceptable configurations found :(')
        return -1
      else:
        logger.info('No acceptable configurations found :(')
    else:
      logger.info(f'Best sample rate: {best[0][0]}')

    output = {}
    for index, run in enumerate(best):
      _, execution, stats = run
      output[index] = {
        'execution': execution,
        'stats': stats
      }

    if calculon.io.is_json_extension(args.output):
      logger.info(f'Output: {args.output}')
      calculon.io.write_json_file(output, args.output)
    elif args.output.endswith('.csv') or args.output.endswith('.csv.gz'):
      logger.info(f'Output: {args.output}')
      exe_keys = list(output[0]['execution'].keys())
      stats_keys = list(output[0]['stats'].keys())
      opener = gzip.open if args.output.endswith('.gz') else open
      with opener(args.output, 'wb') as fd:
        fd.write(bytes(f',{",".join(exe_keys)},{",".join(stats_keys)}\n',
                       'utf-8'))
        for index in sorted(output.keys()):
          fd.write(bytes(f'{index}', 'utf-8'))
          for exe_key in exe_keys:
            fd.write(bytes(f',{output[index]["execution"][exe_key]}', 'utf-8'))
          for stats_key in stats_keys:
            fd.write(bytes(f',{output[index]["stats"][stats_key]}', 'utf-8'))
          fd.write(bytes('\n', 'utf-8'))
    else:
      assert False, f'Unknown file type: {args.output}'

    return 0

  @staticmethod
  def get_batch_size(data_par, max_batch_size):
    if data_par > max_batch_size:
      return None
    last = data_par
    while True:
      if last + data_par > max_batch_size:
        return last
      else:
        last += data_par

  @staticmethod
  def search(debug, top_n, layers, num_procs, max_batch_size, datatype,
             app, syst, tp, pp, dp, ppint, batch_size, activation_recompute,
             optimizer_sharding, tensor_par_comm_type, fused_acts, mbs_break,
             allow_tp_overlap, allow_dp_overlap):
    num_nets = syst.num_networks

    best = []
    exe_count = 0
    good_exe_count = 0
    bad_exe_count = 0

    has_mem2 = syst.mem2.capacity > 0

    can_redo = Llm.can_redo_ag(tensor_par_comm_type,
                               activation_recompute)
    for seq_par_ag_redo in pick(can_redo, [True, False], [False]):
      for data_par_overlap in pick(dp>1 and allow_dp_overlap, [True, False],
                                   [False]):
        for tensor_par_overlap in pick(tp>1 and allow_tp_overlap,
                                       ['none', 'ring', 'pipe'], ['none']):
          for weight_offload in pick(has_mem2, [True, False], [False]):
            if activation_recompute == 'full' or not has_mem2:
              activations_offloads = [False]
            else:
              activations_offloads = [True, False]
            for activations_offload in activations_offloads:
              for optimizer_offload in pick(has_mem2, [True, False],
                                            [False]):
                for fused_act in fused_acts:
                  for microbatch_size in Llm.get_valid_microbatch_sizes(
                      app.seq_size, tp, dp, batch_size, pp):
                    mbs_break_good = good_exe_count
                    for tn in pick(tp>1, range(num_nets), [0]):
                      for pn in pick(pp>1, range(num_nets), [0]):
                        for dn in pick(dp>1, range(num_nets), [0]):
                          exe_count += 1
                          exe_json = {
                            'num_procs': num_procs,
                            'tensor_par': tp,
                            'pipeline_par': pp,
                            'data_par': dp,
                            'tensor_par_net': tn,
                            'pipeline_par_net': pn,
                            'data_par_net': dn,
                            'batch_size': batch_size,
                            'microbatch_size': microbatch_size,
                            'datatype': datatype,
                            'fused_activation': fused_act,
                            'attention_type': 'multihead',
                            'activation_recompute': activation_recompute,
                            'pipeline_interleaving': ppint,
                            'optimizer_sharding': optimizer_sharding,
                            'tensor_par_comm_type': tensor_par_comm_type,
                            'tensor_par_overlap': tensor_par_overlap,
                            'seq_par_ag_redo': seq_par_ag_redo,
                            'data_par_overlap': data_par_overlap,
                            'weight_offload': weight_offload,
                            'activations_offload': activations_offload,
                            'optimizer_offload': optimizer_offload,
                            'training': True
                          }

                          if not debug:
                            try:
                              logger = logging.Logger('sub')
                              model = Llm(app, logger)
                              model.compile(
                                syst,
                                Llm.Execution.from_json(exe_json))
                              model.run(syst)
                              stats = model.get_stats_json(layers)
                              good_exe_count += 1
                              curr = (stats['sample_rate'], exe_json, stats)
                              best = OptimalExecution.update_list(best, curr,
                                                                  top_n)
                            except Llm.Error as ex:
                              logger = logging.getLogger()
                              logger.debug(f'JSON:{exe_json}\nERROR:{ex}\n')
                              bad_exe_count += 1
                    if mbs_break and good_exe_count == mbs_break_good:
                      break
    return (best, exe_count, good_exe_count, bad_exe_count, tp, pp)

  @staticmethod
  def update_list(current, candidate, quantity):
    if not isinstance(candidate, list):
      current.append(candidate)
    else:
      current.extend(candidate)
    current.sort(reverse=True, key=lambda x: x[0])
    return current[:quantity]


calculon.CommandLine.register(OptimalExecution)


================================================
FILE: calculon/llm/parameter_calculator.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import os

import calculon
from calculon.llm import *

class ParameterCalculator(calculon.CommandLine):
  NAME = 'llm-parameter-calculator'
  ALIASES = ['lpc']

  @staticmethod
  def create_parser(subparser):
    sp = subparser.add_parser(ParameterCalculator.NAME,
                              aliases=ParameterCalculator.ALIASES,
                              help='run a single llm calculation')
    sp.set_defaults(func=ParameterCalculator.run_command)
    sp.add_argument('application', type=str,
                    help='File path to application configuration')
    sp.add_argument('-a', '--alignment', type=int, default=13,
                    help='Alignment spaces')

  @staticmethod
  def run_command(logger, args):
    app_json = calculon.io.read_json_file(args.application)

    try:
      app = Llm.Application(app_json)
    except Llm.Error as error:
      print(f'ERROR: {error}')
      return -1

    app_name, _ = os.path.splitext(os.path.basename(args.application))

    logger.info(f'{app_name}'
                f'{" " * (args.alignment - len(app_name))}'
                ' -> '
                f'{human_format(app.num_parameters())}')


calculon.CommandLine.register(ParameterCalculator)


================================================
FILE: calculon/llm/runner.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import calculon
from calculon.llm import *

class Runner(calculon.CommandLine):
  NAME = 'llm'
  ALIASES = []

  @staticmethod
  def create_parser(subparser):
    sp = subparser.add_parser(Runner.NAME, aliases=Runner.ALIASES,
                              help='run a single llm calculation')
    sp.set_defaults(func=Runner.run_command)
    sp.add_argument('application', type=str,
                    help='File path to application configuration')
    sp.add_argument('execution', type=str,
                    help='File path to execution configuration')
    sp.add_argument('system', type=str,
                    help='File path to system configuration')
    sp.add_argument('stats', type=str,
                    help='File path to stats output ("-" for stdout")')
    sp.add_argument('-p', '--peers', type=str, default=None,
                    help='File path to write out peers file')
    sp.add_argument('-l', '--layers', action='store_true',
                    help='Include layers information in output stats file')

  @staticmethod
  def run_command(logger, args):
    app_json = calculon.io.read_json_file(args.application)
    exe_json = calculon.io.read_json_file(args.execution)
    sys_json = calculon.io.read_json_file(args.system)

    app = Llm.Application(app_json)
    exe = Llm.Execution.from_json(exe_json)
    syst = System(sys_json)

    try:
      model = Llm(app, logger)
      model.compile(syst, exe)
      model.run(syst)
    except Llm.Error as error:
      print(f'ERROR: {error}')
      return -1

    if args.stats == '-':
      model.display_stats()
    elif calculon.is_json_extension(args.stats):
      calculon.write_json_file(model.get_stats_json(args.layers), args.stats)
    else:
      assert False, f'unknown stats extension: {args.stats}'

    if args.peers:
      calculon.write_json_file(exe.get_peers_json(), args.peers)

    return 0


calculon.CommandLine.register(Runner)


================================================
FILE: calculon/llm/validation.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  https://www.apache.org/licenses/LICENSE-2.0
 *
 * See the NOTICE file distributed with this work for additional information
 * regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
"""

import logging
import math
import os

import calculon
from calculon.util import pick
from calculon.llm import *


class Validation(calculon.CommandLine):
  NAME = 'llm-validation'
  ALIASES = ['lv']

  @staticmethod
  def create_parser(subparser):
    sp = subparser.add_parser(
      Validation.NAME, aliases=Validation.ALIASES,
      help='run a validation of llm execution')
    sp.set_defaults(func=Validation.run_command)
    sp.add_argument('-b', '--base_dir', default='.',
                    help='Base directory')
    sp.add_argument('-v', '--verbose', action='store_true',
                    help='Show verbose output while running')

  @staticmethod
  def run_command(logger, args):
    funcs = [
      Validation.seqsel_fig1,
      Validation.seqsel_fig7,
      Validation.seqsel_tab5
    ]
    for func in funcs:
      if args.verbose:
        print(f'\n\nNow running test: {func.__name__}')
      if func(logger, args) is not None:
        return -1

  @staticmethod
  def seqsel_fig1(logger, args):
    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']
    kModes = ['none', 'seqsel']
    # These profiled values are reported here:
    # https://arxiv.org/pdf/2205.05198.pdf
    # Figure 1
    kProfile = {
      'megatron-22B': {
        'none': {
          'par_opt': 45.5625,
          'act': 59.25
        },
        'seqsel': {
          'par_opt': 45.5625,
          'act': 9.5625
        }
      },
      'gpt3-175B': {
        'none': {
          'par_opt': 45.5625,
          'act': 66.84375
        },
        'seqsel': {
          'par_opt': 45.5625,
          'act': 12.3515625
        }
      },
      'turing-530B': {
        'none': {
          'par_opt': 31.640625,
          'act': 114.0234375
        },
        'seqsel': {
          'par_opt': 31.640625,
          'act': 23.076171875
        }
      },
      'megatron-1T': {
        'none': {
          'par_opt': 32.958984375,
          'act': 131.25
        },
        'seqsel': {
          'par_opt': 32.958984375,
          'act': 26.5625
        }
      }
    }

    def get_files(model, mode):
      assert model in kModels
      assert mode in kModes
      app = os.path.join(args.base_dir, 'models', f'{model}.json')
      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'fig1',
                         f'{model}_{mode}.json')
      return app, exe

    def get_profile(model, mode):
      assert model in kModels
      assert mode in kModes
      return kProfile[model][mode]

    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80e.json')
    syst = System(calculon.io.read_json_file(syst_file))
    data = {}
    for model in kModels:
      data[model] = {}
      for mode in kModes:
        if args.verbose:
          print(f'Analyzing {model} {mode}')
        data[model][mode] = {}
        app_file, exe_file = get_files(model, mode)
        app = Llm.Application(calculon.read_json_file(app_file))
        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))
        mt = Llm(app, logger)
        mt.compile(syst, exe)
        mt.run(syst)
        stats = mt.get_stats_json(False)
        data[model][mode]['profile_gib'] = get_profile(model, mode)
        act_par_opt = (stats['weight_space'] + stats['weight_grad_space'] +
                       stats['optimizer_space']) / (1024**3)
        act_act = stats['act_space'] / (1024**3)
        data[model][mode]['actual_gib'] = {
          'par_opt': act_par_opt,
          'act': act_act
        }

    print('*Params & Opt,|,none,,,|,seqsel,,,')
    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')
    max_error = 0
    abs_error = 0
    for model in kModels:
      print(f'{model},', end='')
      for mode in kModes:
        p = data[model][mode]['profile_gib']['par_opt']
        a = data[model][mode]['actual_gib']['par_opt']
        d = 100*(1-a/p)
        if math.fabs(d) > max_error:
          max_error = math.fabs(d)
        abs_error += math.fabs(d)
        print(f'|,{p},{a:.2f},{d:.2f}%,', end='')
      print()
    ave_error = abs_error / (len(kModels) * len(kModes))
    print(f'Ave,,{ave_error:.2f}%')
    print(f'Max,,{max_error:.2f}%')
    print(',')

    print('*Activations,|,none,,,|,seqsel,,,')
    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')
    max_error = 0
    abs_error = 0
    for model in kModels:
      print(f'{model},', end='')
      for mode in kModes:
        p = data[model][mode]['profile_gib']['act']
        a = data[model][mode]['actual_gib']['act']
        d = 100*(1-a/p)
        if math.fabs(d) > max_error:
          max_error = math.fabs(d)
        abs_error += math.fabs(d)
        print(f'|,{p},{a:.2f},{d:.2f}%,', end='')
      print()
    ave_error = abs_error / (len(kModels) * len(kModes))
    print(f'Ave,,{ave_error:.2f}%')
    print(f'Max,,{max_error:.2f}%')
    print(',')

  @staticmethod
  def seqsel_fig7(logger, args):
    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']
    kModes = ['none', 'seq', 'sel', 'seqsel', 'full']
    # These profiled values are reported here:
    # https://arxiv.org/pdf/2205.05198.pdf
    # Figure 7
    kProfile = {
      'megatron-22B': {
        'none': 100.00,
        'seq': 66.84,
        'sel': 49.42,
        'seqsel': 16.18,
        'full': 7.64
      },
      'gpt3-175B': {
        'none': 100.00,
        'seq': 62.04,
        'sel': 56.53,
        'seqsel': 18.49,
        'full': 8.71
      },
      'turing-530B': {
        'none': 100.00,
        'seq': 58.31,
        'sel': 62.04,
        'seqsel': 20.27,
        'full': 9.42
      },
      'megatron-1T': {
        'none': 100.00,
        'seq': 58.31,
        'sel': 62.04,
        'seqsel': 20.27,
        'full': 9.42
      }
    }

    def get_files(model, mode):
      assert model in kModels
      assert mode in kModes
      app = os.path.join(args.base_dir, 'models', f'{model}.json')
      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'fig7',
                         f'{model}_{mode}.json')
      return app, exe

    def get_profile(model, mode):
      assert model in kModels
      assert mode in kModes
      return kProfile[model][mode]

    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80e.json')
    syst = System(calculon.io.read_json_file(syst_file))
    raw = {}
    for model in kModels:
      raw[model] = {}
      for mode in kModes:
        if args.verbose:
          print(f'Analyzing {model} {mode}')
        raw[model][mode] = {}
        app_file, exe_file = get_files(model, mode)
        app = Llm.Application(calculon.read_json_file(app_file))
        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))
        mt = Llm(app, logger)
        mt.compile(syst, exe)
        mt.run(syst)
        stats = mt.get_stats_json(False)
        raw[model][mode] = stats['act_space'] + stats['act_checkpoint_size']

    rel = {}
    for model in kModels:
      rel[model] = {}
      for mode in kModes:
        rel[model][mode] = {}
        rel[model][mode] = raw[model][mode] / raw[model]['none'] * 100

    print('Activations,|,none,,,|,seq,,,|,sel,,,|,seqsel,,,|,full,,,')
    print('Model,|,Profile,Calc,Delta,|,Profile,Calc,Delta,|'
          ',Profile,Calc,Delta,|,Profile,Calc,Delta,|,Profile,Calc,Delta,')
    max_error = 0
    abs_error = 0
    for model in kModels:
      print(f'{model},', end='')
      for mode in kModes:
        p = get_profile(model, mode)
        a = rel[model][mode]
        d = 100*(1-a/p)
        if math.fabs(d) > max_error:
          max_error = math.fabs(d)
        abs_error += math.fabs(d)
        print(f'|,{p}%,{a:.2f}%,{d:.2f}%,', end='')
      print()
    ave_error = abs_error / (len(kModels) * len(kModes))
    print(f'Ave,,{ave_error:.2f}%')
    print(f'Max,,{max_error:.2f}%')
    print(',')

  @staticmethod
  def seqsel_tab5(logger, args):
    kModels = ['megatron-22B', 'gpt3-175B', 'turing-530B', 'megatron-1T']
    kModes = ['full', 'seqsel']
    # These profiled values are reported here:
    # https://arxiv.org/pdf/2205.05198.pdf
    # Table 5
    kProfile = {
      'megatron-22B': {
        'full': 1.42,
        'seqsel': 1.10
      },
      'gpt3-175B': {
        'full': 18.13,
        'seqsel': 13.75
      },
      'turing-530B': {
        'full': 49.05,
        'seqsel': 37.83
      },
      'megatron-1T': {
        'full': 94.42,
        'seqsel': 71.49
      }
    }

    def get_files(model, mode):
      assert model in kModels
      assert mode in kModes
      app = os.path.join(args.base_dir, 'models', f'{model}.json')
      exe = os.path.join(args.base_dir, 'validation', 'seqsel', 'tab5',
                         f'{model}_{mode}.json')
      return app, exe

    def get_profile(model, mode):
      assert model in kModels
      assert mode in kModes
      return kProfile[model][mode]

    syst_file = os.path.join(args.base_dir, 'systems', 'a100_80g.json')
    syst = System(calculon.io.read_json_file(syst_file))
    data = {}
    for model in kModels:
      data[model] = {}
      for mode in kModes:
        if args.verbose:
          print(f'Analyzing {model} {mode}')
        data[model][mode] = {}
        app_file, exe_file = get_files(model, mode)
        app = Llm.Application(calculon.read_json_file(app_file))
        exe = Llm.Execution.from_json(calculon.read_json_file(exe_file))
        mt = Llm(app, logger)
        mt.compile(syst, exe)
        mt.run(syst)
        stats = mt.get_stats_json(False)
        data[model][mode]['profile_time'] = get_profile(model, mode)
        data[model][mode]['actual_time'] = stats["total_time"]
        data[model][mode]['memory_req'] = stats["proc_mem_tier1_cap_req"]

    print('End-to-end,|,full,,,,|,seqsel,,,,')
    print('Model,|,Profile,Calc,Delta,GiB,|,Profile,Calc,Delta,GiB,')
    max_error = 0
    abs_error = 0
    for model in kModels:
      print(f'{model},', end='')
      for mode in kModes:
        p = data[model][mode]['profile_time']
        a = data[model][mode]['actual_time']
        d = 100*(1-a/p)
        if math.fabs(d) > max_error:
          max_error = math.fabs(d)
        abs_error += math.fabs(d)
        m = data[model][mode]['memory_req'] / (1024**3)
        print(f'|,{p},{a:.2f},{d:.2f}%,{m:.2f},', end='')
      print()
    ave_error = abs_error / (len(kModels) * len(kModes))
    print(f'Ave,,{ave_error:.2f}%')
    print(f'Max,,{max_error:.2f}%')
    print(',')

calculon.CommandLine.register(Validation)


================================================
FILE: calculon/memory.py
================================================
"""
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 
Download .txt
gitextract_qq3udcpt/

├── .gitignore
├── LICENSE
├── Makefile
├── NOTICE
├── README.md
├── bin/
│   └── calculon
├── calculon/
│   ├── __init__.py
│   ├── command_line.py
│   ├── io.py
│   ├── llm/
│   │   ├── __init__.py
│   │   ├── all_executions.py
│   │   ├── layers.py
│   │   ├── llm.py
│   │   ├── optimal_execution.py
│   │   ├── parameter_calculator.py
│   │   ├── runner.py
│   │   └── validation.py
│   ├── memory.py
│   ├── network.py
│   ├── processor.py
│   ├── system.py
│   ├── util.py
│   └── version.py
├── examples/
│   └── 3072_t4_p64_d12_mbs4_full.json
├── models/
│   ├── anthropic-52B.json
│   ├── chinchilla.json
│   ├── gopher-280B.json
│   ├── gpt3-13B.json
│   ├── gpt3-175B.json
│   ├── lamda.json
│   ├── megatron-126M.json
│   ├── megatron-1T.json
│   ├── megatron-22B.json
│   ├── megatron-40B.json
│   ├── megatron-5B.json
│   ├── palm-540B.json
│   └── turing-530B.json
├── pylintrc
├── pyproject.toml
├── scripts/
│   ├── 3dplot.py
│   ├── find_huge.py
│   ├── heatmap.py
│   ├── install_hooks.sh
│   └── json_to_csv.py
├── setup.py
├── systems/
│   ├── a100_80e.json
│   ├── a100_80g.json
│   └── h100_80g_nvl8.json
├── test/
│   ├── __init__.py
│   ├── test.sh
│   └── test_json_write_read.py
└── validation/
    └── seqsel/
        ├── fig1/
        │   ├── gpt3-175B_none.json
        │   ├── gpt3-175B_seqsel.json
        │   ├── megatron-1T_none.json
        │   ├── megatron-1T_seqsel.json
        │   ├── megatron-22B_none.json
        │   ├── megatron-22B_seqsel.json
        │   ├── turing-530B_none.json
        │   └── turing-530B_seqsel.json
        ├── fig7/
        │   ├── gpt3-175B_full.json
        │   ├── gpt3-175B_none.json
        │   ├── gpt3-175B_sel.json
        │   ├── gpt3-175B_seq.json
        │   ├── gpt3-175B_seqsel.json
        │   ├── megatron-1T_full.json
        │   ├── megatron-1T_none.json
        │   ├── megatron-1T_sel.json
        │   ├── megatron-1T_seq.json
        │   ├── megatron-1T_seqsel.json
        │   ├── megatron-22B_full.json
        │   ├── megatron-22B_none.json
        │   ├── megatron-22B_sel.json
        │   ├── megatron-22B_seq.json
        │   ├── megatron-22B_seqsel.json
        │   ├── turing-530B_full.json
        │   ├── turing-530B_none.json
        │   ├── turing-530B_sel.json
        │   ├── turing-530B_seq.json
        │   └── turing-530B_seqsel.json
        └── tab5/
            ├── gpt3-175B_full.json
            ├── gpt3-175B_seqsel.json
            ├── megatron-1T_full.json
            ├── megatron-1T_seqsel.json
            ├── megatron-22B_full.json
            ├── megatron-22B_seqsel.json
            ├── turing-530B_full.json
            └── turing-530B_seqsel.json
Download .txt
SYMBOL INDEX (255 symbols across 21 files)

FILE: calculon/command_line.py
  class CommandLine (line 20) | class CommandLine:
    method create_parser (line 26) | def create_parser(subparser):
    method run_command (line 34) | def run_command(logger, args):
    method register (line 45) | def register(cls):
    method command_lines (line 64) | def command_lines():
    method all_names (line 68) | def all_names():

FILE: calculon/io.py
  class NpEncoder (line 22) | class NpEncoder(json.JSONEncoder):
    method default (line 23) | def default(self, obj):
  function is_json_extension (line 34) | def is_json_extension(filename):
  function write_json_file (line 38) | def write_json_file(jdata, filename):
  function read_json_file (line 46) | def read_json_file(filename):

FILE: calculon/llm/all_executions.py
  class AllExecutions (line 34) | class AllExecutions(calculon.CommandLine):
    method create_parser (line 39) | def create_parser(subparser):
    method execution_fields (line 67) | def execution_fields():
    method get_batch_size (line 77) | def get_batch_size(data_par, max_batch_size):
    method all_executions (line 88) | def all_executions(app, syst, num_procs, max_batch_size, datatype, fus...
    method run_command (line 134) | def run_command(logger, args):
    method search (line 190) | def search(app, syst, executions):
    method update_list (line 205) | def update_list(current, candidate, quantity):

FILE: calculon/llm/layers.py
  class Layer (line 21) | class Layer:
    method __init__ (line 28) | def __init__(self, name, sys, fw_flops=0, agrad_flops=0, wgrad_flops=0,
    method get_stats_json (line 62) | def get_stats_json(self):
    method get_stats_str (line 120) | def get_stats_str(self):
    method set_bytes_per_element (line 149) | def set_bytes_per_element(self, bytes_per_element):
    method shard_optimizer (line 153) | def shard_optimizer(self, num_procs):
    method get_fw_flops (line 157) | def get_fw_flops(self):
    method get_fw_mem_accessed (line 160) | def get_fw_mem_accessed(self):
    method get_fw_arithmetic_intensity (line 165) | def get_fw_arithmetic_intensity(self):
    method get_recompute_flag (line 172) | def get_recompute_flag(self):
    method get_recomm_flag (line 175) | def get_recomm_flag(self):
    method reuses_activation (line 178) | def reuses_activation(self):
    method stores_activation (line 181) | def stores_activation(self):
    method stores_output (line 184) | def stores_output(self):
    method get_agrad_flops (line 187) | def get_agrad_flops(self):
    method get_agrad_mem_accessed (line 190) | def get_agrad_mem_accessed(self):
    method get_agrad_arithmetic_intensity (line 198) | def get_agrad_arithmetic_intensity(self):
    method get_wgrad_flops (line 205) | def get_wgrad_flops(self):
    method get_wgrad_mem_accessed (line 208) | def get_wgrad_mem_accessed(self):
    method get_wgrad_arithmetic_intensity (line 220) | def get_wgrad_arithmetic_intensity(self):
    method get_optim_step_flops (line 230) | def get_optim_step_flops(self):
    method get_optim_step_mem_accessed (line 234) | def get_optim_step_mem_accessed(self):
    method get_optim_step_arithmetic_intensity (line 237) | def get_optim_step_arithmetic_intensity(self):
    method get_weight (line 244) | def get_weight(self):
    method get_activation (line 247) | def get_activation(self):
    method get_output (line 250) | def get_output(self):
    method get_weight_grad (line 253) | def get_weight_grad(self, sharded=True):
    method get_activation_grad (line 265) | def get_activation_grad(self):
    method get_optimizer (line 268) | def get_optimizer(self):
    method set_processing_time (line 278) | def set_processing_time(self, processing_time):
    method get_processing_time (line 281) | def get_processing_time(self):
    method use_matrix_engine (line 284) | def use_matrix_engine(self):
    method get_comm_bytes (line 287) | def get_comm_bytes(self, stage, baseblock=True):
    method get_comm_tile (line 290) | def get_comm_tile(self, stage, baseblock=True):
    method compute_flops_time (line 293) | def compute_flops_time(self, stage):
    method compute_mem_time (line 310) | def compute_mem_time(self, stage):
    method compute_net_time (line 323) | def compute_net_time(self, stage, baseblock=True):
    method get_exposed_net_time (line 326) | def get_exposed_net_time(self, stage, baseblock=True):
    method get_required_bandwidth (line 329) | def get_required_bandwidth(self, stage, baseblock=True):
    method compute_processing_time (line 332) | def compute_processing_time(self, stage):
  class Linear (line 341) | class Linear(Layer):
    method __init__ (line 342) | def __init__(self, name, sys, batch_seq, c_in, c_out,
    method use_matrix_engine (line 363) | def use_matrix_engine(self):
  class LinearOverlapped (line 366) | class LinearOverlapped(Layer):
    method __init__ (line 367) | def __init__(self, name, sys, batch_seq, c_in, c_out, tensor_par_comm_...
    method use_matrix_engine (line 438) | def use_matrix_engine(self):
    method get_comm_bytes (line 441) | def get_comm_bytes(self, stage, baseblock=True):
    method get_comm_flops (line 479) | def get_comm_flops(self, stage, baseblock=True):
    method get_num_tiles (line 482) | def get_num_tiles(self):
    method get_comm_tile (line 485) | def get_comm_tile(self, stage, baseblock=True):
    method compute_net_time (line 488) | def compute_net_time(self, stage, baseblock=True):
    method compute_processing_time (line 549) | def compute_processing_time(self, stage):
    method get_exposed_net_time (line 594) | def get_exposed_net_time(self, stage, baseblock=True):
    method get_required_bandwidth (line 599) | def get_required_bandwidth(self, stage, baseblock=True):
  class BatchMatMul (line 607) | class BatchMatMul(Layer):
    method __init__ (line 608) | def __init__(self, name, sys, batch, size_a, contraction_size, size_b,
    method use_matrix_engine (line 625) | def use_matrix_engine(self):
  class LayerNorm (line 630) | class LayerNorm(Layer):
    method __init__ (line 631) | def __init__(self, name, sys, act_size, hidden,
  class DropOut (line 652) | class DropOut(Layer):
    method __init__ (line 653) | def __init__(self, name, sys, act_size,
    method get_activation (line 672) | def get_activation(self):
    method get_activation_grad (line 675) | def get_activation_grad(self):
    method get_fw_mem_accessed (line 678) | def get_fw_mem_accessed(self):
    method get_agrad_mem_accessed (line 685) | def get_agrad_mem_accessed(self):
  class GeLU (line 690) | class GeLU(Layer):
    method __init__ (line 691) | def __init__(self, name, sys, act_size,
    method get_agrad_mem_accessed (line 713) | def get_agrad_mem_accessed(self):
  class SoftMax (line 718) | class SoftMax(Layer):
    method __init__ (line 719) | def __init__(self, name, sys, act_size,
    method get_agrad_mem_accessed (line 735) | def get_agrad_mem_accessed(self):
  class ElementWise (line 740) | class ElementWise(Layer):
    method __init__ (line 741) | def __init__(self, name, sys, operand1, operand2,
  class Fork (line 760) | class Fork(Layer):
    method __init__ (line 761) | def __init__(self, name, sys, act_size, num_users,
    method get_fw_mem_accessed (line 779) | def get_fw_mem_accessed(self):
    method get_agrad_mem_accessed (line 782) | def get_agrad_mem_accessed(self):
  class TPComm (line 787) | class TPComm(Layer):
    method __init__ (line 789) | def __init__(self, name, sys, act_size, net_id, num_peers, tensor_par_...
    method get_activation (line 835) | def get_activation(self):
    method get_fw_mem_accessed (line 845) | def get_fw_mem_accessed(self):
    method get_activation_grad (line 852) | def get_activation_grad(self):
    method get_agrad_mem_accessed (line 862) | def get_agrad_mem_accessed(self):
    method get_comm_bytes (line 869) | def get_comm_bytes(self, stage, baseblock=True):
    method compute_net_time (line 890) | def compute_net_time(self, stage, baseblock=True):
    method get_exposed_net_time (line 935) | def get_exposed_net_time(self, stage, baseblock=True):
    method compute_processing_time (line 939) | def compute_processing_time(self, stage):

FILE: calculon/llm/llm.py
  class Llm (line 22) | class Llm:
    class Application (line 31) | class Application:
      method __init__ (line 33) | def __init__(self, cfg):
      method num_parameters (line 42) | def num_parameters(self):
    class Execution (line 54) | class Execution:
      method fields (line 58) | def fields():
      method from_json (line 68) | def from_json(cfg):
      method __init__ (line 73) | def __init__(self, num_procs, tensor_par, pipeline_par, data_par,
      method get_json (line 147) | def get_json(self):
      method get_peers_json (line 160) | def get_peers_json(self):
    class Error (line 202) | class Error(Exception):
    method _factors (line 206) | def _factors(x):
    method get_all_tensor_parallelisms (line 212) | def get_all_tensor_parallelisms(num_procs, hidden, attn_heads):
    method get_all_pipeline_parallelisms (line 218) | def get_all_pipeline_parallelisms(num_procs, tensor_par, num_blocks):
    method get_data_parallelism (line 227) | def get_data_parallelism(num_procs, tensor_par, pipeline_par):
    method get_valid_pipeline_interleavings (line 233) | def get_valid_pipeline_interleavings(num_blocks, pipeline_par):
    method get_valid_microbatch_sizes (line 242) | def get_valid_microbatch_sizes(
    method can_redo_ag (line 252) | def can_redo_ag(tensor_par_comm_type, activation_recompute):
    method __init__ (line 255) | def __init__(self, app, log):
    method get_stats_fields (line 417) | def get_stats_fields():
    method get_stats_values (line 521) | def get_stats_values(self):
    method get_stats_json (line 626) | def get_stats_json(self, include_layers):
    method _build_attn_block (line 638) | def _build_attn_block(self):
    method _build_mlp_block (line 901) | def _build_mlp_block(self):
    method compile (line 1027) | def compile(self, sys, exe):
    method _check_network_assignments (line 1095) | def _check_network_assignments(self):
    method _compute_block_stats (line 1127) | def _compute_block_stats(self):
    method _compute_batch_stats (line 1448) | def _compute_batch_stats(self):
    method _check_mem_caps (line 1930) | def _check_mem_caps(self):
    method _misc_sanity_checks (line 1942) | def _misc_sanity_checks(self):
    method run (line 2011) | def run(self, sys):
    method _get_fw_offload_size (line 2021) | def _get_fw_offload_size(self):
    method _get_bw_offload_size (line 2035) | def _get_bw_offload_size(self):
    method get_fw_time (line 2049) | def get_fw_time(self):
    method get_fw_offload_time (line 2052) | def get_fw_offload_time(self):
    method get_fw_offload_overhead (line 2055) | def get_fw_offload_overhead(self):
    method get_bw_time (line 2061) | def get_bw_time(self):
    method get_optim_step_time (line 2064) | def get_optim_step_time(self):
    method get_bw_offload_time (line 2067) | def get_bw_offload_time(self):
    method get_bw_offload_overhead (line 2073) | def get_bw_offload_overhead(self):
    method get_recompute_time (line 2082) | def get_recompute_time(self):
    method get_recomm_exposed_time (line 2085) | def get_recomm_exposed_time(self):
    method get_recomm_link_time (line 2091) | def get_recomm_link_time(self):
    method get_bubble_time (line 2097) | def get_bubble_time(self):
    method get_tp_comm_exposed_time (line 2100) | def get_tp_comm_exposed_time(self):
    method get_pp_comm_exposed_time (line 2103) | def get_pp_comm_exposed_time(self):
    method get_dp_comm_exposed_time (line 2106) | def get_dp_comm_exposed_time(self):
    method get_tp_comm_link_time (line 2112) | def get_tp_comm_link_time(self):
    method get_pp_comm_link_time (line 2115) | def get_pp_comm_link_time(self):
    method get_dp_comm_link_time (line 2118) | def get_dp_comm_link_time(self):
    method get_dp_comm_net_time (line 2124) | def get_dp_comm_net_time(self):
    method get_total_time (line 2130) | def get_total_time(self):
    method get_useful_flops (line 2144) | def get_useful_flops(self):
    method get_compute_efficiency (line 2153) | def get_compute_efficiency(self):
    method get_system_efficiency (line 2161) | def get_system_efficiency(self):
    method get_total_efficiency (line 2166) | def get_total_efficiency(self):
    method get_weight_space_min (line 2172) | def get_weight_space_min(self):
    method get_weight_space (line 2175) | def get_weight_space(self):
    method get_act_space_min (line 2178) | def get_act_space_min(self):
    method get_act_space (line 2184) | def get_act_space(self):
    method get_act_checkpoint_size_min (line 2187) | def get_act_checkpoint_size_min(self):
    method get_act_checkpoint_size (line 2194) | def get_act_checkpoint_size(self):
    method get_weight_grad_space_min (line 2203) | def get_weight_grad_space_min(self):
    method get_weight_grad_space (line 2212) | def get_weight_grad_space(self):
    method get_act_grad_space_min (line 2218) | def get_act_grad_space_min(self):
    method get_act_grad_space (line 2221) | def get_act_grad_space(self):
    method get_optimizer_space_min (line 2229) | def get_optimizer_space_min(self):
    method get_optimizer_space (line 2235) | def get_optimizer_space(self):
    method _get_mem_cap_reqs (line 2241) | def _get_mem_cap_reqs(self):
    method get_mem_tier1_cap_req (line 2273) | def get_mem_tier1_cap_req(self):
    method get_mem_tier2_cap_req (line 2276) | def get_mem_tier2_cap_req(self):
    method get_act_offload_bw_req (line 2279) | def get_act_offload_bw_req(self):
    method get_weight_offload_bw_req (line 2292) | def get_weight_offload_bw_req(self):
    method get_optim_offload_bw_req (line 2301) | def get_optim_offload_bw_req(self):
    method get_offload_mem_bw_req (line 2316) | def get_offload_mem_bw_req(self):
    method get_sample_rate (line 2332) | def get_sample_rate(self):
    method display_stats (line 2335) | def display_stats(self):

FILE: calculon/llm/optimal_execution.py
  class OptimalExecution (line 30) | class OptimalExecution(calculon.CommandLine):
    method create_parser (line 35) | def create_parser(subparser):
    method run_command (line 73) | def run_command(logger, args):
    method get_batch_size (line 165) | def get_batch_size(data_par, max_batch_size):
    method search (line 176) | def search(debug, top_n, layers, num_procs, max_batch_size, datatype,
    method update_list (line 260) | def update_list(current, candidate, quantity):

FILE: calculon/llm/parameter_calculator.py
  class ParameterCalculator (line 23) | class ParameterCalculator(calculon.CommandLine):
    method create_parser (line 28) | def create_parser(subparser):
    method run_command (line 39) | def run_command(logger, args):

FILE: calculon/llm/runner.py
  class Runner (line 21) | class Runner(calculon.CommandLine):
    method create_parser (line 26) | def create_parser(subparser):
    method run_command (line 44) | def run_command(logger, args):

FILE: calculon/llm/validation.py
  class Validation (line 27) | class Validation(calculon.CommandLine):
    method create_parser (line 32) | def create_parser(subparser):
    method run_command (line 43) | def run_command(logger, args):
    method seqsel_fig1 (line 56) | def seqsel_fig1(logger, args):
    method seqsel_fig7 (line 184) | def seqsel_fig7(logger, args):
    method seqsel_tab5 (line 281) | def seqsel_tab5(logger, args):

FILE: calculon/memory.py
  class Memory (line 18) | class Memory:
    method __init__ (line 21) | def __init__(self, cfg):
    method capacity (line 31) | def capacity(self):
    method bandwidth (line 35) | def bandwidth(self):
    method efficiency (line 38) | def efficiency(self, op_bytes):
    method throughput (line 44) | def throughput(self, op_bytes):

FILE: calculon/network.py
  class Network (line 19) | class Network:
    class Op (line 27) | class Op:
      method __init__ (line 28) | def __init__(self, scalar, offset):
    method _parse_op (line 33) | def _parse_op(op, scalar, offset):
    method __init__ (line 43) | def __init__(self, cfg):
    method size (line 62) | def size(self):
    method must_be_filled (line 66) | def must_be_filled(self):
    method processor_usage (line 70) | def processor_usage(self):
    method time (line 73) | def time(self, op, op_size, comm_size):

FILE: calculon/processor.py
  class Processor (line 18) | class Processor:
    method __init__ (line 21) | def __init__(self, cfg):
    method flops (line 37) | def flops(self, datatype):
    method efficiency (line 40) | def efficiency(self, datatype, op_flops):
    method throughput (line 46) | def throughput(self, datatype, op_flops):

FILE: calculon/system.py
  class System (line 22) | class System:
    method supported_datatypes (line 33) | def supported_datatypes():
    method __init__ (line 36) | def __init__(self, cfg):
    method num_networks (line 51) | def num_networks(self):
    method get_network (line 54) | def get_network(self, tier):
    method set_datatype (line 58) | def set_datatype(self, datatype):
    method get_matrix_throughput (line 62) | def get_matrix_throughput(self, flops):
    method get_vector_throughput (line 65) | def get_vector_throughput(self, flops):
    method get_mem1_throughput (line 68) | def get_mem1_throughput(self, size):
    method get_mem2_throughput (line 71) | def get_mem2_throughput(self, size):
    method compute_offload_time (line 74) | def compute_offload_time(self, size):
    method get_processing_time (line 77) | def get_processing_time(self, flops_time, mem_time):

FILE: calculon/util.py
  function human_format (line 21) | def human_format(value, v_type='base10', precision=3):
  function pick (line 66) | def pick(en, a, b):
  function arg_true_false_all (line 72) | def arg_true_false_all(arg):

FILE: calculon/version.py
  class Version (line 20) | class Version(calculon.CommandLine):
    method create_parser (line 25) | def create_parser(subparser):
    method run_command (line 31) | def run_command(logger, args):

FILE: scripts/3dplot.py
  function main (line 12) | def main(args):

FILE: scripts/find_huge.py
  function transformer_attn_size (line 10) | def transformer_attn_size(hidden, layers, attn_size_step=32):
  function transformer_num_parameters (line 13) | def transformer_num_parameters(hidden, layers, attn_size_step=32):
  function transformer_t_params (line 21) | def transformer_t_params(hidden, layers):
  function step_rounder (line 24) | def step_rounder(layer, step=1):
  function model_ratio (line 27) | def model_ratio(hidden, layers):
  function human_format (line 30) | def human_format(value, v_type='base10', precision=3):
  function ratio_layer_scale (line 75) | def ratio_layer_scale(hidden, ratio=128, step=4):
  function ratio_hidden_scale (line 77) | def ratio_hidden_scale(layers, ratio=128, step=4096):
  function ratio_param_layer_scale (line 79) | def ratio_param_layer_scale(layers, ratio=128, step=4096):
  function ratio_param_hidden_scale (line 82) | def ratio_param_hidden_scale(hidden, ratio=128, step=4):

FILE: scripts/heatmap.py
  function main (line 13) | def main(args):

FILE: scripts/json_to_csv.py
  function main (line 10) | def main(args):

FILE: setup.py
  function find_version (line 30) | def find_version(*file_paths):

FILE: test/test_json_write_read.py
  class JsonWriteReadTestCase (line 23) | class JsonWriteReadTestCase(unittest.TestCase):
    method test_json_read_write (line 24) | def test_json_read_write(self):
Condensed preview — 87 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (272K chars).
[
  {
    "path": ".gitignore",
    "chars": 337,
    "preview": ".DS_Store\n*.py[cod]\n*.log\n\n# C extensions\n*.so\n\n# Packages\n*.egg\n*.egg-info\ndist\nbuild\neggs\nparts\nvar\nsdist\ndevelop-eggs"
  },
  {
    "path": "LICENSE",
    "chars": 11362,
    "preview": "\n                                 Apache License\n                           Version 2.0, January 2004\n                  "
  },
  {
    "path": "Makefile",
    "chars": 647,
    "preview": ".SUFFIXES:\n.PHONY: help install clean lint test count\n\nhelp:\n\t@echo \"options are: install clean lint test count\"\n\ninstal"
  },
  {
    "path": "NOTICE",
    "chars": 122,
    "preview": "Calculon - Co-design for large scale parallel applications\nCopyright 2022 Michael Isaev, Nic McDonald\nAll rights reserve"
  },
  {
    "path": "README.md",
    "chars": 1988,
    "preview": "[![DOI](https://zenodo.org/badge/660734586.svg)](https://zenodo.org/badge/latestdoi/660734586)\n# Calculon - Co-design fo"
  },
  {
    "path": "bin/calculon",
    "chars": 1960,
    "preview": "#!/usr/bin/env python3\n\n\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this f"
  },
  {
    "path": "calculon/__init__.py",
    "chars": 897,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/command_line.py",
    "chars": 2150,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/io.py",
    "chars": 1685,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/__init__.py",
    "chars": 940,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/all_executions.py",
    "chars": 9038,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/layers.py",
    "chars": 36143,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/llm.py",
    "chars": 102961,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/optimal_execution.py",
    "chars": 11717,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/parameter_calculator.py",
    "chars": 1885,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/runner.py",
    "chars": 2601,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/llm/validation.py",
    "chars": 11117,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/memory.py",
    "chars": 1399,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/network.py",
    "chars": 3222,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/processor.py",
    "chars": 1715,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/system.py",
    "chars": 2368,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/util.py",
    "chars": 2160,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "calculon/version.py",
    "chars": 1166,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "examples/3072_t4_p64_d12_mbs4_full.json",
    "chars": 611,
    "preview": "{\n  \"num_procs\": 3072,\n  \"tensor_par\": 4,\n  \"pipeline_par\": 64,\n  \"data_par\": 12,\n  \"tensor_par_net\": 0,\n  \"pipeline_par"
  },
  {
    "path": "models/anthropic-52B.json",
    "chars": 125,
    "preview": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 8192,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": "
  },
  {
    "path": "models/chinchilla.json",
    "chars": 125,
    "preview": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": "
  },
  {
    "path": "models/gopher-280B.json",
    "chars": 127,
    "preview": "{\n  \"hidden\": 16384,\n  \"feedforward\": 65536,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 128,\n  \"num_blocks\""
  },
  {
    "path": "models/gpt3-13B.json",
    "chars": 125,
    "preview": "{\n  \"hidden\": 5140,\n  \"feedforward\": 20560,\n  \"seq_size\": 2048,\n  \"attn_heads\": 40,\n  \"attn_size\": 128,\n  \"num_blocks\": "
  },
  {
    "path": "models/gpt3-175B.json",
    "chars": 126,
    "preview": "{\n  \"hidden\": 12288,\n  \"feedforward\": 49152,\n  \"seq_size\": 2048,\n  \"attn_heads\": 96,\n  \"attn_size\": 128,\n  \"num_blocks\":"
  },
  {
    "path": "models/lamda.json",
    "chars": 126,
    "preview": "{\n  \"hidden\": 8192,\n  \"feedforward\": 65536,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 128,\n  \"num_blocks\":"
  },
  {
    "path": "models/megatron-126M.json",
    "chars": 122,
    "preview": "{\n  \"hidden\": 768,\n  \"feedforward\": 3072,\n  \"seq_size\": 2048,\n  \"attn_heads\": 16,\n  \"attn_size\": 48,\n  \"num_blocks\": 12\n"
  },
  {
    "path": "models/megatron-1T.json",
    "chars": 129,
    "preview": "{\n  \"hidden\": 25600,\n  \"feedforward\": 102400,\n  \"seq_size\": 2048,\n  \"attn_heads\": 160,\n  \"attn_size\": 160,\n  \"num_blocks"
  },
  {
    "path": "models/megatron-22B.json",
    "chars": 124,
    "preview": "{\n  \"hidden\": 6144,\n  \"feedforward\": 24576,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 96,\n  \"num_blocks\": 4"
  },
  {
    "path": "models/megatron-40B.json",
    "chars": 125,
    "preview": "{\n  \"hidden\": 8192,\n  \"feedforward\": 32768,\n  \"seq_size\": 2048,\n  \"attn_heads\": 64,\n  \"attn_size\": 128,\n  \"num_blocks\": "
  },
  {
    "path": "models/megatron-5B.json",
    "chars": 125,
    "preview": "{\n  \"hidden\": 4096,\n  \"feedforward\": 16384,\n  \"seq_size\": 2048,\n  \"attn_heads\": 32,\n  \"attn_size\": 128,\n  \"num_blocks\": "
  },
  {
    "path": "models/palm-540B.json",
    "chars": 127,
    "preview": "{\n  \"hidden\": 18432,\n  \"feedforward\": 73728,\n  \"seq_size\": 2048,\n  \"attn_heads\": 48,\n  \"attn_size\": 256,\n  \"num_blocks\":"
  },
  {
    "path": "models/turing-530B.json",
    "chars": 128,
    "preview": "{\n  \"hidden\": 20480,\n  \"feedforward\": 81920,\n  \"seq_size\": 2048,\n  \"attn_heads\": 128,\n  \"attn_size\": 160,\n  \"num_blocks\""
  },
  {
    "path": "pylintrc",
    "chars": 287,
    "preview": "[MESSAGES CONTROL]\ndisable=locally-disabled,\n\ttoo-many-branches,\n\ttoo-many-instance-attributes,\n\ttoo-many-return-stateme"
  },
  {
    "path": "pyproject.toml",
    "chars": 104,
    "preview": "[build-system]\nrequires = [\n    \"setuptools>=42\",\n    \"wheel\"\n]\nbuild-backend = \"setuptools.build_meta\"\n"
  },
  {
    "path": "scripts/3dplot.py",
    "chars": 2052,
    "preview": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport matplotlib\nmatplotlib.use('TkAgg')\nimport matplotlib.pypl"
  },
  {
    "path": "scripts/find_huge.py",
    "chars": 5392,
    "preview": "#!/usr/bin/env python3\n\nimport numpy as np\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimpor"
  },
  {
    "path": "scripts/heatmap.py",
    "chars": 2409,
    "preview": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot"
  },
  {
    "path": "scripts/install_hooks.sh",
    "chars": 215,
    "preview": "#!/bin/bash\n\nset -e\n\n# Pre-commit hook\ncat > .git/hooks/pre-commit <<-EOF\n#!/bin/bash\necho -n \"Testing...\"\nif ! make tes"
  },
  {
    "path": "scripts/json_to_csv.py",
    "chars": 994,
    "preview": "#!/usr/bin/env python3\n\nimport argparse\nimport calculon\nimport gzip\nimport json\nimport sys\n\n\ndef main(args):\n  j = calcu"
  },
  {
    "path": "setup.py",
    "chars": 1615,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "systems/a100_80e.json",
    "chars": 1324,
    "preview": "{\n  \"matrix\" : {\n    \"float16\": {\n      \"tflops\": 312,\n      \"gflops_efficiency\": [\n        [128, 0.99],\n        [16, 0."
  },
  {
    "path": "systems/a100_80g.json",
    "chars": 1315,
    "preview": "{\n  \"matrix\" : {\n    \"float16\": {\n      \"tflops\": 312,\n      \"gflops_efficiency\": [\n        [128, 0.95],\n        [16, 0."
  },
  {
    "path": "systems/h100_80g_nvl8.json",
    "chars": 1607,
    "preview": "{\n  \"matrix\": {\n    \"float8\": {\n      \"tflops\": 2000,\n      \"gflops_efficiency\": [\n        [128, 0.95],\n        [16, 0.9"
  },
  {
    "path": "test/__init__.py",
    "chars": 675,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "test/test.sh",
    "chars": 1813,
    "preview": "#!/bin/bash\n\nset -e\n\nexport PYTHONPATH=.\n\n# CLI interface infrastructure\necho -e \"### Testing top level --help\"\n./bin/ca"
  },
  {
    "path": "test/test_json_write_read.py",
    "chars": 2217,
    "preview": "\"\"\"\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance"
  },
  {
    "path": "validation/seqsel/fig1/gpt3-175B_none.json",
    "chars": 604,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig1/gpt3-175B_seqsel.json",
    "chars": 611,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig1/megatron-1T_none.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-1T_seqsel.json",
    "chars": 615,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig1/megatron-22B_none.json",
    "chars": 602,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig1/megatron-22B_seqsel.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig1/turing-530B_none.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig1/turing-530B_seqsel.json",
    "chars": 614,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_full.json",
    "chars": 604,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_none.json",
    "chars": 604,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_sel.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_seq.json",
    "chars": 606,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig7/gpt3-175B_seqsel.json",
    "chars": 611,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_full.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_none.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_sel.json",
    "chars": 612,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_seq.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-1T_seqsel.json",
    "chars": 614,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_full.json",
    "chars": 602,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_none.json",
    "chars": 602,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_sel.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_seq.json",
    "chars": 604,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig7/megatron-22B_seqsel.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_full.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_none.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_sel.json",
    "chars": 612,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_seq.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/fig7/turing-530B_seqsel.json",
    "chars": 614,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/tab5/gpt3-175B_full.json",
    "chars": 604,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/tab5/gpt3-175B_seqsel.json",
    "chars": 611,
    "preview": "{\n  \"num_procs\": 64,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 8,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net"
  },
  {
    "path": "validation/seqsel/tab5/megatron-1T_full.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-1T_seqsel.json",
    "chars": 614,
    "preview": "{\n  \"num_procs\": 512,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 64,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/tab5/megatron-22B_full.json",
    "chars": 602,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/tab5/megatron-22B_seqsel.json",
    "chars": 609,
    "preview": "{\n  \"num_procs\": 8,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 1,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_net\""
  },
  {
    "path": "validation/seqsel/tab5/turing-530B_full.json",
    "chars": 607,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  },
  {
    "path": "validation/seqsel/tab5/turing-530B_seqsel.json",
    "chars": 614,
    "preview": "{\n  \"num_procs\": 280,\n  \"tensor_par\": 8,\n  \"pipeline_par\": 35,\n  \"data_par\": 1,\n  \"tensor_par_net\": 0,\n  \"pipeline_par_n"
  }
]

About this extraction

This page contains the full source code of the calculon-ai/calculon GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 87 files (249.7 KB), approximately 68.3k tokens, and a symbol index with 255 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!