[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control\n.pdm.toml\n.pdm-python\n.pdm-build/\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n\n# Ruff\n.ruff_cache"
  },
  {
    "path": "Install.md",
    "content": "## Prerequisites\n\n-   **Python**: This project requires Python version 3.8 or higher. It has been tested on Python 3.10.\n\n### Environment Setup\n\n1. **Clone the Repository**\n\n    First, clone the repository to your local machine:\n\n    ```bash\n    git clone git@github.com:idoh/mamba.np.git\n    cd mamba.np\n    ```\n\n2. **Set Up Virtual Environment**\n\n    It is recommended to use a virtual environment to manage dependencies. You can set up a virtual environment using `venv`:\n\n    ```bash\n    python3 -m venv venv\n    ```\n\n3. **Activate Virtual Environment**\n\n    - On **Windows**:\n        ```bash\n        .\\venv\\Scripts\\activate\n        ```\n    - On **macOS/Linux**:\n        ```bash\n        source venv/bin/activate\n        ```\n\n4. **Install Required Packages**\n\n    Install the required packages using `pip`:\n\n    ```bash\n    pip install -r requirements.txt\n    ```\n\n5. **Install PyTorch**\n\n    Install PyTorch from the official [PyTorch website](https://pytorch.org/get-started/locally/). Choose the appropriate configuration for your system. For example, for a basic CPU-only installation on Windows/Mac, you can use:\n\n    ```bash\n    pip install torch\n    ```\n\n### Notes\n\n-   Ensure your Python version is at least 3.8. You can check your Python version by running:\n    ```bash\n    python --version\n    ```\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2024 idoh\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# mamba.np\n\n`mamba.np` is a pure NumPy implementation of Mamba.\n\n<p align=\"center\">\n  <img src=\"assets/mamba.jpeg\" width=\"300\" alt=\"mamba.np\">\n</p>\n\n## Installation\n\nPlease refer to [Install.md](./Install.md) guide\n\n## Usage\n\n```shell\n$ python mamba.py \"I have a dream that\"\n\"\"\"\nI have a dream that I will be able to see the sunrise in the morning.\n\nToken count: 18, elapsed: 18.57s, 1 tokens/s\n\"\"\"\n```\n\n## Citing mamba.np\n\nIf you use or discuss `mamba.np` in your academic research, please cite the project to help spread awareness:\n\n```\n@misc{mamba.np,\n  title = {mamba.np: pure NumPy implementation for Mamba},\n  author = {Ido Hakimi},\n  howpublished = {\\url{https://github.com/idoh/mamba.np}},\n  note = {mamba.np, MIT License}\n  year = {2024},\n}\n```\n\n# References\n\nThank you to the creators of the following libraries and tools and their contributors:\n\n-   [mamba-minimal](https://github.com/johnma2006/mamba-minimal) - @johnma2006\n-   [llama3.np](https://github.com/likejazz/llama3.np) - @likejazz\n-   The Mamba architecture was introduced in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by [Albert Gu](https://twitter.com/_albertgu?lang=en) and [Tri Dao](https://twitter.com/tri_dao?ref_src=twsrc%5Egoogle%7Ctwcamp%5Eserp%7Ctwgr%5Eauthor)\n-   The official implementation is here: https://github.com/state-spaces/mamba\n-   Title image was generated by [Microsoft Designer](https://designer.microsoft.com/)\n"
  },
  {
    "path": "mamba.py",
    "content": "\"\"\"Simple, minimal implementation of Mamba in one file of Numpy adapted from (1) and inspired from (2).\n\nSuggest reading the following before/while reading the code:\n    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)\n        https://arxiv.org/abs/2312.00752\n    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)\n        https://srush.github.io/annotated-s4\n\nGlossary:\n    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)\n    l: sequence length                  (`L` in [1] Algorithm 2)\n    d or d_model: hidden dim\n    n or d_state: latent state dim      (`N` in [1] Algorithm 2)\n    expand: expansion factor            (`E` in [1] Section 3.4)\n    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)\n    A, B, C, D: state space parameters  (See any state space representation formula)\n                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)\n    Δ or delta: input-dependent step size\n    dt_rank: rank of Δ                  (See [1] Section 3.6 \"Parameterization of ∆\")\n\nReferences:\n    (1) mamba-minimal (John Ma)\n        https://github.com/johnma2006/mamba-minimal/\n    (2) llama3.np (Sang Park)\n        https://github.com/likejazz/llama3.np\n\"\"\" # noqa: E501\n\nfrom __future__ import annotations\n\nimport json\nimport math\nimport sys\nimport time\nfrom collections.abc import Mapping\nfrom dataclasses import dataclass\nfrom typing import Union\n\nimport numpy as np\nfrom einops import einsum, rearrange\nfrom transformers import AutoTokenizer\n\n_MAX_NEW_TOKENS = 18\n\n\ndef load_model(pretrained_model_name: str) -> Mamba:\n    \"\"\"Load pretrained weights from HuggingFace into model.\n\n    Args:\n        pretrained_model_name: One of\n            * 'state-spaces/mamba-2.8b-slimpj'\n            * 'state-spaces/mamba-2.8b'\n            * 'state-spaces/mamba-1.4b'\n            * 'state-spaces/mamba-790m'\n            * 'state-spaces/mamba-370m'\n            * 'state-spaces/mamba-130m'\n\n    Returns:\n        model: Mamba model with weights loaded\n\n    \"\"\"\n    import torch\n    from transformers.utils import CONFIG_NAME, WEIGHTS_NAME\n    from transformers.utils.hub import cached_file\n\n    def load_config_hf(model_name):\n        resolved_archive_file = cached_file(\n            model_name,\n            CONFIG_NAME,\n            _raise_exceptions_for_missing_entries=False,\n        )\n        return json.load(open(resolved_archive_file))\n\n    def load_state_dict_hf(model_name):\n        resolved_archive_file = cached_file(\n            model_name,\n            WEIGHTS_NAME,\n            _raise_exceptions_for_missing_entries=False,\n        )\n        return torch.load(\n            resolved_archive_file,\n            weights_only=True,\n            map_location=\"cpu\",\n            mmap=True,\n        )\n\n    config_data = load_config_hf(pretrained_model_name)\n    args = ModelArgs(\n        d_model=config_data[\"d_model\"],\n        n_layer=config_data[\"n_layer\"],\n        vocab_size=config_data[\"vocab_size\"],\n    )\n    state_dict = load_state_dict_hf(pretrained_model_name)\n\n    weights = {}\n    for key in state_dict:\n        new_key = key.replace(\"backbone.\", \"\")\n        weights[new_key] = state_dict[key].numpy()\n\n    model = Mamba(weights, args)\n\n    return model\n\n\n@dataclass\nclass ModelArgs:\n    \"\"\"\n    Data class for storing model-specific arguments.\n\n    Args:\n        d_model (int): Model dimension.\n        n_layer (int): Number of layers.\n        vocab_size (int): Vocabulary size.\n        d_state (int, optional): State dimension (default: 16).\n        expand (int, optional): Expansion factor (default: 2).\n        dt_rank (Union[int, str], optional): Rank for Δ (default: \"auto\").\n        d_conv (int, optional): Convolution dimension (default: 4).\n        pad_vocab_size_multiple (int, optional): Padding vocabulary size multiple (default: 8).\n        conv_bias (bool, optional): Whether to use bias in convolution layers (default: True).\n        bias (bool, optional): Whether to use bias in linear layers (default: False).\n\n    Attributes:\n        d_inner (int): Inner dimension calculated as expand * d_model.\n\n    Notes:\n        - If dt_rank is set to \"auto\", it computes it as the ceiling of d_model / 16.\n        - Ensures that vocab_size is a multiple of pad_vocab_size_multiple.\n    \"\"\" # noqa: E501\n\n    d_model: int\n    n_layer: int\n    vocab_size: int\n    d_state: int = 16\n    expand: int = 2\n    dt_rank: Union[int, str] = \"auto\"\n    d_conv: int = 4\n    pad_vocab_size_multiple: int = 8\n    conv_bias: bool = True\n    bias: bool = False\n\n    def __post_init__(self):\n        self.d_inner = int(self.expand * self.d_model)\n        if self.dt_rank == \"auto\":\n            self.dt_rank = math.ceil(self.d_model / 16)\n        if self.vocab_size % self.pad_vocab_size_multiple != 0:\n            self.vocab_size += (\n                self.pad_vocab_size_multiple\n                - self.vocab_size % self.pad_vocab_size_multiple\n            )\n\n\nclass Mamba:\n    def __init__(self, weights: Mapping[str, np.ndarray], args: ModelArgs):\n        \"\"\"\n        Full Mamba model.\n\n        Args:\n            weights (Mapping[str, np.ndarray]): Pre-trained weights.\n            args (ModelArgs): Model-specific arguments.\n        \"\"\"\n        self.args = args\n        self.embedding = Embedding(weight=weights.get(\"embedding.weight\"))\n        self.layers = [\n            ResidualBlock(i, weights, args) for i in range(args.n_layer)\n        ]\n        self.norm_f = RMSNorm(weight=weights.get(\"norm_f.weight\"))\n\n        # Tie output projection to embedding weights.\n        # See \"Weight Tying\" paper\n        self.lm_head = Linear(weight=self.embedding.weight, bias=None)\n\n    def __call__(self, input_ids: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Forward pass through the Mamba model.\n\n        Args:\n            input_ids (np.ndarray): shape (b, l), dtype long.\n\n        Returns:\n            np.ndarray: shape (b, l, vocab_size). The output logits tensor.\n\n        Official Implementation:\n            class MambaLMHeadModel, see https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L118\n        \"\"\"\n        x = self.embedding(input_ids)\n\n        for layer in self.layers:\n            x = layer(x)\n\n        x = self.norm_f(x)\n        logits = self.lm_head(x)\n\n        return logits\n\n    def generate(self, input_ids: np.ndarray, max_new_tokens: int):\n        _, L = input_ids.shape\n        for _ in range(L, max_new_tokens):\n            logits = self(input_ids)[:, -1]\n            next_id = np.argmax(logits, axis=-1, keepdims=True)\n            input_ids = np.concatenate([input_ids, next_id], axis=-1)\n            yield next_id\n\n\nclass ResidualBlock:\n    def __init__(\n        self, layer_id: int, weights: Mapping[str, np.ndarray], args: ModelArgs\n    ):\n        \"\"\"\n        Residual block for Mamba-based models.\n\n        Args:\n            layer_id (int): Identifier for the layer.\n            weights (Mapping[str, np.ndarray]): Pre-trained weights.\n            args (ModelArgs): Model-specific arguments.\n        \"\"\"\n        self.args = args\n        self.mixer = MambaBlock(\n            in_proj=Linear(\n                weight=weights.get(f\"layers.{layer_id}.mixer.in_proj.weight\"),\n                bias=None,\n            ),\n            conv1d=MambaConv1d(\n                weight=weights.get(f\"layers.{layer_id}.mixer.conv1d.weight\"),\n                bias=weights.get(f\"layers.{layer_id}.mixer.conv1d.bias\"),\n            ),\n            x_proj=Linear(\n                weight=weights.get(f\"layers.{layer_id}.mixer.x_proj.weight\"),\n                bias=None,\n            ),\n            dt_proj=Linear(\n                weight=weights.get(f\"layers.{layer_id}.mixer.dt_proj.weight\"),\n                bias=weights.get(f\"layers.{layer_id}.mixer.dt_proj.bias\"),\n            ),\n            A_log=weights.get(f\"layers.{layer_id}.mixer.A_log\"),\n            D=weights.get(f\"layers.{layer_id}.mixer.D\"),\n            out_proj=Linear(\n                weight=weights.get(f\"layers.{layer_id}.mixer.out_proj.weight\"),\n                bias=None,\n            ),\n            args=args,\n        )\n\n        self.norm = RMSNorm(\n            weight=weights.get(f\"layers.{layer_id}.norm.weight\")\n        )\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Forward pass through the residual block.\n\n        Args:\n            x (np.ndarray): shape (b, l, d).\n\n        Returns:\n            np.ndarray: shape (b, l, d).\n\n        Official Implementation:\n            Block.forward(), see https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L142\n\n            Note: The official repo chains residual blocks that look like\n                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...\n            where the first Add is a no-op. This is purely for performance reasons as this\n            allows them to fuse the Add->Norm.\n\n            We instead implement our blocks as the more familiar, simpler, and numerically equivalent\n                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ...\n\n        \"\"\" # noqa: E501\n        output = self.mixer(self.norm(x)) + x\n        return output\n\n\nclass MambaBlock:\n    def __init__(\n        self,\n        in_proj: Linear,\n        conv1d: MambaConv1d,\n        x_proj: Linear,\n        dt_proj: Linear,\n        A_log: np.ndarray,\n        D: np.ndarray,\n        out_proj: Linear,\n        args: ModelArgs,\n    ):\n        \"\"\"\n        A single Mamba block, as described in Figure 3 in Section 3.4 of the Mamba paper [1].\n\n        Args:\n            in_proj (Linear): shape (d, 2*d_in). Linear layer for input projection.\n            conv1d (MambaConv1d): shape (d_in, 1, d_conv). Mamba-specific 1D convolutional layer.\n            x_proj (Linear): shape (d_in, dt_rank+2*d_state). Linear layer for projecting input-specific Δ, B, and C.\n            dt_proj (Linear): shape (dt_rank, d_in). Linear layer for projecting Δ dt_rank to d_in.\n            A_log (np.ndarray): shape (d_in, d). Matrix A_log.\n            D (np.ndarray): shape (d_in,). Vector D.\n            out_proj (Linear): shape (d_in, d). Linear layer for output projection.\n            args (ModelArgs): Model-specific arguments.\n        \"\"\" # noqa: E501\n        self.args = args\n        self.in_proj: Linear = in_proj\n        self.conv1d: MambaConv1d = conv1d\n        self.x_proj: Linear = x_proj\n        self.dt_proj: Linear = dt_proj\n        self.A_log: np.ndarray = A_log\n        self.D: np.ndarray = D\n        self.out_proj: Linear = out_proj\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Forward pass through the Mamba block.\n\n        Args:\n            x (np.ndarray): Input tensor of shape (b, l, d).\n\n        Returns:\n            np.ndarray: Output tensor of shape (b, l, d).\n        \"\"\"\n        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)\n        (x, res) = np.split(\n            x_and_res,\n            indices_or_sections=(self.args.d_inner, 2 * self.args.d_inner),\n            axis=-1,\n        )[:-1]\n\n        x = self.conv1d(x)\n        x = silu(x)\n\n        y = self.ssm(x)\n        y = y * silu(res)\n\n        output = self.out_proj(y)\n\n        return output\n\n    def ssm(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Runs the SSM. See:\n            - Algorithm 2 in Section 3.2 in the Mamba paper [1] [1].\n            - run_SSM(A, B, C, u) in The Annotated S4 [2]\n\n        Args:\n            x (np.ndarray): shape (b, l, d_in).\n\n        Returns:\n            np.ndarray: shape (b, l, d_in).\n\n        Official Implementation:\n            mamba_inner_ref(), see https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311\n\n        References:\n            [1] Mamba paper: https://arxiv.org/abs/2106.16067\n            [2] The Annotated S4: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119\n        \"\"\" # noqa: E501\n        (d_in, n) = self.A_log.shape\n\n        # Compute ∆, A, B, C, D (state space parameters)\n        # A and D are input-independent (see Mamba paper [1], Section 3.5.2 for A's interpretation) # noqa: E501\n        # ∆, B, C are input-dependent (a key difference between Mamba and linear time-invariant S4) # noqa: E501\n\n        A = -np.exp(self.A_log.astype(float))  # shape (d_in, n)\n        D = self.D.astype(float)\n\n        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)\n        (delta, B, C) = np.split(\n            x_dbl,\n            indices_or_sections=(\n                self.args.dt_rank,\n                self.args.dt_rank + n,\n                self.args.dt_rank + 2 * n,\n            ),\n            axis=-1,\n        )[\n            :-1\n        ]  # delta: (b, l, dt_rank). B, C: (b, l, n)\n        delta = softplus(self.dt_proj(delta))  # (b, l, d_in)\n\n        y = self.selective_scan(\n            x, delta, A, B, C, D\n        )  # Similar to run_SSM(A, B, C, u) in The Annotated S4 [2]\n\n        return y\n\n    def selective_scan(\n        self,\n        u: np.ndarray,\n        delta: np.ndarray,\n        A: np.ndarray,\n        B: np.ndarray,\n        C: np.ndarray,\n        D: np.ndarray,\n    ) -> np.ndarray:\n        \"\"\"\n        Performs the selective scan algorithm as described in the Mamba paper [1].\n        This function computes the output based on input data and state space parameters.\n        See:\n            - Section 2 State Space Models in the Mamba paper [1]\n            - Algorithm 2 in Section 3.2 in the Mamba paper [1]\n            - run_SSM(A, B, C, u) in The Annotated S4 [2]\n\n        This is the classic discrete state space formula:\n            x(t + 1) = Ax(t) + Bu(t)\n            y(t)     = Cx(t) + Du(t)\n        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).\n\n        Args:\n            u (np.ndarray): shape (b, l, d_in). Input tensor.\n            delta (np.ndarray): shape (b, l, d_in). Step size tensor.\n            A (np.ndarray): shape (d_in, n). Matrix A.\n            B (np.ndarray): shape (b, l, n). Tensor B.\n            C (np.ndarray): shape (b, l, n). Tensor C.\n            D (np.ndarray): shape (d_in,). Vector D.\n\n        Returns:\n            np.ndarray: Output tensor of shape (b, l, d_in).\n\n        Official Implementation:\n            selective_scan_ref(), see https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86\n            Note: Some parts have been refactored from `selective_scan_ref`, so the functionality may not match exactly.\n\n        References:\n            [1] Mamba paper: https://arxiv.org/abs/2106.16067\n            [2] The Annotated S4: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119\n        \"\"\" # noqa: E501\n        (b, l, d_in) = u.shape\n        n = A.shape[1]\n\n        # Discretize continuous parameters (A, B)\n        deltaA = np.exp(einsum(delta, A, \"b l d_in, d_in n -> b l d_in n\"))\n        deltaB_u = einsum(\n            delta, B, u, \"b l d_in, b l n, b l d_in -> b l d_in n\"\n        )\n\n        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])\n        # Note that the below is sequential, while the official implementation does a much faster parallel scan that is additionally hardware-aware (like FlashAttention). # noqa: E501\n        x = np.zeros((b, d_in, n))\n        ys = []\n        for i in range(l):\n            x = deltaA[:, i] * x + deltaB_u[:, i]\n            y = einsum(x, C[:, i, :], \"b d_in n, b n -> b d_in\")\n            ys.append(y)\n\n        y = np.stack(ys, axis=1)  # shape (b, l, d_in)\n\n        y = y + u * D\n\n        return y\n\n\nclass MambaConv1d:\n    \"\"\"\n    A 1 dimensional convolution layer with pre-defined weights and optional bias that is applied on each channel separately.\n\n    Args:\n        weight (np.ndarray): The weight tensor for the convolution layer.\n        bias (np.ndarray or None): The bias tensor (optional). If None, no bias is applied.\n    \"\"\" # noqa: E501\n\n    def __init__(self, weight: np.ndarray, bias: np.ndarray):\n        self.weight = weight\n        self.bias = bias\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Applies 1D convolution to the input tensor `x`.\n\n        Args:\n            x (np.ndarray): Input tensor with shape (batch_size, sequence_length, in_channels).\n\n        Returns:\n            np.ndarray: Output tensor after 1D convolution.\n        \"\"\" # noqa: E501\n        x = rearrange(x, \"b l d_in -> b d_in l\")\n\n        batch_size, in_channels, length = x.shape\n        out_channels, _, kernel_size = self.weight.shape\n\n        assert in_channels == out_channels\n\n        output_length = length + kernel_size - 1\n        output_tensor = np.empty((batch_size, in_channels, output_length))\n\n        for b in range(batch_size):\n            for i in range(in_channels):\n                output_tensor[b, i, :] = np.convolve(\n                    x[b, i, :], self.weight[i, 0, ::-1], mode=\"full\"\n                )\n                if self.bias is not None:\n                    output_tensor[b, i, :] += self.bias[i]\n\n        output_tensor = output_tensor[:, :, :length]\n        output_tensor = rearrange(output_tensor, \"b d_in l -> b l d_in\")\n\n        return output_tensor\n\n\nclass Linear:\n    \"\"\"\n    Represents a linear transformation layer.\n\n    Args:\n        weight (np.ndarray): The weight matrix for the linear transformation.\n        bias (np.ndarray or None): The bias vector (optional). If None, no bias is applied.\n    \"\"\" # noqa: E501\n\n    def __init__(self, weight: np.ndarray, bias: np.ndarray):\n        self.weight = weight\n        self.bias = bias\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Applies the linear transformation to the input tensor `x`.\n\n        Args:\n            x (np.ndarray): Input tensor.\n\n        Returns:\n            np.ndarray: Output tensor after linear transformation.\n        \"\"\"\n        output_tensor = x @ self.weight.T\n        if self.bias is not None:\n            output_tensor += self.bias\n\n        return output_tensor\n\n\nclass Embedding:\n    \"\"\"\n    Represents an embedding layer with pre-defined weights.\n\n    Args:\n        weight (np.ndarray): The weight matrix for the embedding layer.\n    \"\"\"\n\n    def __init__(self, weight: np.ndarray):\n        self.weight = weight\n\n    def __call__(self, x: np.ndarray) -> np.ndarray:\n        \"\"\"\n        Returns the embedding vectors for the given indices `x`.\n\n        Args:\n            x (np.ndarray): Indices of the desired embeddings, shape: (b, l).\n\n        Returns:\n            np.ndarray: The embedding vectors corresponding to the input indices.\n        \"\"\" # noqa: E501\n        return self.weight[x]\n\n\nclass RMSNorm:\n    def __init__(self, weight: np.ndarray, eps: float = 1e-5):\n        \"\"\"\n        Initializes an instance of the RMSNorm class.\n\n        Args:\n            weight (np.ndarray): Weight vector for normalization.\n            eps (float, optional): Small constant to prevent division by zero. Defaults to 1e-5.\n        \"\"\" # noqa: E501\n        self.weight = weight\n        self.eps = eps\n\n    def __call__(self, x: np.ndarray):\n        \"\"\"\n        Applies RMS normalization to the input tensor.\n\n        Args:\n            x (np.ndarray): Input tensor.\n\n        Returns:\n            np.ndarray: Normalized tensor.\n        \"\"\"\n        # Compute the root mean square along the last dimension\n        rms = np.sqrt(np.mean(x**2, axis=-1, keepdims=True) + self.eps)\n        return x * self.weight / rms\n\n\ndef silu(x: np.ndarray) -> np.ndarray:\n    \"\"\"\n    Applies the Sigmoid Linear Unit (SiLU) activation function to the input tensor.\n\n    Args:\n        x (np.ndarray): Input tensor.\n\n    Returns:\n        np.ndarray: Output tensor after applying SiLU.\n    \"\"\" # noqa: E501\n    return x / (1 + np.exp(-x))\n\n\ndef softplus(x: np.ndarray) -> np.ndarray:\n    \"\"\"\n    Applies the Softplus activation function to the input tensor.\n\n    Args:\n        x (np.ndarray): Input tensor.\n\n    Returns:\n        np.ndarray: Output tensor after applying Softplus.\n    \"\"\"\n    return np.log(1 + np.exp(x))\n\n\ndef generate_text(model: Mamba, tokenizer: AutoTokenizer, prompt: str) -> None:\n    \"\"\"\n    Generates text using a pre-trained language model.\n\n    Args:\n        model (Mamba): The pre-trained language model.\n        tokenizer (AutoTokenizer): The tokenizer for encoding input prompts.\n        prompt (str): Input prompt for text generation.\n    \"\"\"\n\n    # Print the input prompt\n    print(f\"\\n{prompt}\", end=\"\")\n\n    # Encode the input prompt and initialize token count\n    input_ids = np.array([tokenizer.encode(prompt)])\n    _, L = input_ids.shape\n\n    # Generate text\n    start = time.time()\n    for id in model.generate(input_ids, max_new_tokens=_MAX_NEW_TOKENS):\n        L += 1\n        output_id = id[0].tolist()\n        if output_id[-1] in [tokenizer.eos_token_id, tokenizer.bos_token_id]:\n            break\n        print(tokenizer.decode(output_id), end=\"\")\n        sys.stdout.flush()\n\n    # Calculate elapsed time and tokens per second\n    elapsed = time.time() - start\n    print(\n        f\"\\n\\nToken count: {L}, elapsed: {elapsed:.2f}s, {round(L / elapsed)} \\\n            tokens/s\"\n    )\n\n\nif __name__ == \"__main__\":\n    # Read input prompt\n    prompt = \"I have a dream that\" if len(sys.argv) == 1 else sys.argv[1]\n\n    # Load the pre-trained language model and tokenizer\n    model = load_model(\"state-spaces/mamba-130m\")\n    tokenizer = AutoTokenizer.from_pretrained(\"EleutherAI/gpt-neox-20b\")\n\n    generate_text(model=model, tokenizer=tokenizer, prompt=prompt)\n"
  },
  {
    "path": "requirements.txt",
    "content": "einops\ntransformers\nruff\ntorch"
  },
  {
    "path": "ruff.toml",
    "content": "line-length = 79\n\n[lint]\nselect = [\"E501\", \"I\"]\n\n[format]\ndocstring-code-format = true\ndocstring-code-line-length = 200"
  }
]