Full Code of vikhyat/moondream for AI

main 9fe3ad77616b cached
66 files
343.1 KB
82.9k tokens
212 symbols
1 requests
Download .txt
Showing preview only (363K chars total). Download the full file or copy to clipboard to get everything.
Repository: vikhyat/moondream
Branch: main
Commit: 9fe3ad77616b
Files: 66
Total size: 343.1 KB

Directory structure:
gitextract_g04l91uk/

├── .github/
│   └── workflows/
│       ├── pylint.yml
│       └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── batch_generate_example.py
├── gradio_demo.py
├── moondream/
│   ├── __init__.py
│   ├── config/
│   │   ├── config_md05.json
│   │   └── config_md2.json
│   ├── eval/
│   │   ├── chartqa.py
│   │   ├── coco_map.py
│   │   ├── countbenchqa.py
│   │   ├── docvqa.py
│   │   ├── eval_all.py
│   │   ├── gazefollow.py
│   │   ├── mmstar.py
│   │   ├── naturalbench.py
│   │   ├── pope.py
│   │   ├── realworldqa.py
│   │   ├── tallyqa.py
│   │   ├── textvqa.py
│   │   ├── utils.py
│   │   └── waste_detection.py
│   └── torch/
│       ├── config.py
│       ├── hf_moondream.py
│       ├── hf_release.py
│       ├── image_crops.py
│       ├── layers.py
│       ├── lora.py
│       ├── moondream.py
│       ├── region.py
│       ├── rope.py
│       ├── sample.py
│       ├── text.py
│       ├── utils.py
│       ├── vision.py
│       └── weights.py
├── notebooks/
│   └── RepEng.ipynb
├── recipes/
│   ├── gaze-detection-video/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── gaze-detection-video.py
│   │   ├── input/
│   │   │   └── .gitkeep
│   │   ├── output/
│   │   │   └── .gitkeep
│   │   ├── requirements.txt
│   │   └── temp/
│   │       └── .gitkeep
│   ├── promptable-content-moderation/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── app.py
│   │   ├── deep_sort_integration.py
│   │   ├── main.py
│   │   ├── packages.txt
│   │   ├── persistence.py
│   │   ├── requirements.txt
│   │   ├── video_visualization.py
│   │   └── visualization.py
│   └── promptable-video-redaction/
│       ├── .gitignore
│       ├── README.md
│       ├── app.py
│       ├── main.py
│       ├── packages.txt
│       └── requirements.txt
├── requirements.txt
├── sample.py
├── tests/
│   └── test_image_crops.py
└── webcam_gradio_demo.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/pylint.yml
================================================
name: Lint

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]

permissions:
  contents: read
jobs:
  build:
    runs-on: ubuntu-latest
    permissions:
      contents: read
    strategy:
      matrix:
        python-version: ["3.12"]  # Run lint checks only on latest Python version
    steps:
    - uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v4
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install autoflake black
    - name: Checking for unused imports
      run: |
        autoflake -c -r .
    - name: Checking code style
      run: |
        black --check .


================================================
FILE: .github/workflows/test.yml
================================================
name: Tests

on:
  push:
    branches: [ main ]
  pull_request:
    branches: [ main ]

jobs:
  test:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: ["3.9", "3.10", "3.11", "3.12"]

    steps:
    - uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        pip install pytest
        pip install -r requirements.txt
    - name: Run tests
      run: |
        python -m pytest tests/test_image_crops.py -v

================================================
FILE: .gitignore
================================================
.venv
__pycache__
checkpoints
data
/pyproject.toml
poetry.lock
dist
clients/python/moondream/torch
wandb/
moondream_finetune.safetensors


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

================================================
FILE: README.md
================================================
# 🌔 moondream

a tiny vision language model that kicks ass and runs anywhere

[Website](https://moondream.ai/) | [Demo](https://moondream.ai/playground)

## Examples

| Image                  | Example                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |
| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ![](assets/demo-1.jpg) | **What is the girl doing?**<br>The girl is sitting at a table and eating a large hamburger.<br><br>**What color is the girl's hair?**<br>The girl's hair is white.                                                                                                                                                                                                                                                                                                                                                                                                    |
| ![](assets/demo-2.jpg) | **What is this?**<br>This is a computer server rack, which is a device used to store and manage multiple computer servers. The rack is filled with various computer servers, each with their own dedicated space and power supply. The servers are connected to the rack via multiple cables, indicating that they are part of a larger system. The rack is placed on a carpeted floor, and there is a couch nearby, suggesting that the setup is in a living or entertainment area.<br><br>**What is behind the stand?**<br>Behind the stand, there is a brick wall. |

## About

Moondream is a highly efficient open-source vision language model that combines powerful image understanding capabilities with a remarkably small footprint. It's designed to be versatile and accessible, capable of running on a wide range of devices and platforms.

The project offers two model variants:

- **Moondream 2B**: The primary model with 2 billion parameters, offering robust performance for general-purpose image understanding tasks including captioning, visual question answering, and object detection.
- **Moondream 0.5B**: A compact 500 million parameter model specifically optimized as a distillation target for edge devices, enabling efficient deployment on resource-constrained hardware while maintaining impressive capabilities.

## How to use

Moondream can be run locally, or in the cloud. Please refer to the [Getting Started](https://moondream.ai/c/docs/quickstart) page for details.

## Special thanks

* [Modal](https://modal.com/?utm_source=github&utm_medium=github&utm_campaign=moondream) - Modal lets you run jobs in the cloud, by just writing a few lines of Python. Here's an [example of how to run Moondream on Modal](https://github.com/m87-labs/moondream-examples/tree/main/quickstart/modal).


================================================
FILE: batch_generate_example.py
================================================
from PIL import Image
from transformers import AutoTokenizer

from moondream.hf import LATEST_REVISION, Moondream, detect_device

device, dtype = detect_device()

model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
moondream = Moondream.from_pretrained(
    model_id,
    revision=LATEST_REVISION,
    torch_dtype=dtype,
).to(device=device)
moondream.eval()

image1 = Image.open("assets/demo-1.jpg")
image2 = Image.open("assets/demo-2.jpg")
prompts = [
    "What is the girl doing?",
    "What color is the girl's hair?",
    "What is this?",
    "What is behind the stand?",
]

answers = moondream.batch_answer(
    images=[image1, image1, image2, image2],
    prompts=prompts,
    tokenizer=tokenizer,
)

for question, answer in zip(prompts, answers):
    print(f"Q: {question}")
    print(f"A: {answer}")
    print()


================================================
FILE: gradio_demo.py
================================================
import argparse
import re
from threading import Thread

import gradio as gr
import torch
from PIL import ImageDraw
from torchvision.transforms.v2 import Resize
from transformers import AutoTokenizer, TextIteratorStreamer

from moondream.hf import LATEST_REVISION, Moondream, detect_device

parser = argparse.ArgumentParser()
parser.add_argument("--cpu", action="store_true")
args = parser.parse_args()

if args.cpu:
    device = torch.device("cpu")
    dtype = torch.float32
else:
    device, dtype = detect_device()
    if device != torch.device("cpu"):
        print("Using device:", device)
        print("If you run into issues, pass the `--cpu` flag to this script.")
        print()

model_id = "vikhyatk/moondream2"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)
moondream = Moondream.from_pretrained(
    model_id, revision=LATEST_REVISION, torch_dtype=dtype
).to(device=device)
moondream.eval()


def answer_question(img, prompt):
    image_embeds = moondream.encode_image(img)
    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    thread = Thread(
        target=moondream.answer_question,
        kwargs={
            "image_embeds": image_embeds,
            "question": prompt,
            "tokenizer": tokenizer,
            "streamer": streamer,
        },
    )
    thread.start()

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        yield buffer


def extract_floats(text):
    # Regular expression to match an array of four floating point numbers
    pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]"
    match = re.search(pattern, text)
    if match:
        # Extract the numbers and convert them to floats
        return [float(num) for num in match.groups()]
    return None  # Return None if no match is found


def extract_bbox(text):
    bbox = None
    if extract_floats(text) is not None:
        x1, y1, x2, y2 = extract_floats(text)
        bbox = (x1, y1, x2, y2)
    return bbox


def process_answer(img, answer):
    if extract_bbox(answer) is not None:
        x1, y1, x2, y2 = extract_bbox(answer)
        draw_image = Resize(768)(img)
        width, height = draw_image.size
        x1, x2 = int(x1 * width), int(x2 * width)
        y1, y2 = int(y1 * height), int(y2 * height)
        bbox = (x1, y1, x2, y2)
        ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3)
        return gr.update(visible=True, value=draw_image)

    return gr.update(visible=False, value=None)


with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 🌔 moondream
        """
    )
    with gr.Row():
        prompt = gr.Textbox(label="Input Prompt", value="Describe this image.", scale=4)
        submit = gr.Button("Submit")
    with gr.Row():
        img = gr.Image(type="pil", label="Upload an Image")
        with gr.Column():
            output = gr.Markdown(label="Response")
            ann = gr.Image(visible=False, label="Annotated Image")

    submit.click(answer_question, [img, prompt], output)
    prompt.submit(answer_question, [img, prompt], output)
    output.change(process_answer, [img, output], ann, show_progress=False)

demo.queue().launch(debug=True)


================================================
FILE: moondream/__init__.py
================================================


================================================
FILE: moondream/config/config_md05.json
================================================
{
    "text": {
        "dim": 1024,
        "ff_dim": 4096,
        "n_layers": 24,
        "vocab_size": 51200,
        "max_context": 2048,
        "n_heads": 16,
        "prefix_attn": 730
    },
    "vision": {
        "enc_dim": 720,
        "enc_patch_size": 14,
        "enc_n_layers": 27,
        "enc_ff_dim": 2690,
        "enc_n_heads": 10,
        "proj_out_dim": 1024,
        "crop_size": 378,
        "in_channels": 3,
        "max_crops": 12,
        "overlap_margin": 4,
        "proj_inner_dim": 8192
    },
    "region": {
        "dim": 1024,
        "coord_feat_dim": 256,
        "coord_out_dim": 1024,
        "size_feat_dim": 512,
        "size_out_dim": 2048,
        "inner_dim": 8192
    },
    "tokenizer": {
        "bos_id": 50256,
        "eos_id": 50256,
        "templates": {
            "caption": {
                "short": [
                    198,
                    198,
                    16438,
                    8305,
                    25
                ],
                "normal": [
                    198,
                    198,
                    24334,
                    1159,
                    25
                ]
            },
            "query": {
                "prefix": [
                    198,
                    198,
                    24361,
                    25
                ],
                "suffix": [
                    198,
                    198,
                    33706,
                    25
                ]
            },
            "detect": {
                "prefix": [
                    198,
                    198,
                    47504,
                    25
                ],
                "suffix": [
                    628
                ]
            },
            "point": {
                "prefix": [
                    198,
                    198,
                    12727,
                    25
                ],
                "suffix": [
                    628
                ]
            }
        }
    }
}

================================================
FILE: moondream/config/config_md2.json
================================================
{
    "text": {
        "dim": 2048,
        "ff_dim": 8192,
        "n_layers": 24,
        "vocab_size": 51200,
        "max_context": 2048,
        "n_heads": 32,
        "prefix_attn": 730
    },
    "vision": {
        "enc_dim": 1152,
        "enc_patch_size": 14,
        "enc_n_layers": 27,
        "enc_ff_dim": 4304,
        "enc_n_heads": 16,
        "proj_out_dim": 2048,
        "crop_size": 378,
        "in_channels": 3,
        "max_crops": 12,
        "overlap_margin": 4,
        "proj_inner_dim": 8192
    },
    "region": {
        "dim": 2048,
        "coord_feat_dim": 256,
        "coord_out_dim": 1024,
        "size_feat_dim": 512,
        "size_out_dim": 2048,
        "inner_dim": 8192
    },
    "tokenizer": {
        "bos_id": 50256,
        "eos_id": 50256,
        "templates": {
            "caption": {
                "short": [
                    198,
                    198,
                    16438,
                    8305,
                    25
                ],
                "normal": [
                    198,
                    198,
                    24334,
                    1159,
                    25
                ]
            },
            "query": {
                "prefix": [
                    198,
                    198,
                    24361,
                    25
                ],
                "suffix": [
                    198,
                    198,
                    33706,
                    25
                ]
            },
            "detect": {
                "prefix": [
                    198,
                    198,
                    47504,
                    25
                ],
                "suffix": [
                    628
                ]
            },
            "point": {
                "prefix": [
                    198,
                    198,
                    12727,
                    25
                ],
                "suffix": [
                    628
                ]
            }
        }
    }
}

================================================
FILE: moondream/eval/chartqa.py
================================================
import argparse
import datasets
import torch

from tqdm import tqdm
import json

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

PREFIX = "Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. "


# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81
def relaxed_correctness(
    target: str, prediction: str, max_relative_change: float = 0.05
) -> bool:
    """Calculates relaxed correctness.

    The correctness tolerates certain error ratio defined by max_relative_change.
    See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:
    “Following Methani et al. (2020), we use a relaxed accuracy measure for the
    numeric answers to allow a minor inaccuracy that may result from the automatic
    data extraction process. We consider an answer to be correct if it is within
    5% of the gold answer. For non-numeric answers, we still need an exact match
    to consider an answer to be correct.”

    Args:
      target: Target string.
      prediction: Predicted string.
      max_relative_change: Maximum relative change.

    Returns:
      Whether the prediction was correct given the specified tolerance.
    """

    def _to_float(text):
        try:
            if text.endswith("%"):
                # Convert percentages to floats.
                return float(text.rstrip("%")) / 100.0
            else:
                return float(text)
        except ValueError:
            return None

    prediction = str(prediction)
    target = str(target)
    prediction_float = _to_float(prediction)
    target_float = _to_float(target)
    if prediction_float is not None and target_float:
        relative_change = abs(prediction_float - target_float) / abs(target_float)
        return relative_change <= max_relative_change
    else:
        return prediction == target


def eval_chartqa(model, debug=False):
    dataset = datasets.load_dataset("vikhyatk/chartqa", split="test")

    correct = 0
    total = 0
    human_correct = 0
    human_total = 0
    results = []

    for row in tqdm(dataset, disable=debug, desc="ChartQA"):
        image = row["image"]
        encoded_image = model.encode_image(image)

        result = []
        for qa in row["qa"]:
            question = PREFIX + qa["question"]
            answer = qa["answer"]
            model_answer = model.query(encoded_image, question)["answer"]

            # Attempt to parse both answers into lists, otherwise
            try:
                answer_list = json.loads(answer)
                model_answer_list = json.loads(model_answer)
                if not (
                    isinstance(answer_list, list)
                    and isinstance(model_answer_list, list)
                    and len(answer_list) == len(model_answer_list)
                ):
                    raise ValueError
            except:
                # If parsing fails or lengths are not equal, compare the strings directly instead
                answer_list = [answer]
                model_answer_list = [model_answer]

            total += 1
            if qa["source"] == "human":
                human_total += 1

            is_correct = False
            if all(
                relaxed_correctness(
                    str(cur_answer).strip().lower(),
                    str(cur_model_answer).strip().lower(),
                )
                for cur_answer, cur_model_answer in zip(answer_list, model_answer_list)
            ):
                correct += 1
                if qa["source"] == "human":
                    human_correct += 1
                is_correct = True
            if debug:
                print(
                    f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}"
                )
                print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}")
                print(f"Total Accuracy: {correct * 100 / total:.2f}")
                print("---------")
            result.append(
                {
                    "question": question,
                    "ground_truth": answer_list,
                    "model_answer": model_answer_list,
                    "is_correct": is_correct,
                    "source": qa["source"],
                }
            )
        results.append(result)

    return {
        "human_acc": human_correct * 100 / human_total,
        "total_acc": correct * 100 / total,
        "results": results,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    results = eval_chartqa(model, args.debug)
    print(f"Human Accuracy: {results['human_acc']:.2f}")
    print(f"Total Accuracy: {results['total_acc']:.2f}")


================================================
FILE: moondream/eval/coco_map.py
================================================
import argparse
import datasets
import torch
import json
import numpy as np

from typing import List, Tuple
from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


coco_classes = [
    "None",
    "person",
    "bicycle",
    "car",
    "motorcycle",
    "airplane",
    "bus",
    "train",
    "truck",
    "boat",
    "traffic light",
    "fire hydrant",
    "street sign",
    "stop sign",
    "parking meter",
    "bench",
    "bird",
    "cat",
    "dog",
    "horse",
    "sheep",
    "cow",
    "elephant",
    "bear",
    "zebra",
    "giraffe",
    "hat",
    "backpack",
    "umbrella",
    "shoe",
    "eye glasses",
    "handbag",
    "tie",
    "suitcase",
    "frisbee",
    "skis",
    "snowboard",
    "sports ball",
    "kite",
    "baseball bat",
    "baseball glove",
    "skateboard",
    "surfboard",
    "tennis racket",
    "bottle",
    "plate",
    "wine glass",
    "cup",
    "fork",
    "knife",
    "spoon",
    "bowl",
    "banana",
    "apple",
    "sandwich",
    "orange",
    "broccoli",
    "carrot",
    "hot dog",
    "pizza",
    "donut",
    "cake",
    "chair",
    "couch",
    "potted plant",
    "bed",
    "mirror",
    "dining table",
    "window",
    "desk",
    "toilet",
    "door",
    "tv",
    "laptop",
    "mouse",
    "remote",
    "keyboard",
    "cell phone",
    "microwave",
    "oven",
    "toaster",
    "sink",
    "refrigerator",
    "blender",
    "book",
    "clock",
    "vase",
    "scissors",
    "teddy bear",
    "hair drier",
    "toothbrush",
    "hair brush",
]

COCO_LABELS = {}

for i, c in enumerate(coco_classes):
    COCO_LABELS[i] = c


def calculate_iou(
    box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]
) -> float:
    """Calculate IoU between two boxes (x1, y1, x2, y2 format)"""
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    intersection = max(0, x2 - x1) * max(0, y2 - y1)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    return intersection / (box1_area + box2_area - intersection)


def calculate_map(
    ground_truth_boxes: List[List[Tuple[float, float, float, float]]],
    predicted_boxes: List[List[Tuple[float, float, float, float, float]]],
    iou_threshold: float = 0.5,
) -> float:
    """
    Calculate mAP for object detection

    Args:
        ground_truth_boxes: List of lists of ground truth boxes per image [(x1, y1, x2, y2)]
        predicted_boxes: List of lists of predicted boxes per image [(x1, y1, x2, y2, confidence)]
        iou_threshold: IoU threshold for considering a detection as correct

    Returns:
        mean Average Precision
    """
    total_precision = 0
    num_classes = len(ground_truth_boxes)

    for class_idx in range(num_classes):
        # Get all predictions and ground truths for this class
        gt_boxes = ground_truth_boxes[class_idx]
        pred_boxes = predicted_boxes[class_idx]

        # Sort predictions by confidence
        pred_boxes = sorted(pred_boxes, key=lambda x: x[4], reverse=True)

        # Initialize arrays for precision-recall calculation
        num_gt = len(gt_boxes)
        if num_gt == 0:
            continue

        tp = np.zeros(len(pred_boxes))
        fp = np.zeros(len(pred_boxes))
        gt_matched = [False] * num_gt

        # Match each prediction to ground truth
        for pred_idx, pred_box in enumerate(pred_boxes):
            max_iou = 0
            max_idx = -1

            # Find best matching ground truth box
            for gt_idx, gt_box in enumerate(gt_boxes):
                if gt_matched[gt_idx]:
                    continue

                iou = calculate_iou(pred_box[:4], gt_box)
                if iou > max_iou:
                    max_iou = iou
                    max_idx = gt_idx

            # If IoU exceeds threshold, count as true positive
            if max_iou >= iou_threshold:
                tp[pred_idx] = 1
                gt_matched[max_idx] = True
            else:
                fp[pred_idx] = 1

        # Calculate cumulative precision and recall
        cumsum_tp = np.cumsum(tp)
        cumsum_fp = np.cumsum(fp)
        recalls = cumsum_tp / num_gt
        precisions = cumsum_tp / (cumsum_tp + cumsum_fp)

        # Calculate average precision using all points
        ap = 0
        for t in np.arange(0, 1.1, 0.1):
            if np.sum(recalls >= t) == 0:
                p = 0
            else:
                p = np.max(precisions[recalls >= t])
            ap += p / 11

        total_precision += ap

    return total_precision / num_classes


def get_total_map(results_by_label, frequency_by_label):
    total_count = 0
    total_map = 0
    for results, frequency in zip(
        results_by_label.values(), frequency_by_label.values()
    ):
        cur_total_map = sum(results)
        total_map += cur_total_map
        total_count += frequency
    return total_map / total_count


def eval_coco_map(model, iou_threshold=0.5, debug=False):
    dataset = datasets.load_dataset(
        "moondream/coco-val-2017-bbox-cleaned", split="validation"
    )

    total = 0
    results_by_label = {}  # map to list of raw map results for each label
    frequency_by_label = {}  #  many images contain a given label

    for row in tqdm(dataset, disable=debug, desc="COCO mAP"):
        width = row["image"].width
        height = row["image"].height
        total += 1

        objects = json.loads(row["objects"])

        gt_label_to_boxes = {}

        for bbox, label in zip(objects["bbox"], objects["label"]):
            if label not in gt_label_to_boxes:
                gt_label_to_boxes[label] = []
            x1, y1, w, h = bbox
            gt_label_to_boxes[label].append((x1, y1, x1 + w, y1 + h))

        unique_labels = [label for label in set(objects["label"])]

        for label in unique_labels:

            encoded_image = model.encode_image(row["image"])
            model_answer = model.detect(encoded_image, COCO_LABELS[label])["objects"]

            moondream_boxes = []

            for box in model_answer:
                moondream_boxes.append(
                    (
                        box["x_min"] * width,
                        box["y_min"] * height,
                        box["x_max"] * width,
                        box["y_max"] * height,
                        1.0,  # Using default confidence of 1.0
                    )
                )
            map_result = calculate_map(
                [gt_label_to_boxes[label]], [moondream_boxes], iou_threshold
            )
            if debug and map_result == 0:
                print(
                    f"0 Map result for index {total} and label {label} ({COCO_LABELS[label]})"
                )

            if label not in results_by_label:
                results_by_label[label] = []
            results_by_label[label].append(map_result)

            if label not in frequency_by_label:
                frequency_by_label[label] = 0
            frequency_by_label[label] += 1

        if debug and total % 100 == 0:
            print(
                f"Total map: {get_total_map(results_by_label, frequency_by_label)*100:.2f}, ({total} images)"
            )

    return {
        # "results_by_label": results_by_label,
        # "frequency_by_label": frequency_by_label,
        "total_map": get_total_map(results_by_label, frequency_by_label),
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    # This repo doesn't have moondream deps we need
    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_coco_map(model, 0.5, args.debug)

    print(f"Overall MAP: {result['total_map']*100:.2f}")


================================================
FILE: moondream/eval/countbenchqa.py
================================================
import argparse
import datasets
import torch

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. "


def eval_countbenchqa(model, debug=False):
    dataset = datasets.load_dataset("vikhyatk/CountBenchQA", split="test")

    correct = 0
    total = 0
    results = []

    for row in tqdm(dataset, disable=debug, desc="CountBenchQA"):
        image = row["image"]
        encoded_image = model.encode_image(image)

        question = PREFIX + row["question"]
        answer = str(row["number"])
        model_answer = model.query(encoded_image, question)["answer"]
        is_correct = model_answer.strip().lower() == answer.strip().lower()

        results.append(
            {
                "question": question,
                "ground_truth": answer,
                "model_answer": model_answer,
                "is_correct": is_correct,
            }
        )

        total += 1
        if is_correct:
            correct += 1
        elif debug:
            print(f"Question: {row['question']}")
            print(f"Answer: {answer}")
            print(f"Model Answer: {model_answer}")
        if debug:
            print(f"Correct: {correct}, Total: {total}")
            print(f"Accuracy: {correct * 100 / total:.2f}")
            print("---------")

    return {
        "acc": correct * 100 / total,
        "correct_count": correct,
        "total_count": total,
        "results": results,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)

    result = eval_countbenchqa(model, args.debug)

    print(f"Accuracy: {result['acc']:.2f}")
    print(f"Correct: {result['correct_count']}, Total: {result['total_count']}")


================================================
FILE: moondream/eval/docvqa.py
================================================
import argparse
import editdistance
from datasets import load_dataset
from tqdm import tqdm
import torch

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

SUFFIX = " The answer should be a short text span taken verbatim from the document."


def get_anls(s1, s2):
    s1 = s1.lower().strip()
    s2 = s2.lower().strip()
    iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))
    anls = iou if iou >= 0.5 else 0.0
    return anls


def eval_docvqa(model, debug=False):
    docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation")

    scores = []
    results = []

    for row in tqdm(docvqa_val, disable=debug, desc="DocVQA"):
        image = row["image"]
        encoded_image = model.encode_image(image)

        result = []
        for qa in row["qa"]:
            question = qa["question"]
            answers = qa["answers"]
            prompt = question + SUFFIX

            model_answer = model.query(encoded_image, prompt)["answer"]
            anls = max(get_anls(model_answer, gt) for gt in answers)
            scores.append(anls)
            result.append(
                {
                    "question": question,
                    "ground_truth": answers,
                    "model_answer": model_answer,
                    "anls": anls,
                }
            )

            if debug:
                print(f"Question: {question}")
                print(f"Ground Truth: {answers}")
                print(f"Model Answer: {model_answer}")
                print(f"ANLS: {anls}")
                print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}")
                print("---------")
        results.append(result)

    return {
        "anls": sum(scores) / len(scores),
        "results": results,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_docvqa(model, args.debug)

    print(f"ANLS: {result['anls']:.4f}")


================================================
FILE: moondream/eval/eval_all.py
================================================
import argparse
import torch

from pprint import pprint

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

from .countbenchqa import eval_countbenchqa
from .pope import evaluate_pope
from .realworldqa import eval_realworldqa
from .chartqa import eval_chartqa
from .textvqa import eval_textvqa
from .docvqa import eval_docvqa
from .mmstar import eval_mmstar
from .coco_map import eval_coco_map
from .naturalbench import eval_naturalbench
from .tallyqa import eval_tallyqa


def create_model(ckpt_path):
    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(ckpt_path, model)
    model.compile()
    return model


def eval_all(model, skip=[]):
    evals = {
        "countbenchqa": eval_countbenchqa,
        "pope": evaluate_pope,
        "realworldqa": eval_realworldqa,
        "chartqa": eval_chartqa,
        "mmstar": eval_mmstar,
        "docvqa": eval_docvqa,
        "coco_map": eval_coco_map,
        "textvqa": eval_textvqa,
        "naturalbench": eval_naturalbench,
        "tallyqa": eval_tallyqa,
    }

    for b in skip:
        del evals[b]

    results = {}
    for name, eval_fn in evals.items():
        results[name] = eval_fn(model)
        pprint({k: v for k, v in results[name].items() if k != "results"})

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    model = create_model(args.model)
    eval_all(model)


================================================
FILE: moondream/eval/gazefollow.py
================================================
import torch
import datasets
import math

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


def eval_gazefollow(model, debug=False):
    dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test")

    mean_l2_error = []
    min_l2_error = []
    total = 0

    for i, row in tqdm(enumerate(dataset), total=len(dataset)):
        heads = []

        for gaze in row["gazes"]:
            head_bbox = gaze["head_bbox"]  # xmin, ymin, xmax, ymax
            eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"])
            mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"])

            # Check if a head already exists with the same approximate bbox.
            # If so, use that head instead of creating a new one.
            for head in heads:
                if (
                    abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001
                    and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001
                    and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001
                    and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001
                ):
                    head["gazes"].append(mean_target_gaze)
                    break
            else:
                heads.append(
                    {
                        "head_bbox": head_bbox,
                        "eye_coord": eye_coord,
                        "gazes": [mean_target_gaze],
                    }
                )

        for head in heads:
            pred_gaze = model.detect_gaze(
                row["image"],
                eye=head["eye_coord"],
                face={
                    "x_min": head["head_bbox"]["xmin"],
                    "y_min": head["head_bbox"]["ymin"],
                    "x_max": head["head_bbox"]["xmax"],
                    "y_max": head["head_bbox"]["ymax"],
                },
                unstable_settings={"force_detect": True},
            )["gaze"]

            mean_target_gaze = (
                sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]),
                sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]),
            )
            mean_l2 = math.sqrt(
                (mean_target_gaze[0] - pred_gaze["x"]) ** 2
                + (mean_target_gaze[1] - pred_gaze["y"]) ** 2
            )
            min_l2 = min(
                math.sqrt(
                    (target_gaze[0] - pred_gaze["x"]) ** 2
                    + (target_gaze[1] - pred_gaze["y"]) ** 2
                )
                for target_gaze in head["gazes"]
            )

            mean_l2_error.append(mean_l2)
            min_l2_error.append(min_l2)
            total += 1

            if i % 100 == 0 and debug:
                print("Mean L2 error:", sum(mean_l2_error) / total)
                print("Min L2 error:", sum(min_l2_error) / total)

    return {
        "mean_l2": sum(mean_l2_error) / total,
        "min_l2": sum(min_l2_error) / total,
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)

    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)

    results = eval_gazefollow(model, debug=args.debug)

    print(f"Mean L2 error: {results['mean_l2']:.4f}")
    print(f"Min L2 error: {results['min_l2']:.4f}")


================================================
FILE: moondream/eval/mmstar.py
================================================
import datasets
import torch

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

SUFFIX = " Please answer directly with only the letter of the correct option and nothing else."


def eval_mmstar(model, debug=False):
    dataset = datasets.load_dataset("Lin-Chen/MMStar", split="val")

    correct = 0
    total = 0
    category_stats = {}
    results = []

    for row in tqdm(dataset, disable=debug, desc="MMStar"):
        image = row["image"]
        question = row["question"] + SUFFIX
        answer = row["answer"]
        model_answer = model.query(image, question)["answer"]
        is_correct = model_answer.strip().lower() == answer.strip().lower()

        category = f"{row['category']} / {row['l2_category']}"
        if category not in category_stats:
            category_stats[category] = {"correct": 0, "total": 0}

        total += 1
        category_stats[category]["total"] += 1

        results.append(
            {
                "question": question,
                "ground_truth": answer,
                "model_answer": model_answer,
                "is_correct": is_correct,
                "category": category,
            }
        )

        if is_correct:
            correct += 1
            category_stats[category]["correct"] += 1
        elif debug:
            print(f"Index: {row['index']}")
            print(f"Question: {row['question']}")
            print(f"Answer: {answer}")
            print(f"Model Answer: {model_answer}")
        if debug:
            print(f"Correct: {correct}, Total: {total}")
            print(f"Accuracy: {correct * 100 / total:.2f}")
            print("Results by category:")
            for category, stats in category_stats.items():
                acc = stats["correct"] * 100 / stats["total"]
                print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%")
            print("---------")

    return {
        "acc": correct * 100 / total,
        "correct_count": correct,
        "total_count": total,
        "category_stats": category_stats,
        "results": results,
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_mmstar(model, args.debug)

    print(f"Correct: {result['correct_count']}, Total: {result['total_count']}")
    print(f"Accuracy: {result['acc']:.2f}")

    print("\nResults by category:")
    for category, stats in result["category_stats"].items():
        acc = stats["correct"] * 100 / stats["total"]
        print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%")


================================================
FILE: moondream/eval/naturalbench.py
================================================
from datasets import load_dataset
from tqdm import tqdm
import torch

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


def eval_naturalbench(model, debug=False):
    # Yes, the benchmark test set is stored in the 'train' split...
    dataset = load_dataset("BaiqiL/NaturalBench", split="train")

    acc = []
    q_acc = []
    i_acc = []
    g_acc = []

    for row in tqdm(dataset, disable=debug, desc="NaturalBench"):
        if row["Question_Type"] == "yes_no":
            suffix = " Answer yes or no."
        else:
            suffix = ""

        images = [row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]]
        prompts = [
            row["Question_0"] + suffix,
            row["Question_0"] + suffix,
            row["Question_1"] + suffix,
            row["Question_1"] + suffix,
        ]
        expected = [
            row["Image_0_Question_0"].strip().lower(),
            row["Image_1_Question_0"].strip().lower(),
            row["Image_0_Question_1"].strip().lower(),
            row["Image_0_Question_1"].strip().lower(),
        ]

        answers = []
        for img, prompt in zip(images, prompts):
            encoded_image = model.encode_image(img)
            answer = model.query(encoded_image, prompt)["answer"]
            answers.append(answer.strip().lower())

        if debug:
            for i, (q, a, e) in enumerate(zip(prompts, answers, expected)):
                print(f"Q{i}: {q}")
                print(f"Model: {a}")
                print(f"Expected: {e}")
                print(f"Correct: {a == e}")
                print("---")

        acc.append(answers[0] == expected[0])
        acc.append(answers[1] == expected[1])
        acc.append(answers[2] == expected[2])
        acc.append(answers[3] == expected[3])

        i_acc.append(answers[0] == expected[0] and answers[2] == expected[2])
        i_acc.append(answers[1] == expected[1] and answers[3] == expected[3])

        q_acc.append(answers[0] == expected[0] and answers[1] == expected[1])
        q_acc.append(answers[2] == expected[2] and answers[3] == expected[3])

        g_acc.append(
            answers[0] == expected[0]
            and answers[1] == expected[1]
            and answers[2] == expected[2]
            and answers[3] == expected[3]
        )

        if debug:
            print(f"Current Overall Accuracy: {sum(acc) / len(acc):.4f}")
            print(f"Current Image Accuracy: {sum(i_acc) / len(i_acc):.4f}")
            print(f"Current Question Accuracy: {sum(q_acc) / len(q_acc):.4f}")
            print(f"Current Group Accuracy: {sum(g_acc) / len(g_acc):.4f}")
            print("=========")

    return {
        "overall_acc": sum(acc) / len(acc),
        "image_acc": sum(i_acc) / len(i_acc),
        "question_acc": sum(q_acc) / len(q_acc),
        "group_acc": sum(g_acc) / len(g_acc),
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    results = eval_naturalbench(model, debug=args.debug)

    print(f"Overall Accuracy: {results['overall_acc']:.4f}")
    print(f"Image Accuracy: {results['image_acc']:.4f}")
    print(f"Question Accuracy: {results['question_acc']:.4f}")
    print(f"Group Accuracy: {results['group_acc']:.4f}")


================================================
FILE: moondream/eval/pope.py
================================================
import argparse
from datasets import load_dataset
from tqdm import tqdm
import torch

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


def evaluate_pope(model, debug=False):
    pope_dataset = load_dataset("vikhyatk/POPE", split="test")

    stats = {
        "random": (0, 0),
        "popular": (0, 0),
        "adversarial": (0, 0),
    }

    for row in tqdm(pope_dataset, disable=debug, desc="POPE"):
        image = row["image"]
        encoded_image = model.encode_image(image)
        for split in ["adversarial", "popular", "random"]:
            for qa in row[split]:
                question = qa["question"]
                answer = qa["answer"]
                prompt = f"{question}\nAnswer yes or no."
                model_answer = model.query(encoded_image, prompt)["answer"].strip()

                if debug:
                    print(f"Split: {split}")
                    print(f"Question: {question}")
                    print(f"Model: {model_answer}")
                    print(f"Expected: {answer}")
                    print(f"Correct: {model_answer.lower() == answer.lower()}")
                    print("---")

                if model_answer.lower() == answer.lower():
                    stats[split] = (stats[split][0] + 1, stats[split][1] + 1)
                else:
                    stats[split] = (stats[split][0], stats[split][1] + 1)

                if debug:
                    for s in stats:
                        if stats[s][1] > 0:
                            print(
                                f"{s.capitalize()}: {stats[s][0]}/{stats[s][1]} = {stats[s][0] * 100.0 / stats[s][1]:.2f}%"
                            )
                    print("=========")

    return {
        "random": stats["random"][0] * 100.0 / stats["random"][1],
        "popular": stats["popular"][0] * 100.0 / stats["popular"][1],
        "adversarial": stats["adversarial"][0] * 100.0 / stats["adversarial"][1],
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)

    result = evaluate_pope(model, args.debug)

    print(f"Random Accuracy: {result['random']:.2f}")
    print(f"Popular Accuracy: {result['popular']:.2f}")
    print(f"Adversarial Accuracy: {result['adversarial']:.2f}")


================================================
FILE: moondream/eval/realworldqa.py
================================================
import argparse
import datasets
import torch

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


def eval_realworldqa(model, debug=False):
    dataset = datasets.load_dataset("lmms-lab/RealWorldQA", split="test")

    correct = 0
    total = 0
    results = []

    for row in tqdm(dataset, disable=debug, desc="RealWorldQA"):
        image = row["image"]
        question = row["question"]
        answer = row["answer"]
        model_answer = model.query(image, question)["answer"]
        is_correct = model_answer.strip().lower() == answer.strip().lower()

        results.append(
            {
                "question": question,
                "ground_truth": answer,
                "model_answer": model_answer,
                "is_correct": is_correct,
            }
        )

        total += 1
        if is_correct:
            correct += 1
        elif debug:
            print(f"Image: {row['image_path']}")
            print(f"Question: {question}")
            print(f"Answer: {answer}")
            print(f"Model Answer: {model_answer}")
        if debug:
            print(f"Correct: {correct}, Total: {total}")
            print(f"Accuracy: {correct * 100 / total:.2f}")
            print("---------")

    return {
        "acc": correct * 100 / total,
        "correct_count": correct,
        "total_count": total,
        "results": results,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_realworldqa(model, args.debug)

    print(f"Accuracy: {result['acc']:.2f}")
    print(f"Correct: {result['correct_count']} / {result['total_count']}")


================================================
FILE: moondream/eval/tallyqa.py
================================================
import argparse
import datasets
import torch

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. "


def eval_tallyqa(model, debug=False):
    dataset = datasets.load_dataset(
        "vikhyatk/tallyqa-test",
        split="test",
        download_config=datasets.DownloadConfig(num_proc=16),
    )

    total = 0
    total_simple = 0
    correct = 0
    correct_simple = 0

    for row in tqdm(dataset, disable=args.debug):
        image = row["image"]
        encoded_image = model.encode_image(image)

        for qa in row["qa"]:
            question = PREFIX + qa["question"]
            answer = str(qa["answer"])
            is_simple = qa["is_simple"]

            model_answer = model.query(encoded_image, question)["answer"]

            total += 1
            if model_answer.strip().lower() == answer.strip().lower():
                correct += 1
            elif args.debug:
                print(f"Question: {qa['question']}")
                print(f"Answer: {answer}")
                print(f"Model Answer: {model_answer}")

            if is_simple:
                total_simple += 1
                if model_answer.strip().lower() == answer.strip().lower():
                    correct_simple += 1

            if args.debug:
                print(f"Simple - Correct: {correct_simple}, Total: {total_simple}")
                print(f"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}")
                print(f"All - Correct: {correct}, Total: {total}")
                print(f"All Accuracy: {correct * 100 / total:.2f}")
                print("---------")

    return {
        "simple_acc": correct_simple * 100 / total_simple,
        "full_acc": correct * 100 / total,
    }


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_tallyqa(model, args.debug)

    print(f"Simple acc: {result['simple_acc']:.2f}")
    print(f"Full acc: {result['full_acc']:.2f}")


================================================
FILE: moondream/eval/textvqa.py
================================================
import argparse
import datasets
import torch

from tqdm import tqdm

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model
from .utils import VQAScorer

PREFIX_TEXTVQA = "Read the text in the image and provide a brief lowercase answer. Respond 'unanswerable' only if there is no plausible answer. "


def eval_textvqa(model, debug=False):
    dataset = datasets.load_dataset("vikhyatk/textvqa_val", split="validation")

    scorer = VQAScorer()

    total_score = 0
    total_samples = 0
    results = []

    for row in tqdm(dataset, disable=debug, desc="TextVQA"):
        image = row["image"]
        encoded_image = model.encode_image(image)
        question = PREFIX_TEXTVQA + row["question"]
        model_answer = model.query(encoded_image, question)["answer"]

        score = scorer.compute_score(model_answer, row["answers"])
        total_score += score
        total_samples += 1

        results.append(
            {
                "question": question,
                "ground_truth": row["answers"],
                "model_answer": model_answer,
                "score": score,
            }
        )

        if debug:
            print(f"Question: {row['question']}")
            print(f"Ground Truth Answers: {row['answers']}")
            print(f"Model Answer: {model_answer}")
            print(f"Score: {score}")
            print(f"Running Average Score: {total_score * 100 / total_samples:.2f}")
            print("---------")

    return {"score": total_score * 100 / total_samples, "results": results}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        torch.set_default_device("cuda")
    elif torch.backends.mps.is_available():
        torch.set_default_device("mps")

    config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.compile()

    result = eval_textvqa(model, args.debug)

    print(f"Score: {result['score']}")


================================================
FILE: moondream/eval/utils.py
================================================
import re
from typing import List


class VQAScorer:
    def __init__(self):
        self.contractions = {
            "aint": "ain't",
            "arent": "aren't",
            "cant": "can't",
            "couldve": "could've",
            "couldnt": "couldn't",
            "couldn'tve": "couldn't've",
            "couldnt've": "couldn't've",
            "didnt": "didn't",
            "doesnt": "doesn't",
            "dont": "don't",
            "hadnt": "hadn't",
            "hadnt've": "hadn't've",
            "hadn'tve": "hadn't've",
            "hasnt": "hasn't",
            "havent": "haven't",
            "hed": "he'd",
            "hed've": "he'd've",
            "he'dve": "he'd've",
            "hes": "he's",
            "howd": "how'd",
            "howll": "how'll",
            "hows": "how's",
            "Id've": "I'd've",
            "I'dve": "I'd've",
            "Im": "I'm",
            "Ive": "I've",
            "isnt": "isn't",
            "itd": "it'd",
            "itd've": "it'd've",
            "it'dve": "it'd've",
            "itll": "it'll",
            "let's": "let's",
            "maam": "ma'am",
            "mightnt": "mightn't",
            "mightnt've": "mightn't've",
            "mightn'tve": "mightn't've",
            "mightve": "might've",
            "mustnt": "mustn't",
            "mustve": "must've",
            "neednt": "needn't",
            "notve": "not've",
            "oclock": "o'clock",
            "oughtnt": "oughtn't",
            "ow's'at": "'ow's'at",
            "'ows'at": "'ow's'at",
            "'ow'sat": "'ow's'at",
            "shant": "shan't",
            "shed've": "she'd've",
            "she'dve": "she'd've",
            "she's": "she's",
            "shouldve": "should've",
            "shouldnt": "shouldn't",
            "shouldnt've": "shouldn't've",
            "shouldn'tve": "shouldn't've",
            "somebody'd": "somebodyd",
            "somebodyd've": "somebody'd've",
            "somebody'dve": "somebody'd've",
            "somebodyll": "somebody'll",
            "somebodys": "somebody's",
            "someoned": "someone'd",
            "someoned've": "someone'd've",
            "someone'dve": "someone'd've",
            "someonell": "someone'll",
            "someones": "someone's",
            "somethingd": "something'd",
            "somethingd've": "something'd've",
            "something'dve": "something'd've",
            "somethingll": "something'll",
            "thats": "that's",
            "thered": "there'd",
            "thered've": "there'd've",
            "there'dve": "there'd've",
            "therere": "there're",
            "theres": "there's",
            "theyd": "they'd",
            "theyd've": "they'd've",
            "they'dve": "they'd've",
            "theyll": "they'll",
            "theyre": "they're",
            "theyve": "they've",
            "twas": "'twas",
            "wasnt": "wasn't",
            "wed've": "we'd've",
            "we'dve": "we'd've",
            "weve": "we've",
            "werent": "weren't",
            "whatll": "what'll",
            "whatre": "what're",
            "whats": "what's",
            "whatve": "what've",
            "whens": "when's",
            "whered": "where'd",
            "wheres": "where's",
            "whereve": "where've",
            "whod": "who'd",
            "whod've": "who'd've",
            "who'dve": "who'd've",
            "wholl": "who'll",
            "whos": "who's",
            "whove": "who've",
            "whyll": "why'll",
            "whyre": "why're",
            "whys": "why's",
            "wont": "won't",
            "wouldve": "would've",
            "wouldnt": "wouldn't",
            "wouldnt've": "wouldn't've",
            "wouldn'tve": "wouldn't've",
            "yall": "y'all",
            "yall'll": "y'all'll",
            "y'allll": "y'all'll",
            "yall'd've": "y'all'd've",
            "y'alld've": "y'all'd've",
            "y'all'dve": "y'all'd've",
            "youd": "you'd",
            "youd've": "you'd've",
            "you'dve": "you'd've",
            "youll": "you'll",
            "youre": "you're",
            "youve": "you've",
        }

        self.manualMap = {
            "none": "0",
            "zero": "0",
            "one": "1",
            "two": "2",
            "three": "3",
            "four": "4",
            "five": "5",
            "six": "6",
            "seven": "7",
            "eight": "8",
            "nine": "9",
            "ten": "10",
        }

        self.articles = ["a", "an", "the"]

        self.periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)")
        self.commaStrip = re.compile(r"(\d)(\,)(\d)")
        self.punct = [
            ";",
            r"/",
            "[",
            "]",
            '"',
            "{",
            "}",
            "(",
            ")",
            "=",
            "+",
            "\\",
            "_",
            "-",
            ">",
            "<",
            "@",
            "`",
            ",",
            "?",
            "!",
        ]
        self.commaStrip = re.compile(r"(\d)(,)(\d)")
        self.periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)")

    def process_punctuation(self, inText: str) -> str:
        outText = inText

        for p in self.punct:
            if (p + " " in inText or " " + p in inText) or (
                re.search(self.commaStrip, inText) is not None
            ):
                outText = outText.replace(p, "")
            else:
                outText = outText.replace(p, " ")
        outText = self.periodStrip.sub("", outText, re.UNICODE)
        return outText

    def process_digit_article(self, inText: str) -> str:
        outText = []
        tempText = inText.lower().split()
        for word in tempText:
            word = self.manualMap.setdefault(word, word)
            if word not in self.articles:
                outText.append(word)
        for wordId, word in enumerate(outText):
            if word in self.contractions:
                outText[wordId] = self.contractions[word]
        outText = " ".join(outText)
        return outText

    def process_answer(self, answer):
        answer = answer.replace("\n", " ")
        answer = answer.replace("\t", " ")
        answer = answer.strip()
        answer = self.process_punctuation(answer)
        answer = self.process_digit_article(answer)
        return answer

    def process_line(self, prediction: str, gt_answers: List[str]) -> float:
        gt_answers = [self.process_answer(x) for x in gt_answers]
        prediction = self.process_answer(prediction)
        matches = []
        for current_idx, gtAnsDatum in enumerate(gt_answers):
            otherGTAns = [
                item
                for ret_gt_idx, item in enumerate(gt_answers)
                if ret_gt_idx != current_idx
            ]
            matchingAns = [item for item in otherGTAns if item == prediction]
            acc = min(1, float(len(matchingAns)) / 3)
            matches.append(acc)

        return sum(matches) / len(matches)

    def compute_score(
        self, candidate_answer: str, ground_truth_answers: List[str]
    ) -> float:
        """
        Compute VQA score for a candidate answer against ground truth answers,
        exactly matching the VQAEval scoring logic
        """
        # Process candidate answer
        candidate = self.process_answer(candidate_answer)

        # Process ground truth answers
        processed_gts = []
        for gt in ground_truth_answers:
            gt = gt.replace("\n", " ")
            gt = gt.replace("\t", " ")
            gt = gt.strip()
            processed_gts.append(gt)

        # If there are multiple different answers, apply additional processing
        if len(set(processed_gts)) > 1:
            candidate = self.process_punctuation(candidate)
            candidate = self.process_digit_article(candidate)
            processed_gts = [
                self.process_punctuation(self.process_digit_article(gt))
                for gt in processed_gts
            ]

        # Count matches
        matching_answers = [1 for gt in processed_gts if gt == candidate]
        score = min(1.0, float(len(matching_answers)) / 3.0)

        return score


================================================
FILE: moondream/eval/waste_detection.py
================================================
import argparse
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
from PIL import Image
from tqdm import tqdm
from datasets import load_dataset

from ..torch.config import MoondreamConfig
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model


Box = Tuple[float, float, float, float]  # (x1, y1, x2, y2) – in proportion form


def iou(a: Box, b: Box) -> float:
    """Corner-format IoU. Returns 0 when either box has zero area."""
    x1, y1 = max(a[0], b[0]), max(a[1], b[1])
    x2, y2 = min(a[2], b[2]), min(a[3], b[3])
    inter = max(0.0, x2 - x1) * max(0.0, y2 - y1)

    union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter
    return inter / union if union else 0.0


def match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, int, int]:
    """
    Greedy one-to-one matching with no confidences.
    Predictions are taken in the order produced by the model.
    """
    tp = fp = 0
    seen = [False] * len(gt)

    for p in pr:
        best, best_i = 0.0, -1
        for i, g in enumerate(gt):
            if seen[i]:
                continue
            iou_ = iou(p, g)
            if iou_ > best:
                best, best_i = iou_, i
        if best >= iou_thr:
            tp += 1
            seen[best_i] = True
        else:
            fp += 1

    fn = len(gt) - tp
    return tp, fp, fn


class WasteDetection(torch.utils.data.Dataset):
    def __init__(self, name: str = "moondream/waste_detection", split: str = "test"):
        self.ds = load_dataset(name, split=split)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx: int) -> Dict:
        s = self.ds[idx]
        img = (
            s["image"]
            if isinstance(s["image"], Image.Image)
            else Image.fromarray(s["image"])
        )
        W, H = float(s.get("width", img.width)), float(s.get("height", img.height))

        lbl_to_boxes = defaultdict(list)
        for (xc, yc, bw, bh), lbl in zip(s["boxes"], s["labels"]):
            x1 = xc - bw / 2
            y1 = yc - bh / 2
            x2 = xc + bw / 2
            y2 = yc + bh / 2
            lbl_to_boxes[lbl].append((x1, y1, x2, y2))

        return {"image": img, "gt": lbl_to_boxes, "W": W, "H": H}


def evaluate(
    model: MoondreamModel,
    iou_thr: float,
    debug: bool,
):
    ds = WasteDetection(split="test")
    TP = FP = FN = 0

    for s in tqdm(ds, disable=debug, desc="Waste"):
        img, gts = s["image"], s["gt"]
        enc = model.encode_image(img)

        for lbl, gt_boxes in gts.items():
            preds: List[Box] = [
                (
                    o["x_min"],
                    o["y_min"],
                    o["x_max"],
                    o["y_max"],
                )
                for o in model.detect(enc, lbl)["objects"]
            ]
            tp, fp, fn = match(gt_boxes, preds, iou_thr)
            TP += tp
            FP += fp
            FN += fn

    prec = TP / (TP + FP) if TP + FP else 0.0
    rec = TP / (TP + FN) if TP + FN else 0.0
    f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0
    return dict(precision=prec, recall=rec, f1=f1, tp=TP, fp=FP, fn=FN)


def load_model(path: str, device: torch.device) -> MoondreamModel:
    cfg = MoondreamConfig()
    model = MoondreamModel(cfg)
    load_weights_into_model(path, model)
    model.compile()
    model.to(device)
    return model


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--model", required=True)
    p.add_argument("--iou_thr", type=float, default=0.5)
    p.add_argument("--gpu", type=int, default=0)
    p.add_argument("--debug", action="store_true")
    args = p.parse_args()

    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        device = torch.device(f"cuda:{args.gpu}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model = load_model(args.model, device)
    res = evaluate(model, args.iou_thr, args.debug)

    print(f"Precision: {res['precision']*100:.2f}%")
    print(f"Recall: {res['recall']*100:.2f}%")
    print(f"F1 Score:  {res['f1']*100:.2f}%")
    print(f"TP: {res['tp']}  FP: {res['fp']}  FN: {res['fn']}")


if __name__ == "__main__":
    """
    Eval to accompany finetune_region.py.
    """
    main()


================================================
FILE: moondream/torch/config.py
================================================
from dataclasses import dataclass, field
from typing import Dict, List, Optional


@dataclass(frozen=True)
class TextMoeConfig:
    num_experts: int = 64
    start_layer: int = 4
    experts_per_token: int = 8
    expert_inner_dim: int = 1024


@dataclass(frozen=True)
class TextConfig:
    dim: int = 2048
    ff_dim: int = 8192
    n_layers: int = 24
    vocab_size: int = 51200
    max_context: int = 4096
    n_heads: int = 32
    n_kv_heads: int = 32
    prefix_attn: int = 730
    group_size: Optional[int] = None
    moe: Optional[TextMoeConfig] = TextMoeConfig()


@dataclass(frozen=True)
class VisionConfig:
    enc_dim: int = 1152
    enc_patch_size: int = 14
    enc_n_layers: int = 27
    enc_ff_dim: int = 4304
    enc_n_heads: int = 16
    proj_out_dim: int = 2048
    crop_size: int = 378
    in_channels: int = 3
    max_crops: int = 12
    overlap_margin: int = 4
    proj_inner_dim: int = 8192


@dataclass(frozen=True)
class RegionConfig:
    dim: int = 2048
    coord_feat_dim: int = 256
    coord_out_dim: int = 1024
    size_feat_dim: int = 512
    size_out_dim: int = 2048
    group_size: Optional[int] = None


@dataclass(frozen=True)
class TokenizerConfig:
    bos_id: int = 0
    eos_id: int = 0
    answer_id: int = 3
    thinking_id: int = 4
    coord_id: int = 5
    size_id: int = 6
    start_ground_points_id: int = 7
    end_ground_id: int = 9
    templates: Dict[str, Optional[Dict[str, List[int]]]] = field(
        default_factory=lambda: {
            "caption": {
                "short": [1, 32708, 2, 12492, 3],
                "normal": [1, 32708, 2, 6382, 3],
                "long": [1, 32708, 2, 4059, 3],
            },
            "query": {"prefix": [1, 15381, 2], "suffix": [3]},
            "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]},
            "point": {"prefix": [1, 2581, 2], "suffix": [3]},
        }
    )


@dataclass(frozen=True)
class MoondreamConfig:
    text: TextConfig = TextConfig()
    vision: VisionConfig = VisionConfig()
    region: RegionConfig = RegionConfig()
    tokenizer: TokenizerConfig = TokenizerConfig()

    @classmethod
    def from_dict(cls, config_dict: dict):
        text_config = TextConfig(**config_dict.get("text", {}))
        vision_config = VisionConfig(**config_dict.get("vision", {}))
        region_config = RegionConfig(**config_dict.get("region", {}))
        tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {}))
        return cls(
            text=text_config,
            vision=vision_config,
            region=region_config,
            tokenizer=tokenizer_config,
        )

    def to_dict(self):
        return {
            "text": self.text.__dict__,
            "vision": self.vision.__dict__,
            "region": self.region.__dict__,
            "tokenizer": self.tokenizer.__dict__,
        }


================================================
FILE: moondream/torch/hf_moondream.py
================================================
import torch
import torch.nn as nn

from transformers import PreTrainedModel, PretrainedConfig
from typing import Union

from .config import MoondreamConfig
from .moondream import MoondreamModel

# Files sometimes don't get loaded without these...
from .image_crops import *
from .vision import *
from .text import *
from .region import *
from .utils import *


def extract_question(text):
    prefix = "<image>\n\nQuestion: "
    suffix = "\n\nAnswer:"

    if text.startswith(prefix) and text.endswith(suffix):
        return text[len(prefix) : -len(suffix)]
    else:
        return None


class HfConfig(PretrainedConfig):
    _auto_class = "AutoConfig"
    model_type = "moondream3"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.config = {"skills": ["query", "caption", "detect", "point"]}


class HfMoondream(PreTrainedModel):
    _auto_class = "AutoModelForCausalLM"
    config_class = HfConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = MoondreamModel(
            MoondreamConfig.from_dict(config.config), setup_caches=False
        )
        self._is_kv_cache_setup = False

    def _setup_caches(self):
        if not self._is_kv_cache_setup:
            self.model._setup_caches()
            self._is_kv_cache_setup = True

    @property
    def encode_image(self):
        self._setup_caches()
        return self.model.encode_image

    @property
    def query(self):
        self._setup_caches()
        return self.model.query

    @property
    def caption(self):
        self._setup_caches()
        return self.model.caption

    @property
    def detect(self):
        self._setup_caches()
        return self.model.detect

    @property
    def point(self):
        self._setup_caches()
        return self.model.point

    @property
    def detect_gaze(self):
        self._setup_caches()
        return self.model.detect_gaze

    def answer_question(
        self,
        image_embeds,
        question,
        tokenizer=None,
        chat_history="",
        result_queue=None,
        max_new_tokens=256,
        **kwargs
    ):
        answer = self.query(image_embeds, question)["answer"].strip()

        if result_queue is not None:
            result_queue.put(answer)
        return answer

    def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
        answers = []
        for image, prompt in zip(images, prompts):
            answers.append(self.query(image, prompt)["answer"].strip())
        return answers

    def _unsupported_exception(self):
        raise NotImplementedError(
            "This method is not supported in the latest version of moondream. "
            "Consider upgrading to the updated API spec, or alternately pin "
            "to 'revision=2024-08-26'."
        )

    def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):
        """
        Function definition remains unchanged for backwards compatibility.
        Be aware that tokenizer, max_new_takens, and kwargs are ignored.
        """
        prompt_extracted = extract_question(prompt)
        if prompt_extracted is not None:
            answer = self.model.query(
                image=image_embeds, question=prompt_extracted, stream=False
            )["answer"]
        else:
            image_embeds = self.encode_image(image_embeds)
            prompt_tokens = torch.tensor(
                [self.model.tokenizer.encode(prompt).ids],
                device=self.device,
            )

            def generator():
                for token in self.model._generate_answer(
                    prompt_tokens,
                    image_embeds.kv_cache,
                    image_embeds.pos,
                    max_new_tokens,
                ):
                    yield token

            answer = "".join(list(generator()))

        return [answer]

    def get_input_embeddings(self) -> nn.Embedding:
        """
        Lazily wrap the raw parameter `self.model.text.wte` in a real
        `nn.Embedding` layer so that HF mix-ins recognise it.  The wrapper
        **shares** the weight tensor—no copy is made.
        """
        if not hasattr(self, "_input_embeddings"):
            self._input_embeddings = nn.Embedding.from_pretrained(
                self.model.text.wte,  # tensor created in text.py
                freeze=True,  # set to False if you need it trainable
            )
        return self._input_embeddings

    def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:
        """
        Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the
        embeddings and keeps everything tied to `self.model.text.wte`.
        """
        # 1. point the low-level parameter to the new weight matrix
        self.model.text.wte = value.weight
        # 2. keep a reference for get_input_embeddings()
        self._input_embeddings = value

    def input_embeds(
        self,
        input_ids: Union[torch.LongTensor, list, tuple],
        *,
        device: torch.device | None = None
    ) -> torch.FloatTensor:
        """
        Back-compat wrapper that turns token IDs into embeddings.

        Example:
            ids = torch.tensor([[1, 2, 3]])
            embeds = model.input_embeds(ids)      # (1, 3, hidden_dim)
        """
        if not torch.is_tensor(input_ids):
            input_ids = torch.as_tensor(input_ids)
        if device is not None:
            input_ids = input_ids.to(device)

        return self.get_input_embeddings()(input_ids)


================================================
FILE: moondream/torch/hf_release.py
================================================
import torch
import argparse

from .weights import load_weights_into_model
from .hf_moondream import HfConfig, HfMoondream

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-name", type=str, default="vikhyatk/moondream-next")
    parser.add_argument("--ckpt", type=str, required=True)
    args = parser.parse_args()

    config = HfConfig()
    model = HfMoondream(config)
    load_weights_into_model(args.ckpt, model.model)

    model.push_to_hub(args.model_name, config=config)


================================================
FILE: moondream/torch/image_crops.py
================================================
import math
import numpy as np
import torch

from typing import TypedDict

try:
    import pyvips

    HAS_VIPS = True
except:
    from PIL import Image

    HAS_VIPS = False


def select_tiling(
    height: int, width: int, crop_size: int, max_crops: int
) -> tuple[int, int]:
    """
    Determine the optimal number of tiles to cover an image with overlapping crops.
    """
    if height <= crop_size or width <= crop_size:
        return (1, 1)

    # Minimum required tiles in each dimension
    min_h = math.ceil(height / crop_size)
    min_w = math.ceil(width / crop_size)

    # If minimum required tiles exceed max_crops, return proportional distribution
    if min_h * min_w > max_crops:
        ratio = math.sqrt(max_crops / (min_h * min_w))
        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))

    # Perfect aspect-ratio tiles that satisfy max_crops
    h_tiles = math.floor(math.sqrt(max_crops * height / width))
    w_tiles = math.floor(math.sqrt(max_crops * width / height))

    # Ensure we meet minimum tile requirements
    h_tiles = max(h_tiles, min_h)
    w_tiles = max(w_tiles, min_w)

    # If we exceeded max_crops, scale down the larger dimension
    if h_tiles * w_tiles > max_crops:
        if w_tiles > h_tiles:
            w_tiles = math.floor(max_crops / h_tiles)
        else:
            h_tiles = math.floor(max_crops / w_tiles)

    return (max(1, h_tiles), max(1, w_tiles))


class OverlapCropOutput(TypedDict):
    crops: np.ndarray
    tiling: tuple[int, int]


def overlap_crop_image(
    image: np.ndarray,
    overlap_margin: int,
    max_crops: int,
    base_size: tuple[int, int] = (378, 378),
    patch_size: int = 14,
) -> OverlapCropOutput:
    """
    Process an image using an overlap-and-resize cropping strategy with margin handling.

    This function takes an input image and creates multiple overlapping crops with
    consistent margins. It produces:
    1. A single global crop resized to base_size
    2. Multiple overlapping local crops that maintain high resolution details
    3. A patch ordering matrix that tracks correspondence between crops

    The overlap strategy ensures:
    - Smooth transitions between adjacent crops
    - No loss of information at crop boundaries
    - Proper handling of features that cross crop boundaries
    - Consistent patch indexing across the full image

    Args:
        image (np.ndarray): Input image as numpy array with shape (H,W,C)
        base_size (tuple[int,int]): Target size for crops, default (378,378)
        patch_size (int): Size of patches in pixels, default 14
        overlap_margin (int): Margin size in patch units, default 4
        max_crops (int): Maximum number of crops allowed, default 12

    Returns:
        OverlapCropOutput: Dictionary containing:
            - crops: A numpy array containing the global crop of the full image (index 0)
                followed by the overlapping cropped regions (indices 1+)
            - tiling: Tuple of (height,width) tile counts
    """
    original_h, original_w = image.shape[:2]

    # Convert margin from patch units to pixels
    margin_pixels = patch_size * overlap_margin
    total_margin_pixels = margin_pixels * 2  # Both sides

    # Calculate crop parameters
    crop_patches = base_size[0] // patch_size  # patches per crop dimension
    crop_window_patches = crop_patches - (2 * overlap_margin)  # usable patches
    crop_window_size = crop_window_patches * patch_size  # usable size in pixels

    # Determine tiling
    tiling = select_tiling(
        original_h - total_margin_pixels,
        original_w - total_margin_pixels,
        crop_window_size,
        max_crops,
    )

    # Pre-allocate crops.
    n_crops = tiling[0] * tiling[1] + 1  # 1 = global crop
    crops = np.zeros(
        (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
    )

    # Resize image to fit tiling
    target_size = (
        tiling[0] * crop_window_size + total_margin_pixels,
        tiling[1] * crop_window_size + total_margin_pixels,
    )

    if HAS_VIPS:
        # Convert to vips for resizing
        vips_image = pyvips.Image.new_from_array(image)
        scale_x = target_size[1] / image.shape[1]
        scale_y = target_size[0] / image.shape[0]
        resized = vips_image.resize(scale_x, vscale=scale_y)
        image = resized.numpy()

        # Create global crop
        scale_x = base_size[1] / vips_image.width
        scale_y = base_size[0] / vips_image.height
        global_vips = vips_image.resize(scale_x, vscale=scale_y)
        crops[0] = global_vips.numpy()
    else:
        # Fallback to PIL
        pil_img = Image.fromarray(image)
        resized = pil_img.resize(
            (int(target_size[1]), int(target_size[0])),
            resample=Image.Resampling.LANCZOS,
        )
        image = np.asarray(resized)

        # Create global crop
        global_pil = pil_img.resize(
            (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS
        )
        crops[0] = np.asarray(global_pil)

    for i in range(tiling[0]):
        for j in range(tiling[1]):
            # Calculate crop coordinates
            y0 = i * crop_window_size
            x0 = j * crop_window_size

            # Extract crop with padding if needed
            y_end = min(y0 + base_size[0], image.shape[0])
            x_end = min(x0 + base_size[1], image.shape[1])

            crop_region = image[y0:y_end, x0:x_end]
            crops[
                1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
            ] = crop_region

    return {"crops": crops, "tiling": tiling}


def reconstruct_from_crops(
    crops: torch.Tensor,
    tiling: tuple[int, int],
    overlap_margin: int,
    patch_size: int = 14,
) -> torch.Tensor:
    """
    Reconstruct the original image from overlapping crops into a single seamless image.

    Takes a list of overlapping image crops along with their positional metadata and
    reconstructs them into a single coherent image by carefully stitching together
    non-overlapping regions. Handles both numpy arrays and PyTorch tensors.

    Args:
        crops: List of image crops as numpy arrays or PyTorch tensors with shape
            (H,W,C)
        tiling: Tuple of (height,width) indicating crop grid layout
        patch_size: Size in pixels of each patch, default 14
        overlap_margin: Number of overlapping patches on each edge, default 4

    Returns:
        Reconstructed image as numpy array or PyTorch tensor matching input type,
        with shape (H,W,C) where H,W are the original image dimensions
    """
    tiling_h, tiling_w = tiling
    crop_height, crop_width = crops[0].shape[:2]
    margin_pixels = overlap_margin * patch_size

    # Calculate output size (only adding margins once)
    output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels
    output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels

    reconstructed = torch.zeros(
        (output_h, output_w, crops[0].shape[2]),
        device=crops[0].device,
        dtype=crops[0].dtype,
    )

    for i, crop in enumerate(crops):
        tile_y = i // tiling_w
        tile_x = i % tiling_w

        # For each tile, determine which part to keep
        # Keep left margin only for first column
        x_start = 0 if tile_x == 0 else margin_pixels
        # Keep right margin only for last column
        x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels
        # Keep top margin only for first row
        y_start = 0 if tile_y == 0 else margin_pixels
        # Keep bottom margin only for last row
        y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels

        # Calculate where this piece belongs in the output
        out_x = tile_x * (crop_width - 2 * margin_pixels)
        out_y = tile_y * (crop_height - 2 * margin_pixels)

        # Place the piece
        reconstructed[
            out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end
        ] = crop[y_start:y_end, x_start:x_end]

    return reconstructed


================================================
FILE: moondream/torch/layers.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from typing import Literal, Optional

try:
    from torchao import quantize_
    from torchao.quantization import int4_weight_only
except ImportError:

    def quantize_(model, quant_mode):
        raise ImportError(
            "torchao is not installed. Please install it with `pip install torchao`."
        )

    def int4_weight_only(group_size):
        raise ImportError(
            "torchao is not installed. Please install it with `pip install torchao`."
        )


def gelu_approx(x):
    return F.gelu(x, approximate="tanh")


@dataclass
class LinearWeights:
    weight: torch.Tensor
    bias: torch.Tensor


def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
    return F.linear(x, w.weight, w.bias)


def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
    _step = W_q.shape[0]
    W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)
    W_r[:_step] = (W_q & 0b11110000) >> 4
    W_r[_step:] = W_q & 0b00001111
    W_r.sub_(zero).mul_(scale)
    return W_r.reshape(orig_shape)


class QuantizedLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        dtype: torch.dtype,
    ):
        # TODO: Take group_size as an input instead of hardcoding it here.
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.ParameterDict(
            {
                "packed": nn.Parameter(
                    torch.empty(
                        out_features * in_features // (128 * 2), 128, dtype=torch.uint8
                    ),
                    requires_grad=False,
                ),
                "scale": nn.Parameter(
                    torch.empty(out_features * in_features // 128, 1),
                    requires_grad=False,
                ),
                "zero_point": nn.Parameter(
                    torch.empty(out_features * in_features // 128, 1),
                    requires_grad=False,
                ),
            }
        )
        self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
        self.unpacked = False

    def unpack(self):
        if self.unpacked:
            return

        self.weight = nn.Parameter(
            dequantize_tensor(
                self.weight["packed"],
                self.weight["scale"],
                self.weight["zero_point"],
                (self.out_features, self.in_features),
                torch.bfloat16,
            )
        )
        with torch.device("meta"):
            self.linear = nn.Linear(
                self.in_features, self.out_features, dtype=torch.bfloat16
            )
        self.linear.weight = self.weight
        self.linear.bias = nn.Parameter(
            self.bias.to(torch.bfloat16), requires_grad=False
        )

        del self.weight, self.bias
        quantize_(self, int4_weight_only(group_size=128))
        self.unpacked = True
        torch.cuda.empty_cache()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.unpacked:
            self.unpack()
        return self.linear(x)


@dataclass
class LayerNormWeights:
    weight: torch.Tensor
    bias: torch.Tensor


def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
    return F.layer_norm(x, w.bias.shape, w.weight, w.bias)


@dataclass
class MLPWeights:
    fc1: LinearWeights
    fc2: LinearWeights
    act: Literal["gelu_approx"] = "gelu_approx"


def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:
    x0 = w.fc1(x)
    if lora is not None:
        x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"])
        x = x0 + x1
    else:
        x = x0

    x = gelu_approx(x)

    x0 = w.fc2(x)
    if lora is not None:
        x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"])
        x = x0 + x1
    else:
        x = x0

    return x


def moe_mlp(
    x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int
) -> torch.Tensor:
    B, T, C = x.shape
    x = x.reshape(-1, C)

    # Router computation
    router_logits = mlp_module.router(x)
    topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)
    topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)
    num_tokens, top_k = topk_idxs.shape

    if T == 1:
        w1_weight = mlp_module.fc1.weight
        w2_weight = mlp_module.fc2.weight

        # Flatten to process all token-expert pairs at once
        flat_idxs = topk_idxs.view(-1)  # [T*A]
        flat_weights = topk_weights.view(-1)  # [T*A]

        # Select expert weights
        w1_selected = w1_weight[flat_idxs]  # [T*A, H, D]
        w2_selected = w2_weight[flat_idxs]  # [T*A, D, H]

        # Expand input for all token-expert pairs
        x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C)  # [T*A, D]

        # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]
        x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(
            -1
        )  # [T*A, H]
        x1, g = x1_full.chunk(2, dim=-1)
        x1 = F.gelu(x1) * (g + 1)

        # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]
        expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1)  # [T*A, D]

        # Apply weights and reshape
        weighted_outs = expert_outs * flat_weights.unsqueeze(-1)  # [T*A, D]
        weighted_outs = weighted_outs.view(num_tokens, top_k, C)  # [T, A, D]

        # Sum over experts
        mlp_out = weighted_outs.sum(dim=1)  # [T, D]
        mlp_out = mlp_out.view(B, T, C)

        return mlp_out
    else:
        out = x.new_zeros(x.size())

        for expert_id in range(mlp_module.fc1.weight.shape[0]):
            token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)
            if token_pos.numel() == 0:
                continue

            x_tok = x.index_select(0, token_pos)
            gate_tok = topk_weights[token_pos, which_k]

            h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])
            h, g = h_full.chunk(2, dim=-1)
            h = F.gelu(h) * (g + 1)
            y = F.linear(h, mlp_module.fc2.weight[expert_id])

            y.mul_(gate_tok.unsqueeze(-1))
            out.index_add_(0, token_pos, y)

        return out.view(B, T, C)


@dataclass
class AttentionWeights:
    qkv: LinearWeights
    proj: LinearWeights


def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
    bsz, q_len, d_model = x.shape
    head_dim = d_model // n_heads

    q, k, v = [
        t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
        for t in linear(x, w.qkv).chunk(3, dim=-1)
    ]
    out = F.scaled_dot_product_attention(q, k, v)
    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)
    out = linear(out, w.proj)
    return out


================================================
FILE: moondream/torch/lora.py
================================================
import functools
import os
import shutil
import torch

from pathlib import Path
from urllib.request import Request, urlopen
from typing import Optional


def variant_cache_dir():
    hf_hub_cache = os.environ.get("HF_HUB_CACHE")
    if hf_hub_cache is not None:
        return Path(hf_hub_cache) / "md_variants"

    hf_home = os.environ.get("HF_HOME")
    if hf_home is not None:
        return Path(hf_home) / "hub" / "md_variants"

    return Path("~/.cache/huggingface/hub").expanduser() / "md_variants"


def cached_variant_path(variant_id: str):
    variant, *rest = variant_id.split("/", 1)
    step = rest[0] if rest else "final"

    cache_dir = variant_cache_dir() / variant
    os.makedirs(cache_dir, exist_ok=True)
    dest = cache_dir / f"{step}.pt"
    if dest.exists():
        return dest

    md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai")

    headers = {"User-Agent": "moondream-torch"}
    api_key = os.getenv("MOONDREAM_API_KEY")
    if api_key is not None:
        headers["X-Moondream-Auth"] = api_key

    req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers)
    with urlopen(req) as r, open(dest, "wb") as f:
        shutil.copyfileobj(r, f)
    return dest


def nest(flat):
    tree = {}
    for k, v in flat.items():
        parts = k.split(".")
        d = tree
        for p in parts[:-1]:
            d = d.setdefault(p, {})
        d[parts[-1]] = v
    return tree


@functools.lru_cache(maxsize=5)
def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"):
    if variant_id is None:
        return None

    state_dict = torch.load(
        cached_variant_path(variant_id), map_location=device, weights_only=True
    )

    # TODO: Move these into the training code that saves checkpoints...
    rename_rules = [
        ("text_model.transformer.h", "text.blocks"),
        (".mixer", ".attn"),
        (".out_proj", ".proj"),
        (".Wqkv", ".qkv"),
        (".parametrizations.weight.0", ""),
    ]
    new_state_dict = {}
    for key, tensor in state_dict.items():
        new_key = key
        for old, new in rename_rules:
            if old in new_key:
                new_key = new_key.replace(old, new)
        new_state_dict[new_key] = tensor

    return nest(new_state_dict)


================================================
FILE: moondream/torch/moondream.py
================================================
import torch
import torch.nn as nn
import random

from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List
from PIL import Image
from dataclasses import dataclass
from tokenizers import Tokenizer
from torch.nn.attention.flex_attention import create_block_mask

from .config import MoondreamConfig
from .image_crops import reconstruct_from_crops
from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model
from .text import build_text_model, text_encoder, lm_head, text_decoder
from .region import (
    decode_coordinate,
    encode_coordinate,
    decode_size,
    encode_size,
    encode_spatial_refs,
    SpatialRefs,
)
from .layers import QuantizedLinear
from .lora import variant_state_dict
from .utils import remove_outlier_points

ImageEncodingSettings = TypedDict(
    "ImageEncodingSettings",
    {"variant": str},
    total=False,
)

TextSamplingSettings = TypedDict(
    "TextSamplingSettings",
    {
        "max_tokens": int,
        "temperature": float,
        "top_p": float,
        "variant": str,
    },
    total=False,
)

ObjectSamplingSettings = TypedDict(
    "ObjectSamplingSettings",
    {"max_objects": int, "variant": str},
    total=False,
)


DEFAULT_MAX_TOKENS = 768
DEFAULT_TEMPERATURE = 0.5
DEFAULT_TOP_P = 0.9
DEFAULT_MAX_OBJECTS = 150


@dataclass(frozen=True)
class EncodedImage:
    pos: int
    caches: List[Tuple[torch.Tensor, torch.Tensor]]


class KVCache(nn.Module):

    def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
        super().__init__()
        cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
        self.register_buffer(
            "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
        )
        self.register_buffer(
            "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
        )

    def update(self, pos_ids, k, v):
        kout, vout = self.k_cache, self.v_cache
        kout[:, :, pos_ids, :] = k
        vout[:, :, pos_ids, :] = v
        return kout, vout


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def get_mask_mod(mask_mod, offset):
    def _mask_mod(b, h, q, kv):
        return mask_mod(b, h, q + offset, kv)

    return _mask_mod


class MoondreamModel(nn.Module):

    def __init__(
        self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True
    ):
        super().__init__()
        self.config = config

        self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1")
        self.vision = build_vision_model(config.vision, dtype)
        self.text = build_text_model(config.text, dtype)

        # Region Model
        linear_cls = (
            QuantizedLinear if config.region.group_size is not None else nn.Linear
        )
        self.region = nn.ModuleDict(
            {
                "coord_encoder": linear_cls(
                    config.region.coord_feat_dim, config.region.dim, dtype=dtype
                ),
                "coord_decoder": linear_cls(
                    config.region.dim, config.region.coord_out_dim, dtype=dtype
                ),
                "size_encoder": linear_cls(
                    config.region.size_feat_dim, config.region.dim, dtype=dtype
                ),
                "size_decoder": linear_cls(
                    config.region.dim, config.region.size_out_dim, dtype=dtype
                ),
            }
        )
        self.region.coord_features = nn.Parameter(
            torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
        )
        self.region.size_features = nn.Parameter(
            torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
        )

        attn_mask = torch.tril(
            torch.ones(
                1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool
            )
        )
        patch_w = config.vision.crop_size // config.vision.enc_patch_size
        prefix_attn_len = 1 + patch_w**2
        attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1
        self.register_buffer("attn_mask", attn_mask, persistent=False)

        self.use_flex_decoding = True
        self._causal_block_mask = None
        self._point_gen_indices = None

        # Initialize KV caches.
        if setup_caches:
            self._setup_caches()

    @property
    def causal_block_mask(self):
        # The things we do to deal with ZeroGPU...
        if self._causal_block_mask is None:
            self._causal_block_mask = create_block_mask(
                causal_mask,
                B=None,
                H=None,
                Q_LEN=self.config.text.max_context,
                KV_LEN=self.config.text.max_context,
            )
        return self._causal_block_mask

    @property
    def point_gen_indices(self):
        if self._point_gen_indices is None:
            self._point_gen_indices = torch.tensor(
                [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],
                device=self.device,
            )
        return self._point_gen_indices

    def _setup_caches(self):
        c = self.config.text
        for b in self.text.blocks:
            b.kv_cache = KVCache(
                c.n_heads,
                c.n_kv_heads,
                c.max_context,
                c.dim,
                device=self.device,
                dtype=self.vision.pos_emb.dtype,
            )

    @property
    def device(self):
        return self.vision.pos_emb.device

    def _vis_enc(self, x: torch.Tensor):
        return vision_encoder(x, self.vision, self.config.vision)

    def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
        return vision_projection(g, r, self.vision, self.config.vision)

    def _prefill(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        pos_ids: torch.Tensor,
        lora: Optional[torch.Tensor],
    ):
        return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)

    def _decode_one_tok(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        pos_ids: torch.Tensor,
        lora: Optional[torch.Tensor],
        lm_head_indices: Optional[torch.Tensor] = None,
    ):
        if self.use_flex_decoding:
            torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape")
            block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]
            mask = self.causal_block_mask[:, :, block_index]
            mask.seq_lengths = (1, mask.seq_lengths[1])
            mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])
        else:
            mask = None

        hidden = text_decoder(
            x,
            self.text,
            attn_mask,
            pos_ids,
            self.config.text,
            lora=lora,
            flex_block_mask_slice=mask,
        )
        logits = lm_head(hidden, self.text, indices=lm_head_indices)
        return logits, hidden

    def compile(self):
        for module in self.modules():
            if isinstance(module, QuantizedLinear):
                module.unpack()

        # Initialize lazy properties to avoid first-call overhead
        self.causal_block_mask
        self.point_gen_indices

        # TODO: vision_projection and _prefill is not being compiled
        self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)
        self._decode_one_tok = torch.compile(
            self._decode_one_tok, fullgraph=True, mode="reduce-overhead"
        )

        # Warm up compiled methods with dummy forward passes
        device = self.device
        dtype = self.vision.pos_emb.dtype
        with torch.no_grad():
            # Warmup vision encoder
            dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)
            self._vis_enc(dummy_crops)

            # Warmup _decode_one_tok (both normal and point generation modes)
            dummy_emb = torch.randn(
                1, 1, self.config.text.dim, device=device, dtype=dtype
            )
            dummy_mask = torch.ones(
                1, 1, self.config.text.max_context, device=device, dtype=torch.bool
            )
            dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)
            self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)
            self._decode_one_tok(
                dummy_emb,
                dummy_mask,
                dummy_pos_ids,
                None,
                lm_head_indices=self.point_gen_indices,
            )

    def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
        all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)

        torch._dynamo.mark_dynamic(all_crops, 0)

        outputs = self._vis_enc(all_crops)

        global_features = outputs[0]
        local_features = outputs[1:].view(
            -1,
            self.config.vision.enc_n_layers,
            self.config.vision.enc_n_layers,
            self.config.vision.enc_dim,
        )

        reconstructed = reconstruct_from_crops(
            local_features,
            tiling,
            patch_size=1,
            overlap_margin=self.config.vision.overlap_margin,
        )

        return self._vis_proj(global_features, reconstructed)

    def encode_image(
        self,
        image: Union[Image.Image, EncodedImage],
        settings: Optional[ImageEncodingSettings] = None,
    ) -> EncodedImage:
        if isinstance(image, EncodedImage):
            return image
        elif not isinstance(image, Image.Image):
            raise ValueError("image must be a PIL Image or EncodedImage")

        lora = (
            variant_state_dict(settings["variant"], device=self.device)
            if settings is not None and "variant" in settings
            else None
        )

        # Run through text model in addition to the vision encoder, to minimize
        # re-computation if multiple queries are performed on this image.
        with torch.inference_mode():
            img_emb = self._run_vision_encoder(image)
            bos_emb = text_encoder(
                torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),
                self.text,
            )
            inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
            mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
            pos_ids = torch.arange(
                inputs_embeds.size(1), dtype=torch.long, device=self.device
            )
            self._prefill(inputs_embeds, mask, pos_ids, lora)

        return EncodedImage(
            pos=inputs_embeds.size(1),
            caches=[
                (
                    b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
                    b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
                )
                for b in self.text.blocks
            ],
        )

    def _apply_top_p(self, probs: torch.Tensor, top_p: float):
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        mask = probs_sum - probs_sort > top_p
        probs_sort[mask] = 0.0
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        next_probs = torch.zeros_like(probs)
        next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)
        return next_probs

    def _prefill_prompt(
        self,
        prompt_tokens: torch.Tensor,
        pos: int,
        temperature: float,
        top_p: float,
        spatial_refs: Optional[SpatialRefs] = None,
        attn_mask: Optional[torch.Tensor] = None,
        lora: Optional[dict] = None,
    ):
        with torch.inference_mode():
            prompt_emb = text_encoder(prompt_tokens, self.text)

            if spatial_refs:
                encoded_refs = encode_spatial_refs(spatial_refs, self.region)
                prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (
                    encoded_refs["coords"]
                )
                if encoded_refs["sizes"] is not None:
                    prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (
                        encoded_refs["sizes"]
                    )

            torch._dynamo.mark_dynamic(prompt_emb, 1)

            if attn_mask is None:
                attn_mask = self.attn_mask

            mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]
            pos_ids = torch.arange(
                pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device
            )
            hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)
            logits_BV = lm_head(hidden_BC, self.text)

            if temperature == 0:
                next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)
            else:
                probs = torch.softmax(logits_BV / temperature, dim=-1)
                probs = self._apply_top_p(probs, top_p)
                next_token = torch.multinomial(probs, num_samples=1)

        pos = pos + prompt_emb.size(1)
        return logits_BV, hidden_BC, next_token, pos

    def _generate_reasoning(
        self,
        prompt_tokens,
        pos,
        settings: Optional[TextSamplingSettings] = None,
        spatial_refs: Optional[SpatialRefs] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[int, str, List[dict]]:
        max_tokens = (
            settings.get("max_tokens", DEFAULT_MAX_TOKENS)
            if settings
            else DEFAULT_MAX_TOKENS
        )
        temperature = (
            settings.get("temperature", DEFAULT_TEMPERATURE)
            if settings
            else DEFAULT_TEMPERATURE
        )
        lora = (
            variant_state_dict(settings["variant"], device=self.device)
            if settings is not None and "variant" in settings
            else None
        )

        top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
        eos_id = self.config.tokenizer.answer_id

        _, last_hidden_BC, next_token, pos = self._prefill_prompt(
            prompt_tokens,
            pos,
            temperature,
            top_p,
            spatial_refs,
            attn_mask=attn_mask,
            lora=lora,
        )

        text_token_chunks = [[]]
        grounding_chunks = [[]]

        mask = torch.zeros(
            1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
        )
        mask[:, :, :pos] = 1
        pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
        generated_tokens = 0

        while (
            next_token_id := next_token.item()
        ) != eos_id and generated_tokens < max_tokens:
            if (
                next_token_id == self.config.tokenizer.start_ground_points_id
                or next_token_id == self.config.tokenizer.end_ground_id
            ):
                text_token_chunks.append([])
                grounding_chunks.append([])

            text_token_chunks[-1].append(next_token_id)

            with torch.inference_mode():
                if next_token_id == self.config.tokenizer.coord_id:
                    coord_logits = decode_coordinate(last_hidden_BC, self.region)
                    coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)
                    grounding_chunks[-1].append(coord.item())

                    next_emb = encode_coordinate(
                        coord.to(dtype=coord_logits.dtype), self.region
                    ).unsqueeze(0)
                else:
                    next_emb = text_encoder(next_token, self.text)

                mask[:, :, pos], pos_ids[0] = 1, pos

                logits_BV, last_hidden_BC = self._decode_one_tok(
                    next_emb, mask, pos_ids, lora
                )
                logits_BV[:, self.config.tokenizer.eos_id] = float("-inf")
                logits_BV[:, self.config.tokenizer.size_id] = float("-inf")

                pos += 1

                if temperature == 0:
                    next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)  # (1, 1)
                else:
                    probs = torch.softmax(logits_BV / temperature, dim=-1)  # (1, V)
                    probs = self._apply_top_p(probs, top_p)
                    next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)

                generated_tokens += 1

        text_chunks = [
            self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks
        ]
        text = "".join(text_chunks)

        start_idx = 0
        grounding = []
        for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):
            if len(grounding_chunk) > 1:
                points = []
                for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):
                    points.append((grounding_chunk[i], grounding_chunk[i + 1]))
                grounding.append(
                    {
                        "start_idx": start_idx,
                        "end_idx": start_idx + len(text_chunk),
                        "points": points,
                    }
                )
            start_idx += len(text_chunk)

        return pos, text, grounding

    def _generate_answer(
        self,
        prompt_tokens: torch.Tensor,
        pos: int,
        settings: Optional[TextSamplingSettings] = None,
        spatial_refs: Optional[SpatialRefs] = None,
        eos_id: Optional[int] = None,
        attn_mask: Optional[torch.Tensor] = None,
    ):
        max_tokens = (
            settings.get("max_tokens", DEFAULT_MAX_TOKENS)
            if settings
            else DEFAULT_MAX_TOKENS
        )
        temperature = (
            settings.get("temperature", DEFAULT_TEMPERATURE)
            if settings
            else DEFAULT_TEMPERATURE
        )
        top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P
        eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id
        lora = (
            variant_state_dict(settings["variant"], device=self.device)
            if settings is not None and "variant" in settings
            else None
        )

        _, _, next_token, pos = self._prefill_prompt(
            prompt_tokens,
            pos,
            temperature,
            top_p,
            spatial_refs,
            attn_mask=attn_mask,
            lora=lora,
        )

        def generator(next_token, pos):
            mask = torch.zeros(
                1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
            )
            mask[:, :, :pos] = 1
            pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)
            generated_tokens = 0

            # For properly handling token streaming with Unicode
            token_cache = []
            print_len = 0

            while (
                next_token_id := next_token.item()
            ) != eos_id and generated_tokens < max_tokens:
                # Add token to our cache
                token_cache.append(next_token_id)

                # Decode all tokens collected so far
                text = self.tokenizer.decode(token_cache)

                # After a newline, we flush the cache completely
                if text.endswith("\n"):
                    printable_text = text[print_len:]
                    token_cache = []
                    print_len = 0
                    if printable_text:
                        yield printable_text
                # If the last token is a CJK character, we can safely print it
                elif len(text) > 0 and _is_cjk_char(ord(text[-1])):
                    printable_text = text[print_len:]
                    print_len += len(printable_text)
                    if printable_text:
                        yield printable_text
                # Otherwise, only yield up to the last space to avoid cutting words
                else:
                    last_space_idx = text.rfind(" ", print_len)
                    if last_space_idx >= print_len:
                        printable_text = text[print_len : last_space_idx + 1]
                        print_len += len(printable_text)
                        if printable_text:
                            yield printable_text

                with torch.inference_mode():
                    next_emb = text_encoder(next_token, self.text)
                    mask[:, :, pos], pos_ids[0] = 1, pos

                    logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)
                    logits_BV[:, self.config.tokenizer.answer_id] = float("-inf")

                    pos += 1

                    if temperature == 0:
                        next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(
                            1
                        )  # (1, 1)
                    else:
                        probs = torch.softmax(logits_BV / temperature, dim=-1)  # (1, V)
                        probs = self._apply_top_p(probs, top_p)
                        next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)

                    generated_tokens += 1

            # Flush any remaining text in the cache
            if token_cache:
                text = self.tokenizer.decode(token_cache)
                printable_text = text[print_len:]
                if printable_text:
                    yield printable_text

        return generator(next_token, pos)

    def query(
        self,
        image: Optional[Union[Image.Image, EncodedImage]] = None,
        question: str = None,
        reasoning: bool = True,
        spatial_refs: Optional[SpatialRefs] = None,
        stream: bool = False,
        settings: Optional[TextSamplingSettings] = None,
    ):
        if self.config.tokenizer.templates["query"] is None:
            raise NotImplementedError("Model does not support querying.")

        if question is None:
            raise ValueError("question must be provided.")

        if spatial_refs and image is None:
            raise ValueError("spatial_refs can only be used with an image.")

        attn_mask = self.attn_mask
        if image is not None:
            image = self.encode_image(image, settings)
            self.load_encoded_image(image)
            pos = image.pos
            prompt_toks = self.config.tokenizer.templates["query"]["prefix"]
        else:
            self._setup_caches()
            pos = 0
            prompt_toks = [
                self.config.tokenizer.bos_id
            ] + self.config.tokenizer.templates["query"]["prefix"]
            max_context = self.config.text.max_context
            attn_mask = torch.tril(
                torch.ones(1, 1, max_context, max_context, dtype=torch.bool)
            ).to(self.device)

        spatial_toks = []
        if spatial_refs:
            for ref in spatial_refs:
                coord_id = self.config.tokenizer.coord_id
                size_id = self.config.tokenizer.size_id
                if len(ref) == 2:
                    spatial_toks.extend([coord_id, coord_id])
                else:
                    spatial_toks.extend([coord_id, coord_id, size_id])

        prompt_tokens = [
            prompt_toks + spatial_toks + self.tokenizer.encode(question).ids
        ]

        if reasoning:
            prompt_tokens[0] += [self.config.tokenizer.thinking_id]
            prompt_tokens = torch.tensor(prompt_tokens, device=self.device)
            pos, reasoning_text, reasoning_grounding = self._generate_reasoning(
                prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
            )
            prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]]
            reasoning_dict = {
                "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding}
            }
        else:
            prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"]
            reasoning_dict = {}

        prompt_tokens = torch.tensor(prompt_tokens, device=self.device)

        def generator():
            for token in self._generate_answer(
                prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask
            ):
                yield token

        if stream:
            return {**reasoning_dict, "answer": generator()}
        else:
            return {**reasoning_dict, "answer": "".join(list(generator()))}

    def load_encoded_image(self, encoded_image: EncodedImage):
        for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
            b.kv_cache.k_cache[:, :, : k.size(2), :] = k
            b.kv_cache.v_cache[:, :, : v.size(2), :] = v

    def caption(
        self,
        image: Union[Image.Image, EncodedImage],
        length: Literal["normal", "short", "long"] = "normal",
        stream: bool = False,
        settings: Optional[TextSamplingSettings] = None,
    ):
        if self.config.tokenizer.templates["caption"] is None:
            raise NotImplementedError("Model does not support captioning.")
        if length not in self.config.tokenizer.templates["caption"]:
            raise ValueError(f"Model does not support caption length '{length}'.")

        image = self.encode_image(image, settings)
        self.load_encoded_image(image)

        prompt_tokens = torch.tensor(
            [self.config.tokenizer.templates["caption"][length]], device=self.device
        )

        def generator():
            for token in self._generate_answer(prompt_tokens, image.pos, settings):
                yield token

        if stream:
            return {"caption": generator()}
        else:
            return {"caption": "".join(list(generator()))}

    def _generate_points(
        self,
        hidden: torch.Tensor,
        next_token: torch.Tensor,
        pos: int,
        include_size: bool = True,
        max_objects: int = DEFAULT_MAX_OBJECTS,
        lora: Optional[dict] = None,
    ):
        out = []
        mask = torch.zeros(
            1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool
        )
        mask[:, :, :pos] = 1
        pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)

        with torch.inference_mode():
            while (
                next_token.item() != self.config.tokenizer.eos_id
                and len(out) < max_objects
            ):
                x_logits = decode_coordinate(hidden, self.region)
                x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)
                next_emb = encode_coordinate(
                    x_center.to(dtype=x_logits.dtype), self.region
                ).unsqueeze(0)

                # Decode y-coordinate
                mask[:, :, pos], pos_ids[0] = 1, pos
                _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
                pos += 1
                y_logits = decode_coordinate(hidden, self.region)
                y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)
                next_emb = encode_coordinate(
                    y_center.to(dtype=y_logits.dtype), self.region
                ).unsqueeze(0)

                # Decode size
                if include_size:
                    mask[:, :, pos], pos_ids[0] = 1, pos
                    logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)
                    pos += 1
                    size_logits = decode_size(hidden, self.region)

                    # Get bin indices from the logits
                    w_bin = torch.argmax(size_logits[0], dim=-1)
                    h_bin = torch.argmax(size_logits[1], dim=-1)

                    # Convert from bin indices to actual size values using the inverse of the log-scale mapping
                    # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)
                    w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)
                    h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)

                    next_emb = (
                        encode_size(
                            torch.tensor(
                                [w, h], device=self.device, dtype=size_logits.dtype
                            ),
                            self.region,
                        )
                        .unsqueeze(0)
                        .unsqueeze(0)
                    )

                    # Add object
                    out.append(
                        {
                            "x_min": x_center.item() - w.item() / 2,
                            "y_min": y_center.item() - h.item() / 2,
                            "x_max": x_center.item() + w.item() / 2,
                            "y_max": y_center.item() + h.item() / 2,
                        }
                    )
                else:
                    out.append({"x": x_center.item(), "y": y_center.item()})

                # Decode next token (x-coordinate, or eos)
                mask[:, :, pos], pos_ids[0] = 1, pos
                logits, hidden = self._decode_one_tok(
                    next_emb,
                    mask,
                    pos_ids,
                    lora,
                    lm_head_indices=self.point_gen_indices,
                )
                pos += 1
                # Map back: index 0 -> coord_id, index 1 -> eos_id
                next_token_idx = torch.argmax(logits, dim=-1)
                next_token = self.point_gen_indices[next_token_idx]

        return out

    def detect(
        self,
        image: Union[Image.Image, EncodedImage],
        object: str,
        settings: Optional[ObjectSamplingSettings] = None,
    ):
        if self.config.tokenizer.templates["detect"] is None:
            raise NotImplementedError("Model does not support object detection.")

        image = self.encode_image(image, settings)
        self.load_encoded_image(image)

        prompt_tokens = torch.tensor(
            [
                self.config.tokenizer.templates["detect"]["prefix"]
                + self.tokenizer.encode(" " + object).ids
                + self.config.tokenizer.templates["detect"]["suffix"]
            ],
            device=self.device,
        )

        lora = (
            variant_state_dict(settings["variant"], device=self.device)
            if settings is not None and "variant" in settings
            else None
        )

        _, hidden, next_token, pos = self._prefill_prompt(
            prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
        )
        hidden = hidden[:, -1:, :]

        max_objects = (
            settings.get("max_objects", DEFAULT_MAX_OBJECTS)
            if settings
            else DEFAULT_MAX_OBJECTS
        )
        objects = self._generate_points(
            hidden,
            next_token,
            pos,
            include_size=True,
            max_objects=max_objects,
            lora=lora,
        )

        return {"objects": objects}

    def point(
        self,
        image: Union[Image.Image, EncodedImage],
        object: str,
        settings: Optional[ObjectSamplingSettings] = None,
    ):
        if self.config.tokenizer.templates["point"] is None:
            raise NotImplementedError("Model does not support pointing.")

        image = self.encode_image(image, settings)
        self.load_encoded_image(image)

        prompt_tokens = torch.tensor(
            [
                self.config.tokenizer.templates["point"]["prefix"]
                + self.tokenizer.encode(" " + object).ids
                + self.config.tokenizer.templates["point"]["suffix"]
            ],
            device=self.device,
        )

        lora = (
            variant_state_dict(settings["variant"], device=self.device)
            if settings is not None and "variant" in settings
            else None
        )

        _, hidden, next_token, pos = self._prefill_prompt(
            prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora
        )
        hidden = hidden[:, -1:, :]

        max_objects = (
            settings.get("max_objects", DEFAULT_MAX_OBJECTS)
            if settings
            else DEFAULT_MAX_OBJECTS
        )
        objects = self._generate_points(
            hidden,
            next_token,
            pos,
            include_size=False,
            max_objects=max_objects,
            lora=lora,
        )

        return {"points": objects}

    def _detect_gaze(
        self,
        image: EncodedImage,
        source: Tuple[float, float],
        force_detect: bool = False,
    ):
        with torch.inference_mode():
            before_emb = text_encoder(
                torch.tensor(
                    [self.tokenizer.encode("\n\nPoint:").ids], device=self.device
                ),
                self.text,
            )
            after_emb = text_encoder(
                torch.tensor(
                    [self.tokenizer.encode(" gaze\n\n").ids], device=self.device
                ),
                self.text,
            )
            x_emb = encode_coordinate(
                torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),
                self.region,
            )
            y_emb = encode_coordinate(
                torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),
                self.region,
            )

            prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)

            self.load_encoded_image(image)

            mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]
            pos_ids = torch.arange(
                image.pos,
                image.pos + prompt_emb.size(1),
                dtype=torch.long,
                device=self.device,
            )
            hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)
            logits = lm_head(hidden, self.text)
            next_token = torch.argmax(logits, dim=-1)
            pos = image.pos + prompt_emb.size(1)
            hidden = hidden[:, -1:, :]

            if force_detect:
                next_token = torch.tensor([[0]], device=self.device)

            if next_token.item() == self.config.tokenizer.eos_id:
                return None

            gaze = self._generate_points(
                hidden, next_token, pos, include_size=False, max_objects=1
            )
            return gaze[0]

    def detect_gaze(
        self,
        image: Union[Image.Image, EncodedImage],
        eye: Optional[Tuple[float, float]] = None,
        face: Optional[Dict[str, float]] = None,
        unstable_settings: Dict[str, Any] = {},
    ):
        if "force_detect" in unstable_settings:
            force_detect = unstable_settings["force_detect"]
        else:
            force_detect = False

        if "prioritize_accuracy" in unstable_settings:
            prioritize_accuracy = unstable_settings["prioritize_accuracy"]
        else:
            prioritize_accuracy = False

        if not prioritize_accuracy:
            if eye is None:
                raise ValueError("eye must be provided when prioritize_accuracy=False")
            image = self.encode_image(image)
            return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)}
        else:
            if (
                not isinstance(image, Image.Image)
                and "flip_enc_img" not in unstable_settings
            ):
                raise ValueError(
                    "image must be a PIL Image when prioritize_accuracy=True, "
                    "or flip_enc_img must be provided"
                )
            if face is None:
                raise ValueError("face must be provided when prioritize_accuracy=True")

            encoded_image = self.encode_image(image)
            if (
                isinstance(image, Image.Image)
                and "flip_enc_img" not in unstable_settings
            ):
                flipped_pil = image.copy()
                flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)
                encoded_flipped_image = self.encode_image(flipped_pil)
            else:
                encoded_flipped_image = unstable_settings["flip_enc_img"]

            N = 10

            detections = [
                self._detect_gaze(
                    encoded_image,
                    (
                        random.uniform(face["x_min"], face["x_max"]),
                        random.uniform(face["y_min"], face["y_max"]),
                    ),
                    force_detect=force_detect,
                )
                for _ in range(N)
            ]
            detections = [
                (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None
            ]
            flipped_detections = [
                self._detect_gaze(
                    encoded_flipped_image,
                    (
                        1 - random.uniform(face["x_min"], face["x_max"]),
                        random.uniform(face["y_min"], face["y_max"]),
                    ),
                    force_detect=force_detect,
                )
                for _ in range(N)
            ]
            detections.extend(
                [
                    (1 - gaze["x"], gaze["y"])
                    for gaze in flipped_detections
                    if gaze is not None
                ]
            )

            if len(detections) < N:
                return {"gaze": None}

            detections = remove_outlier_points(detections)
            mean_gaze = (
                sum(gaze[0] for gaze in detections) / len(detections),
                sum(gaze[1] for gaze in detections) / len(detections),
            )

            return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}}


def _is_cjk_char(cp):
    """Checks whether CP is the codepoint of a CJK character."""
    # This defines a "chinese character" as anything in the CJK Unicode block:
    # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
    if (
        (cp >= 0x4E00 and cp <= 0x9FFF)
        or (cp >= 0x3400 and cp <= 0x4DBF)
        or (cp >= 0x2F800 and cp <= 0x2FA1F)
    ):
        return True
    return False


================================================
FILE: moondream/torch/region.py
================================================
import torch
import torch.nn as nn
import math

from typing import List, Tuple, Union

SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]


def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
    """
    Applies Fourier feature mapping to input tensor x using frequency matrix w. This
    projects inputs through sinusoidal functions to create higher dimensional features
    that help mitigate spectral bias - the tendency of neural networks to learn
    low-frequency functions more easily than high-frequency ones. By explicitly
    mapping inputs to higher frequencies through sin/cos transformations, we enable
    better learning of fine details and higher frequency patterns.

    Args:
        x: Input tensor to transform
        w: Matrix of frequencies for the Fourier features transformation

    Returns:
        Concatenated cosine and sine transformed features as a tensor
    """
    f = 2 * math.pi * x @ w
    return torch.cat([f.cos(), f.sin()], dim=-1)


def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input a tensor containing a single float coordinate value (x or y)
    and encodes it into hidden states for input to the text model.

    Args:
        coord: Tensor with single float coordinate value

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.coord_encoder(fourier_features(coord, w.coord_features))


def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs a single logit
    representing either an x or y coordinate prediction.

    Args:
        hidden_state: The final hidden state tensor from the text model.

    Returns:
        A single logit representing the predicted coordinate value (x or y)
    """
    return w.coord_decoder(hidden_state)


def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes a tensor containing width and height values and encodes them into
    hidden states for input to the text model.

    Args:
        size: Tensor with two floats for width and height

    Returns:
        Encoded hidden states tensor for input to text model
    """
    return w.size_encoder(fourier_features(size, w.size_features))


def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
    """
    Takes as input the last hidden state from the text model and outputs logits
    for 1024 bins representing width and height in log-scale.

    The bins are distributed according to the formula:
    bin = (log2(size) + 10.0) / 10.0 * 1023.0
    where size values are clamped to be at least 1/1024.

    To convert from bin back to size:
    size = 2^((bin / 1023.0) * 10.0 - 10.0)

    Args:
        hidden_state: The final hidden state tensor from the text model.

    Returns:
        A tensor containing logits for 1024 bins for width and height.
        Shape is (2, 1024) where the first dimension corresponds to width and height.
    """
    return w.size_decoder(hidden_state).view(2, -1)


def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:
    """
    Takes a list of spatial references (points or regions) and encodes them into
    hidden states for input to the text model.

    Args:
        spatial_refs: List of spatial references (points or boxes)
            - Points are represented as normalized (x, y) tuples
            - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples

    Returns:
        {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]}
    """
    coords, sizes = [], []
    for ref in spatial_refs:
        if len(ref) == 2:
            coords.append(ref[0])
            coords.append(ref[1])
        else:
            x_c = (ref[0] + ref[2]) / 2
            y_c = (ref[1] + ref[3]) / 2
            width = ref[2] - ref[0]
            height = ref[3] - ref[1]
            coords.append(x_c)
            coords.append(y_c)
            sizes.append([width, height])

    coords = torch.tensor(
        coords, device=w.coord_features.device, dtype=w.coord_features.dtype
    ).view(-1, 1)
    coords = encode_coordinate(coords, w)

    if sizes:
        sizes = torch.tensor(
            sizes, device=w.size_features.device, dtype=w.size_features.dtype
        )
        sizes = encode_size(sizes, w)
    else:
        sizes = None

    return {"coords": coords, "sizes": sizes}


================================================
FILE: moondream/torch/rope.py
================================================
# Ethically sourced from https://github.com/xjdr-alt/entropix

import torch


def precompute_freqs_cis(
    dim: int,
    end: int,
    theta: float = 1500000.0,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
    t = torch.arange(end, dtype=dtype).unsqueeze(1)
    freqs = t * freqs.unsqueeze(0)
    freqs = torch.exp(1j * freqs)
    return torch.stack([freqs.real, freqs.imag], dim=-1)


def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
    position_ids: torch.Tensor,
    num_heads: int,
    rot_dim: int = 32,
    interleave: bool = False,
) -> torch.Tensor:
    assert rot_dim == freqs_cis.shape[-2] * 2
    assert num_heads == x.shape[1]

    x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]

    if interleave:
        xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
        xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
    else:
        d_q = x_rot.shape[-1] // 2
        xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]

    freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)
    freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)

    # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)

    return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)


================================================
FILE: moondream/torch/sample.py
================================================
import argparse
import json
import os
import torch

from PIL import Image, ImageDraw
from tqdm import tqdm

from .weights import load_weights_into_model
from .moondream import MoondreamModel, MoondreamConfig

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image", "-i", type=str, required=True)
    parser.add_argument("--prompt", "-p", type=str, required=True)
    parser.add_argument("--model", "-m", type=str, required=True)
    parser.add_argument("--config", "-c", type=str, default=None)
    parser.add_argument("--max-tokens", "-t", type=int, default=200)
    parser.add_argument("--sampler", "-s", type=str, default="greedy")
    parser.add_argument("--benchmark", "-b", action="store_true")
    args = parser.parse_args()

    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    print(f"Using device: {device}")

    # Load model.
    if args.config is not None:
        with open(args.config, "r") as f:
            config = json.load(f)
        config = MoondreamConfig.from_dict(config)
    else:
        config = MoondreamConfig()
    model = MoondreamModel(config)
    load_weights_into_model(args.model, model)
    model.to(device, dtype=torch.bfloat16)
    model.compile()

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()

    # Encode image.
    image_path = args.image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Image not found at {image_path}")
    image = Image.open(image_path)

    if not args.benchmark:
        encoded_image = model.encode_image(image)

        # Text query
        text_query = "What is the capital of Washington, USA? Answer in JSON format."
        print("Query:", text_query)
        text_response = model.query(None, text_query, reasoning=True, stream=True)
        print("Reasoning:", text_response["reasoning"])
        for t in text_response["answer"]:
            print(t, end="", flush=True)
        print()
        print()

        # Short caption
        print("Caption: short")
        for t in model.caption(encoded_image, "short", stream=True)["caption"]:
            print(t, end="", flush=True)
        print()
        print()

        # Regular caption
        print("Caption: normal")
        for t in model.caption(encoded_image, "normal", stream=True)["caption"]:
            print(t, end="", flush=True)
        print()
        print()

        # Long caption
        print("Caption: long")
        for t in model.caption(encoded_image, "long", stream=True)["caption"]:
            print(t, end="", flush=True)
        print()
        print()

        # Query
        print("Query:", args.prompt)
        for t in model.query(
            encoded_image,
            args.prompt,
            stream=True,
            settings={"variant": "geoguesser_lora_only"},
        )["answer"]:
            print(t, end="", flush=True)
        print()
        print()

        # Query (reasoning)
        reasoning_prompt = "How many sesame seeds are on the burger?"
        print("Query (reasoning):", reasoning_prompt)
        resp = model.query(encoded_image, reasoning_prompt, reasoning=True, stream=True)
        print("Reasoning:", resp["reasoning"])
        for t in resp["answer"]:
            print(t, end="", flush=True)
        print()
        print()

        # Detect
        obj = "hand"
        print(f"Detect: {obj}")
        objs = model.detect(encoded_image, obj)["objects"]
        print(f"Found {len(objs)}")
        print()
        draw = ImageDraw.Draw(image)
        for obj in objs:
            x_min, y_min, x_max, y_max = (
                obj["x_min"] * image.width,
                obj["y_min"] * image.height,
                obj["x_max"] * image.width,
                obj["y_max"] * image.height,
            )
            draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)
        image.save("detect.jpg")

        # Spatial query
        if len(objs) > 0:
            print("Spatial query: What is this?")
            for t in model.query(
                encoded_image,
                "What is this?",
                spatial_refs=[
                    [
                        (obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"])
                        for obj in objs
                    ][0]
                ],
                stream=True,
            )["answer"]:
                print(t, end="", flush=True)
            print()
            print()

        # Point
        obj = "ear"
        print(f"Point: {obj}")
        points = model.point(encoded_image, obj)["points"]
        print(f"Found {len(points)}")
        draw = ImageDraw.Draw(image)
        for point in points:
            x, y = point["x"] * image.width, point["y"] * image.height
            draw.ellipse([x - 5, y - 5, x + 5, y + 5], fill="red")
        image.save("point.jpg")
        print()
        print()

        # Spatial query
        if len(objs) > 0:
            for o in ["hand", "ear", "face"]:
                for k in [(objs, "hand"), (points, "ear")]:
                    print(f"Spatial query: Is this a {o}? ({k[1]})")
                    for t in model.query(
                        encoded_image,
                        f"Is this a {o}?",
                        spatial_refs=[
                            [
                                (
                                    (
                                        obj["x_min"],
                                        obj["y_min"],
                                        obj["x_max"],
                                        obj["y_max"],
                                    )
                                    if "x_min" in obj
                                    else (obj["x"], obj["y"])
                                )
                                for obj in k[0]
                            ][0]
                        ],
                    )["answer"]:
                        print(t, end="", flush=True)
                    print()

        # Detect gaze
        model.detect_gaze(encoded_image, (0.5, 0.5))
    elif model.device.type != "mps":
        # Warmup runs
        for _ in tqdm(range(5), desc="Warmup"):
            encoded_image = model.encode_image(image)
            for _ in model.query(encoded_image, args.prompt, stream=True)["answer"]:
                pass

        # Benchmark runs
        encode_times = []
        query_speeds = []
        for i in tqdm(range(10), desc="Benchmark"):
            # Measure encode time
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            encoded_image = model.encode_image(image)
            end.record()
            torch.cuda.synchronize()
            encode_time = start.elapsed_time(end)
            encode_times.append(encode_time)

            # Measure query speed
            tokens = []
            query_start = torch.cuda.Event(enable_timing=True)
            query_end = torch.cuda.Event(enable_timing=True)
            query_start.record()
            for t in model.query(encoded_image, args.prompt, stream=True)["answer"]:
                tokens.append(t)
            query_end.record()
            torch.cuda.synchronize()
            query_time = query_start.elapsed_time(query_end)
            tokens_per_sec = len(tokens) / (query_time / 1000.0)  # Convert ms to s
            query_speeds.append(tokens_per_sec)

        # Print results
        print("\nBenchmark Results (10 runs):")
        print("Image Encoding Time (ms):")
        print(f"  Mean: {sum(encode_times)/len(encode_times):.2f}")
        print(f"  Min:  {min(encode_times):.2f}")
        print(f"  Max:  {max(encode_times):.2f}")
        print("\nQuery Speed (tokens/sec):")
        print(f"  Mean: {sum(query_speeds)/len(query_speeds):.2f}")
        print(f"  Min:  {min(query_speeds):.2f}")
        print(f"  Max:  {max(query_speeds):.2f}")

        print(torch.cuda.memory_summary(abbreviated=False))


================================================
FILE: moondream/torch/text.py
================================================
import torch
import torch.nn as nn

from torch.nn import functional as F
from torch.nn.attention.flex_attention import flex_attention
from typing import Optional

from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp
from .rope import apply_rotary_emb, precompute_freqs_cis
from .config import TextConfig


def text_encoder(input_ids: torch.Tensor, w: nn.Module):
    return F.embedding(input_ids, w.wte)


def attn(
    x: torch.Tensor,
    w: nn.Module,
    freqs_cis: torch.Tensor,
    kv_cache: nn.Module,
    attn_mask: torch.Tensor,
    n_heads: int,
    n_kv_heads: int,
    position_ids: torch.Tensor,
    lora: Optional[dict] = None,
    flex_block_mask_slice=None,
):
    bsz, q_len, d_model = x.shape
    head_dim = d_model // n_heads

    qkv_out = w.qkv(x)  # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
    if lora is not None:
        qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"])
    q_dim = n_heads * head_dim
    kv_dim = n_kv_heads * head_dim
    q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)

    q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
    k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
    v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

    if hasattr(w, "tau") and w.tau is not None:
        tok_feat = F.gelu(qkv_out)
        tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1)
        tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1)
        pos = position_ids.to(q.dtype) + 1
        tau_pos = 1 + (
            torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5
        )  # (H,S)
        tau_q = (tok_q + tau_pos[None]).unsqueeze(-1)  # (B,H,S,1)
        tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)
        q = q * tau_q
        v = v * tau_v

    q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
    k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)

    if kv_cache is not None:
        k, v = kv_cache.update(position_ids, k, v)

    if flex_block_mask_slice is not None:
        torch._assert(n_heads == n_kv_heads, "gqa not supported yet")
        out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)
    else:
        out = F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads
        )

    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)

    out0 = w.proj(out)
    if lora is not None:
        out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"])
        out = out0 + out1
    else:
        out = out0

    return out


def text_decoder(
    x: torch.Tensor,
    w: nn.Module,
    attn_mask: torch.Tensor,
    position_ids: torch.Tensor,
    config: TextConfig,
    lora: Optional[dict] = None,
    flex_block_mask_slice=None,
):
    for i, block in enumerate(w.blocks):
        if lora is not None:
            layer_lora = lora["text"]["blocks"][str(i)]
            mlp_lora = layer_lora["mlp"]
            attn_lora = layer_lora["attn"]
        else:
            mlp_lora = None
            attn_lora = None

        l_in = layer_norm(x, block.ln)
        l_attn = attn(
            l_in,
            block.attn,
            freqs_cis=w.freqs_cis,
            kv_cache=block.kv_cache,
            attn_mask=attn_mask,
            n_heads=config.n_heads,
            n_kv_heads=config.n_kv_heads,
            position_ids=position_ids,
            lora=attn_lora,
            flex_block_mask_slice=flex_block_mask_slice,
        )

        if config.moe is not None and i >= config.moe.start_layer:
            l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)
        else:
            l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)

        x = x + l_attn + l_mlp

    return x


def lm_head(
    hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None
):
    hidden_BC = hidden_BTC[:, -1, :]
    hidden_BC = layer_norm(hidden_BC, w.post_ln)
    if indices is not None:
        # Only compute logits for specified token indices
        logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]
    else:
        logits = w.lm_head(hidden_BC)
    return logits


def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
    return nn.ModuleDict(
        {
            "fc1": linear_cls(d_model, d_ffn, dtype=dtype),
            "fc2": linear_cls(d_ffn, d_model, dtype=dtype),
        }
    )


def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
    # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)
    return nn.ModuleDict(
        {
            "router": nn.Linear(d_model, n_experts, dtype=dtype),
            "fc1": nn.ParameterDict(
                {
                    "weight": nn.Parameter(
                        torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)
                    )
                }
            ),
            "fc2": nn.ParameterDict(
                {
                    "weight": nn.Parameter(
                        torch.empty(n_experts, d_model, d_ffn, dtype=dtype)
                    )
                }
            ),
        }
    )


def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
    qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))
    linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear

    text = nn.ModuleDict(
        {
            "blocks": nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "ln": nn.LayerNorm(config.dim, dtype=dtype),
                            "attn": nn.ModuleDict(
                                {
                                    "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype),
                                    "proj": linear_cls(
                                        config.dim, config.dim, dtype=dtype
                                    ),
                                    "tau": nn.ParameterDict(
                                        {
                                            "wq": nn.Parameter(
                                                torch.empty(
                                                    config.n_heads, qkv_dim, dtype=dtype
                                                )
                                            ),
                                            "wv": nn.Parameter(
                                                torch.empty(
                                                    config.n_heads, qkv_dim, dtype=dtype
                                                )
                                            ),
                                            "alpha": nn.Parameter(
                                                torch.empty(config.n_heads, dtype=dtype)
                                            ),
                                        }
                                    ),
                                }
                            ),
                            "mlp": (
                                build_moe_mlp(
                                    config.dim,
                                    config.moe.expert_inner_dim,
                                    config.moe.num_experts,
                                    dtype,
                                )
                                if config.moe is not None
                                and layer_idx >= config.moe.start_layer
                                else build_dense_mlp(
                                    config.dim, config.ff_dim, dtype, linear_cls
                                )
                            ),
                        }
                    )
                    for layer_idx in range(config.n_layers)
                ]
            ),
            "post_ln": nn.LayerNorm(config.dim, dtype=dtype),
            "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype),
        }
    )
    text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))
    text.register_buffer(
        "freqs_cis",
        precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),
        persistent=False,
    )

    return text


================================================
FILE: moondream/torch/utils.py
================================================
import numpy as np


def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):
    """
    Robust outlier detection for list of (x,y) tuples.
    Only requires numpy.

    Args:
        points_tuples: list of (x,y) tuples
        k_nearest: number of neighbors to consider
        threshold: multiplier for median distance

    Returns:
        list: filtered list of (x,y) tuples with outliers removed
        list: list of booleans indicating which points were kept (True = kept)
    """
    points = np.array(points_tuples)
    n_points = len(points)

    # Calculate pairwise distances manually
    dist_matrix = np.zeros((n_points, n_points))
    for i in range(n_points):
        for j in range(i + 1, n_points):
            # Euclidean distance between points i and j
            dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))
            dist_matrix[i, j] = dist
            dist_matrix[j, i] = dist

    # Get k nearest neighbors' distances
    k = min(k_nearest, n_points - 1)
    neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]
    avg_neighbor_dist = np.mean(neighbor_distances, axis=1)

    # Calculate mask using median distance
    median_dist = np.median(avg_neighbor_dist)
    mask = avg_neighbor_dist <= threshold * median_dist

    # Return filtered tuples and mask
    filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]
    return filtered_tuples


================================================
FILE: moondream/torch/vision.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from typing import Union, Tuple
from PIL import Image

from .layers import attn, layer_norm, mlp
from .image_crops import overlap_crop_image
from .config import VisionConfig

if torch.backends.mps.is_available():
    # Non-divisible input sizes are not implemented on MPS device yet.
    # https://github.com/pytorch/pytorch/issues/96056
    def adaptive_avg_pool2d(input, output_size):
        return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")

else:
    adaptive_avg_pool2d = F.adaptive_avg_pool2d

DeviceLike = Union[str, torch.device, int]


def prepare_crops(
    image: Image.Image, config: VisionConfig, device: DeviceLike
) -> Tuple[torch.Tensor, Tuple[int, int]]:
    np_image = np.array(image.convert("RGB"))
    overlap_crops = overlap_crop_image(
        np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
    )
    all_crops = overlap_crops["crops"]
    all_crops = np.transpose(all_crops, (0, 3, 1, 2))
    all_crops = (
        torch.from_numpy(all_crops)
        .to(device=device, dtype=torch.bfloat16)
        .div_(255.0)
        .sub_(0.5)
        .div_(0.5)
    )
    return all_crops, overlap_crops["tiling"]


def create_patches(x, patch_size):
    # Original shape: [B, C, H, W]
    B, C, H, W = x.shape
    P1 = P2 = patch_size

    # Step 1: Split H and W dimensions into patches
    # [B, C, H/P1, P1, W/P2, P2]
    x = x.reshape(B, C, H // P1, P1, W // P2, P2)

    # Step 2: Rearrange dimensions to match target shape
    # [B, H/P1, W/P2, C, P1, P2]
    x = x.permute(0, 2, 4, 1, 3, 5)

    # Step 3: Combine dimensions to get final shape
    # [B, (H/P1)*(W/P2), C*P1*P2]
    x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)

    return x


def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
    x = create_patches(input_BCHW, config.enc_patch_size)

    x = w.patch_emb(x)
    x = x + w.pos_emb
    for block in w.blocks:
        x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
        x = x + mlp(layer_norm(x, block.ln2), block.mlp)
    x = layer_norm(x, w.post_ln)

    return x


def vision_projection(
    global_features: torch.Tensor,
    reconstructed: torch.Tensor,
    w: nn.Module,
    config: VisionConfig,
):
    reconstructed = reconstructed.permute(2, 0, 1)
    reconstructed = adaptive_avg_pool2d(
        reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
    )
    reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
    final_features = torch.cat([global_features, reconstructed], dim=-1)
    return mlp(final_features, w.proj_mlp)


def build_vision_model(config: VisionConfig, dtype: torch.dtype):
    patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
    grid_size = config.crop_size // config.enc_patch_size
    num_patches = grid_size * grid_size

    vision = nn.ModuleDict(
        {
            "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
            "blocks": nn.ModuleList(
                [
                    nn.ModuleDict(
                        {
                            "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
                            "attn": nn.ModuleDict(
                                {
                                    "qkv": nn.Linear(
                                        config.enc_dim, 3 * config.enc_dim, dtype=dtype
                                    ),
                                    "proj": nn.Linear(
                                        config.enc_dim, config.enc_dim, dtype=dtype
                                    ),
                                }
                            ),
                            "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
                            "mlp": nn.ModuleDict(
                                {
                                    "fc1": nn.Linear(
                                        config.enc_dim, config.enc_ff_dim, dtype=dtype
                                    ),
                                    "fc2": nn.Linear(
                                        config.enc_ff_dim, config.enc_dim, dtype=dtype
                                    ),
                                }
                            ),
                        }
                    )
                    for _ in range(config.enc_n_layers)
                ]
            ),
            "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
            "proj_mlp": nn.ModuleDict(
                {
                    "fc1": nn.Linear(
                        config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
                    ),
                    "fc2": nn.Linear(
                        config.proj_inner_dim, config.proj_out_dim, dtype=dtype
                    ),
                }
            ),
        }
    )
    vision.pos_emb = nn.Parameter(
        torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
    )
    return vision


================================================
FILE: moondream/torch/weights.py
================================================
import safetensors
import torch
import torch.nn as nn

from contextlib import contextmanager
from typing import Callable, List


@contextmanager
def safetensors_open(safetensors_file: str):
    """
    Simplify interfacing with safetensors files. Eliminates the need to ignore
    type errors when using the `safe_open` function.
    """
    with safetensors.safe_open(
        safetensors_file, framework="pt"
    ) as st:  # pyright: ignore

        def get_tensor(name: str) -> torch.Tensor:
            return st.get_tensor(name)

        def get_keys() -> List[str]:
            return st.keys()

        get_tensor.keys = get_keys

        yield get_tensor


def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:
    """Internal function to load weights using a tensor getter function."""
    model = model.to(dtype=torch.bfloat16)

    vision = model.vision
    region = model.region
    weight_map = {
        "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[
            "patch_emb"
        ].weight,
        "vision_encoder.encoder.model.visual.patch_embed.linear.bias": vision[
            "patch_emb"
        ].bias,
        "vision_encoder.encoder.model.visual.pos_embed": vision.pos_emb,
        "vision_encoder.encoder.model.visual.norm.weight": vision["post_ln"].weight,
        "vision_encoder.encoder.model.visual.norm.bias": vision["post_ln"].bias,
        "vision_encoder.projection.mlp.fc1.weight": vision["proj_mlp"]["fc1"].weight,
        "vision_encoder.projection.mlp.fc1.bias": vision["proj_mlp"]["fc1"].bias,
        "vision_encoder.projection.mlp.fc2.weight": vision["proj_mlp"]["fc2"].weight,
        "vision_encoder.projection.mlp.fc2.bias": vision["proj_mlp"]["fc2"].bias,
        "text_model.transformer.embd.wte.weight": model.text.wte,
        "text_model.lm_head.ln.weight": model.text["post_ln"].weight,
        "text_model.lm_head.ln.bias": model.text["post_ln"].bias,
        "text_model.lm_head.linear.weight": model.text["lm_head"].weight,
        "text_model.lm_head.linear.bias": model.text["lm_head"].bias,
        "region_model.coordinate_encoder.weight": region["coord_encoder"].weight,
        "region_model.coordinate_encoder.bias": region["coord_encoder"].bias,
        "region_model.coordinate_head.weight": region["coord_decoder"].weight,
        "region_model.coordinate_head.bias": region["coord_decoder"].bias,
        "region_model.size_encoder.weight": region["size_encoder"].weight,
        "region_model.size_encoder.bias": region["size_encoder"].bias,
        "region_model.size_head.weight": region["size_decoder"].weight,
        "region_model.size_head.bias": region["size_decoder"].bias,
    }

    for i in range(len(model.vision["blocks"])):
        prefix = f"vision_encoder.encoder.model.visual.blocks.{i}"
        blk = model.vision["blocks"][i]
        weight_map.update(
            {
                f"{prefix}.norm1.weight": blk["ln1"].weight,
                f"{prefix}.norm1.bias": blk["ln1"].bias,
                f"{prefix}.norm2.weight": blk["ln2"].weight,
                f"{prefix}.norm2.bias": blk["ln2"].bias,
                f"{prefix}.attn.qkv.weight": blk["attn"]["qkv"].weight,
                f"{prefix}.attn.qkv.bias": blk["attn"]["qkv"].bias,
                f"{prefix}.attn.proj.weight": blk["attn"]["proj"].weight,
                f"{prefix}.attn.proj.bias": blk["attn"]["proj"].bias,
                f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
                f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
                f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
                f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
            }
        )

    for i in range(len(model.text["blocks"])):
        prefix = f"text_model.transformer.h.{i}"
        blk = model.text["blocks"][i]
        is_moe = hasattr(blk.mlp, "router")
        weight_map.update(
            {
                f"{prefix}.ln.weight": blk["ln"].weight,
                f"{prefix}.ln.bias": blk["ln"].bias,
                f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight,
                f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias,
                f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight,
                f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias,
                f"{prefix}.tau_wq": blk["attn"]["tau"]["wq"],
                f"{prefix}.tau_wv": blk["attn"]["tau"]["wv"],
                f"{prefix}.tau_alpha": blk["attn"]["tau"]["alpha"],
            }
        )
        if is_moe:
            weight_map.update(
                {
                    f"{prefix}.gate.weight": blk["mlp"]["router"].weight,
                    f"{prefix}.gate.bias": blk["mlp"]["router"].bias,
                    f"{prefix}.mlp.experts.weight": blk["mlp"]["fc1"].weight,
                    f"{prefix}.mlp.output_experts.weight": blk["mlp"]["fc2"].weight,
                }
            )
        else:
            weight_map.update(
                {
                    f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight,
                    f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias,
                    f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight,
                    f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias,
                }
            )

    for key, tensor in weight_map.items():
        tensor.data.copy_(get_tensor(key))

    region.coord_features.data.copy_(
        get_tensor("region_model.coordinate_features.weight").T
    )
    region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T)


def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:
    """Load weights from a safetensors file into a MoondreamModel instance."""
    with safetensors_open(weights_file) as get_tensor:
        if (
            "vision.blocks.0.attn.proj.bias" in get_tensor.keys()
            or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys()
        ):
            with safetensors_open(weights_file) as get_tensor:
                tensors = {
                    k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys()
                }
                model.load_state_dict(tensors, strict=False)
        else:
            # Wrap the get_tensor function to handle key normalization
            name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()}
            _load_weights(
                lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model
            )


def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
    """Load weights from a PyTorch file into a MoondreamModel instance."""
    device = str(torch.empty(0).device)
    tensors = torch.load(weights_file, map_location=device, weights_only=True)
    if "vision.blocks.0.attn.proj.bias" in tensors.keys():
        missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)
        print("Missing keys:", missing_keys)
        print("Unexpected keys:", unexpected_keys)
    else:
        tensors = {
            k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16)
            for k, v in tensors.items()
        }
        _load_weights(lambda x: tensors[x], model)


def load_weights_into_model(weights_file: str, model: nn.Module) -> None:
    """
    Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.

    Args:
        weights_file: Path to weights file (either .safetensors or .pt)
        model: MoondreamModel instance to load weights into
    """
    if weights_file.endswith(".safetensors"):
        load_weights_from_safetensors(weights_file, model)
    else:
        load_weights_from_pt(weights_file, model)

    # Make all parameters contiguous
    for param in model.parameters():
        param.data = param.data.contiguous()


================================================
FILE: notebooks/RepEng.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook shows how to compute control vectors to steer moondream's behavior\n",
    "in fun and interesting ways. To learn more about control vectors and representation\n",
    "engineering check out [Theia's blog post on the topic](https://vgel.me/posts/representation-engineering/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "from tqdm import tqdm\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "from sklearn.decomposition import PCA\n",
    "from IPython.display import display, HTML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"vikhyatk/moondream2\")\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    \"vikhyatk/moondream2\", trust_remote_code=True,\n",
    "    torch_dtype=torch.float16, device_map={\"\": \"cuda\"}\n",
    ")\n",
    "\n",
    "# We will only be using the images, so it doesn't really matter what\n",
    "# dataset we use here.\n",
    "dataset = load_dataset(\"vikhyatk/lnqa\", streaming=True)[\"train\"]\n",
    "\n",
    "def hidden_states(enc_img, prompt):\n",
    "    with torch.no_grad():\n",
    "        inputs_embeds = model.input_embeds(prompt, enc_img, tokenizer)\n",
    "        hidden_states = model.text_model.generate(\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            max_new_tokens=128,\n",
    "            pad_token_id=tokenizer.eos_token_id,\n",
    "            eos_token_id=tokenizer.eos_token_id,\n",
    "            return_dict_in_generate=True,\n",
    "            output_hidden_states=True,\n",
    "            do_sample=True,\n",
    "            temperature=0.5\n",
    "        ).hidden_states[1:]\n",
    "    return [torch.stack([hs.view(-1, 2048) for hs in h[1:]]).cpu() for h in hidden_states]\n",
    "\n",
    "class LayerWrapper(torch.nn.Module):\n",
    "    def __init__(self, og_layer, control_vectors, scale=4.2):\n",
    "        super().__init__()\n",
    "        self.og_layer = og_layer\n",
    "        self.control_vectors = control_vectors\n",
    "        self.scale = scale\n",
    "\n",
    "    def forward(self, *args, **kwargs):\n",
    "        layer_outputs = self.og_layer(*args, **kwargs)\n",
    "        layer_outputs = (layer_outputs[0] + self.scale * self.control_vectors, *layer_outputs[1:])\n",
    "        return layer_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "negative_prompt = \"<image>\\n\\nQuestion: Describe this image.\\n\\nAnswer:\"\n",
    "positive_prompt = \"<image>\\n\\nQuestion: What is the meaning of life?\\n\\nAnswer:\"\n",
    "\n",
    "# This can be lowered without noticeable loss in quality. Feel free to drop it to\n",
    "# IMAGES_PER_CONTROL=50 and SAMPLES_PER_IMAGE=2 if it's taking too long.\n",
    "IMAGES_PER_CONTROL = 200\n",
    "SAMPLES_PER_IMAGE = 5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 200/200 [37:09<00:00, 11.15s/it]\n"
     ]
    }
   ],
   "source": [
    "# This is not very efficient, batching would speed things up a lot.\n",
    "# But eh, works for a quick demo.\n",
    "\n",
    "hs_dataset = [[] for _ in range(24)]\n",
    "\n",
    "for i, sample in tqdm(enumerate(dataset), total=IMAGES_PER_CONTROL):\n",
    "    if i >= IMAGES_PER_CONTROL:\n",
    "        break\n",
    "    image = sample[\"image\"]\n",
    "    enc_img = model.encode_image(image)\n",
    "    for _ in range(SAMPLES_PER_IMAGE):\n",
    "        phs = hidden_states(enc_img, positive_prompt)\n",
    "        nhs = hidden_states(enc_img, negative_prompt)\n",
    "        t_max = min(len(phs), len(nhs))\n",
    "        for t in range(t_max):\n",
    "            phs_t = phs[t]\n",
    "            nhs_t = nhs[t]\n",
    "            for j in range(24):\n",
    "                hs_dataset[j].append(phs_t[j])\n",
    "                hs_dataset[j].append(nhs_t[j])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24/24 [02:30<00:00,  6.26s/it]\n"
     ]
    }
   ],
   "source": [
    "control_vectors = []\n",
    "\n",
    "for i in tqdm(range(24)):\n",
    "    layer_hiddens = torch.stack(hs_dataset[i])\n",
    "\n",
    "    layer_centers = (layer_hiddens[::2] + layer_hiddens[1::2]) / 2\n",
    "    relative_layer_hiddens = layer_hiddens\n",
    "    relative_layer_hiddens[::2] -= layer_centers\n",
    "    relative_layer_hiddens[1::2] -= layer_centers\n",
    "\n",
    "    train = relative_layer_hiddens - relative_layer_hiddens.mean(axis=0, keepdims=True)\n",
    "    train = train.view(-1, 2048).cpu().numpy()\n",
    "    pca_model = PCA(n_components=1, whiten=False).fit(train)\n",
    "    directions = pca_model.components_.astype(np.float32).squeeze(axis=0)\n",
    "\n",
    "    projected_hiddens = (layer_hiddens.cpu().numpy() @ directions) / np.linalg.norm(directions)\n",
    "\n",
    "    positive_smaller_mean = np.mean(\n",
    "        [\n",
    "            projected_hiddens[i] < projected_hiddens[i + 1]\n",
    "            for i in range(0, len(hs_dataset[i]), 2)\n",
    "        ]\n",
    "    )\n",
    "    positive_larger_mean = np.mean(\n",
    "        [\n",
    "            projected_hiddens[i] > projected_hiddens[i + 1]\n",
    "            for i in range(0, len(hs_dataset[i]), 2)\n",
    "        ]\n",
    "    )\n",
    "    if positive_smaller_mean > positive_larger_mean:  # type: ignore\n",
    "        directions *= -1\n",
    "\n",
    "    control_vectors.append(torch.tensor(directions, device=\"cuda\", dtype=torch.float16))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "<table style=\"border: 1px solid black; border-collapse: collapse;\">\n",
       "    <tr>\n",
       "        <th style=\"border: 1px solid black; padding: 8px;\">Image</th>\n",
       "        <th style=\"border: 1px solid black; padding: 8px;\">Text</th>\n",
       "    </tr>\n",
       "\n",
       "    <tr>\n",
       "        <td style=\"border: 1px solid black; padding: 8px;\">\n",
       "            <img src=\"../assets/demo-1.jpg\" style=\"width:250px;\">\n",
       "        </td>\n",
       "        <td style=\"border: 1px solid black; padding: 8px; text-align: left\">\n",
       "            <strong>Describe this image.</strong><br><br>\n",
       "            In this image, we will explore the concept of happiness through an anthropomorphic interpretation. We'll look at it from a spiritual or religious perspective by examining what brings joy and fulfillment to human beings. This can be approached in various ways depending on individual beliefs and experiences.<br><br>To understand happiness, one might consider factors such as personal goals, relationships, health, wealth, spirituality, ethics, and more. It is subjective and varies greatly among individuals based on their values and life experiences. However, many people often seek happiness in different aspects like achieving success, finding inner peace, or living a meaningful life.<br><br>In the context of the provided statements about happiness, some may find happiness in material possessions (e.of), while others may seek happiness through relationships, creativity, or even simple pleasures that bring satisfaction and contentment. Ultimately, happiness is a deeply personal journey for each person to define and pursue according to their own beliefs and values.\n",
       "        </td>\n",
       "    </tr>\n",
       "    \n",
       "    <tr>\n",
       "        <td style=\"border: 1px solid black; padding: 8px;\">\n",
       "            <img src=\"../assets/demo-2.jpg\" style=\"width:250px;\">\n",
       "        </td>\n",
       "        <td style=\"border: 1px solid black; padding: 8px; text-align: left\">\n",
       "            <strong>What is this?</strong><br><br>\n",
       "            In the context of this image, a server or processor is an essential component for computing and running various applications on computers. A Processor (or CPU) can refer to any device that executes tasks according to specific programming requirements.<br><br>In this particular scenario, we are referring to advanced technologies like virtual machines, artificial intelligence, machine learning, etc., which require powerful computing systems to function effectively.<br><br>For example, in the case of AI research, researchers develop and test theories using sophisticated computer models and simulations. These concepts may involve analyzing vast amounts of data, exploring ethical questions, understanding existence, or even developing new knowledge about life itself.<br><br>In summary, when people talk about \"the meaning\" or \"purpose,\" they often refer to these advanced concepts as well. It's subjective and varies from person to person based on their beliefs, values, and experiences.\n",
       "        </td>\n",
       "    </tr>\n",
       "    \n",
       "    <tr>\n",
       "        <td style=\"border: 1px solid black; padding: 8px;\">\n",
       "            <img src=\"../assets/demo-2.jpg\" style=\"width:250px;\">\n",
       "        </td>\n",
       "        <td style=\"border: 1px solid black; padding: 8px; text-align: left\">\n",
       "            <strong>What color is the couch?</strong><br><br>\n",
       "            The couch in the image is described as \"black.\" However, without more information or context from different sources, it's difficult to determine its actual color. It could be any of those things like comfort, aesthetics, personal preferences, etc., which can vary among individuals.\n",
       "        </td>\n",
       "    </tr>\n",
       "    </table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "prompts = [\n",
    "    (\"../assets/demo-1.jpg\", \"Describe this image.\"),\n",
    "    (\"../assets/demo-2.jpg\", \"What is this?\"),\n",
    "    (\"../assets/demo-2.jpg\", \"What color is the couch?\"),\n",
    "]\n",
    "data = []\n",
    "\n",
    "def run_model(img_path, prompt, scale=4.2):\n",
    "    og_h = model.text_model.transformer.h\n",
    "    model.text_model.transformer.h = torch.nn.ModuleList([\n",
    "        LayerWrapper(layer, vector, scale) for layer, vector in zip(og_h, control_vectors)\n",
    "    ])\n",
    "    answer = model.answer_question(\n",
    "        model.encode_image(Image.open(img_path)), prompt, tokenizer,\n",
    "        repetition_penalty=1.2, temperature=0.1, do_sample=True,\n",
    "        length_penalty=1.2\n",
    "    )\n",
    "    model.text_model.transformer.h = og_h\n",
    "    return answer\n",
    "\n",
    "for img_path, prompt in prompts:\n",
    "    answer = run_model(img_path, prompt)\n",
    "    data.append({\"prompt\": prompt, \"answer\": answer.replace(\"\\n\", \"<br>\"), \"image\": img_path})\n",
    "\n",
    "html_table = \"\"\"\n",
    "<table style=\"border: 1px solid black; border-collapse: collapse;\">\n",
    "    <tr>\n",
    "        <th style=\"border: 1px solid black; padding: 8px;\">Image</th>\n",
    "        <th style=\"border: 1px solid black; padding: 8px;\">Text</th>\n",
    "    </tr>\n",
    "\"\"\"\n",
    "\n",
    "for item in data:\n",
    "    html_table += f\"\"\"\n",
    "    <tr>\n",
    "        <td style=\"border: 1px solid black; padding: 8px;\">\n",
    "            <img src=\"{item['image']}\" style=\"width:250px;\">\n",
    "        </td>\n",
    "        <td style=\"border: 1px solid black; padding: 8px; text-align: left\">\n",
    "            <strong>{item['prompt']}</strong><br><br>\n",
    "            {item['answer']}\n",
    "        </td>\n",
    "    </tr>\n",
    "    \"\"\"\n",
    "\n",
    "html_table += \"</table>\"\n",
    "\n",
    "# Display the HTML table\n",
    "display(HTML(html_table))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: recipes/gaze-detection-video/.gitignore
================================================
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Virtual Environment
venv/
ENV/

# IDE
.idea/
.vscode/
*.swp
*.swo

# Project specific
# input/*
# !input/.gitkeep
# output/*
# !output/.gitkeep
# temp/*
# !temp/.gitkeep

# Model files
*.pt
*.pth
*.ckpt

# Logs
*.log

# OS specific
.DS_Store
Thumbs.db


================================================
FILE: recipes/gaze-detection-video/README.md
================================================
# Gaze Detection Video Processor

> **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client
> libraries once they become available for this version.

## Table of Contents

- [Overview](#overview)
- [Sample Output](#sample-output)
- [Features](#features)
- [Prerequisites](#prerequisites)
- [Installation](#installation)
  - [Linux/macOS Installation](#linuxmacos-installation)
  - [Windows Installation](#windows-installation)
- [Usage](#usage)
- [Output](#output)
- [Troubleshooting](#troubleshooting)
- [Performance Notes](#performance-notes)
- [Dependencies](#dependencies)
- [Model Details](#model-details)
- [License](#license)

## Overview

This project uses the Moondream 2B model to detect faces and their gaze directions in videos. It processes videos frame by frame, visualizing face detections and gaze directions.

## Sample Output

|              Input Video              |              Processed Output               |
| :-----------------------------------: | :-----------------------------------------: |
| ![Input Video](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-input-sample.gif?raw=true) | ![Processed Output](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-output-sample.gif?raw=true) |

## Features

- Face detection in video frames
- Gaze direction tracking
- Real-time visualization with:
  - Colored bounding boxes for faces
  - Gradient lines showing gaze direction
  - Gaze target points
- Supports multiple faces per frame
- Processes all com
Download .txt
gitextract_g04l91uk/

├── .github/
│   └── workflows/
│       ├── pylint.yml
│       └── test.yml
├── .gitignore
├── LICENSE
├── README.md
├── batch_generate_example.py
├── gradio_demo.py
├── moondream/
│   ├── __init__.py
│   ├── config/
│   │   ├── config_md05.json
│   │   └── config_md2.json
│   ├── eval/
│   │   ├── chartqa.py
│   │   ├── coco_map.py
│   │   ├── countbenchqa.py
│   │   ├── docvqa.py
│   │   ├── eval_all.py
│   │   ├── gazefollow.py
│   │   ├── mmstar.py
│   │   ├── naturalbench.py
│   │   ├── pope.py
│   │   ├── realworldqa.py
│   │   ├── tallyqa.py
│   │   ├── textvqa.py
│   │   ├── utils.py
│   │   └── waste_detection.py
│   └── torch/
│       ├── config.py
│       ├── hf_moondream.py
│       ├── hf_release.py
│       ├── image_crops.py
│       ├── layers.py
│       ├── lora.py
│       ├── moondream.py
│       ├── region.py
│       ├── rope.py
│       ├── sample.py
│       ├── text.py
│       ├── utils.py
│       ├── vision.py
│       └── weights.py
├── notebooks/
│   └── RepEng.ipynb
├── recipes/
│   ├── gaze-detection-video/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── gaze-detection-video.py
│   │   ├── input/
│   │   │   └── .gitkeep
│   │   ├── output/
│   │   │   └── .gitkeep
│   │   ├── requirements.txt
│   │   └── temp/
│   │       └── .gitkeep
│   ├── promptable-content-moderation/
│   │   ├── .gitignore
│   │   ├── README.md
│   │   ├── app.py
│   │   ├── deep_sort_integration.py
│   │   ├── main.py
│   │   ├── packages.txt
│   │   ├── persistence.py
│   │   ├── requirements.txt
│   │   ├── video_visualization.py
│   │   └── visualization.py
│   └── promptable-video-redaction/
│       ├── .gitignore
│       ├── README.md
│       ├── app.py
│       ├── main.py
│       ├── packages.txt
│       └── requirements.txt
├── requirements.txt
├── sample.py
├── tests/
│   └── test_image_crops.py
└── webcam_gradio_demo.py
Download .txt
SYMBOL INDEX (212 symbols across 38 files)

FILE: gradio_demo.py
  function answer_question (line 35) | def answer_question(img, prompt):
  function extract_floats (line 55) | def extract_floats(text):
  function extract_bbox (line 65) | def extract_bbox(text):
  function process_answer (line 73) | def process_answer(img, answer):

FILE: moondream/eval/chartqa.py
  function relaxed_correctness (line 16) | def relaxed_correctness(
  function eval_chartqa (line 59) | def eval_chartqa(model, debug=False):

FILE: moondream/eval/coco_map.py
  function calculate_iou (line 116) | def calculate_iou(
  function calculate_map (line 132) | def calculate_map(
  function get_total_map (line 210) | def get_total_map(results_by_label, frequency_by_label):
  function eval_coco_map (line 222) | def eval_coco_map(model, iou_threshold=0.5, debug=False):

FILE: moondream/eval/countbenchqa.py
  function eval_countbenchqa (line 14) | def eval_countbenchqa(model, debug=False):

FILE: moondream/eval/docvqa.py
  function get_anls (line 14) | def get_anls(s1, s2):
  function eval_docvqa (line 22) | def eval_docvqa(model, debug=False):

FILE: moondream/eval/eval_all.py
  function create_model (line 22) | def create_model(ckpt_path):
  function eval_all (line 30) | def eval_all(model, skip=[]):

FILE: moondream/eval/gazefollow.py
  function eval_gazefollow (line 12) | def eval_gazefollow(model, debug=False):

FILE: moondream/eval/mmstar.py
  function eval_mmstar (line 13) | def eval_mmstar(model, debug=False):

FILE: moondream/eval/naturalbench.py
  function eval_naturalbench (line 10) | def eval_naturalbench(model, debug=False):

FILE: moondream/eval/pope.py
  function evaluate_pope (line 11) | def evaluate_pope(model, debug=False):

FILE: moondream/eval/realworldqa.py
  function eval_realworldqa (line 12) | def eval_realworldqa(model, debug=False):

FILE: moondream/eval/tallyqa.py
  function eval_tallyqa (line 14) | def eval_tallyqa(model, debug=False):

FILE: moondream/eval/textvqa.py
  function eval_textvqa (line 15) | def eval_textvqa(model, debug=False):

FILE: moondream/eval/utils.py
  class VQAScorer (line 5) | class VQAScorer:
    method __init__ (line 6) | def __init__(self):
    method process_punctuation (line 175) | def process_punctuation(self, inText: str) -> str:
    method process_digit_article (line 188) | def process_digit_article(self, inText: str) -> str:
    method process_answer (line 201) | def process_answer(self, answer):
    method process_line (line 209) | def process_line(self, prediction: str, gt_answers: List[str]) -> float:
    method compute_score (line 225) | def compute_score(

FILE: moondream/eval/waste_detection.py
  function iou (line 18) | def iou(a: Box, b: Box) -> float:
  function match (line 28) | def match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, in...
  class WasteDetection (line 54) | class WasteDetection(torch.utils.data.Dataset):
    method __init__ (line 55) | def __init__(self, name: str = "moondream/waste_detection", split: str...
    method __len__ (line 58) | def __len__(self):
    method __getitem__ (line 61) | def __getitem__(self, idx: int) -> Dict:
  function evaluate (line 81) | def evaluate(
  function load_model (line 114) | def load_model(path: str, device: torch.device) -> MoondreamModel:
  function main (line 123) | def main():

FILE: moondream/torch/config.py
  class TextMoeConfig (line 6) | class TextMoeConfig:
  class TextConfig (line 14) | class TextConfig:
  class VisionConfig (line 28) | class VisionConfig:
  class RegionConfig (line 43) | class RegionConfig:
  class TokenizerConfig (line 53) | class TokenizerConfig:
  class MoondreamConfig (line 77) | class MoondreamConfig:
    method from_dict (line 84) | def from_dict(cls, config_dict: dict):
    method to_dict (line 96) | def to_dict(self):

FILE: moondream/torch/hf_moondream.py
  function extract_question (line 18) | def extract_question(text):
  class HfConfig (line 28) | class HfConfig(PretrainedConfig):
    method __init__ (line 32) | def __init__(self, **kwargs):
  class HfMoondream (line 37) | class HfMoondream(PreTrainedModel):
    method __init__ (line 41) | def __init__(self, config):
    method _setup_caches (line 48) | def _setup_caches(self):
    method encode_image (line 54) | def encode_image(self):
    method query (line 59) | def query(self):
    method caption (line 64) | def caption(self):
    method detect (line 69) | def detect(self):
    method point (line 74) | def point(self):
    method detect_gaze (line 79) | def detect_gaze(self):
    method answer_question (line 83) | def answer_question(
    method batch_answer (line 99) | def batch_answer(self, images, prompts, tokenizer=None, **kwargs):
    method _unsupported_exception (line 105) | def _unsupported_exception(self):
    method generate (line 112) | def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128...
    method get_input_embeddings (line 142) | def get_input_embeddings(self) -> nn.Embedding:
    method set_input_embeddings (line 155) | def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) ...
    method input_embeds (line 165) | def input_embeds(

FILE: moondream/torch/image_crops.py
  function select_tiling (line 17) | def select_tiling(
  class OverlapCropOutput (line 53) | class OverlapCropOutput(TypedDict):
  function overlap_crop_image (line 58) | def overlap_crop_image(
  function reconstruct_from_crops (line 170) | def reconstruct_from_crops(

FILE: moondream/torch/layers.py
  function quantize_ (line 13) | def quantize_(model, quant_mode):
  function int4_weight_only (line 18) | def int4_weight_only(group_size):
  function gelu_approx (line 24) | def gelu_approx(x):
  class LinearWeights (line 29) | class LinearWeights:
  function linear (line 34) | def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:
  function dequantize_tensor (line 38) | def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):
  class QuantizedLinear (line 47) | class QuantizedLinear(nn.Module):
    method __init__ (line 48) | def __init__(
    method unpack (line 79) | def unpack(self):
    method forward (line 106) | def forward(self, x: torch.Tensor) -> torch.Tensor:
  class LayerNormWeights (line 113) | class LayerNormWeights:
  function layer_norm (line 118) | def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:
  class MLPWeights (line 123) | class MLPWeights:
  function mlp (line 129) | def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> ...
  function moe_mlp (line 149) | def moe_mlp(
  class AttentionWeights (line 218) | class AttentionWeights:
  function attn (line 223) | def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Te...

FILE: moondream/torch/lora.py
  function variant_cache_dir (line 11) | def variant_cache_dir():
  function cached_variant_path (line 23) | def cached_variant_path(variant_id: str):
  function nest (line 46) | def nest(flat):
  function variant_state_dict (line 58) | def variant_state_dict(variant_id: Optional[str] = None, device: str = "...

FILE: moondream/torch/moondream.py
  class EncodedImage (line 58) | class EncodedImage:
  class KVCache (line 63) | class KVCache(nn.Module):
    method __init__ (line 65) | def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
    method update (line 75) | def update(self, pos_ids, k, v):
  function causal_mask (line 82) | def causal_mask(b, h, q_idx, kv_idx):
  function get_mask_mod (line 86) | def get_mask_mod(mask_mod, offset):
  class MoondreamModel (line 93) | class MoondreamModel(nn.Module):
    method __init__ (line 95) | def __init__(
    method causal_block_mask (line 151) | def causal_block_mask(self):
    method point_gen_indices (line 164) | def point_gen_indices(self):
    method _setup_caches (line 172) | def _setup_caches(self):
    method device (line 185) | def device(self):
    method _vis_enc (line 188) | def _vis_enc(self, x: torch.Tensor):
    method _vis_proj (line 191) | def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):
    method _prefill (line 194) | def _prefill(
    method _decode_one_tok (line 203) | def _decode_one_tok(
    method compile (line 232) | def compile(self):
    method _run_vision_encoder (line 272) | def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:
    method encode_image (line 296) | def encode_image(
    method _apply_top_p (line 338) | def _apply_top_p(self, probs: torch.Tensor, top_p: float):
    method _prefill_prompt (line 348) | def _prefill_prompt(
    method _generate_reasoning (line 393) | def _generate_reasoning(
    method _generate_answer (line 506) | def _generate_answer(
    method query (line 615) | def query(
    method load_encoded_image (line 691) | def load_encoded_image(self, encoded_image: EncodedImage):
    method caption (line 696) | def caption(
    method _generate_points (line 724) | def _generate_points(
    method detect (line 816) | def detect(
    method point (line 864) | def point(
    method _detect_gaze (line 912) | def _detect_gaze(
    method detect_gaze (line 968) | def detect_gaze(
  function _is_cjk_char (line 1060) | def _is_cjk_char(cp):

FILE: moondream/torch/region.py
  function fourier_features (line 10) | def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
  function encode_coordinate (line 30) | def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:
  function decode_coordinate (line 44) | def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch...
  function encode_size (line 58) | def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
  function decode_size (line 72) | def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
  function encode_spatial_refs (line 94) | def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torc...

FILE: moondream/torch/rope.py
  function precompute_freqs_cis (line 6) | def precompute_freqs_cis(
  function apply_rotary_emb (line 19) | def apply_rotary_emb(

FILE: moondream/torch/text.py
  function text_encoder (line 13) | def text_encoder(input_ids: torch.Tensor, w: nn.Module):
  function attn (line 17) | def attn(
  function text_decoder (line 82) | def text_decoder(
  function lm_head (line 124) | def lm_head(
  function build_dense_mlp (line 137) | def build_dense_mlp(d_model, d_ffn, dtype, linear_cls):
  function build_moe_mlp (line 146) | def build_moe_mlp(d_model, d_ffn, n_experts, dtype):
  function build_text_model (line 169) | def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:

FILE: moondream/torch/utils.py
  function remove_outlier_points (line 4) | def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):

FILE: moondream/torch/vision.py
  function adaptive_avg_pool2d (line 16) | def adaptive_avg_pool2d(input, output_size):
  function prepare_crops (line 25) | def prepare_crops(
  function create_patches (line 44) | def create_patches(x, patch_size):
  function vision_encoder (line 64) | def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: Visio...
  function vision_projection (line 77) | def vision_projection(
  function build_vision_model (line 92) | def build_vision_model(config: VisionConfig, dtype: torch.dtype):

FILE: moondream/torch/weights.py
  function safetensors_open (line 10) | def safetensors_open(safetensors_file: str):
  function _load_weights (line 30) | def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.M...
  function load_weights_from_safetensors (line 130) | def load_weights_from_safetensors(weights_file: str, model: nn.Module) -...
  function load_weights_from_pt (line 150) | def load_weights_from_pt(weights_file: str, model: nn.Module) -> None:
  function load_weights_into_model (line 166) | def load_weights_into_model(weights_file: str, model: nn.Module) -> None:

FILE: recipes/gaze-detection-video/gaze-detection-video.py
  function initialize_model (line 20) | def initialize_model() -> Optional[AutoModelForCausalLM]:
  function video_handler (line 56) | def video_handler(
  function fig2rgb_array (line 81) | def fig2rgb_array(fig: plt.Figure) -> np.ndarray:
  function visualize_frame (line 91) | def visualize_frame(
  function process_video (line 200) | def process_video(

FILE: recipes/promptable-content-moderation/app.py
  function process_video_file (line 28) | def process_video_file(
  function create_visualization_plots (line 119) | def create_visualization_plots(json_path):

FILE: recipes/promptable-content-moderation/deep_sort_integration.py
  class DeepSORTTracker (line 7) | class DeepSORTTracker:
    method __init__ (line 8) | def __init__(self, max_age=5):
    method _create_tracker (line 13) | def _create_tracker(self):
    method reset (line 21) | def reset(self):
    method update (line 26) | def update(self, frame, detections):

FILE: recipes/promptable-content-moderation/main.py
  function get_sam_model (line 50) | def get_sam_model(slim=False):
  function load_sam_model (line 72) | def load_sam_model(slim=False):
  function generate_color_pair (line 77) | def generate_color_pair():
  function create_mask_overlay (line 84) | def create_mask_overlay(image, masks, points=None, labels=None):
  function process_sam_detection (line 153) | def process_sam_detection(image, center_x, center_y, slim=False):
  function load_moondream (line 184) | def load_moondream():
  function get_video_properties (line 193) | def get_video_properties(video_path):
  function is_valid_bounding_box (line 204) | def is_valid_bounding_box(bounding_box):
  function split_frame_into_grid (line 221) | def split_frame_into_grid(frame, grid_rows, grid_cols):
  function convert_tile_coords_to_frame (line 243) | def convert_tile_coords_to_frame(box, tile_pos, frame_shape):
  function merge_tile_detections (line 273) | def merge_tile_detections(tile_detections, iou_threshold=0.5):
  function detect_objects_in_frame (line 330) | def detect_objects_in_frame(
  function detect_objects_in_frame_single (line 376) | def detect_objects_in_frame_single(model, tokenizer, image, target_object):
  function draw_hitmarker (line 401) | def draw_hitmarker(
  function draw_ad_boxes (line 501) | def draw_ad_boxes(frame, detected_objects, detect_keyword, model, box_st...
  function filter_temporal_outliers (line 767) | def filter_temporal_outliers(detections_dict):
  function describe_frames (line 805) | def describe_frames(
  function create_detection_video (line 875) | def create_detection_video(
  function process_video (line 1058) | def process_video(
  function main (line 1226) | def main():

FILE: recipes/promptable-content-moderation/persistence.py
  function save_detection_data (line 5) | def save_detection_data(data, output_file):
  function load_detection_data (line 26) | def load_detection_data(input_file):

FILE: recipes/promptable-content-moderation/video_visualization.py
  function create_frame_data (line 12) | def create_frame_data(json_path):
  function generate_frame_image (line 64) | def generate_frame_image(df, frame_num, temp_dir, max_y):
  function generate_gauge_frame (line 124) | def generate_gauge_frame(df, frame_num, temp_dir, detect_keyword="OBJECT"):
  function create_video_visualization (line 256) | def create_video_visualization(json_path, style="timeline"):

FILE: recipes/promptable-content-moderation/visualization.py
  function visualize_detections (line 7) | def visualize_detections(json_path):
  function main (line 97) | def main():

FILE: recipes/promptable-video-redaction/app.py
  function process_video_file (line 22) | def process_video_file(

FILE: recipes/promptable-video-redaction/main.py
  function load_moondream (line 35) | def load_moondream():
  function get_video_properties (line 44) | def get_video_properties(video_path):
  function is_valid_box (line 55) | def is_valid_box(box):
  function split_frame_into_tiles (line 72) | def split_frame_into_tiles(frame, rows, cols):
  function convert_tile_coords_to_frame (line 94) | def convert_tile_coords_to_frame(box, tile_pos, frame_shape):
  function merge_tile_detections (line 124) | def merge_tile_detections(tile_detections, iou_threshold=0.5):
  function detect_ads_in_frame (line 181) | def detect_ads_in_frame(model, tokenizer, image, detect_keyword, rows=1,...
  function detect_ads_in_frame_single (line 225) | def detect_ads_in_frame_single(model, tokenizer, image, detect_keyword):
  function draw_hitmarker (line 250) | def draw_hitmarker(
  function draw_ad_boxes (line 350) | def draw_ad_boxes(frame, detected_objects, detect_keyword, box_style="ce...
  function filter_temporal_outliers (line 427) | def filter_temporal_outliers(detections_dict):
  function describe_frames (line 455) | def describe_frames(
  function create_detection_video (line 503) | def create_detection_video(
  function process_video (line 617) | def process_video(
  function main (line 657) | def main():

FILE: tests/test_image_crops.py
  function test_overlap_crop_basic (line 6) | def test_overlap_crop_basic():
  function test_overlap_crop_small_image (line 21) | def test_overlap_crop_small_image():
  function test_reconstruction (line 32) | def test_reconstruction():

FILE: webcam_gradio_demo.py
  function answer_question (line 33) | def answer_question(img, prompt):
  function img_change (line 81) | def img_change(img):
  function prompt_change (line 86) | def prompt_change(prompt):
  function live_video (line 91) | def live_video():
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (372K chars).
[
  {
    "path": ".github/workflows/pylint.yml",
    "chars": 780,
    "preview": "name: Lint\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\npermissions:\n  contents: read\njob"
  },
  {
    "path": ".github/workflows/test.yml",
    "chars": 658,
    "preview": "name: Tests\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  test:\n    runs-on: ubunt"
  },
  {
    "path": ".gitignore",
    "chars": 137,
    "preview": ".venv\n__pycache__\ncheckpoints\ndata\n/pyproject.toml\npoetry.lock\ndist\nclients/python/moondream/torch\nwandb/\nmoondream_fine"
  },
  {
    "path": "LICENSE",
    "chars": 11356,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 3717,
    "preview": "# 🌔 moondream\n\na tiny vision language model that kicks ass and runs anywhere\n\n[Website](https://moondream.ai/) | [Demo]("
  },
  {
    "path": "batch_generate_example.py",
    "chars": 879,
    "preview": "from PIL import Image\nfrom transformers import AutoTokenizer\n\nfrom moondream.hf import LATEST_REVISION, Moondream, detec"
  },
  {
    "path": "gradio_demo.py",
    "chars": 3240,
    "preview": "import argparse\nimport re\nfrom threading import Thread\n\nimport gradio as gr\nimport torch\nfrom PIL import ImageDraw\nfrom "
  },
  {
    "path": "moondream/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "moondream/config/config_md05.json",
    "chars": 2053,
    "preview": "{\n    \"text\": {\n        \"dim\": 1024,\n        \"ff_dim\": 4096,\n        \"n_layers\": 24,\n        \"vocab_size\": 51200,\n      "
  },
  {
    "path": "moondream/config/config_md2.json",
    "chars": 2054,
    "preview": "{\n    \"text\": {\n        \"dim\": 2048,\n        \"ff_dim\": 8192,\n        \"n_layers\": 24,\n        \"vocab_size\": 51200,\n      "
  },
  {
    "path": "moondream/eval/chartqa.py",
    "chars": 5317,
    "preview": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\nimport json\n\nfrom ..torch.config import MoondreamCon"
  },
  {
    "path": "moondream/eval/coco_map.py",
    "chars": 8263,
    "preview": "import argparse\nimport datasets\nimport torch\nimport json\nimport numpy as np\n\nfrom typing import List, Tuple\nfrom tqdm im"
  },
  {
    "path": "moondream/eval/countbenchqa.py",
    "chars": 2309,
    "preview": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..t"
  },
  {
    "path": "moondream/eval/docvqa.py",
    "chars": 2451,
    "preview": "import argparse\nimport editdistance\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.c"
  },
  {
    "path": "moondream/eval/eval_all.py",
    "chars": 1758,
    "preview": "import argparse\nimport torch\n\nfrom pprint import pprint\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondre"
  },
  {
    "path": "moondream/eval/gazefollow.py",
    "chars": 3763,
    "preview": "import torch\nimport datasets\nimport math\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch"
  },
  {
    "path": "moondream/eval/mmstar.py",
    "chars": 3125,
    "preview": "import datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream i"
  },
  {
    "path": "moondream/eval/naturalbench.py",
    "chars": 3767,
    "preview": "from datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.config import MoondreamConfig\nfrom .."
  },
  {
    "path": "moondream/eval/pope.py",
    "chars": 2749,
    "preview": "import argparse\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.config import Moondre"
  },
  {
    "path": "moondream/eval/realworldqa.py",
    "chars": 2168,
    "preview": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..t"
  },
  {
    "path": "moondream/eval/tallyqa.py",
    "chars": 2592,
    "preview": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..t"
  },
  {
    "path": "moondream/eval/textvqa.py",
    "chars": 2211,
    "preview": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..t"
  },
  {
    "path": "moondream/eval/utils.py",
    "chars": 8296,
    "preview": "import re\nfrom typing import List\n\n\nclass VQAScorer:\n    def __init__(self):\n        self.contractions = {\n            \""
  },
  {
    "path": "moondream/eval/waste_detection.py",
    "chars": 4403,
    "preview": "import argparse\nfrom collections import defaultdict\nfrom typing import Dict, List, Tuple\n\nimport torch\nfrom PIL import I"
  },
  {
    "path": "moondream/torch/config.py",
    "chars": 2834,
    "preview": "from dataclasses import dataclass, field\nfrom typing import Dict, List, Optional\n\n\n@dataclass(frozen=True)\nclass TextMoe"
  },
  {
    "path": "moondream/torch/hf_moondream.py",
    "chars": 5581,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom transformers import PreTrainedModel, PretrainedConfig\nfrom typing import Union\n"
  },
  {
    "path": "moondream/torch/hf_release.py",
    "chars": 529,
    "preview": "import torch\nimport argparse\n\nfrom .weights import load_weights_into_model\nfrom .hf_moondream import HfConfig, HfMoondre"
  },
  {
    "path": "moondream/torch/image_crops.py",
    "chars": 8145,
    "preview": "import math\nimport numpy as np\nimport torch\n\nfrom typing import TypedDict\n\ntry:\n    import pyvips\n\n    HAS_VIPS = True\ne"
  },
  {
    "path": "moondream/torch/layers.py",
    "chars": 6985,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dataclasses import dataclass\nfrom typing import"
  },
  {
    "path": "moondream/torch/lora.py",
    "chars": 2299,
    "preview": "import functools\nimport os\nimport shutil\nimport torch\n\nfrom pathlib import Path\nfrom urllib.request import Request, urlo"
  },
  {
    "path": "moondream/torch/moondream.py",
    "chars": 38203,
    "preview": "import torch\nimport torch.nn as nn\nimport random\n\nfrom typing import Literal, Tuple, TypedDict, Union, Dict, Any, Option"
  },
  {
    "path": "moondream/torch/region.py",
    "chars": 4510,
    "preview": "import torch\nimport torch.nn as nn\nimport math\n\nfrom typing import List, Tuple, Union\n\nSpatialRefs = List[Union[Tuple[fl"
  },
  {
    "path": "moondream/torch/rope.py",
    "chars": 1545,
    "preview": "# Ethically sourced from https://github.com/xjdr-alt/entropix\n\nimport torch\n\n\ndef precompute_freqs_cis(\n    dim: int,\n  "
  },
  {
    "path": "moondream/torch/sample.py",
    "chars": 8142,
    "preview": "import argparse\nimport json\nimport os\nimport torch\n\nfrom PIL import Image, ImageDraw\nfrom tqdm import tqdm\n\nfrom .weight"
  },
  {
    "path": "moondream/torch/text.py",
    "chars": 8253,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom torch.nn import functional as F\nfrom torch.nn.attention.flex_attention import f"
  },
  {
    "path": "moondream/torch/utils.py",
    "chars": 1415,
    "preview": "import numpy as np\n\n\ndef remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):\n    \"\"\"\n    Robust outlier de"
  },
  {
    "path": "moondream/torch/vision.py",
    "chars": 5087,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom typing import Union, Tuple\nf"
  },
  {
    "path": "moondream/torch/weights.py",
    "chars": 7898,
    "preview": "import safetensors\nimport torch\nimport torch.nn as nn\n\nfrom contextlib import contextmanager\nfrom typing import Callable"
  },
  {
    "path": "notebooks/RepEng.ipynb",
    "chars": 13041,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"This notebook shows how to compute "
  },
  {
    "path": "recipes/gaze-detection-video/.gitignore",
    "chars": 439,
    "preview": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib"
  },
  {
    "path": "recipes/gaze-detection-video/README.md",
    "chars": 5864,
    "preview": "# Gaze Detection Video Processor\n\n> **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via "
  },
  {
    "path": "recipes/gaze-detection-video/gaze-detection-video.py",
    "chars": 10546,
    "preview": "\"\"\"\nGaze Detection Video Processor using Moondream 2\n------------------------------------------------\nRead the README.md"
  },
  {
    "path": "recipes/gaze-detection-video/input/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "recipes/gaze-detection-video/output/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "recipes/gaze-detection-video/requirements.txt",
    "chars": 147,
    "preview": "torch>=2.0.0\ntransformers>=4.36.0\nopencv-python>=4.8.0\npillow>=10.0.0\nmatplotlib>=3.7.0\nnumpy>=1.24.0\ntqdm>=4.65.0\npyvip"
  },
  {
    "path": "recipes/gaze-detection-video/temp/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "recipes/promptable-content-moderation/.gitignore",
    "chars": 420,
    "preview": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\np"
  },
  {
    "path": "recipes/promptable-content-moderation/README.md",
    "chars": 5400,
    "preview": "# Promptable Content Moderation with Moondream\n\nWelcome to the future of content moderation with Moondream 2B, a powerfu"
  },
  {
    "path": "recipes/promptable-content-moderation/app.py",
    "chars": 24689,
    "preview": "#!/usr/bin/env python3\nimport gradio as gr\nimport os\nfrom main import load_moondream, process_video, load_sam_model\nimpo"
  },
  {
    "path": "recipes/promptable-content-moderation/deep_sort_integration.py",
    "chars": 2440,
    "preview": "import numpy as np\nimport torch\nfrom deep_sort_realtime.deepsort_tracker import DeepSort\nfrom datetime import datetime\n\n"
  },
  {
    "path": "recipes/promptable-content-moderation/main.py",
    "chars": 46200,
    "preview": "#!/usr/bin/env python3\nimport cv2, os, subprocess, argparse\nfrom PIL import Image\nimport torch\nfrom transformers import "
  },
  {
    "path": "recipes/promptable-content-moderation/packages.txt",
    "chars": 14,
    "preview": "libvips\nffmpeg"
  },
  {
    "path": "recipes/promptable-content-moderation/persistence.py",
    "chars": 1056,
    "preview": "import json\nimport os\n\n\ndef save_detection_data(data, output_file):\n    \"\"\"\n    Saves the detection data to a JSON file."
  },
  {
    "path": "recipes/promptable-content-moderation/requirements.txt",
    "chars": 687,
    "preview": "gradio>=4.0.0\ntorch>=2.0.0\n# if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https:/"
  },
  {
    "path": "recipes/promptable-content-moderation/video_visualization.py",
    "chars": 12594,
    "preview": "import os\nimport tempfile\nimport subprocess\nimport matplotlib.pyplot as plt\nimport pandas as pd\nimport cv2\nimport numpy "
  },
  {
    "path": "recipes/promptable-content-moderation/visualization.py",
    "chars": 3256,
    "preview": "import pandas as pd\nimport matplotlib.pyplot as plt\nfrom persistence import load_detection_data\nimport argparse\n\n\ndef vi"
  },
  {
    "path": "recipes/promptable-video-redaction/.gitignore",
    "chars": 414,
    "preview": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\np"
  },
  {
    "path": "recipes/promptable-video-redaction/README.md",
    "chars": 5960,
    "preview": "# Promptable Video Redaction with Moondream\n\nThis tool uses Moondream 2B, a powerful yet lightweight vision-language mod"
  },
  {
    "path": "recipes/promptable-video-redaction/app.py",
    "chars": 7564,
    "preview": "#!/usr/bin/env python3\nimport gradio as gr\nimport os\nfrom main import load_moondream, process_video\nimport shutil\nimport"
  },
  {
    "path": "recipes/promptable-video-redaction/main.py",
    "chars": 23233,
    "preview": "#!/usr/bin/env python3\nimport cv2, os, subprocess, argparse\nfrom PIL import Image\nfrom transformers import AutoModelForC"
  },
  {
    "path": "recipes/promptable-video-redaction/packages.txt",
    "chars": 14,
    "preview": "libvips\nffmpeg"
  },
  {
    "path": "recipes/promptable-video-redaction/requirements.txt",
    "chars": 103,
    "preview": "gradio>=4.0.0\ntorch\ntransformers\nopencv-python\npillow\nnumpy\ntqdm\nffmpeg-python\neinops\npyvips\naccelerate"
  },
  {
    "path": "requirements.txt",
    "chars": 193,
    "preview": "torch==2.8.0\nPillow-SIMD==9.5.0.post2\ntransformers==4.56.1\npyvips-binary==8.16.0\npyvips==2.2.3\naccelerate==1.10.1\ngradio"
  },
  {
    "path": "sample.py",
    "chars": 2747,
    "preview": "import argparse\nfrom queue import Queue\nfrom threading import Thread\n\nimport torch\nfrom PIL import Image\nfrom transforme"
  },
  {
    "path": "tests/test_image_crops.py",
    "chars": 2158,
    "preview": "import numpy as np\nimport torch\nfrom moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops\n\n\ndef"
  },
  {
    "path": "webcam_gradio_demo.py",
    "chars": 2663,
    "preview": "import argparse\nimport time\nfrom threading import Thread\n\nimport gradio as gr\nimport torch\nfrom transformers import Auto"
  }
]

About this extraction

This page contains the full source code of the vikhyat/moondream GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (343.1 KB), approximately 82.9k tokens, and a symbol index with 212 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!