[
  {
    "path": ".github/workflows/pylint.yml",
    "content": "name: Lint\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\npermissions:\n  contents: read\njobs:\n  build:\n    runs-on: ubuntu-latest\n    permissions:\n      contents: read\n    strategy:\n      matrix:\n        python-version: [\"3.12\"]  # Run lint checks only on latest Python version\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v4\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install autoflake black\n    - name: Checking for unused imports\n      run: |\n        autoflake -c -r .\n    - name: Checking code style\n      run: |\n        black --check .\n"
  },
  {
    "path": ".github/workflows/test.yml",
    "content": "name: Tests\n\non:\n  push:\n    branches: [ main ]\n  pull_request:\n    branches: [ main ]\n\njobs:\n  test:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.9\", \"3.10\", \"3.11\", \"3.12\"]\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v5\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install pytest\n        pip install -r requirements.txt\n    - name: Run tests\n      run: |\n        python -m pytest tests/test_image_crops.py -v"
  },
  {
    "path": ".gitignore",
    "content": ".venv\n__pycache__\ncheckpoints\ndata\n/pyproject.toml\npoetry.lock\ndist\nclients/python/moondream/torch\nwandb/\nmoondream_finetune.safetensors\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "# 🌔 moondream\n\na tiny vision language model that kicks ass and runs anywhere\n\n[Website](https://moondream.ai/) | [Demo](https://moondream.ai/playground)\n\n## Examples\n\n| Image                  | Example                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |\n| ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| ![](assets/demo-1.jpg) | **What is the girl doing?**<br>The girl is sitting at a table and eating a large hamburger.<br><br>**What color is the girl's hair?**<br>The girl's hair is white.                                                                                                                                                                                                                                                                                                                                                                                                    |\n| ![](assets/demo-2.jpg) | **What is this?**<br>This is a computer server rack, which is a device used to store and manage multiple computer servers. The rack is filled with various computer servers, each with their own dedicated space and power supply. The servers are connected to the rack via multiple cables, indicating that they are part of a larger system. The rack is placed on a carpeted floor, and there is a couch nearby, suggesting that the setup is in a living or entertainment area.<br><br>**What is behind the stand?**<br>Behind the stand, there is a brick wall. |\n\n## About\n\nMoondream is a highly efficient open-source vision language model that combines powerful image understanding capabilities with a remarkably small footprint. It's designed to be versatile and accessible, capable of running on a wide range of devices and platforms.\n\nThe project offers two model variants:\n\n- **Moondream 2B**: The primary model with 2 billion parameters, offering robust performance for general-purpose image understanding tasks including captioning, visual question answering, and object detection.\n- **Moondream 0.5B**: A compact 500 million parameter model specifically optimized as a distillation target for edge devices, enabling efficient deployment on resource-constrained hardware while maintaining impressive capabilities.\n\n## How to use\n\nMoondream can be run locally, or in the cloud. Please refer to the [Getting Started](https://moondream.ai/c/docs/quickstart) page for details.\n\n## Special thanks\n\n* [Modal](https://modal.com/?utm_source=github&utm_medium=github&utm_campaign=moondream) - Modal lets you run jobs in the cloud, by just writing a few lines of Python. Here's an [example of how to run Moondream on Modal](https://github.com/m87-labs/moondream-examples/tree/main/quickstart/modal).\n"
  },
  {
    "path": "batch_generate_example.py",
    "content": "from PIL import Image\nfrom transformers import AutoTokenizer\n\nfrom moondream.hf import LATEST_REVISION, Moondream, detect_device\n\ndevice, dtype = detect_device()\n\nmodel_id = \"vikhyatk/moondream2\"\ntokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)\nmoondream = Moondream.from_pretrained(\n    model_id,\n    revision=LATEST_REVISION,\n    torch_dtype=dtype,\n).to(device=device)\nmoondream.eval()\n\nimage1 = Image.open(\"assets/demo-1.jpg\")\nimage2 = Image.open(\"assets/demo-2.jpg\")\nprompts = [\n    \"What is the girl doing?\",\n    \"What color is the girl's hair?\",\n    \"What is this?\",\n    \"What is behind the stand?\",\n]\n\nanswers = moondream.batch_answer(\n    images=[image1, image1, image2, image2],\n    prompts=prompts,\n    tokenizer=tokenizer,\n)\n\nfor question, answer in zip(prompts, answers):\n    print(f\"Q: {question}\")\n    print(f\"A: {answer}\")\n    print()\n"
  },
  {
    "path": "gradio_demo.py",
    "content": "import argparse\nimport re\nfrom threading import Thread\n\nimport gradio as gr\nimport torch\nfrom PIL import ImageDraw\nfrom torchvision.transforms.v2 import Resize\nfrom transformers import AutoTokenizer, TextIteratorStreamer\n\nfrom moondream.hf import LATEST_REVISION, Moondream, detect_device\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--cpu\", action=\"store_true\")\nargs = parser.parse_args()\n\nif args.cpu:\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\nelse:\n    device, dtype = detect_device()\n    if device != torch.device(\"cpu\"):\n        print(\"Using device:\", device)\n        print(\"If you run into issues, pass the `--cpu` flag to this script.\")\n        print()\n\nmodel_id = \"vikhyatk/moondream2\"\ntokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)\nmoondream = Moondream.from_pretrained(\n    model_id, revision=LATEST_REVISION, torch_dtype=dtype\n).to(device=device)\nmoondream.eval()\n\n\ndef answer_question(img, prompt):\n    image_embeds = moondream.encode_image(img)\n    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)\n    thread = Thread(\n        target=moondream.answer_question,\n        kwargs={\n            \"image_embeds\": image_embeds,\n            \"question\": prompt,\n            \"tokenizer\": tokenizer,\n            \"streamer\": streamer,\n        },\n    )\n    thread.start()\n\n    buffer = \"\"\n    for new_text in streamer:\n        buffer += new_text\n        yield buffer\n\n\ndef extract_floats(text):\n    # Regular expression to match an array of four floating point numbers\n    pattern = r\"\\[\\s*(-?\\d+\\.\\d+)\\s*,\\s*(-?\\d+\\.\\d+)\\s*,\\s*(-?\\d+\\.\\d+)\\s*,\\s*(-?\\d+\\.\\d+)\\s*\\]\"\n    match = re.search(pattern, text)\n    if match:\n        # Extract the numbers and convert them to floats\n        return [float(num) for num in match.groups()]\n    return None  # Return None if no match is found\n\n\ndef extract_bbox(text):\n    bbox = None\n    if extract_floats(text) is not None:\n        x1, y1, x2, y2 = extract_floats(text)\n        bbox = (x1, y1, x2, y2)\n    return bbox\n\n\ndef process_answer(img, answer):\n    if extract_bbox(answer) is not None:\n        x1, y1, x2, y2 = extract_bbox(answer)\n        draw_image = Resize(768)(img)\n        width, height = draw_image.size\n        x1, x2 = int(x1 * width), int(x2 * width)\n        y1, y2 = int(y1 * height), int(y2 * height)\n        bbox = (x1, y1, x2, y2)\n        ImageDraw.Draw(draw_image).rectangle(bbox, outline=\"red\", width=3)\n        return gr.update(visible=True, value=draw_image)\n\n    return gr.update(visible=False, value=None)\n\n\nwith gr.Blocks() as demo:\n    gr.Markdown(\n        \"\"\"\n        # 🌔 moondream\n        \"\"\"\n    )\n    with gr.Row():\n        prompt = gr.Textbox(label=\"Input Prompt\", value=\"Describe this image.\", scale=4)\n        submit = gr.Button(\"Submit\")\n    with gr.Row():\n        img = gr.Image(type=\"pil\", label=\"Upload an Image\")\n        with gr.Column():\n            output = gr.Markdown(label=\"Response\")\n            ann = gr.Image(visible=False, label=\"Annotated Image\")\n\n    submit.click(answer_question, [img, prompt], output)\n    prompt.submit(answer_question, [img, prompt], output)\n    output.change(process_answer, [img, output], ann, show_progress=False)\n\ndemo.queue().launch(debug=True)\n"
  },
  {
    "path": "moondream/__init__.py",
    "content": ""
  },
  {
    "path": "moondream/config/config_md05.json",
    "content": "{\n    \"text\": {\n        \"dim\": 1024,\n        \"ff_dim\": 4096,\n        \"n_layers\": 24,\n        \"vocab_size\": 51200,\n        \"max_context\": 2048,\n        \"n_heads\": 16,\n        \"prefix_attn\": 730\n    },\n    \"vision\": {\n        \"enc_dim\": 720,\n        \"enc_patch_size\": 14,\n        \"enc_n_layers\": 27,\n        \"enc_ff_dim\": 2690,\n        \"enc_n_heads\": 10,\n        \"proj_out_dim\": 1024,\n        \"crop_size\": 378,\n        \"in_channels\": 3,\n        \"max_crops\": 12,\n        \"overlap_margin\": 4,\n        \"proj_inner_dim\": 8192\n    },\n    \"region\": {\n        \"dim\": 1024,\n        \"coord_feat_dim\": 256,\n        \"coord_out_dim\": 1024,\n        \"size_feat_dim\": 512,\n        \"size_out_dim\": 2048,\n        \"inner_dim\": 8192\n    },\n    \"tokenizer\": {\n        \"bos_id\": 50256,\n        \"eos_id\": 50256,\n        \"templates\": {\n            \"caption\": {\n                \"short\": [\n                    198,\n                    198,\n                    16438,\n                    8305,\n                    25\n                ],\n                \"normal\": [\n                    198,\n                    198,\n                    24334,\n                    1159,\n                    25\n                ]\n            },\n            \"query\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    24361,\n                    25\n                ],\n                \"suffix\": [\n                    198,\n                    198,\n                    33706,\n                    25\n                ]\n            },\n            \"detect\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    47504,\n                    25\n                ],\n                \"suffix\": [\n                    628\n                ]\n            },\n            \"point\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    12727,\n                    25\n                ],\n                \"suffix\": [\n                    628\n                ]\n            }\n        }\n    }\n}"
  },
  {
    "path": "moondream/config/config_md2.json",
    "content": "{\n    \"text\": {\n        \"dim\": 2048,\n        \"ff_dim\": 8192,\n        \"n_layers\": 24,\n        \"vocab_size\": 51200,\n        \"max_context\": 2048,\n        \"n_heads\": 32,\n        \"prefix_attn\": 730\n    },\n    \"vision\": {\n        \"enc_dim\": 1152,\n        \"enc_patch_size\": 14,\n        \"enc_n_layers\": 27,\n        \"enc_ff_dim\": 4304,\n        \"enc_n_heads\": 16,\n        \"proj_out_dim\": 2048,\n        \"crop_size\": 378,\n        \"in_channels\": 3,\n        \"max_crops\": 12,\n        \"overlap_margin\": 4,\n        \"proj_inner_dim\": 8192\n    },\n    \"region\": {\n        \"dim\": 2048,\n        \"coord_feat_dim\": 256,\n        \"coord_out_dim\": 1024,\n        \"size_feat_dim\": 512,\n        \"size_out_dim\": 2048,\n        \"inner_dim\": 8192\n    },\n    \"tokenizer\": {\n        \"bos_id\": 50256,\n        \"eos_id\": 50256,\n        \"templates\": {\n            \"caption\": {\n                \"short\": [\n                    198,\n                    198,\n                    16438,\n                    8305,\n                    25\n                ],\n                \"normal\": [\n                    198,\n                    198,\n                    24334,\n                    1159,\n                    25\n                ]\n            },\n            \"query\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    24361,\n                    25\n                ],\n                \"suffix\": [\n                    198,\n                    198,\n                    33706,\n                    25\n                ]\n            },\n            \"detect\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    47504,\n                    25\n                ],\n                \"suffix\": [\n                    628\n                ]\n            },\n            \"point\": {\n                \"prefix\": [\n                    198,\n                    198,\n                    12727,\n                    25\n                ],\n                \"suffix\": [\n                    628\n                ]\n            }\n        }\n    }\n}"
  },
  {
    "path": "moondream/eval/chartqa.py",
    "content": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\nimport json\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nPREFIX = \"Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. \"\n\n\n# https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81\ndef relaxed_correctness(\n    target: str, prediction: str, max_relative_change: float = 0.05\n) -> bool:\n    \"\"\"Calculates relaxed correctness.\n\n    The correctness tolerates certain error ratio defined by max_relative_change.\n    See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1:\n    “Following Methani et al. (2020), we use a relaxed accuracy measure for the\n    numeric answers to allow a minor inaccuracy that may result from the automatic\n    data extraction process. We consider an answer to be correct if it is within\n    5% of the gold answer. For non-numeric answers, we still need an exact match\n    to consider an answer to be correct.”\n\n    Args:\n      target: Target string.\n      prediction: Predicted string.\n      max_relative_change: Maximum relative change.\n\n    Returns:\n      Whether the prediction was correct given the specified tolerance.\n    \"\"\"\n\n    def _to_float(text):\n        try:\n            if text.endswith(\"%\"):\n                # Convert percentages to floats.\n                return float(text.rstrip(\"%\")) / 100.0\n            else:\n                return float(text)\n        except ValueError:\n            return None\n\n    prediction = str(prediction)\n    target = str(target)\n    prediction_float = _to_float(prediction)\n    target_float = _to_float(target)\n    if prediction_float is not None and target_float:\n        relative_change = abs(prediction_float - target_float) / abs(target_float)\n        return relative_change <= max_relative_change\n    else:\n        return prediction == target\n\n\ndef eval_chartqa(model, debug=False):\n    dataset = datasets.load_dataset(\"vikhyatk/chartqa\", split=\"test\")\n\n    correct = 0\n    total = 0\n    human_correct = 0\n    human_total = 0\n    results = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"ChartQA\"):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n\n        result = []\n        for qa in row[\"qa\"]:\n            question = PREFIX + qa[\"question\"]\n            answer = qa[\"answer\"]\n            model_answer = model.query(encoded_image, question)[\"answer\"]\n\n            # Attempt to parse both answers into lists, otherwise\n            try:\n                answer_list = json.loads(answer)\n                model_answer_list = json.loads(model_answer)\n                if not (\n                    isinstance(answer_list, list)\n                    and isinstance(model_answer_list, list)\n                    and len(answer_list) == len(model_answer_list)\n                ):\n                    raise ValueError\n            except:\n                # If parsing fails or lengths are not equal, compare the strings directly instead\n                answer_list = [answer]\n                model_answer_list = [model_answer]\n\n            total += 1\n            if qa[\"source\"] == \"human\":\n                human_total += 1\n\n            is_correct = False\n            if all(\n                relaxed_correctness(\n                    str(cur_answer).strip().lower(),\n                    str(cur_model_answer).strip().lower(),\n                )\n                for cur_answer, cur_model_answer in zip(answer_list, model_answer_list)\n            ):\n                correct += 1\n                if qa[\"source\"] == \"human\":\n                    human_correct += 1\n                is_correct = True\n            if debug:\n                print(\n                    f\"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}\"\n                )\n                print(f\"Human Accuracy: {human_correct * 100 / human_total:.2f}\")\n                print(f\"Total Accuracy: {correct * 100 / total:.2f}\")\n                print(\"---------\")\n            result.append(\n                {\n                    \"question\": question,\n                    \"ground_truth\": answer_list,\n                    \"model_answer\": model_answer_list,\n                    \"is_correct\": is_correct,\n                    \"source\": qa[\"source\"],\n                }\n            )\n        results.append(result)\n\n    return {\n        \"human_acc\": human_correct * 100 / human_total,\n        \"total_acc\": correct * 100 / total,\n        \"results\": results,\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    results = eval_chartqa(model, args.debug)\n    print(f\"Human Accuracy: {results['human_acc']:.2f}\")\n    print(f\"Total Accuracy: {results['total_acc']:.2f}\")\n"
  },
  {
    "path": "moondream/eval/coco_map.py",
    "content": "import argparse\nimport datasets\nimport torch\nimport json\nimport numpy as np\n\nfrom typing import List, Tuple\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\ncoco_classes = [\n    \"None\",\n    \"person\",\n    \"bicycle\",\n    \"car\",\n    \"motorcycle\",\n    \"airplane\",\n    \"bus\",\n    \"train\",\n    \"truck\",\n    \"boat\",\n    \"traffic light\",\n    \"fire hydrant\",\n    \"street sign\",\n    \"stop sign\",\n    \"parking meter\",\n    \"bench\",\n    \"bird\",\n    \"cat\",\n    \"dog\",\n    \"horse\",\n    \"sheep\",\n    \"cow\",\n    \"elephant\",\n    \"bear\",\n    \"zebra\",\n    \"giraffe\",\n    \"hat\",\n    \"backpack\",\n    \"umbrella\",\n    \"shoe\",\n    \"eye glasses\",\n    \"handbag\",\n    \"tie\",\n    \"suitcase\",\n    \"frisbee\",\n    \"skis\",\n    \"snowboard\",\n    \"sports ball\",\n    \"kite\",\n    \"baseball bat\",\n    \"baseball glove\",\n    \"skateboard\",\n    \"surfboard\",\n    \"tennis racket\",\n    \"bottle\",\n    \"plate\",\n    \"wine glass\",\n    \"cup\",\n    \"fork\",\n    \"knife\",\n    \"spoon\",\n    \"bowl\",\n    \"banana\",\n    \"apple\",\n    \"sandwich\",\n    \"orange\",\n    \"broccoli\",\n    \"carrot\",\n    \"hot dog\",\n    \"pizza\",\n    \"donut\",\n    \"cake\",\n    \"chair\",\n    \"couch\",\n    \"potted plant\",\n    \"bed\",\n    \"mirror\",\n    \"dining table\",\n    \"window\",\n    \"desk\",\n    \"toilet\",\n    \"door\",\n    \"tv\",\n    \"laptop\",\n    \"mouse\",\n    \"remote\",\n    \"keyboard\",\n    \"cell phone\",\n    \"microwave\",\n    \"oven\",\n    \"toaster\",\n    \"sink\",\n    \"refrigerator\",\n    \"blender\",\n    \"book\",\n    \"clock\",\n    \"vase\",\n    \"scissors\",\n    \"teddy bear\",\n    \"hair drier\",\n    \"toothbrush\",\n    \"hair brush\",\n]\n\nCOCO_LABELS = {}\n\nfor i, c in enumerate(coco_classes):\n    COCO_LABELS[i] = c\n\n\ndef calculate_iou(\n    box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]\n) -> float:\n    \"\"\"Calculate IoU between two boxes (x1, y1, x2, y2 format)\"\"\"\n    x1 = max(box1[0], box2[0])\n    y1 = max(box1[1], box2[1])\n    x2 = min(box1[2], box2[2])\n    y2 = min(box1[3], box2[3])\n\n    intersection = max(0, x2 - x1) * max(0, y2 - y1)\n    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])\n    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])\n\n    return intersection / (box1_area + box2_area - intersection)\n\n\ndef calculate_map(\n    ground_truth_boxes: List[List[Tuple[float, float, float, float]]],\n    predicted_boxes: List[List[Tuple[float, float, float, float, float]]],\n    iou_threshold: float = 0.5,\n) -> float:\n    \"\"\"\n    Calculate mAP for object detection\n\n    Args:\n        ground_truth_boxes: List of lists of ground truth boxes per image [(x1, y1, x2, y2)]\n        predicted_boxes: List of lists of predicted boxes per image [(x1, y1, x2, y2, confidence)]\n        iou_threshold: IoU threshold for considering a detection as correct\n\n    Returns:\n        mean Average Precision\n    \"\"\"\n    total_precision = 0\n    num_classes = len(ground_truth_boxes)\n\n    for class_idx in range(num_classes):\n        # Get all predictions and ground truths for this class\n        gt_boxes = ground_truth_boxes[class_idx]\n        pred_boxes = predicted_boxes[class_idx]\n\n        # Sort predictions by confidence\n        pred_boxes = sorted(pred_boxes, key=lambda x: x[4], reverse=True)\n\n        # Initialize arrays for precision-recall calculation\n        num_gt = len(gt_boxes)\n        if num_gt == 0:\n            continue\n\n        tp = np.zeros(len(pred_boxes))\n        fp = np.zeros(len(pred_boxes))\n        gt_matched = [False] * num_gt\n\n        # Match each prediction to ground truth\n        for pred_idx, pred_box in enumerate(pred_boxes):\n            max_iou = 0\n            max_idx = -1\n\n            # Find best matching ground truth box\n            for gt_idx, gt_box in enumerate(gt_boxes):\n                if gt_matched[gt_idx]:\n                    continue\n\n                iou = calculate_iou(pred_box[:4], gt_box)\n                if iou > max_iou:\n                    max_iou = iou\n                    max_idx = gt_idx\n\n            # If IoU exceeds threshold, count as true positive\n            if max_iou >= iou_threshold:\n                tp[pred_idx] = 1\n                gt_matched[max_idx] = True\n            else:\n                fp[pred_idx] = 1\n\n        # Calculate cumulative precision and recall\n        cumsum_tp = np.cumsum(tp)\n        cumsum_fp = np.cumsum(fp)\n        recalls = cumsum_tp / num_gt\n        precisions = cumsum_tp / (cumsum_tp + cumsum_fp)\n\n        # Calculate average precision using all points\n        ap = 0\n        for t in np.arange(0, 1.1, 0.1):\n            if np.sum(recalls >= t) == 0:\n                p = 0\n            else:\n                p = np.max(precisions[recalls >= t])\n            ap += p / 11\n\n        total_precision += ap\n\n    return total_precision / num_classes\n\n\ndef get_total_map(results_by_label, frequency_by_label):\n    total_count = 0\n    total_map = 0\n    for results, frequency in zip(\n        results_by_label.values(), frequency_by_label.values()\n    ):\n        cur_total_map = sum(results)\n        total_map += cur_total_map\n        total_count += frequency\n    return total_map / total_count\n\n\ndef eval_coco_map(model, iou_threshold=0.5, debug=False):\n    dataset = datasets.load_dataset(\n        \"moondream/coco-val-2017-bbox-cleaned\", split=\"validation\"\n    )\n\n    total = 0\n    results_by_label = {}  # map to list of raw map results for each label\n    frequency_by_label = {}  #  many images contain a given label\n\n    for row in tqdm(dataset, disable=debug, desc=\"COCO mAP\"):\n        width = row[\"image\"].width\n        height = row[\"image\"].height\n        total += 1\n\n        objects = json.loads(row[\"objects\"])\n\n        gt_label_to_boxes = {}\n\n        for bbox, label in zip(objects[\"bbox\"], objects[\"label\"]):\n            if label not in gt_label_to_boxes:\n                gt_label_to_boxes[label] = []\n            x1, y1, w, h = bbox\n            gt_label_to_boxes[label].append((x1, y1, x1 + w, y1 + h))\n\n        unique_labels = [label for label in set(objects[\"label\"])]\n\n        for label in unique_labels:\n\n            encoded_image = model.encode_image(row[\"image\"])\n            model_answer = model.detect(encoded_image, COCO_LABELS[label])[\"objects\"]\n\n            moondream_boxes = []\n\n            for box in model_answer:\n                moondream_boxes.append(\n                    (\n                        box[\"x_min\"] * width,\n                        box[\"y_min\"] * height,\n                        box[\"x_max\"] * width,\n                        box[\"y_max\"] * height,\n                        1.0,  # Using default confidence of 1.0\n                    )\n                )\n            map_result = calculate_map(\n                [gt_label_to_boxes[label]], [moondream_boxes], iou_threshold\n            )\n            if debug and map_result == 0:\n                print(\n                    f\"0 Map result for index {total} and label {label} ({COCO_LABELS[label]})\"\n                )\n\n            if label not in results_by_label:\n                results_by_label[label] = []\n            results_by_label[label].append(map_result)\n\n            if label not in frequency_by_label:\n                frequency_by_label[label] = 0\n            frequency_by_label[label] += 1\n\n        if debug and total % 100 == 0:\n            print(\n                f\"Total map: {get_total_map(results_by_label, frequency_by_label)*100:.2f}, ({total} images)\"\n            )\n\n    return {\n        # \"results_by_label\": results_by_label,\n        # \"frequency_by_label\": frequency_by_label,\n        \"total_map\": get_total_map(results_by_label, frequency_by_label),\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    # This repo doesn't have moondream deps we need\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_coco_map(model, 0.5, args.debug)\n\n    print(f\"Overall MAP: {result['total_map']*100:.2f}\")\n"
  },
  {
    "path": "moondream/eval/countbenchqa.py",
    "content": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nPREFIX = \"Look at the image carefully and count the objects. Answer with just a number, without any additional text. \"\n\n\ndef eval_countbenchqa(model, debug=False):\n    dataset = datasets.load_dataset(\"vikhyatk/CountBenchQA\", split=\"test\")\n\n    correct = 0\n    total = 0\n    results = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"CountBenchQA\"):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n\n        question = PREFIX + row[\"question\"]\n        answer = str(row[\"number\"])\n        model_answer = model.query(encoded_image, question)[\"answer\"]\n        is_correct = model_answer.strip().lower() == answer.strip().lower()\n\n        results.append(\n            {\n                \"question\": question,\n                \"ground_truth\": answer,\n                \"model_answer\": model_answer,\n                \"is_correct\": is_correct,\n            }\n        )\n\n        total += 1\n        if is_correct:\n            correct += 1\n        elif debug:\n            print(f\"Question: {row['question']}\")\n            print(f\"Answer: {answer}\")\n            print(f\"Model Answer: {model_answer}\")\n        if debug:\n            print(f\"Correct: {correct}, Total: {total}\")\n            print(f\"Accuracy: {correct * 100 / total:.2f}\")\n            print(\"---------\")\n\n    return {\n        \"acc\": correct * 100 / total,\n        \"correct_count\": correct,\n        \"total_count\": total,\n        \"results\": results,\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n\n    result = eval_countbenchqa(model, args.debug)\n\n    print(f\"Accuracy: {result['acc']:.2f}\")\n    print(f\"Correct: {result['correct_count']}, Total: {result['total_count']}\")\n"
  },
  {
    "path": "moondream/eval/docvqa.py",
    "content": "import argparse\nimport editdistance\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nSUFFIX = \" The answer should be a short text span taken verbatim from the document.\"\n\n\ndef get_anls(s1, s2):\n    s1 = s1.lower().strip()\n    s2 = s2.lower().strip()\n    iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2))\n    anls = iou if iou >= 0.5 else 0.0\n    return anls\n\n\ndef eval_docvqa(model, debug=False):\n    docvqa_val = load_dataset(\"vikhyatk/docvqa-val\", split=\"validation\")\n\n    scores = []\n    results = []\n\n    for row in tqdm(docvqa_val, disable=debug, desc=\"DocVQA\"):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n\n        result = []\n        for qa in row[\"qa\"]:\n            question = qa[\"question\"]\n            answers = qa[\"answers\"]\n            prompt = question + SUFFIX\n\n            model_answer = model.query(encoded_image, prompt)[\"answer\"]\n            anls = max(get_anls(model_answer, gt) for gt in answers)\n            scores.append(anls)\n            result.append(\n                {\n                    \"question\": question,\n                    \"ground_truth\": answers,\n                    \"model_answer\": model_answer,\n                    \"anls\": anls,\n                }\n            )\n\n            if debug:\n                print(f\"Question: {question}\")\n                print(f\"Ground Truth: {answers}\")\n                print(f\"Model Answer: {model_answer}\")\n                print(f\"ANLS: {anls}\")\n                print(f\"Current Average ANLS: {sum(scores) / len(scores):.4f}\")\n                print(\"---------\")\n        results.append(result)\n\n    return {\n        \"anls\": sum(scores) / len(scores),\n        \"results\": results,\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_docvqa(model, args.debug)\n\n    print(f\"ANLS: {result['anls']:.4f}\")\n"
  },
  {
    "path": "moondream/eval/eval_all.py",
    "content": "import argparse\nimport torch\n\nfrom pprint import pprint\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nfrom .countbenchqa import eval_countbenchqa\nfrom .pope import evaluate_pope\nfrom .realworldqa import eval_realworldqa\nfrom .chartqa import eval_chartqa\nfrom .textvqa import eval_textvqa\nfrom .docvqa import eval_docvqa\nfrom .mmstar import eval_mmstar\nfrom .coco_map import eval_coco_map\nfrom .naturalbench import eval_naturalbench\nfrom .tallyqa import eval_tallyqa\n\n\ndef create_model(ckpt_path):\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(ckpt_path, model)\n    model.compile()\n    return model\n\n\ndef eval_all(model, skip=[]):\n    evals = {\n        \"countbenchqa\": eval_countbenchqa,\n        \"pope\": evaluate_pope,\n        \"realworldqa\": eval_realworldqa,\n        \"chartqa\": eval_chartqa,\n        \"mmstar\": eval_mmstar,\n        \"docvqa\": eval_docvqa,\n        \"coco_map\": eval_coco_map,\n        \"textvqa\": eval_textvqa,\n        \"naturalbench\": eval_naturalbench,\n        \"tallyqa\": eval_tallyqa,\n    }\n\n    for b in skip:\n        del evals[b]\n\n    results = {}\n    for name, eval_fn in evals.items():\n        results[name] = eval_fn(model)\n        pprint({k: v for k, v in results[name].items() if k != \"results\"})\n\n    return results\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    model = create_model(args.model)\n    eval_all(model)\n"
  },
  {
    "path": "moondream/eval/gazefollow.py",
    "content": "import torch\nimport datasets\nimport math\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\ndef eval_gazefollow(model, debug=False):\n    dataset = datasets.load_dataset(\"vikhyatk/gazefollow\", split=\"test\")\n\n    mean_l2_error = []\n    min_l2_error = []\n    total = 0\n\n    for i, row in tqdm(enumerate(dataset), total=len(dataset)):\n        heads = []\n\n        for gaze in row[\"gazes\"]:\n            head_bbox = gaze[\"head_bbox\"]  # xmin, ymin, xmax, ymax\n            eye_coord = (gaze[\"eye\"][\"x\"], gaze[\"eye\"][\"y\"])\n            mean_target_gaze = (gaze[\"gaze\"][\"x\"], gaze[\"gaze\"][\"y\"])\n\n            # Check if a head already exists with the same approximate bbox.\n            # If so, use that head instead of creating a new one.\n            for head in heads:\n                if (\n                    abs(head[\"head_bbox\"][\"xmin\"] - head_bbox[\"xmin\"]) < 0.001\n                    and abs(head[\"head_bbox\"][\"xmax\"] - head_bbox[\"xmax\"]) < 0.001\n                    and abs(head[\"head_bbox\"][\"ymin\"] - head_bbox[\"ymin\"]) < 0.001\n                    and abs(head[\"head_bbox\"][\"ymax\"] - head_bbox[\"ymax\"]) < 0.001\n                ):\n                    head[\"gazes\"].append(mean_target_gaze)\n                    break\n            else:\n                heads.append(\n                    {\n                        \"head_bbox\": head_bbox,\n                        \"eye_coord\": eye_coord,\n                        \"gazes\": [mean_target_gaze],\n                    }\n                )\n\n        for head in heads:\n            pred_gaze = model.detect_gaze(\n                row[\"image\"],\n                eye=head[\"eye_coord\"],\n                face={\n                    \"x_min\": head[\"head_bbox\"][\"xmin\"],\n                    \"y_min\": head[\"head_bbox\"][\"ymin\"],\n                    \"x_max\": head[\"head_bbox\"][\"xmax\"],\n                    \"y_max\": head[\"head_bbox\"][\"ymax\"],\n                },\n                unstable_settings={\"force_detect\": True},\n            )[\"gaze\"]\n\n            mean_target_gaze = (\n                sum(gaze[0] for gaze in head[\"gazes\"]) / len(head[\"gazes\"]),\n                sum(gaze[1] for gaze in head[\"gazes\"]) / len(head[\"gazes\"]),\n            )\n            mean_l2 = math.sqrt(\n                (mean_target_gaze[0] - pred_gaze[\"x\"]) ** 2\n                + (mean_target_gaze[1] - pred_gaze[\"y\"]) ** 2\n            )\n            min_l2 = min(\n                math.sqrt(\n                    (target_gaze[0] - pred_gaze[\"x\"]) ** 2\n                    + (target_gaze[1] - pred_gaze[\"y\"]) ** 2\n                )\n                for target_gaze in head[\"gazes\"]\n            )\n\n            mean_l2_error.append(mean_l2)\n            min_l2_error.append(min_l2)\n            total += 1\n\n            if i % 100 == 0 and debug:\n                print(\"Mean L2 error:\", sum(mean_l2_error) / total)\n                print(\"Min L2 error:\", sum(min_l2_error) / total)\n\n    return {\n        \"mean_l2\": sum(mean_l2_error) / total,\n        \"min_l2\": sum(min_l2_error) / total,\n    }\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n\n    results = eval_gazefollow(model, debug=args.debug)\n\n    print(f\"Mean L2 error: {results['mean_l2']:.4f}\")\n    print(f\"Min L2 error: {results['min_l2']:.4f}\")\n"
  },
  {
    "path": "moondream/eval/mmstar.py",
    "content": "import datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nSUFFIX = \" Please answer directly with only the letter of the correct option and nothing else.\"\n\n\ndef eval_mmstar(model, debug=False):\n    dataset = datasets.load_dataset(\"Lin-Chen/MMStar\", split=\"val\")\n\n    correct = 0\n    total = 0\n    category_stats = {}\n    results = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"MMStar\"):\n        image = row[\"image\"]\n        question = row[\"question\"] + SUFFIX\n        answer = row[\"answer\"]\n        model_answer = model.query(image, question)[\"answer\"]\n        is_correct = model_answer.strip().lower() == answer.strip().lower()\n\n        category = f\"{row['category']} / {row['l2_category']}\"\n        if category not in category_stats:\n            category_stats[category] = {\"correct\": 0, \"total\": 0}\n\n        total += 1\n        category_stats[category][\"total\"] += 1\n\n        results.append(\n            {\n                \"question\": question,\n                \"ground_truth\": answer,\n                \"model_answer\": model_answer,\n                \"is_correct\": is_correct,\n                \"category\": category,\n            }\n        )\n\n        if is_correct:\n            correct += 1\n            category_stats[category][\"correct\"] += 1\n        elif debug:\n            print(f\"Index: {row['index']}\")\n            print(f\"Question: {row['question']}\")\n            print(f\"Answer: {answer}\")\n            print(f\"Model Answer: {model_answer}\")\n        if debug:\n            print(f\"Correct: {correct}, Total: {total}\")\n            print(f\"Accuracy: {correct * 100 / total:.2f}\")\n            print(\"Results by category:\")\n            for category, stats in category_stats.items():\n                acc = stats[\"correct\"] * 100 / stats[\"total\"]\n                print(f\"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%\")\n            print(\"---------\")\n\n    return {\n        \"acc\": correct * 100 / total,\n        \"correct_count\": correct,\n        \"total_count\": total,\n        \"category_stats\": category_stats,\n        \"results\": results,\n    }\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_mmstar(model, args.debug)\n\n    print(f\"Correct: {result['correct_count']}, Total: {result['total_count']}\")\n    print(f\"Accuracy: {result['acc']:.2f}\")\n\n    print(\"\\nResults by category:\")\n    for category, stats in result[\"category_stats\"].items():\n        acc = stats[\"correct\"] * 100 / stats[\"total\"]\n        print(f\"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%\")\n"
  },
  {
    "path": "moondream/eval/naturalbench.py",
    "content": "from datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\ndef eval_naturalbench(model, debug=False):\n    # Yes, the benchmark test set is stored in the 'train' split...\n    dataset = load_dataset(\"BaiqiL/NaturalBench\", split=\"train\")\n\n    acc = []\n    q_acc = []\n    i_acc = []\n    g_acc = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"NaturalBench\"):\n        if row[\"Question_Type\"] == \"yes_no\":\n            suffix = \" Answer yes or no.\"\n        else:\n            suffix = \"\"\n\n        images = [row[\"Image_0\"], row[\"Image_1\"], row[\"Image_0\"], row[\"Image_1\"]]\n        prompts = [\n            row[\"Question_0\"] + suffix,\n            row[\"Question_0\"] + suffix,\n            row[\"Question_1\"] + suffix,\n            row[\"Question_1\"] + suffix,\n        ]\n        expected = [\n            row[\"Image_0_Question_0\"].strip().lower(),\n            row[\"Image_1_Question_0\"].strip().lower(),\n            row[\"Image_0_Question_1\"].strip().lower(),\n            row[\"Image_0_Question_1\"].strip().lower(),\n        ]\n\n        answers = []\n        for img, prompt in zip(images, prompts):\n            encoded_image = model.encode_image(img)\n            answer = model.query(encoded_image, prompt)[\"answer\"]\n            answers.append(answer.strip().lower())\n\n        if debug:\n            for i, (q, a, e) in enumerate(zip(prompts, answers, expected)):\n                print(f\"Q{i}: {q}\")\n                print(f\"Model: {a}\")\n                print(f\"Expected: {e}\")\n                print(f\"Correct: {a == e}\")\n                print(\"---\")\n\n        acc.append(answers[0] == expected[0])\n        acc.append(answers[1] == expected[1])\n        acc.append(answers[2] == expected[2])\n        acc.append(answers[3] == expected[3])\n\n        i_acc.append(answers[0] == expected[0] and answers[2] == expected[2])\n        i_acc.append(answers[1] == expected[1] and answers[3] == expected[3])\n\n        q_acc.append(answers[0] == expected[0] and answers[1] == expected[1])\n        q_acc.append(answers[2] == expected[2] and answers[3] == expected[3])\n\n        g_acc.append(\n            answers[0] == expected[0]\n            and answers[1] == expected[1]\n            and answers[2] == expected[2]\n            and answers[3] == expected[3]\n        )\n\n        if debug:\n            print(f\"Current Overall Accuracy: {sum(acc) / len(acc):.4f}\")\n            print(f\"Current Image Accuracy: {sum(i_acc) / len(i_acc):.4f}\")\n            print(f\"Current Question Accuracy: {sum(q_acc) / len(q_acc):.4f}\")\n            print(f\"Current Group Accuracy: {sum(g_acc) / len(g_acc):.4f}\")\n            print(\"=========\")\n\n    return {\n        \"overall_acc\": sum(acc) / len(acc),\n        \"image_acc\": sum(i_acc) / len(i_acc),\n        \"question_acc\": sum(q_acc) / len(q_acc),\n        \"group_acc\": sum(g_acc) / len(g_acc),\n    }\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    results = eval_naturalbench(model, debug=args.debug)\n\n    print(f\"Overall Accuracy: {results['overall_acc']:.4f}\")\n    print(f\"Image Accuracy: {results['image_acc']:.4f}\")\n    print(f\"Question Accuracy: {results['question_acc']:.4f}\")\n    print(f\"Group Accuracy: {results['group_acc']:.4f}\")\n"
  },
  {
    "path": "moondream/eval/pope.py",
    "content": "import argparse\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nimport torch\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\ndef evaluate_pope(model, debug=False):\n    pope_dataset = load_dataset(\"vikhyatk/POPE\", split=\"test\")\n\n    stats = {\n        \"random\": (0, 0),\n        \"popular\": (0, 0),\n        \"adversarial\": (0, 0),\n    }\n\n    for row in tqdm(pope_dataset, disable=debug, desc=\"POPE\"):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n        for split in [\"adversarial\", \"popular\", \"random\"]:\n            for qa in row[split]:\n                question = qa[\"question\"]\n                answer = qa[\"answer\"]\n                prompt = f\"{question}\\nAnswer yes or no.\"\n                model_answer = model.query(encoded_image, prompt)[\"answer\"].strip()\n\n                if debug:\n                    print(f\"Split: {split}\")\n                    print(f\"Question: {question}\")\n                    print(f\"Model: {model_answer}\")\n                    print(f\"Expected: {answer}\")\n                    print(f\"Correct: {model_answer.lower() == answer.lower()}\")\n                    print(\"---\")\n\n                if model_answer.lower() == answer.lower():\n                    stats[split] = (stats[split][0] + 1, stats[split][1] + 1)\n                else:\n                    stats[split] = (stats[split][0], stats[split][1] + 1)\n\n                if debug:\n                    for s in stats:\n                        if stats[s][1] > 0:\n                            print(\n                                f\"{s.capitalize()}: {stats[s][0]}/{stats[s][1]} = {stats[s][0] * 100.0 / stats[s][1]:.2f}%\"\n                            )\n                    print(\"=========\")\n\n    return {\n        \"random\": stats[\"random\"][0] * 100.0 / stats[\"random\"][1],\n        \"popular\": stats[\"popular\"][0] * 100.0 / stats[\"popular\"][1],\n        \"adversarial\": stats[\"adversarial\"][0] * 100.0 / stats[\"adversarial\"][1],\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n\n    result = evaluate_pope(model, args.debug)\n\n    print(f\"Random Accuracy: {result['random']:.2f}\")\n    print(f\"Popular Accuracy: {result['popular']:.2f}\")\n    print(f\"Adversarial Accuracy: {result['adversarial']:.2f}\")\n"
  },
  {
    "path": "moondream/eval/realworldqa.py",
    "content": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\ndef eval_realworldqa(model, debug=False):\n    dataset = datasets.load_dataset(\"lmms-lab/RealWorldQA\", split=\"test\")\n\n    correct = 0\n    total = 0\n    results = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"RealWorldQA\"):\n        image = row[\"image\"]\n        question = row[\"question\"]\n        answer = row[\"answer\"]\n        model_answer = model.query(image, question)[\"answer\"]\n        is_correct = model_answer.strip().lower() == answer.strip().lower()\n\n        results.append(\n            {\n                \"question\": question,\n                \"ground_truth\": answer,\n                \"model_answer\": model_answer,\n                \"is_correct\": is_correct,\n            }\n        )\n\n        total += 1\n        if is_correct:\n            correct += 1\n        elif debug:\n            print(f\"Image: {row['image_path']}\")\n            print(f\"Question: {question}\")\n            print(f\"Answer: {answer}\")\n            print(f\"Model Answer: {model_answer}\")\n        if debug:\n            print(f\"Correct: {correct}, Total: {total}\")\n            print(f\"Accuracy: {correct * 100 / total:.2f}\")\n            print(\"---------\")\n\n    return {\n        \"acc\": correct * 100 / total,\n        \"correct_count\": correct,\n        \"total_count\": total,\n        \"results\": results,\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_realworldqa(model, args.debug)\n\n    print(f\"Accuracy: {result['acc']:.2f}\")\n    print(f\"Correct: {result['correct_count']} / {result['total_count']}\")\n"
  },
  {
    "path": "moondream/eval/tallyqa.py",
    "content": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\nPREFIX = \"Look at the image carefully and count the objects. Answer with just a number, without any additional text. \"\n\n\ndef eval_tallyqa(model, debug=False):\n    dataset = datasets.load_dataset(\n        \"vikhyatk/tallyqa-test\",\n        split=\"test\",\n        download_config=datasets.DownloadConfig(num_proc=16),\n    )\n\n    total = 0\n    total_simple = 0\n    correct = 0\n    correct_simple = 0\n\n    for row in tqdm(dataset, disable=args.debug):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n\n        for qa in row[\"qa\"]:\n            question = PREFIX + qa[\"question\"]\n            answer = str(qa[\"answer\"])\n            is_simple = qa[\"is_simple\"]\n\n            model_answer = model.query(encoded_image, question)[\"answer\"]\n\n            total += 1\n            if model_answer.strip().lower() == answer.strip().lower():\n                correct += 1\n            elif args.debug:\n                print(f\"Question: {qa['question']}\")\n                print(f\"Answer: {answer}\")\n                print(f\"Model Answer: {model_answer}\")\n\n            if is_simple:\n                total_simple += 1\n                if model_answer.strip().lower() == answer.strip().lower():\n                    correct_simple += 1\n\n            if args.debug:\n                print(f\"Simple - Correct: {correct_simple}, Total: {total_simple}\")\n                print(f\"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}\")\n                print(f\"All - Correct: {correct}, Total: {total}\")\n                print(f\"All Accuracy: {correct * 100 / total:.2f}\")\n                print(\"---------\")\n\n    return {\n        \"simple_acc\": correct_simple * 100 / total_simple,\n        \"full_acc\": correct * 100 / total,\n    }\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_tallyqa(model, args.debug)\n\n    print(f\"Simple acc: {result['simple_acc']:.2f}\")\n    print(f\"Full acc: {result['full_acc']:.2f}\")\n"
  },
  {
    "path": "moondream/eval/textvqa.py",
    "content": "import argparse\nimport datasets\nimport torch\n\nfrom tqdm import tqdm\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\nfrom .utils import VQAScorer\n\nPREFIX_TEXTVQA = \"Read the text in the image and provide a brief lowercase answer. Respond 'unanswerable' only if there is no plausible answer. \"\n\n\ndef eval_textvqa(model, debug=False):\n    dataset = datasets.load_dataset(\"vikhyatk/textvqa_val\", split=\"validation\")\n\n    scorer = VQAScorer()\n\n    total_score = 0\n    total_samples = 0\n    results = []\n\n    for row in tqdm(dataset, disable=debug, desc=\"TextVQA\"):\n        image = row[\"image\"]\n        encoded_image = model.encode_image(image)\n        question = PREFIX_TEXTVQA + row[\"question\"]\n        model_answer = model.query(encoded_image, question)[\"answer\"]\n\n        score = scorer.compute_score(model_answer, row[\"answers\"])\n        total_score += score\n        total_samples += 1\n\n        results.append(\n            {\n                \"question\": question,\n                \"ground_truth\": row[\"answers\"],\n                \"model_answer\": model_answer,\n                \"score\": score,\n            }\n        )\n\n        if debug:\n            print(f\"Question: {row['question']}\")\n            print(f\"Ground Truth Answers: {row['answers']}\")\n            print(f\"Model Answer: {model_answer}\")\n            print(f\"Score: {score}\")\n            print(f\"Running Average Score: {total_score * 100 / total_samples:.2f}\")\n            print(\"---------\")\n\n    return {\"score\": total_score * 100 / total_samples, \"results\": results}\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        torch.set_default_device(\"cuda\")\n    elif torch.backends.mps.is_available():\n        torch.set_default_device(\"mps\")\n\n    config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.compile()\n\n    result = eval_textvqa(model, args.debug)\n\n    print(f\"Score: {result['score']}\")\n"
  },
  {
    "path": "moondream/eval/utils.py",
    "content": "import re\nfrom typing import List\n\n\nclass VQAScorer:\n    def __init__(self):\n        self.contractions = {\n            \"aint\": \"ain't\",\n            \"arent\": \"aren't\",\n            \"cant\": \"can't\",\n            \"couldve\": \"could've\",\n            \"couldnt\": \"couldn't\",\n            \"couldn'tve\": \"couldn't've\",\n            \"couldnt've\": \"couldn't've\",\n            \"didnt\": \"didn't\",\n            \"doesnt\": \"doesn't\",\n            \"dont\": \"don't\",\n            \"hadnt\": \"hadn't\",\n            \"hadnt've\": \"hadn't've\",\n            \"hadn'tve\": \"hadn't've\",\n            \"hasnt\": \"hasn't\",\n            \"havent\": \"haven't\",\n            \"hed\": \"he'd\",\n            \"hed've\": \"he'd've\",\n            \"he'dve\": \"he'd've\",\n            \"hes\": \"he's\",\n            \"howd\": \"how'd\",\n            \"howll\": \"how'll\",\n            \"hows\": \"how's\",\n            \"Id've\": \"I'd've\",\n            \"I'dve\": \"I'd've\",\n            \"Im\": \"I'm\",\n            \"Ive\": \"I've\",\n            \"isnt\": \"isn't\",\n            \"itd\": \"it'd\",\n            \"itd've\": \"it'd've\",\n            \"it'dve\": \"it'd've\",\n            \"itll\": \"it'll\",\n            \"let's\": \"let's\",\n            \"maam\": \"ma'am\",\n            \"mightnt\": \"mightn't\",\n            \"mightnt've\": \"mightn't've\",\n            \"mightn'tve\": \"mightn't've\",\n            \"mightve\": \"might've\",\n            \"mustnt\": \"mustn't\",\n            \"mustve\": \"must've\",\n            \"neednt\": \"needn't\",\n            \"notve\": \"not've\",\n            \"oclock\": \"o'clock\",\n            \"oughtnt\": \"oughtn't\",\n            \"ow's'at\": \"'ow's'at\",\n            \"'ows'at\": \"'ow's'at\",\n            \"'ow'sat\": \"'ow's'at\",\n            \"shant\": \"shan't\",\n            \"shed've\": \"she'd've\",\n            \"she'dve\": \"she'd've\",\n            \"she's\": \"she's\",\n            \"shouldve\": \"should've\",\n            \"shouldnt\": \"shouldn't\",\n            \"shouldnt've\": \"shouldn't've\",\n            \"shouldn'tve\": \"shouldn't've\",\n            \"somebody'd\": \"somebodyd\",\n            \"somebodyd've\": \"somebody'd've\",\n            \"somebody'dve\": \"somebody'd've\",\n            \"somebodyll\": \"somebody'll\",\n            \"somebodys\": \"somebody's\",\n            \"someoned\": \"someone'd\",\n            \"someoned've\": \"someone'd've\",\n            \"someone'dve\": \"someone'd've\",\n            \"someonell\": \"someone'll\",\n            \"someones\": \"someone's\",\n            \"somethingd\": \"something'd\",\n            \"somethingd've\": \"something'd've\",\n            \"something'dve\": \"something'd've\",\n            \"somethingll\": \"something'll\",\n            \"thats\": \"that's\",\n            \"thered\": \"there'd\",\n            \"thered've\": \"there'd've\",\n            \"there'dve\": \"there'd've\",\n            \"therere\": \"there're\",\n            \"theres\": \"there's\",\n            \"theyd\": \"they'd\",\n            \"theyd've\": \"they'd've\",\n            \"they'dve\": \"they'd've\",\n            \"theyll\": \"they'll\",\n            \"theyre\": \"they're\",\n            \"theyve\": \"they've\",\n            \"twas\": \"'twas\",\n            \"wasnt\": \"wasn't\",\n            \"wed've\": \"we'd've\",\n            \"we'dve\": \"we'd've\",\n            \"weve\": \"we've\",\n            \"werent\": \"weren't\",\n            \"whatll\": \"what'll\",\n            \"whatre\": \"what're\",\n            \"whats\": \"what's\",\n            \"whatve\": \"what've\",\n            \"whens\": \"when's\",\n            \"whered\": \"where'd\",\n            \"wheres\": \"where's\",\n            \"whereve\": \"where've\",\n            \"whod\": \"who'd\",\n            \"whod've\": \"who'd've\",\n            \"who'dve\": \"who'd've\",\n            \"wholl\": \"who'll\",\n            \"whos\": \"who's\",\n            \"whove\": \"who've\",\n            \"whyll\": \"why'll\",\n            \"whyre\": \"why're\",\n            \"whys\": \"why's\",\n            \"wont\": \"won't\",\n            \"wouldve\": \"would've\",\n            \"wouldnt\": \"wouldn't\",\n            \"wouldnt've\": \"wouldn't've\",\n            \"wouldn'tve\": \"wouldn't've\",\n            \"yall\": \"y'all\",\n            \"yall'll\": \"y'all'll\",\n            \"y'allll\": \"y'all'll\",\n            \"yall'd've\": \"y'all'd've\",\n            \"y'alld've\": \"y'all'd've\",\n            \"y'all'dve\": \"y'all'd've\",\n            \"youd\": \"you'd\",\n            \"youd've\": \"you'd've\",\n            \"you'dve\": \"you'd've\",\n            \"youll\": \"you'll\",\n            \"youre\": \"you're\",\n            \"youve\": \"you've\",\n        }\n\n        self.manualMap = {\n            \"none\": \"0\",\n            \"zero\": \"0\",\n            \"one\": \"1\",\n            \"two\": \"2\",\n            \"three\": \"3\",\n            \"four\": \"4\",\n            \"five\": \"5\",\n            \"six\": \"6\",\n            \"seven\": \"7\",\n            \"eight\": \"8\",\n            \"nine\": \"9\",\n            \"ten\": \"10\",\n        }\n\n        self.articles = [\"a\", \"an\", \"the\"]\n\n        self.periodStrip = re.compile(r\"(?!<=\\d)(\\.)(?!\\d)\")\n        self.commaStrip = re.compile(r\"(\\d)(\\,)(\\d)\")\n        self.punct = [\n            \";\",\n            r\"/\",\n            \"[\",\n            \"]\",\n            '\"',\n            \"{\",\n            \"}\",\n            \"(\",\n            \")\",\n            \"=\",\n            \"+\",\n            \"\\\\\",\n            \"_\",\n            \"-\",\n            \">\",\n            \"<\",\n            \"@\",\n            \"`\",\n            \",\",\n            \"?\",\n            \"!\",\n        ]\n        self.commaStrip = re.compile(r\"(\\d)(,)(\\d)\")\n        self.periodStrip = re.compile(r\"(?!<=\\d)(\\.)(?!\\d)\")\n\n    def process_punctuation(self, inText: str) -> str:\n        outText = inText\n\n        for p in self.punct:\n            if (p + \" \" in inText or \" \" + p in inText) or (\n                re.search(self.commaStrip, inText) is not None\n            ):\n                outText = outText.replace(p, \"\")\n            else:\n                outText = outText.replace(p, \" \")\n        outText = self.periodStrip.sub(\"\", outText, re.UNICODE)\n        return outText\n\n    def process_digit_article(self, inText: str) -> str:\n        outText = []\n        tempText = inText.lower().split()\n        for word in tempText:\n            word = self.manualMap.setdefault(word, word)\n            if word not in self.articles:\n                outText.append(word)\n        for wordId, word in enumerate(outText):\n            if word in self.contractions:\n                outText[wordId] = self.contractions[word]\n        outText = \" \".join(outText)\n        return outText\n\n    def process_answer(self, answer):\n        answer = answer.replace(\"\\n\", \" \")\n        answer = answer.replace(\"\\t\", \" \")\n        answer = answer.strip()\n        answer = self.process_punctuation(answer)\n        answer = self.process_digit_article(answer)\n        return answer\n\n    def process_line(self, prediction: str, gt_answers: List[str]) -> float:\n        gt_answers = [self.process_answer(x) for x in gt_answers]\n        prediction = self.process_answer(prediction)\n        matches = []\n        for current_idx, gtAnsDatum in enumerate(gt_answers):\n            otherGTAns = [\n                item\n                for ret_gt_idx, item in enumerate(gt_answers)\n                if ret_gt_idx != current_idx\n            ]\n            matchingAns = [item for item in otherGTAns if item == prediction]\n            acc = min(1, float(len(matchingAns)) / 3)\n            matches.append(acc)\n\n        return sum(matches) / len(matches)\n\n    def compute_score(\n        self, candidate_answer: str, ground_truth_answers: List[str]\n    ) -> float:\n        \"\"\"\n        Compute VQA score for a candidate answer against ground truth answers,\n        exactly matching the VQAEval scoring logic\n        \"\"\"\n        # Process candidate answer\n        candidate = self.process_answer(candidate_answer)\n\n        # Process ground truth answers\n        processed_gts = []\n        for gt in ground_truth_answers:\n            gt = gt.replace(\"\\n\", \" \")\n            gt = gt.replace(\"\\t\", \" \")\n            gt = gt.strip()\n            processed_gts.append(gt)\n\n        # If there are multiple different answers, apply additional processing\n        if len(set(processed_gts)) > 1:\n            candidate = self.process_punctuation(candidate)\n            candidate = self.process_digit_article(candidate)\n            processed_gts = [\n                self.process_punctuation(self.process_digit_article(gt))\n                for gt in processed_gts\n            ]\n\n        # Count matches\n        matching_answers = [1 for gt in processed_gts if gt == candidate]\n        score = min(1.0, float(len(matching_answers)) / 3.0)\n\n        return score\n"
  },
  {
    "path": "moondream/eval/waste_detection.py",
    "content": "import argparse\nfrom collections import defaultdict\nfrom typing import Dict, List, Tuple\n\nimport torch\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom datasets import load_dataset\n\nfrom ..torch.config import MoondreamConfig\nfrom ..torch.moondream import MoondreamModel\nfrom ..torch.weights import load_weights_into_model\n\n\nBox = Tuple[float, float, float, float]  # (x1, y1, x2, y2) – in proportion form\n\n\ndef iou(a: Box, b: Box) -> float:\n    \"\"\"Corner-format IoU. Returns 0 when either box has zero area.\"\"\"\n    x1, y1 = max(a[0], b[0]), max(a[1], b[1])\n    x2, y2 = min(a[2], b[2]), min(a[3], b[3])\n    inter = max(0.0, x2 - x1) * max(0.0, y2 - y1)\n\n    union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter\n    return inter / union if union else 0.0\n\n\ndef match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, int, int]:\n    \"\"\"\n    Greedy one-to-one matching with no confidences.\n    Predictions are taken in the order produced by the model.\n    \"\"\"\n    tp = fp = 0\n    seen = [False] * len(gt)\n\n    for p in pr:\n        best, best_i = 0.0, -1\n        for i, g in enumerate(gt):\n            if seen[i]:\n                continue\n            iou_ = iou(p, g)\n            if iou_ > best:\n                best, best_i = iou_, i\n        if best >= iou_thr:\n            tp += 1\n            seen[best_i] = True\n        else:\n            fp += 1\n\n    fn = len(gt) - tp\n    return tp, fp, fn\n\n\nclass WasteDetection(torch.utils.data.Dataset):\n    def __init__(self, name: str = \"moondream/waste_detection\", split: str = \"test\"):\n        self.ds = load_dataset(name, split=split)\n\n    def __len__(self):\n        return len(self.ds)\n\n    def __getitem__(self, idx: int) -> Dict:\n        s = self.ds[idx]\n        img = (\n            s[\"image\"]\n            if isinstance(s[\"image\"], Image.Image)\n            else Image.fromarray(s[\"image\"])\n        )\n        W, H = float(s.get(\"width\", img.width)), float(s.get(\"height\", img.height))\n\n        lbl_to_boxes = defaultdict(list)\n        for (xc, yc, bw, bh), lbl in zip(s[\"boxes\"], s[\"labels\"]):\n            x1 = xc - bw / 2\n            y1 = yc - bh / 2\n            x2 = xc + bw / 2\n            y2 = yc + bh / 2\n            lbl_to_boxes[lbl].append((x1, y1, x2, y2))\n\n        return {\"image\": img, \"gt\": lbl_to_boxes, \"W\": W, \"H\": H}\n\n\ndef evaluate(\n    model: MoondreamModel,\n    iou_thr: float,\n    debug: bool,\n):\n    ds = WasteDetection(split=\"test\")\n    TP = FP = FN = 0\n\n    for s in tqdm(ds, disable=debug, desc=\"Waste\"):\n        img, gts = s[\"image\"], s[\"gt\"]\n        enc = model.encode_image(img)\n\n        for lbl, gt_boxes in gts.items():\n            preds: List[Box] = [\n                (\n                    o[\"x_min\"],\n                    o[\"y_min\"],\n                    o[\"x_max\"],\n                    o[\"y_max\"],\n                )\n                for o in model.detect(enc, lbl)[\"objects\"]\n            ]\n            tp, fp, fn = match(gt_boxes, preds, iou_thr)\n            TP += tp\n            FP += fp\n            FN += fn\n\n    prec = TP / (TP + FP) if TP + FP else 0.0\n    rec = TP / (TP + FN) if TP + FN else 0.0\n    f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0\n    return dict(precision=prec, recall=rec, f1=f1, tp=TP, fp=FP, fn=FN)\n\n\ndef load_model(path: str, device: torch.device) -> MoondreamModel:\n    cfg = MoondreamConfig()\n    model = MoondreamModel(cfg)\n    load_weights_into_model(path, model)\n    model.compile()\n    model.to(device)\n    return model\n\n\ndef main():\n    p = argparse.ArgumentParser()\n    p.add_argument(\"--model\", required=True)\n    p.add_argument(\"--iou_thr\", type=float, default=0.5)\n    p.add_argument(\"--gpu\", type=int, default=0)\n    p.add_argument(\"--debug\", action=\"store_true\")\n    args = p.parse_args()\n\n    if torch.cuda.is_available():\n        torch.cuda.set_device(args.gpu)\n        device = torch.device(f\"cuda:{args.gpu}\")\n    elif torch.backends.mps.is_available():\n        device = torch.device(\"mps\")\n    else:\n        device = torch.device(\"cpu\")\n\n    model = load_model(args.model, device)\n    res = evaluate(model, args.iou_thr, args.debug)\n\n    print(f\"Precision: {res['precision']*100:.2f}%\")\n    print(f\"Recall: {res['recall']*100:.2f}%\")\n    print(f\"F1 Score:  {res['f1']*100:.2f}%\")\n    print(f\"TP: {res['tp']}  FP: {res['fp']}  FN: {res['fn']}\")\n\n\nif __name__ == \"__main__\":\n    \"\"\"\n    Eval to accompany finetune_region.py.\n    \"\"\"\n    main()\n"
  },
  {
    "path": "moondream/torch/config.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Dict, List, Optional\n\n\n@dataclass(frozen=True)\nclass TextMoeConfig:\n    num_experts: int = 64\n    start_layer: int = 4\n    experts_per_token: int = 8\n    expert_inner_dim: int = 1024\n\n\n@dataclass(frozen=True)\nclass TextConfig:\n    dim: int = 2048\n    ff_dim: int = 8192\n    n_layers: int = 24\n    vocab_size: int = 51200\n    max_context: int = 4096\n    n_heads: int = 32\n    n_kv_heads: int = 32\n    prefix_attn: int = 730\n    group_size: Optional[int] = None\n    moe: Optional[TextMoeConfig] = TextMoeConfig()\n\n\n@dataclass(frozen=True)\nclass VisionConfig:\n    enc_dim: int = 1152\n    enc_patch_size: int = 14\n    enc_n_layers: int = 27\n    enc_ff_dim: int = 4304\n    enc_n_heads: int = 16\n    proj_out_dim: int = 2048\n    crop_size: int = 378\n    in_channels: int = 3\n    max_crops: int = 12\n    overlap_margin: int = 4\n    proj_inner_dim: int = 8192\n\n\n@dataclass(frozen=True)\nclass RegionConfig:\n    dim: int = 2048\n    coord_feat_dim: int = 256\n    coord_out_dim: int = 1024\n    size_feat_dim: int = 512\n    size_out_dim: int = 2048\n    group_size: Optional[int] = None\n\n\n@dataclass(frozen=True)\nclass TokenizerConfig:\n    bos_id: int = 0\n    eos_id: int = 0\n    answer_id: int = 3\n    thinking_id: int = 4\n    coord_id: int = 5\n    size_id: int = 6\n    start_ground_points_id: int = 7\n    end_ground_id: int = 9\n    templates: Dict[str, Optional[Dict[str, List[int]]]] = field(\n        default_factory=lambda: {\n            \"caption\": {\n                \"short\": [1, 32708, 2, 12492, 3],\n                \"normal\": [1, 32708, 2, 6382, 3],\n                \"long\": [1, 32708, 2, 4059, 3],\n            },\n            \"query\": {\"prefix\": [1, 15381, 2], \"suffix\": [3]},\n            \"detect\": {\"prefix\": [1, 7235, 476, 2], \"suffix\": [3]},\n            \"point\": {\"prefix\": [1, 2581, 2], \"suffix\": [3]},\n        }\n    )\n\n\n@dataclass(frozen=True)\nclass MoondreamConfig:\n    text: TextConfig = TextConfig()\n    vision: VisionConfig = VisionConfig()\n    region: RegionConfig = RegionConfig()\n    tokenizer: TokenizerConfig = TokenizerConfig()\n\n    @classmethod\n    def from_dict(cls, config_dict: dict):\n        text_config = TextConfig(**config_dict.get(\"text\", {}))\n        vision_config = VisionConfig(**config_dict.get(\"vision\", {}))\n        region_config = RegionConfig(**config_dict.get(\"region\", {}))\n        tokenizer_config = TokenizerConfig(**config_dict.get(\"tokenizer\", {}))\n        return cls(\n            text=text_config,\n            vision=vision_config,\n            region=region_config,\n            tokenizer=tokenizer_config,\n        )\n\n    def to_dict(self):\n        return {\n            \"text\": self.text.__dict__,\n            \"vision\": self.vision.__dict__,\n            \"region\": self.region.__dict__,\n            \"tokenizer\": self.tokenizer.__dict__,\n        }\n"
  },
  {
    "path": "moondream/torch/hf_moondream.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom transformers import PreTrainedModel, PretrainedConfig\nfrom typing import Union\n\nfrom .config import MoondreamConfig\nfrom .moondream import MoondreamModel\n\n# Files sometimes don't get loaded without these...\nfrom .image_crops import *\nfrom .vision import *\nfrom .text import *\nfrom .region import *\nfrom .utils import *\n\n\ndef extract_question(text):\n    prefix = \"<image>\\n\\nQuestion: \"\n    suffix = \"\\n\\nAnswer:\"\n\n    if text.startswith(prefix) and text.endswith(suffix):\n        return text[len(prefix) : -len(suffix)]\n    else:\n        return None\n\n\nclass HfConfig(PretrainedConfig):\n    _auto_class = \"AutoConfig\"\n    model_type = \"moondream3\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.config = {\"skills\": [\"query\", \"caption\", \"detect\", \"point\"]}\n\n\nclass HfMoondream(PreTrainedModel):\n    _auto_class = \"AutoModelForCausalLM\"\n    config_class = HfConfig\n\n    def __init__(self, config):\n        super().__init__(config)\n        self.model = MoondreamModel(\n            MoondreamConfig.from_dict(config.config), setup_caches=False\n        )\n        self._is_kv_cache_setup = False\n\n    def _setup_caches(self):\n        if not self._is_kv_cache_setup:\n            self.model._setup_caches()\n            self._is_kv_cache_setup = True\n\n    @property\n    def encode_image(self):\n        self._setup_caches()\n        return self.model.encode_image\n\n    @property\n    def query(self):\n        self._setup_caches()\n        return self.model.query\n\n    @property\n    def caption(self):\n        self._setup_caches()\n        return self.model.caption\n\n    @property\n    def detect(self):\n        self._setup_caches()\n        return self.model.detect\n\n    @property\n    def point(self):\n        self._setup_caches()\n        return self.model.point\n\n    @property\n    def detect_gaze(self):\n        self._setup_caches()\n        return self.model.detect_gaze\n\n    def answer_question(\n        self,\n        image_embeds,\n        question,\n        tokenizer=None,\n        chat_history=\"\",\n        result_queue=None,\n        max_new_tokens=256,\n        **kwargs\n    ):\n        answer = self.query(image_embeds, question)[\"answer\"].strip()\n\n        if result_queue is not None:\n            result_queue.put(answer)\n        return answer\n\n    def batch_answer(self, images, prompts, tokenizer=None, **kwargs):\n        answers = []\n        for image, prompt in zip(images, prompts):\n            answers.append(self.query(image, prompt)[\"answer\"].strip())\n        return answers\n\n    def _unsupported_exception(self):\n        raise NotImplementedError(\n            \"This method is not supported in the latest version of moondream. \"\n            \"Consider upgrading to the updated API spec, or alternately pin \"\n            \"to 'revision=2024-08-26'.\"\n        )\n\n    def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs):\n        \"\"\"\n        Function definition remains unchanged for backwards compatibility.\n        Be aware that tokenizer, max_new_takens, and kwargs are ignored.\n        \"\"\"\n        prompt_extracted = extract_question(prompt)\n        if prompt_extracted is not None:\n            answer = self.model.query(\n                image=image_embeds, question=prompt_extracted, stream=False\n            )[\"answer\"]\n        else:\n            image_embeds = self.encode_image(image_embeds)\n            prompt_tokens = torch.tensor(\n                [self.model.tokenizer.encode(prompt).ids],\n                device=self.device,\n            )\n\n            def generator():\n                for token in self.model._generate_answer(\n                    prompt_tokens,\n                    image_embeds.kv_cache,\n                    image_embeds.pos,\n                    max_new_tokens,\n                ):\n                    yield token\n\n            answer = \"\".join(list(generator()))\n\n        return [answer]\n\n    def get_input_embeddings(self) -> nn.Embedding:\n        \"\"\"\n        Lazily wrap the raw parameter `self.model.text.wte` in a real\n        `nn.Embedding` layer so that HF mix-ins recognise it.  The wrapper\n        **shares** the weight tensor—no copy is made.\n        \"\"\"\n        if not hasattr(self, \"_input_embeddings\"):\n            self._input_embeddings = nn.Embedding.from_pretrained(\n                self.model.text.wte,  # tensor created in text.py\n                freeze=True,  # set to False if you need it trainable\n            )\n        return self._input_embeddings\n\n    def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None:\n        \"\"\"\n        Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the\n        embeddings and keeps everything tied to `self.model.text.wte`.\n        \"\"\"\n        # 1. point the low-level parameter to the new weight matrix\n        self.model.text.wte = value.weight\n        # 2. keep a reference for get_input_embeddings()\n        self._input_embeddings = value\n\n    def input_embeds(\n        self,\n        input_ids: Union[torch.LongTensor, list, tuple],\n        *,\n        device: torch.device | None = None\n    ) -> torch.FloatTensor:\n        \"\"\"\n        Back-compat wrapper that turns token IDs into embeddings.\n\n        Example:\n            ids = torch.tensor([[1, 2, 3]])\n            embeds = model.input_embeds(ids)      # (1, 3, hidden_dim)\n        \"\"\"\n        if not torch.is_tensor(input_ids):\n            input_ids = torch.as_tensor(input_ids)\n        if device is not None:\n            input_ids = input_ids.to(device)\n\n        return self.get_input_embeddings()(input_ids)\n"
  },
  {
    "path": "moondream/torch/hf_release.py",
    "content": "import torch\nimport argparse\n\nfrom .weights import load_weights_into_model\nfrom .hf_moondream import HfConfig, HfMoondream\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model-name\", type=str, default=\"vikhyatk/moondream-next\")\n    parser.add_argument(\"--ckpt\", type=str, required=True)\n    args = parser.parse_args()\n\n    config = HfConfig()\n    model = HfMoondream(config)\n    load_weights_into_model(args.ckpt, model.model)\n\n    model.push_to_hub(args.model_name, config=config)\n"
  },
  {
    "path": "moondream/torch/image_crops.py",
    "content": "import math\nimport numpy as np\nimport torch\n\nfrom typing import TypedDict\n\ntry:\n    import pyvips\n\n    HAS_VIPS = True\nexcept:\n    from PIL import Image\n\n    HAS_VIPS = False\n\n\ndef select_tiling(\n    height: int, width: int, crop_size: int, max_crops: int\n) -> tuple[int, int]:\n    \"\"\"\n    Determine the optimal number of tiles to cover an image with overlapping crops.\n    \"\"\"\n    if height <= crop_size or width <= crop_size:\n        return (1, 1)\n\n    # Minimum required tiles in each dimension\n    min_h = math.ceil(height / crop_size)\n    min_w = math.ceil(width / crop_size)\n\n    # If minimum required tiles exceed max_crops, return proportional distribution\n    if min_h * min_w > max_crops:\n        ratio = math.sqrt(max_crops / (min_h * min_w))\n        return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))\n\n    # Perfect aspect-ratio tiles that satisfy max_crops\n    h_tiles = math.floor(math.sqrt(max_crops * height / width))\n    w_tiles = math.floor(math.sqrt(max_crops * width / height))\n\n    # Ensure we meet minimum tile requirements\n    h_tiles = max(h_tiles, min_h)\n    w_tiles = max(w_tiles, min_w)\n\n    # If we exceeded max_crops, scale down the larger dimension\n    if h_tiles * w_tiles > max_crops:\n        if w_tiles > h_tiles:\n            w_tiles = math.floor(max_crops / h_tiles)\n        else:\n            h_tiles = math.floor(max_crops / w_tiles)\n\n    return (max(1, h_tiles), max(1, w_tiles))\n\n\nclass OverlapCropOutput(TypedDict):\n    crops: np.ndarray\n    tiling: tuple[int, int]\n\n\ndef overlap_crop_image(\n    image: np.ndarray,\n    overlap_margin: int,\n    max_crops: int,\n    base_size: tuple[int, int] = (378, 378),\n    patch_size: int = 14,\n) -> OverlapCropOutput:\n    \"\"\"\n    Process an image using an overlap-and-resize cropping strategy with margin handling.\n\n    This function takes an input image and creates multiple overlapping crops with\n    consistent margins. It produces:\n    1. A single global crop resized to base_size\n    2. Multiple overlapping local crops that maintain high resolution details\n    3. A patch ordering matrix that tracks correspondence between crops\n\n    The overlap strategy ensures:\n    - Smooth transitions between adjacent crops\n    - No loss of information at crop boundaries\n    - Proper handling of features that cross crop boundaries\n    - Consistent patch indexing across the full image\n\n    Args:\n        image (np.ndarray): Input image as numpy array with shape (H,W,C)\n        base_size (tuple[int,int]): Target size for crops, default (378,378)\n        patch_size (int): Size of patches in pixels, default 14\n        overlap_margin (int): Margin size in patch units, default 4\n        max_crops (int): Maximum number of crops allowed, default 12\n\n    Returns:\n        OverlapCropOutput: Dictionary containing:\n            - crops: A numpy array containing the global crop of the full image (index 0)\n                followed by the overlapping cropped regions (indices 1+)\n            - tiling: Tuple of (height,width) tile counts\n    \"\"\"\n    original_h, original_w = image.shape[:2]\n\n    # Convert margin from patch units to pixels\n    margin_pixels = patch_size * overlap_margin\n    total_margin_pixels = margin_pixels * 2  # Both sides\n\n    # Calculate crop parameters\n    crop_patches = base_size[0] // patch_size  # patches per crop dimension\n    crop_window_patches = crop_patches - (2 * overlap_margin)  # usable patches\n    crop_window_size = crop_window_patches * patch_size  # usable size in pixels\n\n    # Determine tiling\n    tiling = select_tiling(\n        original_h - total_margin_pixels,\n        original_w - total_margin_pixels,\n        crop_window_size,\n        max_crops,\n    )\n\n    # Pre-allocate crops.\n    n_crops = tiling[0] * tiling[1] + 1  # 1 = global crop\n    crops = np.zeros(\n        (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8\n    )\n\n    # Resize image to fit tiling\n    target_size = (\n        tiling[0] * crop_window_size + total_margin_pixels,\n        tiling[1] * crop_window_size + total_margin_pixels,\n    )\n\n    if HAS_VIPS:\n        # Convert to vips for resizing\n        vips_image = pyvips.Image.new_from_array(image)\n        scale_x = target_size[1] / image.shape[1]\n        scale_y = target_size[0] / image.shape[0]\n        resized = vips_image.resize(scale_x, vscale=scale_y)\n        image = resized.numpy()\n\n        # Create global crop\n        scale_x = base_size[1] / vips_image.width\n        scale_y = base_size[0] / vips_image.height\n        global_vips = vips_image.resize(scale_x, vscale=scale_y)\n        crops[0] = global_vips.numpy()\n    else:\n        # Fallback to PIL\n        pil_img = Image.fromarray(image)\n        resized = pil_img.resize(\n            (int(target_size[1]), int(target_size[0])),\n            resample=Image.Resampling.LANCZOS,\n        )\n        image = np.asarray(resized)\n\n        # Create global crop\n        global_pil = pil_img.resize(\n            (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS\n        )\n        crops[0] = np.asarray(global_pil)\n\n    for i in range(tiling[0]):\n        for j in range(tiling[1]):\n            # Calculate crop coordinates\n            y0 = i * crop_window_size\n            x0 = j * crop_window_size\n\n            # Extract crop with padding if needed\n            y_end = min(y0 + base_size[0], image.shape[0])\n            x_end = min(x0 + base_size[1], image.shape[1])\n\n            crop_region = image[y0:y_end, x0:x_end]\n            crops[\n                1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]\n            ] = crop_region\n\n    return {\"crops\": crops, \"tiling\": tiling}\n\n\ndef reconstruct_from_crops(\n    crops: torch.Tensor,\n    tiling: tuple[int, int],\n    overlap_margin: int,\n    patch_size: int = 14,\n) -> torch.Tensor:\n    \"\"\"\n    Reconstruct the original image from overlapping crops into a single seamless image.\n\n    Takes a list of overlapping image crops along with their positional metadata and\n    reconstructs them into a single coherent image by carefully stitching together\n    non-overlapping regions. Handles both numpy arrays and PyTorch tensors.\n\n    Args:\n        crops: List of image crops as numpy arrays or PyTorch tensors with shape\n            (H,W,C)\n        tiling: Tuple of (height,width) indicating crop grid layout\n        patch_size: Size in pixels of each patch, default 14\n        overlap_margin: Number of overlapping patches on each edge, default 4\n\n    Returns:\n        Reconstructed image as numpy array or PyTorch tensor matching input type,\n        with shape (H,W,C) where H,W are the original image dimensions\n    \"\"\"\n    tiling_h, tiling_w = tiling\n    crop_height, crop_width = crops[0].shape[:2]\n    margin_pixels = overlap_margin * patch_size\n\n    # Calculate output size (only adding margins once)\n    output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels\n    output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels\n\n    reconstructed = torch.zeros(\n        (output_h, output_w, crops[0].shape[2]),\n        device=crops[0].device,\n        dtype=crops[0].dtype,\n    )\n\n    for i, crop in enumerate(crops):\n        tile_y = i // tiling_w\n        tile_x = i % tiling_w\n\n        # For each tile, determine which part to keep\n        # Keep left margin only for first column\n        x_start = 0 if tile_x == 0 else margin_pixels\n        # Keep right margin only for last column\n        x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels\n        # Keep top margin only for first row\n        y_start = 0 if tile_y == 0 else margin_pixels\n        # Keep bottom margin only for last row\n        y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels\n\n        # Calculate where this piece belongs in the output\n        out_x = tile_x * (crop_width - 2 * margin_pixels)\n        out_y = tile_y * (crop_height - 2 * margin_pixels)\n\n        # Place the piece\n        reconstructed[\n            out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end\n        ] = crop[y_start:y_end, x_start:x_end]\n\n    return reconstructed\n"
  },
  {
    "path": "moondream/torch/layers.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom dataclasses import dataclass\nfrom typing import Literal, Optional\n\ntry:\n    from torchao import quantize_\n    from torchao.quantization import int4_weight_only\nexcept ImportError:\n\n    def quantize_(model, quant_mode):\n        raise ImportError(\n            \"torchao is not installed. Please install it with `pip install torchao`.\"\n        )\n\n    def int4_weight_only(group_size):\n        raise ImportError(\n            \"torchao is not installed. Please install it with `pip install torchao`.\"\n        )\n\n\ndef gelu_approx(x):\n    return F.gelu(x, approximate=\"tanh\")\n\n\n@dataclass\nclass LinearWeights:\n    weight: torch.Tensor\n    bias: torch.Tensor\n\n\ndef linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor:\n    return F.linear(x, w.weight, w.bias)\n\n\ndef dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16):\n    _step = W_q.shape[0]\n    W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device)\n    W_r[:_step] = (W_q & 0b11110000) >> 4\n    W_r[_step:] = W_q & 0b00001111\n    W_r.sub_(zero).mul_(scale)\n    return W_r.reshape(orig_shape)\n\n\nclass QuantizedLinear(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        out_features: int,\n        dtype: torch.dtype,\n    ):\n        # TODO: Take group_size as an input instead of hardcoding it here.\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = nn.ParameterDict(\n            {\n                \"packed\": nn.Parameter(\n                    torch.empty(\n                        out_features * in_features // (128 * 2), 128, dtype=torch.uint8\n                    ),\n                    requires_grad=False,\n                ),\n                \"scale\": nn.Parameter(\n                    torch.empty(out_features * in_features // 128, 1),\n                    requires_grad=False,\n                ),\n                \"zero_point\": nn.Parameter(\n                    torch.empty(out_features * in_features // 128, 1),\n                    requires_grad=False,\n                ),\n            }\n        )\n        self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False)\n        self.unpacked = False\n\n    def unpack(self):\n        if self.unpacked:\n            return\n\n        self.weight = nn.Parameter(\n            dequantize_tensor(\n                self.weight[\"packed\"],\n                self.weight[\"scale\"],\n                self.weight[\"zero_point\"],\n                (self.out_features, self.in_features),\n                torch.bfloat16,\n            )\n        )\n        with torch.device(\"meta\"):\n            self.linear = nn.Linear(\n                self.in_features, self.out_features, dtype=torch.bfloat16\n            )\n        self.linear.weight = self.weight\n        self.linear.bias = nn.Parameter(\n            self.bias.to(torch.bfloat16), requires_grad=False\n        )\n\n        del self.weight, self.bias\n        quantize_(self, int4_weight_only(group_size=128))\n        self.unpacked = True\n        torch.cuda.empty_cache()\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if not self.unpacked:\n            self.unpack()\n        return self.linear(x)\n\n\n@dataclass\nclass LayerNormWeights:\n    weight: torch.Tensor\n    bias: torch.Tensor\n\n\ndef layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor:\n    return F.layer_norm(x, w.bias.shape, w.weight, w.bias)\n\n\n@dataclass\nclass MLPWeights:\n    fc1: LinearWeights\n    fc2: LinearWeights\n    act: Literal[\"gelu_approx\"] = \"gelu_approx\"\n\n\ndef mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor:\n    x0 = w.fc1(x)\n    if lora is not None:\n        x1 = F.linear(F.linear(x, lora[\"fc1\"][\"A\"]), lora[\"fc1\"][\"B\"])\n        x = x0 + x1\n    else:\n        x = x0\n\n    x = gelu_approx(x)\n\n    x0 = w.fc2(x)\n    if lora is not None:\n        x1 = F.linear(F.linear(x, lora[\"fc2\"][\"A\"]), lora[\"fc2\"][\"B\"])\n        x = x0 + x1\n    else:\n        x = x0\n\n    return x\n\n\ndef moe_mlp(\n    x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int\n) -> torch.Tensor:\n    B, T, C = x.shape\n    x = x.reshape(-1, C)\n\n    # Router computation\n    router_logits = mlp_module.router(x)\n    topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1)\n    topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype)\n    num_tokens, top_k = topk_idxs.shape\n\n    if T == 1:\n        w1_weight = mlp_module.fc1.weight\n        w2_weight = mlp_module.fc2.weight\n\n        # Flatten to process all token-expert pairs at once\n        flat_idxs = topk_idxs.view(-1)  # [T*A]\n        flat_weights = topk_weights.view(-1)  # [T*A]\n\n        # Select expert weights\n        w1_selected = w1_weight[flat_idxs]  # [T*A, H, D]\n        w2_selected = w2_weight[flat_idxs]  # [T*A, D, H]\n\n        # Expand input for all token-expert pairs\n        x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C)  # [T*A, D]\n\n        # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H]\n        x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze(\n            -1\n        )  # [T*A, H]\n        x1, g = x1_full.chunk(2, dim=-1)\n        x1 = F.gelu(x1) * (g + 1)\n\n        # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D]\n        expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1)  # [T*A, D]\n\n        # Apply weights and reshape\n        weighted_outs = expert_outs * flat_weights.unsqueeze(-1)  # [T*A, D]\n        weighted_outs = weighted_outs.view(num_tokens, top_k, C)  # [T, A, D]\n\n        # Sum over experts\n        mlp_out = weighted_outs.sum(dim=1)  # [T, D]\n        mlp_out = mlp_out.view(B, T, C)\n\n        return mlp_out\n    else:\n        out = x.new_zeros(x.size())\n\n        for expert_id in range(mlp_module.fc1.weight.shape[0]):\n            token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True)\n            if token_pos.numel() == 0:\n                continue\n\n            x_tok = x.index_select(0, token_pos)\n            gate_tok = topk_weights[token_pos, which_k]\n\n            h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id])\n            h, g = h_full.chunk(2, dim=-1)\n            h = F.gelu(h) * (g + 1)\n            y = F.linear(h, mlp_module.fc2.weight[expert_id])\n\n            y.mul_(gate_tok.unsqueeze(-1))\n            out.index_add_(0, token_pos, y)\n\n        return out.view(B, T, C)\n\n\n@dataclass\nclass AttentionWeights:\n    qkv: LinearWeights\n    proj: LinearWeights\n\n\ndef attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:\n    bsz, q_len, d_model = x.shape\n    head_dim = d_model // n_heads\n\n    q, k, v = [\n        t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n        for t in linear(x, w.qkv).chunk(3, dim=-1)\n    ]\n    out = F.scaled_dot_product_attention(q, k, v)\n    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)\n    out = linear(out, w.proj)\n    return out\n"
  },
  {
    "path": "moondream/torch/lora.py",
    "content": "import functools\nimport os\nimport shutil\nimport torch\n\nfrom pathlib import Path\nfrom urllib.request import Request, urlopen\nfrom typing import Optional\n\n\ndef variant_cache_dir():\n    hf_hub_cache = os.environ.get(\"HF_HUB_CACHE\")\n    if hf_hub_cache is not None:\n        return Path(hf_hub_cache) / \"md_variants\"\n\n    hf_home = os.environ.get(\"HF_HOME\")\n    if hf_home is not None:\n        return Path(hf_home) / \"hub\" / \"md_variants\"\n\n    return Path(\"~/.cache/huggingface/hub\").expanduser() / \"md_variants\"\n\n\ndef cached_variant_path(variant_id: str):\n    variant, *rest = variant_id.split(\"/\", 1)\n    step = rest[0] if rest else \"final\"\n\n    cache_dir = variant_cache_dir() / variant\n    os.makedirs(cache_dir, exist_ok=True)\n    dest = cache_dir / f\"{step}.pt\"\n    if dest.exists():\n        return dest\n\n    md_endpoint = os.getenv(\"MOONDREAM_ENDPOINT\", \"https://api.moondream.ai\")\n\n    headers = {\"User-Agent\": \"moondream-torch\"}\n    api_key = os.getenv(\"MOONDREAM_API_KEY\")\n    if api_key is not None:\n        headers[\"X-Moondream-Auth\"] = api_key\n\n    req = Request(f\"{md_endpoint}/v1/variants/{variant_id}/download\", headers=headers)\n    with urlopen(req) as r, open(dest, \"wb\") as f:\n        shutil.copyfileobj(r, f)\n    return dest\n\n\ndef nest(flat):\n    tree = {}\n    for k, v in flat.items():\n        parts = k.split(\".\")\n        d = tree\n        for p in parts[:-1]:\n            d = d.setdefault(p, {})\n        d[parts[-1]] = v\n    return tree\n\n\n@functools.lru_cache(maxsize=5)\ndef variant_state_dict(variant_id: Optional[str] = None, device: str = \"cpu\"):\n    if variant_id is None:\n        return None\n\n    state_dict = torch.load(\n        cached_variant_path(variant_id), map_location=device, weights_only=True\n    )\n\n    # TODO: Move these into the training code that saves checkpoints...\n    rename_rules = [\n        (\"text_model.transformer.h\", \"text.blocks\"),\n        (\".mixer\", \".attn\"),\n        (\".out_proj\", \".proj\"),\n        (\".Wqkv\", \".qkv\"),\n        (\".parametrizations.weight.0\", \"\"),\n    ]\n    new_state_dict = {}\n    for key, tensor in state_dict.items():\n        new_key = key\n        for old, new in rename_rules:\n            if old in new_key:\n                new_key = new_key.replace(old, new)\n        new_state_dict[new_key] = tensor\n\n    return nest(new_state_dict)\n"
  },
  {
    "path": "moondream/torch/moondream.py",
    "content": "import torch\nimport torch.nn as nn\nimport random\n\nfrom typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List\nfrom PIL import Image\nfrom dataclasses import dataclass\nfrom tokenizers import Tokenizer\nfrom torch.nn.attention.flex_attention import create_block_mask\n\nfrom .config import MoondreamConfig\nfrom .image_crops import reconstruct_from_crops\nfrom .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model\nfrom .text import build_text_model, text_encoder, lm_head, text_decoder\nfrom .region import (\n    decode_coordinate,\n    encode_coordinate,\n    decode_size,\n    encode_size,\n    encode_spatial_refs,\n    SpatialRefs,\n)\nfrom .layers import QuantizedLinear\nfrom .lora import variant_state_dict\nfrom .utils import remove_outlier_points\n\nImageEncodingSettings = TypedDict(\n    \"ImageEncodingSettings\",\n    {\"variant\": str},\n    total=False,\n)\n\nTextSamplingSettings = TypedDict(\n    \"TextSamplingSettings\",\n    {\n        \"max_tokens\": int,\n        \"temperature\": float,\n        \"top_p\": float,\n        \"variant\": str,\n    },\n    total=False,\n)\n\nObjectSamplingSettings = TypedDict(\n    \"ObjectSamplingSettings\",\n    {\"max_objects\": int, \"variant\": str},\n    total=False,\n)\n\n\nDEFAULT_MAX_TOKENS = 768\nDEFAULT_TEMPERATURE = 0.5\nDEFAULT_TOP_P = 0.9\nDEFAULT_MAX_OBJECTS = 150\n\n\n@dataclass(frozen=True)\nclass EncodedImage:\n    pos: int\n    caches: List[Tuple[torch.Tensor, torch.Tensor]]\n\n\nclass KVCache(nn.Module):\n\n    def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):\n        super().__init__()\n        cache_shape = (1, n_kv_heads, max_context, dim // n_heads)\n        self.register_buffer(\n            \"k_cache\", torch.zeros(*cache_shape, device=device, dtype=dtype)\n        )\n        self.register_buffer(\n            \"v_cache\", torch.zeros(*cache_shape, device=device, dtype=dtype)\n        )\n\n    def update(self, pos_ids, k, v):\n        kout, vout = self.k_cache, self.v_cache\n        kout[:, :, pos_ids, :] = k\n        vout[:, :, pos_ids, :] = v\n        return kout, vout\n\n\ndef causal_mask(b, h, q_idx, kv_idx):\n    return q_idx >= kv_idx\n\n\ndef get_mask_mod(mask_mod, offset):\n    def _mask_mod(b, h, q, kv):\n        return mask_mod(b, h, q + offset, kv)\n\n    return _mask_mod\n\n\nclass MoondreamModel(nn.Module):\n\n    def __init__(\n        self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True\n    ):\n        super().__init__()\n        self.config = config\n\n        self.tokenizer = Tokenizer.from_pretrained(\"moondream/starmie-v1\")\n        self.vision = build_vision_model(config.vision, dtype)\n        self.text = build_text_model(config.text, dtype)\n\n        # Region Model\n        linear_cls = (\n            QuantizedLinear if config.region.group_size is not None else nn.Linear\n        )\n        self.region = nn.ModuleDict(\n            {\n                \"coord_encoder\": linear_cls(\n                    config.region.coord_feat_dim, config.region.dim, dtype=dtype\n                ),\n                \"coord_decoder\": linear_cls(\n                    config.region.dim, config.region.coord_out_dim, dtype=dtype\n                ),\n                \"size_encoder\": linear_cls(\n                    config.region.size_feat_dim, config.region.dim, dtype=dtype\n                ),\n                \"size_decoder\": linear_cls(\n                    config.region.dim, config.region.size_out_dim, dtype=dtype\n                ),\n            }\n        )\n        self.region.coord_features = nn.Parameter(\n            torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T\n        )\n        self.region.size_features = nn.Parameter(\n            torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T\n        )\n\n        attn_mask = torch.tril(\n            torch.ones(\n                1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool\n            )\n        )\n        patch_w = config.vision.crop_size // config.vision.enc_patch_size\n        prefix_attn_len = 1 + patch_w**2\n        attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1\n        self.register_buffer(\"attn_mask\", attn_mask, persistent=False)\n\n        self.use_flex_decoding = True\n        self._causal_block_mask = None\n        self._point_gen_indices = None\n\n        # Initialize KV caches.\n        if setup_caches:\n            self._setup_caches()\n\n    @property\n    def causal_block_mask(self):\n        # The things we do to deal with ZeroGPU...\n        if self._causal_block_mask is None:\n            self._causal_block_mask = create_block_mask(\n                causal_mask,\n                B=None,\n                H=None,\n                Q_LEN=self.config.text.max_context,\n                KV_LEN=self.config.text.max_context,\n            )\n        return self._causal_block_mask\n\n    @property\n    def point_gen_indices(self):\n        if self._point_gen_indices is None:\n            self._point_gen_indices = torch.tensor(\n                [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id],\n                device=self.device,\n            )\n        return self._point_gen_indices\n\n    def _setup_caches(self):\n        c = self.config.text\n        for b in self.text.blocks:\n            b.kv_cache = KVCache(\n                c.n_heads,\n                c.n_kv_heads,\n                c.max_context,\n                c.dim,\n                device=self.device,\n                dtype=self.vision.pos_emb.dtype,\n            )\n\n    @property\n    def device(self):\n        return self.vision.pos_emb.device\n\n    def _vis_enc(self, x: torch.Tensor):\n        return vision_encoder(x, self.vision, self.config.vision)\n\n    def _vis_proj(self, g: torch.Tensor, r: torch.Tensor):\n        return vision_projection(g, r, self.vision, self.config.vision)\n\n    def _prefill(\n        self,\n        x: torch.Tensor,\n        attn_mask: torch.Tensor,\n        pos_ids: torch.Tensor,\n        lora: Optional[torch.Tensor],\n    ):\n        return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora)\n\n    def _decode_one_tok(\n        self,\n        x: torch.Tensor,\n        attn_mask: torch.Tensor,\n        pos_ids: torch.Tensor,\n        lora: Optional[torch.Tensor],\n        lm_head_indices: Optional[torch.Tensor] = None,\n    ):\n        if self.use_flex_decoding:\n            torch._assert(pos_ids.shape[-1] == 1, \"Invalid position ID shape\")\n            block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0]\n            mask = self.causal_block_mask[:, :, block_index]\n            mask.seq_lengths = (1, mask.seq_lengths[1])\n            mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0])\n        else:\n            mask = None\n\n        hidden = text_decoder(\n            x,\n            self.text,\n            attn_mask,\n            pos_ids,\n            self.config.text,\n            lora=lora,\n            flex_block_mask_slice=mask,\n        )\n        logits = lm_head(hidden, self.text, indices=lm_head_indices)\n        return logits, hidden\n\n    def compile(self):\n        for module in self.modules():\n            if isinstance(module, QuantizedLinear):\n                module.unpack()\n\n        # Initialize lazy properties to avoid first-call overhead\n        self.causal_block_mask\n        self.point_gen_indices\n\n        # TODO: vision_projection and _prefill is not being compiled\n        self._vis_enc = torch.compile(self._vis_enc, fullgraph=True)\n        self._decode_one_tok = torch.compile(\n            self._decode_one_tok, fullgraph=True, mode=\"reduce-overhead\"\n        )\n\n        # Warm up compiled methods with dummy forward passes\n        device = self.device\n        dtype = self.vision.pos_emb.dtype\n        with torch.no_grad():\n            # Warmup vision encoder\n            dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype)\n            self._vis_enc(dummy_crops)\n\n            # Warmup _decode_one_tok (both normal and point generation modes)\n            dummy_emb = torch.randn(\n                1, 1, self.config.text.dim, device=device, dtype=dtype\n            )\n            dummy_mask = torch.ones(\n                1, 1, self.config.text.max_context, device=device, dtype=torch.bool\n            )\n            dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long)\n            self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None)\n            self._decode_one_tok(\n                dummy_emb,\n                dummy_mask,\n                dummy_pos_ids,\n                None,\n                lm_head_indices=self.point_gen_indices,\n            )\n\n    def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor:\n        all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device)\n\n        torch._dynamo.mark_dynamic(all_crops, 0)\n\n        outputs = self._vis_enc(all_crops)\n\n        global_features = outputs[0]\n        local_features = outputs[1:].view(\n            -1,\n            self.config.vision.enc_n_layers,\n            self.config.vision.enc_n_layers,\n            self.config.vision.enc_dim,\n        )\n\n        reconstructed = reconstruct_from_crops(\n            local_features,\n            tiling,\n            patch_size=1,\n            overlap_margin=self.config.vision.overlap_margin,\n        )\n\n        return self._vis_proj(global_features, reconstructed)\n\n    def encode_image(\n        self,\n        image: Union[Image.Image, EncodedImage],\n        settings: Optional[ImageEncodingSettings] = None,\n    ) -> EncodedImage:\n        if isinstance(image, EncodedImage):\n            return image\n        elif not isinstance(image, Image.Image):\n            raise ValueError(\"image must be a PIL Image or EncodedImage\")\n\n        lora = (\n            variant_state_dict(settings[\"variant\"], device=self.device)\n            if settings is not None and \"variant\" in settings\n            else None\n        )\n\n        # Run through text model in addition to the vision encoder, to minimize\n        # re-computation if multiple queries are performed on this image.\n        with torch.inference_mode():\n            img_emb = self._run_vision_encoder(image)\n            bos_emb = text_encoder(\n                torch.tensor([[self.config.tokenizer.bos_id]], device=self.device),\n                self.text,\n            )\n            inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)\n            mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]\n            pos_ids = torch.arange(\n                inputs_embeds.size(1), dtype=torch.long, device=self.device\n            )\n            self._prefill(inputs_embeds, mask, pos_ids, lora)\n\n        return EncodedImage(\n            pos=inputs_embeds.size(1),\n            caches=[\n                (\n                    b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),\n                    b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),\n                )\n                for b in self.text.blocks\n            ],\n        )\n\n    def _apply_top_p(self, probs: torch.Tensor, top_p: float):\n        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)\n        probs_sum = torch.cumsum(probs_sort, dim=-1)\n        mask = probs_sum - probs_sort > top_p\n        probs_sort[mask] = 0.0\n        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))\n        next_probs = torch.zeros_like(probs)\n        next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort)\n        return next_probs\n\n    def _prefill_prompt(\n        self,\n        prompt_tokens: torch.Tensor,\n        pos: int,\n        temperature: float,\n        top_p: float,\n        spatial_refs: Optional[SpatialRefs] = None,\n        attn_mask: Optional[torch.Tensor] = None,\n        lora: Optional[dict] = None,\n    ):\n        with torch.inference_mode():\n            prompt_emb = text_encoder(prompt_tokens, self.text)\n\n            if spatial_refs:\n                encoded_refs = encode_spatial_refs(spatial_refs, self.region)\n                prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = (\n                    encoded_refs[\"coords\"]\n                )\n                if encoded_refs[\"sizes\"] is not None:\n                    prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = (\n                        encoded_refs[\"sizes\"]\n                    )\n\n            torch._dynamo.mark_dynamic(prompt_emb, 1)\n\n            if attn_mask is None:\n                attn_mask = self.attn_mask\n\n            mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :]\n            pos_ids = torch.arange(\n                pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device\n            )\n            hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora)\n            logits_BV = lm_head(hidden_BC, self.text)\n\n            if temperature == 0:\n                next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)\n            else:\n                probs = torch.softmax(logits_BV / temperature, dim=-1)\n                probs = self._apply_top_p(probs, top_p)\n                next_token = torch.multinomial(probs, num_samples=1)\n\n        pos = pos + prompt_emb.size(1)\n        return logits_BV, hidden_BC, next_token, pos\n\n    def _generate_reasoning(\n        self,\n        prompt_tokens,\n        pos,\n        settings: Optional[TextSamplingSettings] = None,\n        spatial_refs: Optional[SpatialRefs] = None,\n        attn_mask: Optional[torch.Tensor] = None,\n    ) -> Tuple[int, str, List[dict]]:\n        max_tokens = (\n            settings.get(\"max_tokens\", DEFAULT_MAX_TOKENS)\n            if settings\n            else DEFAULT_MAX_TOKENS\n        )\n        temperature = (\n            settings.get(\"temperature\", DEFAULT_TEMPERATURE)\n            if settings\n            else DEFAULT_TEMPERATURE\n        )\n        lora = (\n            variant_state_dict(settings[\"variant\"], device=self.device)\n            if settings is not None and \"variant\" in settings\n            else None\n        )\n\n        top_p = settings.get(\"top_p\", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P\n        eos_id = self.config.tokenizer.answer_id\n\n        _, last_hidden_BC, next_token, pos = self._prefill_prompt(\n            prompt_tokens,\n            pos,\n            temperature,\n            top_p,\n            spatial_refs,\n            attn_mask=attn_mask,\n            lora=lora,\n        )\n\n        text_token_chunks = [[]]\n        grounding_chunks = [[]]\n\n        mask = torch.zeros(\n            1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool\n        )\n        mask[:, :, :pos] = 1\n        pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)\n        generated_tokens = 0\n\n        while (\n            next_token_id := next_token.item()\n        ) != eos_id and generated_tokens < max_tokens:\n            if (\n                next_token_id == self.config.tokenizer.start_ground_points_id\n                or next_token_id == self.config.tokenizer.end_ground_id\n            ):\n                text_token_chunks.append([])\n                grounding_chunks.append([])\n\n            text_token_chunks[-1].append(next_token_id)\n\n            with torch.inference_mode():\n                if next_token_id == self.config.tokenizer.coord_id:\n                    coord_logits = decode_coordinate(last_hidden_BC, self.region)\n                    coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1)\n                    grounding_chunks[-1].append(coord.item())\n\n                    next_emb = encode_coordinate(\n                        coord.to(dtype=coord_logits.dtype), self.region\n                    ).unsqueeze(0)\n                else:\n                    next_emb = text_encoder(next_token, self.text)\n\n                mask[:, :, pos], pos_ids[0] = 1, pos\n\n                logits_BV, last_hidden_BC = self._decode_one_tok(\n                    next_emb, mask, pos_ids, lora\n                )\n                logits_BV[:, self.config.tokenizer.eos_id] = float(\"-inf\")\n                logits_BV[:, self.config.tokenizer.size_id] = float(\"-inf\")\n\n                pos += 1\n\n                if temperature == 0:\n                    next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1)  # (1, 1)\n                else:\n                    probs = torch.softmax(logits_BV / temperature, dim=-1)  # (1, V)\n                    probs = self._apply_top_p(probs, top_p)\n                    next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)\n\n                generated_tokens += 1\n\n        text_chunks = [\n            self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks\n        ]\n        text = \"\".join(text_chunks)\n\n        start_idx = 0\n        grounding = []\n        for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks):\n            if len(grounding_chunk) > 1:\n                points = []\n                for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2):\n                    points.append((grounding_chunk[i], grounding_chunk[i + 1]))\n                grounding.append(\n                    {\n                        \"start_idx\": start_idx,\n                        \"end_idx\": start_idx + len(text_chunk),\n                        \"points\": points,\n                    }\n                )\n            start_idx += len(text_chunk)\n\n        return pos, text, grounding\n\n    def _generate_answer(\n        self,\n        prompt_tokens: torch.Tensor,\n        pos: int,\n        settings: Optional[TextSamplingSettings] = None,\n        spatial_refs: Optional[SpatialRefs] = None,\n        eos_id: Optional[int] = None,\n        attn_mask: Optional[torch.Tensor] = None,\n    ):\n        max_tokens = (\n            settings.get(\"max_tokens\", DEFAULT_MAX_TOKENS)\n            if settings\n            else DEFAULT_MAX_TOKENS\n        )\n        temperature = (\n            settings.get(\"temperature\", DEFAULT_TEMPERATURE)\n            if settings\n            else DEFAULT_TEMPERATURE\n        )\n        top_p = settings.get(\"top_p\", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P\n        eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id\n        lora = (\n            variant_state_dict(settings[\"variant\"], device=self.device)\n            if settings is not None and \"variant\" in settings\n            else None\n        )\n\n        _, _, next_token, pos = self._prefill_prompt(\n            prompt_tokens,\n            pos,\n            temperature,\n            top_p,\n            spatial_refs,\n            attn_mask=attn_mask,\n            lora=lora,\n        )\n\n        def generator(next_token, pos):\n            mask = torch.zeros(\n                1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool\n            )\n            mask[:, :, :pos] = 1\n            pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)\n            generated_tokens = 0\n\n            # For properly handling token streaming with Unicode\n            token_cache = []\n            print_len = 0\n\n            while (\n                next_token_id := next_token.item()\n            ) != eos_id and generated_tokens < max_tokens:\n                # Add token to our cache\n                token_cache.append(next_token_id)\n\n                # Decode all tokens collected so far\n                text = self.tokenizer.decode(token_cache)\n\n                # After a newline, we flush the cache completely\n                if text.endswith(\"\\n\"):\n                    printable_text = text[print_len:]\n                    token_cache = []\n                    print_len = 0\n                    if printable_text:\n                        yield printable_text\n                # If the last token is a CJK character, we can safely print it\n                elif len(text) > 0 and _is_cjk_char(ord(text[-1])):\n                    printable_text = text[print_len:]\n                    print_len += len(printable_text)\n                    if printable_text:\n                        yield printable_text\n                # Otherwise, only yield up to the last space to avoid cutting words\n                else:\n                    last_space_idx = text.rfind(\" \", print_len)\n                    if last_space_idx >= print_len:\n                        printable_text = text[print_len : last_space_idx + 1]\n                        print_len += len(printable_text)\n                        if printable_text:\n                            yield printable_text\n\n                with torch.inference_mode():\n                    next_emb = text_encoder(next_token, self.text)\n                    mask[:, :, pos], pos_ids[0] = 1, pos\n\n                    logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora)\n                    logits_BV[:, self.config.tokenizer.answer_id] = float(\"-inf\")\n\n                    pos += 1\n\n                    if temperature == 0:\n                        next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(\n                            1\n                        )  # (1, 1)\n                    else:\n                        probs = torch.softmax(logits_BV / temperature, dim=-1)  # (1, V)\n                        probs = self._apply_top_p(probs, top_p)\n                        next_token = torch.multinomial(probs, num_samples=1)  # (1, 1)\n\n                    generated_tokens += 1\n\n            # Flush any remaining text in the cache\n            if token_cache:\n                text = self.tokenizer.decode(token_cache)\n                printable_text = text[print_len:]\n                if printable_text:\n                    yield printable_text\n\n        return generator(next_token, pos)\n\n    def query(\n        self,\n        image: Optional[Union[Image.Image, EncodedImage]] = None,\n        question: str = None,\n        reasoning: bool = True,\n        spatial_refs: Optional[SpatialRefs] = None,\n        stream: bool = False,\n        settings: Optional[TextSamplingSettings] = None,\n    ):\n        if self.config.tokenizer.templates[\"query\"] is None:\n            raise NotImplementedError(\"Model does not support querying.\")\n\n        if question is None:\n            raise ValueError(\"question must be provided.\")\n\n        if spatial_refs and image is None:\n            raise ValueError(\"spatial_refs can only be used with an image.\")\n\n        attn_mask = self.attn_mask\n        if image is not None:\n            image = self.encode_image(image, settings)\n            self.load_encoded_image(image)\n            pos = image.pos\n            prompt_toks = self.config.tokenizer.templates[\"query\"][\"prefix\"]\n        else:\n            self._setup_caches()\n            pos = 0\n            prompt_toks = [\n                self.config.tokenizer.bos_id\n            ] + self.config.tokenizer.templates[\"query\"][\"prefix\"]\n            max_context = self.config.text.max_context\n            attn_mask = torch.tril(\n                torch.ones(1, 1, max_context, max_context, dtype=torch.bool)\n            ).to(self.device)\n\n        spatial_toks = []\n        if spatial_refs:\n            for ref in spatial_refs:\n                coord_id = self.config.tokenizer.coord_id\n                size_id = self.config.tokenizer.size_id\n                if len(ref) == 2:\n                    spatial_toks.extend([coord_id, coord_id])\n                else:\n                    spatial_toks.extend([coord_id, coord_id, size_id])\n\n        prompt_tokens = [\n            prompt_toks + spatial_toks + self.tokenizer.encode(question).ids\n        ]\n\n        if reasoning:\n            prompt_tokens[0] += [self.config.tokenizer.thinking_id]\n            prompt_tokens = torch.tensor(prompt_tokens, device=self.device)\n            pos, reasoning_text, reasoning_grounding = self._generate_reasoning(\n                prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask\n            )\n            prompt_tokens = [self.config.tokenizer.templates[\"query\"][\"suffix\"]]\n            reasoning_dict = {\n                \"reasoning\": {\"text\": reasoning_text, \"grounding\": reasoning_grounding}\n            }\n        else:\n            prompt_tokens[0] += self.config.tokenizer.templates[\"query\"][\"suffix\"]\n            reasoning_dict = {}\n\n        prompt_tokens = torch.tensor(prompt_tokens, device=self.device)\n\n        def generator():\n            for token in self._generate_answer(\n                prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask\n            ):\n                yield token\n\n        if stream:\n            return {**reasoning_dict, \"answer\": generator()}\n        else:\n            return {**reasoning_dict, \"answer\": \"\".join(list(generator()))}\n\n    def load_encoded_image(self, encoded_image: EncodedImage):\n        for b, (k, v) in zip(self.text.blocks, encoded_image.caches):\n            b.kv_cache.k_cache[:, :, : k.size(2), :] = k\n            b.kv_cache.v_cache[:, :, : v.size(2), :] = v\n\n    def caption(\n        self,\n        image: Union[Image.Image, EncodedImage],\n        length: Literal[\"normal\", \"short\", \"long\"] = \"normal\",\n        stream: bool = False,\n        settings: Optional[TextSamplingSettings] = None,\n    ):\n        if self.config.tokenizer.templates[\"caption\"] is None:\n            raise NotImplementedError(\"Model does not support captioning.\")\n        if length not in self.config.tokenizer.templates[\"caption\"]:\n            raise ValueError(f\"Model does not support caption length '{length}'.\")\n\n        image = self.encode_image(image, settings)\n        self.load_encoded_image(image)\n\n        prompt_tokens = torch.tensor(\n            [self.config.tokenizer.templates[\"caption\"][length]], device=self.device\n        )\n\n        def generator():\n            for token in self._generate_answer(prompt_tokens, image.pos, settings):\n                yield token\n\n        if stream:\n            return {\"caption\": generator()}\n        else:\n            return {\"caption\": \"\".join(list(generator()))}\n\n    def _generate_points(\n        self,\n        hidden: torch.Tensor,\n        next_token: torch.Tensor,\n        pos: int,\n        include_size: bool = True,\n        max_objects: int = DEFAULT_MAX_OBJECTS,\n        lora: Optional[dict] = None,\n    ):\n        out = []\n        mask = torch.zeros(\n            1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool\n        )\n        mask[:, :, :pos] = 1\n        pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long)\n\n        with torch.inference_mode():\n            while (\n                next_token.item() != self.config.tokenizer.eos_id\n                and len(out) < max_objects\n            ):\n                x_logits = decode_coordinate(hidden, self.region)\n                x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1)\n                next_emb = encode_coordinate(\n                    x_center.to(dtype=x_logits.dtype), self.region\n                ).unsqueeze(0)\n\n                # Decode y-coordinate\n                mask[:, :, pos], pos_ids[0] = 1, pos\n                _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)\n                pos += 1\n                y_logits = decode_coordinate(hidden, self.region)\n                y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1)\n                next_emb = encode_coordinate(\n                    y_center.to(dtype=y_logits.dtype), self.region\n                ).unsqueeze(0)\n\n                # Decode size\n                if include_size:\n                    mask[:, :, pos], pos_ids[0] = 1, pos\n                    logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora)\n                    pos += 1\n                    size_logits = decode_size(hidden, self.region)\n\n                    # Get bin indices from the logits\n                    w_bin = torch.argmax(size_logits[0], dim=-1)\n                    h_bin = torch.argmax(size_logits[1], dim=-1)\n\n                    # Convert from bin indices to actual size values using the inverse of the log-scale mapping\n                    # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0)\n                    w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0)\n                    h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0)\n\n                    next_emb = (\n                        encode_size(\n                            torch.tensor(\n                                [w, h], device=self.device, dtype=size_logits.dtype\n                            ),\n                            self.region,\n                        )\n                        .unsqueeze(0)\n                        .unsqueeze(0)\n                    )\n\n                    # Add object\n                    out.append(\n                        {\n                            \"x_min\": x_center.item() - w.item() / 2,\n                            \"y_min\": y_center.item() - h.item() / 2,\n                            \"x_max\": x_center.item() + w.item() / 2,\n                            \"y_max\": y_center.item() + h.item() / 2,\n                        }\n                    )\n                else:\n                    out.append({\"x\": x_center.item(), \"y\": y_center.item()})\n\n                # Decode next token (x-coordinate, or eos)\n                mask[:, :, pos], pos_ids[0] = 1, pos\n                logits, hidden = self._decode_one_tok(\n                    next_emb,\n                    mask,\n                    pos_ids,\n                    lora,\n                    lm_head_indices=self.point_gen_indices,\n                )\n                pos += 1\n                # Map back: index 0 -> coord_id, index 1 -> eos_id\n                next_token_idx = torch.argmax(logits, dim=-1)\n                next_token = self.point_gen_indices[next_token_idx]\n\n        return out\n\n    def detect(\n        self,\n        image: Union[Image.Image, EncodedImage],\n        object: str,\n        settings: Optional[ObjectSamplingSettings] = None,\n    ):\n        if self.config.tokenizer.templates[\"detect\"] is None:\n            raise NotImplementedError(\"Model does not support object detection.\")\n\n        image = self.encode_image(image, settings)\n        self.load_encoded_image(image)\n\n        prompt_tokens = torch.tensor(\n            [\n                self.config.tokenizer.templates[\"detect\"][\"prefix\"]\n                + self.tokenizer.encode(\" \" + object).ids\n                + self.config.tokenizer.templates[\"detect\"][\"suffix\"]\n            ],\n            device=self.device,\n        )\n\n        lora = (\n            variant_state_dict(settings[\"variant\"], device=self.device)\n            if settings is not None and \"variant\" in settings\n            else None\n        )\n\n        _, hidden, next_token, pos = self._prefill_prompt(\n            prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora\n        )\n        hidden = hidden[:, -1:, :]\n\n        max_objects = (\n            settings.get(\"max_objects\", DEFAULT_MAX_OBJECTS)\n            if settings\n            else DEFAULT_MAX_OBJECTS\n        )\n        objects = self._generate_points(\n            hidden,\n            next_token,\n            pos,\n            include_size=True,\n            max_objects=max_objects,\n            lora=lora,\n        )\n\n        return {\"objects\": objects}\n\n    def point(\n        self,\n        image: Union[Image.Image, EncodedImage],\n        object: str,\n        settings: Optional[ObjectSamplingSettings] = None,\n    ):\n        if self.config.tokenizer.templates[\"point\"] is None:\n            raise NotImplementedError(\"Model does not support pointing.\")\n\n        image = self.encode_image(image, settings)\n        self.load_encoded_image(image)\n\n        prompt_tokens = torch.tensor(\n            [\n                self.config.tokenizer.templates[\"point\"][\"prefix\"]\n                + self.tokenizer.encode(\" \" + object).ids\n                + self.config.tokenizer.templates[\"point\"][\"suffix\"]\n            ],\n            device=self.device,\n        )\n\n        lora = (\n            variant_state_dict(settings[\"variant\"], device=self.device)\n            if settings is not None and \"variant\" in settings\n            else None\n        )\n\n        _, hidden, next_token, pos = self._prefill_prompt(\n            prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora\n        )\n        hidden = hidden[:, -1:, :]\n\n        max_objects = (\n            settings.get(\"max_objects\", DEFAULT_MAX_OBJECTS)\n            if settings\n            else DEFAULT_MAX_OBJECTS\n        )\n        objects = self._generate_points(\n            hidden,\n            next_token,\n            pos,\n            include_size=False,\n            max_objects=max_objects,\n            lora=lora,\n        )\n\n        return {\"points\": objects}\n\n    def _detect_gaze(\n        self,\n        image: EncodedImage,\n        source: Tuple[float, float],\n        force_detect: bool = False,\n    ):\n        with torch.inference_mode():\n            before_emb = text_encoder(\n                torch.tensor(\n                    [self.tokenizer.encode(\"\\n\\nPoint:\").ids], device=self.device\n                ),\n                self.text,\n            )\n            after_emb = text_encoder(\n                torch.tensor(\n                    [self.tokenizer.encode(\" gaze\\n\\n\").ids], device=self.device\n                ),\n                self.text,\n            )\n            x_emb = encode_coordinate(\n                torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16),\n                self.region,\n            )\n            y_emb = encode_coordinate(\n                torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16),\n                self.region,\n            )\n\n            prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1)\n\n            self.load_encoded_image(image)\n\n            mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :]\n            pos_ids = torch.arange(\n                image.pos,\n                image.pos + prompt_emb.size(1),\n                dtype=torch.long,\n                device=self.device,\n            )\n            hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None)\n            logits = lm_head(hidden, self.text)\n            next_token = torch.argmax(logits, dim=-1)\n            pos = image.pos + prompt_emb.size(1)\n            hidden = hidden[:, -1:, :]\n\n            if force_detect:\n                next_token = torch.tensor([[0]], device=self.device)\n\n            if next_token.item() == self.config.tokenizer.eos_id:\n                return None\n\n            gaze = self._generate_points(\n                hidden, next_token, pos, include_size=False, max_objects=1\n            )\n            return gaze[0]\n\n    def detect_gaze(\n        self,\n        image: Union[Image.Image, EncodedImage],\n        eye: Optional[Tuple[float, float]] = None,\n        face: Optional[Dict[str, float]] = None,\n        unstable_settings: Dict[str, Any] = {},\n    ):\n        if \"force_detect\" in unstable_settings:\n            force_detect = unstable_settings[\"force_detect\"]\n        else:\n            force_detect = False\n\n        if \"prioritize_accuracy\" in unstable_settings:\n            prioritize_accuracy = unstable_settings[\"prioritize_accuracy\"]\n        else:\n            prioritize_accuracy = False\n\n        if not prioritize_accuracy:\n            if eye is None:\n                raise ValueError(\"eye must be provided when prioritize_accuracy=False\")\n            image = self.encode_image(image)\n            return {\"gaze\": self._detect_gaze(image, eye, force_detect=force_detect)}\n        else:\n            if (\n                not isinstance(image, Image.Image)\n                and \"flip_enc_img\" not in unstable_settings\n            ):\n                raise ValueError(\n                    \"image must be a PIL Image when prioritize_accuracy=True, \"\n                    \"or flip_enc_img must be provided\"\n                )\n            if face is None:\n                raise ValueError(\"face must be provided when prioritize_accuracy=True\")\n\n            encoded_image = self.encode_image(image)\n            if (\n                isinstance(image, Image.Image)\n                and \"flip_enc_img\" not in unstable_settings\n            ):\n                flipped_pil = image.copy()\n                flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT)\n                encoded_flipped_image = self.encode_image(flipped_pil)\n            else:\n                encoded_flipped_image = unstable_settings[\"flip_enc_img\"]\n\n            N = 10\n\n            detections = [\n                self._detect_gaze(\n                    encoded_image,\n                    (\n                        random.uniform(face[\"x_min\"], face[\"x_max\"]),\n                        random.uniform(face[\"y_min\"], face[\"y_max\"]),\n                    ),\n                    force_detect=force_detect,\n                )\n                for _ in range(N)\n            ]\n            detections = [\n                (gaze[\"x\"], gaze[\"y\"]) for gaze in detections if gaze is not None\n            ]\n            flipped_detections = [\n                self._detect_gaze(\n                    encoded_flipped_image,\n                    (\n                        1 - random.uniform(face[\"x_min\"], face[\"x_max\"]),\n                        random.uniform(face[\"y_min\"], face[\"y_max\"]),\n                    ),\n                    force_detect=force_detect,\n                )\n                for _ in range(N)\n            ]\n            detections.extend(\n                [\n                    (1 - gaze[\"x\"], gaze[\"y\"])\n                    for gaze in flipped_detections\n                    if gaze is not None\n                ]\n            )\n\n            if len(detections) < N:\n                return {\"gaze\": None}\n\n            detections = remove_outlier_points(detections)\n            mean_gaze = (\n                sum(gaze[0] for gaze in detections) / len(detections),\n                sum(gaze[1] for gaze in detections) / len(detections),\n            )\n\n            return {\"gaze\": {\"x\": mean_gaze[0], \"y\": mean_gaze[1]}}\n\n\ndef _is_cjk_char(cp):\n    \"\"\"Checks whether CP is the codepoint of a CJK character.\"\"\"\n    # This defines a \"chinese character\" as anything in the CJK Unicode block:\n    # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)\n    if (\n        (cp >= 0x4E00 and cp <= 0x9FFF)\n        or (cp >= 0x3400 and cp <= 0x4DBF)\n        or (cp >= 0x2F800 and cp <= 0x2FA1F)\n    ):\n        return True\n    return False\n"
  },
  {
    "path": "moondream/torch/region.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\n\nfrom typing import List, Tuple, Union\n\nSpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]]\n\n\ndef fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Applies Fourier feature mapping to input tensor x using frequency matrix w. This\n    projects inputs through sinusoidal functions to create higher dimensional features\n    that help mitigate spectral bias - the tendency of neural networks to learn\n    low-frequency functions more easily than high-frequency ones. By explicitly\n    mapping inputs to higher frequencies through sin/cos transformations, we enable\n    better learning of fine details and higher frequency patterns.\n\n    Args:\n        x: Input tensor to transform\n        w: Matrix of frequencies for the Fourier features transformation\n\n    Returns:\n        Concatenated cosine and sine transformed features as a tensor\n    \"\"\"\n    f = 2 * math.pi * x @ w\n    return torch.cat([f.cos(), f.sin()], dim=-1)\n\n\ndef encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor:\n    \"\"\"\n    Takes as input a tensor containing a single float coordinate value (x or y)\n    and encodes it into hidden states for input to the text model.\n\n    Args:\n        coord: Tensor with single float coordinate value\n\n    Returns:\n        Encoded hidden states tensor for input to text model\n    \"\"\"\n    return w.coord_encoder(fourier_features(coord, w.coord_features))\n\n\ndef decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:\n    \"\"\"\n    Takes as input the last hidden state from the text model and outputs a single logit\n    representing either an x or y coordinate prediction.\n\n    Args:\n        hidden_state: The final hidden state tensor from the text model.\n\n    Returns:\n        A single logit representing the predicted coordinate value (x or y)\n    \"\"\"\n    return w.coord_decoder(hidden_state)\n\n\ndef encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:\n    \"\"\"\n    Takes a tensor containing width and height values and encodes them into\n    hidden states for input to the text model.\n\n    Args:\n        size: Tensor with two floats for width and height\n\n    Returns:\n        Encoded hidden states tensor for input to text model\n    \"\"\"\n    return w.size_encoder(fourier_features(size, w.size_features))\n\n\ndef decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:\n    \"\"\"\n    Takes as input the last hidden state from the text model and outputs logits\n    for 1024 bins representing width and height in log-scale.\n\n    The bins are distributed according to the formula:\n    bin = (log2(size) + 10.0) / 10.0 * 1023.0\n    where size values are clamped to be at least 1/1024.\n\n    To convert from bin back to size:\n    size = 2^((bin / 1023.0) * 10.0 - 10.0)\n\n    Args:\n        hidden_state: The final hidden state tensor from the text model.\n\n    Returns:\n        A tensor containing logits for 1024 bins for width and height.\n        Shape is (2, 1024) where the first dimension corresponds to width and height.\n    \"\"\"\n    return w.size_decoder(hidden_state).view(2, -1)\n\n\ndef encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor:\n    \"\"\"\n    Takes a list of spatial references (points or regions) and encodes them into\n    hidden states for input to the text model.\n\n    Args:\n        spatial_refs: List of spatial references (points or boxes)\n            - Points are represented as normalized (x, y) tuples\n            - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples\n\n    Returns:\n        {\"coords\": torch.Tensor, \"sizes\": Optional[torch.Tensor]}\n    \"\"\"\n    coords, sizes = [], []\n    for ref in spatial_refs:\n        if len(ref) == 2:\n            coords.append(ref[0])\n            coords.append(ref[1])\n        else:\n            x_c = (ref[0] + ref[2]) / 2\n            y_c = (ref[1] + ref[3]) / 2\n            width = ref[2] - ref[0]\n            height = ref[3] - ref[1]\n            coords.append(x_c)\n            coords.append(y_c)\n            sizes.append([width, height])\n\n    coords = torch.tensor(\n        coords, device=w.coord_features.device, dtype=w.coord_features.dtype\n    ).view(-1, 1)\n    coords = encode_coordinate(coords, w)\n\n    if sizes:\n        sizes = torch.tensor(\n            sizes, device=w.size_features.device, dtype=w.size_features.dtype\n        )\n        sizes = encode_size(sizes, w)\n    else:\n        sizes = None\n\n    return {\"coords\": coords, \"sizes\": sizes}\n"
  },
  {
    "path": "moondream/torch/rope.py",
    "content": "# Ethically sourced from https://github.com/xjdr-alt/entropix\n\nimport torch\n\n\ndef precompute_freqs_cis(\n    dim: int,\n    end: int,\n    theta: float = 1500000.0,\n    dtype: torch.dtype = torch.float32,\n) -> torch.Tensor:\n    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))\n    t = torch.arange(end, dtype=dtype).unsqueeze(1)\n    freqs = t * freqs.unsqueeze(0)\n    freqs = torch.exp(1j * freqs)\n    return torch.stack([freqs.real, freqs.imag], dim=-1)\n\n\ndef apply_rotary_emb(\n    x: torch.Tensor,\n    freqs_cis: torch.Tensor,\n    position_ids: torch.Tensor,\n    num_heads: int,\n    rot_dim: int = 32,\n    interleave: bool = False,\n) -> torch.Tensor:\n    assert rot_dim == freqs_cis.shape[-2] * 2\n    assert num_heads == x.shape[1]\n\n    x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:]\n\n    if interleave:\n        xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]\n        xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]\n    else:\n        d_q = x_rot.shape[-1] // 2\n        xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:]\n\n    freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0)\n    freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0)\n\n    # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i\n    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin\n    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos\n    xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2)\n\n    return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1)\n"
  },
  {
    "path": "moondream/torch/sample.py",
    "content": "import argparse\nimport json\nimport os\nimport torch\n\nfrom PIL import Image, ImageDraw\nfrom tqdm import tqdm\n\nfrom .weights import load_weights_into_model\nfrom .moondream import MoondreamModel, MoondreamConfig\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--image\", \"-i\", type=str, required=True)\n    parser.add_argument(\"--prompt\", \"-p\", type=str, required=True)\n    parser.add_argument(\"--model\", \"-m\", type=str, required=True)\n    parser.add_argument(\"--config\", \"-c\", type=str, default=None)\n    parser.add_argument(\"--max-tokens\", \"-t\", type=int, default=200)\n    parser.add_argument(\"--sampler\", \"-s\", type=str, default=\"greedy\")\n    parser.add_argument(\"--benchmark\", \"-b\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if torch.cuda.is_available():\n        device = \"cuda\"\n    elif torch.backends.mps.is_available():\n        device = \"mps\"\n    print(f\"Using device: {device}\")\n\n    # Load model.\n    if args.config is not None:\n        with open(args.config, \"r\") as f:\n            config = json.load(f)\n        config = MoondreamConfig.from_dict(config)\n    else:\n        config = MoondreamConfig()\n    model = MoondreamModel(config)\n    load_weights_into_model(args.model, model)\n    model.to(device, dtype=torch.bfloat16)\n    model.compile()\n\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n    torch.cuda.reset_accumulated_memory_stats()\n\n    # Encode image.\n    image_path = args.image\n    if not os.path.exists(image_path):\n        raise FileNotFoundError(f\"Image not found at {image_path}\")\n    image = Image.open(image_path)\n\n    if not args.benchmark:\n        encoded_image = model.encode_image(image)\n\n        # Text query\n        text_query = \"What is the capital of Washington, USA? Answer in JSON format.\"\n        print(\"Query:\", text_query)\n        text_response = model.query(None, text_query, reasoning=True, stream=True)\n        print(\"Reasoning:\", text_response[\"reasoning\"])\n        for t in text_response[\"answer\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Short caption\n        print(\"Caption: short\")\n        for t in model.caption(encoded_image, \"short\", stream=True)[\"caption\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Regular caption\n        print(\"Caption: normal\")\n        for t in model.caption(encoded_image, \"normal\", stream=True)[\"caption\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Long caption\n        print(\"Caption: long\")\n        for t in model.caption(encoded_image, \"long\", stream=True)[\"caption\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Query\n        print(\"Query:\", args.prompt)\n        for t in model.query(\n            encoded_image,\n            args.prompt,\n            stream=True,\n            settings={\"variant\": \"geoguesser_lora_only\"},\n        )[\"answer\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Query (reasoning)\n        reasoning_prompt = \"How many sesame seeds are on the burger?\"\n        print(\"Query (reasoning):\", reasoning_prompt)\n        resp = model.query(encoded_image, reasoning_prompt, reasoning=True, stream=True)\n        print(\"Reasoning:\", resp[\"reasoning\"])\n        for t in resp[\"answer\"]:\n            print(t, end=\"\", flush=True)\n        print()\n        print()\n\n        # Detect\n        obj = \"hand\"\n        print(f\"Detect: {obj}\")\n        objs = model.detect(encoded_image, obj)[\"objects\"]\n        print(f\"Found {len(objs)}\")\n        print()\n        draw = ImageDraw.Draw(image)\n        for obj in objs:\n            x_min, y_min, x_max, y_max = (\n                obj[\"x_min\"] * image.width,\n                obj[\"y_min\"] * image.height,\n                obj[\"x_max\"] * image.width,\n                obj[\"y_max\"] * image.height,\n            )\n            draw.rectangle([x_min, y_min, x_max, y_max], outline=\"red\", width=2)\n        image.save(\"detect.jpg\")\n\n        # Spatial query\n        if len(objs) > 0:\n            print(\"Spatial query: What is this?\")\n            for t in model.query(\n                encoded_image,\n                \"What is this?\",\n                spatial_refs=[\n                    [\n                        (obj[\"x_min\"], obj[\"y_min\"], obj[\"x_max\"], obj[\"y_max\"])\n                        for obj in objs\n                    ][0]\n                ],\n                stream=True,\n            )[\"answer\"]:\n                print(t, end=\"\", flush=True)\n            print()\n            print()\n\n        # Point\n        obj = \"ear\"\n        print(f\"Point: {obj}\")\n        points = model.point(encoded_image, obj)[\"points\"]\n        print(f\"Found {len(points)}\")\n        draw = ImageDraw.Draw(image)\n        for point in points:\n            x, y = point[\"x\"] * image.width, point[\"y\"] * image.height\n            draw.ellipse([x - 5, y - 5, x + 5, y + 5], fill=\"red\")\n        image.save(\"point.jpg\")\n        print()\n        print()\n\n        # Spatial query\n        if len(objs) > 0:\n            for o in [\"hand\", \"ear\", \"face\"]:\n                for k in [(objs, \"hand\"), (points, \"ear\")]:\n                    print(f\"Spatial query: Is this a {o}? ({k[1]})\")\n                    for t in model.query(\n                        encoded_image,\n                        f\"Is this a {o}?\",\n                        spatial_refs=[\n                            [\n                                (\n                                    (\n                                        obj[\"x_min\"],\n                                        obj[\"y_min\"],\n                                        obj[\"x_max\"],\n                                        obj[\"y_max\"],\n                                    )\n                                    if \"x_min\" in obj\n                                    else (obj[\"x\"], obj[\"y\"])\n                                )\n                                for obj in k[0]\n                            ][0]\n                        ],\n                    )[\"answer\"]:\n                        print(t, end=\"\", flush=True)\n                    print()\n\n        # Detect gaze\n        model.detect_gaze(encoded_image, (0.5, 0.5))\n    elif model.device.type != \"mps\":\n        # Warmup runs\n        for _ in tqdm(range(5), desc=\"Warmup\"):\n            encoded_image = model.encode_image(image)\n            for _ in model.query(encoded_image, args.prompt, stream=True)[\"answer\"]:\n                pass\n\n        # Benchmark runs\n        encode_times = []\n        query_speeds = []\n        for i in tqdm(range(10), desc=\"Benchmark\"):\n            # Measure encode time\n            start = torch.cuda.Event(enable_timing=True)\n            end = torch.cuda.Event(enable_timing=True)\n            start.record()\n            encoded_image = model.encode_image(image)\n            end.record()\n            torch.cuda.synchronize()\n            encode_time = start.elapsed_time(end)\n            encode_times.append(encode_time)\n\n            # Measure query speed\n            tokens = []\n            query_start = torch.cuda.Event(enable_timing=True)\n            query_end = torch.cuda.Event(enable_timing=True)\n            query_start.record()\n            for t in model.query(encoded_image, args.prompt, stream=True)[\"answer\"]:\n                tokens.append(t)\n            query_end.record()\n            torch.cuda.synchronize()\n            query_time = query_start.elapsed_time(query_end)\n            tokens_per_sec = len(tokens) / (query_time / 1000.0)  # Convert ms to s\n            query_speeds.append(tokens_per_sec)\n\n        # Print results\n        print(\"\\nBenchmark Results (10 runs):\")\n        print(\"Image Encoding Time (ms):\")\n        print(f\"  Mean: {sum(encode_times)/len(encode_times):.2f}\")\n        print(f\"  Min:  {min(encode_times):.2f}\")\n        print(f\"  Max:  {max(encode_times):.2f}\")\n        print(\"\\nQuery Speed (tokens/sec):\")\n        print(f\"  Mean: {sum(query_speeds)/len(query_speeds):.2f}\")\n        print(f\"  Min:  {min(query_speeds):.2f}\")\n        print(f\"  Max:  {max(query_speeds):.2f}\")\n\n        print(torch.cuda.memory_summary(abbreviated=False))\n"
  },
  {
    "path": "moondream/torch/text.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom torch.nn import functional as F\nfrom torch.nn.attention.flex_attention import flex_attention\nfrom typing import Optional\n\nfrom .layers import layer_norm, mlp, QuantizedLinear, moe_mlp\nfrom .rope import apply_rotary_emb, precompute_freqs_cis\nfrom .config import TextConfig\n\n\ndef text_encoder(input_ids: torch.Tensor, w: nn.Module):\n    return F.embedding(input_ids, w.wte)\n\n\ndef attn(\n    x: torch.Tensor,\n    w: nn.Module,\n    freqs_cis: torch.Tensor,\n    kv_cache: nn.Module,\n    attn_mask: torch.Tensor,\n    n_heads: int,\n    n_kv_heads: int,\n    position_ids: torch.Tensor,\n    lora: Optional[dict] = None,\n    flex_block_mask_slice=None,\n):\n    bsz, q_len, d_model = x.shape\n    head_dim = d_model // n_heads\n\n    qkv_out = w.qkv(x)  # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)\n    if lora is not None:\n        qkv_out += F.linear(F.linear(x, lora[\"qkv\"][\"A\"]), lora[\"qkv\"][\"B\"])\n    q_dim = n_heads * head_dim\n    kv_dim = n_kv_heads * head_dim\n    q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)\n\n    q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)\n    k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n    v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)\n\n    if hasattr(w, \"tau\") and w.tau is not None:\n        tok_feat = F.gelu(qkv_out)\n        tok_q = torch.tanh(torch.matmul(tok_feat, w.tau[\"wq\"].t())).permute(0, 2, 1)\n        tok_v = torch.tanh(torch.matmul(tok_feat, w.tau[\"wv\"].t())).permute(0, 2, 1)\n        pos = position_ids.to(q.dtype) + 1\n        tau_pos = 1 + (\n            torch.sigmoid(w.tau[\"alpha\"][:, None] * pos.log()) - 0.5\n        )  # (H,S)\n        tau_q = (tok_q + tau_pos[None]).unsqueeze(-1)  # (B,H,S,1)\n        tau_v = (tok_v + tau_pos[None]).unsqueeze(-1)\n        q = q * tau_q\n        v = v * tau_v\n\n    q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)\n    k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)\n\n    if kv_cache is not None:\n        k, v = kv_cache.update(position_ids, k, v)\n\n    if flex_block_mask_slice is not None:\n        torch._assert(n_heads == n_kv_heads, \"gqa not supported yet\")\n        out = flex_attention(q, k, v, block_mask=flex_block_mask_slice)\n    else:\n        out = F.scaled_dot_product_attention(\n            q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads\n        )\n\n    out = out.transpose(1, 2).reshape(bsz, q_len, d_model)\n\n    out0 = w.proj(out)\n    if lora is not None:\n        out1 = F.linear(F.linear(x, lora[\"proj\"][\"A\"]), lora[\"proj\"][\"B\"])\n        out = out0 + out1\n    else:\n        out = out0\n\n    return out\n\n\ndef text_decoder(\n    x: torch.Tensor,\n    w: nn.Module,\n    attn_mask: torch.Tensor,\n    position_ids: torch.Tensor,\n    config: TextConfig,\n    lora: Optional[dict] = None,\n    flex_block_mask_slice=None,\n):\n    for i, block in enumerate(w.blocks):\n        if lora is not None:\n            layer_lora = lora[\"text\"][\"blocks\"][str(i)]\n            mlp_lora = layer_lora[\"mlp\"]\n            attn_lora = layer_lora[\"attn\"]\n        else:\n            mlp_lora = None\n            attn_lora = None\n\n        l_in = layer_norm(x, block.ln)\n        l_attn = attn(\n            l_in,\n            block.attn,\n            freqs_cis=w.freqs_cis,\n            kv_cache=block.kv_cache,\n            attn_mask=attn_mask,\n            n_heads=config.n_heads,\n            n_kv_heads=config.n_kv_heads,\n            position_ids=position_ids,\n            lora=attn_lora,\n            flex_block_mask_slice=flex_block_mask_slice,\n        )\n\n        if config.moe is not None and i >= config.moe.start_layer:\n            l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token)\n        else:\n            l_mlp = mlp(l_in, block.mlp, lora=mlp_lora)\n\n        x = x + l_attn + l_mlp\n\n    return x\n\n\ndef lm_head(\n    hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None\n):\n    hidden_BC = hidden_BTC[:, -1, :]\n    hidden_BC = layer_norm(hidden_BC, w.post_ln)\n    if indices is not None:\n        # Only compute logits for specified token indices\n        logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices]\n    else:\n        logits = w.lm_head(hidden_BC)\n    return logits\n\n\ndef build_dense_mlp(d_model, d_ffn, dtype, linear_cls):\n    return nn.ModuleDict(\n        {\n            \"fc1\": linear_cls(d_model, d_ffn, dtype=dtype),\n            \"fc2\": linear_cls(d_ffn, d_model, dtype=dtype),\n        }\n    )\n\n\ndef build_moe_mlp(d_model, d_ffn, n_experts, dtype):\n    # For GeGLU, fc1 needs to output 2 * d_ffn (for gating)\n    return nn.ModuleDict(\n        {\n            \"router\": nn.Linear(d_model, n_experts, dtype=dtype),\n            \"fc1\": nn.ParameterDict(\n                {\n                    \"weight\": nn.Parameter(\n                        torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype)\n                    )\n                }\n            ),\n            \"fc2\": nn.ParameterDict(\n                {\n                    \"weight\": nn.Parameter(\n                        torch.empty(n_experts, d_model, d_ffn, dtype=dtype)\n                    )\n                }\n            ),\n        }\n    )\n\n\ndef build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:\n    qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads))\n    linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear\n\n    text = nn.ModuleDict(\n        {\n            \"blocks\": nn.ModuleList(\n                [\n                    nn.ModuleDict(\n                        {\n                            \"ln\": nn.LayerNorm(config.dim, dtype=dtype),\n                            \"attn\": nn.ModuleDict(\n                                {\n                                    \"qkv\": linear_cls(config.dim, qkv_dim, dtype=dtype),\n                                    \"proj\": linear_cls(\n                                        config.dim, config.dim, dtype=dtype\n                                    ),\n                                    \"tau\": nn.ParameterDict(\n                                        {\n                                            \"wq\": nn.Parameter(\n                                                torch.empty(\n                                                    config.n_heads, qkv_dim, dtype=dtype\n                                                )\n                                            ),\n                                            \"wv\": nn.Parameter(\n                                                torch.empty(\n                                                    config.n_heads, qkv_dim, dtype=dtype\n                                                )\n                                            ),\n                                            \"alpha\": nn.Parameter(\n                                                torch.empty(config.n_heads, dtype=dtype)\n                                            ),\n                                        }\n                                    ),\n                                }\n                            ),\n                            \"mlp\": (\n                                build_moe_mlp(\n                                    config.dim,\n                                    config.moe.expert_inner_dim,\n                                    config.moe.num_experts,\n                                    dtype,\n                                )\n                                if config.moe is not None\n                                and layer_idx >= config.moe.start_layer\n                                else build_dense_mlp(\n                                    config.dim, config.ff_dim, dtype, linear_cls\n                                )\n                            ),\n                        }\n                    )\n                    for layer_idx in range(config.n_layers)\n                ]\n            ),\n            \"post_ln\": nn.LayerNorm(config.dim, dtype=dtype),\n            \"lm_head\": nn.Linear(config.dim, config.vocab_size, dtype=dtype),\n        }\n    )\n    text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype))\n    text.register_buffer(\n        \"freqs_cis\",\n        precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context),\n        persistent=False,\n    )\n\n    return text\n"
  },
  {
    "path": "moondream/torch/utils.py",
    "content": "import numpy as np\n\n\ndef remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0):\n    \"\"\"\n    Robust outlier detection for list of (x,y) tuples.\n    Only requires numpy.\n\n    Args:\n        points_tuples: list of (x,y) tuples\n        k_nearest: number of neighbors to consider\n        threshold: multiplier for median distance\n\n    Returns:\n        list: filtered list of (x,y) tuples with outliers removed\n        list: list of booleans indicating which points were kept (True = kept)\n    \"\"\"\n    points = np.array(points_tuples)\n    n_points = len(points)\n\n    # Calculate pairwise distances manually\n    dist_matrix = np.zeros((n_points, n_points))\n    for i in range(n_points):\n        for j in range(i + 1, n_points):\n            # Euclidean distance between points i and j\n            dist = np.sqrt(np.sum((points[i] - points[j]) ** 2))\n            dist_matrix[i, j] = dist\n            dist_matrix[j, i] = dist\n\n    # Get k nearest neighbors' distances\n    k = min(k_nearest, n_points - 1)\n    neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k]\n    avg_neighbor_dist = np.mean(neighbor_distances, axis=1)\n\n    # Calculate mask using median distance\n    median_dist = np.median(avg_neighbor_dist)\n    mask = avg_neighbor_dist <= threshold * median_dist\n\n    # Return filtered tuples and mask\n    filtered_tuples = [t for t, m in zip(points_tuples, mask) if m]\n    return filtered_tuples\n"
  },
  {
    "path": "moondream/torch/vision.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nfrom typing import Union, Tuple\nfrom PIL import Image\n\nfrom .layers import attn, layer_norm, mlp\nfrom .image_crops import overlap_crop_image\nfrom .config import VisionConfig\n\nif torch.backends.mps.is_available():\n    # Non-divisible input sizes are not implemented on MPS device yet.\n    # https://github.com/pytorch/pytorch/issues/96056\n    def adaptive_avg_pool2d(input, output_size):\n        return F.adaptive_avg_pool2d(input.to(\"cpu\"), output_size).to(\"mps\")\n\nelse:\n    adaptive_avg_pool2d = F.adaptive_avg_pool2d\n\nDeviceLike = Union[str, torch.device, int]\n\n\ndef prepare_crops(\n    image: Image.Image, config: VisionConfig, device: DeviceLike\n) -> Tuple[torch.Tensor, Tuple[int, int]]:\n    np_image = np.array(image.convert(\"RGB\"))\n    overlap_crops = overlap_crop_image(\n        np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin\n    )\n    all_crops = overlap_crops[\"crops\"]\n    all_crops = np.transpose(all_crops, (0, 3, 1, 2))\n    all_crops = (\n        torch.from_numpy(all_crops)\n        .to(device=device, dtype=torch.bfloat16)\n        .div_(255.0)\n        .sub_(0.5)\n        .div_(0.5)\n    )\n    return all_crops, overlap_crops[\"tiling\"]\n\n\ndef create_patches(x, patch_size):\n    # Original shape: [B, C, H, W]\n    B, C, H, W = x.shape\n    P1 = P2 = patch_size\n\n    # Step 1: Split H and W dimensions into patches\n    # [B, C, H/P1, P1, W/P2, P2]\n    x = x.reshape(B, C, H // P1, P1, W // P2, P2)\n\n    # Step 2: Rearrange dimensions to match target shape\n    # [B, H/P1, W/P2, C, P1, P2]\n    x = x.permute(0, 2, 4, 1, 3, 5)\n\n    # Step 3: Combine dimensions to get final shape\n    # [B, (H/P1)*(W/P2), C*P1*P2]\n    x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)\n\n    return x\n\n\ndef vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):\n    x = create_patches(input_BCHW, config.enc_patch_size)\n\n    x = w.patch_emb(x)\n    x = x + w.pos_emb\n    for block in w.blocks:\n        x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)\n        x = x + mlp(layer_norm(x, block.ln2), block.mlp)\n    x = layer_norm(x, w.post_ln)\n\n    return x\n\n\ndef vision_projection(\n    global_features: torch.Tensor,\n    reconstructed: torch.Tensor,\n    w: nn.Module,\n    config: VisionConfig,\n):\n    reconstructed = reconstructed.permute(2, 0, 1)\n    reconstructed = adaptive_avg_pool2d(\n        reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)\n    )\n    reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)\n    final_features = torch.cat([global_features, reconstructed], dim=-1)\n    return mlp(final_features, w.proj_mlp)\n\n\ndef build_vision_model(config: VisionConfig, dtype: torch.dtype):\n    patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels\n    grid_size = config.crop_size // config.enc_patch_size\n    num_patches = grid_size * grid_size\n\n    vision = nn.ModuleDict(\n        {\n            \"patch_emb\": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),\n            \"blocks\": nn.ModuleList(\n                [\n                    nn.ModuleDict(\n                        {\n                            \"ln1\": nn.LayerNorm(config.enc_dim, dtype=dtype),\n                            \"attn\": nn.ModuleDict(\n                                {\n                                    \"qkv\": nn.Linear(\n                                        config.enc_dim, 3 * config.enc_dim, dtype=dtype\n                                    ),\n                                    \"proj\": nn.Linear(\n                                        config.enc_dim, config.enc_dim, dtype=dtype\n                                    ),\n                                }\n                            ),\n                            \"ln2\": nn.LayerNorm(config.enc_dim, dtype=dtype),\n                            \"mlp\": nn.ModuleDict(\n                                {\n                                    \"fc1\": nn.Linear(\n                                        config.enc_dim, config.enc_ff_dim, dtype=dtype\n                                    ),\n                                    \"fc2\": nn.Linear(\n                                        config.enc_ff_dim, config.enc_dim, dtype=dtype\n                                    ),\n                                }\n                            ),\n                        }\n                    )\n                    for _ in range(config.enc_n_layers)\n                ]\n            ),\n            \"post_ln\": nn.LayerNorm(config.enc_dim, dtype=dtype),\n            \"proj_mlp\": nn.ModuleDict(\n                {\n                    \"fc1\": nn.Linear(\n                        config.enc_dim * 2, config.proj_inner_dim, dtype=dtype\n                    ),\n                    \"fc2\": nn.Linear(\n                        config.proj_inner_dim, config.proj_out_dim, dtype=dtype\n                    ),\n                }\n            ),\n        }\n    )\n    vision.pos_emb = nn.Parameter(\n        torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)\n    )\n    return vision\n"
  },
  {
    "path": "moondream/torch/weights.py",
    "content": "import safetensors\nimport torch\nimport torch.nn as nn\n\nfrom contextlib import contextmanager\nfrom typing import Callable, List\n\n\n@contextmanager\ndef safetensors_open(safetensors_file: str):\n    \"\"\"\n    Simplify interfacing with safetensors files. Eliminates the need to ignore\n    type errors when using the `safe_open` function.\n    \"\"\"\n    with safetensors.safe_open(\n        safetensors_file, framework=\"pt\"\n    ) as st:  # pyright: ignore\n\n        def get_tensor(name: str) -> torch.Tensor:\n            return st.get_tensor(name)\n\n        def get_keys() -> List[str]:\n            return st.keys()\n\n        get_tensor.keys = get_keys\n\n        yield get_tensor\n\n\ndef _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None:\n    \"\"\"Internal function to load weights using a tensor getter function.\"\"\"\n    model = model.to(dtype=torch.bfloat16)\n\n    vision = model.vision\n    region = model.region\n    weight_map = {\n        \"vision_encoder.encoder.model.visual.patch_embed.linear.weight\": vision[\n            \"patch_emb\"\n        ].weight,\n        \"vision_encoder.encoder.model.visual.patch_embed.linear.bias\": vision[\n            \"patch_emb\"\n        ].bias,\n        \"vision_encoder.encoder.model.visual.pos_embed\": vision.pos_emb,\n        \"vision_encoder.encoder.model.visual.norm.weight\": vision[\"post_ln\"].weight,\n        \"vision_encoder.encoder.model.visual.norm.bias\": vision[\"post_ln\"].bias,\n        \"vision_encoder.projection.mlp.fc1.weight\": vision[\"proj_mlp\"][\"fc1\"].weight,\n        \"vision_encoder.projection.mlp.fc1.bias\": vision[\"proj_mlp\"][\"fc1\"].bias,\n        \"vision_encoder.projection.mlp.fc2.weight\": vision[\"proj_mlp\"][\"fc2\"].weight,\n        \"vision_encoder.projection.mlp.fc2.bias\": vision[\"proj_mlp\"][\"fc2\"].bias,\n        \"text_model.transformer.embd.wte.weight\": model.text.wte,\n        \"text_model.lm_head.ln.weight\": model.text[\"post_ln\"].weight,\n        \"text_model.lm_head.ln.bias\": model.text[\"post_ln\"].bias,\n        \"text_model.lm_head.linear.weight\": model.text[\"lm_head\"].weight,\n        \"text_model.lm_head.linear.bias\": model.text[\"lm_head\"].bias,\n        \"region_model.coordinate_encoder.weight\": region[\"coord_encoder\"].weight,\n        \"region_model.coordinate_encoder.bias\": region[\"coord_encoder\"].bias,\n        \"region_model.coordinate_head.weight\": region[\"coord_decoder\"].weight,\n        \"region_model.coordinate_head.bias\": region[\"coord_decoder\"].bias,\n        \"region_model.size_encoder.weight\": region[\"size_encoder\"].weight,\n        \"region_model.size_encoder.bias\": region[\"size_encoder\"].bias,\n        \"region_model.size_head.weight\": region[\"size_decoder\"].weight,\n        \"region_model.size_head.bias\": region[\"size_decoder\"].bias,\n    }\n\n    for i in range(len(model.vision[\"blocks\"])):\n        prefix = f\"vision_encoder.encoder.model.visual.blocks.{i}\"\n        blk = model.vision[\"blocks\"][i]\n        weight_map.update(\n            {\n                f\"{prefix}.norm1.weight\": blk[\"ln1\"].weight,\n                f\"{prefix}.norm1.bias\": blk[\"ln1\"].bias,\n                f\"{prefix}.norm2.weight\": blk[\"ln2\"].weight,\n                f\"{prefix}.norm2.bias\": blk[\"ln2\"].bias,\n                f\"{prefix}.attn.qkv.weight\": blk[\"attn\"][\"qkv\"].weight,\n                f\"{prefix}.attn.qkv.bias\": blk[\"attn\"][\"qkv\"].bias,\n                f\"{prefix}.attn.proj.weight\": blk[\"attn\"][\"proj\"].weight,\n                f\"{prefix}.attn.proj.bias\": blk[\"attn\"][\"proj\"].bias,\n                f\"{prefix}.mlp.fc1.weight\": blk[\"mlp\"][\"fc1\"].weight,\n                f\"{prefix}.mlp.fc1.bias\": blk[\"mlp\"][\"fc1\"].bias,\n                f\"{prefix}.mlp.fc2.weight\": blk[\"mlp\"][\"fc2\"].weight,\n                f\"{prefix}.mlp.fc2.bias\": blk[\"mlp\"][\"fc2\"].bias,\n            }\n        )\n\n    for i in range(len(model.text[\"blocks\"])):\n        prefix = f\"text_model.transformer.h.{i}\"\n        blk = model.text[\"blocks\"][i]\n        is_moe = hasattr(blk.mlp, \"router\")\n        weight_map.update(\n            {\n                f\"{prefix}.ln.weight\": blk[\"ln\"].weight,\n                f\"{prefix}.ln.bias\": blk[\"ln\"].bias,\n                f\"{prefix}.mixer.Wqkv.weight\": blk[\"attn\"][\"qkv\"].weight,\n                f\"{prefix}.mixer.Wqkv.bias\": blk[\"attn\"][\"qkv\"].bias,\n                f\"{prefix}.mixer.out_proj.weight\": blk[\"attn\"][\"proj\"].weight,\n                f\"{prefix}.mixer.out_proj.bias\": blk[\"attn\"][\"proj\"].bias,\n                f\"{prefix}.tau_wq\": blk[\"attn\"][\"tau\"][\"wq\"],\n                f\"{prefix}.tau_wv\": blk[\"attn\"][\"tau\"][\"wv\"],\n                f\"{prefix}.tau_alpha\": blk[\"attn\"][\"tau\"][\"alpha\"],\n            }\n        )\n        if is_moe:\n            weight_map.update(\n                {\n                    f\"{prefix}.gate.weight\": blk[\"mlp\"][\"router\"].weight,\n                    f\"{prefix}.gate.bias\": blk[\"mlp\"][\"router\"].bias,\n                    f\"{prefix}.mlp.experts.weight\": blk[\"mlp\"][\"fc1\"].weight,\n                    f\"{prefix}.mlp.output_experts.weight\": blk[\"mlp\"][\"fc2\"].weight,\n                }\n            )\n        else:\n            weight_map.update(\n                {\n                    f\"{prefix}.mlp.fc1.weight\": blk[\"mlp\"][\"fc1\"].weight,\n                    f\"{prefix}.mlp.fc1.bias\": blk[\"mlp\"][\"fc1\"].bias,\n                    f\"{prefix}.mlp.fc2.weight\": blk[\"mlp\"][\"fc2\"].weight,\n                    f\"{prefix}.mlp.fc2.bias\": blk[\"mlp\"][\"fc2\"].bias,\n                }\n            )\n\n    for key, tensor in weight_map.items():\n        tensor.data.copy_(get_tensor(key))\n\n    region.coord_features.data.copy_(\n        get_tensor(\"region_model.coordinate_features.weight\").T\n    )\n    region.size_features.data.copy_(get_tensor(\"region_model.size_features.weight\").T)\n\n\ndef load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None:\n    \"\"\"Load weights from a safetensors file into a MoondreamModel instance.\"\"\"\n    with safetensors_open(weights_file) as get_tensor:\n        if (\n            \"vision.blocks.0.attn.proj.bias\" in get_tensor.keys()\n            or \"model.vision.blocks.0.attn.proj.bias\" in get_tensor.keys()\n        ):\n            with safetensors_open(weights_file) as get_tensor:\n                tensors = {\n                    k.replace(\"model.\", \"\"): get_tensor(k) for k in get_tensor.keys()\n                }\n                model.load_state_dict(tensors, strict=False)\n        else:\n            # Wrap the get_tensor function to handle key normalization\n            name_map = {k.replace(\"._orig_mod\", \"\"): k for k in get_tensor.keys()}\n            _load_weights(\n                lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model\n            )\n\n\ndef load_weights_from_pt(weights_file: str, model: nn.Module) -> None:\n    \"\"\"Load weights from a PyTorch file into a MoondreamModel instance.\"\"\"\n    device = str(torch.empty(0).device)\n    tensors = torch.load(weights_file, map_location=device, weights_only=True)\n    if \"vision.blocks.0.attn.proj.bias\" in tensors.keys():\n        missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False)\n        print(\"Missing keys:\", missing_keys)\n        print(\"Unexpected keys:\", unexpected_keys)\n    else:\n        tensors = {\n            k.replace(\"._orig_mod\", \"\"): v.to(dtype=torch.bfloat16)\n            for k, v in tensors.items()\n        }\n        _load_weights(lambda x: tensors[x], model)\n\n\ndef load_weights_into_model(weights_file: str, model: nn.Module) -> None:\n    \"\"\"\n    Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance.\n\n    Args:\n        weights_file: Path to weights file (either .safetensors or .pt)\n        model: MoondreamModel instance to load weights into\n    \"\"\"\n    if weights_file.endswith(\".safetensors\"):\n        load_weights_from_safetensors(weights_file, model)\n    else:\n        load_weights_from_pt(weights_file, model)\n\n    # Make all parameters contiguous\n    for param in model.parameters():\n        param.data = param.data.contiguous()\n"
  },
  {
    "path": "notebooks/RepEng.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"This notebook shows how to compute control vectors to steer moondream's behavior\\n\",\n    \"in fun and interesting ways. To learn more about control vectors and representation\\n\",\n    \"engineering check out [Theia's blog post on the topic](https://vgel.me/posts/representation-engineering/).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 32,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from transformers import AutoModelForCausalLM, AutoTokenizer\\n\",\n    \"from datasets import load_dataset\\n\",\n    \"from tqdm import tqdm\\n\",\n    \"from PIL import Image\\n\",\n    \"import numpy as np\\n\",\n    \"from sklearn.decomposition import PCA\\n\",\n    \"from IPython.display import display, HTML\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"tokenizer = AutoTokenizer.from_pretrained(\\\"vikhyatk/moondream2\\\")\\n\",\n    \"model = AutoModelForCausalLM.from_pretrained(\\n\",\n    \"    \\\"vikhyatk/moondream2\\\", trust_remote_code=True,\\n\",\n    \"    torch_dtype=torch.float16, device_map={\\\"\\\": \\\"cuda\\\"}\\n\",\n    \")\\n\",\n    \"\\n\",\n    \"# We will only be using the images, so it doesn't really matter what\\n\",\n    \"# dataset we use here.\\n\",\n    \"dataset = load_dataset(\\\"vikhyatk/lnqa\\\", streaming=True)[\\\"train\\\"]\\n\",\n    \"\\n\",\n    \"def hidden_states(enc_img, prompt):\\n\",\n    \"    with torch.no_grad():\\n\",\n    \"        inputs_embeds = model.input_embeds(prompt, enc_img, tokenizer)\\n\",\n    \"        hidden_states = model.text_model.generate(\\n\",\n    \"            inputs_embeds=inputs_embeds,\\n\",\n    \"            max_new_tokens=128,\\n\",\n    \"            pad_token_id=tokenizer.eos_token_id,\\n\",\n    \"            eos_token_id=tokenizer.eos_token_id,\\n\",\n    \"            return_dict_in_generate=True,\\n\",\n    \"            output_hidden_states=True,\\n\",\n    \"            do_sample=True,\\n\",\n    \"            temperature=0.5\\n\",\n    \"        ).hidden_states[1:]\\n\",\n    \"    return [torch.stack([hs.view(-1, 2048) for hs in h[1:]]).cpu() for h in hidden_states]\\n\",\n    \"\\n\",\n    \"class LayerWrapper(torch.nn.Module):\\n\",\n    \"    def __init__(self, og_layer, control_vectors, scale=4.2):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.og_layer = og_layer\\n\",\n    \"        self.control_vectors = control_vectors\\n\",\n    \"        self.scale = scale\\n\",\n    \"\\n\",\n    \"    def forward(self, *args, **kwargs):\\n\",\n    \"        layer_outputs = self.og_layer(*args, **kwargs)\\n\",\n    \"        layer_outputs = (layer_outputs[0] + self.scale * self.control_vectors, *layer_outputs[1:])\\n\",\n    \"        return layer_outputs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 112,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"negative_prompt = \\\"<image>\\\\n\\\\nQuestion: Describe this image.\\\\n\\\\nAnswer:\\\"\\n\",\n    \"positive_prompt = \\\"<image>\\\\n\\\\nQuestion: What is the meaning of life?\\\\n\\\\nAnswer:\\\"\\n\",\n    \"\\n\",\n    \"# This can be lowered without noticeable loss in quality. Feel free to drop it to\\n\",\n    \"# IMAGES_PER_CONTROL=50 and SAMPLES_PER_IMAGE=2 if it's taking too long.\\n\",\n    \"IMAGES_PER_CONTROL = 200\\n\",\n    \"SAMPLES_PER_IMAGE = 5\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 113,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|██████████| 200/200 [37:09<00:00, 11.15s/it]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# This is not very efficient, batching would speed things up a lot.\\n\",\n    \"# But eh, works for a quick demo.\\n\",\n    \"\\n\",\n    \"hs_dataset = [[] for _ in range(24)]\\n\",\n    \"\\n\",\n    \"for i, sample in tqdm(enumerate(dataset), total=IMAGES_PER_CONTROL):\\n\",\n    \"    if i >= IMAGES_PER_CONTROL:\\n\",\n    \"        break\\n\",\n    \"    image = sample[\\\"image\\\"]\\n\",\n    \"    enc_img = model.encode_image(image)\\n\",\n    \"    for _ in range(SAMPLES_PER_IMAGE):\\n\",\n    \"        phs = hidden_states(enc_img, positive_prompt)\\n\",\n    \"        nhs = hidden_states(enc_img, negative_prompt)\\n\",\n    \"        t_max = min(len(phs), len(nhs))\\n\",\n    \"        for t in range(t_max):\\n\",\n    \"            phs_t = phs[t]\\n\",\n    \"            nhs_t = nhs[t]\\n\",\n    \"            for j in range(24):\\n\",\n    \"                hs_dataset[j].append(phs_t[j])\\n\",\n    \"                hs_dataset[j].append(nhs_t[j])\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 114,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"100%|██████████| 24/24 [02:30<00:00,  6.26s/it]\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"control_vectors = []\\n\",\n    \"\\n\",\n    \"for i in tqdm(range(24)):\\n\",\n    \"    layer_hiddens = torch.stack(hs_dataset[i])\\n\",\n    \"\\n\",\n    \"    layer_centers = (layer_hiddens[::2] + layer_hiddens[1::2]) / 2\\n\",\n    \"    relative_layer_hiddens = layer_hiddens\\n\",\n    \"    relative_layer_hiddens[::2] -= layer_centers\\n\",\n    \"    relative_layer_hiddens[1::2] -= layer_centers\\n\",\n    \"\\n\",\n    \"    train = relative_layer_hiddens - relative_layer_hiddens.mean(axis=0, keepdims=True)\\n\",\n    \"    train = train.view(-1, 2048).cpu().numpy()\\n\",\n    \"    pca_model = PCA(n_components=1, whiten=False).fit(train)\\n\",\n    \"    directions = pca_model.components_.astype(np.float32).squeeze(axis=0)\\n\",\n    \"\\n\",\n    \"    projected_hiddens = (layer_hiddens.cpu().numpy() @ directions) / np.linalg.norm(directions)\\n\",\n    \"\\n\",\n    \"    positive_smaller_mean = np.mean(\\n\",\n    \"        [\\n\",\n    \"            projected_hiddens[i] < projected_hiddens[i + 1]\\n\",\n    \"            for i in range(0, len(hs_dataset[i]), 2)\\n\",\n    \"        ]\\n\",\n    \"    )\\n\",\n    \"    positive_larger_mean = np.mean(\\n\",\n    \"        [\\n\",\n    \"            projected_hiddens[i] > projected_hiddens[i + 1]\\n\",\n    \"            for i in range(0, len(hs_dataset[i]), 2)\\n\",\n    \"        ]\\n\",\n    \"    )\\n\",\n    \"    if positive_smaller_mean > positive_larger_mean:  # type: ignore\\n\",\n    \"        directions *= -1\\n\",\n    \"\\n\",\n    \"    control_vectors.append(torch.tensor(directions, device=\\\"cuda\\\", dtype=torch.float16))\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 116,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/html\": [\n       \"\\n\",\n       \"<table style=\\\"border: 1px solid black; border-collapse: collapse;\\\">\\n\",\n       \"    <tr>\\n\",\n       \"        <th style=\\\"border: 1px solid black; padding: 8px;\\\">Image</th>\\n\",\n       \"        <th style=\\\"border: 1px solid black; padding: 8px;\\\">Text</th>\\n\",\n       \"    </tr>\\n\",\n       \"\\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px;\\\">\\n\",\n       \"            <img src=\\\"../assets/demo-1.jpg\\\" style=\\\"width:250px;\\\">\\n\",\n       \"        </td>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px; text-align: left\\\">\\n\",\n       \"            <strong>Describe this image.</strong><br><br>\\n\",\n       \"            In this image, we will explore the concept of happiness through an anthropomorphic interpretation. We'll look at it from a spiritual or religious perspective by examining what brings joy and fulfillment to human beings. This can be approached in various ways depending on individual beliefs and experiences.<br><br>To understand happiness, one might consider factors such as personal goals, relationships, health, wealth, spirituality, ethics, and more. It is subjective and varies greatly among individuals based on their values and life experiences. However, many people often seek happiness in different aspects like achieving success, finding inner peace, or living a meaningful life.<br><br>In the context of the provided statements about happiness, some may find happiness in material possessions (e.of), while others may seek happiness through relationships, creativity, or even simple pleasures that bring satisfaction and contentment. Ultimately, happiness is a deeply personal journey for each person to define and pursue according to their own beliefs and values.\\n\",\n       \"        </td>\\n\",\n       \"    </tr>\\n\",\n       \"    \\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px;\\\">\\n\",\n       \"            <img src=\\\"../assets/demo-2.jpg\\\" style=\\\"width:250px;\\\">\\n\",\n       \"        </td>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px; text-align: left\\\">\\n\",\n       \"            <strong>What is this?</strong><br><br>\\n\",\n       \"            In the context of this image, a server or processor is an essential component for computing and running various applications on computers. A Processor (or CPU) can refer to any device that executes tasks according to specific programming requirements.<br><br>In this particular scenario, we are referring to advanced technologies like virtual machines, artificial intelligence, machine learning, etc., which require powerful computing systems to function effectively.<br><br>For example, in the case of AI research, researchers develop and test theories using sophisticated computer models and simulations. These concepts may involve analyzing vast amounts of data, exploring ethical questions, understanding existence, or even developing new knowledge about life itself.<br><br>In summary, when people talk about \\\"the meaning\\\" or \\\"purpose,\\\" they often refer to these advanced concepts as well. It's subjective and varies from person to person based on their beliefs, values, and experiences.\\n\",\n       \"        </td>\\n\",\n       \"    </tr>\\n\",\n       \"    \\n\",\n       \"    <tr>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px;\\\">\\n\",\n       \"            <img src=\\\"../assets/demo-2.jpg\\\" style=\\\"width:250px;\\\">\\n\",\n       \"        </td>\\n\",\n       \"        <td style=\\\"border: 1px solid black; padding: 8px; text-align: left\\\">\\n\",\n       \"            <strong>What color is the couch?</strong><br><br>\\n\",\n       \"            The couch in the image is described as \\\"black.\\\" However, without more information or context from different sources, it's difficult to determine its actual color. It could be any of those things like comfort, aesthetics, personal preferences, etc., which can vary among individuals.\\n\",\n       \"        </td>\\n\",\n       \"    </tr>\\n\",\n       \"    </table>\"\n      ],\n      \"text/plain\": [\n       \"<IPython.core.display.HTML object>\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"prompts = [\\n\",\n    \"    (\\\"../assets/demo-1.jpg\\\", \\\"Describe this image.\\\"),\\n\",\n    \"    (\\\"../assets/demo-2.jpg\\\", \\\"What is this?\\\"),\\n\",\n    \"    (\\\"../assets/demo-2.jpg\\\", \\\"What color is the couch?\\\"),\\n\",\n    \"]\\n\",\n    \"data = []\\n\",\n    \"\\n\",\n    \"def run_model(img_path, prompt, scale=4.2):\\n\",\n    \"    og_h = model.text_model.transformer.h\\n\",\n    \"    model.text_model.transformer.h = torch.nn.ModuleList([\\n\",\n    \"        LayerWrapper(layer, vector, scale) for layer, vector in zip(og_h, control_vectors)\\n\",\n    \"    ])\\n\",\n    \"    answer = model.answer_question(\\n\",\n    \"        model.encode_image(Image.open(img_path)), prompt, tokenizer,\\n\",\n    \"        repetition_penalty=1.2, temperature=0.1, do_sample=True,\\n\",\n    \"        length_penalty=1.2\\n\",\n    \"    )\\n\",\n    \"    model.text_model.transformer.h = og_h\\n\",\n    \"    return answer\\n\",\n    \"\\n\",\n    \"for img_path, prompt in prompts:\\n\",\n    \"    answer = run_model(img_path, prompt)\\n\",\n    \"    data.append({\\\"prompt\\\": prompt, \\\"answer\\\": answer.replace(\\\"\\\\n\\\", \\\"<br>\\\"), \\\"image\\\": img_path})\\n\",\n    \"\\n\",\n    \"html_table = \\\"\\\"\\\"\\n\",\n    \"<table style=\\\"border: 1px solid black; border-collapse: collapse;\\\">\\n\",\n    \"    <tr>\\n\",\n    \"        <th style=\\\"border: 1px solid black; padding: 8px;\\\">Image</th>\\n\",\n    \"        <th style=\\\"border: 1px solid black; padding: 8px;\\\">Text</th>\\n\",\n    \"    </tr>\\n\",\n    \"\\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"for item in data:\\n\",\n    \"    html_table += f\\\"\\\"\\\"\\n\",\n    \"    <tr>\\n\",\n    \"        <td style=\\\"border: 1px solid black; padding: 8px;\\\">\\n\",\n    \"            <img src=\\\"{item['image']}\\\" style=\\\"width:250px;\\\">\\n\",\n    \"        </td>\\n\",\n    \"        <td style=\\\"border: 1px solid black; padding: 8px; text-align: left\\\">\\n\",\n    \"            <strong>{item['prompt']}</strong><br><br>\\n\",\n    \"            {item['answer']}\\n\",\n    \"        </td>\\n\",\n    \"    </tr>\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"html_table += \\\"</table>\\\"\\n\",\n    \"\\n\",\n    \"# Display the HTML table\\n\",\n    \"display(HTML(html_table))\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \".venv\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.10.12\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "recipes/gaze-detection-video/.gitignore",
    "content": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# Virtual Environment\nvenv/\nENV/\n\n# IDE\n.idea/\n.vscode/\n*.swp\n*.swo\n\n# Project specific\n# input/*\n# !input/.gitkeep\n# output/*\n# !output/.gitkeep\n# temp/*\n# !temp/.gitkeep\n\n# Model files\n*.pt\n*.pth\n*.ckpt\n\n# Logs\n*.log\n\n# OS specific\n.DS_Store\nThumbs.db\n"
  },
  {
    "path": "recipes/gaze-detection-video/README.md",
    "content": "# Gaze Detection Video Processor\n\n> **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client\n> libraries once they become available for this version.\n\n## Table of Contents\n\n- [Overview](#overview)\n- [Sample Output](#sample-output)\n- [Features](#features)\n- [Prerequisites](#prerequisites)\n- [Installation](#installation)\n  - [Linux/macOS Installation](#linuxmacos-installation)\n  - [Windows Installation](#windows-installation)\n- [Usage](#usage)\n- [Output](#output)\n- [Troubleshooting](#troubleshooting)\n- [Performance Notes](#performance-notes)\n- [Dependencies](#dependencies)\n- [Model Details](#model-details)\n- [License](#license)\n\n## Overview\n\nThis project uses the Moondream 2B model to detect faces and their gaze directions in videos. It processes videos frame by frame, visualizing face detections and gaze directions.\n\n## Sample Output\n\n|              Input Video              |              Processed Output               |\n| :-----------------------------------: | :-----------------------------------------: |\n| ![Input Video](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-input-sample.gif?raw=true) | ![Processed Output](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-output-sample.gif?raw=true) |\n\n## Features\n\n- Face detection in video frames\n- Gaze direction tracking\n- Real-time visualization with:\n  - Colored bounding boxes for faces\n  - Gradient lines showing gaze direction\n  - Gaze target points\n- Supports multiple faces per frame\n- Processes all common video formats (.mp4, .avi, .mov, .mkv)\n- Uses Moondream 2 (2025-01-09 release) via Hugging Face Transformers\n  - Note: Will be migrated to official client libraries in future updates\n  - No authentication required\n\n## Prerequisites\n\n1. Python 3.8 or later\n2. CUDA-capable GPU recommended (but CPU mode works too)\n3. FFmpeg installed on your system\n\n## Installation\n\n### Linux/macOS Installation\n\n1. Install system dependencies:\n\n   ```bash\n   # Ubuntu/Debian\n   sudo apt-get update && sudo apt-get install -y libvips42 libvips-dev ffmpeg\n\n   # CentOS/RHEL\n   sudo yum install vips vips-devel ffmpeg\n\n   # macOS\n   brew install vips ffmpeg\n   ```\n\n2. Clone and setup the project:\n   ```bash\n   git clone https://github.com/vikhyat/moondream.git\n   cd moondream/recipes/gaze-detection-video\n   python3 -m venv venv\n   source venv/bin/activate\n   pip install -r requirements.txt\n   ```\n\n### Windows Installation\n\nWindows setup requires a few additional steps for proper GPU support and libvips installation.\n\n1. Clone the repository:\n\n   ```bash\n   git clone [repository-url]\n   cd moondream/recipes/gaze-detection-video\n   ```\n\n2. Create and activate virtual environment:\n\n   ```bash\n   python -m venv venv\n   .\\venv\\Scripts\\activate\n   ```\n\n3. Install PyTorch with CUDA support:\n\n   ```bash\n   # For NVIDIA GPUs\n   pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n   ```\n\n4. Install libvips: Download the appropriate version based on your system architecture:\n\n   | Architecture | VIPS Version to Download |\n   | ------------ | ------------------------ |\n   | 32-bit x86   | vips-dev-w32-all-8.16.0.zip |\n   | 64-bit x64   | vips-dev-w64-all-8.16.0.zip |\n\n   - Extract the ZIP file\n   - Copy all DLL files from `vips-dev-8.16\\bin` to either:\n     - Your project's root directory (easier) OR\n     - `C:\\Windows\\System32` (requires admin privileges)\n   - Add to PATH:\n     1. Open System Properties → Advanced → Environment Variables\n     2. Under System Variables, find PATH\n     3. Add the full path to the `vips-dev-8.16\\bin` directory\n\n5. Install FFmpeg:\n\n   - Download from https://ffmpeg.org/download.html#build-windows\n   - Extract and add the `bin` folder to your system PATH (similar to step 4) or to the project root directory\n\n6. Install other dependencies:\n   ```bash\n   pip install -r requirements.txt\n   ```\n\n## Usage\n\n1. Place your input videos in the `input` directory\n\n   - Supported formats: .mp4, .avi, .mov, .mkv\n   - The directory will be created automatically if it doesn't exist\n\n2. Run the script:\n\n   ```bash\n   python gaze-detection-video.py\n   ```\n\n3. The script will:\n   - Process all videos in the input directory\n   - Show progress bars for each video\n   - Save processed videos to the `output` directory with prefix 'processed\\_'\n\n## Output\n\n- Processed videos are saved as `output/processed_[original_name].[ext]`\n- Each frame in the output video shows:\n  - Colored boxes around detected faces\n  - Lines indicating gaze direction\n  - Points showing where each person is looking\n\n## Troubleshooting\n\n1. CUDA/GPU Issues:\n\n   - Ensure you have CUDA installed for GPU support\n   - The script will automatically fall back to CPU if no GPU is available\n\n2. Memory Issues:\n\n   - If processing large videos, ensure you have enough RAM\n   - Consider reducing video resolution if needed\n\n3. libvips Errors:\n\n   - Make sure libvips is properly installed for your OS\n   - Check system PATH includes libvips\n\n4. Video Format Issues:\n   - Ensure FFmpeg is installed and in your system PATH\n   - Try converting problematic videos to MP4 format\n\n## Performance Notes\n\n- GPU processing is significantly faster than CPU\n- Processing time depends on:\n  - Video resolution\n  - Number of faces per frame\n  - Frame rate\n  - Available computing power\n\n## Dependencies\n\n- transformers (for Moondream 2 model access)\n- torch\n- opencv-python\n- pillow\n- matplotlib\n- numpy\n- tqdm\n- pyvips\n- accelerate\n- einops\n\n## Model Details\n\n> **⚠️ IMPORTANT:** This project currently uses Moondream 2 (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client\n> libraries once they become available for this version.\n\nThe model is loaded using:\n"
  },
  {
    "path": "recipes/gaze-detection-video/gaze-detection-video.py",
    "content": "\"\"\"\nGaze Detection Video Processor using Moondream 2\n------------------------------------------------\nRead the README.md file for more information on how to use this script. Contact us in our discord for any questions if you get stuck.\n\"\"\"\n\nimport torch\nimport numpy as np\nimport cv2\nimport matplotlib.pyplot as plt\nfrom PIL import Image\nfrom transformers import AutoModelForCausalLM\nfrom tqdm import tqdm\nimport os\nimport glob\nfrom typing import List, Dict, Tuple, Optional\nfrom contextlib import contextmanager\n\n\ndef initialize_model() -> Optional[AutoModelForCausalLM]:\n    \"\"\"Initialize the Moondream 2 model with error handling.\"\"\"\n    try:\n        print(\"\\nInitializing Moondream 2 model...\")\n        model_id = \"vikhyatk/moondream2\"\n        revision = \"2025-01-09\"  # Specify revision for stability\n\n        if torch.cuda.is_available():\n            print(f\"GPU detected: {torch.cuda.get_device_name(0)}\")\n            device = \"cuda\"\n        else:\n            print(\"No GPU detected, using CPU\")\n            device = \"cpu\"\n\n        print(\"Loading model from HuggingFace...\")\n        model = AutoModelForCausalLM.from_pretrained(\n            model_id,\n            revision=revision,\n            trust_remote_code=True,\n            torch_dtype=torch.float16 if device == \"cuda\" else torch.float32,\n            low_cpu_mem_usage=True,\n            device_map={\"\": device} if device == \"cuda\" else None,\n        )\n\n        if device == \"cpu\":\n            model = model.to(device)\n        model.eval()\n\n        print(\"✓ Model initialized successfully\")\n        return model\n    except Exception as e:\n        print(f\"\\nError initializing model: {e}\")\n        return None\n\n\n@contextmanager\ndef video_handler(\n    input_path: str, output_path: str\n) -> Tuple[cv2.VideoCapture, cv2.VideoWriter]:\n    \"\"\"Context manager for handling video capture and writer.\"\"\"\n    cap = cv2.VideoCapture(input_path)\n    if not cap.isOpened():\n        raise ValueError(f\"Could not open video file: {input_path}\")\n\n    # Get video properties\n    fps = int(cap.get(cv2.CAP_PROP_FPS))\n    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n\n    # Create video writer\n    fourcc = cv2.VideoWriter_fourcc(*\"mp4v\")\n    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n\n    try:\n        yield cap, out\n    finally:\n        cap.release()\n        out.release()\n        cv2.destroyAllWindows()\n\n\ndef fig2rgb_array(fig: plt.Figure) -> np.ndarray:\n    \"\"\"Convert matplotlib figure to RGB array\"\"\"\n    fig.canvas.draw()\n    buf = fig.canvas.buffer_rgba()\n    w, h = fig.canvas.get_width_height()\n    img_array = np.asarray(buf).reshape((h, w, 4))\n    rgb_array = img_array[:, :, :3]  # Drop alpha channel\n    return rgb_array\n\n\ndef visualize_frame(\n    frame: np.ndarray, faces: List[Dict], model: AutoModelForCausalLM, pil_image: Image\n) -> np.ndarray:\n    \"\"\"Visualize a single frame using matplotlib\"\"\"\n    try:\n        # Create figure without margins\n        fig = plt.figure(figsize=(frame.shape[1] / 100, frame.shape[0] / 100), dpi=100)\n        ax = fig.add_axes([0, 0, 1, 1])\n\n        # Display frame\n        ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n\n        # Sort faces by x_min coordinate for stable colors\n        faces = sorted(faces, key=lambda f: (f[\"y_min\"], f[\"x_min\"]))\n\n        # Generate colors\n        colors = plt.cm.rainbow(np.linspace(0, 1, max(1, len(faces))))\n\n        # Process each face\n        for face, color in zip(faces, colors):\n            try:\n                # Calculate face box coordinates\n                x_min = int(float(face[\"x_min\"]) * frame.shape[1])\n                y_min = int(float(face[\"y_min\"]) * frame.shape[0])\n                width = int(float(face[\"x_max\"] - face[\"x_min\"]) * frame.shape[1])\n                height = int(float(face[\"y_max\"] - face[\"y_min\"]) * frame.shape[0])\n\n                # Draw face rectangle\n                rect = plt.Rectangle(\n                    (x_min, y_min), width, height, fill=False, color=color, linewidth=2\n                )\n                ax.add_patch(rect)\n\n                # Calculate face center\n                face_center = (\n                    float(face[\"x_min\"] + face[\"x_max\"]) / 2,\n                    float(face[\"y_min\"] + face[\"y_max\"]) / 2,\n                )\n\n                # Try to detect gaze\n                try:\n                    gaze_result = model.detect_gaze(pil_image, face_center)\n                    if isinstance(gaze_result, dict) and \"gaze\" in gaze_result:\n                        gaze = gaze_result[\"gaze\"]\n                    else:\n                        gaze = gaze_result\n                except Exception as e:\n                    print(f\"Error detecting gaze: {e}\")\n                    continue\n\n                if (\n                    gaze is not None\n                    and isinstance(gaze, dict)\n                    and \"x\" in gaze\n                    and \"y\" in gaze\n                ):\n                    gaze_x = int(float(gaze[\"x\"]) * frame.shape[1])\n                    gaze_y = int(float(gaze[\"y\"]) * frame.shape[0])\n                    face_center_x = x_min + width // 2\n                    face_center_y = y_min + height // 2\n\n                    # Draw gaze line with gradient effect\n                    points = 50\n                    alphas = np.linspace(0.8, 0, points)\n\n                    # Calculate points along the line\n                    x_points = np.linspace(face_center_x, gaze_x, points)\n                    y_points = np.linspace(face_center_y, gaze_y, points)\n\n                    # Draw gradient line segments\n                    for i in range(points - 1):\n                        ax.plot(\n                            [x_points[i], x_points[i + 1]],\n                            [y_points[i], y_points[i + 1]],\n                            color=color,\n                            alpha=alphas[i],\n                            linewidth=4,\n                        )\n\n                    # Draw gaze point\n                    ax.scatter(gaze_x, gaze_y, color=color, s=100, zorder=5)\n                    ax.scatter(gaze_x, gaze_y, color=\"white\", s=50, zorder=6)\n\n            except Exception as e:\n                print(f\"Error processing face: {e}\")\n                continue\n\n        # Configure axes\n        ax.set_xlim(0, frame.shape[1])\n        ax.set_ylim(frame.shape[0], 0)\n        ax.axis(\"off\")\n\n        # Convert matplotlib figure to image\n        frame_rgb = fig2rgb_array(fig)\n\n        # Convert RGB to BGR for OpenCV\n        frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)\n\n        # Clean up\n        plt.close(fig)\n\n        return frame_bgr\n\n    except Exception as e:\n        print(f\"Error in visualize_frame: {e}\")\n        plt.close(\"all\")\n        return frame\n\n\ndef process_video(\n    input_path: str, output_path: str, model: AutoModelForCausalLM\n) -> None:\n    \"\"\"Process video file and create new video with gaze visualization\"\"\"\n    with video_handler(input_path, output_path) as (cap, out):\n        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n        fps = int(cap.get(cv2.CAP_PROP_FPS))\n        print(f\"Processing video: {total_frames} frames at {fps} FPS\")\n\n        # Process frames\n        with tqdm(\n            total=total_frames, desc=f\"Processing {os.path.basename(input_path)}\"\n        ) as pbar:\n            while True:\n                ret, frame = cap.read()\n                if not ret:\n                    break\n\n                try:\n                    # Convert frame for model\n                    pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n\n                    # Detect faces\n                    detection_result = model.detect(pil_image, \"face\")\n\n                    # Handle different possible return formats\n                    if (\n                        isinstance(detection_result, dict)\n                        and \"objects\" in detection_result\n                    ):\n                        faces = detection_result[\"objects\"]\n                    elif isinstance(detection_result, list):\n                        faces = detection_result\n                    else:\n                        print(\n                            f\"Unexpected detection result format: {type(detection_result)}\"\n                        )\n                        faces = []\n\n                    # Ensure each face has the required coordinates\n                    faces = [\n                        face\n                        for face in faces\n                        if all(k in face for k in [\"x_min\", \"y_min\", \"x_max\", \"y_max\"])\n                    ]\n\n                    if not faces:\n                        processed_frame = frame\n                    else:\n                        # Visualize frame with matplotlib\n                        processed_frame = visualize_frame(\n                            frame, faces, model, pil_image\n                        )\n\n                    # Write frame\n                    out.write(processed_frame)\n                    pbar.update(1)\n\n                    # Force matplotlib to clean up\n                    plt.close(\"all\")\n\n                except Exception as e:\n                    print(f\"Error processing frame: {e}\")\n                    out.write(frame)  # Write original frame on error\n                    pbar.update(1)\n                    plt.close(\"all\")  # Clean up even on error\n\n\nif __name__ == \"__main__\":\n    # Ensure input and output directories exist\n    input_dir = os.path.join(os.path.dirname(__file__), \"input\")\n    output_dir = os.path.join(os.path.dirname(__file__), \"output\")\n    os.makedirs(input_dir, exist_ok=True)\n    os.makedirs(output_dir, exist_ok=True)\n\n    # Find all video files in input directory\n    video_extensions = [\".mp4\", \".avi\", \".mov\", \".mkv\"]\n    input_videos = []\n    for ext in video_extensions:\n        input_videos.extend(glob.glob(os.path.join(input_dir, f\"*{ext}\")))\n\n    if not input_videos:\n        print(\"No video files found in input directory\")\n        exit(1)\n\n    # Initialize model once for all videos\n    model = initialize_model()\n    if model is None:\n        print(\"Failed to initialize model\")\n        exit(1)\n\n    # Process each video file\n    for input_video in input_videos:\n        base_name = os.path.basename(input_video)\n        output_video = os.path.join(output_dir, f\"processed_{base_name}\")\n        try:\n            process_video(input_video, output_video, model)\n        except Exception as e:\n            print(f\"Error processing {base_name}: {e}\")\n            continue\n"
  },
  {
    "path": "recipes/gaze-detection-video/input/.gitkeep",
    "content": ""
  },
  {
    "path": "recipes/gaze-detection-video/output/.gitkeep",
    "content": ""
  },
  {
    "path": "recipes/gaze-detection-video/requirements.txt",
    "content": "torch>=2.0.0\ntransformers>=4.36.0\nopencv-python>=4.8.0\npillow>=10.0.0\nmatplotlib>=3.7.0\nnumpy>=1.24.0\ntqdm>=4.65.0\npyvips\naccelerate>=0.26.0\neinops"
  },
  {
    "path": "recipes/gaze-detection-video/temp/.gitkeep",
    "content": ""
  },
  {
    "path": "recipes/promptable-content-moderation/.gitignore",
    "content": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n*.dll\n\n# Virtual Environment\nvenv/\nenv/\nENV/\n.venv/\n\n# IDE\n.idea/\n.vscode/\n*.swp\n*.swo\n\n# Project specific\ninputs/*\noutputs/*\n!inputs/.gitkeep\n!outputs/.gitkeep\ninputs/\noutputs/\n\n# Model files\n*.pth\n*.onnx\n*.pt\n\n# Logs\n*.log\n\ncertificate.pem"
  },
  {
    "path": "recipes/promptable-content-moderation/README.md",
    "content": "# Promptable Content Moderation with Moondream\n\nWelcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts.\n\n[Try it now.](https://huggingface.co/spaces/moondream/content-moderation)\n\n## Features\n\n- Content moderation through natural language prompts\n- Multiple visualization styles\n- Intelligent scene detection and tracking:\n  - DeepSORT tracking with scene-aware reset\n  - Persistent moderation across frames\n  - Smart tracker reset at scene boundaries\n- Optional grid-based detection for improved accuracy on complex scenes\n- Frame-by-frame processing with IoU-based merging\n- Web-compatible output format\n- Test mode (process only first X seconds)\n- Advanced moderation analysis with multiple visualization plots\n\n## Examples\n\n| Prompt | Output |\n|--------|-----------------|\n| \"white cigarette\" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-cig.gif)     |\n| \"gun\" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-gu.gif)      |\n| \"confederate flag\" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-conflag.gif) |\n\n## Requirements\n\n### Python Dependencies\n\nFor Windows users, before installing other requirements, first install PyTorch with CUDA support:\n\n```bash\npip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121\n```\n\nThen install the remaining dependencies:\n\n```bash\npip install -r requirements.txt\n```\n\n### System Requirements\n\n- FFmpeg (required for video processing)\n- libvips (required for image processing)\n\nInstallation by platform:\n\n- Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`\n- macOS: `brew install ffmpeg libvips`\n- Windows:\n  - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html)\n  - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start)\n\n## Installation\n\n1. Clone this repository and create a new virtual environment:\n\n```bash\ngit clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction\npython -m venv .venv\nsource .venv/bin/activate  # On Windows: .venv\\Scripts\\activate\n```\n\n2. Install Python dependencies:\n\n```bash\npip install -r requirements.txt\n```\n\n3. Install ffmpeg and libvips:\n   - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`\n   - On macOS: `brew install ffmpeg`\n   - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)\n\n> Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start)\n\n## Usage\n\nThe easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation.\n\n### Web Interface\n\n1. Start the web interface:\n\n```bash\npython app.py\n```\n\n2. Open the provided URL in your browser (typically <http://localhost:7860>)\n\n3. Use the interface to:\n   - Upload your video file\n   - Specify content to moderate (e.g., \"face\", \"cigarette\", \"gun\")\n   - Choose redaction style (default: obfuscated-pixel)\n   - OPTIONAL: Configure advanced settings\n     - Processing speed/quality\n     - Grid size for detection\n     - Test mode for quick validation (default: on, 3 seconds)\n   - Process the video and download results\n   - Analyze detection patterns with visualization tools\n\n## Output Files\n\nThe tool generates two types of output files in the `outputs` directory:\n\n1. Processed Videos:\n   - Format: `[style]_[content_type]_[original_filename].mp4`\n   - Example: `censor_inappropriate_video.mp4`\n\n2. Detection Data:\n   - Format: `[style]_[content_type]_[original_filename]_detections.json`\n   - Contains frame-by-frame detection information\n   - Used for visualization and analysis\n\n## Technical Details\n\n### Scene Detection and Tracking\n\nThe tool uses advanced scene detection and object tracking:\n\n1. Scene Detection:\n   - Powered by PySceneDetect's ContentDetector\n   - Automatically identifies scene changes in videos\n   - Configurable detection threshold (default: 30.0)\n   - Helps maintain tracking accuracy across scene boundaries\n\n2. Object Tracking:\n   - DeepSORT tracking for consistent object identification\n   - Automatic tracker reset at scene changes\n   - Maintains object identity within scenes\n   - Prevents tracking errors across scene boundaries\n\n3. Integration Benefits:\n   - More accurate object tracking\n   - Better handling of scene transitions\n   - Reduced false positives in tracking\n   - Improved tracking consistency\n\n## Best Practices\n\n- Use test mode for initial configuration\n- Enable grid-based detection for complex scenes\n- Choose appropriate redaction style based on content type:\n  - Censor: Complete content blocking\n  - Blur styles: Less intrusive moderation\n  - Bounding Box: Content review and analysis\n- Monitor system resources during processing\n- Use appropriate processing quality settings based on your needs\n\n## Notes\n\n- Processing time depends on video length, resolution, GPU availability, and chosen settings\n- GPU is strongly recommended for faster processing\n- Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently)\n- Test mode processes only first X seconds (default: 3 seconds) for quick validation\n"
  },
  {
    "path": "recipes/promptable-content-moderation/app.py",
    "content": "#!/usr/bin/env python3\nimport gradio as gr\nimport os\nfrom main import load_moondream, process_video, load_sam_model\nimport shutil\nimport torch\nfrom visualization import visualize_detections\nfrom persistence import load_detection_data\nimport matplotlib.pyplot as plt\nimport io\nfrom PIL import Image\nimport pandas as pd\nfrom video_visualization import create_video_visualization\n\n# Get absolute path to workspace root\nWORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__))\n\n# Check CUDA availability\nprint(f\"Is CUDA available: {torch.cuda.is_available()}\")\n# We want to get True\nprint(f\"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}\")\n# GPU Name\n\n# Initialize Moondream model globally for reuse (will be loaded on first use)\nmodel, tokenizer = None, None\n\n\ndef process_video_file(\n    video_file,\n    target_object,\n    box_style,\n    ffmpeg_preset,\n    grid_rows,\n    grid_cols,\n    test_mode,\n    test_duration,\n):\n    \"\"\"Process a video file through the Gradio interface.\"\"\"\n    try:\n        if not video_file:\n            raise gr.Error(\"Please upload a video file\")\n\n        # Load models if not already loaded\n        global model, tokenizer\n        if model is None or tokenizer is None:\n            model, tokenizer = load_moondream()\n\n        # Ensure input/output directories exist using absolute paths\n        inputs_dir = os.path.join(WORKSPACE_ROOT, \"inputs\")\n        outputs_dir = os.path.join(WORKSPACE_ROOT, \"outputs\")\n        os.makedirs(inputs_dir, exist_ok=True)\n        os.makedirs(outputs_dir, exist_ok=True)\n\n        # Copy uploaded video to inputs directory\n        video_filename = f\"input_{os.path.basename(video_file)}\"\n        input_video_path = os.path.join(inputs_dir, video_filename)\n        shutil.copy2(video_file, input_video_path)\n\n        try:\n            # Process the video\n            output_path = process_video(\n                input_video_path,\n                target_object,\n                test_mode=test_mode,\n                test_duration=test_duration,\n                ffmpeg_preset=ffmpeg_preset,\n                grid_rows=grid_rows,\n                grid_cols=grid_cols,\n                box_style=box_style,\n            )\n\n            # Get the corresponding JSON path\n            base_name = os.path.splitext(os.path.basename(video_filename))[0]\n            json_path = os.path.join(\n                outputs_dir, f\"{box_style}_{target_object}_{base_name}_detections.json\"\n            )\n\n            # Verify output exists and is readable\n            if not output_path or not os.path.exists(output_path):\n                print(f\"Warning: Output path {output_path} does not exist\")\n                # Try to find the output based on expected naming convention\n                expected_output = os.path.join(\n                    outputs_dir, f\"{box_style}_{target_object}_{video_filename}\"\n                )\n                if os.path.exists(expected_output):\n                    output_path = expected_output\n                else:\n                    # Try searching in outputs directory for any matching file\n                    matching_files = [\n                        f\n                        for f in os.listdir(outputs_dir)\n                        if f.startswith(f\"{box_style}_{target_object}_\")\n                    ]\n                    if matching_files:\n                        output_path = os.path.join(outputs_dir, matching_files[0])\n                    else:\n                        raise gr.Error(\"Failed to locate output video\")\n\n            # Convert output path to absolute path if it isn't already\n            if not os.path.isabs(output_path):\n                output_path = os.path.join(WORKSPACE_ROOT, output_path)\n\n            print(f\"Returning output path: {output_path}\")\n            return output_path, json_path\n\n        finally:\n            # Clean up input file\n            try:\n                if os.path.exists(input_video_path):\n                    os.remove(input_video_path)\n            except:\n                pass\n\n    except Exception as e:\n        print(f\"Error in process_video_file: {str(e)}\")\n        raise gr.Error(f\"Error processing video: {str(e)}\")\n\n\ndef create_visualization_plots(json_path):\n    \"\"\"Create visualization plots and return them as images.\"\"\"\n    try:\n        # Load the data\n        data = load_detection_data(json_path)\n        if not data:\n            return None, None, None, None, None, None, None, None, \"No data found\"\n\n        # Convert to DataFrame\n        rows = []\n        for frame_data in data[\"frame_detections\"]:\n            frame = frame_data[\"frame\"]\n            timestamp = frame_data[\"timestamp\"]\n            for obj in frame_data[\"objects\"]:\n                rows.append(\n                    {\n                        \"frame\": frame,\n                        \"timestamp\": timestamp,\n                        \"keyword\": obj[\"keyword\"],\n                        \"x1\": obj[\"bbox\"][0],\n                        \"y1\": obj[\"bbox\"][1],\n                        \"x2\": obj[\"bbox\"][2],\n                        \"y2\": obj[\"bbox\"][3],\n                        \"area\": (obj[\"bbox\"][2] - obj[\"bbox\"][0])\n                        * (obj[\"bbox\"][3] - obj[\"bbox\"][1]),\n                        \"center_x\": (obj[\"bbox\"][0] + obj[\"bbox\"][2]) / 2,\n                        \"center_y\": (obj[\"bbox\"][1] + obj[\"bbox\"][3]) / 2,\n                    }\n                )\n\n        if not rows:\n            return (\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n                None,\n                \"No detections found in the data\",\n            )\n\n        df = pd.DataFrame(rows)\n        plots = []\n\n        # Create each plot and convert to image\n        for plot_num in range(8):  # Increased to 8 plots\n            plt.figure(figsize=(8, 6))\n\n            if plot_num == 0:\n                # Plot 1: Number of detections per frame (Original)\n                detections_per_frame = df.groupby(\"frame\").size()\n                plt.plot(detections_per_frame.index, detections_per_frame.values)\n                plt.xlabel(\"Frame\")\n                plt.ylabel(\"Number of Detections\")\n                plt.title(\"Detections Per Frame\")\n\n            elif plot_num == 1:\n                # Plot 2: Distribution of detection areas (Original)\n                df[\"area\"].hist(bins=30)\n                plt.xlabel(\"Detection Area (normalized)\")\n                plt.ylabel(\"Count\")\n                plt.title(\"Distribution of Detection Areas\")\n\n            elif plot_num == 2:\n                # Plot 3: Average detection area over time (Original)\n                avg_area = df.groupby(\"frame\")[\"area\"].mean()\n                plt.plot(avg_area.index, avg_area.values)\n                plt.xlabel(\"Frame\")\n                plt.ylabel(\"Average Detection Area\")\n                plt.title(\"Average Detection Area Over Time\")\n\n            elif plot_num == 3:\n                # Plot 4: Heatmap of detection centers (Original)\n                plt.hist2d(df[\"center_x\"], df[\"center_y\"], bins=30)\n                plt.colorbar()\n                plt.xlabel(\"X Position\")\n                plt.ylabel(\"Y Position\")\n                plt.title(\"Detection Center Heatmap\")\n\n            elif plot_num == 4:\n                # Plot 5: Time-based Detection Density\n                # Shows when in the video most detections occur\n                df[\"time_bucket\"] = pd.qcut(df[\"timestamp\"], q=20, labels=False)\n                time_density = df.groupby(\"time_bucket\").size()\n                plt.bar(time_density.index, time_density.values)\n                plt.xlabel(\"Video Timeline (20 segments)\")\n                plt.ylabel(\"Number of Detections\")\n                plt.title(\"Detection Density Over Video Duration\")\n\n            elif plot_num == 5:\n                # Plot 6: Screen Region Analysis\n                # Divide screen into 3x3 grid and show detection counts\n                try:\n                    df[\"grid_x\"] = pd.qcut(\n                        df[\"center_x\"],\n                        q=3,\n                        labels=[\"Left\", \"Center\", \"Right\"],\n                        duplicates=\"drop\",\n                    )\n                    df[\"grid_y\"] = pd.qcut(\n                        df[\"center_y\"],\n                        q=3,\n                        labels=[\"Top\", \"Middle\", \"Bottom\"],\n                        duplicates=\"drop\",\n                    )\n                    region_counts = (\n                        df.groupby([\"grid_y\", \"grid_x\"]).size().unstack(fill_value=0)\n                    )\n                    plt.imshow(region_counts, cmap=\"YlOrRd\")\n                    plt.colorbar(label=\"Detection Count\")\n                    for i in range(3):\n                        for j in range(3):\n                            plt.text(\n                                j, i, region_counts.iloc[i, j], ha=\"center\", va=\"center\"\n                            )\n                    plt.xticks(range(3), [\"Left\", \"Center\", \"Right\"])\n                    plt.yticks(range(3), [\"Top\", \"Middle\", \"Bottom\"])\n                    plt.title(\"Screen Region Analysis\")\n                except Exception as e:\n                    plt.text(\n                        0.5,\n                        0.5,\n                        \"Insufficient variation in detection positions\",\n                        ha=\"center\",\n                        va=\"center\",\n                    )\n                    plt.title(\"Screen Region Analysis (Not Available)\")\n\n            elif plot_num == 6:\n                # Plot 7: Detection Size Categories\n                # Categorize detections by size for content moderation\n                try:\n                    size_labels = [\n                        \"Small (likely far/background)\",\n                        \"Medium-small\",\n                        \"Medium-large\",\n                        \"Large (likely foreground/close)\",\n                    ]\n\n                    # Handle cases with limited unique values\n                    unique_areas = df[\"area\"].nunique()\n                    if unique_areas >= 4:\n                        df[\"size_category\"] = pd.qcut(\n                            df[\"area\"], q=4, labels=size_labels, duplicates=\"drop\"\n                        )\n                    else:\n                        # Alternative binning for limited unique values\n                        df[\"size_category\"] = pd.cut(\n                            df[\"area\"],\n                            bins=unique_areas,\n                            labels=size_labels[:unique_areas],\n                        )\n\n                    size_dist = df[\"size_category\"].value_counts()\n                    plt.pie(size_dist.values, labels=size_dist.index, autopct=\"%1.1f%%\")\n                    plt.title(\"Detection Size Distribution\")\n                except Exception as e:\n                    plt.text(\n                        0.5,\n                        0.5,\n                        \"Insufficient variation in detection sizes\",\n                        ha=\"center\",\n                        va=\"center\",\n                    )\n                    plt.title(\"Detection Size Distribution (Not Available)\")\n\n            elif plot_num == 7:\n                # Plot 8: Temporal Pattern Analysis\n                # Show patterns of when detections occur in sequence\n                try:\n                    detection_gaps = df.sort_values(\"frame\")[\"frame\"].diff()\n                    if len(detection_gaps.dropna().unique()) > 1:\n                        plt.hist(\n                            detection_gaps.dropna(),\n                            bins=min(30, len(detection_gaps.dropna().unique())),\n                            edgecolor=\"black\",\n                        )\n                        plt.xlabel(\"Frames Between Detections\")\n                        plt.ylabel(\"Frequency\")\n                        plt.title(\"Detection Temporal Pattern Analysis\")\n                    else:\n                        plt.text(\n                            0.5,\n                            0.5,\n                            \"Uniform detection intervals\",\n                            ha=\"center\",\n                            va=\"center\",\n                        )\n                        plt.title(\"Temporal Pattern Analysis (Uniform)\")\n                except Exception as e:\n                    plt.text(\n                        0.5, 0.5, \"Insufficient temporal data\", ha=\"center\", va=\"center\"\n                    )\n                    plt.title(\"Temporal Pattern Analysis (Not Available)\")\n\n            # Save plot to bytes\n            buf = io.BytesIO()\n            plt.savefig(buf, format=\"png\", bbox_inches=\"tight\")\n            buf.seek(0)\n            plots.append(Image.open(buf))\n            plt.close()\n\n        # Enhanced summary text\n        summary = f\"\"\"Summary Statistics:\nTotal frames analyzed: {len(data['frame_detections'])}\nTotal detections: {len(df)}\nAverage detections per frame: {len(df) / len(data['frame_detections']):.2f}\n\nDetection Patterns:\n- Peak detection count: {df.groupby('frame').size().max()} (in a single frame)\n- Most common screen region: {df.groupby(['grid_y', 'grid_x']).size().idxmax()}\n- Average detection size: {df['area'].mean():.3f}\n- Median frames between detections: {detection_gaps.median():.1f}\n\nVideo metadata:\n\"\"\"\n        for key, value in data[\"video_metadata\"].items():\n            summary += f\"{key}: {value}\\n\"\n\n        return (\n            plots[0],\n            plots[1],\n            plots[2],\n            plots[3],\n            plots[4],\n            plots[5],\n            plots[6],\n            plots[7],\n            summary,\n        )\n\n    except Exception as e:\n        print(f\"Error creating visualization: {str(e)}\")\n        import traceback\n\n        traceback.print_exc()\n        return (\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            None,\n            f\"Error creating visualization: {str(e)}\",\n        )\n\n\n# Create the Gradio interface\nwith gr.Blocks(title=\"Promptable Content Moderation\") as app:\n    with gr.Tabs():\n        with gr.Tab(\"Process Video\"):\n            gr.Markdown(\"# Promptable Content Moderation with Moondream\")\n            gr.Markdown(\n                \"\"\"\n            Powered by [Moondream 2B](https://github.com/vikhyat/moondream).\n\n            Upload a video and specify what to moderate. The app will process each frame and moderate any visual content that matches the prompt. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH).\n            \"\"\"\n            )\n\n            with gr.Row():\n                with gr.Column():\n                    # Input components\n                    video_input = gr.Video(label=\"Upload Video\")\n\n                    detect_input = gr.Textbox(\n                        label=\"What to Moderate\",\n                        placeholder=\"e.g. face, cigarette, gun, etc.\",\n                        value=\"face\",\n                        info=\"Moondream can moderate anything that you can describe in natural language\",\n                    )\n\n                    process_btn = gr.Button(\"Process Video\", variant=\"primary\")\n\n                    with gr.Accordion(\"Advanced Settings\", open=False):\n                        box_style_input = gr.Radio(\n                            choices=[\n                                \"censor\",\n                                \"bounding-box\",\n                                \"hitmarker\",\n                                \"sam\",\n                                \"sam-fast\",\n                                \"fuzzy-blur\",\n                                \"pixelated-blur\",\n                                \"intense-pixelated-blur\",\n                                \"obfuscated-pixel\",\n                            ],\n                            value=\"obfuscated-pixel\",\n                            label=\"Visualization Style\",\n                            info=\"Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)\",\n                        )\n                        preset_input = gr.Dropdown(\n                            choices=[\n                                \"ultrafast\",\n                                \"superfast\",\n                                \"veryfast\",\n                                \"faster\",\n                                \"fast\",\n                                \"medium\",\n                                \"slow\",\n                                \"slower\",\n                                \"veryslow\",\n                            ],\n                            value=\"medium\",\n                            label=\"Processing Speed (faster = lower quality)\",\n                        )\n                        with gr.Row():\n                            rows_input = gr.Slider(\n                                minimum=1, maximum=4, value=1, step=1, label=\"Grid Rows\"\n                            )\n                            cols_input = gr.Slider(\n                                minimum=1,\n                                maximum=4,\n                                value=1,\n                                step=1,\n                                label=\"Grid Columns\",\n                            )\n\n                        test_mode_input = gr.Checkbox(\n                            label=\"Test Mode (Process first 3 seconds only)\",\n                            value=True,\n                            info=\"Enable to quickly test settings on a short clip before processing the full video (recommended). If using the data visualizations, disable.\",\n                        )\n\n                        test_duration_input = gr.Slider(\n                            minimum=1,\n                            maximum=10,\n                            value=3,\n                            step=1,\n                            label=\"Test Mode Duration (seconds)\",\n                            info=\"Number of seconds to process in test mode\",\n                        )\n\n                        gr.Markdown(\n                            \"\"\"\n                        Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings.\n                        \"\"\"\n                        )\n\n                        gr.Markdown(\n                            \"\"\"\n                        We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection.\n                        For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU).\n                        \n                        Note: Using the SAM visualization style will increase processing time significantly as it performs additional segmentation for each detection. The sam-fast option uses a smaller model for faster processing at the cost of some accuracy.\n                        \"\"\"\n                        )\n\n                with gr.Column():\n                    # Output components\n                    video_output = gr.Video(label=\"Processed Video\")\n                    json_output = gr.Text(label=\"Detection Data Path\", visible=False)\n\n                    # About section under the video output\n                    gr.Markdown(\n                        \"\"\"\n                    ### Links:\n                    - [GitHub Repository](https://github.com/vikhyat/moondream)\n                    - [Hugging Face](https://huggingface.co/vikhyatk/moondream2)\n                    - [Quick Start](https://docs.moondream.ai/quick-start)\n                    - [Moondream Recipes](https://docs.moondream.ai/recipes)\n                    \"\"\"\n                    )\n\n        with gr.Tab(\"Analyze Results\"):\n            gr.Markdown(\"# Detection Analysis\")\n            gr.Markdown(\n                \"\"\"\n            Analyze the detection results from processed videos. The analysis includes:\n            - Basic detection statistics and patterns\n            - Temporal and spatial distribution analysis\n            - Size-based categorization\n            - Screen region analysis\n            - Detection density patterns\n            \"\"\"\n            )\n\n            with gr.Row():\n                json_input = gr.File(\n                    label=\"Upload Detection Data (JSON)\",\n                    file_types=[\".json\"],\n                )\n                analyze_btn = gr.Button(\"Analyze\", variant=\"primary\")\n\n            with gr.Row():\n                with gr.Column():\n                    plot1 = gr.Image(\n                        label=\"Detections Per Frame\",\n                    )\n                    plot2 = gr.Image(\n                        label=\"Detection Areas Distribution\",\n                    )\n                    plot5 = gr.Image(\n                        label=\"Detection Density Timeline\",\n                    )\n                    plot6 = gr.Image(\n                        label=\"Screen Region Analysis\",\n                    )\n\n                with gr.Column():\n                    plot3 = gr.Image(\n                        label=\"Average Detection Area Over Time\",\n                    )\n                    plot4 = gr.Image(\n                        label=\"Detection Center Heatmap\",\n                    )\n                    plot7 = gr.Image(\n                        label=\"Detection Size Categories\",\n                    )\n                    plot8 = gr.Image(\n                        label=\"Temporal Pattern Analysis\",\n                    )\n\n            stats_output = gr.Textbox(\n                label=\"Statistics\",\n                info=\"Summary of key metrics and patterns found in the detection data.\",\n                lines=12,\n                max_lines=15,\n                interactive=False,\n            )\n\n        # with gr.Tab(\"Video Visualizations\"):\n        #     gr.Markdown(\"# Real-time Detection Visualization\")\n        #     gr.Markdown(\n        #         \"\"\"\n        #     Watch the detection patterns unfold in real-time. Choose from:\n        #     - Timeline: Shows number of detections over time\n        #     - Gauge: Simple yes/no indicator for current frame detections\n        #     \"\"\"\n        #     )\n\n        #     with gr.Row():\n        #         json_input_realtime = gr.File(\n        #             label=\"Upload Detection Data (JSON)\",\n        #             file_types=[\".json\"],\n        #         )\n        #         viz_style = gr.Radio(\n        #             choices=[\"timeline\", \"gauge\"],\n        #             value=\"timeline\",\n        #             label=\"Visualization Style\",\n        #             info=\"Choose between timeline view or simple gauge indicator\"\n        #         )\n        #         visualize_btn = gr.Button(\"Visualize\", variant=\"primary\")\n\n        #     with gr.Row():\n        #         video_visualization = gr.Video(\n        #             label=\"Detection Visualization\",\n        #             interactive=False\n        #         )\n        #         stats_realtime = gr.Textbox(\n        #             label=\"Video Statistics\",\n        #             lines=6,\n        #             max_lines=8,\n        #             interactive=False\n        #         )\n\n    # Event handlers\n    process_outputs = process_btn.click(\n        fn=process_video_file,\n        inputs=[\n            video_input,\n            detect_input,\n            box_style_input,\n            preset_input,\n            rows_input,\n            cols_input,\n            test_mode_input,\n            test_duration_input,\n        ],\n        outputs=[video_output, json_output],\n    )\n\n    # Auto-analyze after processing\n    process_outputs.then(\n        fn=create_visualization_plots,\n        inputs=[json_output],\n        outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output],\n    )\n\n    # Manual analysis button\n    analyze_btn.click(\n        fn=create_visualization_plots,\n        inputs=[json_input],\n        outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output],\n    )\n\n    # Video visualization button\n    # visualize_btn.click(\n    #     fn=lambda json_file, style: create_video_visualization(json_file.name if json_file else None, style),\n    #     inputs=[json_input_realtime, viz_style],\n    #     outputs=[video_visualization, stats_realtime],\n    # )\n\nif __name__ == \"__main__\":\n    app.launch(share=True)\n"
  },
  {
    "path": "recipes/promptable-content-moderation/deep_sort_integration.py",
    "content": "import numpy as np\nimport torch\nfrom deep_sort_realtime.deepsort_tracker import DeepSort\nfrom datetime import datetime\n\n\nclass DeepSORTTracker:\n    def __init__(self, max_age=5):\n        \"\"\"Initialize DeepSORT tracker.\"\"\"\n        self.max_age = max_age\n        self.tracker = self._create_tracker()\n\n    def _create_tracker(self):\n        \"\"\"Create a new instance of DeepSort tracker.\"\"\"\n        return DeepSort(\n            max_age=self.max_age,\n            embedder=\"mobilenet\",  # Using default MobileNetV2 embedder\n            today=datetime.now().date(),  # For track naming and daily ID reset\n        )\n\n    def reset(self):\n        \"\"\"Reset the tracker state by creating a new instance.\"\"\"\n        print(\"Resetting DeepSORT tracker...\")\n        self.tracker = self._create_tracker()\n\n    def update(self, frame, detections):\n        \"\"\"Update tracking with new detections.\n\n        Args:\n            frame: Current video frame (numpy array)\n            detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized\n\n        Returns:\n            List of (box, keyword, track_id) tuples\n        \"\"\"\n        if not detections:\n            return []\n\n        height, width = frame.shape[:2]\n\n        # Convert normalized coordinates to absolute and format detections\n        detection_list = []\n        for box, keyword in detections:\n            x1 = int(box[0] * width)\n            y1 = int(box[1] * height)\n            x2 = int(box[2] * width)\n            y2 = int(box[3] * height)\n            w = x2 - x1\n            h = y2 - y1\n\n            # Format: ([left,top,w,h], confidence, detection_class)\n            detection_list.append(([x1, y1, w, h], 1.0, keyword))\n\n        # Update tracker\n        tracks = self.tracker.update_tracks(detection_list, frame=frame)\n\n        # Convert back to normalized coordinates with track IDs\n        tracked_objects = []\n        for track in tracks:\n            if not track.is_confirmed():\n                continue\n\n            ltrb = track.to_ltrb()  # Get [left,top,right,bottom] format\n            x1, y1, x2, y2 = ltrb\n\n            # Normalize coordinates\n            x1 = max(0.0, min(1.0, x1 / width))\n            y1 = max(0.0, min(1.0, y1 / height))\n            x2 = max(0.0, min(1.0, x2 / width))\n            y2 = max(0.0, min(1.0, y2 / height))\n\n            tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id))\n\n        return tracked_objects\n"
  },
  {
    "path": "recipes/promptable-content-moderation/main.py",
    "content": "#!/usr/bin/env python3\nimport cv2, os, subprocess, argparse\nfrom PIL import Image\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, SamModel, SamProcessor\nfrom tqdm import tqdm\nimport numpy as np\nfrom datetime import datetime\nfrom deep_sort_integration import DeepSORTTracker\nfrom scenedetect import detect, ContentDetector\nfrom functools import lru_cache\n\n# Constants\nDEFAULT_TEST_MODE_DURATION = 3  # Process only first 3 seconds in test mode by default\nFFMPEG_PRESETS = [\n    \"ultrafast\",\n    \"superfast\",\n    \"veryfast\",\n    \"faster\",\n    \"fast\",\n    \"medium\",\n    \"slow\",\n    \"slower\",\n    \"veryslow\",\n]\nFONT = cv2.FONT_HERSHEY_SIMPLEX  # Font for bounding-box-style labels\n\n# Detection parameters\nIOU_THRESHOLD = 0.5  # IoU threshold for considering boxes related\n\n# Hitmarker parameters\nHITMARKER_SIZE = 20  # Size of the hitmarker in pixels\nHITMARKER_GAP = 3  # Size of the empty space in the middle (reduced from 8)\nHITMARKER_THICKNESS = 2  # Thickness of hitmarker lines\nHITMARKER_COLOR = (255, 255, 255)  # White color for hitmarker\nHITMARKER_SHADOW_COLOR = (80, 80, 80)  # Lighter gray for shadow effect\nHITMARKER_SHADOW_OFFSET = 1  # Smaller shadow offset\n\n# SAM parameters\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Initialize model variables as None\nsam_model = None\nsam_processor = None\nslimsam_model = None\nslimsam_processor = None\n\n\n@lru_cache(maxsize=2)  # Cache both regular and slim SAM models\ndef get_sam_model(slim=False):\n    \"\"\"Get cached SAM model and processor.\"\"\"\n    global sam_model, sam_processor, slimsam_model, slimsam_processor\n\n    if slim:\n        if slimsam_model is None:\n            print(\"Loading SlimSAM model for the first time...\")\n            slimsam_model = SamModel.from_pretrained(\"nielsr/slimsam-50-uniform\").to(\n                device\n            )\n            slimsam_processor = SamProcessor.from_pretrained(\n                \"nielsr/slimsam-50-uniform\"\n            )\n        return slimsam_model, slimsam_processor\n    else:\n        if sam_model is None:\n            print(\"Loading SAM model for the first time...\")\n            sam_model = SamModel.from_pretrained(\"facebook/sam-vit-huge\").to(device)\n            sam_processor = SamProcessor.from_pretrained(\"facebook/sam-vit-huge\")\n        return sam_model, sam_processor\n\n\ndef load_sam_model(slim=False):\n    \"\"\"Load SAM model and processor with caching.\"\"\"\n    return get_sam_model(slim=slim)\n\n\ndef generate_color_pair():\n    \"\"\"Generate a generic light blue and dark blue color pair for SAM visualization.\"\"\"\n    dark_rgb = [0, 0, 139]  # Dark blue\n    light_rgb = [173, 216, 230]  # Light blue\n    return dark_rgb, light_rgb\n\n\ndef create_mask_overlay(image, masks, points=None, labels=None):\n    \"\"\"Create a mask overlay with contours for multiple SAM visualizations.\n\n    Args:\n        image: PIL Image to overlay masks on\n        masks: List of binary masks or single mask\n        points: Optional list of (x,y) points for labels\n        labels: Optional list of label strings for each point\n    \"\"\"\n    # Convert single mask to list for uniform processing\n    if not isinstance(masks, list):\n        masks = [masks]\n\n    # Create empty overlays\n    overlay = np.zeros((*image.size[::-1], 4), dtype=np.uint8)\n    outline = np.zeros((*image.size[::-1], 4), dtype=np.uint8)\n\n    # Process each mask\n    for i, mask in enumerate(masks):\n        # Convert binary mask to uint8\n        mask_uint8 = (mask > 0).astype(np.uint8)\n\n        # Dilation to fill gaps\n        kernel = np.ones((5, 5), np.uint8)\n        mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)\n\n        # Find contours of the dilated mask\n        contours, _ = cv2.findContours(\n            mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE\n        )\n\n        # Generate random color pair for this segmentation\n        dark_color, light_color = generate_color_pair()\n\n        # Add to the overlays\n        overlay[mask_dilated > 0] = [*light_color, 90]  # Light color with 35% opacity\n        cv2.drawContours(\n            outline, contours, -1, (*dark_color, 255), 2\n        )  # Dark color outline\n\n    # Convert to PIL images\n    mask_overlay = Image.fromarray(overlay, \"RGBA\")\n    outline_overlay = Image.fromarray(outline, \"RGBA\")\n\n    # Composite the layers\n    result = image.convert(\"RGBA\")\n    result.paste(mask_overlay, (0, 0), mask_overlay)\n    result.paste(outline_overlay, (0, 0), outline_overlay)\n\n    # Add labels if provided\n    if points and labels:\n        result_array = np.array(result)\n        for (x, y), label in zip(points, labels):\n            label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]\n            cv2.putText(\n                result_array,\n                label,\n                (int(x - label_size[0] // 2), int(y - 20)),\n                FONT,\n                0.5,\n                (255, 255, 255),\n                1,\n                cv2.LINE_AA,\n            )\n        result = Image.fromarray(result_array)\n\n    return result\n\n\ndef process_sam_detection(image, center_x, center_y, slim=False):\n    \"\"\"Process a single detection point with SAM.\n\n    Returns:\n        tuple: (mask, result_pil) where mask is the binary mask and result_pil is the visualization\n    \"\"\"\n    if not isinstance(image, Image.Image):\n        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n\n    # Get appropriate model from cache\n    model, processor = get_sam_model(slim)\n\n    # Process the image with SAM\n    inputs = processor(\n        image, input_points=[[[center_x, center_y]]], return_tensors=\"pt\"\n    ).to(device)\n\n    with torch.no_grad():\n        outputs = model(**inputs)\n\n    mask = processor.post_process_masks(\n        outputs.pred_masks.cpu(),\n        inputs[\"original_sizes\"].cpu(),\n        inputs[\"reshaped_input_sizes\"].cpu(),\n    )[0][0][0].numpy()\n\n    # Create the visualization\n    result = create_mask_overlay(image, mask)\n    return mask, result\n\n\ndef load_moondream():\n    \"\"\"Load Moondream model and tokenizer.\"\"\"\n    model = AutoModelForCausalLM.from_pretrained(\n        \"vikhyatk/moondream2\", trust_remote_code=True, device_map={\"\": \"cuda\"}\n    )\n    tokenizer = AutoTokenizer.from_pretrained(\"vikhyatk/moondream2\")\n    return model, tokenizer\n\n\ndef get_video_properties(video_path):\n    \"\"\"Get basic video properties.\"\"\"\n    video = cv2.VideoCapture(video_path)\n    fps = video.get(cv2.CAP_PROP_FPS)\n    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))\n    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))\n    video.release()\n    return {\"fps\": fps, \"frame_count\": frame_count, \"width\": width, \"height\": height}\n\n\ndef is_valid_bounding_box(bounding_box):\n    \"\"\"Check if bounding box coordinates are reasonable.\"\"\"\n    x1, y1, x2, y2 = bounding_box\n    width = x2 - x1\n    height = y2 - y1\n\n    # Reject boxes that are too large (over 90% of frame in both dimensions)\n    if width > 0.9 and height > 0.9:\n        return False\n\n    # Reject boxes that are too small (less than 1% of frame)\n    if width < 0.01 or height < 0.01:\n        return False\n\n    return True\n\n\ndef split_frame_into_grid(frame, grid_rows, grid_cols):\n    \"\"\"Split a frame into a grid of tiles.\"\"\"\n    height, width = frame.shape[:2]\n    tile_height = height // grid_rows\n    tile_width = width // grid_cols\n    tiles = []\n    tile_positions = []\n\n    for i in range(grid_rows):\n        for j in range(grid_cols):\n            y1 = i * tile_height\n            y2 = (i + 1) * tile_height if i < grid_rows - 1 else height\n            x1 = j * tile_width\n            x2 = (j + 1) * tile_width if j < grid_cols - 1 else width\n\n            tile = frame[y1:y2, x1:x2]\n            tiles.append(tile)\n            tile_positions.append((x1, y1, x2, y2))\n\n    return tiles, tile_positions\n\n\ndef convert_tile_coords_to_frame(box, tile_pos, frame_shape):\n    \"\"\"Convert coordinates from tile space to frame space.\"\"\"\n    frame_height, frame_width = frame_shape[:2]\n    tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos\n    tile_width = tile_x2 - tile_x1\n    tile_height = tile_y2 - tile_y1\n\n    x1_tile_abs = box[0] * tile_width\n    y1_tile_abs = box[1] * tile_height\n    x2_tile_abs = box[2] * tile_width\n    y2_tile_abs = box[3] * tile_height\n\n    x1_frame_abs = tile_x1 + x1_tile_abs\n    y1_frame_abs = tile_y1 + y1_tile_abs\n    x2_frame_abs = tile_x1 + x2_tile_abs\n    y2_frame_abs = tile_y1 + y2_tile_abs\n\n    x1_norm = x1_frame_abs / frame_width\n    y1_norm = y1_frame_abs / frame_height\n    x2_norm = x2_frame_abs / frame_width\n    y2_norm = y2_frame_abs / frame_height\n\n    x1_norm = max(0.0, min(1.0, x1_norm))\n    y1_norm = max(0.0, min(1.0, y1_norm))\n    x2_norm = max(0.0, min(1.0, x2_norm))\n    y2_norm = max(0.0, min(1.0, y2_norm))\n\n    return [x1_norm, y1_norm, x2_norm, y2_norm]\n\n\ndef merge_tile_detections(tile_detections, iou_threshold=0.5):\n    \"\"\"Merge detections from different tiles using NMS-like approach.\"\"\"\n    if not tile_detections:\n        return []\n\n    all_boxes = []\n    all_keywords = []\n\n    # Collect all boxes and their keywords\n    for detections in tile_detections:\n        for box, keyword in detections:\n            all_boxes.append(box)\n            all_keywords.append(keyword)\n\n    if not all_boxes:\n        return []\n\n    # Convert to numpy for easier processing\n    boxes = np.array(all_boxes)\n\n    # Calculate areas\n    x1 = boxes[:, 0]\n    y1 = boxes[:, 1]\n    x2 = boxes[:, 2]\n    y2 = boxes[:, 3]\n    areas = (x2 - x1) * (y2 - y1)\n\n    # Sort boxes by area\n    order = areas.argsort()[::-1]\n\n    keep = []\n    while order.size > 0:\n        i = order[0]\n        keep.append(i)\n\n        if order.size == 1:\n            break\n\n        # Calculate IoU with rest of boxes\n        xx1 = np.maximum(x1[i], x1[order[1:]])\n        yy1 = np.maximum(y1[i], y1[order[1:]])\n        xx2 = np.minimum(x2[i], x2[order[1:]])\n        yy2 = np.minimum(y2[i], y2[order[1:]])\n\n        w = np.maximum(0.0, xx2 - xx1)\n        h = np.maximum(0.0, yy2 - yy1)\n        inter = w * h\n\n        ovr = inter / (areas[i] + areas[order[1:]] - inter)\n\n        # Get indices of boxes with IoU less than threshold\n        inds = np.where(ovr <= iou_threshold)[0]\n        order = order[inds + 1]\n\n    return [(all_boxes[i], all_keywords[i]) for i in keep]\n\n\ndef detect_objects_in_frame(\n    model, tokenizer, image, target_object, grid_rows=1, grid_cols=1\n):\n    \"\"\"Detect specified objects in a frame using grid-based analysis.\"\"\"\n    if grid_rows == 1 and grid_cols == 1:\n        return detect_objects_in_frame_single(model, tokenizer, image, target_object)\n\n    # Convert numpy array to PIL Image if needed\n    if not isinstance(image, Image.Image):\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n\n    # Split frame into tiles\n    tiles, tile_positions = split_frame_into_grid(image, grid_rows, grid_cols)\n\n    # Process each tile\n    tile_detections = []\n    for tile, tile_pos in zip(tiles, tile_positions):\n        # Convert tile to PIL Image\n        tile_pil = Image.fromarray(tile)\n\n        # Detect objects in tile\n        response = model.detect(tile_pil, target_object)\n\n        if response and \"objects\" in response and response[\"objects\"]:\n            objects = response[\"objects\"]\n            tile_objects = []\n\n            for obj in objects:\n                if all(k in obj for k in [\"x_min\", \"y_min\", \"x_max\", \"y_max\"]):\n                    box = [obj[\"x_min\"], obj[\"y_min\"], obj[\"x_max\"], obj[\"y_max\"]]\n\n                    if is_valid_bounding_box(box):\n                        # Convert tile coordinates to frame coordinates\n                        frame_box = convert_tile_coords_to_frame(\n                            box, tile_pos, image.shape\n                        )\n                        tile_objects.append((frame_box, target_object))\n\n            if tile_objects:  # Only append if we found valid objects\n                tile_detections.append(tile_objects)\n\n    # Merge detections from all tiles\n    merged_detections = merge_tile_detections(tile_detections)\n    return merged_detections\n\n\ndef detect_objects_in_frame_single(model, tokenizer, image, target_object):\n    \"\"\"Single-frame detection function.\"\"\"\n    detected_objects = []\n\n    # Convert numpy array to PIL Image if needed\n    if not isinstance(image, Image.Image):\n        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n\n    # Detect objects\n    response = model.detect(image, target_object)\n\n    # Check if we have valid objects\n    if response and \"objects\" in response and response[\"objects\"]:\n        objects = response[\"objects\"]\n\n        for obj in objects:\n            if all(k in obj for k in [\"x_min\", \"y_min\", \"x_max\", \"y_max\"]):\n                box = [obj[\"x_min\"], obj[\"y_min\"], obj[\"x_max\"], obj[\"y_max\"]]\n                # If box is valid (not full-frame), add it\n                if is_valid_bounding_box(box):\n                    detected_objects.append((box, target_object))\n\n    return detected_objects\n\n\ndef draw_hitmarker(\n    frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True\n):\n    \"\"\"Draw a COD-style hitmarker cross with more space in the middle.\"\"\"\n    half_size = size // 2\n\n    # Draw shadow first if enabled\n    if shadow:\n        # Top-left to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x - half_size + HITMARKER_SHADOW_OFFSET,\n                center_y - half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Top-right to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x + half_size + HITMARKER_SHADOW_OFFSET,\n                center_y - half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Bottom-left to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x - half_size + HITMARKER_SHADOW_OFFSET,\n                center_y + half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Bottom-right to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x + half_size + HITMARKER_SHADOW_OFFSET,\n                center_y + half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n\n    # Draw main hitmarker\n    # Top-left to center\n    cv2.line(\n        frame,\n        (center_x - half_size, center_y - half_size),\n        (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Top-right to center\n    cv2.line(\n        frame,\n        (center_x + half_size, center_y - half_size),\n        (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Bottom-left to center\n    cv2.line(\n        frame,\n        (center_x - half_size, center_y + half_size),\n        (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Bottom-right to center\n    cv2.line(\n        frame,\n        (center_x + half_size, center_y + half_size),\n        (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n\n\ndef draw_ad_boxes(frame, detected_objects, detect_keyword, model, box_style=\"censor\"):\n    height, width = frame.shape[:2]\n\n    points = []\n    # Only get points if we need them for hitmarker or SAM styles\n    if box_style in [\"hitmarker\", \"sam\", \"sam-fast\"]:\n        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n        try:\n            point_response = model.point(frame_pil, detect_keyword)\n\n            if isinstance(point_response, dict) and \"points\" in point_response:\n                points = point_response[\"points\"]\n        except Exception as e:\n            print(f\"Error during point detection: {str(e)}\")\n            points = []\n\n    # Only load SAM models and process points if we're using SAM styles and have points\n    if box_style in [\"sam\", \"sam-fast\"] and points:\n        # Start with the original PIL image\n        frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n\n        # Collect all masks and points\n        all_masks = []\n        point_coords = []\n        point_labels = []\n\n        for point in points:\n            try:\n                center_x = int(float(point[\"x\"]) * width)\n                center_y = int(float(point[\"y\"]) * height)\n\n                # Get mask and visualization\n                mask, _ = process_sam_detection(\n                    frame_pil, center_x, center_y, slim=(box_style == \"sam-fast\")\n                )\n\n                # Collect mask and point data\n                all_masks.append(mask)\n                point_coords.append((center_x, center_y))\n                point_labels.append(detect_keyword)\n\n            except Exception as e:\n                print(f\"Error processing individual SAM point: {str(e)}\")\n                print(f\"Point data: {point}\")\n\n        if all_masks:\n            # Create final visualization with all masks\n            result_pil = create_mask_overlay(\n                frame_pil, all_masks, point_coords, point_labels\n            )\n            frame = cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR)\n\n    # Process other visualization styles\n    for detection in detected_objects:\n        try:\n            # Handle both tracked and untracked detections\n            if len(detection) == 3:  # Tracked detection with ID\n                box, keyword, track_id = detection\n            else:  # Regular detection without tracking\n                box, keyword = detection\n                track_id = None\n\n            x1 = int(box[0] * width)\n            y1 = int(box[1] * height)\n            x2 = int(box[2] * width)\n            y2 = int(box[3] * height)\n\n            x1 = max(0, min(x1, width - 1))\n            y1 = max(0, min(y1, height - 1))\n            x2 = max(0, min(x2, width - 1))\n            y2 = max(0, min(y2, height - 1))\n\n            if x2 > x1 and y2 > y1:\n                if box_style == \"censor\":\n                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1)\n                elif box_style == \"bounding-box\":\n                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3)\n\n                    label = (\n                        f\"{detect_keyword}\" if track_id is not None else detect_keyword\n                    )\n                    label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0]\n                    cv2.rectangle(\n                        frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1\n                    )\n                    cv2.putText(\n                        frame,\n                        label,\n                        (x1, y1 - 6),\n                        FONT,\n                        0.7,\n                        (255, 255, 255),\n                        2,\n                        cv2.LINE_AA,\n                    )\n                elif box_style == \"fuzzy-blur\":\n                    # Extract ROI\n                    roi = frame[y1:y2, x1:x2]\n                    # Apply Gaussian blur with much larger kernel for intense blur\n                    blurred_roi = cv2.GaussianBlur(roi, (125, 125), 0)\n                    # Replace original ROI with blurred version\n                    frame[y1:y2, x1:x2] = blurred_roi\n                elif box_style == \"pixelated-blur\":\n                    # Extract ROI\n                    roi = frame[y1:y2, x1:x2]\n                    # Pixelate by resizing down and up\n                    h, w = roi.shape[:2]\n                    temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR)\n                    pixelated = cv2.resize(\n                        temp, (w, h), interpolation=cv2.INTER_NEAREST\n                    )\n                    # Mix up the pixelated frame slightly by adding random noise\n                    noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8)\n                    pixelated = cv2.add(pixelated, noise)\n                    # Apply stronger Gaussian blur to smooth edges\n                    blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0)\n                    # Replace original ROI\n                    frame[y1:y2, x1:x2] = blurred_pixelated\n                elif box_style == \"obfuscated-pixel\":\n                    # Calculate expansion amount based on 10% of object dimensions\n                    box_width = x2 - x1\n                    box_height = y2 - y1\n                    expand_x = int(box_width * 0.10)\n                    expand_y = int(box_height * 0.10)\n\n                    # Expand the bounding box by 10% in all directions\n                    x1_expanded = max(0, x1 - expand_x)\n                    y1_expanded = max(0, y1 - expand_y)\n                    x2_expanded = min(width - 1, x2 + expand_x)\n                    y2_expanded = min(height - 1, y2 + expand_y)\n\n                    # Extract ROI with much larger padding for true background sampling\n                    padding = 100  # Much larger padding to get true background\n                    y1_pad = max(0, y1_expanded - padding)\n                    y2_pad = min(height, y2_expanded + padding)\n                    x1_pad = max(0, x1_expanded - padding)\n                    x2_pad = min(width, x2_expanded + padding)\n\n                    # Get the padded region including background\n                    padded_roi = frame[y1_pad:y2_pad, x1_pad:x2_pad]\n\n                    # Create mask that excludes a larger region around the detection\n                    h, w = y2_expanded - y1_expanded, x2_expanded - x1_expanded\n                    bg_mask = np.ones(padded_roi.shape[:2], dtype=bool)\n\n                    # Exclude a larger region around the detection from background sampling\n                    exclusion_padding = 50  # Area to exclude around detection\n                    exclude_y1 = padding - exclusion_padding\n                    exclude_y2 = padding + h + exclusion_padding\n                    exclude_x1 = padding - exclusion_padding\n                    exclude_x2 = padding + w + exclusion_padding\n\n                    # Make sure exclusion coordinates are valid\n                    exclude_y1 = max(0, exclude_y1)\n                    exclude_y2 = min(padded_roi.shape[0], exclude_y2)\n                    exclude_x1 = max(0, exclude_x1)\n                    exclude_x2 = min(padded_roi.shape[1], exclude_x2)\n\n                    # Mark the exclusion zone in the mask\n                    bg_mask[exclude_y1:exclude_y2, exclude_x1:exclude_x2] = False\n\n                    # If we have enough background pixels, calculate average color\n                    if np.any(bg_mask):\n                        bg_color = np.mean(padded_roi[bg_mask], axis=0).astype(np.uint8)\n                    else:\n                        # Fallback to edges if we couldn't get enough background\n                        edge_samples = np.concatenate(\n                            [\n                                padded_roi[0],  # Top edge\n                                padded_roi[-1],  # Bottom edge\n                                padded_roi[:, 0],  # Left edge\n                                padded_roi[:, -1],  # Right edge\n                            ]\n                        )\n                        bg_color = np.mean(edge_samples, axis=0).astype(np.uint8)\n\n                    # Create base pixelated version (of the expanded region)\n                    temp = cv2.resize(\n                        frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded],\n                        (6, 6),\n                        interpolation=cv2.INTER_LINEAR,\n                    )\n                    pixelated = cv2.resize(\n                        temp, (w, h), interpolation=cv2.INTER_NEAREST\n                    )\n\n                    # Blend heavily towards background color\n                    blend_factor = 0.9  # Much stronger blend with background\n                    blended = cv2.addWeighted(\n                        pixelated,\n                        1 - blend_factor,\n                        np.full((h, w, 3), bg_color, dtype=np.uint8),\n                        blend_factor,\n                        0,\n                    )\n\n                    # Replace original ROI with blended version (using expanded coordinates)\n                    frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = blended\n                elif box_style == \"intense-pixelated-blur\":\n                    # Expand the bounding box by pixels in all directions\n                    x1_expanded = max(0, x1 - 15)\n                    y1_expanded = max(0, y1 - 15)\n                    x2_expanded = min(width - 1, x2 + 25)\n                    y2_expanded = min(height - 1, y2 + 25)\n\n                    # Extract ROI\n                    roi = frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded]\n                    # Pixelate by resizing down and up\n                    h, w = roi.shape[:2]\n                    temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR)\n                    pixelated = cv2.resize(\n                        temp, (w, h), interpolation=cv2.INTER_NEAREST\n                    )\n                    # Mix up the pixelated frame slightly by adding random noise\n                    noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8)\n                    pixelated = cv2.add(pixelated, noise)\n                    # Apply stronger Gaussian blur to smooth edges\n                    blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0)\n                    # Replace original ROI\n                    frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = (\n                        blurred_pixelated\n                    )\n                elif box_style == \"hitmarker\":\n                    if points:\n                        for point in points:\n                            try:\n                                print(f\"Processing point: {point}\")\n                                center_x = int(float(point[\"x\"]) * width)\n                                center_y = int(float(point[\"y\"]) * height)\n                                print(\n                                    f\"Converted coordinates: ({center_x}, {center_y})\"\n                                )\n\n                                draw_hitmarker(frame, center_x, center_y)\n\n                                label = (\n                                    f\"{detect_keyword}\"\n                                    if track_id is not None\n                                    else detect_keyword\n                                )\n                                label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]\n                                cv2.putText(\n                                    frame,\n                                    label,\n                                    (\n                                        center_x - label_size[0] // 2,\n                                        center_y - HITMARKER_SIZE - 5,\n                                    ),\n                                    FONT,\n                                    0.5,\n                                    HITMARKER_COLOR,\n                                    1,\n                                    cv2.LINE_AA,\n                                )\n                            except Exception as e:\n                                print(f\"Error processing individual point: {str(e)}\")\n                                print(f\"Point data: {point}\")\n\n        except Exception as e:\n            print(f\"Error drawing {box_style} style box: {str(e)}\")\n            print(f\"Box data: {box}\")\n            print(f\"Keyword: {keyword}\")\n\n    return frame\n\n\ndef filter_temporal_outliers(detections_dict):\n    \"\"\"Filter out extremely large detections that take up most of the frame.\n    Only keeps detections that are reasonable in size.\n\n    Args:\n        detections_dict: Dictionary of {frame_number: [(box, keyword, track_id), ...]}\n    \"\"\"\n    filtered_detections = {}\n\n    for t, detections in detections_dict.items():\n        # Only keep detections that aren't too large\n        valid_detections = []\n        for detection in detections:\n            # Handle both tracked and untracked detections\n            if len(detection) == 3:  # Tracked detection with ID\n                box, keyword, track_id = detection\n            else:  # Regular detection without tracking\n                box, keyword = detection\n                track_id = None\n\n            # Calculate box size as percentage of frame\n            width = box[2] - box[0]\n            height = box[3] - box[1]\n            area = width * height\n\n            # If box is less than 90% of frame, keep it\n            if area < 0.9:\n                if track_id is not None:\n                    valid_detections.append((box, keyword, track_id))\n                else:\n                    valid_detections.append((box, keyword))\n\n        if valid_detections:\n            filtered_detections[t] = valid_detections\n\n    return filtered_detections\n\n\ndef describe_frames(\n    video_path,\n    model,\n    tokenizer,\n    detect_keyword,\n    test_mode=False,\n    test_duration=DEFAULT_TEST_MODE_DURATION,\n    grid_rows=1,\n    grid_cols=1,\n):\n    \"\"\"Extract and detect objects in frames.\"\"\"\n    props = get_video_properties(video_path)\n    fps = props[\"fps\"]\n\n    # Initialize DeepSORT tracker\n    tracker = DeepSORTTracker()\n\n    # If in test mode, only process first N seconds\n    if test_mode:\n        frame_count = min(int(fps * test_duration), props[\"frame_count\"])\n    else:\n        frame_count = props[\"frame_count\"]\n\n    ad_detections = {}  # Store detection results by frame number\n\n    print(\"Extracting frames and detecting objects...\")\n    video = cv2.VideoCapture(video_path)\n\n    # Detect scenes first\n    scenes = detect(video_path, scene_detector)\n    scene_changes = set(end.get_frames() for _, end in scenes)\n    print(f\"Detected {len(scenes)} scenes\")\n\n    frame_count_processed = 0\n    with tqdm(total=frame_count) as pbar:\n        while frame_count_processed < frame_count:\n            ret, frame = video.read()\n            if not ret:\n                break\n\n            # Check if current frame is a scene change\n            if frame_count_processed in scene_changes:\n                # Detect objects in the frame\n                detected_objects = detect_objects_in_frame(\n                    model,\n                    tokenizer,\n                    frame,\n                    detect_keyword,\n                    grid_rows=grid_rows,\n                    grid_cols=grid_cols,\n                )\n\n            # Update tracker with current detections\n            tracked_objects = tracker.update(frame, detected_objects)\n\n            # Store results for every frame, even if empty\n            ad_detections[frame_count_processed] = tracked_objects\n\n            frame_count_processed += 1\n            pbar.update(1)\n\n    video.release()\n\n    if frame_count_processed == 0:\n        print(\"No frames could be read from video\")\n        return {}\n\n    return ad_detections\n\n\ndef create_detection_video(\n    video_path,\n    ad_detections,\n    detect_keyword,\n    model,\n    output_path=None,\n    ffmpeg_preset=\"medium\",\n    test_mode=False,\n    test_duration=DEFAULT_TEST_MODE_DURATION,\n    box_style=\"censor\",\n):\n    \"\"\"Create video with detection boxes while preserving audio.\"\"\"\n    if output_path is None:\n        # Create outputs directory if it doesn't exist\n        outputs_dir = os.path.join(\n            os.path.dirname(os.path.abspath(__file__)), \"outputs\"\n        )\n        os.makedirs(outputs_dir, exist_ok=True)\n\n        # Clean the detect_keyword for filename\n        safe_keyword = \"\".join(\n            x for x in detect_keyword if x.isalnum() or x in (\" \", \"_\", \"-\")\n        )\n        safe_keyword = safe_keyword.replace(\" \", \"_\")\n\n        # Create output filename\n        base_name = os.path.splitext(os.path.basename(video_path))[0]\n        output_path = os.path.join(\n            outputs_dir, f\"{box_style}_{safe_keyword}_{base_name}.mp4\"\n        )\n\n    print(f\"Will save output to: {output_path}\")\n\n    props = get_video_properties(video_path)\n    fps, width, height = props[\"fps\"], props[\"width\"], props[\"height\"]\n\n    # If in test mode, only process first few seconds\n    if test_mode:\n        frame_count = min(int(fps * test_duration), props[\"frame_count\"])\n        print(\n            f\"Test mode enabled: Processing first {test_duration} seconds ({frame_count} frames)\"\n        )\n    else:\n        frame_count = props[\"frame_count\"]\n        print(\"Full video mode: Processing entire video\")\n\n    video = cv2.VideoCapture(video_path)\n\n    # Create temp output path by adding _temp before the extension\n    base, ext = os.path.splitext(output_path)\n    temp_output = f\"{base}_temp{ext}\"\n    temp_audio = f\"{base}_audio.aac\"  # Temporary audio file\n\n    out = cv2.VideoWriter(\n        temp_output, cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (width, height)\n    )\n\n    print(\"Creating detection video...\")\n    frame_count_processed = 0\n\n    with tqdm(total=frame_count) as pbar:\n        while frame_count_processed < frame_count:\n            ret, frame = video.read()\n            if not ret:\n                break\n\n            # Get detections for this exact frame\n            if frame_count_processed in ad_detections:\n                current_detections = ad_detections[frame_count_processed]\n                if current_detections:\n                    frame = draw_ad_boxes(\n                        frame,\n                        current_detections,\n                        detect_keyword,\n                        model,\n                        box_style=box_style,\n                    )\n\n            out.write(frame)\n            frame_count_processed += 1\n            pbar.update(1)\n\n    video.release()\n    out.release()\n\n    # Extract audio from original video\n    try:\n        if test_mode:\n            # In test mode, extract only the required duration of audio\n            subprocess.run(\n                [\n                    \"ffmpeg\",\n                    \"-y\",\n                    \"-i\",\n                    video_path,\n                    \"-t\",\n                    str(test_duration),\n                    \"-vn\",  # No video\n                    \"-acodec\",\n                    \"copy\",\n                    temp_audio,\n                ],\n                check=True,\n            )\n        else:\n            subprocess.run(\n                [\n                    \"ffmpeg\",\n                    \"-y\",\n                    \"-i\",\n                    video_path,\n                    \"-vn\",  # No video\n                    \"-acodec\",\n                    \"copy\",\n                    temp_audio,\n                ],\n                check=True,\n            )\n    except subprocess.CalledProcessError as e:\n        print(f\"Error extracting audio: {str(e)}\")\n        if os.path.exists(temp_output):\n            os.remove(temp_output)\n        return None\n\n    # Merge processed video with original audio\n    try:\n        # Base FFmpeg command\n        ffmpeg_cmd = [\n            \"ffmpeg\",\n            \"-y\",\n            \"-i\",\n            temp_output,\n            \"-i\",\n            temp_audio,\n            \"-c:v\",\n            \"libx264\",\n            \"-preset\",\n            ffmpeg_preset,\n            \"-crf\",\n            \"23\",\n            \"-c:a\",\n            \"aac\",\n            \"-b:a\",\n            \"192k\",\n            \"-movflags\",\n            \"+faststart\",  # Better web playback\n        ]\n\n        if test_mode:\n            # In test mode, ensure output duration matches test_duration\n            ffmpeg_cmd.extend(\n                [\n                    \"-t\",\n                    str(test_duration),\n                    \"-shortest\",  # Ensure output duration matches shortest input\n                ]\n            )\n\n        ffmpeg_cmd.extend([\"-loglevel\", \"error\", output_path])\n\n        subprocess.run(ffmpeg_cmd, check=True)\n\n        # Clean up temporary files\n        os.remove(temp_output)\n        os.remove(temp_audio)\n\n        if not os.path.exists(output_path):\n            print(\n                f\"Warning: FFmpeg completed but output file not found at {output_path}\"\n            )\n            return None\n\n        return output_path\n\n    except subprocess.CalledProcessError as e:\n        print(f\"Error merging audio with video: {str(e)}\")\n        if os.path.exists(temp_output):\n            os.remove(temp_output)\n        if os.path.exists(temp_audio):\n            os.remove(temp_audio)\n        return None\n\n\ndef process_video(\n    video_path,\n    target_object,\n    test_mode=False,\n    test_duration=DEFAULT_TEST_MODE_DURATION,\n    ffmpeg_preset=\"medium\",\n    grid_rows=1,\n    grid_cols=1,\n    box_style=\"censor\",\n):\n    \"\"\"Process a video to detect and visualize specified objects.\"\"\"\n    try:\n        print(f\"\\nProcessing: {video_path}\")\n        print(f\"Looking for: {target_object}\")\n\n        # Load model\n        print(\"Loading Moondream model...\")\n        model, tokenizer = load_moondream()\n\n        # Get video properties\n        props = get_video_properties(video_path)\n\n        # Initialize scene detector with ContentDetector\n        scene_detector = ContentDetector(threshold=30.0)  # Adjust threshold as needed\n\n        # Initialize DeepSORT tracker\n        tracker = DeepSORTTracker()\n\n        # If in test mode, only process first N seconds\n        if test_mode:\n            frame_count = min(int(props[\"fps\"] * test_duration), props[\"frame_count\"])\n        else:\n            frame_count = props[\"frame_count\"]\n\n        ad_detections = {}  # Store detection results by frame number\n\n        print(\"Extracting frames and detecting objects...\")\n        video = cv2.VideoCapture(video_path)\n\n        # Detect scenes first\n        scenes = detect(video_path, scene_detector)\n        scene_changes = set(end.get_frames() for _, end in scenes)\n        print(f\"Detected {len(scenes)} scenes\")\n\n        frame_count_processed = 0\n        with tqdm(total=frame_count) as pbar:\n            while frame_count_processed < frame_count:\n                ret, frame = video.read()\n                if not ret:\n                    break\n\n                # Check if current frame is a scene change\n                if frame_count_processed in scene_changes:\n                    print(\n                        f\"Scene change detected at frame {frame_count_processed}. Resetting tracker.\"\n                    )\n                    tracker.reset()\n\n                # Detect objects in the frame\n                detected_objects = detect_objects_in_frame(\n                    model,\n                    tokenizer,\n                    frame,\n                    target_object,\n                    grid_rows=grid_rows,\n                    grid_cols=grid_cols,\n                )\n\n                # Update tracker with current detections\n                tracked_objects = tracker.update(frame, detected_objects)\n\n                # Store results for every frame, even if empty\n                ad_detections[frame_count_processed] = tracked_objects\n\n                frame_count_processed += 1\n                pbar.update(1)\n\n        video.release()\n\n        if frame_count_processed == 0:\n            print(\"No frames could be read from video\")\n            return {}\n\n        # Apply filtering\n        filtered_ad_detections = filter_temporal_outliers(ad_detections)\n\n        # Build detection data structure\n        detection_data = {\n            \"video_metadata\": {\n                \"file_name\": os.path.basename(video_path),\n                \"fps\": props[\"fps\"],\n                \"width\": props[\"width\"],\n                \"height\": props[\"height\"],\n                \"total_frames\": props[\"frame_count\"],\n                \"duration_sec\": props[\"frame_count\"] / props[\"fps\"],\n                \"detect_keyword\": target_object,\n                \"test_mode\": test_mode,\n                \"grid_size\": f\"{grid_rows}x{grid_cols}\",\n                \"box_style\": box_style,\n                \"timestamp\": datetime.now().isoformat(),\n            },\n            \"frame_detections\": [\n                {\n                    \"frame\": frame_num,\n                    \"timestamp\": frame_num / props[\"fps\"],\n                    \"objects\": [\n                        {\n                            \"keyword\": kw,\n                            \"bbox\": list(box),  # Convert numpy array to list if needed\n                            \"track_id\": track_id if len(detection) == 3 else None,\n                        }\n                        for detection in filtered_ad_detections.get(frame_num, [])\n                        for box, kw, *track_id in [\n                            detection\n                        ]  # Unpack detection tuple, track_id will be empty list if not present\n                    ],\n                }\n                for frame_num in range(\n                    props[\"frame_count\"]\n                    if not test_mode\n                    else min(int(props[\"fps\"] * test_duration), props[\"frame_count\"])\n                )\n            ],\n        }\n\n        # Save filtered data\n        outputs_dir = os.path.join(\n            os.path.dirname(os.path.abspath(__file__)), \"outputs\"\n        )\n        os.makedirs(outputs_dir, exist_ok=True)\n        base_name = os.path.splitext(os.path.basename(video_path))[0]\n        json_path = os.path.join(\n            outputs_dir, f\"{box_style}_{target_object}_{base_name}_detections.json\"\n        )\n\n        from persistence import save_detection_data\n\n        if not save_detection_data(detection_data, json_path):\n            print(\"Warning: Failed to save detection data\")\n\n        # Create video with filtered data\n        output_path = create_detection_video(\n            video_path,\n            filtered_ad_detections,\n            target_object,\n            model,\n            ffmpeg_preset=ffmpeg_preset,\n            test_mode=test_mode,\n            test_duration=test_duration,\n            box_style=box_style,\n        )\n\n        if output_path is None:\n            print(\"\\nError: Failed to create output video\")\n            return None\n\n        print(f\"\\nOutput saved to: {output_path}\")\n        print(f\"Detection data saved to: {json_path}\")\n        return output_path\n\n    except Exception as e:\n        print(f\"Error processing video: {str(e)}\")\n        import traceback\n\n        traceback.print_exc()\n        return None\n\n\ndef main():\n    \"\"\"Process all videos in the inputs directory.\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Detect objects in videos using Moondream2\"\n    )\n    parser.add_argument(\n        \"--test\", action=\"store_true\", help=\"Process only first 3 seconds of each video\"\n    )\n    parser.add_argument(\n        \"--test-duration\",\n        type=int,\n        default=DEFAULT_TEST_MODE_DURATION,\n        help=f\"Number of seconds to process in test mode (default: {DEFAULT_TEST_MODE_DURATION})\",\n    )\n    parser.add_argument(\n        \"--preset\",\n        choices=FFMPEG_PRESETS,\n        default=\"medium\",\n        help=\"FFmpeg encoding preset (default: medium). Faster presets = lower quality\",\n    )\n    parser.add_argument(\n        \"--detect\",\n        type=str,\n        default=\"face\",\n        help='Object to detect in the video (default: face, use --detect \"thing to detect\" to override)',\n    )\n    parser.add_argument(\n        \"--rows\",\n        type=int,\n        default=1,\n        help=\"Number of rows to split each frame into (default: 1)\",\n    )\n    parser.add_argument(\n        \"--cols\",\n        type=int,\n        default=1,\n        help=\"Number of columns to split each frame into (default: 1)\",\n    )\n    parser.add_argument(\n        \"--box-style\",\n        choices=[\n            \"censor\",\n            \"bounding-box\",\n            \"hitmarker\",\n            \"sam\",\n            \"sam-fast\",\n            \"fuzzy-blur\",\n            \"pixelated-blur\",\n            \"intense-pixelated-blur\",\n            \"obfuscated-pixel\",\n        ],\n        default=\"censor\",\n        help=\"Style of detection visualization (default: censor)\",\n    )\n    args = parser.parse_args()\n\n    input_dir = \"inputs\"\n    os.makedirs(input_dir, exist_ok=True)\n    os.makedirs(\"outputs\", exist_ok=True)\n\n    video_files = [\n        f\n        for f in os.listdir(input_dir)\n        if f.lower().endswith((\".mp4\", \".avi\", \".mov\", \".mkv\", \".webm\"))\n    ]\n\n    if not video_files:\n        print(\"No video files found in 'inputs' directory\")\n        return\n\n    print(f\"Found {len(video_files)} videos to process\")\n    print(f\"Will detect: {args.detect}\")\n    if args.test:\n        print(\"Running in test mode - processing only first 3 seconds of each video\")\n    print(f\"Using FFmpeg preset: {args.preset}\")\n    print(f\"Grid size: {args.rows}x{args.cols}\")\n    print(f\"Box style: {args.box_style}\")\n\n    success_count = 0\n    for video_file in video_files:\n        video_path = os.path.join(input_dir, video_file)\n        output_path = process_video(\n            video_path,\n            args.detect,\n            test_mode=args.test,\n            test_duration=args.test_duration,\n            ffmpeg_preset=args.preset,\n            grid_rows=args.rows,\n            grid_cols=args.cols,\n            box_style=args.box_style,\n        )\n        if output_path:\n            success_count += 1\n\n    print(\n        f\"\\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos.\"\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "recipes/promptable-content-moderation/packages.txt",
    "content": "libvips\nffmpeg"
  },
  {
    "path": "recipes/promptable-content-moderation/persistence.py",
    "content": "import json\nimport os\n\n\ndef save_detection_data(data, output_file):\n    \"\"\"\n    Saves the detection data to a JSON file.\n\n    Args:\n        data (dict): The complete detection data structure.\n        output_file (str): Path to the output JSON file.\n    \"\"\"\n    try:\n        # Create directory if it doesn't exist\n        os.makedirs(os.path.dirname(output_file), exist_ok=True)\n\n        with open(output_file, \"w\") as f:\n            json.dump(data, f, indent=4)\n        print(f\"Detection data saved to {output_file}\")\n        return True\n    except Exception as e:\n        print(f\"Error saving data: {str(e)}\")\n        return False\n\n\ndef load_detection_data(input_file):\n    \"\"\"\n    Loads the detection data from a JSON file.\n\n    Args:\n        input_file (str): Path to the JSON file.\n\n    Returns:\n        dict: The loaded detection data, or None if there was an error.\n    \"\"\"\n    try:\n        with open(input_file, \"r\") as f:\n            return json.load(f)\n    except Exception as e:\n        print(f\"Error loading data: {str(e)}\")\n        return None\n"
  },
  {
    "path": "recipes/promptable-content-moderation/requirements.txt",
    "content": "gradio>=4.0.0\ntorch>=2.0.0\n# if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 \ntransformers>=4.36.0\nopencv-python>=4.8.0\npillow>=10.0.0\nnumpy>=1.24.0\ntqdm>=4.66.0\nffmpeg-python\neinops\npyvips-binary\npyvips\naccelerate\n# for spaces\n--extra-index-url https://download.pytorch.org/whl/cu113\nspaces\n# SAM dependencies\ntorchvision>=0.20.1\nmatplotlib>=3.7.0\npandas>=2.0.0\nplotly\n# DeepSORT dependencies\ndeep-sort-realtime>=1.3.2\nscikit-learn  # Required for deep-sort-realtime\n# Scene detection dependencies (for intelligent scene-aware tracking)\nscenedetect[opencv]>=0.6.2  # Provides scene change detection capabilities"
  },
  {
    "path": "recipes/promptable-content-moderation/video_visualization.py",
    "content": "import os\nimport tempfile\nimport subprocess\nimport matplotlib.pyplot as plt\nimport pandas as pd\nimport cv2\nimport numpy as np\nfrom tqdm import tqdm\nfrom persistence import load_detection_data\n\n\ndef create_frame_data(json_path):\n    \"\"\"Create frame-by-frame detection data for visualization.\"\"\"\n    try:\n        data = load_detection_data(json_path)\n        if not data:\n            print(\"No data loaded from JSON file\")\n            return None\n\n        if \"video_metadata\" not in data or \"frame_detections\" not in data:\n            print(\"Invalid JSON structure: missing required fields\")\n            return None\n\n        # Extract video metadata\n        metadata = data[\"video_metadata\"]\n        if \"fps\" not in metadata or \"total_frames\" not in metadata:\n            print(\"Invalid metadata: missing fps or total_frames\")\n            return None\n\n        fps = metadata[\"fps\"]\n        total_frames = metadata[\"total_frames\"]\n\n        # Create frame data\n        frame_counts = {}\n        for frame_data in data[\"frame_detections\"]:\n            if \"frame\" not in frame_data or \"objects\" not in frame_data:\n                continue  # Skip invalid frame data\n            frame_num = frame_data[\"frame\"]\n            frame_counts[frame_num] = len(frame_data[\"objects\"])\n\n        # Fill in missing frames with 0 detections\n        for frame in range(total_frames):\n            if frame not in frame_counts:\n                frame_counts[frame] = 0\n\n        if not frame_counts:\n            print(\"No valid frame data found\")\n            return None\n\n        # Convert to DataFrame\n        df = pd.DataFrame(list(frame_counts.items()), columns=[\"frame\", \"detections\"])\n        df[\"timestamp\"] = df[\"frame\"] / fps\n\n        return df, metadata\n\n    except Exception as e:\n        print(f\"Error creating frame data: {str(e)}\")\n        import traceback\n\n        traceback.print_exc()\n        return None\n\n\ndef generate_frame_image(df, frame_num, temp_dir, max_y):\n    \"\"\"Generate and save a single frame of the visualization.\"\"\"\n    # Set the style to dark background\n    plt.style.use(\"dark_background\")\n\n    # Set global font to monospace\n    plt.rcParams[\"font.family\"] = \"monospace\"\n    plt.rcParams[\"font.monospace\"] = [\"DejaVu Sans Mono\"]\n\n    plt.figure(figsize=(10, 6))\n\n    # Plot data up to current frame\n    current_data = df[df[\"frame\"] <= frame_num]\n    plt.plot(\n        df[\"frame\"], df[\"detections\"], color=\"#1a1a1a\", alpha=0.5\n    )  # Darker background line\n    plt.plot(\n        current_data[\"frame\"], current_data[\"detections\"], color=\"#00ff41\"\n    )  # Matrix green\n\n    # Add vertical line for current position\n    plt.axvline(\n        x=frame_num, color=\"#ff0000\", linestyle=\"-\", alpha=0.7\n    )  # Keep red for position\n\n    # Set consistent axes\n    plt.xlim(0, len(df) - 1)\n    plt.ylim(0, max_y * 1.1)  # Add 10% padding\n\n    # Add labels with Matrix green color\n    plt.title(f\"FRAME {frame_num:04d} - DETECTIONS OVER TIME\", color=\"#00ff41\", pad=20)\n    plt.xlabel(\"FRAME NUMBER\", color=\"#00ff41\")\n    plt.ylabel(\"NUMBER OF DETECTIONS\", color=\"#00ff41\")\n\n    # Add current stats in Matrix green with monospace formatting\n    current_detections = df[df[\"frame\"] == frame_num][\"detections\"].iloc[0]\n    plt.text(\n        0.02,\n        0.98,\n        f\"CURRENT DETECTIONS: {current_detections:02d}\",\n        transform=plt.gca().transAxes,\n        verticalalignment=\"top\",\n        color=\"#00ff41\",\n        family=\"monospace\",\n    )\n\n    # Style the grid and ticks\n    plt.grid(True, color=\"#1a1a1a\", linestyle=\"-\", alpha=0.3)\n    plt.tick_params(colors=\"#00ff41\")\n\n    # Save frame\n    frame_path = os.path.join(temp_dir, f\"frame_{frame_num:05d}.png\")\n    plt.savefig(\n        frame_path, bbox_inches=\"tight\", dpi=100, facecolor=\"black\", edgecolor=\"none\"\n    )\n    plt.close()\n\n    return frame_path\n\n\ndef generate_gauge_frame(df, frame_num, temp_dir, detect_keyword=\"OBJECT\"):\n    \"\"\"Generate a modern square-style binary gauge visualization frame.\"\"\"\n    # Set the style to dark background\n    plt.style.use(\"dark_background\")\n\n    # Set global font to monospace\n    plt.rcParams[\"font.family\"] = \"monospace\"\n    plt.rcParams[\"font.monospace\"] = [\"DejaVu Sans Mono\"]\n\n    # Create figure with 16:9 aspect ratio\n    plt.figure(figsize=(16, 9))\n\n    # Get current detection state\n    current_detections = df[df[\"frame\"] == frame_num][\"detections\"].iloc[0]\n    has_detection = current_detections > 0\n\n    # Create a simple gauge visualization\n    plt.axis(\"off\")\n\n    # Set colors\n    if has_detection:\n        color = \"#00ff41\"  # Matrix green for YES\n        status = \"YES\"\n        indicator_pos = 0.8  # Right position\n    else:\n        color = \"#ff0000\"  # Red for NO\n        status = \"NO\"\n        indicator_pos = 0.2  # Left position\n\n    # Draw background rectangle\n    background = plt.Rectangle(\n        (0.1, 0.3), 0.8, 0.2, facecolor=\"#1a1a1a\", edgecolor=\"#333333\", linewidth=2\n    )\n    plt.gca().add_patch(background)\n\n    # Draw indicator\n    indicator_width = 0.05\n    indicator = plt.Rectangle(\n        (indicator_pos - indicator_width / 2, 0.25),\n        indicator_width,\n        0.3,\n        facecolor=color,\n        edgecolor=None,\n    )\n    plt.gca().add_patch(indicator)\n\n    # Add tick marks\n    tick_positions = [0.2, 0.5, 0.8]  # NO, CENTER, YES\n    for x in tick_positions:\n        plt.plot([x, x], [0.3, 0.5], color=\"#444444\", linewidth=2)\n\n    # Add YES/NO labels\n    plt.text(\n        0.8,\n        0.2,\n        \"YES\",\n        color=\"#00ff41\",\n        fontsize=14,\n        ha=\"center\",\n        va=\"center\",\n        family=\"monospace\",\n    )\n    plt.text(\n        0.2,\n        0.2,\n        \"NO\",\n        color=\"#ff0000\",\n        fontsize=14,\n        ha=\"center\",\n        va=\"center\",\n        family=\"monospace\",\n    )\n\n    # Add status box at top with detection keyword\n    plt.text(\n        0.5,\n        0.8,\n        f\"{detect_keyword.upper()} DETECTED?\",\n        color=color,\n        fontsize=16,\n        ha=\"center\",\n        va=\"center\",\n        family=\"monospace\",\n        bbox=dict(facecolor=\"#1a1a1a\", edgecolor=color, linewidth=2, pad=10),\n    )\n\n    # Add frame counter at bottom\n    plt.text(\n        0.5,\n        0.1,\n        f\"FRAME: {frame_num:04d}\",\n        color=\"#00ff41\",\n        fontsize=14,\n        ha=\"center\",\n        va=\"center\",\n        family=\"monospace\",\n    )\n\n    # Add subtle grid lines for depth\n    for x in np.linspace(0.2, 0.8, 7):\n        plt.plot([x, x], [0.3, 0.5], color=\"#222222\", linewidth=1, zorder=0)\n\n    # Add glow effect to indicator\n    for i in range(3):\n        glow = plt.Rectangle(\n            (indicator_pos - (indicator_width + i * 0.01) / 2, 0.25 - i * 0.01),\n            indicator_width + i * 0.01,\n            0.3 + i * 0.02,\n            facecolor=color,\n            alpha=0.1 / (i + 1),\n        )\n        plt.gca().add_patch(glow)\n\n    # Set consistent plot limits\n    plt.xlim(0, 1)\n    plt.ylim(0, 1)\n\n    # Save frame with 16:9 aspect ratio\n    frame_path = os.path.join(temp_dir, f\"gauge_{frame_num:05d}.png\")\n    plt.savefig(\n        frame_path,\n        bbox_inches=\"tight\",\n        dpi=100,\n        facecolor=\"black\",\n        edgecolor=\"none\",\n        pad_inches=0,\n    )\n    plt.close()\n\n    return frame_path\n\n\ndef create_video_visualization(json_path, style=\"timeline\"):\n    \"\"\"Create a video visualization of the detection data.\"\"\"\n    try:\n        if not json_path:\n            return None, \"No JSON file provided\"\n\n        if not os.path.exists(json_path):\n            return None, f\"File not found: {json_path}\"\n\n        # Load and process data\n        result = create_frame_data(json_path)\n        if result is None:\n            return None, \"Failed to load detection data from JSON file\"\n\n        frame_data, metadata = result\n        if len(frame_data) == 0:\n            return None, \"No frame data found in JSON file\"\n\n        total_frames = metadata[\"total_frames\"]\n        detect_keyword = metadata.get(\n            \"detect_keyword\", \"OBJECT\"\n        )  # Get the detection keyword\n\n        # Create temporary directory for frames\n        with tempfile.TemporaryDirectory() as temp_dir:\n            max_y = frame_data[\"detections\"].max()\n\n            # Generate each frame\n            print(\"Generating frames...\")\n            frame_paths = []\n            with tqdm(total=total_frames, desc=\"Generating frames\") as pbar:\n                for frame in range(total_frames):\n                    try:\n                        if style == \"gauge\":\n                            frame_path = generate_gauge_frame(\n                                frame_data, frame, temp_dir, detect_keyword\n                            )\n                        else:  # default to timeline\n                            frame_path = generate_frame_image(\n                                frame_data, frame, temp_dir, max_y\n                            )\n                        if frame_path and os.path.exists(frame_path):\n                            frame_paths.append(frame_path)\n                        else:\n                            print(f\"Warning: Failed to generate frame {frame}\")\n                        pbar.update(1)\n                    except Exception as e:\n                        print(f\"Error generating frame {frame}: {str(e)}\")\n                        continue\n\n            if not frame_paths:\n                return None, \"Failed to generate any frames\"\n\n            # Create output video path\n            output_dir = os.path.join(\n                os.path.dirname(os.path.abspath(__file__)), \"outputs\"\n            )\n            os.makedirs(output_dir, exist_ok=True)\n            output_video = os.path.join(\n                output_dir, f\"detection_visualization_{style}.mp4\"\n            )\n\n            # Create temp output path\n            base, ext = os.path.splitext(output_video)\n            temp_output = f\"{base}_temp{ext}\"\n\n            # First pass: Create video with OpenCV VideoWriter\n            print(\"Creating initial video...\")\n            # Get frame size from first image\n            first_frame = cv2.imread(frame_paths[0])\n            height, width = first_frame.shape[:2]\n\n            out = cv2.VideoWriter(\n                temp_output,\n                cv2.VideoWriter_fourcc(*\"mp4v\"),\n                metadata[\"fps\"],\n                (width, height),\n            )\n\n            with tqdm(\n                total=total_frames, desc=\"Creating video\"\n            ) as pbar:  # Use total_frames here too\n                for frame_path in frame_paths:\n                    frame = cv2.imread(frame_path)\n                    out.write(frame)\n                    pbar.update(1)\n\n            out.release()\n\n            # Second pass: Convert to web-compatible format\n            print(\"Converting to web format...\")\n            try:\n                subprocess.run(\n                    [\n                        \"ffmpeg\",\n                        \"-y\",\n                        \"-i\",\n                        temp_output,\n                        \"-c:v\",\n                        \"libx264\",\n                        \"-preset\",\n                        \"medium\",\n                        \"-crf\",\n                        \"23\",\n                        \"-movflags\",\n                        \"+faststart\",  # Better web playback\n                        \"-loglevel\",\n                        \"error\",\n                        output_video,\n                    ],\n                    check=True,\n                )\n\n                os.remove(temp_output)  # Remove the temporary file\n\n                if not os.path.exists(output_video):\n                    print(\n                        f\"Warning: FFmpeg completed but output file not found at {output_video}\"\n                    )\n                    return None, \"Failed to create video\"\n\n                # Return video path and stats\n                stats = f\"\"\"Video Stats:\nFPS: {metadata['fps']}\nTotal Frames: {metadata['total_frames']}\nDuration: {metadata['duration_sec']:.2f} seconds\nMax Detections in a Frame: {frame_data['detections'].max()}\nAverage Detections per Frame: {frame_data['detections'].mean():.2f}\"\"\"\n\n                return output_video, stats\n\n            except subprocess.CalledProcessError as e:\n                print(f\"Error running FFmpeg: {str(e)}\")\n                if os.path.exists(temp_output):\n                    os.remove(temp_output)\n                return None, f\"Error creating visualization: {str(e)}\"\n\n    except Exception as e:\n        print(f\"Error creating video visualization: {str(e)}\")\n        import traceback\n\n        traceback.print_exc()\n        return None, f\"Error creating visualization: {str(e)}\"\n"
  },
  {
    "path": "recipes/promptable-content-moderation/visualization.py",
    "content": "import pandas as pd\nimport matplotlib.pyplot as plt\nfrom persistence import load_detection_data\nimport argparse\n\n\ndef visualize_detections(json_path):\n    \"\"\"\n    Visualize detection data from a JSON file.\n\n    Args:\n        json_path (str): Path to the JSON file containing detection data.\n    \"\"\"\n    # Load the persisted JSON data\n    data = load_detection_data(json_path)\n    if not data:\n        return\n\n    # Convert the frame detections to a DataFrame\n    rows = []\n    for frame_data in data[\"frame_detections\"]:\n        frame = frame_data[\"frame\"]\n        timestamp = frame_data[\"timestamp\"]\n        for obj in frame_data[\"objects\"]:\n            rows.append(\n                {\n                    \"frame\": frame,\n                    \"timestamp\": timestamp,\n                    \"keyword\": obj[\"keyword\"],\n                    \"x1\": obj[\"bbox\"][0],\n                    \"y1\": obj[\"bbox\"][1],\n                    \"x2\": obj[\"bbox\"][2],\n                    \"y2\": obj[\"bbox\"][3],\n                    \"area\": (obj[\"bbox\"][2] - obj[\"bbox\"][0])\n                    * (obj[\"bbox\"][3] - obj[\"bbox\"][1]),\n                }\n            )\n\n    if not rows:\n        print(\"No detections found in the data\")\n        return\n\n    df = pd.DataFrame(rows)\n\n    # Create a figure with multiple subplots\n    fig = plt.figure(figsize=(15, 10))\n\n    # Plot 1: Number of detections per frame\n    plt.subplot(2, 2, 1)\n    detections_per_frame = df.groupby(\"frame\").size()\n    plt.plot(detections_per_frame.index, detections_per_frame.values)\n    plt.xlabel(\"Frame\")\n    plt.ylabel(\"Number of Detections\")\n    plt.title(\"Detections Per Frame\")\n\n    # Plot 2: Distribution of detection areas\n    plt.subplot(2, 2, 2)\n    df[\"area\"].hist(bins=30)\n    plt.xlabel(\"Detection Area (normalized)\")\n    plt.ylabel(\"Count\")\n    plt.title(\"Distribution of Detection Areas\")\n\n    # Plot 3: Average detection area over time\n    plt.subplot(2, 2, 3)\n    avg_area = df.groupby(\"frame\")[\"area\"].mean()\n    plt.plot(avg_area.index, avg_area.values)\n    plt.xlabel(\"Frame\")\n    plt.ylabel(\"Average Detection Area\")\n    plt.title(\"Average Detection Area Over Time\")\n\n    # Plot 4: Heatmap of detection centers\n    plt.subplot(2, 2, 4)\n    df[\"center_x\"] = (df[\"x1\"] + df[\"x2\"]) / 2\n    df[\"center_y\"] = (df[\"y1\"] + df[\"y2\"]) / 2\n    plt.hist2d(df[\"center_x\"], df[\"center_y\"], bins=30)\n    plt.colorbar()\n    plt.xlabel(\"X Position\")\n    plt.ylabel(\"Y Position\")\n    plt.title(\"Detection Center Heatmap\")\n\n    # Adjust layout and display\n    plt.tight_layout()\n    plt.show()\n\n    # Print summary statistics\n    print(\"\\nSummary Statistics:\")\n    print(f\"Total frames analyzed: {len(data['frame_detections'])}\")\n    print(f\"Total detections: {len(df)}\")\n    print(\n        f\"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}\"\n    )\n    print(f\"\\nVideo metadata:\")\n    for key, value in data[\"video_metadata\"].items():\n        print(f\"{key}: {value}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Visualize object detection data\")\n    parser.add_argument(\n        \"json_file\", help=\"Path to the JSON file containing detection data\"\n    )\n    args = parser.parse_args()\n\n    visualize_detections(args.json_file)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "recipes/promptable-video-redaction/.gitignore",
    "content": "# Python\n__pycache__/\n*.py[cod]\n*$py.class\n*.so\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# Virtual Environment\nvenv/\nenv/\nENV/\n.venv/\n\n# IDE\n.idea/\n.vscode/\n*.swp\n*.swo\n\n# Project specific\ninputs/*\noutputs/*\n!inputs/.gitkeep\n!outputs/.gitkeep\ninputs/\noutputs/\n\n# Model files\n*.pth\n*.onnx\n*.pt\n\n# Logs\n*.log\n\ncertificate.pem"
  },
  {
    "path": "recipes/promptable-video-redaction/README.md",
    "content": "# Promptable Video Redaction with Moondream\n\nThis tool uses Moondream 2B, a powerful yet lightweight vision-language model, to detect and redact objects from videos. Moondream can recognize a wide variety of objects, people,\ntext, and more with high accuracy while being much smaller than traditional models.\n\n[Try it now.](https://huggingface.co/spaces/moondream/promptable-video-redaction)\n\n## About Moondream\n\nMoondream is a tiny yet powerful vision-language model that can analyze images and answer questions about them. It's designed to be lightweight and efficient while maintaining high\naccuracy. Some key features:\n\n- Only 2B parameters\n- Fast inference with minimal resource requirements\n- Supports CPU and GPU execution\n- Open source and free to use\n- Can detect almost anything you can describe in natural language\n\nLinks:\n\n- [GitHub Repository](https://github.com/vikhyat/moondream)\n- [Hugging Face](https://huggingface.co/vikhyatk/moondream2)\n- [Build with Moondream](http://docs.moondream.ai/)\n\n## Features\n\n- Real-time object detection in videos using Moondream\n- Multiple visualization styles:\n  - Censor: Black boxes over detected objects\n  - Bounding Box: Traditional bounding boxes with labels\n  - Hitmarker: Call of Duty style crosshair markers\n- Optional grid-based detection for improved accuracy\n- Flexible object type detection using natural language\n- Frame-by-frame processing with IoU-based merging\n- Batch processing of multiple videos\n- Web-compatible output format\n- User-friendly web interface\n- Command-line interface for automation\n\n## Requirements\n\n- Python 3.8+\n- OpenCV (cv2)\n- PyTorch\n- Transformers\n- Pillow (PIL)\n- tqdm\n- ffmpeg\n- numpy\n- gradio (for web interface)\n\n## Installation\n\n1. Clone this repository and create a new virtual environment\n\n```bash\ngit clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction\npython -m venv .venv\nsource .venv/bin/activate\n```\n\n2. Install the required packages:\n\n```bash\npip install -r requirements.txt\n```\n\n3. Install ffmpeg:\n   - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips`\n   - On macOS: `brew install ffmpeg`\n   - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html)\n     > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start)\n\n## Usage\n\n### Web Interface\n\n1. Start the web interface:\n\n```bash\npython app.py\n```\n\n2. Open the provided URL in your browser\n\n3. Use the interface to:\n   - Upload your video\n   - Specify what to censor (e.g., face, logo, text)\n   - Adjust processing speed and quality\n   - Configure grid size for detection\n   - Process and download the censored video\n\n### Command Line Interface\n\n1. Create an `inputs` directory in the same folder as the script:\n\n```bash\nmkdir inputs\n```\n\n2. Place your video files in the `inputs` directory. Supported formats:\n\n   - .mp4\n   - .avi\n   - .mov\n   - .mkv\n   - .webm\n\n3. Run the script:\n\n```bash\npython main.py\n```\n\n### Optional Arguments:\n\n- `--test`: Process only first 3 seconds of each video (useful for testing detection settings)\n\n```bash\npython main.py --test\n```\n\n- `--preset`: Choose FFmpeg encoding preset (affects output quality vs. speed)\n\n```bash\npython main.py --preset ultrafast  # Fastest, lower quality\npython main.py --preset veryslow   # Slowest, highest quality\n```\n\n- `--detect`: Specify what object type to detect (using natural language)\n\n```bash\npython main.py --detect person     # Detect people\npython main.py --detect \"red car\"  # Detect red cars\npython main.py --detect \"person wearing a hat\"  # Detect people with hats\n```\n\n- `--box-style`: Choose visualization style\n\n```bash\npython main.py --box-style censor     # Black boxes (default)\npython main.py --box-style bounding-box       # Bounding box-style boxes with labels\npython main.py --box-style hitmarker  # COD-style hitmarkers\n```\n\n- `--rows` and `--cols`: Enable grid-based detection by splitting frames\n\n```bash\npython main.py --rows 2 --cols 2   # Split each frame into 2x2 grid\npython main.py --rows 3 --cols 3   # Split each frame into 3x3 grid\n```\n\nYou can combine arguments:\n\n```bash\npython main.py --detect \"person wearing sunglasses\" --box-style bounding-box --test --preset \"fast\" --rows 2 --cols 2\n```\n\n### Visualization Styles\n\nThe tool supports three different visualization styles for detected objects:\n\n1. **Censor** (default)\n\n   - Places solid black rectangles over detected objects\n   - Best for privacy and content moderation\n   - Completely obscures the detected region\n\n2. **Bounding Box**\n\n   - Traditional object detection style\n   - Red bounding box around detected objects\n   - Label showing object type above the box\n   - Good for analysis and debugging\n\n3. **Hitmarker**\n   - Call of Duty inspired visualization\n   - White crosshair marker at center of detected objects\n   - Small label above the marker\n   - Stylistic choice for gaming-inspired visualization\n\nChoose the style that best fits your use case using the `--box-style` argument.\n\n## Output\n\nProcessed videos will be saved in the `outputs` directory with the format: `[style]_[object_type]_[original_filename].mp4`\n\nFor example:\n\n- `censor_face_video.mp4`\n- `bounding-box_person_video.mp4`\n- `hitmarker_car_video.mp4`\n\nThe output videos will include:\n\n- Original video content\n- Selected visualization style for detected objects\n- Web-compatible H.264 encoding\n\n## Notes\n\n- Processing time depends on video length, grid size, and GPU availability\n- GPU is strongly recommended for faster processing\n- Requires sufficient disk space for temporary files\n- Detection quality varies based on video quality and Moondream's ability to recognize the specified object\n- Grid-based detection impacts performance significantly - use only when needed\n- Web interface shows progress updates and errors\n- Choose visualization style based on your use case\n- Moondream can detect almost anything you can describe in natural language\n"
  },
  {
    "path": "recipes/promptable-video-redaction/app.py",
    "content": "#!/usr/bin/env python3\nimport gradio as gr\nimport os\nfrom main import load_moondream, process_video\nimport shutil\nimport torch\n\n# Get absolute path to workspace root\nWORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__))\n\n# Check CUDA availability\nprint(f\"Is CUDA available: {torch.cuda.is_available()}\")\n# We want to get True\nprint(f\"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}\")\n# GPU Name\n\n# Initialize model globally for reuse\nprint(\"Loading Moondream model...\")\nmodel, tokenizer = load_moondream()\n\n\ndef process_video_file(\n    video_file, detect_keyword, box_style, ffmpeg_preset, rows, cols, test_mode\n):\n    \"\"\"Process a video file through the Gradio interface.\"\"\"\n    try:\n        if not video_file:\n            raise gr.Error(\"Please upload a video file\")\n\n        # Ensure input/output directories exist using absolute paths\n        inputs_dir = os.path.join(WORKSPACE_ROOT, \"inputs\")\n        outputs_dir = os.path.join(WORKSPACE_ROOT, \"outputs\")\n        os.makedirs(inputs_dir, exist_ok=True)\n        os.makedirs(outputs_dir, exist_ok=True)\n\n        # Copy uploaded video to inputs directory\n        video_filename = f\"input_{os.path.basename(video_file)}\"\n        input_video_path = os.path.join(inputs_dir, video_filename)\n        shutil.copy2(video_file, input_video_path)\n\n        try:\n            # Process the video\n            output_path = process_video(\n                input_video_path,\n                detect_keyword,\n                test_mode=test_mode,\n                ffmpeg_preset=ffmpeg_preset,\n                rows=rows,\n                cols=cols,\n                box_style=box_style,\n            )\n\n            # Verify output exists and is readable\n            if not output_path or not os.path.exists(output_path):\n                print(f\"Warning: Output path {output_path} does not exist\")\n                # Try to find the output based on expected naming convention\n                expected_output = os.path.join(\n                    outputs_dir, f\"{box_style}_{detect_keyword}_{video_filename}\"\n                )\n                if os.path.exists(expected_output):\n                    output_path = expected_output\n                else:\n                    # Try searching in outputs directory for any matching file\n                    matching_files = [\n                        f\n                        for f in os.listdir(outputs_dir)\n                        if f.startswith(f\"{box_style}_{detect_keyword}_\")\n                    ]\n                    if matching_files:\n                        output_path = os.path.join(outputs_dir, matching_files[0])\n                    else:\n                        raise gr.Error(\"Failed to locate output video\")\n\n            # Convert output path to absolute path if it isn't already\n            if not os.path.isabs(output_path):\n                output_path = os.path.join(WORKSPACE_ROOT, output_path)\n\n            print(f\"Returning output path: {output_path}\")\n            return output_path\n\n        finally:\n            # Clean up input file\n            try:\n                if os.path.exists(input_video_path):\n                    os.remove(input_video_path)\n            except:\n                pass\n\n    except Exception as e:\n        print(f\"Error in process_video_file: {str(e)}\")\n        raise gr.Error(f\"Error processing video: {str(e)}\")\n\n\n# Create the Gradio interface\nwith gr.Blocks(title=\"Promptable Video Redaction\") as app:\n    gr.Markdown(\"# Promptable Video Redaction with Moondream\")\n    gr.Markdown(\n        \"\"\"\n    [Moondream 2B](https://github.com/vikhyat/moondream) is a lightweight vision model that detects and visualizes objects in videos. It can identify objects, people, text and more.\n\n    Upload a video and specify what to detect. The app will process each frame and apply your chosen visualization style. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH).\n    \"\"\"\n    )\n\n    with gr.Row():\n        with gr.Column():\n            # Input components\n            video_input = gr.Video(label=\"Upload Video\")\n            detect_input = gr.Textbox(\n                label=\"What to Detect\",\n                placeholder=\"e.g. face, logo, text, person, car, dog, etc.\",\n                value=\"face\",\n                info=\"Moondream can detect anything that you can describe in natural language\",\n            )\n            process_btn = gr.Button(\"Process Video\", variant=\"primary\")\n\n            with gr.Accordion(\"Advanced Settings\", open=False):\n                box_style_input = gr.Radio(\n                    choices=[\"censor\", \"bounding-box\", \"hitmarker\"],\n                    value=\"censor\",\n                    label=\"Visualization Style\",\n                    info=\"Choose how to display detections\",\n                )\n                preset_input = gr.Dropdown(\n                    choices=[\n                        \"ultrafast\",\n                        \"superfast\",\n                        \"veryfast\",\n                        \"faster\",\n                        \"fast\",\n                        \"medium\",\n                        \"slow\",\n                        \"slower\",\n                        \"veryslow\",\n                    ],\n                    value=\"medium\",\n                    label=\"Processing Speed (faster = lower quality)\",\n                )\n                with gr.Row():\n                    rows_input = gr.Slider(\n                        minimum=1, maximum=4, value=1, step=1, label=\"Grid Rows\"\n                    )\n                    cols_input = gr.Slider(\n                        minimum=1, maximum=4, value=1, step=1, label=\"Grid Columns\"\n                    )\n\n                test_mode_input = gr.Checkbox(\n                    label=\"Test Mode (Process first 3 seconds only)\",\n                    value=True,\n                    info=\"Enable to quickly test settings on a short clip before processing the full video (recommended)\",\n                )\n\n                gr.Markdown(\n                    \"\"\"\n                Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings.\n                \"\"\"\n                )\n\n                gr.Markdown(\n                    \"\"\"\n                We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection.\n                For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU).\n                \"\"\"\n                )\n\n        with gr.Column():\n            # Output components\n            video_output = gr.Video(label=\"Processed Video\")\n\n            # About section under the video output\n            gr.Markdown(\n                \"\"\"\n            ### Links:\n            - [GitHub Repository](https://github.com/vikhyat/moondream)\n            - [Hugging Face](https://huggingface.co/vikhyatk/moondream2)\n            - [Python Package](https://pypi.org/project/moondream/)\n            - [Moondream Recipes](https://docs.moondream.ai/recipes)\n            \"\"\"\n            )\n\n    # Event handlers\n    process_btn.click(\n        fn=process_video_file,\n        inputs=[\n            video_input,\n            detect_input,\n            box_style_input,\n            preset_input,\n            rows_input,\n            cols_input,\n            test_mode_input,\n        ],\n        outputs=video_output,\n    )\n\nif __name__ == \"__main__\":\n    app.launch(share=True)\n"
  },
  {
    "path": "recipes/promptable-video-redaction/main.py",
    "content": "#!/usr/bin/env python3\nimport cv2, os, subprocess, argparse\nfrom PIL import Image\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\nfrom tqdm import tqdm\nimport numpy as np\n\n# Constants\nTEST_MODE_DURATION = 3  # Process only first 3 seconds in test mode\nFFMPEG_PRESETS = [\n    \"ultrafast\",\n    \"superfast\",\n    \"veryfast\",\n    \"faster\",\n    \"fast\",\n    \"medium\",\n    \"slow\",\n    \"slower\",\n    \"veryslow\",\n]\nFONT = cv2.FONT_HERSHEY_SIMPLEX  # Font for bounding-box-style labels\n\n# Detection parameters\nIOU_THRESHOLD = 0.5  # IoU threshold for considering boxes related\n\n# Hitmarker parameters\nHITMARKER_SIZE = 20  # Size of the hitmarker in pixels\nHITMARKER_GAP = 3  # Size of the empty space in the middle (reduced from 8)\nHITMARKER_THICKNESS = 2  # Thickness of hitmarker lines\nHITMARKER_COLOR = (255, 255, 255)  # White color for hitmarker\nHITMARKER_SHADOW_COLOR = (80, 80, 80)  # Lighter gray for shadow effect\nHITMARKER_SHADOW_OFFSET = 1  # Smaller shadow offset\n\n\ndef load_moondream():\n    \"\"\"Load Moondream model and tokenizer.\"\"\"\n    model = AutoModelForCausalLM.from_pretrained(\n        \"vikhyatk/moondream2\", trust_remote_code=True, device_map={\"\": \"cuda\"}\n    )\n    tokenizer = AutoTokenizer.from_pretrained(\"vikhyatk/moondream2\")\n    return model, tokenizer\n\n\ndef get_video_properties(video_path):\n    \"\"\"Get basic video properties.\"\"\"\n    video = cv2.VideoCapture(video_path)\n    fps = video.get(cv2.CAP_PROP_FPS)\n    frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))\n    width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))\n    height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))\n    video.release()\n    return {\"fps\": fps, \"frame_count\": frame_count, \"width\": width, \"height\": height}\n\n\ndef is_valid_box(box):\n    \"\"\"Check if box coordinates are reasonable.\"\"\"\n    x1, y1, x2, y2 = box\n    width = x2 - x1\n    height = y2 - y1\n\n    # Reject boxes that are too large (over 90% of frame in both dimensions)\n    if width > 0.9 and height > 0.9:\n        return False\n\n    # Reject boxes that are too small (less than 1% of frame)\n    if width < 0.01 or height < 0.01:\n        return False\n\n    return True\n\n\ndef split_frame_into_tiles(frame, rows, cols):\n    \"\"\"Split a frame into a grid of tiles.\"\"\"\n    height, width = frame.shape[:2]\n    tile_height = height // rows\n    tile_width = width // cols\n    tiles = []\n    tile_positions = []\n\n    for i in range(rows):\n        for j in range(cols):\n            y1 = i * tile_height\n            y2 = (i + 1) * tile_height if i < rows - 1 else height\n            x1 = j * tile_width\n            x2 = (j + 1) * tile_width if j < cols - 1 else width\n\n            tile = frame[y1:y2, x1:x2]\n            tiles.append(tile)\n            tile_positions.append((x1, y1, x2, y2))\n\n    return tiles, tile_positions\n\n\ndef convert_tile_coords_to_frame(box, tile_pos, frame_shape):\n    \"\"\"Convert coordinates from tile space to frame space.\"\"\"\n    frame_height, frame_width = frame_shape[:2]\n    tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos\n    tile_width = tile_x2 - tile_x1\n    tile_height = tile_y2 - tile_y1\n\n    x1_tile_abs = box[0] * tile_width\n    y1_tile_abs = box[1] * tile_height\n    x2_tile_abs = box[2] * tile_width\n    y2_tile_abs = box[3] * tile_height\n\n    x1_frame_abs = tile_x1 + x1_tile_abs\n    y1_frame_abs = tile_y1 + y1_tile_abs\n    x2_frame_abs = tile_x1 + x2_tile_abs\n    y2_frame_abs = tile_y1 + y2_tile_abs\n\n    x1_norm = x1_frame_abs / frame_width\n    y1_norm = y1_frame_abs / frame_height\n    x2_norm = x2_frame_abs / frame_width\n    y2_norm = y2_frame_abs / frame_height\n\n    x1_norm = max(0.0, min(1.0, x1_norm))\n    y1_norm = max(0.0, min(1.0, y1_norm))\n    x2_norm = max(0.0, min(1.0, x2_norm))\n    y2_norm = max(0.0, min(1.0, y2_norm))\n\n    return [x1_norm, y1_norm, x2_norm, y2_norm]\n\n\ndef merge_tile_detections(tile_detections, iou_threshold=0.5):\n    \"\"\"Merge detections from different tiles using NMS-like approach.\"\"\"\n    if not tile_detections:\n        return []\n\n    all_boxes = []\n    all_keywords = []\n\n    # Collect all boxes and their keywords\n    for detections in tile_detections:\n        for box, keyword in detections:\n            all_boxes.append(box)\n            all_keywords.append(keyword)\n\n    if not all_boxes:\n        return []\n\n    # Convert to numpy for easier processing\n    boxes = np.array(all_boxes)\n\n    # Calculate areas\n    x1 = boxes[:, 0]\n    y1 = boxes[:, 1]\n    x2 = boxes[:, 2]\n    y2 = boxes[:, 3]\n    areas = (x2 - x1) * (y2 - y1)\n\n    # Sort boxes by area\n    order = areas.argsort()[::-1]\n\n    keep = []\n    while order.size > 0:\n        i = order[0]\n        keep.append(i)\n\n        if order.size == 1:\n            break\n\n        # Calculate IoU with rest of boxes\n        xx1 = np.maximum(x1[i], x1[order[1:]])\n        yy1 = np.maximum(y1[i], y1[order[1:]])\n        xx2 = np.minimum(x2[i], x2[order[1:]])\n        yy2 = np.minimum(y2[i], y2[order[1:]])\n\n        w = np.maximum(0.0, xx2 - xx1)\n        h = np.maximum(0.0, yy2 - yy1)\n        inter = w * h\n\n        ovr = inter / (areas[i] + areas[order[1:]] - inter)\n\n        # Get indices of boxes with IoU less than threshold\n        inds = np.where(ovr <= iou_threshold)[0]\n        order = order[inds + 1]\n\n    return [(all_boxes[i], all_keywords[i]) for i in keep]\n\n\ndef detect_ads_in_frame(model, tokenizer, image, detect_keyword, rows=1, cols=1):\n    \"\"\"Detect objects in a frame using grid-based detection.\"\"\"\n    if rows == 1 and cols == 1:\n        return detect_ads_in_frame_single(model, tokenizer, image, detect_keyword)\n\n    # Convert numpy array to PIL Image if needed\n    if not isinstance(image, Image.Image):\n        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n\n    # Split frame into tiles\n    tiles, tile_positions = split_frame_into_tiles(image, rows, cols)\n\n    # Process each tile\n    tile_detections = []\n    for tile, tile_pos in zip(tiles, tile_positions):\n        # Convert tile to PIL Image\n        tile_pil = Image.fromarray(tile)\n\n        # Detect objects in tile\n        response = model.detect(tile_pil, detect_keyword)\n\n        if response and \"objects\" in response and response[\"objects\"]:\n            objects = response[\"objects\"]\n            tile_objects = []\n\n            for obj in objects:\n                if all(k in obj for k in [\"x_min\", \"y_min\", \"x_max\", \"y_max\"]):\n                    box = [obj[\"x_min\"], obj[\"y_min\"], obj[\"x_max\"], obj[\"y_max\"]]\n\n                    if is_valid_box(box):\n                        # Convert tile coordinates to frame coordinates\n                        frame_box = convert_tile_coords_to_frame(\n                            box, tile_pos, image.shape\n                        )\n                        tile_objects.append((frame_box, detect_keyword))\n\n            if tile_objects:  # Only append if we found valid objects\n                tile_detections.append(tile_objects)\n\n    # Merge detections from all tiles\n    merged_detections = merge_tile_detections(tile_detections)\n    return merged_detections\n\n\ndef detect_ads_in_frame_single(model, tokenizer, image, detect_keyword):\n    \"\"\"Single-frame detection function.\"\"\"\n    detected_objects = []\n\n    # Convert numpy array to PIL Image if needed\n    if not isinstance(image, Image.Image):\n        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n\n    # Detect objects\n    response = model.detect(image, detect_keyword)\n\n    # Check if we have valid objects\n    if response and \"objects\" in response and response[\"objects\"]:\n        objects = response[\"objects\"]\n\n        for obj in objects:\n            if all(k in obj for k in [\"x_min\", \"y_min\", \"x_max\", \"y_max\"]):\n                box = [obj[\"x_min\"], obj[\"y_min\"], obj[\"x_max\"], obj[\"y_max\"]]\n                # If box is valid (not full-frame), add it\n                if is_valid_box(box):\n                    detected_objects.append((box, detect_keyword))\n\n    return detected_objects\n\n\ndef draw_hitmarker(\n    frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True\n):\n    \"\"\"Draw a COD-style hitmarker cross with more space in the middle.\"\"\"\n    half_size = size // 2\n\n    # Draw shadow first if enabled\n    if shadow:\n        # Top-left to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x - half_size + HITMARKER_SHADOW_OFFSET,\n                center_y - half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Top-right to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x + half_size + HITMARKER_SHADOW_OFFSET,\n                center_y - half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Bottom-left to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x - half_size + HITMARKER_SHADOW_OFFSET,\n                center_y + half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n        # Bottom-right to center shadow\n        cv2.line(\n            frame,\n            (\n                center_x + half_size + HITMARKER_SHADOW_OFFSET,\n                center_y + half_size + HITMARKER_SHADOW_OFFSET,\n            ),\n            (\n                center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n                center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET,\n            ),\n            HITMARKER_SHADOW_COLOR,\n            HITMARKER_THICKNESS,\n        )\n\n    # Draw main hitmarker\n    # Top-left to center\n    cv2.line(\n        frame,\n        (center_x - half_size, center_y - half_size),\n        (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Top-right to center\n    cv2.line(\n        frame,\n        (center_x + half_size, center_y - half_size),\n        (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Bottom-left to center\n    cv2.line(\n        frame,\n        (center_x - half_size, center_y + half_size),\n        (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n    # Bottom-right to center\n    cv2.line(\n        frame,\n        (center_x + half_size, center_y + half_size),\n        (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP),\n        color,\n        HITMARKER_THICKNESS,\n    )\n\n\ndef draw_ad_boxes(frame, detected_objects, detect_keyword, box_style=\"censor\"):\n    \"\"\"Draw detection visualizations over detected objects.\n\n    Args:\n        frame: The video frame to draw on\n        detected_objects: List of (box, keyword) tuples\n        detect_keyword: The detection keyword\n        box_style: Visualization style ('censor', 'bounding-box', or 'hitmarker')\n    \"\"\"\n    height, width = frame.shape[:2]\n\n    for box, keyword in detected_objects:\n        try:\n            # Convert normalized coordinates to pixel coordinates\n            x1 = int(box[0] * width)\n            y1 = int(box[1] * height)\n            x2 = int(box[2] * width)\n            y2 = int(box[3] * height)\n\n            # Ensure coordinates are within frame boundaries\n            x1 = max(0, min(x1, width - 1))\n            y1 = max(0, min(y1, height - 1))\n            x2 = max(0, min(x2, width - 1))\n            y2 = max(0, min(y2, height - 1))\n\n            # Only draw if box has reasonable size\n            if x2 > x1 and y2 > y1:\n                if box_style == \"censor\":\n                    # Draw solid black rectangle\n                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1)\n                elif box_style == \"bounding-box\":\n                    # Draw red rectangle with thicker line\n                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3)\n\n                    # Add label with background\n                    label = detect_keyword  # Use exact capitalization\n                    label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0]\n                    cv2.rectangle(\n                        frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1\n                    )\n                    cv2.putText(\n                        frame,\n                        label,\n                        (x1, y1 - 6),\n                        FONT,\n                        0.7,\n                        (255, 255, 255),\n                        2,\n                        cv2.LINE_AA,\n                    )\n                elif box_style == \"hitmarker\":\n                    # Calculate center of the box\n                    center_x = (x1 + x2) // 2\n                    center_y = (y1 + y2) // 2\n\n                    # Draw hitmarker at the center\n                    draw_hitmarker(frame, center_x, center_y)\n\n                    # Optional: Add small label above hitmarker\n                    label = detect_keyword  # Use exact capitalization\n                    label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0]\n                    cv2.putText(\n                        frame,\n                        label,\n                        (center_x - label_size[0] // 2, center_y - HITMARKER_SIZE - 5),\n                        FONT,\n                        0.5,\n                        HITMARKER_COLOR,\n                        1,\n                        cv2.LINE_AA,\n                    )\n        except Exception as e:\n            print(f\"Error drawing {box_style} style box: {str(e)}\")\n\n    return frame\n\n\ndef filter_temporal_outliers(detections_dict):\n    \"\"\"Filter out extremely large detections that take up most of the frame.\n    Only keeps detections that are reasonable in size.\n\n    Args:\n        detections_dict: Dictionary of {frame_number: [(box, keyword), ...]}\n    \"\"\"\n    filtered_detections = {}\n\n    for t, detections in detections_dict.items():\n        # Only keep detections that aren't too large\n        valid_detections = []\n        for box, keyword in detections:\n            # Calculate box size as percentage of frame\n            width = box[2] - box[0]\n            height = box[3] - box[1]\n            area = width * height\n\n            # If box is less than 90% of frame, keep it\n            if area < 0.9:\n                valid_detections.append((box, keyword))\n\n        if valid_detections:\n            filtered_detections[t] = valid_detections\n\n    return filtered_detections\n\n\ndef describe_frames(\n    video_path, model, tokenizer, detect_keyword, test_mode=False, rows=1, cols=1\n):\n    \"\"\"Extract and detect objects in frames.\"\"\"\n    props = get_video_properties(video_path)\n    fps = props[\"fps\"]\n\n    # If in test mode, only process first 3 seconds\n    if test_mode:\n        frame_count = min(int(fps * TEST_MODE_DURATION), props[\"frame_count\"])\n    else:\n        frame_count = props[\"frame_count\"]\n\n    ad_detections = {}  # Store detection results by frame number\n\n    print(\"Extracting frames and detecting objects...\")\n    video = cv2.VideoCapture(video_path)\n\n    # Process every frame\n    frame_count_processed = 0\n    with tqdm(total=frame_count) as pbar:\n        while frame_count_processed < frame_count:\n            ret, frame = video.read()\n            if not ret:\n                break\n\n            # Detect objects in the frame\n            detected_objects = detect_ads_in_frame(\n                model, tokenizer, frame, detect_keyword, rows=rows, cols=cols\n            )\n\n            # Store results for every frame, even if empty\n            ad_detections[frame_count_processed] = detected_objects\n\n            frame_count_processed += 1\n            pbar.update(1)\n\n    video.release()\n\n    if frame_count_processed == 0:\n        print(\"No frames could be read from video\")\n        return {}\n\n    # Filter out only extremely large detections\n    ad_detections = filter_temporal_outliers(ad_detections)\n    return ad_detections\n\n\ndef create_detection_video(\n    video_path,\n    ad_detections,\n    detect_keyword,\n    output_path=None,\n    ffmpeg_preset=\"medium\",\n    test_mode=False,\n    box_style=\"censor\",\n):\n    \"\"\"Create video with detection boxes.\"\"\"\n    if output_path is None:\n        # Create outputs directory if it doesn't exist\n        outputs_dir = os.path.join(\n            os.path.dirname(os.path.abspath(__file__)), \"outputs\"\n        )\n        os.makedirs(outputs_dir, exist_ok=True)\n\n        # Clean the detect_keyword for filename\n        safe_keyword = \"\".join(\n            x for x in detect_keyword if x.isalnum() or x in (\" \", \"_\", \"-\")\n        )\n        safe_keyword = safe_keyword.replace(\" \", \"_\")\n\n        # Create output filename\n        base_name = os.path.splitext(os.path.basename(video_path))[0]\n        output_path = os.path.join(\n            outputs_dir, f\"{box_style}_{safe_keyword}_{base_name}.mp4\"\n        )\n\n    print(f\"Will save output to: {output_path}\")\n\n    props = get_video_properties(video_path)\n    fps, width, height = props[\"fps\"], props[\"width\"], props[\"height\"]\n\n    # If in test mode, only process first few seconds\n    if test_mode:\n        frame_count = min(int(fps * TEST_MODE_DURATION), props[\"frame_count\"])\n    else:\n        frame_count = props[\"frame_count\"]\n\n    video = cv2.VideoCapture(video_path)\n\n    # Create temp output path by adding _temp before the extension\n    base, ext = os.path.splitext(output_path)\n    temp_output = f\"{base}_temp{ext}\"\n\n    out = cv2.VideoWriter(\n        temp_output, cv2.VideoWriter_fourcc(*\"mp4v\"), fps, (width, height)\n    )\n\n    print(\"Creating detection video...\")\n    frame_count_processed = 0\n\n    with tqdm(total=frame_count) as pbar:\n        while frame_count_processed < frame_count:\n            ret, frame = video.read()\n            if not ret:\n                break\n\n            # Get detections for this exact frame\n            if frame_count_processed in ad_detections:\n                current_detections = ad_detections[frame_count_processed]\n                if current_detections:\n                    frame = draw_ad_boxes(\n                        frame, current_detections, detect_keyword, box_style=box_style\n                    )\n\n            out.write(frame)\n            frame_count_processed += 1\n            pbar.update(1)\n\n    video.release()\n    out.release()\n\n    # Convert to web-compatible format more efficiently\n    try:\n        subprocess.run(\n            [\n                \"ffmpeg\",\n                \"-y\",\n                \"-i\",\n                temp_output,\n                \"-c:v\",\n                \"libx264\",\n                \"-preset\",\n                ffmpeg_preset,\n                \"-crf\",\n                \"23\",\n                \"-movflags\",\n                \"+faststart\",  # Better web playback\n                \"-loglevel\",\n                \"error\",\n                output_path,\n            ],\n            check=True,\n        )\n\n        os.remove(temp_output)  # Remove the temporary file\n\n        if not os.path.exists(output_path):\n            print(\n                f\"Warning: FFmpeg completed but output file not found at {output_path}\"\n            )\n            return None\n\n        return output_path\n\n    except subprocess.CalledProcessError as e:\n        print(f\"Error running FFmpeg: {str(e)}\")\n        if os.path.exists(temp_output):\n            os.remove(temp_output)\n        return None\n\n\ndef process_video(\n    video_path,\n    detect_keyword,\n    test_mode=False,\n    ffmpeg_preset=\"medium\",\n    rows=1,\n    cols=1,\n    box_style=\"censor\",\n):\n    \"\"\"Process a single video file.\"\"\"\n    print(f\"\\nProcessing: {video_path}\")\n    print(f\"Looking for: {detect_keyword}\")\n\n    # Load model\n    print(\"Loading Moondream model...\")\n    model, tokenizer = load_moondream()\n\n    # Process video - detect objects\n    ad_detections = describe_frames(\n        video_path, model, tokenizer, detect_keyword, test_mode, rows, cols\n    )\n\n    # Create video with detection boxes\n    output_path = create_detection_video(\n        video_path,\n        ad_detections,\n        detect_keyword,\n        ffmpeg_preset=ffmpeg_preset,\n        test_mode=test_mode,\n        box_style=box_style,\n    )\n\n    if output_path is None:\n        print(\"\\nError: Failed to create output video\")\n        return None\n\n    print(f\"\\nOutput saved to: {output_path}\")\n    return output_path\n\n\ndef main():\n    \"\"\"Process all videos in the inputs directory.\"\"\"\n    parser = argparse.ArgumentParser(\n        description=\"Detect objects in videos using Moondream2\"\n    )\n    parser.add_argument(\n        \"--test\", action=\"store_true\", help=\"Process only first 3 seconds of each video\"\n    )\n    parser.add_argument(\n        \"--preset\",\n        choices=FFMPEG_PRESETS,\n        default=\"medium\",\n        help=\"FFmpeg encoding preset (default: medium). Faster presets = lower quality\",\n    )\n    parser.add_argument(\n        \"--detect\",\n        type=str,\n        default=\"face\",\n        help='Object to detect in the video (default: face, use --detect \"thing to detect\" to override)',\n    )\n    parser.add_argument(\n        \"--rows\",\n        type=int,\n        default=1,\n        help=\"Number of rows to split each frame into (default: 1)\",\n    )\n    parser.add_argument(\n        \"--cols\",\n        type=int,\n        default=1,\n        help=\"Number of columns to split each frame into (default: 1)\",\n    )\n    parser.add_argument(\n        \"--box-style\",\n        choices=[\"censor\", \"bounding-box\", \"hitmarker\"],\n        default=\"censor\",\n        help=\"Style of detection visualization (default: censor)\",\n    )\n    args = parser.parse_args()\n\n    input_dir = \"inputs\"\n    os.makedirs(input_dir, exist_ok=True)\n    os.makedirs(\"outputs\", exist_ok=True)\n\n    video_files = [\n        f\n        for f in os.listdir(input_dir)\n        if f.lower().endswith((\".mp4\", \".avi\", \".mov\", \".mkv\", \".webm\"))\n    ]\n\n    if not video_files:\n        print(\"No video files found in 'inputs' directory\")\n        return\n\n    print(f\"Found {len(video_files)} videos to process\")\n    print(f\"Will detect: {args.detect}\")\n    if args.test:\n        print(\"Running in test mode - processing only first 3 seconds of each video\")\n    print(f\"Using FFmpeg preset: {args.preset}\")\n    print(f\"Grid size: {args.rows}x{args.cols}\")\n    print(f\"Box style: {args.box_style}\")\n\n    success_count = 0\n    for video_file in video_files:\n        video_path = os.path.join(input_dir, video_file)\n        output_path = process_video(\n            video_path,\n            args.detect,\n            test_mode=args.test,\n            ffmpeg_preset=args.preset,\n            rows=args.rows,\n            cols=args.cols,\n            box_style=args.box_style,\n        )\n        if output_path:\n            success_count += 1\n\n    print(\n        f\"\\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos.\"\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "recipes/promptable-video-redaction/packages.txt",
    "content": "libvips\nffmpeg"
  },
  {
    "path": "recipes/promptable-video-redaction/requirements.txt",
    "content": "gradio>=4.0.0\ntorch\ntransformers\nopencv-python\npillow\nnumpy\ntqdm\nffmpeg-python\neinops\npyvips\naccelerate"
  },
  {
    "path": "requirements.txt",
    "content": "torch==2.8.0\nPillow-SIMD==9.5.0.post2\ntransformers==4.56.1\npyvips-binary==8.16.0\npyvips==2.2.3\naccelerate==1.10.1\ngradio==4.38.1\n\n# Needed for running evals\ndatasets==3.2.0\neditdistance==0.8.1\n"
  },
  {
    "path": "sample.py",
    "content": "import argparse\nfrom queue import Queue\nfrom threading import Thread\n\nimport torch\nfrom PIL import Image\nfrom transformers import AutoTokenizer, TextIteratorStreamer\n\nfrom moondream.hf import LATEST_REVISION, Moondream, detect_device\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--image\", type=str, required=True)\n    parser.add_argument(\"--prompt\", type=str, required=False)\n    parser.add_argument(\"--caption\", action=\"store_true\")\n    parser.add_argument(\"--cpu\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if args.cpu:\n        device = torch.device(\"cpu\")\n        dtype = torch.float32\n    else:\n        device, dtype = detect_device()\n        if device != torch.device(\"cpu\"):\n            print(\"Using device:\", device)\n            print(\"If you run into issues, pass the `--cpu` flag to this script.\")\n            print()\n\n    image_path = args.image\n    prompt = args.prompt\n\n    model_id = \"vikhyatk/moondream2\"\n    tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)\n    moondream = Moondream.from_pretrained(\n        model_id,\n        revision=LATEST_REVISION,\n        torch_dtype=dtype,\n    ).to(device=device)\n    moondream.eval()\n\n    image = Image.open(image_path)\n\n    if args.caption:\n        print(moondream.caption(images=[image], tokenizer=tokenizer)[0])\n    else:\n        image_embeds = moondream.encode_image(image)\n\n        if prompt is None:\n            chat_history = \"\"\n\n            while True:\n                question = input(\"> \")\n\n                result_queue = Queue()\n\n                streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)\n\n                # Separate direct arguments from keyword arguments\n                thread_args = (image_embeds, question, tokenizer, chat_history)\n                thread_kwargs = {\"streamer\": streamer, \"result_queue\": result_queue}\n\n                thread = Thread(\n                    target=moondream.answer_question,\n                    args=thread_args,\n                    kwargs=thread_kwargs,\n                )\n                thread.start()\n\n                buffer = \"\"\n                for new_text in streamer:\n                    buffer += new_text\n                    if not new_text.endswith(\"<\") and not new_text.endswith(\"END\"):\n                        print(buffer, end=\"\", flush=True)\n                        buffer = \"\"\n                print(buffer)\n\n                thread.join()\n\n                answer = result_queue.get()\n                chat_history += f\"Question: {question}\\n\\nAnswer: {answer}\\n\\n\"\n        else:\n            print(\">\", prompt)\n            answer = moondream.answer_question(image_embeds, prompt, tokenizer)\n            print(answer)\n"
  },
  {
    "path": "tests/test_image_crops.py",
    "content": "import numpy as np\nimport torch\nfrom moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops\n\n\ndef test_overlap_crop_basic():\n    # Create a test image\n    test_image = np.zeros((800, 600, 3), dtype=np.uint8)\n    # Add a recognizable pattern - white rectangle in the middle\n    test_image[300:500, 200:400] = 255\n\n    result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)\n\n    # Check basic properties\n    assert result[\"crops\"][0].shape == (378, 378, 3)\n    assert len(result[\"crops\"]) > 1\n    assert all(crop.shape == (378, 378, 3) for crop in result[\"crops\"])\n    assert len(result[\"tiling\"]) == 2\n\n\ndef test_overlap_crop_small_image():\n    # Test with image smaller than crop size\n    test_image = np.zeros((300, 200, 3), dtype=np.uint8)\n    result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)\n\n    # Should still produce valid output\n    assert result[\"crops\"][0].shape == (378, 378, 3)\n    assert len(result[\"crops\"]) == 2\n    assert result[\"tiling\"] == (1, 1)\n\n\ndef test_reconstruction():\n    # Create a test image\n    test_image = np.zeros((800, 600, 3), dtype=np.uint8)\n    # Add a recognizable pattern\n    test_image[300:500, 200:400] = 255\n\n    # Crop and reconstruct\n    result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)\n    crops_tensor = [torch.from_numpy(crop) for crop in result[\"crops\"][1:]]\n    reconstructed = reconstruct_from_crops(\n        crops_tensor, result[\"tiling\"], overlap_margin=4\n    )\n\n    # Convert back to numpy for comparison\n    reconstructed_np = reconstructed.numpy()\n\n    # The reconstructed image should be similar to the input\n    # We can't expect exact equality due to resizing operations\n    # but the white rectangle should still be visible in the middle\n    center_reconstructed = reconstructed_np[\n        reconstructed_np.shape[0] // 2 - 100 : reconstructed_np.shape[0] // 2 + 100,\n        reconstructed_np.shape[1] // 2 - 100 : reconstructed_np.shape[1] // 2 + 100,\n    ].mean()\n\n    # The center region should be significantly brighter than the edges\n    assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100\n"
  },
  {
    "path": "webcam_gradio_demo.py",
    "content": "import argparse\nimport time\nfrom threading import Thread\n\nimport gradio as gr\nimport torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer\n\nfrom moondream.hf import LATEST_REVISION, detect_device\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\"--cpu\", action=\"store_true\")\nargs = parser.parse_args()\n\nif args.cpu:\n    device = torch.device(\"cpu\")\n    dtype = torch.float32\nelse:\n    device, dtype = detect_device()\n    if device != torch.device(\"cpu\"):\n        print(\"Using device:\", device)\n        print(\"If you run into issues, pass the `--cpu` flag to this script.\")\n        print()\n\nmodel_id = \"vikhyatk/moondream2\"\ntokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION)\nmoondream = AutoModelForCausalLM.from_pretrained(\n    model_id, trust_remote_code=True, revision=LATEST_REVISION\n).to(device=device, dtype=dtype)\nmoondream.eval()\n\n\ndef answer_question(img, prompt):\n    image_embeds = moondream.encode_image(img)\n    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)\n    thread = Thread(\n        target=moondream.answer_question,\n        kwargs={\n            \"image_embeds\": image_embeds,\n            \"question\": prompt,\n            \"tokenizer\": tokenizer,\n            \"streamer\": streamer,\n        },\n    )\n    thread.start()\n\n    buffer = \"\"\n    for new_text in streamer:\n        buffer += new_text\n        yield buffer\n\n\nwith gr.Blocks() as demo:\n    gr.Markdown(\"# 🌔 moondream\")\n\n    gr.HTML(\n        \"\"\"\n        <style type=\"text/css\">\n            .md_output p {\n                padding-top: 1rem;\n                font-size: 1.2rem !important;\n            }\n        </style>\n        \"\"\"\n    )\n\n    with gr.Row():\n        prompt = gr.Textbox(\n            label=\"Prompt\",\n            value=\"What's going on? Respond with a single sentence.\",\n            interactive=True,\n        )\n    with gr.Row():\n        img = gr.Image(type=\"pil\", label=\"Upload an Image\", streaming=True)\n        output = gr.Markdown(elem_classes=[\"md_output\"])\n\n    latest_img = None\n    latest_prompt = prompt.value\n\n    @img.change(inputs=[img])\n    def img_change(img):\n        global latest_img\n        latest_img = img\n\n    @prompt.change(inputs=[prompt])\n    def prompt_change(prompt):\n        global latest_prompt\n        latest_prompt = prompt\n\n    @demo.load(outputs=[output])\n    def live_video():\n        while True:\n            if latest_img is None:\n                time.sleep(0.1)\n            else:\n                for text in answer_question(latest_img, latest_prompt):\n                    if len(text) > 0:\n                        yield text\n\n\ndemo.queue().launch(debug=True)\n"
  }
]