Repository: vikhyat/moondream Branch: main Commit: 9fe3ad77616b Files: 66 Total size: 343.1 KB Directory structure: gitextract_g04l91uk/ ├── .github/ │ └── workflows/ │ ├── pylint.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── batch_generate_example.py ├── gradio_demo.py ├── moondream/ │ ├── __init__.py │ ├── config/ │ │ ├── config_md05.json │ │ └── config_md2.json │ ├── eval/ │ │ ├── chartqa.py │ │ ├── coco_map.py │ │ ├── countbenchqa.py │ │ ├── docvqa.py │ │ ├── eval_all.py │ │ ├── gazefollow.py │ │ ├── mmstar.py │ │ ├── naturalbench.py │ │ ├── pope.py │ │ ├── realworldqa.py │ │ ├── tallyqa.py │ │ ├── textvqa.py │ │ ├── utils.py │ │ └── waste_detection.py │ └── torch/ │ ├── config.py │ ├── hf_moondream.py │ ├── hf_release.py │ ├── image_crops.py │ ├── layers.py │ ├── lora.py │ ├── moondream.py │ ├── region.py │ ├── rope.py │ ├── sample.py │ ├── text.py │ ├── utils.py │ ├── vision.py │ └── weights.py ├── notebooks/ │ └── RepEng.ipynb ├── recipes/ │ ├── gaze-detection-video/ │ │ ├── .gitignore │ │ ├── README.md │ │ ├── gaze-detection-video.py │ │ ├── input/ │ │ │ └── .gitkeep │ │ ├── output/ │ │ │ └── .gitkeep │ │ ├── requirements.txt │ │ └── temp/ │ │ └── .gitkeep │ ├── promptable-content-moderation/ │ │ ├── .gitignore │ │ ├── README.md │ │ ├── app.py │ │ ├── deep_sort_integration.py │ │ ├── main.py │ │ ├── packages.txt │ │ ├── persistence.py │ │ ├── requirements.txt │ │ ├── video_visualization.py │ │ └── visualization.py │ └── promptable-video-redaction/ │ ├── .gitignore │ ├── README.md │ ├── app.py │ ├── main.py │ ├── packages.txt │ └── requirements.txt ├── requirements.txt ├── sample.py ├── tests/ │ └── test_image_crops.py └── webcam_gradio_demo.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/pylint.yml ================================================ name: Lint on: push: branches: [ main ] pull_request: branches: [ main ] permissions: contents: read jobs: build: runs-on: ubuntu-latest permissions: contents: read strategy: matrix: python-version: ["3.12"] # Run lint checks only on latest Python version steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install autoflake black - name: Checking for unused imports run: | autoflake -c -r . - name: Checking code style run: | black --check . ================================================ FILE: .github/workflows/test.yml ================================================ name: Tests on: push: branches: [ main ] pull_request: branches: [ main ] jobs: test: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install pytest pip install -r requirements.txt - name: Run tests run: | python -m pytest tests/test_image_crops.py -v ================================================ FILE: .gitignore ================================================ .venv __pycache__ checkpoints data /pyproject.toml poetry.lock dist clients/python/moondream/torch wandb/ moondream_finetune.safetensors ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # 🌔 moondream a tiny vision language model that kicks ass and runs anywhere [Website](https://moondream.ai/) | [Demo](https://moondream.ai/playground) ## Examples | Image | Example | | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ![](assets/demo-1.jpg) | **What is the girl doing?**
The girl is sitting at a table and eating a large hamburger.

**What color is the girl's hair?**
The girl's hair is white. | | ![](assets/demo-2.jpg) | **What is this?**
This is a computer server rack, which is a device used to store and manage multiple computer servers. The rack is filled with various computer servers, each with their own dedicated space and power supply. The servers are connected to the rack via multiple cables, indicating that they are part of a larger system. The rack is placed on a carpeted floor, and there is a couch nearby, suggesting that the setup is in a living or entertainment area.

**What is behind the stand?**
Behind the stand, there is a brick wall. | ## About Moondream is a highly efficient open-source vision language model that combines powerful image understanding capabilities with a remarkably small footprint. It's designed to be versatile and accessible, capable of running on a wide range of devices and platforms. The project offers two model variants: - **Moondream 2B**: The primary model with 2 billion parameters, offering robust performance for general-purpose image understanding tasks including captioning, visual question answering, and object detection. - **Moondream 0.5B**: A compact 500 million parameter model specifically optimized as a distillation target for edge devices, enabling efficient deployment on resource-constrained hardware while maintaining impressive capabilities. ## How to use Moondream can be run locally, or in the cloud. Please refer to the [Getting Started](https://moondream.ai/c/docs/quickstart) page for details. ## Special thanks * [Modal](https://modal.com/?utm_source=github&utm_medium=github&utm_campaign=moondream) - Modal lets you run jobs in the cloud, by just writing a few lines of Python. Here's an [example of how to run Moondream on Modal](https://github.com/m87-labs/moondream-examples/tree/main/quickstart/modal). ================================================ FILE: batch_generate_example.py ================================================ from PIL import Image from transformers import AutoTokenizer from moondream.hf import LATEST_REVISION, Moondream, detect_device device, dtype = detect_device() model_id = "vikhyatk/moondream2" tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) moondream = Moondream.from_pretrained( model_id, revision=LATEST_REVISION, torch_dtype=dtype, ).to(device=device) moondream.eval() image1 = Image.open("assets/demo-1.jpg") image2 = Image.open("assets/demo-2.jpg") prompts = [ "What is the girl doing?", "What color is the girl's hair?", "What is this?", "What is behind the stand?", ] answers = moondream.batch_answer( images=[image1, image1, image2, image2], prompts=prompts, tokenizer=tokenizer, ) for question, answer in zip(prompts, answers): print(f"Q: {question}") print(f"A: {answer}") print() ================================================ FILE: gradio_demo.py ================================================ import argparse import re from threading import Thread import gradio as gr import torch from PIL import ImageDraw from torchvision.transforms.v2 import Resize from transformers import AutoTokenizer, TextIteratorStreamer from moondream.hf import LATEST_REVISION, Moondream, detect_device parser = argparse.ArgumentParser() parser.add_argument("--cpu", action="store_true") args = parser.parse_args() if args.cpu: device = torch.device("cpu") dtype = torch.float32 else: device, dtype = detect_device() if device != torch.device("cpu"): print("Using device:", device) print("If you run into issues, pass the `--cpu` flag to this script.") print() model_id = "vikhyatk/moondream2" tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) moondream = Moondream.from_pretrained( model_id, revision=LATEST_REVISION, torch_dtype=dtype ).to(device=device) moondream.eval() def answer_question(img, prompt): image_embeds = moondream.encode_image(img) streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) thread = Thread( target=moondream.answer_question, kwargs={ "image_embeds": image_embeds, "question": prompt, "tokenizer": tokenizer, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer def extract_floats(text): # Regular expression to match an array of four floating point numbers pattern = r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]" match = re.search(pattern, text) if match: # Extract the numbers and convert them to floats return [float(num) for num in match.groups()] return None # Return None if no match is found def extract_bbox(text): bbox = None if extract_floats(text) is not None: x1, y1, x2, y2 = extract_floats(text) bbox = (x1, y1, x2, y2) return bbox def process_answer(img, answer): if extract_bbox(answer) is not None: x1, y1, x2, y2 = extract_bbox(answer) draw_image = Resize(768)(img) width, height = draw_image.size x1, x2 = int(x1 * width), int(x2 * width) y1, y2 = int(y1 * height), int(y2 * height) bbox = (x1, y1, x2, y2) ImageDraw.Draw(draw_image).rectangle(bbox, outline="red", width=3) return gr.update(visible=True, value=draw_image) return gr.update(visible=False, value=None) with gr.Blocks() as demo: gr.Markdown( """ # 🌔 moondream """ ) with gr.Row(): prompt = gr.Textbox(label="Input Prompt", value="Describe this image.", scale=4) submit = gr.Button("Submit") with gr.Row(): img = gr.Image(type="pil", label="Upload an Image") with gr.Column(): output = gr.Markdown(label="Response") ann = gr.Image(visible=False, label="Annotated Image") submit.click(answer_question, [img, prompt], output) prompt.submit(answer_question, [img, prompt], output) output.change(process_answer, [img, output], ann, show_progress=False) demo.queue().launch(debug=True) ================================================ FILE: moondream/__init__.py ================================================ ================================================ FILE: moondream/config/config_md05.json ================================================ { "text": { "dim": 1024, "ff_dim": 4096, "n_layers": 24, "vocab_size": 51200, "max_context": 2048, "n_heads": 16, "prefix_attn": 730 }, "vision": { "enc_dim": 720, "enc_patch_size": 14, "enc_n_layers": 27, "enc_ff_dim": 2690, "enc_n_heads": 10, "proj_out_dim": 1024, "crop_size": 378, "in_channels": 3, "max_crops": 12, "overlap_margin": 4, "proj_inner_dim": 8192 }, "region": { "dim": 1024, "coord_feat_dim": 256, "coord_out_dim": 1024, "size_feat_dim": 512, "size_out_dim": 2048, "inner_dim": 8192 }, "tokenizer": { "bos_id": 50256, "eos_id": 50256, "templates": { "caption": { "short": [ 198, 198, 16438, 8305, 25 ], "normal": [ 198, 198, 24334, 1159, 25 ] }, "query": { "prefix": [ 198, 198, 24361, 25 ], "suffix": [ 198, 198, 33706, 25 ] }, "detect": { "prefix": [ 198, 198, 47504, 25 ], "suffix": [ 628 ] }, "point": { "prefix": [ 198, 198, 12727, 25 ], "suffix": [ 628 ] } } } } ================================================ FILE: moondream/config/config_md2.json ================================================ { "text": { "dim": 2048, "ff_dim": 8192, "n_layers": 24, "vocab_size": 51200, "max_context": 2048, "n_heads": 32, "prefix_attn": 730 }, "vision": { "enc_dim": 1152, "enc_patch_size": 14, "enc_n_layers": 27, "enc_ff_dim": 4304, "enc_n_heads": 16, "proj_out_dim": 2048, "crop_size": 378, "in_channels": 3, "max_crops": 12, "overlap_margin": 4, "proj_inner_dim": 8192 }, "region": { "dim": 2048, "coord_feat_dim": 256, "coord_out_dim": 1024, "size_feat_dim": 512, "size_out_dim": 2048, "inner_dim": 8192 }, "tokenizer": { "bos_id": 50256, "eos_id": 50256, "templates": { "caption": { "short": [ 198, 198, 16438, 8305, 25 ], "normal": [ 198, 198, 24334, 1159, 25 ] }, "query": { "prefix": [ 198, 198, 24361, 25 ], "suffix": [ 198, 198, 33706, 25 ] }, "detect": { "prefix": [ 198, 198, 47504, 25 ], "suffix": [ 628 ] }, "point": { "prefix": [ 198, 198, 12727, 25 ], "suffix": [ 628 ] } } } } ================================================ FILE: moondream/eval/chartqa.py ================================================ import argparse import datasets import torch from tqdm import tqdm import json from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model PREFIX = "Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. " # https://github.com/google-research/pix2struct/blob/main/pix2struct/metrics.py#L81 def relaxed_correctness( target: str, prediction: str, max_relative_change: float = 0.05 ) -> bool: """Calculates relaxed correctness. The correctness tolerates certain error ratio defined by max_relative_change. See https://arxiv.org/pdf/2203.10244.pdf, end of section 5.1: “Following Methani et al. (2020), we use a relaxed accuracy measure for the numeric answers to allow a minor inaccuracy that may result from the automatic data extraction process. We consider an answer to be correct if it is within 5% of the gold answer. For non-numeric answers, we still need an exact match to consider an answer to be correct.” Args: target: Target string. prediction: Predicted string. max_relative_change: Maximum relative change. Returns: Whether the prediction was correct given the specified tolerance. """ def _to_float(text): try: if text.endswith("%"): # Convert percentages to floats. return float(text.rstrip("%")) / 100.0 else: return float(text) except ValueError: return None prediction = str(prediction) target = str(target) prediction_float = _to_float(prediction) target_float = _to_float(target) if prediction_float is not None and target_float: relative_change = abs(prediction_float - target_float) / abs(target_float) return relative_change <= max_relative_change else: return prediction == target def eval_chartqa(model, debug=False): dataset = datasets.load_dataset("vikhyatk/chartqa", split="test") correct = 0 total = 0 human_correct = 0 human_total = 0 results = [] for row in tqdm(dataset, disable=debug, desc="ChartQA"): image = row["image"] encoded_image = model.encode_image(image) result = [] for qa in row["qa"]: question = PREFIX + qa["question"] answer = qa["answer"] model_answer = model.query(encoded_image, question)["answer"] # Attempt to parse both answers into lists, otherwise try: answer_list = json.loads(answer) model_answer_list = json.loads(model_answer) if not ( isinstance(answer_list, list) and isinstance(model_answer_list, list) and len(answer_list) == len(model_answer_list) ): raise ValueError except: # If parsing fails or lengths are not equal, compare the strings directly instead answer_list = [answer] model_answer_list = [model_answer] total += 1 if qa["source"] == "human": human_total += 1 is_correct = False if all( relaxed_correctness( str(cur_answer).strip().lower(), str(cur_model_answer).strip().lower(), ) for cur_answer, cur_model_answer in zip(answer_list, model_answer_list) ): correct += 1 if qa["source"] == "human": human_correct += 1 is_correct = True if debug: print( f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}" ) print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}") print(f"Total Accuracy: {correct * 100 / total:.2f}") print("---------") result.append( { "question": question, "ground_truth": answer_list, "model_answer": model_answer_list, "is_correct": is_correct, "source": qa["source"], } ) results.append(result) return { "human_acc": human_correct * 100 / human_total, "total_acc": correct * 100 / total, "results": results, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() results = eval_chartqa(model, args.debug) print(f"Human Accuracy: {results['human_acc']:.2f}") print(f"Total Accuracy: {results['total_acc']:.2f}") ================================================ FILE: moondream/eval/coco_map.py ================================================ import argparse import datasets import torch import json import numpy as np from typing import List, Tuple from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model coco_classes = [ "None", "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "street sign", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "hat", "backpack", "umbrella", "shoe", "eye glasses", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "plate", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "mirror", "dining table", "window", "desk", "toilet", "door", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "blender", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", "hair brush", ] COCO_LABELS = {} for i, c in enumerate(coco_classes): COCO_LABELS[i] = c def calculate_iou( box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float] ) -> float: """Calculate IoU between two boxes (x1, y1, x2, y2 format)""" x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) intersection = max(0, x2 - x1) * max(0, y2 - y1) box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) return intersection / (box1_area + box2_area - intersection) def calculate_map( ground_truth_boxes: List[List[Tuple[float, float, float, float]]], predicted_boxes: List[List[Tuple[float, float, float, float, float]]], iou_threshold: float = 0.5, ) -> float: """ Calculate mAP for object detection Args: ground_truth_boxes: List of lists of ground truth boxes per image [(x1, y1, x2, y2)] predicted_boxes: List of lists of predicted boxes per image [(x1, y1, x2, y2, confidence)] iou_threshold: IoU threshold for considering a detection as correct Returns: mean Average Precision """ total_precision = 0 num_classes = len(ground_truth_boxes) for class_idx in range(num_classes): # Get all predictions and ground truths for this class gt_boxes = ground_truth_boxes[class_idx] pred_boxes = predicted_boxes[class_idx] # Sort predictions by confidence pred_boxes = sorted(pred_boxes, key=lambda x: x[4], reverse=True) # Initialize arrays for precision-recall calculation num_gt = len(gt_boxes) if num_gt == 0: continue tp = np.zeros(len(pred_boxes)) fp = np.zeros(len(pred_boxes)) gt_matched = [False] * num_gt # Match each prediction to ground truth for pred_idx, pred_box in enumerate(pred_boxes): max_iou = 0 max_idx = -1 # Find best matching ground truth box for gt_idx, gt_box in enumerate(gt_boxes): if gt_matched[gt_idx]: continue iou = calculate_iou(pred_box[:4], gt_box) if iou > max_iou: max_iou = iou max_idx = gt_idx # If IoU exceeds threshold, count as true positive if max_iou >= iou_threshold: tp[pred_idx] = 1 gt_matched[max_idx] = True else: fp[pred_idx] = 1 # Calculate cumulative precision and recall cumsum_tp = np.cumsum(tp) cumsum_fp = np.cumsum(fp) recalls = cumsum_tp / num_gt precisions = cumsum_tp / (cumsum_tp + cumsum_fp) # Calculate average precision using all points ap = 0 for t in np.arange(0, 1.1, 0.1): if np.sum(recalls >= t) == 0: p = 0 else: p = np.max(precisions[recalls >= t]) ap += p / 11 total_precision += ap return total_precision / num_classes def get_total_map(results_by_label, frequency_by_label): total_count = 0 total_map = 0 for results, frequency in zip( results_by_label.values(), frequency_by_label.values() ): cur_total_map = sum(results) total_map += cur_total_map total_count += frequency return total_map / total_count def eval_coco_map(model, iou_threshold=0.5, debug=False): dataset = datasets.load_dataset( "moondream/coco-val-2017-bbox-cleaned", split="validation" ) total = 0 results_by_label = {} # map to list of raw map results for each label frequency_by_label = {} # many images contain a given label for row in tqdm(dataset, disable=debug, desc="COCO mAP"): width = row["image"].width height = row["image"].height total += 1 objects = json.loads(row["objects"]) gt_label_to_boxes = {} for bbox, label in zip(objects["bbox"], objects["label"]): if label not in gt_label_to_boxes: gt_label_to_boxes[label] = [] x1, y1, w, h = bbox gt_label_to_boxes[label].append((x1, y1, x1 + w, y1 + h)) unique_labels = [label for label in set(objects["label"])] for label in unique_labels: encoded_image = model.encode_image(row["image"]) model_answer = model.detect(encoded_image, COCO_LABELS[label])["objects"] moondream_boxes = [] for box in model_answer: moondream_boxes.append( ( box["x_min"] * width, box["y_min"] * height, box["x_max"] * width, box["y_max"] * height, 1.0, # Using default confidence of 1.0 ) ) map_result = calculate_map( [gt_label_to_boxes[label]], [moondream_boxes], iou_threshold ) if debug and map_result == 0: print( f"0 Map result for index {total} and label {label} ({COCO_LABELS[label]})" ) if label not in results_by_label: results_by_label[label] = [] results_by_label[label].append(map_result) if label not in frequency_by_label: frequency_by_label[label] = 0 frequency_by_label[label] += 1 if debug and total % 100 == 0: print( f"Total map: {get_total_map(results_by_label, frequency_by_label)*100:.2f}, ({total} images)" ) return { # "results_by_label": results_by_label, # "frequency_by_label": frequency_by_label, "total_map": get_total_map(results_by_label, frequency_by_label), } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() # This repo doesn't have moondream deps we need if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_coco_map(model, 0.5, args.debug) print(f"Overall MAP: {result['total_map']*100:.2f}") ================================================ FILE: moondream/eval/countbenchqa.py ================================================ import argparse import datasets import torch from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. " def eval_countbenchqa(model, debug=False): dataset = datasets.load_dataset("vikhyatk/CountBenchQA", split="test") correct = 0 total = 0 results = [] for row in tqdm(dataset, disable=debug, desc="CountBenchQA"): image = row["image"] encoded_image = model.encode_image(image) question = PREFIX + row["question"] answer = str(row["number"]) model_answer = model.query(encoded_image, question)["answer"] is_correct = model_answer.strip().lower() == answer.strip().lower() results.append( { "question": question, "ground_truth": answer, "model_answer": model_answer, "is_correct": is_correct, } ) total += 1 if is_correct: correct += 1 elif debug: print(f"Question: {row['question']}") print(f"Answer: {answer}") print(f"Model Answer: {model_answer}") if debug: print(f"Correct: {correct}, Total: {total}") print(f"Accuracy: {correct * 100 / total:.2f}") print("---------") return { "acc": correct * 100 / total, "correct_count": correct, "total_count": total, "results": results, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) result = eval_countbenchqa(model, args.debug) print(f"Accuracy: {result['acc']:.2f}") print(f"Correct: {result['correct_count']}, Total: {result['total_count']}") ================================================ FILE: moondream/eval/docvqa.py ================================================ import argparse import editdistance from datasets import load_dataset from tqdm import tqdm import torch from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model SUFFIX = " The answer should be a short text span taken verbatim from the document." def get_anls(s1, s2): s1 = s1.lower().strip() s2 = s2.lower().strip() iou = 1 - editdistance.eval(s1, s2) / max(len(s1), len(s2)) anls = iou if iou >= 0.5 else 0.0 return anls def eval_docvqa(model, debug=False): docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation") scores = [] results = [] for row in tqdm(docvqa_val, disable=debug, desc="DocVQA"): image = row["image"] encoded_image = model.encode_image(image) result = [] for qa in row["qa"]: question = qa["question"] answers = qa["answers"] prompt = question + SUFFIX model_answer = model.query(encoded_image, prompt)["answer"] anls = max(get_anls(model_answer, gt) for gt in answers) scores.append(anls) result.append( { "question": question, "ground_truth": answers, "model_answer": model_answer, "anls": anls, } ) if debug: print(f"Question: {question}") print(f"Ground Truth: {answers}") print(f"Model Answer: {model_answer}") print(f"ANLS: {anls}") print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}") print("---------") results.append(result) return { "anls": sum(scores) / len(scores), "results": results, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_docvqa(model, args.debug) print(f"ANLS: {result['anls']:.4f}") ================================================ FILE: moondream/eval/eval_all.py ================================================ import argparse import torch from pprint import pprint from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model from .countbenchqa import eval_countbenchqa from .pope import evaluate_pope from .realworldqa import eval_realworldqa from .chartqa import eval_chartqa from .textvqa import eval_textvqa from .docvqa import eval_docvqa from .mmstar import eval_mmstar from .coco_map import eval_coco_map from .naturalbench import eval_naturalbench from .tallyqa import eval_tallyqa def create_model(ckpt_path): config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(ckpt_path, model) model.compile() return model def eval_all(model, skip=[]): evals = { "countbenchqa": eval_countbenchqa, "pope": evaluate_pope, "realworldqa": eval_realworldqa, "chartqa": eval_chartqa, "mmstar": eval_mmstar, "docvqa": eval_docvqa, "coco_map": eval_coco_map, "textvqa": eval_textvqa, "naturalbench": eval_naturalbench, "tallyqa": eval_tallyqa, } for b in skip: del evals[b] results = {} for name, eval_fn in evals.items(): results[name] = eval_fn(model) pprint({k: v for k, v in results[name].items() if k != "results"}) return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") model = create_model(args.model) eval_all(model) ================================================ FILE: moondream/eval/gazefollow.py ================================================ import torch import datasets import math from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model def eval_gazefollow(model, debug=False): dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test") mean_l2_error = [] min_l2_error = [] total = 0 for i, row in tqdm(enumerate(dataset), total=len(dataset)): heads = [] for gaze in row["gazes"]: head_bbox = gaze["head_bbox"] # xmin, ymin, xmax, ymax eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"]) mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"]) # Check if a head already exists with the same approximate bbox. # If so, use that head instead of creating a new one. for head in heads: if ( abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001 and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001 and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001 and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001 ): head["gazes"].append(mean_target_gaze) break else: heads.append( { "head_bbox": head_bbox, "eye_coord": eye_coord, "gazes": [mean_target_gaze], } ) for head in heads: pred_gaze = model.detect_gaze( row["image"], eye=head["eye_coord"], face={ "x_min": head["head_bbox"]["xmin"], "y_min": head["head_bbox"]["ymin"], "x_max": head["head_bbox"]["xmax"], "y_max": head["head_bbox"]["ymax"], }, unstable_settings={"force_detect": True}, )["gaze"] mean_target_gaze = ( sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]), sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]), ) mean_l2 = math.sqrt( (mean_target_gaze[0] - pred_gaze["x"]) ** 2 + (mean_target_gaze[1] - pred_gaze["y"]) ** 2 ) min_l2 = min( math.sqrt( (target_gaze[0] - pred_gaze["x"]) ** 2 + (target_gaze[1] - pred_gaze["y"]) ** 2 ) for target_gaze in head["gazes"] ) mean_l2_error.append(mean_l2) min_l2_error.append(min_l2) total += 1 if i % 100 == 0 and debug: print("Mean L2 error:", sum(mean_l2_error) / total) print("Min L2 error:", sum(min_l2_error) / total) return { "mean_l2": sum(mean_l2_error) / total, "min_l2": sum(min_l2_error) / total, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) results = eval_gazefollow(model, debug=args.debug) print(f"Mean L2 error: {results['mean_l2']:.4f}") print(f"Min L2 error: {results['min_l2']:.4f}") ================================================ FILE: moondream/eval/mmstar.py ================================================ import datasets import torch from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model SUFFIX = " Please answer directly with only the letter of the correct option and nothing else." def eval_mmstar(model, debug=False): dataset = datasets.load_dataset("Lin-Chen/MMStar", split="val") correct = 0 total = 0 category_stats = {} results = [] for row in tqdm(dataset, disable=debug, desc="MMStar"): image = row["image"] question = row["question"] + SUFFIX answer = row["answer"] model_answer = model.query(image, question)["answer"] is_correct = model_answer.strip().lower() == answer.strip().lower() category = f"{row['category']} / {row['l2_category']}" if category not in category_stats: category_stats[category] = {"correct": 0, "total": 0} total += 1 category_stats[category]["total"] += 1 results.append( { "question": question, "ground_truth": answer, "model_answer": model_answer, "is_correct": is_correct, "category": category, } ) if is_correct: correct += 1 category_stats[category]["correct"] += 1 elif debug: print(f"Index: {row['index']}") print(f"Question: {row['question']}") print(f"Answer: {answer}") print(f"Model Answer: {model_answer}") if debug: print(f"Correct: {correct}, Total: {total}") print(f"Accuracy: {correct * 100 / total:.2f}") print("Results by category:") for category, stats in category_stats.items(): acc = stats["correct"] * 100 / stats["total"] print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%") print("---------") return { "acc": correct * 100 / total, "correct_count": correct, "total_count": total, "category_stats": category_stats, "results": results, } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_mmstar(model, args.debug) print(f"Correct: {result['correct_count']}, Total: {result['total_count']}") print(f"Accuracy: {result['acc']:.2f}") print("\nResults by category:") for category, stats in result["category_stats"].items(): acc = stats["correct"] * 100 / stats["total"] print(f"{category}: {stats['correct']}/{stats['total']} = {acc:.2f}%") ================================================ FILE: moondream/eval/naturalbench.py ================================================ from datasets import load_dataset from tqdm import tqdm import torch from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model def eval_naturalbench(model, debug=False): # Yes, the benchmark test set is stored in the 'train' split... dataset = load_dataset("BaiqiL/NaturalBench", split="train") acc = [] q_acc = [] i_acc = [] g_acc = [] for row in tqdm(dataset, disable=debug, desc="NaturalBench"): if row["Question_Type"] == "yes_no": suffix = " Answer yes or no." else: suffix = "" images = [row["Image_0"], row["Image_1"], row["Image_0"], row["Image_1"]] prompts = [ row["Question_0"] + suffix, row["Question_0"] + suffix, row["Question_1"] + suffix, row["Question_1"] + suffix, ] expected = [ row["Image_0_Question_0"].strip().lower(), row["Image_1_Question_0"].strip().lower(), row["Image_0_Question_1"].strip().lower(), row["Image_0_Question_1"].strip().lower(), ] answers = [] for img, prompt in zip(images, prompts): encoded_image = model.encode_image(img) answer = model.query(encoded_image, prompt)["answer"] answers.append(answer.strip().lower()) if debug: for i, (q, a, e) in enumerate(zip(prompts, answers, expected)): print(f"Q{i}: {q}") print(f"Model: {a}") print(f"Expected: {e}") print(f"Correct: {a == e}") print("---") acc.append(answers[0] == expected[0]) acc.append(answers[1] == expected[1]) acc.append(answers[2] == expected[2]) acc.append(answers[3] == expected[3]) i_acc.append(answers[0] == expected[0] and answers[2] == expected[2]) i_acc.append(answers[1] == expected[1] and answers[3] == expected[3]) q_acc.append(answers[0] == expected[0] and answers[1] == expected[1]) q_acc.append(answers[2] == expected[2] and answers[3] == expected[3]) g_acc.append( answers[0] == expected[0] and answers[1] == expected[1] and answers[2] == expected[2] and answers[3] == expected[3] ) if debug: print(f"Current Overall Accuracy: {sum(acc) / len(acc):.4f}") print(f"Current Image Accuracy: {sum(i_acc) / len(i_acc):.4f}") print(f"Current Question Accuracy: {sum(q_acc) / len(q_acc):.4f}") print(f"Current Group Accuracy: {sum(g_acc) / len(g_acc):.4f}") print("=========") return { "overall_acc": sum(acc) / len(acc), "image_acc": sum(i_acc) / len(i_acc), "question_acc": sum(q_acc) / len(q_acc), "group_acc": sum(g_acc) / len(g_acc), } if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() results = eval_naturalbench(model, debug=args.debug) print(f"Overall Accuracy: {results['overall_acc']:.4f}") print(f"Image Accuracy: {results['image_acc']:.4f}") print(f"Question Accuracy: {results['question_acc']:.4f}") print(f"Group Accuracy: {results['group_acc']:.4f}") ================================================ FILE: moondream/eval/pope.py ================================================ import argparse from datasets import load_dataset from tqdm import tqdm import torch from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model def evaluate_pope(model, debug=False): pope_dataset = load_dataset("vikhyatk/POPE", split="test") stats = { "random": (0, 0), "popular": (0, 0), "adversarial": (0, 0), } for row in tqdm(pope_dataset, disable=debug, desc="POPE"): image = row["image"] encoded_image = model.encode_image(image) for split in ["adversarial", "popular", "random"]: for qa in row[split]: question = qa["question"] answer = qa["answer"] prompt = f"{question}\nAnswer yes or no." model_answer = model.query(encoded_image, prompt)["answer"].strip() if debug: print(f"Split: {split}") print(f"Question: {question}") print(f"Model: {model_answer}") print(f"Expected: {answer}") print(f"Correct: {model_answer.lower() == answer.lower()}") print("---") if model_answer.lower() == answer.lower(): stats[split] = (stats[split][0] + 1, stats[split][1] + 1) else: stats[split] = (stats[split][0], stats[split][1] + 1) if debug: for s in stats: if stats[s][1] > 0: print( f"{s.capitalize()}: {stats[s][0]}/{stats[s][1]} = {stats[s][0] * 100.0 / stats[s][1]:.2f}%" ) print("=========") return { "random": stats["random"][0] * 100.0 / stats["random"][1], "popular": stats["popular"][0] * 100.0 / stats["popular"][1], "adversarial": stats["adversarial"][0] * 100.0 / stats["adversarial"][1], } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) result = evaluate_pope(model, args.debug) print(f"Random Accuracy: {result['random']:.2f}") print(f"Popular Accuracy: {result['popular']:.2f}") print(f"Adversarial Accuracy: {result['adversarial']:.2f}") ================================================ FILE: moondream/eval/realworldqa.py ================================================ import argparse import datasets import torch from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model def eval_realworldqa(model, debug=False): dataset = datasets.load_dataset("lmms-lab/RealWorldQA", split="test") correct = 0 total = 0 results = [] for row in tqdm(dataset, disable=debug, desc="RealWorldQA"): image = row["image"] question = row["question"] answer = row["answer"] model_answer = model.query(image, question)["answer"] is_correct = model_answer.strip().lower() == answer.strip().lower() results.append( { "question": question, "ground_truth": answer, "model_answer": model_answer, "is_correct": is_correct, } ) total += 1 if is_correct: correct += 1 elif debug: print(f"Image: {row['image_path']}") print(f"Question: {question}") print(f"Answer: {answer}") print(f"Model Answer: {model_answer}") if debug: print(f"Correct: {correct}, Total: {total}") print(f"Accuracy: {correct * 100 / total:.2f}") print("---------") return { "acc": correct * 100 / total, "correct_count": correct, "total_count": total, "results": results, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_realworldqa(model, args.debug) print(f"Accuracy: {result['acc']:.2f}") print(f"Correct: {result['correct_count']} / {result['total_count']}") ================================================ FILE: moondream/eval/tallyqa.py ================================================ import argparse import datasets import torch from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. " def eval_tallyqa(model, debug=False): dataset = datasets.load_dataset( "vikhyatk/tallyqa-test", split="test", download_config=datasets.DownloadConfig(num_proc=16), ) total = 0 total_simple = 0 correct = 0 correct_simple = 0 for row in tqdm(dataset, disable=args.debug): image = row["image"] encoded_image = model.encode_image(image) for qa in row["qa"]: question = PREFIX + qa["question"] answer = str(qa["answer"]) is_simple = qa["is_simple"] model_answer = model.query(encoded_image, question)["answer"] total += 1 if model_answer.strip().lower() == answer.strip().lower(): correct += 1 elif args.debug: print(f"Question: {qa['question']}") print(f"Answer: {answer}") print(f"Model Answer: {model_answer}") if is_simple: total_simple += 1 if model_answer.strip().lower() == answer.strip().lower(): correct_simple += 1 if args.debug: print(f"Simple - Correct: {correct_simple}, Total: {total_simple}") print(f"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}") print(f"All - Correct: {correct}, Total: {total}") print(f"All Accuracy: {correct * 100 / total:.2f}") print("---------") return { "simple_acc": correct_simple * 100 / total_simple, "full_acc": correct * 100 / total, } if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_tallyqa(model, args.debug) print(f"Simple acc: {result['simple_acc']:.2f}") print(f"Full acc: {result['full_acc']:.2f}") ================================================ FILE: moondream/eval/textvqa.py ================================================ import argparse import datasets import torch from tqdm import tqdm from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model from .utils import VQAScorer PREFIX_TEXTVQA = "Read the text in the image and provide a brief lowercase answer. Respond 'unanswerable' only if there is no plausible answer. " def eval_textvqa(model, debug=False): dataset = datasets.load_dataset("vikhyatk/textvqa_val", split="validation") scorer = VQAScorer() total_score = 0 total_samples = 0 results = [] for row in tqdm(dataset, disable=debug, desc="TextVQA"): image = row["image"] encoded_image = model.encode_image(image) question = PREFIX_TEXTVQA + row["question"] model_answer = model.query(encoded_image, question)["answer"] score = scorer.compute_score(model_answer, row["answers"]) total_score += score total_samples += 1 results.append( { "question": question, "ground_truth": row["answers"], "model_answer": model_answer, "score": score, } ) if debug: print(f"Question: {row['question']}") print(f"Ground Truth Answers: {row['answers']}") print(f"Model Answer: {model_answer}") print(f"Score: {score}") print(f"Running Average Score: {total_score * 100 / total_samples:.2f}") print("---------") return {"score": total_score * 100 / total_samples, "results": results} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, required=True) parser.add_argument("--debug", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): torch.set_default_device("cuda") elif torch.backends.mps.is_available(): torch.set_default_device("mps") config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.compile() result = eval_textvqa(model, args.debug) print(f"Score: {result['score']}") ================================================ FILE: moondream/eval/utils.py ================================================ import re from typing import List class VQAScorer: def __init__(self): self.contractions = { "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", "youre": "you're", "youve": "you've", } self.manualMap = { "none": "0", "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10", } self.articles = ["a", "an", "the"] self.periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)") self.commaStrip = re.compile(r"(\d)(\,)(\d)") self.punct = [ ";", r"/", "[", "]", '"', "{", "}", "(", ")", "=", "+", "\\", "_", "-", ">", "<", "@", "`", ",", "?", "!", ] self.commaStrip = re.compile(r"(\d)(,)(\d)") self.periodStrip = re.compile(r"(?!<=\d)(\.)(?!\d)") def process_punctuation(self, inText: str) -> str: outText = inText for p in self.punct: if (p + " " in inText or " " + p in inText) or ( re.search(self.commaStrip, inText) is not None ): outText = outText.replace(p, "") else: outText = outText.replace(p, " ") outText = self.periodStrip.sub("", outText, re.UNICODE) return outText def process_digit_article(self, inText: str) -> str: outText = [] tempText = inText.lower().split() for word in tempText: word = self.manualMap.setdefault(word, word) if word not in self.articles: outText.append(word) for wordId, word in enumerate(outText): if word in self.contractions: outText[wordId] = self.contractions[word] outText = " ".join(outText) return outText def process_answer(self, answer): answer = answer.replace("\n", " ") answer = answer.replace("\t", " ") answer = answer.strip() answer = self.process_punctuation(answer) answer = self.process_digit_article(answer) return answer def process_line(self, prediction: str, gt_answers: List[str]) -> float: gt_answers = [self.process_answer(x) for x in gt_answers] prediction = self.process_answer(prediction) matches = [] for current_idx, gtAnsDatum in enumerate(gt_answers): otherGTAns = [ item for ret_gt_idx, item in enumerate(gt_answers) if ret_gt_idx != current_idx ] matchingAns = [item for item in otherGTAns if item == prediction] acc = min(1, float(len(matchingAns)) / 3) matches.append(acc) return sum(matches) / len(matches) def compute_score( self, candidate_answer: str, ground_truth_answers: List[str] ) -> float: """ Compute VQA score for a candidate answer against ground truth answers, exactly matching the VQAEval scoring logic """ # Process candidate answer candidate = self.process_answer(candidate_answer) # Process ground truth answers processed_gts = [] for gt in ground_truth_answers: gt = gt.replace("\n", " ") gt = gt.replace("\t", " ") gt = gt.strip() processed_gts.append(gt) # If there are multiple different answers, apply additional processing if len(set(processed_gts)) > 1: candidate = self.process_punctuation(candidate) candidate = self.process_digit_article(candidate) processed_gts = [ self.process_punctuation(self.process_digit_article(gt)) for gt in processed_gts ] # Count matches matching_answers = [1 for gt in processed_gts if gt == candidate] score = min(1.0, float(len(matching_answers)) / 3.0) return score ================================================ FILE: moondream/eval/waste_detection.py ================================================ import argparse from collections import defaultdict from typing import Dict, List, Tuple import torch from PIL import Image from tqdm import tqdm from datasets import load_dataset from ..torch.config import MoondreamConfig from ..torch.moondream import MoondreamModel from ..torch.weights import load_weights_into_model Box = Tuple[float, float, float, float] # (x1, y1, x2, y2) – in proportion form def iou(a: Box, b: Box) -> float: """Corner-format IoU. Returns 0 when either box has zero area.""" x1, y1 = max(a[0], b[0]), max(a[1], b[1]) x2, y2 = min(a[2], b[2]), min(a[3], b[3]) inter = max(0.0, x2 - x1) * max(0.0, y2 - y1) union = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - inter return inter / union if union else 0.0 def match(gt: List[Box], pr: List[Box], iou_thr: float) -> Tuple[int, int, int]: """ Greedy one-to-one matching with no confidences. Predictions are taken in the order produced by the model. """ tp = fp = 0 seen = [False] * len(gt) for p in pr: best, best_i = 0.0, -1 for i, g in enumerate(gt): if seen[i]: continue iou_ = iou(p, g) if iou_ > best: best, best_i = iou_, i if best >= iou_thr: tp += 1 seen[best_i] = True else: fp += 1 fn = len(gt) - tp return tp, fp, fn class WasteDetection(torch.utils.data.Dataset): def __init__(self, name: str = "moondream/waste_detection", split: str = "test"): self.ds = load_dataset(name, split=split) def __len__(self): return len(self.ds) def __getitem__(self, idx: int) -> Dict: s = self.ds[idx] img = ( s["image"] if isinstance(s["image"], Image.Image) else Image.fromarray(s["image"]) ) W, H = float(s.get("width", img.width)), float(s.get("height", img.height)) lbl_to_boxes = defaultdict(list) for (xc, yc, bw, bh), lbl in zip(s["boxes"], s["labels"]): x1 = xc - bw / 2 y1 = yc - bh / 2 x2 = xc + bw / 2 y2 = yc + bh / 2 lbl_to_boxes[lbl].append((x1, y1, x2, y2)) return {"image": img, "gt": lbl_to_boxes, "W": W, "H": H} def evaluate( model: MoondreamModel, iou_thr: float, debug: bool, ): ds = WasteDetection(split="test") TP = FP = FN = 0 for s in tqdm(ds, disable=debug, desc="Waste"): img, gts = s["image"], s["gt"] enc = model.encode_image(img) for lbl, gt_boxes in gts.items(): preds: List[Box] = [ ( o["x_min"], o["y_min"], o["x_max"], o["y_max"], ) for o in model.detect(enc, lbl)["objects"] ] tp, fp, fn = match(gt_boxes, preds, iou_thr) TP += tp FP += fp FN += fn prec = TP / (TP + FP) if TP + FP else 0.0 rec = TP / (TP + FN) if TP + FN else 0.0 f1 = 2 * prec * rec / (prec + rec) if prec + rec else 0.0 return dict(precision=prec, recall=rec, f1=f1, tp=TP, fp=FP, fn=FN) def load_model(path: str, device: torch.device) -> MoondreamModel: cfg = MoondreamConfig() model = MoondreamModel(cfg) load_weights_into_model(path, model) model.compile() model.to(device) return model def main(): p = argparse.ArgumentParser() p.add_argument("--model", required=True) p.add_argument("--iou_thr", type=float, default=0.5) p.add_argument("--gpu", type=int, default=0) p.add_argument("--debug", action="store_true") args = p.parse_args() if torch.cuda.is_available(): torch.cuda.set_device(args.gpu) device = torch.device(f"cuda:{args.gpu}") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") model = load_model(args.model, device) res = evaluate(model, args.iou_thr, args.debug) print(f"Precision: {res['precision']*100:.2f}%") print(f"Recall: {res['recall']*100:.2f}%") print(f"F1 Score: {res['f1']*100:.2f}%") print(f"TP: {res['tp']} FP: {res['fp']} FN: {res['fn']}") if __name__ == "__main__": """ Eval to accompany finetune_region.py. """ main() ================================================ FILE: moondream/torch/config.py ================================================ from dataclasses import dataclass, field from typing import Dict, List, Optional @dataclass(frozen=True) class TextMoeConfig: num_experts: int = 64 start_layer: int = 4 experts_per_token: int = 8 expert_inner_dim: int = 1024 @dataclass(frozen=True) class TextConfig: dim: int = 2048 ff_dim: int = 8192 n_layers: int = 24 vocab_size: int = 51200 max_context: int = 4096 n_heads: int = 32 n_kv_heads: int = 32 prefix_attn: int = 730 group_size: Optional[int] = None moe: Optional[TextMoeConfig] = TextMoeConfig() @dataclass(frozen=True) class VisionConfig: enc_dim: int = 1152 enc_patch_size: int = 14 enc_n_layers: int = 27 enc_ff_dim: int = 4304 enc_n_heads: int = 16 proj_out_dim: int = 2048 crop_size: int = 378 in_channels: int = 3 max_crops: int = 12 overlap_margin: int = 4 proj_inner_dim: int = 8192 @dataclass(frozen=True) class RegionConfig: dim: int = 2048 coord_feat_dim: int = 256 coord_out_dim: int = 1024 size_feat_dim: int = 512 size_out_dim: int = 2048 group_size: Optional[int] = None @dataclass(frozen=True) class TokenizerConfig: bos_id: int = 0 eos_id: int = 0 answer_id: int = 3 thinking_id: int = 4 coord_id: int = 5 size_id: int = 6 start_ground_points_id: int = 7 end_ground_id: int = 9 templates: Dict[str, Optional[Dict[str, List[int]]]] = field( default_factory=lambda: { "caption": { "short": [1, 32708, 2, 12492, 3], "normal": [1, 32708, 2, 6382, 3], "long": [1, 32708, 2, 4059, 3], }, "query": {"prefix": [1, 15381, 2], "suffix": [3]}, "detect": {"prefix": [1, 7235, 476, 2], "suffix": [3]}, "point": {"prefix": [1, 2581, 2], "suffix": [3]}, } ) @dataclass(frozen=True) class MoondreamConfig: text: TextConfig = TextConfig() vision: VisionConfig = VisionConfig() region: RegionConfig = RegionConfig() tokenizer: TokenizerConfig = TokenizerConfig() @classmethod def from_dict(cls, config_dict: dict): text_config = TextConfig(**config_dict.get("text", {})) vision_config = VisionConfig(**config_dict.get("vision", {})) region_config = RegionConfig(**config_dict.get("region", {})) tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {})) return cls( text=text_config, vision=vision_config, region=region_config, tokenizer=tokenizer_config, ) def to_dict(self): return { "text": self.text.__dict__, "vision": self.vision.__dict__, "region": self.region.__dict__, "tokenizer": self.tokenizer.__dict__, } ================================================ FILE: moondream/torch/hf_moondream.py ================================================ import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from typing import Union from .config import MoondreamConfig from .moondream import MoondreamModel # Files sometimes don't get loaded without these... from .image_crops import * from .vision import * from .text import * from .region import * from .utils import * def extract_question(text): prefix = "\n\nQuestion: " suffix = "\n\nAnswer:" if text.startswith(prefix) and text.endswith(suffix): return text[len(prefix) : -len(suffix)] else: return None class HfConfig(PretrainedConfig): _auto_class = "AutoConfig" model_type = "moondream3" def __init__(self, **kwargs): super().__init__(**kwargs) self.config = {"skills": ["query", "caption", "detect", "point"]} class HfMoondream(PreTrainedModel): _auto_class = "AutoModelForCausalLM" config_class = HfConfig def __init__(self, config): super().__init__(config) self.model = MoondreamModel( MoondreamConfig.from_dict(config.config), setup_caches=False ) self._is_kv_cache_setup = False def _setup_caches(self): if not self._is_kv_cache_setup: self.model._setup_caches() self._is_kv_cache_setup = True @property def encode_image(self): self._setup_caches() return self.model.encode_image @property def query(self): self._setup_caches() return self.model.query @property def caption(self): self._setup_caches() return self.model.caption @property def detect(self): self._setup_caches() return self.model.detect @property def point(self): self._setup_caches() return self.model.point @property def detect_gaze(self): self._setup_caches() return self.model.detect_gaze def answer_question( self, image_embeds, question, tokenizer=None, chat_history="", result_queue=None, max_new_tokens=256, **kwargs ): answer = self.query(image_embeds, question)["answer"].strip() if result_queue is not None: result_queue.put(answer) return answer def batch_answer(self, images, prompts, tokenizer=None, **kwargs): answers = [] for image, prompt in zip(images, prompts): answers.append(self.query(image, prompt)["answer"].strip()) return answers def _unsupported_exception(self): raise NotImplementedError( "This method is not supported in the latest version of moondream. " "Consider upgrading to the updated API spec, or alternately pin " "to 'revision=2024-08-26'." ) def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): """ Function definition remains unchanged for backwards compatibility. Be aware that tokenizer, max_new_takens, and kwargs are ignored. """ prompt_extracted = extract_question(prompt) if prompt_extracted is not None: answer = self.model.query( image=image_embeds, question=prompt_extracted, stream=False )["answer"] else: image_embeds = self.encode_image(image_embeds) prompt_tokens = torch.tensor( [self.model.tokenizer.encode(prompt).ids], device=self.device, ) def generator(): for token in self.model._generate_answer( prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens, ): yield token answer = "".join(list(generator())) return [answer] def get_input_embeddings(self) -> nn.Embedding: """ Lazily wrap the raw parameter `self.model.text.wte` in a real `nn.Embedding` layer so that HF mix-ins recognise it. The wrapper **shares** the weight tensor—no copy is made. """ if not hasattr(self, "_input_embeddings"): self._input_embeddings = nn.Embedding.from_pretrained( self.model.text.wte, # tensor created in text.py freeze=True, # set to False if you need it trainable ) return self._input_embeddings def set_input_embeddings(self, value: Union[nn.Embedding, nn.Module]) -> None: """ Lets HF functions (e.g. `resize_token_embeddings`) replace or resize the embeddings and keeps everything tied to `self.model.text.wte`. """ # 1. point the low-level parameter to the new weight matrix self.model.text.wte = value.weight # 2. keep a reference for get_input_embeddings() self._input_embeddings = value def input_embeds( self, input_ids: Union[torch.LongTensor, list, tuple], *, device: torch.device | None = None ) -> torch.FloatTensor: """ Back-compat wrapper that turns token IDs into embeddings. Example: ids = torch.tensor([[1, 2, 3]]) embeds = model.input_embeds(ids) # (1, 3, hidden_dim) """ if not torch.is_tensor(input_ids): input_ids = torch.as_tensor(input_ids) if device is not None: input_ids = input_ids.to(device) return self.get_input_embeddings()(input_ids) ================================================ FILE: moondream/torch/hf_release.py ================================================ import torch import argparse from .weights import load_weights_into_model from .hf_moondream import HfConfig, HfMoondream if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model-name", type=str, default="vikhyatk/moondream-next") parser.add_argument("--ckpt", type=str, required=True) args = parser.parse_args() config = HfConfig() model = HfMoondream(config) load_weights_into_model(args.ckpt, model.model) model.push_to_hub(args.model_name, config=config) ================================================ FILE: moondream/torch/image_crops.py ================================================ import math import numpy as np import torch from typing import TypedDict try: import pyvips HAS_VIPS = True except: from PIL import Image HAS_VIPS = False def select_tiling( height: int, width: int, crop_size: int, max_crops: int ) -> tuple[int, int]: """ Determine the optimal number of tiles to cover an image with overlapping crops. """ if height <= crop_size or width <= crop_size: return (1, 1) # Minimum required tiles in each dimension min_h = math.ceil(height / crop_size) min_w = math.ceil(width / crop_size) # If minimum required tiles exceed max_crops, return proportional distribution if min_h * min_w > max_crops: ratio = math.sqrt(max_crops / (min_h * min_w)) return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio))) # Perfect aspect-ratio tiles that satisfy max_crops h_tiles = math.floor(math.sqrt(max_crops * height / width)) w_tiles = math.floor(math.sqrt(max_crops * width / height)) # Ensure we meet minimum tile requirements h_tiles = max(h_tiles, min_h) w_tiles = max(w_tiles, min_w) # If we exceeded max_crops, scale down the larger dimension if h_tiles * w_tiles > max_crops: if w_tiles > h_tiles: w_tiles = math.floor(max_crops / h_tiles) else: h_tiles = math.floor(max_crops / w_tiles) return (max(1, h_tiles), max(1, w_tiles)) class OverlapCropOutput(TypedDict): crops: np.ndarray tiling: tuple[int, int] def overlap_crop_image( image: np.ndarray, overlap_margin: int, max_crops: int, base_size: tuple[int, int] = (378, 378), patch_size: int = 14, ) -> OverlapCropOutput: """ Process an image using an overlap-and-resize cropping strategy with margin handling. This function takes an input image and creates multiple overlapping crops with consistent margins. It produces: 1. A single global crop resized to base_size 2. Multiple overlapping local crops that maintain high resolution details 3. A patch ordering matrix that tracks correspondence between crops The overlap strategy ensures: - Smooth transitions between adjacent crops - No loss of information at crop boundaries - Proper handling of features that cross crop boundaries - Consistent patch indexing across the full image Args: image (np.ndarray): Input image as numpy array with shape (H,W,C) base_size (tuple[int,int]): Target size for crops, default (378,378) patch_size (int): Size of patches in pixels, default 14 overlap_margin (int): Margin size in patch units, default 4 max_crops (int): Maximum number of crops allowed, default 12 Returns: OverlapCropOutput: Dictionary containing: - crops: A numpy array containing the global crop of the full image (index 0) followed by the overlapping cropped regions (indices 1+) - tiling: Tuple of (height,width) tile counts """ original_h, original_w = image.shape[:2] # Convert margin from patch units to pixels margin_pixels = patch_size * overlap_margin total_margin_pixels = margin_pixels * 2 # Both sides # Calculate crop parameters crop_patches = base_size[0] // patch_size # patches per crop dimension crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches crop_window_size = crop_window_patches * patch_size # usable size in pixels # Determine tiling tiling = select_tiling( original_h - total_margin_pixels, original_w - total_margin_pixels, crop_window_size, max_crops, ) # Pre-allocate crops. n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop crops = np.zeros( (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8 ) # Resize image to fit tiling target_size = ( tiling[0] * crop_window_size + total_margin_pixels, tiling[1] * crop_window_size + total_margin_pixels, ) if HAS_VIPS: # Convert to vips for resizing vips_image = pyvips.Image.new_from_array(image) scale_x = target_size[1] / image.shape[1] scale_y = target_size[0] / image.shape[0] resized = vips_image.resize(scale_x, vscale=scale_y) image = resized.numpy() # Create global crop scale_x = base_size[1] / vips_image.width scale_y = base_size[0] / vips_image.height global_vips = vips_image.resize(scale_x, vscale=scale_y) crops[0] = global_vips.numpy() else: # Fallback to PIL pil_img = Image.fromarray(image) resized = pil_img.resize( (int(target_size[1]), int(target_size[0])), resample=Image.Resampling.LANCZOS, ) image = np.asarray(resized) # Create global crop global_pil = pil_img.resize( (int(base_size[1]), int(base_size[0])), resample=Image.Resampling.LANCZOS ) crops[0] = np.asarray(global_pil) for i in range(tiling[0]): for j in range(tiling[1]): # Calculate crop coordinates y0 = i * crop_window_size x0 = j * crop_window_size # Extract crop with padding if needed y_end = min(y0 + base_size[0], image.shape[0]) x_end = min(x0 + base_size[1], image.shape[1]) crop_region = image[y0:y_end, x0:x_end] crops[ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1] ] = crop_region return {"crops": crops, "tiling": tiling} def reconstruct_from_crops( crops: torch.Tensor, tiling: tuple[int, int], overlap_margin: int, patch_size: int = 14, ) -> torch.Tensor: """ Reconstruct the original image from overlapping crops into a single seamless image. Takes a list of overlapping image crops along with their positional metadata and reconstructs them into a single coherent image by carefully stitching together non-overlapping regions. Handles both numpy arrays and PyTorch tensors. Args: crops: List of image crops as numpy arrays or PyTorch tensors with shape (H,W,C) tiling: Tuple of (height,width) indicating crop grid layout patch_size: Size in pixels of each patch, default 14 overlap_margin: Number of overlapping patches on each edge, default 4 Returns: Reconstructed image as numpy array or PyTorch tensor matching input type, with shape (H,W,C) where H,W are the original image dimensions """ tiling_h, tiling_w = tiling crop_height, crop_width = crops[0].shape[:2] margin_pixels = overlap_margin * patch_size # Calculate output size (only adding margins once) output_h = (crop_height - 2 * margin_pixels) * tiling_h + 2 * margin_pixels output_w = (crop_width - 2 * margin_pixels) * tiling_w + 2 * margin_pixels reconstructed = torch.zeros( (output_h, output_w, crops[0].shape[2]), device=crops[0].device, dtype=crops[0].dtype, ) for i, crop in enumerate(crops): tile_y = i // tiling_w tile_x = i % tiling_w # For each tile, determine which part to keep # Keep left margin only for first column x_start = 0 if tile_x == 0 else margin_pixels # Keep right margin only for last column x_end = crop_width if tile_x == tiling_w - 1 else crop_width - margin_pixels # Keep top margin only for first row y_start = 0 if tile_y == 0 else margin_pixels # Keep bottom margin only for last row y_end = crop_height if tile_y == tiling_h - 1 else crop_height - margin_pixels # Calculate where this piece belongs in the output out_x = tile_x * (crop_width - 2 * margin_pixels) out_y = tile_y * (crop_height - 2 * margin_pixels) # Place the piece reconstructed[ out_y + y_start : out_y + y_end, out_x + x_start : out_x + x_end ] = crop[y_start:y_end, x_start:x_end] return reconstructed ================================================ FILE: moondream/torch/layers.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Literal, Optional try: from torchao import quantize_ from torchao.quantization import int4_weight_only except ImportError: def quantize_(model, quant_mode): raise ImportError( "torchao is not installed. Please install it with `pip install torchao`." ) def int4_weight_only(group_size): raise ImportError( "torchao is not installed. Please install it with `pip install torchao`." ) def gelu_approx(x): return F.gelu(x, approximate="tanh") @dataclass class LinearWeights: weight: torch.Tensor bias: torch.Tensor def linear(x: torch.Tensor, w: LinearWeights) -> torch.Tensor: return F.linear(x, w.weight, w.bias) def dequantize_tensor(W_q, scale, zero, orig_shape, dtype=torch.bfloat16): _step = W_q.shape[0] W_r = torch.empty([2 * _step, W_q.shape[1]], dtype=dtype, device=W_q.device) W_r[:_step] = (W_q & 0b11110000) >> 4 W_r[_step:] = W_q & 0b00001111 W_r.sub_(zero).mul_(scale) return W_r.reshape(orig_shape) class QuantizedLinear(nn.Module): def __init__( self, in_features: int, out_features: int, dtype: torch.dtype, ): # TODO: Take group_size as an input instead of hardcoding it here. super().__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.ParameterDict( { "packed": nn.Parameter( torch.empty( out_features * in_features // (128 * 2), 128, dtype=torch.uint8 ), requires_grad=False, ), "scale": nn.Parameter( torch.empty(out_features * in_features // 128, 1), requires_grad=False, ), "zero_point": nn.Parameter( torch.empty(out_features * in_features // 128, 1), requires_grad=False, ), } ) self.bias = nn.Parameter(torch.empty(out_features), requires_grad=False) self.unpacked = False def unpack(self): if self.unpacked: return self.weight = nn.Parameter( dequantize_tensor( self.weight["packed"], self.weight["scale"], self.weight["zero_point"], (self.out_features, self.in_features), torch.bfloat16, ) ) with torch.device("meta"): self.linear = nn.Linear( self.in_features, self.out_features, dtype=torch.bfloat16 ) self.linear.weight = self.weight self.linear.bias = nn.Parameter( self.bias.to(torch.bfloat16), requires_grad=False ) del self.weight, self.bias quantize_(self, int4_weight_only(group_size=128)) self.unpacked = True torch.cuda.empty_cache() def forward(self, x: torch.Tensor) -> torch.Tensor: if not self.unpacked: self.unpack() return self.linear(x) @dataclass class LayerNormWeights: weight: torch.Tensor bias: torch.Tensor def layer_norm(x: torch.Tensor, w: LayerNormWeights) -> torch.Tensor: return F.layer_norm(x, w.bias.shape, w.weight, w.bias) @dataclass class MLPWeights: fc1: LinearWeights fc2: LinearWeights act: Literal["gelu_approx"] = "gelu_approx" def mlp(x: torch.Tensor, w: MLPWeights, lora: Optional[dict] = None) -> torch.Tensor: x0 = w.fc1(x) if lora is not None: x1 = F.linear(F.linear(x, lora["fc1"]["A"]), lora["fc1"]["B"]) x = x0 + x1 else: x = x0 x = gelu_approx(x) x0 = w.fc2(x) if lora is not None: x1 = F.linear(F.linear(x, lora["fc2"]["A"]), lora["fc2"]["B"]) x = x0 + x1 else: x = x0 return x def moe_mlp( x: torch.Tensor, mlp_module: nn.Module, experts_per_token: int ) -> torch.Tensor: B, T, C = x.shape x = x.reshape(-1, C) # Router computation router_logits = mlp_module.router(x) topk_logits, topk_idxs = torch.topk(router_logits, experts_per_token, dim=-1) topk_weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(x.dtype) num_tokens, top_k = topk_idxs.shape if T == 1: w1_weight = mlp_module.fc1.weight w2_weight = mlp_module.fc2.weight # Flatten to process all token-expert pairs at once flat_idxs = topk_idxs.view(-1) # [T*A] flat_weights = topk_weights.view(-1) # [T*A] # Select expert weights w1_selected = w1_weight[flat_idxs] # [T*A, H, D] w2_selected = w2_weight[flat_idxs] # [T*A, D, H] # Expand input for all token-expert pairs x_expanded = x.unsqueeze(1).expand(-1, top_k, -1).reshape(-1, C) # [T*A, D] # First linear layer with GeGLU: [T*A, H, D] @ [T*A, D, 1] -> [T*A, H] x1_full = torch.bmm(w1_selected, x_expanded.unsqueeze(-1)).squeeze( -1 ) # [T*A, H] x1, g = x1_full.chunk(2, dim=-1) x1 = F.gelu(x1) * (g + 1) # Second linear layer: [T*A, D, H] @ [T*A, H, 1] -> [T*A, D] expert_outs = torch.bmm(w2_selected, x1.unsqueeze(-1)).squeeze(-1) # [T*A, D] # Apply weights and reshape weighted_outs = expert_outs * flat_weights.unsqueeze(-1) # [T*A, D] weighted_outs = weighted_outs.view(num_tokens, top_k, C) # [T, A, D] # Sum over experts mlp_out = weighted_outs.sum(dim=1) # [T, D] mlp_out = mlp_out.view(B, T, C) return mlp_out else: out = x.new_zeros(x.size()) for expert_id in range(mlp_module.fc1.weight.shape[0]): token_pos, which_k = (topk_idxs == expert_id).nonzero(as_tuple=True) if token_pos.numel() == 0: continue x_tok = x.index_select(0, token_pos) gate_tok = topk_weights[token_pos, which_k] h_full = F.linear(x_tok, mlp_module.fc1.weight[expert_id]) h, g = h_full.chunk(2, dim=-1) h = F.gelu(h) * (g + 1) y = F.linear(h, mlp_module.fc2.weight[expert_id]) y.mul_(gate_tok.unsqueeze(-1)) out.index_add_(0, token_pos, y) return out.view(B, T, C) @dataclass class AttentionWeights: qkv: LinearWeights proj: LinearWeights def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: bsz, q_len, d_model = x.shape head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) for t in linear(x, w.qkv).chunk(3, dim=-1) ] out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out = linear(out, w.proj) return out ================================================ FILE: moondream/torch/lora.py ================================================ import functools import os import shutil import torch from pathlib import Path from urllib.request import Request, urlopen from typing import Optional def variant_cache_dir(): hf_hub_cache = os.environ.get("HF_HUB_CACHE") if hf_hub_cache is not None: return Path(hf_hub_cache) / "md_variants" hf_home = os.environ.get("HF_HOME") if hf_home is not None: return Path(hf_home) / "hub" / "md_variants" return Path("~/.cache/huggingface/hub").expanduser() / "md_variants" def cached_variant_path(variant_id: str): variant, *rest = variant_id.split("/", 1) step = rest[0] if rest else "final" cache_dir = variant_cache_dir() / variant os.makedirs(cache_dir, exist_ok=True) dest = cache_dir / f"{step}.pt" if dest.exists(): return dest md_endpoint = os.getenv("MOONDREAM_ENDPOINT", "https://api.moondream.ai") headers = {"User-Agent": "moondream-torch"} api_key = os.getenv("MOONDREAM_API_KEY") if api_key is not None: headers["X-Moondream-Auth"] = api_key req = Request(f"{md_endpoint}/v1/variants/{variant_id}/download", headers=headers) with urlopen(req) as r, open(dest, "wb") as f: shutil.copyfileobj(r, f) return dest def nest(flat): tree = {} for k, v in flat.items(): parts = k.split(".") d = tree for p in parts[:-1]: d = d.setdefault(p, {}) d[parts[-1]] = v return tree @functools.lru_cache(maxsize=5) def variant_state_dict(variant_id: Optional[str] = None, device: str = "cpu"): if variant_id is None: return None state_dict = torch.load( cached_variant_path(variant_id), map_location=device, weights_only=True ) # TODO: Move these into the training code that saves checkpoints... rename_rules = [ ("text_model.transformer.h", "text.blocks"), (".mixer", ".attn"), (".out_proj", ".proj"), (".Wqkv", ".qkv"), (".parametrizations.weight.0", ""), ] new_state_dict = {} for key, tensor in state_dict.items(): new_key = key for old, new in rename_rules: if old in new_key: new_key = new_key.replace(old, new) new_state_dict[new_key] = tensor return nest(new_state_dict) ================================================ FILE: moondream/torch/moondream.py ================================================ import torch import torch.nn as nn import random from typing import Literal, Tuple, TypedDict, Union, Dict, Any, Optional, List from PIL import Image from dataclasses import dataclass from tokenizers import Tokenizer from torch.nn.attention.flex_attention import create_block_mask from .config import MoondreamConfig from .image_crops import reconstruct_from_crops from .vision import vision_encoder, vision_projection, prepare_crops, build_vision_model from .text import build_text_model, text_encoder, lm_head, text_decoder from .region import ( decode_coordinate, encode_coordinate, decode_size, encode_size, encode_spatial_refs, SpatialRefs, ) from .layers import QuantizedLinear from .lora import variant_state_dict from .utils import remove_outlier_points ImageEncodingSettings = TypedDict( "ImageEncodingSettings", {"variant": str}, total=False, ) TextSamplingSettings = TypedDict( "TextSamplingSettings", { "max_tokens": int, "temperature": float, "top_p": float, "variant": str, }, total=False, ) ObjectSamplingSettings = TypedDict( "ObjectSamplingSettings", {"max_objects": int, "variant": str}, total=False, ) DEFAULT_MAX_TOKENS = 768 DEFAULT_TEMPERATURE = 0.5 DEFAULT_TOP_P = 0.9 DEFAULT_MAX_OBJECTS = 150 @dataclass(frozen=True) class EncodedImage: pos: int caches: List[Tuple[torch.Tensor, torch.Tensor]] class KVCache(nn.Module): def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype): super().__init__() cache_shape = (1, n_kv_heads, max_context, dim // n_heads) self.register_buffer( "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype) ) self.register_buffer( "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype) ) def update(self, pos_ids, k, v): kout, vout = self.k_cache, self.v_cache kout[:, :, pos_ids, :] = k vout[:, :, pos_ids, :] = v return kout, vout def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx def get_mask_mod(mask_mod, offset): def _mask_mod(b, h, q, kv): return mask_mod(b, h, q + offset, kv) return _mask_mod class MoondreamModel(nn.Module): def __init__( self, config: MoondreamConfig, dtype=torch.bfloat16, setup_caches=True ): super().__init__() self.config = config self.tokenizer = Tokenizer.from_pretrained("moondream/starmie-v1") self.vision = build_vision_model(config.vision, dtype) self.text = build_text_model(config.text, dtype) # Region Model linear_cls = ( QuantizedLinear if config.region.group_size is not None else nn.Linear ) self.region = nn.ModuleDict( { "coord_encoder": linear_cls( config.region.coord_feat_dim, config.region.dim, dtype=dtype ), "coord_decoder": linear_cls( config.region.dim, config.region.coord_out_dim, dtype=dtype ), "size_encoder": linear_cls( config.region.size_feat_dim, config.region.dim, dtype=dtype ), "size_decoder": linear_cls( config.region.dim, config.region.size_out_dim, dtype=dtype ), } ) self.region.coord_features = nn.Parameter( torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T ) self.region.size_features = nn.Parameter( torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T ) attn_mask = torch.tril( torch.ones( 1, 1, config.text.max_context, config.text.max_context, dtype=torch.bool ) ) patch_w = config.vision.crop_size // config.vision.enc_patch_size prefix_attn_len = 1 + patch_w**2 attn_mask[..., :prefix_attn_len, :prefix_attn_len] = 1 self.register_buffer("attn_mask", attn_mask, persistent=False) self.use_flex_decoding = True self._causal_block_mask = None self._point_gen_indices = None # Initialize KV caches. if setup_caches: self._setup_caches() @property def causal_block_mask(self): # The things we do to deal with ZeroGPU... if self._causal_block_mask is None: self._causal_block_mask = create_block_mask( causal_mask, B=None, H=None, Q_LEN=self.config.text.max_context, KV_LEN=self.config.text.max_context, ) return self._causal_block_mask @property def point_gen_indices(self): if self._point_gen_indices is None: self._point_gen_indices = torch.tensor( [self.config.tokenizer.coord_id, self.config.tokenizer.eos_id], device=self.device, ) return self._point_gen_indices def _setup_caches(self): c = self.config.text for b in self.text.blocks: b.kv_cache = KVCache( c.n_heads, c.n_kv_heads, c.max_context, c.dim, device=self.device, dtype=self.vision.pos_emb.dtype, ) @property def device(self): return self.vision.pos_emb.device def _vis_enc(self, x: torch.Tensor): return vision_encoder(x, self.vision, self.config.vision) def _vis_proj(self, g: torch.Tensor, r: torch.Tensor): return vision_projection(g, r, self.vision, self.config.vision) def _prefill( self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor, lora: Optional[torch.Tensor], ): return text_decoder(x, self.text, attn_mask, pos_ids, self.config.text, lora) def _decode_one_tok( self, x: torch.Tensor, attn_mask: torch.Tensor, pos_ids: torch.Tensor, lora: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, ): if self.use_flex_decoding: torch._assert(pos_ids.shape[-1] == 1, "Invalid position ID shape") block_index = pos_ids // self.causal_block_mask.BLOCK_SIZE[0] mask = self.causal_block_mask[:, :, block_index] mask.seq_lengths = (1, mask.seq_lengths[1]) mask.mask_mod = get_mask_mod(self.causal_block_mask.mask_mod, pos_ids[0]) else: mask = None hidden = text_decoder( x, self.text, attn_mask, pos_ids, self.config.text, lora=lora, flex_block_mask_slice=mask, ) logits = lm_head(hidden, self.text, indices=lm_head_indices) return logits, hidden def compile(self): for module in self.modules(): if isinstance(module, QuantizedLinear): module.unpack() # Initialize lazy properties to avoid first-call overhead self.causal_block_mask self.point_gen_indices # TODO: vision_projection and _prefill is not being compiled self._vis_enc = torch.compile(self._vis_enc, fullgraph=True) self._decode_one_tok = torch.compile( self._decode_one_tok, fullgraph=True, mode="reduce-overhead" ) # Warm up compiled methods with dummy forward passes device = self.device dtype = self.vision.pos_emb.dtype with torch.no_grad(): # Warmup vision encoder dummy_crops = torch.randn(1, 3, 378, 378, device=device, dtype=dtype) self._vis_enc(dummy_crops) # Warmup _decode_one_tok (both normal and point generation modes) dummy_emb = torch.randn( 1, 1, self.config.text.dim, device=device, dtype=dtype ) dummy_mask = torch.ones( 1, 1, self.config.text.max_context, device=device, dtype=torch.bool ) dummy_pos_ids = torch.tensor([100], device=device, dtype=torch.long) self._decode_one_tok(dummy_emb, dummy_mask, dummy_pos_ids, None) self._decode_one_tok( dummy_emb, dummy_mask, dummy_pos_ids, None, lm_head_indices=self.point_gen_indices, ) def _run_vision_encoder(self, image: Image.Image) -> torch.Tensor: all_crops, tiling = prepare_crops(image, self.config.vision, device=self.device) torch._dynamo.mark_dynamic(all_crops, 0) outputs = self._vis_enc(all_crops) global_features = outputs[0] local_features = outputs[1:].view( -1, self.config.vision.enc_n_layers, self.config.vision.enc_n_layers, self.config.vision.enc_dim, ) reconstructed = reconstruct_from_crops( local_features, tiling, patch_size=1, overlap_margin=self.config.vision.overlap_margin, ) return self._vis_proj(global_features, reconstructed) def encode_image( self, image: Union[Image.Image, EncodedImage], settings: Optional[ImageEncodingSettings] = None, ) -> EncodedImage: if isinstance(image, EncodedImage): return image elif not isinstance(image, Image.Image): raise ValueError("image must be a PIL Image or EncodedImage") lora = ( variant_state_dict(settings["variant"], device=self.device) if settings is not None and "variant" in settings else None ) # Run through text model in addition to the vision encoder, to minimize # re-computation if multiple queries are performed on this image. with torch.inference_mode(): img_emb = self._run_vision_encoder(image) bos_emb = text_encoder( torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text, ) inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1) mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :] pos_ids = torch.arange( inputs_embeds.size(1), dtype=torch.long, device=self.device ) self._prefill(inputs_embeds, mask, pos_ids, lora) return EncodedImage( pos=inputs_embeds.size(1), caches=[ ( b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(), b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(), ) for b in self.text.blocks ], ) def _apply_top_p(self, probs: torch.Tensor, top_p: float): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > top_p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_probs = torch.zeros_like(probs) next_probs.scatter_(dim=-1, index=probs_idx, src=probs_sort) return next_probs def _prefill_prompt( self, prompt_tokens: torch.Tensor, pos: int, temperature: float, top_p: float, spatial_refs: Optional[SpatialRefs] = None, attn_mask: Optional[torch.Tensor] = None, lora: Optional[dict] = None, ): with torch.inference_mode(): prompt_emb = text_encoder(prompt_tokens, self.text) if spatial_refs: encoded_refs = encode_spatial_refs(spatial_refs, self.region) prompt_emb[prompt_tokens == self.config.tokenizer.coord_id] = ( encoded_refs["coords"] ) if encoded_refs["sizes"] is not None: prompt_emb[prompt_tokens == self.config.tokenizer.size_id] = ( encoded_refs["sizes"] ) torch._dynamo.mark_dynamic(prompt_emb, 1) if attn_mask is None: attn_mask = self.attn_mask mask = attn_mask[:, :, pos : pos + prompt_emb.size(1), :] pos_ids = torch.arange( pos, pos + prompt_emb.size(1), dtype=torch.long, device=self.device ) hidden_BC = self._prefill(prompt_emb, mask, pos_ids, lora) logits_BV = lm_head(hidden_BC, self.text) if temperature == 0: next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) else: probs = torch.softmax(logits_BV / temperature, dim=-1) probs = self._apply_top_p(probs, top_p) next_token = torch.multinomial(probs, num_samples=1) pos = pos + prompt_emb.size(1) return logits_BV, hidden_BC, next_token, pos def _generate_reasoning( self, prompt_tokens, pos, settings: Optional[TextSamplingSettings] = None, spatial_refs: Optional[SpatialRefs] = None, attn_mask: Optional[torch.Tensor] = None, ) -> Tuple[int, str, List[dict]]: max_tokens = ( settings.get("max_tokens", DEFAULT_MAX_TOKENS) if settings else DEFAULT_MAX_TOKENS ) temperature = ( settings.get("temperature", DEFAULT_TEMPERATURE) if settings else DEFAULT_TEMPERATURE ) lora = ( variant_state_dict(settings["variant"], device=self.device) if settings is not None and "variant" in settings else None ) top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P eos_id = self.config.tokenizer.answer_id _, last_hidden_BC, next_token, pos = self._prefill_prompt( prompt_tokens, pos, temperature, top_p, spatial_refs, attn_mask=attn_mask, lora=lora, ) text_token_chunks = [[]] grounding_chunks = [[]] mask = torch.zeros( 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) generated_tokens = 0 while ( next_token_id := next_token.item() ) != eos_id and generated_tokens < max_tokens: if ( next_token_id == self.config.tokenizer.start_ground_points_id or next_token_id == self.config.tokenizer.end_ground_id ): text_token_chunks.append([]) grounding_chunks.append([]) text_token_chunks[-1].append(next_token_id) with torch.inference_mode(): if next_token_id == self.config.tokenizer.coord_id: coord_logits = decode_coordinate(last_hidden_BC, self.region) coord = torch.argmax(coord_logits, dim=-1) / coord_logits.size(-1) grounding_chunks[-1].append(coord.item()) next_emb = encode_coordinate( coord.to(dtype=coord_logits.dtype), self.region ).unsqueeze(0) else: next_emb = text_encoder(next_token, self.text) mask[:, :, pos], pos_ids[0] = 1, pos logits_BV, last_hidden_BC = self._decode_one_tok( next_emb, mask, pos_ids, lora ) logits_BV[:, self.config.tokenizer.eos_id] = float("-inf") logits_BV[:, self.config.tokenizer.size_id] = float("-inf") pos += 1 if temperature == 0: next_token = torch.argmax(logits_BV, dim=-1).unsqueeze(1) # (1, 1) else: probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V) probs = self._apply_top_p(probs, top_p) next_token = torch.multinomial(probs, num_samples=1) # (1, 1) generated_tokens += 1 text_chunks = [ self.tokenizer.decode(chunk_tokens) for chunk_tokens in text_token_chunks ] text = "".join(text_chunks) start_idx = 0 grounding = [] for text_chunk, grounding_chunk in zip(text_chunks, grounding_chunks): if len(grounding_chunk) > 1: points = [] for i in range(0, len(grounding_chunk) - (len(grounding_chunk) % 2), 2): points.append((grounding_chunk[i], grounding_chunk[i + 1])) grounding.append( { "start_idx": start_idx, "end_idx": start_idx + len(text_chunk), "points": points, } ) start_idx += len(text_chunk) return pos, text, grounding def _generate_answer( self, prompt_tokens: torch.Tensor, pos: int, settings: Optional[TextSamplingSettings] = None, spatial_refs: Optional[SpatialRefs] = None, eos_id: Optional[int] = None, attn_mask: Optional[torch.Tensor] = None, ): max_tokens = ( settings.get("max_tokens", DEFAULT_MAX_TOKENS) if settings else DEFAULT_MAX_TOKENS ) temperature = ( settings.get("temperature", DEFAULT_TEMPERATURE) if settings else DEFAULT_TEMPERATURE ) top_p = settings.get("top_p", DEFAULT_TOP_P) if settings else DEFAULT_TOP_P eos_id = eos_id if eos_id is not None else self.config.tokenizer.eos_id lora = ( variant_state_dict(settings["variant"], device=self.device) if settings is not None and "variant" in settings else None ) _, _, next_token, pos = self._prefill_prompt( prompt_tokens, pos, temperature, top_p, spatial_refs, attn_mask=attn_mask, lora=lora, ) def generator(next_token, pos): mask = torch.zeros( 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) generated_tokens = 0 # For properly handling token streaming with Unicode token_cache = [] print_len = 0 while ( next_token_id := next_token.item() ) != eos_id and generated_tokens < max_tokens: # Add token to our cache token_cache.append(next_token_id) # Decode all tokens collected so far text = self.tokenizer.decode(token_cache) # After a newline, we flush the cache completely if text.endswith("\n"): printable_text = text[print_len:] token_cache = [] print_len = 0 if printable_text: yield printable_text # If the last token is a CJK character, we can safely print it elif len(text) > 0 and _is_cjk_char(ord(text[-1])): printable_text = text[print_len:] print_len += len(printable_text) if printable_text: yield printable_text # Otherwise, only yield up to the last space to avoid cutting words else: last_space_idx = text.rfind(" ", print_len) if last_space_idx >= print_len: printable_text = text[print_len : last_space_idx + 1] print_len += len(printable_text) if printable_text: yield printable_text with torch.inference_mode(): next_emb = text_encoder(next_token, self.text) mask[:, :, pos], pos_ids[0] = 1, pos logits_BV, _ = self._decode_one_tok(next_emb, mask, pos_ids, lora) logits_BV[:, self.config.tokenizer.answer_id] = float("-inf") pos += 1 if temperature == 0: next_token = torch.argmax(logits_BV, dim=-1).unsqueeze( 1 ) # (1, 1) else: probs = torch.softmax(logits_BV / temperature, dim=-1) # (1, V) probs = self._apply_top_p(probs, top_p) next_token = torch.multinomial(probs, num_samples=1) # (1, 1) generated_tokens += 1 # Flush any remaining text in the cache if token_cache: text = self.tokenizer.decode(token_cache) printable_text = text[print_len:] if printable_text: yield printable_text return generator(next_token, pos) def query( self, image: Optional[Union[Image.Image, EncodedImage]] = None, question: str = None, reasoning: bool = True, spatial_refs: Optional[SpatialRefs] = None, stream: bool = False, settings: Optional[TextSamplingSettings] = None, ): if self.config.tokenizer.templates["query"] is None: raise NotImplementedError("Model does not support querying.") if question is None: raise ValueError("question must be provided.") if spatial_refs and image is None: raise ValueError("spatial_refs can only be used with an image.") attn_mask = self.attn_mask if image is not None: image = self.encode_image(image, settings) self.load_encoded_image(image) pos = image.pos prompt_toks = self.config.tokenizer.templates["query"]["prefix"] else: self._setup_caches() pos = 0 prompt_toks = [ self.config.tokenizer.bos_id ] + self.config.tokenizer.templates["query"]["prefix"] max_context = self.config.text.max_context attn_mask = torch.tril( torch.ones(1, 1, max_context, max_context, dtype=torch.bool) ).to(self.device) spatial_toks = [] if spatial_refs: for ref in spatial_refs: coord_id = self.config.tokenizer.coord_id size_id = self.config.tokenizer.size_id if len(ref) == 2: spatial_toks.extend([coord_id, coord_id]) else: spatial_toks.extend([coord_id, coord_id, size_id]) prompt_tokens = [ prompt_toks + spatial_toks + self.tokenizer.encode(question).ids ] if reasoning: prompt_tokens[0] += [self.config.tokenizer.thinking_id] prompt_tokens = torch.tensor(prompt_tokens, device=self.device) pos, reasoning_text, reasoning_grounding = self._generate_reasoning( prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask ) prompt_tokens = [self.config.tokenizer.templates["query"]["suffix"]] reasoning_dict = { "reasoning": {"text": reasoning_text, "grounding": reasoning_grounding} } else: prompt_tokens[0] += self.config.tokenizer.templates["query"]["suffix"] reasoning_dict = {} prompt_tokens = torch.tensor(prompt_tokens, device=self.device) def generator(): for token in self._generate_answer( prompt_tokens, pos, settings, spatial_refs, attn_mask=attn_mask ): yield token if stream: return {**reasoning_dict, "answer": generator()} else: return {**reasoning_dict, "answer": "".join(list(generator()))} def load_encoded_image(self, encoded_image: EncodedImage): for b, (k, v) in zip(self.text.blocks, encoded_image.caches): b.kv_cache.k_cache[:, :, : k.size(2), :] = k b.kv_cache.v_cache[:, :, : v.size(2), :] = v def caption( self, image: Union[Image.Image, EncodedImage], length: Literal["normal", "short", "long"] = "normal", stream: bool = False, settings: Optional[TextSamplingSettings] = None, ): if self.config.tokenizer.templates["caption"] is None: raise NotImplementedError("Model does not support captioning.") if length not in self.config.tokenizer.templates["caption"]: raise ValueError(f"Model does not support caption length '{length}'.") image = self.encode_image(image, settings) self.load_encoded_image(image) prompt_tokens = torch.tensor( [self.config.tokenizer.templates["caption"][length]], device=self.device ) def generator(): for token in self._generate_answer(prompt_tokens, image.pos, settings): yield token if stream: return {"caption": generator()} else: return {"caption": "".join(list(generator()))} def _generate_points( self, hidden: torch.Tensor, next_token: torch.Tensor, pos: int, include_size: bool = True, max_objects: int = DEFAULT_MAX_OBJECTS, lora: Optional[dict] = None, ): out = [] mask = torch.zeros( 1, 1, self.config.text.max_context, device=self.device, dtype=torch.bool ) mask[:, :, :pos] = 1 pos_ids = torch.tensor([pos], device=self.device, dtype=torch.long) with torch.inference_mode(): while ( next_token.item() != self.config.tokenizer.eos_id and len(out) < max_objects ): x_logits = decode_coordinate(hidden, self.region) x_center = torch.argmax(x_logits, dim=-1) / x_logits.size(-1) next_emb = encode_coordinate( x_center.to(dtype=x_logits.dtype), self.region ).unsqueeze(0) # Decode y-coordinate mask[:, :, pos], pos_ids[0] = 1, pos _, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora) pos += 1 y_logits = decode_coordinate(hidden, self.region) y_center = torch.argmax(y_logits, dim=-1) / y_logits.size(-1) next_emb = encode_coordinate( y_center.to(dtype=y_logits.dtype), self.region ).unsqueeze(0) # Decode size if include_size: mask[:, :, pos], pos_ids[0] = 1, pos logits, hidden = self._decode_one_tok(next_emb, mask, pos_ids, lora) pos += 1 size_logits = decode_size(hidden, self.region) # Get bin indices from the logits w_bin = torch.argmax(size_logits[0], dim=-1) h_bin = torch.argmax(size_logits[1], dim=-1) # Convert from bin indices to actual size values using the inverse of the log-scale mapping # Formula: size = 2^((bin / 1023.0) * 10.0 - 10.0) w = torch.pow(2.0, (w_bin.float() / 1023.0) * 10.0 - 10.0) h = torch.pow(2.0, (h_bin.float() / 1023.0) * 10.0 - 10.0) next_emb = ( encode_size( torch.tensor( [w, h], device=self.device, dtype=size_logits.dtype ), self.region, ) .unsqueeze(0) .unsqueeze(0) ) # Add object out.append( { "x_min": x_center.item() - w.item() / 2, "y_min": y_center.item() - h.item() / 2, "x_max": x_center.item() + w.item() / 2, "y_max": y_center.item() + h.item() / 2, } ) else: out.append({"x": x_center.item(), "y": y_center.item()}) # Decode next token (x-coordinate, or eos) mask[:, :, pos], pos_ids[0] = 1, pos logits, hidden = self._decode_one_tok( next_emb, mask, pos_ids, lora, lm_head_indices=self.point_gen_indices, ) pos += 1 # Map back: index 0 -> coord_id, index 1 -> eos_id next_token_idx = torch.argmax(logits, dim=-1) next_token = self.point_gen_indices[next_token_idx] return out def detect( self, image: Union[Image.Image, EncodedImage], object: str, settings: Optional[ObjectSamplingSettings] = None, ): if self.config.tokenizer.templates["detect"] is None: raise NotImplementedError("Model does not support object detection.") image = self.encode_image(image, settings) self.load_encoded_image(image) prompt_tokens = torch.tensor( [ self.config.tokenizer.templates["detect"]["prefix"] + self.tokenizer.encode(" " + object).ids + self.config.tokenizer.templates["detect"]["suffix"] ], device=self.device, ) lora = ( variant_state_dict(settings["variant"], device=self.device) if settings is not None and "variant" in settings else None ) _, hidden, next_token, pos = self._prefill_prompt( prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora ) hidden = hidden[:, -1:, :] max_objects = ( settings.get("max_objects", DEFAULT_MAX_OBJECTS) if settings else DEFAULT_MAX_OBJECTS ) objects = self._generate_points( hidden, next_token, pos, include_size=True, max_objects=max_objects, lora=lora, ) return {"objects": objects} def point( self, image: Union[Image.Image, EncodedImage], object: str, settings: Optional[ObjectSamplingSettings] = None, ): if self.config.tokenizer.templates["point"] is None: raise NotImplementedError("Model does not support pointing.") image = self.encode_image(image, settings) self.load_encoded_image(image) prompt_tokens = torch.tensor( [ self.config.tokenizer.templates["point"]["prefix"] + self.tokenizer.encode(" " + object).ids + self.config.tokenizer.templates["point"]["suffix"] ], device=self.device, ) lora = ( variant_state_dict(settings["variant"], device=self.device) if settings is not None and "variant" in settings else None ) _, hidden, next_token, pos = self._prefill_prompt( prompt_tokens, image.pos, temperature=0, top_p=0, lora=lora ) hidden = hidden[:, -1:, :] max_objects = ( settings.get("max_objects", DEFAULT_MAX_OBJECTS) if settings else DEFAULT_MAX_OBJECTS ) objects = self._generate_points( hidden, next_token, pos, include_size=False, max_objects=max_objects, lora=lora, ) return {"points": objects} def _detect_gaze( self, image: EncodedImage, source: Tuple[float, float], force_detect: bool = False, ): with torch.inference_mode(): before_emb = text_encoder( torch.tensor( [self.tokenizer.encode("\n\nPoint:").ids], device=self.device ), self.text, ) after_emb = text_encoder( torch.tensor( [self.tokenizer.encode(" gaze\n\n").ids], device=self.device ), self.text, ) x_emb = encode_coordinate( torch.tensor([[[source[0]]]], device=self.device, dtype=torch.bfloat16), self.region, ) y_emb = encode_coordinate( torch.tensor([[[source[1]]]], device=self.device, dtype=torch.bfloat16), self.region, ) prompt_emb = torch.cat([before_emb, x_emb, y_emb, after_emb], dim=1) self.load_encoded_image(image) mask = self.attn_mask[:, :, image.pos : image.pos + prompt_emb.size(1), :] pos_ids = torch.arange( image.pos, image.pos + prompt_emb.size(1), dtype=torch.long, device=self.device, ) hidden = self._prefill(prompt_emb, mask, pos_ids, lora=None) logits = lm_head(hidden, self.text) next_token = torch.argmax(logits, dim=-1) pos = image.pos + prompt_emb.size(1) hidden = hidden[:, -1:, :] if force_detect: next_token = torch.tensor([[0]], device=self.device) if next_token.item() == self.config.tokenizer.eos_id: return None gaze = self._generate_points( hidden, next_token, pos, include_size=False, max_objects=1 ) return gaze[0] def detect_gaze( self, image: Union[Image.Image, EncodedImage], eye: Optional[Tuple[float, float]] = None, face: Optional[Dict[str, float]] = None, unstable_settings: Dict[str, Any] = {}, ): if "force_detect" in unstable_settings: force_detect = unstable_settings["force_detect"] else: force_detect = False if "prioritize_accuracy" in unstable_settings: prioritize_accuracy = unstable_settings["prioritize_accuracy"] else: prioritize_accuracy = False if not prioritize_accuracy: if eye is None: raise ValueError("eye must be provided when prioritize_accuracy=False") image = self.encode_image(image) return {"gaze": self._detect_gaze(image, eye, force_detect=force_detect)} else: if ( not isinstance(image, Image.Image) and "flip_enc_img" not in unstable_settings ): raise ValueError( "image must be a PIL Image when prioritize_accuracy=True, " "or flip_enc_img must be provided" ) if face is None: raise ValueError("face must be provided when prioritize_accuracy=True") encoded_image = self.encode_image(image) if ( isinstance(image, Image.Image) and "flip_enc_img" not in unstable_settings ): flipped_pil = image.copy() flipped_pil = flipped_pil.transpose(method=Image.FLIP_LEFT_RIGHT) encoded_flipped_image = self.encode_image(flipped_pil) else: encoded_flipped_image = unstable_settings["flip_enc_img"] N = 10 detections = [ self._detect_gaze( encoded_image, ( random.uniform(face["x_min"], face["x_max"]), random.uniform(face["y_min"], face["y_max"]), ), force_detect=force_detect, ) for _ in range(N) ] detections = [ (gaze["x"], gaze["y"]) for gaze in detections if gaze is not None ] flipped_detections = [ self._detect_gaze( encoded_flipped_image, ( 1 - random.uniform(face["x_min"], face["x_max"]), random.uniform(face["y_min"], face["y_max"]), ), force_detect=force_detect, ) for _ in range(N) ] detections.extend( [ (1 - gaze["x"], gaze["y"]) for gaze in flipped_detections if gaze is not None ] ) if len(detections) < N: return {"gaze": None} detections = remove_outlier_points(detections) mean_gaze = ( sum(gaze[0] for gaze in detections) / len(detections), sum(gaze[1] for gaze in detections) / len(detections), ) return {"gaze": {"x": mean_gaze[0], "y": mean_gaze[1]}} def _is_cjk_char(cp): """Checks whether CP is the codepoint of a CJK character.""" # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) if ( (cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) or (cp >= 0x2F800 and cp <= 0x2FA1F) ): return True return False ================================================ FILE: moondream/torch/region.py ================================================ import torch import torch.nn as nn import math from typing import List, Tuple, Union SpatialRefs = List[Union[Tuple[float, float], Tuple[float, float, float, float]]] def fourier_features(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: """ Applies Fourier feature mapping to input tensor x using frequency matrix w. This projects inputs through sinusoidal functions to create higher dimensional features that help mitigate spectral bias - the tendency of neural networks to learn low-frequency functions more easily than high-frequency ones. By explicitly mapping inputs to higher frequencies through sin/cos transformations, we enable better learning of fine details and higher frequency patterns. Args: x: Input tensor to transform w: Matrix of frequencies for the Fourier features transformation Returns: Concatenated cosine and sine transformed features as a tensor """ f = 2 * math.pi * x @ w return torch.cat([f.cos(), f.sin()], dim=-1) def encode_coordinate(coord: torch.Tensor, w: nn.Module) -> torch.Tensor: """ Takes as input a tensor containing a single float coordinate value (x or y) and encodes it into hidden states for input to the text model. Args: coord: Tensor with single float coordinate value Returns: Encoded hidden states tensor for input to text model """ return w.coord_encoder(fourier_features(coord, w.coord_features)) def decode_coordinate(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: """ Takes as input the last hidden state from the text model and outputs a single logit representing either an x or y coordinate prediction. Args: hidden_state: The final hidden state tensor from the text model. Returns: A single logit representing the predicted coordinate value (x or y) """ return w.coord_decoder(hidden_state) def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor: """ Takes a tensor containing width and height values and encodes them into hidden states for input to the text model. Args: size: Tensor with two floats for width and height Returns: Encoded hidden states tensor for input to text model """ return w.size_encoder(fourier_features(size, w.size_features)) def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor: """ Takes as input the last hidden state from the text model and outputs logits for 1024 bins representing width and height in log-scale. The bins are distributed according to the formula: bin = (log2(size) + 10.0) / 10.0 * 1023.0 where size values are clamped to be at least 1/1024. To convert from bin back to size: size = 2^((bin / 1023.0) * 10.0 - 10.0) Args: hidden_state: The final hidden state tensor from the text model. Returns: A tensor containing logits for 1024 bins for width and height. Shape is (2, 1024) where the first dimension corresponds to width and height. """ return w.size_decoder(hidden_state).view(2, -1) def encode_spatial_refs(spatial_refs: SpatialRefs, w: nn.Module) -> torch.Tensor: """ Takes a list of spatial references (points or regions) and encodes them into hidden states for input to the text model. Args: spatial_refs: List of spatial references (points or boxes) - Points are represented as normalized (x, y) tuples - Boxes are represented as normalized (x_min, y_min, x_max, y_max) tuples Returns: {"coords": torch.Tensor, "sizes": Optional[torch.Tensor]} """ coords, sizes = [], [] for ref in spatial_refs: if len(ref) == 2: coords.append(ref[0]) coords.append(ref[1]) else: x_c = (ref[0] + ref[2]) / 2 y_c = (ref[1] + ref[3]) / 2 width = ref[2] - ref[0] height = ref[3] - ref[1] coords.append(x_c) coords.append(y_c) sizes.append([width, height]) coords = torch.tensor( coords, device=w.coord_features.device, dtype=w.coord_features.dtype ).view(-1, 1) coords = encode_coordinate(coords, w) if sizes: sizes = torch.tensor( sizes, device=w.size_features.device, dtype=w.size_features.dtype ) sizes = encode_size(sizes, w) else: sizes = None return {"coords": coords, "sizes": sizes} ================================================ FILE: moondream/torch/rope.py ================================================ # Ethically sourced from https://github.com/xjdr-alt/entropix import torch def precompute_freqs_cis( dim: int, end: int, theta: float = 1500000.0, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) t = torch.arange(end, dtype=dtype).unsqueeze(1) freqs = t * freqs.unsqueeze(0) freqs = torch.exp(1j * freqs) return torch.stack([freqs.real, freqs.imag], dim=-1) def apply_rotary_emb( x: torch.Tensor, freqs_cis: torch.Tensor, position_ids: torch.Tensor, num_heads: int, rot_dim: int = 32, interleave: bool = False, ) -> torch.Tensor: assert rot_dim == freqs_cis.shape[-2] * 2 assert num_heads == x.shape[1] x_rot, x_pass = x[..., :rot_dim], x[..., rot_dim:] if interleave: xq_r = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0] xq_i = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1] else: d_q = x_rot.shape[-1] // 2 xq_r, xq_i = x_rot[..., :d_q], x_rot[..., d_q:] freqs_cos = freqs_cis[..., 0][position_ids, :].unsqueeze(0).unsqueeze(0) freqs_sin = freqs_cis[..., 1][position_ids, :].unsqueeze(0).unsqueeze(0) # Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xq_out = torch.stack((xq_out_r, xq_out_i), dim=-1).flatten(-2) return torch.cat([xq_out.to(x.dtype), x_pass], dim=-1) ================================================ FILE: moondream/torch/sample.py ================================================ import argparse import json import os import torch from PIL import Image, ImageDraw from tqdm import tqdm from .weights import load_weights_into_model from .moondream import MoondreamModel, MoondreamConfig if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--image", "-i", type=str, required=True) parser.add_argument("--prompt", "-p", type=str, required=True) parser.add_argument("--model", "-m", type=str, required=True) parser.add_argument("--config", "-c", type=str, default=None) parser.add_argument("--max-tokens", "-t", type=int, default=200) parser.add_argument("--sampler", "-s", type=str, default="greedy") parser.add_argument("--benchmark", "-b", action="store_true") args = parser.parse_args() if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" print(f"Using device: {device}") # Load model. if args.config is not None: with open(args.config, "r") as f: config = json.load(f) config = MoondreamConfig.from_dict(config) else: config = MoondreamConfig() model = MoondreamModel(config) load_weights_into_model(args.model, model) model.to(device, dtype=torch.bfloat16) model.compile() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.reset_accumulated_memory_stats() # Encode image. image_path = args.image if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") image = Image.open(image_path) if not args.benchmark: encoded_image = model.encode_image(image) # Text query text_query = "What is the capital of Washington, USA? Answer in JSON format." print("Query:", text_query) text_response = model.query(None, text_query, reasoning=True, stream=True) print("Reasoning:", text_response["reasoning"]) for t in text_response["answer"]: print(t, end="", flush=True) print() print() # Short caption print("Caption: short") for t in model.caption(encoded_image, "short", stream=True)["caption"]: print(t, end="", flush=True) print() print() # Regular caption print("Caption: normal") for t in model.caption(encoded_image, "normal", stream=True)["caption"]: print(t, end="", flush=True) print() print() # Long caption print("Caption: long") for t in model.caption(encoded_image, "long", stream=True)["caption"]: print(t, end="", flush=True) print() print() # Query print("Query:", args.prompt) for t in model.query( encoded_image, args.prompt, stream=True, settings={"variant": "geoguesser_lora_only"}, )["answer"]: print(t, end="", flush=True) print() print() # Query (reasoning) reasoning_prompt = "How many sesame seeds are on the burger?" print("Query (reasoning):", reasoning_prompt) resp = model.query(encoded_image, reasoning_prompt, reasoning=True, stream=True) print("Reasoning:", resp["reasoning"]) for t in resp["answer"]: print(t, end="", flush=True) print() print() # Detect obj = "hand" print(f"Detect: {obj}") objs = model.detect(encoded_image, obj)["objects"] print(f"Found {len(objs)}") print() draw = ImageDraw.Draw(image) for obj in objs: x_min, y_min, x_max, y_max = ( obj["x_min"] * image.width, obj["y_min"] * image.height, obj["x_max"] * image.width, obj["y_max"] * image.height, ) draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2) image.save("detect.jpg") # Spatial query if len(objs) > 0: print("Spatial query: What is this?") for t in model.query( encoded_image, "What is this?", spatial_refs=[ [ (obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]) for obj in objs ][0] ], stream=True, )["answer"]: print(t, end="", flush=True) print() print() # Point obj = "ear" print(f"Point: {obj}") points = model.point(encoded_image, obj)["points"] print(f"Found {len(points)}") draw = ImageDraw.Draw(image) for point in points: x, y = point["x"] * image.width, point["y"] * image.height draw.ellipse([x - 5, y - 5, x + 5, y + 5], fill="red") image.save("point.jpg") print() print() # Spatial query if len(objs) > 0: for o in ["hand", "ear", "face"]: for k in [(objs, "hand"), (points, "ear")]: print(f"Spatial query: Is this a {o}? ({k[1]})") for t in model.query( encoded_image, f"Is this a {o}?", spatial_refs=[ [ ( ( obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"], ) if "x_min" in obj else (obj["x"], obj["y"]) ) for obj in k[0] ][0] ], )["answer"]: print(t, end="", flush=True) print() # Detect gaze model.detect_gaze(encoded_image, (0.5, 0.5)) elif model.device.type != "mps": # Warmup runs for _ in tqdm(range(5), desc="Warmup"): encoded_image = model.encode_image(image) for _ in model.query(encoded_image, args.prompt, stream=True)["answer"]: pass # Benchmark runs encode_times = [] query_speeds = [] for i in tqdm(range(10), desc="Benchmark"): # Measure encode time start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() encoded_image = model.encode_image(image) end.record() torch.cuda.synchronize() encode_time = start.elapsed_time(end) encode_times.append(encode_time) # Measure query speed tokens = [] query_start = torch.cuda.Event(enable_timing=True) query_end = torch.cuda.Event(enable_timing=True) query_start.record() for t in model.query(encoded_image, args.prompt, stream=True)["answer"]: tokens.append(t) query_end.record() torch.cuda.synchronize() query_time = query_start.elapsed_time(query_end) tokens_per_sec = len(tokens) / (query_time / 1000.0) # Convert ms to s query_speeds.append(tokens_per_sec) # Print results print("\nBenchmark Results (10 runs):") print("Image Encoding Time (ms):") print(f" Mean: {sum(encode_times)/len(encode_times):.2f}") print(f" Min: {min(encode_times):.2f}") print(f" Max: {max(encode_times):.2f}") print("\nQuery Speed (tokens/sec):") print(f" Mean: {sum(query_speeds)/len(query_speeds):.2f}") print(f" Min: {min(query_speeds):.2f}") print(f" Max: {max(query_speeds):.2f}") print(torch.cuda.memory_summary(abbreviated=False)) ================================================ FILE: moondream/torch/text.py ================================================ import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.attention.flex_attention import flex_attention from typing import Optional from .layers import layer_norm, mlp, QuantizedLinear, moe_mlp from .rope import apply_rotary_emb, precompute_freqs_cis from .config import TextConfig def text_encoder(input_ids: torch.Tensor, w: nn.Module): return F.embedding(input_ids, w.wte) def attn( x: torch.Tensor, w: nn.Module, freqs_cis: torch.Tensor, kv_cache: nn.Module, attn_mask: torch.Tensor, n_heads: int, n_kv_heads: int, position_ids: torch.Tensor, lora: Optional[dict] = None, flex_block_mask_slice=None, ): bsz, q_len, d_model = x.shape head_dim = d_model // n_heads qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim) if lora is not None: qkv_out += F.linear(F.linear(x, lora["qkv"]["A"]), lora["qkv"]["B"]) q_dim = n_heads * head_dim kv_dim = n_kv_heads * head_dim q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1) q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2) if hasattr(w, "tau") and w.tau is not None: tok_feat = F.gelu(qkv_out) tok_q = torch.tanh(torch.matmul(tok_feat, w.tau["wq"].t())).permute(0, 2, 1) tok_v = torch.tanh(torch.matmul(tok_feat, w.tau["wv"].t())).permute(0, 2, 1) pos = position_ids.to(q.dtype) + 1 tau_pos = 1 + ( torch.sigmoid(w.tau["alpha"][:, None] * pos.log()) - 0.5 ) # (H,S) tau_q = (tok_q + tau_pos[None]).unsqueeze(-1) # (B,H,S,1) tau_v = (tok_v + tau_pos[None]).unsqueeze(-1) q = q * tau_q v = v * tau_v q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads) if kv_cache is not None: k, v = kv_cache.update(position_ids, k, v) if flex_block_mask_slice is not None: torch._assert(n_heads == n_kv_heads, "gqa not supported yet") out = flex_attention(q, k, v, block_mask=flex_block_mask_slice) else: out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, enable_gqa=n_heads != n_kv_heads ) out = out.transpose(1, 2).reshape(bsz, q_len, d_model) out0 = w.proj(out) if lora is not None: out1 = F.linear(F.linear(x, lora["proj"]["A"]), lora["proj"]["B"]) out = out0 + out1 else: out = out0 return out def text_decoder( x: torch.Tensor, w: nn.Module, attn_mask: torch.Tensor, position_ids: torch.Tensor, config: TextConfig, lora: Optional[dict] = None, flex_block_mask_slice=None, ): for i, block in enumerate(w.blocks): if lora is not None: layer_lora = lora["text"]["blocks"][str(i)] mlp_lora = layer_lora["mlp"] attn_lora = layer_lora["attn"] else: mlp_lora = None attn_lora = None l_in = layer_norm(x, block.ln) l_attn = attn( l_in, block.attn, freqs_cis=w.freqs_cis, kv_cache=block.kv_cache, attn_mask=attn_mask, n_heads=config.n_heads, n_kv_heads=config.n_kv_heads, position_ids=position_ids, lora=attn_lora, flex_block_mask_slice=flex_block_mask_slice, ) if config.moe is not None and i >= config.moe.start_layer: l_mlp = moe_mlp(l_in, block.mlp, config.moe.experts_per_token) else: l_mlp = mlp(l_in, block.mlp, lora=mlp_lora) x = x + l_attn + l_mlp return x def lm_head( hidden_BTC: torch.Tensor, w: nn.Module, indices: Optional[torch.Tensor] = None ): hidden_BC = hidden_BTC[:, -1, :] hidden_BC = layer_norm(hidden_BC, w.post_ln) if indices is not None: # Only compute logits for specified token indices logits = hidden_BC @ w.lm_head.weight[indices].T + w.lm_head.bias[indices] else: logits = w.lm_head(hidden_BC) return logits def build_dense_mlp(d_model, d_ffn, dtype, linear_cls): return nn.ModuleDict( { "fc1": linear_cls(d_model, d_ffn, dtype=dtype), "fc2": linear_cls(d_ffn, d_model, dtype=dtype), } ) def build_moe_mlp(d_model, d_ffn, n_experts, dtype): # For GeGLU, fc1 needs to output 2 * d_ffn (for gating) return nn.ModuleDict( { "router": nn.Linear(d_model, n_experts, dtype=dtype), "fc1": nn.ParameterDict( { "weight": nn.Parameter( torch.empty(n_experts, 2 * d_ffn, d_model, dtype=dtype) ) } ), "fc2": nn.ParameterDict( { "weight": nn.Parameter( torch.empty(n_experts, d_model, d_ffn, dtype=dtype) ) } ), } ) def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: qkv_dim = int(config.dim * (1 + 2 * config.n_kv_heads / config.n_heads)) linear_cls = QuantizedLinear if config.group_size is not None else nn.Linear text = nn.ModuleDict( { "blocks": nn.ModuleList( [ nn.ModuleDict( { "ln": nn.LayerNorm(config.dim, dtype=dtype), "attn": nn.ModuleDict( { "qkv": linear_cls(config.dim, qkv_dim, dtype=dtype), "proj": linear_cls( config.dim, config.dim, dtype=dtype ), "tau": nn.ParameterDict( { "wq": nn.Parameter( torch.empty( config.n_heads, qkv_dim, dtype=dtype ) ), "wv": nn.Parameter( torch.empty( config.n_heads, qkv_dim, dtype=dtype ) ), "alpha": nn.Parameter( torch.empty(config.n_heads, dtype=dtype) ), } ), } ), "mlp": ( build_moe_mlp( config.dim, config.moe.expert_inner_dim, config.moe.num_experts, dtype, ) if config.moe is not None and layer_idx >= config.moe.start_layer else build_dense_mlp( config.dim, config.ff_dim, dtype, linear_cls ) ), } ) for layer_idx in range(config.n_layers) ] ), "post_ln": nn.LayerNorm(config.dim, dtype=dtype), "lm_head": nn.Linear(config.dim, config.vocab_size, dtype=dtype), } ) text.wte = nn.Parameter(torch.empty(config.vocab_size, config.dim, dtype=dtype)) text.register_buffer( "freqs_cis", precompute_freqs_cis(config.dim // (2 * config.n_heads), config.max_context), persistent=False, ) return text ================================================ FILE: moondream/torch/utils.py ================================================ import numpy as np def remove_outlier_points(points_tuples, k_nearest=2, threshold=2.0): """ Robust outlier detection for list of (x,y) tuples. Only requires numpy. Args: points_tuples: list of (x,y) tuples k_nearest: number of neighbors to consider threshold: multiplier for median distance Returns: list: filtered list of (x,y) tuples with outliers removed list: list of booleans indicating which points were kept (True = kept) """ points = np.array(points_tuples) n_points = len(points) # Calculate pairwise distances manually dist_matrix = np.zeros((n_points, n_points)) for i in range(n_points): for j in range(i + 1, n_points): # Euclidean distance between points i and j dist = np.sqrt(np.sum((points[i] - points[j]) ** 2)) dist_matrix[i, j] = dist dist_matrix[j, i] = dist # Get k nearest neighbors' distances k = min(k_nearest, n_points - 1) neighbor_distances = np.partition(dist_matrix, k, axis=1)[:, :k] avg_neighbor_dist = np.mean(neighbor_distances, axis=1) # Calculate mask using median distance median_dist = np.median(avg_neighbor_dist) mask = avg_neighbor_dist <= threshold * median_dist # Return filtered tuples and mask filtered_tuples = [t for t, m in zip(points_tuples, mask) if m] return filtered_tuples ================================================ FILE: moondream/torch/vision.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Union, Tuple from PIL import Image from .layers import attn, layer_norm, mlp from .image_crops import overlap_crop_image from .config import VisionConfig if torch.backends.mps.is_available(): # Non-divisible input sizes are not implemented on MPS device yet. # https://github.com/pytorch/pytorch/issues/96056 def adaptive_avg_pool2d(input, output_size): return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps") else: adaptive_avg_pool2d = F.adaptive_avg_pool2d DeviceLike = Union[str, torch.device, int] def prepare_crops( image: Image.Image, config: VisionConfig, device: DeviceLike ) -> Tuple[torch.Tensor, Tuple[int, int]]: np_image = np.array(image.convert("RGB")) overlap_crops = overlap_crop_image( np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin ) all_crops = overlap_crops["crops"] all_crops = np.transpose(all_crops, (0, 3, 1, 2)) all_crops = ( torch.from_numpy(all_crops) .to(device=device, dtype=torch.bfloat16) .div_(255.0) .sub_(0.5) .div_(0.5) ) return all_crops, overlap_crops["tiling"] def create_patches(x, patch_size): # Original shape: [B, C, H, W] B, C, H, W = x.shape P1 = P2 = patch_size # Step 1: Split H and W dimensions into patches # [B, C, H/P1, P1, W/P2, P2] x = x.reshape(B, C, H // P1, P1, W // P2, P2) # Step 2: Rearrange dimensions to match target shape # [B, H/P1, W/P2, C, P1, P2] x = x.permute(0, 2, 4, 1, 3, 5) # Step 3: Combine dimensions to get final shape # [B, (H/P1)*(W/P2), C*P1*P2] x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2) return x def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig): x = create_patches(input_BCHW, config.enc_patch_size) x = w.patch_emb(x) x = x + w.pos_emb for block in w.blocks: x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads) x = x + mlp(layer_norm(x, block.ln2), block.mlp) x = layer_norm(x, w.post_ln) return x def vision_projection( global_features: torch.Tensor, reconstructed: torch.Tensor, w: nn.Module, config: VisionConfig, ): reconstructed = reconstructed.permute(2, 0, 1) reconstructed = adaptive_avg_pool2d( reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers) ) reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim) final_features = torch.cat([global_features, reconstructed], dim=-1) return mlp(final_features, w.proj_mlp) def build_vision_model(config: VisionConfig, dtype: torch.dtype): patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels grid_size = config.crop_size // config.enc_patch_size num_patches = grid_size * grid_size vision = nn.ModuleDict( { "patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype), "blocks": nn.ModuleList( [ nn.ModuleDict( { "ln1": nn.LayerNorm(config.enc_dim, dtype=dtype), "attn": nn.ModuleDict( { "qkv": nn.Linear( config.enc_dim, 3 * config.enc_dim, dtype=dtype ), "proj": nn.Linear( config.enc_dim, config.enc_dim, dtype=dtype ), } ), "ln2": nn.LayerNorm(config.enc_dim, dtype=dtype), "mlp": nn.ModuleDict( { "fc1": nn.Linear( config.enc_dim, config.enc_ff_dim, dtype=dtype ), "fc2": nn.Linear( config.enc_ff_dim, config.enc_dim, dtype=dtype ), } ), } ) for _ in range(config.enc_n_layers) ] ), "post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype), "proj_mlp": nn.ModuleDict( { "fc1": nn.Linear( config.enc_dim * 2, config.proj_inner_dim, dtype=dtype ), "fc2": nn.Linear( config.proj_inner_dim, config.proj_out_dim, dtype=dtype ), } ), } ) vision.pos_emb = nn.Parameter( torch.zeros(1, num_patches, config.enc_dim, dtype=dtype) ) return vision ================================================ FILE: moondream/torch/weights.py ================================================ import safetensors import torch import torch.nn as nn from contextlib import contextmanager from typing import Callable, List @contextmanager def safetensors_open(safetensors_file: str): """ Simplify interfacing with safetensors files. Eliminates the need to ignore type errors when using the `safe_open` function. """ with safetensors.safe_open( safetensors_file, framework="pt" ) as st: # pyright: ignore def get_tensor(name: str) -> torch.Tensor: return st.get_tensor(name) def get_keys() -> List[str]: return st.keys() get_tensor.keys = get_keys yield get_tensor def _load_weights(get_tensor: Callable[[str], torch.Tensor], model: nn.Module) -> None: """Internal function to load weights using a tensor getter function.""" model = model.to(dtype=torch.bfloat16) vision = model.vision region = model.region weight_map = { "vision_encoder.encoder.model.visual.patch_embed.linear.weight": vision[ "patch_emb" ].weight, "vision_encoder.encoder.model.visual.patch_embed.linear.bias": vision[ "patch_emb" ].bias, "vision_encoder.encoder.model.visual.pos_embed": vision.pos_emb, "vision_encoder.encoder.model.visual.norm.weight": vision["post_ln"].weight, "vision_encoder.encoder.model.visual.norm.bias": vision["post_ln"].bias, "vision_encoder.projection.mlp.fc1.weight": vision["proj_mlp"]["fc1"].weight, "vision_encoder.projection.mlp.fc1.bias": vision["proj_mlp"]["fc1"].bias, "vision_encoder.projection.mlp.fc2.weight": vision["proj_mlp"]["fc2"].weight, "vision_encoder.projection.mlp.fc2.bias": vision["proj_mlp"]["fc2"].bias, "text_model.transformer.embd.wte.weight": model.text.wte, "text_model.lm_head.ln.weight": model.text["post_ln"].weight, "text_model.lm_head.ln.bias": model.text["post_ln"].bias, "text_model.lm_head.linear.weight": model.text["lm_head"].weight, "text_model.lm_head.linear.bias": model.text["lm_head"].bias, "region_model.coordinate_encoder.weight": region["coord_encoder"].weight, "region_model.coordinate_encoder.bias": region["coord_encoder"].bias, "region_model.coordinate_head.weight": region["coord_decoder"].weight, "region_model.coordinate_head.bias": region["coord_decoder"].bias, "region_model.size_encoder.weight": region["size_encoder"].weight, "region_model.size_encoder.bias": region["size_encoder"].bias, "region_model.size_head.weight": region["size_decoder"].weight, "region_model.size_head.bias": region["size_decoder"].bias, } for i in range(len(model.vision["blocks"])): prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" blk = model.vision["blocks"][i] weight_map.update( { f"{prefix}.norm1.weight": blk["ln1"].weight, f"{prefix}.norm1.bias": blk["ln1"].bias, f"{prefix}.norm2.weight": blk["ln2"].weight, f"{prefix}.norm2.bias": blk["ln2"].bias, f"{prefix}.attn.qkv.weight": blk["attn"]["qkv"].weight, f"{prefix}.attn.qkv.bias": blk["attn"]["qkv"].bias, f"{prefix}.attn.proj.weight": blk["attn"]["proj"].weight, f"{prefix}.attn.proj.bias": blk["attn"]["proj"].bias, f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, } ) for i in range(len(model.text["blocks"])): prefix = f"text_model.transformer.h.{i}" blk = model.text["blocks"][i] is_moe = hasattr(blk.mlp, "router") weight_map.update( { f"{prefix}.ln.weight": blk["ln"].weight, f"{prefix}.ln.bias": blk["ln"].bias, f"{prefix}.mixer.Wqkv.weight": blk["attn"]["qkv"].weight, f"{prefix}.mixer.Wqkv.bias": blk["attn"]["qkv"].bias, f"{prefix}.mixer.out_proj.weight": blk["attn"]["proj"].weight, f"{prefix}.mixer.out_proj.bias": blk["attn"]["proj"].bias, f"{prefix}.tau_wq": blk["attn"]["tau"]["wq"], f"{prefix}.tau_wv": blk["attn"]["tau"]["wv"], f"{prefix}.tau_alpha": blk["attn"]["tau"]["alpha"], } ) if is_moe: weight_map.update( { f"{prefix}.gate.weight": blk["mlp"]["router"].weight, f"{prefix}.gate.bias": blk["mlp"]["router"].bias, f"{prefix}.mlp.experts.weight": blk["mlp"]["fc1"].weight, f"{prefix}.mlp.output_experts.weight": blk["mlp"]["fc2"].weight, } ) else: weight_map.update( { f"{prefix}.mlp.fc1.weight": blk["mlp"]["fc1"].weight, f"{prefix}.mlp.fc1.bias": blk["mlp"]["fc1"].bias, f"{prefix}.mlp.fc2.weight": blk["mlp"]["fc2"].weight, f"{prefix}.mlp.fc2.bias": blk["mlp"]["fc2"].bias, } ) for key, tensor in weight_map.items(): tensor.data.copy_(get_tensor(key)) region.coord_features.data.copy_( get_tensor("region_model.coordinate_features.weight").T ) region.size_features.data.copy_(get_tensor("region_model.size_features.weight").T) def load_weights_from_safetensors(weights_file: str, model: nn.Module) -> None: """Load weights from a safetensors file into a MoondreamModel instance.""" with safetensors_open(weights_file) as get_tensor: if ( "vision.blocks.0.attn.proj.bias" in get_tensor.keys() or "model.vision.blocks.0.attn.proj.bias" in get_tensor.keys() ): with safetensors_open(weights_file) as get_tensor: tensors = { k.replace("model.", ""): get_tensor(k) for k in get_tensor.keys() } model.load_state_dict(tensors, strict=False) else: # Wrap the get_tensor function to handle key normalization name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()} _load_weights( lambda x: get_tensor(name_map[x]).to(dtype=torch.bfloat16), model ) def load_weights_from_pt(weights_file: str, model: nn.Module) -> None: """Load weights from a PyTorch file into a MoondreamModel instance.""" device = str(torch.empty(0).device) tensors = torch.load(weights_file, map_location=device, weights_only=True) if "vision.blocks.0.attn.proj.bias" in tensors.keys(): missing_keys, unexpected_keys = model.load_state_dict(tensors, strict=False) print("Missing keys:", missing_keys) print("Unexpected keys:", unexpected_keys) else: tensors = { k.replace("._orig_mod", ""): v.to(dtype=torch.bfloat16) for k, v in tensors.items() } _load_weights(lambda x: tensors[x], model) def load_weights_into_model(weights_file: str, model: nn.Module) -> None: """ Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance. Args: weights_file: Path to weights file (either .safetensors or .pt) model: MoondreamModel instance to load weights into """ if weights_file.endswith(".safetensors"): load_weights_from_safetensors(weights_file, model) else: load_weights_from_pt(weights_file, model) # Make all parameters contiguous for param in model.parameters(): param.data = param.data.contiguous() ================================================ FILE: notebooks/RepEng.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows how to compute control vectors to steer moondream's behavior\n", "in fun and interesting ways. To learn more about control vectors and representation\n", "engineering check out [Theia's blog post on the topic](https://vgel.me/posts/representation-engineering/)." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from PIL import Image\n", "import numpy as np\n", "from sklearn.decomposition import PCA\n", "from IPython.display import display, HTML" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"vikhyatk/moondream2\")\n", "model = AutoModelForCausalLM.from_pretrained(\n", " \"vikhyatk/moondream2\", trust_remote_code=True,\n", " torch_dtype=torch.float16, device_map={\"\": \"cuda\"}\n", ")\n", "\n", "# We will only be using the images, so it doesn't really matter what\n", "# dataset we use here.\n", "dataset = load_dataset(\"vikhyatk/lnqa\", streaming=True)[\"train\"]\n", "\n", "def hidden_states(enc_img, prompt):\n", " with torch.no_grad():\n", " inputs_embeds = model.input_embeds(prompt, enc_img, tokenizer)\n", " hidden_states = model.text_model.generate(\n", " inputs_embeds=inputs_embeds,\n", " max_new_tokens=128,\n", " pad_token_id=tokenizer.eos_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", " return_dict_in_generate=True,\n", " output_hidden_states=True,\n", " do_sample=True,\n", " temperature=0.5\n", " ).hidden_states[1:]\n", " return [torch.stack([hs.view(-1, 2048) for hs in h[1:]]).cpu() for h in hidden_states]\n", "\n", "class LayerWrapper(torch.nn.Module):\n", " def __init__(self, og_layer, control_vectors, scale=4.2):\n", " super().__init__()\n", " self.og_layer = og_layer\n", " self.control_vectors = control_vectors\n", " self.scale = scale\n", "\n", " def forward(self, *args, **kwargs):\n", " layer_outputs = self.og_layer(*args, **kwargs)\n", " layer_outputs = (layer_outputs[0] + self.scale * self.control_vectors, *layer_outputs[1:])\n", " return layer_outputs" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [], "source": [ "negative_prompt = \"\\n\\nQuestion: Describe this image.\\n\\nAnswer:\"\n", "positive_prompt = \"\\n\\nQuestion: What is the meaning of life?\\n\\nAnswer:\"\n", "\n", "# This can be lowered without noticeable loss in quality. Feel free to drop it to\n", "# IMAGES_PER_CONTROL=50 and SAMPLES_PER_IMAGE=2 if it's taking too long.\n", "IMAGES_PER_CONTROL = 200\n", "SAMPLES_PER_IMAGE = 5\n" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 200/200 [37:09<00:00, 11.15s/it]\n" ] } ], "source": [ "# This is not very efficient, batching would speed things up a lot.\n", "# But eh, works for a quick demo.\n", "\n", "hs_dataset = [[] for _ in range(24)]\n", "\n", "for i, sample in tqdm(enumerate(dataset), total=IMAGES_PER_CONTROL):\n", " if i >= IMAGES_PER_CONTROL:\n", " break\n", " image = sample[\"image\"]\n", " enc_img = model.encode_image(image)\n", " for _ in range(SAMPLES_PER_IMAGE):\n", " phs = hidden_states(enc_img, positive_prompt)\n", " nhs = hidden_states(enc_img, negative_prompt)\n", " t_max = min(len(phs), len(nhs))\n", " for t in range(t_max):\n", " phs_t = phs[t]\n", " nhs_t = nhs[t]\n", " for j in range(24):\n", " hs_dataset[j].append(phs_t[j])\n", " hs_dataset[j].append(nhs_t[j])" ] }, { "cell_type": "code", "execution_count": 114, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 24/24 [02:30<00:00, 6.26s/it]\n" ] } ], "source": [ "control_vectors = []\n", "\n", "for i in tqdm(range(24)):\n", " layer_hiddens = torch.stack(hs_dataset[i])\n", "\n", " layer_centers = (layer_hiddens[::2] + layer_hiddens[1::2]) / 2\n", " relative_layer_hiddens = layer_hiddens\n", " relative_layer_hiddens[::2] -= layer_centers\n", " relative_layer_hiddens[1::2] -= layer_centers\n", "\n", " train = relative_layer_hiddens - relative_layer_hiddens.mean(axis=0, keepdims=True)\n", " train = train.view(-1, 2048).cpu().numpy()\n", " pca_model = PCA(n_components=1, whiten=False).fit(train)\n", " directions = pca_model.components_.astype(np.float32).squeeze(axis=0)\n", "\n", " projected_hiddens = (layer_hiddens.cpu().numpy() @ directions) / np.linalg.norm(directions)\n", "\n", " positive_smaller_mean = np.mean(\n", " [\n", " projected_hiddens[i] < projected_hiddens[i + 1]\n", " for i in range(0, len(hs_dataset[i]), 2)\n", " ]\n", " )\n", " positive_larger_mean = np.mean(\n", " [\n", " projected_hiddens[i] > projected_hiddens[i + 1]\n", " for i in range(0, len(hs_dataset[i]), 2)\n", " ]\n", " )\n", " if positive_smaller_mean > positive_larger_mean: # type: ignore\n", " directions *= -1\n", "\n", " control_vectors.append(torch.tensor(directions, device=\"cuda\", dtype=torch.float16))\n" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ImageText
\n", " \n", " \n", " Describe this image.

\n", " In this image, we will explore the concept of happiness through an anthropomorphic interpretation. We'll look at it from a spiritual or religious perspective by examining what brings joy and fulfillment to human beings. This can be approached in various ways depending on individual beliefs and experiences.

To understand happiness, one might consider factors such as personal goals, relationships, health, wealth, spirituality, ethics, and more. It is subjective and varies greatly among individuals based on their values and life experiences. However, many people often seek happiness in different aspects like achieving success, finding inner peace, or living a meaningful life.

In the context of the provided statements about happiness, some may find happiness in material possessions (e.of), while others may seek happiness through relationships, creativity, or even simple pleasures that bring satisfaction and contentment. Ultimately, happiness is a deeply personal journey for each person to define and pursue according to their own beliefs and values.\n", "
\n", " \n", " \n", " What is this?

\n", " In the context of this image, a server or processor is an essential component for computing and running various applications on computers. A Processor (or CPU) can refer to any device that executes tasks according to specific programming requirements.

In this particular scenario, we are referring to advanced technologies like virtual machines, artificial intelligence, machine learning, etc., which require powerful computing systems to function effectively.

For example, in the case of AI research, researchers develop and test theories using sophisticated computer models and simulations. These concepts may involve analyzing vast amounts of data, exploring ethical questions, understanding existence, or even developing new knowledge about life itself.

In summary, when people talk about \"the meaning\" or \"purpose,\" they often refer to these advanced concepts as well. It's subjective and varies from person to person based on their beliefs, values, and experiences.\n", "
\n", " \n", " \n", " What color is the couch?

\n", " The couch in the image is described as \"black.\" However, without more information or context from different sources, it's difficult to determine its actual color. It could be any of those things like comfort, aesthetics, personal preferences, etc., which can vary among individuals.\n", "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prompts = [\n", " (\"../assets/demo-1.jpg\", \"Describe this image.\"),\n", " (\"../assets/demo-2.jpg\", \"What is this?\"),\n", " (\"../assets/demo-2.jpg\", \"What color is the couch?\"),\n", "]\n", "data = []\n", "\n", "def run_model(img_path, prompt, scale=4.2):\n", " og_h = model.text_model.transformer.h\n", " model.text_model.transformer.h = torch.nn.ModuleList([\n", " LayerWrapper(layer, vector, scale) for layer, vector in zip(og_h, control_vectors)\n", " ])\n", " answer = model.answer_question(\n", " model.encode_image(Image.open(img_path)), prompt, tokenizer,\n", " repetition_penalty=1.2, temperature=0.1, do_sample=True,\n", " length_penalty=1.2\n", " )\n", " model.text_model.transformer.h = og_h\n", " return answer\n", "\n", "for img_path, prompt in prompts:\n", " answer = run_model(img_path, prompt)\n", " data.append({\"prompt\": prompt, \"answer\": answer.replace(\"\\n\", \"
\"), \"image\": img_path})\n", "\n", "html_table = \"\"\"\n", "\n", " \n", " \n", " \n", " \n", "\"\"\"\n", "\n", "for item in data:\n", " html_table += f\"\"\"\n", " \n", " \n", " \n", " \n", " \"\"\"\n", "\n", "html_table += \"
ImageText
\n", " \n", " \n", " {item['prompt']}

\n", " {item['answer']}\n", "
\"\n", "\n", "# Display the HTML table\n", "display(HTML(html_table))" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: recipes/gaze-detection-video/.gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # Virtual Environment venv/ ENV/ # IDE .idea/ .vscode/ *.swp *.swo # Project specific # input/* # !input/.gitkeep # output/* # !output/.gitkeep # temp/* # !temp/.gitkeep # Model files *.pt *.pth *.ckpt # Logs *.log # OS specific .DS_Store Thumbs.db ================================================ FILE: recipes/gaze-detection-video/README.md ================================================ # Gaze Detection Video Processor > **⚠️ IMPORTANT:** This project currently uses Moondream 2B (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client > libraries once they become available for this version. ## Table of Contents - [Overview](#overview) - [Sample Output](#sample-output) - [Features](#features) - [Prerequisites](#prerequisites) - [Installation](#installation) - [Linux/macOS Installation](#linuxmacos-installation) - [Windows Installation](#windows-installation) - [Usage](#usage) - [Output](#output) - [Troubleshooting](#troubleshooting) - [Performance Notes](#performance-notes) - [Dependencies](#dependencies) - [Model Details](#model-details) - [License](#license) ## Overview This project uses the Moondream 2B model to detect faces and their gaze directions in videos. It processes videos frame by frame, visualizing face detections and gaze directions. ## Sample Output | Input Video | Processed Output | | :-----------------------------------: | :-----------------------------------------: | | ![Input Video](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-input-sample.gif?raw=true) | ![Processed Output](https://github.com/parsakhaz/gaze-detection-video/blob/master/gif-output-sample.gif?raw=true) | ## Features - Face detection in video frames - Gaze direction tracking - Real-time visualization with: - Colored bounding boxes for faces - Gradient lines showing gaze direction - Gaze target points - Supports multiple faces per frame - Processes all common video formats (.mp4, .avi, .mov, .mkv) - Uses Moondream 2 (2025-01-09 release) via Hugging Face Transformers - Note: Will be migrated to official client libraries in future updates - No authentication required ## Prerequisites 1. Python 3.8 or later 2. CUDA-capable GPU recommended (but CPU mode works too) 3. FFmpeg installed on your system ## Installation ### Linux/macOS Installation 1. Install system dependencies: ```bash # Ubuntu/Debian sudo apt-get update && sudo apt-get install -y libvips42 libvips-dev ffmpeg # CentOS/RHEL sudo yum install vips vips-devel ffmpeg # macOS brew install vips ffmpeg ``` 2. Clone and setup the project: ```bash git clone https://github.com/vikhyat/moondream.git cd moondream/recipes/gaze-detection-video python3 -m venv venv source venv/bin/activate pip install -r requirements.txt ``` ### Windows Installation Windows setup requires a few additional steps for proper GPU support and libvips installation. 1. Clone the repository: ```bash git clone [repository-url] cd moondream/recipes/gaze-detection-video ``` 2. Create and activate virtual environment: ```bash python -m venv venv .\venv\Scripts\activate ``` 3. Install PyTorch with CUDA support: ```bash # For NVIDIA GPUs pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 ``` 4. Install libvips: Download the appropriate version based on your system architecture: | Architecture | VIPS Version to Download | | ------------ | ------------------------ | | 32-bit x86 | vips-dev-w32-all-8.16.0.zip | | 64-bit x64 | vips-dev-w64-all-8.16.0.zip | - Extract the ZIP file - Copy all DLL files from `vips-dev-8.16\bin` to either: - Your project's root directory (easier) OR - `C:\Windows\System32` (requires admin privileges) - Add to PATH: 1. Open System Properties → Advanced → Environment Variables 2. Under System Variables, find PATH 3. Add the full path to the `vips-dev-8.16\bin` directory 5. Install FFmpeg: - Download from https://ffmpeg.org/download.html#build-windows - Extract and add the `bin` folder to your system PATH (similar to step 4) or to the project root directory 6. Install other dependencies: ```bash pip install -r requirements.txt ``` ## Usage 1. Place your input videos in the `input` directory - Supported formats: .mp4, .avi, .mov, .mkv - The directory will be created automatically if it doesn't exist 2. Run the script: ```bash python gaze-detection-video.py ``` 3. The script will: - Process all videos in the input directory - Show progress bars for each video - Save processed videos to the `output` directory with prefix 'processed\_' ## Output - Processed videos are saved as `output/processed_[original_name].[ext]` - Each frame in the output video shows: - Colored boxes around detected faces - Lines indicating gaze direction - Points showing where each person is looking ## Troubleshooting 1. CUDA/GPU Issues: - Ensure you have CUDA installed for GPU support - The script will automatically fall back to CPU if no GPU is available 2. Memory Issues: - If processing large videos, ensure you have enough RAM - Consider reducing video resolution if needed 3. libvips Errors: - Make sure libvips is properly installed for your OS - Check system PATH includes libvips 4. Video Format Issues: - Ensure FFmpeg is installed and in your system PATH - Try converting problematic videos to MP4 format ## Performance Notes - GPU processing is significantly faster than CPU - Processing time depends on: - Video resolution - Number of faces per frame - Frame rate - Available computing power ## Dependencies - transformers (for Moondream 2 model access) - torch - opencv-python - pillow - matplotlib - numpy - tqdm - pyvips - accelerate - einops ## Model Details > **⚠️ IMPORTANT:** This project currently uses Moondream 2 (2025-01-09 release) via the Hugging Face Transformers library. We will migrate to the official Moondream client > libraries once they become available for this version. The model is loaded using: ================================================ FILE: recipes/gaze-detection-video/gaze-detection-video.py ================================================ """ Gaze Detection Video Processor using Moondream 2 ------------------------------------------------ Read the README.md file for more information on how to use this script. Contact us in our discord for any questions if you get stuck. """ import torch import numpy as np import cv2 import matplotlib.pyplot as plt from PIL import Image from transformers import AutoModelForCausalLM from tqdm import tqdm import os import glob from typing import List, Dict, Tuple, Optional from contextlib import contextmanager def initialize_model() -> Optional[AutoModelForCausalLM]: """Initialize the Moondream 2 model with error handling.""" try: print("\nInitializing Moondream 2 model...") model_id = "vikhyatk/moondream2" revision = "2025-01-09" # Specify revision for stability if torch.cuda.is_available(): print(f"GPU detected: {torch.cuda.get_device_name(0)}") device = "cuda" else: print("No GPU detected, using CPU") device = "cpu" print("Loading model from HuggingFace...") model = AutoModelForCausalLM.from_pretrained( model_id, revision=revision, trust_remote_code=True, torch_dtype=torch.float16 if device == "cuda" else torch.float32, low_cpu_mem_usage=True, device_map={"": device} if device == "cuda" else None, ) if device == "cpu": model = model.to(device) model.eval() print("✓ Model initialized successfully") return model except Exception as e: print(f"\nError initializing model: {e}") return None @contextmanager def video_handler( input_path: str, output_path: str ) -> Tuple[cv2.VideoCapture, cv2.VideoWriter]: """Context manager for handling video capture and writer.""" cap = cv2.VideoCapture(input_path) if not cap.isOpened(): raise ValueError(f"Could not open video file: {input_path}") # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Create video writer fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) try: yield cap, out finally: cap.release() out.release() cv2.destroyAllWindows() def fig2rgb_array(fig: plt.Figure) -> np.ndarray: """Convert matplotlib figure to RGB array""" fig.canvas.draw() buf = fig.canvas.buffer_rgba() w, h = fig.canvas.get_width_height() img_array = np.asarray(buf).reshape((h, w, 4)) rgb_array = img_array[:, :, :3] # Drop alpha channel return rgb_array def visualize_frame( frame: np.ndarray, faces: List[Dict], model: AutoModelForCausalLM, pil_image: Image ) -> np.ndarray: """Visualize a single frame using matplotlib""" try: # Create figure without margins fig = plt.figure(figsize=(frame.shape[1] / 100, frame.shape[0] / 100), dpi=100) ax = fig.add_axes([0, 0, 1, 1]) # Display frame ax.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Sort faces by x_min coordinate for stable colors faces = sorted(faces, key=lambda f: (f["y_min"], f["x_min"])) # Generate colors colors = plt.cm.rainbow(np.linspace(0, 1, max(1, len(faces)))) # Process each face for face, color in zip(faces, colors): try: # Calculate face box coordinates x_min = int(float(face["x_min"]) * frame.shape[1]) y_min = int(float(face["y_min"]) * frame.shape[0]) width = int(float(face["x_max"] - face["x_min"]) * frame.shape[1]) height = int(float(face["y_max"] - face["y_min"]) * frame.shape[0]) # Draw face rectangle rect = plt.Rectangle( (x_min, y_min), width, height, fill=False, color=color, linewidth=2 ) ax.add_patch(rect) # Calculate face center face_center = ( float(face["x_min"] + face["x_max"]) / 2, float(face["y_min"] + face["y_max"]) / 2, ) # Try to detect gaze try: gaze_result = model.detect_gaze(pil_image, face_center) if isinstance(gaze_result, dict) and "gaze" in gaze_result: gaze = gaze_result["gaze"] else: gaze = gaze_result except Exception as e: print(f"Error detecting gaze: {e}") continue if ( gaze is not None and isinstance(gaze, dict) and "x" in gaze and "y" in gaze ): gaze_x = int(float(gaze["x"]) * frame.shape[1]) gaze_y = int(float(gaze["y"]) * frame.shape[0]) face_center_x = x_min + width // 2 face_center_y = y_min + height // 2 # Draw gaze line with gradient effect points = 50 alphas = np.linspace(0.8, 0, points) # Calculate points along the line x_points = np.linspace(face_center_x, gaze_x, points) y_points = np.linspace(face_center_y, gaze_y, points) # Draw gradient line segments for i in range(points - 1): ax.plot( [x_points[i], x_points[i + 1]], [y_points[i], y_points[i + 1]], color=color, alpha=alphas[i], linewidth=4, ) # Draw gaze point ax.scatter(gaze_x, gaze_y, color=color, s=100, zorder=5) ax.scatter(gaze_x, gaze_y, color="white", s=50, zorder=6) except Exception as e: print(f"Error processing face: {e}") continue # Configure axes ax.set_xlim(0, frame.shape[1]) ax.set_ylim(frame.shape[0], 0) ax.axis("off") # Convert matplotlib figure to image frame_rgb = fig2rgb_array(fig) # Convert RGB to BGR for OpenCV frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) # Clean up plt.close(fig) return frame_bgr except Exception as e: print(f"Error in visualize_frame: {e}") plt.close("all") return frame def process_video( input_path: str, output_path: str, model: AutoModelForCausalLM ) -> None: """Process video file and create new video with gaze visualization""" with video_handler(input_path, output_path) as (cap, out): total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) print(f"Processing video: {total_frames} frames at {fps} FPS") # Process frames with tqdm( total=total_frames, desc=f"Processing {os.path.basename(input_path)}" ) as pbar: while True: ret, frame = cap.read() if not ret: break try: # Convert frame for model pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Detect faces detection_result = model.detect(pil_image, "face") # Handle different possible return formats if ( isinstance(detection_result, dict) and "objects" in detection_result ): faces = detection_result["objects"] elif isinstance(detection_result, list): faces = detection_result else: print( f"Unexpected detection result format: {type(detection_result)}" ) faces = [] # Ensure each face has the required coordinates faces = [ face for face in faces if all(k in face for k in ["x_min", "y_min", "x_max", "y_max"]) ] if not faces: processed_frame = frame else: # Visualize frame with matplotlib processed_frame = visualize_frame( frame, faces, model, pil_image ) # Write frame out.write(processed_frame) pbar.update(1) # Force matplotlib to clean up plt.close("all") except Exception as e: print(f"Error processing frame: {e}") out.write(frame) # Write original frame on error pbar.update(1) plt.close("all") # Clean up even on error if __name__ == "__main__": # Ensure input and output directories exist input_dir = os.path.join(os.path.dirname(__file__), "input") output_dir = os.path.join(os.path.dirname(__file__), "output") os.makedirs(input_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) # Find all video files in input directory video_extensions = [".mp4", ".avi", ".mov", ".mkv"] input_videos = [] for ext in video_extensions: input_videos.extend(glob.glob(os.path.join(input_dir, f"*{ext}"))) if not input_videos: print("No video files found in input directory") exit(1) # Initialize model once for all videos model = initialize_model() if model is None: print("Failed to initialize model") exit(1) # Process each video file for input_video in input_videos: base_name = os.path.basename(input_video) output_video = os.path.join(output_dir, f"processed_{base_name}") try: process_video(input_video, output_video, model) except Exception as e: print(f"Error processing {base_name}: {e}") continue ================================================ FILE: recipes/gaze-detection-video/input/.gitkeep ================================================ ================================================ FILE: recipes/gaze-detection-video/output/.gitkeep ================================================ ================================================ FILE: recipes/gaze-detection-video/requirements.txt ================================================ torch>=2.0.0 transformers>=4.36.0 opencv-python>=4.8.0 pillow>=10.0.0 matplotlib>=3.7.0 numpy>=1.24.0 tqdm>=4.65.0 pyvips accelerate>=0.26.0 einops ================================================ FILE: recipes/gaze-detection-video/temp/.gitkeep ================================================ ================================================ FILE: recipes/promptable-content-moderation/.gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg *.dll # Virtual Environment venv/ env/ ENV/ .venv/ # IDE .idea/ .vscode/ *.swp *.swo # Project specific inputs/* outputs/* !inputs/.gitkeep !outputs/.gitkeep inputs/ outputs/ # Model files *.pth *.onnx *.pt # Logs *.log certificate.pem ================================================ FILE: recipes/promptable-content-moderation/README.md ================================================ # Promptable Content Moderation with Moondream Welcome to the future of content moderation with Moondream 2B, a powerful and lightweight vision-language model that enables detection and moderation of video content using natural language prompts. [Try it now.](https://huggingface.co/spaces/moondream/content-moderation) ## Features - Content moderation through natural language prompts - Multiple visualization styles - Intelligent scene detection and tracking: - DeepSORT tracking with scene-aware reset - Persistent moderation across frames - Smart tracker reset at scene boundaries - Optional grid-based detection for improved accuracy on complex scenes - Frame-by-frame processing with IoU-based merging - Web-compatible output format - Test mode (process only first X seconds) - Advanced moderation analysis with multiple visualization plots ## Examples | Prompt | Output | |--------|-----------------| | "white cigarette" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-cig.gif) | | "gun" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-gu.gif) | | "confederate flag" | ![Demo](https://github.com/parsakhaz/promptable-content-moderation/raw/main/examples/clip-conflag.gif) | ## Requirements ### Python Dependencies For Windows users, before installing other requirements, first install PyTorch with CUDA support: ```bash pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 ``` Then install the remaining dependencies: ```bash pip install -r requirements.txt ``` ### System Requirements - FFmpeg (required for video processing) - libvips (required for image processing) Installation by platform: - Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` - macOS: `brew install ffmpeg libvips` - Windows: - Download FFmpeg from [ffmpeg.org](https://ffmpeg.org/download.html) - Follow [libvips Windows installation guide](https://docs.moondream.ai/quick-start) ## Installation 1. Clone this repository and create a new virtual environment: ```bash git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction python -m venv .venv source .venv/bin/activate # On Windows: .venv\Scripts\activate ``` 2. Install Python dependencies: ```bash pip install -r requirements.txt ``` 3. Install ffmpeg and libvips: - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` - On macOS: `brew install ffmpeg` - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html) > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start) ## Usage The easiest way to use this tool is through its web interface, which provides a user-friendly experience for video content moderation. ### Web Interface 1. Start the web interface: ```bash python app.py ``` 2. Open the provided URL in your browser (typically ) 3. Use the interface to: - Upload your video file - Specify content to moderate (e.g., "face", "cigarette", "gun") - Choose redaction style (default: obfuscated-pixel) - OPTIONAL: Configure advanced settings - Processing speed/quality - Grid size for detection - Test mode for quick validation (default: on, 3 seconds) - Process the video and download results - Analyze detection patterns with visualization tools ## Output Files The tool generates two types of output files in the `outputs` directory: 1. Processed Videos: - Format: `[style]_[content_type]_[original_filename].mp4` - Example: `censor_inappropriate_video.mp4` 2. Detection Data: - Format: `[style]_[content_type]_[original_filename]_detections.json` - Contains frame-by-frame detection information - Used for visualization and analysis ## Technical Details ### Scene Detection and Tracking The tool uses advanced scene detection and object tracking: 1. Scene Detection: - Powered by PySceneDetect's ContentDetector - Automatically identifies scene changes in videos - Configurable detection threshold (default: 30.0) - Helps maintain tracking accuracy across scene boundaries 2. Object Tracking: - DeepSORT tracking for consistent object identification - Automatic tracker reset at scene changes - Maintains object identity within scenes - Prevents tracking errors across scene boundaries 3. Integration Benefits: - More accurate object tracking - Better handling of scene transitions - Reduced false positives in tracking - Improved tracking consistency ## Best Practices - Use test mode for initial configuration - Enable grid-based detection for complex scenes - Choose appropriate redaction style based on content type: - Censor: Complete content blocking - Blur styles: Less intrusive moderation - Bounding Box: Content review and analysis - Monitor system resources during processing - Use appropriate processing quality settings based on your needs ## Notes - Processing time depends on video length, resolution, GPU availability, and chosen settings - GPU is strongly recommended for faster processing - Grid-based detection increases accuracy but requires more processing time (each grid cell is processed independently) - Test mode processes only first X seconds (default: 3 seconds) for quick validation ================================================ FILE: recipes/promptable-content-moderation/app.py ================================================ #!/usr/bin/env python3 import gradio as gr import os from main import load_moondream, process_video, load_sam_model import shutil import torch from visualization import visualize_detections from persistence import load_detection_data import matplotlib.pyplot as plt import io from PIL import Image import pandas as pd from video_visualization import create_video_visualization # Get absolute path to workspace root WORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) # Check CUDA availability print(f"Is CUDA available: {torch.cuda.is_available()}") # We want to get True print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # GPU Name # Initialize Moondream model globally for reuse (will be loaded on first use) model, tokenizer = None, None def process_video_file( video_file, target_object, box_style, ffmpeg_preset, grid_rows, grid_cols, test_mode, test_duration, ): """Process a video file through the Gradio interface.""" try: if not video_file: raise gr.Error("Please upload a video file") # Load models if not already loaded global model, tokenizer if model is None or tokenizer is None: model, tokenizer = load_moondream() # Ensure input/output directories exist using absolute paths inputs_dir = os.path.join(WORKSPACE_ROOT, "inputs") outputs_dir = os.path.join(WORKSPACE_ROOT, "outputs") os.makedirs(inputs_dir, exist_ok=True) os.makedirs(outputs_dir, exist_ok=True) # Copy uploaded video to inputs directory video_filename = f"input_{os.path.basename(video_file)}" input_video_path = os.path.join(inputs_dir, video_filename) shutil.copy2(video_file, input_video_path) try: # Process the video output_path = process_video( input_video_path, target_object, test_mode=test_mode, test_duration=test_duration, ffmpeg_preset=ffmpeg_preset, grid_rows=grid_rows, grid_cols=grid_cols, box_style=box_style, ) # Get the corresponding JSON path base_name = os.path.splitext(os.path.basename(video_filename))[0] json_path = os.path.join( outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json" ) # Verify output exists and is readable if not output_path or not os.path.exists(output_path): print(f"Warning: Output path {output_path} does not exist") # Try to find the output based on expected naming convention expected_output = os.path.join( outputs_dir, f"{box_style}_{target_object}_{video_filename}" ) if os.path.exists(expected_output): output_path = expected_output else: # Try searching in outputs directory for any matching file matching_files = [ f for f in os.listdir(outputs_dir) if f.startswith(f"{box_style}_{target_object}_") ] if matching_files: output_path = os.path.join(outputs_dir, matching_files[0]) else: raise gr.Error("Failed to locate output video") # Convert output path to absolute path if it isn't already if not os.path.isabs(output_path): output_path = os.path.join(WORKSPACE_ROOT, output_path) print(f"Returning output path: {output_path}") return output_path, json_path finally: # Clean up input file try: if os.path.exists(input_video_path): os.remove(input_video_path) except: pass except Exception as e: print(f"Error in process_video_file: {str(e)}") raise gr.Error(f"Error processing video: {str(e)}") def create_visualization_plots(json_path): """Create visualization plots and return them as images.""" try: # Load the data data = load_detection_data(json_path) if not data: return None, None, None, None, None, None, None, None, "No data found" # Convert to DataFrame rows = [] for frame_data in data["frame_detections"]: frame = frame_data["frame"] timestamp = frame_data["timestamp"] for obj in frame_data["objects"]: rows.append( { "frame": frame, "timestamp": timestamp, "keyword": obj["keyword"], "x1": obj["bbox"][0], "y1": obj["bbox"][1], "x2": obj["bbox"][2], "y2": obj["bbox"][3], "area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1]), "center_x": (obj["bbox"][0] + obj["bbox"][2]) / 2, "center_y": (obj["bbox"][1] + obj["bbox"][3]) / 2, } ) if not rows: return ( None, None, None, None, None, None, None, None, "No detections found in the data", ) df = pd.DataFrame(rows) plots = [] # Create each plot and convert to image for plot_num in range(8): # Increased to 8 plots plt.figure(figsize=(8, 6)) if plot_num == 0: # Plot 1: Number of detections per frame (Original) detections_per_frame = df.groupby("frame").size() plt.plot(detections_per_frame.index, detections_per_frame.values) plt.xlabel("Frame") plt.ylabel("Number of Detections") plt.title("Detections Per Frame") elif plot_num == 1: # Plot 2: Distribution of detection areas (Original) df["area"].hist(bins=30) plt.xlabel("Detection Area (normalized)") plt.ylabel("Count") plt.title("Distribution of Detection Areas") elif plot_num == 2: # Plot 3: Average detection area over time (Original) avg_area = df.groupby("frame")["area"].mean() plt.plot(avg_area.index, avg_area.values) plt.xlabel("Frame") plt.ylabel("Average Detection Area") plt.title("Average Detection Area Over Time") elif plot_num == 3: # Plot 4: Heatmap of detection centers (Original) plt.hist2d(df["center_x"], df["center_y"], bins=30) plt.colorbar() plt.xlabel("X Position") plt.ylabel("Y Position") plt.title("Detection Center Heatmap") elif plot_num == 4: # Plot 5: Time-based Detection Density # Shows when in the video most detections occur df["time_bucket"] = pd.qcut(df["timestamp"], q=20, labels=False) time_density = df.groupby("time_bucket").size() plt.bar(time_density.index, time_density.values) plt.xlabel("Video Timeline (20 segments)") plt.ylabel("Number of Detections") plt.title("Detection Density Over Video Duration") elif plot_num == 5: # Plot 6: Screen Region Analysis # Divide screen into 3x3 grid and show detection counts try: df["grid_x"] = pd.qcut( df["center_x"], q=3, labels=["Left", "Center", "Right"], duplicates="drop", ) df["grid_y"] = pd.qcut( df["center_y"], q=3, labels=["Top", "Middle", "Bottom"], duplicates="drop", ) region_counts = ( df.groupby(["grid_y", "grid_x"]).size().unstack(fill_value=0) ) plt.imshow(region_counts, cmap="YlOrRd") plt.colorbar(label="Detection Count") for i in range(3): for j in range(3): plt.text( j, i, region_counts.iloc[i, j], ha="center", va="center" ) plt.xticks(range(3), ["Left", "Center", "Right"]) plt.yticks(range(3), ["Top", "Middle", "Bottom"]) plt.title("Screen Region Analysis") except Exception as e: plt.text( 0.5, 0.5, "Insufficient variation in detection positions", ha="center", va="center", ) plt.title("Screen Region Analysis (Not Available)") elif plot_num == 6: # Plot 7: Detection Size Categories # Categorize detections by size for content moderation try: size_labels = [ "Small (likely far/background)", "Medium-small", "Medium-large", "Large (likely foreground/close)", ] # Handle cases with limited unique values unique_areas = df["area"].nunique() if unique_areas >= 4: df["size_category"] = pd.qcut( df["area"], q=4, labels=size_labels, duplicates="drop" ) else: # Alternative binning for limited unique values df["size_category"] = pd.cut( df["area"], bins=unique_areas, labels=size_labels[:unique_areas], ) size_dist = df["size_category"].value_counts() plt.pie(size_dist.values, labels=size_dist.index, autopct="%1.1f%%") plt.title("Detection Size Distribution") except Exception as e: plt.text( 0.5, 0.5, "Insufficient variation in detection sizes", ha="center", va="center", ) plt.title("Detection Size Distribution (Not Available)") elif plot_num == 7: # Plot 8: Temporal Pattern Analysis # Show patterns of when detections occur in sequence try: detection_gaps = df.sort_values("frame")["frame"].diff() if len(detection_gaps.dropna().unique()) > 1: plt.hist( detection_gaps.dropna(), bins=min(30, len(detection_gaps.dropna().unique())), edgecolor="black", ) plt.xlabel("Frames Between Detections") plt.ylabel("Frequency") plt.title("Detection Temporal Pattern Analysis") else: plt.text( 0.5, 0.5, "Uniform detection intervals", ha="center", va="center", ) plt.title("Temporal Pattern Analysis (Uniform)") except Exception as e: plt.text( 0.5, 0.5, "Insufficient temporal data", ha="center", va="center" ) plt.title("Temporal Pattern Analysis (Not Available)") # Save plot to bytes buf = io.BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") buf.seek(0) plots.append(Image.open(buf)) plt.close() # Enhanced summary text summary = f"""Summary Statistics: Total frames analyzed: {len(data['frame_detections'])} Total detections: {len(df)} Average detections per frame: {len(df) / len(data['frame_detections']):.2f} Detection Patterns: - Peak detection count: {df.groupby('frame').size().max()} (in a single frame) - Most common screen region: {df.groupby(['grid_y', 'grid_x']).size().idxmax()} - Average detection size: {df['area'].mean():.3f} - Median frames between detections: {detection_gaps.median():.1f} Video metadata: """ for key, value in data["video_metadata"].items(): summary += f"{key}: {value}\n" return ( plots[0], plots[1], plots[2], plots[3], plots[4], plots[5], plots[6], plots[7], summary, ) except Exception as e: print(f"Error creating visualization: {str(e)}") import traceback traceback.print_exc() return ( None, None, None, None, None, None, None, None, f"Error creating visualization: {str(e)}", ) # Create the Gradio interface with gr.Blocks(title="Promptable Content Moderation") as app: with gr.Tabs(): with gr.Tab("Process Video"): gr.Markdown("# Promptable Content Moderation with Moondream") gr.Markdown( """ Powered by [Moondream 2B](https://github.com/vikhyat/moondream). Upload a video and specify what to moderate. The app will process each frame and moderate any visual content that matches the prompt. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH). """ ) with gr.Row(): with gr.Column(): # Input components video_input = gr.Video(label="Upload Video") detect_input = gr.Textbox( label="What to Moderate", placeholder="e.g. face, cigarette, gun, etc.", value="face", info="Moondream can moderate anything that you can describe in natural language", ) process_btn = gr.Button("Process Video", variant="primary") with gr.Accordion("Advanced Settings", open=False): box_style_input = gr.Radio( choices=[ "censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur", "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel", ], value="obfuscated-pixel", label="Visualization Style", info="Choose how to display moderations: censor (black boxes), bounding-box (red boxes with labels), hitmarker (COD-style markers), sam (precise segmentation), sam-fast (faster but less precise segmentation), fuzzy-blur (Gaussian blur), pixelated-blur (pixelated with blur), obfuscated-pixel (advanced pixelation with neighborhood averaging)", ) preset_input = gr.Dropdown( choices=[ "ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow", ], value="medium", label="Processing Speed (faster = lower quality)", ) with gr.Row(): rows_input = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Grid Rows" ) cols_input = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Grid Columns", ) test_mode_input = gr.Checkbox( label="Test Mode (Process first 3 seconds only)", value=True, info="Enable to quickly test settings on a short clip before processing the full video (recommended). If using the data visualizations, disable.", ) test_duration_input = gr.Slider( minimum=1, maximum=10, value=3, step=1, label="Test Mode Duration (seconds)", info="Number of seconds to process in test mode", ) gr.Markdown( """ Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings. """ ) gr.Markdown( """ We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection. For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU). Note: Using the SAM visualization style will increase processing time significantly as it performs additional segmentation for each detection. The sam-fast option uses a smaller model for faster processing at the cost of some accuracy. """ ) with gr.Column(): # Output components video_output = gr.Video(label="Processed Video") json_output = gr.Text(label="Detection Data Path", visible=False) # About section under the video output gr.Markdown( """ ### Links: - [GitHub Repository](https://github.com/vikhyat/moondream) - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) - [Quick Start](https://docs.moondream.ai/quick-start) - [Moondream Recipes](https://docs.moondream.ai/recipes) """ ) with gr.Tab("Analyze Results"): gr.Markdown("# Detection Analysis") gr.Markdown( """ Analyze the detection results from processed videos. The analysis includes: - Basic detection statistics and patterns - Temporal and spatial distribution analysis - Size-based categorization - Screen region analysis - Detection density patterns """ ) with gr.Row(): json_input = gr.File( label="Upload Detection Data (JSON)", file_types=[".json"], ) analyze_btn = gr.Button("Analyze", variant="primary") with gr.Row(): with gr.Column(): plot1 = gr.Image( label="Detections Per Frame", ) plot2 = gr.Image( label="Detection Areas Distribution", ) plot5 = gr.Image( label="Detection Density Timeline", ) plot6 = gr.Image( label="Screen Region Analysis", ) with gr.Column(): plot3 = gr.Image( label="Average Detection Area Over Time", ) plot4 = gr.Image( label="Detection Center Heatmap", ) plot7 = gr.Image( label="Detection Size Categories", ) plot8 = gr.Image( label="Temporal Pattern Analysis", ) stats_output = gr.Textbox( label="Statistics", info="Summary of key metrics and patterns found in the detection data.", lines=12, max_lines=15, interactive=False, ) # with gr.Tab("Video Visualizations"): # gr.Markdown("# Real-time Detection Visualization") # gr.Markdown( # """ # Watch the detection patterns unfold in real-time. Choose from: # - Timeline: Shows number of detections over time # - Gauge: Simple yes/no indicator for current frame detections # """ # ) # with gr.Row(): # json_input_realtime = gr.File( # label="Upload Detection Data (JSON)", # file_types=[".json"], # ) # viz_style = gr.Radio( # choices=["timeline", "gauge"], # value="timeline", # label="Visualization Style", # info="Choose between timeline view or simple gauge indicator" # ) # visualize_btn = gr.Button("Visualize", variant="primary") # with gr.Row(): # video_visualization = gr.Video( # label="Detection Visualization", # interactive=False # ) # stats_realtime = gr.Textbox( # label="Video Statistics", # lines=6, # max_lines=8, # interactive=False # ) # Event handlers process_outputs = process_btn.click( fn=process_video_file, inputs=[ video_input, detect_input, box_style_input, preset_input, rows_input, cols_input, test_mode_input, test_duration_input, ], outputs=[video_output, json_output], ) # Auto-analyze after processing process_outputs.then( fn=create_visualization_plots, inputs=[json_output], outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], ) # Manual analysis button analyze_btn.click( fn=create_visualization_plots, inputs=[json_input], outputs=[plot1, plot2, plot3, plot4, plot5, plot6, plot7, plot8, stats_output], ) # Video visualization button # visualize_btn.click( # fn=lambda json_file, style: create_video_visualization(json_file.name if json_file else None, style), # inputs=[json_input_realtime, viz_style], # outputs=[video_visualization, stats_realtime], # ) if __name__ == "__main__": app.launch(share=True) ================================================ FILE: recipes/promptable-content-moderation/deep_sort_integration.py ================================================ import numpy as np import torch from deep_sort_realtime.deepsort_tracker import DeepSort from datetime import datetime class DeepSORTTracker: def __init__(self, max_age=5): """Initialize DeepSORT tracker.""" self.max_age = max_age self.tracker = self._create_tracker() def _create_tracker(self): """Create a new instance of DeepSort tracker.""" return DeepSort( max_age=self.max_age, embedder="mobilenet", # Using default MobileNetV2 embedder today=datetime.now().date(), # For track naming and daily ID reset ) def reset(self): """Reset the tracker state by creating a new instance.""" print("Resetting DeepSORT tracker...") self.tracker = self._create_tracker() def update(self, frame, detections): """Update tracking with new detections. Args: frame: Current video frame (numpy array) detections: List of (box, keyword) tuples where box is [x1, y1, x2, y2] normalized Returns: List of (box, keyword, track_id) tuples """ if not detections: return [] height, width = frame.shape[:2] # Convert normalized coordinates to absolute and format detections detection_list = [] for box, keyword in detections: x1 = int(box[0] * width) y1 = int(box[1] * height) x2 = int(box[2] * width) y2 = int(box[3] * height) w = x2 - x1 h = y2 - y1 # Format: ([left,top,w,h], confidence, detection_class) detection_list.append(([x1, y1, w, h], 1.0, keyword)) # Update tracker tracks = self.tracker.update_tracks(detection_list, frame=frame) # Convert back to normalized coordinates with track IDs tracked_objects = [] for track in tracks: if not track.is_confirmed(): continue ltrb = track.to_ltrb() # Get [left,top,right,bottom] format x1, y1, x2, y2 = ltrb # Normalize coordinates x1 = max(0.0, min(1.0, x1 / width)) y1 = max(0.0, min(1.0, y1 / height)) x2 = max(0.0, min(1.0, x2 / width)) y2 = max(0.0, min(1.0, y2 / height)) tracked_objects.append(([x1, y1, x2, y2], track.det_class, track.track_id)) return tracked_objects ================================================ FILE: recipes/promptable-content-moderation/main.py ================================================ #!/usr/bin/env python3 import cv2, os, subprocess, argparse from PIL import Image import torch from transformers import AutoModelForCausalLM, AutoTokenizer, SamModel, SamProcessor from tqdm import tqdm import numpy as np from datetime import datetime from deep_sort_integration import DeepSORTTracker from scenedetect import detect, ContentDetector from functools import lru_cache # Constants DEFAULT_TEST_MODE_DURATION = 3 # Process only first 3 seconds in test mode by default FFMPEG_PRESETS = [ "ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow", ] FONT = cv2.FONT_HERSHEY_SIMPLEX # Font for bounding-box-style labels # Detection parameters IOU_THRESHOLD = 0.5 # IoU threshold for considering boxes related # Hitmarker parameters HITMARKER_SIZE = 20 # Size of the hitmarker in pixels HITMARKER_GAP = 3 # Size of the empty space in the middle (reduced from 8) HITMARKER_THICKNESS = 2 # Thickness of hitmarker lines HITMARKER_COLOR = (255, 255, 255) # White color for hitmarker HITMARKER_SHADOW_COLOR = (80, 80, 80) # Lighter gray for shadow effect HITMARKER_SHADOW_OFFSET = 1 # Smaller shadow offset # SAM parameters device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize model variables as None sam_model = None sam_processor = None slimsam_model = None slimsam_processor = None @lru_cache(maxsize=2) # Cache both regular and slim SAM models def get_sam_model(slim=False): """Get cached SAM model and processor.""" global sam_model, sam_processor, slimsam_model, slimsam_processor if slim: if slimsam_model is None: print("Loading SlimSAM model for the first time...") slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to( device ) slimsam_processor = SamProcessor.from_pretrained( "nielsr/slimsam-50-uniform" ) return slimsam_model, slimsam_processor else: if sam_model is None: print("Loading SAM model for the first time...") sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") return sam_model, sam_processor def load_sam_model(slim=False): """Load SAM model and processor with caching.""" return get_sam_model(slim=slim) def generate_color_pair(): """Generate a generic light blue and dark blue color pair for SAM visualization.""" dark_rgb = [0, 0, 139] # Dark blue light_rgb = [173, 216, 230] # Light blue return dark_rgb, light_rgb def create_mask_overlay(image, masks, points=None, labels=None): """Create a mask overlay with contours for multiple SAM visualizations. Args: image: PIL Image to overlay masks on masks: List of binary masks or single mask points: Optional list of (x,y) points for labels labels: Optional list of label strings for each point """ # Convert single mask to list for uniform processing if not isinstance(masks, list): masks = [masks] # Create empty overlays overlay = np.zeros((*image.size[::-1], 4), dtype=np.uint8) outline = np.zeros((*image.size[::-1], 4), dtype=np.uint8) # Process each mask for i, mask in enumerate(masks): # Convert binary mask to uint8 mask_uint8 = (mask > 0).astype(np.uint8) # Dilation to fill gaps kernel = np.ones((5, 5), np.uint8) mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1) # Find contours of the dilated mask contours, _ = cv2.findContours( mask_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE ) # Generate random color pair for this segmentation dark_color, light_color = generate_color_pair() # Add to the overlays overlay[mask_dilated > 0] = [*light_color, 90] # Light color with 35% opacity cv2.drawContours( outline, contours, -1, (*dark_color, 255), 2 ) # Dark color outline # Convert to PIL images mask_overlay = Image.fromarray(overlay, "RGBA") outline_overlay = Image.fromarray(outline, "RGBA") # Composite the layers result = image.convert("RGBA") result.paste(mask_overlay, (0, 0), mask_overlay) result.paste(outline_overlay, (0, 0), outline_overlay) # Add labels if provided if points and labels: result_array = np.array(result) for (x, y), label in zip(points, labels): label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0] cv2.putText( result_array, label, (int(x - label_size[0] // 2), int(y - 20)), FONT, 0.5, (255, 255, 255), 1, cv2.LINE_AA, ) result = Image.fromarray(result_array) return result def process_sam_detection(image, center_x, center_y, slim=False): """Process a single detection point with SAM. Returns: tuple: (mask, result_pil) where mask is the binary mask and result_pil is the visualization """ if not isinstance(image, Image.Image): image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Get appropriate model from cache model, processor = get_sam_model(slim) # Process the image with SAM inputs = processor( image, input_points=[[[center_x, center_y]]], return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) mask = processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu(), )[0][0][0].numpy() # Create the visualization result = create_mask_overlay(image, mask) return mask, result def load_moondream(): """Load Moondream model and tokenizer.""" model = AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream2", trust_remote_code=True, device_map={"": "cuda"} ) tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") return model, tokenizer def get_video_properties(video_path): """Get basic video properties.""" video = cv2.VideoCapture(video_path) fps = video.get(cv2.CAP_PROP_FPS) frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) video.release() return {"fps": fps, "frame_count": frame_count, "width": width, "height": height} def is_valid_bounding_box(bounding_box): """Check if bounding box coordinates are reasonable.""" x1, y1, x2, y2 = bounding_box width = x2 - x1 height = y2 - y1 # Reject boxes that are too large (over 90% of frame in both dimensions) if width > 0.9 and height > 0.9: return False # Reject boxes that are too small (less than 1% of frame) if width < 0.01 or height < 0.01: return False return True def split_frame_into_grid(frame, grid_rows, grid_cols): """Split a frame into a grid of tiles.""" height, width = frame.shape[:2] tile_height = height // grid_rows tile_width = width // grid_cols tiles = [] tile_positions = [] for i in range(grid_rows): for j in range(grid_cols): y1 = i * tile_height y2 = (i + 1) * tile_height if i < grid_rows - 1 else height x1 = j * tile_width x2 = (j + 1) * tile_width if j < grid_cols - 1 else width tile = frame[y1:y2, x1:x2] tiles.append(tile) tile_positions.append((x1, y1, x2, y2)) return tiles, tile_positions def convert_tile_coords_to_frame(box, tile_pos, frame_shape): """Convert coordinates from tile space to frame space.""" frame_height, frame_width = frame_shape[:2] tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos tile_width = tile_x2 - tile_x1 tile_height = tile_y2 - tile_y1 x1_tile_abs = box[0] * tile_width y1_tile_abs = box[1] * tile_height x2_tile_abs = box[2] * tile_width y2_tile_abs = box[3] * tile_height x1_frame_abs = tile_x1 + x1_tile_abs y1_frame_abs = tile_y1 + y1_tile_abs x2_frame_abs = tile_x1 + x2_tile_abs y2_frame_abs = tile_y1 + y2_tile_abs x1_norm = x1_frame_abs / frame_width y1_norm = y1_frame_abs / frame_height x2_norm = x2_frame_abs / frame_width y2_norm = y2_frame_abs / frame_height x1_norm = max(0.0, min(1.0, x1_norm)) y1_norm = max(0.0, min(1.0, y1_norm)) x2_norm = max(0.0, min(1.0, x2_norm)) y2_norm = max(0.0, min(1.0, y2_norm)) return [x1_norm, y1_norm, x2_norm, y2_norm] def merge_tile_detections(tile_detections, iou_threshold=0.5): """Merge detections from different tiles using NMS-like approach.""" if not tile_detections: return [] all_boxes = [] all_keywords = [] # Collect all boxes and their keywords for detections in tile_detections: for box, keyword in detections: all_boxes.append(box) all_keywords.append(keyword) if not all_boxes: return [] # Convert to numpy for easier processing boxes = np.array(all_boxes) # Calculate areas x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] areas = (x2 - x1) * (y2 - y1) # Sort boxes by area order = areas.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) if order.size == 1: break # Calculate IoU with rest of boxes xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) # Get indices of boxes with IoU less than threshold inds = np.where(ovr <= iou_threshold)[0] order = order[inds + 1] return [(all_boxes[i], all_keywords[i]) for i in keep] def detect_objects_in_frame( model, tokenizer, image, target_object, grid_rows=1, grid_cols=1 ): """Detect specified objects in a frame using grid-based analysis.""" if grid_rows == 1 and grid_cols == 1: return detect_objects_in_frame_single(model, tokenizer, image, target_object) # Convert numpy array to PIL Image if needed if not isinstance(image, Image.Image): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Split frame into tiles tiles, tile_positions = split_frame_into_grid(image, grid_rows, grid_cols) # Process each tile tile_detections = [] for tile, tile_pos in zip(tiles, tile_positions): # Convert tile to PIL Image tile_pil = Image.fromarray(tile) # Detect objects in tile response = model.detect(tile_pil, target_object) if response and "objects" in response and response["objects"]: objects = response["objects"] tile_objects = [] for obj in objects: if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] if is_valid_bounding_box(box): # Convert tile coordinates to frame coordinates frame_box = convert_tile_coords_to_frame( box, tile_pos, image.shape ) tile_objects.append((frame_box, target_object)) if tile_objects: # Only append if we found valid objects tile_detections.append(tile_objects) # Merge detections from all tiles merged_detections = merge_tile_detections(tile_detections) return merged_detections def detect_objects_in_frame_single(model, tokenizer, image, target_object): """Single-frame detection function.""" detected_objects = [] # Convert numpy array to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Detect objects response = model.detect(image, target_object) # Check if we have valid objects if response and "objects" in response and response["objects"]: objects = response["objects"] for obj in objects: if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] # If box is valid (not full-frame), add it if is_valid_bounding_box(box): detected_objects.append((box, target_object)) return detected_objects def draw_hitmarker( frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True ): """Draw a COD-style hitmarker cross with more space in the middle.""" half_size = size // 2 # Draw shadow first if enabled if shadow: # Top-left to center shadow cv2.line( frame, ( center_x - half_size + HITMARKER_SHADOW_OFFSET, center_y - half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Top-right to center shadow cv2.line( frame, ( center_x + half_size + HITMARKER_SHADOW_OFFSET, center_y - half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Bottom-left to center shadow cv2.line( frame, ( center_x - half_size + HITMARKER_SHADOW_OFFSET, center_y + half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Bottom-right to center shadow cv2.line( frame, ( center_x + half_size + HITMARKER_SHADOW_OFFSET, center_y + half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Draw main hitmarker # Top-left to center cv2.line( frame, (center_x - half_size, center_y - half_size), (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Top-right to center cv2.line( frame, (center_x + half_size, center_y - half_size), (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Bottom-left to center cv2.line( frame, (center_x - half_size, center_y + half_size), (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Bottom-right to center cv2.line( frame, (center_x + half_size, center_y + half_size), (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP), color, HITMARKER_THICKNESS, ) def draw_ad_boxes(frame, detected_objects, detect_keyword, model, box_style="censor"): height, width = frame.shape[:2] points = [] # Only get points if we need them for hitmarker or SAM styles if box_style in ["hitmarker", "sam", "sam-fast"]: frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) try: point_response = model.point(frame_pil, detect_keyword) if isinstance(point_response, dict) and "points" in point_response: points = point_response["points"] except Exception as e: print(f"Error during point detection: {str(e)}") points = [] # Only load SAM models and process points if we're using SAM styles and have points if box_style in ["sam", "sam-fast"] and points: # Start with the original PIL image frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) # Collect all masks and points all_masks = [] point_coords = [] point_labels = [] for point in points: try: center_x = int(float(point["x"]) * width) center_y = int(float(point["y"]) * height) # Get mask and visualization mask, _ = process_sam_detection( frame_pil, center_x, center_y, slim=(box_style == "sam-fast") ) # Collect mask and point data all_masks.append(mask) point_coords.append((center_x, center_y)) point_labels.append(detect_keyword) except Exception as e: print(f"Error processing individual SAM point: {str(e)}") print(f"Point data: {point}") if all_masks: # Create final visualization with all masks result_pil = create_mask_overlay( frame_pil, all_masks, point_coords, point_labels ) frame = cv2.cvtColor(np.array(result_pil), cv2.COLOR_RGB2BGR) # Process other visualization styles for detection in detected_objects: try: # Handle both tracked and untracked detections if len(detection) == 3: # Tracked detection with ID box, keyword, track_id = detection else: # Regular detection without tracking box, keyword = detection track_id = None x1 = int(box[0] * width) y1 = int(box[1] * height) x2 = int(box[2] * width) y2 = int(box[3] * height) x1 = max(0, min(x1, width - 1)) y1 = max(0, min(y1, height - 1)) x2 = max(0, min(x2, width - 1)) y2 = max(0, min(y2, height - 1)) if x2 > x1 and y2 > y1: if box_style == "censor": cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1) elif box_style == "bounding-box": cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3) label = ( f"{detect_keyword}" if track_id is not None else detect_keyword ) label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0] cv2.rectangle( frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1 ) cv2.putText( frame, label, (x1, y1 - 6), FONT, 0.7, (255, 255, 255), 2, cv2.LINE_AA, ) elif box_style == "fuzzy-blur": # Extract ROI roi = frame[y1:y2, x1:x2] # Apply Gaussian blur with much larger kernel for intense blur blurred_roi = cv2.GaussianBlur(roi, (125, 125), 0) # Replace original ROI with blurred version frame[y1:y2, x1:x2] = blurred_roi elif box_style == "pixelated-blur": # Extract ROI roi = frame[y1:y2, x1:x2] # Pixelate by resizing down and up h, w = roi.shape[:2] temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR) pixelated = cv2.resize( temp, (w, h), interpolation=cv2.INTER_NEAREST ) # Mix up the pixelated frame slightly by adding random noise noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8) pixelated = cv2.add(pixelated, noise) # Apply stronger Gaussian blur to smooth edges blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0) # Replace original ROI frame[y1:y2, x1:x2] = blurred_pixelated elif box_style == "obfuscated-pixel": # Calculate expansion amount based on 10% of object dimensions box_width = x2 - x1 box_height = y2 - y1 expand_x = int(box_width * 0.10) expand_y = int(box_height * 0.10) # Expand the bounding box by 10% in all directions x1_expanded = max(0, x1 - expand_x) y1_expanded = max(0, y1 - expand_y) x2_expanded = min(width - 1, x2 + expand_x) y2_expanded = min(height - 1, y2 + expand_y) # Extract ROI with much larger padding for true background sampling padding = 100 # Much larger padding to get true background y1_pad = max(0, y1_expanded - padding) y2_pad = min(height, y2_expanded + padding) x1_pad = max(0, x1_expanded - padding) x2_pad = min(width, x2_expanded + padding) # Get the padded region including background padded_roi = frame[y1_pad:y2_pad, x1_pad:x2_pad] # Create mask that excludes a larger region around the detection h, w = y2_expanded - y1_expanded, x2_expanded - x1_expanded bg_mask = np.ones(padded_roi.shape[:2], dtype=bool) # Exclude a larger region around the detection from background sampling exclusion_padding = 50 # Area to exclude around detection exclude_y1 = padding - exclusion_padding exclude_y2 = padding + h + exclusion_padding exclude_x1 = padding - exclusion_padding exclude_x2 = padding + w + exclusion_padding # Make sure exclusion coordinates are valid exclude_y1 = max(0, exclude_y1) exclude_y2 = min(padded_roi.shape[0], exclude_y2) exclude_x1 = max(0, exclude_x1) exclude_x2 = min(padded_roi.shape[1], exclude_x2) # Mark the exclusion zone in the mask bg_mask[exclude_y1:exclude_y2, exclude_x1:exclude_x2] = False # If we have enough background pixels, calculate average color if np.any(bg_mask): bg_color = np.mean(padded_roi[bg_mask], axis=0).astype(np.uint8) else: # Fallback to edges if we couldn't get enough background edge_samples = np.concatenate( [ padded_roi[0], # Top edge padded_roi[-1], # Bottom edge padded_roi[:, 0], # Left edge padded_roi[:, -1], # Right edge ] ) bg_color = np.mean(edge_samples, axis=0).astype(np.uint8) # Create base pixelated version (of the expanded region) temp = cv2.resize( frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded], (6, 6), interpolation=cv2.INTER_LINEAR, ) pixelated = cv2.resize( temp, (w, h), interpolation=cv2.INTER_NEAREST ) # Blend heavily towards background color blend_factor = 0.9 # Much stronger blend with background blended = cv2.addWeighted( pixelated, 1 - blend_factor, np.full((h, w, 3), bg_color, dtype=np.uint8), blend_factor, 0, ) # Replace original ROI with blended version (using expanded coordinates) frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = blended elif box_style == "intense-pixelated-blur": # Expand the bounding box by pixels in all directions x1_expanded = max(0, x1 - 15) y1_expanded = max(0, y1 - 15) x2_expanded = min(width - 1, x2 + 25) y2_expanded = min(height - 1, y2 + 25) # Extract ROI roi = frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] # Pixelate by resizing down and up h, w = roi.shape[:2] temp = cv2.resize(roi, (10, 10), interpolation=cv2.INTER_LINEAR) pixelated = cv2.resize( temp, (w, h), interpolation=cv2.INTER_NEAREST ) # Mix up the pixelated frame slightly by adding random noise noise = np.random.randint(0, 50, (h, w, 3), dtype=np.uint8) pixelated = cv2.add(pixelated, noise) # Apply stronger Gaussian blur to smooth edges blurred_pixelated = cv2.GaussianBlur(pixelated, (15, 15), 0) # Replace original ROI frame[y1_expanded:y2_expanded, x1_expanded:x2_expanded] = ( blurred_pixelated ) elif box_style == "hitmarker": if points: for point in points: try: print(f"Processing point: {point}") center_x = int(float(point["x"]) * width) center_y = int(float(point["y"]) * height) print( f"Converted coordinates: ({center_x}, {center_y})" ) draw_hitmarker(frame, center_x, center_y) label = ( f"{detect_keyword}" if track_id is not None else detect_keyword ) label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0] cv2.putText( frame, label, ( center_x - label_size[0] // 2, center_y - HITMARKER_SIZE - 5, ), FONT, 0.5, HITMARKER_COLOR, 1, cv2.LINE_AA, ) except Exception as e: print(f"Error processing individual point: {str(e)}") print(f"Point data: {point}") except Exception as e: print(f"Error drawing {box_style} style box: {str(e)}") print(f"Box data: {box}") print(f"Keyword: {keyword}") return frame def filter_temporal_outliers(detections_dict): """Filter out extremely large detections that take up most of the frame. Only keeps detections that are reasonable in size. Args: detections_dict: Dictionary of {frame_number: [(box, keyword, track_id), ...]} """ filtered_detections = {} for t, detections in detections_dict.items(): # Only keep detections that aren't too large valid_detections = [] for detection in detections: # Handle both tracked and untracked detections if len(detection) == 3: # Tracked detection with ID box, keyword, track_id = detection else: # Regular detection without tracking box, keyword = detection track_id = None # Calculate box size as percentage of frame width = box[2] - box[0] height = box[3] - box[1] area = width * height # If box is less than 90% of frame, keep it if area < 0.9: if track_id is not None: valid_detections.append((box, keyword, track_id)) else: valid_detections.append((box, keyword)) if valid_detections: filtered_detections[t] = valid_detections return filtered_detections def describe_frames( video_path, model, tokenizer, detect_keyword, test_mode=False, test_duration=DEFAULT_TEST_MODE_DURATION, grid_rows=1, grid_cols=1, ): """Extract and detect objects in frames.""" props = get_video_properties(video_path) fps = props["fps"] # Initialize DeepSORT tracker tracker = DeepSORTTracker() # If in test mode, only process first N seconds if test_mode: frame_count = min(int(fps * test_duration), props["frame_count"]) else: frame_count = props["frame_count"] ad_detections = {} # Store detection results by frame number print("Extracting frames and detecting objects...") video = cv2.VideoCapture(video_path) # Detect scenes first scenes = detect(video_path, scene_detector) scene_changes = set(end.get_frames() for _, end in scenes) print(f"Detected {len(scenes)} scenes") frame_count_processed = 0 with tqdm(total=frame_count) as pbar: while frame_count_processed < frame_count: ret, frame = video.read() if not ret: break # Check if current frame is a scene change if frame_count_processed in scene_changes: # Detect objects in the frame detected_objects = detect_objects_in_frame( model, tokenizer, frame, detect_keyword, grid_rows=grid_rows, grid_cols=grid_cols, ) # Update tracker with current detections tracked_objects = tracker.update(frame, detected_objects) # Store results for every frame, even if empty ad_detections[frame_count_processed] = tracked_objects frame_count_processed += 1 pbar.update(1) video.release() if frame_count_processed == 0: print("No frames could be read from video") return {} return ad_detections def create_detection_video( video_path, ad_detections, detect_keyword, model, output_path=None, ffmpeg_preset="medium", test_mode=False, test_duration=DEFAULT_TEST_MODE_DURATION, box_style="censor", ): """Create video with detection boxes while preserving audio.""" if output_path is None: # Create outputs directory if it doesn't exist outputs_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "outputs" ) os.makedirs(outputs_dir, exist_ok=True) # Clean the detect_keyword for filename safe_keyword = "".join( x for x in detect_keyword if x.isalnum() or x in (" ", "_", "-") ) safe_keyword = safe_keyword.replace(" ", "_") # Create output filename base_name = os.path.splitext(os.path.basename(video_path))[0] output_path = os.path.join( outputs_dir, f"{box_style}_{safe_keyword}_{base_name}.mp4" ) print(f"Will save output to: {output_path}") props = get_video_properties(video_path) fps, width, height = props["fps"], props["width"], props["height"] # If in test mode, only process first few seconds if test_mode: frame_count = min(int(fps * test_duration), props["frame_count"]) print( f"Test mode enabled: Processing first {test_duration} seconds ({frame_count} frames)" ) else: frame_count = props["frame_count"] print("Full video mode: Processing entire video") video = cv2.VideoCapture(video_path) # Create temp output path by adding _temp before the extension base, ext = os.path.splitext(output_path) temp_output = f"{base}_temp{ext}" temp_audio = f"{base}_audio.aac" # Temporary audio file out = cv2.VideoWriter( temp_output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height) ) print("Creating detection video...") frame_count_processed = 0 with tqdm(total=frame_count) as pbar: while frame_count_processed < frame_count: ret, frame = video.read() if not ret: break # Get detections for this exact frame if frame_count_processed in ad_detections: current_detections = ad_detections[frame_count_processed] if current_detections: frame = draw_ad_boxes( frame, current_detections, detect_keyword, model, box_style=box_style, ) out.write(frame) frame_count_processed += 1 pbar.update(1) video.release() out.release() # Extract audio from original video try: if test_mode: # In test mode, extract only the required duration of audio subprocess.run( [ "ffmpeg", "-y", "-i", video_path, "-t", str(test_duration), "-vn", # No video "-acodec", "copy", temp_audio, ], check=True, ) else: subprocess.run( [ "ffmpeg", "-y", "-i", video_path, "-vn", # No video "-acodec", "copy", temp_audio, ], check=True, ) except subprocess.CalledProcessError as e: print(f"Error extracting audio: {str(e)}") if os.path.exists(temp_output): os.remove(temp_output) return None # Merge processed video with original audio try: # Base FFmpeg command ffmpeg_cmd = [ "ffmpeg", "-y", "-i", temp_output, "-i", temp_audio, "-c:v", "libx264", "-preset", ffmpeg_preset, "-crf", "23", "-c:a", "aac", "-b:a", "192k", "-movflags", "+faststart", # Better web playback ] if test_mode: # In test mode, ensure output duration matches test_duration ffmpeg_cmd.extend( [ "-t", str(test_duration), "-shortest", # Ensure output duration matches shortest input ] ) ffmpeg_cmd.extend(["-loglevel", "error", output_path]) subprocess.run(ffmpeg_cmd, check=True) # Clean up temporary files os.remove(temp_output) os.remove(temp_audio) if not os.path.exists(output_path): print( f"Warning: FFmpeg completed but output file not found at {output_path}" ) return None return output_path except subprocess.CalledProcessError as e: print(f"Error merging audio with video: {str(e)}") if os.path.exists(temp_output): os.remove(temp_output) if os.path.exists(temp_audio): os.remove(temp_audio) return None def process_video( video_path, target_object, test_mode=False, test_duration=DEFAULT_TEST_MODE_DURATION, ffmpeg_preset="medium", grid_rows=1, grid_cols=1, box_style="censor", ): """Process a video to detect and visualize specified objects.""" try: print(f"\nProcessing: {video_path}") print(f"Looking for: {target_object}") # Load model print("Loading Moondream model...") model, tokenizer = load_moondream() # Get video properties props = get_video_properties(video_path) # Initialize scene detector with ContentDetector scene_detector = ContentDetector(threshold=30.0) # Adjust threshold as needed # Initialize DeepSORT tracker tracker = DeepSORTTracker() # If in test mode, only process first N seconds if test_mode: frame_count = min(int(props["fps"] * test_duration), props["frame_count"]) else: frame_count = props["frame_count"] ad_detections = {} # Store detection results by frame number print("Extracting frames and detecting objects...") video = cv2.VideoCapture(video_path) # Detect scenes first scenes = detect(video_path, scene_detector) scene_changes = set(end.get_frames() for _, end in scenes) print(f"Detected {len(scenes)} scenes") frame_count_processed = 0 with tqdm(total=frame_count) as pbar: while frame_count_processed < frame_count: ret, frame = video.read() if not ret: break # Check if current frame is a scene change if frame_count_processed in scene_changes: print( f"Scene change detected at frame {frame_count_processed}. Resetting tracker." ) tracker.reset() # Detect objects in the frame detected_objects = detect_objects_in_frame( model, tokenizer, frame, target_object, grid_rows=grid_rows, grid_cols=grid_cols, ) # Update tracker with current detections tracked_objects = tracker.update(frame, detected_objects) # Store results for every frame, even if empty ad_detections[frame_count_processed] = tracked_objects frame_count_processed += 1 pbar.update(1) video.release() if frame_count_processed == 0: print("No frames could be read from video") return {} # Apply filtering filtered_ad_detections = filter_temporal_outliers(ad_detections) # Build detection data structure detection_data = { "video_metadata": { "file_name": os.path.basename(video_path), "fps": props["fps"], "width": props["width"], "height": props["height"], "total_frames": props["frame_count"], "duration_sec": props["frame_count"] / props["fps"], "detect_keyword": target_object, "test_mode": test_mode, "grid_size": f"{grid_rows}x{grid_cols}", "box_style": box_style, "timestamp": datetime.now().isoformat(), }, "frame_detections": [ { "frame": frame_num, "timestamp": frame_num / props["fps"], "objects": [ { "keyword": kw, "bbox": list(box), # Convert numpy array to list if needed "track_id": track_id if len(detection) == 3 else None, } for detection in filtered_ad_detections.get(frame_num, []) for box, kw, *track_id in [ detection ] # Unpack detection tuple, track_id will be empty list if not present ], } for frame_num in range( props["frame_count"] if not test_mode else min(int(props["fps"] * test_duration), props["frame_count"]) ) ], } # Save filtered data outputs_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "outputs" ) os.makedirs(outputs_dir, exist_ok=True) base_name = os.path.splitext(os.path.basename(video_path))[0] json_path = os.path.join( outputs_dir, f"{box_style}_{target_object}_{base_name}_detections.json" ) from persistence import save_detection_data if not save_detection_data(detection_data, json_path): print("Warning: Failed to save detection data") # Create video with filtered data output_path = create_detection_video( video_path, filtered_ad_detections, target_object, model, ffmpeg_preset=ffmpeg_preset, test_mode=test_mode, test_duration=test_duration, box_style=box_style, ) if output_path is None: print("\nError: Failed to create output video") return None print(f"\nOutput saved to: {output_path}") print(f"Detection data saved to: {json_path}") return output_path except Exception as e: print(f"Error processing video: {str(e)}") import traceback traceback.print_exc() return None def main(): """Process all videos in the inputs directory.""" parser = argparse.ArgumentParser( description="Detect objects in videos using Moondream2" ) parser.add_argument( "--test", action="store_true", help="Process only first 3 seconds of each video" ) parser.add_argument( "--test-duration", type=int, default=DEFAULT_TEST_MODE_DURATION, help=f"Number of seconds to process in test mode (default: {DEFAULT_TEST_MODE_DURATION})", ) parser.add_argument( "--preset", choices=FFMPEG_PRESETS, default="medium", help="FFmpeg encoding preset (default: medium). Faster presets = lower quality", ) parser.add_argument( "--detect", type=str, default="face", help='Object to detect in the video (default: face, use --detect "thing to detect" to override)', ) parser.add_argument( "--rows", type=int, default=1, help="Number of rows to split each frame into (default: 1)", ) parser.add_argument( "--cols", type=int, default=1, help="Number of columns to split each frame into (default: 1)", ) parser.add_argument( "--box-style", choices=[ "censor", "bounding-box", "hitmarker", "sam", "sam-fast", "fuzzy-blur", "pixelated-blur", "intense-pixelated-blur", "obfuscated-pixel", ], default="censor", help="Style of detection visualization (default: censor)", ) args = parser.parse_args() input_dir = "inputs" os.makedirs(input_dir, exist_ok=True) os.makedirs("outputs", exist_ok=True) video_files = [ f for f in os.listdir(input_dir) if f.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")) ] if not video_files: print("No video files found in 'inputs' directory") return print(f"Found {len(video_files)} videos to process") print(f"Will detect: {args.detect}") if args.test: print("Running in test mode - processing only first 3 seconds of each video") print(f"Using FFmpeg preset: {args.preset}") print(f"Grid size: {args.rows}x{args.cols}") print(f"Box style: {args.box_style}") success_count = 0 for video_file in video_files: video_path = os.path.join(input_dir, video_file) output_path = process_video( video_path, args.detect, test_mode=args.test, test_duration=args.test_duration, ffmpeg_preset=args.preset, grid_rows=args.rows, grid_cols=args.cols, box_style=args.box_style, ) if output_path: success_count += 1 print( f"\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos." ) if __name__ == "__main__": main() ================================================ FILE: recipes/promptable-content-moderation/packages.txt ================================================ libvips ffmpeg ================================================ FILE: recipes/promptable-content-moderation/persistence.py ================================================ import json import os def save_detection_data(data, output_file): """ Saves the detection data to a JSON file. Args: data (dict): The complete detection data structure. output_file (str): Path to the output JSON file. """ try: # Create directory if it doesn't exist os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, "w") as f: json.dump(data, f, indent=4) print(f"Detection data saved to {output_file}") return True except Exception as e: print(f"Error saving data: {str(e)}") return False def load_detection_data(input_file): """ Loads the detection data from a JSON file. Args: input_file (str): Path to the JSON file. Returns: dict: The loaded detection data, or None if there was an error. """ try: with open(input_file, "r") as f: return json.load(f) except Exception as e: print(f"Error loading data: {str(e)}") return None ================================================ FILE: recipes/promptable-content-moderation/requirements.txt ================================================ gradio>=4.0.0 torch>=2.0.0 # if on windows: pip install torch==2.5.1+cu121 torchvision==0.20.1+cu121 --index-url https://download.pytorch.org/whl/cu121 transformers>=4.36.0 opencv-python>=4.8.0 pillow>=10.0.0 numpy>=1.24.0 tqdm>=4.66.0 ffmpeg-python einops pyvips-binary pyvips accelerate # for spaces --extra-index-url https://download.pytorch.org/whl/cu113 spaces # SAM dependencies torchvision>=0.20.1 matplotlib>=3.7.0 pandas>=2.0.0 plotly # DeepSORT dependencies deep-sort-realtime>=1.3.2 scikit-learn # Required for deep-sort-realtime # Scene detection dependencies (for intelligent scene-aware tracking) scenedetect[opencv]>=0.6.2 # Provides scene change detection capabilities ================================================ FILE: recipes/promptable-content-moderation/video_visualization.py ================================================ import os import tempfile import subprocess import matplotlib.pyplot as plt import pandas as pd import cv2 import numpy as np from tqdm import tqdm from persistence import load_detection_data def create_frame_data(json_path): """Create frame-by-frame detection data for visualization.""" try: data = load_detection_data(json_path) if not data: print("No data loaded from JSON file") return None if "video_metadata" not in data or "frame_detections" not in data: print("Invalid JSON structure: missing required fields") return None # Extract video metadata metadata = data["video_metadata"] if "fps" not in metadata or "total_frames" not in metadata: print("Invalid metadata: missing fps or total_frames") return None fps = metadata["fps"] total_frames = metadata["total_frames"] # Create frame data frame_counts = {} for frame_data in data["frame_detections"]: if "frame" not in frame_data or "objects" not in frame_data: continue # Skip invalid frame data frame_num = frame_data["frame"] frame_counts[frame_num] = len(frame_data["objects"]) # Fill in missing frames with 0 detections for frame in range(total_frames): if frame not in frame_counts: frame_counts[frame] = 0 if not frame_counts: print("No valid frame data found") return None # Convert to DataFrame df = pd.DataFrame(list(frame_counts.items()), columns=["frame", "detections"]) df["timestamp"] = df["frame"] / fps return df, metadata except Exception as e: print(f"Error creating frame data: {str(e)}") import traceback traceback.print_exc() return None def generate_frame_image(df, frame_num, temp_dir, max_y): """Generate and save a single frame of the visualization.""" # Set the style to dark background plt.style.use("dark_background") # Set global font to monospace plt.rcParams["font.family"] = "monospace" plt.rcParams["font.monospace"] = ["DejaVu Sans Mono"] plt.figure(figsize=(10, 6)) # Plot data up to current frame current_data = df[df["frame"] <= frame_num] plt.plot( df["frame"], df["detections"], color="#1a1a1a", alpha=0.5 ) # Darker background line plt.plot( current_data["frame"], current_data["detections"], color="#00ff41" ) # Matrix green # Add vertical line for current position plt.axvline( x=frame_num, color="#ff0000", linestyle="-", alpha=0.7 ) # Keep red for position # Set consistent axes plt.xlim(0, len(df) - 1) plt.ylim(0, max_y * 1.1) # Add 10% padding # Add labels with Matrix green color plt.title(f"FRAME {frame_num:04d} - DETECTIONS OVER TIME", color="#00ff41", pad=20) plt.xlabel("FRAME NUMBER", color="#00ff41") plt.ylabel("NUMBER OF DETECTIONS", color="#00ff41") # Add current stats in Matrix green with monospace formatting current_detections = df[df["frame"] == frame_num]["detections"].iloc[0] plt.text( 0.02, 0.98, f"CURRENT DETECTIONS: {current_detections:02d}", transform=plt.gca().transAxes, verticalalignment="top", color="#00ff41", family="monospace", ) # Style the grid and ticks plt.grid(True, color="#1a1a1a", linestyle="-", alpha=0.3) plt.tick_params(colors="#00ff41") # Save frame frame_path = os.path.join(temp_dir, f"frame_{frame_num:05d}.png") plt.savefig( frame_path, bbox_inches="tight", dpi=100, facecolor="black", edgecolor="none" ) plt.close() return frame_path def generate_gauge_frame(df, frame_num, temp_dir, detect_keyword="OBJECT"): """Generate a modern square-style binary gauge visualization frame.""" # Set the style to dark background plt.style.use("dark_background") # Set global font to monospace plt.rcParams["font.family"] = "monospace" plt.rcParams["font.monospace"] = ["DejaVu Sans Mono"] # Create figure with 16:9 aspect ratio plt.figure(figsize=(16, 9)) # Get current detection state current_detections = df[df["frame"] == frame_num]["detections"].iloc[0] has_detection = current_detections > 0 # Create a simple gauge visualization plt.axis("off") # Set colors if has_detection: color = "#00ff41" # Matrix green for YES status = "YES" indicator_pos = 0.8 # Right position else: color = "#ff0000" # Red for NO status = "NO" indicator_pos = 0.2 # Left position # Draw background rectangle background = plt.Rectangle( (0.1, 0.3), 0.8, 0.2, facecolor="#1a1a1a", edgecolor="#333333", linewidth=2 ) plt.gca().add_patch(background) # Draw indicator indicator_width = 0.05 indicator = plt.Rectangle( (indicator_pos - indicator_width / 2, 0.25), indicator_width, 0.3, facecolor=color, edgecolor=None, ) plt.gca().add_patch(indicator) # Add tick marks tick_positions = [0.2, 0.5, 0.8] # NO, CENTER, YES for x in tick_positions: plt.plot([x, x], [0.3, 0.5], color="#444444", linewidth=2) # Add YES/NO labels plt.text( 0.8, 0.2, "YES", color="#00ff41", fontsize=14, ha="center", va="center", family="monospace", ) plt.text( 0.2, 0.2, "NO", color="#ff0000", fontsize=14, ha="center", va="center", family="monospace", ) # Add status box at top with detection keyword plt.text( 0.5, 0.8, f"{detect_keyword.upper()} DETECTED?", color=color, fontsize=16, ha="center", va="center", family="monospace", bbox=dict(facecolor="#1a1a1a", edgecolor=color, linewidth=2, pad=10), ) # Add frame counter at bottom plt.text( 0.5, 0.1, f"FRAME: {frame_num:04d}", color="#00ff41", fontsize=14, ha="center", va="center", family="monospace", ) # Add subtle grid lines for depth for x in np.linspace(0.2, 0.8, 7): plt.plot([x, x], [0.3, 0.5], color="#222222", linewidth=1, zorder=0) # Add glow effect to indicator for i in range(3): glow = plt.Rectangle( (indicator_pos - (indicator_width + i * 0.01) / 2, 0.25 - i * 0.01), indicator_width + i * 0.01, 0.3 + i * 0.02, facecolor=color, alpha=0.1 / (i + 1), ) plt.gca().add_patch(glow) # Set consistent plot limits plt.xlim(0, 1) plt.ylim(0, 1) # Save frame with 16:9 aspect ratio frame_path = os.path.join(temp_dir, f"gauge_{frame_num:05d}.png") plt.savefig( frame_path, bbox_inches="tight", dpi=100, facecolor="black", edgecolor="none", pad_inches=0, ) plt.close() return frame_path def create_video_visualization(json_path, style="timeline"): """Create a video visualization of the detection data.""" try: if not json_path: return None, "No JSON file provided" if not os.path.exists(json_path): return None, f"File not found: {json_path}" # Load and process data result = create_frame_data(json_path) if result is None: return None, "Failed to load detection data from JSON file" frame_data, metadata = result if len(frame_data) == 0: return None, "No frame data found in JSON file" total_frames = metadata["total_frames"] detect_keyword = metadata.get( "detect_keyword", "OBJECT" ) # Get the detection keyword # Create temporary directory for frames with tempfile.TemporaryDirectory() as temp_dir: max_y = frame_data["detections"].max() # Generate each frame print("Generating frames...") frame_paths = [] with tqdm(total=total_frames, desc="Generating frames") as pbar: for frame in range(total_frames): try: if style == "gauge": frame_path = generate_gauge_frame( frame_data, frame, temp_dir, detect_keyword ) else: # default to timeline frame_path = generate_frame_image( frame_data, frame, temp_dir, max_y ) if frame_path and os.path.exists(frame_path): frame_paths.append(frame_path) else: print(f"Warning: Failed to generate frame {frame}") pbar.update(1) except Exception as e: print(f"Error generating frame {frame}: {str(e)}") continue if not frame_paths: return None, "Failed to generate any frames" # Create output video path output_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "outputs" ) os.makedirs(output_dir, exist_ok=True) output_video = os.path.join( output_dir, f"detection_visualization_{style}.mp4" ) # Create temp output path base, ext = os.path.splitext(output_video) temp_output = f"{base}_temp{ext}" # First pass: Create video with OpenCV VideoWriter print("Creating initial video...") # Get frame size from first image first_frame = cv2.imread(frame_paths[0]) height, width = first_frame.shape[:2] out = cv2.VideoWriter( temp_output, cv2.VideoWriter_fourcc(*"mp4v"), metadata["fps"], (width, height), ) with tqdm( total=total_frames, desc="Creating video" ) as pbar: # Use total_frames here too for frame_path in frame_paths: frame = cv2.imread(frame_path) out.write(frame) pbar.update(1) out.release() # Second pass: Convert to web-compatible format print("Converting to web format...") try: subprocess.run( [ "ffmpeg", "-y", "-i", temp_output, "-c:v", "libx264", "-preset", "medium", "-crf", "23", "-movflags", "+faststart", # Better web playback "-loglevel", "error", output_video, ], check=True, ) os.remove(temp_output) # Remove the temporary file if not os.path.exists(output_video): print( f"Warning: FFmpeg completed but output file not found at {output_video}" ) return None, "Failed to create video" # Return video path and stats stats = f"""Video Stats: FPS: {metadata['fps']} Total Frames: {metadata['total_frames']} Duration: {metadata['duration_sec']:.2f} seconds Max Detections in a Frame: {frame_data['detections'].max()} Average Detections per Frame: {frame_data['detections'].mean():.2f}""" return output_video, stats except subprocess.CalledProcessError as e: print(f"Error running FFmpeg: {str(e)}") if os.path.exists(temp_output): os.remove(temp_output) return None, f"Error creating visualization: {str(e)}" except Exception as e: print(f"Error creating video visualization: {str(e)}") import traceback traceback.print_exc() return None, f"Error creating visualization: {str(e)}" ================================================ FILE: recipes/promptable-content-moderation/visualization.py ================================================ import pandas as pd import matplotlib.pyplot as plt from persistence import load_detection_data import argparse def visualize_detections(json_path): """ Visualize detection data from a JSON file. Args: json_path (str): Path to the JSON file containing detection data. """ # Load the persisted JSON data data = load_detection_data(json_path) if not data: return # Convert the frame detections to a DataFrame rows = [] for frame_data in data["frame_detections"]: frame = frame_data["frame"] timestamp = frame_data["timestamp"] for obj in frame_data["objects"]: rows.append( { "frame": frame, "timestamp": timestamp, "keyword": obj["keyword"], "x1": obj["bbox"][0], "y1": obj["bbox"][1], "x2": obj["bbox"][2], "y2": obj["bbox"][3], "area": (obj["bbox"][2] - obj["bbox"][0]) * (obj["bbox"][3] - obj["bbox"][1]), } ) if not rows: print("No detections found in the data") return df = pd.DataFrame(rows) # Create a figure with multiple subplots fig = plt.figure(figsize=(15, 10)) # Plot 1: Number of detections per frame plt.subplot(2, 2, 1) detections_per_frame = df.groupby("frame").size() plt.plot(detections_per_frame.index, detections_per_frame.values) plt.xlabel("Frame") plt.ylabel("Number of Detections") plt.title("Detections Per Frame") # Plot 2: Distribution of detection areas plt.subplot(2, 2, 2) df["area"].hist(bins=30) plt.xlabel("Detection Area (normalized)") plt.ylabel("Count") plt.title("Distribution of Detection Areas") # Plot 3: Average detection area over time plt.subplot(2, 2, 3) avg_area = df.groupby("frame")["area"].mean() plt.plot(avg_area.index, avg_area.values) plt.xlabel("Frame") plt.ylabel("Average Detection Area") plt.title("Average Detection Area Over Time") # Plot 4: Heatmap of detection centers plt.subplot(2, 2, 4) df["center_x"] = (df["x1"] + df["x2"]) / 2 df["center_y"] = (df["y1"] + df["y2"]) / 2 plt.hist2d(df["center_x"], df["center_y"], bins=30) plt.colorbar() plt.xlabel("X Position") plt.ylabel("Y Position") plt.title("Detection Center Heatmap") # Adjust layout and display plt.tight_layout() plt.show() # Print summary statistics print("\nSummary Statistics:") print(f"Total frames analyzed: {len(data['frame_detections'])}") print(f"Total detections: {len(df)}") print( f"Average detections per frame: {len(df) / len(data['frame_detections']):.2f}" ) print(f"\nVideo metadata:") for key, value in data["video_metadata"].items(): print(f"{key}: {value}") def main(): parser = argparse.ArgumentParser(description="Visualize object detection data") parser.add_argument( "json_file", help="Path to the JSON file containing detection data" ) args = parser.parse_args() visualize_detections(args.json_file) if __name__ == "__main__": main() ================================================ FILE: recipes/promptable-video-redaction/.gitignore ================================================ # Python __pycache__/ *.py[cod] *$py.class *.so .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg # Virtual Environment venv/ env/ ENV/ .venv/ # IDE .idea/ .vscode/ *.swp *.swo # Project specific inputs/* outputs/* !inputs/.gitkeep !outputs/.gitkeep inputs/ outputs/ # Model files *.pth *.onnx *.pt # Logs *.log certificate.pem ================================================ FILE: recipes/promptable-video-redaction/README.md ================================================ # Promptable Video Redaction with Moondream This tool uses Moondream 2B, a powerful yet lightweight vision-language model, to detect and redact objects from videos. Moondream can recognize a wide variety of objects, people, text, and more with high accuracy while being much smaller than traditional models. [Try it now.](https://huggingface.co/spaces/moondream/promptable-video-redaction) ## About Moondream Moondream is a tiny yet powerful vision-language model that can analyze images and answer questions about them. It's designed to be lightweight and efficient while maintaining high accuracy. Some key features: - Only 2B parameters - Fast inference with minimal resource requirements - Supports CPU and GPU execution - Open source and free to use - Can detect almost anything you can describe in natural language Links: - [GitHub Repository](https://github.com/vikhyat/moondream) - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) - [Build with Moondream](http://docs.moondream.ai/) ## Features - Real-time object detection in videos using Moondream - Multiple visualization styles: - Censor: Black boxes over detected objects - Bounding Box: Traditional bounding boxes with labels - Hitmarker: Call of Duty style crosshair markers - Optional grid-based detection for improved accuracy - Flexible object type detection using natural language - Frame-by-frame processing with IoU-based merging - Batch processing of multiple videos - Web-compatible output format - User-friendly web interface - Command-line interface for automation ## Requirements - Python 3.8+ - OpenCV (cv2) - PyTorch - Transformers - Pillow (PIL) - tqdm - ffmpeg - numpy - gradio (for web interface) ## Installation 1. Clone this repository and create a new virtual environment ```bash git clone https://github.com/vikhyat/moondream/blob/main/recipes/promptable-video-redaction python -m venv .venv source .venv/bin/activate ``` 2. Install the required packages: ```bash pip install -r requirements.txt ``` 3. Install ffmpeg: - On Ubuntu/Debian: `sudo apt-get install ffmpeg libvips` - On macOS: `brew install ffmpeg` - On Windows: Download from [ffmpeg.org](https://ffmpeg.org/download.html) > Downloading libvips for Windows requires some additional steps, see [here](https://docs.moondream.ai/quick-start) ## Usage ### Web Interface 1. Start the web interface: ```bash python app.py ``` 2. Open the provided URL in your browser 3. Use the interface to: - Upload your video - Specify what to censor (e.g., face, logo, text) - Adjust processing speed and quality - Configure grid size for detection - Process and download the censored video ### Command Line Interface 1. Create an `inputs` directory in the same folder as the script: ```bash mkdir inputs ``` 2. Place your video files in the `inputs` directory. Supported formats: - .mp4 - .avi - .mov - .mkv - .webm 3. Run the script: ```bash python main.py ``` ### Optional Arguments: - `--test`: Process only first 3 seconds of each video (useful for testing detection settings) ```bash python main.py --test ``` - `--preset`: Choose FFmpeg encoding preset (affects output quality vs. speed) ```bash python main.py --preset ultrafast # Fastest, lower quality python main.py --preset veryslow # Slowest, highest quality ``` - `--detect`: Specify what object type to detect (using natural language) ```bash python main.py --detect person # Detect people python main.py --detect "red car" # Detect red cars python main.py --detect "person wearing a hat" # Detect people with hats ``` - `--box-style`: Choose visualization style ```bash python main.py --box-style censor # Black boxes (default) python main.py --box-style bounding-box # Bounding box-style boxes with labels python main.py --box-style hitmarker # COD-style hitmarkers ``` - `--rows` and `--cols`: Enable grid-based detection by splitting frames ```bash python main.py --rows 2 --cols 2 # Split each frame into 2x2 grid python main.py --rows 3 --cols 3 # Split each frame into 3x3 grid ``` You can combine arguments: ```bash python main.py --detect "person wearing sunglasses" --box-style bounding-box --test --preset "fast" --rows 2 --cols 2 ``` ### Visualization Styles The tool supports three different visualization styles for detected objects: 1. **Censor** (default) - Places solid black rectangles over detected objects - Best for privacy and content moderation - Completely obscures the detected region 2. **Bounding Box** - Traditional object detection style - Red bounding box around detected objects - Label showing object type above the box - Good for analysis and debugging 3. **Hitmarker** - Call of Duty inspired visualization - White crosshair marker at center of detected objects - Small label above the marker - Stylistic choice for gaming-inspired visualization Choose the style that best fits your use case using the `--box-style` argument. ## Output Processed videos will be saved in the `outputs` directory with the format: `[style]_[object_type]_[original_filename].mp4` For example: - `censor_face_video.mp4` - `bounding-box_person_video.mp4` - `hitmarker_car_video.mp4` The output videos will include: - Original video content - Selected visualization style for detected objects - Web-compatible H.264 encoding ## Notes - Processing time depends on video length, grid size, and GPU availability - GPU is strongly recommended for faster processing - Requires sufficient disk space for temporary files - Detection quality varies based on video quality and Moondream's ability to recognize the specified object - Grid-based detection impacts performance significantly - use only when needed - Web interface shows progress updates and errors - Choose visualization style based on your use case - Moondream can detect almost anything you can describe in natural language ================================================ FILE: recipes/promptable-video-redaction/app.py ================================================ #!/usr/bin/env python3 import gradio as gr import os from main import load_moondream, process_video import shutil import torch # Get absolute path to workspace root WORKSPACE_ROOT = os.path.dirname(os.path.abspath(__file__)) # Check CUDA availability print(f"Is CUDA available: {torch.cuda.is_available()}") # We want to get True print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") # GPU Name # Initialize model globally for reuse print("Loading Moondream model...") model, tokenizer = load_moondream() def process_video_file( video_file, detect_keyword, box_style, ffmpeg_preset, rows, cols, test_mode ): """Process a video file through the Gradio interface.""" try: if not video_file: raise gr.Error("Please upload a video file") # Ensure input/output directories exist using absolute paths inputs_dir = os.path.join(WORKSPACE_ROOT, "inputs") outputs_dir = os.path.join(WORKSPACE_ROOT, "outputs") os.makedirs(inputs_dir, exist_ok=True) os.makedirs(outputs_dir, exist_ok=True) # Copy uploaded video to inputs directory video_filename = f"input_{os.path.basename(video_file)}" input_video_path = os.path.join(inputs_dir, video_filename) shutil.copy2(video_file, input_video_path) try: # Process the video output_path = process_video( input_video_path, detect_keyword, test_mode=test_mode, ffmpeg_preset=ffmpeg_preset, rows=rows, cols=cols, box_style=box_style, ) # Verify output exists and is readable if not output_path or not os.path.exists(output_path): print(f"Warning: Output path {output_path} does not exist") # Try to find the output based on expected naming convention expected_output = os.path.join( outputs_dir, f"{box_style}_{detect_keyword}_{video_filename}" ) if os.path.exists(expected_output): output_path = expected_output else: # Try searching in outputs directory for any matching file matching_files = [ f for f in os.listdir(outputs_dir) if f.startswith(f"{box_style}_{detect_keyword}_") ] if matching_files: output_path = os.path.join(outputs_dir, matching_files[0]) else: raise gr.Error("Failed to locate output video") # Convert output path to absolute path if it isn't already if not os.path.isabs(output_path): output_path = os.path.join(WORKSPACE_ROOT, output_path) print(f"Returning output path: {output_path}") return output_path finally: # Clean up input file try: if os.path.exists(input_video_path): os.remove(input_video_path) except: pass except Exception as e: print(f"Error in process_video_file: {str(e)}") raise gr.Error(f"Error processing video: {str(e)}") # Create the Gradio interface with gr.Blocks(title="Promptable Video Redaction") as app: gr.Markdown("# Promptable Video Redaction with Moondream") gr.Markdown( """ [Moondream 2B](https://github.com/vikhyat/moondream) is a lightweight vision model that detects and visualizes objects in videos. It can identify objects, people, text and more. Upload a video and specify what to detect. The app will process each frame and apply your chosen visualization style. For help, join the [Moondream Discord](https://discord.com/invite/tRUdpjDQfH). """ ) with gr.Row(): with gr.Column(): # Input components video_input = gr.Video(label="Upload Video") detect_input = gr.Textbox( label="What to Detect", placeholder="e.g. face, logo, text, person, car, dog, etc.", value="face", info="Moondream can detect anything that you can describe in natural language", ) process_btn = gr.Button("Process Video", variant="primary") with gr.Accordion("Advanced Settings", open=False): box_style_input = gr.Radio( choices=["censor", "bounding-box", "hitmarker"], value="censor", label="Visualization Style", info="Choose how to display detections", ) preset_input = gr.Dropdown( choices=[ "ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow", ], value="medium", label="Processing Speed (faster = lower quality)", ) with gr.Row(): rows_input = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Grid Rows" ) cols_input = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="Grid Columns" ) test_mode_input = gr.Checkbox( label="Test Mode (Process first 3 seconds only)", value=True, info="Enable to quickly test settings on a short clip before processing the full video (recommended)", ) gr.Markdown( """ Note: Processing in test mode will only process the first 3 seconds of the video and is recommended for testing settings. """ ) gr.Markdown( """ We can get a rough estimate of how long the video will take to process by multiplying the videos framerate * seconds * the number of rows and columns and assuming 0.12 seconds processing time per detection. For example, a 3 second video at 30fps with 2x2 grid, the estimated time is 3 * 30 * 2 * 2 * 0.12 = 43.2 seconds (tested on a 4090 GPU). """ ) with gr.Column(): # Output components video_output = gr.Video(label="Processed Video") # About section under the video output gr.Markdown( """ ### Links: - [GitHub Repository](https://github.com/vikhyat/moondream) - [Hugging Face](https://huggingface.co/vikhyatk/moondream2) - [Python Package](https://pypi.org/project/moondream/) - [Moondream Recipes](https://docs.moondream.ai/recipes) """ ) # Event handlers process_btn.click( fn=process_video_file, inputs=[ video_input, detect_input, box_style_input, preset_input, rows_input, cols_input, test_mode_input, ], outputs=video_output, ) if __name__ == "__main__": app.launch(share=True) ================================================ FILE: recipes/promptable-video-redaction/main.py ================================================ #!/usr/bin/env python3 import cv2, os, subprocess, argparse from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer from tqdm import tqdm import numpy as np # Constants TEST_MODE_DURATION = 3 # Process only first 3 seconds in test mode FFMPEG_PRESETS = [ "ultrafast", "superfast", "veryfast", "faster", "fast", "medium", "slow", "slower", "veryslow", ] FONT = cv2.FONT_HERSHEY_SIMPLEX # Font for bounding-box-style labels # Detection parameters IOU_THRESHOLD = 0.5 # IoU threshold for considering boxes related # Hitmarker parameters HITMARKER_SIZE = 20 # Size of the hitmarker in pixels HITMARKER_GAP = 3 # Size of the empty space in the middle (reduced from 8) HITMARKER_THICKNESS = 2 # Thickness of hitmarker lines HITMARKER_COLOR = (255, 255, 255) # White color for hitmarker HITMARKER_SHADOW_COLOR = (80, 80, 80) # Lighter gray for shadow effect HITMARKER_SHADOW_OFFSET = 1 # Smaller shadow offset def load_moondream(): """Load Moondream model and tokenizer.""" model = AutoModelForCausalLM.from_pretrained( "vikhyatk/moondream2", trust_remote_code=True, device_map={"": "cuda"} ) tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") return model, tokenizer def get_video_properties(video_path): """Get basic video properties.""" video = cv2.VideoCapture(video_path) fps = video.get(cv2.CAP_PROP_FPS) frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) video.release() return {"fps": fps, "frame_count": frame_count, "width": width, "height": height} def is_valid_box(box): """Check if box coordinates are reasonable.""" x1, y1, x2, y2 = box width = x2 - x1 height = y2 - y1 # Reject boxes that are too large (over 90% of frame in both dimensions) if width > 0.9 and height > 0.9: return False # Reject boxes that are too small (less than 1% of frame) if width < 0.01 or height < 0.01: return False return True def split_frame_into_tiles(frame, rows, cols): """Split a frame into a grid of tiles.""" height, width = frame.shape[:2] tile_height = height // rows tile_width = width // cols tiles = [] tile_positions = [] for i in range(rows): for j in range(cols): y1 = i * tile_height y2 = (i + 1) * tile_height if i < rows - 1 else height x1 = j * tile_width x2 = (j + 1) * tile_width if j < cols - 1 else width tile = frame[y1:y2, x1:x2] tiles.append(tile) tile_positions.append((x1, y1, x2, y2)) return tiles, tile_positions def convert_tile_coords_to_frame(box, tile_pos, frame_shape): """Convert coordinates from tile space to frame space.""" frame_height, frame_width = frame_shape[:2] tile_x1, tile_y1, tile_x2, tile_y2 = tile_pos tile_width = tile_x2 - tile_x1 tile_height = tile_y2 - tile_y1 x1_tile_abs = box[0] * tile_width y1_tile_abs = box[1] * tile_height x2_tile_abs = box[2] * tile_width y2_tile_abs = box[3] * tile_height x1_frame_abs = tile_x1 + x1_tile_abs y1_frame_abs = tile_y1 + y1_tile_abs x2_frame_abs = tile_x1 + x2_tile_abs y2_frame_abs = tile_y1 + y2_tile_abs x1_norm = x1_frame_abs / frame_width y1_norm = y1_frame_abs / frame_height x2_norm = x2_frame_abs / frame_width y2_norm = y2_frame_abs / frame_height x1_norm = max(0.0, min(1.0, x1_norm)) y1_norm = max(0.0, min(1.0, y1_norm)) x2_norm = max(0.0, min(1.0, x2_norm)) y2_norm = max(0.0, min(1.0, y2_norm)) return [x1_norm, y1_norm, x2_norm, y2_norm] def merge_tile_detections(tile_detections, iou_threshold=0.5): """Merge detections from different tiles using NMS-like approach.""" if not tile_detections: return [] all_boxes = [] all_keywords = [] # Collect all boxes and their keywords for detections in tile_detections: for box, keyword in detections: all_boxes.append(box) all_keywords.append(keyword) if not all_boxes: return [] # Convert to numpy for easier processing boxes = np.array(all_boxes) # Calculate areas x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] areas = (x2 - x1) * (y2 - y1) # Sort boxes by area order = areas.argsort()[::-1] keep = [] while order.size > 0: i = order[0] keep.append(i) if order.size == 1: break # Calculate IoU with rest of boxes xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) inter = w * h ovr = inter / (areas[i] + areas[order[1:]] - inter) # Get indices of boxes with IoU less than threshold inds = np.where(ovr <= iou_threshold)[0] order = order[inds + 1] return [(all_boxes[i], all_keywords[i]) for i in keep] def detect_ads_in_frame(model, tokenizer, image, detect_keyword, rows=1, cols=1): """Detect objects in a frame using grid-based detection.""" if rows == 1 and cols == 1: return detect_ads_in_frame_single(model, tokenizer, image, detect_keyword) # Convert numpy array to PIL Image if needed if not isinstance(image, Image.Image): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Split frame into tiles tiles, tile_positions = split_frame_into_tiles(image, rows, cols) # Process each tile tile_detections = [] for tile, tile_pos in zip(tiles, tile_positions): # Convert tile to PIL Image tile_pil = Image.fromarray(tile) # Detect objects in tile response = model.detect(tile_pil, detect_keyword) if response and "objects" in response and response["objects"]: objects = response["objects"] tile_objects = [] for obj in objects: if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] if is_valid_box(box): # Convert tile coordinates to frame coordinates frame_box = convert_tile_coords_to_frame( box, tile_pos, image.shape ) tile_objects.append((frame_box, detect_keyword)) if tile_objects: # Only append if we found valid objects tile_detections.append(tile_objects) # Merge detections from all tiles merged_detections = merge_tile_detections(tile_detections) return merged_detections def detect_ads_in_frame_single(model, tokenizer, image, detect_keyword): """Single-frame detection function.""" detected_objects = [] # Convert numpy array to PIL Image if needed if not isinstance(image, Image.Image): image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Detect objects response = model.detect(image, detect_keyword) # Check if we have valid objects if response and "objects" in response and response["objects"]: objects = response["objects"] for obj in objects: if all(k in obj for k in ["x_min", "y_min", "x_max", "y_max"]): box = [obj["x_min"], obj["y_min"], obj["x_max"], obj["y_max"]] # If box is valid (not full-frame), add it if is_valid_box(box): detected_objects.append((box, detect_keyword)) return detected_objects def draw_hitmarker( frame, center_x, center_y, size=HITMARKER_SIZE, color=HITMARKER_COLOR, shadow=True ): """Draw a COD-style hitmarker cross with more space in the middle.""" half_size = size // 2 # Draw shadow first if enabled if shadow: # Top-left to center shadow cv2.line( frame, ( center_x - half_size + HITMARKER_SHADOW_OFFSET, center_y - half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Top-right to center shadow cv2.line( frame, ( center_x + half_size + HITMARKER_SHADOW_OFFSET, center_y - half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Bottom-left to center shadow cv2.line( frame, ( center_x - half_size + HITMARKER_SHADOW_OFFSET, center_y + half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x - HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Bottom-right to center shadow cv2.line( frame, ( center_x + half_size + HITMARKER_SHADOW_OFFSET, center_y + half_size + HITMARKER_SHADOW_OFFSET, ), ( center_x + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, center_y + HITMARKER_GAP + HITMARKER_SHADOW_OFFSET, ), HITMARKER_SHADOW_COLOR, HITMARKER_THICKNESS, ) # Draw main hitmarker # Top-left to center cv2.line( frame, (center_x - half_size, center_y - half_size), (center_x - HITMARKER_GAP, center_y - HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Top-right to center cv2.line( frame, (center_x + half_size, center_y - half_size), (center_x + HITMARKER_GAP, center_y - HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Bottom-left to center cv2.line( frame, (center_x - half_size, center_y + half_size), (center_x - HITMARKER_GAP, center_y + HITMARKER_GAP), color, HITMARKER_THICKNESS, ) # Bottom-right to center cv2.line( frame, (center_x + half_size, center_y + half_size), (center_x + HITMARKER_GAP, center_y + HITMARKER_GAP), color, HITMARKER_THICKNESS, ) def draw_ad_boxes(frame, detected_objects, detect_keyword, box_style="censor"): """Draw detection visualizations over detected objects. Args: frame: The video frame to draw on detected_objects: List of (box, keyword) tuples detect_keyword: The detection keyword box_style: Visualization style ('censor', 'bounding-box', or 'hitmarker') """ height, width = frame.shape[:2] for box, keyword in detected_objects: try: # Convert normalized coordinates to pixel coordinates x1 = int(box[0] * width) y1 = int(box[1] * height) x2 = int(box[2] * width) y2 = int(box[3] * height) # Ensure coordinates are within frame boundaries x1 = max(0, min(x1, width - 1)) y1 = max(0, min(y1, height - 1)) x2 = max(0, min(x2, width - 1)) y2 = max(0, min(y2, height - 1)) # Only draw if box has reasonable size if x2 > x1 and y2 > y1: if box_style == "censor": # Draw solid black rectangle cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 0), -1) elif box_style == "bounding-box": # Draw red rectangle with thicker line cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 0, 255), 3) # Add label with background label = detect_keyword # Use exact capitalization label_size = cv2.getTextSize(label, FONT, 0.7, 2)[0] cv2.rectangle( frame, (x1, y1 - 25), (x1 + label_size[0], y1), (0, 0, 255), -1 ) cv2.putText( frame, label, (x1, y1 - 6), FONT, 0.7, (255, 255, 255), 2, cv2.LINE_AA, ) elif box_style == "hitmarker": # Calculate center of the box center_x = (x1 + x2) // 2 center_y = (y1 + y2) // 2 # Draw hitmarker at the center draw_hitmarker(frame, center_x, center_y) # Optional: Add small label above hitmarker label = detect_keyword # Use exact capitalization label_size = cv2.getTextSize(label, FONT, 0.5, 1)[0] cv2.putText( frame, label, (center_x - label_size[0] // 2, center_y - HITMARKER_SIZE - 5), FONT, 0.5, HITMARKER_COLOR, 1, cv2.LINE_AA, ) except Exception as e: print(f"Error drawing {box_style} style box: {str(e)}") return frame def filter_temporal_outliers(detections_dict): """Filter out extremely large detections that take up most of the frame. Only keeps detections that are reasonable in size. Args: detections_dict: Dictionary of {frame_number: [(box, keyword), ...]} """ filtered_detections = {} for t, detections in detections_dict.items(): # Only keep detections that aren't too large valid_detections = [] for box, keyword in detections: # Calculate box size as percentage of frame width = box[2] - box[0] height = box[3] - box[1] area = width * height # If box is less than 90% of frame, keep it if area < 0.9: valid_detections.append((box, keyword)) if valid_detections: filtered_detections[t] = valid_detections return filtered_detections def describe_frames( video_path, model, tokenizer, detect_keyword, test_mode=False, rows=1, cols=1 ): """Extract and detect objects in frames.""" props = get_video_properties(video_path) fps = props["fps"] # If in test mode, only process first 3 seconds if test_mode: frame_count = min(int(fps * TEST_MODE_DURATION), props["frame_count"]) else: frame_count = props["frame_count"] ad_detections = {} # Store detection results by frame number print("Extracting frames and detecting objects...") video = cv2.VideoCapture(video_path) # Process every frame frame_count_processed = 0 with tqdm(total=frame_count) as pbar: while frame_count_processed < frame_count: ret, frame = video.read() if not ret: break # Detect objects in the frame detected_objects = detect_ads_in_frame( model, tokenizer, frame, detect_keyword, rows=rows, cols=cols ) # Store results for every frame, even if empty ad_detections[frame_count_processed] = detected_objects frame_count_processed += 1 pbar.update(1) video.release() if frame_count_processed == 0: print("No frames could be read from video") return {} # Filter out only extremely large detections ad_detections = filter_temporal_outliers(ad_detections) return ad_detections def create_detection_video( video_path, ad_detections, detect_keyword, output_path=None, ffmpeg_preset="medium", test_mode=False, box_style="censor", ): """Create video with detection boxes.""" if output_path is None: # Create outputs directory if it doesn't exist outputs_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), "outputs" ) os.makedirs(outputs_dir, exist_ok=True) # Clean the detect_keyword for filename safe_keyword = "".join( x for x in detect_keyword if x.isalnum() or x in (" ", "_", "-") ) safe_keyword = safe_keyword.replace(" ", "_") # Create output filename base_name = os.path.splitext(os.path.basename(video_path))[0] output_path = os.path.join( outputs_dir, f"{box_style}_{safe_keyword}_{base_name}.mp4" ) print(f"Will save output to: {output_path}") props = get_video_properties(video_path) fps, width, height = props["fps"], props["width"], props["height"] # If in test mode, only process first few seconds if test_mode: frame_count = min(int(fps * TEST_MODE_DURATION), props["frame_count"]) else: frame_count = props["frame_count"] video = cv2.VideoCapture(video_path) # Create temp output path by adding _temp before the extension base, ext = os.path.splitext(output_path) temp_output = f"{base}_temp{ext}" out = cv2.VideoWriter( temp_output, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height) ) print("Creating detection video...") frame_count_processed = 0 with tqdm(total=frame_count) as pbar: while frame_count_processed < frame_count: ret, frame = video.read() if not ret: break # Get detections for this exact frame if frame_count_processed in ad_detections: current_detections = ad_detections[frame_count_processed] if current_detections: frame = draw_ad_boxes( frame, current_detections, detect_keyword, box_style=box_style ) out.write(frame) frame_count_processed += 1 pbar.update(1) video.release() out.release() # Convert to web-compatible format more efficiently try: subprocess.run( [ "ffmpeg", "-y", "-i", temp_output, "-c:v", "libx264", "-preset", ffmpeg_preset, "-crf", "23", "-movflags", "+faststart", # Better web playback "-loglevel", "error", output_path, ], check=True, ) os.remove(temp_output) # Remove the temporary file if not os.path.exists(output_path): print( f"Warning: FFmpeg completed but output file not found at {output_path}" ) return None return output_path except subprocess.CalledProcessError as e: print(f"Error running FFmpeg: {str(e)}") if os.path.exists(temp_output): os.remove(temp_output) return None def process_video( video_path, detect_keyword, test_mode=False, ffmpeg_preset="medium", rows=1, cols=1, box_style="censor", ): """Process a single video file.""" print(f"\nProcessing: {video_path}") print(f"Looking for: {detect_keyword}") # Load model print("Loading Moondream model...") model, tokenizer = load_moondream() # Process video - detect objects ad_detections = describe_frames( video_path, model, tokenizer, detect_keyword, test_mode, rows, cols ) # Create video with detection boxes output_path = create_detection_video( video_path, ad_detections, detect_keyword, ffmpeg_preset=ffmpeg_preset, test_mode=test_mode, box_style=box_style, ) if output_path is None: print("\nError: Failed to create output video") return None print(f"\nOutput saved to: {output_path}") return output_path def main(): """Process all videos in the inputs directory.""" parser = argparse.ArgumentParser( description="Detect objects in videos using Moondream2" ) parser.add_argument( "--test", action="store_true", help="Process only first 3 seconds of each video" ) parser.add_argument( "--preset", choices=FFMPEG_PRESETS, default="medium", help="FFmpeg encoding preset (default: medium). Faster presets = lower quality", ) parser.add_argument( "--detect", type=str, default="face", help='Object to detect in the video (default: face, use --detect "thing to detect" to override)', ) parser.add_argument( "--rows", type=int, default=1, help="Number of rows to split each frame into (default: 1)", ) parser.add_argument( "--cols", type=int, default=1, help="Number of columns to split each frame into (default: 1)", ) parser.add_argument( "--box-style", choices=["censor", "bounding-box", "hitmarker"], default="censor", help="Style of detection visualization (default: censor)", ) args = parser.parse_args() input_dir = "inputs" os.makedirs(input_dir, exist_ok=True) os.makedirs("outputs", exist_ok=True) video_files = [ f for f in os.listdir(input_dir) if f.lower().endswith((".mp4", ".avi", ".mov", ".mkv", ".webm")) ] if not video_files: print("No video files found in 'inputs' directory") return print(f"Found {len(video_files)} videos to process") print(f"Will detect: {args.detect}") if args.test: print("Running in test mode - processing only first 3 seconds of each video") print(f"Using FFmpeg preset: {args.preset}") print(f"Grid size: {args.rows}x{args.cols}") print(f"Box style: {args.box_style}") success_count = 0 for video_file in video_files: video_path = os.path.join(input_dir, video_file) output_path = process_video( video_path, args.detect, test_mode=args.test, ffmpeg_preset=args.preset, rows=args.rows, cols=args.cols, box_style=args.box_style, ) if output_path: success_count += 1 print( f"\nProcessing complete. Successfully processed {success_count} out of {len(video_files)} videos." ) if __name__ == "__main__": main() ================================================ FILE: recipes/promptable-video-redaction/packages.txt ================================================ libvips ffmpeg ================================================ FILE: recipes/promptable-video-redaction/requirements.txt ================================================ gradio>=4.0.0 torch transformers opencv-python pillow numpy tqdm ffmpeg-python einops pyvips accelerate ================================================ FILE: requirements.txt ================================================ torch==2.8.0 Pillow-SIMD==9.5.0.post2 transformers==4.56.1 pyvips-binary==8.16.0 pyvips==2.2.3 accelerate==1.10.1 gradio==4.38.1 # Needed for running evals datasets==3.2.0 editdistance==0.8.1 ================================================ FILE: sample.py ================================================ import argparse from queue import Queue from threading import Thread import torch from PIL import Image from transformers import AutoTokenizer, TextIteratorStreamer from moondream.hf import LATEST_REVISION, Moondream, detect_device if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--image", type=str, required=True) parser.add_argument("--prompt", type=str, required=False) parser.add_argument("--caption", action="store_true") parser.add_argument("--cpu", action="store_true") args = parser.parse_args() if args.cpu: device = torch.device("cpu") dtype = torch.float32 else: device, dtype = detect_device() if device != torch.device("cpu"): print("Using device:", device) print("If you run into issues, pass the `--cpu` flag to this script.") print() image_path = args.image prompt = args.prompt model_id = "vikhyatk/moondream2" tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) moondream = Moondream.from_pretrained( model_id, revision=LATEST_REVISION, torch_dtype=dtype, ).to(device=device) moondream.eval() image = Image.open(image_path) if args.caption: print(moondream.caption(images=[image], tokenizer=tokenizer)[0]) else: image_embeds = moondream.encode_image(image) if prompt is None: chat_history = "" while True: question = input("> ") result_queue = Queue() streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) # Separate direct arguments from keyword arguments thread_args = (image_embeds, question, tokenizer, chat_history) thread_kwargs = {"streamer": streamer, "result_queue": result_queue} thread = Thread( target=moondream.answer_question, args=thread_args, kwargs=thread_kwargs, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text if not new_text.endswith("<") and not new_text.endswith("END"): print(buffer, end="", flush=True) buffer = "" print(buffer) thread.join() answer = result_queue.get() chat_history += f"Question: {question}\n\nAnswer: {answer}\n\n" else: print(">", prompt) answer = moondream.answer_question(image_embeds, prompt, tokenizer) print(answer) ================================================ FILE: tests/test_image_crops.py ================================================ import numpy as np import torch from moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops def test_overlap_crop_basic(): # Create a test image test_image = np.zeros((800, 600, 3), dtype=np.uint8) # Add a recognizable pattern - white rectangle in the middle test_image[300:500, 200:400] = 255 result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) # Check basic properties assert result["crops"][0].shape == (378, 378, 3) assert len(result["crops"]) > 1 assert all(crop.shape == (378, 378, 3) for crop in result["crops"]) assert len(result["tiling"]) == 2 def test_overlap_crop_small_image(): # Test with image smaller than crop size test_image = np.zeros((300, 200, 3), dtype=np.uint8) result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) # Should still produce valid output assert result["crops"][0].shape == (378, 378, 3) assert len(result["crops"]) == 2 assert result["tiling"] == (1, 1) def test_reconstruction(): # Create a test image test_image = np.zeros((800, 600, 3), dtype=np.uint8) # Add a recognizable pattern test_image[300:500, 200:400] = 255 # Crop and reconstruct result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12) crops_tensor = [torch.from_numpy(crop) for crop in result["crops"][1:]] reconstructed = reconstruct_from_crops( crops_tensor, result["tiling"], overlap_margin=4 ) # Convert back to numpy for comparison reconstructed_np = reconstructed.numpy() # The reconstructed image should be similar to the input # We can't expect exact equality due to resizing operations # but the white rectangle should still be visible in the middle center_reconstructed = reconstructed_np[ reconstructed_np.shape[0] // 2 - 100 : reconstructed_np.shape[0] // 2 + 100, reconstructed_np.shape[1] // 2 - 100 : reconstructed_np.shape[1] // 2 + 100, ].mean() # The center region should be significantly brighter than the edges assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100 ================================================ FILE: webcam_gradio_demo.py ================================================ import argparse import time from threading import Thread import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from moondream.hf import LATEST_REVISION, detect_device parser = argparse.ArgumentParser() parser.add_argument("--cpu", action="store_true") args = parser.parse_args() if args.cpu: device = torch.device("cpu") dtype = torch.float32 else: device, dtype = detect_device() if device != torch.device("cpu"): print("Using device:", device) print("If you run into issues, pass the `--cpu` flag to this script.") print() model_id = "vikhyatk/moondream2" tokenizer = AutoTokenizer.from_pretrained(model_id, revision=LATEST_REVISION) moondream = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, revision=LATEST_REVISION ).to(device=device, dtype=dtype) moondream.eval() def answer_question(img, prompt): image_embeds = moondream.encode_image(img) streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) thread = Thread( target=moondream.answer_question, kwargs={ "image_embeds": image_embeds, "question": prompt, "tokenizer": tokenizer, "streamer": streamer, }, ) thread.start() buffer = "" for new_text in streamer: buffer += new_text yield buffer with gr.Blocks() as demo: gr.Markdown("# 🌔 moondream") gr.HTML( """ """ ) with gr.Row(): prompt = gr.Textbox( label="Prompt", value="What's going on? Respond with a single sentence.", interactive=True, ) with gr.Row(): img = gr.Image(type="pil", label="Upload an Image", streaming=True) output = gr.Markdown(elem_classes=["md_output"]) latest_img = None latest_prompt = prompt.value @img.change(inputs=[img]) def img_change(img): global latest_img latest_img = img @prompt.change(inputs=[prompt]) def prompt_change(prompt): global latest_prompt latest_prompt = prompt @demo.load(outputs=[output]) def live_video(): while True: if latest_img is None: time.sleep(0.1) else: for text in answer_question(latest_img, latest_prompt): if len(text) > 0: yield text demo.queue().launch(debug=True)