[
  {
    "path": ".gitignore",
    "content": "checkpoints/*\n!checkpoints/README.md\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "content": "Be excellent to each other.\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# Grok-1\n\nThis repository contains JAX example code for loading and running the Grok-1 open-weights model.\n\nMake sure to download the checkpoint and place the `ckpt-0` directory in `checkpoints` - see [Downloading the weights](#downloading-the-weights)\n\nThen, run\n\n```shell\npip install -r requirements.txt\npython run.py\n```\n\nto test the code.\n\nThe script loads the checkpoint and samples from the model on a test input.\n\nDue to the large size of the model (314B parameters), a machine with enough GPU memory is required to test the model with the example code.\nThe implementation of the MoE layer in this repository is not efficient. The implementation was chosen to avoid the need for custom kernels to validate the correctness of the model.\n\n# Model Specifications\n\nGrok-1 is currently designed with the following specifications:\n\n- **Parameters:** 314B\n- **Architecture:** Mixture of 8 Experts (MoE)\n- **Experts Utilization:** 2 experts used per token\n- **Layers:** 64\n- **Attention Heads:** 48 for queries, 8 for keys/values\n- **Embedding Size:** 6,144\n- **Tokenization:** SentencePiece tokenizer with 131,072 tokens\n- **Additional Features:**\n  - Rotary embeddings (RoPE)\n  - Supports activation sharding and 8-bit quantization\n- **Maximum Sequence Length (context):** 8,192 tokens\n\n# Downloading the weights\n\nYou can download the weights using a torrent client and this magnet link:\n\n```\nmagnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce\n```\n\nor directly using [HuggingFace 🤗 Hub](https://huggingface.co/xai-org/grok-1):\n```\ngit clone https://github.com/xai-org/grok-1.git && cd grok-1\npip install huggingface_hub[hf_transfer]\nhuggingface-cli download xai-org/grok-1 --repo-type model --include ckpt-0/* --local-dir checkpoints --local-dir-use-symlinks False\n```\n\n# License\n\nThe code and associated Grok-1 weights in this release are licensed under the\nApache 2.0 license. The license only applies to the source files in this\nrepository and the model weights of Grok-1.\n"
  },
  {
    "path": "checkpoint.py",
    "content": "# Copyright 2024 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom __future__ import annotations\n\nimport contextlib\nimport logging\nimport math\nimport os\nimport pickle\nimport re\nimport shutil\nimport sys\nimport tempfile\nfrom concurrent.futures import ThreadPoolExecutor, wait\nfrom typing import Any, Optional\n\nimport jax\nimport numpy as np\nfrom jax.experimental import multihost_utils\n\nfrom model import QuantizedWeight8bit\n\nlogger = logging.getLogger(__name__)\nrank_logger = logging.getLogger(\"rank\")\n\n# Needed for loading the checkpoint with pickle.\nsys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit\n\n\n@contextlib.contextmanager\ndef copy_to_shm(file: str):\n    if file.startswith(\"/dev/shm/\"):\n        # Nothing to do, the file is already in shared memory.\n        yield file\n        return\n\n    tmp_dir = \"/dev/shm/\"\n    fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)\n    try:\n        shutil.copyfile(file, tmp_path)\n        yield tmp_path\n    finally:\n        os.remove(tmp_path)\n        os.close(fd)\n\n\n@contextlib.contextmanager\ndef copy_from_shm(file: str):\n    tmp_dir = \"/dev/shm/\"\n    fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)\n    try:\n        yield tmp_path\n        shutil.copyfile(tmp_path, file)\n    finally:\n        os.remove(tmp_path)\n        os.close(fd)\n\n\ndef fast_unpickle(path: str) -> Any:\n    with copy_to_shm(path) as tmp_path:\n        with open(tmp_path, \"rb\") as f:\n            return pickle.load(f)\n\n\ndef fast_pickle(obj: Any, path: str) -> None:\n    with copy_from_shm(path) as tmp_path:\n        with open(tmp_path, \"wb\") as f:\n            pickle.dump(obj, f)\n\n\ndef load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):\n    \"\"\"Loads a set of arrays.\"\"\"\n    pool = ThreadPoolExecutor(max_workers=32)\n    fs = list()\n    num_tensors = 0\n    num_replicas = 1\n    data_model_shards = math.prod(mesh_config)\n    if tensor_indices is None:\n        iterator = enumerate(shaped_arrays)\n    else:\n        iterator = zip(tensor_indices, shaped_arrays)\n    for i, t in iterator:\n        if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):\n            idx = (\n                jax.process_index() // (num_replicas * data_model_shards) * data_model_shards\n                + jax.process_index() % data_model_shards\n            )\n            fs.append(\n                pool.submit(fast_unpickle, os.path.join(directory, f\"tensor{i:05d}_{idx:03d}\"))\n            )\n            num_tensors += 1\n        else:\n            fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))\n    wait(fs)\n    return [f.result() for f in fs]\n\n\ndef path_tuple_to_string(path: tuple) -> str:\n    pieces = []\n    for elem in path:\n        if isinstance(elem, jax.tree_util.DictKey):\n            pieces.append(elem.key)\n        elif isinstance(elem, jax.tree_util.GetAttrKey):\n            pieces.append(elem.name)\n        else:\n            assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))\n    return \"/\".join(pieces)\n\n\ndef get_load_path_str(\n    init_path_str: str,\n    load_rename_rules: Optional[list[tuple[str, str]]] = None,\n    load_exclude_rules: Optional[list[str]] = None,\n) -> Optional[str]:\n    # Exclusion\n    if load_exclude_rules is not None:\n        for search_pattern in load_exclude_rules:\n            if re.search(search_pattern, init_path_str):\n                return None\n\n    # Renaming\n    load_path_str = init_path_str\n    if load_rename_rules is not None:\n        for search_pattern, replacement_pattern in load_rename_rules:\n            if re.search(search_pattern, load_path_str):\n                load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)\n                break\n\n    return load_path_str\n\n\ndef replace_with_load_state(\n    init_state: Any,\n    load_state: Any,\n    load_rename_rules: Optional[list[tuple[str, str]]] = None,\n    load_exclude_rules: Optional[list[str]] = None,\n    mesh_config: tuple = (1, 1),\n) -> Any:\n    flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)\n    flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)\n    load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}\n\n    replaced = []\n    num_replicas = 1\n    data_model_shards = math.prod(mesh_config)\n    for i, (init_path, tensor) in enumerate(flatten_init):\n        init_path_str = path_tuple_to_string(init_path)\n        load_path_str = get_load_path_str(init_path_str, load_rename_rules, load_exclude_rules)\n        if load_path_str is None:\n            rank_logger.info(f\"Excluded from restore: {init_path_str}.\")\n            replaced.append(tensor)\n        elif load_path_str in load_map:\n            if load_path_str == init_path_str:\n                rank_logger.info(f\"Restored from ckpt: {init_path_str}.\")\n            else:\n                rank_logger.info(f\"Restored from ckpt: {init_path_str} <-- {load_path_str}.\")\n            replaced.append(load_map[load_path_str])\n        else:\n            rank_logger.info(f\"Not found in ckpt: {init_path_str}.\")\n            if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):\n                replaced.append(tensor)\n            else:\n                replaced.append(np.zeros_like(tensor))\n\n    return jax.tree_util.tree_unflatten(structure_init, replaced)\n\n\ndef restore(\n    checkpoint_path: str,\n    state_shapes: Any,\n    mesh,\n    between_hosts_config,\n    params_only,\n    state_sharding,\n    init_state: Optional[Any] = None,\n) -> Any:\n    ckpt_path = os.path.join(checkpoint_path, \"ckpt-0\")\n\n    rank_logger.info(\"Loading checkpoint at {}\".format(ckpt_path))\n    ckpt_shapes = state_shapes\n    ckpt_shapes_with_path, structure = jax.tree_util.tree_flatten_with_path(ckpt_shapes)\n\n    ckpt_shapes_flat = [elem[1] for elem in ckpt_shapes_with_path]\n    loaded_tensors = load_tensors(ckpt_shapes_flat, ckpt_path, between_hosts_config)\n\n    state = jax.tree_util.tree_unflatten(structure, loaded_tensors)\n\n    # Sanity check to give a better error message.\n    ckpt_keys = set(state.params.keys())\n    code_keys = set(state_sharding.params.keys())\n\n    if ckpt_keys != code_keys and init_state is None:\n        missing_in_ckpt = code_keys - ckpt_keys\n        missing_locally = ckpt_keys - code_keys\n        raise ValueError(\n            \"Parameters in the code are not matching checkpoint parameters.\\n\"\n            \"Params missing in checkpoint: {}\\nParams missing in code: {}\".format(\n                missing_in_ckpt, missing_locally\n            )\n        )\n    state_sharding = jax.tree_util.tree_map(\n        lambda x: jax.sharding.PartitionSpec() if x is None else x,\n        state_sharding,\n        is_leaf=lambda x: x is None,\n    )\n    state = multihost_utils.host_local_array_to_global_array(state, mesh, state_sharding)\n    if params_only:\n        state = state.params\n    return state\n"
  },
  {
    "path": "checkpoints/README.md",
    "content": "# Checkpoint directory\n\nPlace Grok-1 checkpoints here so they can be loaded by the example script.\n"
  },
  {
    "path": "model.py",
    "content": "# Copyright 2024 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport functools\nimport logging\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union\n\nimport haiku as hk\nimport jax\nimport jax.experimental.maps\nimport jax.numpy as jnp\nfrom jax import config, tree_util\nfrom jax.experimental.shard_map import shard_map\nfrom jax.lax import with_sharding_constraint as pjit_sharding_constraint\nfrom jax.sharding import PartitionSpec\nfrom jax.sharding import PartitionSpec as P\n\nconfig.update(\"jax_spmd_mode\", \"allow_all\")\n\nlogger = logging.getLogger(__name__)\nrank_logger = logging.getLogger(\"rank\")\n\n\n@dataclass\nclass QuantizedWeight8bit:\n    weight: jnp.array\n    scales: jnp.array\n\n    @property\n    def shape(self):\n        return self.weight.shape\n\n\ntree_util.register_pytree_node(\n    QuantizedWeight8bit,\n    lambda qw: ([qw.weight, qw.scales], ()),\n    lambda _, children: QuantizedWeight8bit(children[0], children[1]),\n)\n\n\nclass TrainingState(NamedTuple):\n    \"\"\"Container for the training state.\"\"\"\n\n    params: hk.Params\n\n\ndef _match(qs, ks):\n    \"\"\"Return True if regexes in qs match any window of strings in tuple ks.\"\"\"\n    # compile regexes and force complete match\n    qts = tuple(map(lambda x: re.compile(x + \"$\"), qs))\n    for i in range(len(ks) - len(qs) + 1):\n        matches = [x.match(y) for x, y in zip(qts, ks[i:])]\n        if matches and all(matches):\n            return True\n    return False\n\n\ndef with_sharding_constraint(x, constraint):\n    if jax.experimental.maps.thread_resources.env.physical_mesh.empty:\n        return x\n    else:\n        return pjit_sharding_constraint(x, constraint)\n\n\ndef cast_bfloat16(x):\n    if x.dtype.kind == \"f\":\n        return x.astype(jnp.bfloat16)\n    else:\n        return x\n\n\ndef ffn_size(emb_size, widening_factor):\n    _ffn_size = int(widening_factor * emb_size) * 2 // 3\n    _ffn_size = _ffn_size + (8 - _ffn_size) % 8  # ensure it's a multiple of 8\n    logger.debug(f\"emd_size: {emb_size} adjusted ffn_size: {_ffn_size}\")\n    return _ffn_size\n\n\ndef apply_rules(rules):\n    def _apply_rules(path, value):\n        del value  # Unused.\n\n        path_list = [str(i.key).split(\"/\") for i in path if isinstance(i, jax.tree_util.DictKey)]\n        flattened_path = jax.tree_util.tree_flatten(path_list)[0]\n\n        for rule, replacement in rules:\n            if _match(rule, flattened_path):\n                if isinstance(replacement, PartitionSpec):\n                    if \"layer_stack\" in flattened_path:\n                        replacement = PartitionSpec(None, *replacement)\n                rank_logger.debug(f\"Apply {replacement} to {flattened_path} with rule {rule}\")\n                return replacement\n        rank_logger.info(f\"{flattened_path} no matching found!\")\n        return None\n\n    return _apply_rules\n\n\nTRANSFORMER_PARTITION_RULES = [\n    # attention\n    ((\"multi_head_attention\", \"(query|key|value)\", \"w\"), P(\"data\", \"model\")),\n    ((\"multi_head_attention\", \"(query|key|value)\", \"b\"), P(None)),\n    ((\"multi_head_attention\", \"linear\", \"w\"), P(\"model\", \"data\")),\n    ((\"multi_head_attention\", \"linear\", \"b\"), P(None)),\n    # mlp\n    ((r\"decoder_layer_[0-9]+\", \"linear\", \"w\"), P(\"data\", \"model\")),\n    ((r\"decoder_layer_[0-9]+\", \"linear\", \"b\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"linear_v\", \"w\"), P(\"data\", \"model\")),\n    ((r\"decoder_layer_[0-9]+\", \"linear_v\", \"b\"), P(None)),\n    (\n        (r\"decoder_layer_[0-9]+\", \"linear_1\", \"w\"),\n        P(\n            \"model\",\n            \"data\",\n        ),\n    ),\n    ((r\"decoder_layer_[0-9]+\", \"linear_1\", \"b\"), P(None)),\n    # layer norms\n    ((r\"decoder_layer_[0-9]+\", \"layer_norm\", \"offset\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"layer_norm\", \"scale\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"layer_norm_1\", \"offset\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"layer_norm_1\", \"scale\"), P(None)),\n    # rms norms\n    ((r\"decoder_layer_[0-9]+\", \"rms_norm\", \"scale\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"rms_norm_1\", \"scale\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"rms_norm_2\", \"scale\"), P(None)),\n    ((r\"decoder_layer_[0-9]+\", \"rms_norm_3\", \"scale\"), P(None)),\n    # router\n    ((\"router\", \"w\"), P(\"data\")),\n    # moe mlp\n    ((\"moe\", \"linear\", \"w\"), P(None, \"data\", \"model\")),\n    ((\"moe\", \"linear\", \"b\"), P(None)),\n    ((\"moe\", \"linear_v\", \"w\"), P(None, \"data\", \"model\")),\n    ((\"moe\", \"linear_v\", \"b\"), P(None)),\n    ((\"moe\", \"linear_1\", \"w\"), P(None, \"model\", \"data\")),\n    ((\"moe\", \"linear_1\", \"b\"), P(None)),\n    # layer norms\n    ((\"moe\", \"layer_norm\", \"offset\"), P(None)),\n    ((\"moe\", \"layer_norm\", \"scale\"), P(None)),\n    ((\"moe\", \"layer_norm_1\", \"offset\"), P(None)),\n    ((\"moe\", \"layer_norm_1\", \"scale\"), P(None)),\n    # rms norms\n    ((\"moe\", \"rms_norm\", \"scale\"), P(None)),\n    ((\"moe\", \"rms_norm_1\", \"scale\"), P(None)),\n    ((\"moe\", \"rms_norm_2\", \"scale\"), P(None)),\n    ((\"moe\", \"rms_norm_3\", \"scale\"), P(None)),\n]\n\nLM_PARTITION_RULES = [\n    # Embedding layer.\n    (\n        (\"language_model\", \"positional_embeddings\"),\n        P(None, (\"data\", \"model\")),\n    ),\n    (\n        (\"language_model\", \"in_out_embed\", \"embeddings\"),\n        P(None, (\"data\", \"model\")),\n    ),\n    # Final RMSNorm.\n    ((\"language_model\", \"rms_norm\"), P(None)),\n]\nTOP_K = 8\n\n\nclass KVMemory(NamedTuple):\n    k: Optional[jax.Array]\n    v: Optional[jax.Array]\n    step: Optional[jax.Array]\n\n\ndef init_layer_memories(\n    batch_size: int,\n    sequence_len: int,\n    num_kv_heads: int,\n    key_size: int,\n    num_layers: int,\n    step: Optional[jax.Array] = None,\n    dtype=jnp.bfloat16,\n):\n    return [\n        KVMemory(\n            k=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),\n            v=jnp.zeros((batch_size, sequence_len, num_kv_heads, key_size), dtype=dtype),\n            step=step,\n        )\n        for _ in range(num_layers)\n    ]\n\n\nclass Memory(NamedTuple):\n    # Self-attention key/value cache.\n    layers: List[KVMemory]\n\n\nclass Router(hk.Module):\n    def __init__(\n        self,\n        num_selected_experts: int,\n        data_axis: Union[str, Tuple[str, ...]] = \"data\",\n        model_axis: Union[str, Tuple[str, ...]] = \"model\",\n        shard_activations: bool = False,\n        mesh: Any = None,\n        name: str = \"router\",\n    ):\n        super().__init__(name)\n        self.shard_activations = shard_activations\n        self.data_axis = data_axis\n        self.model_axis = model_axis\n        self.mesh = mesh\n        self.num_selected_experts = num_selected_experts\n\n    def compute_routing_prob(\n        self, inputs: jax.Array, padding_mask: Optional[jax.Array], num_experts: int\n    ):\n        return self._compute_routing_prob(inputs, padding_mask, num_experts)\n\n    @hk.transparent\n    def _compute_routing_prob(\n        self,\n        inputs: jax.Array,\n        padding_mask: Optional[jax.Array],\n        num_experts: int,\n    ):\n        # Using fp32 for the routing prob computation.\n        inputs = jax.lax.convert_element_type(inputs, jnp.float32)\n\n        # [batch_size, seq_len, num_experts]\n        routing_logits = self._router_weights(inputs, num_experts, sharding=P(\"data\"))\n        assert routing_logits.dtype == jnp.float32\n        routing_probs = jax.nn.softmax(routing_logits)\n\n        if padding_mask is not None:\n            routing_probs *= padding_mask\n\n        return routing_probs, routing_logits, 0\n\n    @hk.transparent\n    def _router_weights(\n        self,\n        x: jax.Array,\n        num_experts: int,\n        sharding: Optional[P] = None,\n    ):\n        fprop_dtype = x.dtype\n        if not x.shape:\n            raise ValueError(\"Input must not be scalar.\")\n\n        input_size = self.input_size = x.shape[-1]\n        w = hk.get_parameter(\n            \"w\", [input_size, num_experts], jnp.float32, init=hk.initializers.Constant(0)\n        )\n        if sharding:\n            w = with_sharding_constraint(w, sharding)\n\n        out = jnp.dot(x, w.astype(fprop_dtype))\n        return out\n\n\nclass MoELayer(hk.Module):\n    def __init__(\n        self,\n        num_experts: int,\n        layer_fn: Callable,\n        router: Router,\n        mesh: Any = None,\n        shard_activations: bool = False,\n        data_axis: Union[str, Tuple[str, ...]] = \"data\",\n        model_axis: Union[str, Tuple[str, ...]] = \"model\",\n        name: Optional[str] = \"moe\",\n    ):\n        super().__init__(name)\n        self.num_experts = num_experts\n        self.layer_fn = layer_fn\n        self.router = router\n        self.mesh = mesh\n        self.shard_activations = shard_activations\n        self.data_axis = data_axis\n        self.model_axis = model_axis\n\n    @hk.transparent\n    def _inference_call(self, inputs: jax.Array, padding_mask: Optional[jax.Array] = None):\n        routing_probs, _, _ = self.router.compute_routing_prob(\n            inputs, padding_mask, self.num_experts\n        )\n        expert_gate, expert_index = jax.lax.top_k(routing_probs, k=self.router.num_selected_experts)\n        tmp = jnp.reshape(inputs, (inputs.shape[0] * inputs.shape[1], inputs.shape[2]))\n        broad_inputs = jnp.tile(tmp[:, jnp.newaxis, :], (1, self.router.num_selected_experts, 1))\n        broad_inputs = jnp.reshape(\n            broad_inputs, (broad_inputs.shape[0] * broad_inputs.shape[1], broad_inputs.shape[2])\n        )\n        init_fn, _ = hk.transform(self.layer_fn)\n        vmapped_init_fn = jax.vmap(init_fn, in_axes=0, out_axes=0)\n        lifted_init_fn = hk.experimental.transparent_lift(vmapped_init_fn)\n        # Fetch the vmapped params of the DenseBlock.\n        params = lifted_init_fn(\n            jax.random.split(jax.random.PRNGKey(1), self.num_experts),\n            jnp.zeros((self.num_experts, 1, 1, inputs.shape[-1])),\n        )\n\n        # Index and prob are in the shape [m, 2] indicating which token assigned to which experts.\n        # b: num_expert\n        # m: token or sequence dim\n        # k: input embed dim\n        # n: output embed dim\n        # e: the number of experts chosen for each token\n        @functools.partial(\n            shard_map,\n            mesh=self.mesh,\n            in_specs=(\n                P(self.data_axis, None),\n                P(None, None, self.model_axis),\n                P(None, None, self.model_axis),\n                P(None),\n                P(None),\n            ),\n            out_specs=P(self.data_axis, self.model_axis),\n            check_rep=False,\n        )\n        def moe_slow_matmul1(input, weight, scales, index, prob):\n            weight = weight * scales\n            one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)\n            all_expert_output = jnp.einsum(\"mk,bkn->bmn\", input, weight)\n            output = jnp.einsum(\"bm,bmn->mn\", one_hot_indices, all_expert_output)\n            return output\n\n        @functools.partial(\n            shard_map,\n            mesh=self.mesh,\n            in_specs=(\n                P(self.data_axis, self.model_axis),\n                P(None, self.model_axis, None),\n                P(None, self.model_axis, None),\n                P(None),\n                P(None),\n            ),\n            out_specs=P(self.data_axis, None),\n            check_rep=False,\n        )\n        def moe_slow_matmul2(input, weight, scales, index, prob):\n            weight = weight * scales\n            one_hot_indices = jax.nn.one_hot(index.reshape(-1), 8, axis=0)\n            all_expert_output = jnp.einsum(\"mk,bkn->bmn\", input, weight)\n            output = jnp.einsum(\"bm,bmn->mn\", one_hot_indices, all_expert_output)\n            return jax.lax.psum(output, axis_name=\"model\")\n\n        if hasattr(params[\"linear\"][\"w\"], \"scales\"):\n            x = moe_slow_matmul1(\n                broad_inputs,\n                params[\"linear_v\"][\"w\"].weight,\n                params[\"linear_v\"][\"w\"].scales,\n                expert_index,\n                expert_gate,\n            )\n            y = moe_slow_matmul1(\n                broad_inputs,\n                params[\"linear\"][\"w\"].weight,\n                params[\"linear\"][\"w\"].scales,\n                expert_index,\n                expert_gate,\n            )\n            y = jax.nn.gelu(y)\n            out = moe_slow_matmul2(\n                x * y,\n                params[\"linear_1\"][\"w\"].weight,\n                params[\"linear_1\"][\"w\"].scales,\n                expert_index,\n                expert_gate,\n            )\n            out = jnp.reshape(\n                out,\n                [\n                    inputs.shape[0],\n                    inputs.shape[1],\n                    self.router.num_selected_experts,\n                    out.shape[-1],\n                ],\n            )\n            out = expert_gate[:, :, :, None].astype(jnp.bfloat16) * out\n            out = jnp.sum(out, axis=2)\n            out = out.astype(jnp.bfloat16)\n        else:\n            # This is only here so that we can construct a valid init_fn with this code.\n            return inputs\n        return out\n\n    def __call__(self, inputs: jax.Array, padding_mask: jax.Array):\n        return self._inference_call(inputs)\n\n\nclass MHAOutput(NamedTuple):\n    \"\"\"Outputs of the multi-head attention operation.\"\"\"\n\n    embeddings: jax.Array\n    memory: Any\n\n\nclass DecoderOutput(NamedTuple):\n    embeddings: jax.Array\n    memory: Any\n\n\nclass TransformerOutput(NamedTuple):\n    embeddings: jax.Array\n    memory: Any\n\n\n@dataclass\nclass TransformerConfig:\n    emb_size: int\n    key_size: int\n    num_q_heads: int\n    num_kv_heads: int\n    num_layers: int\n    vocab_size: int = 128 * 1024\n    widening_factor: float = 4.0\n\n    attn_output_multiplier: float = 1.0\n\n    name: Optional[str] = None\n\n    num_experts: int = -1\n    capacity_factor: float = 1.0\n    num_selected_experts: int = 1\n\n    init_scale: float = 1.0\n    shard_activations: bool = False\n\n    # Used for activation sharding.\n    data_axis: Union[str, Tuple[str, ...]] = \"data\"\n    model_axis: Union[str, Tuple[str, ...]] = \"model\"\n\n    def __post_init__(self):\n        if isinstance(self.data_axis, list):\n            self.data_axis = tuple(self.data_axis)\n        if isinstance(self.model_axis, list):\n            self.model_axis = tuple(self.model_axis)\n\n    def partition_rules(self):\n        return TRANSFORMER_PARTITION_RULES\n\n    def make(self, mesh=None) -> \"Transformer\":\n        data_axis = tuple(self.data_axis) if isinstance(self.data_axis, list) else self.data_axis\n        model_axis = (\n            tuple(self.model_axis) if isinstance(self.model_axis, list) else self.model_axis\n        )\n\n        return Transformer(\n            num_q_heads=self.num_q_heads,\n            num_kv_heads=self.num_kv_heads,\n            widening_factor=self.widening_factor,\n            key_size=self.key_size,\n            init_scale=self.init_scale,\n            mesh=mesh,\n            attn_output_multiplier=self.attn_output_multiplier,\n            shard_activations=self.shard_activations,\n            num_layers=self.num_layers,\n            num_experts=self.num_experts,\n            num_selected_experts=self.num_selected_experts,\n            data_axis=data_axis,\n            model_axis=model_axis,\n        )\n\n    def get_memory_sharding(self):\n        return Memory(\n            layers=[\n                KVMemory(\n                    k=P(self.data_axis, self.model_axis),\n                    v=P(self.data_axis, self.model_axis),\n                    step=P(self.data_axis),\n                )\n                for _ in range(self.num_layers)\n            ],\n        )\n\n\ndef hk_rms_norm(\n    x: jax.Array,\n    fixed_scale=False,\n    sharding=P(None),\n) -> jax.Array:\n    \"\"\"Applies a unique LayerNorm to x with default settings.\"\"\"\n    ln = RMSNorm(axis=-1, create_scale=not fixed_scale, sharding=sharding)\n    return ln(x)\n\n\ndef make_attention_mask(\n    query_input: jax.Array,\n    key_input: jax.Array,\n    pairwise_fn: Callable[..., Any] = jnp.multiply,\n    dtype: Any = jnp.bfloat16,\n):\n    \"\"\"Mask-making helper for attention weights.\n\n    In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the\n    attention weights will be `[batch..., heads, len_q, len_kv]` and this\n    function will produce `[batch..., 1, len_q, len_kv]`.\n\n    Args:\n      query_input: a batched, flat input of query_length size\n      key_input: a batched, flat input of key_length size\n      pairwise_fn: broadcasting elementwise comparison function\n      dtype: mask return dtype\n\n    Returns:\n      A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention.\n    \"\"\"\n    mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2))\n    mask = jnp.expand_dims(mask, axis=-3)\n    return mask.astype(dtype)\n\n\nclass Linear(hk.Linear):\n    def __init__(\n        self,\n        output_size: int,\n        with_bias: bool = True,\n        sharding: Optional[P] = None,\n        mesh: Any = None,\n        name: Optional[str] = None,\n        shard_axis: int = 0,\n    ):\n        super().__init__(\n            output_size=output_size,\n            with_bias=with_bias,\n            name=name,\n        )\n        self.sharding = sharding\n        self.mesh = mesh\n        self.shard_axis = shard_axis\n\n    def __call__(\n        self,\n        inputs: jax.Array,\n    ) -> jax.Array:\n        \"\"\"Computes a linear transform of the input.\"\"\"\n\n        fprop_dtype = inputs.dtype\n        if not inputs.shape:\n            raise ValueError(\"Input must not be scalar.\")\n\n        input_size = self.input_size = inputs.shape[-1]\n        output_size = self.output_size\n\n        w = hk.get_parameter(\n            \"w\", [input_size, output_size], jnp.float32, init=hk.initializers.Constant(0)\n        )\n\n        if hasattr(w, \"scales\"):\n            shape = inputs.shape\n            inputs = jnp.reshape(inputs, (-1, shape[-1]))\n\n            @functools.partial(\n                shard_map,\n                mesh=self.mesh,\n                in_specs=(self.sharding, self.sharding),\n                out_specs=self.sharding,\n                check_rep=False,\n            )\n            def mul(w, s):\n                return w.astype(s.dtype) * s\n\n            w = mul(w.weight, w.scales)\n        out = jnp.dot(inputs, w.astype(fprop_dtype))\n        if self.with_bias:\n            b = hk.get_parameter(\n                \"b\", [self.output_size], jnp.float32, init=hk.initializers.Constant(0)\n            )\n            b = jnp.broadcast_to(b, out.shape)\n            out = out + b.astype(fprop_dtype)\n\n        return out\n\n\nclass RMSNorm(hk.RMSNorm):\n\n    def __init__(\n        self,\n        axis: Union[int, Sequence[int], slice],\n        eps: float = 1e-5,\n        name: Optional[str] = None,\n        create_scale: bool = True,\n        sharding: Optional[P] = None,\n    ):\n        super().__init__(axis, eps, create_scale=create_scale, name=name)\n        self.sharding = sharding\n\n    def __call__(self, inputs: jax.Array):\n        fprop_dtype = inputs.dtype\n        param_shape = (inputs.shape[-1],)\n        if self.create_scale:\n            scale = hk.get_parameter(\n                \"scale\",\n                param_shape,\n                dtype=jnp.float32,\n                init=hk.initializers.Constant(0),\n            )\n            if self.sharding:\n                scale = with_sharding_constraint(scale, self.sharding)\n            scale = jnp.broadcast_to(scale.astype(jnp.float32), inputs.shape)\n        else:\n            scale = 1.0\n        inputs = inputs.astype(jnp.float32)\n        scale = scale.astype(jnp.float32)\n        mean_squared = jnp.mean(jnp.square(inputs), axis=[-1], keepdims=True)\n        mean_squared = jnp.broadcast_to(mean_squared, inputs.shape)\n\n        normed_inputs = inputs * jax.lax.rsqrt(mean_squared + self.eps)\n\n        outputs = scale * normed_inputs\n\n        return outputs.astype(fprop_dtype)\n\n\ndef rotate_half(\n    x: jax.Array,\n) -> jax.Array:\n    \"\"\"Obtain the rotated counterpart of each feature\"\"\"\n    x1, x2 = jnp.split(x, 2, axis=-1)\n    return jnp.concatenate((-x2, x1), axis=-1)\n\n\nclass RotaryEmbedding(hk.Module):\n    \"\"\"Applies rotary embeddings (RoPE) to the input sequence tensor,\n    as described in https://arxiv.org/abs/2104.09864.\n\n    Attributes:\n        dim (int): Dimensionality of the feature vectors\n        base_exponent (int): Base exponent to compute embeddings from\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        name: Optional[str] = None,\n        base_exponent: int = 10000,\n    ):\n        super().__init__(name)\n        self.dim = dim\n        self.base_exponent = base_exponent\n        assert self.dim % 2 == 0\n\n    def __call__(\n        self,\n        x: jax.Array,\n        seq_dim: int,\n        offset: jax.Array,\n        const_position: Optional[int] = None,\n        t: Optional[jax.Array] = None,\n    ) -> jax.Array:\n        fprop_dtype = x.dtype\n        # Compute the per-dimension frequencies\n        exponents = jnp.arange(0, self.dim, 2, dtype=jnp.float32)\n        inv_freq = jnp.asarray(\n            1.0 / (self.base_exponent ** (exponents / self.dim)), dtype=jnp.float32\n        )\n\n        if jnp.shape(offset) == ():\n            # Offset can be a scalar or one offset per batch element.\n            offset = jnp.expand_dims(offset, 0)\n\n        # Compute the per element phase (to pass into sin and cos)\n        if const_position:\n            t = const_position * jnp.ones(\n                (\n                    1,\n                    x.shape[seq_dim],\n                ),\n                dtype=jnp.float32,\n            )\n        elif t is None:\n            t = jnp.arange(x.shape[seq_dim], dtype=jnp.float32) + jnp.expand_dims(offset, -1)\n        phase = jnp.einsum(\"bi,j->bij\", t, inv_freq)\n        phase = jnp.tile(phase, reps=(1, 2))[:, :, None, :]\n\n        x = x * jnp.cos(phase) + rotate_half(x) * jnp.sin(phase)\n        x = x.astype(fprop_dtype)\n\n        return x\n\n\nclass MultiHeadAttention(hk.Module):\n    def __init__(\n        self,\n        num_q_heads: int,\n        num_kv_heads: int,\n        key_size: int,\n        *,\n        with_bias: bool = True,\n        value_size: Optional[int] = None,\n        model_size: Optional[int] = None,\n        attn_output_multiplier: 1.0,\n        data_axis: Union[str, Tuple[str, ...]] = \"data\",\n        model_axis: Union[str, Tuple[str, ...]] = \"model\",\n        name: Optional[str] = None,\n    ):\n        super().__init__(name=name)\n        self.num_q_heads = num_q_heads\n        self.num_kv_heads = num_kv_heads\n        self.key_size = key_size\n        self.value_size = value_size or key_size\n        self.model_size = model_size or key_size * num_q_heads\n        self.data_axis = data_axis\n        self.model_axis = model_axis\n        self.attn_output_multiplier = attn_output_multiplier\n        self.with_bias = with_bias\n\n    def __call__(\n        self,\n        query: jax.Array,\n        key: Optional[jax.Array],\n        value: Optional[jax.Array],\n        mask: Optional[jax.Array] = None,\n        kv_memory: Optional[KVMemory] = None,\n        mesh: Any = None,\n    ) -> MHAOutput:\n        # In shape hints below, we suppress the leading dims [...] for brevity.\n        # Hence e.g. [A, B] should be read in every case as [..., A, B].\n        sequence_length = query.shape[1]\n        projection = self._linear_projection\n        use_memory = False\n        if kv_memory is not None:\n            if kv_memory.k is None:\n                assert kv_memory.v is None\n                assert key is not None\n                assert value is not None\n            else:\n                assert kv_memory.v is not None\n                use_memory = True\n        else:\n            assert key is not None\n            assert value is not None\n\n        # Check that the keys and values have consistent batch size and sequence length.\n        if not use_memory:\n            assert key.shape[:2] == value.shape[:2], f\"key/value shape: {key.shape}/{value.shape}\"\n\n        if mask is not None:\n            assert mask.ndim == 4\n            assert mask.shape[0] in {\n                1,\n                query.shape[0],\n            }, f\"mask/query shape: {mask.shape}/{query.shape}\"\n            if not use_memory:\n                assert key.shape[0] in {\n                    1,\n                    query.shape[0],\n                }, f\"key/query shape: {key.shape}/{query.shape}\"\n            assert mask.shape[1] == 1\n            assert mask.shape[2] in {\n                1,\n                query.shape[1],\n            }, f\"mask/query shape: {mask.shape}/{query.shape}\"\n            if not use_memory:\n                assert mask.shape[3] in {\n                    1,\n                    key.shape[1],\n                }, f\"mask/query shape: {mask.shape}/{key.shape}\"\n\n        # Compute key/query/values (overload K/Q/V to denote the respective sizes).\n        assert self.num_q_heads % self.num_kv_heads == 0\n        query_heads = projection(\n            query,\n            self.key_size,\n            self.num_q_heads,\n            name=\"query\",\n            sharding=P(\"data\", \"model\"),\n            mesh=mesh,\n        )  # [B, T', H, Q=K]\n\n        new_memory = None\n        key_heads = projection(\n            key,\n            self.key_size,\n            self.num_kv_heads,\n            name=\"key\",\n            sharding=P(\"data\", \"model\"),\n            mesh=mesh,\n        )  # [B, T, H, K]\n        value_heads = projection(\n            value,\n            self.value_size,\n            self.num_kv_heads,\n            name=\"value\",\n            sharding=P(\"data\", \"model\"),\n            mesh=mesh,\n        )  # [B, T, H, V]\n\n        rotate = RotaryEmbedding(dim=self.key_size, base_exponent=int(1e4))\n        key_heads = rotate(key_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))\n        query_heads = rotate(query_heads, seq_dim=1, offset=(kv_memory.step if kv_memory else 0))\n\n        @functools.partial(jax.vmap)\n        def update_into(mem, start, update):\n            return jax.lax.dynamic_update_slice_in_dim(mem, update, start, axis=0)\n\n        if kv_memory:\n            if mesh is not None:\n\n                @functools.partial(\n                    shard_map,\n                    mesh=mesh,\n                    in_specs=(\n                        P(\"data\", None, \"model\"),\n                        P(\"data\"),\n                        P(\"data\", None, \"model\"),\n                    ),\n                    out_specs=P(\"data\", None, \"model\"),\n                    check_rep=False,\n                )\n                def update_into_shmap(mems, starts, updates):\n                    return update_into(mems, starts, updates)\n\n                key_heads = update_into_shmap(kv_memory.k, kv_memory.step, key_heads)\n                value_heads = update_into_shmap(kv_memory.v, kv_memory.step, value_heads)\n            else:\n                key_heads = update_into(kv_memory.k, kv_memory.step, key_heads)\n                value_heads = update_into(kv_memory.v, kv_memory.step, value_heads)\n\n            new_step = kv_memory.step + sequence_length\n            memory_mask = jnp.arange(kv_memory.k.shape[1]) < new_step[:, None]\n            memory_mask = memory_mask[:, None, None, :]  # [B, H, T, T]\n            if mask is not None:\n                mask = memory_mask * mask\n            else:\n                mask = memory_mask\n\n            new_memory = KVMemory(\n                k=key_heads,\n                v=value_heads,\n                step=new_step,\n            )\n        # Add separate dimension for grouped query heads.\n        query_heads = with_sharding_constraint(query_heads, P(self.data_axis, None, \"model\", None))\n        key_heads = with_sharding_constraint(key_heads, P(self.data_axis, None, \"model\", None))\n        value_heads = with_sharding_constraint(value_heads, P(self.data_axis, None, \"model\", None))\n        b, t, h, d = query_heads.shape\n        _, _, kv_h, _ = key_heads.shape\n        assert h % kv_h == 0, f\"query_heads {h} must be a multiple of kv_heads {kv_h}\"\n\n        query_heads = jnp.reshape(query_heads, (b, t, kv_h, h // kv_h, d))\n        query_heads = with_sharding_constraint(\n            query_heads, P(self.data_axis, None, \"model\", None, None)\n        )\n\n        # Compute attention weights.\n        # Attention softmax is always carried out in fp32.\n        attn_logits = jnp.einsum(\"...thHd,...Thd->...hHtT\", query_heads, key_heads).astype(\n            jnp.float32\n        )\n        attn_logits *= self.attn_output_multiplier\n        max_attn_val = jnp.array(30.0, dtype=attn_logits.dtype)\n        attn_logits = max_attn_val * jnp.tanh(attn_logits / max_attn_val)\n\n        mask = mask[:, :, None, :, :]\n\n        if mask is not None:\n            if mask.ndim != attn_logits.ndim:\n                raise ValueError(\n                    f\"Mask dimensionality {mask.ndim} must match logits dimensionality \"\n                    f\"{attn_logits.ndim} for {mask.shape}/{attn_logits.shape}.\"\n                )\n            attn_logits = jnp.where(mask, attn_logits, -1e30)\n        attn_weights = jax.nn.softmax(attn_logits).astype(query.dtype)  # [H, T', T]\n\n        # Weight the values by the attention and flatten the head vectors.\n        attn = jnp.einsum(\"...hHtT,...Thd->...thHd\", attn_weights, value_heads)\n        attn = with_sharding_constraint(attn, P(self.data_axis, None, \"model\", None, None))\n        leading_dims = attn.shape[:2]\n        attn = jnp.reshape(attn, (*leading_dims, -1))  # [T', H*V]\n        attn = with_sharding_constraint(attn, P(self.data_axis, None, \"model\"))\n        # Apply another projection to get the final embeddings.\n        final_projection = Linear(\n            self.model_size,\n            with_bias=False,\n            sharding=P(\"model\", \"data\"),\n            mesh=mesh,\n        )\n        return MHAOutput(final_projection(attn), new_memory)\n\n    @hk.transparent\n    def _linear_projection(\n        self,\n        x: jax.Array,\n        head_size: int,\n        num_heads: int,\n        sharding: Optional[P] = None,\n        name: Optional[str] = None,\n        mesh: Any = None,\n    ) -> jax.Array:\n        y = Linear(\n            num_heads * head_size,\n            with_bias=False,\n            name=name,\n            sharding=sharding,\n            mesh=mesh,\n        )(x)\n        *leading_dims, _ = x.shape\n        return y.reshape((*leading_dims, num_heads, head_size))\n\n\n@dataclass\nclass MHABlock(hk.Module):\n    \"\"\"A MHA Block\"\"\"\n\n    num_q_heads: int\n    num_kv_heads: int\n    key_size: int\n    attn_output_multiplier: float = 1.0\n    mesh: Any = None\n    data_axis: Union[str, Tuple[str, ...]] = \"data\"\n    model_axis: Union[str, Tuple[str, ...]] = \"model\"\n\n    @hk.transparent\n    def __call__(\n        self,\n        inputs: jax.Array,  # [B, T, D]\n        mask: jax.Array,  # [B, 1, T, T] or [B, 1, 1, T] or B[1, 1, 1, 1]\n        layer_memory: Optional[KVMemory],\n    ) -> MHAOutput:\n        _, _, model_size = inputs.shape\n        assert mask.ndim == 4, f\"shape: {mask.shape}\"\n        assert mask.shape[2] in {1, inputs.shape[1]}, str(mask.shape)\n        assert mask.shape[3] in {1, inputs.shape[1]}, str(mask.shape)\n        side_input = inputs\n\n        def attn_block(query, key, value, mask, memory) -> MHAOutput:\n            return MultiHeadAttention(\n                num_q_heads=self.num_q_heads,\n                num_kv_heads=self.num_kv_heads,\n                key_size=self.key_size,\n                model_size=model_size,\n                data_axis=self.data_axis,\n                model_axis=self.model_axis,\n                attn_output_multiplier=self.attn_output_multiplier,\n            )(\n                query,\n                key,\n                value,\n                mask,\n                memory,\n                mesh=self.mesh,\n            )\n\n        attn_output = attn_block(inputs, side_input, side_input, mask, layer_memory)\n        h_attn = attn_output.embeddings\n\n        return attn_output._replace(embeddings=h_attn)\n\n\n@dataclass\nclass DenseBlock(hk.Module):\n    num_q_heads: int\n    num_kv_heads: int\n    key_size: int\n    widening_factor: float = 4.0\n    sharding_constraint: bool = False\n    mesh: Any = None\n\n    @hk.transparent\n    def __call__(\n        self,\n        inputs: jax.Array,  # [B, T, D]\n    ) -> jax.Array:  # [B, T, D]\n        _, _, model_size = inputs.shape\n        h_v = Linear(\n            ffn_size(\n                model_size,\n                self.widening_factor,\n            ),\n            with_bias=False,\n            mesh=self.mesh,\n            sharding=P(\"data\", \"model\"),\n            name=\"linear_v\",\n        )(inputs)\n        h_w1 = jax.nn.gelu(\n            Linear(\n                ffn_size(\n                    model_size,\n                    self.widening_factor,\n                ),\n                with_bias=False,\n                mesh=self.mesh,\n                sharding=P(\"data\", \"model\"),\n            )(inputs)\n        )\n        h_dense = Linear(\n            model_size,\n            with_bias=False,\n            sharding=P(\"model\", \"data\"),\n            mesh=self.mesh,\n            shard_axis=1,\n        )(h_w1 * h_v)\n\n        return h_dense\n\n\n@dataclass\nclass DecoderLayer(hk.Module):\n    \"\"\"A transformer stack.\"\"\"\n\n    num_q_heads: int\n    num_kv_heads: int\n    key_size: int\n    num_layers: int\n    # MoE.\n    num_experts: int\n    layer_index: Optional[int] = None\n    num_selected_experts: int = 1\n    widening_factor: float = 4.0\n    name: Optional[str] = None\n    data_axis: Union[str, Tuple[str, ...]] = \"data\"\n    model_axis: Union[str, Tuple[str, ...]] = \"model\"\n    shard_activations: bool = False\n    attn_output_multiplier: float = 1.0\n    mesh: Any = None\n\n    def __call__(\n        self,\n        inputs: jax.Array,  # [B, T, D]\n        mask: jax.Array,  # [B, 1, T, T] or [B, 1, 1, T]\n        padding_mask: Optional[jax.Array],\n        layer_memory: Optional[KVMemory],\n    ) -> DecoderOutput:\n        \"\"\"Transforms input embedding sequences to output embedding sequences.\"\"\"\n\n        def layer_norm(x):\n            return hk_rms_norm(x)\n\n        if self.shard_activations:\n            sharding = P(self.data_axis, None, self.model_axis)\n        else:\n            sharding = P(self.data_axis, None)\n        h = with_sharding_constraint(inputs, sharding)\n\n        attn_output = MHABlock(\n            num_q_heads=self.num_q_heads,\n            num_kv_heads=self.num_kv_heads,\n            key_size=self.key_size,\n            attn_output_multiplier=self.attn_output_multiplier,\n            mesh=self.mesh,\n            data_axis=self.data_axis,\n            model_axis=self.model_axis,\n        )(layer_norm(h), mask, layer_memory)\n        h_attn = attn_output.embeddings\n\n        h_attn = layer_norm(h_attn)\n        h += h_attn\n        h = with_sharding_constraint(h, sharding)\n\n        def base_dense_block(h):\n            h = DenseBlock(\n                num_q_heads=self.num_q_heads,\n                num_kv_heads=self.num_kv_heads,\n                key_size=self.key_size,\n                widening_factor=self.widening_factor,\n                sharding_constraint=False,\n                mesh=self.mesh,\n            )(h)\n            return h\n\n        if self.num_experts > 1:\n            rank_logger.debug(\"Using MoE!\")\n            router = Router(\n                num_selected_experts=self.num_selected_experts,\n                shard_activations=self.shard_activations,\n                data_axis=self.data_axis,\n                model_axis=self.model_axis,\n                mesh=self.mesh,\n            )\n            h_dense = MoELayer(\n                num_experts=self.num_experts,\n                mesh=self.mesh,\n                layer_fn=base_dense_block,\n                router=router,\n                shard_activations=self.shard_activations,\n                data_axis=self.data_axis,\n                model_axis=self.model_axis,\n            )(layer_norm(h), padding_mask)\n        else:\n            h_dense = base_dense_block(layer_norm(h))\n\n        h_dense = layer_norm(h_dense)\n        h += h_dense\n        h = with_sharding_constraint(h, sharding)\n\n        return DecoderOutput(\n            embeddings=h,\n            memory=attn_output.memory,\n        )\n\n\nclass LanguageModelOutput(NamedTuple):\n    logits: jax.Array\n    model_state: Any\n\n\nclass InOutEmbed(hk.Embed):\n    \"\"\"Module for embedding tokens in a low-dimensional space.\"\"\"\n\n    def __init__(\n        self,\n        vocab_size: Optional[int] = None,\n        embed_dim: Optional[int] = None,\n        sharding: Optional[P] = None,\n        name: Optional[str] = None,\n    ):\n        super().__init__(\n            vocab_size=vocab_size,\n            embed_dim=embed_dim,\n            name=name,\n        )\n        self.sharding = sharding\n\n    @property\n    def embeddings(self):\n        embed_mat = hk.get_parameter(\n            \"embeddings\",\n            [self.vocab_size, self.embed_dim],\n            dtype=jnp.float32,\n            init=hk.initializers.Constant(0),\n        )\n        if self.sharding:\n            embed_mat = with_sharding_constraint(embed_mat, self.sharding)\n        return embed_mat\n\n    def decode(\n        self,\n        inputs: jax.Array,\n    ) -> jax.Array:\n        return jnp.dot(inputs, self.embeddings.T.astype(inputs.dtype))\n\n\n@dataclass\nclass LanguageModelConfig:\n    \"\"\"An autoregressive transformer-based language model.\"\"\"\n\n    model: Optional[TransformerConfig]\n    vocab_size: int\n    pad_token: int\n    eos_token: int\n    sequence_len: int\n    model_size: int = 0\n    embedding_init_scale: float = 1.0\n    embedding_multiplier_scale: float = 1.0\n    output_multiplier_scale: float = 1.0\n    name: Optional[str] = None\n    fprop_dtype: Any = jnp.bfloat16\n    model_type: Optional[str] = None\n    init_scale_override: Optional[float] = None\n    shard_embeddings: bool = True\n\n    _initialized = False\n\n    def initialize(self):\n        # We cannot specify [] as a default value (it is mutable), hence None.\n        model_config = self.model\n        assert self.init_scale_override is None, (\n            \"Overriding model initialize scale is supported only for predefined models.\"\n        )\n        if self.model_size == 0:\n            self.model_size = model_config.emb_size\n        assert self.model is not None, \"Model could not be initialized.\"\n        self._initialized = True\n        return self\n\n    def make(self, *args, **kwargs):\n        if not self._initialized:\n            logger.warning(\n                f\"LanguageModel {self.name} is not initialized. Initializing for one replica.\"\n            )\n            self.initialize()\n\n        return LanguageModel(\n            model=self.model.make(*args, **kwargs),\n            config=self,\n            fprop_dtype=self.fprop_dtype,\n            mesh=kwargs.get(\"mesh\", None),\n        )\n\n    def partition_rules(self):\n        return LM_PARTITION_RULES + self.model.partition_rules()\n\n\ndef layer_norm(x, model):\n    return hk_rms_norm(x)\n\n\n@dataclass\nclass LanguageModel(hk.Module):\n    \"\"\"An autoregressive transformer-based language model.\"\"\"\n\n    model: \"Transformer\"\n    config: LanguageModelConfig\n    fprop_dtype: Any = jnp.bfloat16\n    name: Optional[str] = None\n    mesh: Any = None\n\n    def __call__(\n        self,\n        tokens: jax.Array,\n        memory: Optional[Memory] = None,\n        *,\n        batch: Dict[str, jax.Array] = {},\n        last_hid_only: bool = False,\n        length: Optional[jax.Array] = None,\n    ) -> LanguageModelOutput:\n        \"\"\"Forward pass, producing a sequence of logits.\"\"\"\n        del batch  # Unused.\n\n        config = self.config\n\n        input_mask = jnp.greater(tokens, config.pad_token)\n\n        # Embed the input tokens and positions.\n        in_out_embed = InOutEmbed(\n            self.config.vocab_size,\n            embed_dim=self.config.model_size,\n            sharding=P(None, (\"data\", \"model\")),\n        )\n        input_embeddings = in_out_embed(tokens).astype(config.fprop_dtype)\n        input_embeddings = with_sharding_constraint(\n            input_embeddings, P(\"data\", None, self.model.model_axis)\n        )\n        input_embeddings *= config.embedding_multiplier_scale\n\n        model_output = self.model(\n            input_embeddings,\n            input_mask,\n            memory=memory,\n        )  # [B, T, D]\n        embeddings, model_state = model_output.embeddings, model_output.memory\n        if self.model.shard_activations:\n            embeddings = with_sharding_constraint(\n                embeddings, P(\"data\", None, self.model.model_axis)\n            )\n        else:\n            embeddings = with_sharding_constraint(embeddings, P(\"data\", None))\n        rank_logger.debug(f\"Final embedding shape: {embeddings.shape}\")\n        embeddings = layer_norm(embeddings, self.model)\n        assert embeddings.dtype == self.fprop_dtype\n\n        if last_hid_only:\n            last_step = jnp.maximum(jnp.sum(input_mask.astype(jnp.int32), axis=1) - 1, 0)\n            last_hid = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step)\n            return last_hid\n\n        if length is not None:\n            last_step = jnp.maximum(length.astype(jnp.int32) - 1, 0)\n            embeddings = jax.vmap(lambda x, i: x[i], in_axes=0, out_axes=0)(embeddings, last_step)\n            embeddings = jnp.expand_dims(embeddings, axis=1)\n\n        # Decode the embeddings (here, we use tied weights).\n        rank_logger.info(embeddings.shape)\n        out = in_out_embed.decode(embeddings)\n        rank_logger.info(out.shape)\n        out *= config.output_multiplier_scale\n\n        if self.model.shard_activations:\n            out = with_sharding_constraint(out, P(\"data\", None, self.model.model_axis))\n        else:\n            out = with_sharding_constraint(out, P(\"data\", None))\n\n        return LanguageModelOutput(\n            logits=out,\n            model_state=model_state,\n        )\n\n    def init_memory(self, batch_size: int, seq_len: int, dtype=jnp.bfloat16):\n        return self.model.init_memory(batch_size=batch_size, sequence_len=seq_len, dtype=dtype)\n\n    def prefill_memory(self, prompts, memory):\n        # Pad to the left and right align?\n        # Basically assume prompt is already padded\n        model_output = self(prompts, memory=memory)\n        return model_output.logits, model_output.model_state\n\n\n@dataclass\nclass Transformer(hk.Module):\n    \"\"\"A transformer stack.\"\"\"\n\n    num_q_heads: int\n    num_kv_heads: int\n    key_size: int\n    widening_factor: float\n    init_scale: float\n    mesh: Any\n    attn_output_multiplier: float\n    shard_activations: bool\n    num_layers: int\n    # MoE\n    num_experts: int\n    num_selected_experts: int\n    name: Optional[str] = None\n\n    # Used for activation sharding\n    data_axis: Union[str, Tuple[str, ...]] = \"data\"\n    model_axis: Union[str, Tuple[str, ...]] = \"model\"\n\n    def init_memory(self, batch_size: int, sequence_len: int, dtype=jnp.bfloat16):\n        return Memory(\n            layers=init_layer_memories(\n                batch_size,\n                sequence_len,\n                self.num_kv_heads,\n                self.key_size,\n                self.num_layers,\n                step=jnp.zeros(batch_size, dtype=jnp.int32),\n                dtype=dtype,\n            ),\n        )\n\n    def __call__(\n        self,\n        embeddings: jax.Array,  # [B, T, D]\n        mask: jax.Array,  # [B, T]\n        memory: Optional[Memory],\n    ) -> TransformerOutput:\n        \"\"\"Transforms input embedding sequences to output embedding sequences.\"\"\"\n\n        fprop_dtype = embeddings.dtype\n        _, seq_len, model_size = embeddings.shape\n        padding_mask = mask.copy()\n        mask = mask[:, None, None, :]  # [B, H=1, T'=1, T]\n\n        # Compute causal mask for autoregressive sequence modelling.\n        causal_mask = jnp.tril(jnp.ones((1, 1, seq_len, seq_len))).astype(\n            fprop_dtype\n        )  # [B=1, H=1, T, T]\n        mask = mask * causal_mask  # [B, H=1, T, T]\n\n        h = embeddings\n        kv_memories = []\n\n        def block(\n            h,\n            mask,\n            padding_mask,\n            memory,\n            layer_index: Optional[int] = None,\n            widening_factor: Optional[int] = None,\n            name: Optional[str] = None,\n        ) -> DecoderOutput:\n            return DecoderLayer(\n                num_q_heads=self.num_q_heads,\n                num_kv_heads=self.num_kv_heads,\n                key_size=self.key_size,\n                widening_factor=widening_factor or self.widening_factor,\n                num_layers=self.num_layers,\n                mesh=self.mesh,\n                data_axis=self.data_axis,\n                model_axis=self.model_axis,\n                attn_output_multiplier=self.attn_output_multiplier,\n                shard_activations=self.shard_activations,\n                # MoE.\n                num_experts=self.num_experts,\n                num_selected_experts=self.num_selected_experts,\n                name=name,\n                layer_index=layer_index,\n            )(\n                h,\n                mask,\n                padding_mask,\n                memory,\n            )\n\n        for i in range(self.num_layers):\n            decoder_output = block(\n                h,\n                mask,\n                padding_mask,\n                memory.layers[i] if memory else None,\n                layer_index=i,\n                name=f\"decoder_layer_{i}\",\n            )\n            h, new_kv_memory = (\n                decoder_output.embeddings,\n                decoder_output.memory,\n            )\n            kv_memories.append(new_kv_memory)\n\n        return TransformerOutput(\n            embeddings=h,\n            memory=Memory(layers=kv_memories),\n        )\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.ruff]\nindent-width = 4\nline-length = 100\n\n[tool.ruff.lint]\nignore = [\n    \"E722\",\n    \"E731\",\n    \"E741\",\n    \"F405\",\n    \"E402\",\n    \"F403\",\n]\nselect = [\"ISC001\"]\n"
  },
  {
    "path": "requirements.txt",
    "content": "dm_haiku==0.0.12\njax[cuda12-pip]==0.4.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\nnumpy==1.26.4\nsentencepiece==0.2.0\n"
  },
  {
    "path": "run.py",
    "content": "# Copyright 2024 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport logging\n\nfrom model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit\nfrom runners import InferenceRunner, ModelRunner, sample_from_model\n\n\nCKPT_PATH = \"./checkpoints/\"\n\n\ndef main():\n    grok_1_model = LanguageModelConfig(\n        vocab_size=128 * 1024,\n        pad_token=0,\n        eos_token=2,\n        sequence_len=8192,\n        embedding_init_scale=1.0,\n        output_multiplier_scale=0.5773502691896257,\n        embedding_multiplier_scale=78.38367176906169,\n        model=TransformerConfig(\n            emb_size=48 * 128,\n            widening_factor=8,\n            key_size=128,\n            num_q_heads=48,\n            num_kv_heads=8,\n            num_layers=64,\n            attn_output_multiplier=0.08838834764831845,\n            shard_activations=True,\n            # MoE.\n            num_experts=8,\n            num_selected_experts=2,\n            # Activation sharding.\n            data_axis=\"data\",\n            model_axis=\"model\",\n        ),\n    )\n    inference_runner = InferenceRunner(\n        pad_sizes=(1024,),\n        runner=ModelRunner(\n            model=grok_1_model,\n            bs_per_device=0.125,\n            checkpoint_path=CKPT_PATH,\n        ),\n        name=\"local\",\n        load=CKPT_PATH,\n        tokenizer_path=\"./tokenizer.model\",\n        local_mesh_config=(1, 8),\n        between_hosts_config=(1, 1),\n    )\n    inference_runner.initialize()\n    gen = inference_runner.run()\n\n    inp = \"The answer to life the universe and everything is of course\"\n    print(f\"Output for prompt: {inp}\", sample_from_model(gen, inp, max_len=100, temperature=0.01))\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO)\n    main()\n"
  },
  {
    "path": "runners.py",
    "content": "# Copyright 2024 X.AI Corp.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport bisect\nimport functools\nimport logging\nimport math\nimport re\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, NamedTuple, Optional, Tuple\n\nimport haiku as hk\nimport jax\nimport jax.experimental.pjit as pjit\nimport jax.numpy as jnp\nimport numpy as np\nimport sentencepiece\nfrom jax.experimental import mesh_utils\nfrom jax.sharding import PartitionSpec as P\nfrom jax.typing import ArrayLike\n\nimport checkpoint as xai_checkpoint\nfrom model import (\n    LanguageModelConfig,\n    LanguageModelOutput,\n    TrainingState,\n    apply_rules,\n    Memory,\n    KVMemory,\n)\n\nlogger = logging.getLogger(__name__)\nrank_logger = logging.getLogger(\"rank\")\n\nTOP_K = 8\n\n\nclass SampleSettings(NamedTuple):\n    temperature: ArrayLike\n    nucleus_p: ArrayLike\n    mask: ArrayLike\n    # Whether a given batch element is actively used. [B]\n    active: ArrayLike\n\n\nclass SampleOutput(NamedTuple):\n    token_id: ArrayLike\n    prob: ArrayLike\n    top_k_token_ids: ArrayLike\n    top_k_probs: ArrayLike\n\n\ndef insert_slice(memory: Memory, slice, length, i):\n    slice = Memory(\n        layers=[\n            KVMemory(layer.k, layer.v, step=jnp.array([length]))\n            for layer in slice.layers\n        ],\n    )\n\n    return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0),\n                        memory, slice)\n\n\ndef pad_to_size(x, size):\n    if x.shape[0] > size:\n        # Left truncate if the context is too long.\n        x = x[-size:]\n    return np.pad(x, [0, size - x.shape[0]], mode=\"constant\", constant_values=0)\n\n\ndef top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array:\n    \"\"\"Performs nucleus filtering on logits.\"\"\"\n    assert logits.ndim == top_p.ndim, f\"Expected {logits.ndim} equal {top_p.ndim}\"\n    sorted_logits = jax.lax.sort(logits, is_stable=False)\n    sorted_probs = jax.nn.softmax(sorted_logits)\n    threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1)\n    threshold_largest_logits = jnp.take_along_axis(\n        sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1\n    )\n    assert threshold_largest_logits.shape == logits.shape[:-1] + (1,)\n    mask = logits >= threshold_largest_logits\n    # Set unused logits to -inf.\n    logits = jnp.where(mask, logits, -1e10)\n    return logits\n\n\ndef sample_token(\n    rngs: jax.random.PRNGKey,\n    lm_outputs: LanguageModelOutput,\n    settings: SampleSettings,\n) -> SampleOutput:\n    # Expand the settings shape to match the logit shape.\n    settings = SampleSettings(\n        temperature=jnp.expand_dims(settings.temperature, (1, 2)),  # Input [B], output [B, 1, 1].\n        nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)),  # Input [B], output [B, 1, 1].\n        mask=jnp.expand_dims(settings.mask, 1),  # Input [B, V], output [B, 1, V].\n        active=settings.active,  # [B].\n    )\n    logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype)\n    # Mask out all disallowed tokens by assigning them a near-zero probability.\n    logits = jnp.where(settings.mask, logits, -1e10)\n    # Mask out all tokens that don't fall into the p-th percentile.\n    logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype))\n\n    new_token = jax.vmap(jax.random.categorical)(rngs, logits)\n\n    probabilities = jax.nn.softmax(logits)\n    token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2)\n    token_prob = jnp.squeeze(token_prob, 1)\n\n    # Gather the top-k tokens and probabilities.\n    top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K)\n    top_k_probs = jnp.squeeze(top_k_probs, 1)\n    top_k_token_ids = jnp.squeeze(top_k_token_ids, 1)\n    return SampleOutput(\n        new_token,\n        token_prob,\n        top_k_token_ids,\n        top_k_probs,\n    )\n\n\n@dataclass\nclass ModelRunner:\n    model: LanguageModelConfig\n\n    bs_per_device: float = 2.0\n\n    load_rename_rules: Optional[list[tuple[str, str]]] = None\n    load_exclude_rules: Optional[list[str]] = None\n\n    rng_seed: int = 42  # Initial rng seed.\n    transform_forward: bool = False\n\n    checkpoint_path: str = \"\"\n\n    def make_forward_fn(self, mesh: Any):\n        def forward(tokens):\n            out = self.model.make(mesh=mesh)(tokens)\n            return out, None\n\n        if self.transform_forward:\n            forward = hk.transform(forward)\n        return forward\n\n    def initialize(\n        self,\n        init_data,\n        local_mesh_config: tuple[int, int],\n        between_hosts_config: tuple[int, int],\n    ):\n        num_replicas = math.prod(between_hosts_config)\n        self.model.initialize()\n        self.model.fprop_dtype = jnp.bfloat16\n        num_local_gpus = len(jax.local_devices())\n\n        # Calculate the global batch size from the local batch size.\n        self.batch_size = int(self.bs_per_device * num_local_gpus * num_replicas)\n\n        # Calculate the batch size per host from the global batch size.\n        self.local_batch_size = self.batch_size // jax.process_count()\n\n        self.local_mesh_config = local_mesh_config\n        self.between_hosts_config = between_hosts_config\n        rank_logger.info(\n            f\"Initializing mesh for {self.local_mesh_config=} {self.between_hosts_config=}...\"\n        )\n        self.mesh = make_mesh(self.local_mesh_config, self.between_hosts_config)\n        self.forward = self.make_forward_fn(mesh=self.mesh)\n        self.logits_fn = hk.transform(lambda tokens: self.forward(tokens)[0])\n\n        self.eval_forward = self.make_forward_fn(mesh=self.mesh)\n        self.logits_eval_fn = hk.transform(lambda tokens: self.eval_forward(tokens)[0])\n\n        if self.transform_forward:\n            self.state_sharding = self.get_state_sharding(init_data)\n            rank_logger.info(f\"State sharding type: {type(self.state_sharding)}\")\n            self.init_fn = pjit.pjit(self.init, out_shardings=self.state_sharding)\n\n    def init(self, rng: jax.Array, data) -> TrainingState:\n        assert self.transform_forward\n        rng, init_rng = jax.random.split(rng)\n        params = self.forward.init(init_rng, data[\"inputs\"])\n        return TrainingState(params=params)\n\n    def get_state_sharding(self, init_data):\n        assert self.transform_forward\n        rng = jax.random.PRNGKey(self.rng_seed)\n        rank_logger.info(f\"partition rules: {self.model.partition_rules}\")\n\n        with self.mesh:\n            shapes = jax.eval_shape(self.init, rng, init_data)\n            sharding = jax.tree_util.tree_map_with_path(\n                apply_rules(self.model.partition_rules()),\n                shapes,\n            )\n        return sharding\n\n    def load_or_init(\n        self,\n        init_data: Any,\n        from_checkpoint: bool = True,\n        init_fn: Optional[Callable] = None,\n    ):\n        rng = jax.random.PRNGKey(self.rng_seed)\n\n        if not self.checkpoint_path or not from_checkpoint:\n            rank_logger.info(\"Initializing model...\")\n            with self.mesh:\n                if init_fn is not None:\n                    state = init_fn(rng, init_data)\n                else:\n                    assert self.transform_forward\n                    state = self.init_fn(rng, init_data)\n            rank_logger.info(\"Model state is newly initialized.\")\n        else:\n            with self.mesh:\n                if init_fn:\n                    state_shapes = jax.eval_shape(init_fn, rng, init_data)\n                else:\n                    assert self.transform_forward\n                    state_shapes = jax.eval_shape(self.init_fn, rng, init_data)\n            init_state = None\n\n            state = xai_checkpoint.restore(\n                checkpoint_path=self.checkpoint_path,\n                state_shapes=state_shapes,\n                mesh=self.mesh,\n                between_hosts_config=self.between_hosts_config,\n                state_sharding=self.state_sharding,\n                init_state=init_state,\n                params_only=True,\n            )\n\n            del init_state\n        return state\n\n\n@dataclass\nclass Request:\n    prompt: str\n    temperature: float\n    nucleus_p: float\n    rng_seed: int\n    max_len: int\n\n\n@dataclass\nclass InferenceRunner:\n    name: str\n    runner: Any\n    load: str\n    tokenizer_path: str = \"/tmp/xai_data/tokenizer.model\"\n    local_mesh_config: Tuple[int, int] = (1, 1)\n    between_hosts_config: Tuple[int, int] = (1, 1)\n    pad_sizes: tuple[int] = (1024,)\n\n    def get_pad_bucket(self, size):\n        i = bisect.bisect_left(self.pad_sizes, size)\n        return self.pad_sizes[min(i, len(self.pad_sizes) - 1)]\n\n    def initialize(self):\n        runner = self.runner\n        self.runner.transform_forward = True\n        dummy_data = dict(\n            inputs=np.zeros((1, 256), dtype=np.int32),\n            targets=np.zeros((1, 256), dtype=np.int32),\n        )\n        runner.initialize(\n            dummy_data,\n            local_mesh_config=self.local_mesh_config,\n            between_hosts_config=self.between_hosts_config,\n        )\n\n        self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=self.tokenizer_path)\n\n        max_len = runner.model.sequence_len\n\n        self.vocab_size = self.runner.model.vocab_size\n\n        params = runner.load_or_init(dummy_data)\n        self.params = params\n\n        def pad_to_max_len(x):\n            if len(x.shape) > 1:\n                pad_width = max_len - x.shape[1]\n                return jnp.pad(x, [(0, 0), (0, pad_width), (0, 0), (0, 0)])\n            else:\n                return x\n\n        @functools.lru_cache\n        def lm():\n            return runner.model.make(mesh=runner.mesh)\n\n        def hk_forward(\n            tokens,\n            memory=None,\n            length=None,\n            active=None,\n        ) -> LanguageModelOutput:\n            if memory is not None:\n                assert active is not None\n                layers = []\n                for l in memory.layers:\n                    # Reset steps to 0 for inactive requests to avoid unnecessary computations.\n                    step = jnp.where(active, l.step, jnp.zeros_like(l.step))\n                    layers.append(l._replace(step=step))\n                memory = memory._replace(layers=layers)\n            return lm()(tokens, memory, length=length)\n\n        def hk_sample_step(rngs, last_output: SampleOutput, memory, settings):\n            rngs, rngs_ = jax.vmap(jax.random.split, out_axes=1)(rngs)\n            lm_outputs = hk_forward(last_output.token_id, memory=memory, active=settings.active)\n            sample_result = sample_token(rngs_, lm_outputs, settings)\n            return rngs, sample_result, lm_outputs.model_state\n\n        def hk_new_memory(batch_size, sequence_len):\n            return lm().init_memory(batch_size, sequence_len)\n\n        def hk_prefill_memory(\n            rngs,\n            memory,\n            settings,\n            last_output,\n            prompt,\n            length,\n            rng_seed,\n            new_settings,\n            i,\n        ):\n            rng = jax.random.PRNGKey(seed=rng_seed)\n            rng, rng_ = jax.random.split(rng)\n\n            # Allocate new memory for this sample. The memory length is equal to the length of the\n            # prompt.\n            slice = hk_new_memory(1, prompt.shape[0])\n\n            # Move the settings for this individual batch entry into the joint settings tensor.\n            settings = jax.tree_map(\n                lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0),\n                settings,\n                new_settings,\n            )\n\n            # Get the settings for the batch entry from the joint settings tensor.\n            settings_slice = jax.tree_map(lambda t: jnp.expand_dims(t[i], axis=0), settings)\n\n            # Process the first n-1 tokens of the prompt.\n            lm_outputs = hk_forward(\n                jnp.expand_dims(prompt, 0),\n                memory=slice,\n                length=jnp.expand_dims(length, 0),\n                active=settings_slice.active,\n            )\n\n            # The forward pass doesn't correctly set the `step` counter inside the memory. Manually\n            # override it so `hk_forward` uses the correct context length in the next call.\n            slice = lm_outputs.model_state\n            slice = slice._replace(\n                layers=[l._replace(step=jnp.array([length])) for l in slice.layers]\n            )\n\n            # Sample the actual output token.\n            rng_ = jnp.expand_dims(rng_, 0)\n            new_output = sample_token(rng_, lm_outputs, settings_slice)\n\n            # Update the KV cache/memory.\n            slice = jax.tree_map(pad_to_max_len, slice)\n            memory = insert_slice(memory, slice, length, i)\n\n            rng = jnp.expand_dims(rng, 0)\n            rngs = jax.lax.dynamic_update_index_in_dim(rngs, rng, i, axis=0)\n\n            # Move the network outputs for this batch entry into the joint output tensor.\n            last_output = jax.tree_util.tree_map(\n                lambda last, new: jax.lax.dynamic_update_index_in_dim(last, new, i, axis=0),\n                last_output,\n                new_output,\n            )\n            return rngs, last_output, memory, settings\n\n        sample_step_ = hk.without_apply_rng(hk.transform(hk_sample_step))\n        prefill_memory_ = hk.without_apply_rng(hk.transform(hk_prefill_memory))\n        new_memory_ = hk.without_apply_rng(hk.transform(hk_new_memory))\n        forward_ = hk.without_apply_rng(hk.transform(hk_forward))\n\n        rng = jax.random.PRNGKey(42)\n        dummy_tokens = jnp.zeros((1, max_len), jnp.int32)\n\n        with runner.mesh:\n            shapes = jax.eval_shape(forward_.init, rng, dummy_tokens)\n\n        self.params_sharding = jax.tree_util.tree_map_with_path(\n            apply_rules(runner.model.partition_rules()),\n            shapes,\n        )\n\n        ds = P(\"data\")\n        ms = runner.model.model.get_memory_sharding()\n        self.sample_step = pjit.pjit(\n            sample_step_.apply,\n            in_shardings=(self.params_sharding, None, ds, ms, None),\n            out_shardings=(None, ds, ms),\n            donate_argnums=3,\n        )\n        self.prefill_memory = pjit.pjit(\n            functools.partial(prefill_memory_.apply),\n            in_shardings=(\n                self.params_sharding,\n                None,\n                ms,\n                None,\n                ds,\n                None,\n                None,\n                None,\n                None,\n                None,\n            ),\n            out_shardings=(None, ds, ms, None),\n            donate_argnums=(2,),\n        )\n        self.new_memory = pjit.pjit(\n            new_memory_.apply,\n            static_argnums=(1, 2),\n            out_shardings=ms,\n        )\n\n    def run(self):\n        \"\"\"Generator that accepts prompts.\"\"\"\n        runner = self.runner\n        mesh = runner.mesh\n        max_len = runner.model.sequence_len\n        batch_size = runner.batch_size\n        params = self.params\n        rngs = jax.random.split(jax.random.PRNGKey(1), batch_size)\n        with mesh:\n            memory = self.new_memory(params, batch_size, max_len)\n            settings = SampleSettings(\n                temperature=np.zeros((batch_size,), dtype=np.float32),\n                nucleus_p=np.zeros((batch_size,), dtype=np.float32),\n                mask=np.ones((batch_size, self.vocab_size), dtype=np.int32),\n                active=np.zeros((batch_size), dtype=np.int32),\n            )\n            last_output = SampleOutput(\n                token_id=np.zeros((batch_size, 1), dtype=np.int32),\n                prob=np.zeros((batch_size, 1), dtype=jnp.bfloat16),\n                top_k_token_ids=np.zeros((batch_size, TOP_K), dtype=np.int32),\n                top_k_probs=np.zeros((batch_size, TOP_K), dtype=jnp.bfloat16),\n            )\n\n            prompt = np.array([300, 400, 500, 600, 600, 700, 800])\n\n            new_settings = SampleSettings(\n                temperature=np.float32(1),\n                nucleus_p=np.float32(1),\n                mask=np.ones((self.vocab_size,), dtype=np.int32),\n                active=np.zeros((), dtype=np.int32),\n            )\n            rng_seed = np.uint64(1)\n\n            for size in self.pad_sizes:\n                if size > runner.model.sequence_len:\n                    break\n                logger.info(\"Precompile {}\".format(size))\n                prompt_len = len(prompt)\n                prompt = pad_to_size(prompt, size)\n                rngs, last_output, memory, settings = self.prefill_memory(\n                    params,\n                    rngs,\n                    memory,\n                    settings,\n                    last_output,\n                    prompt,\n                    prompt_len,\n                    rng_seed,\n                    new_settings,\n                    0,\n                )\n        with runner.mesh:\n            logger.info(\"Compiling...\")\n            rngs, last_output, memory = self.sample_step(\n                params, rngs, last_output, memory, settings\n            )\n            logger.info(\"Done compiling.\")\n\n        all_tokens = []\n        free_slots = list(range(batch_size))\n        requests = [None] * batch_size\n        first_output = [None] * batch_size\n        jax.tree_map(lambda x: x.copy_to_host_async(), last_output)\n        prev_token = last_output\n        step = 0\n        total_num_tokens = 0\n        total_num_sequences = 0\n        with mesh:\n            while True:\n                while free_slots:\n                    request: Optional[Request] = yield\n                    tokens = self.tokenizer.encode(request.prompt)\n                    temperature = request.temperature\n                    nucleus_p = request.nucleus_p\n                    rng_seed = request.rng_seed\n\n                    i = free_slots.pop()\n                    prompt = np.array(tokens, dtype=np.int32)\n                    prompt_len = len(prompt)\n                    prompt = pad_to_size(prompt, self.get_pad_bucket(prompt.shape[0]))\n                    # All tokens are allowed.\n                    mask = np.ones((self.vocab_size,), dtype=np.int32)\n\n                    new_settings = SampleSettings(\n                        temperature=np.float32(temperature),\n                        nucleus_p=np.float32(nucleus_p),\n                        mask=mask,\n                        active=np.ones((), dtype=np.int32),\n                    )\n                    rng_seed = np.uint64(rng_seed)\n                    rngs, last_output, memory, settings = self.prefill_memory(\n                        params,\n                        rngs,\n                        memory,\n                        settings,\n                        last_output,\n                        prompt,\n                        prompt_len,\n                        rng_seed,\n                        new_settings,\n                        i,\n                    )\n                    jax.tree_map(lambda x: x.copy_to_host_async(), last_output)\n                    first_output[i] = last_output\n                    requests[i] = request\n                    total_num_sequences += 1\n\n                rngs, last_output, memory = self.sample_step(\n                    params, rngs, last_output, memory, settings\n                )\n                total_num_tokens += batch_size - len(free_slots)\n\n                # prev_token should already be on the host.\n                prev_token = jax.tree_map(np.array, prev_token)\n                for i in range(batch_size):\n                    if requests[i] is not None:\n                        if first_output[i] is not None:\n                            first_output_i = jax.tree_map(np.array, first_output[i])\n                            all_tokens.append(int(first_output_i.token_id[i][0]))\n                            first_output[i] = None\n                            continue\n\n                        all_tokens.append(int(prev_token.token_id[i][0]))\n                        cont = len(all_tokens) < requests[i].max_len\n\n                        if not cont:\n                            output_str = self.tokenizer.decode(all_tokens)\n                            requests[i] = None\n                            free_slots.append(i)\n                            all_tokens = []\n                            settings = settings._replace(active=settings.active.at[i].set(0))\n                            yield output_str\n\n                jax.tree_map(lambda x: x.copy_to_host_async(), last_output)\n                prev_token = last_output\n                step += 1\n\n\ndef make_mesh(\n    local_mesh_config: tuple[int, ...], between_hosts_config: tuple[int, ...]\n) -> jax.sharding.Mesh:\n    assert len(local_mesh_config) == 2\n    assert len(between_hosts_config) == 2\n    rank_logger.info(\"Detected %s devices in mesh\", jax.device_count())\n    device_mesh = mesh_utils.create_hybrid_device_mesh(\n        local_mesh_config,\n        between_hosts_config,\n        devices=jax.devices(),\n        process_is_granule=True,\n    )\n    rank_logger.debug(re.sub(\"\\n+\", \"\\n\", f\"Job device mesh is:\\n{device_mesh}\"))\n    return jax.sharding.Mesh(device_mesh, (\"data\", \"model\"))\n\n\ndef sample_from_model(server, prompt, max_len, temperature):\n    next(server)\n    inp = Request(\n        prompt=prompt,\n        temperature=temperature,\n        nucleus_p=1.0,\n        rng_seed=42,\n        max_len=max_len,\n    )\n    return server.send(inp)\n"
  }
]