Full Code of ml-explore/mlx-examples for AI

main e52c128d113f cached
216 files
4.5 MB
1.2M tokens
1221 symbols
1 requests
Download .txt
Showing preview only (4,794K chars total). Download the full file or copy to clipboard to get everything.
Repository: ml-explore/mlx-examples
Branch: main
Commit: e52c128d113f
Files: 216
Total size: 4.5 MB

Directory structure:
gitextract_ur1ntgnw/

├── .github/
│   └── workflows/
│       └── pull_request.yml
├── .gitignore
├── .pre-commit-config.yaml
├── ACKNOWLEDGMENTS.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bert/
│   ├── README.md
│   ├── convert.py
│   ├── model.py
│   ├── requirements.txt
│   ├── test.py
│   └── weights/
│       └── .gitignore
├── cifar/
│   ├── README.md
│   ├── dataset.py
│   ├── main.py
│   ├── requirements.txt
│   └── resnet.py
├── clip/
│   ├── .gitignore
│   ├── README.md
│   ├── clip.py
│   ├── convert.py
│   ├── hf_preproc.py
│   ├── image_processor.py
│   ├── linear_probe.py
│   ├── model.py
│   ├── requirements.txt
│   ├── test.py
│   └── tokenizer.py
├── cvae/
│   ├── .gitignore
│   ├── README.md
│   ├── dataset.py
│   ├── main.py
│   ├── requirements.txt
│   └── vae.py
├── encodec/
│   ├── README.md
│   ├── benchmarks/
│   │   ├── bench_mx.py
│   │   └── bench_pt.py
│   ├── convert.py
│   ├── encodec.py
│   ├── example.py
│   ├── requirements.txt
│   ├── test.py
│   └── utils.py
├── flux/
│   ├── README.md
│   ├── dreambooth.py
│   ├── flux/
│   │   ├── __init__.py
│   │   ├── autoencoder.py
│   │   ├── clip.py
│   │   ├── datasets.py
│   │   ├── flux.py
│   │   ├── layers.py
│   │   ├── lora.py
│   │   ├── model.py
│   │   ├── sampler.py
│   │   ├── t5.py
│   │   ├── tokenizers.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── generate_interactive.py
│   ├── requirements.txt
│   └── txt2image.py
├── gcn/
│   ├── .gitignore
│   ├── README.md
│   ├── datasets.py
│   ├── gcn.py
│   ├── main.py
│   └── requirements.txt
├── llava/
│   ├── .gitignore
│   ├── README.md
│   ├── generate.py
│   ├── language.py
│   ├── llava.py
│   ├── requirements.txt
│   ├── test.py
│   └── vision.py
├── llms/
│   ├── README.md
│   ├── gguf_llm/
│   │   ├── README.md
│   │   ├── generate.py
│   │   ├── models.py
│   │   ├── requirements.txt
│   │   └── utils.py
│   ├── llama/
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── llama.py
│   │   ├── requirements.txt
│   │   └── sample_prompt.txt
│   ├── mistral/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── mistral.py
│   │   ├── requirements.txt
│   │   └── test.py
│   ├── mixtral/
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── mixtral.py
│   │   ├── params.json
│   │   └── requirements.txt
│   └── speculative_decoding/
│       ├── README.md
│       ├── convert.py
│       ├── decoder.py
│       ├── main.py
│       ├── model.py
│       └── requirements.txt
├── lora/
│   ├── .gitignore
│   ├── README.md
│   ├── convert.py
│   ├── data/
│   │   ├── test.jsonl
│   │   ├── train.jsonl
│   │   ├── valid.jsonl
│   │   └── wikisql.py
│   ├── fuse.py
│   ├── lora.py
│   ├── models.py
│   ├── requirements.txt
│   └── utils.py
├── mnist/
│   ├── README.md
│   ├── main.py
│   ├── mnist.py
│   └── requirements.txt
├── musicgen/
│   ├── README.md
│   ├── benchmarks/
│   │   ├── bench_mx.py
│   │   └── bench_pt.py
│   ├── generate.py
│   ├── musicgen.py
│   ├── requirements.txt
│   └── utils.py
├── normalizing_flow/
│   ├── README.md
│   ├── bijectors.py
│   ├── distributions.py
│   ├── flows.py
│   ├── main.py
│   └── requirements.txt
├── segment_anything/
│   ├── README.md
│   ├── convert.py
│   ├── main.py
│   ├── notebooks/
│   │   ├── automatic_mask_generator_example.ipynb
│   │   └── predictor_example.ipynb
│   ├── requirements.txt
│   └── segment_anything/
│       ├── __init__.py
│       ├── automatic_mask_generator.py
│       ├── common.py
│       ├── image_encoder.py
│       ├── mask_decoder.py
│       ├── predictor.py
│       ├── prompt_encoder.py
│       ├── sam.py
│       ├── transformer.py
│       └── utils/
│           ├── __init__.py
│           ├── amg.py
│           └── transforms.py
├── speechcommands/
│   ├── README.md
│   ├── kwt.py
│   ├── main.py
│   └── requirements.txt
├── stable_diffusion/
│   ├── README.md
│   ├── image2image.py
│   ├── requirements.txt
│   ├── stable_diffusion/
│   │   ├── __init__.py
│   │   ├── clip.py
│   │   ├── config.py
│   │   ├── model_io.py
│   │   ├── sampler.py
│   │   ├── tokenizer.py
│   │   ├── unet.py
│   │   └── vae.py
│   └── txt2image.py
├── t5/
│   ├── .gitignore
│   ├── README.md
│   ├── hf_t5.py
│   ├── requirements.txt
│   └── t5.py
├── transformer_lm/
│   ├── README.md
│   ├── datasets.py
│   ├── main.py
│   └── requirements.txt
├── whisper/
│   ├── MANIFEST.in
│   ├── README.md
│   ├── benchmark.py
│   ├── convert.py
│   ├── mlx_whisper/
│   │   ├── __init__.py
│   │   ├── _version.py
│   │   ├── assets/
│   │   │   ├── download_alice.sh
│   │   │   ├── gpt2.tiktoken
│   │   │   ├── ls_test.flac
│   │   │   ├── mel_filters.npz
│   │   │   └── multilingual.tiktoken
│   │   ├── audio.py
│   │   ├── cli.py
│   │   ├── decoding.py
│   │   ├── load_models.py
│   │   ├── requirements.txt
│   │   ├── timing.py
│   │   ├── tokenizer.py
│   │   ├── torch_whisper.py
│   │   ├── transcribe.py
│   │   ├── whisper.py
│   │   └── writers.py
│   ├── setup.py
│   └── test.py
└── wwdc25/
    ├── Explore_language_models_on_Apple_silicon_with_MLX.ipynb
    ├── Get_started_with_MLX_for_Apple_silicon.ipynb
    ├── README.md
    ├── WWDC25MLXSwiftExamples/
    │   ├── WWDC25MLXSwiftExamples/
    │   │   ├── SimpleMLXLM.swift
    │   │   ├── SimpleMLXLMWithKVCache.swift
    │   │   └── main.swift
    │   └── WWDC25MLXSwiftExamples.xcodeproj/
    │       ├── project.pbxproj
    │       ├── project.xcworkspace/
    │       │   ├── contents.xcworkspacedata
    │       │   ├── xcshareddata/
    │       │   │   └── swiftpm/
    │       │   │       └── Package.resolved
    │       │   └── xcuserdata/
    │       │       └── shashankprasanna.xcuserdatad/
    │       │           └── UserInterfaceState.xcuserstate
    │       └── xcuserdata/
    │           └── shashankprasanna.xcuserdatad/
    │               └── xcschemes/
    │                   └── xcschememanagement.plist
    ├── data/
    │   ├── all.jsonl
    │   ├── train.jsonl
    │   └── valid.jsonl
    └── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/pull_request.yml
================================================
name: Test

on:
  push:
    branches: ["main"]
  pull_request:

permissions:
  contents: read

concurrency:
  group: ${{ github.workflow }}-${{ github.ref }}
  cancel-in-progress: ${{ github.ref != 'refs/head/main' }}

jobs:
  check_lint:
    if: github.repository == 'ml-explore/mlx-examples'
    runs-on: ubuntu-22.04
    steps:
      - uses: actions/checkout@v5
      - uses: actions/setup-python@v6
        with:
          python-version: "3.10"
      - uses: pre-commit/action@v3.0.1



================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Vim
*.swp

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# IDE files
.idea/
.vscode/

# .DS_Store files
.DS_Store


================================================
FILE: .pre-commit-config.yaml
================================================
repos:
-   repo: https://github.com/psf/black-pre-commit-mirror
    rev: 25.1.0
    hooks:
    -   id: black
-   repo: https://github.com/pycqa/isort
    rev: 6.0.0
    hooks:
    -   id: isort
        args:
            - --profile=black


================================================
FILE: ACKNOWLEDGMENTS.md
================================================
# Individual Contributors

If you wish to be acknowledged for your contributions, please list your name
with a short description of your contribution(s) below. For example:

- Jane Smith: Added the `foo` example.

MLX Examples was developed with contributions from the following individuals:

- Juarez Bochi: Added support for T5 models.
- Sarthak Yadav: Added the `cifar` and `speechcommands` examples.
- Shunta Saito: Added support for PLaMo models.
- Gabrijel Boduljak: Implemented `CLIP`.
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.

================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Contributor Covenant Code of Conduct

## Our Pledge

We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation.

We pledge to act and interact in ways that contribute to an open, welcoming,
diverse, inclusive, and healthy community.

## Our Standards

Examples of behavior that contributes to a positive environment for our
community include:

* Demonstrating empathy and kindness toward other people
* Being respectful of differing opinions, viewpoints, and experiences
* Giving and gracefully accepting constructive feedback
* Accepting responsibility and apologizing to those affected by our mistakes,
  and learning from the experience
* Focusing on what is best not just for us as individuals, but for the overall
  community

Examples of unacceptable behavior include:

* The use of sexualized language or imagery, and sexual attention or advances of
  any kind
* Trolling, insulting or derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or email address,
  without their explicit permission
* Other conduct which could reasonably be considered inappropriate in a
  professional setting

## Enforcement Responsibilities

Community leaders are responsible for clarifying and enforcing our standards of
acceptable behavior and will take appropriate and fair corrective action in
response to any behavior that they deem inappropriate, threatening, offensive,
or harmful.

Community leaders have the right and responsibility to remove, edit, or reject
comments, commits, code, wiki edits, issues, and other contributions that are
not aligned to this Code of Conduct, and will communicate reasons for moderation
decisions when appropriate.

## Scope

This Code of Conduct applies within all community spaces, and also applies when
an individual is officially representing the community in public spaces.
Examples of representing our community include using an official e-mail address,
posting via an official social media account, or acting as an appointed
representative at an online or offline event.

## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported to the community leaders responsible for enforcement at
[opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
All complaints will be reviewed and investigated promptly and fairly.

All community leaders are obligated to respect the privacy and security of the
reporter of any incident.

## Enforcement Guidelines

Community leaders will follow these Community Impact Guidelines in determining
the consequences for any action they deem in violation of this Code of Conduct:

### 1. Correction

**Community Impact**: Use of inappropriate language or other behavior deemed
unprofessional or unwelcome in the community.

**Consequence**: A private, written warning from community leaders, providing
clarity around the nature of the violation and an explanation of why the
behavior was inappropriate. A public apology may be requested.

### 2. Warning

**Community Impact**: A violation through a single incident or series of
actions.

**Consequence**: A warning with consequences for continued behavior. No
interaction with the people involved, including unsolicited interaction with
those enforcing the Code of Conduct, for a specified period of time. This
includes avoiding interactions in community spaces as well as external channels
like social media. Violating these terms may lead to a temporary or permanent
ban.

### 3. Temporary Ban

**Community Impact**: A serious violation of community standards, including
sustained inappropriate behavior.

**Consequence**: A temporary ban from any sort of interaction or public
communication with the community for a specified period of time. No public or
private interaction with the people involved, including unsolicited interaction
with those enforcing the Code of Conduct, is allowed during this period.
Violating these terms may lead to a permanent ban.

### 4. Permanent Ban

**Community Impact**: Demonstrating a pattern of violation of community
standards, including sustained inappropriate behavior, harassment of an
individual, or aggression toward or disparagement of classes of individuals.

**Consequence**: A permanent ban from any sort of public interaction within the
community.

## Attribution

This Code of Conduct is adapted from the [Contributor Covenant][homepage],
version 2.1, available at
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].

Community Impact Guidelines were inspired by
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].

For answers to common questions about this code of conduct, see the FAQ at
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
[https://www.contributor-covenant.org/translations][translations].

[homepage]: https://www.contributor-covenant.org
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
[Mozilla CoC]: https://github.com/mozilla/diversity
[FAQ]: https://www.contributor-covenant.org/faq
[translations]: https://www.contributor-covenant.org/translations


================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to mlx-examples

We want to make contributing to this project as easy and transparent as
possible.

## Pull Requests

1. Fork and submit pull requests to the repo. 
2. If you've added code that should be tested, add tests.
3. Every PR should have passing tests and at least one review. 
4. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
   This should install hooks for running `black` and `clang-format` to ensure
   consistent style for C++ and python code.
 
   You can also run the formatters manually as follows on individual files:
 
     ```bash
     clang-format -i file.cpp
     ```
 
     ```bash
     black file.py
     ```

     or,

     ```bash
     # single file
     pre-commit run --files file1.py 

     # specific files
     pre-commit run --files file1.py file2.py
     ```
 
   or run `pre-commit run --all-files` to check all files in the repo.

## Issues

We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

## License

By contributing to mlx-examples, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.


================================================
FILE: LICENSE
================================================
MIT License

Copyright © 2023 Apple Inc.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# MLX Examples

This repo contains a variety of standalone examples using the [MLX
framework](https://github.com/ml-explore/mlx).

The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples are listed below. Check-out [MLX
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
package for LLMs with MLX.

### Text Models 

- [Transformer language model](transformer_lm) training.
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
  [Mistral](llms/mistral), and more in the [LLMs](llms) directory.
- A mixture-of-experts (MoE) language model with [Mixtral 8x7B](llms/mixtral).
- Parameter efficient fine-tuning with [LoRA or QLoRA](lora).
- Text-to-text multi-task Transformers with [T5](t5).
- Bidirectional language understanding with [BERT](bert).

### Image Models 

- Generating images
  - [FLUX](flux)
  - [Stable Diffusion or SDXL](stable_diffusion)
- Image classification using [ResNets on CIFAR-10](cifar).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).

### Audio Models

- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
- Music generation with [Meta's MusicGen](musicgen).

### Multimodal models

- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
- Image segmentation with [Segment Anything (SAM)](segment_anything).

### Other Models 

- Semi-supervised learning on graph-structured data with [GCN](gcn).
- Real NVP [normalizing flow](normalizing_flow) for density estimation and
  sampling.

### Hugging Face

You can directly use or download converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155).

## Contributing 

We are grateful for all of [our
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
to MLX Examples and wish to be acknowledged, please add your name to the list in your
pull request.

## Citing MLX Examples

The MLX software suite was initially developed with equal contribution by Awni
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
MLX Examples useful in your research and wish to cite it, please use the following
BibTex entry:

```
@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}
```


================================================
FILE: bert/README.md
================================================
# BERT

An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX.

## Setup 

Install the requirements:

```
pip install -r requirements.txt
```

Then convert the weights with:

```
python convert.py \
    --bert-model bert-base-uncased \
    --mlx-model weights/bert-base-uncased.npz
```

## Usage

To use the `Bert` model in your own code, you can load it with:

```python
import mlx.core as mx
from model import Bert, load_model

model, tokenizer = load_model(
    "bert-base-uncased",
    "weights/bert-base-uncased.npz")

batch = ["This is an example of BERT working on MLX."]
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}

output, pooled = model(**tokens)
```

The `output` contains a `Batch x Tokens x Dims` tensor, representing a vector
for every input token. If you want to train anything at the **token-level**,
use this.

The `pooled` contains a `Batch x Dims` tensor, which is the pooled
representation for each input. If you want to train a **classification**
model, use this.


## Test

You can check the output for the default model (`bert-base-uncased`) matches the
Hugging Face version with:

```
python test.py
```


================================================
FILE: bert/convert.py
================================================
import argparse

import numpy
from transformers import AutoModel


def replace_key(key: str) -> str:
    key = key.replace(".layer.", ".layers.")
    key = key.replace(".self.key.", ".key_proj.")
    key = key.replace(".self.query.", ".query_proj.")
    key = key.replace(".self.value.", ".value_proj.")
    key = key.replace(".attention.output.dense.", ".attention.out_proj.")
    key = key.replace(".attention.output.LayerNorm.", ".ln1.")
    key = key.replace(".output.LayerNorm.", ".ln2.")
    key = key.replace(".intermediate.dense.", ".linear1.")
    key = key.replace(".output.dense.", ".linear2.")
    key = key.replace(".LayerNorm.", ".norm.")
    key = key.replace("pooler.dense.", "pooler.")
    return key


def convert(bert_model: str, mlx_model: str) -> None:
    model = AutoModel.from_pretrained(bert_model)
    # save the tensors
    tensors = {
        replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
    }
    numpy.savez(mlx_model, **tensors)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
    parser.add_argument(
        "--bert-model",
        type=str,
        default="bert-base-uncased",
        help="The huggingface name of the BERT model to save. Any BERT-like model can be specified.",
    )
    parser.add_argument(
        "--mlx-model",
        type=str,
        default="weights/bert-base-uncased.npz",
        help="The output path for the MLX BERT weights.",
    )
    args = parser.parse_args()

    convert(args.bert_model, args.mlx_model)


================================================
FILE: bert/model.py
================================================
import argparse
from pathlib import Path
from typing import List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase


class TransformerEncoderLayer(nn.Module):
    """
    A transformer encoder layer with (the original BERT) post-normalization.
    """

    def __init__(
        self,
        dims: int,
        num_heads: int,
        mlp_dims: Optional[int] = None,
        layer_norm_eps: float = 1e-12,
    ):
        super().__init__()
        mlp_dims = mlp_dims or dims * 4
        self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
        self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
        self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
        self.linear1 = nn.Linear(dims, mlp_dims)
        self.linear2 = nn.Linear(mlp_dims, dims)
        self.gelu = nn.GELU()

    def __call__(self, x, mask):
        attention_out = self.attention(x, x, x, mask)
        add_and_norm = self.ln1(x + attention_out)

        ff = self.linear1(add_and_norm)
        ff_gelu = self.gelu(ff)
        ff_out = self.linear2(ff_gelu)
        x = self.ln2(ff_out + add_and_norm)

        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
    ):
        super().__init__()
        self.layers = [
            TransformerEncoderLayer(dims, num_heads, mlp_dims)
            for i in range(num_layers)
        ]

    def __call__(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)

        return x


class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size
        )
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size
        )
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def __call__(
        self, input_ids: mx.array, token_type_ids: mx.array = None
    ) -> mx.array:
        words = self.word_embeddings(input_ids)
        position = self.position_embeddings(
            mx.broadcast_to(mx.arange(input_ids.shape[1]), input_ids.shape)
        )

        if token_type_ids is None:
            # If token_type_ids is not provided, default to zeros
            token_type_ids = mx.zeros_like(input_ids)

        token_types = self.token_type_embeddings(token_type_ids)

        embeddings = position + words + token_types
        return self.norm(embeddings)


class Bert(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = BertEmbeddings(config)
        self.encoder = TransformerEncoder(
            num_layers=config.num_hidden_layers,
            dims=config.hidden_size,
            num_heads=config.num_attention_heads,
            mlp_dims=config.intermediate_size,
        )
        self.pooler = nn.Linear(config.hidden_size, config.hidden_size)

    def __call__(
        self,
        input_ids: mx.array,
        token_type_ids: mx.array = None,
        attention_mask: mx.array = None,
    ) -> Tuple[mx.array, mx.array]:
        x = self.embeddings(input_ids, token_type_ids)

        if attention_mask is not None:
            # convert 0's to -infs, 1's to 0's, and make it broadcastable
            attention_mask = mx.log(attention_mask)
            attention_mask = mx.expand_dims(attention_mask, (1, 2))

        y = self.encoder(x, attention_mask)
        return y, mx.tanh(self.pooler(y[:, 0]))


def load_model(
    bert_model: str, weights_path: str
) -> Tuple[Bert, PreTrainedTokenizerBase]:
    if not Path(weights_path).exists():
        raise ValueError(f"No model weights found in {weights_path}")

    config = AutoConfig.from_pretrained(bert_model)

    # create and update the model
    model = Bert(config)
    model.load_weights(weights_path)

    tokenizer = AutoTokenizer.from_pretrained(bert_model)

    return model, tokenizer


def run(bert_model: str, mlx_model: str, batch: List[str]):
    model, tokenizer = load_model(bert_model, mlx_model)

    tokens = tokenizer(batch, return_tensors="np", padding=True)
    tokens = {key: mx.array(v) for key, v in tokens.items()}

    return model(**tokens)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
    parser.add_argument(
        "--bert-model",
        type=str,
        default="bert-base-uncased",
        help="The huggingface name of the BERT model to save.",
    )
    parser.add_argument(
        "--mlx-model",
        type=str,
        default="weights/bert-base-uncased.npz",
        help="The path of the stored MLX BERT weights (npz file).",
    )
    parser.add_argument(
        "--text",
        type=str,
        default="This is an example of BERT working in MLX",
        help="The text to generate embeddings for.",
    )
    args = parser.parse_args()
    run(args.bert_model, args.mlx_model, args.text)


================================================
FILE: bert/requirements.txt
================================================
mlx>=0.0.5
transformers
numpy


================================================
FILE: bert/test.py
================================================
import argparse
from typing import List

import model
import numpy as np
from transformers import AutoModel, AutoTokenizer


def run_torch(bert_model: str, batch: List[str]):
    tokenizer = AutoTokenizer.from_pretrained(bert_model)
    torch_model = AutoModel.from_pretrained(bert_model)
    torch_tokens = tokenizer(batch, return_tensors="pt", padding=True)
    torch_forward = torch_model(**torch_tokens)
    torch_output = torch_forward.last_hidden_state.detach().numpy()
    torch_pooled = torch_forward.pooler_output.detach().numpy()
    return torch_output, torch_pooled


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run a BERT-like model for a batch of text."
    )
    parser.add_argument(
        "--bert-model",
        type=str,
        default="bert-base-uncased",
        help="The model identifier for a BERT-like model from Hugging Face Transformers.",
    )
    parser.add_argument(
        "--mlx-model",
        type=str,
        default="weights/bert-base-uncased.npz",
        help="The path of the stored MLX BERT weights (npz file).",
    )
    parser.add_argument(
        "--text",
        nargs="+",
        default=["This is an example of BERT working in MLX."],
        help="A batch of texts to process. Multiple texts should be separated by spaces.",
    )

    args = parser.parse_args()

    torch_output, torch_pooled = run_torch(args.bert_model, args.text)

    mlx_output, mlx_pooled = model.run(args.bert_model, args.mlx_model, args.text)

    if torch_pooled is not None and mlx_pooled is not None:
        assert np.allclose(
            torch_output, mlx_output, rtol=1e-4, atol=1e-5
        ), "Model output is different"
        assert np.allclose(
            torch_pooled, mlx_pooled, rtol=1e-4, atol=1e-5
        ), "Model pooled output is different"
        print("Tests pass :)")
    else:
        print("Pooled outputs were not compared due to one or both being None.")


================================================
FILE: bert/weights/.gitignore
================================================
*.npz

================================================
FILE: cifar/README.md
================================================
# CIFAR and ResNets

An example of training a ResNet on CIFAR-10 with MLX. Several ResNet
configurations in accordance with the original
[paper](https://arxiv.org/abs/1512.03385) are available. The example also
illustrates how to use [MLX Data](https://github.com/ml-explore/mlx-data) to
load the dataset.

## Pre-requisites

Install the dependencies:

```
pip install -r requirements.txt
```

## Running the example

Run the example with:

```
python main.py
```

By default the example runs on the GPU. To run on the CPU, use: 

```
python main.py --cpu
```

For all available options, run:

```
python main.py --help
```

## Results

After training with the default `resnet20` architecture for 30 epochs, you
should see the following results:

```
Epoch: 29 | avg. Train loss 0.294 | avg. Train acc 0.897 | Throughput: 270.81 images/sec
Epoch: 29 | Test acc 0.841
```

Note this was run on an M1 Macbook Pro with 16GB RAM.

At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.

## Distributed training

The example also supports distributed data parallel training. You can launch a
distributed training as follows:

```shell
$ cat >hostfile.json
[
    {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
    {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```


================================================
FILE: cifar/dataset.py
================================================
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10


def get_cifar10(batch_size, root=None):
    tr = load_cifar10(root=root)

    mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
    std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))

    def normalize(x):
        x = x.astype("float32") / 255.0
        return (x - mean) / std

    group = mx.distributed.init()

    tr_iter = (
        tr.shuffle()
        .partition_if(group.size() > 1, group.size(), group.rank())
        .to_stream()
        .image_random_h_flip("image", prob=0.5)
        .pad("image", 0, 4, 4, 0.0)
        .pad("image", 1, 4, 4, 0.0)
        .image_random_crop("image", 32, 32)
        .key_transform("image", normalize)
        .batch(batch_size)
        .prefetch(4, 4)
    )

    test = load_cifar10(root=root, train=False)
    test_iter = (
        test.to_stream()
        .partition_if(group.size() > 1, group.size(), group.rank())
        .key_transform("image", normalize)
        .batch(batch_size)
    )

    return tr_iter, test_iter


================================================
FILE: cifar/main.py
================================================
import argparse
import time
from functools import partial

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import resnet
from dataset import get_cifar10

parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
    "--arch",
    type=str,
    default="resnet20",
    choices=[f"resnet{d}" for d in [20, 32, 44, 56, 110, 1202]],
    help="model architecture",
)
parser.add_argument("--batch_size", type=int, default=256, help="batch size")
parser.add_argument("--epochs", type=int, default=30, help="number of epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")


def print_zero(group, *args, **kwargs):
    if group.rank() != 0:
        return
    flush = kwargs.pop("flush", True)
    print(*args, **kwargs, flush=flush)


def eval_fn(model, inp, tgt):
    return mx.mean(mx.argmax(model(inp), axis=1) == tgt)


def train_epoch(model, train_iter, optimizer, epoch):
    def train_step(model, inp, tgt):
        output = model(inp)
        loss = mx.mean(nn.losses.cross_entropy(output, tgt))
        acc = mx.mean(mx.argmax(output, axis=1) == tgt)
        return loss, acc

    world = mx.distributed.init()
    losses = 0
    accuracies = 0
    samples_per_sec = 0
    count = 0

    def average_stats(stats, count):
        if world.size() == 1:
            return [s / count for s in stats]

        with mx.stream(mx.cpu):
            stats = mx.distributed.all_sum(mx.array(stats))
            count = mx.distributed.all_sum(count)
            return (stats / count).tolist()

    state = [model.state, optimizer.state]

    @partial(mx.compile, inputs=state, outputs=state)
    def step(inp, tgt):
        train_step_fn = nn.value_and_grad(model, train_step)
        (loss, acc), grads = train_step_fn(model, inp, tgt)
        grads = nn.utils.average_gradients(grads)
        optimizer.update(model, grads)
        return loss, acc

    for batch_counter, batch in enumerate(train_iter):
        x = mx.array(batch["image"])
        y = mx.array(batch["label"])
        tic = time.perf_counter()
        loss, acc = step(x, y)
        mx.eval(loss, acc, state)
        toc = time.perf_counter()
        losses += loss.item()
        accuracies += acc.item()
        samples_per_sec += x.shape[0] / (toc - tic)
        count += 1
        if batch_counter % 10 == 0:
            l, a, s = average_stats(
                [losses, accuracies, world.size() * samples_per_sec],
                count,
            )
            print_zero(
                world,
                " | ".join(
                    (
                        f"Epoch {epoch:02d} [{batch_counter:03d}]",
                        f"Train loss {l:.3f}",
                        f"Train acc {a:.3f}",
                        f"Throughput: {s:.2f} images/second",
                    )
                ),
            )

    return average_stats([losses, accuracies, world.size() * samples_per_sec], count)


def test_epoch(model, test_iter, epoch):
    accuracies = 0
    count = 0
    for batch_counter, batch in enumerate(test_iter):
        x = mx.array(batch["image"])
        y = mx.array(batch["label"])
        acc = eval_fn(model, x, y)
        accuracies += acc.item()
        count += 1

    with mx.stream(mx.cpu):
        accuracies = mx.distributed.all_sum(accuracies)
        count = mx.distributed.all_sum(count)
        return (accuracies / count).item()


def main(args):
    mx.random.seed(args.seed)

    # Initialize the distributed group and report the nodes that showed up
    world = mx.distributed.init()
    if world.size() > 1:
        print(f"Starting rank {world.rank()} of {world.size()}", flush=True)

    model = getattr(resnet, args.arch)()

    print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")

    optimizer = optim.Adam(learning_rate=args.lr)

    train_data, test_data = get_cifar10(args.batch_size)
    for epoch in range(args.epochs):
        tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
        print_zero(
            world,
            " | ".join(
                (
                    f"Epoch: {epoch}",
                    f"avg. Train loss {tr_loss:.3f}",
                    f"avg. Train acc {tr_acc:.3f}",
                    f"Throughput: {throughput:.2f} images/sec",
                )
            ),
        )

        test_acc = test_epoch(model, test_data, epoch)
        print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")

        train_data.reset()
        test_data.reset()


if __name__ == "__main__":
    args = parser.parse_args()
    if args.cpu:
        mx.set_default_device(mx.cpu)
    main(args)


================================================
FILE: cifar/requirements.txt
================================================
mlx>=0.2
mlx-data
numpy


================================================
FILE: cifar/resnet.py
================================================
"""
Implementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].
Configurations include ResNet-20, ResNet-32, ResNet-44, ResNet-56, ResNet-110, ResNet-1202.
"""

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten

__all__ = [
    "ResNet",
    "resnet20",
    "resnet32",
    "resnet44",
    "resnet56",
    "resnet110",
    "resnet1202",
]


class ShortcutA(nn.Module):
    def __init__(self, dims):
        super().__init__()
        self.dims = dims

    def __call__(self, x):
        return mx.pad(
            x[:, ::2, ::2, :],
            pad_width=[(0, 0), (0, 0), (0, 0), (self.dims // 4, self.dims // 4)],
        )


class Block(nn.Module):
    """
    Implements a ResNet block with two convolutional layers and a skip connection.
    As per the paper, CIFAR-10 uses Shortcut type-A skip connections. (See paper for details)
    """

    def __init__(self, in_dims, dims, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm(dims)

        self.conv2 = nn.Conv2d(
            dims, dims, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm(dims)

        if stride != 1:
            self.shortcut = ShortcutA(dims)
        else:
            self.shortcut = None

    def __call__(self, x):
        out = nn.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.shortcut is None:
            out += x
        else:
            out += self.shortcut(x)
        out = nn.relu(out)
        return out


class ResNet(nn.Module):
    """
    Creates a ResNet model for CIFAR-10, as specified in the original paper.
    """

    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm(16)

        self.layer1 = self._make_layer(block, 16, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 16, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 32, 64, num_blocks[2], stride=2)

        self.linear = nn.Linear(64, num_classes)

    def _make_layer(self, block, in_dims, dims, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(in_dims, dims, stride))
            in_dims = dims
        return nn.Sequential(*layers)

    def num_params(self):
        nparams = sum(x.size for k, x in tree_flatten(self.parameters()))
        return nparams

    def __call__(self, x):
        x = nn.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = mx.mean(x, axis=[1, 2]).reshape(x.shape[0], -1)
        x = self.linear(x)
        return x


def resnet20(**kwargs):
    return ResNet(Block, [3, 3, 3], **kwargs)


def resnet32(**kwargs):
    return ResNet(Block, [5, 5, 5], **kwargs)


def resnet44(**kwargs):
    return ResNet(Block, [7, 7, 7], **kwargs)


def resnet56(**kwargs):
    return ResNet(Block, [9, 9, 9], **kwargs)


def resnet110(**kwargs):
    return ResNet(Block, [18, 18, 18], **kwargs)


def resnet1202(**kwargs):
    return ResNet(Block, [200, 200, 200], **kwargs)


================================================
FILE: clip/.gitignore
================================================
mlx_model/


================================================
FILE: clip/README.md
================================================
# CLIP

An example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image
pre-training) model embeds images and text in the same space.[^1]

### Setup

Install the dependencies:

```shell
pip install -r requirements.txt
```

Next, download a CLIP model from Hugging Face and convert it to MLX. The
default model is
[openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32).

```
python convert.py
```

The script will by default download the model and configuration files to the
directory ``mlx_model/``.

### Run

You can use the CLIP model to embed images and text. 

```python
from PIL import Image
import clip

model, tokenizer, img_processor = clip.load("mlx_model")
inputs = {
    "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]),
    "pixel_values": img_processor(
        [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")]
    ),
}
output = model(**inputs)

# Get text and image embeddings:
text_embeds = output.text_embeds
image_embeds = output.image_embeds
```

Run the above example with `python clip.py`.

To embed only images or only the text, pass only the ``input_ids`` or
``pixel_values``, respectively.

This example re-implements minimal image preprocessing and tokenization to reduce
dependencies. For additional preprocessing functionality, you can use
``transformers``. The file `hf_preproc.py` has an example.

MLX CLIP has been tested and works with the following Hugging Face repos:

- [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32)
- [openai/clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)

You can run the tests with:

```shell
python test.py
```

To test new models, update the `MLX_PATH` and `HF_PATH` in `test.py`.

### Attribution

- `assets/cat.jpeg` is a "Cat" by London's, licensed under CC BY-SA 2.0.
- `assets/dog.jpeg` is a "Happy Dog" by tedmurphy, licensed under CC BY 2.0.

[^1]: Refer to the original paper [Learning Transferable Visual Models From
  Natural Language Supervision ](https://arxiv.org/abs/2103.00020) or [blog
  post](https://openai.com/research/clip)


================================================
FILE: clip/clip.py
================================================
from typing import Tuple

from image_processor import CLIPImageProcessor
from model import CLIPModel
from tokenizer import CLIPTokenizer


def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]:
    model = CLIPModel.from_pretrained(model_dir)
    tokenizer = CLIPTokenizer.from_pretrained(model_dir)
    img_processor = CLIPImageProcessor.from_pretrained(model_dir)
    return model, tokenizer, img_processor


if __name__ == "__main__":
    from PIL import Image

    model, tokenizer, img_processor = load("mlx_model")
    inputs = {
        "input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]),
        "pixel_values": img_processor(
            [Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")]
        ),
    }
    output = model(**inputs)

    # Get text and image embeddings:
    text_embeds = output.text_embeds
    image_embeds = output.image_embeds
    print("Text embeddings shape:", text_embeds.shape)
    print("Image embeddings shape:", image_embeds.shape)


================================================
FILE: clip/convert.py
================================================
# Copyright © 2023-2024 Apple Inc.

import argparse
import json
import shutil
from pathlib import Path
from typing import Any, Dict, Union

import mlx.core as mx
import torch
from huggingface_hub import snapshot_download


def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
    max_file_size_bytes = max_file_size_gb << 30
    shards = []
    shard, shard_size = {}, 0
    for k, v in weights.items():
        if shard_size + v.nbytes > max_file_size_bytes:
            shards.append(shard)
            shard, shard_size = {}, 0
        shard[k] = v
        shard_size += v.nbytes
    shards.append(shard)
    return shards


def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
    """Save model weights into specified directory."""
    if isinstance(save_path, str):
        save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)

    shards = make_shards(weights)
    shards_count = len(shards)
    shard_file_format = (
        "model-{:05d}-of-{:05d}.safetensors"
        if shards_count > 1
        else "model.safetensors"
    )

    total_size = sum(v.nbytes for v in weights.values())
    index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}

    for i, shard in enumerate(shards):
        shard_name = shard_file_format.format(i + 1, shards_count)
        shard_path = save_path / shard_name

        mx.save_safetensors(str(shard_path), shard)

        for weight_name in shard.keys():
            index_data["weight_map"][weight_name] = shard_name

    index_data["weight_map"] = {
        k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
    }

    with open(save_path / "model.safetensors.index.json", "w") as f:
        json.dump(
            index_data,
            f,
            indent=4,
        )


def get_model_path(path_or_hf_repo: str, force_download: bool = False) -> Path:
    model_path = Path(path_or_hf_repo)
    if not model_path.exists():
        model_path = Path(
            snapshot_download(
                repo_id=path_or_hf_repo,
                allow_patterns=[
                    "*.bin",
                    "*.json",
                    "*.txt",
                ],
                force_download=force_download,
            )
        )
    return model_path


def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
    # bfloat16 is not numpy convertible. Upcast to float32 to avoid precision loss
    a = a.to(torch.float32) if dtype == "bfloat16" else a.to(getattr(torch, dtype))
    return mx.array(a.numpy(), getattr(mx, dtype))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Download and Convert (OpenAI) CLIP weights to MLX"
    )
    parser.add_argument(
        "--hf-repo",
        type=str,
        default="openai/clip-vit-base-patch32",
        help="Hugging Face repository name.",
    )
    parser.add_argument(
        "--mlx-path",
        type=str,
        default="mlx_model",
        help="Path to save the MLX model.",
    )
    parser.add_argument(
        "--dtype",
        help="The data type to save the converted model.",
        type=str,
        default="float32",
    )
    parser.add_argument(
        "-f",
        "--force-download",
        help="Force download the model from Hugging Face.",
        action="store_true",
    )
    args = parser.parse_args()

    torch_path = get_model_path(args.hf_repo, args.force_download)
    mlx_path = Path(args.mlx_path)
    mlx_path.mkdir(parents=True, exist_ok=True)

    print("[INFO] Loading")
    torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
    print("[INFO] Converting")
    mlx_weights = {
        k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()
    }
    print("[INFO] Saving")
    save_weights(mlx_path, mlx_weights)
    for fn in ["config.json", "merges.txt", "vocab.json", "preprocessor_config.json"]:
        shutil.copyfile(
            str(torch_path / f"{fn}"),
            str(mlx_path / f"{fn}"),
        )


================================================
FILE: clip/hf_preproc.py
================================================
import mlx.core as mx
import transformers
from PIL import Image

import clip

hf_model = "openai/clip-vit-base-patch32"
mlx_model = "mlx_model"

model, *_ = clip.load(mlx_model)
processor = transformers.CLIPProcessor.from_pretrained(hf_model)

inputs = processor(
    text=["a photo of a cat", "a photo of a dog"],
    images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
    return_tensors="np",
)

out = model(
    input_ids=mx.array(inputs.input_ids),
    pixel_values=mx.array(inputs.pixel_values).transpose((0, 2, 3, 1)),
    return_loss=True,
)

print("text embeddings:")
print(out.text_embeds)
print("image embeddings:")
print(out.image_embeds)
print(f"CLIP loss: {out.loss.item():.3f}")


================================================
FILE: clip/image_processor.py
================================================
# Copyright © 2023-2024 Apple Inc.

import json
from pathlib import Path
from typing import List, Tuple

import mlx.core as mx
import numpy as np
from PIL.Image import Image


class CLIPImageProcessor:
    """
    A simple port of
    https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py.
    """

    def __init__(
        self,
        crop_size: int = 224,
        do_center_crop: bool = True,
        do_normalize: bool = True,
        do_resize: bool = True,
        image_mean: List[float] = [0.48145466, 0.4578275, 0.40821073],
        image_std: List[float] = [0.26862954, 0.26130258, 0.27577711],
        size: int = 224,
        **kwargs
    ) -> None:
        self.crop_size = crop_size
        self.do_center_crop = do_center_crop
        self.do_normalize = do_normalize
        self.do_resize = do_resize
        self.image_mean = mx.array(image_mean)
        self.image_std = mx.array(image_std)
        self.size = size

    def __call__(self, images: List[Image]) -> mx.array:
        return mx.concatenate(
            [self._preprocess(image)[None] for image in images], axis=0
        )

    def _preprocess(self, image: Image) -> mx.array:
        if self.do_resize:
            image = resize(image, self.size)
        if self.do_center_crop:
            image = center_crop(image, (self.crop_size, self.crop_size))
        image = mx.array(np.array(image))
        image = rescale(image)
        if self.do_normalize:
            image = normalize(image, self.image_mean, self.image_std)
        return image

    @staticmethod
    def from_pretrained(path: str):
        path = Path(path)
        with open(path / "preprocessor_config.json", encoding="utf-8") as f:
            config = json.load(f)
        return CLIPImageProcessor(**config)


def resize(image: Image, short_size: int) -> Image:
    """
    Resize so small size to short_size
    """
    width, height = image.size
    short = min(width, height)
    long = max(width, height)
    if short == short_size:
        return image
    new_short = short_size
    new_long = int(short_size * long / short)
    new_size = (new_short, new_long) if width <= height else (new_long, new_short)
    return image.resize(new_size)


def center_crop(image: Image, size: Tuple[int, int]) -> Image:
    if size[0] % 2 != 0 or size[1] % 2 != 0:
        raise ValueError("Only even crop sizes supported.")
    original_width, original_height = image.size
    crop_height, crop_width = size
    top = (original_height - crop_height) // 2
    bottom = top + crop_height
    left = (original_width - crop_width) // 2
    right = left + crop_width
    return image.crop((left, top, right, bottom))


def rescale(image: mx.array) -> mx.array:
    return image.astype(mx.float32) * (1 / 255.0)


def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array:
    return (image - mean) / std


================================================
FILE: clip/linear_probe.py
================================================
# Mirror of the Linear Probe Evaluation Script
# from the official CLIP Repository.

import mlx.core as mx
import numpy as np
from image_processor import CLIPImageProcessor
from mlx.data.datasets import load_cifar10
from model import CLIPModel
from PIL import Image
from sklearn.linear_model import LogisticRegression
from tqdm import tqdm


def get_cifar10(batch_size, root=None):
    tr = load_cifar10(root=root).batch(batch_size)
    test = load_cifar10(root=root, train=False).batch(batch_size)

    return tr, test


def get_features(model, image_proc, iter):
    all_features = []
    all_labels = []

    for batch in tqdm(iter):
        image, label = batch["image"], batch["label"]
        x = image_proc([Image.fromarray(im) for im in image])
        y = mx.array(label)

        image_embeds = model.get_image_features(x)
        mx.eval(image_embeds)

        all_features.append(image_embeds)
        all_labels.append(y)

    return mx.concatenate(all_features), mx.concatenate(all_labels)


if __name__ == "__main__":
    model = CLIPModel.from_pretrained("mlx_model")
    image_proc = CLIPImageProcessor.from_pretrained("mlx_model")

    train_iter, test_iter = get_cifar10(batch_size=256)
    train_features, train_labels = get_features(model, image_proc, train_iter)
    test_features, test_labels = get_features(model, image_proc, test_iter)

    # Perform logistic regression
    # NOTE: The value of C should be determined via a hyperparameter sweep
    # using a validation split
    classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
    classifier.fit(train_features, train_labels)

    # Evaluate using the logistic regression classifier
    predictions = classifier.predict(test_features)
    accuracy = (test_labels.squeeze() == predictions).mean().item() * 100
    print(f"Accuracy = {accuracy:.3f}")


================================================
FILE: clip/model.py
================================================
# Copyright © 2023-2024 Apple Inc.

import glob
import json
import logging
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import mlx.core as mx
import mlx.nn as nn
from mlx.core import linalg as LA
from mlx.nn.losses import cross_entropy


@dataclass
class CLIPVisionOutput:
    pooler_output: mx.array
    last_hidden_state: mx.array
    hidden_states: Optional[mx.array]


@dataclass
class CLIPTextOutput:
    pooler_output: mx.array
    last_hidden_state: mx.array


@dataclass
class CLIPModelOutput:
    loss: Optional[mx.array]
    text_embeds: Optional[mx.array]
    image_embeds: Optional[mx.array]
    text_model_output: CLIPTextOutput
    vision_model_output: CLIPVisionOutput


@dataclass
class CLIPTextConfig:
    num_hidden_layers: int
    hidden_size: int
    intermediate_size: int
    num_attention_heads: int
    max_position_embeddings: int
    vocab_size: int
    layer_norm_eps: float


@dataclass
class CLIPVisionConfig:
    num_hidden_layers: int
    hidden_size: int
    intermediate_size: int
    num_attention_heads: int
    num_channels: int
    image_size: int
    patch_size: int
    layer_norm_eps: float


@dataclass
class CLIPConfig:
    text_config: CLIPTextConfig
    vision_config: CLIPVisionConfig
    projection_dim: int


def quick_gelu(x: mx.array) -> mx.array:
    """
    A fast GELU approximation https://github.com/hendrycks/GELUs
    """
    return x * mx.sigmoid(1.702 * x)


def clip_loss(logits: mx.array) -> mx.array:
    N, M = logits.shape
    caption_loss = cross_entropy(logits, mx.arange(N), reduction="mean")
    image_loss = cross_entropy(logits.T, mx.arange(M), reduction="mean")
    return (caption_loss + image_loss) / 2.0


class Attention(nn.Module):
    def __init__(
        self,
        dims: int,
        num_heads: int,
        query_input_dims: Optional[int] = None,
        key_input_dims: Optional[int] = None,
        value_input_dims: Optional[int] = None,
        value_dims: Optional[int] = None,
        value_output_dims: Optional[int] = None,
        bias: bool = False,
    ):
        super().__init__()

        if (dims % num_heads) != 0:
            raise ValueError(
                "The input feature dimensions should be divisible by the "
                f"number of heads ({dims} % {num_heads}) != 0"
            )

        query_input_dims = query_input_dims or dims
        key_input_dims = key_input_dims or dims
        value_input_dims = value_input_dims or key_input_dims
        value_dims = value_dims or dims
        value_output_dims = value_output_dims or dims

        self.num_heads = num_heads
        self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
        self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
        self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
        self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)

    def __call__(self, queries, keys, values, mask=None):
        queries = self.q_proj(queries)
        keys = self.k_proj(keys)
        values = self.v_proj(values)

        num_heads = self.num_heads
        B, L, D = queries.shape
        _, S, _ = keys.shape
        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
        values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)

        scale = math.sqrt(1 / queries.shape[-1])
        scores = (queries * scale) @ keys
        if mask is not None:
            scores = scores + mask.astype(scores.dtype)
        scores = mx.softmax(scores, axis=-1)
        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)

        return self.out_proj(values_hat)


class MLP(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.config = config
        self.activation_fn = quick_gelu
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def __call__(self, x: mx.array) -> mx.array:
        x = self.activation_fn(self.fc1(x))
        x = self.fc2(x)
        return x


class EncoderLayer(nn.Module):
    """The transformer encoder layer from CLIP."""

    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        # Add biases to the attention projections
        self.self_attn = Attention(
            config.hidden_size, config.num_attention_heads, bias=True
        )
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = MLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
        y = self.layer_norm1(x)
        y = self.self_attn(y, y, y, mask)
        x = x + y
        y = self.layer_norm2(x)
        y = self.mlp(y)
        return x + y


class TextEmbeddings(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        embed_dim = config.hidden_size

        self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(
            config.max_position_embeddings, embed_dim
        )

    def __call__(self, x: mx.array) -> mx.array:
        embeddings = self.token_embedding(x)
        embeddings += self.position_embedding.weight[: x.shape[1]]
        return embeddings


class Encoder(nn.Module):
    def __init__(self, config: CLIPTextConfig):
        self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]


class ClipTextModel(nn.Module):
    """Implements the text encoder transformer from CLIP."""

    def __init__(self, config: CLIPTextConfig):
        super().__init__()
        self.embeddings = TextEmbeddings(config)
        self.encoder = Encoder(config)
        self.final_layer_norm = nn.LayerNorm(config.hidden_size)

    def __call__(self, x: mx.array) -> CLIPTextOutput:
        B, N = x.shape
        eot_tokens = mx.argmax(x, axis=-1)
        x = self.embeddings(x)
        mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
        for l in self.encoder.layers:
            x = l(x, mask)
        last_hidden_state = self.final_layer_norm(x)
        pooler_output = last_hidden_state[mx.arange(B), eot_tokens]

        return CLIPTextOutput(
            pooler_output=pooler_output, last_hidden_state=last_hidden_state
        )


class VisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = mx.zeros((config.hidden_size,))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)

    def __call__(self, x: mx.array) -> mx.array:
        batch_size = x.shape[0]
        # Patchify using conv:
        # [batch_size, sqrt(num_patches), sqrt(num_patches), embed_dim]
        patch_embeddings = self.patch_embedding(x)
        # [batch_size, num_patches, embed_dim]
        patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
        embed_dim = patch_embeddings.shape[-1]
        # Prepend <CLS> embeddings
        # [batch_size, 1, embed_dim]
        cls_embeddings = mx.broadcast_to(
            self.class_embedding, (batch_size, 1, embed_dim)
        )
        # [batch_size, num_patches + 1, embed_dim]
        embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
        # Add positional encoding
        embeddings += self.position_embedding.weight
        return embeddings


class ClipVisionModel(nn.Module):
    """Implements the vision encoder transformer from CLIP."""

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.embeddings = VisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
        self.encoder = Encoder(config)
        self.post_layernorm = nn.LayerNorm(config.hidden_size)

    def __call__(
        self,
        x: mx.array,
        output_hidden_states: Optional[bool] = None,
    ) -> CLIPVisionOutput:
        x = self.embeddings(x)
        x = self.pre_layrnorm(x)

        encoder_states = (x,) if output_hidden_states else None

        for l in self.encoder.layers:
            x = l(x, mask=None)
            if output_hidden_states:
                encoder_states = encoder_states + (x,)

        # Extract <CLS> token embedding
        pooler_output = self.post_layernorm(x[:, 0, :])
        return CLIPVisionOutput(
            pooler_output=pooler_output,
            last_hidden_state=x,
            hidden_states=encoder_states,
        )


class CLIPModel(nn.Module):
    def __init__(self, config: CLIPConfig):
        self.text_model = ClipTextModel(config.text_config)
        self.vision_model = ClipVisionModel(config.vision_config)

        text_embed_dim = config.text_config.hidden_size
        vision_embed_dim = config.vision_config.hidden_size
        projection_dim = config.projection_dim

        self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False)
        self.text_projection = nn.Linear(text_embed_dim, projection_dim, bias=False)
        self.logit_scale = mx.array(0.0)

    def get_text_features(self, x: mx.array) -> mx.array:
        return self.text_projection(self.text_model(x).pooler_output)

    def get_image_features(self, x: mx.array) -> mx.array:
        return self.visual_projection(self.vision_model(x).pooler_output)

    def __call__(
        self,
        input_ids: Optional[mx.array] = None,
        pixel_values: Optional[mx.array] = None,
        return_loss=False,
    ) -> CLIPModelOutput:
        if input_ids is not None:
            text_model_output = self.text_model(input_ids)
            text_embeds = self.text_projection(text_model_output.pooler_output)
            text_embeds = text_embeds / LA.norm(text_embeds, axis=-1, keepdims=True)
        else:
            text_embeds = None
            text_model_output = None

        if pixel_values is not None:
            vision_model_output = self.vision_model(pixel_values)
            image_embeds = self.visual_projection(vision_model_output.pooler_output)
            image_embeds = image_embeds / LA.norm(image_embeds, axis=-1, keepdims=True)
        else:
            image_embeds = None
            vision_model_output = None

        if return_loss and (input_ids is None or pixel_values is None):
            raise ValueError("Must provide text and image inputs to compute loss.")

        if return_loss:
            logit_scale = mx.exp(self.logit_scale)
            logits = (text_embeds @ image_embeds.T) * logit_scale
            loss = clip_loss(logits)
        else:
            loss = None

        return CLIPModelOutput(
            loss=loss,
            text_embeds=text_embeds,
            image_embeds=image_embeds,
            vision_model_output=vision_model_output,
            text_model_output=text_model_output,
        )

    @staticmethod
    def from_pretrained(path: str):
        path = Path(path)

        with open(path / "config.json", "r") as fid:
            config = json.load(fid)

        text_config = config["text_config"]
        text_config = CLIPTextConfig(
            num_hidden_layers=text_config["num_hidden_layers"],
            hidden_size=text_config["hidden_size"],
            intermediate_size=text_config["intermediate_size"],
            num_attention_heads=text_config["num_attention_heads"],
            max_position_embeddings=text_config["max_position_embeddings"],
            vocab_size=text_config["vocab_size"],
            layer_norm_eps=text_config["layer_norm_eps"],
        )

        vision_config = config["vision_config"]

        vision_config = CLIPVisionConfig(
            num_hidden_layers=vision_config["num_hidden_layers"],
            hidden_size=vision_config["hidden_size"],
            intermediate_size=vision_config["intermediate_size"],
            num_attention_heads=vision_config["num_attention_heads"],
            num_channels=3,
            image_size=vision_config["image_size"],
            patch_size=vision_config["patch_size"],
            layer_norm_eps=vision_config["layer_norm_eps"],
        )

        config = CLIPConfig(
            text_config=text_config,
            vision_config=vision_config,
            projection_dim=config["projection_dim"],
        )
        model = CLIPModel(config)
        weight_files = glob.glob(str(path / "*.safetensors"))
        if not weight_files:
            logging.error(f"No safetensors found in {path}")
            raise FileNotFoundError(f"No safetensors found in {path}")

        weights = {}
        for wf in weight_files:
            weights.update(mx.load(wf))

        weights = model.sanitize(weights)
        model.load_weights(list(weights.items()))
        return model

    @staticmethod
    def sanitize(weights):
        sanitized_weights = {}
        for k, v in weights.items():
            if "position_ids" in k:
                # Remove unused position_ids
                continue
            elif "patch_embedding.weight" in k:
                # pytorch conv2d expects the weight tensor to be of shape [out_channels, in_channels, kH, KW]
                # mlx conv2d expects the weight tensor to be of shape [out_channels, kH, KW, in_channels]
                sanitized_weights[k] = v.transpose(0, 2, 3, 1)
            else:
                sanitized_weights[k] = v

        return sanitized_weights


================================================
FILE: clip/requirements.txt
================================================
mlx
mlx-data
numpy
transformers
torch
huggingface_hub
Pillow


================================================
FILE: clip/test.py
================================================
import unittest

import mlx.core as mx
import model
import numpy as np
import torch
import transformers
from image_processor import CLIPImageProcessor
from PIL import Image
from tokenizer import CLIPTokenizer
from transformers import AutoTokenizer
from transformers.image_processing_utils import ChannelDimension

MLX_PATH = "mlx_model"
HF_PATH = "openai/clip-vit-base-patch32"


def load_mlx_models(path):
    image_proc = CLIPImageProcessor.from_pretrained(path)
    tokenizer = CLIPTokenizer.from_pretrained(path)
    clip = model.CLIPModel.from_pretrained(path)
    return image_proc, tokenizer, clip


def load_hf_models(path):
    image_proc = transformers.CLIPImageProcessor.from_pretrained(path)
    tokenizer = AutoTokenizer.from_pretrained(path)
    clip = transformers.CLIPModel.from_pretrained(path)
    return image_proc, tokenizer, clip


class TestCLIP(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.mx_image_proc, cls.mx_tokenizer, cls.mx_clip = load_mlx_models(MLX_PATH)
        cls.hf_image_proc, cls.hf_tokenizer, cls.hf_clip = load_hf_models(HF_PATH)

    def test_image_processor(self):
        image = Image.open("assets/cat.jpeg")

        mx_data = self.mx_image_proc([image])
        hf_data = mx.array(
            np.array(
                self.hf_image_proc([image], data_format=ChannelDimension.LAST)[
                    "pixel_values"
                ]
            )
        )
        self.assertTrue(mx.allclose(mx_data, hf_data, atol=1e-5))

    def test_text_tokenizer(self):
        texts = ["a photo of a cat", "a photo of a dog"]
        for txt in texts:
            self.assertTrue(
                np.array_equal(
                    self.mx_tokenizer.tokenize(txt)[None, :],
                    self.hf_tokenizer(txt, return_tensors="np")["input_ids"],
                ),
            )

    def test_text_encoder(self):
        texts = ["a photo of a cat", "a photo of a dog"]
        # Tokenize
        hf_tokens = self.hf_tokenizer(texts, return_tensors="pt")
        mx_tokens = self.mx_tokenizer(texts)
        # Get expected
        with torch.inference_mode():
            expected_out = self.hf_clip.text_model(**hf_tokens)
            expected_last_hidden = expected_out.last_hidden_state.numpy()
            expected_pooler_output = expected_out.pooler_output.numpy()
        out = self.mx_clip.text_model(mx_tokens)
        self.assertTrue(
            np.allclose(out.last_hidden_state, expected_last_hidden, atol=1e-5)
        )
        self.assertTrue(
            np.allclose(out.pooler_output, expected_pooler_output, atol=1e-5)
        )

    def test_vision_encoder(self):
        # Load and process test image
        x = self.hf_image_proc(
            images=[Image.open("assets/dog.jpeg")], return_tensors="np"
        ).pixel_values

        # Infer with HuggingFace model
        with torch.inference_mode():
            # Get expected
            x_tc = torch.tensor(x)
            expected_out = self.hf_clip.vision_model(x_tc, output_hidden_states=True)
            expected_last_hidden = expected_out.last_hidden_state.numpy()
            expected_pooler_output = expected_out.pooler_output.numpy()
            expected_hidden_states = [hs.numpy() for hs in expected_out.hidden_states]
        # Test MLX vision encoder
        out = self.mx_clip.vision_model(
            mx.array(x.transpose(0, 2, 3, 1)), output_hidden_states=True
        )
        self.assertTrue(
            np.allclose(
                out.last_hidden_state, expected_last_hidden, rtol=1e-4, atol=1e-3
            ),
        )
        self.assertTrue(
            np.allclose(
                out.pooler_output, expected_pooler_output, rtol=1e-4, atol=1e-3
            ),
        )
        for expected_hs, out_hs in zip(expected_hidden_states, out.hidden_states):
            self.assertTrue(
                np.allclose(expected_hs, out_hs, rtol=1e-4, atol=1e-3),
            )

    def test_clip_model(self):
        image_input = self.hf_image_proc(
            images=[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")],
            return_tensors="np",
        )["pixel_values"]
        text = ["a photo of a cat", "a photo of a dog"]
        tokens = self.hf_tokenizer(text, return_tensors="np")["input_ids"]
        with torch.inference_mode():
            expected_out = self.hf_clip(
                input_ids=torch.tensor(tokens),
                pixel_values=torch.tensor(image_input),
                return_loss=True,
            )

        out = self.mx_clip(
            input_ids=mx.array(tokens),
            pixel_values=mx.array(image_input.transpose((0, 2, 3, 1))),
            return_loss=True,
        )

        self.assertTrue(
            np.allclose(out.text_embeds, expected_out.text_embeds, atol=1e-5)
        )
        self.assertTrue(
            np.allclose(out.image_embeds, expected_out.image_embeds, atol=1e-5)
        )
        self.assertTrue(np.allclose(out.loss, expected_out.loss, atol=1e-5))


if __name__ == "__main__":
    unittest.main()


================================================
FILE: clip/tokenizer.py
================================================
# Copyright © 2023-2024 Apple Inc.

import json
from pathlib import Path
from typing import Any

import mlx.core as mx
import regex


class CLIPTokenizer:
    """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""

    def __init__(self, bpe_ranks, vocab):
        self.bpe_ranks = bpe_ranks
        self.vocab = vocab
        self.pat = regex.compile(
            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
            regex.IGNORECASE,
        )
        self._cache = {self.bos: self.bos, self.eos: self.eos}

    @property
    def bos(self):
        return "<|startoftext|>"

    @property
    def bos_token(self):
        return self.vocab[self.bos]

    @property
    def eos(self):
        return "<|endoftext|>"

    @property
    def eos_token(self):
        return self.vocab[self.eos]

    def bpe(self, text):
        if text in self._cache:
            return self._cache[text]

        unigrams = list(text[:-1]) + [text[-1] + "</w>"]
        unique_bigrams = set(zip(unigrams, unigrams[1:]))

        if not unique_bigrams:
            return unigrams

        # In every iteration try to merge the two most likely bigrams. If none
        # was merged we are done.
        #
        # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_py
        while unique_bigrams:
            bigram = min(
                unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
            )
            if bigram not in self.bpe_ranks:
                break

            new_unigrams = []
            skip = False
            for a, b in zip(unigrams, unigrams[1:]):
                if skip:
                    skip = False
                    continue

                if (a, b) == bigram:
                    new_unigrams.append(a + b)
                    skip = True

                else:
                    new_unigrams.append(a)

            if not skip:
                new_unigrams.append(b)

            unigrams = new_unigrams
            unique_bigrams = set(zip(unigrams, unigrams[1:]))

        self._cache[text] = unigrams

        return unigrams

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.tokenize(*args, **kwargs)

    def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array:
        if isinstance(text, list):
            return mx.array([self.tokenize(t, prepend_bos, append_eos) for t in text])

        # Lower case, cleanup, and split. Hugging Face does a much,
        # more thorough job here but this should suffice for 95% of
        # cases.
        clean_text = regex.sub(r"\s+", " ", text.lower())
        tokens = regex.findall(self.pat, clean_text)

        # Split the tokens according to the byte-pair merge file
        bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]

        # Map to token ids and return
        tokens = []
        if prepend_bos:
            tokens.append(self.bos_token)
        tokens.extend(self.vocab[t] for t in bpe_tokens)
        if append_eos:
            tokens.append(self.eos_token)
        return mx.array(tokens)

    @staticmethod
    def from_pretrained(path: str):
        path = Path(path)

        with open(path / "vocab.json", encoding="utf-8") as f:
            vocab = json.load(f)
        with open(path / "merges.txt", encoding="utf-8") as f:
            bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]

        bpe_merges = [tuple(m.split()) for m in bpe_merges]
        bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))

        return CLIPTokenizer(bpe_ranks, vocab)


================================================
FILE: cvae/.gitignore
================================================
models/


================================================
FILE: cvae/README.md
================================================
# Convolutional Variational Autoencoder (CVAE) on MNIST

Convolutional variational autoencoder (CVAE) implementation in MLX using
MNIST.[^1]

## Setup 

Install the requirements:

```
pip install -r requirements.txt
```

## Run


To train a VAE run:

```shell
python main.py
```

To see the supported options, do `python main.py -h`.

Training with the default options should give:

```shell
$ python train.py 
Options: 
  Device: GPU
  Seed: 0
  Batch size: 128
  Max number of filters: 64
  Number of epochs: 50
  Learning rate: 0.001
  Number of latent dimensions: 8
Number of trainable params: 0.1493 M
Epoch    1 | Loss   14626.96 | Throughput  1803.44 im/s | Time     34.3 (s)
Epoch    2 | Loss   10462.21 | Throughput  1802.20 im/s | Time     34.3 (s)
...
Epoch   50 | Loss    8293.13 | Throughput  1804.91 im/s | Time     34.2 (s)
```

The throughput was measured on a 32GB M1 Max. 

Reconstructed and generated images will be saved after each epoch in the
`models/` path. Below are examples of reconstructed training set images and
generated images.

#### Reconstruction

![MNIST Reconstructions](assets/rec_mnist.png)

#### Generation 

![MNIST Samples](assets/samples_mnist.png)


## Limitations

At the time of writing, MLX does not have transposed 2D convolutions. The
example approximates them with a combination of nearest neighbor upsampling and
regular convolutions, similar to the original U-Net. We intend to update this
example once transposed 2D convolutions are available.

[^1]: For a good overview of VAEs see the original paper [Auto-Encoding
  Variational Bayes](https://arxiv.org/abs/1312.6114) or [An Introduction to
  Variational Autoencoders](https://arxiv.org/abs/1906.02691).


================================================
FILE: cvae/dataset.py
================================================
# Copyright © 2023-2024 Apple Inc.

from mlx.data.datasets import load_mnist


def mnist(batch_size, img_size, root=None):
    # load train and test sets using mlx-data
    load_fn = load_mnist
    tr = load_fn(root=root, train=True)
    test = load_fn(root=root, train=False)

    # number of image channels is 1 for MNIST
    num_img_channels = 1

    # normalize to [0,1]
    def normalize(x):
        return x.astype("float32") / 255.0

    # iterator over training set
    tr_iter = (
        tr.shuffle()
        .to_stream()
        .image_resize("image", h=img_size[0], w=img_size[1])
        .key_transform("image", normalize)
        .batch(batch_size)
        .prefetch(4, 4)
    )

    # iterator over test set
    test_iter = (
        test.to_stream()
        .image_resize("image", h=img_size[0], w=img_size[1])
        .key_transform("image", normalize)
        .batch(batch_size)
    )
    return tr_iter, test_iter


if __name__ == "__main__":
    batch_size = 32
    img_size = (64, 64)  # (H, W)

    tr_iter, test_iter = mnist(batch_size=batch_size, img_size=img_size)

    B, H, W, C = batch_size, img_size[0], img_size[1], 1
    print(f"Batch size: {B}, Channels: {C}, Height: {H}, Width: {W}")

    batch_tr_iter = next(tr_iter)
    assert batch_tr_iter["image"].shape == (B, H, W, C), "Wrong training set size"
    assert batch_tr_iter["label"].shape == (batch_size,), "Wrong training set size"

    batch_test_iter = next(test_iter)
    assert batch_test_iter["image"].shape == (B, H, W, C), "Wrong training set size"
    assert batch_test_iter["label"].shape == (batch_size,), "Wrong training set size"


================================================
FILE: cvae/main.py
================================================
# Copyright © 2023-2024 Apple Inc.

import argparse
import time
from functools import partial
from pathlib import Path

import dataset
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import vae
from mlx.utils import tree_flatten
from PIL import Image


def grid_image_from_batch(image_batch, num_rows):
    """
    Generate a grid image from a batch of images.
    Assumes input has shape (B, H, W, C).
    """

    B, H, W, _ = image_batch.shape

    num_cols = B // num_rows

    # Calculate the size of the output grid image
    grid_height = num_rows * H
    grid_width = num_cols * W

    # Normalize and convert to the desired data type
    image_batch = np.array(image_batch * 255).astype(np.uint8)

    # Reshape the batch of images into a 2D grid
    grid_image = image_batch.reshape(num_rows, num_cols, H, W, -1)
    grid_image = grid_image.swapaxes(1, 2)
    grid_image = grid_image.reshape(grid_height, grid_width, -1)

    # Convert the grid to a PIL Image
    return Image.fromarray(grid_image.squeeze())


def loss_fn(model, X):
    X_recon, mu, logvar = model(X)

    # Reconstruction loss
    recon_loss = nn.losses.mse_loss(X_recon, X, reduction="sum")

    # KL divergence between encoder distribution and standard normal:
    kl_div = -0.5 * mx.sum(1 + logvar - mu.square() - logvar.exp())

    # Total loss
    return recon_loss + kl_div


def reconstruct(model, batch, out_file):
    # Reconstruct a single batch only
    images = mx.array(batch["image"])
    images_recon = model(images)[0]
    paired_images = mx.stack([images, images_recon]).swapaxes(0, 1).flatten(0, 1)
    grid_image = grid_image_from_batch(paired_images, num_rows=16)
    grid_image.save(out_file)


def generate(
    model,
    out_file,
    num_samples=128,
):
    # Sample from the latent distribution:
    z = mx.random.normal([num_samples, model.num_latent_dims])

    # Decode the latent vectors to images:
    images = model.decode(z)

    # Save all images in a single file
    grid_image = grid_image_from_batch(images, num_rows=8)
    grid_image.save(out_file)


def main(args):
    # Load the data
    img_size = (64, 64, 1)
    train_iter, test_iter = dataset.mnist(
        batch_size=args.batch_size, img_size=img_size[:2]
    )

    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Load the model
    model = vae.CVAE(args.latent_dims, img_size, args.max_filters)
    mx.eval(model.parameters())

    num_params = sum(x.size for _, x in tree_flatten(model.trainable_parameters()))
    print("Number of trainable params: {:0.04f} M".format(num_params / 1e6))

    optimizer = optim.AdamW(learning_rate=args.lr)

    # Batches for reconstruction
    train_batch = next(train_iter)
    test_batch = next(test_iter)

    state = [model.state, optimizer.state]

    @partial(mx.compile, inputs=state, outputs=state)
    def step(X):
        loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
        loss, grads = loss_and_grad_fn(model, X)
        optimizer.update(model, grads)
        return loss

    for e in range(1, args.epochs + 1):
        # Reset iterators and stats at the beginning of each epoch
        train_iter.reset()
        model.train()

        # Train one epoch
        tic = time.perf_counter()
        loss_acc = 0.0
        throughput_acc = 0.0

        # Iterate over training batches
        for batch_count, batch in enumerate(train_iter):
            X = mx.array(batch["image"])
            throughput_tic = time.perf_counter()

            # Forward pass + backward pass + update
            loss = step(X)

            # Evaluate updated model parameters
            mx.eval(state)

            throughput_toc = time.perf_counter()
            throughput_acc += X.shape[0] / (throughput_toc - throughput_tic)
            loss_acc += loss.item()

            if batch_count > 0 and (batch_count % 10 == 0):
                print(
                    " | ".join(
                        [
                            f"Epoch {e:4d}",
                            f"Loss {(loss_acc / batch_count):10.2f}",
                            f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
                            f"Batch {batch_count:5d}",
                        ]
                    ),
                    end="\r",
                )
        toc = time.perf_counter()

        print(
            " | ".join(
                [
                    f"Epoch {e:4d}",
                    f"Loss {(loss_acc / batch_count):10.2f}",
                    f"Throughput {(throughput_acc / batch_count):8.2f} im/s",
                    f"Time {toc - tic:8.1f} (s)",
                ]
            )
        )

        model.eval()

        # Reconstruct a batch of training and test images
        reconstruct(model, train_batch, save_dir / f"train_{e:03d}.png")
        reconstruct(model, test_batch, save_dir / f"test_{e:03d}.png")

        # Generate images
        generate(model, save_dir / f"generated_{e:03d}.png")

        model.save_weights(str(save_dir / "weights.npz"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Use CPU instead of GPU acceleration",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--batch-size", type=int, default=128, help="Batch size for training"
    )
    parser.add_argument(
        "--max-filters",
        type=int,
        default=64,
        help="Maximum number of filters in the convolutional layers",
    )
    parser.add_argument(
        "--epochs", type=int, default=50, help="Number of training epochs"
    )
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")

    parser.add_argument(
        "--latent-dims",
        type=int,
        default=8,
        help="Number of latent dimensions (positive integer)",
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="models/",
        help="Path to save the model and reconstructed images.",
    )

    args = parser.parse_args()

    if args.cpu:
        mx.set_default_device(mx.cpu)

    np.random.seed(args.seed)
    mx.random.seed(args.seed)

    print("Options: ")
    print(f"  Device: {'GPU' if not args.cpu else 'CPU'}")
    print(f"  Seed: {args.seed}")
    print(f"  Batch size: {args.batch_size}")
    print(f"  Max number of filters: {args.max_filters}")
    print(f"  Number of epochs: {args.epochs}")
    print(f"  Learning rate: {args.lr}")
    print(f"  Number of latent dimensions: {args.latent_dims}")

    main(args)


================================================
FILE: cvae/requirements.txt
================================================
mlx>=0.2
mlx-data
numpy
Pillow


================================================
FILE: cvae/vae.py
================================================
# Copyright © 2023-2024 Apple Inc.

import math

import mlx.core as mx
import mlx.nn as nn


# from https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
def upsample_nearest(x, scale: int = 2):
    B, H, W, C = x.shape
    x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
    x = x.reshape(B, H * scale, W * scale, C)
    return x


class UpsamplingConv2d(nn.Module):
    """
    A convolutional layer that upsamples the input by a factor of 2. MLX does
    not yet support transposed convolutions, so we approximate them with
    nearest neighbor upsampling followed by a convolution. This is similar to
    the approach used in the original U-Net.
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride=stride, padding=padding
        )

    def __call__(self, x):
        x = self.conv(upsample_nearest(x))
        return x


class Encoder(nn.Module):
    """
    A convolutional variational encoder.
    Maps the input to a normal distribution in latent space and sample a latent
    vector from that distribution.
    """

    def __init__(self, num_latent_dims, image_shape, max_num_filters):
        super().__init__()

        # number of filters in the convolutional layers
        num_filters_1 = max_num_filters // 4
        num_filters_2 = max_num_filters // 2
        num_filters_3 = max_num_filters

        # Output (BHWC):  B x 32 x 32 x num_filters_1
        self.conv1 = nn.Conv2d(image_shape[-1], num_filters_1, 3, stride=2, padding=1)
        # Output (BHWC):  B x 16 x 16 x num_filters_2
        self.conv2 = nn.Conv2d(num_filters_1, num_filters_2, 3, stride=2, padding=1)
        # Output (BHWC):  B x 8 x 8 x num_filters_3
        self.conv3 = nn.Conv2d(num_filters_2, num_filters_3, 3, stride=2, padding=1)

        # Batch Normalization
        self.bn1 = nn.BatchNorm(num_filters_1)
        self.bn2 = nn.BatchNorm(num_filters_2)
        self.bn3 = nn.BatchNorm(num_filters_3)

        # Divide the spatial dimensions by 8 because of the 3 strided convolutions
        output_shape = [num_filters_3] + [
            dimension // 8 for dimension in image_shape[:-1]
        ]

        flattened_dim = math.prod(output_shape)

        # Linear mappings to mean and standard deviation
        self.proj_mu = nn.Linear(flattened_dim, num_latent_dims)
        self.proj_log_var = nn.Linear(flattened_dim, num_latent_dims)

    def __call__(self, x):
        x = nn.leaky_relu(self.bn1(self.conv1(x)))
        x = nn.leaky_relu(self.bn2(self.conv2(x)))
        x = nn.leaky_relu(self.bn3(self.conv3(x)))
        x = mx.flatten(x, 1)  # flatten all dimensions except batch

        mu = self.proj_mu(x)
        logvar = self.proj_log_var(x)
        # Ensure this is the std deviation, not variance
        sigma = mx.exp(logvar * 0.5)

        # Generate a tensor of random values from a normal distribution
        eps = mx.random.normal(sigma.shape)

        # Reparametrization trick to brackpropagate through sampling.
        z = eps * sigma + mu

        return z, mu, logvar


class Decoder(nn.Module):
    """A convolutional decoder"""

    def __init__(self, num_latent_dims, image_shape, max_num_filters):
        super().__init__()
        self.num_latent_dims = num_latent_dims
        num_img_channels = image_shape[-1]
        self.max_num_filters = max_num_filters

        # decoder layers
        num_filters_1 = max_num_filters
        num_filters_2 = max_num_filters // 2
        num_filters_3 = max_num_filters // 4

        # divide the last two dimensions by 8 because of the 3 upsampling convolutions
        self.input_shape = [dimension // 8 for dimension in image_shape[:-1]] + [
            num_filters_1
        ]
        flattened_dim = math.prod(self.input_shape)

        # Output: flattened_dim
        self.lin1 = nn.Linear(num_latent_dims, flattened_dim)
        # Output (BHWC):  B x 16 x 16 x num_filters_2
        self.upconv1 = UpsamplingConv2d(
            num_filters_1, num_filters_2, 3, stride=1, padding=1
        )
        # Output (BHWC):  B x 32 x 32 x num_filters_1
        self.upconv2 = UpsamplingConv2d(
            num_filters_2, num_filters_3, 3, stride=1, padding=1
        )
        # Output (BHWC):  B x 64 x 64 x #img_channels
        self.upconv3 = UpsamplingConv2d(
            num_filters_3, num_img_channels, 3, stride=1, padding=1
        )

        # Batch Normalizations
        self.bn1 = nn.BatchNorm(num_filters_2)
        self.bn2 = nn.BatchNorm(num_filters_3)

    def __call__(self, z):
        x = self.lin1(z)

        # reshape to BHWC
        x = x.reshape(
            -1, self.input_shape[0], self.input_shape[1], self.max_num_filters
        )

        # approximate transposed convolutions with nearest neighbor upsampling
        x = nn.leaky_relu(self.bn1(self.upconv1(x)))
        x = nn.leaky_relu(self.bn2(self.upconv2(x)))
        # sigmoid to ensure pixel values are in [0,1]
        x = mx.sigmoid(self.upconv3(x))
        return x


class CVAE(nn.Module):
    """
    A convolutional variational autoencoder consisting of an encoder and a
    decoder.
    """

    def __init__(self, num_latent_dims, input_shape, max_num_filters):
        super().__init__()
        self.num_latent_dims = num_latent_dims
        self.encoder = Encoder(num_latent_dims, input_shape, max_num_filters)
        self.decoder = Decoder(num_latent_dims, input_shape, max_num_filters)

    def __call__(self, x):
        # image to latent vector
        z, mu, logvar = self.encoder(x)
        # latent vector to image
        x = self.decode(z)
        return x, mu, logvar

    def encode(self, x):
        return self.encoder(x)[0]

    def decode(self, z):
        return self.decoder(z)


================================================
FILE: encodec/README.md
================================================
# EnCodec

An example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and
generate audio.

### Setup

Install the requirements:

```
pip install -r requirements.txt
```

Optionally install FFmpeg and SciPy for loading and saving audio files,
respectively.

Install [FFmpeg](https://ffmpeg.org/):

```
# on macOS using Homebrew (https://brew.sh/)
brew install ffmpeg
```

Install SciPy:

```
pip install scipy
```

### Example

An example using the model:

```python
import mlx.core as mx
from encodec import EncodecModel
from utils import load_audio, save_audio

# Load the 48 KHz model and preprocessor.
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")

# Load an audio file
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)

# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)

# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
    return model.encode(feats, mask, bandwidth=3)

# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
    return model.decode(codes, scales, mask)


codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)

# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]

# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)
```

The 24 KHz, 32 KHz, and 48 KHz MLX formatted models are available in the
[Hugging Face MLX Community](https://huggingface.co/collections/mlx-community/encodec-66e62334038300b07a43b164)
in several data types.

### Optional

To convert models, use the `convert.py` script. To see the options, run:

```bash
python convert.py -h
```

[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2210.13438) and
  [code](https://github.com/facebookresearch/encodec) for more details.


================================================
FILE: encodec/benchmarks/bench_mx.py
================================================
# Copyright © 2024 Apple Inc.

import time

import mlx.core as mx

from encodec import EncodecModel

model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")

audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)
mx.eval(model, feats, mask)


@mx.compile
def fun():
    codes, scales = model.encode(feats, mask, bandwidth=3)
    reconstructed = model.decode(codes, scales, mask)
    return reconstructed


for _ in range(5):
    mx.eval(fun())

tic = time.time()
for _ in range(10):
    mx.eval(fun())
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")


================================================
FILE: encodec/benchmarks/bench_pt.py
================================================
# Copyright © 2024 Apple Inc.

import time

import numpy as np
import torch
from transformers import AutoProcessor, EncodecModel

processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
audio = np.random.uniform(size=(2, 288000)).astype(np.float32)

pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz").to("mps")
pt_inputs = processor(
    raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
).to("mps")


def fun():
    pt_encoded = pt_model.encode(pt_inputs["input_values"], pt_inputs["padding_mask"])
    pt_audio = pt_model.decode(
        pt_encoded.audio_codes, pt_encoded.audio_scales, pt_inputs["padding_mask"]
    )
    torch.mps.synchronize()


for _ in range(5):
    fun()

tic = time.time()
for _ in range(10):
    fun()
toc = time.time()
ms = 1000 * (toc - tic) / 10
print(f"Time per it: {ms:.3f}")


================================================
FILE: encodec/convert.py
================================================
# Copyright © 2024 Apple Inc.

import argparse
import json
from pathlib import Path
from textwrap import dedent
from types import SimpleNamespace
from typing import Any, Dict, Union

import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download

import encodec


def fetch_from_hub(hf_repo: str) -> Path:
    model_path = Path(
        snapshot_download(
            repo_id=hf_repo,
            allow_patterns=["*.json", "*.safetensors"],
        )
    )
    return model_path


def upload_to_hub(path: str, upload_repo: str, hf_path: str):
    """
    Uploads the model to Hugging Face hub.

    Args:
        path (str): Local path to the model.
        upload_repo (str): Name of the HF repo to upload to.
        hf_path (str): Path to the original Hugging Face model.
    """
    import os

    from huggingface_hub import HfApi, ModelCard, logging

    content = dedent(
        f"""
        ---
        language: en
        license: other
        library: mlx
        tags:
        - mlx
        ---

        The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
        converted to MLX format from
        [{hf_path}](https://huggingface.co/{hf_path}).

        This model is intended to be used with the [EnCodec MLX
        example](https://github.com/ml-explore/mlx-examples/tree/main/encodec).
        """
    )

    card = ModelCard(content)
    card.save(os.path.join(path, "README.md"))

    logging.set_verbosity_info()

    api = HfApi()
    api.create_repo(repo_id=upload_repo, exist_ok=True)
    api.upload_folder(
        folder_path=path,
        repo_id=upload_repo,
        repo_type="model",
        multi_commits=True,
        multi_commits_verbose=True,
    )
    print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")


def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -> None:
    if isinstance(save_path, str):
        save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)

    total_size = sum(v.nbytes for v in weights.values())
    index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
    mx.save_safetensors(
        str(save_path / "model.safetensors"), weights, metadata={"format": "mlx"}
    )

    for weight_name in weights.keys():
        index_data["weight_map"][weight_name] = "model.safetensors"

    index_data["weight_map"] = {
        k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
    }

    with open(save_path / "model.safetensors.index.json", "w") as f:
        json.dump(index_data, f, indent=4)


def save_config(
    config: dict,
    config_path: Union[str, Path],
) -> None:
    """Save the model configuration to the ``config_path``.

    The final configuration will be sorted before saving for better readability.

    Args:
        config (dict): The model configuration.
        config_path (Union[str, Path]): Model configuration file path.
    """
    # Clean unused keys
    config.pop("_name_or_path", None)

    # sort the config for better readability
    config = dict(sorted(config.items()))

    # write the updated config to the config_path (if provided)
    with open(config_path, "w") as fid:
        json.dump(config, fid, indent=4)


def convert(
    upload: bool,
    model: str,
    dtype: str = None,
):
    hf_repo = f"facebook/encodec_{model}"
    mlx_repo = f"mlx-community/encodec-{model}-{dtype}"
    path = fetch_from_hub(hf_repo)
    save_path = Path("mlx_models")

    weights = mx.load(str(Path(path) / "model.safetensors"))

    with open(path / "config.json", "r") as fid:
        config = SimpleNamespace(**json.load(fid))

    model = encodec.EncodecModel(config)

    new_weights = {}
    for k, v in weights.items():
        basename, pname = k.rsplit(".", 1)
        if pname == "weight_v":
            g = weights[basename + ".weight_g"]
            v = g * (v / mx.linalg.norm(v, axis=(1, 2), keepdims=True))
            k = basename + ".weight"
        elif pname in ["weight_g", "embed_avg", "cluster_size", "inited"]:
            continue
        elif "lstm" in basename:
            w_or_b, ih_or_hh, ln = pname.split("_")
            if w_or_b == "weight":
                new_pname = "Wx" if ih_or_hh == "ih" else "Wh"
            elif w_or_b == "bias" and ih_or_hh == "ih":
                continue
            else:
                v = v + weights[k.replace("_hh_", "_ih_")]
                new_pname = "bias"
            k = basename + "." + ln[1:] + "." + new_pname
        if "conv.weight" in k:
            # Possibly a transposed conv which has a different order
            if "decoder" in k:
                ln = int(k.split(".")[2])
                if "conv" in model.decoder.layers[ln] and isinstance(
                    model.decoder.layers[ln].conv, nn.ConvTranspose1d
                ):
                    v = mx.moveaxis(v, 0, 2)
                else:
                    v = mx.moveaxis(v, 1, 2)
            else:
                v = mx.moveaxis(v, 1, 2)

        new_weights[k] = v
    weights = new_weights

    model.load_weights(list(weights.items()))

    if dtype is not None:
        t = getattr(mx, dtype)
        weights = {k: v.astype(t) for k, v in weights.items()}

    if isinstance(save_path, str):
        save_path = Path(save_path)

    save_weights(save_path, weights)

    save_config(vars(config), config_path=save_path / "config.json")

    if upload:
        upload_to_hub(save_path, mlx_repo, hf_repo)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert EnCodec weights to MLX.")
    parser.add_argument(
        "--model",
        type=str,
        default="48khz",
        help="",
        choices=["24khz", "32khz", "48khz"],
    )
    parser.add_argument(
        "--upload",
        action="store_true",
        help="Upload the weights to Hugging Face.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        help="Data type to convert the model to.",
        default="float32",
        choices=["float32", "bfloat16", "float16"],
    )
    args = parser.parse_args()
    convert(upload=args.upload, model=args.model, dtype=args.dtype)


================================================
FILE: encodec/encodec.py
================================================
# Copyright © 2024 Apple Inc.

import functools
import json
import math
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

_lstm_kernel = mx.fast.metal_kernel(
    name="lstm",
    input_names=["x", "h_in", "cell", "hidden_size", "time_step", "num_time_steps"],
    output_names=["hidden_state", "cell_state"],
    header="""
    template <typename T>
    T sigmoid(T x) {
        auto y = 1 / (1 + metal::exp(-metal::abs(x)));
        return (x < 0) ? 1 - y : y;
    }
    """,
    source="""
        uint b = thread_position_in_grid.x;
        uint d = hidden_size * 4;

        uint elem = b * d + thread_position_in_grid.y;
        uint index = elem;
        uint x_index = b * num_time_steps * d + time_step * d + index;

        auto i = sigmoid(h_in[index] + x[x_index]);
        index += hidden_size;
        x_index += hidden_size;
        auto f = sigmoid(h_in[index] + x[x_index]);
        index += hidden_size;
        x_index += hidden_size;
        auto g = metal::precise::tanh(h_in[index] + x[x_index]);
        index += hidden_size;
        x_index += hidden_size;
        auto o = sigmoid(h_in[index] + x[x_index]);

        cell_state[elem] = f * cell[elem] + i * g;
        hidden_state[elem] = o * metal::precise::tanh(cell_state[elem]);
    """,
)


def lstm_custom(x, h_in, cell, time_step):
    assert x.ndim == 3, "Input to LSTM must have 3 dimensions."
    out_shape = cell.shape
    return _lstm_kernel(
        inputs=[x, h_in, cell, out_shape[-1], time_step, x.shape[-2]],
        output_shapes=[out_shape, out_shape],
        output_dtypes=[h_in.dtype, h_in.dtype],
        grid=(x.shape[0], h_in.size // 4, 1),
        threadgroup=(256, 1, 1),
    )


class LSTM(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool = True,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.Wx = mx.zeros((4 * hidden_size, input_size))
        self.Wh = mx.zeros((4 * hidden_size, hidden_size))
        self.bias = mx.zeros((4 * hidden_size,)) if bias else None

    def __call__(self, x, hidden=None, cell=None):
        if self.bias is not None:
            x = mx.addmm(self.bias, x, self.Wx.T)
        else:
            x = x @ self.Wx.T

        all_hidden = []

        B = x.shape[0]
        cell = cell or mx.zeros((B, self.hidden_size), x.dtype)
        for t in range(x.shape[-2]):
            if hidden is None:
                hidden = mx.zeros((B, self.hidden_size * 4), x.dtype)
            else:
                hidden = hidden @ self.Wh.T
            hidden, cell = lstm_custom(x, hidden, cell, t)
            all_hidden.append(hidden)

        return mx.stack(all_hidden, axis=-2)


class EncodecConv1d(nn.Module):
    """Conv1d with asymmetric or causal padding and normalization."""

    def __init__(
        self,
        config,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        dilation: int = 1,
    ):
        super().__init__()
        self.causal = config.use_causal_conv
        self.pad_mode = config.pad_mode
        self.norm_type = config.norm_type

        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size, stride, dilation=dilation
        )
        if self.norm_type == "time_group_norm":
            self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)

        self.stride = stride

        # Effective kernel size with dilations.
        self.kernel_size = (kernel_size - 1) * dilation + 1

        self.padding_total = kernel_size - stride

    def _get_extra_padding_for_conv1d(
        self,
        hidden_states: mx.array,
    ) -> mx.array:
        length = hidden_states.shape[1]
        n_frames = (length - self.kernel_size + self.padding_total) / self.stride + 1
        n_frames = int(math.ceil(n_frames)) - 1
        ideal_length = n_frames * self.stride + self.kernel_size - self.padding_total
        return ideal_length - length

    def _pad1d(
        self,
        hidden_states: mx.array,
        paddings: Tuple[int, int],
        mode: str = "zero",
        value: float = 0.0,
    ):
        if mode != "reflect":
            return mx.pad(
                hidden_states, paddings, mode="constant", constant_values=value
            )

        length = hidden_states.shape[1]
        prefix = hidden_states[:, 1 : paddings[0] + 1][:, ::-1]
        suffix = hidden_states[:, max(length - (paddings[1] + 1), 0) : -1][:, ::-1]
        return mx.concatenate([prefix, hidden_states, suffix], axis=1)

    def __call__(self, hidden_states):
        extra_padding = self._get_extra_padding_for_conv1d(hidden_states)

        if self.causal:
            # Left padding for causal
            hidden_states = self._pad1d(
                hidden_states, (self.padding_total, extra_padding), mode=self.pad_mode
            )
        else:
            # Asymmetric padding required for odd strides
            padding_right = self.padding_total // 2
            padding_left = self.padding_total - padding_right
            hidden_states = self._pad1d(
                hidden_states,
                (padding_left, padding_right + extra_padding),
                mode=self.pad_mode,
            )

        hidden_states = self.conv(hidden_states)

        if self.norm_type == "time_group_norm":
            hidden_states = self.norm(hidden_states)

        return hidden_states


class EncodecConvTranspose1d(nn.Module):
    """ConvTranspose1d with asymmetric or causal padding and normalization."""

    def __init__(
        self,
        config,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
    ):
        super().__init__()
        self.causal = config.use_causal_conv
        self.trim_right_ratio = config.trim_right_ratio
        self.norm_type = config.norm_type
        self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride)
        if config.norm_type == "time_group_norm":
            self.norm = nn.GroupNorm(1, out_channels, pytorch_compatible=True)
        self.padding_total = kernel_size - stride

    def __call__(self, hidden_states):
        hidden_states = self.conv(hidden_states)

        if self.norm_type == "time_group_norm":
            hidden_states = self.norm(hidden_states)

        if self.causal:
            padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
        else:
            padding_right = self.padding_total // 2

        padding_left = self.padding_total - padding_right

        end = hidden_states.shape[1] - padding_right
        hidden_states = hidden_states[:, padding_left:end, :]
        return hidden_states


class EncodecLSTM(nn.Module):
    def __init__(self, config, dimension):
        super().__init__()
        self.lstm = [LSTM(dimension, dimension) for _ in range(config.num_lstm_layers)]

    def __call__(self, hidden_states):
        h = hidden_states
        for lstm in self.lstm:
            h = lstm(h)
        return h + hidden_states


class EncodecResnetBlock(nn.Module):
    """
    Residual block from SEANet model as used by EnCodec.
    """

    def __init__(self, config, dim: int, dilations: List[int]):
        super().__init__()
        kernel_sizes = (config.residual_kernel_size, 1)
        if len(kernel_sizes) != len(dilations):
            raise ValueError("Number of kernel sizes should match number of dilations")

        hidden = dim // config.compress
        block = []
        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
            in_chs = dim if i == 0 else hidden
            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
            block += [nn.ELU()]
            block += [
                EncodecConv1d(config, in_chs, out_chs, kernel_size, dilation=dilation)
            ]
        self.block = block

        if getattr(config, "use_conv_shortcut", True):
            self.shortcut = EncodecConv1d(config, dim, dim, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def __call__(self, hidden_states):
        residual = hidden_states
        for layer in self.block:
            hidden_states = layer(hidden_states)

        return self.shortcut(residual) + hidden_states


class EncodecEncoder(nn.Module):
    """SEANet encoder as used by EnCodec."""

    def __init__(self, config):
        super().__init__()
        model = [
            EncodecConv1d(
                config, config.audio_channels, config.num_filters, config.kernel_size
            )
        ]
        scaling = 1

        for ratio in reversed(config.upsampling_ratios):
            current_scale = scaling * config.num_filters
            for j in range(config.num_residual_layers):
                model += [
                    EncodecResnetBlock(
                        config, current_scale, [config.dilation_growth_rate**j, 1]
                    )
                ]
            model += [nn.ELU()]
            model += [
                EncodecConv1d(
                    config,
                    current_scale,
                    current_scale * 2,
                    kernel_size=ratio * 2,
                    stride=ratio,
                )
            ]
            scaling *= 2

        model += [EncodecLSTM(config, scaling * config.num_filters)]
        model += [nn.ELU()]
        model += [
            EncodecConv1d(
                config,
                scaling * config.num_filters,
                config.hidden_size,
                config.last_kernel_size,
            )
        ]

        self.layers = model

    def __call__(self, hidden_states):
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        return hidden_states


class EncodecDecoder(nn.Module):
    """SEANet decoder as used by EnCodec."""

    def __init__(self, config):
        super().__init__()
        scaling = int(2 ** len(config.upsampling_ratios))
        model = [
            EncodecConv1d(
                config,
                config.hidden_size,
                scaling * config.num_filters,
                config.kernel_size,
            )
        ]

        model += [EncodecLSTM(config, scaling * config.num_filters)]

        for ratio in config.upsampling_ratios:
            current_scale = scaling * config.num_filters
            model += [nn.ELU()]
            model += [
                EncodecConvTranspose1d(
                    config,
                    current_scale,
                    current_scale // 2,
                    kernel_size=ratio * 2,
                    stride=ratio,
                )
            ]
            for j in range(config.num_residual_layers):
                model += [
                    EncodecResnetBlock(
                        config, current_scale // 2, (config.dilation_growth_rate**j, 1)
                    )
                ]
            scaling //= 2

        model += [nn.ELU()]
        model += [
            EncodecConv1d(
                config,
                config.num_filters,
                config.audio_channels,
                config.last_kernel_size,
            )
        ]
        self.layers = model

    def __call__(self, hidden_states):
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        return hidden_states


class EncodecEuclideanCodebook(nn.Module):
    """Codebook with Euclidean distance."""

    def __init__(self, config):
        super().__init__()
        self.embed = mx.zeros((config.codebook_size, config.codebook_dim))

    def quantize(self, hidden_states):
        embed = self.embed.T
        scaled_states = hidden_states.square().sum(axis=1, keepdims=True)
        dist = -(
            scaled_states
            - 2 * hidden_states @ embed
            + embed.square().sum(axis=0, keepdims=True)
        )
        embed_ind = dist.argmax(axis=-1)
        return embed_ind

    def encode(self, hidden_states):
        shape = hidden_states.shape
        hidden_states = hidden_states.reshape((-1, shape[-1]))
        embed_ind = self.quantize(hidden_states)
        embed_ind = embed_ind.reshape(*shape[:-1])
        return embed_ind

    def decode(self, embed_ind):
        return self.embed[embed_ind]


class EncodecVectorQuantization(nn.Module):
    """
    Vector quantization implementation. Currently supports only euclidean distance.
    """

    def __init__(self, config):
        super().__init__()
        self.codebook = EncodecEuclideanCodebook(config)

    def encode(self, hidden_states):
        return self.codebook.encode(hidden_states)

    def decode(self, embed_ind):
        return self.codebook.decode(embed_ind)


class EncodecResidualVectorQuantizer(nn.Module):
    """Residual Vector Quantizer."""

    def __init__(self, config):
        super().__init__()
        self.codebook_size = config.codebook_size

        hop_length = np.prod(config.upsampling_ratios)
        self.frame_rate = math.ceil(config.sampling_rate / hop_length)
        self.num_quantizers = int(
            1000 * config.target_bandwidths[-1] // (self.frame_rate * 10)
        )
        self.layers = [
            EncodecVectorQuantization(config) for _ in range(self.num_quantizers)
        ]

    def get_num_quantizers_for_bandwidth(
        self, bandwidth: Optional[float] = None
    ) -> int:
        """Return num_quantizers based on specified target bandwidth."""
        bw_per_q = math.log2(self.codebook_size) * self.frame_rate
        num_quantizers = self.num_quantizers
        if bandwidth is not None and bandwidth > 0.0:
            num_quantizers = int(max(1, math.floor(bandwidth * 1000 / bw_per_q)))
        return num_quantizers

    def encode(
        self, embeddings: mx.array, bandwidth: Optional[float] = None
    ) -> mx.array:
        """
        Encode a given input array with the specified frame rate at the given
        bandwidth. The RVQ encode method sets the appropriate number of
        quantizers to use and returns indices for each quantizer.
        """
        num_quantizers = self.get_num_quantizers_for_bandwidth(bandwidth)
        residual = embeddings
        all_indices = []
        for layer in self.layers[:num_quantizers]:
            indices = layer.encode(residual)
            quantized = layer.decode(indices)
            residual = residual - quantized
            all_indices.append(indices)
        out_indices = mx.stack(all_indices, axis=1)
        return out_indices

    def decode(self, codes: mx.array) -> mx.array:
        """Decode the given codes to the quantized representation."""
        quantized_out = None
        for i, indices in enumerate(codes.split(codes.shape[1], axis=1)):
            layer = self.layers[i]
            quantized = layer.decode(indices.squeeze(1))
            if quantized_out is None:
                quantized_out = quantized
            else:
                quantized_out = quantized + quantized_out
        return quantized_out


class EncodecModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = EncodecEncoder(config)
        self.decoder = EncodecDecoder(config)
        self.quantizer = EncodecResidualVectorQuantizer(config)

    def _encode_frame(
        self, input_values: mx.array, bandwidth: float, padding_mask: mx.array
    ) -> Tuple[mx.array, Optional[mx.array]]:
        """
        Encodes the given input using the underlying VQVAE.
        """
        length = input_values.shape[1]
        duration = length / self.config.sampling_rate

        if (
            self.config.chunk_length_s is not None
            and duration > 1e-5 + self.config.chunk_length_s
        ):
            raise RuntimeError(
                f"Duration of frame ({duration}) is longer than chunk {self.config.chunk_length_s}"
            )

        scale = None
        if self.config.normalize:
            # if the padding is non zero
            input_values = input_values * padding_mask[..., None]
            mono = mx.sum(input_values, axis=2, keepdims=True) / input_values.shape[2]
            scale = mono.square().mean(axis=1, keepdims=True).sqrt() + 1e-8
            input_values = input_values / scale

        embeddings = self.encoder(input_values)
        codes = self.quantizer.encode(embeddings, bandwidth)
        return codes, scale

    def encode(
        self,
        input_values: mx.array,
        padding_mask: mx.array = None,
        bandwidth: Optional[float] = None,
    ) -> Tuple[mx.array, Optional[mx.array]]:
        """
        Encodes the input audio waveform into discrete codes.

        Args:
            input_values (mx.array): The input audio waveform with shape
                ``(batch_size, channels, sequence_length)``.
            padding_mask (mx.array): Padding mask used to pad the ``input_values``.
            bandwidth (float, optional): The target bandwidth. Must be one of
                ``config.target_bandwidths``. If ``None``, uses the smallest
                possible bandwidth. bandwidth is represented as a thousandth of
                what it is, e.g. 6kbps bandwidth is represented as bandwidth == 6.0

        Returns:
            A list of frames containing the discrete encoded codes for the
            input audio waveform, along with rescaling factors for each chunk
            when ``config.normalize==True``. Each frame is a tuple ``(codebook,
            scale)``, with ``codebook`` of shape ``(batch_size, num_codebooks,
            frames)``.
        """

        if bandwidth is None:
            bandwidth = self.config.target_bandwidths[0]
        if bandwidth not in self.config.target_bandwidths:
            raise ValueError(
                f"This model doesn't support the bandwidth {bandwidth}. "
                f"Select one of {self.config.target_bandwidths}."
            )

        _, input_length, channels = input_values.shape

        if channels < 1 or channels > 2:
            raise ValueError(
                f"Number of audio channels must be 1 or 2, but got {channels}"
            )

        chunk_length = self.chunk_length
        if chunk_length is None:
            chunk_length = input_length
            stride = input_length
        else:
            stride = self.chunk_stride

        if padding_mask is None:
            padding_mask = mx.ones(input_values.shape[:2], dtype=mx.bool_)
        encoded_frames = []
        scales = []

        step = chunk_length - stride
        if (input_length % stride) != step:
            raise ValueError(
                "The input length is not properly padded for batched chunked "
                "encoding. Make sure to pad the input correctly."
            )

        for offset in range(0, input_length - step, stride):
            mask = padding_mask[:, offset : offset + chunk_length].astype(mx.bool_)
            frame = input_values[:, offset : offset + chunk_length]
            encoded_frame, scale = self._encode_frame(frame, bandwidth, mask)
            encoded_frames.append(encoded_frame)
            scales.append(scale)

        encoded_frames = mx.stack(encoded_frames)

        return (encoded_frames, scales)

    @staticmethod
    def _linear_overlap_add(frames: List[mx.array], stride: int):
        if len(frames) == 0:
            raise ValueError("`frames` cannot be an empty list.")

        dtype = frames[0].dtype
        N, frame_length, C = frames[0].shape
        total_size = stride * (len(frames) - 1) + frames[-1].shape[1]

        time_vec = mx.linspace(0, 1, frame_length + 2, dtype=dtype)[1:-1]
        weight = 0.5 - (time_vec - 0.5).abs()

        weight = weight[:, None]
        sum_weight = mx.zeros((total_size, 1), dtype=dtype)
        out = mx.zeros((N, total_size, C), dtype=dtype)
        offset = 0

        for frame in frames:
            frame_length = frame.shape[1]
            out[:, offset : offset + frame_length] += weight[:frame_length] * frame
            sum_weight[offset : offset + frame_length] += weight[:frame_length]
            offset += stride

        return out / sum_weight

    def _decode_frame(
        self, codes: mx.array, scale: Optional[mx.array] = None
    ) -> mx.array:
        embeddings = self.quantizer.decode(codes)
        outputs = self.decoder(embeddings)
        if scale is not None:
            outputs = outputs * scale
        return outputs

    @property
    def channels(self):
        return self.config.audio_channels

    @property
    def sampling_rate(self):
        return self.config.sampling_rate

    @property
    def chunk_length(self):
        if self.config.chunk_length_s is None:
            return None
        else:
            return int(self.config.chunk_length_s * self.config.sampling_rate)

    @property
    def chunk_stride(self):
        if self.config.chunk_length_s is None or self.config.overlap is None:
            return None
        else:
            return max(1, int((1.0 - self.config.overlap) * self.chunk_length))

    def decode(
        self,
        audio_codes: mx.array,
        audio_scales: Union[mx.array, List[mx.array]],
        padding_mask: Optional[mx.array] = None,
    ) -> Tuple[mx.array, mx.array]:
        """
        Decodes the given frames into an output audio waveform.

        Note that the output might be a bit bigger than the input. In that
        case, any extra steps at the end should be trimmed.

        Args:
            audio_codes (mx.array): Discret code embeddings of shape
                ``(batch_size, nb_chunks, chunk_length)``.
            audio_scales (mx.array): Scaling factor for each input.
            padding_mask (mx.array): Padding mask.
        """
        chunk_length = self.chunk_length
        if chunk_length is None:
            if audio_codes.shape[1] != 1:
                raise ValueError(f"Expected one frame, got {len(audio_codes)}")
            audio_values = self._decode_frame(audio_codes[:, 0], audio_scales[0])
        else:
            decoded_frames = []

            for frame, scale in zip(audio_codes, audio_scales):
                frames = self._decode_frame(frame, scale)
                decoded_frames.append(frames)

            audio_values = self._linear_overlap_add(
                decoded_frames, self.chunk_stride or 1
            )

        # truncate based on padding mask
        if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
            audio_values = audio_values[:, : padding_mask.shape[1]]
        return audio_values

    @classmethod
    def from_pretrained(cls, path_or_repo: str):
        from huggingface_hub import snapshot_download

        path = Path(path_or_repo)
        if not path.exists():
            path = Path(
                snapshot_download(
                    repo_id=path_or_repo,
                    allow_patterns=["*.json", "*.safetensors", "*.model"],
                )
            )

        with open(path / "config.json", "r") as f:
            config = SimpleNamespace(**json.load(f))

        model = EncodecModel(config)
        model.load_weights(str(path / "model.safetensors"))
        processor = functools.partial(
            preprocess_audio,
            sampling_rate=config.sampling_rate,
            chunk_length=model.chunk_length,
            chunk_stride=model.chunk_stride,
        )
        mx.eval(model)
        return model, processor


def preprocess_audio(
    raw_audio: Union[mx.array, List[mx.array]],
    sampling_rate: int = 24000,
    chunk_length: Optional[int] = None,
    chunk_stride: Optional[int] = None,
):
    r"""
    Prepare inputs for the EnCodec model.

    Args:
        raw_audio (mx.array or List[mx.array]): The sequence or batch of
            sequences to be processed.
        sampling_rate (int): The sampling rate at which the audio waveform
            should be digitalized.
        chunk_length (int, optional): The model's chunk length.
        chunk_stride (int, optional): The model's chunk stride.
    """
    if not isinstance(raw_audio, list):
        raw_audio = [raw_audio]

    raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]

    max_length = max(array.shape[0] for array in raw_audio)
    if chunk_length is not None:
        max_length += chunk_length - (max_length % chunk_stride)

    inputs = []
    masks = []
    for x in raw_audio:
        length = x.shape[0]
        mask = mx.ones((length,), dtype=mx.bool_)
        difference = max_length - length
        if difference > 0:
            mask = mx.pad(mask, (0, difference))
            x = mx.pad(x, ((0, difference), (0, 0)))
        inputs.append(x)
        masks.append(mask)
    return mx.stack(inputs), mx.stack(masks)


================================================
FILE: encodec/example.py
================================================
# Copyright © 2024 Apple Inc.

import mlx.core as mx
from utils import load_audio, save_audio

from encodec import EncodecModel

# Load the 48 KHz model and preprocessor.
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")

# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)

# Preprocess the audio (this can also be a list of arrays for batched
# processing).
feats, mask = processor(audio)


# Encode at the given bandwidth. A lower bandwidth results in more
# compression but lower reconstruction quality.
@mx.compile
def encode(feats, mask):
    return model.encode(feats, mask, bandwidth=3)


# Decode to reconstruct the audio
@mx.compile
def decode(codes, scales, mask):
    return model.decode(codes, scales, mask)


codes, scales = encode(feats, mask)
reconstructed = decode(codes, scales, mask)

# Trim any padding:
reconstructed = reconstructed[0, : len(audio)]

# Save the audio as a wave file
save_audio("reconstructed.wav", reconstructed, model.sampling_rate)


================================================
FILE: encodec/requirements.txt
================================================
mlx>=0.18
numpy
huggingface_hub


================================================
FILE: encodec/test.py
================================================
# Copyright © 2024 Apple Inc.

import mlx.core as mx
import numpy as np
import torch
from transformers import AutoProcessor
from transformers import EncodecModel as PTEncodecModel

from encodec import EncodecModel, preprocess_audio


def compare_processors():
    np.random.seed(0)
    audio_length = 95500
    audio = np.random.uniform(size=(2, audio_length)).astype(np.float32)

    processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")

    pt_inputs = processor(
        raw_audio=audio, sampling_rate=processor.sampling_rate, return_tensors="pt"
    )
    mx_inputs = preprocess_audio(
        mx.array(audio).T,
        processor.sampling_rate,
        processor.chunk_length,
        processor.chunk_stride,
    )

    assert np.array_equal(pt_inputs["input_values"], mx_inputs[0].moveaxis(2, 1))
    assert np.array_equal(pt_inputs["padding_mask"], mx_inputs[1])


def compare_models():
    pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
    mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")

    np.random.seed(0)
    audio_length = 190560
    audio = np.random.uniform(size=(1, audio_length, 2)).astype(np.float32)
    mask = np.ones((1, audio_length), dtype=np.int32)
    pt_encoded = pt_model.encode(
        torch.tensor(audio).moveaxis(2, 1), torch.tensor(mask)[None]
    )
    mx_encoded = mx_model.encode(mx.array(audio), mx.array(mask))
    pt_codes = pt_encoded.audio_codes.numpy()
    mx_codes = mx_encoded[0]
    assert np.array_equal(pt_codes, mx_codes), "Encoding codes mismatch"

    for mx_scale, pt_scale in zip(mx_encoded[1], pt_encoded.audio_scales):
        if mx_scale is not None:
            pt_scale = pt_scale.numpy()
            assert np.allclose(pt_scale, mx_scale, atol=1e-3, rtol=1e-4)

    pt_audio = pt_model.decode(
        pt_encoded.audio_codes, pt_encoded.audio_scales, torch.tensor(mask)[None]
    )
    pt_audio = pt_audio[0].squeeze().T.detach().numpy()
    mx_audio = mx_model.decode(*mx_encoded, mx.array(mask))
    mx_audio = mx_audio.squeeze()
    assert np.allclose(
        pt_audio, mx_audio, atol=1e-4, rtol=1e-4
    ), "Decoding audio mismatch"


if __name__ == "__main__":
    compare_processors()
    compare_models()


================================================
FILE: encodec/utils.py
================================================
# Copyright © 2024 Apple Inc.

import mlx.core as mx
import numpy as np


def save_audio(file: str, audio: mx.array, sampling_rate: int):
    """
    Save audio to a wave (.wav) file.
    """
    from scipy.io.wavfile import write

    audio = (audio * 32767).astype(mx.int16)
    write(file, sampling_rate, np.array(audio))


def load_audio(file: str, sampling_rate: int, channels: int):
    """
    Read audio into an mx.array, resampling if necessary.

    Args:
        file (str): The audio file to open.
        sampling_rate (int): The sample rate to resample the audio at if needed.
        channels (int): The number of audio channels.

    Returns:
        An mx.array containing the audio waveform in float32.
    """
    from subprocess import CalledProcessError, run

    # This launches a subprocess to decode audio while down-mixing
    # and resampling as necessary.  Requires the ffmpeg CLI in PATH.
    # fmt: off
    cmd = [
        "ffmpeg",
        "-nostdin",
        "-threads", "0",
        "-i", file,
        "-f", "s16le",
        "-ac", str(channels),
        "-acodec", "pcm_s16le",
        "-ar", str(sampling_rate),
        "-"
    ]
    # fmt: on
    try:
        out = run(cmd, capture_output=True, check=True).stdout
    except CalledProcessError as e:
        raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e

    out = mx.array(np.frombuffer(out, np.int16))
    return out.reshape(-1, channels).astype(mx.float32) / 32767.0


================================================
FILE: flux/README.md
================================================
FLUX
====

FLUX implementation in MLX. The implementation is ported directly from
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
and the model weights are downloaded directly from the Hugging Face Hub.

The goal of this example is to be clean, educational and to allow for
experimentation with finetuning FLUX models as well as adding extra
functionality such as in-/outpainting, guidance with custom losses etc.

![MLX image](static/generated-mlx.png)    
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
tron emanating futuristic technology with the word "MLX" in the center with
capital red letters.'*

Installation
------------

The dependencies are minimal, namely:

- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly

You can install all of the above with the `requirements.txt` as follows:

    pip install -r requirements.txt


Usage
---------

You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.

```shell
python txt2image.py --model schnell \
    --n-images 1 \
    --image-size 256x512 \
    --verbose \
    'A photo of an astronaut riding a horse on Mars.'
```

For more parameters, please use the `--help` command to view.

```shell
python txt2image.py --help
```

Inference
---------

Inference in this example is similar to the stable diffusion example. The
classes to get you started are `FluxPipeline` from the `flux` module.

```python
import mlx.core as mx
from flux import FluxPipeline

# This will download all the weights from HF hub
flux = FluxPipeline("flux-schnell")

# Make a generator that returns the latent variables from the reverse diffusion
# process
latent_generator = flux.generate_latents(
    "A photo of an astronaut riding a horse on Mars",
    num_steps=4,
    latent_size=(32, 64),  # 256x512 image
)

# The first return value of the generator contains the conditioning and the
# random noise at the beginning of the diffusion process.
conditioning = next(latent_generator)
(
    x_T,                # The initial noise
    x_positions,        # The integer positions used for image positional encoding
    t5_conditioning,    # The T5 features from the text prompt
    t5_positions,       # Integer positions for text (normally all 0s)
    clip_conditioning,  # The clip text features from the text prompt
) = conditioning

# Returning the conditioning as the first output from the generator allows us
# to unload T5 and clip before running the diffusion transformer.
mx.eval(conditioning)

# Evaluate each diffusion step
for x_t in latent_generator:
    mx.eval(x_t)

# Note that we need to pass the latent size because it is collapsed and
# patchified in x_t and we need to unwrap it.
img = flux.decode(x_t, latent_size=(32, 64))
```

The above are essentially the implementation of the `txt2image.py` script
except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows:

```shell
python txt2image.py \
    --n-images 4 \
    --n-rows 2 \
    --image-size 256x512 \
    'A photo of an astronaut riding a horse on Mars.'
```

### Experimental Options

FLUX pads the prompt to a specific size of 512 tokens for the dev model and
256 for the schnell model. Not applying padding results in faster generation
but it is not clear how it may affect the generated images. To enable that
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.

Finetuning
----------

The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`train.jsonl` file with the following format:

```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
```

The training script by default trains for 600 iterations with a batch size of
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
--help` for the list of hyperparameters you can tune.

> [!Note]
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
> should reduce this number significantly.

### Training Example

This is a step-by-step finetuning example. We will be using the data from
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
In particular, we will use `dog6` which is a popular example for showcasing
dreambooth [^1].

The training images are the following 5 images [^2]:

![dog6](static/dog6.png)

We start by making the following `train.jsonl` file and placing it in the same
folder as the images.

```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```

Subsequently we finetune FLUX using the following command:

```shell
python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 8 \
    path/to/dreambooth/dataset/dog6
```

Or you can directly use the pre-processed Hugging Face dataset
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
for fine-tuning.

```shell
python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 8 \
    mlx-community/dreambooth-dog6
```

The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.

### Using the Adapter

The adapters are saved in `mlx_output` and can be used directly by the
`txt2image.py` script. For instance,

```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
    --adapter mlx_output/final_adapters.safetensors \
    --fuse-adapter \
    --no-t5-padding \
    'A photo of an sks dog lying on the sand at a beach in Greece'
```

generates an image that looks like the following,

![dog image](static/dog-r4-g8-1200.png)

and of course we can pass `--image-size 512x1024` to get larger images with
different aspect ratios,

![wide dog image](static/dog-r4-g8-1200-512x1024.png)

The arguments that are relevant to the adapters are of course `--adapter` and
`--fuse-adapter`. The first defines the path to an adapter to apply to the
model and the second fuses the adapter back into the model to get a bit more
speed during generation.

[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .


Distributed Computation
------------------------

The FLUX example supports distributed computation during both generation and
training. See the [distributed communication
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
for information on how to set-up MLX for distributed communication. The rest of
this section assumes you can launch distributed MLX programs using `mlx.launch
--hostfile hostfile.json`.

### Distributed Finetuning

Distributed finetuning scales very well with FLUX and all one has to do is
adjust the gradient accumulation and training iterations so that the batch
size remains the same. For instance, to replicate the following training

```shell
python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 600 --iterations 1200 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 8 \
    mlx-community/dreambooth-dog6
```

On 4 machines we simply run

```shell
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
    --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
    --progress-every 150 --iterations 300 --learning-rate 0.0001 \
    --lora-rank 4 --grad-accumulate 2 \
    mlx-community/dreambooth-dog6
```

Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.

### Distributed Inference

Distributed inference can be divided in two different approaches. The first
approach is the data-parallel approach, where each node generates its own
images and shares them at the end. The second approach is the model-parallel
approach where the model is shared across the nodes and they collaboratively
generate the images.

The `txt2image.py` script will attempt to choose the best approach depending on
how many images are being generated across the nodes. The model-parallel
approach can be forced by passing the argument `--force-shard`.

For better performance in the model-parallel approach we suggest that you use a
[thunderbolt
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).

All you have to do once again is use `mlx.launch` as follows

```shell
mlx.launch --verbose --hostfile hostfile.json -- \
    python txt2image.py --model schnell \
    --n-images 8 \
    --image-size 512x512 \
    --verbose \
    'A photo of an astronaut riding a horse on Mars'
```

for model-parallel generation you may want to also pass `--env
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
reduces the CPU/GPU synchronization overhead.


================================================
FILE: flux/dreambooth.py
================================================
# Copyright © 2024 Apple Inc.

import argparse
import time
from functools import partial
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image

from flux import FluxPipeline, Trainer, load_dataset, save_config


def generate_progress_images(iteration, flux, args):
    """Generate images to monitor the progress of the finetuning."""
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / f"{iteration:07d}_progress.png"
    print(f"Generating {str(out_file)}", flush=True)

    # Generate some images and arrange them in a grid
    n_rows = 2
    n_images = 4
    x = flux.generate_images(
        args.progress_prompt,
        n_images,
        args.progress_steps,
    )
    x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
    B, H, W, C = x.shape
    x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
    x = x.reshape(n_rows * H, B // n_rows * W, C)
    x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
    x = (x * 255).astype(mx.uint8)

    # Save them to disc
    im = Image.fromarray(np.array(x))
    im.save(out_file)


def save_adapters(adapter_name, flux, args):
    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_file = out_dir / adapter_name
    print(f"Saving {str(out_file)}")

    mx.save_safetensors(
        str(out_file),
        dict(tree_flatten(flux.flow.trainable_parameters())),
        metadata={
            "lora_rank": str(args.lora_rank),
            "lora_blocks": str(args.lora_blocks),
        },
    )


def setup_arg_parser():
    """Set up and return the argument parser."""
    parser = argparse.ArgumentParser(
        description="Finetune Flux to generate images with a specific subject"
    )

    parser.add_argument(
        "--model",
        default="dev",
        choices=[
            "dev",
            "schnell",
        ],
        help="Which flux model to train",
    )
    parser.add_argument(
        "--guidance", type=float, default=4.0, help="The guidance factor to use."
    )
    parser.add_argument(
        "--iterations",
        type=int,
        default=600,
        help="How many iterations to train for",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1,
        help="The batch size to use when training the stable diffusion model",
    )
    parser.add_argument(
        "--resolution",
        type=lambda x: tuple(map(int, x.split("x"))),
        default=(512, 512),
        help="The resolution of the training images",
    )
    parser.add_argument(
        "--num-augmentations",
        type=int,
        default=5,
        help="Augment the images by random cropping and panning",
    )
    parser.add_argument(
        "--progress-prompt",
        required=True,
        help="Use this prompt when generating images for evaluation",
    )
    parser.add_argument(
        "--progress-steps",
        type=int,
        default=50,
        help="Use this many steps when generating images for evaluation",
    )
    parser.add_argument(
        "--progress-every",
        type=int,
        default=50,
        help="Generate images every PROGRESS_EVERY steps",
    )
    parser.add_argument(
        "--checkpoint-every",
        type=int,
        default=50,
        help="Save the model every CHECKPOINT_EVERY steps",
    )
    parser.add_argument(
        "--lora-blocks",
        type=int,
        default=-1,
        help="Train the last LORA_BLOCKS transformer blocks",
    )
    parser.add_argument(
        "--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
    )
    parser.add_argument(
        "--warmup-steps", type=int, default=100, help="Learning rate warmup"
    )
    parser.add_argument(
        "--learning-rate", type=float, default="1e-4", help="Learning rate for training"
    )
    parser.add_argument(
        "--grad-accumulate",
        type=int,
        default=4,
        help="Accumulate gradients for that many iterations before applying them",
    )
    parser.add_argument(
        "--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
    )

    parser.add_argument("dataset")
    return parser


if __name__ == "__main__":
    parser = setup_arg_parser()
    args = parser.parse_args()

    output_path = Path(args.output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    save_config(vars(args), output_path / "adapter_config.json")

    # Load the model and set it up for LoRA training. We use the same random
    # state when creating the LoRA layers so all workers will have the same
    # initial weights.
    mx.random.seed(0x0F0F0F0F)
    flux = FluxPipeline("flux-" + args.model)
    flux.flow.freeze()
    flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)

    # Reset the seed to a different seed per worker if we are in distributed
    # mode so that each worker is working on different data, diffusion step and
    # random noise.
    mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())

    # Report how many parameters we are training
    trainable_params = tree_reduce(
        lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
    )
    print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)

    # Set up the optimizer and training steps. The steps are a bit verbose to
    # support gradient accumulation together with compilation.
    warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
    cosine = optim.cosine_decay(
        args.learning_rate, args.iterations // args.grad_accumulate
    )
    lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
    optimizer = optim.Adam(learning_rate=lr_schedule)
    state = [flux.flow.state, optimizer.state, mx.random.state]

    @partial(mx.compile, inputs=state, outputs=state)
    def single_step(x, t5_feat, clip_feat, guidance):
        loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
            x, t5_feat, clip_feat, guidance
        )
        grads = average_gradients(grads)
        optimizer.update(flux.flow, grads)

        return loss

    @partial(mx.compile, inputs=state, outputs=state)
    def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
        return nn.value_and_grad(flux.flow, flux.training_loss)(
            x, t5_feat, clip_feat, guidance
        )

    @partial(mx.compile, inputs=state, outputs=state)
    def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
        loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
            x, t5_feat, clip_feat, guidance
        )
        grads = tree_map(lambda a, b: a + b, prev_grads, grads)
        return loss, grads

    @partial(mx.compile, inputs=state, outputs=state)
    def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
        loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
            x, t5_feat, clip_feat, guidance
        )
        grads = tree_map(
            lambda a, b: (a + b) / args.grad_accumulate,
            prev_grads,
            grads,
        )
        grads = average_gradients(grads)
        optimizer.update(flux.flow, grads)

        return loss

    # We simply route to the appropriate step based on whether we have
    # gradients from a previous step and whether we should be performing an
    # update or simply computing and accumulating gradients in this step.
    def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
        if prev_grads is None:
            if perform_step:
                return single_step(x, t5_feat, clip_feat, guidance), None
            else:
                return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
        else:
            if perform_step:
                return (
                    grad_accumulate_and_step(
                        x, t5_feat, clip_feat, guidance, prev_grads
                    ),
                    None,
                )
            else:
                return compute_loss_and_accumulate_grads(
                    x, t5_feat, clip_feat, guidance, prev_grads
                )

    dataset = load_dataset(args.dataset)
    trainer = Trainer(flux, dataset, args)
    trainer.encode_dataset()

    guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)

    # An initial generation to compare
    generate_progress_images(0, flux, args)

    grads = None
    losses = []
    tic = time.time()
    for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
        loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
        mx.eval(loss, grads, state)
        losses.append(loss.item())

        if (i + 1) % 10 == 0:
            toc = time.time()
            peak_mem = mx.metal.get_peak_memory() / 1024**3
            print(
                f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
                f"It/s: {10 / (toc - tic):.3f} "
                f"Peak mem: {peak_mem:.3f} GB",
                flush=True,
            )

        if (i + 1) % args.progress_every == 0:
            generate_progress_images(i + 1, flux, args)

        if (i + 1) % args.checkpoint_every == 0:
            save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)

        if (i + 1) % 10 == 0:
            losses = []
            tic = time.time()

    save_adapters("final_adapters.safetensors", flux, args)
    print("Training successful.")


================================================
FILE: flux/flux/__init__.py
================================================
# Copyright © 2024 Apple Inc.

from .datasets import Dataset, load_dataset
from .flux import FluxPipeline
from .lora import LoRALinear
from .sampler import FluxSampler
from .trainer import Trainer
from .utils import (
    load_ae,
    load_clip,
    load_clip_tokenizer,
    load_flow_model,
    load_t5,
    load_t5_tokenizer,
    save_config,
)


================================================
FILE: flux/flux/autoencoder.py
================================================
# Copyright © 2024 Apple Inc.

from dataclasses import dataclass
from typing import List

import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.upsample import upsample_nearest


@dataclass
class AutoEncoderParams:
    resolution: int
    in_channels: int
    ch: int
    out_ch: int
    ch_mult: List[int]
    num_res_blocks: int
    z_channels: int
    scale_factor: float
    shift_factor: float


class AttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

        self.norm = nn.GroupNorm(
            num_groups=32,
            dims=in_channels,
            eps=1e-6,
            affine=True,
            pytorch_compatible=True,
        )
        self.q = nn.Linear(in_channels, in_channels)
        self.k = nn.Linear(in_channels, in_channels)
        self.v = nn.Linear(in_channels, in_channels)
        self.proj_out = nn.Linear(in_channels, in_channels)

    def __call__(self, x: mx.array) -> mx.array:
        B, H, W, C = x.shape

        y = x.reshape(B, 1, -1, C)
        y = self.norm(y)
        q = self.q(y)
        k = self.k(y)
        v = self.v(y)
        y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
        y = self.proj_out(y)

        return x + y.reshape(B, H, W, C)


class ResnetBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        self.norm1 = nn.GroupNorm(
            num_groups=32,
            dims=in_channels,
            eps=1e-6,
            affine=True,
            pytorch_compatible=True,
        )
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        self.norm2 = nn.GroupNorm(
            num_groups=32,
            dims=out_channels,
            eps=1e-6,
            affine=True,
            pytorch_compatible=True,
        )
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1
        )
        if self.in_channels != self.out_channels:
            self.nin_shortcut = nn.Linear(in_channels, out_channels)

    def __call__(self, x):
        h = x
        h = self.norm1(h)
        h = nn.silu(h)
        h = self.conv1(h)

        h = self.norm2(h)
        h = nn.silu(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)

        return x + h


class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=2, padding=0
        )

    def __call__(self, x: mx.array):
        x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
        x = self.conv(x)
        return x


class Upsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=1, padding=1
        )

    def __call__(self, x: mx.array):
        x = upsample_nearest(x, (2, 2))
        x = self.conv(x)
        return x


class Encoder(nn.Module):
    def __init__(
        self,
        resolution: int,
        in_channels: int,
        ch: int,
        ch_mult: list[int],
        num_res_blocks: int,
        z_channels: int,
    ):
        super().__init__()
        self.ch = ch
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        # downsampling
        self.conv_in = nn.Conv2d(
            in_channels, self.ch, kernel_size=3, stride=1, padding=1
        )

        curr_res = resolution
        in_ch_mult = (1,) + tuple(ch_mult)
        self.in_ch_mult = in_ch_mult
        self.down = []
        block_in = self.ch
        for i_level in range(self.num_resolutions):
            block = []
            attn = []  # TODO: Remove the attn, nobody appends anything to it
            block_in = ch * in_ch_mult[i_level]
            block_out = ch * ch_mult[i_level]
            for _ in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
                block_in = block_out
            down = {}
            down["block"] = block
            down["attn"] = attn
            if i_level != self.num_resolutions - 1:
                down["downsample"] = Downsample(block_in)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = {}
        self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
        self.mid["attn_1"] = AttnBlock(block_in)
        self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)

        # end
        self.norm_out = nn.GroupNorm(
            num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
        )
        self.conv_out = nn.Conv2d(
            block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
        )

    def __call__(self, x: mx.array):
        hs = [self.conv_in(x)]
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level]["block"][i_block](hs[-1])

                # TODO: Remove the attn
                if len(self.down[i_level]["attn"]) > 0:
                    h = self.down[i_level]["attn"][i_block](h)

                hs.append(h)

            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level]["downsample"](hs[-1]))

        # middle
        h = hs[-1]
        h = self.mid["block_1"](h)
        h = self.mid["attn_1"](h)
        h = self.mid["block_2"](h)

        # end
        h = self.norm_out(h)
        h = nn.silu(h)
        h = self.conv_out(h)

        return h


class Decoder(nn.Module):
    def __init__(
        self,
        ch: int,
        out_ch: int,
        ch_mult: list[int],
        num_res_blocks: int,
        in_channels: int,
        resolution: int,
        z_channels: int,
    ):
        super().__init__()
        self.ch = ch
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.ffactor = 2 ** (self.num_resolutions - 1)

        # compute in_ch_mult, block_in and curr_res at lowest res
        block_in = ch * ch_mult[self.num_resolutions - 1]
        curr_res = resolution // 2 ** (self.num_resolutions - 1)
        self.z_shape = (1, z_channels, curr_res, curr_res)

        # z to block_in
        self.conv_in = nn.Conv2d(
            z_channels, block_in, kernel_size=3, stride=1, padding=1
        )

        # middle
        self.mid = {}
        self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
        self.mid["attn_1"] = AttnBlock(block_in)
        self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)

        # upsampling
        self.up = []
        for i_level in reversed(range(self.num_resolutions)):
            block = []
            attn = []  # TODO: Remove the attn, nobody appends anything to it

            block_out = ch * ch_mult[i_level]
            for _ in range(self.num_res_blocks + 1):
                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
                block_in = block_out
            up = {}
            up["block"] = block
            up["attn"] = attn
            if i_level != 0:
                up["upsample"] = Upsample(block_in)
                curr_res = curr_res * 2
            self.up.insert(0, up)  # prepend to get consistent order

        # end
        self.norm_out = nn.GroupNorm(
            num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
        )
        self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)

    def __call__(self, z: mx.array):
        # z to block_in
        h = self.conv_in(z)

        # middle
        h = self.mid["block_1"](h)
        h = self.mid["attn_1"](h)
        h = self.mid["block_2"](h)

        # upsampling
        for i_level in reversed(range(self.num_resolutions)):
            for i_block in range(self.num_res_blocks + 1):
                h = self.up[i_level]["block"][i_block](h)

                # TODO: Remove the attn
                if len(self.up[i_level]["attn"]) > 0:
                    h = self.up[i_level]["attn"][i_block](h)

            if i_level != 0:
                h = self.up[i_level]["upsample"](h)

        # end
        h = self.norm_out(h)
        h = nn.silu(h)
        h = self.conv_out(h)

        return h


class DiagonalGaussian(nn.Module):
    def __call__(self, z: mx.array):
        mean, logvar = mx.split(z, 2, axis=-1)
        if self.training:
            std = mx.exp(0.5 * logvar)
            eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
            return mean + std * eps
        else:
            return mean


class AutoEncoder(nn.Module):
    def __init__(self, params: AutoEncoderParams):
        super().__init__()
        self.encoder = Encoder(
            resolution=params.resolution,
            in_channels=params.in_channels,
            ch=params.ch,
            ch_mult=params.ch_mult,
            num_res_blocks=params.num_res_blocks,
            z_channels=params.z_channels,
        )
        self.decoder = Decoder(
            resolution=params.resolution,
            in_channels=params.in_channels,
            ch=params.ch,
            out_ch=params.out_ch,
            ch_mult=params.ch_mult,
            num_res_blocks=params.num_res_blocks,
            z_channels=params.z_channels,
        )
        self.reg = DiagonalGaussian()

        self.scale_factor = params.scale_factor
        self.shift_factor = params.shift_factor

    def sanitize(self, weights):
        new_weights = {}
        for k, w in weights.items():
            if w.ndim == 4:
                w = w.transpose(0, 2, 3, 1)
                w = w.reshape(-1).reshape(w.shape)
                if w.shape[1:3] == (1, 1):
                    w = w.squeeze((1, 2))
            new_weights[k] = w
        return new_weights

    def encode(self, x: mx.array):
        z = self.reg(self.encoder(x))
        z = self.scale_factor * (z - self.shift_factor)
        return z

    def decode(self, z: mx.array):
        z = z / self.scale_factor + self.shift_factor
        return self.decoder(z)

    def __call__(self, x: mx.array):
        return self.decode(self.encode(x))


================================================
FILE: flux/flux/clip.py
================================================
# Copyright © 2024 Apple Inc.

from dataclasses import dataclass
from typing import List, Optional

import mlx.core as mx
import mlx.nn as nn

_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}


@dataclass
class CLIPTextModelConfig:
    num_layers: int = 23
    model_dims: int = 1024
    num_heads: int = 16
    max_length: int = 77
    vocab_size: int = 49408
    hidden_act: str = "quick_gelu"

    @classmethod
    def from_dict(cls, config):
        return cls(
            num_layers=config["num_hidden_layers"],
            model_dims=config["hidden_size"],
            num_heads=config["num_attention_heads"],
            max_length=config["max_position_embeddings"],
            vocab_size=config["vocab_size"],
            hidden_act=config["hidden_act"],
        )


@dataclass
class CLIPOutput:
    # The last_hidden_state indexed at the EOS token and possibly projected if
    # the model has a projection layer
    pooled_output: Optional[mx.array] = None

    # The full sequence output of the transformer after the final layernorm
    last_hidden_state: Optional[mx.array] = None

    # A list of hidden states corresponding to the outputs of the transformer layers
    hidden_states: Optional[List[mx.array]] = None


class CLIPEncoderLayer(nn.Module):
    """The transformer encoder layer from CLIP."""

    def __init__(self, model_dims: int, num_heads: int, activation: str):
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(model_dims)
        self.layer_norm2 = nn.LayerNorm(model_dims)

        self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)

        self.linear1 = nn.Linear(model_dims, 4 * model_dims)
        self.linear2 = nn.Linear(4 * model_dims, model_dims)

        self.act = _ACTIVATIONS[activation]

    def __call__(self, x, attn_mask=None):
        y = self.layer_norm1(x)
        y = self.attention(y, y, y, attn_mask)
        x = y + x

        y = self.layer_norm2(x)
        y = self.linear1(y)
        y = self.act(y)
        y = self.linear2(y)
        x = y + x

        return x


class CLIPTextModel(nn.Module):
    """Implements the text encoder transformer from CLIP."""

    def __init__(self, config: CLIPTextModelConfig):
        super().__init__()

        self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
        self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
        self.layers = [
            CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
            for i in range(config.num_layers)
        ]
        self.final_layer_norm = nn.LayerNorm(config.model_dims)

    def _get_mask(self, N, dtype):
        indices = mx.arange(N)
        mask = indices[:, None] < indices[None]
        mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
        return mask

    def sanitize(self, weights):
        new_weights = {}
        for key, w in weights.items():
            # Remove prefixes
            if key.startswith("text_model."):
                key = key[11:]
            if key.startswith("embeddings."):
                key = key[11:]
            if key.startswith("encoder."):
                key = key[8:]

            # Map attention layers
            if "self_attn." in key:
                key = key.replace("self_attn.", "attention.")
            if "q_proj." in key:
                key = key.replace("q_proj.", "query_proj.")
            if "k_proj." in key:
                key = key.replace("k_proj.", "key_proj.")
            if "v_proj." in key:
                key = key.replace("v_proj.", "value_proj.")

            # Map ffn layers
            if "mlp.fc1" in key:
                key = key.replace("mlp.fc1", "linear1")
            if "mlp.fc2" in key:
                key = key.replace("mlp.fc2", "linear2")

            new_weights[key] = w

        return new_weights

    def __call__(self, x):
        # Extract some shapes
        B, N = x.shape
        eos_tokens = x.argmax(-1)

        # Compute the embeddings
        x = self.token_embedding(x)
        x = x + self.position_embedding.weight[:N]

        # Compute the features from the transformer
        mask = self._get_mask(N, x.dtype)
        hidden_states = []
        for l in self.layers:
            x = l(x, mask)
            hidden_states.append(x)

        # Apply the final layernorm and return
        x = self.final_layer_norm(x)
        last_hidden_state = x

        # Select the EOS token
        pooled_output = x[mx.arange(len(x)), eos_tokens]

        return CLIPOutput(
            pooled_output=pooled_output,
            last_hidden_state=last_hidden_state,
            hidden_states=hidden_states,
        )


================================================
FILE: flux/flux/datasets.py
================================================
import json
from pathlib import Path

from PIL import Image


class Dataset:
    def __getitem__(self, index: int):
        raise NotImplementedError()

    def __len__(self):
        raise NotImplementedError()


class LocalDataset(Dataset):
    prompt_key = "prompt"

    def __init__(self, dataset: str, data_file):
        self.dataset_base = Path(dataset)
        with open(data_file, "r") as fid:
            self._data = [json.loads(l) for l in fid]

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index: int):
        item = self._data[index]
        image = Image.open(self.dataset_base / item["image"])
        return image, item[self.prompt_key]


class LegacyDataset(LocalDataset):
    prompt_key = "text"

    def __init__(self, dataset: str):
        self.dataset_base = Path(dataset)
        with open(self.dataset_base / "index.json") as f:
            self._data = json.load(f)["data"]


class HuggingFaceDataset(Dataset):

    def __init__(self, dataset: str):
        from datasets import load_dataset as hf_load_dataset

        self._df = hf_load_dataset(dataset)["train"]

    def __len__(self):
        return len(self._df)

    def __getitem__(self, index: int):
        item = self._df[index]
        return item["image"], item["prompt"]


def load_dataset(dataset: str):
    dataset_base = Path(dataset)
    data_file = dataset_base / "train.jsonl"
    legacy_file = dataset_base / "index.json"

    if data_file.exists():
        print(f"Load the local dataset {data_file} .", flush=True)
        dataset = LocalDataset(dataset, data_file)
    elif legacy_file.exists():
        print(f"Load the local dataset {legacy_file} .")
        print()
        print("     WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
        print("              See the README for details.")
        print(flush=True)
        dataset = LegacyDataset(dataset)
    else:
        print(f"Load the Hugging Face dataset {dataset} .", flush=True)
        dataset = HuggingFaceDataset(dataset)

    return dataset


================================================
FILE: flux/flux/flux.py
================================================
# Copyright © 2024 Apple Inc.

from typing import Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm

from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
    load_ae,
    load_clip,
    load_clip_tokenizer,
    load_flow_model,
    load_t5,
    load_t5_tokenizer,
)


class FluxPipeline:
    def __init__(self, name: str, t5_padding: bool = True):
        self.dtype = mx.bfloat16
        self.name = name
        self.t5_padding = t5_padding

        self.ae = load_ae(name)
        self.flow = load_flow_model(name)
        self.clip = load_clip(name)
        self.clip_tokenizer = load_clip_tokenizer(name)
        self.t5 = load_t5(name)
        self.t5_tokenizer = load_t5_tokenizer(name)
        self.sampler = FluxSampler(name)

    def ensure_models_are_loaded(self):
        mx.eval(
            self.ae.parameters(),
            self.flow.parameters(),
            self.clip.parameters(),
            self.t5.parameters(),
        )

    def reload_text_encoders(self):
        self.t5 = load_t5(self.name)
        self.clip = load_clip(self.name)

    def tokenize(self, text):
        t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
        clip_tokens = self.clip_tokenizer.encode(text)
        return t5_tokens, clip_tokens

    def _prepare_latent_images(self, x):
        b, h, w, c = x.shape

        # Pack the latent image to 2x2 patches
        x = x.reshape(b, h // 2, 2, w // 2, 2, c)
        x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)

        # Create positions ids used to positionally encode each patch. Due to
        # the way RoPE works, this results in an interesting positional
        # encoding where parts of the feature are holding different positional
        # information. Namely, the first part holds information independent of
        # the spatial position (hence 0s), the 2nd part holds vertical spatial
        # information and the last one horizontal.
        i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
        j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
        x_ids = mx.stack([i, j, k], axis=-1)
        x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)

        return x, x_ids

    def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
        # Prepare the text features
        txt = self.t5(t5_tokens)
        if len(txt) == 1 and n_images > 1:
            txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
        txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)

        # Prepare the clip text features
        vec = self.clip(clip_tokens).pooled_output
        if len(vec) == 1 and n_images > 1:
            vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))

        return txt, txt_ids, vec

    def _denoising_loop(
        self,
        x_t,
        x_ids,
        txt,
        txt_ids,
        vec,
        num_steps: int = 35,
        guidance: float = 4.0,
        start: float = 1,
        stop: float = 0,
    ):
        B = len(x_t)

        def scalar(x):
            return mx.full((B,), x, dtype=self.dtype)

        guidance = scalar(guidance)
        timesteps = self.sampler.timesteps(
            num_steps,
            x_t.shape[1],
            start=start,
            stop=stop,
        )
        for i in range(num_steps):
            t = timesteps[i]
            t_prev = timesteps[i + 1]

            pred = self.flow(
                img=x_t,
                img_ids=x_ids,
                txt=txt,
                txt_ids=txt_ids,
                y=vec,
                timesteps=scalar(t),
                guidance=guidance,
            )
            x_t = self.sampler.step(pred, x_t, t, t_prev)

            yield x_t

    def generate_latents(
        self,
        text: str,
        n_images: int = 1,
        num_steps: int = 35,
        guidance: float = 4.0,
        latent_size: Tuple[int, int] = (64, 64),
        seed=None,
    ):
        # Set the PRNG state
        if seed is not None:
            mx.random.seed(seed)

        # Create the latent variables
        x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
        x_T, x_ids = self._prepare_latent_images(x_T)

        # Get the conditioning
        t5_tokens, clip_tokens = self.tokenize(text)
        txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)

        # Yield the conditioning for controlled evaluation by the caller
        yield (x_T, x_ids, txt, txt_ids, vec)

        # Yield the latent sequences from the denoising loop
        yield from self._denoising_loop(
            x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
        )

    def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
        h, w = latent_size
        x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
        x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
        x = self.ae.decode(x)
        return mx.clip(x + 1, 0, 2) * 0.5

    def generate_images(
        self,
        text: str,
        n_images: int = 1,
        num_steps: int = 35,
        guidance: float = 4.0,
        latent_size: Tuple[int, int] = (64, 64),
        seed=None,
        reload_text_encoders: bool = True,
        progress: bool = True,
    ):
        latents = self.generate_latents(
            text, n_images, num_steps, guidance, latent_size, seed
        )
        mx.eval(next(latents))

        if reload_text_encoders:
            self.reload_text_encoders()

        for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
            mx.eval(x_t)

        images = []
        for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
            images.append(self.decode(x_t[i : i + 1]))
            mx.eval(images[-1])
        images = mx.concatenate(images, axis=0)
        mx.eval(images)

        return images

    def training_loss(
        self,
        x_0: mx.array,
        t5_features: mx.array,
        clip_features: mx.array,
        guidance: mx.array,
    ):
        # Get the text conditioning
        txt = t5_features
        txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
        vec = clip_features

        # Prepare the latent input
        x_0, x_ids = self._prepare_latent_images(x_0)

        # Forward process
        t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
        eps = mx.random.normal(x_0.shape, dtype=self.dtype)
        x_t = self.sampler.add_noise(x_0, t, noise=eps)
        x_t = mx.stop_gradient(x_t)

        # Do the denoising
        pred = self.flow(
            img=x_t,
            img_ids=x_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t,
            guidance=guidance,
        )

        return (pred + x_0 - eps).square().mean()

    def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
        """Swap the linear layers in the transformer blocks with LoRA layers."""
        all_blocks = self.flow.double_blocks + self.flow.single_blocks
        all_blocks.reverse()
        num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
        for i, block in zip(range(num_blocks), all_blocks):
            loras = []
            for name, module in block.named_modules():
                if isinstance(module, nn.Linear):
                    loras.append((name, LoRALinear.from_base(module, r=rank)))
            block.update_modules(tree_unflatten(loras))

    def fuse_lora_layers(self):
        fused_layers = []
        for name, module in self.flow.named_modules():
            if isinstance(module, LoRALinear):
                fused_layers.append((name, module.fuse()))
        self.flow.update_modules(tree_unflatten(fused_layers))


================================================
FILE: flux/flux/layers.py
================================================
# Copyright © 2024 Apple Inc.

import math
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn


def _rope(pos: mx.array, dim: int, theta: float):
    scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
    omega = 1.0 / (theta**scale)
    x = pos[..., None] * omega
    cosx = mx.cos(x)
    sinx = mx.sin(x)
    pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
    pe = pe.reshape(*pe.shape[:-1], 2, 2)

    return pe


@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
    return a * b + c * d


def _apply_rope(x, pe):
    s = x.shape
    x = x.reshape(*s[:-1], -1, 1, 2)
    x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
    return x.reshape(s)


def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
    B, H, L, D = q.shape

    q = _apply_rope(q, pe)
    k = _apply_rope(k, pe)
    x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))

    return x.transpose(0, 2, 1, 3).reshape(B, L, -1)


def timestep_embedding(
    t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
    half = dim // 2
    freqs = mx.arange(0, half, dtype=mx.float32) / half
    freqs = freqs * (-math.log(max_period))
    freqs = mx.exp(freqs)

    x = (time_factor * t)[:, None] * freqs[None]
    x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)

    return x.astype(t.dtype)


class EmbedND(nn.Module):
    def __init__(self, dim: int, theta: int, axes_dim: List[int]):
        super().__init__()

        self.dim = dim
        self.theta = theta
        self.axes_dim = axes_dim

    def __call__(self, ids: mx.array):
        n_axes = ids.shape[-1]
        pe = mx.concatenate(
            [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            axis=-3,
        )

        return pe[:, None]


class MLPEmbedder(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int):
        super().__init__()
        self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
        self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def __call__(self, x: mx.array) -> mx.array:
        return self.out_layer(nn.silu(self.in_layer(x)))


class QKNorm(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.query_norm = nn.RMSNorm(dim)
        self.key_norm = nn.RMSNorm(dim)

    def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
        return self.query_norm(q), self.key_norm(k)


class SelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.norm = QKNorm(head_dim)
        self.proj = nn.Linear(dim, dim)

    def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
        H = self.num_heads
        B, L, _ = x.shape
        qkv = self.qkv(x)
        q, k, v = mx.split(qkv, 3, axis=-1)
        q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        q, k = self.norm(q, k)
        x = _attention(q, k, v, pe)
        x = self.proj(x)
        return x


@dataclass
class ModulationOut:
    shift: mx.array
    scale: mx.array
    gate: mx.array


class Modulation(nn.Module):
    def __init__(self, dim: int, double: bool):
        super().__init__()
        self.is_double = double
        self.multiplier = 6 if double else 3
        self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)

    def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
        x = self.lin(nn.silu(x))
        xs = mx.split(x[:, None, :], self.multiplier, axis=-1)

        mod1 = ModulationOut(*xs[:3])
        mod2 = ModulationOut(*xs[3:]) if self.is_double else None

        return mod1, mod2


class DoubleStreamBlock(nn.Module):
    def __init__(
        self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
    ):
        super().__init__()

        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.img_mod = Modulation(hidden_size, double=True)
        self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
        self.img_attn = SelfAttention(
            dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
        )

        self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
        self.img_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approx="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

        self.txt_mod = Modulation(hidden_size, double=True)
        self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
        self.txt_attn = SelfAttention(
            dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
        )

        self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
        self.txt_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approx="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

        self.sharding_group = None

    def __call__(
        self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
    ) -> Tuple[mx.array, mx.array]:
        B, L, _ = img.shape
        _, S, _ = txt.shape
        H = self.num_heads

        img_mod1, img_mod2 = self.img_mod(vec)
        txt_mod1, txt_mod2 = self.txt_mod(vec)

        # prepare image for attention
        img_modulated = self.img_norm1(img)
        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
        img_qkv = self.img_attn.qkv(img_modulated)
        img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
        img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        img_q, img_k = self.img_attn.norm(img_q, img_k)

        # prepare txt for attention
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
        txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
        txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
        txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)

        # run actual attention
        q = mx.concatenate([txt_q, img_q], axis=2)
        k = mx.concatenate([txt_k, img_k], axis=2)
        v = mx.concatenate([txt_v, img_v], axis=2)

        attn = _attention(q, k, v, pe)
        txt_attn, img_attn = mx.split(attn, [S], axis=1)

        # Project - cat - average - split
        txt_attn = self.txt_attn.proj(txt_attn)
        img_attn = self.img_attn.proj(img_attn)
        if self.sharding_group is not None:
            attn = mx.concatenate([txt_attn, img_attn], axis=1)
            attn = mx.distributed.all_sum(attn, group=self.sharding_group)
            txt_attn, img_attn = mx.split(attn, [S], axis=1)

        # calculate the img bloks
        img = img + img_mod1.gate * img_attn
        img_mlp = self.img_mlp(
            (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
        )

        # calculate the txt bloks
        txt = txt + txt_mod1.gate * txt_attn
        txt_mlp = self.txt_mlp(
            (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
        )

        if self.sharding_group is not None:
            txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
            txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
            txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)

        # finalize the img/txt blocks
        img = img + img_mod2.gate * img_mlp
        txt = txt + txt_mod2.gate * txt_mlp

        return img, txt


class SingleStreamBlock(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: Optional[float] = None,
    ):
        super().__init__()
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        # qkv and mlp_in
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
        # proj and mlp_out
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

        self.norm = QKNorm(head_dim)

        self.hidden_size = hidden_size
        self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)

        self.mlp_act = nn.GELU(approx="tanh")
        self.modulation = Modulation(hidden_size, double=False)

    def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
        B, L, _ = x.shape
        H = self.num_heads

        mod, _ = self.modulation(vec)
        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift

        q, k, v, mlp = mx.split(
            self.linear1(x_mod),
            [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
            axis=-1,
        )
        q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
        q, k = self.norm(q, k)

        # compute attention
        y = _attention(q, k, v, pe)

        # compute activation in mlp stream, cat again and run second linear layer
        y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
        return x + mod.gate * y


class LastLayer(nn.Module):
    def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
        self.linear = nn.Linear(
            hidden_size, patch_size * patch_size * out_channels, bias=True
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def __call__(self, x: mx.array, vec: mx.array):
        shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
        x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
        x = self.linear(x)
        return x


================================================
FILE: flux/flux/lora.py
================================================
# Copyright © 2024 Apple Inc.

import math

import mlx.core as mx
import mlx.nn as nn


class LoRALinear(nn.Module):
    @staticmethod
    def from_base(
        linear: nn.Linear,
        r: int = 8,
        dropout: float = 0.0,
        scale: float = 1.0,
    ):
        output_dims, input_dims = linear.weight.shape
        lora_lin = LoRALinear(
            input_dims=input_dims,
            output_dims=output_dims,
            r=r,
            dropout=dropout,
            scale=scale,
        )
        lora_lin.linear = linear
        return lora_lin

    def fuse(self):
        linear = self.linear
        bias = "bias" in linear
        weight = linear.weight
        dtype = weight.dtype

        output_dims, input_dims = weight.shape
        fused_linear = nn.Linear(input_dims, output_dims, bias=bias)

        lora_b = self.scale * self.lora_b.T
        lora_a = self.lora_a.T
        fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
        if bias:
            fused_linear.bias = linear.bias

        return fused_linear

    def __init__(
        self,
        input_dims: int,
        output_dims: int,
        r: int = 8,
        dropout: float = 0.0,
        scale: float = 1.0,
        bias: bool = False,
    ):
        super().__init__()

        # Regular linear layer weights
        self.linear = nn.Linear(input_dims, output_dims, bias=bias)

        self.dropout = nn.Dropout(p=dropout)

        # Scale for low-rank update
        self.scale = scale

        # Low rank lora weights
        scale = 1 / math.sqrt(input_dims)
        self.lora_a = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(input_dims, r),
        )
        self.lora_b = mx.zeros(shape=(r, output_dims))

    def __call__(self, x):
        y = self.linear(x)
        z = (self.dropout(x) @ self.lora_a) @ self.lora_b
        return y + (self.scale * z).astype(x.dtype)


================================================
FILE: flux/flux/model.py
================================================
# Copyright © 2024 Apple Inc.

from dataclasses import dataclass
from typing import Optional

import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_inplace, shard_linear

from .layers import (
    DoubleStreamBlock,
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
)


@dataclass
class FluxParams:
    in_channels: int
    vec_in_dim: int
    context_in_dim: int
    hidden_size: int
    mlp_ratio: float
    num_heads: int
    depth: int
    depth_single_blocks: int
    axes_dim: list[int]
    theta: int
    qkv_bias: bool
    guidance_embed: bool


class Flux(nn.Module):
    def __init__(self, params: FluxParams):
        super().__init__()

        self.params = params
        self.in_channels = params.in_channels
        self.out_channels = self.in_channels
        if params.hidden_size % params.num_heads != 0:
            raise ValueError(
                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
            )
        pe_dim = params.hidden_size // params.num_heads
        if sum(params.axes_dim) != pe_dim:
            raise ValueError(
                f"Got {params.axes_dim} but expected positional dim {pe_dim}"
            )
        self.hidden_size = params.hidden_size
        self.num_heads = params.num_heads
        self.pe_embedder = EmbedND(
            dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
        )
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
            if params.guidance_embed
            else nn.Identity()
        )
        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

        self.double_blocks = [
            DoubleStreamBlock(
                self.hidden_size,
                self.num_heads,
                mlp_ratio=params.mlp_ratio,
                qkv_bias=params.qkv_bias,
            )
            for _ in range(params.depth)
        ]

        self.single_blocks = [
            SingleStreamBlock(
                self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
            )
            for _ in range(params.depth_single_blocks)
        ]

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    def sanitize(self, weights):
        new_weights = {}
        for k, w in weights.items():
            if k.startswith("model.diffusion_model."):
                k = k[22:]
            if k.endswith(".scale"):
                k = k[:-6] + ".weight"
            for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
                if f".{seq}." in k:
                    k = k.replace(f".{seq}.", f".{seq}.layers.")
                    break
            new_weights[k] = w
        return new_weights

    def shard(self, group: Optional[mx.distributed.Group] = None):
        group = group or mx.distributed.init()
        N = group.size()
        if N == 1:
            return

        for block in self.double_blocks:
            block.num_heads //= N
            block.img_attn.num_heads //= N
            block.txt_attn.num_heads //= N
            block.sharding_group = group
            block.img_attn.qkv = shard_linear(
                block.img_attn.qkv, "all-to-sharded", segments=3, group=group
            )
            block.txt_attn.qkv = shard_linear(
                block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
            )
            shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
            shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
            block.img_mlp.layers[0] = shard_linear(
                block.img_mlp.layers[0], "all-to-sharded", group=group
            )
            block.txt_mlp.layers[0] = shard_linear(
                block.txt_mlp.layers[0], "all-to-sharded", group=group
            )
            shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
            shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)

        for block in self.single_blocks:
            block.num_heads //= N
            block.hidden_size //= N
            block.linear1 = shard_linear(
                block.linear1,
                "all-to-sharded",
                segments=[1 / 7, 2 / 7, 3 / 7],
                group=group,
            )
            block.linear2 = shard_linear(
                block.linear2, "sharded-to-all", segments=[1 / 5], group=group
            )

    def __call__(
        self,
        img: mx.array,
        img_ids: mx.array,
        txt: mx.array,
        txt_ids: mx.array,
        timesteps: mx.array,
        y: mx.array,
        guidance: Optional[mx.array] = None,
    ) -> mx.array:
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")

        img = self.img_in(img)
        vec = self.time_in(timestep_embedding(timesteps, 256))
        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError(
                    "Didn't get guidance strength for guidance distilled model."
                )
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
        vec = vec + self.vector_in(y)
        txt = self.txt_in(txt)

        ids = mx.concatenate([txt_ids, img_ids], axis=1)
        pe = self.pe_embedder(ids).astype(img.dtype)

        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

        img = mx.concatenate([txt, img], axis=1)
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)
        img = img[:, txt.shape[1] :, ...]

        img = self.final_layer(img, vec)

        return img


================================================
FILE: flux/flux/sampler.py
================================================
# Copyright © 2024 Apple Inc.

import math
from functools import lru_cache

import mlx.core as mx


class FluxSampler:
    def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
        self._base_shift = base_shift
        self._max_shift = max_shift
        self._schnell = "schnell" in name

    def _time_shift(self, x, t):
        x1, x2 = 256, 4096
        t1, t2 = self._base_shift, self._max_shift
        exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
        t = exp_mu / (exp_mu + (1 / t - 1))
        return t

    @lru_cache
    def timesteps(
        self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
    ):
        t = mx.linspace(start, stop, num_steps + 1)

        if not self._schnell:
            t = self._time_shift(image_sequence_length, t)

        return t.tolist()

    def random_timesteps(self, B, L, dtype=mx.float32, key=None):
        if self._schnell:
            # TODO: Should we upweigh 1 and 0.75?
            t = mx.random.randint(1, 5, shape=(B,), key=key)
            t = t.astype(dtype) / 4
        else:
            t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
            t = self._time_shift(L, t)

        return t

    def sample_prior(self, shape, dtype=mx.float32, key=None):
        return mx.random.normal(shape, dtype=dtype, key=key)

    def add_noise(self, x, t, noise=None, key=None):
        noise = (
            noise
            if noise is not None
            else mx.random.normal(x.shape, dtype=x.dtype, key=key)
        )
        t = t.reshape([-1] + [1] * (x.ndim - 1))
        return x * (1 - t) + t * noise

    def step(self, pred, x_t, t, t_prev):
        return x_t + (t_prev - t) * pred


================================================
FILE: flux/flux/t5.py
================================================
# Copyright © 2024 Apple Inc.

import math
from dataclasses import dataclass
from typing import List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn

_SHARED_REPLACEMENT_PATTERNS = [
    (".block.", ".layers."),
    (".k.", ".key_proj."),
    (".o.", ".out_proj."),
    (".q.", ".query_proj."),
    (".v.", ".value_proj."),
    ("shared.", "wte."),
    ("lm_head.", "lm_head.linear."),
    (".layer.0.layer_norm.", ".ln1."),
    (".layer.1.layer_norm.", ".ln2."),
    (".layer.2.layer_norm.", ".ln3."),
    (".final_layer_norm.", ".ln."),
    (
        "layers.0.layer.0.SelfAttention.relative_attention_bias.",
        "relative_attention_bias.embeddings.",
    ),
]

_ENCODER_REPLACEMENT_PATTERNS = [
    (".layer.0.SelfAttention.", ".attention."),
    (".layer.1.DenseReluDense.", ".dense."),
]


@dataclass
class T5Config:
    vocab_size: int
    num_layers: int
    num_heads: int
    relative_attention_num_buckets: int
    d_kv: int
    d_model: int
    feed_forward_proj: str
    tie_word_embeddings: bool

    d_ff: Optional[int] = None
    num_decoder_layers: Optional[int] = None
    relative_attention_max_distance: int = 128
    layer_norm_epsilon: float = 1e-6

    @classmethod
    def from_dict(cls, config):
        return cls(
            vocab_size=config["vocab_size"],
            num_layers=config["num_layers"],
            num_heads=config["num_heads"],
         
Download .txt
gitextract_ur1ntgnw/

├── .github/
│   └── workflows/
│       └── pull_request.yml
├── .gitignore
├── .pre-commit-config.yaml
├── ACKNOWLEDGMENTS.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── bert/
│   ├── README.md
│   ├── convert.py
│   ├── model.py
│   ├── requirements.txt
│   ├── test.py
│   └── weights/
│       └── .gitignore
├── cifar/
│   ├── README.md
│   ├── dataset.py
│   ├── main.py
│   ├── requirements.txt
│   └── resnet.py
├── clip/
│   ├── .gitignore
│   ├── README.md
│   ├── clip.py
│   ├── convert.py
│   ├── hf_preproc.py
│   ├── image_processor.py
│   ├── linear_probe.py
│   ├── model.py
│   ├── requirements.txt
│   ├── test.py
│   └── tokenizer.py
├── cvae/
│   ├── .gitignore
│   ├── README.md
│   ├── dataset.py
│   ├── main.py
│   ├── requirements.txt
│   └── vae.py
├── encodec/
│   ├── README.md
│   ├── benchmarks/
│   │   ├── bench_mx.py
│   │   └── bench_pt.py
│   ├── convert.py
│   ├── encodec.py
│   ├── example.py
│   ├── requirements.txt
│   ├── test.py
│   └── utils.py
├── flux/
│   ├── README.md
│   ├── dreambooth.py
│   ├── flux/
│   │   ├── __init__.py
│   │   ├── autoencoder.py
│   │   ├── clip.py
│   │   ├── datasets.py
│   │   ├── flux.py
│   │   ├── layers.py
│   │   ├── lora.py
│   │   ├── model.py
│   │   ├── sampler.py
│   │   ├── t5.py
│   │   ├── tokenizers.py
│   │   ├── trainer.py
│   │   └── utils.py
│   ├── generate_interactive.py
│   ├── requirements.txt
│   └── txt2image.py
├── gcn/
│   ├── .gitignore
│   ├── README.md
│   ├── datasets.py
│   ├── gcn.py
│   ├── main.py
│   └── requirements.txt
├── llava/
│   ├── .gitignore
│   ├── README.md
│   ├── generate.py
│   ├── language.py
│   ├── llava.py
│   ├── requirements.txt
│   ├── test.py
│   └── vision.py
├── llms/
│   ├── README.md
│   ├── gguf_llm/
│   │   ├── README.md
│   │   ├── generate.py
│   │   ├── models.py
│   │   ├── requirements.txt
│   │   └── utils.py
│   ├── llama/
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── llama.py
│   │   ├── requirements.txt
│   │   └── sample_prompt.txt
│   ├── mistral/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── mistral.py
│   │   ├── requirements.txt
│   │   └── test.py
│   ├── mixtral/
│   │   ├── README.md
│   │   ├── convert.py
│   │   ├── mixtral.py
│   │   ├── params.json
│   │   └── requirements.txt
│   └── speculative_decoding/
│       ├── README.md
│       ├── convert.py
│       ├── decoder.py
│       ├── main.py
│       ├── model.py
│       └── requirements.txt
├── lora/
│   ├── .gitignore
│   ├── README.md
│   ├── convert.py
│   ├── data/
│   │   ├── test.jsonl
│   │   ├── train.jsonl
│   │   ├── valid.jsonl
│   │   └── wikisql.py
│   ├── fuse.py
│   ├── lora.py
│   ├── models.py
│   ├── requirements.txt
│   └── utils.py
├── mnist/
│   ├── README.md
│   ├── main.py
│   ├── mnist.py
│   └── requirements.txt
├── musicgen/
│   ├── README.md
│   ├── benchmarks/
│   │   ├── bench_mx.py
│   │   └── bench_pt.py
│   ├── generate.py
│   ├── musicgen.py
│   ├── requirements.txt
│   └── utils.py
├── normalizing_flow/
│   ├── README.md
│   ├── bijectors.py
│   ├── distributions.py
│   ├── flows.py
│   ├── main.py
│   └── requirements.txt
├── segment_anything/
│   ├── README.md
│   ├── convert.py
│   ├── main.py
│   ├── notebooks/
│   │   ├── automatic_mask_generator_example.ipynb
│   │   └── predictor_example.ipynb
│   ├── requirements.txt
│   └── segment_anything/
│       ├── __init__.py
│       ├── automatic_mask_generator.py
│       ├── common.py
│       ├── image_encoder.py
│       ├── mask_decoder.py
│       ├── predictor.py
│       ├── prompt_encoder.py
│       ├── sam.py
│       ├── transformer.py
│       └── utils/
│           ├── __init__.py
│           ├── amg.py
│           └── transforms.py
├── speechcommands/
│   ├── README.md
│   ├── kwt.py
│   ├── main.py
│   └── requirements.txt
├── stable_diffusion/
│   ├── README.md
│   ├── image2image.py
│   ├── requirements.txt
│   ├── stable_diffusion/
│   │   ├── __init__.py
│   │   ├── clip.py
│   │   ├── config.py
│   │   ├── model_io.py
│   │   ├── sampler.py
│   │   ├── tokenizer.py
│   │   ├── unet.py
│   │   └── vae.py
│   └── txt2image.py
├── t5/
│   ├── .gitignore
│   ├── README.md
│   ├── hf_t5.py
│   ├── requirements.txt
│   └── t5.py
├── transformer_lm/
│   ├── README.md
│   ├── datasets.py
│   ├── main.py
│   └── requirements.txt
├── whisper/
│   ├── MANIFEST.in
│   ├── README.md
│   ├── benchmark.py
│   ├── convert.py
│   ├── mlx_whisper/
│   │   ├── __init__.py
│   │   ├── _version.py
│   │   ├── assets/
│   │   │   ├── download_alice.sh
│   │   │   ├── gpt2.tiktoken
│   │   │   ├── ls_test.flac
│   │   │   ├── mel_filters.npz
│   │   │   └── multilingual.tiktoken
│   │   ├── audio.py
│   │   ├── cli.py
│   │   ├── decoding.py
│   │   ├── load_models.py
│   │   ├── requirements.txt
│   │   ├── timing.py
│   │   ├── tokenizer.py
│   │   ├── torch_whisper.py
│   │   ├── transcribe.py
│   │   ├── whisper.py
│   │   └── writers.py
│   ├── setup.py
│   └── test.py
└── wwdc25/
    ├── Explore_language_models_on_Apple_silicon_with_MLX.ipynb
    ├── Get_started_with_MLX_for_Apple_silicon.ipynb
    ├── README.md
    ├── WWDC25MLXSwiftExamples/
    │   ├── WWDC25MLXSwiftExamples/
    │   │   ├── SimpleMLXLM.swift
    │   │   ├── SimpleMLXLMWithKVCache.swift
    │   │   └── main.swift
    │   └── WWDC25MLXSwiftExamples.xcodeproj/
    │       ├── project.pbxproj
    │       ├── project.xcworkspace/
    │       │   ├── contents.xcworkspacedata
    │       │   ├── xcshareddata/
    │       │   │   └── swiftpm/
    │       │   │       └── Package.resolved
    │       │   └── xcuserdata/
    │       │       └── shashankprasanna.xcuserdatad/
    │       │           └── UserInterfaceState.xcuserstate
    │       └── xcuserdata/
    │           └── shashankprasanna.xcuserdatad/
    │               └── xcschemes/
    │                   └── xcschememanagement.plist
    ├── data/
    │   ├── all.jsonl
    │   ├── train.jsonl
    │   └── valid.jsonl
    └── requirements.txt
Download .txt
SYMBOL INDEX (1221 symbols across 113 files)

FILE: bert/convert.py
  function replace_key (line 7) | def replace_key(key: str) -> str:
  function convert (line 22) | def convert(bert_model: str, mlx_model: str) -> None:

FILE: bert/model.py
  class TransformerEncoderLayer (line 11) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 16) | def __init__(
    method __call__ (line 32) | def __call__(self, x, mask):
  class TransformerEncoder (line 44) | class TransformerEncoder(nn.Module):
    method __init__ (line 45) | def __init__(
    method __call__ (line 54) | def __call__(self, x, mask):
  class BertEmbeddings (line 61) | class BertEmbeddings(nn.Module):
    method __init__ (line 62) | def __init__(self, config):
    method __call__ (line 73) | def __call__(
  class Bert (line 91) | class Bert(nn.Module):
    method __init__ (line 92) | def __init__(self, config):
    method __call__ (line 103) | def __call__(
  function load_model (line 120) | def load_model(
  function run (line 137) | def run(bert_model: str, mlx_model: str, batch: List[str]):

FILE: bert/test.py
  function run_torch (line 9) | def run_torch(bert_model: str, batch: List[str]):

FILE: cifar/dataset.py
  function get_cifar10 (line 6) | def get_cifar10(batch_size, root=None):

FILE: cifar/main.py
  function print_zero (line 26) | def print_zero(group, *args, **kwargs):
  function eval_fn (line 33) | def eval_fn(model, inp, tgt):
  function train_epoch (line 37) | def train_epoch(model, train_iter, optimizer, epoch):
  function test_epoch (line 100) | def test_epoch(model, test_iter, epoch):
  function main (line 116) | def main(args):

FILE: cifar/resnet.py
  class ShortcutA (line 21) | class ShortcutA(nn.Module):
    method __init__ (line 22) | def __init__(self, dims):
    method __call__ (line 26) | def __call__(self, x):
  class Block (line 33) | class Block(nn.Module):
    method __init__ (line 39) | def __init__(self, in_dims, dims, stride=1):
    method __call__ (line 57) | def __call__(self, x):
  class ResNet (line 68) | class ResNet(nn.Module):
    method __init__ (line 73) | def __init__(self, block, num_blocks, num_classes=10):
    method _make_layer (line 84) | def _make_layer(self, block, in_dims, dims, num_blocks, stride):
    method num_params (line 92) | def num_params(self):
    method __call__ (line 96) | def __call__(self, x):
  function resnet20 (line 106) | def resnet20(**kwargs):
  function resnet32 (line 110) | def resnet32(**kwargs):
  function resnet44 (line 114) | def resnet44(**kwargs):
  function resnet56 (line 118) | def resnet56(**kwargs):
  function resnet110 (line 122) | def resnet110(**kwargs):
  function resnet1202 (line 126) | def resnet1202(**kwargs):

FILE: clip/clip.py
  function load (line 8) | def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImagePro...

FILE: clip/convert.py
  function make_shards (line 14) | def make_shards(weights: dict, max_file_size_gb: int = 5) -> list:
  function save_weights (line 28) | def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -...
  function get_model_path (line 66) | def get_model_path(path_or_hf_repo: str, force_download: bool = False) -...
  function torch_to_mx (line 83) | def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:

FILE: clip/image_processor.py
  class CLIPImageProcessor (line 12) | class CLIPImageProcessor:
    method __init__ (line 18) | def __init__(
    method __call__ (line 37) | def __call__(self, images: List[Image]) -> mx.array:
    method _preprocess (line 42) | def _preprocess(self, image: Image) -> mx.array:
    method from_pretrained (line 54) | def from_pretrained(path: str):
  function resize (line 61) | def resize(image: Image, short_size: int) -> Image:
  function center_crop (line 76) | def center_crop(image: Image, size: Tuple[int, int]) -> Image:
  function rescale (line 88) | def rescale(image: mx.array) -> mx.array:
  function normalize (line 92) | def normalize(image: mx.array, mean: mx.array, std: mx.array) -> mx.array:

FILE: clip/linear_probe.py
  function get_cifar10 (line 14) | def get_cifar10(batch_size, root=None):
  function get_features (line 21) | def get_features(model, image_proc, iter):

FILE: clip/model.py
  class CLIPVisionOutput (line 18) | class CLIPVisionOutput:
  class CLIPTextOutput (line 25) | class CLIPTextOutput:
  class CLIPModelOutput (line 31) | class CLIPModelOutput:
  class CLIPTextConfig (line 40) | class CLIPTextConfig:
  class CLIPVisionConfig (line 51) | class CLIPVisionConfig:
  class CLIPConfig (line 63) | class CLIPConfig:
  function quick_gelu (line 69) | def quick_gelu(x: mx.array) -> mx.array:
  function clip_loss (line 76) | def clip_loss(logits: mx.array) -> mx.array:
  class Attention (line 83) | class Attention(nn.Module):
    method __init__ (line 84) | def __init__(
    method __call__ (line 115) | def __call__(self, queries, keys, values, mask=None):
  class MLP (line 137) | class MLP(nn.Module):
    method __init__ (line 138) | def __init__(self, config: CLIPTextConfig):
    method __call__ (line 145) | def __call__(self, x: mx.array) -> mx.array:
  class EncoderLayer (line 151) | class EncoderLayer(nn.Module):
    method __init__ (line 154) | def __init__(self, config: CLIPTextConfig):
    method __call__ (line 165) | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx...
  class TextEmbeddings (line 174) | class TextEmbeddings(nn.Module):
    method __init__ (line 175) | def __init__(self, config: CLIPTextConfig):
    method __call__ (line 184) | def __call__(self, x: mx.array) -> mx.array:
  class Encoder (line 190) | class Encoder(nn.Module):
    method __init__ (line 191) | def __init__(self, config: CLIPTextConfig):
  class ClipTextModel (line 195) | class ClipTextModel(nn.Module):
    method __init__ (line 198) | def __init__(self, config: CLIPTextConfig):
    method __call__ (line 204) | def __call__(self, x: mx.array) -> CLIPTextOutput:
  class VisionEmbeddings (line 219) | class VisionEmbeddings(nn.Module):
    method __init__ (line 220) | def __init__(self, config: CLIPVisionConfig):
    method __call__ (line 241) | def __call__(self, x: mx.array) -> mx.array:
  class ClipVisionModel (line 261) | class ClipVisionModel(nn.Module):
    method __init__ (line 264) | def __init__(self, config: CLIPVisionConfig):
    method __call__ (line 271) | def __call__(
  class CLIPModel (line 295) | class CLIPModel(nn.Module):
    method __init__ (line 296) | def __init__(self, config: CLIPConfig):
    method get_text_features (line 308) | def get_text_features(self, x: mx.array) -> mx.array:
    method get_image_features (line 311) | def get_image_features(self, x: mx.array) -> mx.array:
    method __call__ (line 314) | def __call__(
    method from_pretrained (line 355) | def from_pretrained(path: str):
    method sanitize (line 405) | def sanitize(weights):

FILE: clip/test.py
  function load_mlx_models (line 18) | def load_mlx_models(path):
  function load_hf_models (line 25) | def load_hf_models(path):
  class TestCLIP (line 32) | class TestCLIP(unittest.TestCase):
    method setUpClass (line 34) | def setUpClass(cls):
    method test_image_processor (line 38) | def test_image_processor(self):
    method test_text_tokenizer (line 51) | def test_text_tokenizer(self):
    method test_text_encoder (line 61) | def test_text_encoder(self):
    method test_vision_encoder (line 79) | def test_vision_encoder(self):
    method test_clip_model (line 112) | def test_clip_model(self):

FILE: clip/tokenizer.py
  class CLIPTokenizer (line 11) | class CLIPTokenizer:
    method __init__ (line 14) | def __init__(self, bpe_ranks, vocab):
    method bos (line 24) | def bos(self):
    method bos_token (line 28) | def bos_token(self):
    method eos (line 32) | def eos(self):
    method eos_token (line 36) | def eos_token(self):
    method bpe (line 39) | def bpe(self, text):
    method __call__ (line 84) | def __call__(self, *args: Any, **kwargs: Any) -> Any:
    method tokenize (line 87) | def tokenize(self, text, prepend_bos=True, append_eos=True) -> mx.array:
    method from_pretrained (line 110) | def from_pretrained(path: str):

FILE: cvae/dataset.py
  function mnist (line 6) | def mnist(batch_size, img_size, root=None):

FILE: cvae/main.py
  function grid_image_from_batch (line 18) | def grid_image_from_batch(image_batch, num_rows):
  function loss_fn (line 44) | def loss_fn(model, X):
  function reconstruct (line 57) | def reconstruct(model, batch, out_file):
  function generate (line 66) | def generate(
  function main (line 82) | def main(args):

FILE: cvae/vae.py
  function upsample_nearest (line 10) | def upsample_nearest(x, scale: int = 2):
  class UpsamplingConv2d (line 17) | class UpsamplingConv2d(nn.Module):
    method __init__ (line 25) | def __init__(self, in_channels, out_channels, kernel_size, stride, pad...
    method __call__ (line 31) | def __call__(self, x):
  class Encoder (line 36) | class Encoder(nn.Module):
    method __init__ (line 43) | def __init__(self, num_latent_dims, image_shape, max_num_filters):
    method __call__ (line 74) | def __call__(self, x):
  class Decoder (line 94) | class Decoder(nn.Module):
    method __init__ (line 97) | def __init__(self, num_latent_dims, image_shape, max_num_filters):
    method __call__ (line 133) | def __call__(self, z):
  class CVAE (line 149) | class CVAE(nn.Module):
    method __init__ (line 155) | def __init__(self, num_latent_dims, input_shape, max_num_filters):
    method __call__ (line 161) | def __call__(self, x):
    method encode (line 168) | def encode(self, x):
    method decode (line 171) | def decode(self, z):

FILE: encodec/benchmarks/bench_mx.py
  function fun (line 17) | def fun():

FILE: encodec/benchmarks/bench_pt.py
  function fun (line 18) | def fun():

FILE: encodec/convert.py
  function fetch_from_hub (line 17) | def fetch_from_hub(hf_repo: str) -> Path:
  function upload_to_hub (line 27) | def upload_to_hub(path: str, upload_repo: str, hf_path: str):
  function save_weights (line 76) | def save_weights(save_path: Union[str, Path], weights: Dict[str, Any]) -...
  function save_config (line 98) | def save_config(
  function convert (line 121) | def convert(

FILE: encodec/encodec.py
  function lstm_custom (line 50) | def lstm_custom(x, h_in, cell, time_step):
  class LSTM (line 62) | class LSTM(nn.Module):
    method __init__ (line 63) | def __init__(
    method __call__ (line 76) | def __call__(self, x, hidden=None, cell=None):
  class EncodecConv1d (line 97) | class EncodecConv1d(nn.Module):
    method __init__ (line 100) | def __init__(
    method _get_extra_padding_for_conv1d (line 127) | def _get_extra_padding_for_conv1d(
    method _pad1d (line 137) | def _pad1d(
    method __call__ (line 154) | def __call__(self, hidden_states):
  class EncodecConvTranspose1d (line 180) | class EncodecConvTranspose1d(nn.Module):
    method __init__ (line 183) | def __init__(
    method __call__ (line 200) | def __call__(self, hidden_states):
  class EncodecLSTM (line 218) | class EncodecLSTM(nn.Module):
    method __init__ (line 219) | def __init__(self, config, dimension):
    method __call__ (line 223) | def __call__(self, hidden_states):
  class EncodecResnetBlock (line 230) | class EncodecResnetBlock(nn.Module):
    method __init__ (line 235) | def __init__(self, config, dim: int, dilations: List[int]):
    method __call__ (line 257) | def __call__(self, hidden_states):
  class EncodecEncoder (line 265) | class EncodecEncoder(nn.Module):
    method __init__ (line 268) | def __init__(self, config):
    method __call__ (line 310) | def __call__(self, hidden_states):
  class EncodecDecoder (line 316) | class EncodecDecoder(nn.Module):
    method __init__ (line 319) | def __init__(self, config):
    method __call__ (line 364) | def __call__(self, hidden_states):
  class EncodecEuclideanCodebook (line 370) | class EncodecEuclideanCodebook(nn.Module):
    method __init__ (line 373) | def __init__(self, config):
    method quantize (line 377) | def quantize(self, hidden_states):
    method encode (line 388) | def encode(self, hidden_states):
    method decode (line 395) | def decode(self, embed_ind):
  class EncodecVectorQuantization (line 399) | class EncodecVectorQuantization(nn.Module):
    method __init__ (line 404) | def __init__(self, config):
    method encode (line 408) | def encode(self, hidden_states):
    method decode (line 411) | def decode(self, embed_ind):
  class EncodecResidualVectorQuantizer (line 415) | class EncodecResidualVectorQuantizer(nn.Module):
    method __init__ (line 418) | def __init__(self, config):
    method get_num_quantizers_for_bandwidth (line 431) | def get_num_quantizers_for_bandwidth(
    method encode (line 441) | def encode(
    method decode (line 460) | def decode(self, codes: mx.array) -> mx.array:
  class EncodecModel (line 473) | class EncodecModel(nn.Module):
    method __init__ (line 474) | def __init__(self, config):
    method _encode_frame (line 481) | def _encode_frame(
    method encode (line 510) | def encode(
    method _linear_overlap_add (line 582) | def _linear_overlap_add(frames: List[mx.array], stride: int):
    method _decode_frame (line 606) | def _decode_frame(
    method channels (line 616) | def channels(self):
    method sampling_rate (line 620) | def sampling_rate(self):
    method chunk_length (line 624) | def chunk_length(self):
    method chunk_stride (line 631) | def chunk_stride(self):
    method decode (line 637) | def decode(
    method from_pretrained (line 677) | def from_pretrained(cls, path_or_repo: str):
  function preprocess_audio (line 704) | def preprocess_audio(

FILE: encodec/example.py
  function encode (line 22) | def encode(feats, mask):
  function decode (line 28) | def decode(codes, scales, mask):

FILE: encodec/test.py
  function compare_processors (line 12) | def compare_processors():
  function compare_models (line 33) | def compare_models():

FILE: encodec/utils.py
  function save_audio (line 7) | def save_audio(file: str, audio: mx.array, sampling_rate: int):
  function load_audio (line 17) | def load_audio(file: str, sampling_rate: int, channels: int):

FILE: flux/dreambooth.py
  function generate_progress_images (line 19) | def generate_progress_images(iteration, flux, args):
  function save_adapters (line 46) | def save_adapters(adapter_name, flux, args):
  function setup_arg_parser (line 62) | def setup_arg_parser():
  function single_step (line 194) | def single_step(x, t5_feat, clip_feat, guidance):
  function compute_loss_and_grads (line 204) | def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
  function compute_loss_and_accumulate_grads (line 210) | def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, p...
  function grad_accumulate_and_step (line 218) | def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
  function step (line 235) | def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):

FILE: flux/flux/autoencoder.py
  class AutoEncoderParams (line 12) | class AutoEncoderParams:
  class AttnBlock (line 24) | class AttnBlock(nn.Module):
    method __init__ (line 25) | def __init__(self, in_channels: int):
    method __call__ (line 41) | def __call__(self, x: mx.array) -> mx.array:
  class ResnetBlock (line 55) | class ResnetBlock(nn.Module):
    method __init__ (line 56) | def __init__(self, in_channels: int, out_channels: int):
    method __call__ (line 85) | def __call__(self, x):
  class Downsample (line 101) | class Downsample(nn.Module):
    method __init__ (line 102) | def __init__(self, in_channels: int):
    method __call__ (line 108) | def __call__(self, x: mx.array):
  class Upsample (line 114) | class Upsample(nn.Module):
    method __init__ (line 115) | def __init__(self, in_channels: int):
    method __call__ (line 121) | def __call__(self, x: mx.array):
  class Encoder (line 127) | class Encoder(nn.Module):
    method __init__ (line 128) | def __init__(
    method __call__ (line 183) | def __call__(self, x: mx.array):
  class Decoder (line 212) | class Decoder(nn.Module):
    method __init__ (line 213) | def __init__(
    method __call__ (line 271) | def __call__(self, z: mx.array):
  class DiagonalGaussian (line 300) | class DiagonalGaussian(nn.Module):
    method __call__ (line 301) | def __call__(self, z: mx.array):
  class AutoEncoder (line 311) | class AutoEncoder(nn.Module):
    method __init__ (line 312) | def __init__(self, params: AutoEncoderParams):
    method sanitize (line 336) | def sanitize(self, weights):
    method encode (line 347) | def encode(self, x: mx.array):
    method decode (line 352) | def decode(self, z: mx.array):
    method __call__ (line 356) | def __call__(self, x: mx.array):

FILE: flux/flux/clip.py
  class CLIPTextModelConfig (line 13) | class CLIPTextModelConfig:
    method from_dict (line 22) | def from_dict(cls, config):
  class CLIPOutput (line 34) | class CLIPOutput:
  class CLIPEncoderLayer (line 46) | class CLIPEncoderLayer(nn.Module):
    method __init__ (line 49) | def __init__(self, model_dims: int, num_heads: int, activation: str):
    method __call__ (line 62) | def __call__(self, x, attn_mask=None):
  class CLIPTextModel (line 76) | class CLIPTextModel(nn.Module):
    method __init__ (line 79) | def __init__(self, config: CLIPTextModelConfig):
    method _get_mask (line 90) | def _get_mask(self, N, dtype):
    method sanitize (line 96) | def sanitize(self, weights):
    method __call__ (line 127) | def __call__(self, x):

FILE: flux/flux/datasets.py
  class Dataset (line 7) | class Dataset:
    method __getitem__ (line 8) | def __getitem__(self, index: int):
    method __len__ (line 11) | def __len__(self):
  class LocalDataset (line 15) | class LocalDataset(Dataset):
    method __init__ (line 18) | def __init__(self, dataset: str, data_file):
    method __len__ (line 23) | def __len__(self):
    method __getitem__ (line 26) | def __getitem__(self, index: int):
  class LegacyDataset (line 32) | class LegacyDataset(LocalDataset):
    method __init__ (line 35) | def __init__(self, dataset: str):
  class HuggingFaceDataset (line 41) | class HuggingFaceDataset(Dataset):
    method __init__ (line 43) | def __init__(self, dataset: str):
    method __len__ (line 48) | def __len__(self):
    method __getitem__ (line 51) | def __getitem__(self, index: int):
  function load_dataset (line 56) | def load_dataset(dataset: str):

FILE: flux/flux/flux.py
  class FluxPipeline (line 22) | class FluxPipeline:
    method __init__ (line 23) | def __init__(self, name: str, t5_padding: bool = True):
    method ensure_models_are_loaded (line 36) | def ensure_models_are_loaded(self):
    method reload_text_encoders (line 44) | def reload_text_encoders(self):
    method tokenize (line 48) | def tokenize(self, text):
    method _prepare_latent_images (line 53) | def _prepare_latent_images(self, x):
    method _prepare_conditioning (line 73) | def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
    method _denoising_loop (line 87) | def _denoising_loop(
    method generate_latents (line 128) | def generate_latents(
    method decode (line 157) | def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
    method generate_images (line 164) | def generate_images(
    method training_loss (line 195) | def training_loss(
    method linear_to_lora_layers (line 229) | def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
    method fuse_lora_layers (line 241) | def fuse_lora_layers(self):

FILE: flux/flux/layers.py
  function _rope (line 12) | def _rope(pos: mx.array, dim: int, theta: float):
  function _ab_plus_cd (line 25) | def _ab_plus_cd(a, b, c, d):
  function _apply_rope (line 29) | def _apply_rope(x, pe):
  function _attention (line 36) | def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
  function timestep_embedding (line 46) | def timestep_embedding(
  class EmbedND (line 60) | class EmbedND(nn.Module):
    method __init__ (line 61) | def __init__(self, dim: int, theta: int, axes_dim: List[int]):
    method __call__ (line 68) | def __call__(self, ids: mx.array):
  class MLPEmbedder (line 78) | class MLPEmbedder(nn.Module):
    method __init__ (line 79) | def __init__(self, in_dim: int, hidden_dim: int):
    method __call__ (line 84) | def __call__(self, x: mx.array) -> mx.array:
  class QKNorm (line 88) | class QKNorm(nn.Module):
    method __init__ (line 89) | def __init__(self, dim: int):
    method __call__ (line 94) | def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.arr...
  class SelfAttention (line 98) | class SelfAttention(nn.Module):
    method __init__ (line 99) | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
    method __call__ (line 108) | def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
  class ModulationOut (line 123) | class ModulationOut:
  class Modulation (line 129) | class Modulation(nn.Module):
    method __init__ (line 130) | def __init__(self, dim: int, double: bool):
    method __call__ (line 136) | def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[Modul...
  class DoubleStreamBlock (line 146) | class DoubleStreamBlock(nn.Module):
    method __init__ (line 147) | def __init__(
    method __call__ (line 183) | def __call__(
  class SingleStreamBlock (line 253) | class SingleStreamBlock(nn.Module):
    method __init__ (line 254) | def __init__(
    method __call__ (line 281) | def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
  class LastLayer (line 306) | class LastLayer(nn.Module):
    method __init__ (line 307) | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
    method __call__ (line 317) | def __call__(self, x: mx.array, vec: mx.array):

FILE: flux/flux/lora.py
  class LoRALinear (line 9) | class LoRALinear(nn.Module):
    method from_base (line 11) | def from_base(
    method fuse (line 28) | def fuse(self):
    method __init__ (line 45) | def __init__(
    method __call__ (line 73) | def __call__(self, x):

FILE: flux/flux/model.py
  class FluxParams (line 21) | class FluxParams:
  class Flux (line 36) | class Flux(nn.Module):
    method __init__ (line 37) | def __init__(self, params: FluxParams):
    method sanitize (line 86) | def sanitize(self, weights):
    method shard (line 100) | def shard(self, group: Optional[mx.distributed.Group] = None):
    method __call__ (line 141) | def __call__(

FILE: flux/flux/sampler.py
  class FluxSampler (line 9) | class FluxSampler:
    method __init__ (line 10) | def __init__(self, name: str, base_shift: float = 0.5, max_shift: floa...
    method _time_shift (line 15) | def _time_shift(self, x, t):
    method timesteps (line 23) | def timesteps(
    method random_timesteps (line 33) | def random_timesteps(self, B, L, dtype=mx.float32, key=None):
    method sample_prior (line 44) | def sample_prior(self, shape, dtype=mx.float32, key=None):
    method add_noise (line 47) | def add_noise(self, x, t, noise=None, key=None):
    method step (line 56) | def step(self, pred, x_t, t, t_prev):

FILE: flux/flux/t5.py
  class T5Config (line 35) | class T5Config:
    method from_dict (line 51) | def from_dict(cls, config):
  class RelativePositionBias (line 70) | class RelativePositionBias(nn.Module):
    method __init__ (line 71) | def __init__(self, config: T5Config, bidirectional: bool):
    method _relative_position_bucket (line 79) | def _relative_position_bucket(rpos, bidirectional, num_buckets, max_di...
    method __call__ (line 98) | def __call__(self, query_length: int, key_length: int, offset: int = 0):
  class MultiHeadAttention (line 119) | class MultiHeadAttention(nn.Module):
    method __init__ (line 120) | def __init__(self, config: T5Config):
    method __call__ (line 129) | def __call__(
  class DenseActivation (line 161) | class DenseActivation(nn.Module):
    method __init__ (line 162) | def __init__(self, config: T5Config):
    method __call__ (line 182) | def __call__(self, x):
  class TransformerEncoderLayer (line 192) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 193) | def __init__(self, config: T5Config):
    method __call__ (line 200) | def __call__(self, x, mask):
  class TransformerEncoder (line 210) | class TransformerEncoder(nn.Module):
    method __init__ (line 211) | def __init__(self, config: T5Config):
    method __call__ (line 219) | def __call__(self, x: mx.array):
  class T5Encoder (line 227) | class T5Encoder(nn.Module):
    method __init__ (line 228) | def __init__(self, config: T5Config):
    method sanitize (line 232) | def sanitize(self, weights):
    method __call__ (line 243) | def __call__(self, inputs: mx.array):

FILE: flux/flux/tokenizers.py
  class CLIPTokenizer (line 8) | class CLIPTokenizer:
    method __init__ (line 11) | def __init__(self, bpe_ranks, vocab, max_length=77):
    method bos (line 23) | def bos(self):
    method bos_token (line 27) | def bos_token(self):
    method eos (line 31) | def eos(self):
    method eos_token (line 35) | def eos_token(self):
    method bpe (line 38) | def bpe(self, text):
    method tokenize (line 83) | def tokenize(self, text, prepend_bos=True, append_eos=True):
    method encode (line 110) | def encode(self, text):
  class T5Tokenizer (line 122) | class T5Tokenizer:
    method __init__ (line 123) | def __init__(self, model_file, max_length=512):
    method pad (line 128) | def pad(self):
    method pad_token (line 135) | def pad_token(self):
    method bos (line 139) | def bos(self):
    method bos_token (line 146) | def bos_token(self):
    method eos (line 150) | def eos(self):
    method eos_token (line 157) | def eos_token(self):
    method tokenize (line 160) | def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
    method encode (line 175) | def encode(self, text, pad=True):

FILE: flux/flux/trainer.py
  class Trainer (line 10) | class Trainer:
    method __init__ (line 12) | def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
    method _random_crop_resize (line 20) | def _random_crop_resize(self, img):
    method _encode_image (line 62) | def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentati...
    method _encode_prompt (line 71) | def _encode_prompt(self, prompt):
    method encode_dataset (line 79) | def encode_dataset(self):
    method iterate (line 86) | def iterate(self, batch_size):

FILE: flux/flux/utils.py
  class ModelSpec (line 20) | class ModelSpec:
  function load_flow_model (line 98) | def load_flow_model(name: str, hf_download: bool = True):
  function load_ae (line 123) | def load_ae(name: str, hf_download: bool = True):
  function load_clip (line 148) | def load_clip(name: str):
  function load_t5 (line 166) | def load_t5(name: str):
  function load_clip_tokenizer (line 194) | def load_clip_tokenizer(name: str):
  function load_t5_tokenizer (line 208) | def load_t5_tokenizer(name: str, pad: bool = True):
  function save_config (line 213) | def save_config(

FILE: flux/generate_interactive.py
  function print_zero (line 12) | def print_zero(group, *args, **kwargs):
  function quantization_predicate (line 18) | def quantization_predicate(name, m):
  function to_latent_size (line 22) | def to_latent_size(image_size):
  function print_help (line 59) | def print_help():

FILE: flux/txt2image.py
  function to_latent_size (line 14) | def to_latent_size(image_size):
  function quantization_predicate (line 28) | def quantization_predicate(name, m):
  function load_adapter (line 32) | def load_adapter(flux, adapter_file, fuse=False):

FILE: gcn/datasets.py
  function download_cora (line 16) | def download_cora():
  function train_val_test_mask (line 41) | def train_val_test_mask():
  function enumerate_labels (line 51) | def enumerate_labels(labels):
  function normalize_adjacency (line 60) | def normalize_adjacency(adj):
  function load_data (line 77) | def load_data(config):

FILE: gcn/gcn.py
  class GCNLayer (line 4) | class GCNLayer(nn.Module):
    method __init__ (line 5) | def __init__(self, in_features, out_features, bias=True):
    method __call__ (line 9) | def __call__(self, x, adj):
  class GCN (line 14) | class GCN(nn.Module):
    method __init__ (line 15) | def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bi...
    method __call__ (line 25) | def __call__(self, x, adj):

FILE: gcn/main.py
  function loss_fn (line 14) | def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
  function eval_fn (line 25) | def eval_fn(x, y):
  function forward_fn (line 29) | def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
  function main (line 35) | def main(args):

FILE: llava/generate.py
  function parse_arguments (line 15) | def parse_arguments():
  function load_image (line 55) | def load_image(image_source):
  function prepare_inputs (line 79) | def prepare_inputs(processor, image, prompt):
  function load_model (line 88) | def load_model(model_path, tokenizer_config={}):
  function sample (line 94) | def sample(logits, temperature=0.0):
  function generate_text (line 101) | def generate_text(input_ids, pixel_values, model, processor, max_tokens,...
  function main (line 119) | def main():

FILE: llava/language.py
  class TextConfig (line 12) | class TextConfig:
    method from_dict (line 26) | def from_dict(cls, params):
    method __post_init__ (line 35) | def __post_init__(self):
  class Attention (line 48) | class Attention(nn.Module):
    method __init__ (line 49) | def __init__(self, config: TextConfig):
    method __call__ (line 79) | def __call__(
  class MLP (line 111) | class MLP(nn.Module):
    method __init__ (line 112) | def __init__(self, dim, hidden_dim):
    method __call__ (line 118) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 122) | class TransformerBlock(nn.Module):
    method __init__ (line 123) | def __init__(self, config: TextConfig):
    method __call__ (line 135) | def __call__(
  class Llama (line 148) | class Llama(nn.Module):
    method __init__ (line 149) | def __init__(self, config: TextConfig):
    method __call__ (line 161) | def __call__(
  class LanguageModel (line 187) | class LanguageModel(nn.Module):
    method __init__ (line 188) | def __init__(self, config: TextConfig):
    method __call__ (line 198) | def __call__(
    method sanitize (line 208) | def sanitize(weights):

FILE: llava/llava.py
  class LlaVAConfig (line 19) | class LlaVAConfig:
    method from_dict (line 29) | def from_dict(cls, params):
  class LlavaMultiModalProjector (line 39) | class LlavaMultiModalProjector(nn.Module):
    method __init__ (line 40) | def __init__(self, config: LlaVAConfig):
    method __call__ (line 50) | def __call__(self, x: mx.array) -> mx.array:
  class LlavaModel (line 57) | class LlavaModel(nn.Module):
    method __init__ (line 58) | def __init__(self, config: LlaVAConfig):
    method get_input_embeddings (line 66) | def get_input_embeddings(
    method _merge_input_ids_with_image_features (line 103) | def _merge_input_ids_with_image_features(
    method __call__ (line 123) | def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=...
    method from_pretrained (line 131) | def from_pretrained(path_or_hf_repo: str):

FILE: llava/test.py
  function load_mlx_models (line 18) | def load_mlx_models(path):
  function load_hf_models (line 24) | def load_hf_models(path):
  class TestVisionTower (line 30) | class TestVisionTower(unittest.TestCase):
    method setUpClass (line 32) | def setUpClass(cls):
    method test_image_features (line 37) | def test_image_features(self):
  class TestLlava (line 77) | class TestLlava(unittest.TestCase):
    method setUpClass (line 79) | def setUpClass(cls):
    method test_merge_input_ids_with_image_features (line 84) | def test_merge_input_ids_with_image_features(self):
    method test_generated_tokens (line 139) | def test_generated_tokens(self):

FILE: llava/vision.py
  class VisionConfig (line 13) | class VisionConfig:
    method from_dict (line 27) | def from_dict(cls, params):
  class Attention (line 37) | class Attention(nn.Module):
    method __init__ (line 38) | def __init__(
    method __call__ (line 69) | def __call__(self, queries, keys, values, mask=None):
  class MLP (line 91) | class MLP(nn.Module):
    method __init__ (line 92) | def __init__(self, config: VisionConfig):
    method __call__ (line 98) | def __call__(self, x: mx.array) -> mx.array:
  class EncoderLayer (line 104) | class EncoderLayer(nn.Module):
    method __init__ (line 105) | def __init__(self, config: VisionConfig):
    method __call__ (line 115) | def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx...
  class Encoder (line 124) | class Encoder(nn.Module):
    method __init__ (line 125) | def __init__(self, config: VisionConfig):
  class VisionEmbeddings (line 130) | class VisionEmbeddings(nn.Module):
    method __init__ (line 131) | def __init__(self, config: VisionConfig):
    method __call__ (line 152) | def __call__(self, x: mx.array) -> mx.array:
  class ClipVisionModel (line 165) | class ClipVisionModel(nn.Module):
    method __init__ (line 166) | def __init__(self, config: VisionConfig):
    method __call__ (line 173) | def __call__(
  class VisionModel (line 192) | class VisionModel(nn.Module):
    method __init__ (line 193) | def __init__(self, config: VisionConfig):
    method __call__ (line 202) | def __call__(
    method sanitize (line 208) | def sanitize(weights):

FILE: llms/gguf_llm/generate.py
  function generate (line 10) | def generate(

FILE: llms/gguf_llm/models.py
  class ModelArgs (line 14) | class ModelArgs:
    method __post_init__ (line 28) | def __post_init__(self):
    method from_dict (line 41) | def from_dict(cls, params):
  class Attention (line 51) | class Attention(nn.Module):
    method __init__ (line 52) | def __init__(self, args: ModelArgs):
    method __call__ (line 80) | def __call__(
  class MLP (line 112) | class MLP(nn.Module):
    method __init__ (line 113) | def __init__(self, dim, hidden_dim):
    method __call__ (line 119) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 123) | class TransformerBlock(nn.Module):
    method __init__ (line 124) | def __init__(self, args: ModelArgs):
    method __call__ (line 136) | def __call__(
  class LlamaModel (line 149) | class LlamaModel(nn.Module):
    method __init__ (line 150) | def __init__(self, args: ModelArgs):
    method __call__ (line 172) | def __call__(
  class Model (line 193) | class Model(nn.Module):
    method __init__ (line 194) | def __init__(self, args: ModelArgs):
    method __call__ (line 199) | def __call__(
  function get_config (line 208) | def get_config(metadata: dict):
  class GGUFTokenizer (line 225) | class GGUFTokenizer:
    method __init__ (line 226) | def __init__(self, metadata):
    method encode (line 229) | def encode(self, s: str) -> mx.array:
    method eos_token_id (line 233) | def eos_token_id(self):
    method decode (line 236) | def decode(self, toks: List[int]) -> str:
  function translate_weight_names (line 240) | def translate_weight_names(name):
  function load (line 257) | def load(gguf_file: str, repo: Optional[str] = None):
  function generate (line 313) | def generate(prompt: mx.array, model: Model, temp: float = 0.0):

FILE: llms/gguf_llm/utils.py
  function spm_tokenizer (line 5) | def spm_tokenizer(metadata):

FILE: llms/llama/convert.py
  function torch_to_mx (line 19) | def torch_to_mx(a: torch.Tensor, *, dtype: str) -> mx.array:
  function llama (line 25) | def llama(model_path, *, dtype: str):
  function tiny_llama (line 67) | def tiny_llama(model_path, *, dtype: str):
  function quantize (line 128) | def quantize(weights, config, args):
  function make_shards (line 150) | def make_shards(weights: dict, max_file_size_gibibyte: int = 15):

FILE: llms/llama/llama.py
  class ModelArgs (line 18) | class ModelArgs:
  class Attention (line 31) | class Attention(nn.Module):
    method __init__ (line 32) | def __init__(self, args: ModelArgs):
    method __call__ (line 51) | def __call__(
  class FeedForward (line 90) | class FeedForward(nn.Module):
    method __init__ (line 91) | def __init__(self, args: ModelArgs):
    method __call__ (line 98) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 102) | class TransformerBlock(nn.Module):
    method __init__ (line 103) | def __init__(self, args: ModelArgs):
    method __call__ (line 113) | def __call__(
  class Llama (line 126) | class Llama(nn.Module):
    method __init__ (line 127) | def __init__(self, args: ModelArgs):
    method __call__ (line 136) | def __call__(self, x):
    method generate (line 146) | def generate(self, x, temp=1.0):
  function tic (line 198) | def tic():
  function toc (line 202) | def toc(msg, start):
  function generate (line 207) | def generate(args):
  function few_shot_generate (line 243) | def few_shot_generate(args):
  function sanitize_config (line 299) | def sanitize_config(config, weights):
  function load_model (line 318) | def load_model(model_path):

FILE: llms/mistral/convert.py
  function quantize (line 17) | def quantize(weights, config, args):

FILE: llms/mistral/mistral.py
  class ModelArgs (line 17) | class ModelArgs:
  class Attention (line 29) | class Attention(nn.Module):
    method __init__ (line 30) | def __init__(self, args: ModelArgs):
    method __call__ (line 47) | def __call__(
  class FeedForward (line 79) | class FeedForward(nn.Module):
    method __init__ (line 80) | def __init__(self, args: ModelArgs):
    method __call__ (line 87) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 91) | class TransformerBlock(nn.Module):
    method __init__ (line 92) | def __init__(self, args: ModelArgs):
    method __call__ (line 102) | def __call__(
  class Mistral (line 115) | class Mistral(nn.Module):
    method __init__ (line 116) | def __init__(self, args: ModelArgs):
    method __call__ (line 127) | def __call__(
  class Tokenizer (line 148) | class Tokenizer:
    method __init__ (line 149) | def __init__(self, model_path: str):
    method eos_id (line 156) | def eos_id(self) -> int:
    method pad_id (line 160) | def pad_id(self) -> int:
    method encode (line 163) | def encode(self, s: str) -> List[int]:
    method decode (line 166) | def decode(self, t: List[int]) -> str:
  function load_model (line 173) | def load_model(folder: str):
  function generate (line 192) | def generate(prompt: mx.array, model: Mistral, temp: Optional[float] = 0...

FILE: llms/mistral/test.py
  class TestMistral (line 10) | class TestMistral(unittest.TestCase):
    method test_model (line 11) | def test_model(self):
    method test_generate (line 37) | def test_generate(self):
    method benchmark (line 77) | def benchmark(self):

FILE: llms/mixtral/convert.py
  function convert (line 18) | def convert(tf, config):
  function quantize (line 48) | def quantize(weights, config, args):

FILE: llms/mixtral/mixtral.py
  class ModelArgs (line 17) | class ModelArgs:
  class Attention (line 29) | class Attention(nn.Module):
    method __init__ (line 30) | def __init__(self, args: ModelArgs):
    method __call__ (line 47) | def __call__(
  class FeedForward (line 79) | class FeedForward(nn.Module):
    method __init__ (line 80) | def __init__(self, args: ModelArgs):
    method __call__ (line 87) | def __call__(self, x) -> mx.array:
  class MOEFeedForward (line 91) | class MOEFeedForward(nn.Module):
    method __init__ (line 92) | def __init__(self, args: ModelArgs):
    method __call__ (line 99) | def __call__(self, x) -> mx.array:
  class MOETransformerBlock (line 120) | class MOETransformerBlock(nn.Module):
    method __init__ (line 121) | def __init__(self, args: ModelArgs):
    method __call__ (line 131) | def __call__(
  class Mixtral (line 144) | class Mixtral(nn.Module):
    method __init__ (line 145) | def __init__(self, args: ModelArgs):
    method __call__ (line 156) | def __call__(
  class Tokenizer (line 178) | class Tokenizer:
    method __init__ (line 179) | def __init__(self, model_path: str):
    method eos_id (line 186) | def eos_id(self) -> int:
    method pad_id (line 190) | def pad_id(self) -> int:
    method encode (line 193) | def encode(self, s: str) -> List[int]:
    method decode (line 196) | def decode(self, t: List[int]) -> str:
  function load_model (line 203) | def load_model(folder: str):
  function generate (line 224) | def generate(prompt: mx.array, model: Mixtral, temp: Optional[float] = 0...

FILE: llms/speculative_decoding/convert.py
  function replace_key (line 34) | def replace_key(key: str) -> str:
  function convert (line 46) | def convert(model_name, dtype):

FILE: llms/speculative_decoding/decoder.py
  class Tokenizer (line 8) | class Tokenizer:
    method __init__ (line 9) | def __init__(self, model_name: str):
    method eos_id (line 18) | def eos_id(self) -> int:
    method decoder_start_id (line 22) | def decoder_start_id(self) -> int:
    method encode (line 25) | def encode(self, s: str) -> mx.array:
    method decode (line 36) | def decode(self, t: List[int]) -> str:
  class SpeculativeDecoder (line 40) | class SpeculativeDecoder:
    method __init__ (line 41) | def __init__(
    method _generate (line 55) | def _generate(
    method generate (line 69) | def generate(
    method _get_num_accept (line 91) | def _get_num_accept(self, draft_tokens, draft_probs, model_logits):
    method speculative_decode (line 105) | def speculative_decode(

FILE: llms/speculative_decoding/main.py
  function load_model (line 11) | def load_model(model_name: str):
  function main (line 21) | def main(args):

FILE: llms/speculative_decoding/model.py
  function _relative_position_bucket (line 10) | def _relative_position_bucket(
  class RelativePositionBias (line 60) | class RelativePositionBias(nn.Module):
    method __init__ (line 61) | def __init__(self, config: T5Config, bidirectional: bool):
    method __call__ (line 70) | def __call__(self, query_length: int, key_length: int, offset: int = 0):
  class MultiHeadAttention (line 91) | class MultiHeadAttention(nn.Module):
    method __init__ (line 92) | def __init__(self, config: T5Config):
    method __call__ (line 101) | def __call__(
  class DenseActivation (line 135) | class DenseActivation(nn.Module):
    method __init__ (line 136) | def __init__(self, config: T5Config):
    method __call__ (line 156) | def __call__(self, x):
  class TransformerEncoderLayer (line 166) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 167) | def __init__(self, config: T5Config):
    method __call__ (line 174) | def __call__(self, x, mask):
  class TransformerEncoder (line 184) | class TransformerEncoder(nn.Module):
    method __init__ (line 185) | def __init__(self, config: T5Config):
    method __call__ (line 193) | def __call__(self, x: mx.array):
  class TransformerDecoderLayer (line 200) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 201) | def __init__(self, config: T5Config):
    method __call__ (line 210) | def __call__(
  function create_additive_causal_mask (line 233) | def create_additive_causal_mask(N: int, offset: int = 0):
  class TransformerDecoder (line 240) | class TransformerDecoder(nn.Module):
    method __init__ (line 241) | def __init__(self, config: T5Config):
    method __call__ (line 248) | def __call__(self, x, memory, cache=None):
  class OutputHead (line 273) | class OutputHead(nn.Module):
    method __init__ (line 274) | def __init__(self, config: T5Config):
    method __call__ (line 277) | def __call__(self, inputs):
  class Model (line 281) | class Model(nn.Module):
    method __init__ (line 282) | def __init__(self, config: T5Config):
    method encode (line 292) | def encode(self, inputs: mx.array):
    method truncate_cache (line 295) | def truncate_cache(self, num_to_truncate):
    method reset_cache (line 304) | def reset_cache(self):
    method decode (line 307) | def decode(
    method __call__ (line 321) | def __call__(

FILE: lora/convert.py
  function quantize (line 13) | def quantize(weights, config, args):

FILE: lora/data/wikisql.py
  function load (line 14) | def load():
  class WikiSQL (line 21) | class WikiSQL:
    method __init__ (line 22) | def __init__(self, dataset, save_dir="/tmp"):
    method _maybe_download (line 32) | def _maybe_download(self, data_dir):
    method _parse_tables (line 43) | def _parse_tables(self, tables):
    method _parse_queries (line 54) | def _parse_queries(self, queries):
    method query_to_text (line 68) | def query_to_text(self, query, table, columns, types):
    method __getitem__ (line 92) | def __getitem__(self, idx):
    method __len__ (line 95) | def __len__(self):

FILE: lora/lora.py
  function build_parser (line 23) | def build_parser():
  class Dataset (line 131) | class Dataset:
    method __init__ (line 136) | def __init__(self, path: Path, key: str = "text"):
    method __getitem__ (line 144) | def __getitem__(self, idx: int):
    method __len__ (line 147) | def __len__(self):
  function load (line 151) | def load(args):
  function loss (line 178) | def loss(model, inputs, targets, lengths):
  function iterate_batches (line 193) | def iterate_batches(dset, tokenizer, batch_size, train=False):
  function evaluate (line 225) | def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
  function train (line 243) | def train(model, train_set, val_set, optimizer, loss, tokenizer, args):
  function generate (line 303) | def generate(model, prompt, tokenizer, args):

FILE: lora/models.py
  class ModelArgs (line 13) | class ModelArgs:
    method __post_init__ (line 26) | def __post_init__(self):
    method from_dict (line 39) | def from_dict(cls, params):
  class LoRALinear (line 49) | class LoRALinear(nn.Module):
    method from_linear (line 51) | def from_linear(linear: nn.Linear, rank: int = 8):
    method to_linear (line 61) | def to_linear(self):
    method __init__ (line 97) | def __init__(
    method __call__ (line 122) | def __call__(self, x):
  class Attention (line 131) | class Attention(nn.Module):
    method __init__ (line 132) | def __init__(self, args: ModelArgs):
    method __call__ (line 160) | def __call__(
  class MLP (line 192) | class MLP(nn.Module):
    method __init__ (line 193) | def __init__(self, dim, hidden_dim):
    method __call__ (line 199) | def __call__(self, x) -> mx.array:
  class TransformerBlock (line 203) | class TransformerBlock(nn.Module):
    method __init__ (line 204) | def __init__(self, args: ModelArgs):
    method __call__ (line 216) | def __call__(
  class LlamaModel (line 229) | class LlamaModel(nn.Module):
    method __init__ (line 230) | def __init__(self, args: ModelArgs):
    method __call__ (line 242) | def __call__(
  class Model (line 263) | class Model(nn.Module):
    method __init__ (line 264) | def __init__(self, args: ModelArgs):
    method __call__ (line 269) | def __call__(

FILE: lora/utils.py
  function fetch_from_hub (line 16) | def fetch_from_hub(hf_path: str):
  function upload_to_hub (line 36) | def upload_to_hub(path: str, name: str, hf_path: str):
  function make_shards (line 72) | def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
  function save_model (line 87) | def save_model(save_dir: Union[str, Path], weights, tokenizer, config):
  function load (line 129) | def load(path_or_hf_repo: str, tokenizer_config={}):
  function generate (line 175) | def generate(

FILE: mnist/main.py
  class MLP (line 15) | class MLP(nn.Module):
    method __init__ (line 18) | def __init__(
    method __call__ (line 28) | def __call__(self, x):
  function loss_fn (line 34) | def loss_fn(model, X, y):
  function batch_iterate (line 38) | def batch_iterate(batch_size, X, y):
  function main (line 45) | def main(args):

FILE: mnist/mnist.py
  function mnist (line 11) | def mnist(
  function fashion_mnist (line 70) | def fashion_mnist(save_dir="/tmp"):

FILE: musicgen/generate.py
  function main (line 10) | def main(text: str, output_path: str, model_name: str, max_steps: int):

FILE: musicgen/musicgen.py
  class TextConditioner (line 17) | class TextConditioner(nn.Module):
    method __init__ (line 18) | def __init__(self, t5_name, input_dim, output_dim):
    method __call__ (line 23) | def __call__(self, text):
  class KVCache (line 29) | class KVCache:
    method __init__ (line 30) | def __init__(self, head_dim, n_kv_heads):
    method update_and_fetch (line 43) | def update_and_fetch(self, keys, values):
    method state (line 67) | def state(self):
  class MultiHeadAttention (line 71) | class MultiHeadAttention(nn.Module):
    method __init__ (line 72) | def __init__(self, dim, n_heads):
    method __call__ (line 86) | def __call__(
  class TransformerBlock (line 118) | class TransformerBlock(nn.Module):
    method __init__ (line 119) | def __init__(self, config):
    method __call__ (line 132) | def __call__(
  function top_k_sampling (line 149) | def top_k_sampling(
  function create_sin_embedding (line 186) | def create_sin_embedding(positions: mx.array, dim: int, max_period: floa...
  class MusicGen (line 194) | class MusicGen(nn.Module):
    method __init__ (line 195) | def __init__(self, config):
    method __call__ (line 226) | def __call__(
    method generate (line 249) | def generate(
    method sanitize (line 307) | def sanitize(cls, weights):
    method from_pretrained (line 333) | def from_pretrained(cls, path_or_repo: str):

FILE: musicgen/utils.py
  function save_audio (line 7) | def save_audio(file: str, audio: mx.array, sampling_rate: int):

FILE: normalizing_flow/bijectors.py
  class Bijector (line 9) | class Bijector:
    method forward_and_log_det (line 10) | def forward_and_log_det(self, x: mx.array) -> Tuple[mx.array, mx.array]:
    method inverse_and_log_det (line 13) | def inverse_and_log_det(self, y: mx.array) -> Tuple[mx.array, mx.array]:
  class AffineBijector (line 17) | class AffineBijector(Bijector):
    method __init__ (line 18) | def __init__(self, shift_and_log_scale: mx.array):
    method forward_and_log_det (line 21) | def forward_and_log_det(self, x: mx.array):
    method inverse_and_log_det (line 27) | def inverse_and_log_det(self, y: mx.array):
  class MaskedCoupling (line 34) | class MaskedCoupling(Bijector):
    method __init__ (line 35) | def __init__(self, mask: mx.array, conditioner: nn.Module, bijector: B...
    method apply_mask (line 41) | def apply_mask(self, x: mx.array, func: callable):
    method forward_and_log_det (line 50) | def forward_and_log_det(self, x: mx.array):
    method inverse_and_log_det (line 56) | def inverse_and_log_det(self, y: mx.array):

FILE: normalizing_flow/distributions.py
  class Normal (line 9) | class Normal:
    method __init__ (line 10) | def __init__(self, mu: mx.array, sigma: mx.array):
    method sample (line 15) | def sample(
    method log_prob (line 20) | def log_prob(self, x: mx.array):
    method sample_and_log_prob (line 27) | def sample_and_log_prob(

FILE: normalizing_flow/flows.py
  class MLP (line 11) | class MLP(nn.Module):
    method __init__ (line 12) | def __init__(self, n_layers: int, d_in: int, d_hidden: int, d_out: int):
    method __call__ (line 20) | def __call__(self, x):
  class RealNVP (line 26) | class RealNVP(nn.Module):
    method __init__ (line 27) | def __init__(self, n_transforms: int, d_params: int, d_hidden: int, n_...
    method log_prob (line 43) | def log_prob(self, x: mx.array):
    method sample (line 56) | def sample(
    method __call__ (line 74) | def __call__(self, x: mx.array):

FILE: normalizing_flow/main.py
  function get_moons_dataset (line 15) | def get_moons_dataset(n_samples=100_000, noise=0.06):
  function main (line 23) | def main(args):

FILE: segment_anything/convert.py
  function save_weights (line 11) | def save_weights(save_path: Union[str, Path], weights: Dict[str, mx.arra...
  function download (line 34) | def download(hf_repo):
  function convert (line 44) | def convert(model_path):

FILE: segment_anything/main.py
  function write_masks_to_folder (line 138) | def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
  function get_amg_kwargs (line 163) | def get_amg_kwargs(args):
  function main (line 181) | def main(args: argparse.Namespace) -> None:

FILE: segment_anything/segment_anything/automatic_mask_generator.py
  class SamAutomaticMaskGenerator (line 28) | class SamAutomaticMaskGenerator:
    method __init__ (line 29) | def __init__(
    method generate (line 129) | def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
    method _generate_masks (line 191) | def _generate_masks(self, image: np.ndarray) -> MaskData:
    method _process_crop (line 217) | def _process_crop(
    method _process_batch (line 258) | def _process_batch(
    method postprocess_small_regions (line 316) | def postprocess_small_regions(
  function box_area (line 368) | def box_area(boxes: mx.array) -> mx.array:
  function batched_iou (line 384) | def batched_iou(boxes_a: mx.array, boxes_b: mx.array) -> mx.array:
  function non_max_supression (line 406) | def non_max_supression(

FILE: segment_anything/segment_anything/common.py
  class MLPBlock (line 7) | class MLPBlock(nn.Module):
    method __init__ (line 8) | def __init__(
    method __call__ (line 19) | def __call__(self, x: mx.array) -> mx.array:
  class LayerNorm2d (line 23) | class LayerNorm2d(nn.Module):
    method __init__ (line 24) | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
    method __call__ (line 30) | def __call__(self, x: mx.array) -> mx.array:

FILE: segment_anything/segment_anything/image_encoder.py
  class ImageEncoderViT (line 9) | class ImageEncoderViT(nn.Module):
    method __init__ (line 10) | def __init__(
    method __call__ (line 83) | def __call__(self, x: mx.array) -> mx.array:
  class Neck (line 95) | class Neck(nn.Module):
    method __init__ (line 96) | def __init__(self, embed_dim, out_chans):
    method __call__ (line 114) | def __call__(self, x):
  class Block (line 118) | class Block(nn.Module):
    method __init__ (line 121) | def __init__(
    method __call__ (line 167) | def __call__(self, x: mx.array) -> mx.array:
  class Attention (line 186) | class Attention(nn.Module):
    method __init__ (line 189) | def __init__(
    method __call__ (line 225) | def __call__(self, x: mx.array) -> mx.array:
  function window_partition (line 257) | def window_partition(x: mx.array, window_size: int) -> Tuple[mx.array, T...
  function window_unpartition (line 281) | def window_unpartition(
  function get_rel_pos (line 311) | def get_rel_pos(q_size: int, k_size: int, rel_pos: mx.array) -> mx.array:
  function add_decomposed_rel_pos (line 347) | def add_decomposed_rel_pos(
  class PatchEmbed (line 393) | class PatchEmbed(nn.Module):
    method __init__ (line 398) | def __init__(
    method __call__ (line 420) | def __call__(self, x: mx.array) -> mx.array:

FILE: segment_anything/segment_anything/mask_decoder.py
  class MaskDecoder (line 10) | class MaskDecoder(nn.Module):
    method __init__ (line 11) | def __init__(
    method __call__ (line 75) | def __call__(
    method predict_masks (line 116) | def predict_masks(
  class MLP (line 175) | class MLP(nn.Module):
    method __init__ (line 176) | def __init__(
    method __call__ (line 191) | def __call__(self, x):
  class ConvTranspose2d (line 202) | class ConvTranspose2d(nn.Module):
    method __init__ (line 203) | def __init__(
    method _extra_repr (line 232) | def _extra_repr(self):
    method __call__ (line 240) | def __call__(self, x):

FILE: segment_anything/segment_anything/predictor.py
  class SamPredictor (line 10) | class SamPredictor:
    method __init__ (line 11) | def __init__(
    method set_image (line 27) | def set_image(
    method predict (line 59) | def predict(
    method get_image_embedding (line 148) | def get_image_embedding(self) -> mx.array:
    method reset_image (line 163) | def reset_image(self) -> None:

FILE: segment_anything/segment_anything/prompt_encoder.py
  class PromptEncoder (line 9) | class PromptEncoder(nn.Module):
    method __init__ (line 10) | def __init__(
    method _embed_points (line 50) | def _embed_points(
    method _embed_boxes (line 82) | def _embed_boxes(self, boxes: mx.array, pe_layer: nn.Module) -> mx.array:
    method _embed_masks (line 91) | def _embed_masks(self, masks: mx.array) -> mx.array:
    method _get_batch_size (line 96) | def _get_batch_size(
    method __call__ (line 114) | def __call__(
  class MaskEmbed (line 172) | class MaskEmbed(nn.Module):
    method __init__ (line 173) | def __init__(self, embed_dim, mask_in_chans, activation):
    method __call__ (line 184) | def __call__(self, x):
  class PositionEmbeddingRandom (line 190) | class PositionEmbeddingRandom(nn.Module):
    method __init__ (line 195) | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = N...
    method _pe_encoding (line 201) | def _pe_encoding(self, coords: mx.array) -> mx.array:
    method __call__ (line 210) | def __call__(self, size: Tuple[int, int]) -> mx.array:
    method forward_with_coords (line 222) | def forward_with_coords(

FILE: segment_anything/segment_anything/sam.py
  class Sam (line 15) | class Sam(nn.Module):
    method __init__ (line 19) | def __init__(
    method __call__ (line 49) | def __call__(
    method postprocess_masks (line 134) | def postprocess_masks(
    method preprocess (line 172) | def preprocess(self, x: mx.array) -> mx.array:
  function load (line 193) | def load(model_path):

FILE: segment_anything/segment_anything/transformer.py
  class TwoWayTransformer (line 10) | class TwoWayTransformer(nn.Module):
    method __init__ (line 11) | def __init__(
    method __call__ (line 56) | def __call__(
  class TwoWayAttentionBlock (line 102) | class TwoWayAttentionBlock(nn.Module):
    method __init__ (line 103) | def __init__(
    method __call__ (line 144) | def __call__(
  class Attention (line 178) | class Attention(nn.Module):
    method __init__ (line 184) | def __init__(
    method _separate_heads (line 203) | def _separate_heads(self, x: mx.array, num_heads: int) -> mx.array:
    method _recombine_heads (line 208) | def _recombine_heads(self, x: mx.array) -> mx.array:
    method __call__ (line 213) | def __call__(self, q: mx.array, k: mx.array, v: mx.array) -> mx.array:

FILE: segment_anything/segment_anything/utils/amg.py
  class MaskData (line 10) | class MaskData:
    method __init__ (line 16) | def __init__(self, **kwargs) -> None:
    method __setitem__ (line 23) | def __setitem__(self, key: str, item: Any) -> None:
    method __delitem__ (line 29) | def __delitem__(self, key: str) -> None:
    method __getitem__ (line 32) | def __getitem__(self, key: str) -> Any:
    method items (line 35) | def items(self) -> ItemsView[str, Any]:
    method filter (line 38) | def filter(self, keep: mx.array) -> None:
    method cat (line 53) | def cat(self, new_stats: "MaskData") -> None:
    method to_numpy (line 66) | def to_numpy(self) -> None:
  function is_box_near_crop_edge (line 72) | def is_box_near_crop_edge(
  function box_xyxy_to_xywh (line 85) | def box_xyxy_to_xywh(box_xyxy: mx.array) -> mx.array:
  function batch_iterator (line 92) | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None,...
  function mask_to_rle_mlx (line 101) | def mask_to_rle_mlx(tensor: mx.array) -> List[Dict[str, Any]]:
  function rle_to_mask (line 134) | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
  function area_from_rle (line 148) | def area_from_rle(rle: Dict[str, Any]) -> int:
  function calculate_stability_score (line 152) | def calculate_stability_score(
  function build_point_grid (line 181) | def build_point_grid(n_per_side: int) -> np.ndarray:
  function build_all_layer_point_grids (line 191) | def build_all_layer_point_grids(
  function generate_crop_boxes (line 202) | def generate_crop_boxes(
  function uncrop_boxes_xyxy (line 239) | def uncrop_boxes_xyxy(boxes: mx.array, crop_box: List[int]) -> mx.array:
  function uncrop_points (line 248) | def uncrop_points(points: mx.array, crop_box: List[int]) -> mx.array:
  function uncrop_masks (line 257) | def uncrop_masks(
  function remove_small_regions (line 269) | def remove_small_regions(
  function coco_encode_rle (line 296) | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
  function batched_mask_to_box (line 305) | def batched_mask_to_box(masks: mx.array) -> mx.array:

FILE: segment_anything/segment_anything/utils/transforms.py
  class ResizeLongestSide (line 10) | class ResizeLongestSide:
    method __init__ (line 17) | def __init__(self, target_length: int) -> None:
    method apply_image (line 20) | def apply_image(self, image: np.ndarray) -> np.ndarray:
    method apply_coords (line 33) | def apply_coords(
    method apply_boxes (line 46) | def apply_boxes(self, boxes: mx.array, original_size: Tuple[int, ...])...
    method get_preprocess_shape (line 55) | def get_preprocess_shape(

FILE: speechcommands/kwt.py
  class FeedForward (line 8) | class FeedForward(nn.Sequential):
    method __init__ (line 9) | def __init__(self, dim, hidden_dim, dropout=0.0):
  class Attention (line 19) | class Attention(nn.Module):
    method __init__ (line 20) | def __init__(self, dim, heads, dropout=0.0):
    method __call__ (line 27) | def __call__(self, x):
  class Block (line 39) | class Block(nn.Module):
    method __init__ (line 40) | def __init__(self, dim, heads, mlp_dim, dropout=0.0):
    method __call__ (line 47) | def __call__(self, x):
  class Transformer (line 53) | class Transformer(nn.Module):
    method __init__ (line 54) | def __init__(self, dim, depth, heads, mlp_dim, dropout=0.0):
    method __call__ (line 61) | def __call__(self, x):
  class KWT (line 67) | class KWT(nn.Module):
    method __init__ (line 105) | def __init__(
    method num_params (line 139) | def num_params(self):
    method __call__ (line 143) | def __call__(self, x):
  function parse_kwt_args (line 162) | def parse_kwt_args(**kwargs):
  function kwt1 (line 170) | def kwt1(**kwargs):
  function kwt2 (line 185) | def kwt2(**kwargs):
  function kwt3 (line 200) | def kwt3(**kwargs):

FILE: speechcommands/main.py
  function prepare_dataset (line 27) | def prepare_dataset(batch_size, split, root=None):
  function eval_fn (line 55) | def eval_fn(model, x, y):
  function train_epoch (line 59) | def train_epoch(model, train_iter, optimizer, epoch):
  function test_epoch (line 111) | def test_epoch(model, test_iter):
  function main (line 129) | def main(args):

FILE: stable_diffusion/stable_diffusion/__init__.py
  class StableDiffusion (line 19) | class StableDiffusion:
    method __init__ (line 20) | def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
    method ensure_models_are_loaded (line 29) | def ensure_models_are_loaded(self):
    method _tokenize (line 34) | def _tokenize(self, tokenizer, text: str, negative_text: Optional[str]...
    method _get_text_conditioning (line 46) | def _get_text_conditioning(
    method _denoising_step (line 67) | def _denoising_step(
    method _denoising_loop (line 84) | def _denoising_loop(
    method generate_latents (line 102) | def generate_latents(
    method generate_latents_from_image (line 131) | def generate_latents_from_image(
    method decode (line 166) | def decode(self, x_t):
  class StableDiffusionXL (line 172) | class StableDiffusionXL(StableDiffusion):
    method __init__ (line 173) | def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
    method ensure_models_are_loaded (line 193) | def ensure_models_are_loaded(self):
    method _get_text_conditioning (line 199) | def _get_text_conditioning(
    method generate_latents (line 231) | def generate_latents(
    method generate_latents_from_image (line 269) | def generate_latents_from_image(

FILE: stable_diffusion/stable_diffusion/clip.py
  class CLIPOutput (line 15) | class CLIPOutput:
  class CLIPEncoderLayer (line 27) | class CLIPEncoderLayer(nn.Module):
    method __init__ (line 30) | def __init__(self, model_dims: int, num_heads: int, activation: str):
    method __call__ (line 48) | def __call__(self, x, attn_mask=None):
  class CLIPTextModel (line 62) | class CLIPTextModel(nn.Module):
    method __init__ (line 65) | def __init__(self, config: CLIPTextModelConfig):
    method _get_mask (line 81) | def _get_mask(self, N, dtype):
    method __call__ (line 87) | def __call__(self, x):

FILE: stable_diffusion/stable_diffusion/config.py
  class AutoencoderConfig (line 8) | class AutoencoderConfig:
  class CLIPTextModelConfig (line 20) | class CLIPTextModelConfig:
  class UNetConfig (line 31) | class UNetConfig:
  class DiffusionConfig (line 61) | class DiffusionConfig:

FILE: stable_diffusion/stable_diffusion/model_io.py
  function map_unet_weights (line 49) | def map_unet_weights(key, value):
  function map_clip_text_encoder_weights (line 98) | def map_clip_text_encoder_weights(key, value):
  function map_vae_weights (line 126) | def map_vae_weights(key, value):
  function _flatten (line 167) | def _flatten(params):
  function _load_safetensor_weights (line 171) | def _load_safetensor_weights(mapper, model, weight_file, float16: bool =...
  function _check_key (line 178) | def _check_key(key: str, part: str):
  function load_unet (line 185) | def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
  function load_text_encoder (line 229) | def load_text_encoder(
  function load_autoencoder (line 267) | def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
  function load_diffusion_config (line 297) | def load_diffusion_config(key: str = _DEFAULT_MODEL):
  function load_tokenizer (line 313) | def load_tokenizer(

FILE: stable_diffusion/stable_diffusion/sampler.py
  function _linspace (line 8) | def _linspace(a, b, num):
  function _interp (line 13) | def _interp(y, x_new):
  class SimpleEulerSampler (line 26) | class SimpleEulerSampler:
    method __init__ (line 32) | def __init__(self, config: DiffusionConfig):
    method max_time (line 53) | def max_time(self):
    method sample_prior (line 56) | def sample_prior(self, shape, dtype=mx.float32, key=None):
    method add_noise (line 62) | def add_noise(self, x, t, key=None):
    method sigmas (line 67) | def sigmas(self, t):
    method timesteps (line 70) | def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
    method step (line 76) | def step(self, eps_pred, x_t, t, t_prev):
  class SimpleEulerAncestralSampler (line 88) | class SimpleEulerAncestralSampler(SimpleEulerSampler):
    method step (line 89) | def step(self, eps_pred, x_t, t, t_prev):

FILE: stable_diffusion/stable_diffusion/tokenizer.py
  class Tokenizer (line 6) | class Tokenizer:
    method __init__ (line 9) | def __init__(self, bpe_ranks, vocab):
    method bos (line 20) | def bos(self):
    method bos_token (line 24) | def bos_token(self):
    method eos (line 28) | def eos(self):
    method eos_token (line 32) | def eos_token(self):
    method bpe (line 35) | def bpe(self, text):
    method tokenize (line 80) | def tokenize(self, text, prepend_bos=True, append_eos=True):

FILE: stable_diffusion/stable_diffusion/unet.py
  function upsample_nearest (line 12) | def upsample_nearest(x, scale: int = 2):
  class TimestepEmbedding (line 20) | class TimestepEmbedding(nn.Module):
    method __init__ (line 21) | def __init__(self, in_channels: int, time_embed_dim: int):
    method __call__ (line 27) | def __call__(self, x):
  class TransformerBlock (line 35) | class TransformerBlock(nn.Module):
    method __init__ (line 36) | def __init__(
    method __call__ (line 62) | def __call__(self, x, memory, attn_mask, memory_mask):
  class Transformer2D (line 84) | class Transformer2D(nn.Module):
    method __init__ (line 87) | def __init__(
    method __call__ (line 106) | def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
  class ResnetBlock2D (line 127) | class ResnetBlock2D(nn.Module):
    method __init__ (line 128) | def __init__(
    method __call__ (line 153) | def __call__(self, x, temb=None):
  class UNetBlock2D (line 173) | class UNetBlock2D(nn.Module):
    method __init__ (line 174) | def __init__(
    method __call__ (line 237) | def __call__(
  class UNetModel (line 270) | class UNetModel(nn.Module):
    method __init__ (line 273) | def __init__(self, config: UNetConfig):
    method __call__ (line 403) | def __call__(

FILE: stable_diffusion/stable_diffusion/vae.py
  class Attention (line 13) | class Attention(nn.Module):
    method __init__ (line 16) | def __init__(self, dims: int, norm_groups: int = 32):
    method __call__ (line 25) | def __call__(self, x):
  class EncoderDecoderBlock2D (line 45) | class EncoderDecoderBlock2D(nn.Module):
    method __init__ (line 46) | def __init__(
    method __call__ (line 79) | def __call__(self, x):
  class Encoder (line 93) | class Encoder(nn.Module):
    method __init__ (line 96) | def __init__(
    method __call__ (line 142) | def __call__(self, x):
  class Decoder (line 159) | class Decoder(nn.Module):
    method __init__ (line 162) | def __init__(
    method __call__ (line 209) | def __call__(self, x):
  class Autoencoder (line 226) | class Autoencoder(nn.Module):
    method __init__ (line 229) | def __init__(self, config: AutoencoderConfig):
    method decode (line 256) | def decode(self, z):
    method encode (line 260) | def encode(self, x):
    method __call__ (line 269) | def __call__(self, x, key=None):

FILE: t5/hf_t5.py
  function embed (line 6) | def embed(t5_model: str):
  function generate (line 25) | def generate(t5_model: str):

FILE: t5/t5.py
  class Tokenizer (line 14) | class Tokenizer:
    method __init__ (line 15) | def __init__(self, config, model_name):
    method eos_id (line 24) | def eos_id(self) -> int:
    method decoder_start_id (line 28) | def decoder_start_id(self) -> int:
    method encode (line 31) | def encode(self, s: str) -> mx.array:
    method decode (line 40) | def decode(self, t: List[int], with_sep: bool = True) -> str:
  function _relative_position_bucket (line 45) | def _relative_position_bucket(
  class RelativePositionBias (line 95) | class RelativePositionBias(nn.Module):
    method __init__ (line 96) | def __init__(self, config, bidirectional: bool):
    method __call__ (line 105) | def __call__(self, query_length: int, key_length: int, offset: int = 0):
  class MultiHeadAttention (line 126) | class MultiHeadAttention(nn.Module):
    method __init__ (line 127) | def __init__(self, config):
    method __call__ (line 136) | def __call__(
  class DenseActivation (line 170) | class DenseActivation(nn.Module):
    method __init__ (line 171) | def __init__(self, config):
    method __call__ (line 195) | def __call__(self, x):
  class TransformerEncoderLayer (line 205) | class TransformerEncoderLayer(nn.Module):
    method __init__ (line 206) | def __init__(self, config):
    method __call__ (line 213) | def __call__(self, x, mask):
  class TransformerEncoder (line 223) | class TransformerEncoder(nn.Module):
    method __init__ (line 224) | def __init__(self, config):
    method __call__ (line 232) | def __call__(self, x: mx.array):
  class TransformerDecoderLayer (line 239) | class TransformerDecoderLayer(nn.Module):
    method __init__ (line 240) | def __init__(self, config):
    method __call__ (line 249) | def __call__(
  class TransformerDecoder (line 272) | class TransformerDecoder(nn.Module):
    method __init__ (line 273) | def __init__(self, config):
    method __call__ (line 280) | def __call__(self, x, memory, mask, memory_mask, cache=None):
  class OutputHead (line 301) | class OutputHead(nn.Module):
    method __init__ (line 302) | def __init__(self, config):
    method __call__ (line 305) | def __call__(self, inputs):
  class T5 (line 309) | class T5(nn.Module):
    method __init__ (line 310) | def __init__(self, config):
    method encode (line 319) | def encode(self, inputs: mx.array):
    method decode (line 322) | def decode(
    method __call__ (line 346) | def __call__(
    method sanitize (line 354) | def sanitize(cls, weights):
    method from_pretrained (line 406) | def from_pretrained(
  function generate (line 431) | def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optiona...

FILE: transformer_lm/datasets.py
  function load_dataset (line 12) | def load_dataset(dataname):
  function _load (line 23) | def _load(save_dir, filenames):
  function wikitext (line 43) | def wikitext(dataset="2", save_dir="/tmp"):
  function ptb (line 66) | def ptb(save_dir="/tmp"):
  function enwik8 (line 92) | def enwik8(save_dir="/tmp"):

FILE: transformer_lm/main.py
  class TransformerLM (line 14) | class TransformerLM(nn.Module):
    method __init__ (line 15) | def __init__(
    method __call__ (line 32) | def __call__(self, x):
  function to_samples (line 41) | def to_samples(context_size, dataset):
  function iterate_batches (line 48) | def iterate_batches(batch_size, context_size, dataset):
  function main (line 62) | def main(args):

FILE: whisper/benchmark.py
  function parse_arguments (line 12) | def parse_arguments():
  function timer (line 28) | def timer(fn, *args):
  function feats (line 41) | def feats(n_mels: int = 80):
  function model_forward (line 49) | def model_forward(model, mels, tokens):
  function decode (line 55) | def decode(model, mels):
  function everything (line 59) | def everything(model_path):

FILE: whisper/convert.py
  function _download (line 62) | def _download(url: str, root: str) -> str:
  function available_models (line 107) | def available_models() -> List[str]:
  function hf_to_pt (line 112) | def hf_to_pt(weights, config):
  function load_torch_weights_and_config (line 152) | def load_torch_weights_and_config(
  function load_torch_model (line 195) | def load_torch_model(
  function convert (line 232) | def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
  function upload_to_hub (line 256) | def upload_to_hub(path: str, name: str, torch_name_or_path: str):
  function quantize (line 298) | def quantize(weights, config, args):

FILE: whisper/mlx_whisper/audio.py
  function load_audio (line 24) | def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_st...
  function pad_or_trim (line 66) | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
  function mel_filters (line 84) | def mel_filters(n_mels: int) -> mx.array:
  function hanning (line 102) | def hanning(size):
  function stft (line 106) | def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_...
  function log_mel_spectrogram (line 132) | def log_mel_spectrogram(

FILE: whisper/mlx_whisper/cli.py
  function build_parser (line 15) | def build_parser():
  function main (line 205) | def main():

FILE: whisper/mlx_whisper/decoding.py
  function compression_ratio (line 15) | def compression_ratio(text) -> float:
  function detect_language (line 20) | def detect_language(
  class DecodingOptions (line 83) | class DecodingOptions:
  class DecodingResult (line 120) | class DecodingResult:
  class Inference (line 132) | class Inference:
    method __init__ (line 133) | def __init__(self, model: "Whisper"):
    method logits (line 137) | def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
    method rearrange_kv_cache (line 144) | def rearrange_kv_cache(self, source_indices):
    method reset (line 150) | def reset(self):
  class SequenceRanker (line 154) | class SequenceRanker:
    method rank (line 155) | def rank(
  class MaximumLikelihoodRanker (line 165) | class MaximumLikelihoodRanker(SequenceRanker):
    method __init__ (line 171) | def __init__(self, length_penalty: Optional[float]):
    method rank (line 174) | def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[...
  class TokenDecoder (line 191) | class TokenDecoder:
    method reset (line 192) | def reset(self):
    method update (line 195) | def update(
    method finalize (line 225) | def finalize(
  function categorical (line 251) | def categorical(logits, temp):
  class GreedyDecoder (line 255) | class GreedyDecoder(TokenDecoder):
    method __init__ (line 256) | def __init__(self, temperature: float, eot: int):
    method update (line 260) | def update(
    method finalize (line 280) | def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
  class LogitFilter (line 286) | class LogitFilter:
    method apply (line 287) | def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
  class SuppressBlank (line 302) | class SuppressBlank(LogitFilter):
    method __init__ (line 303) | def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: i...
    method apply (line 309) | def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
  class SuppressTokens (line 315) | class SuppressTokens(LogitFilter):
    method __init__ (line 316) | def __init__(self, suppress_tokens: Sequence[int], n_vocab: int):
    method apply (line 321) | def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
  class ApplyTimestampRules (line 325) | class ApplyTimestampRules(LogitFilter):
    method __init__ (line 326) | def __init__(
    method apply (line 336) | def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
  class DecodingTask (line 398) | class DecodingTask:
    method __init__ (line 404) | def __init__(self, model: "Whisper", options: DecodingOptions):
    method _verify_options (line 465) | def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
    method _get_initial_tokens (line 480) | def _get_initial_tokens(self) -> Tuple[int]:
    method _get_suppress_tokens (line 508) | def _get_suppress_tokens(self) -> Tuple[int]:
    method _get_audio_features (line 537) | def _get_audio_features(self, mel: mx.array):
    method _detect_language (line 557) | def _detect_language(self, audio_features: mx.array, tokens: np.array):
    method _main_loop (line 572) | def _main_loop(self, audio_features: mx.array, tokens: mx.array):
    method run (line 619) | def run(self, mel: mx.array) -> List[DecodingResult]:
  function decode (line 709) | def decode(

FILE: whisper/mlx_whisper/load_models.py
  function load_model (line 14) | def load_model(

FILE: whisper/mlx_whisper/timing.py
  function median_filter (line 19) | def median_filter(x: np.ndarray, filter_width: int):
  function backtrace (line 48) | def backtrace(trace: np.ndarray):
  function dtw_cpu (line 73) | def dtw_cpu(x: np.ndarray):
  function dtw (line 98) | def dtw(x: np.ndarray) -> np.ndarray:
  class WordTiming (line 104) | class WordTiming:
  function find_alignment (line 112) | def find_alignment(
  function merge_punctuations (line 186) | def merge_punctuations(alignment: List[WordTiming], prepended: str, appe...
  function add_word_timestamps (line 220) | def add_word_timestamps(

FILE: whisper/mlx_whisper/tokenizer.py
  class Tokenizer (line 134) | class Tokenizer:
    method __post_init__ (line 144) | def __post_init__(self):
    method encode (line 163) | def encode(self, text, **kwargs):
    method decode (line 166) | def decode(self, token_ids: List[int], **kwargs) -> str:
    method decode_with_timestamps (line 170) | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
    method eot (line 178) | def eot(self) -> int:
    method transcribe (line 182) | def transcribe(self) -> int:
    method translate (line 186) | def translate(self) -> int:
    method sot (line 190) | def sot(self) -> int:
    method sot_lm (line 194) | def sot_lm(self) -> int:
    method sot_prev (line 198) | def sot_prev(self) -> int:
    method no_speech (line 202) | def no_speech(self) -> int:
    method no_timestamps (line 206) | def no_timestamps(self) -> int:
    method timestamp_begin (line 210) | def timestamp_begin(self) -> int:
    method language_token (line 214) | def language_token(self) -> int:
    method to_language_token (line 221) | def to_language_token(self, language):
    method all_language_tokens (line 228) | def all_language_tokens(self) -> Tuple[int]:
    method all_language_codes (line 236) | def all_language_codes(self) -> Tuple[str]:
    method sot_sequence_including_notimestamps (line 240) | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
    method non_speech_tokens (line 244) | def non_speech_tokens(self) -> Tuple[int]:
    method split_to_word_tokens (line 279) | def split_to_word_tokens(self, tokens: List[int]):
    method split_tokens_on_unicode (line 288) | def split_tokens_on_unicode(self, tokens: List[int]):
    method split_tokens_on_spaces (line 313) | def split_tokens_on_spaces(self, tokens: List[int]):
  function get_encoding (line 333) | def get_encoding(name: str = "gpt2", num_languages: int = 99):
  function get_tokenizer (line 370) | def get_tokenizer(

FILE: whisper/mlx_whisper/torch_whisper.py
  class ModelDimensions (line 15) | class ModelDimensions:
  class LayerNorm (line 28) | class LayerNorm(nn.LayerNorm):
    method forward (line 29) | def forward(self, x: Tensor) -> Tensor:
  class Linear (line 33) | class Linear(nn.Linear):
    method forward (line 34) | def forward(self, x: Tensor) -> Tensor:
  class Conv1d (line 42) | class Conv1d(nn.Conv1d):
    method _conv_forward (line 43) | def _conv_forward(
  function sinusoids (line 51) | def sinusoids(length, channels, max_timescale=10000):
  class MultiHeadAttention (line 60) | class MultiHeadAttention(nn.Module):
    method __init__ (line 61) | def __init__(self, n_state: int, n_head: int):
    method forward (line 69) | def forward(
    method qkv_attention (line 91) | def qkv_attention(
  class ResidualAttentionBlock (line 109) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 110) | def __init__(self, n_state: int, n_head: int, cross_attention: bool = ...
    method forward (line 127) | def forward(
  class AudioEncoder (line 141) | class AudioEncoder(nn.Module):
    method __init__ (line 142) | def __init__(
    method forward (line 155) | def forward(self, x: Tensor):
  class TextDecoder (line 174) | class TextDecoder(nn.Module):
    method __init__ (line 175) | def __init__(
    method forward (line 194) | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = No...
  class Whisper (line 219) | class Whisper(nn.Module):
    method __init__ (line 220) | def __init__(self, dims: ModelDimensions):
    method set_alignment_heads (line 245) | def set_alignment_heads(self, dump: bytes):
    method embed_audio (line 254) | def embed_audio(self, mel: torch.Tensor):
    method logits (line 257) | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
    method forward (line 260) | def forward(
    method device (line 266) | def device(self):
    method is_multilingual (line 270) | def is_multilingual(self):
    method num_languages (line 274) | def num_languages(self):
    method install_kv_cache_hooks (line 277) | def install_kv_cache_hooks(self, cache: Optional[dict] = None):

FILE: whisper/mlx_whisper/transcribe.py
  function _format_timestamp (line 26) | def _format_timestamp(seconds: float):
  function _get_end (line 43) | def _get_end(segments: List[dict]) -> Optional[float]:
  class ModelHolder (line 50) | class ModelHolder:
    method get_model (line 55) | def get_model(cls, model_path: str, dtype: mx.Dtype):
  function transcribe (line 62) | def transcribe(

FILE: whisper/mlx_whisper/whisper.py
  class ModelDimensions (line 18) | class ModelDimensions:
  function sinusoids (line 31) | def sinusoids(length, channels, max_timescale=10000):
  class MultiHeadAttention (line 40) | class MultiHeadAttention(nn.Module):
    method __init__ (line 41) | def __init__(self, n_state: int, n_head: int):
    method __call__ (line 49) | def __call__(
    method qkv_attention (line 73) | def qkv_attention(self, q, k, v, mask=None):
  class ResidualAttentionBlock (line 90) | class ResidualAttentionBlock(nn.Module):
    method __init__ (line 91) | def __init__(self, n_state: int, n_head: int, cross_attention: bool = ...
    method __call__ (line 107) | def __call__(self, x, xa=None, mask=None, kv_cache=None):
  class AudioEncoder (line 121) | class AudioEncoder(nn.Module):
    method __init__ (line 122) | def __init__(
    method __call__ (line 139) | def __call__(self, x):
  class TextDecoder (line 152) | class TextDecoder(nn.Module):
    method __init__ (line 153) | def __init__(
    method __call__ (line 176) | def __call__(self, x, xa, kv_cache=None):
  class Whisper (line 201) | class Whisper(nn.Module):
    method __init__ (line 202) | def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16):
    method set_alignment_heads (line 229) | def set_alignment_heads(self, dump: Union[bytes, np.ndarray]):
    method embed_audio (line 244) | def embed_audio(self, mel):
    method logits (line 247) | def logits(self, tokens, audio_features):
    method forward_with_cross_qk (line 250) | def forward_with_cross_qk(self, mel, tokens):
    method __call__ (line 254) | def __call__(self, mel, tokens):
    method is_multilingual (line 258) | def is_multilingual(self):
    method num_languages (line 262) | def num_languages(self):

FILE: whisper/mlx_whisper/writers.py
  function format_timestamp (line 9) | def format_timestamp(
  function get_start (line 30) | def get_start(segments: List[dict]) -> Optional[float]:
  class ResultWriter (line 37) | class ResultWriter:
    method __init__ (line 40) | def __init__(self, output_dir: str):
    method __call__ (line 43) | def __call__(
    method write_result (line 53) | def write_result(
  class WriteTXT (line 59) | class WriteTXT(ResultWriter):
    method write_result (line 62) | def write_result(
  class SubtitlesWriter (line 69) | class SubtitlesWriter(ResultWriter):
    method iterate_result (line 73) | def iterate_result(
    method format_timestamp (line 180) | def format_timestamp(self, seconds: float):
  class WriteVTT (line 188) | class WriteVTT(SubtitlesWriter):
    method write_result (line 193) | def write_result(
  class WriteSRT (line 201) | class WriteSRT(SubtitlesWriter):
    method write_result (line 206) | def write_result(
  class WriteTSV (line 215) | class WriteTSV(ResultWriter):
    method write_result (line 227) | def write_result(
  class WriteJSON (line 237) | class WriteJSON(ResultWriter):
    method write_result (line 240) | def write_result(
  function get_writer (line 246) | def get_writer(

FILE: whisper/test.py
  function _save_model (line 26) | def _save_model(save_dir, weights, config):
  function load_torch_and_mlx (line 41) | def load_torch_and_mlx():
  function forward_torch (line 63) | def forward_torch(model, mels, tokens):
  function forward_mlx (line 71) | def forward_mlx(model, mels, tokens):
  class TestWhisper (line 78) | class TestWhisper(unittest.TestCase):
    method setUpClass (line 80) | def setUpClass(cls):
    method test_torch_mlx (line 86) | def test_torch_mlx(self):
    method test_fp16 (line 101) | def test_fp16(self):
    method test_quantized_4bits (line 109) | def test_quantized_4bits(self):
    method test_decode_lang (line 118) | def test_decode_lang(self):
    method test_decode_greedy (line 127) | def test_decode_greedy(self):
    method test_transcribe (line 188) | def test_transcribe(self):
    method test_transcribe_alice (line 200) | def test_transcribe_alice(self):
    method test_transcribe_word_level_timestamps_confidence_scores (line 312) | def test_transcribe_word_level_timestamps_confidence_scores(self):
  class TestAudio (line 443) | class TestAudio(unittest.TestCase):
    method test_load (line 444) | def test_load(self):
    method test_pad (line 452) | def test_pad(self):
    method test_mel_spec (line 457) | def test_mel_spec(self):
Condensed preview — 216 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,962K chars).
[
  {
    "path": ".github/workflows/pull_request.yml",
    "chars": 490,
    "preview": "name: Test\n\non:\n  push:\n    branches: [\"main\"]\n  pull_request:\n\npermissions:\n  contents: read\n\nconcurrency:\n  group: ${{"
  },
  {
    "path": ".gitignore",
    "chars": 1870,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Vim\n*.swp\n\n# Distribut"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 238,
    "preview": "repos:\n-   repo: https://github.com/psf/black-pre-commit-mirror\n    rev: 25.1.0\n    hooks:\n    -   id: black\n-   repo: h"
  },
  {
    "path": "ACKNOWLEDGMENTS.md",
    "chars": 782,
    "preview": "# Individual Contributors\n\nIf you wish to be acknowledged for your contributions, please list your name\nwith a short des"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 5544,
    "preview": "# Contributor Covenant Code of Conduct\n\n## Our Pledge\n\nWe as members, contributors, and leaders pledge to make participa"
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1277,
    "preview": "# Contributing to mlx-examples\n\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Pu"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright © 2023 Apple Inc.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 2726,
    "preview": "# MLX Examples\n\nThis repo contains a variety of standalone examples using the [MLX\nframework](https://github.com/ml-expl"
  },
  {
    "path": "bert/README.md",
    "chars": 1242,
    "preview": "# BERT\n\nAn implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) in MLX.\n\n## Setup \n\nInsta"
  },
  {
    "path": "bert/convert.py",
    "chars": 1572,
    "preview": "import argparse\n\nimport numpy\nfrom transformers import AutoModel\n\n\ndef replace_key(key: str) -> str:\n    key = key.repla"
  },
  {
    "path": "bert/model.py",
    "chars": 5223,
    "preview": "import argparse\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\n\nimport mlx.core as mx\nimport mlx.nn a"
  },
  {
    "path": "bert/requirements.txt",
    "chars": 30,
    "preview": "mlx>=0.0.5\ntransformers\nnumpy\n"
  },
  {
    "path": "bert/test.py",
    "chars": 1959,
    "preview": "import argparse\nfrom typing import List\n\nimport model\nimport numpy as np\nfrom transformers import AutoModel, AutoTokeniz"
  },
  {
    "path": "bert/weights/.gitignore",
    "chars": 5,
    "preview": "*.npz"
  },
  {
    "path": "cifar/README.md",
    "chars": 1453,
    "preview": "# CIFAR and ResNets\n\nAn example of training a ResNet on CIFAR-10 with MLX. Several ResNet\nconfigurations in accordance w"
  },
  {
    "path": "cifar/dataset.py",
    "chars": 1071,
    "preview": "import mlx.core as mx\nimport numpy as np\nfrom mlx.data.datasets import load_cifar10\n\n\ndef get_cifar10(batch_size, root=N"
  },
  {
    "path": "cifar/main.py",
    "chars": 4815,
    "preview": "import argparse\nimport time\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport mlx.optimize"
  },
  {
    "path": "cifar/requirements.txt",
    "chars": 24,
    "preview": "mlx>=0.2\nmlx-data\nnumpy\n"
  },
  {
    "path": "cifar/resnet.py",
    "chars": 3420,
    "preview": "\"\"\"\nImplementation of ResNets for CIFAR-10 as per the original paper [https://arxiv.org/abs/1512.03385].\nConfigurations "
  },
  {
    "path": "clip/.gitignore",
    "chars": 11,
    "preview": "mlx_model/\n"
  },
  {
    "path": "clip/README.md",
    "chars": 2135,
    "preview": "# CLIP\n\nAn example of OpenAI's CLIP in MLX. The CLIP (contrastive language-image\npre-training) model embeds images and t"
  },
  {
    "path": "clip/clip.py",
    "chars": 1022,
    "preview": "from typing import Tuple\n\nfrom image_processor import CLIPImageProcessor\nfrom model import CLIPModel\nfrom tokenizer impo"
  },
  {
    "path": "clip/convert.py",
    "chars": 4055,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport json\nimport shutil\nfrom pathlib import Path\nfrom typing impor"
  },
  {
    "path": "clip/hf_preproc.py",
    "chars": 713,
    "preview": "import mlx.core as mx\nimport transformers\nfrom PIL import Image\n\nimport clip\n\nhf_model = \"openai/clip-vit-base-patch32\"\n"
  },
  {
    "path": "clip/image_processor.py",
    "chars": 2924,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport json\nfrom pathlib import Path\nfrom typing import List, Tuple\n\nimport mlx.core"
  },
  {
    "path": "clip/linear_probe.py",
    "chars": 1861,
    "preview": "# Mirror of the Linear Probe Evaluation Script\n# from the official CLIP Repository.\n\nimport mlx.core as mx\nimport numpy "
  },
  {
    "path": "clip/model.py",
    "chars": 14135,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport glob\nimport json\nimport logging\nimport math\nfrom dataclasses import dataclass"
  },
  {
    "path": "clip/requirements.txt",
    "chars": 61,
    "preview": "mlx\nmlx-data\nnumpy\ntransformers\ntorch\nhuggingface_hub\nPillow\n"
  },
  {
    "path": "clip/test.py",
    "chars": 5062,
    "preview": "import unittest\n\nimport mlx.core as mx\nimport model\nimport numpy as np\nimport torch\nimport transformers\nfrom image_proce"
  },
  {
    "path": "clip/tokenizer.py",
    "chars": 3704,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport json\nfrom pathlib import Path\nfrom typing import Any\n\nimport mlx.core as mx\ni"
  },
  {
    "path": "cvae/.gitignore",
    "chars": 8,
    "preview": "models/\n"
  },
  {
    "path": "cvae/README.md",
    "chars": 1708,
    "preview": "# Convolutional Variational Autoencoder (CVAE) on MNIST\n\nConvolutional variational autoencoder (CVAE) implementation in "
  },
  {
    "path": "cvae/dataset.py",
    "chars": 1630,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom mlx.data.datasets import load_mnist\n\n\ndef mnist(batch_size, img_size, root=None"
  },
  {
    "path": "cvae/main.py",
    "chars": 6684,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport time\nfrom functools import partial\nfrom pathlib import Path\n\n"
  },
  {
    "path": "cvae/requirements.txt",
    "chars": 31,
    "preview": "mlx>=0.2\nmlx-data\nnumpy\nPillow\n"
  },
  {
    "path": "cvae/vae.py",
    "chars": 5894,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\n# from https://github.com/m"
  },
  {
    "path": "encodec/README.md",
    "chars": 1978,
    "preview": "# EnCodec\n\nAn example of Meta's EnCodec model in MLX.[^1] EnCodec is used to compress and\ngenerate audio.\n\n### Setup\n\nIn"
  },
  {
    "path": "encodec/benchmarks/bench_mx.py",
    "chars": 633,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport time\n\nimport mlx.core as mx\n\nfrom encodec import EncodecModel\n\nmodel, processor = "
  },
  {
    "path": "encodec/benchmarks/bench_pt.py",
    "chars": 861,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport time\n\nimport numpy as np\nimport torch\nfrom transformers import AutoProcessor, Enco"
  },
  {
    "path": "encodec/convert.py",
    "chars": 6202,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport json\nfrom pathlib import Path\nfrom textwrap import dedent\nfrom typ"
  },
  {
    "path": "encodec/encodec.py",
    "chars": 24910,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport functools\nimport json\nimport math\nfrom pathlib import Path\nfrom types import Simpl"
  },
  {
    "path": "encodec/example.py",
    "chars": 1052,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nfrom utils import load_audio, save_audio\n\nfrom encodec import Encod"
  },
  {
    "path": "encodec/requirements.txt",
    "chars": 32,
    "preview": "mlx>=0.18\nnumpy\nhuggingface_hub\n"
  },
  {
    "path": "encodec/test.py",
    "chars": 2243,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport numpy as np\nimport torch\nfrom transformers import AutoProces"
  },
  {
    "path": "encodec/utils.py",
    "chars": 1483,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport numpy as np\n\n\ndef save_audio(file: str, audio: mx.array, sam"
  },
  {
    "path": "flux/README.md",
    "chars": 9791,
    "preview": "FLUX\n====\n\nFLUX implementation in MLX. The implementation is ported directly from\n[https://github.com/black-forest-labs/"
  },
  {
    "path": "flux/dreambooth.py",
    "chars": 9674,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport time\nfrom functools import partial\nfrom pathlib import Path\n\nimpor"
  },
  {
    "path": "flux/flux/__init__.py",
    "chars": 347,
    "preview": "# Copyright © 2024 Apple Inc.\n\nfrom .datasets import Dataset, load_dataset\nfrom .flux import FluxPipeline\nfrom .lora imp"
  },
  {
    "path": "flux/flux/autoencoder.py",
    "chars": 10796,
    "preview": "# Copyright © 2024 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom typing import List\n\nimport mlx.core as mx\nimport m"
  },
  {
    "path": "flux/flux/clip.py",
    "chars": 4730,
    "preview": "# Copyright © 2024 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport mlx.core as m"
  },
  {
    "path": "flux/flux/datasets.py",
    "chars": 2064,
    "preview": "import json\nfrom pathlib import Path\n\nfrom PIL import Image\n\n\nclass Dataset:\n    def __getitem__(self, index: int):\n    "
  },
  {
    "path": "flux/flux/flux.py",
    "chars": 7878,
    "preview": "# Copyright © 2024 Apple Inc.\n\nfrom typing import Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom mlx.utils import"
  },
  {
    "path": "flux/flux/layers.py",
    "chars": 10747,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport math\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing i"
  },
  {
    "path": "flux/flux/lora.py",
    "chars": 1928,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport math\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\nclass LoRALinear(nn.Module):\n   "
  },
  {
    "path": "flux/flux/model.py",
    "chars": 5996,
    "preview": "# Copyright © 2024 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport mlx.core as mx\nimpo"
  },
  {
    "path": "flux/flux/sampler.py",
    "chars": 1727,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport math\nfrom functools import lru_cache\n\nimport mlx.core as mx\n\n\nclass FluxSampler:\n "
  },
  {
    "path": "flux/flux/t5.py",
    "chars": 8657,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport math\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Tuple\n\ni"
  },
  {
    "path": "flux/flux/tokenizers.py",
    "chars": 5469,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport regex\nfrom sentencepiece import SentencePieceProcessor\n\n\ncla"
  },
  {
    "path": "flux/flux/trainer.py",
    "chars": 3389,
    "preview": "import mlx.core as mx\nimport numpy as np\nfrom PIL import Image, ImageFile\nfrom tqdm import tqdm\n\nfrom .datasets import D"
  },
  {
    "path": "flux/flux/utils.py",
    "chars": 6633,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport json\nimport os\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typ"
  },
  {
    "path": "flux/generate_interactive.py",
    "chars": 3382,
    "preview": "import argparse\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np\nfrom PIL import Image\nfrom tqdm import tqd"
  },
  {
    "path": "flux/requirements.txt",
    "chars": 66,
    "preview": "mlx>=0.18.1\nhuggingface-hub\nregex\nnumpy\ntqdm\nPillow\nsentencepiece\n"
  },
  {
    "path": "flux/txt2image.py",
    "chars": 6322,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np\nfrom PIL im"
  },
  {
    "path": "gcn/.gitignore",
    "chars": 6,
    "preview": "cora/\n"
  },
  {
    "path": "gcn/README.md",
    "chars": 372,
    "preview": "# Graph Convolutional Network\n\nAn example of [GCN](https://arxiv.org/abs/1609.02907) implementation with MLX.\n\n### Insta"
  },
  {
    "path": "gcn/datasets.py",
    "chars": 3518,
    "preview": "import os\nimport tarfile\n\nimport mlx.core as mx\nimport numpy as np\nimport requests\nimport scipy.sparse as sparse\n\n\"\"\"\nPr"
  },
  {
    "path": "gcn/gcn.py",
    "chars": 915,
    "preview": "import mlx.nn as nn\n\n\nclass GCNLayer(nn.Module):\n    def __init__(self, in_features, out_features, bias=True):\n        s"
  },
  {
    "path": "gcn/main.py",
    "chars": 3621,
    "preview": "import time\nfrom argparse import ArgumentParser\nfrom functools import partial\n\nimport mlx.core as mx\nimport mlx.nn as nn"
  },
  {
    "path": "gcn/requirements.txt",
    "chars": 56,
    "preview": "mlx>=0.0.4\nnumpy>=1.26.2\nscipy>=1.11.4\nrequests>=2.31.0\n"
  },
  {
    "path": "llava/.gitignore",
    "chars": 8,
    "preview": "**.ipynb"
  },
  {
    "path": "llava/README.md",
    "chars": 1361,
    "preview": "# LLaVA\n\nAn example of LLaVA: Large Language and Vision Assistant in MLX.[^1] LLlava is\na multimodal model that can gene"
  },
  {
    "path": "llava/generate.py",
    "chars": 4020,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport codecs\nfrom pathlib import Path\n\nimport mlx.core as mx\nimport requ"
  },
  {
    "path": "llava/language.py",
    "chars": 7039,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport inspect\nfrom dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple"
  },
  {
    "path": "llava/llava.py",
    "chars": 5730,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport glob\nimport inspect\nimport json\nfrom dataclasses import dataclass\nfrom pathlib imp"
  },
  {
    "path": "llava/requirements.txt",
    "chars": 59,
    "preview": "mlx>=0.8.0\nnumpy\ntransformers\ntorch\nhuggingface_hub\nPillow\n"
  },
  {
    "path": "llava/test.py",
    "chars": 5572,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport unittest\n\nimport mlx.core as mx\nimport requests\nimport torch\nfrom PIL import Image"
  },
  {
    "path": "llava/vision.py",
    "chars": 7415,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport inspect\nimport math\nfrom dataclasses import dataclass\nfrom typing import Optional\n"
  },
  {
    "path": "llms/README.md",
    "chars": 212,
    "preview": "# MOVE NOTICE \n\nThe mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).\n\nThe package has bee"
  },
  {
    "path": "llms/gguf_llm/README.md",
    "chars": 1568,
    "preview": "# LLMs in MLX with GGUF\n\nAn example generating text using GGUF format models in MLX.[^1]\n\n> [!NOTE]\n> MLX is able to rea"
  },
  {
    "path": "llms/gguf_llm/generate.py",
    "chars": 2253,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport time\n\nimport mlx.core as mx\nimport models\n\n\ndef generate(\n    mode"
  },
  {
    "path": "llms/gguf_llm/models.py",
    "chars": 11179,
    "preview": "# Copyright © 2023 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Dict, List,"
  },
  {
    "path": "llms/gguf_llm/requirements.txt",
    "chars": 62,
    "preview": "mlx>=0.8\nnumpy\nprotobuf==3.20.2\nsentencepiece\nhuggingface_hub\n"
  },
  {
    "path": "llms/gguf_llm/utils.py",
    "chars": 1897,
    "preview": "import sentencepiece as spm\nimport sentencepiece.sentencepiece_model_pb2 as model\n\n\ndef spm_tokenizer(metadata):\n    tok"
  },
  {
    "path": "llms/llama/README.md",
    "chars": 2043,
    "preview": "# Llama\n\nAn example of generating text with Llama (1 or 2) using MLX.\n\nLlama is a set of open source language models fro"
  },
  {
    "path": "llms/llama/convert.py",
    "chars": 7547,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport collections\nimport copy\nimport glob\nimport json\nimport shutil\nfrom"
  },
  {
    "path": "llms/llama/llama.py",
    "chars": 12833,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport glob\nimport json\nimport time\nfrom dataclasses import dataclass\nfro"
  },
  {
    "path": "llms/llama/requirements.txt",
    "chars": 38,
    "preview": "mlx>=0.11.0\nsentencepiece\ntorch\nnumpy\n"
  },
  {
    "path": "llms/llama/sample_prompt.txt",
    "chars": 1404,
    "preview": "[Instruction] Give the list of U.S. states bordering Canada\n[Answer] OK, here is the list of U.S. states located on the "
  },
  {
    "path": "llms/mistral/.gitignore",
    "chars": 17,
    "preview": "mistral-7B-v0.1/\n"
  },
  {
    "path": "llms/mistral/README.md",
    "chars": 1383,
    "preview": "# Mistral \n\nAn example of generating text with Mistral using MLX.\n\nMistral 7B is one of the top large language models in"
  },
  {
    "path": "llms/mistral/convert.py",
    "chars": 2622,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport copy\nimport json\nimport shutil\nfrom pathlib import Path\n\nimport ml"
  },
  {
    "path": "llms/mistral/mistral.py",
    "chars": 8640,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport json\nimport time\nfrom dataclasses import dataclass\nfrom pathlib im"
  },
  {
    "path": "llms/mistral/requirements.txt",
    "chars": 38,
    "preview": "mlx>=0.11.0\nsentencepiece\ntorch\nnumpy\n"
  },
  {
    "path": "llms/mistral/test.py",
    "chars": 3145,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport unittest\n\nimport mistral\nimport mlx.core as mx\nfrom mlx.utils import tree_map\n\n\ncl"
  },
  {
    "path": "llms/mixtral/README.md",
    "chars": 1930,
    "preview": "## Mixtral 8x7B\n\nRun the Mixtral[^mixtral] 8x7B mixture-of-experts (MoE) model in MLX on Apple silicon.\n\nThis example al"
  },
  {
    "path": "llms/mixtral/convert.py",
    "chars": 4248,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport copy\nimport glob\nimport json\nimport shutil\nfrom pathlib import Pat"
  },
  {
    "path": "llms/mixtral/mixtral.py",
    "chars": 9475,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport glob\nimport json\nfrom dataclasses import dataclass\nfrom pathlib im"
  },
  {
    "path": "llms/mixtral/params.json",
    "chars": 193,
    "preview": "{\"dim\": 4096, \"n_layers\": 32, \"head_dim\": 128, \"hidden_dim\": 14336, \"n_heads\": 32, \"n_kv_heads\": 8, \"norm_eps\": 1e-05, \""
  },
  {
    "path": "llms/mixtral/requirements.txt",
    "chars": 38,
    "preview": "mlx>=0.11.0\nsentencepiece\ntorch\nnumpy\n"
  },
  {
    "path": "llms/speculative_decoding/README.md",
    "chars": 2188,
    "preview": "# Speculative Decoding\n\nThis example implements speculative decoding with the T5 model for text\ngeneration.[^1][^2] Spec"
  },
  {
    "path": "llms/speculative_decoding/convert.py",
    "chars": 2203,
    "preview": "import numpy as np\nfrom transformers import T5ForConditionalGeneration\n\nSHARED_REPLACEMENT_PATTERNS = [\n    (\".block.\", "
  },
  {
    "path": "llms/speculative_decoding/decoder.py",
    "chars": 6187,
    "preview": "from typing import List\n\nimport mlx.core as mx\nimport transformers\nfrom model import Model\n\n\nclass Tokenizer:\n    def __"
  },
  {
    "path": "llms/speculative_decoding/main.py",
    "chars": 2616,
    "preview": "import argparse\nimport time\n\nimport mlx.core as mx\nfrom decoder import SpeculativeDecoder\nfrom mlx.utils import tree_unf"
  },
  {
    "path": "llms/speculative_decoding/model.py",
    "chars": 11948,
    "preview": "from typing import List, Optional, Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np\nfrom mlx.utils im"
  },
  {
    "path": "llms/speculative_decoding/requirements.txt",
    "chars": 30,
    "preview": "mlx>=0.8.0\ntransformers\nnumpy\n"
  },
  {
    "path": "lora/.gitignore",
    "chars": 13,
    "preview": "adapters.npz\n"
  },
  {
    "path": "lora/README.md",
    "chars": 6797,
    "preview": "# Fine-Tuning with LoRA or QLoRA\n\nThis is an example of using MLX to fine-tune an LLM with low rank adaptation\n(LoRA) fo"
  },
  {
    "path": "lora/convert.py",
    "chars": 2506,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport copy\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport models"
  },
  {
    "path": "lora/data/test.jsonl",
    "chars": 27862,
    "preview": "{\"text\": \"table: 1-10015132-16\\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\\nQ: What"
  },
  {
    "path": "lora/data/train.jsonl",
    "chars": 272537,
    "preview": "{\"text\": \"table: 1-1000181-1\\ncolumns: State/territory, Text/background colour, Format, Current slogan, Current series, "
  },
  {
    "path": "lora/data/valid.jsonl",
    "chars": 28495,
    "preview": "{\"text\": \"table: 1-10015132-11\\ncolumns: Player, No., Nationality, Position, Years in Toronto, School/Club Team\\nQ: What"
  },
  {
    "path": "lora/data/wikisql.py",
    "chars": 3855,
    "preview": "# Copyright © 2023 Apple Inc.\n\n\"\"\"\nCode to preprocess the WikiSQL dataset adapted from\nhttps://github.com/salesforce/Wik"
  },
  {
    "path": "lora/fuse.py",
    "chars": 3740,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nfrom pathlib import Path\n\nimport mlx.core as mx\nimport mlx.nn as nn\n"
  },
  {
    "path": "lora/lora.py",
    "chars": 12031,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport json\nimport math\nimport os\nimport sys\nimport time\nfrom pathli"
  },
  {
    "path": "lora/models.py",
    "chars": 8835,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport inspect\nimport math\nfrom dataclasses import dataclass\nfrom typing import Dict, Opt"
  },
  {
    "path": "lora/requirements.txt",
    "chars": 30,
    "preview": "mlx>=0.8.0\ntransformers\nnumpy\n"
  },
  {
    "path": "lora/utils.py",
    "chars": 5943,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport glob\nimport json\nimport logging\nfrom pathlib import Path\nfrom typing import A"
  },
  {
    "path": "mnist/README.md",
    "chars": 415,
    "preview": "# MNIST\n\nThis example shows how to run some simple models on MNIST. \n\nInstall the dependencies:\n\n```\npip install -r requ"
  },
  {
    "path": "mnist/main.py",
    "chars": 2791,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport time\nfrom functools import partial\n\nimport mlx.core as mx\nimport m"
  },
  {
    "path": "mnist/mnist.py",
    "chars": 2660,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport gzip\nimport os\nimport pickle\nfrom urllib import request\n\nimport numpy as np\n\n\ndef "
  },
  {
    "path": "mnist/requirements.txt",
    "chars": 15,
    "preview": "mlx>=0.2\nnumpy\n"
  },
  {
    "path": "musicgen/README.md",
    "chars": 639,
    "preview": "# MusicGen\n\nAn example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate\nmusic from text descriptions.\n\n"
  },
  {
    "path": "musicgen/benchmarks/bench_mx.py",
    "chars": 549,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport sys\nimport time\nfrom pathlib import Path\n\nimport mlx.core as mx\n\ncur_path = Path(_"
  },
  {
    "path": "musicgen/benchmarks/bench_pt.py",
    "chars": 822,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport time\n\nimport torch\nfrom transformers import AutoProcessor, MusicgenForConditionalG"
  },
  {
    "path": "musicgen/generate.py",
    "chars": 817,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\n\nfrom utils import save_audio\n\nfrom musicgen import MusicGen\n\n\ndef main(t"
  },
  {
    "path": "musicgen/musicgen.py",
    "chars": 13165,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport json\nfrom functools import partial\nfrom pathlib import Path\nfrom types import Simp"
  },
  {
    "path": "musicgen/requirements.txt",
    "chars": 57,
    "preview": "mlx>=0.18\nnumpy\nhuggingface_hub\ntorch\ntransformers\nscipy\n"
  },
  {
    "path": "musicgen/utils.py",
    "chars": 359,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport mlx.core as mx\nimport numpy as np\n\n\ndef save_audio(file: str, audio: mx.array, sam"
  },
  {
    "path": "normalizing_flow/README.md",
    "chars": 1073,
    "preview": "# Normalizing Flow\n\nAn example of a normalizing flow for density estimation and sampling\nimplemented in MLX. This exampl"
  },
  {
    "path": "normalizing_flow/bijectors.py",
    "chars": 2152,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom typing import Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\nclass Bijector"
  },
  {
    "path": "normalizing_flow/distributions.py",
    "chars": 849,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nfrom typing import Optional, Tuple, Union\n\nimport mlx.core as mx\n\n\nclass"
  },
  {
    "path": "normalizing_flow/flows.py",
    "chars": 2482,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom typing import Optional, Tuple, Union\n\nimport mlx.core as mx\nimport mlx.nn as nn"
  },
  {
    "path": "normalizing_flow/main.py",
    "chars": 3655,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom functools import partial\n\nimport matplotlib.pyplot as plt\nimport mlx.core as mx"
  },
  {
    "path": "normalizing_flow/requirements.txt",
    "chars": 44,
    "preview": "mlx>=0.2\nnumpy\ntqdm\nscikit-learn\nmatplotlib\n"
  },
  {
    "path": "segment_anything/README.md",
    "chars": 971,
    "preview": "# Segment Anything\n\nAn implementation of the Segment Anything Model (SAM) in MLX. See the original\nrepo by Meta AI for m"
  },
  {
    "path": "segment_anything/convert.py",
    "chars": 2719,
    "preview": "import argparse\nimport json\nimport shutil\nfrom pathlib import Path\nfrom typing import Dict, Union\n\nimport mlx.core as mx"
  },
  {
    "path": "segment_anything/main.py",
    "chars": 6507,
    "preview": "import argparse\nimport json\nimport os\nimport sys\nfrom typing import Any, Dict, List\n\nimport cv2\n\nfrom segment_anything i"
  },
  {
    "path": "segment_anything/notebooks/automatic_mask_generator_example.ipynb",
    "chars": 6846,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Automatically generating object m"
  },
  {
    "path": "segment_anything/notebooks/predictor_example.ipynb",
    "chars": 17389,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"## Segmenting from Prompts\\n\",\n    "
  },
  {
    "path": "segment_anything/requirements.txt",
    "chars": 41,
    "preview": "matplotlib\nopencv-python\nhuggingface_hub\n"
  },
  {
    "path": "segment_anything/segment_anything/__init__.py",
    "chars": 64,
    "preview": "from .automatic_mask_generator import SamAutomaticMaskGenerator\n"
  },
  {
    "path": "segment_anything/segment_anything/automatic_mask_generator.py",
    "chars": 16408,
    "preview": "from typing import Any, Dict, List, Optional, Tuple\n\nimport mlx.core as mx\nimport numpy as np\n\nfrom .predictor import Sa"
  },
  {
    "path": "segment_anything/segment_anything/common.py",
    "chars": 969,
    "preview": "from typing import Type\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\n\nclass MLPBlock(nn.Module):\n    def __init__(\n      "
  },
  {
    "path": "segment_anything/segment_anything/image_encoder.py",
    "chars": 14671,
    "preview": "from typing import Optional, Tuple, Type\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .common import LayerNorm2d, ML"
  },
  {
    "path": "segment_anything/segment_anything/mask_decoder.py",
    "chars": 8340,
    "preview": "import math\nfrom typing import List, Tuple, Type, Union\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .common import "
  },
  {
    "path": "segment_anything/segment_anything/predictor.py",
    "chars": 6792,
    "preview": "from typing import Optional, Tuple\n\nimport mlx.core as mx\nimport numpy as np\n\nfrom .sam import Sam\nfrom .utils.transform"
  },
  {
    "path": "segment_anything/segment_anything/prompt_encoder.py",
    "chars": 8523,
    "preview": "from typing import Optional, Tuple, Type\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .common import LayerNorm2d\n\n\nc"
  },
  {
    "path": "segment_anything/segment_anything/sam.py",
    "chars": 9588,
    "preview": "import json\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Tuple\n\nimport mlx"
  },
  {
    "path": "segment_anything/segment_anything/transformer.py",
    "chars": 8327,
    "preview": "import math\nfrom typing import Tuple, Type\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .common import MLPBlock\n\n\ncl"
  },
  {
    "path": "segment_anything/segment_anything/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "segment_anything/segment_anything/utils/amg.py",
    "chars": 12208,
    "preview": "import math\nfrom copy import deepcopy\nfrom itertools import product\nfrom typing import Any, Dict, Generator, ItemsView, "
  },
  {
    "path": "segment_anything/segment_anything/utils/transforms.py",
    "chars": 2147,
    "preview": "from copy import deepcopy\nfrom typing import Tuple\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np\nfrom PI"
  },
  {
    "path": "speechcommands/README.md",
    "chars": 1866,
    "preview": "# Train a Keyword Spotting Transformer on Speech Commands\n\nAn example of training a Keyword Spotting Transformer[^1] on "
  },
  {
    "path": "speechcommands/kwt.py",
    "chars": 5908,
    "preview": "import mlx.core as mx\nimport mlx.nn as nn\nfrom mlx.utils import tree_flatten\n\n__all__ = [\"KWT\", \"kwt1\", \"kwt2\", \"kwt3\"]\n"
  },
  {
    "path": "speechcommands/main.py",
    "chars": 5335,
    "preview": "import argparse\nimport time\nfrom functools import partial\n\nimport kwt\nimport mlx.core as mx\nimport mlx.nn as nn\nimport m"
  },
  {
    "path": "speechcommands/requirements.txt",
    "chars": 18,
    "preview": "mlx>=0.2\nmlx-data\n"
  },
  {
    "path": "stable_diffusion/README.md",
    "chars": 4094,
    "preview": "Stable Diffusion\n================\n\nStable Diffusion in MLX. The implementation was ported from Hugging Face's\n[diffusers"
  },
  {
    "path": "stable_diffusion/image2image.py",
    "chars": 4980,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\nimport math\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np"
  },
  {
    "path": "stable_diffusion/requirements.txt",
    "chars": 50,
    "preview": "mlx>=0.11\nhuggingface-hub\nregex\nnumpy\ntqdm\nPillow\n"
  },
  {
    "path": "stable_diffusion/stable_diffusion/__init__.py",
    "chars": 9523,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport time\nfrom typing import Optional, Tuple\n\nimport mlx.core as mx\n\nfrom .model_i"
  },
  {
    "path": "stable_diffusion/stable_diffusion/clip.py",
    "chars": 3732,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional\n\nimport mlx.core"
  },
  {
    "path": "stable_diffusion/stable_diffusion/config.py",
    "chars": 1772,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom dataclasses import dataclass\nfrom typing import Optional, Tuple\n\n\n@dataclass\ncl"
  },
  {
    "path": "stable_diffusion/stable_diffusion/model_io.py",
    "chars": 11667,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport json\nfrom typing import Optional\n\nimport mlx.core as mx\nfrom huggingface_hub "
  },
  {
    "path": "stable_diffusion/stable_diffusion/sampler.py",
    "chars": 3396,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport mlx.core as mx\n\nfrom .config import DiffusionConfig\n\n\ndef _linspace(a, b, num):\n  "
  },
  {
    "path": "stable_diffusion/stable_diffusion/tokenizer.py",
    "chars": 2990,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport regex\n\n\nclass Tokenizer:\n    \"\"\"A simple port of CLIPTokenizer from https://github"
  },
  {
    "path": "stable_diffusion/stable_diffusion/unet.py",
    "chars": 14761,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import Optional\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom "
  },
  {
    "path": "stable_diffusion/stable_diffusion/vae.py",
    "chars": 8043,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport math\nfrom typing import List\n\nimport mlx.core as mx\nimport mlx.nn as nn\n\nfrom .con"
  },
  {
    "path": "stable_diffusion/txt2image.py",
    "chars": 3914,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport argparse\n\nimport mlx.core as mx\nimport mlx.nn as nn\nimport numpy as np\nfrom PIL im"
  },
  {
    "path": "t5/.gitignore",
    "chars": 6,
    "preview": "*.npz\n"
  },
  {
    "path": "t5/README.md",
    "chars": 1309,
    "preview": "# T5\n\nThe T5 models are encoder-decoder models pre-trained on a mixture of\nunsupervised and supervised tasks.[^1] These "
  },
  {
    "path": "t5/hf_t5.py",
    "chars": 1892,
    "preview": "import argparse\n\nfrom transformers import AutoModelForSeq2SeqLM, AutoTokenizer, T5EncoderModel\n\n\ndef embed(t5_model: str"
  },
  {
    "path": "t5/requirements.txt",
    "chars": 30,
    "preview": "mlx>=0.8.0\nnumpy\ntransformers\n"
  },
  {
    "path": "t5/t5.py",
    "chars": 17934,
    "preview": "import argparse\nimport json\nfrom pathlib import Path\nfrom time import perf_counter_ns\nfrom types import SimpleNamespace\n"
  },
  {
    "path": "transformer_lm/README.md",
    "chars": 316,
    "preview": "# Transformer LM \n\nThis is an example of a decoder-only Transformer LM. The only dependency is\nMLX. \n\nRun the example on"
  },
  {
    "path": "transformer_lm/datasets.py",
    "chars": 3842,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport io\nimport itertools\nimport os\nimport zipfile\nfrom urllib import request\n\nimport nu"
  },
  {
    "path": "transformer_lm/main.py",
    "chars": 6705,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport math\nimport time\nfrom functools import partial\n\nimport datasets\nimport mlx.co"
  },
  {
    "path": "transformer_lm/requirements.txt",
    "chars": 9,
    "preview": "mlx>=0.2\n"
  },
  {
    "path": "whisper/MANIFEST.in",
    "chars": 170,
    "preview": "include mlx_whisper/requirements.txt\ninclude mlx_whisper/assets/mel_filters.npz\ninclude mlx_whisper/assets/multilingual."
  },
  {
    "path": "whisper/README.md",
    "chars": 3455,
    "preview": "# Whisper\n\nSpeech recognition with Whisper in MLX. Whisper is a set of open source speech\nrecognition models from OpenAI"
  },
  {
    "path": "whisper/benchmark.py",
    "chars": 3058,
    "preview": "# Copyright © 2023-2024 Apple Inc.\nimport argparse\nimport os\nimport time\n\nimport mlx.core as mx\nfrom mlx_whisper import "
  },
  {
    "path": "whisper/convert.py",
    "chars": 13931,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport argparse\nimport copy\nimport hashlib\nimport json\nimport os\nimport urllib\nimpor"
  },
  {
    "path": "whisper/mlx_whisper/__init__.py",
    "chars": 148,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nfrom . import audio, decoding, load_models\nfrom ._version import __version__\nfrom .t"
  },
  {
    "path": "whisper/mlx_whisper/_version.py",
    "chars": 58,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\n__version__ = \"0.4.3\"\n"
  },
  {
    "path": "whisper/mlx_whisper/assets/download_alice.sh",
    "chars": 281,
    "preview": "#!/bin/bash\n\naudio_file=$HOME/.cache/whisper/alice.mp3\necho $audio_file\nzipf=alice_in_wonderland_librivox_64kb_mp3.zip\nu"
  },
  {
    "path": "whisper/mlx_whisper/assets/gpt2.tiktoken",
    "chars": 835554,
    "preview": "IQ== 0\nIg== 1\nIw== 2\nJA== 3\nJQ== 4\nJg== 5\nJw== 6\nKA== 7\nKQ== 8\nKg== 9\nKw== 10\nLA== 11\nLQ== 12\nLg== 13\nLw== 14\nMA== 15\nMQ"
  },
  {
    "path": "whisper/mlx_whisper/assets/multilingual.tiktoken",
    "chars": 816730,
    "preview": "IQ== 0\nIg== 1\nIw== 2\nJA== 3\nJQ== 4\nJg== 5\nJw== 6\nKA== 7\nKQ== 8\nKg== 9\nKw== 10\nLA== 11\nLQ== 12\nLg== 13\nLw== 14\nMA== 15\nMQ"
  },
  {
    "path": "whisper/mlx_whisper/audio.py",
    "chars": 5089,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport os\nfrom functools import lru_cache\nfrom subprocess import CalledProcessError, run\n"
  },
  {
    "path": "whisper/mlx_whisper/cli.py",
    "chars": 8630,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport argparse\nimport os\nimport pathlib\nimport traceback\nimport warnings\n\nfrom . import "
  },
  {
    "path": "whisper/mlx_whisper/decoding.py",
    "chars": 28493,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport zlib\nfrom dataclasses import dataclass, field, replace\nfrom typing import Dict, It"
  },
  {
    "path": "whisper/mlx_whisper/load_models.py",
    "chars": 1426,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport json\nfrom pathlib import Path\n\nimport mlx.core as mx\nimport mlx.nn as nn\nfrom hugg"
  },
  {
    "path": "whisper/mlx_whisper/requirements.txt",
    "chars": 79,
    "preview": "mlx>=0.11\nnumba\nnumpy\ntorch\ntqdm\nmore-itertools\ntiktoken\nhuggingface_hub\nscipy\n"
  },
  {
    "path": "whisper/mlx_whisper/timing.py",
    "chars": 10865,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport itertools\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, List"
  },
  {
    "path": "whisper/mlx_whisper/tokenizer.py",
    "chars": 12368,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport base64\nimport os\nimport string\nfrom dataclasses import dataclass, field\nfrom funct"
  },
  {
    "path": "whisper/mlx_whisper/torch_whisper.py",
    "chars": 10591,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport base64\nimport gzip\nfrom dataclasses import dataclass\nfrom typing import Dict, Iter"
  },
  {
    "path": "whisper/mlx_whisper/transcribe.py",
    "chars": 22982,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport sys\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport mlx.co"
  },
  {
    "path": "whisper/mlx_whisper/whisper.py",
    "chars": 8631,
    "preview": "# Copyright © 2023 Apple Inc.\n\nimport base64\nimport gzip\nimport math\nfrom dataclasses import dataclass\nfrom typing impor"
  },
  {
    "path": "whisper/mlx_whisper/writers.py",
    "chars": 10137,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport json\nimport pathlib\nimport re\nfrom typing import Callable, List, Optional, TextIO\n"
  },
  {
    "path": "whisper/setup.py",
    "chars": 1031,
    "preview": "# Copyright © 2024 Apple Inc.\n\nimport sys\nfrom pathlib import Path\n\nfrom setuptools import find_namespace_packages, setu"
  },
  {
    "path": "whisper/test.py",
    "chars": 14410,
    "preview": "# Copyright © 2023-2024 Apple Inc.\n\nimport json\nimport os\nimport unittest\nfrom dataclasses import asdict\nfrom pathlib im"
  },
  {
    "path": "wwdc25/Explore_language_models_on_Apple_silicon_with_MLX.ipynb",
    "chars": 12858,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"id\": \"74bc2ccb\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Explore larg"
  }
]

// ... and 16 more files (download for full content)

About this extraction

This page contains the full source code of the ml-explore/mlx-examples GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 216 files (4.5 MB), approximately 1.2M tokens, and a symbol index with 1221 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!