[
  {
    "path": ".dockerignore",
    "content": ".git\n.github\n.pytest_cache\n.ruff_cache\n.uv-cache\n.venv\n__pycache__/\n*.pyc\n*.pyo\n*.pyd\n\n# Docker files (no need to copy themselves into the image)\nDockerfile\ndocker-compose.yml\n\n# Documentation and repo metadata\nCONTRIBUTING.md\ndocs/\n.git-blame-ignore-revs\n\n# Tests (not needed at runtime)\ntests/\ntest_vram.py\n\n# Launcher scripts (host-only)\n*.bat\n*.sh\n\n# Large/generated data and local outputs\nClipsForInference/\nIgnoredClips/\nOutput/\nCorridorKeyModule/checkpoints/\ngvm_core/weights/\nVideoMaMaInferenceModule/checkpoints/\n"
  },
  {
    "path": ".git-blame-ignore-revs",
    "content": "# Automated code formatting — no behavioral changes\n# ruff format + lint fixes\nb0ad00efbc791ed097cd3fd241c10319beb8a631\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "content": "name: 🐛 Bug Report\ndescription: Report a reproducible issue or unexpected behavior in the project.\ntitle: \"[Bug]: \"\nlabels: [bug]\n\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        Thank you for reporting a bug! Detailed bug reports help us fix issues faster.\n\n        **Before submitting:**\n        - Search existing issues to avoid duplicates\n        - Verify the issue is reproducible with the latest version\n        - Isolate the problem (provide minimal steps to reproduce)\n\n  - type: input\n    id: os\n    attributes:\n      label: Operating System\n      description: Your operating system and version.\n      placeholder: \"Windows 11, macOS 14.2, Ubuntu 22.04 LTS\"\n    validations:\n      required: true\n\n  - type: input\n    id: installation_method\n    attributes:\n      label: Installation Method\n      description: How did you install CorridorKey? (e.g., Windows batch installer, `uv sync`, Docker, manual setup)\n      placeholder: \"Windows batch installer, uv sync --extra cuda, Docker\"\n    validations:\n      required: true\n\n  - type: input\n    id: gpu_info\n    attributes:\n      label: GPU/Hardware (Optional)\n      description: GPU model and VRAM if applicable, or other relevant hardware constraints.\n      placeholder: \"NVIDIA RTX 4090 (24GB), Apple M3 Pro (18GB unified memory)\"\n    validations:\n      required: true\n\n  - type: textarea\n    id: steps\n    attributes:\n      label: Steps to Reproduce\n      description: Provide clear, numbered steps to reproduce the issue. Be as specific as possible.\n      placeholder: |\n        1. Step 1\n        2. Step 2\n        3. Step 3\n        4. Step 4\n        5. Error occurs\n    validations:\n      required: true\n\n  - type: textarea\n    id: expected\n    attributes:\n      label: Expected Behavior\n      description: What should happen instead? Be clear and concise.\n      placeholder: \"The application should display a settings panel without errors.\"\n    validations:\n      required: false\n\n  - type: textarea\n    id: actual\n    attributes:\n      label: Actual Behavior\n      description: What happened instead? Describe the bug clearly and concisely.\n      placeholder: \"The application crashes with an Error\"\n    validations:\n      required: true\n\n  - type: textarea\n    id: logs\n    attributes:\n      label: Relevant Logs or Error Messages\n      description: Paste complete error logs, full stack traces, or screenshots. Use code blocks for readability.\n      placeholder: |\n        ```\n        Put your log in here\n        ```\n\n  - type: textarea\n    id: workaround\n    attributes:\n      label: Workaround (if available)\n      description: If you've found a way to work around this issue, please describe it.\n      placeholder: \"As a workaround, I can set the timeout environment variable before starting the app.\"\n    validations:\n      required: false\n\n  - type: checkboxes\n    id: checks\n    attributes:\n      label: Verification Checklist\n      options:\n        - label: I've verified this bug hasn't been reported before\n          required: true\n        - label: I can reproduce this issue consistently\n          required: true\n        - label: I've included all relevant logs, screenshots, and error messages\n          required: true\n        - label: I've tested with the latest version of the project\n          required: false\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "## What does this change?\n\n## How was it tested?\n\n## Checklist\n\n- [ ] `uv run pytest` passes\n- [ ] `uv run ruff check` passes\n- [ ] `uv run ruff format --check` passes\n"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  push:\n    branches: [main]\n  pull_request:\n    branches: [main]\n\nenv:\n  UV_NO_SYNC: 1\n  UV_LOCKED: 1\n  OPENCV_IO_ENABLE_OPENEXR: 1\n\npermissions:\n  contents: read\n  checks: write\n  pull-requests: write\n\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6\n        with:\n          persist-credentials: false\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@v7\n        with:\n          enable-cache: true\n\n      - name: Install dependencies\n        run: uv sync --group dev\n\n      - name: Check formatting\n        run: uv run ruff format --check\n\n      - name: Check lint\n        run: uv run ruff check\n\n  test:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.10\", \"3.13\"]\n    steps:\n      - uses: actions/checkout@v6\n        with:\n          persist-credentials: false\n\n      - name: Install uv\n        uses: astral-sh/setup-uv@v7\n        with:\n          enable-cache: true\n\n      - name: Install dependencies\n        run: uv sync --group dev --python ${{ matrix.python-version }}\n\n      - name: Run tests\n        run: uv run pytest -v --tb=short -m \"not gpu\"\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "name: Documentation\non:\n    push:\n        branches:\n            - main\n        paths:\n            - \"docs/**\"\n            - \"zensical.toml\"\n    workflow_dispatch:\n\npermissions:\n    contents: read\n    pages: write\n    id-token: write\n\njobs:\n    deploy:\n        environment:\n            name: github-pages\n            url: ${{ steps.deployment.outputs.page_url }}\n        runs-on: ubuntu-latest\n\n        steps:\n            - uses: actions/checkout@v6\n\n            - uses: actions/configure-pages@v5\n\n            - name: Install uv\n              uses: astral-sh/setup-uv@v7\n              with:\n                  enable-cache: true\n\n            - name: Install docs dependencies via uv\n              run: uv sync --locked --only-group docs\n\n            - name: Build docs with zensical via uv\n              run: uv run zensical build --clean\n\n            - uses: actions/upload-pages-artifact@v4\n              with:\n                  path: site\n\n            - uses: actions/deploy-pages@v4\n              id: deployment\n"
  },
  {
    "path": ".gitignore",
    "content": "# Python\n__pycache__/\n*.pyc\n*.pyo\n*.pyd\n.coverage\n.Python\n.pytest_cache/\n.venv/\n.hypothesis\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Project Specific\nClipsForInference/*\n!ClipsForInference/.gitkeep\nOutput/*\n!Output/.gitkeep\nIgnored*/*\n!Ignored*/.gitkeep\nCorridorKey_remote.bat\n.ipynb_checkpoints/\n.DS_Store\n\n# IDE\n.vscode/\n.idea/\n\n# Models & Checkpoints (Large Files)\n*.pth\n*.pt\n*.ckpt\n*.safetensors\n*.bin\n*.onnx\n\n# Checkpoint Directories\nCorridorKeyModule/checkpoints/*\n!CorridorKeyModule/checkpoints/.gitkeep\nCorridorKeyModule/IgnoredCheckpoints/*\n!CorridorKeyModule/IgnoredCheckpoints/.gitkeep\nVideoMaMaInferenceModule/checkpoints/*\n!VideoMaMaInferenceModule/checkpoints/.gitkeep\ngvm_core/weights/*\n!gvm_core/weights/.gitkeep\nBiRefNetModule/checkpoints/*\n!BiRefNetModule/checkpoints/.gitkeep\n\nsite\n"
  },
  {
    "path": ".python-version",
    "content": "3.13"
  },
  {
    "path": "BiRefNetModule/checkpoints/.gitkeep",
    "content": ""
  },
  {
    "path": "BiRefNetModule/wrapper.py",
    "content": "import logging\nimport os\nfrom pathlib import Path\nfrom typing import Tuple\n\nimport cv2\nimport numpy as np\nimport torch\nfrom huggingface_hub import snapshot_download\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import AutoModelForImageSegmentation\n\ntorch.set_float32_matmul_precision([\"high\", \"highest\"][0])\n\n\nclass ImagePreprocessor:\n    def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:\n        self.transform_image = transforms.Compose(\n            [\n                transforms.Resize(resolution),\n                transforms.ToTensor(),\n                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n            ]\n        )\n\n    def proc(self, image: Image.Image) -> torch.Tensor:\n        image = self.transform_image(image)\n        return image\n\n\nusage_to_weights_file = {\n    \"General\": \"BiRefNet\",\n    \"General-dynamic\": \"BiRefNet_dynamic\",\n    \"General-HR\": \"BiRefNet_HR\",\n    \"General-Lite\": \"BiRefNet_lite\",\n    \"General-Lite-2K\": \"BiRefNet_lite-2K\",\n    \"General-reso_512\": \"BiRefNet_512x512\",\n    \"Matting\": \"BiRefNet-matting\",\n    \"Matting-dynamic\": \"BiRefNet_dynamic-matting\",\n    \"Matting-HR\": \"BiRefNet_HR-Matting\",\n    \"Matting-Lite\": \"BiRefNet_lite-matting\",\n    \"Portrait\": \"BiRefNet-portrait\",\n    \"DIS\": \"BiRefNet-DIS5K\",\n    \"HRSOD\": \"BiRefNet-HRSOD\",\n    \"COD\": \"BiRefNet-COD\",\n    \"DIS-TR_TEs\": \"BiRefNet-DIS5K-TR_TEs\",\n    \"General-legacy\": \"BiRefNet-legacy\",\n}\n\nhalf_precision = True\n\nbase_folder = os.path.join(os.path.dirname(__file__), \"checkpoints\")\n\n\nclass BiRefNetHandler:\n    def __init__(self, device=\"cpu\", usage=\"General\"):\n        self.device = device\n\n        # Set resolution\n        if usage in [\"General-Lite-2K\"]:\n            self.resolution = (2560, 1440)\n        elif usage in [\"General-reso_512\"]:\n            self.resolution = (512, 512)\n        elif usage in [\"General-HR\", \"Matting-HR\"]:\n            self.resolution = (2048, 2048)\n        else:\n            if \"-dynamic\" in usage:\n                self.resolution = None\n            else:\n                self.resolution = (1024, 1024)\n\n        repo_name = usage_to_weights_file[usage]\n        repo_id = f\"ZhengPeng7/{repo_name}\"\n        model_local_dir = os.path.join(base_folder, repo_name)\n\n        snapshot_download(\n            repo_id=repo_id,\n            local_dir=model_local_dir,\n            local_dir_use_symlinks=False,  # Ensures actual files are downloaded, not just symlinks to the cache\n        )\n\n        self.birefnet = AutoModelForImageSegmentation.from_pretrained(model_local_dir, trust_remote_code=True)\n\n        self.birefnet.to(device)\n        self.birefnet.eval()\n        if half_precision:\n            self.birefnet.half()\n\n    def cleanup(self):\n        \"\"\"Explicitly clear model and release GPU memory.\"\"\"\n        # Delete the model reference\n        if hasattr(self, \"birefnet\"):\n            del self.birefnet\n\n        # Clear Python garbage\n        import gc\n\n        gc.collect()\n\n        # Clear PyTorch CUDA cache\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n    def process(self, input_path, alpha_output_dir=None, dilate_radius=0, on_frame_complete=None):\n        \"\"\"\n        Process a single video or directory of images.\n        \"\"\"\n        input_path = Path(input_path)\n        file_name = input_path.stem\n        is_video = input_path.suffix.lower() in [\".mp4\", \".mkv\", \".gif\", \".mov\", \".avi\"]\n\n        def get_frames():\n            \"\"\"Yields tuples of (image_numpy_array, output_file_name)\"\"\"\n            if is_video:\n                cap = cv2.VideoCapture(str(input_path))\n                count = 0\n                while True:\n                    success, img = cap.read()\n                    if not success:\n                        break\n                    yield img, f\"{file_name}_alpha_{count:05d}.png\"\n                    count += 1\n                cap.release()\n            else:\n                image_files = sorted(\n                    [\n                        f\n                        for f in input_path.iterdir()\n                        if f.is_file() and f.suffix.lower() in [\".jpg\", \".png\", \".jpeg\", \".exr\"]\n                    ]\n                )\n                if not image_files:\n                    logging.warning(f\"No images found in {input_path}\")\n                    return\n\n                # Setup EXR support once if needed\n                if \"OPENCV_IO_ENABLE_OPENEXR\" not in os.environ:\n                    os.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\n\n                for img_path in image_files:\n                    img = cv2.imread(str(img_path), cv2.IMREAD_UNCHANGED)\n                    if img is None:\n                        continue\n                    # Keep original filename for image sequences\n                    yield img, f\"alphaSeq_{img_path.stem}.png\"\n\n        count = 0\n        for image, out_name in get_frames():\n            # Ensure correct conversion to RGB regardless of input format (EXR/PNG/JPG)\n            if len(image.shape) == 2:\n                image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)\n            elif image.shape[2] == 4:\n                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)\n            else:\n                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n\n            # EXR images load as float32. PIL expects uint8. Normalize if necessary.\n            if image_rgb.dtype != np.uint8:\n                image_rgb = cv2.normalize(image_rgb, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)\n\n            pil_image = Image.fromarray(image_rgb)\n\n            # Preprocess\n            if self.resolution is None:  # Account for dynamic models\n                resolution_div_by_32 = [int(int(reso) // 32 * 32) for reso in pil_image.size]\n                if resolution_div_by_32 != self.resolution:\n                    self.resolution = resolution_div_by_32\n            image_preprocessor = ImagePreprocessor(resolution=tuple(self.resolution))\n            image_proc = image_preprocessor.proc(pil_image).unsqueeze(0).to(self.device)\n            if half_precision:\n                image_proc = image_proc.half()\n\n            # Inference\n            with torch.no_grad():\n                preds = self.birefnet(image_proc)[-1].sigmoid().cpu()\n\n            pred = preds[0].squeeze()\n            pred_pil = transforms.ToPILImage()(pred.float())\n\n            # Post-Process\n            target_size = (image.shape[1], image.shape[0])\n            mask = pred_pil.resize(target_size)\n            mask_np = np.array(mask)\n\n            # Dilate\n            if dilate_radius != 0:\n                abs_radius = abs(dilate_radius)\n                k_size = abs_radius * 2 + 1\n                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))\n                if dilate_radius > 0:\n                    mask_np = cv2.dilate(mask_np, kernel, iterations=1)  # Expansion\n                else:\n                    mask_np = cv2.erode(mask_np, kernel, iterations=1)  # Contraction\n\n            # Strict Binary Threshold\n            _, mask_np = cv2.threshold(mask_np, 10, 255, cv2.THRESH_BINARY)\n\n            # Save\n            if alpha_output_dir:\n                save_path = os.path.join(alpha_output_dir, out_name)\n                cv2.imwrite(save_path, mask_np)\n\n            if on_frame_complete:\n                on_frame_complete(count, 0)\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to CorridorKey\n\nThanks for your interest in improving CorridorKey! Whether you're a VFX artist, a pipeline TD, or a machine learning researcher, contributions of all kinds are welcome — bug reports, feature ideas, documentation fixes, and code.\n\n## Legal Agreement\n\nBy contributing to this project, you agree that your contributions will be licensed under the project's **[CorridorKey Licence](LICENSE)**.\n\nBy submitting a Pull Request, you specifically acknowledge and agree to the terms set forth in **Section 6 (CONTRIBUTIONS)** of the license. This ensures that Corridor Digital maintains the full right to use, distribute, and sublicense this codebase, including PR contributions. This is a project for the community, and will always remain freely available here.\n\n## Getting Started\n\n### Prerequisites\n\n- Python 3.10 or newer\n- [uv](https://docs.astral.sh/uv/) for dependency management\n\n### Dev Setup\n\n```bash\ngit clone https://github.com/nikopueringer/CorridorKey.git\ncd CorridorKey\nuv sync --group dev    # installs all dependencies + dev tools (pytest, ruff)\n```\n\nThat's it. No manual virtualenv creation, no `pip install` — uv handles everything.\n\n### Running Tests\n\n```bash\nuv run pytest              # run all tests\nuv run pytest -v           # verbose (shows each test name)\nuv run pytest -m \"not gpu\" # skip tests that need a CUDA GPU\nuv run pytest --cov        # show test coverage (sources and branch mode configured in pyproject.toml)\n```\n\nMost tests run in a few seconds and don't need a GPU or model weights. Tests that require CUDA are marked with `@pytest.mark.gpu` and will be skipped automatically if no GPU is available.\n\n### Apple Silicon (Mac) Notes\n\nIf you are contributing on an Apple Silicon Mac, there are a few extra things to be aware of.\n\n**`uv.lock` drift:** Running `uv run pytest` on macOS regenerates `uv.lock` with macOS-specific dependency markers. **Do not commit this file.** Before staging your changes, always run:\n\n```bash\ngit restore uv.lock\n```\n\n**Selecting the compute backend:** CorridorKey auto-detects MPS on Apple Silicon. To test with the MLX backend or force CPU, set the environment variable before running:\n\n```bash\nexport CORRIDORKEY_BACKEND=mlx   # use native MLX on Apple Silicon\nexport CORRIDORKEY_DEVICE=cpu    # force CPU (useful for isolating device bugs)\n```\n\n**MPS operator fallback:** If PyTorch raises an error about an unsupported MPS operator, enable CPU fallback for those ops:\n\n```bash\nexport PYTORCH_ENABLE_MPS_FALLBACK=1\n```\n\n### Linting and Formatting\n\n```bash\nuv run ruff check          # check for lint errors\nuv run ruff format --check # check formatting (no changes)\nuv run ruff format         # auto-format your code\n```\n\nCI runs both checks on every pull request. Running them locally before pushing saves a round-trip.\n\n## Making Changes\n\n### Pull Requests\n\n1. Fork the repo and create a branch for your change\n2. Make your changes\n3. Run `uv run pytest` and `uv run ruff check` to make sure everything passes\n4. Open a pull request against `main`\n\nIn your PR description, focus on **why** you made the change, not just what changed. If you're fixing a bug, describe the symptoms. If you're adding a feature, explain the use case. A couple of sentences is plenty.\n\n### What Makes a Good Contribution\n\n- **Bug fixes** — especially for edge cases in EXR/linear workflows, color space handling, or platform-specific issues\n- **Tests** — more test coverage is always welcome, particularly for `clip_manager.py` and `inference_engine.py`\n- **Documentation** — better explanations, usage examples, or clarifying comments in tricky code\n- **Performance** — reducing GPU memory usage, speeding up frame processing, or optimizing I/O\n\n### Code Style\n\n- The project uses [ruff](https://docs.astral.sh/ruff/) for both linting and formatting\n- Lint rules: `E, F, W, I, B` (basic style, unused imports, import sorting, common bug patterns)\n- Line length: 120 characters\n- Third-party code in `gvm_core/` and `VideoMaMaInferenceModule/` is excluded from lint enforcement — those are derived from research repos and we try to keep them close to upstream\n\n### Model Weights\n\nThe model checkpoint (`CorridorKey_v1.0.pth`) and optional GVM/VideoMaMa weights are **not** in the git repo. Most tests don't need them. If you're working on inference code and need the weights, follow the download instructions in the [README](README.md).\n\n## Questions?\n\nJoin the [Discord](https://discord.gg/zvwUrdWXJm) — it's the fastest way to get help or discuss ideas before opening a PR.\n"
  },
  {
    "path": "ClipsForInference/.gitkeep",
    "content": ""
  },
  {
    "path": "CorridorKeyModule/IgnoredCheckpoints/.gitkeep",
    "content": ""
  },
  {
    "path": "CorridorKeyModule/README.md",
    "content": "# CorridorKeyModule\n\nA self-contained, high-performance AI Chroma Keying engine. This module provides a simple API to access the `CorridorKey` architecture (Hiera Backbone + CNN Refiner) for processing green screen footage.\n\n## Features\n*   **Resolution Independent:** Automatically resizes input images to match the native training resolution of the model (2048x2048).\n*   **High Fidelity:** Preserves original input resolution using Lanczos4 resampling for final output.\n*   **Robust:** Supports explicit configurations for Linear (EXR) and sRGB (PNG/MP4) source inputs.\n\n## Installation\n\nDependencies for the engine are managed in the main project root `requirements.txt`.  \n*(Requires PyTorch, NumPy, OpenCV, Timm)*\n\n## Usage (GUI Wizard)\n\nFor most users, the easiest way to interact with the module is through the included wizard:\n`clip_manager.py` (or dragging and dropping folders onto the `.bat` / `.sh` scripts).\nThe wizard handles finding the latest `.pth` checkpoint automatically, prompting for configuration (gamma, despill strength, despeckling), and batch processing entire sequences.\n\n## Usage (Python API)\n\n### 1. Initialization\nInitialize the engine once. Point it to your `.pth` checkpoint. The engine is hardcoded to process at 2048x2048, representing the data it was trained on.\n\n```python\nfrom CorridorKeyModule import CorridorKeyEngine\n\n# Initialize standard engine (CUDA)\nengine = CorridorKeyEngine(\n    checkpoint_path=\"models/latest_model.pth\", \n    device='cuda', \n    img_size=2048\n)\n```\n\n### 2. Processing a Frame\nThe engine expects inputs as Numpy Arrays (`H, W, Channels`).\n*   It natively processes in **32-bit float** (`0.0 - 1.0`).\n*   If you pass an **8-bit integer** (`0 - 255`) array, the engine will automatically normalize it to `0.0 - 1.0` floats for you. \n*   If you pass a **16-bit or 32-bit float** array (like an EXR), it will process it at full precision without downgrading.\n\n```python\nimport cv2\nimport os\n\n# Enable EXR Support in OpenCV\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\n\n# Load Image (Linear EXR - Read as 32-bit Float)\nimg_linear = cv2.imread(\"input.exr\", cv2.IMREAD_UNCHANGED)\nimg_linear_rgb = cv2.cvtColor(img_linear, cv2.COLOR_BGR2RGB)\n\n# Load Coarse Mask (Linear EXR - Read as 32-bit Float)\nmask = cv2.imread(\"mask.exr\", cv2.IMREAD_UNCHANGED)\nif mask.ndim == 3: \n    mask = mask[:,:,0] # Keep single channel\n\n# Process\nresult = engine.process_frame(\n    img_linear_rgb, \n    mask,\n    input_is_linear=True, # Critical: Tell the engine this is a Linear EXR\n)\n\n# Save Results (Preserving Float Precision as EXR)\n# 'processed' contains the final RGBA composite (Linear 0-1 float)\nproc_rgba = result['processed']\nproc_bgra = cv2.cvtColor(proc_rgba, cv2.COLOR_RGBA2BGRA)\n\nexr_flags = [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF, cv2.IMWRITE_EXR_COMPRESSION, cv2.IMWRITE_EXR_COMPRESSION_PXR24]\ncv2.imwrite(\"output_processed.exr\", proc_bgra, exr_flags)\n```\n\n## Module Structure\n*   `inference_engine.py`: The main API wrapper class `CorridorKeyEngine`. Handles automated input normalization (uint8 to float), tensor conversions, memory transfer, resizing to/from the 2K processing resolution, and packing the final analytical passes (RG, Alpha, Processed EXR, and Comp overlays).\n*   `core/model_transformer.py`: The architecture definition for the PyTorch model, combining the Hiera backbone and the convolutional refiner head.\n*   `core/color_utils.py`: Custom digital compositing math utilities, including logic for luminance-preserving despilling, straight/premultiplied compositing algorithms, true sRGB gamma conversions, and connected-components morphological matte cleaning.\n"
  },
  {
    "path": "CorridorKeyModule/__init__.py",
    "content": "from __future__ import annotations\n\nfrom .inference_engine import CorridorKeyEngine as CorridorKeyEngine\n"
  },
  {
    "path": "CorridorKeyModule/backend.py",
    "content": "\"\"\"Backend factory — selects Torch or MLX engine and normalizes output contracts.\"\"\"\n\nfrom __future__ import annotations\n\nimport errno\nimport glob\nimport logging\nimport os\nimport platform\nimport shutil\nimport sys\nimport urllib.request\nfrom pathlib import Path\n\nimport numpy as np\n\nlogger = logging.getLogger(__name__)\n\nCHECKPOINT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"checkpoints\")\nTORCH_EXT = \".pth\"\nMLX_EXT = \".safetensors\"\nDEFAULT_IMG_SIZE = 2048\n\nBACKEND_ENV_VAR = \"CORRIDORKEY_BACKEND\"\nVALID_BACKENDS = (\"auto\", \"torch\", \"mlx\")\n\n# Update HF_REPO_ID and HF_CHECKPOINT_FILENAME if a new model version is released.\nHF_REPO_ID = \"nikopueringer/CorridorKey_v1.0\"\nHF_CHECKPOINT_FILENAME = \"CorridorKey.pth\"\n\n\ndef resolve_backend(requested: str | None = None) -> str:\n    \"\"\"Resolve backend: CLI flag > env var > auto-detect.\n\n    Auto mode: Apple Silicon + corridorkey_mlx importable + .safetensors found → mlx.\n    Otherwise → torch.\n\n    Raises RuntimeError if explicit backend is unavailable.\n    \"\"\"\n    if requested is None or requested.lower() == \"auto\":\n        backend = os.environ.get(BACKEND_ENV_VAR, \"auto\").lower()\n    else:\n        backend = requested.lower()\n\n    if backend == \"auto\":\n        return _auto_detect_backend()\n\n    if backend not in VALID_BACKENDS:\n        raise RuntimeError(f\"Unknown backend '{backend}'. Valid: {', '.join(VALID_BACKENDS)}\")\n\n    if backend == \"mlx\":\n        _validate_mlx_available()\n\n    return backend\n\n\nCHECKPOINT_DIR = os.path.join(\"CorridorKeyModule\", \"checkpoints\")\nMLX_MODEL_URL = \"https://github.com/nikopueringer/corridorkey-mlx/releases/download/v1.0.0/corridorkey_mlx.safetensors\"\nMLX_MODEL_FILENAME = \"corridorkey_mlx.safetensors\"\n\n\ndef _auto_detect_backend() -> str:\n    \"\"\"Try MLX on Apple Silicon, fall back to Torch.\"\"\"\n    if sys.platform != \"darwin\" or platform.machine() != \"arm64\":\n        logger.info(\"Not Apple Silicon — using torch backend\")\n        return \"torch\"\n\n    try:\n        import corridorkey_mlx  # type: ignore[import-not-found]  # noqa: F401\n    except ImportError:\n        logger.info(\"corridorkey_mlx not installed — using torch backend\")\n        return \"torch\"\n\n        # Auto-download logic for the .safetensors file\n    model_path = os.path.join(CHECKPOINT_DIR, MLX_MODEL_FILENAME)\n    cache_path = model_path + \".tmp\"\n\n    if not os.path.exists(model_path):\n        logger.info(f\"MLX checkpoint not found. Downloading to {model_path}...\")\n        try:\n            if os.path.exists(cache_path):\n                os.remove(cache_path)\n\n            # Create CorridorKeyModule/checkpoints/ if it doesn't exist\n            os.makedirs(CHECKPOINT_DIR, exist_ok=True)\n\n            # Download the file\n            urllib.request.urlretrieve(MLX_MODEL_URL, cache_path)\n            os.rename(cache_path, model_path)\n            logger.info(\"Download complete.\")\n\n        except Exception as e:\n            logger.error(f\"Failed to download MLX checkpoint: {e}\")\n            logger.info(\"Falling back to torch backend due to download failure.\")\n\n            # Clean up corrupted/partial file if the download failed midway\n            if os.path.exists(model_path):\n                os.remove(model_path)\n\n            return \"torch\"\n\n    logger.info(\"Apple Silicon + MLX available — using mlx backend\")\n    return \"mlx\"\n\n\ndef _validate_mlx_available() -> None:\n    \"\"\"Raise RuntimeError with actionable message if MLX can't be used.\"\"\"\n    if sys.platform != \"darwin\" or platform.machine() != \"arm64\":\n        raise RuntimeError(\"MLX backend requires Apple Silicon (M1+ Mac)\")\n\n    try:\n        import corridorkey_mlx  # type: ignore[import-not-found]  # noqa: F401\n    except ImportError as err:\n        raise RuntimeError(\n            \"MLX backend requested but corridorkey_mlx is not installed. \"\n            \"Install with: uv pip install corridorkey-mlx@git+https://github.com/cmoyates/corridorkey-mlx.git\"\n        ) from err\n\n\ndef _ensure_torch_checkpoint() -> Path:\n    \"\"\"Download the Torch checkpoint from HuggingFace if not present.\n\n    Returns the path to the downloaded checkpoint file.\n\n    Raises:\n        RuntimeError: Network or download failure.\n        OSError: Disk space or filesystem error.\n    \"\"\"\n    dest = Path(CHECKPOINT_DIR) / HF_CHECKPOINT_FILENAME\n    hf_url = f\"https://huggingface.co/{HF_REPO_ID}\"\n\n    from huggingface_hub import hf_hub_download\n\n    logger.info(\"Downloading CorridorKey checkpoint from %s ...\", hf_url)\n\n    try:\n        cached_path = hf_hub_download(\n            repo_id=HF_REPO_ID,\n            filename=HF_CHECKPOINT_FILENAME,\n        )\n    except Exception as exc:\n        raise RuntimeError(\n            f\"Failed to download CorridorKey checkpoint from {hf_url}. \"\n            \"Check your network connection and try again. \"\n            f\"Original error: {exc}\"\n        ) from exc\n\n    try:\n        shutil.copy2(cached_path, dest)\n    except OSError as exc:\n        if exc.errno == errno.ENOSPC:\n            raise OSError(\n                errno.ENOSPC,\n                \"Not enough disk space to save checkpoint (~300 MB required). \"\n                f\"Free up space in {CHECKPOINT_DIR} and try again.\",\n            ) from exc\n        raise\n\n    logger.info(\"Checkpoint saved to %s\", dest)\n    return dest\n\n\ndef _discover_checkpoint(ext: str) -> Path:\n    \"\"\"Find exactly one checkpoint with the given extension.\n\n    Raises FileNotFoundError (0 found) or ValueError (>1 found).\n    Includes cross-reference hints when wrong extension files exist.\n    \"\"\"\n    matches = glob.glob(os.path.join(CHECKPOINT_DIR, f\"*{ext}\"))\n\n    if len(matches) == 0:\n        if ext == TORCH_EXT:\n            return _ensure_torch_checkpoint()\n        other_ext = MLX_EXT if ext == TORCH_EXT else TORCH_EXT\n        other_files = glob.glob(os.path.join(CHECKPOINT_DIR, f\"*{other_ext}\"))\n        hint = \"\"\n        if other_files:\n            other_backend = \"mlx\" if other_ext == MLX_EXT else \"torch\"\n            hint = f\" (Found {other_ext} files — did you mean --backend={other_backend}?)\"\n        raise FileNotFoundError(f\"No {ext} checkpoint found in {CHECKPOINT_DIR}.{hint}\")\n\n    if len(matches) > 1:\n        names = [os.path.basename(f) for f in matches]\n        raise ValueError(f\"Multiple {ext} checkpoints in {CHECKPOINT_DIR}: {names}. Keep exactly one.\")\n\n    return Path(matches[0])\n\n\ndef _wrap_mlx_output(raw: dict, despill_strength: float, auto_despeckle: bool, despeckle_size: int) -> dict:\n    \"\"\"Normalize MLX uint8 output to match Torch float32 contract.\n\n    Torch contract:\n      alpha:     [H,W,1] float32 0-1\n      fg:        [H,W,3] float32 0-1 sRGB\n      comp:      [H,W,3] float32 0-1 sRGB\n      processed: [H,W,4] float32 linear premul RGBA\n    \"\"\"\n    from CorridorKeyModule.core import color_utils as cu\n\n    # alpha: uint8 [H,W] → float32 [H,W,1]\n    alpha_raw = raw[\"alpha\"]\n    alpha = alpha_raw.astype(np.float32) / 255.0\n    if alpha.ndim == 2:\n        alpha = alpha[:, :, np.newaxis]\n\n    # fg: uint8 [H,W,3] → float32 [H,W,3] (sRGB)\n    fg = raw[\"fg\"].astype(np.float32) / 255.0\n\n    # Apply despeckle (MLX stubs this)\n    if auto_despeckle:\n        processed_alpha = cu.clean_matte(alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)\n    else:\n        processed_alpha = alpha\n\n    # Apply despill (MLX stubs this)\n    fg_despilled = cu.despill(fg, green_limit_mode=\"average\", strength=despill_strength)\n\n    # Composite over checkerboard for comp output\n    h, w = fg.shape[:2]\n    bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55)\n    bg_lin = cu.srgb_to_linear(bg_srgb)\n    fg_despilled_lin = cu.srgb_to_linear(fg_despilled)\n    comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha)\n    comp_srgb = cu.linear_to_srgb(comp_lin)\n\n    # Build processed: [H,W,4] linear premul RGBA\n    fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha)\n    processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1)\n\n    return {\n        \"alpha\": alpha,  # raw prediction (before despeckle), matches Torch\n        \"fg\": fg,  # raw sRGB prediction, matches Torch\n        \"comp\": comp_srgb,  # sRGB composite on checker\n        \"processed\": processed_rgba,  # linear premul RGBA\n    }\n\n\nclass _MLXEngineAdapter:\n    \"\"\"Wraps CorridorKeyMLXEngine to match Torch output contract.\"\"\"\n\n    def __init__(self, raw_engine):\n        self._engine = raw_engine\n        logger.info(\"MLX adapter active: despill and despeckle are handled by the adapter layer, not native MLX\")\n\n    def process_frame(\n        self,\n        image,\n        mask_linear,\n        refiner_scale=1.0,\n        input_is_linear=False,\n        fg_is_straight=True,\n        despill_strength=1.0,\n        auto_despeckle=True,\n        despeckle_size=400,\n    ):\n        \"\"\"Delegate to MLX engine, then normalize output to Torch contract.\"\"\"\n        # MLX engine expects uint8 input — convert if float\n        if image.dtype != np.uint8:\n            image_u8 = (np.clip(image, 0.0, 1.0) * 255).astype(np.uint8)\n        else:\n            image_u8 = image\n\n        if mask_linear.dtype != np.uint8:\n            mask_u8 = (np.clip(mask_linear, 0.0, 1.0) * 255).astype(np.uint8)\n        else:\n            mask_u8 = mask_linear\n\n        # Squeeze mask to 2D for MLX (it validates [H,W] or [H,W,1])\n        if mask_u8.ndim == 3:\n            mask_u8 = mask_u8[:, :, 0]\n\n        raw = self._engine.process_frame(\n            image_u8,\n            mask_u8,\n            refiner_scale=refiner_scale,\n            input_is_linear=input_is_linear,\n            fg_is_straight=fg_is_straight,\n            despill_strength=0.0,  # disable MLX stubs — adapter applies these\n            auto_despeckle=False,\n            despeckle_size=despeckle_size,\n        )\n\n        return _wrap_mlx_output(raw, despill_strength, auto_despeckle, despeckle_size)\n\n\nDEFAULT_MLX_TILE_SIZE = 512\nDEFAULT_MLX_TILE_OVERLAP = 64\n\n\ndef create_engine(\n    backend: str | None = None,\n    device: str | None = None,\n    img_size: int = DEFAULT_IMG_SIZE,\n    tile_size: int | None = DEFAULT_MLX_TILE_SIZE,\n    overlap: int = DEFAULT_MLX_TILE_OVERLAP,\n):\n    \"\"\"Factory: returns an engine with process_frame() matching the Torch contract.\n\n    Args:\n        tile_size: MLX only — tile size for tiled inference (default 512).\n            Set to None to disable tiling and use full-frame inference.\n        overlap: MLX only — overlap pixels between tiles (default 64).\n    \"\"\"\n    backend = resolve_backend(backend)\n\n    if backend == \"mlx\":\n        ckpt = _discover_checkpoint(MLX_EXT)\n        from corridorkey_mlx import CorridorKeyMLXEngine  # type: ignore[import-not-found]\n\n        raw_engine = CorridorKeyMLXEngine(str(ckpt), img_size=img_size, tile_size=tile_size, overlap=overlap)\n        mode = f\"tiled (tile={tile_size}, overlap={overlap})\" if tile_size else \"full-frame\"\n        logger.info(\"MLX engine loaded: %s [%s]\", ckpt.name, mode)\n        return _MLXEngineAdapter(raw_engine)\n    else:\n        ckpt = _discover_checkpoint(TORCH_EXT)\n        from CorridorKeyModule.inference_engine import CorridorKeyEngine\n\n        logger.info(\"Torch engine loaded: %s (device=%s)\", ckpt.name, device)\n        return CorridorKeyEngine(checkpoint_path=str(ckpt), device=device or \"cpu\", img_size=img_size)\n"
  },
  {
    "path": "CorridorKeyModule/checkpoints/.gitkeep",
    "content": ""
  },
  {
    "path": "CorridorKeyModule/core/__init__.py",
    "content": ""
  },
  {
    "path": "CorridorKeyModule/core/color_utils.py",
    "content": "from __future__ import annotations\n\nimport functools\nfrom collections.abc import Callable\n\nimport cv2\nimport numpy as np\nimport torch\n\n\ndef _is_tensor(x: np.ndarray | torch.Tensor) -> bool:\n    return isinstance(x, torch.Tensor)\n\n\ndef _if_tensor(is_tensor: bool, tensor_func: Callable, numpy_func: Callable) -> Callable:\n    return tensor_func if is_tensor else numpy_func\n\n\ndef _power(x: np.ndarray | torch.Tensor, exponent: float) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Power function that supports both Numpy arrays and PyTorch tensors.\n    \"\"\"\n    power = _if_tensor(_is_tensor(x), torch.pow, np.power)\n    return power(x, exponent)\n\n\ndef _where(\n    condition: np.ndarray | torch.Tensor, x: np.ndarray | torch.Tensor, y: np.ndarray | torch.Tensor\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Where function that supports both Numpy arrays and PyTorch tensors.\n    \"\"\"\n    where = _if_tensor(_is_tensor(x), torch.where, np.where)\n    return where(condition, x, y)\n\n\ndef _clamp(x: np.ndarray | torch.Tensor, min: float) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Clamp function that supports both Numpy arrays and PyTorch tensors.\n    \"\"\"\n    if isinstance(x, torch.Tensor):\n        return x.clamp(min=0.0)\n    return np.clip(x, 0.0, None)\n\n\n_torch_stack = functools.partial(torch.stack, dim=-1)\n_numpy_stack = functools.partial(np.stack, axis=-1)\n\n\ndef linear_to_srgb(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Converts Linear to sRGB using the official piecewise sRGB transfer function.\n    Supports both Numpy arrays and PyTorch tensors.\n    \"\"\"\n    x = _clamp(x, 0.0)\n    mask = x <= 0.0031308\n    return _where(mask, x * 12.92, 1.055 * _power(x, 1.0 / 2.4) - 0.055)\n\n\ndef srgb_to_linear(x: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Converts sRGB to Linear using the official piecewise sRGB transfer function.\n    Supports both Numpy arrays and PyTorch tensors.\n    \"\"\"\n    x = _clamp(x, 0.0)\n    mask = x <= 0.04045\n    return _where(mask, x / 12.92, _power((x + 0.055) / 1.055, 2.4))\n\n\ndef premultiply(fg: np.ndarray | torch.Tensor, alpha: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Premultiplies foreground by alpha.\n    fg: Color [..., C] or [C, ...]\n    alpha: Alpha [..., 1] or [1, ...]\n    \"\"\"\n    return fg * alpha\n\n\ndef unpremultiply(\n    fg: np.ndarray | torch.Tensor, alpha: np.ndarray | torch.Tensor, eps: float = 1e-6\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Un-premultiplies foreground by alpha.\n    Ref: fg_straight = fg_premul / (alpha + eps)\n    \"\"\"\n    return fg / (alpha + eps)\n\n\ndef composite_straight(\n    fg: np.ndarray | torch.Tensor, bg: np.ndarray | torch.Tensor, alpha: np.ndarray | torch.Tensor\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Composites Straight FG over BG.\n    Formula: FG * Alpha + BG * (1 - Alpha)\n    \"\"\"\n    return fg * alpha + bg * (1.0 - alpha)\n\n\ndef composite_premul(\n    fg: np.ndarray | torch.Tensor, bg: np.ndarray | torch.Tensor, alpha: np.ndarray | torch.Tensor\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Composites Premultiplied FG over BG.\n    Formula: FG + BG * (1 - Alpha)\n    \"\"\"\n    return fg + bg * (1.0 - alpha)\n\n\ndef rgb_to_yuv(image: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    Converts RGB to YUV (Rec. 601).\n    Input: [..., 3, H, W] or [..., 3] depending on layout.\n    Supports standard PyTorch BCHW.\n    \"\"\"\n    if not _is_tensor(image):\n        raise TypeError(\"rgb_to_yuv only supports dict/tensor inputs currently\")\n\n    # Weights for RGB -> Y\n    # Rec. 601: 0.299, 0.587, 0.114\n\n    # Assume BCHW layout if 4 dims\n    if image.dim() == 4:\n        r = image[:, 0:1, :, :]\n        g = image[:, 1:2, :, :]\n        b = image[:, 2:3, :, :]\n    elif image.dim() == 3 and image.shape[0] == 3:  # CHW\n        r = image[0:1, :, :]\n        g = image[1:2, :, :]\n        b = image[2:3, :, :]\n    else:\n        # Last dim conversion\n        r = image[..., 0]\n        g = image[..., 1]\n        b = image[..., 2]\n\n    y = 0.299 * r + 0.587 * g + 0.114 * b\n    u = 0.492 * (b - y)\n    v = 0.877 * (r - y)\n\n    if image.dim() >= 3 and image.shape[-3] == 3:  # Concatenate along Channel dim\n        return torch.cat([y, u, v], dim=-3)\n    else:\n        return torch.stack([y, u, v], dim=-1)\n\n\ndef dilate_mask(mask: np.ndarray | torch.Tensor, radius: int) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Dilates a mask by a given radius.\n    Supports Numpy (using cv2) and PyTorch (using MaxPool).\n    radius: Int (pixels). 0 = No change.\n    \"\"\"\n    if radius <= 0:\n        return mask\n\n    kernel_size = int(radius * 2 + 1)\n\n    if isinstance(mask, torch.Tensor):\n        # PyTorch Dilation (using Max Pooling)\n        # Expects [B, C, H, W]\n        orig_dim = mask.dim()\n\n        if orig_dim == 2:\n            mask = mask.unsqueeze(0).unsqueeze(0)\n        elif orig_dim == 3:\n            mask = mask.unsqueeze(0)\n\n        padding = radius\n        dilated = torch.nn.functional.max_pool2d(mask, kernel_size, stride=1, padding=padding)\n\n        if orig_dim == 2:\n            return dilated.squeeze()\n        elif orig_dim == 3:\n            return dilated.squeeze(0)\n        return dilated\n\n    # Numpy Dilation (using OpenCV)\n    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))\n    return cv2.dilate(mask, kernel)\n\n\ndef apply_garbage_matte(\n    predicted_matte: np.ndarray | torch.Tensor,\n    garbage_matte_input: np.ndarray | torch.Tensor | None,\n    dilation: int = 10,\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Multiplies predicted matte by a dilated garbage matte to clean up background.\n    \"\"\"\n    if garbage_matte_input is None:\n        return predicted_matte\n\n    garbage_mask = dilate_mask(garbage_matte_input, dilation)\n\n    # Ensure dimensions match for multiplication\n    if _is_tensor(predicted_matte):\n        # Handle broadcasting if needed\n        pass\n    elif garbage_mask.ndim == 2 and predicted_matte.ndim == 3:\n        # Numpy\n        garbage_mask = garbage_mask[:, :, np.newaxis]\n\n    return predicted_matte * garbage_mask\n\n\ndef despill(\n    image: np.ndarray | torch.Tensor, green_limit_mode: str = \"average\", strength: float = 1.0\n) -> np.ndarray | torch.Tensor:\n    \"\"\"\n    Removes green spill from an RGB image using a luminance-preserving method.\n    image: RGB float (0-1).\n    green_limit_mode: 'average' ((R+B)/2) or 'max' (max(R, B)).\n    strength: 0.0 to 1.0 multiplier for the despill effect.\n    \"\"\"\n    if strength <= 0.0:\n        return image\n\n    tensor = _is_tensor(image)\n    _maximum = _if_tensor(tensor, torch.max, np.maximum)\n    _stack = _if_tensor(tensor, _torch_stack, _numpy_stack)\n\n    r = image[..., 0]\n    g = image[..., 1]\n    b = image[..., 2]\n\n    if green_limit_mode == \"max\":\n        limit = _maximum(r, b)\n    else:\n        limit = (r + b) / 2.0\n\n    if isinstance(image, torch.Tensor):\n        # PyTorch Impl — g/limit are Tensor since image is Tensor\n        diff: torch.Tensor = g - limit  # type: ignore[assignment]\n        spill_amount = torch.clamp(diff, min=0.0)\n    else:\n        # Numpy Impl\n        spill_amount = np.maximum(g - limit, 0.0)\n\n    g_new = g - spill_amount\n    r_new = r + (spill_amount * 0.5)\n    b_new = b + (spill_amount * 0.5)\n\n    despilled = _stack([r_new, g_new, b_new])\n\n    if strength < 1.0:\n        return image * (1.0 - strength) + despilled * strength\n\n    return despilled\n\n\ndef clean_matte(alpha_np: np.ndarray, area_threshold: int = 300, dilation: int = 15, blur_size: int = 5) -> np.ndarray:\n    \"\"\"\n    Cleans up small disconnected components (like tracking markers) from a predicted alpha matte.\n    alpha_np: Numpy array [H, W] or [H, W, 1] float (0.0 - 1.0)\n    \"\"\"\n    # Needs to be 2D\n    is_3d = False\n    if alpha_np.ndim == 3:\n        is_3d = True\n        alpha_np = alpha_np[:, :, 0]\n\n    # Threshold to binary\n    mask_8u = (alpha_np > 0.5).astype(np.uint8) * 255\n\n    # Find connected components\n    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(mask_8u, connectivity=8)\n\n    # Create an empty mask for the cleaned components\n    cleaned_mask = np.zeros_like(mask_8u)\n\n    # Keep components larger than the threshold (skip label 0, which is background)\n    for i in range(1, num_labels):\n        if stats[i, cv2.CC_STAT_AREA] >= area_threshold:\n            cleaned_mask[labels == i] = 255\n\n    # Dilate\n    if dilation > 0:\n        kernel_size = int(dilation * 2 + 1)\n        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))\n        cleaned_mask = cv2.dilate(cleaned_mask, kernel)\n\n    # Blur\n    if blur_size > 0:\n        b_size = int(blur_size * 2 + 1)\n        cleaned_mask = cv2.GaussianBlur(cleaned_mask, (b_size, b_size), 0)\n\n    # Convert back to 0-1 float\n    safe_zone = cleaned_mask.astype(np.float32) / 255.0\n\n    # Multiply original alpha by the safe zone\n    result_alpha = alpha_np * safe_zone\n\n    if is_3d:\n        result_alpha = result_alpha[:, :, np.newaxis]\n\n    return result_alpha\n\n\ndef create_checkerboard(\n    width: int, height: int, checker_size: int = 64, color1: float = 0.2, color2: float = 0.4\n) -> np.ndarray:\n    \"\"\"\n    Creates a linear grayscale checkerboard pattern.\n    Returns: Numpy array [H, W, 3] float (0.0-1.0)\n    \"\"\"\n    # Create coordinate grids\n    x = np.arange(width)\n    y = np.arange(height)\n\n    # Determine tile parity\n    x_tiles = x // checker_size\n    y_tiles = y // checker_size\n\n    # Broadcast to 2D\n    x_grid, y_grid = np.meshgrid(x_tiles, y_tiles)\n\n    # XOR for checker pattern (1 if odd, 0 if even)\n    checker = (x_grid + y_grid) % 2\n\n    # Map 0 to color1 and 1 to color2\n    bg_img = np.where(checker == 0, color1, color2).astype(np.float32)\n\n    # Make it 3-channel\n    return np.stack([bg_img, bg_img, bg_img], axis=-1)\n"
  },
  {
    "path": "CorridorKeyModule/core/model_transformer.py",
    "content": "from __future__ import annotations\n\nimport logging\n\nimport timm\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nlogger = logging.getLogger(__name__)\n\n\nclass MLP(nn.Module):\n    \"\"\"Linear embedding: C_in -> C_out.\"\"\"\n\n    def __init__(self, input_dim: int = 2048, embed_dim: int = 768) -> None:\n        super().__init__()\n        self.proj = nn.Linear(input_dim, embed_dim)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.proj(x)\n\n\nclass DecoderHead(nn.Module):\n    def __init__(\n        self, feature_channels: list[int] | None = None, embedding_dim: int = 256, output_dim: int = 1\n    ) -> None:\n        super().__init__()\n        if feature_channels is None:\n            feature_channels = [112, 224, 448, 896]\n\n        # MLP layers to unify channel dimensions\n        self.linear_c4 = MLP(input_dim=feature_channels[3], embed_dim=embedding_dim)\n        self.linear_c3 = MLP(input_dim=feature_channels[2], embed_dim=embedding_dim)\n        self.linear_c2 = MLP(input_dim=feature_channels[1], embed_dim=embedding_dim)\n        self.linear_c1 = MLP(input_dim=feature_channels[0], embed_dim=embedding_dim)\n\n        # Fuse\n        self.linear_fuse = nn.Conv2d(embedding_dim * 4, embedding_dim, kernel_size=1, bias=False)\n        self.bn = nn.BatchNorm2d(embedding_dim)\n        self.relu = nn.ReLU(inplace=True)\n\n        # Predict\n        self.dropout = nn.Dropout(0.1)\n        self.classifier = nn.Conv2d(embedding_dim, output_dim, kernel_size=1)\n\n    def forward(self, features: list[torch.Tensor]) -> torch.Tensor:\n        c1, c2, c3, c4 = features\n\n        n, _, h, w = c4.shape\n\n        # Resize to C1 size (which is H/4)\n        _c4 = self.linear_c4(c4.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c4.shape[2], c4.shape[3])\n        _c4 = F.interpolate(_c4, size=c1.shape[2:], mode=\"bilinear\", align_corners=False)\n\n        _c3 = self.linear_c3(c3.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c3.shape[2], c3.shape[3])\n        _c3 = F.interpolate(_c3, size=c1.shape[2:], mode=\"bilinear\", align_corners=False)\n\n        _c2 = self.linear_c2(c2.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c2.shape[2], c2.shape[3])\n        _c2 = F.interpolate(_c2, size=c1.shape[2:], mode=\"bilinear\", align_corners=False)\n\n        _c1 = self.linear_c1(c1.flatten(2).transpose(1, 2)).transpose(1, 2).view(n, -1, c1.shape[2], c1.shape[3])\n\n        _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))\n        _c = self.bn(_c)\n        _c = self.relu(_c)\n\n        x = self.dropout(_c)\n        x = self.classifier(x)\n\n        return x\n\n\nclass RefinerBlock(nn.Module):\n    \"\"\"\n    Residual Block with Dilation and GroupNorm (Safe for Batch Size 2).\n    \"\"\"\n\n    def __init__(self, channels: int, dilation: int = 1) -> None:\n        super().__init__()\n        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)\n        self.gn1 = nn.GroupNorm(8, channels)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=dilation, dilation=dilation)\n        self.gn2 = nn.GroupNorm(8, channels)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        residual = x\n        out = self.conv1(x)\n        out = self.gn1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n        out = self.gn2(out)\n        out += residual\n        out = self.relu(out)\n        return out\n\n\nclass CNNRefinerModule(nn.Module):\n    \"\"\"\n    Dilated Residual Refiner (Receptive Field ~65px).\n    designed to solve Macroblocking artifacts from Hiera.\n    Structure: Stem -> Res(d1) -> Res(d2) -> Res(d4) -> Res(d8) -> Projection.\n    \"\"\"\n\n    def __init__(self, in_channels: int = 7, hidden_channels: int = 64, out_channels: int = 4) -> None:\n        super().__init__()\n\n        # Stem\n        self.stem = nn.Sequential(\n            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),\n            nn.GroupNorm(8, hidden_channels),\n            nn.ReLU(inplace=True),\n        )\n\n        # Dilated Residual Blocks (RF Expansion)\n        self.res1 = RefinerBlock(hidden_channels, dilation=1)\n        self.res2 = RefinerBlock(hidden_channels, dilation=2)\n        self.res3 = RefinerBlock(hidden_channels, dilation=4)\n        self.res4 = RefinerBlock(hidden_channels, dilation=8)\n\n        # Final Projection (No Activation, purely additive logits)\n        self.final = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)\n\n        # Tiny Noise Init (Whisper) - Provides gradients without shock\n        nn.init.normal_(self.final.weight, mean=0.0, std=1e-3)\n        nn.init.constant_(self.final.bias, 0)\n\n    def forward(self, img: torch.Tensor, coarse_pred: torch.Tensor) -> torch.Tensor:\n        # img: [B, 3, H, W]\n        # coarse_pred: [B, 4, H, W]\n        x = torch.cat([img, coarse_pred], dim=1)\n\n        x = self.stem(x)\n        x = self.res1(x)\n        x = self.res2(x)\n        x = self.res3(x)\n        x = self.res4(x)\n\n        # Output Scaling (10x Boost)\n        # Allows the Refiner to predict small stable values (e.g. 0.5) that become strong corrections (5.0).\n        return self.final(x) * 10.0\n\n\nclass GreenFormer(nn.Module):\n    def __init__(\n        self,\n        encoder_name: str = \"hiera_base_plus_224.mae_in1k_ft_in1k\",\n        in_channels: int = 4,\n        img_size: int = 512,\n        use_refiner: bool = True,\n    ) -> None:\n        super().__init__()\n\n        # --- Encoder ---\n        # Load Pretrained Hiera\n        # 1. Create Target Model (512x512, Random Weights)\n        # We use features_only=True, which wraps it in FeatureGetterNet\n        logger.info(\"Initializing %s (img_size=%d)\", encoder_name, img_size)\n        self.encoder = timm.create_model(encoder_name, pretrained=False, features_only=True, img_size=img_size)\n        # We skip downloading/loading base weights because the user's checkpoint\n        # (loaded immediately after this) contains all weights, including correctly\n        # trained/sized PosEmbeds. This keeps the project offline-capable using only local assets.\n        logger.info(\"Skipped downloading base weights (relying on custom checkpoint)\")\n\n        # Patch First Layer for 4 channels\n        if in_channels != 3:\n            self._patch_input_layer(in_channels)\n\n        # Get feature info\n        # Verified Hiera Base Plus channels: [112, 224, 448, 896]\n        # We can try to fetch dynamically\n        try:\n            feature_channels = self.encoder.feature_info.channels()\n        except (AttributeError, TypeError):\n            feature_channels = [112, 224, 448, 896]\n        logger.info(\"Feature channels: %s\", feature_channels)\n\n        # --- Decoders ---\n        embedding_dim = 256\n\n        # Alpha Decoder (Outputs 1 channel)\n        self.alpha_decoder = DecoderHead(feature_channels, embedding_dim, output_dim=1)\n\n        # Foreground Decoder (Outputs 3 channels)\n        self.fg_decoder = DecoderHead(feature_channels, embedding_dim, output_dim=3)\n\n        # --- Refiner ---\n        # CNN Refiner\n        # In Channels: 3 (RGB) + 4 (Coarse Pred) = 7\n        self.use_refiner = use_refiner\n        if self.use_refiner:\n            self.refiner = CNNRefinerModule(in_channels=7, hidden_channels=64, out_channels=4)\n        else:\n            self.refiner = None\n            logger.info(\"Refiner module DISABLED (backbone-only mode)\")\n\n    def _patch_input_layer(self, in_channels: int) -> None:\n        \"\"\"\n        Modifies the first convolution layer to accept `in_channels`.\n        Copies existing RGB weights and initializes extras to zero.\n        \"\"\"\n        # Hiera: self.encoder.model.patch_embed.proj\n\n        try:\n            patch_embed = self.encoder.model.patch_embed.proj\n        except AttributeError:\n            # Fallback if timm changes structure or for other models\n            patch_embed = self.encoder.patch_embed.proj\n        weight = patch_embed.weight.data  # [Out, 3, K, K]\n        bias = patch_embed.bias.data if patch_embed.bias is not None else None\n\n        new_in_channels = in_channels\n        out_channels, _, k, k = weight.shape\n\n        # Create new conv\n        new_conv = nn.Conv2d(\n            new_in_channels,\n            out_channels,\n            kernel_size=k,\n            stride=patch_embed.stride,\n            padding=patch_embed.padding,\n            bias=(bias is not None),\n        )\n\n        # Copy weights\n        new_conv.weight.data[:, :3, :, :] = weight\n        # Initialize new channels to 0 (Weight Patching)\n        new_conv.weight.data[:, 3:, :, :] = 0.0\n\n        if bias is not None:\n            new_conv.bias.data = bias\n\n        # Replace in module\n        try:\n            self.encoder.model.patch_embed.proj = new_conv\n        except AttributeError:\n            self.encoder.patch_embed.proj = new_conv\n\n        logger.info(\"Patched input layer: 3 → %d channels (extra initialized to 0)\", in_channels)\n\n    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:\n        # x: [B, 4, H, W]\n        input_size = x.shape[2:]\n\n        # Encode\n        features = self.encoder(x)  # Returns list of features\n\n        # Decode Streams\n        alpha_logits = self.alpha_decoder(features)  # [B, 1, H/4, W/4]\n        fg_logits = self.fg_decoder(features)  # [B, 3, H/4, W/4]\n\n        # Upsample to full resolution (Bilinear)\n        # These are the \"Coarse\" LOGITS\n        alpha_logits_up = F.interpolate(alpha_logits, size=input_size, mode=\"bilinear\", align_corners=False)\n        fg_logits_up = F.interpolate(fg_logits, size=input_size, mode=\"bilinear\", align_corners=False)\n\n        # --- HUMILITY CLAMP REMOVED (Phase 3) ---\n        # User requested NO CLAMPING to preserve all backbone detail.\n        # Refiner sees raw logits (-inf to +inf).\n        # alpha_logits_up = torch.clamp(alpha_logits_up, -3.0, 3.0)\n        # fg_logits_up = torch.clamp(fg_logits_up, -3.0, 3.0)\n\n        # Coarse Probs (for Loss and Refiner Input)\n        alpha_coarse = torch.sigmoid(alpha_logits_up)\n        fg_coarse = torch.sigmoid(fg_logits_up)\n\n        # --- Refinement (CNN Hybrid) ---\n        # 4. Refine (CNN)\n        # Input to refiner: RGB Image (first 3 channels of x) + Coarse Predictions (Probs)\n        # We give the refiner 'Probs' as input features because they are normalized [0,1]\n        rgb = x[:, :3, :, :]\n\n        # Feed the Refiner\n        coarse_pred = torch.cat([alpha_coarse, fg_coarse], dim=1)  # [B, 4, H, W]\n\n        # Refiner outputs DELTA LOGITS\n        # The refiner predicts the correction in valid score space (-inf, inf)\n        if self.use_refiner and self.refiner is not None:\n            delta_logits = self.refiner(rgb, coarse_pred)\n        else:\n            # Zero Deltas\n            delta_logits = torch.zeros_like(coarse_pred)\n\n        delta_alpha = delta_logits[:, 0:1]\n        delta_fg = delta_logits[:, 1:4]\n\n        # Residual Addition in Logit Space\n        # This allows infinite correction capability and prevents saturation blocking\n        alpha_final_logits = alpha_logits_up + delta_alpha\n        fg_final_logits = fg_logits_up + delta_fg\n\n        # Final Activation\n        alpha_final = torch.sigmoid(alpha_final_logits)\n        fg_final = torch.sigmoid(fg_final_logits)\n\n        return {\"alpha\": alpha_final, \"fg\": fg_final}\n"
  },
  {
    "path": "CorridorKeyModule/inference_engine.py",
    "content": "from __future__ import annotations\n\nimport logging\nimport math\nimport os\nimport sys\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom .core import color_utils as cu\nfrom .core.model_transformer import GreenFormer\n\nlogger = logging.getLogger(__name__)\n\n\nclass CorridorKeyEngine:\n    def __init__(\n        self,\n        checkpoint_path: str,\n        device: str = \"cpu\",\n        img_size: int = 2048,\n        use_refiner: bool = True,\n        mixed_precision: bool = True,\n        model_precision: torch.dtype = torch.float32,\n    ) -> None:\n        self.device = torch.device(device)\n        self.img_size = img_size\n        self.checkpoint_path = checkpoint_path\n        self.use_refiner = use_refiner\n\n        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)\n        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)\n\n        if mixed_precision or model_precision != torch.float32:\n            # Use faster matrix multiplication implementation\n            # This reduces the floating point precision a little bit,\n            # but it should be negligible compared to fp16 precision\n            torch.set_float32_matmul_precision(\"high\")\n\n        self.mixed_precision = mixed_precision\n        if mixed_precision and model_precision == torch.float16:\n            # using mixed precision, when the precision is already fp16, is slower\n            self.mixed_precision = False\n\n        self.model_precision = model_precision\n\n        model = self._load_model().to(model_precision)\n\n        # We only tested compilation on windows and linux. For other platforms compilation is disabled as a precaution.\n        if sys.platform == \"linux\" or sys.platform == \"win32\":\n            # Try compiling the model. Fallback to eager mode if it fails.\n            try:\n                self.model = torch.compile(model)\n                # Trigger compilation with a dummy input\n                dummy_input = torch.zeros(1, 4, img_size, img_size, dtype=model_precision, device=self.device)\n                with torch.inference_mode():\n                    self.model(dummy_input)\n            except Exception as e:\n                logger.info(f\"Model compilation failed with error: {e}\")\n                logger.warning(\"Model compilation failed. Falling back to eager mode.\")\n                torch.cuda.empty_cache()\n                self.model = model\n\n    def _load_model(self) -> GreenFormer:\n        logger.info(\"Loading CorridorKey from %s\", self.checkpoint_path)\n        # Initialize Model (Hiera Backbone)\n        model = GreenFormer(\n            encoder_name=\"hiera_base_plus_224.mae_in1k_ft_in1k\", img_size=self.img_size, use_refiner=self.use_refiner\n        )\n        model = model.to(self.device)\n        model.eval()\n\n        # Load Weights\n        if not os.path.isfile(self.checkpoint_path):\n            raise FileNotFoundError(f\"Checkpoint not found: {self.checkpoint_path}\")\n\n        checkpoint = torch.load(self.checkpoint_path, map_location=self.device, weights_only=True)\n        state_dict = checkpoint.get(\"state_dict\", checkpoint)\n\n        # Fix Compiled Model Prefix & Handle PosEmbed Mismatch\n        new_state_dict = {}\n        model_state = model.state_dict()\n\n        for k, v in state_dict.items():\n            if k.startswith(\"_orig_mod.\"):\n                k = k[10:]\n\n            # Check for PosEmbed Mismatch\n            if \"pos_embed\" in k and k in model_state:\n                if v.shape != model_state[k].shape:\n                    print(f\"Resizing {k} from {v.shape} to {model_state[k].shape}\")\n                    # v: [1, N_src, C]\n                    # target: [1, N_dst, C]\n                    # We assume square grid\n                    N_src = v.shape[1]\n                    N_dst = model_state[k].shape[1]\n                    C = v.shape[2]\n\n                    grid_src = int(math.sqrt(N_src))\n                    grid_dst = int(math.sqrt(N_dst))\n\n                    # Reshape to [1, C, H, W]\n                    v_img = v.permute(0, 2, 1).view(1, C, grid_src, grid_src)\n\n                    # Interpolate\n                    v_resized = F.interpolate(v_img, size=(grid_dst, grid_dst), mode=\"bicubic\", align_corners=False)\n\n                    # Reshape back\n                    v = v_resized.flatten(2).transpose(1, 2)\n\n            new_state_dict[k] = v\n\n        missing, unexpected = model.load_state_dict(new_state_dict, strict=False)\n        if len(missing) > 0:\n            print(f\"[Warning] Missing keys: {missing}\")\n        if len(unexpected) > 0:\n            print(f\"[Warning] Unexpected keys: {unexpected}\")\n\n        return model\n\n    @torch.inference_mode()\n    def process_frame(\n        self,\n        image: np.ndarray,\n        mask_linear: np.ndarray,\n        refiner_scale: float = 1.0,\n        input_is_linear: bool = False,\n        fg_is_straight: bool = True,\n        despill_strength: float = 1.0,\n        auto_despeckle: bool = True,\n        despeckle_size: int = 400,\n    ) -> dict[str, np.ndarray]:\n        \"\"\"\n        Process a single frame.\n        Args:\n            image: Numpy array [H, W, 3] (0.0-1.0 or 0-255).\n                   - If input_is_linear=False (Default): Assumed sRGB.\n                   - If input_is_linear=True: Assumed Linear.\n            mask_linear: Numpy array [H, W] or [H, W, 1] (0.0-1.0). Assumed Linear.\n            refiner_scale: Multiplier for Refiner Deltas (default 1.0).\n            input_is_linear: bool. If True, resizes in Linear then transforms to sRGB.\n                             If False, resizes in sRGB (standard).\n            fg_is_straight: bool. If True, assumes FG output is Straight (unpremultiplied).\n                            If False, assumes FG output is Premultiplied.\n            despill_strength: float. 0.0 to 1.0 multiplier for the despill effect.\n            auto_despeckle: bool. If True, cleans up small disconnected components from the predicted alpha matte.\n            despeckle_size: int. Minimum number of consecutive pixels required to keep an island.\n        Returns:\n             dict: {'alpha': np, 'fg': np (sRGB), 'comp': np (sRGB on Gray)}\n        \"\"\"\n        # 1. Inputs Check & Normalization\n        if image.dtype == np.uint8:\n            image = image.astype(np.float32) / 255.0\n\n        if mask_linear.dtype == np.uint8:\n            mask_linear = mask_linear.astype(np.float32) / 255.0\n\n        h, w = image.shape[:2]\n\n        # Ensure Mask Shape\n        if mask_linear.ndim == 2:\n            mask_linear = mask_linear[:, :, np.newaxis]\n\n        # 2. Resize to Model Size\n        # If input is linear, we resize in linear to preserve energy/highlights,\n        # THEN convert to sRGB for the model.\n        if input_is_linear:\n            # Resize in Linear\n            img_resized_lin = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)\n            # Convert to sRGB for Model\n            img_resized = cu.linear_to_srgb(img_resized_lin)\n        else:\n            # Standard sRGB Resize\n            img_resized = cv2.resize(image, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)\n\n        mask_resized = cv2.resize(mask_linear, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)\n\n        if mask_resized.ndim == 2:\n            mask_resized = mask_resized[:, :, np.newaxis]\n\n        # 3. Normalize (ImageNet)\n        # Model expects sRGB input normalized\n        img_norm = (img_resized - self.mean) / self.std\n\n        # 4. Prepare Tensor\n        inp_np = np.concatenate([img_norm, mask_resized], axis=-1)  # [H, W, 4]\n        inp_t = torch.from_numpy(inp_np.transpose((2, 0, 1))).unsqueeze(0).to(self.model_precision).to(self.device)\n\n        # 5. Inference\n        # Hook for Refiner Scaling\n        handle = None\n        if refiner_scale != 1.0 and self.model.refiner is not None:\n\n            def scale_hook(module, input, output):\n                return output * refiner_scale\n\n            handle = self.model.refiner.register_forward_hook(scale_hook)\n\n        with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=self.mixed_precision):\n            out = self.model(inp_t)\n\n        if handle:\n            handle.remove()\n\n        pred_alpha = out[\"alpha\"]\n        pred_fg = out[\"fg\"]  # Output is sRGB (Sigmoid)\n\n        # 6. Post-Process (Resize Back to Original Resolution)\n        # We use Lanczos4 for high-quality resampling to minimize blur when going back to 4K/Original.\n        res_alpha = pred_alpha[0].permute(1, 2, 0).float().cpu().numpy()\n        res_fg = pred_fg[0].permute(1, 2, 0).float().cpu().numpy()\n        res_alpha = cv2.resize(res_alpha, (w, h), interpolation=cv2.INTER_LANCZOS4)\n        res_fg = cv2.resize(res_fg, (w, h), interpolation=cv2.INTER_LANCZOS4)\n\n        if res_alpha.ndim == 2:\n            res_alpha = res_alpha[:, :, np.newaxis]\n\n        # --- ADVANCED COMPOSITING ---\n\n        # A. Clean Matte (Auto-Despeckle)\n        if auto_despeckle:\n            processed_alpha = cu.clean_matte(res_alpha, area_threshold=despeckle_size, dilation=25, blur_size=5)\n        else:\n            processed_alpha = res_alpha\n\n        # B. Despill FG\n        # res_fg is sRGB.\n        fg_despilled = cu.despill(res_fg, green_limit_mode=\"average\", strength=despill_strength)\n\n        # C. Premultiply (for EXR Output)\n        # CONVERT TO LINEAR FIRST! EXRs must house linear color premultiplied by linear alpha.\n        fg_despilled_lin = cu.srgb_to_linear(fg_despilled)\n        fg_premul_lin = cu.premultiply(fg_despilled_lin, processed_alpha)\n\n        # D. Pack RGBA\n        # [H, W, 4] - All channels are now strictly Linear Float\n        processed_rgba = np.concatenate([fg_premul_lin, processed_alpha], axis=-1)\n\n        # ----------------------------\n\n        # 7. Composite (on Checkerboard) for checking\n        # Generate Dark/Light Gray Checkerboard (in sRGB, convert to Linear)\n        bg_srgb = cu.create_checkerboard(w, h, checker_size=128, color1=0.15, color2=0.55)\n        bg_lin = cu.srgb_to_linear(bg_srgb)\n\n        if fg_is_straight:\n            comp_lin = cu.composite_straight(fg_despilled_lin, bg_lin, processed_alpha)\n        else:\n            # If premultiplied model, we shouldn't multiply again (though our pipeline forces straight)\n            comp_lin = cu.composite_premul(fg_despilled_lin, bg_lin, processed_alpha)\n\n        comp_srgb = cu.linear_to_srgb(comp_lin)\n\n        return {  # type: ignore[return-value]  # cu.* returns ndarray|Tensor but inputs are always ndarray here\n            \"alpha\": res_alpha,  # Linear, Raw Prediction\n            \"fg\": res_fg,  # sRGB, Raw Prediction (Straight)\n            \"comp\": comp_srgb,  # sRGB, Composite\n            \"processed\": processed_rgba,  # Linear/Premul, RGBA, Garbage Matted & Despilled\n        }\n"
  },
  {
    "path": "CorridorKey_DRAG_CLIPS_HERE_local.bat",
    "content": "@echo off\nREM Corridor Key Launcher - Local\n\nREM Set script path (assumes corridorkey_cli.py is in the same directory as this batch file)\nset \"SCRIPT_DIR=%~dp0\"\nset \"LOCAL_SCRIPT=%SCRIPT_DIR%corridorkey_cli.py\"\n\nREM SAFETY CHECK: Ensure a folder was dragged onto the script\nif \"%~1\"==\"\" (\n    echo [ERROR] No target folder provided.\n    echo.\n    echo USAGE: \n    echo Please DRAG AND DROP a folder onto this script to process it.\n    echo Do not double-click this script directly.\n    echo.\n    pause\n    exit /b\n)\n\nREM Folder dragged? Use it as the target path.\nset \"WIN_PATH=%~1\"\n\nREM Strip trailing slash if present\nif \"%WIN_PATH:~-1%\"==\"\\\" set \"WIN_PATH=%WIN_PATH:~0,-1%\"\n\necho Starting Corridor Key locally...\necho Target: \"%WIN_PATH%\"\n\nREM Run via uv entry point (handles the virtual environment automatically)\ncd /d \"%SCRIPT_DIR%\"\nuv run --extra cuda corridorkey wizard \"%WIN_PATH%\"\n\npause\n"
  },
  {
    "path": "CorridorKey_DRAG_CLIPS_HERE_local.sh",
    "content": "#!/usr/bin/env bash\n# Corridor Key Launcher - Local Linux/macOS\n\ncd \"$(dirname \"$0\")\"\n\n# Get the directory where this script is located\nSCRIPT_DIR=\"$( cd \"$( dirname \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\"\nLOCAL_SCRIPT=\"$SCRIPT_DIR/corridorkey_cli.py\"\n\n# SAFETY CHECK: Ensure a folder was provided as an argument\nif [ -z \"$1\" ]; then\n    echo \"[ERROR] No target folder provided.\"\n    echo \"\"\n    echo \"USAGE:\"\n    echo \"You can either run this script from the terminal and provide a path:\"\n    echo \"  ./CorridorKey_DRAG_CLIPS_HERE_local.sh /path/to/your/clip/folder\"\n    echo \"\"\n    echo \"Or, in many Linux/macOS desktop environments, you can simply\"\n    echo \"DRAG AND DROP a folder onto this script icon to process it.\"\n    echo \"\"\n    read -p \"Press enter to exit...\"\n    exit 1\nfi\n\n# Folder dragged or provided via CLI? Use it as the target path.\nTARGET_PATH=\"$1\"\n\n# Strip trailing slash if present\nTARGET_PATH=\"${TARGET_PATH%/}\"\n\n# Ensure uv is available before attempting to run\nif ! command -v uv &> /dev/null; then\n    echo \"[ERROR] 'uv' is not installed or not on PATH.\"\n    echo \"\"\n    echo \"Install uv by running:\"\n    echo \"  curl -LsSf https://astral.sh/uv/install.sh | sh\"\n    echo \"\"\n    echo \"Then reopen your terminal and try again.\"\n    read -p \"Press enter to exit...\"\n    exit 1\nfi\n\necho \"Starting Corridor Key locally...\"\necho \"Target: $TARGET_PATH\"\n\n# Run via uv entry point (handles the virtual environment automatically)\nuv run corridorkey wizard \"$TARGET_PATH\"\n\nread -p \"Press enter to close...\"\n"
  },
  {
    "path": "Dockerfile",
    "content": "FROM ghcr.io/astral-sh/uv:0.7-python3.11-bookworm-slim\n\n# Create non-root user upfront.\nRUN useradd --create-home --uid 1000 appuser\n\nRUN mkdir /app && chown appuser:appuser /app\n\nWORKDIR /app\n\n# Runtime dependencies for OpenCV/video I/O.\nRUN --mount=type=cache,target=/var/cache/apt,sharing=locked \\\n    apt-get update && apt-get install -y --no-install-recommends \\\n    ffmpeg \\\n    git \\\n    libgl1 \\\n    libglib2.0-0 && \\\n    apt-get clean && \\\n    rm -rf /var/lib/apt/lists/*\n\nUSER appuser\n\n# Install Python dependencies first for better layer caching.\nCOPY --chown=appuser:appuser pyproject.toml uv.lock ./\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    uv sync --frozen --no-dev --no-install-project\n\n# Copy project source.\nCOPY --chown=appuser:appuser . .\n\n# Install the project itself (cheap, just sets up editable/entry points).\nRUN --mount=type=cache,target=/root/.cache/uv \\\n    uv sync --frozen --no-dev\n\n# Enable OpenEXR support in OpenCV.\nENV OPENCV_IO_ENABLE_OPENEXR=1\n\nENTRYPOINT [\"/app/.venv/bin/python\", \"corridorkey_cli.py\"]\nCMD [\"--action\", \"list\"]\n"
  },
  {
    "path": "IgnoredClips/.gitkeep",
    "content": ""
  },
  {
    "path": "Install_CorridorKey_Linux_Mac.sh",
    "content": "#!/usr/bin/env bash\n\ncd \"$(dirname \"$0\")\"\n\necho \"===================================================\"\necho \"    CorridorKey - MacOS/Linux Auto-Installer\"\necho \"===================================================\"\necho \"\"\n\n# Detect the Operating System\nOS=\"$(uname -s)\"\nif [ \"$OS\" != \"Darwin\" ] && [ \"$OS\" != \"Linux\" ]; then\n    echo \"[ERROR] Unsupported operating system: $OS\"\n    read -p \"Press [Enter] to exit...\"\n    exit 1\nfi\n\n# 1. Check for uv — install it automatically if missing\nif ! command -v uv >/dev/null 2>&1; then\n    echo \"[INFO] uv is not installed. Installing now...\"\n    curl -LsSf https://astral.sh/uv/install.sh | sh\n    \n    if [ $? -ne 0 ]; then\n        echo \"[ERROR] Failed to install uv. Please visit https://docs.astral.sh/uv/ for manual instructions.\"\n        read -p \"Press [Enter] to exit...\"\n        exit 1\n    fi\n\n    # uv installer adds to PATH, but the current terminal session\n    # doesn't see it yet. Add the default install location so we can continue.\n    export PATH=\"$HOME/.local/bin:$PATH\"\n\n    if ! command -v uv >/dev/null 2>&1; then\n        echo \"[ERROR] uv was installed but cannot be found on PATH.\"\n        echo \"Please close this window, open a new terminal, and run this script again.\"\n        read -p \"Press [Enter] to exit...\"\n        exit 1\n    fi\n    echo \"[INFO] uv installed successfully.\"\n    echo \"\"\nfi\n\n# 2. Install all dependencies\necho \"[1/2] Installing Dependencies (This might take a while on first run)...\"\necho \"      uv will automatically download Python if needed.\"\n\nif [ \"$OS\" = \"Darwin\" ]; then\n    echo \"[INFO] macOS detected. Installing with MLX support...\"\n    uv sync --extra mlx\nelif [ \"$OS\" = \"Linux\" ]; then\n    echo \"[INFO] Linux detected. Installing with CUDA support...\"\n    uv sync --extra cuda\nfi\n\nif [ $? -ne 0 ]; then\n    echo \"[ERROR] uv sync failed. Please check the output above for details.\"\n    read -p \"Press [Enter] to exit...\"\n    exit 1\nfi\n\n# 3. Download Weights\necho \"\"\necho \"[2/2] Downloading CorridorKey Model Weights...\"\n\n# Use -p to create the folder only if it doesn't exist\nmkdir -p \"CorridorKeyModule/checkpoints\"\n\nif [ ! -f \"CorridorKeyModule/checkpoints/CorridorKey.pth\" ]; then\n    echo \"Downloading CorridorKey.pth...\"\n    curl -L -o \"CorridorKeyModule/checkpoints/CorridorKey.pth\" \"https://huggingface.co/nikopueringer/CorridorKey_v1.0/resolve/main/CorridorKey_v1.0.pth\"\nelse\n    echo \"CorridorKey.pth already exists!\"\nfi\n\necho \"\"\necho \"===================================================\"\necho \"  Setup Complete! You are ready to key!\"\necho \"===================================================\"\nread -p \"Press [Enter] to close...\""
  },
  {
    "path": "Install_CorridorKey_Windows.bat",
    "content": "@echo off\nTITLE CorridorKey Setup Wizard\necho ===================================================\necho     CorridorKey - Windows Auto-Installer\necho ===================================================\necho.\n\n:: 1. Check for uv — install it automatically if missing\nwhere uv >nul 2>&1\nif %errorlevel% equ 0 goto :uv_ready\n\necho [INFO] uv is not installed. Installing now...\npowershell -ExecutionPolicy ByPass -c \"irm https://astral.sh/uv/install.ps1 | iex\"\nif %errorlevel% neq 0 (\n    echo [ERROR] Failed to install uv. Please visit https://docs.astral.sh/uv/ for manual instructions.\n    pause\n    exit /b\n)\n\n:: uv installer adds to PATH via registry, but the current cmd session\n:: doesn't see it yet. Add the default install location so we can continue.\nset \"PATH=%USERPROFILE%\\.local\\bin;%PATH%\"\n\nwhere uv >nul 2>&1\nif %errorlevel% neq 0 (\n    echo [ERROR] uv was installed but cannot be found on PATH.\n    echo Please close this window, open a new terminal, and run this script again.\n    pause\n    exit /b\n)\necho [INFO] uv installed successfully.\necho.\n\n:uv_ready\n\n:: 2. Install all dependencies (Python, venv, and packages are handled automatically by uv)\necho [1/2] Installing Dependencies (This might take a while on first run)...\necho       uv will automatically download Python if needed.\nuv sync --extra cuda\nif %errorlevel% neq 0 (\n    echo [ERROR] uv sync failed. Please check the output above for details.\n    pause\n    exit /b\n)\n\n:: 3. Download Weights\necho.\necho [2/2] Downloading CorridorKey Model Weights...\nif not exist \"CorridorKeyModule\\checkpoints\" mkdir \"CorridorKeyModule\\checkpoints\"\n\nif not exist \"CorridorKeyModule\\checkpoints\\CorridorKey.pth\" (\n    echo Downloading CorridorKey.pth...\n    curl.exe -L -o \"CorridorKeyModule\\checkpoints\\CorridorKey.pth\" \"https://huggingface.co/nikopueringer/CorridorKey_v1.0/resolve/main/CorridorKey_v1.0.pth\"\n) else (\n    echo CorridorKey.pth already exists!\n)\n\necho.\necho ===================================================\necho   Setup Complete! You are ready to key!\necho   Drag and drop folders onto CorridorKey_DRAG_CLIPS_HERE_local.bat\necho ===================================================\npause\n"
  },
  {
    "path": "Install_GVM_Linux_Mac.sh",
    "content": "#!/usr/bin/env bash\n\ncd \"$(dirname \"$0\")\"\n\n# Set the Terminal window title\necho -n -e \"\\033]0;GVM Setup Wizard\\007\"\necho \"===================================================\"\necho \"    GVM (AlphaHint Generator) - Auto-Installer\"\necho \"===================================================\"\necho \"\"\n\n# Check that uv sync has been run (the .venv directory should exist)\n# Note: I changed the name in the error message to match your Mac installer!\nif [ ! -d \".venv\" ]; then\n    echo \"[ERROR] Project environment not found.\"\n    echo \"Please run Install_CorridorKey_Linux_Mac.sh first!\"\n    read -p \"Press [Enter] to exit...\"\n    exit 1\nfi\n\n# 1. Download Weights\necho \"[1/1] Downloading GVM Model Weights (WARNING: Massive 80GB+ Download)...\"\nmkdir -p \"gvm_core/weights\"\n\necho \"Downloading GVM weights from HuggingFace...\"\nuv run hf download geyongtao/gvm --local-dir \"gvm_core/weights\"\n\necho \"\"\necho \"===================================================\"\necho \"  GVM Setup Complete!\"\necho \"===================================================\"\nread -p \"Press [Enter] to close...\""
  },
  {
    "path": "Install_GVM_Windows.bat",
    "content": "@echo off\nTITLE GVM Setup Wizard\necho ===================================================\necho     GVM (AlphaHint Generator) - Auto-Installer\necho ===================================================\necho.\n\n:: Check that uv sync has been run (the .venv directory should exist)\nif not exist \".venv\" (\n    echo [ERROR] Project environment not found.\n    echo Please run Install_CorridorKey_Windows.bat first!\n    pause\n    exit /b\n)\n\n:: 1. Download Weights (all Python deps are already installed by uv sync)\necho [1/1] Downloading GVM Model Weights (WARNING: Massive 80GB+ Download)...\nif not exist \"gvm_core\\weights\" mkdir \"gvm_core\\weights\"\n\necho Downloading GVM weights from HuggingFace...\nuv run hf download geyongtao/gvm --local-dir gvm_core\\weights\n\necho.\necho ===================================================\necho   GVM Setup Complete!\necho ===================================================\npause\n"
  },
  {
    "path": "Install_VideoMaMa_Linux_Mac.sh",
    "content": "#!/usr/bin/env bash\n\ncd \"$(dirname \"$0\")\"\n\n# Set the Terminal window title\necho -n -e \"\\033]0;VideoMaMa Setup Wizard\\007\"\necho \"===================================================\"\necho \"   VideoMaMa (AlphaHint Generator) - Auto-Installer\"\necho \"===================================================\"\necho \"\"\n\n# Check that uv sync has been run (the .venv directory should exist)\n# Note: I changed the name in the error message to match your Mac installer!\nif [ ! -d \".venv\" ]; then\n    echo \"[ERROR] Project environment not found.\"\n    echo \"Please run Install_CorridorKey_Linux_Mac.sh first!\"\n    read -p \"Press [Enter] to exit...\"\n    exit 1\nfi\n\n# 1. Download Weights\necho \"[1/1] Downloading VideoMaMa Model Weights...\"\nmkdir -p \"VideoMaMaInferenceModule/checkpoints\"\n\necho \"Downloading VideoMaMa weights from HuggingFace...\"\nuv run hf download SammyLim/VideoMaMa --local-dir \"VideoMaMaInferenceModule/checkpoints\"\n\necho \"\"\necho \"===================================================\"\necho \"  VideoMaMa Setup Complete!\"\necho \"===================================================\"\nread -p \"Press [Enter] to close...\""
  },
  {
    "path": "Install_VideoMaMa_Windows.bat",
    "content": "@echo off\nTITLE VideoMaMa Setup Wizard\necho ===================================================\necho   VideoMaMa (AlphaHint Generator) - Auto-Installer\necho ===================================================\necho.\n\n:: Check that uv sync has been run (the .venv directory should exist)\nif not exist \".venv\" (\n    echo [ERROR] Project environment not found.\n    echo Please run Install_CorridorKey_Windows.bat first!\n    pause\n    exit /b\n)\n\n:: 1. Download Weights (all Python deps are already installed by uv sync)\necho [1/1] Downloading VideoMaMa Model Weights...\nif not exist \"VideoMaMaInferenceModule\\checkpoints\" mkdir \"VideoMaMaInferenceModule\\checkpoints\"\n\necho Downloading VideoMaMa weights from HuggingFace...\nuv run hf download SammyLim/VideoMaMa --local-dir VideoMaMaInferenceModule\\checkpoints\n\necho.\necho ===================================================\necho   VideoMaMa Setup Complete!\necho ===================================================\npause\n"
  },
  {
    "path": "LICENSE",
    "content": "CORRIDOR KEY LICENCE\n=======================================================================\n\nVersion 1.0\n\nCopyright (c) Corridor Digital. All rights reserved.\n\n\nADDITIONAL TERMS AND CONDITIONS\n=======================================================================\n\nThis work is licensed under the Creative Commons\nAttribution-NonCommercial-ShareAlike 4.0 International Public License\n(CC BY-NC-SA 4.0), the full text of which is included below, subject\nto the following additional terms and conditions. These additional\nterms supplement the CC BY-NC-SA 4.0 licence and take precedence\nwhere they conflict.\n\nBy exercising any rights to the Licensed Material, You accept and\nagree to be bound by both these Additional Terms and the\nCC BY-NC-SA 4.0 Public License.\n\n\n1. PERMITTED USE\n\n   You may use this tool for any lawful purpose, including for\n   processing images as part of a commercial project, provided that\n   such use complies with the restrictions set out below and the\n   terms of the CC BY-NC-SA 4.0 licence.\n\n\n2. RESTRICTIONS\n\n   In addition to the restrictions set out in the CC BY-NC-SA 4.0\n   licence, the following restrictions apply:\n\n   a. NO REPACKAGING OR RESALE\n      You may not repackage, redistribute, sublicense, or sell\n      this tool, in whole or in part, as a standalone product or\n      as part of a competing product.\n\n   b. NO PAID API OR INFERENCE SERVICES\n      You may not use this tool to provide inference as a paid\n      API service, whether directly or indirectly. This includes,\n      but is not limited to, offering access to this tool behind\n      a paywall, subscription model, or usage-based billing\n      system.\n\n   c. COMMERCIAL SOFTWARE INTEGRATION\n      If you operate a commercial software package or inference\n      service and wish to incorporate this tool into your\n      software, you must obtain a separate written agreement\n      from the Licensor. Please contact:\n\n          contact@corridordigital.com\n\n\n3. ATTRIBUTION\n\n   a. In addition to the attribution requirements set out in\n      Section 3(a) of the CC BY-NC-SA 4.0 licence, You must\n      include the name \"CorridorKey\" in any attribution notice.\n\n   b. Any fork, derivative work, variation, improvement, or\n      subsequent release of this tool must retain the\n      \"CorridorKey\" name in a reasonably prominent position.\n\n   c. You may interchange \"Corridor Key\" for \"CorridorKey\" as\n      needed.\n\n\n4. SHARE-ALIKE\n\n   Any variations, improvements, derivative works, or modified\n   versions of this tool that are publicly released must be\n   distributed under this same licence, including these Additional\n   Terms, or a licence with substantially equivalent terms.\n\n5. DISCLAIMER\n\n   This licence is provided on an \"as-is\" basis. The Licensor makes\n   no representations or warranties regarding the enforceability or\n   legal sufficiency of these terms. You are encouraged to seek\n   independent legal advice if you have questions about the scope\n   or applicability of this licence.\n\n6. CONTRIBUTIONS\n\n   By submitting any contribution (including source code,\n   documentation, or images) to this repository, You agree that:\n\n   a. Your contribution is provided under the terms of this\n      Corridor Key Licence.\n\n   b. You grant the Licensor (Corridor Digital) a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free,\n      irrevocable license to use, reproduce, prepare derivative\n      works of, publicly display, and sublicense your\n      contribution, including for commercial purposes.\n\n   c. You represent that you are the legal owner of the\n      contribution or have the authority to submit it under\n      these terms.\n\n=======================================================================\n\nhttps://creativecommons.org/licenses/by-nc-sa/4.0/\nhttps://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.txt\n\n=======================================================================\n\nAttribution-NonCommercial-ShareAlike 4.0 International\n\n=======================================================================\n\nCreative Commons Corporation (\"Creative Commons\") is not a law firm and\ndoes not provide legal services or legal advice. Distribution of\nCreative Commons public licenses does not create a lawyer-client or\nother relationship. Creative Commons makes its licenses and related\ninformation available on an \"as-is\" basis. Creative Commons gives no\nwarranties regarding its licenses, any material licensed under their\nterms and conditions, or any related information. Creative Commons\ndisclaims all liability for damages resulting from their use to the\nfullest extent possible.\n\nUsing Creative Commons Public Licenses\n\nCreative Commons public licenses provide a standard set of terms and\nconditions that creators and other rights holders may use to share\noriginal works of authorship and other material subject to copyright\nand certain other rights specified in the public license below. The\nfollowing considerations are for informational purposes only, are not\nexhaustive, and do not form part of our licenses.\n\n     Considerations for licensors: Our public licenses are\n     intended for use by those authorized to give the public\n     permission to use material in ways otherwise restricted by\n     copyright and certain other rights. Our licenses are\n     irrevocable. Licensors should read and understand the terms\n     and conditions of the license they choose before applying it.\n     Licensors should also secure all rights necessary before\n     applying our licenses so that the public can reuse the\n     material as expected. Licensors should clearly mark any\n     material not subject to the license. This includes other CC-\n     licensed material, or material used under an exception or\n     limitation to copyright. More considerations for licensors:\n    wiki.creativecommons.org/Considerations_for_licensors\n\n     Considerations for the public: By using one of our public\n     licenses, a licensor grants the public permission to use the\n     licensed material under specified terms and conditions. If\n     the licensor's permission is not necessary for any reason--for\n     example, because of any applicable exception or limitation to\n     copyright--then that use is not regulated by the license. Our\n     licenses grant only permissions under copyright and certain\n     other rights that a licensor has authority to grant. Use of\n     the licensed material may still be restricted for other\n     reasons, including because others have copyright or other\n     rights in the material. A licensor may make special requests,\n     such as asking that all changes be marked or described.\n     Although not required by our licenses, you are encouraged to\n     respect those requests where reasonable. More considerations\n     for the public:\n    wiki.creativecommons.org/Considerations_for_licensees\n\n=======================================================================\n\nCreative Commons Attribution-NonCommercial-ShareAlike 4.0 International\nPublic License\n\nBy exercising the Licensed Rights (defined below), You accept and agree\nto be bound by the terms and conditions of this Creative Commons\nAttribution-NonCommercial-ShareAlike 4.0 International Public License\n(\"Public License\"). To the extent this Public License may be\ninterpreted as a contract, You are granted the Licensed Rights in\nconsideration of Your acceptance of these terms and conditions, and the\nLicensor grants You such rights in consideration of benefits the\nLicensor receives from making the Licensed Material available under\nthese terms and conditions.\n\n\nSection 1 -- Definitions.\n\n  a. Adapted Material means material subject to Copyright and Similar\n     Rights that is derived from or based upon the Licensed Material\n     and in which the Licensed Material is translated, altered,\n     arranged, transformed, or otherwise modified in a manner requiring\n     permission under the Copyright and Similar Rights held by the\n     Licensor. For purposes of this Public License, where the Licensed\n     Material is a musical work, performance, or sound recording,\n     Adapted Material is always produced where the Licensed Material is\n     synched in timed relation with a moving image.\n\n  b. Adapter's License means the license You apply to Your Copyright\n     and Similar Rights in Your contributions to Adapted Material in\n     accordance with the terms and conditions of this Public License.\n\n  c. BY-NC-SA Compatible License means a license listed at\n     creativecommons.org/compatiblelicenses, approved by Creative\n     Commons as essentially the equivalent of this Public License.\n\n  d. Copyright and Similar Rights means copyright and/or similar rights\n     closely related to copyright including, without limitation,\n     performance, broadcast, sound recording, and Sui Generis Database\n     Rights, without regard to how the rights are labeled or\n     categorized. For purposes of this Public License, the rights\n     specified in Section 2(b)(1)-(2) are not Copyright and Similar\n     Rights.\n\n  e. Effective Technological Measures means those measures that, in the\n     absence of proper authority, may not be circumvented under laws\n     fulfilling obligations under Article 11 of the WIPO Copyright\n     Treaty adopted on December 20, 1996, and/or similar international\n     agreements.\n\n  f. Exceptions and Limitations means fair use, fair dealing, and/or\n     any other exception or limitation to Copyright and Similar Rights\n     that applies to Your use of the Licensed Material.\n\n  g. License Elements means the license attributes listed in the name\n     of a Creative Commons Public License. The License Elements of this\n     Public License are Attribution, NonCommercial, and ShareAlike.\n\n  h. Licensed Material means the artistic or literary work, database,\n     or other material to which the Licensor applied this Public\n     License.\n\n  i. Licensed Rights means the rights granted to You subject to the\n     terms and conditions of this Public License, which are limited to\n     all Copyright and Similar Rights that apply to Your use of the\n     Licensed Material and that the Licensor has authority to license.\n\n  j. Licensor means the individual(s) or entity(ies) granting rights\n     under this Public License.\n\n  k. NonCommercial means not primarily intended for or directed towards\n     commercial advantage or monetary compensation. For purposes of\n     this Public License, the exchange of the Licensed Material for\n     other material subject to Copyright and Similar Rights by digital\n     file-sharing or similar means is NonCommercial provided there is\n     no payment of monetary compensation in connection with the\n     exchange.\n\n  l. Share means to provide material to the public by any means or\n     process that requires permission under the Licensed Rights, such\n     as reproduction, public display, public performance, distribution,\n     dissemination, communication, or importation, and to make material\n     available to the public including in ways that members of the\n     public may access the material from a place and at a time\n     individually chosen by them.\n\n  m. Sui Generis Database Rights means rights other than copyright\n     resulting from Directive 96/9/EC of the European Parliament and of\n     the Council of 11 March 1996 on the legal protection of databases,\n     as amended and/or succeeded, as well as other essentially\n     equivalent rights anywhere in the world.\n\n  n. You means the individual or entity exercising the Licensed Rights\n     under this Public License. Your has a corresponding meaning.\n\n\nSection 2 -- Scope.\n\n  a. License grant.\n\n       1. Subject to the terms and conditions of this Public License,\n          the Licensor hereby grants You a worldwide, royalty-free,\n          non-sublicensable, non-exclusive, irrevocable license to\n          exercise the Licensed Rights in the Licensed Material to:\n\n            a. reproduce and Share the Licensed Material, in whole or\n               in part, for NonCommercial purposes only; and\n\n            b. produce, reproduce, and Share Adapted Material for\n               NonCommercial purposes only.\n\n       2. Exceptions and Limitations. For the avoidance of doubt, where\n          Exceptions and Limitations apply to Your use, this Public\n          License does not apply, and You do not need to comply with\n          its terms and conditions.\n\n       3. Term. The term of this Public License is specified in Section\n          6(a).\n\n       4. Media and formats; technical modifications allowed. The\n          Licensor authorizes You to exercise the Licensed Rights in\n          all media and formats whether now known or hereafter created,\n          and to make technical modifications necessary to do so. The\n          Licensor waives and/or agrees not to assert any right or\n          authority to forbid You from making technical modifications\n          necessary to exercise the Licensed Rights, including\n          technical modifications necessary to circumvent Effective\n          Technological Measures. For purposes of this Public License,\n          simply making modifications authorized by this Section 2(a)\n          (4) never produces Adapted Material.\n\n       5. Downstream recipients.\n\n            a. Offer from the Licensor -- Licensed Material. Every\n               recipient of the Licensed Material automatically\n               receives an offer from the Licensor to exercise the\n               Licensed Rights under the terms and conditions of this\n               Public License.\n\n            b. Additional offer from the Licensor -- Adapted Material.\n               Every recipient of Adapted Material from You\n               automatically receives an offer from the Licensor to\n               exercise the Licensed Rights in the Adapted Material\n               under the conditions of the Adapter's License You apply.\n\n            c. No downstream restrictions. You may not offer or impose\n               any additional or different terms or conditions on, or\n               apply any Effective Technological Measures to, the\n               Licensed Material if doing so restricts exercise of the\n               Licensed Rights by any recipient of the Licensed\n               Material.\n\n       6. No endorsement. Nothing in this Public License constitutes or\n          may be construed as permission to assert or imply that You\n          are, or that Your use of the Licensed Material is, connected\n          with, or sponsored, endorsed, or granted official status by,\n          the Licensor or others designated to receive attribution as\n          provided in Section 3(a)(1)(A)(i).\n\n  b. Other rights.\n\n       1. Moral rights, such as the right of integrity, are not\n          licensed under this Public License, nor are publicity,\n          privacy, and/or other similar personality rights; however, to\n          the extent possible, the Licensor waives and/or agrees not to\n          assert any such rights held by the Licensor to the limited\n          extent necessary to allow You to exercise the Licensed\n          Rights, but not otherwise.\n\n       2. Patent and trademark rights are not licensed under this\n          Public License.\n\n       3. To the extent possible, the Licensor waives any right to\n          collect royalties from You for the exercise of the Licensed\n          Rights, whether directly or through a collecting society\n          under any voluntary or waivable statutory or compulsory\n          licensing scheme. In all other cases the Licensor expressly\n          reserves any right to collect such royalties, including when\n          the Licensed Material is used other than for NonCommercial\n          purposes.\n\n\nSection 3 -- License Conditions.\n\nYour exercise of the Licensed Rights is expressly made subject to the\nfollowing conditions.\n\n  a. Attribution.\n\n       1. If You Share the Licensed Material (including in modified\n          form), You must:\n\n            a. retain the following if it is supplied by the Licensor\n               with the Licensed Material:\n\n                 i. identification of the creator(s) of the Licensed\n                    Material and any others designated to receive\n                    attribution, in any reasonable manner requested by\n                    the Licensor (including by pseudonym if\n                    designated);\n\n                ii. a copyright notice;\n\n               iii. a notice that refers to this Public License;\n\n                iv. a notice that refers to the disclaimer of\n                    warranties;\n\n                 v. a URI or hyperlink to the Licensed Material to the\n                    extent reasonably practicable;\n\n            b. indicate if You modified the Licensed Material and\n               retain an indication of any previous modifications; and\n\n            c. indicate the Licensed Material is licensed under this\n               Public License, and include the text of, or the URI or\n               hyperlink to, this Public License.\n\n       2. You may satisfy the conditions in Section 3(a)(1) in any\n          reasonable manner based on the medium, means, and context in\n          which You Share the Licensed Material. For example, it may be\n          reasonable to satisfy the conditions by providing a URI or\n          hyperlink to a resource that includes the required\n          information.\n       3. If requested by the Licensor, You must remove any of the\n          information required by Section 3(a)(1)(A) to the extent\n          reasonably practicable.\n\n  b. ShareAlike.\n\n     In addition to the conditions in Section 3(a), if You Share\n     Adapted Material You produce, the following conditions also apply.\n\n       1. The Adapter's License You apply must be a Creative Commons\n          license with the same License Elements, this version or\n          later, or a BY-NC-SA Compatible License.\n\n       2. You must include the text of, or the URI or hyperlink to, the\n          Adapter's License You apply. You may satisfy this condition\n          in any reasonable manner based on the medium, means, and\n          context in which You Share Adapted Material.\n\n       3. You may not offer or impose any additional or different terms\n          or conditions on, or apply any Effective Technological\n          Measures to, Adapted Material that restrict exercise of the\n          rights granted under the Adapter's License You apply.\n\n\nSection 4 -- Sui Generis Database Rights.\n\nWhere the Licensed Rights include Sui Generis Database Rights that\napply to Your use of the Licensed Material:\n\n  a. for the avoidance of doubt, Section 2(a)(1) grants You the right\n     to extract, reuse, reproduce, and Share all or a substantial\n     portion of the contents of the database for NonCommercial purposes\n     only;\n\n  b. if You include all or a substantial portion of the database\n     contents in a database in which You have Sui Generis Database\n     Rights, then the database in which You have Sui Generis Database\n     Rights (but not its individual contents) is Adapted Material,\n     including for purposes of Section 3(b); and\n\n  c. You must comply with the conditions in Section 3(a) if You Share\n     all or a substantial portion of the contents of the database.\n\nFor the avoidance of doubt, this Section 4 supplements and does not\nreplace Your obligations under this Public License where the Licensed\nRights include other Copyright and Similar Rights.\n\n\nSection 5 -- Disclaimer of Warranties and Limitation of Liability.\n\n  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE\n     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS\n     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF\n     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,\n     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,\n     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR\n     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,\n     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT\n     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT\n     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.\n\n  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE\n     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,\n     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,\n     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,\n     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR\n     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN\n     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR\n     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR\n     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.\n\n  c. The disclaimer of warranties and limitation of liability provided\n     above shall be interpreted in a manner that, to the extent\n     possible, most closely approximates an absolute disclaimer and\n     waiver of all liability.\n\n\nSection 6 -- Term and Termination.\n\n  a. This Public License applies for the term of the Copyright and\n     Similar Rights licensed here. However, if You fail to comply with\n     this Public License, then Your rights under this Public License\n     terminate automatically.\n\n  b. Where Your right to use the Licensed Material has terminated under\n     Section 6(a), it reinstates:\n\n       1. automatically as of the date the violation is cured, provided\n          it is cured within 30 days of Your discovery of the\n          violation; or\n\n       2. upon express reinstatement by the Licensor.\n\n     For the avoidance of doubt, this Section 6(b) does not affect any\n     right the Licensor may have to seek remedies for Your violations\n     of this Public License.\n\n  c. For the avoidance of doubt, the Licensor may also offer the\n     Licensed Material under separate terms or conditions or stop\n     distributing the Licensed Material at any time; however, doing so\n     will not terminate this Public License.\n\n  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public\n     License.\n\n\nSection 7 -- Other Terms and Conditions.\n\n  a. The Licensor shall not be bound by any additional or different\n     terms or conditions communicated by You unless expressly agreed.\n\n  b. Any arrangements, understandings, or agreements regarding the\n     Licensed Material not stated herein are separate from and\n     independent of the terms and conditions of this Public License.\n\n\nSection 8 -- Interpretation.\n\n  a. For the avoidance of doubt, this Public License does not, and\n     shall not be interpreted to, reduce, limit, restrict, or impose\n     conditions on any use of the Licensed Material that could lawfully\n     be made without permission under this Public License.\n\n  b. To the extent possible, if any provision of this Public License is\n     deemed unenforceable, it shall be automatically reformed to the\n     minimum extent necessary to make it enforceable. If the provision\n     cannot be reformed, it shall be severed from this Public License\n     without affecting the enforceability of the remaining terms and\n     conditions.\n\n  c. No term or condition of this Public License will be waived and no\n     failure to comply consented to unless expressly agreed to by the\n     Licensor.\n\n  d. Nothing in this Public License constitutes or may be interpreted\n     as a limitation upon, or waiver of, any privileges and immunities\n     that apply to the Licensor or You, including from the legal\n     processes of any jurisdiction or authority.\n\n=======================================================================\n\nCreative Commons is not a party to its public\nlicenses. Notwithstanding, Creative Commons may elect to apply one of\nits public licenses to material it publishes and in those instances\nwill be considered the “Licensor.” The text of the Creative Commons\npublic licenses is dedicated to the public domain under the CC0 Public\nDomain Dedication. Except for the limited purpose of indicating that\nmaterial is shared under a Creative Commons public license or as\notherwise permitted by the Creative Commons policies published at\ncreativecommons.org/policies, Creative Commons does not authorize the\nuse of the trademark \"Creative Commons\" or any other trademark or logo\nof Creative Commons without its prior written consent including,\nwithout limitation, in connection with any unauthorized modifications\nto any of its public licenses or any other arrangements,\nunderstandings, or agreements concerning use of licensed material. For\nthe avoidance of doubt, this paragraph does not form part of the\npublic licenses.\n\nCreative Commons may be contacted at creativecommons.org.\n"
  },
  {
    "path": "Output/.gitkeep",
    "content": ""
  },
  {
    "path": "README.md",
    "content": "# CorridorKey\n\n\nhttps://github.com/user-attachments/assets/1fb27ea8-bc91-4ebc-818f-5a3b5585af08\n\n\nWhen you film something against a green screen, the edges of your subject inevitably blend with the green background. This creates pixels that are a mix of your subject's color and the green screen's color. Traditional keyers struggle to untangle these colors, forcing you to spend hours building complex edge mattes or manually rotoscoping. Even modern \"AI Roto\" solutions typically output a harsh binary mask, completely destroying the delicate, semi-transparent pixels needed for a realistic composite.\n\nI built CorridorKey to solve this *unmixing* problem. \n\nYou input a raw green screen frame, and the neural network completely separates the foreground object from the green screen. For every single pixel, even the highly transparent ones like motion blur or out-of-focus edges, the model predicts the true, un-multiplied straight color of the foreground element, alongside a clean, linear alpha channel. It doesn't just guess what is opaque and what is transparent; it actively reconstructs the color of the foreground object as if the green screen was never there.\n\nNo more fighting with garbage mattes or agonizing over \"core\" vs \"edge\" keys. Give CorridorKey a hint of what you want, and it separates the light for you.\n\n## Alert!\n\nThis is a brand new release, I'm sure you will discover many ways it can be improved! I invite everyone to help. Join us on the \"Corridor Creates\" Discord to share ideas, work, forks, etc! https://discord.gg/zvwUrdWXJm\n\nIf you want an easy-install, artist-friendly user interface version of CorridorKey, check out [EZ-CorridorKey](https://github.com/edenaion/EZ-CorridorKey)\n\nThis project uses [uv](https://docs.astral.sh/uv/) to manage dependencies — it handles Python installation, virtual environments, and packages all in one step, so you don't need to worry about any of that. just run the appropriate install script for your OS.\n\nNaturally, I have not tested everything. If you encounter errors, please consider patching the code as needed and submitting a pull request.\n\n## Features\n\n*   **Physically Accurate Unmixing:** Clean extraction of straight color foreground and linear alpha channels, preserving hair, motion blur, and translucency.\n*   **Resolution Independent:** The engine dynamically scales inference to handle 4K plates while predicting using its native 2048x2048 high-fidelity backbone.\n*   **VFX Standard Outputs:** Natively reads and writes 16-bit and 32-bit Linear float EXR files, preserving true color math for integration in Nuke, Fusion, or Resolve.\n*   **Auto-Cleanup:** Includes a morphological cleanup system to automatically prune any tracking markers or tiny background features that slip through CorridorKey's detection.\n\n## Hardware Requirements\n\nThis project was designed and built on a Linux workstation (Puget Systems PC) equipped with an NVIDIA RTX Pro 6000 with 96GB of VRAM. The community is ACTIVELY optimizing it for consumer GPUS.\n\nThe most recent build should work on computers with 6-8 gig of VRAM, and it can run on most M1+ Mac systems with unified memory. Yes, it might even work on your old Macbook pro. Let us know on the Discord!\n\n*   **Windows Users:** To run GPU acceleration natively on Windows, your system MUST have NVIDIA drivers that support **CUDA 12.8 or higher** installed. If your drivers only support older CUDA versions, the installer will likely fallback to the CPU.\n*   **GVM (Optional):** Requires approximately **80 GB of VRAM** and utilizes massive Stable Video Diffusion models.\n*   **VideoMaMa (Optional):** Natively requires a massive chunk of VRAM as well (originally 80GB+). While the community has tweaked the architecture to run at less than 24GB, those extreme memory optimizations have not yet been fully implemented in this repository.\n*   **BiRefNet (Optional):** Lightweight AlphaHint generator option.\n\nBecause GVM and VideoMaMa have huge model file sizes and extreme hardware requirements, installing their modules is completely optional. You can always provide your own Alpha Hints generated from your editing program, BiRefNet, or any other method. The better the AlphaHint, the better the result.\n\n## Getting Started\n\n### 1. Installation\n\nThis project uses **[uv](https://docs.astral.sh/uv/)** to manage Python and all dependencies. uv is a fast, modern replacement for pip that automatically handles Python versions, virtual environments, and package installation in a single step. You do **not** need to install Python yourself — uv does it for you.\n\n**For Windows Users (Automated):**\n1.  Clone or download this repository to your local machine.\n2.  Double-click `Install_CorridorKey_Windows.bat`. This will automatically install uv (if needed), set up your Python environment, install all dependencies, and download the CorridorKey model.\n    > **Note:** If this is the first time installing uv, any terminal windows you already had open won't see it. The installer script handles the current window automatically, but if you open a new terminal and get \"'uv' is not recognized\", just close and reopen that terminal.\n3.  (Optional) Double-click `Install_GVM_Windows.bat` and `Install_VideoMaMa_Windows.bat` to download the heavy optional Alpha Hint generator weights.\n\n**For Linux / Mac Users (Automated):**\n1.  Clone or download this repository to your local machine.\n2.  Open terminal and write `bash`. Put a space after writing `bash`.\n3.  Drag and drop `Install_CorridorKey_Linux_Mac.sh` into the terminal. Then press enter.\n4.  (Optional) Do the 2. step again. But now drag and drop `Install_GVM_Linux_Mac.sh` and `Install_VideoMaMa_Linux_Mac.sh` to download the heavy optional Alpha Hint generator weights.\n\n**For Linux / Mac Users (Manual):**\n1.  Clone or download this repository to your local machine.\n2.  Install uv if you don't have it:\n    ```bash\n    curl -LsSf https://astral.sh/uv/install.sh | sh\n    ```\n3.  Install all dependencies (uv will download Python 3.10+ automatically if needed):\n    ```bash\n    uv sync                  # CPU/MPS (default — works everywhere)\n    uv sync --extra cuda     # CUDA GPU acceleration (Linux/Windows)\n    uv sync --extra mlx      # Apple Silicon MLX acceleration\n    ```\n4.  **Download the Models:**\n    *   **CorridorKey v1.0 Model (~300MB):** Downloads automatically on first run. If no `.pth` file is found in `CorridorKeyModule/checkpoints/`, the engine fetches it from [CorridorKey's HuggingFace](https://huggingface.co/nikopueringer/CorridorKey_v1.0) and saves it as `CorridorKey.pth`. No manual download needed.\n    *   **GVM Weights (Optional):** [HuggingFace: geyongtao/gvm](https://huggingface.co/geyongtao/gvm)\n        *   Download using the CLI: `uv run hf download geyongtao/gvm --local-dir gvm_core/weights`\n    *   **VideoMaMa Weights (Optional):** [HuggingFace: SammyLim/VideoMaMa](https://huggingface.co/SammyLim/VideoMaMa)\n        *   Download the VideoMaMa fine-tuned weights:\n            ```\n            uv run hf download SammyLim/VideoMaMa --local-dir VideoMaMaInferenceModule/checkpoints/VideoMaMa\n            ```\n        *   VideoMaMa also requires the Stable Video Diffusion base model (VAE + image encoder only, ~2.5GB). Accept the license at [stabilityai/stable-video-diffusion-img2vid-xt](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), then:\n            ```\n            uv run hf download stabilityai/stable-video-diffusion-img2vid-xt \\\n              --local-dir VideoMaMaInferenceModule/checkpoints/stable-video-diffusion-img2vid-xt \\\n              --include \"feature_extractor/*\" \"image_encoder/*\" \"vae/*\" \"model_index.json\"\n            ```\n        *   VideoMaMa is an amazing project, please go star their [repo](https://github.com/cvlab-kaist/VideoMaMa) and show them some support! \n### 2. How it Works\n\nCorridorKey requires two inputs to process a frame:\n1.  **The Original RGB Image:** The to-be-processed green screen footage. This requires the sRGB color gamut (interchangeable with REC709 gamut), and the engine can ingest either an sRGB gamma or Linear gamma curve. \n2.  **A Coarse Alpha Hint:** A rough black-and-white mask that generally isolates the subject. This does *not* need to be precise. It can be generated by you with a rough chroma key or AI roto.\n\nI've had the best results using GVM or VideoMaMa to create the AlphaHint, so I've repackaged those projects and integrated them here as optional modules inside `clip_manager.py`. Here is how they compare:\n\n*   **GVM:** Completely automatic and requires no additional input. It works exceptionally well for people, but can struggle with inanimate objects.\n*   **VideoMaMa:** Requires you to provide a rough VideoMamaMaskHint (often drawn by hand or AI) telling it what you want to key. If you choose to use this, place your mask hint in the `VideoMamaMaskHint/` folder that the wizard creates for your shot. VideoMaMa results are spectacular and can be controlled more easily than GVM due to this mask hint.\n*   **Please** go show the creators of these projects some love and star their repos. [VideoMaMa](https://github.com/cvlab-kaist/VideoMaMa) and [GVM](https://github.com/aim-uofa/GVM)\n\nPerhaps in the future, I will implement other generators for the AlphaHint! In the meantime, the better your Alpha Hint, the better CorridorKey's final result will be. Experiment with different amounts of mask erosion or feathering. The model was trained on coarse, blurry, eroded masks, and is exceptional at filling in details from the hint. However, it is generally less effective at subtracting unwanted mask details if your Alpha Hint is expanded too far. \n\nPlease give feedback and share your results!\n\n### Docker (Linux + NVIDIA GPU)\n\nIf you prefer not to install dependencies locally, you can run CorridorKey in Docker.\n\nPrerequisites:\n- Docker Engine + Docker Compose plugin installed.\n- NVIDIA driver installed on the host (Linux), with CUDA compatibility for the PyTorch CUDA 12.6 wheels used by this project.\n- NVIDIA Container Toolkit installed and configured for Docker (`nvidia-smi` should work on host, and `docker run --rm --gpus all nvidia/cuda:12.6.3-runtime-ubuntu22.04 nvidia-smi` should succeed).\n\n1. Build the image:\n   ```bash\n   docker build -t corridorkey:latest .\n   ```\n2. Run an action directly (example: inference):\n   ```bash\n   docker run --rm -it --gpus all \\\n     -e OPENCV_IO_ENABLE_OPENEXR=1 \\\n     -v \"$(pwd)/ClipsForInference:/app/ClipsForInference\" \\\n     -v \"$(pwd)/Output:/app/Output\" \\\n     -v \"$(pwd)/CorridorKeyModule/checkpoints:/app/CorridorKeyModule/checkpoints\" \\\n     -v \"$(pwd)/gvm_core/weights:/app/gvm_core/weights\" \\\n     -v \"$(pwd)/VideoMaMaInferenceModule/checkpoints:/app/VideoMaMaInferenceModule/checkpoints\" \\\n     corridorkey:latest run_inference --device cuda\n   ```\n3. Docker Compose (recommended for repeat runs):\n   ```bash\n   docker compose build\n   docker compose --profile gpu run --rm corridorkey run_inference --device cuda\n   docker compose --profile gpu run --rm corridorkey list\n   docker compose --profile cpu run --rm corridorkey-cpu run_inference --device cpu\n   ```\n4. Optional: pin to specific GPU(s) for multi-GPU workstations:\n   ```bash\n   NVIDIA_VISIBLE_DEVICES=0 docker compose --profile gpu run --rm corridorkey list\n   NVIDIA_VISIBLE_DEVICES=1,2 docker compose --profile gpu run --rm corridorkey run_inference --device cuda\n   ```\n\nNotes:\n- You still need to place model weights in the same folders used by native runs (mounted above).\n- The container does not include kernel GPU drivers; those always come from the host. The image provides user-space dependencies and relies on Docker's NVIDIA runtime to pass through driver libraries/devices.\n- The wizard works too, but use a path inside the container, for example:\n  ```bash\n  docker run --rm -it --gpus all \\\n    -e OPENCV_IO_ENABLE_OPENEXR=1 \\\n    -v \"$(pwd)/ClipsForInference:/app/ClipsForInference\" \\\n    -v \"$(pwd)/Output:/app/Output\" \\\n    -v \"$(pwd)/CorridorKeyModule/checkpoints:/app/CorridorKeyModule/checkpoints\" \\\n    -v \"$(pwd)/gvm_core/weights:/app/gvm_core/weights\" \\\n    -v \"$(pwd)/VideoMaMaInferenceModule/checkpoints:/app/VideoMaMaInferenceModule/checkpoints\" \\\n    corridorkey:latest wizard --win_path /app/ClipsForInference\n  docker compose --profile gpu run --rm corridorkey wizard --win_path /app/ClipsForInference\n  ```\n\n### 3. Usage: The Command Line Wizard\n\nFor the easiest experience, use the provided launcher scripts. These scripts launch a prompt-based configuration wizard in your terminal.\n\n*   **Windows:** Drag-and-drop a video file or folder onto `CorridorKey_DRAG_CLIPS_HERE_local.bat` (Note: Only launch via Drag-and-Drop or CMD. Double-clicking the `.bat` directly will throw an error).\n*   **Linux / Mac:** Run or drag-and-drop a video file or folder onto `./CorridorKey_DRAG_CLIPS_HERE_local.sh`.\n* - Or write `bash` again in terminal. Put a space after and then drag-and-drop `CorridorKey_DRAG_CLIPS_HERE_local.sh` and your clip folder together into terminal, respectively. Then press enter.\n\n**Workflow Steps:**\n1.  **Launch:** You can drag-and-drop a single loose video file (like an `.mp4`), a shot folder containing image sequences, or even a master \"batch\" folder containing multiple different shots all at once onto the launcher script.\n2.  **Organization:** The wizard will detect what you dragged in. If you dropped loose video files or unorganized folders, the first prompt will ask if you want it to organize your clips into the proper structure. \n    *   If you say Yes, the script will automatically create a shot folder, move your footage into an `Input/` sub-folder, and generate empty `AlphaHint/` and `VideoMamaMaskHint/` folders for you. This structure is required for the engine to pair your hints and footage correctly!\n3.  **Generate Hints (Optional):** If the wizard detects your shots are missing an `AlphaHint`, it will ask if you want to generate them automatically using the repackaged GVM or VideoMaMa modules.\n4.  **Configure:** Once your clips have both Inputs and AlphaHints, select \"Process Ready Clips\". The wizard will prompt you to configure the run:\n    *   **Gamma Space:** Tell the engine if your sequence uses a Linear or sRGB gamma curve.\n    *   **Despill Strength:** This is a traditional despill filter (0-10), if you wish to have it baked into the output now as opposed to applying it in your comp later.\n    *   **Auto-Despeckle:** Toggle automatic cleanup and define the size threshold. This isn't just for tracking dots, it removes any small, disconnected islands of pixels.\n    *   **Refiner Strength:** Use the default (1.0) unless you are experimenting with extreme detail pushing.\n5.  **Result:** The engine will generate several folders inside your shot directory:\n    *   `/Matte`: The raw Linear Alpha channel (EXR).\n    *   `/FG`: The raw Straight Foreground Color Object. (Note: The engine natively computes this in the sRGB gamut. You must manually convert this pass to linear gamma before being combined with the alpha in your compositing program).\n    *   `/Processed`: An RGBA image containing the Linear Foreground premultiplied against the Linear Alpha (EXR). This pass exists so you can immediately drop the footage into Premiere/Resolve for a quick preview without dealing with complex premultiplication routing. However, if you want more control over your image, working with the raw FG and Matte outputs will give you that.\n    *   `/Comp`: A simple preview of the key composited over a checkerboard (PNG).\n\n## But What About Training and Datasets?\n\nIf enough people find this project interesting I'll get the training program and datasets uploaded so we can all really go to town making the absolute best keyer fine tunes! Just hit me with some messages on the Corridor Creates discord or here. If enough people lock in, I'll get this stuff packaged up. Hardware requirements are beefy and the gigabytes are plentiful so I don't want to commit the time unless there's demand.\n\n## Device Selection\n\nBy default, CorridorKey auto-detects the best available compute device: **CUDA > MPS > CPU**.\n\n**Override via CLI flag:**\n```bash\nuv run python clip_manager.py --action wizard --win_path \"V:\\...\" --device mps\nuv run python clip_manager.py --action run_inference --device cpu\n```\n\n**Override via environment variable:**\n```bash\nexport CORRIDORKEY_DEVICE=cpu\nuv run python clip_manager.py --action wizard --win_path \"V:\\...\"\n```\n\nPriority: `--device` flag > `CORRIDORKEY_DEVICE` env var > auto-detect.\n\n### Apple Silicon / MPS Troubleshooting\n\n**Confirm MPS is active:** Run with verbose logging to see which device was selected:\n```bash\nuv run python clip_manager.py --action list 2>&1 | grep -i \"device\\|backend\\|mps\"\n```\n\n**MPS operator errors** (`NotImplementedError: ... not implemented for 'MPS'`): Some PyTorch operations are not yet supported on MPS. Enable CPU fallback for those ops:\n```bash\nexport PYTORCH_ENABLE_MPS_FALLBACK=1\nuv run python corridorkey_cli.py wizard --win_path \"/path/to/clips\"\n```\n\n**Silent CPU fallback**: If MPS silently falls back to CPU without this variable, the run will be much slower. Setting `PYTORCH_ENABLE_MPS_FALLBACK=1` in your shell profile (`~/.zshrc`) ensures it is always active.\n\n**Use native MLX instead of PyTorch MPS:** MLX avoids PyTorch's MPS layer entirely and typically runs faster on Apple Silicon. See the [Backend Selection](#backend-selection) section below for setup steps.\n\n## Backend Selection\n\nCorridorKey supports two inference backends:\n- **Torch** (default on Linux/Windows) — CUDA, MPS, or CPU\n- **MLX** (Apple Silicon) — native Metal acceleration, no Torch overhead\n\nResolution: `--backend` flag > `CORRIDORKEY_BACKEND` env var > auto-detect.\nAuto mode prefers MLX on Apple Silicon when available.\n\n**Override via CLI flag (corridorkey_cli.py):**\n```bash\nuv run python corridorkey_cli.py wizard --win_path \"/path/to/clips\" --backend mlx\nuv run python corridorkey_cli.py run_inference --backend torch\n```\n\n### MLX Setup (Apple Silicon)\n\n1. Install the MLX backend:\n   ```bash\n   uv sync --extra mlx\n   ```\n2. Obtain the MLX weights (`.safetensors`) — pick **one** option:\n\n   **Option A — Download pre-converted weights (simplest):**\n   ```bash\n   # Download weights from GitHub Releases into a local cache directory\n   uv run python -m corridorkey_mlx weights download\n\n   # Print the cached path, then copy to the checkpoints folder\n   WEIGHTS=$(uv run python -m corridorkey_mlx weights download --print-path)\n   cp \"$WEIGHTS\" CorridorKeyModule/checkpoints/corridorkey_mlx.safetensors\n   ```\n\n   **Option B — Convert from an existing `.pth` checkpoint:**\n   ```bash\n   # Clone the MLX repo (contains the conversion script)\n   git clone https://github.com/nikopueringer/corridorkey-mlx.git\n   cd corridorkey-mlx\n   uv sync\n\n   # Convert (point --checkpoint at your CorridorKey.pth)\n   uv run python scripts/convert_weights.py \\\n       --checkpoint ../CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth \\\n       --output ../CorridorKeyModule/checkpoints/corridorkey_mlx.safetensors\n   cd ..\n   ```\n\n   Either way the final file must be at:\n   ```\n   CorridorKeyModule/checkpoints/corridorkey_mlx.safetensors\n   ```\n3. Run with auto-detection or explicit backend:\n   ```bash\n   CORRIDORKEY_BACKEND=mlx uv run python clip_manager.py --action run_inference\n   ```\n\nMLX uses img_size=2048 by default (same as Torch).\n\n### Troubleshooting\n- **\"No .safetensors checkpoint found\"** — place MLX weights in `CorridorKeyModule/checkpoints/`\n- **\"corridorkey_mlx not installed\"** — run `uv sync --extra mlx`\n- **\"MLX requires Apple Silicon\"** — MLX only works on M1+ Macs\n- **Auto picked Torch unexpectedly** — set `CORRIDORKEY_BACKEND=mlx` explicitly\n\n## Advanced Usage\n\nFor developers looking for more details on the specifics of what is happening in the CorridorKey engine, check out the README in the `/CorridorKeyModule` folder. We also have a dedicated handover document outlining the pipeline architecture for AI assistants in `/docs/LLM_HANDOVER.md`.\n\nYou can also explore the full, auto-generated codebase documentation on [DeepWiki](https://deepwiki.com/nikopueringer/CorridorKey).\n\n### Running Tests\n\nThe project includes unit tests for the color math and compositing pipeline. No GPU or model weights required — tests run in a few seconds on any machine.\n\n```bash\nuv sync --group dev   # install test dependencies (pytest)\nuv run pytest          # run all tests\nuv run pytest -v       # verbose output (shows each test name)\n```\n\n## CorridorKey Licensing and Permissions\n\nUse this tool for whatever you'd like, including for processing images as part of a commercial project! You MAY NOT repackage this tool and sell it, and any variations or improvements of this tool that are released must remain under the same license, and must include the name Corridor Key.\n\nYou MAY NOT offer inference with this model as a paid API service. If you run a commercial software package or inference service and wish to incoporate this tool into your software, shoot us an email to work out an agreement! I promise we're easy to work with. contact@corridordigital.com. Outside of the stipulations listed above, this license is effectively a variation of [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/)\n\nPlease keep the Corridor Key name in any future forks or releases!\n\n## Community Extensions\n* [CorridorKeyOpenVINO](https://github.com/daniil-lyakhov/CorridorKeyOpenVINO) - Run the CorridorKey model quickly on Intel hardware with the OpenVINO inference framework.\n\n## Acknowledgements and Licensing\n\nCorridorKey integrates several open-source modules for Alpha Hint generation. We would like to explicitly credit and thank the following research teams:\n\n*   **Generative Video Matting (GVM):** Developed by the Advanced Intelligent Machines (AIM) research team at Zhejiang University. The GVM code and models are heavily utilized in the `gvm_core` module. Their work is licensed under the [2-clause BSD License (BSD-2-Clause)](https://opensource.org/license/bsd-2-clause). You can find their source repository here: [aim-uofa/GVM](https://github.com/aim-uofa/GVM). Give them a star!\n*   **VideoMaMa:** Developed by the CVLAB at KAIST. The VideoMaMa architecture is utilized within the `VideoMaMaInferenceModule`. Their code is released under the [Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)](https://creativecommons.org/licenses/by-nc/4.0/), and their specific foundation model checkpoints (`dino_projection_mlp.pth`, `unet/*`) are subject to the [Stability AI Community License](https://stability.ai/license). You can find their source repository here: [cvlab-kaist/VideoMaMa](https://github.com/cvlab-kaist/VideoMaMa). Give them a star!\n\nBy using these optional modules, you agree to abide by their respective Non-Commercial licenses. Please review their repositories for full terms.\n"
  },
  {
    "path": "RunGVMOnly.sh",
    "content": "#!/usr/bin/env bash\n\n# Ensure script stops on error\nset -e\n\n# Path to script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\n# Enable OpenEXR Support\nexport OPENCV_IO_ENABLE_OPENEXR=1\n\necho \"Starting Coarse Alpha Generation...\"\necho \"Scanning ClipsForInference...\"\n\n# Run via uv entry point (handles the virtual environment automatically)\nuv run corridorkey generate-alphas\n\necho \"Done.\"\n"
  },
  {
    "path": "RunInferenceOnly.sh",
    "content": "#!/usr/bin/env bash\n\n# Ensure script stops on error\nset -e\n\n# Path to script directory\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\n\n# Enable OpenEXR Support\nexport OPENCV_IO_ENABLE_OPENEXR=1\n\necho \"Starting CorridorKey Inference...\"\necho \"Scanning ClipsForInference for Ready Clips (Input + Alpha)...\"\n\n# Run via uv entry point (handles the virtual environment automatically)\nuv run corridorkey run-inference\n\necho \"Inference Complete.\"\n"
  },
  {
    "path": "VideoMaMaInferenceModule/LICENSE.md",
    "content": "# VideoMaMa Licensing and Acknowledgements\n\nThis module (`VideoMaMaInferenceModule`) contains repackaged code and integrations from the **Video Masked Modeling (VideoMaMa)** project developed by the CVLAB at KAIST.\n\n## Original Repository\n*   **GitHub:** [https://github.com/cvlab-kaist/VideoMaMa](https://github.com/cvlab-kaist/VideoMaMa)\n*   **HuggingFace:** [https://huggingface.co/SammyLim/VideoMaMa](https://huggingface.co/SammyLim/VideoMaMa)\n\n## License\nThe VideoMaMa codebase is released under the **Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)**. \nTo view a copy of this license, visit: [http://creativecommons.org/licenses/by-nc/4.0/](http://creativecommons.org/licenses/by-nc/4.0/)\n\nAdditionally, the model checkpoints are subject to the **Stability AI Community License**.\nTo view a copy of this license, visit: [https://stability.ai/license](https://stability.ai/license)\n\nBy utilizing this module and downloading the associated weights, you are subject to these Non-Commercial and Community licenses.\n"
  },
  {
    "path": "VideoMaMaInferenceModule/README.md",
    "content": "# VideoMaMa Inference Module\n\nThis module provides a standalone interface for running VideoMaMa inference.\n\n## Usage\n\n```python\nimport sys\n# Ensure the parent directory of this module is in sys.path\nsys.path.append(\"/path/to/parent/directory\")\n\nfrom VideoMaMa_Inference_Module import load_videomama_model, run_inference, extract_frames_from_video, save_video\n\n# 1. Load Model\n# By default, it loads checkpoints from the local 'checkpoints/' directory inside the module.\n# Ensure you have copied 'stable-video-diffusion-img2vid-xt' and 'VideoMaMa' into 'checkpoints/'.\npipeline = load_videomama_model(device=\"cuda\")\n\n# Alternatively, specify custom paths:\n# pipeline = load_videomama_model(base_model_path=\"/path/to/base\", unet_checkpoint_path=\"/path/to/unet\", device=\"cuda\")\n\n# 2. Prepare Inputs\n# You need a list of RGB frames and a list of mask frames (grayscale)\n# Helper function to extract from video:\nvideo_path = \"input_video.mp4\"\ninput_frames, fps = extract_frames_from_video(video_path, max_frames=24)\n\n# Load your masks (e.g. from file or other process)\n# masks = [ ... list of numpy arrays ... ]\n# Ensure len(masks) == len(input_frames)\n\n# 3. Run Inference\noutput_frames = run_inference(pipeline, input_frames, masks)\n\n# 4. Save Output\nsave_video(output_frames, \"output.mp4\", fps)\n```\n\n## Requirements\n\nInstall dependencies listed in `requirements.txt`.\n```bash\npip install -r requirements.txt\n```\n"
  },
  {
    "path": "VideoMaMaInferenceModule/__init__.py",
    "content": "from .inference import load_videomama_model, run_inference, extract_frames_from_video, save_video\nfrom .pipeline import VideoInferencePipeline\n\n__all__ = [\n    \"load_videomama_model\",\n    \"run_inference\",\n    \"extract_frames_from_video\",\n    \"save_video\",\n    \"VideoInferencePipeline\"\n]\n"
  },
  {
    "path": "VideoMaMaInferenceModule/checkpoints/.gitkeep",
    "content": ""
  },
  {
    "path": "VideoMaMaInferenceModule/inference.py",
    "content": "\"\"\"\nVideoMaMa Inference Module\nProvides functions to load the model and run inference on video inputs.\n\"\"\"\n\nimport os\nimport sys\nimport torch\nimport cv2\nimport numpy as np\nfrom PIL import Image\nfrom typing import List, Union, Optional\nfrom pathlib import Path\n\n# Add current directory to path so that pipeline.py's intra-package imports\n# (e.g. \"from pipeline import ...\") resolve when this module is imported from\n# outside the VideoMaMaInferenceModule directory.  This is a workaround for the\n# module's original structure — a cleaner fix would convert to proper relative\n# imports throughout.\ncurrent_dir = os.path.dirname(os.path.abspath(__file__))\nif current_dir not in sys.path:\n    sys.path.append(current_dir)\n\nfrom .pipeline import VideoInferencePipeline\n\ndef load_videomama_model(base_model_path: Optional[str] = None, unet_checkpoint_path: Optional[str] = None, device: str = \"cpu\") -> VideoInferencePipeline:\n    \"\"\"\n    Load VideoMaMa pipeline with pretrained weights.\n\n    Args:\n        base_model_path (str, optional): Path to the base Stable Video Diffusion model. \n                                         Defaults to 'checkpoints/stable-video-diffusion-img2vid-xt' in module dir.\n        unet_checkpoint_path (str, optional): Path to the fine-tuned UNet checkpoint.\n                                              Defaults to 'checkpoints/VideoMaMa' in module dir.\n        device (str): Device to run on (\"cuda\" or \"cpu\").\n\n    Returns:\n        VideoInferencePipeline: Loaded pipeline instance.\n    \"\"\"\n    # Default to local checkpoints if not provided\n    if base_model_path is None:\n        base_model_path = os.path.join(current_dir, \"checkpoints\", \"stable-video-diffusion-img2vid-xt\")\n    \n    if unet_checkpoint_path is None:\n        unet_checkpoint_path = os.path.join(current_dir, \"checkpoints\", \"VideoMaMa\")\n\n    print(f\"Loading Base model from {base_model_path}...\")\n    print(f\"Loading VideoMaMa UNet from {unet_checkpoint_path}...\")\n    \n    # Check if paths exist\n    if not os.path.exists(base_model_path):\n        raise FileNotFoundError(f\"Base model path not found: {base_model_path}\")\n    if not os.path.exists(unet_checkpoint_path):\n        raise FileNotFoundError(f\"UNet checkpoint path not found: {unet_checkpoint_path}\")\n\n    pipeline = VideoInferencePipeline(\n        base_model_path=base_model_path,\n        unet_checkpoint_path=unet_checkpoint_path,\n        weight_dtype=torch.float16, # Use float16 for inference by default\n        device=device\n    )\n    \n    print(\"VideoMaMa pipeline loaded successfully!\")\n    return pipeline\n\ndef extract_frames_from_video(video_path: str, max_frames: Optional[int] = None) -> tuple[List[np.ndarray], float]:\n    \"\"\"\n    Extract frames from video file.\n\n    Args:\n        video_path (str): Path to video file.\n        max_frames (int, optional): Maximum number of frames to extract.\n\n    Returns:\n        tuple: (List of numpy arrays (H,W,3) uint8 RGB, FPS)\n    \"\"\"\n    if not os.path.exists(video_path):\n        raise FileNotFoundError(f\"Video file not found: {video_path}\")\n\n    cap = cv2.VideoCapture(video_path)\n    original_fps = cap.get(cv2.CAP_PROP_FPS)\n    \n    all_frames = []\n    while cap.isOpened():\n        ret, frame = cap.read()\n        if not ret:\n            break\n        # Convert BGR to RGB\n        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n        all_frames.append(frame_rgb)\n    \n    cap.release()\n    \n    if max_frames and len(all_frames) > max_frames:\n        frames = all_frames[:max_frames]\n    else:\n        frames = all_frames\n    \n    return frames, original_fps\n\ndef run_inference(\n    pipeline: VideoInferencePipeline,\n    input_frames: List[np.ndarray],\n    mask_frames: List[np.ndarray],\n    chunk_size: int = 24  # Adjusted default chunk size\n) -> List[np.ndarray]:\n    \"\"\"\n    Run VideoMaMa inference on video frames with mask conditioning.\n\n    Args:\n        pipeline (VideoInferencePipeline): Loaded pipeline instance.\n        input_frames (List[np.ndarray]): List of RGB frames (H,W,3) uint8.\n        mask_frames (List[np.ndarray]): List of mask frames (H,W) uint8 (0-255) grayscale.\n        chunk_size (int): Number of frames to process at once to avoid OOM.\n\n    Returns:\n        List[np.ndarray]: List of output RGB frames (H,W,3) uint8.\n    \"\"\"\n    if len(input_frames) != len(mask_frames):\n        # Resize mask frames list to match input if needed (e.g. repeat or slice)\n        # For strict correctness, we'll raise an error or warn.\n        # But let's assume the user provides matching lengths or we might need to handle it.\n        # Here we just raise for clarity.\n        raise ValueError(f\"Input frames ({len(input_frames)}) and mask frames ({len(mask_frames)}) must have same length.\")\n\n    # Convert numpy arrays to PIL Images\n    frames_pil = [Image.fromarray(f) for f in input_frames]\n    \n    # Handle mask frames - ensure they are PIL \"L\" mode\n    mask_frames_pil = []\n    for m in mask_frames:\n        if m.ndim == 3:\n            # If RGB/BGR mask, convert to grayscale\n            m = cv2.cvtColor(m, cv2.COLOR_RGB2GRAY)\n        mask_frames_pil.append(Image.fromarray(m, mode='L'))\n    \n    # Resize to model input size (1024x576 is standard for SVD)\n    target_width, target_height = 1024, 576\n    frames_resized = [f.resize((target_width, target_height), Image.Resampling.BILINEAR) \n                     for f in frames_pil]\n    masks_resized = [m.resize((target_width, target_height), Image.Resampling.BILINEAR) \n                    for m in mask_frames_pil]\n    \n    print(f\"Processing {len(frames_resized)} frames in chunks of {chunk_size}...\")\n    \n    # Store original size for resizing back\n    if not frames_pil:\n        return []\n        \n    original_size = frames_pil[0].size\n    \n    for i in range(0, len(frames_resized), chunk_size):\n        chunk_frames = frames_resized[i:i + chunk_size]\n        chunk_masks = masks_resized[i:i + chunk_size]\n        \n        print(f\"  Running inference on chunk {i//chunk_size + 1}/{len(frames_resized)//chunk_size + 1} ({len(chunk_frames)} frames)...\")\n        \n        # Clear cache before each chunk\n        if pipeline.device.type == \"cuda\":\n            torch.cuda.empty_cache()\n        \n        chunk_output = pipeline.run(\n            cond_frames=chunk_frames,\n            mask_frames=chunk_masks,\n            seed=42, # Fixed seed for reproducibility\n            mask_cond_mode=\"vae\"\n        )\n        \n        # Resize back to original resolution immediately\n        chunk_output_resized = [f.resize(original_size, Image.Resampling.BILINEAR) \n                                for f in chunk_output]\n        \n        # Convert back to numpy arrays\n        chunk_output_np = [np.array(f) for f in chunk_output_resized]\n        \n        yield chunk_output_np\n\ndef save_video(frames: List[np.ndarray], output_path: str, fps: float):\n    \"\"\"\n    Save frames as a video file.\n\n    Args:\n        frames (List[np.ndarray]): List of frames (RGB).\n        output_path (str): Output video path.\n        fps (float): Frames per second.\n    \"\"\"\n    if not frames:\n        return\n    \n    height, width = frames[0].shape[:2]\n    fourcc = cv2.VideoWriter_fourcc(*'mp4v')\n    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))\n    \n    for frame in frames:\n        # Convert RGB to BGR for OpenCV\n        frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n        out.write(frame_bgr)\n    \n    out.release()\n    print(f\"Saved video to {output_path}\")\n\n"
  },
  {
    "path": "VideoMaMaInferenceModule/pipeline.py",
    "content": "# pipeline_svd_masked.py\n\nimport inspect\nfrom dataclasses import dataclass\nfrom typing import Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nfrom diffusers.image_processor import PipelineImageInput\nfrom diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel\nfrom diffusers.schedulers import EulerDiscreteScheduler\nfrom diffusers.utils import BaseOutput, logging, replace_example_docstring\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.video_processor import VideoProcessor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\n\n# Import necessary helpers from the original SVD pipeline\nfrom diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (\n    _append_dims,\n    retrieve_timesteps,\n    _resize_with_antialiasing,\n)\nimport torch.nn.functional as F\nfrom einops import rearrange\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> from pipeline_svd_masked import StableVideoDiffusionPipelineWithMask\n        >>> from diffusers.utils import load_image, export_to_video\n\n        >>> # Load your fine-tuned UNet, VAE, etc.\n        >>> pipe = StableVideoDiffusionPipelineWithMask.from_pretrained(\n        ...     \"path/to/your/finetuned_model\", torch_dtype=torch.float16, variant=\"fp16\"\n        ... )\n        >>> pipe.to(\"cuda\")\n\n        >>> # Load the conditioning image and the mask\n        >>> image = load_image(\"path/to/your/conditioning_image.png\").resize((1024, 576))\n        >>> mask = load_image(\"path/to/your/mask_image.png\").resize((1024, 576))\n\n        >>> # Generate frames\n        >>> frames = pipe(\n        ...     image=image,\n        ...     mask_image=mask,\n        ...     num_frames=25,\n        ...     decode_chunk_size=8\n        ... ).frames[0]\n\n        >>> export_to_video(frames, \"generated_video.mp4\", fps=7)\n        ```\n\"\"\"\n\n\n@dataclass\nclass StableVideoDiffusionPipelineOutput(BaseOutput):\n    r\"\"\"\n    Output class for the custom Stable Video Diffusion pipeline.\n    Args:\n        frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):\n            List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape\n            `(batch_size, num_frames, height, width, num_channels)`.\n    \"\"\"\n    frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]\n\n\nclass StableVideoDiffusionPipelineWithMask(DiffusionPipeline):\n    r\"\"\"\n    A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.\n    This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels\n    (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).\n    \"\"\"\n\n    model_cpu_offload_seq = \"image_encoder->unet->vae\"\n    _callback_tensor_inputs = [\"latents\"]\n\n    def __init__(\n            self,\n            vae: AutoencoderKLTemporalDecoder,\n            image_encoder: CLIPVisionModelWithProjection,\n            unet: UNetSpatioTemporalConditionModel,\n            scheduler: EulerDiscreteScheduler,\n            feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_image(\n            self,\n            image: PipelineImageInput,\n            device: Union[str, torch.device],\n            num_videos_per_prompt: int,\n    ) -> torch.Tensor:\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.video_processor.pil_to_numpy(image)\n            image = self.video_processor.numpy_to_pt(image)\n\n        image = image * 2.0 - 1.0\n        image = _resize_with_antialiasing(image, (224, 224))\n        image = (image + 1.0) / 2.0\n\n        image = self.feature_extractor(\n            images=image,\n            do_normalize=True,\n            do_center_crop=False,\n            do_resize=False,\n            do_rescale=False,\n            return_tensors=\"pt\",\n        ).pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeddings = self.image_encoder(image).image_embeds\n        image_embeddings = image_embeddings.unsqueeze(1)\n\n        bs_embed, seq_len, _ = image_embeddings.shape\n        image_embeddings = torch.zeros_like(image_embeddings)\n\n        return image_embeddings\n\n    def _encode_vae_image(\n            self,\n            image: torch.Tensor,\n            device: Union[str, torch.device],\n            num_videos_per_prompt: int,\n    ):\n        image = image.to(device=device, dtype=torch.float16)\n        image_latents = self.vae.encode(image).latent_dist.sample()\n        image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)\n        return image_latents\n\n    def _get_add_time_ids(\n            self,\n            fps: int,\n            motion_bucket_id: int,\n            noise_aug_strength: float,\n            dtype: torch.dtype,\n            batch_size: int,\n            num_videos_per_prompt: int,\n    ):\n        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.\"\n            )\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)\n        return add_time_ids\n\n    def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):\n        latents = latents.flatten(0, 1).to(dtype=torch.float16)\n        latents = 1 / self.vae.config.scaling_factor * latents\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            num_frames_in = latents[i: i + decode_chunk_size].shape[0]\n            frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample\n            frames.append(frame)\n        frames = torch.cat(frames, dim=0)\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)\n        frames = frames.float()\n        return frames\n\n    def check_inputs(self, image, height, width):\n        if (\n                not isinstance(image, torch.Tensor)\n                and not isinstance(image, PIL.Image.Image)\n                and not isinstance(image, list)\n        ):\n            raise ValueError(f\"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}\")\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n    def prepare_latents(\n            self,\n            batch_size: int,\n            num_frames: int,\n            height: int,\n            width: int,\n            dtype: torch.dtype,\n            device: Union[str, torch.device],\n            generator: torch.Generator,\n            latents: Optional[torch.Tensor] = None,\n            initial_latents: Optional[torch.Tensor] = None,\n            denoising_strength: float = 1.0,\n            timestep: Optional[torch.Tensor] = None,\n    ):\n        num_channels_latents = self.unet.config.out_channels\n        shape = (\n            batch_size,\n            num_frames,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n\n        if initial_latents is not None:\n            # Noise is added to the initial latents\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            # Get the initial latents at the given timestep\n            latents = self.scheduler.add_noise(initial_latents, noise, timestep)\n        else:\n            # Standard pure noise generation\n            if latents is None:\n                latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n            else:\n                latents = latents.to(device)\n            # Scale the initial noise by the standard deviation required by the scheduler\n            latents = latents * self.scheduler.init_noise_sigma\n\n        return latents\n\n    def _encode_video_vae(\n            self,\n            video_frames: torch.Tensor,  # Expects (B, F, C, H, W)\n            device: Union[str, torch.device],\n    ):\n        video_frames = video_frames.to(device=device, dtype=self.vae.dtype)\n        batch_size, num_frames = video_frames.shape[:2]\n\n        # Reshape for VAE encoding\n        video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:])  # (B*F, C, H, W)\n        latents = self.vae.encode(video_frames_reshaped).latent_dist.sample()  # (B*F, C_latent, H_latent, W_latent)\n\n        # Reshape back to video format\n        latents = latents.reshape(batch_size, num_frames, *latents.shape[1:])  # (B, F, C_latent, H_latent, W_latent)\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n            self,\n            image: Union[List[PIL.Image.Image], torch.Tensor],\n            mask_image: Union[List[PIL.Image.Image], torch.Tensor],\n            alpha_matte_image: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None,\n            denoising_strength: float = 0.7,\n            height: int = 576,\n            width: int = 1024,\n            num_frames: Optional[int] = None,\n            num_inference_steps: int = 30,\n            sigmas: Optional[List[float]] = None,\n            fps: int = 7,\n            motion_bucket_id: int = 127,\n            noise_aug_strength: float = 0.02,\n            decode_chunk_size: Optional[int] = None,\n            num_videos_per_prompt: Optional[int] = 1,\n            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n            latents: Optional[torch.Tensor] = None,\n            output_type: Optional[str] = \"pil\",\n            return_dict: bool = True,\n            mask_noise_strength: float = 0.0,\n    ):\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        if num_frames is None:\n            if isinstance(image, list):\n                num_frames = len(image)\n            else:\n                num_frames = self.unet.config.num_frames\n\n        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames\n\n        self.check_inputs(image, height, width)\n        self.check_inputs(mask_image, height, width)\n        if alpha_matte_image:\n            self.check_inputs(alpha_matte_image, height, width)\n\n        batch_size = 1\n        device = self._execution_device\n        dtype = self.unet.dtype\n\n        image_for_clip = image[0] if isinstance(image, list) else image[0]\n        image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)\n\n        fps = fps - 1\n\n        image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)\n        mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(device).unsqueeze(0)\n\n        noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)\n        image_tensor = image_tensor + noise_aug_strength * noise\n\n        conditional_latents = self._encode_video_vae(image_tensor, device)\n        conditional_latents = conditional_latents / self.vae.config.scaling_factor\n\n        if self.unet.config.in_channels == 12:\n            mask_latents = self._encode_video_vae(mask_tensor, device)\n            mask_latents = mask_latents / self.vae.config.scaling_factor\n        elif self.unet.config.in_channels == 9:\n            mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)\n            binarized_mask = (mask_tensor_gray > 0.0).to(dtype)\n            b, f, c, h, w = binarized_mask.shape\n            binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)\n            target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)\n            interpolated_mask = F.interpolate(\n                binarized_mask_reshaped,\n                size=target_size,\n                mode='nearest',\n            )\n            mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])\n        else:\n            raise ValueError(f\"Unsupported number of UNet input channels: {self.unet.config.in_channels}.\")\n\n        if mask_noise_strength > 0.0:\n            mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)\n            mask_latents = mask_latents + mask_noise_strength * mask_noise\n\n        added_time_ids = self._get_add_time_ids(\n            fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt\n        )\n        added_time_ids = added_time_ids.to(device)\n\n        # --- MODIFIED FOR ALPHA MATTE REFINEMENT ---\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)\n\n        # self.scheduler.set_timesteps(num_inference_steps, device=device)\n        # timesteps = self.scheduler.timesteps\n        initial_latents = None\n\n        if alpha_matte_image is not None:\n            alpha_matte_tensor = self.video_processor.preprocess(alpha_matte_image, height=height, width=width).to(\n                device).unsqueeze(0)\n            initial_latents = self._encode_video_vae(alpha_matte_tensor, device)\n            initial_latents = initial_latents / self.vae.config.scaling_factor\n\n            # Adjust the number of steps and the timesteps to start from\n            t_start = max(num_inference_steps - int(num_inference_steps * denoising_strength), 0)\n            timesteps = timesteps[t_start:]\n            # We need the first timestep to add the correct amount of noise\n            start_timestep = timesteps[0]\n        else:\n            start_timestep = timesteps[0]  # Not used, but for clarity\n\n        latents = self.prepare_latents(\n            batch_size * num_videos_per_prompt,\n            num_frames,\n            height,\n            width,\n            dtype,\n            device,\n            generator,\n            latents,\n            initial_latents=initial_latents,\n            denoising_strength=denoising_strength,\n            timestep=start_timestep if initial_latents is not None else None,\n        )\n\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n\n        with self.progress_bar(total=len(timesteps)) as progress_bar:\n            for i, t in enumerate(timesteps):\n                latent_model_input = self.scheduler.scale_model_input(latents, t)\n                latent_model_input = torch.cat([latent_model_input, conditional_latents, mask_latents], dim=2)\n\n                noise_pred = self.unet(\n                    latent_model_input, t, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,\n                    return_dict=False\n                )[0]\n\n                latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n        frames = self.decode_latents(latents, num_frames, decode_chunk_size)\n        frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)\n\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return frames\n        return StableVideoDiffusionPipelineOutput(frames=frames)\n\n\nclass StableVideoDiffusionPipelineOnestepWithMask(DiffusionPipeline):\n    r\"\"\"\n    A custom pipeline based on Stable Video Diffusion that accepts an additional mask for conditioning.\n    This pipeline is designed to work with a UNet fine-tuned to accept 12 input channels\n    (4 for noise, 4 for VAE-encoded condition image, 4 for VAE-encoded mask).\n    \"\"\"\n\n    model_cpu_offload_seq = \"image_encoder->unet->vae\"\n    _callback_tensor_inputs = [\"latents\"]\n\n    def __init__(\n            self,\n            vae: AutoencoderKLTemporalDecoder,\n            image_encoder: CLIPVisionModelWithProjection,\n            unet: UNetSpatioTemporalConditionModel,\n            scheduler: EulerDiscreteScheduler,\n            feature_extractor: CLIPImageProcessor,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            scheduler=scheduler,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_image(\n            self,\n            image: PipelineImageInput,\n            device: Union[str, torch.device],\n            num_videos_per_prompt: int,\n    ) -> torch.Tensor:\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.video_processor.pil_to_numpy(image)\n            image = self.video_processor.numpy_to_pt(image)\n\n        image = image * 2.0 - 1.0\n        image = _resize_with_antialiasing(image, (224, 224))\n        image = (image + 1.0) / 2.0\n\n        image = self.feature_extractor(\n            images=image,\n            do_normalize=True,\n            do_center_crop=False,\n            do_resize=False,\n            do_rescale=False,\n            return_tensors=\"pt\",\n        ).pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        image_embeddings = self.image_encoder(image).image_embeds\n        image_embeddings = image_embeddings.unsqueeze(1)\n\n        bs_embed, seq_len, _ = image_embeddings.shape\n        image_embeddings = torch.zeros_like(image_embeddings)\n\n        return image_embeddings\n\n    def _encode_vae_image(\n            self,\n            image: torch.Tensor,\n            device: Union[str, torch.device],\n            num_videos_per_prompt: int,\n    ):\n        image = image.to(device=device, dtype=torch.float16)\n        image_latents = self.vae.encode(image).latent_dist.sample()\n        image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)\n        return image_latents\n\n    def _get_add_time_ids(\n            self,\n            fps: int,\n            motion_bucket_id: int,\n            noise_aug_strength: float,\n            dtype: torch.dtype,\n            batch_size: int,\n            num_videos_per_prompt: int,\n    ):\n        add_time_ids = [fps, motion_bucket_id, noise_aug_strength]\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.\"\n            )\n        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)\n        add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)\n        return add_time_ids\n\n    def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):\n        latents = latents.flatten(0, 1).to(dtype=torch.float16)\n        latents = 1 / self.vae.config.scaling_factor * latents\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            num_frames_in = latents[i: i + decode_chunk_size].shape[0]\n            frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=num_frames_in).sample\n            frames.append(frame)\n        frames = torch.cat(frames, dim=0)\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)\n        frames = frames.float()\n        return frames\n\n    def check_inputs(self, image, height, width):\n        if (\n                not isinstance(image, torch.Tensor)\n                and not isinstance(image, PIL.Image.Image)\n                and not isinstance(image, list)\n        ):\n            raise ValueError(f\"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}\")\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n    def prepare_latents(\n            self,\n            batch_size: int,\n            num_frames: int,\n            height: int,\n            width: int,\n            dtype: torch.dtype,\n            device: Union[str, torch.device],\n            generator: torch.Generator,\n            latents: Optional[torch.Tensor] = None,\n    ):\n        # The number of channels for the initial noise is based on the UNet's out_channels\n        num_channels_latents = self.unet.config.out_channels\n        shape = (\n            batch_size,\n            num_frames,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(f\"batch size {batch_size} must match the length of the generators {len(generator)}.\")\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def _encode_video_vae(\n            self,\n            video_frames: torch.Tensor,  # Expects (B, F, C, H, W)\n            device: Union[str, torch.device],\n    ):\n        video_frames = video_frames.to(device=device, dtype=self.vae.dtype)\n        batch_size, num_frames = video_frames.shape[:2]\n\n        # Reshape for VAE encoding\n        video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:])  # (B*F, C, H, W)\n        latents = self.vae.encode(video_frames_reshaped).latent_dist.sample()  # (B*F, C_latent, H_latent, W_latent)\n\n        # Reshape back to video format\n        latents = latents.reshape(batch_size, num_frames, *latents.shape[1:])  # (B, F, C_latent, H_latent, W_latent)\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n            self,\n            image: Union[List[PIL.Image.Image], torch.Tensor],\n            mask_image: Union[List[PIL.Image.Image], torch.Tensor],\n            height: int = 576,\n            width: int = 1024,\n            num_frames: Optional[int] = None,\n            fps: int = 7,\n            motion_bucket_id: int = 127,\n            noise_aug_strength: float = 0.0,\n            decode_chunk_size: Optional[int] = None,\n            num_videos_per_prompt: Optional[int] = 1,\n            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n            latents: Optional[torch.Tensor] = None,\n            output_type: Optional[str] = \"pil\",\n            return_dict: bool = True,\n            mask_noise_strength: float = 0.0,\n    ):\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        if num_frames is None:\n            if isinstance(image, list):\n                num_frames = len(image)\n            else:\n                num_frames = self.unet.config.num_frames\n\n        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames\n\n        self.check_inputs(image, height, width)\n        self.check_inputs(mask_image, height, width)\n        if isinstance(image, list) and isinstance(mask_image, list):\n            if len(image) != len(mask_image):\n                raise ValueError(\"`image` and `mask_image` must have the same number of frames.\")\n            if num_frames != len(image):\n                logger.warning(\n                    f\"Mismatch between `num_frames` ({num_frames}) and number of input images ({len(image)}). Using {len(image)}.\")\n                num_frames = len(image)\n\n        batch_size = 1\n        device = self._execution_device\n        dtype = self.unet.dtype\n\n        image_for_clip = image[0] if isinstance(image, list) else image[0]\n        image_embeddings = self._encode_image(image_for_clip, device, num_videos_per_prompt)\n\n        fps = fps - 1\n\n        image_tensor = self.video_processor.preprocess(image, height=height, width=width).to(device).unsqueeze(0)\n        mask_tensor = self.video_processor.preprocess(mask_image, height=height, width=width).to(\n            device).unsqueeze(0)\n\n        noise = randn_tensor(image_tensor.shape, generator=generator, device=device, dtype=dtype)\n        image_tensor = image_tensor + noise_aug_strength * noise\n\n        conditional_latents = self._encode_video_vae(image_tensor, device)\n        conditional_latents = conditional_latents / self.vae.config.scaling_factor\n\n        if self.unet.config.in_channels == 12:\n            mask_latents = self._encode_video_vae(mask_tensor, device)\n            mask_latents = mask_latents / self.vae.config.scaling_factor\n        elif self.unet.config.in_channels == 9:\n            mask_tensor_gray = mask_tensor.mean(dim=2, keepdim=True)\n            binarized_mask = (mask_tensor_gray > 0.0).to(dtype)\n            b, f, c, h, w = binarized_mask.shape\n            binarized_mask_reshaped = binarized_mask.reshape(b * f, c, h, w)\n            target_size = (height // self.vae_scale_factor, width // self.vae_scale_factor)\n            interpolated_mask = F.interpolate(\n                binarized_mask_reshaped,\n                size=target_size,\n                mode='nearest',\n            )\n            mask_latents = interpolated_mask.reshape(b, f, *interpolated_mask.shape[1:])\n        else:\n            raise ValueError(\n                f\"Unsupported number of UNet input channels: {self.unet.config.in_channels}. \"\n                \"This pipeline only supports 9 (for interpolated mask) or 12 (for VAE mask).\"\n            )\n\n        if mask_noise_strength > 0.0:\n            mask_noise = randn_tensor(mask_latents.shape, generator=generator, device=device, dtype=dtype)\n            mask_latents = mask_latents + mask_noise_strength * mask_noise\n\n        added_time_ids = self._get_add_time_ids(\n            fps, motion_bucket_id, noise_aug_strength, image_embeddings.dtype, batch_size, num_videos_per_prompt\n        )\n        added_time_ids = added_time_ids.to(device)\n\n        # **MODIFIED FOR SINGLE-STEP**: Prepare initial noise\n        num_channels_latents = self.unet.config.out_channels\n        shape = (\n            batch_size * num_videos_per_prompt,\n            num_frames,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # **MODIFIED FOR SINGLE-STEP**: Set a fixed high timestep\n        timestep = torch.tensor([1.0], dtype=dtype, device=device)  # Use a high sigma value\n\n        # **MODIFIED FOR SINGLE-STEP**: Single forward pass\n        latent_model_input = torch.cat([latents, conditional_latents, mask_latents], dim=2)\n\n        noise_pred = self.unet(\n            latent_model_input, timestep, encoder_hidden_states=image_embeddings, added_time_ids=added_time_ids,\n            return_dict=False\n        )[0]\n\n        # The model's prediction is the final denoised latent\n        denoised_latents = noise_pred\n\n        frames = self.decode_latents(denoised_latents, num_frames, decode_chunk_size)\n        frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)\n\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return frames\n        return StableVideoDiffusionPipelineOutput(frames=frames)\n\n\nclass StableVideoDiffusionPipelineWithCrossAtnnMask(DiffusionPipeline):\n    model_cpu_offload_seq = \"image_encoder->unet->vae\"\n    _callback_tensor_inputs = [\"latents\"]\n\n    def __init__(\n            self,\n            vae: AutoencoderKLTemporalDecoder,\n            unet: UNetSpatioTemporalConditionModel,\n            scheduler: EulerDiscreteScheduler,\n            mask_projector: torch.nn.Module,\n            # CLIP models are not strictly needed for inference if embeddings are not used\n            image_encoder: CLIPVisionModelWithProjection = None,\n            feature_extractor: CLIPImageProcessor = None,\n    ):\n        super().__init__()\n        self.register_modules(\n            vae=vae,\n            unet=unet,\n            scheduler=scheduler,\n            mask_projector=mask_projector,\n            image_encoder=image_encoder,\n            feature_extractor=feature_extractor,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)\n\n    def _encode_image_vae(self, image: torch.Tensor, device: Union[str, torch.device]):\n        image = image.to(device=device, dtype=self.vae.dtype)\n        latent = self.vae.encode(image).latent_dist.sample()\n        return latent\n\n    def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int):\n        latents = latents.flatten(0, 1).to(dtype=torch.float16)\n        latents = 1 / self.vae.config.scaling_factor * latents\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            frame = self.vae.decode(latents[i: i + decode_chunk_size], num_frames=decode_chunk_size).sample\n            frames.append(frame)\n\n        frames = torch.cat(frames, dim=0)\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)\n        frames = frames.float()\n        return frames\n\n    def _encode_video_vae(\n            self,\n            video_frames: torch.Tensor,  # Expects (B, F, C, H, W)\n            device: Union[str, torch.device],\n    ):\n        video_frames = video_frames.to(device=device, dtype=self.vae.dtype)\n        batch_size, num_frames = video_frames.shape[:2]\n\n        # Reshape for VAE encoding\n        video_frames_reshaped = video_frames.reshape(batch_size * num_frames, *video_frames.shape[2:])  # (B*F, C, H, W)\n        latents = self.vae.encode(video_frames_reshaped).latent_dist.sample()  # (B*F, C_latent, H_latent, W_latent)\n\n        # Reshape back to video format\n        latents = latents.reshape(batch_size, num_frames, *latents.shape[1:])  # (B, F, C_latent, H_latent, W_latent)\n\n        return latents\n\n    @torch.no_grad()\n    def __call__(\n            self,\n            image: Union[PIL.Image.Image, torch.Tensor],  # Static image for appearance\n            mask_image: List[PIL.Image.Image],  # Video mask for motion\n            height: int = 576,\n            width: int = 1024,\n            num_frames: Optional[int] = None,\n            num_inference_steps: int = 25,\n            fps: int = 7,\n            motion_bucket_id: int = 127,\n            noise_aug_strength: float = 0.0,  # Noise is added to latents now\n            decode_chunk_size: Optional[int] = 8,\n            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n            output_type: Optional[str] = \"pil\",\n            return_dict: bool = True,\n    ):\n        device = self._execution_device\n        dtype = self.unet.dtype\n        num_frames = num_frames if num_frames is not None else len(mask_image)\n        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames\n\n        # 1. PREPARE STATIC IMAGE CONDITION\n        image_tensor = self.video_processor.preprocess(image, height, width).to(device).unsqueeze(0)\n        conditional_latents = self._encode_video_vae(image_tensor, device)\n        conditional_latents = conditional_latents / self.vae.config.scaling_factor\n\n        # 2. PREPARE MASK MOTION CONDITION\n        mask_tensor = self.video_processor.preprocess(mask_image, height, width)\n        if mask_tensor.shape[1] > 1:\n            mask_tensor = mask_tensor.mean(dim=1, keepdim=True)\n\n        # Reshape for projector: (T, C, H, W)\n        mask_for_projection = rearrange(mask_tensor, \"f c h w -> f c h w\").to(device, dtype)\n        encoder_hidden_states = self.mask_projector(mask_for_projection)\n        encoder_hidden_states = encoder_hidden_states.unsqueeze(1)  # (T, 1, D)\n        # Add batch dimension for UNet\n        encoder_hidden_states = encoder_hidden_states.unsqueeze(0)  # (1, T, 1, D)\n        # The UNet will handle flattening this to (B*T, 1, D) where B=1\n        # To be safe, we pass it pre-flattened.\n        encoder_hidden_states = rearrange(encoder_hidden_states, \"b f s d -> (b f) s d\")\n\n        # 3. PREPARE LATENTS\n        shape = (1, num_frames, self.unet.config.out_channels, height // self.vae_scale_factor,\n                 width // self.vae_scale_factor)\n        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        if noise_aug_strength > 0:\n            latents += noise_aug_strength * randn_tensor(latents.shape, generator=generator, device=device,\n                                                         dtype=dtype)\n        latents = latents * self.scheduler.init_noise_sigma\n\n        # 4. GET ADDED TIME IDS\n        # For pipeline, batch size is 1\n        added_time_ids = [fps - 1, motion_bucket_id, 0.0]  # noise_aug_strength for add_time_ids is 0 for inference\n        added_time_ids = torch.tensor([added_time_ids], dtype=dtype, device=device)\n\n        # 5. DENOISING LOOP\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for t in timesteps:\n                latent_model_input = self.scheduler.scale_model_input(latents, t)\n                unet_input = torch.cat([latent_model_input, conditional_latents], dim=2)\n\n                noise_pred = self.unet(\n                    unet_input, t, encoder_hidden_states=encoder_hidden_states, added_time_ids=added_time_ids\n                ).sample\n\n                latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n                progress_bar.update()\n\n        # 6. DECODE\n        frames = self.decode_latents(latents, num_frames, decode_chunk_size)\n        frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)\n\n        if not return_dict:\n            return (frames,)\n        return StableVideoDiffusionPipelineOutput(frames=frames)\n\n\n# pipeline.py\n\nimport torch\nimport torch.nn.functional as F\nfrom PIL import Image\nfrom einops import rearrange\nfrom torchvision import transforms\nfrom diffusers import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\n\nclass VideoInferencePipeline:\n    \"\"\"\n    A reusable pipeline for single-step video diffusion inference.\n\n    This class encapsulates the models and the core inference logic,\n    separating it from data loading and saving, which can vary between tasks.\n    \"\"\"\n\n    def __init__(self, base_model_path: str, unet_checkpoint_path: str, device: str = \"cuda\",\n                 weight_dtype: torch.dtype = torch.float16):\n        \"\"\"\n        Loads all necessary models into memory.\n\n        Args:\n            base_model_path (str): Path to the base Stable Video Diffusion model.\n            unet_checkpoint_path (str): Path to the fine-tuned UNet checkpoint.\n            device (str): The device to run models on ('cuda' or 'cpu').\n            weight_dtype (torch.dtype): The precision for model weights (float16 or bfloat16).\n        \"\"\"\n        logger.info(\"--- Initializing Inference Pipeline and Loading Models ---\")\n        self.device = torch.device(device if torch.cuda.is_available() else \"cpu\")\n        self.weight_dtype = weight_dtype\n\n        # Load models from pretrained paths\n        try:\n            self.feature_extractor = CLIPImageProcessor.from_pretrained(base_model_path, subfolder=\"feature_extractor\")\n            self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(base_model_path,\n                                                                               subfolder=\"image_encoder\",\n                                                                               variant=\"fp16\")\n            self.vae = AutoencoderKLTemporalDecoder.from_pretrained(base_model_path, subfolder=\"vae\", variant=\"fp16\")\n            self.unet = UNetSpatioTemporalConditionModel.from_pretrained(unet_checkpoint_path, subfolder=\"unet\")\n        except Exception as e:\n            raise IOError(f\"Fatal error loading models: {e}\")\n\n        # Move models to the specified device and set to evaluation mode\n        # CLIP must run in FP32 to avoid CUBLAS errors\n        self.image_encoder.to(self.device, dtype=torch.float32).eval()\n        # VAE must also run in FP32 to avoid CUBLAS errors\n        self.vae.to(self.device, dtype=torch.float32).eval()\n        self.unet.to(self.device, dtype=self.weight_dtype).eval()\n\n        logger.info(\"--- Models Loaded Successfully on %s ---\", self.device)\n\n    def run(self, cond_frames, mask_frames, seed=42, mask_cond_mode=\"vae\", fps=7, motion_bucket_id=127,\n            noise_aug_strength=0.0):\n        \"\"\"\n        Runs the core inference process on a sequence of conditioning and mask frames.\n\n        Args:\n            cond_frames (list[Image.Image]): List of PIL images for conditioning.\n            mask_frames (list[Image.Image]): List of PIL images for the masks.\n            seed (int): Random seed for generation.\n            mask_cond_mode (str): How the mask is conditioned (\"vae\" or \"interpolate\").\n            fps (int): Frames per second to condition the model with.\n            motion_bucket_id (int): Motion bucket ID for conditioning.\n            noise_aug_strength (float): Noise augmentation strength.\n\n        Returns:\n            list[Image.Image]: A list of the generated video frames as PIL Images.\n        \"\"\"\n        # --- 1. Prepare Tensors ---\n        cond_video_tensor = self._pil_to_tensor(cond_frames).to(self.device)\n        mask_video_tensor = self._pil_to_tensor(mask_frames).to(self.device)\n\n        if mask_video_tensor.shape[2] != 3:\n            mask_video_tensor = mask_video_tensor.repeat(1, 1, 3, 1, 1)\n\n        with torch.no_grad():\n            # --- 2. Get CLIP Image Embeddings ---\n            first_frame_tensor = cond_video_tensor[:, 0, :, :, :]\n            pixel_values_for_clip = self._resize_with_antialiasing(first_frame_tensor, (224, 224))\n            pixel_values_for_clip = ((pixel_values_for_clip + 1.0) / 2.0).clamp(0, 1)\n            pixel_values = self.feature_extractor(images=pixel_values_for_clip, do_rescale=False, return_tensors=\"pt\").pixel_values\n            # Run CLIP in FP32\n            image_embeddings = self.image_encoder(pixel_values.to(self.device, dtype=torch.float32)).image_embeds\n            \n            logger.debug(\"CLIP Embeds Max: %.4f, Mean: %.4f\", image_embeddings.max().item(), image_embeddings.mean().item())\n\n            # Setup for UNet which uses weight_dtype (likely FP16)\n            image_embeddings = image_embeddings.to(dtype=self.weight_dtype)\n            encoder_hidden_states = torch.zeros_like(image_embeddings).unsqueeze(1)\n\n            # --- 3. Prepare Latents ---\n            # VAE encoding must happen in FP32\n            cond_video_tensor_fp32 = cond_video_tensor.to(dtype=torch.float32)\n            cond_latents = self._tensor_to_vae_latent(cond_video_tensor_fp32)\n            \n            logger.debug(\"Cond Latents Max: %.4f, Mean: %.4f\", cond_latents.max().item(), cond_latents.mean().item())\n\n            # Cast back to weight_dtype (FP16) for UNet\n            cond_latents = cond_latents.to(dtype=self.weight_dtype)\n            cond_latents = cond_latents / self.vae.config.scaling_factor\n\n            if mask_cond_mode == \"vae\":\n                mask_video_tensor_fp32 = mask_video_tensor.to(dtype=torch.float32)\n                mask_latents = self._tensor_to_vae_latent(mask_video_tensor_fp32)\n                logger.debug(\"Mask Latents Max: %.4f, Mean: %.4f\", mask_latents.max().item(), mask_latents.mean().item())\n                mask_latents = mask_latents.to(dtype=self.weight_dtype)\n                mask_latents = mask_latents / self.vae.config.scaling_factor\n            elif mask_cond_mode == \"interpolate\":\n                target_shape = cond_latents.shape[-2:]\n                b, t, c, h, w = mask_video_tensor.shape\n                mask_video_reshaped = rearrange(mask_video_tensor, \"b t c h w -> (b t) c h w\")\n                interpolated_mask = F.interpolate(mask_video_reshaped, size=target_shape, mode='bilinear',\n                                                  align_corners=False)\n                mask_latents = rearrange(interpolated_mask, \"(b t) c h w -> b t c h w\", b=b)\n            else:\n                raise ValueError(f\"Unknown mask_cond_mode: {mask_cond_mode}\")\n\n            # --- 4. Run UNet Single-Step Inference ---\n            generator = torch.Generator(device=self.device).manual_seed(seed)\n            noisy_latents = torch.randn(cond_latents.shape, generator=generator, device=self.device,\n                                        dtype=self.weight_dtype)\n            timesteps = torch.full((1,), 1.0, device=self.device, dtype=torch.long)\n            added_time_ids = self._get_add_time_ids(fps, motion_bucket_id, noise_aug_strength, batch_size=1)\n\n            unet_input = torch.cat([noisy_latents, cond_latents, mask_latents], dim=2)\n            pred_latents = self.unet(unet_input, timesteps, encoder_hidden_states, added_time_ids=added_time_ids).sample\n            \n            logger.debug(\"Pred Latents Max: %.4f, Mean: %.4f\", pred_latents.max().item(), pred_latents.mean().item())\n\n            # --- 5. Decode Latents to Video Frames ---\n            pred_latents = (1 / self.vae.config.scaling_factor) * pred_latents.squeeze(0)\n\n            frames = []\n            # Process in chunks to avoid VRAM issues, especially for long videos\n            # Decode in FP32\n            pred_latents_fp32 = pred_latents.to(dtype=torch.float32)\n            for i in range(0, pred_latents_fp32.shape[0], 8):\n                chunk = pred_latents_fp32[i: i + 8]\n                decoded_chunk = self.vae.decode(chunk, num_frames=chunk.shape[0]).sample\n                frames.append(decoded_chunk)\n\n            video_tensor = torch.cat(frames, dim=0)\n            logger.debug(\"Video Tensor (Pre-Clamp) Max: %.4f, Mean: %.4f\", video_tensor.max().item(), video_tensor.mean().item())\n            video_tensor = (video_tensor / 2.0 + 0.5).clamp(0, 1).mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)\n\n            # Return a list of PIL images\n            return [transforms.ToPILImage()(frame) for frame in video_tensor]\n\n    def _pil_to_tensor(self, frames: list[Image.Image]):\n        \"\"\"Converts a list of PIL images to a normalized video tensor.\"\"\"\n        video_tensor = torch.stack([transforms.ToTensor()(f) for f in frames]).unsqueeze(0)\n        return video_tensor * 2.0 - 1.0\n\n    def _tensor_to_vae_latent(self, t: torch.Tensor):\n        \"\"\"Encodes a video tensor into the VAE's latent space in chunks to avoid OOM.\"\"\"\n        video_length = t.shape[1]\n        t = rearrange(t, \"b f c h w -> (b f) c h w\")\n        \n        # Process in chunks of 8\n        chunk_size = 8\n        latents_list = []\n        \n        for i in range(0, t.shape[0], chunk_size):\n            chunk = t[i:i + chunk_size]\n            chunk_latents = self.vae.encode(chunk).latent_dist.sample()\n            latents_list.append(chunk_latents)\n            \n        latents = torch.cat(latents_list, dim=0)\n        latents = rearrange(latents, \"(b f) c h w -> b f c h w\", f=video_length)\n        return latents * self.vae.config.scaling_factor\n\n    def _get_add_time_ids(self, fps, motion_bucket_id, noise_aug_strength, batch_size):\n        \"\"\"Creates the additional time IDs for conditioning the UNet.\"\"\"\n        add_time_ids_list = [fps, motion_bucket_id, noise_aug_strength]\n        passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids_list)\n        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features\n        if expected_add_embed_dim != passed_add_embed_dim:\n            raise ValueError(\n                f\"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created.\")\n        add_time_ids = torch.tensor([add_time_ids_list], dtype=self.weight_dtype, device=self.device)\n        return add_time_ids.repeat(batch_size, 1)\n\n    def _resize_with_antialiasing(self, input_tensor, size, interpolation=\"bicubic\", align_corners=True):\n        \"\"\"\n        Resizes a tensor with anti-aliasing for CLIP input, mirroring k-diffusion.\n        This is a direct copy of the helper function from your original scripts.\n        \"\"\"\n        h, w = input_tensor.shape[-2:]\n        factors = (h / size[0], w / size[1])\n        sigmas = (max((factors[0] - 1.0) / 2.0, 0.001), max((factors[1] - 1.0) / 2.0, 0.001))\n        ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))\n        if (ks[0] % 2) == 0: ks = ks[0] + 1, ks[1]\n        if (ks[1] % 2) == 0: ks = ks[0], ks[1] + 1\n\n        def _compute_padding(kernel_size):\n            computed = [k - 1 for k in kernel_size]\n            out_padding = 2 * len(kernel_size) * [0]\n            for i in range(len(kernel_size)):\n                computed_tmp = computed[-(i + 1)]\n                pad_front = computed_tmp // 2\n                pad_rear = computed_tmp - pad_front\n                out_padding[2 * i + 0] = pad_front\n                out_padding[2 * i + 1] = pad_rear\n            return out_padding\n\n        def _filter2d(input_tensor, kernel):\n            b, c, h, w = input_tensor.shape\n            tmp_kernel = kernel[:, None, ...].to(device=input_tensor.device, dtype=input_tensor.dtype)\n            tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)\n            height, width = tmp_kernel.shape[-2:]\n            padding_shape = _compute_padding([height, width])\n            input_tensor_padded = F.pad(input_tensor, padding_shape, mode=\"reflect\")\n            tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)\n            input_tensor_padded = input_tensor_padded.view(-1, tmp_kernel.size(0), input_tensor_padded.size(-2),\n                                                           input_tensor_padded.size(-1))\n            output = F.conv2d(input_tensor_padded, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)\n            return output.view(b, c, h, w)\n\n        def _gaussian(window_size, sigma):\n            if isinstance(sigma, float):\n                sigma = torch.tensor([[sigma]])\n            x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(\n                sigma.shape[0], -1)\n            if window_size % 2 == 0:\n                x = x + 0.5\n            gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))\n            return gauss / gauss.sum(-1, keepdim=True)\n\n        def _gaussian_blur2d(input_tensor, kernel_size, sigma):\n            if isinstance(sigma, tuple):\n                sigma = torch.tensor([sigma], dtype=input_tensor.dtype)\n            else:\n                sigma = sigma.to(dtype=input_tensor.dtype)\n            ky, kx = int(kernel_size[0]), int(kernel_size[1])\n            bs = sigma.shape[0]\n            kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))\n            kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))\n            out_x = _filter2d(input_tensor, kernel_x[..., None, :])\n            return _filter2d(out_x, kernel_y[..., None])\n\n        blurred_input = _gaussian_blur2d(input_tensor, ks, sigmas)\n        return F.interpolate(blurred_input, size=size, mode=interpolation, align_corners=align_corners)"
  },
  {
    "path": "backend/__init__.py",
    "content": "\"\"\"Backend service layer for ez-CorridorKey.\"\"\"\n\nfrom .clip_state import (\n    ClipAsset,\n    ClipEntry,\n    ClipState,\n    InOutRange,\n    scan_clips_dir,\n    scan_project_clips,\n)\nfrom .errors import CorridorKeyError\nfrom .job_queue import GPUJob, GPUJobQueue, JobStatus, JobType\nfrom .natural_sort import natsorted, natural_sort_key\nfrom .project import (\n    VIDEO_FILE_FILTER,\n    add_clips_to_project,\n    create_project,\n    get_clip_dirs,\n    get_display_name,\n    is_image_file,\n    is_v2_project,\n    is_video_file,\n    projects_root,\n    read_clip_json,\n    read_project_json,\n    sanitize_stem,\n    set_display_name,\n    write_clip_json,\n    write_project_json,\n)\nfrom .service import CorridorKeyService, InferenceParams, OutputConfig\n\n__all__ = [\n    # Service\n    \"CorridorKeyService\",\n    \"InferenceParams\",\n    \"OutputConfig\",\n    # Clip state\n    \"ClipAsset\",\n    \"ClipEntry\",\n    \"ClipState\",\n    \"InOutRange\",\n    \"scan_clips_dir\",\n    \"scan_project_clips\",\n    # Job queue\n    \"GPUJob\",\n    \"GPUJobQueue\",\n    \"JobType\",\n    \"JobStatus\",\n    # Errors\n    \"CorridorKeyError\",\n    # Project utilities\n    \"projects_root\",\n    \"create_project\",\n    \"add_clips_to_project\",\n    \"sanitize_stem\",\n    \"get_clip_dirs\",\n    \"is_v2_project\",\n    \"write_project_json\",\n    \"read_project_json\",\n    \"write_clip_json\",\n    \"read_clip_json\",\n    \"get_display_name\",\n    \"set_display_name\",\n    \"is_video_file\",\n    \"is_image_file\",\n    \"VIDEO_FILE_FILTER\",\n    # Natural sort\n    \"natural_sort_key\",\n    \"natsorted\",\n]\n"
  },
  {
    "path": "backend/clip_state.py",
    "content": "\"\"\"Clip entry data model and state machine.\n\nState Machine:\n    EXTRACTING — Video input being extracted to image sequence\n    RAW        — Input asset found, no alpha hint yet\n    MASKED     — User mask provided (for VideoMaMa workflow)\n    READY      — Alpha hint available (from GVM or VideoMaMa), ready for inference\n    COMPLETE   — Inference outputs written\n    ERROR      — Processing failed (can retry)\n\nTransitions:\n    EXTRACTING → RAW   (extraction completes)\n    EXTRACTING → ERROR (extraction fails)\n    RAW → MASKED       (user provides VideoMaMa mask)\n    RAW → READY        (GVM auto-generates alpha)\n    RAW → ERROR        (GVM/scan fails)\n    MASKED → READY     (VideoMaMa generates alpha from user mask)\n    MASKED → ERROR     (VideoMaMa fails)\n    READY → COMPLETE   (inference succeeds)\n    READY → ERROR      (inference fails)\n    ERROR → RAW        (retry from scratch)\n    ERROR → MASKED     (retry with mask)\n    ERROR → READY      (retry inference)\n    ERROR → EXTRACTING (retry extraction)\n    COMPLETE → READY   (reprocess with different params)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport glob as glob_module\nimport logging\nimport os\nfrom dataclasses import dataclass, field\nfrom enum import Enum\n\nfrom .errors import ClipScanError, InvalidStateTransitionError\nfrom .natural_sort import natsorted\nfrom .project import is_image_file as _is_image_file\nfrom .project import is_video_file as _is_video_file\n\nlogger = logging.getLogger(__name__)\n\n\nclass ClipState(Enum):\n    EXTRACTING = \"EXTRACTING\"\n    RAW = \"RAW\"\n    MASKED = \"MASKED\"\n    READY = \"READY\"\n    COMPLETE = \"COMPLETE\"\n    ERROR = \"ERROR\"\n\n\n# Valid transitions: from_state -> set of allowed to_states\n_TRANSITIONS: dict[ClipState, set[ClipState]] = {\n    ClipState.EXTRACTING: {ClipState.RAW, ClipState.ERROR},\n    ClipState.RAW: {ClipState.MASKED, ClipState.READY, ClipState.ERROR},\n    ClipState.MASKED: {ClipState.READY, ClipState.ERROR},\n    ClipState.READY: {ClipState.COMPLETE, ClipState.ERROR},\n    ClipState.COMPLETE: {ClipState.READY},  # reprocess with different params\n    ClipState.ERROR: {ClipState.RAW, ClipState.MASKED, ClipState.READY, ClipState.EXTRACTING},\n}\n\n\n@dataclass\nclass ClipAsset:\n    \"\"\"Represents an input source — either an image sequence directory or a video file.\"\"\"\n\n    path: str\n    asset_type: str  # 'sequence' or 'video'\n    frame_count: int = 0\n\n    def __post_init__(self):\n        self._calculate_length()\n\n    def _calculate_length(self):\n        if self.asset_type == \"sequence\":\n            if os.path.isdir(self.path):\n                files = [f for f in os.listdir(self.path) if _is_image_file(f)]\n                self.frame_count = len(files)\n            else:\n                self.frame_count = 0\n        elif self.asset_type == \"video\":\n            try:\n                import cv2\n\n                cap = cv2.VideoCapture(self.path)\n                if cap.isOpened():\n                    self.frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n                    if self.frame_count == 0:\n                        logger.warning(f\"Video reports 0 frames, file may be corrupted: {self.path}\")\n                cap.release()\n            except Exception as e:\n                logger.debug(f\"Video frame count detection failed for {self.path}: {e}\")\n                self.frame_count = 0\n\n    def get_frame_files(self) -> list[str]:\n        \"\"\"Return naturally sorted list of frame filenames for sequence assets.\n\n        Uses natural sort so frame_2 sorts before frame_10 (not lexicographic).\n        \"\"\"\n        if self.asset_type != \"sequence\" or not os.path.isdir(self.path):\n            return []\n        return natsorted([f for f in os.listdir(self.path) if _is_image_file(f)])\n\n\n@dataclass\nclass InOutRange:\n    \"\"\"In/out frame range for sub-clip processing. Both indices inclusive, 0-based.\"\"\"\n\n    in_point: int\n    out_point: int\n\n    @property\n    def frame_count(self) -> int:\n        return self.out_point - self.in_point + 1\n\n    def contains(self, index: int) -> bool:\n        return self.in_point <= index <= self.out_point\n\n    def to_dict(self) -> dict:\n        return {\"in_point\": self.in_point, \"out_point\": self.out_point}\n\n    @classmethod\n    def from_dict(cls, d: dict) -> InOutRange:\n        return cls(in_point=d[\"in_point\"], out_point=d[\"out_point\"])\n\n\n@dataclass\nclass ClipEntry:\n    \"\"\"A single shot/clip with its assets and processing state.\"\"\"\n\n    name: str\n    root_path: str\n    state: ClipState = ClipState.RAW\n    input_asset: ClipAsset | None = None\n    alpha_asset: ClipAsset | None = None\n    mask_asset: ClipAsset | None = None  # User-provided VideoMaMa mask\n    in_out_range: InOutRange | None = None  # Per-clip in/out markers (None = full clip)\n    warnings: list[str] = field(default_factory=list)\n    error_message: str | None = None\n    extraction_progress: float = 0.0  # 0.0 to 1.0 during EXTRACTING\n    extraction_total: int = 0  # total frames expected during extraction\n    _processing: bool = field(default=False, repr=False)  # lock: watcher must not reclassify\n\n    @property\n    def is_processing(self) -> bool:\n        \"\"\"True while a GPU job is actively working on this clip.\"\"\"\n        return self._processing\n\n    def set_processing(self, value: bool) -> None:\n        \"\"\"Set processing lock. Watcher skips reclassification while True.\"\"\"\n        self._processing = value\n\n    def transition_to(self, new_state: ClipState) -> None:\n        \"\"\"Attempt a state transition. Raises InvalidStateTransitionError if not allowed.\"\"\"\n        if new_state not in _TRANSITIONS.get(self.state, set()):\n            raise InvalidStateTransitionError(self.name, self.state.value, new_state.value)\n        old = self.state\n        self.state = new_state\n        if new_state != ClipState.ERROR:\n            self.error_message = None\n        logger.debug(f\"Clip '{self.name}': {old.value} -> {new_state.value}\")\n\n    def set_error(self, message: str) -> None:\n        \"\"\"Transition to ERROR state with a message.\n\n        Works from any state that allows ERROR transition\n        (RAW, MASKED, READY — all can error now).\n        \"\"\"\n        self.transition_to(ClipState.ERROR)\n        self.error_message = message\n\n    @property\n    def output_dir(self) -> str:\n        return os.path.join(self.root_path, \"Output\")\n\n    @property\n    def has_outputs(self) -> bool:\n        \"\"\"Check if output directory exists with content.\"\"\"\n        out = self.output_dir\n        if not os.path.isdir(out):\n            return False\n        for subdir in (\"FG\", \"Matte\", \"Comp\", \"Processed\"):\n            d = os.path.join(out, subdir)\n            if os.path.isdir(d) and os.listdir(d):\n                return True\n        return False\n\n    def completed_frame_count(self) -> int:\n        \"\"\"Count existing output frames for resume support.\n\n        Manifest-aware: reads .corridorkey_manifest.json to determine which\n        outputs were enabled. Falls back to FG+Matte intersection if no manifest.\n        \"\"\"\n        return len(self.completed_stems())\n\n    def completed_stems(self) -> set[str]:\n        \"\"\"Return set of frame stems that have all enabled outputs complete.\n\n        Reads the run manifest to determine which outputs to check.\n        Falls back to FG+Matte intersection if no manifest exists.\n        \"\"\"\n        manifest = self._read_manifest()\n        if manifest:\n            enabled = manifest.get(\"enabled_outputs\", [])\n        else:\n            enabled = [\"fg\", \"matte\"]\n\n        dir_map = {\n            \"fg\": os.path.join(self.output_dir, \"FG\"),\n            \"matte\": os.path.join(self.output_dir, \"Matte\"),\n            \"comp\": os.path.join(self.output_dir, \"Comp\"),\n            \"processed\": os.path.join(self.output_dir, \"Processed\"),\n        }\n\n        stem_sets = []\n        for output_name in enabled:\n            d = dir_map.get(output_name)\n            if d and os.path.isdir(d):\n                stems = {os.path.splitext(f)[0] for f in os.listdir(d) if _is_image_file(f)}\n                stem_sets.append(stems)\n            else:\n                # Required dir missing → no complete frames\n                return set()\n\n        if not stem_sets:\n            return set()\n\n        # Intersection: frame complete only if ALL enabled outputs exist\n        result = stem_sets[0]\n        for s in stem_sets[1:]:\n            result &= s\n        return result\n\n    def _read_manifest(self) -> dict | None:\n        \"\"\"Read the run manifest if it exists.\"\"\"\n        manifest_path = os.path.join(self.output_dir, \".corridorkey_manifest.json\")\n        if not os.path.isfile(manifest_path):\n            return None\n        try:\n            import json\n\n            with open(manifest_path, \"r\") as f:\n                return json.load(f)\n        except Exception as e:\n            logger.debug(f\"Failed to read manifest at {manifest_path}: {e}\")\n            return None\n\n    def _resolve_original_path(self) -> str | None:\n        \"\"\"Resolve the original video path from clip.json or project.json.\"\"\"\n        from .project import _read_clip_or_project_json\n\n        data = _read_clip_or_project_json(self.root_path)\n        if not data:\n            return None\n        source = data.get(\"source\", {})\n        path = source.get(\"original_path\")\n        if path and os.path.isfile(path):\n            return path\n        return None\n\n    def find_assets(self) -> None:\n        \"\"\"Scan the clip directory for Input, AlphaHint, and mask assets.\n\n        Updates state accordingly. Supports both new format (Frames/, Source/)\n        and legacy format (Input/, Input.*) for backward compatibility.\n        \"\"\"\n        # Input asset — check new names first, fall back to legacy\n        frames_dir = os.path.join(self.root_path, \"Frames\")\n        input_dir = os.path.join(self.root_path, \"Input\")\n        source_dir = os.path.join(self.root_path, \"Source\")\n\n        if os.path.isdir(frames_dir) and os.listdir(frames_dir):\n            self.input_asset = ClipAsset(frames_dir, \"sequence\")\n        elif os.path.isdir(input_dir) and os.listdir(input_dir):\n            self.input_asset = ClipAsset(input_dir, \"sequence\")\n        elif os.path.isdir(source_dir):\n            videos = [f for f in os.listdir(source_dir) if _is_video_file(f)]\n            if videos:\n                self.input_asset = ClipAsset(\n                    os.path.join(source_dir, videos[0]),\n                    \"video\",\n                )\n            else:\n                # Source/ exists but is empty — check project.json for external reference\n                original = self._resolve_original_path()\n                if original:\n                    self.input_asset = ClipAsset(original, \"video\")\n                else:\n                    raise ClipScanError(f\"Clip '{self.name}': 'Source' dir has no video.\")\n        else:\n            candidates = glob_module.glob(os.path.join(self.root_path, \"[Ii]nput.*\"))\n            candidates = [c for c in candidates if _is_video_file(c)]\n            if candidates:\n                self.input_asset = ClipAsset(candidates[0], \"video\")\n            elif os.path.isdir(input_dir):\n                raise ClipScanError(f\"Clip '{self.name}': Input dir is empty — no image files.\")\n            else:\n                raise ClipScanError(f\"Clip '{self.name}': no Input found.\")\n\n        # Load display name from project.json if available\n        from .project import get_display_name\n\n        display = get_display_name(self.root_path)\n        if display != os.path.basename(self.root_path):\n            self.name = display\n\n        # Alpha hint asset\n        alpha_dir = os.path.join(self.root_path, \"AlphaHint\")\n        if os.path.isdir(alpha_dir) and os.listdir(alpha_dir):\n            self.alpha_asset = ClipAsset(alpha_dir, \"sequence\")\n\n        # VideoMaMa mask hint — directory OR video file\n        mask_dir = os.path.join(self.root_path, \"VideoMamaMaskHint\")\n        if os.path.isdir(mask_dir) and os.listdir(mask_dir):\n            self.mask_asset = ClipAsset(mask_dir, \"sequence\")\n        else:\n            # Check for mask video file (VideoMamaMaskHint.mp4 etc.)\n            mask_candidates = glob_module.glob(os.path.join(self.root_path, \"VideoMamaMaskHint.*\"))\n            mask_candidates = [c for c in mask_candidates if _is_video_file(c)]\n            if mask_candidates:\n                self.mask_asset = ClipAsset(mask_candidates[0], \"video\")\n\n        # Load in/out range from project.json\n        from .project import load_in_out_range\n\n        self.in_out_range = load_in_out_range(self.root_path)\n\n        # Determine initial state\n        self._resolve_state()\n\n    def _resolve_state(self) -> None:\n        \"\"\"Set state based on what assets are present on disk.\n\n        Recovers the furthest pipeline stage from disk contents so the\n        user never loses completed work after a restart or crash.\n\n        Priority (highest first):\n          COMPLETE  — all input frames have matching outputs (manifest-aware)\n          READY     — AlphaHint exists (inference-ready)\n          MASKED    — VideoMaMa mask hint exists\n          EXTRACTING — video source exists but no frame sequence yet\n          RAW       — frame sequence exists, no alpha/mask/output\n        \"\"\"\n        # Check COMPLETE first: outputs exist and cover all input frames\n        if self.alpha_asset is not None and self.input_asset is not None:\n            completed = self.completed_stems()\n            if completed and len(completed) >= self.input_asset.frame_count:\n                self.state = ClipState.COMPLETE\n                return\n\n        # READY: AlphaHint must cover ALL input frames (not partial)\n        if self.alpha_asset is not None:\n            if self.input_asset is not None and self.alpha_asset.frame_count < self.input_asset.frame_count:\n                # Partial alpha — don't promote to READY, fall through\n                logger.info(\n                    f\"Clip '{self.name}': partial alpha \"\n                    f\"({self.alpha_asset.frame_count}/{self.input_asset.frame_count}), \"\n                    f\"staying at lower state\"\n                )\n            else:\n                self.state = ClipState.READY\n                return\n\n        if self.mask_asset is not None:\n            self.state = ClipState.MASKED\n        elif self.input_asset is not None and self.input_asset.asset_type == \"video\":\n            # Video input needs extraction to image sequence\n            self.state = ClipState.EXTRACTING\n        else:\n            self.state = ClipState.RAW\n\n\ndef scan_project_clips(project_dir: str) -> list[ClipEntry]:\n    \"\"\"Scan a single project directory for its clips.\n\n    v2 projects (with ``clips/`` subdir): each subdirectory inside clips/ is a clip.\n    v1 projects (no ``clips/`` subdir): the project dir itself is a single clip.\n\n    Args:\n        project_dir: Absolute path to a project folder.\n\n    Returns:\n        List of ClipEntry objects with root_path pointing to clip subdirectories.\n    \"\"\"\n    from .project import is_v2_project\n\n    if is_v2_project(project_dir):\n        clips_dir = os.path.join(project_dir, \"clips\")\n        entries: list[ClipEntry] = []\n        for item in sorted(os.listdir(clips_dir)):\n            item_path = os.path.join(clips_dir, item)\n            if item.startswith(\".\") or item.startswith(\"_\"):\n                continue\n            if not os.path.isdir(item_path):\n                continue\n            clip = ClipEntry(name=item, root_path=item_path)\n            try:\n                clip.find_assets()\n                entries.append(clip)\n            except ClipScanError as e:\n                logger.debug(str(e))\n        logger.info(f\"Scanned v2 project {project_dir}: {len(entries)} clip(s)\")\n        return entries\n\n    # v1 fallback: project_dir is itself a single clip\n    clip = ClipEntry(name=os.path.basename(project_dir), root_path=project_dir)\n    try:\n        clip.find_assets()\n        return [clip]\n    except ClipScanError as e:\n        logger.debug(str(e))\n        return []\n\n\ndef scan_clips_dir(\n    clips_dir: str,\n    allow_standalone_videos: bool = True,\n) -> list[ClipEntry]:\n    \"\"\"Scan a directory for clip folders and optionally standalone video files.\n\n    For the Projects root: iterates project subdirectories and delegates to\n    scan_project_clips() for each, flattening results.\n\n    For non-Projects directories: scans subdirectories directly as clips\n    (legacy behavior for drag-and-dropped folders).\n\n    Folders without valid input assets are skipped (not added as broken clips).\n\n    Args:\n        clips_dir: Path to scan.\n        allow_standalone_videos: If False, loose video files at top level are ignored.\n            Set False for the Projects root where videos live inside Source/ subdirs.\n    \"\"\"\n    entries: list[ClipEntry] = []\n    if not os.path.isdir(clips_dir):\n        logger.warning(f\"Clips directory not found: {clips_dir}\")\n        return entries\n\n    # If the directory itself is a v2 project, scan its clips directly\n    from .project import is_v2_project\n\n    if is_v2_project(clips_dir):\n        return scan_project_clips(clips_dir)\n\n    seen_names: set[str] = set()\n\n    for item in sorted(os.listdir(clips_dir)):\n        item_path = os.path.join(clips_dir, item)\n\n        # Skip hidden and special items\n        if item.startswith(\".\") or item.startswith(\"_\"):\n            continue\n\n        if os.path.isdir(item_path):\n            # Check if this is a v2 project container (has clips/ subdir)\n            from .project import is_v2_project\n\n            if is_v2_project(item_path):\n                # v2 project: scan its clips/ subdirectory\n                for clip in scan_project_clips(item_path):\n                    if clip.name not in seen_names:\n                        entries.append(clip)\n                        seen_names.add(clip.name)\n            else:\n                # Flat clip dir or v1 project\n                clip = ClipEntry(name=item, root_path=item_path)\n                try:\n                    clip.find_assets()\n                    entries.append(clip)\n                    seen_names.add(clip.name)\n                except ClipScanError as e:\n                    # Skip folders without valid input assets\n                    logger.debug(str(e))\n\n        elif allow_standalone_videos and os.path.isfile(item_path) and _is_video_file(item_path):\n            # Standalone video file → treat as a clip needing extraction\n            stem = os.path.splitext(item)[0]\n            if stem in seen_names:\n                continue  # folder clip already exists with this name\n            clip = ClipEntry(name=stem, root_path=clips_dir)\n            clip.input_asset = ClipAsset(item_path, \"video\")\n            clip.state = ClipState.EXTRACTING\n            entries.append(clip)\n            seen_names.add(stem)\n\n    logger.info(f\"Scanned {clips_dir}: {len(entries)} clip(s) found\")\n    return entries\n"
  },
  {
    "path": "backend/errors.py",
    "content": "\"\"\"Typed exceptions for the CorridorKey backend.\"\"\"\n\nimport sys\n\n\nclass CorridorKeyError(Exception):\n    \"\"\"Base exception for all CorridorKey backend errors.\"\"\"\n\n    pass\n\n\nclass ClipScanError(CorridorKeyError):\n    \"\"\"Raised when a clip directory cannot be scanned or is malformed.\"\"\"\n\n    pass\n\n\nclass FrameMismatchError(CorridorKeyError):\n    \"\"\"Raised when input and alpha frame counts don't match.\"\"\"\n\n    def __init__(self, clip_name: str, input_count: int, alpha_count: int):\n        self.clip_name = clip_name\n        self.input_count = input_count\n        self.alpha_count = alpha_count\n        super().__init__(f\"Clip '{clip_name}': frame count mismatch — input has {input_count}, alpha has {alpha_count}\")\n\n\nclass FrameReadError(CorridorKeyError):\n    \"\"\"Raised when a frame file cannot be read.\"\"\"\n\n    def __init__(self, clip_name: str, frame_index: int, path: str):\n        self.clip_name = clip_name\n        self.frame_index = frame_index\n        self.path = path\n        super().__init__(f\"Clip '{clip_name}': failed to read frame {frame_index} ({path})\")\n\n\nclass WriteFailureError(CorridorKeyError):\n    \"\"\"Raised when cv2.imwrite or similar write operation fails.\"\"\"\n\n    def __init__(self, clip_name: str, frame_index: int, path: str):\n        self.clip_name = clip_name\n        self.frame_index = frame_index\n        self.path = path\n        super().__init__(f\"Clip '{clip_name}': failed to write frame {frame_index} ({path})\")\n\n\nclass MaskChannelError(CorridorKeyError):\n    \"\"\"Raised when a mask has unexpected channel count that can't be resolved.\"\"\"\n\n    def __init__(self, clip_name: str, frame_index: int, channels: int):\n        self.clip_name = clip_name\n        self.frame_index = frame_index\n        self.channels = channels\n        super().__init__(f\"Clip '{clip_name}': mask frame {frame_index} has {channels} channels, expected 1 or 3+\")\n\n\nclass VRAMInsufficientError(CorridorKeyError):\n    \"\"\"Raised when there isn't enough GPU VRAM for the requested operation.\"\"\"\n\n    def __init__(self, required_gb: float, available_gb: float):\n        self.required_gb = required_gb\n        self.available_gb = available_gb\n        super().__init__(f\"Insufficient VRAM: {required_gb:.1f}GB required, {available_gb:.1f}GB available\")\n\n\nclass InvalidStateTransitionError(CorridorKeyError):\n    \"\"\"Raised when a clip state transition is not allowed.\"\"\"\n\n    def __init__(self, clip_name: str, current_state: str, target_state: str):\n        self.clip_name = clip_name\n        self.current_state = current_state\n        self.target_state = target_state\n        super().__init__(f\"Clip '{clip_name}': invalid state transition {current_state} -> {target_state}\")\n\n\nclass JobCancelledError(CorridorKeyError):\n    \"\"\"Raised when a GPU job is cancelled by the user.\"\"\"\n\n    def __init__(self, clip_name: str, frame_index: int | None = None):\n        self.clip_name = clip_name\n        self.frame_index = frame_index\n        msg = f\"Clip '{clip_name}': job cancelled\"\n        if frame_index is not None:\n            msg += f\" at frame {frame_index}\"\n        super().__init__(msg)\n\n\nclass FFmpegNotFoundError(CorridorKeyError):\n    \"\"\"Raised when FFmpeg/FFprobe binaries cannot be located.\"\"\"\n\n    def __init__(self):\n        if sys.platform == \"darwin\":\n            hint = \"Install FFmpeg via Homebrew: brew install ffmpeg\"\n        elif sys.platform.startswith(\"linux\"):\n            hint = \"Install FFmpeg via your package manager: sudo apt install ffmpeg\"\n        else:\n            hint = r\"Place ffmpeg.exe in C:\\Program Files\\ffmpeg\\bin\\ or add it to PATH\"\n        super().__init__(f\"FFmpeg not found. {hint}\")\n\n\nclass ExtractionError(CorridorKeyError):\n    \"\"\"Raised when video frame extraction fails.\"\"\"\n\n    def __init__(self, clip_name: str, detail: str):\n        self.clip_name = clip_name\n        self.detail = detail\n        super().__init__(f\"Clip '{clip_name}': extraction failed — {detail}\")\n"
  },
  {
    "path": "backend/ffmpeg_tools.py",
    "content": "\"\"\"FFmpeg subprocess wrapper for video extraction and stitching.\n\nPure Python, no Qt deps. Provides:\n- find_ffmpeg() / find_ffprobe() — locate binaries\n- probe_video() — get fps, resolution, frame count, codec\n- extract_frames() — video -> image sequence (PNG)\n- stitch_video() — image sequence -> video (H.264)\n- write/read_video_metadata() — sidecar JSON for roundtrip fidelity\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport os\nimport re\nimport shutil\nimport subprocess\nimport threading\nfrom typing import Callable\n\nlogger = logging.getLogger(__name__)\n\n_METADATA_FILENAME = \".video_metadata.json\"\n\n# Common install locations on Windows\n_FFMPEG_SEARCH_PATHS = [\n    r\"C:\\Program Files\\ffmpeg\\bin\",\n    r\"C:\\Program Files (x86)\\ffmpeg\\bin\",\n    r\"C:\\ffmpeg\\bin\",\n]\n\n\ndef find_ffmpeg() -> str | None:\n    \"\"\"Locate ffmpeg binary. Checks PATH then common install dirs.\"\"\"\n    found = shutil.which(\"ffmpeg\")\n    if found:\n        return found\n    for d in _FFMPEG_SEARCH_PATHS:\n        candidate = os.path.join(d, \"ffmpeg.exe\")\n        if os.path.isfile(candidate):\n            return candidate\n    return None\n\n\ndef find_ffprobe() -> str | None:\n    \"\"\"Locate ffprobe binary. Checks PATH then common install dirs.\"\"\"\n    found = shutil.which(\"ffprobe\")\n    if found:\n        return found\n    for d in _FFMPEG_SEARCH_PATHS:\n        candidate = os.path.join(d, \"ffprobe.exe\")\n        if os.path.isfile(candidate):\n            return candidate\n    return None\n\n\ndef probe_video(path: str) -> dict:\n    \"\"\"Probe a video file for metadata.\n\n    Returns dict with keys: fps (float), width (int), height (int),\n    frame_count (int), codec (str), duration (float).\n    Raises RuntimeError if ffprobe fails.\n    \"\"\"\n    ffprobe = find_ffprobe()\n    if not ffprobe:\n        raise RuntimeError(\"ffprobe not found\")\n\n    cmd = [\n        ffprobe,\n        \"-v\",\n        \"quiet\",\n        \"-print_format\",\n        \"json\",\n        \"-show_streams\",\n        \"-show_format\",\n        path,\n    ]\n\n    result = subprocess.run(\n        cmd,\n        capture_output=True,\n        text=True,\n        timeout=30,\n        creationflags=subprocess.CREATE_NO_WINDOW if os.name == \"nt\" else 0,\n    )\n    if result.returncode != 0:\n        raise RuntimeError(f\"ffprobe failed: {result.stderr[:500]}\")\n\n    data = json.loads(result.stdout)\n\n    # Find first video stream\n    video_stream = None\n    for stream in data.get(\"streams\", []):\n        if stream.get(\"codec_type\") == \"video\":\n            video_stream = stream\n            break\n\n    if not video_stream:\n        raise RuntimeError(f\"No video stream found in {path}\")\n\n    # Parse fps from r_frame_rate (e.g. \"24000/1001\")\n    fps_str = video_stream.get(\"r_frame_rate\", \"24/1\")\n    if \"/\" in fps_str:\n        num, den = fps_str.split(\"/\")\n        fps = float(num) / float(den) if float(den) != 0 else 24.0\n    else:\n        fps = float(fps_str)\n\n    # Frame count: prefer nb_frames, fall back to duration * fps\n    frame_count = 0\n    if \"nb_frames\" in video_stream:\n        try:\n            frame_count = int(video_stream[\"nb_frames\"])\n        except (ValueError, TypeError):\n            pass\n\n    if frame_count <= 0:\n        duration = float(video_stream.get(\"duration\", 0) or data.get(\"format\", {}).get(\"duration\", 0))\n        if duration > 0:\n            frame_count = int(duration * fps)\n\n    return {\n        \"fps\": round(fps, 4),\n        \"width\": int(video_stream.get(\"width\", 0)),\n        \"height\": int(video_stream.get(\"height\", 0)),\n        \"frame_count\": frame_count,\n        \"codec\": video_stream.get(\"codec_name\", \"unknown\"),\n        \"duration\": float(video_stream.get(\"duration\", 0) or data.get(\"format\", {}).get(\"duration\", 0)),\n    }\n\n\ndef extract_frames(\n    video_path: str,\n    out_dir: str,\n    pattern: str = \"frame_%06d.png\",\n    on_progress: Callable[[int, int], None] | None = None,\n    cancel_event: threading.Event | None = None,\n    total_frames: int = 0,\n) -> int:\n    \"\"\"Extract video frames to PNG image sequence.\n\n    Args:\n        video_path: Path to input video file.\n        out_dir: Directory to write frames into (created if needed).\n        pattern: Frame filename pattern (FFmpeg style).\n        on_progress: Callback(current_frame, total_frames).\n        cancel_event: Set to cancel extraction.\n        total_frames: Expected total (for progress). Probed if 0.\n\n    Returns:\n        Number of frames extracted.\n\n    Raises:\n        RuntimeError if ffmpeg is not found or extraction fails.\n    \"\"\"\n    ffmpeg = find_ffmpeg()\n    if not ffmpeg:\n        raise RuntimeError(\"ffmpeg not found\")\n\n    os.makedirs(out_dir, exist_ok=True)\n\n    # Probe for total if not provided\n    video_info = None\n    if total_frames <= 0:\n        try:\n            video_info = probe_video(video_path)\n            total_frames = video_info.get(\"frame_count\", 0)\n        except Exception:\n            total_frames = 0\n\n    # Resume: detect existing frames and skip ahead with conservative rollback.\n    # Delete the last few frames (may be corrupt from mid-write or FFmpeg\n    # output buffering) and re-extract from that point.\n    _RESUME_ROLLBACK = 3  # frames to re-extract for safety\n    start_frame = 0\n    existing = sorted([f for f in os.listdir(out_dir) if f.lower().endswith(\".png\")])\n    if existing:\n        # Remove the last N frames — they may be corrupt or incomplete\n        remove_count = min(_RESUME_ROLLBACK, len(existing))\n        for fname in existing[-remove_count:]:\n            os.remove(os.path.join(out_dir, fname))\n        start_frame = max(0, len(existing) - remove_count)\n        if start_frame > 0:\n            logger.info(\n                f\"Resuming extraction from frame {start_frame} ({len(existing)} existed, rolled back {remove_count})\"\n            )\n\n    if start_frame > 0 and total_frames > 0:\n        # Seek to the resume point\n        if video_info is None:\n            video_info = probe_video(video_path)\n        fps = video_info.get(\"fps\", 24.0)\n        seek_sec = start_frame / fps\n        cmd = [\n            ffmpeg,\n            \"-ss\",\n            f\"{seek_sec:.4f}\",\n            \"-i\",\n            video_path,\n            \"-start_number\",\n            str(start_frame),\n            \"-vsync\",\n            \"passthrough\",\n            os.path.join(out_dir, pattern),\n            \"-y\",\n        ]\n    else:\n        cmd = [\n            ffmpeg,\n            \"-i\",\n            video_path,\n            \"-start_number\",\n            \"0\",\n            \"-vsync\",\n            \"passthrough\",\n            os.path.join(out_dir, pattern),\n            \"-y\",\n        ]\n\n    logger.info(f\"Extracting frames: {video_path} -> {out_dir} (start_frame={start_frame})\")\n\n    proc = subprocess.Popen(\n        cmd,\n        stdin=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        stdout=subprocess.DEVNULL,\n        text=True,\n        creationflags=subprocess.CREATE_NO_WINDOW if os.name == \"nt\" else 0,\n    )\n\n    last_frame = start_frame\n    frame_re = re.compile(r\"frame=\\s*(\\d+)\")\n\n    # Read stderr in a background thread so cancel checks aren't blocked\n    import queue as _queue\n\n    line_q: _queue.Queue[str | None] = _queue.Queue()\n\n    def _reader():\n        for ln in proc.stderr:\n            line_q.put(ln)\n        line_q.put(None)  # sentinel\n\n    reader_thread = threading.Thread(target=_reader, daemon=True)\n    reader_thread.start()\n\n    try:\n        while True:\n            # Check cancellation every 0.2s even if no output\n            if cancel_event and cancel_event.is_set():\n                proc.kill()\n                try:\n                    proc.wait(timeout=5)\n                except subprocess.TimeoutExpired:\n                    pass\n                logger.info(\"Extraction cancelled — FFmpeg killed\")\n                return last_frame\n\n            try:\n                line = line_q.get(timeout=0.2)\n            except _queue.Empty:\n                # No output yet — check if process is still alive\n                if proc.poll() is not None:\n                    break\n                continue\n\n            if line is None:\n                break  # stderr closed — process ending\n\n            match = frame_re.search(line)\n            if match:\n                last_frame = start_frame + int(match.group(1))\n                if on_progress and total_frames > 0:\n                    on_progress(last_frame, total_frames)\n\n        proc.wait(timeout=30)\n    except subprocess.TimeoutExpired:\n        proc.kill()\n        raise RuntimeError(\"FFmpeg extraction timed out\") from None\n\n    if proc.returncode != 0 and not (cancel_event and cancel_event.is_set()):\n        raise RuntimeError(f\"FFmpeg extraction failed with code {proc.returncode}\")\n\n    # Count actual extracted frames\n    extracted = len([f for f in os.listdir(out_dir) if f.lower().endswith(\".png\")])\n    logger.info(f\"Extracted {extracted} frames to {out_dir}\")\n    return extracted\n\n\ndef stitch_video(\n    in_dir: str,\n    out_path: str,\n    fps: float = 24.0,\n    pattern: str = \"frame_%06d.png\",\n    codec: str = \"libx264\",\n    crf: int = 18,\n    on_progress: Callable[[int, int], None] | None = None,\n    cancel_event: threading.Event | None = None,\n) -> None:\n    \"\"\"Stitch image sequence back into a video file.\n\n    Args:\n        in_dir: Directory containing frame images.\n        out_path: Output video file path.\n        fps: Frame rate.\n        pattern: Frame filename pattern.\n        codec: Video codec (libx264, libx265, etc.).\n        crf: Quality (0-51, lower = better).\n        on_progress: Callback(current_frame, total_frames).\n        cancel_event: Set to cancel stitching.\n\n    Raises:\n        RuntimeError if ffmpeg is not found or stitching fails.\n    \"\"\"\n    ffmpeg = find_ffmpeg()\n    if not ffmpeg:\n        raise RuntimeError(\"ffmpeg not found\")\n\n    # Count total frames\n    total_frames = len([f for f in os.listdir(in_dir) if f.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".exr\"))])\n\n    cmd = [\n        ffmpeg,\n        \"-framerate\",\n        str(fps),\n        \"-start_number\",\n        \"0\",\n        \"-i\",\n        os.path.join(in_dir, pattern),\n        \"-c:v\",\n        codec,\n        \"-crf\",\n        str(crf),\n        \"-pix_fmt\",\n        \"yuv420p\",\n        out_path,\n        \"-y\",\n    ]\n\n    logger.info(f\"Stitching video: {in_dir} -> {out_path}\")\n\n    proc = subprocess.Popen(\n        cmd,\n        stdin=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n        stdout=subprocess.DEVNULL,\n        text=True,\n        creationflags=subprocess.CREATE_NO_WINDOW if os.name == \"nt\" else 0,\n    )\n\n    frame_re = re.compile(r\"frame=\\s*(\\d+)\")\n\n    try:\n        for line in proc.stderr:\n            if cancel_event and cancel_event.is_set():\n                try:\n                    proc.stdin.write(\"q\\n\")\n                    proc.stdin.flush()\n                except Exception:\n                    pass\n                proc.wait(timeout=5)\n                logger.info(\"Stitching cancelled\")\n                return\n\n            match = frame_re.search(line)\n            if match:\n                current = int(match.group(1))\n                if on_progress and total_frames > 0:\n                    on_progress(current, total_frames)\n\n        proc.wait(timeout=60)\n    except subprocess.TimeoutExpired:\n        proc.kill()\n        raise RuntimeError(\"FFmpeg stitching timed out\") from None\n\n    if proc.returncode != 0 and not (cancel_event and cancel_event.is_set()):\n        raise RuntimeError(f\"FFmpeg stitching failed with code {proc.returncode}\")\n\n    logger.info(f\"Video stitched: {out_path}\")\n\n\ndef write_video_metadata(clip_root: str, metadata: dict) -> None:\n    \"\"\"Write video metadata sidecar JSON to clip root.\n\n    Metadata typically includes: source_path, fps, width, height,\n    frame_count, codec, duration.\n    \"\"\"\n    path = os.path.join(clip_root, _METADATA_FILENAME)\n    with open(path, \"w\") as f:\n        json.dump(metadata, f, indent=2)\n    logger.debug(f\"Video metadata written: {path}\")\n\n\ndef read_video_metadata(clip_root: str) -> dict | None:\n    \"\"\"Read video metadata sidecar from clip root. Returns None if not found.\"\"\"\n    path = os.path.join(clip_root, _METADATA_FILENAME)\n    if not os.path.isfile(path):\n        return None\n    try:\n        with open(path, \"r\") as f:\n            return json.load(f)\n    except (json.JSONDecodeError, OSError) as e:\n        logger.debug(f\"Failed to read video metadata: {e}\")\n        return None\n"
  },
  {
    "path": "backend/frame_io.py",
    "content": "\"\"\"Unified frame I/O — read images and video frames as float32 RGB.\n\nAll reading functions return float32 arrays in [0, 1] range with RGB channel\norder. EXR files are read as-is (linear float); standard formats (PNG, JPG,\netc.) are normalized from uint8.\n\nThis module consolidates frame-reading patterns that were previously duplicated\nacross service.py methods (_read_input_frame, reprocess_single_frame,\n_load_frames_for_videomama, _load_mask_frames_for_videomama).\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nfrom typing import Callable\n\nimport cv2\nimport numpy as np\n\nfrom CorridorKeyModule.core.color_utils import linear_to_srgb\n\nfrom .validators import normalize_mask_channels, normalize_mask_dtype\n\nlogger = logging.getLogger(__name__)\n\n# EXR write flags — PXR24 half-float (smallest working compression)\nEXR_WRITE_FLAGS = [\n    cv2.IMWRITE_EXR_TYPE,\n    cv2.IMWRITE_EXR_TYPE_HALF,\n    cv2.IMWRITE_EXR_COMPRESSION,\n    cv2.IMWRITE_EXR_COMPRESSION_PXR24,\n]\n\n\ndef read_image_frame(fpath: str, gamma_correct_exr: bool = False) -> np.ndarray | None:\n    \"\"\"Read an image file (EXR or standard) as float32 RGB [0, 1].\n\n    Args:\n        fpath: Absolute path to image file.\n        gamma_correct_exr: If True, apply piecewise sRGB transfer function\n            to EXR data (converts linear → sRGB for models expecting sRGB).\n\n    Returns:\n        float32 array [H, W, 3] in RGB order, or None if read fails.\n    \"\"\"\n    is_exr = fpath.lower().endswith(\".exr\")\n\n    if is_exr:\n        img = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)\n        if img is None:\n            logger.warning(\"Could not read frame: %s\", fpath)\n            return None\n        # Strip alpha channel from BGRA EXR\n        if img.ndim == 3 and img.shape[2] == 4:\n            img = img[:, :, :3]\n        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        result = np.maximum(img_rgb, 0.0).astype(np.float32)\n        if gamma_correct_exr:\n            result = linear_to_srgb(result).astype(np.float32)\n        return result\n    else:\n        img = cv2.imread(fpath)\n        if img is None:\n            logger.warning(\"Could not read frame: %s\", fpath)\n            return None\n        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        return img_rgb.astype(np.float32) / 255.0\n\n\ndef read_video_frame_at(\n    video_path: str,\n    frame_index: int,\n) -> np.ndarray | None:\n    \"\"\"Read a single frame from a video by index, as float32 RGB [0, 1].\n\n    Args:\n        video_path: Path to video file.\n        frame_index: Zero-based frame index to seek to.\n\n    Returns:\n        float32 array [H, W, 3] in RGB order, or None if seek/read fails.\n    \"\"\"\n    if frame_index < 0:\n        logger.warning(\"Invalid frame_index %d (must be >= 0)\", frame_index)\n        return None\n    cap = cv2.VideoCapture(video_path)\n    try:\n        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)\n        ret, frame = cap.read()\n        if not ret:\n            logger.warning(\"Could not read video frame %d from: %s\", frame_index, video_path)\n            return None\n        return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0\n    finally:\n        cap.release()\n\n\ndef read_video_frames(\n    video_path: str,\n    processor: Callable[[np.ndarray], np.ndarray] | None = None,\n) -> list[np.ndarray]:\n    \"\"\"Read all frames from a video, optionally applying a processor to each.\n\n    Without a processor, frames are returned as float32 RGB [0, 1].\n\n    Args:\n        video_path: Path to video file.\n        processor: Optional callable (BGR uint8 frame) → processed array.\n            If None, default conversion to float32 RGB [0, 1] is applied.\n\n    Returns:\n        List of processed frames.\n    \"\"\"\n    frames: list[np.ndarray] = []\n    cap = cv2.VideoCapture(video_path)\n    try:\n        while True:\n            ret, frame = cap.read()\n            if not ret:\n                break\n            if processor is not None:\n                frames.append(processor(frame))\n            else:\n                img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0\n                frames.append(img_rgb)\n    finally:\n        cap.release()\n    return frames\n\n\ndef read_mask_frame(fpath: str, clip_name: str = \"\", frame_index: int = 0) -> np.ndarray | None:\n    \"\"\"Read a mask frame as float32 [H, W] in [0, 1].\n\n    Handles any channel count and dtype via normalize_mask_channels/dtype.\n\n    Args:\n        fpath: Path to mask image.\n        clip_name: For error context in normalization.\n        frame_index: For error context in normalization.\n\n    Returns:\n        float32 array [H, W] in [0, 1], or None if read fails.\n    \"\"\"\n    mask_in = cv2.imread(fpath, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED)\n    if mask_in is None:\n        return None\n    # dtype normalization MUST happen before channel extraction, because\n    # normalize_mask_channels casts to float32 — which would make a uint8\n    # 255 into float32 255.0, skipping the /255 division in normalize_mask_dtype.\n    mask = normalize_mask_dtype(mask_in)\n    mask = normalize_mask_channels(mask, clip_name, frame_index)\n    return mask\n\n\ndef read_video_mask_at(\n    video_path: str,\n    frame_index: int,\n) -> np.ndarray | None:\n    \"\"\"Read a single mask frame from a video by index, as float32 [H, W] [0, 1].\n\n    Extracts the blue channel (index 2) from BGR, matching the convention\n    used by alpha-channel video masks.\n\n    Args:\n        video_path: Path to video file.\n        frame_index: Zero-based frame index.\n\n    Returns:\n        float32 array [H, W] in [0, 1], or None if seek/read fails.\n    \"\"\"\n    if frame_index < 0:\n        logger.warning(\"Invalid frame_index %d (must be >= 0)\", frame_index)\n        return None\n    cap = cv2.VideoCapture(video_path)\n    try:\n        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)\n        ret, frame = cap.read()\n        if not ret:\n            return None\n        return frame[:, :, 2].astype(np.float32) / 255.0\n    finally:\n        cap.release()\n"
  },
  {
    "path": "backend/job_queue.py",
    "content": "\"\"\"GPU job queue with mutual exclusion.\n\nEnsures only ONE GPU job runs at a time across all job types\n(inference, GVM alpha gen, VideoMaMa alpha gen). This prevents VRAM\ncontention — CorridorKey alone needs ~22.7GB of 24GB.\n\nDesign:\n    - Thread-safe queue of GPUJob dataclasses\n    - Single consumer loop (designed to be driven by a QThread in the UI,\n      or called directly in CLI mode)\n    - Jobs carry a cancel flag checked between frames\n    - Callbacks for progress, warnings, completion, errors\n    - Jobs have stable IDs assigned at creation time\n    - Deduplication prevents double-submit of same clip+job_type\n    - Job history preserved for UI display (cancelled/completed/failed)\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport threading\nimport uuid\nfrom collections import deque\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom typing import Any, Callable\n\nfrom .errors import JobCancelledError\n\nlogger = logging.getLogger(__name__)\n\n\nclass JobType(Enum):\n    INFERENCE = \"inference\"\n    GVM_ALPHA = \"gvm_alpha\"\n    VIDEOMAMA_ALPHA = \"videomama_alpha\"\n    PREVIEW_REPROCESS = \"preview_reprocess\"\n    VIDEO_EXTRACT = \"video_extract\"\n    VIDEO_STITCH = \"video_stitch\"\n\n\nclass JobStatus(Enum):\n    QUEUED = \"queued\"\n    RUNNING = \"running\"\n    COMPLETED = \"completed\"\n    CANCELLED = \"cancelled\"\n    FAILED = \"failed\"\n\n\n@dataclass\nclass GPUJob:\n    \"\"\"A single GPU job to be executed.\"\"\"\n\n    job_type: JobType\n    clip_name: str\n    id: str = field(default_factory=lambda: uuid.uuid4().hex[:8])\n    params: dict[str, Any] = field(default_factory=dict)\n    status: JobStatus = JobStatus.QUEUED\n    _cancel_requested: bool = field(default=False, repr=False)\n    error_message: str | None = None\n\n    # Progress tracking\n    current_frame: int = 0\n    total_frames: int = 0\n\n    def request_cancel(self) -> None:\n        \"\"\"Signal that this job should stop at the next frame boundary.\"\"\"\n        self._cancel_requested = True\n\n    @property\n    def is_cancelled(self) -> bool:\n        return self._cancel_requested\n\n    def check_cancelled(self) -> None:\n        \"\"\"Raise JobCancelledError if cancel was requested. Call between frames.\"\"\"\n        if self._cancel_requested:\n            raise JobCancelledError(self.clip_name, self.current_frame)\n\n\n# Callback type aliases\nProgressCallback = Callable[[str, int, int], None]  # clip_name, current, total\nWarningCallback = Callable[[str], None]  # message\nCompletionCallback = Callable[[str], None]  # clip_name\nErrorCallback = Callable[[str, str], None]  # clip_name, error_message\n\n\nclass GPUJobQueue:\n    \"\"\"Thread-safe GPU job queue with mutual exclusion.\n\n    Usage (CLI mode):\n        queue = GPUJobQueue()\n        queue.submit(GPUJob(JobType.INFERENCE, \"shot1\", params={...}))\n        queue.submit(GPUJob(JobType.GVM_ALPHA, \"shot2\", params={...}))\n\n        # Process all jobs sequentially\n        while queue.has_pending:\n            job = queue.next_job()\n            if job:\n                queue.start_job(job)\n                try:\n                    run_the_job(job)  # your processing function\n                    queue.complete_job(job)\n                except Exception as e:\n                    queue.fail_job(job, str(e))\n\n    Usage (GUI mode):\n        The GPU worker QThread calls next_job() / start_job() / complete_job()\n        in its run loop. The UI submits jobs from the main thread.\n    \"\"\"\n\n    def __init__(self):\n        self._queue: deque[GPUJob] = deque()\n        self._lock = threading.Lock()\n        self._current_job: GPUJob | None = None\n        self._history: list[GPUJob] = []  # completed/cancelled/failed jobs for UI display\n\n        # Callbacks (set by UI or CLI)\n        self.on_progress: ProgressCallback | None = None\n        self.on_warning: WarningCallback | None = None\n        self.on_completion: CompletionCallback | None = None\n        self.on_error: ErrorCallback | None = None\n\n    def submit(self, job: GPUJob) -> bool:\n        \"\"\"Add a job to the queue. Returns False if duplicate detected.\n\n        PREVIEW_REPROCESS uses replacement semantics — any existing preview\n        reprocess in the queue is replaced by the new one (latest-only).\n        \"\"\"\n        with self._lock:\n            # PREVIEW_REPROCESS: replace existing queued preview jobs (latest-only)\n            if job.job_type == JobType.PREVIEW_REPROCESS:\n                replaced = [j for j in self._queue if j.job_type == JobType.PREVIEW_REPROCESS]\n                for old in replaced:\n                    self._queue.remove(old)\n                    old.status = JobStatus.CANCELLED\n                    logger.debug(f\"Preview reprocess [{old.id}] replaced by [{job.id}]\")\n            else:\n                # Deduplication: reject if same clip+job_type already queued or running\n                for existing in self._queue:\n                    if existing.clip_name == job.clip_name and existing.job_type == job.job_type:\n                        logger.warning(\n                            f\"Duplicate job rejected: {job.job_type.value} for '{job.clip_name}' \"\n                            f\"(already queued as {existing.id})\"\n                        )\n                        return False\n                if (\n                    self._current_job\n                    and self._current_job.clip_name == job.clip_name\n                    and self._current_job.job_type == job.job_type\n                    and self._current_job.status == JobStatus.RUNNING\n                ):\n                    logger.warning(\n                        f\"Duplicate job rejected: {job.job_type.value} for '{job.clip_name}' \"\n                        f\"(already running as {self._current_job.id})\"\n                    )\n                    return False\n\n            job.status = JobStatus.QUEUED\n            self._queue.append(job)\n            logger.info(f\"Job queued [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n            return True\n\n    def next_job(self) -> GPUJob | None:\n        \"\"\"Get the next pending job without starting it. Returns None if empty.\"\"\"\n        with self._lock:\n            if self._queue:\n                return self._queue[0]\n            return None\n\n    def start_job(self, job: GPUJob) -> None:\n        \"\"\"Mark a job as running. Must be called before processing.\"\"\"\n        with self._lock:\n            if job in self._queue:\n                self._queue.remove(job)\n            job.status = JobStatus.RUNNING\n            self._current_job = job\n            logger.info(f\"Job started [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n\n    def complete_job(self, job: GPUJob) -> None:\n        \"\"\"Mark a job as successfully completed.\"\"\"\n        with self._lock:\n            job.status = JobStatus.COMPLETED\n            if self._current_job is job:\n                self._current_job = None\n            self._history.append(job)\n            logger.info(f\"Job completed [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n        # Emit AFTER lock release (Codex: no deadlock risk)\n        if self.on_completion:\n            self.on_completion(job.clip_name)\n\n    def fail_job(self, job: GPUJob, error: str) -> None:\n        \"\"\"Mark a job as failed.\"\"\"\n        with self._lock:\n            job.status = JobStatus.FAILED\n            job.error_message = error\n            if self._current_job is job:\n                self._current_job = None\n            self._history.append(job)\n            logger.error(f\"Job failed [{job.id}]: {job.job_type.value} for '{job.clip_name}': {error}\")\n        # Emit AFTER lock release\n        if self.on_error:\n            self.on_error(job.clip_name, error)\n\n    def mark_cancelled(self, job: GPUJob) -> None:\n        \"\"\"Mark a running job as cancelled AND clear _current_job.\n\n        This is the cancel-safe path that was missing — calling\n        job.request_cancel() alone doesn't clear _current_job, which\n        poisons queue state for subsequent jobs.\n        \"\"\"\n        with self._lock:\n            job.status = JobStatus.CANCELLED\n            if self._current_job is job:\n                self._current_job = None\n            self._history.append(job)\n            logger.info(f\"Job cancelled [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n\n    def cancel_job(self, job: GPUJob) -> None:\n        \"\"\"Request cancellation of a specific job.\"\"\"\n        with self._lock:\n            if job.status == JobStatus.QUEUED:\n                if job in self._queue:\n                    self._queue.remove(job)\n                job.status = JobStatus.CANCELLED\n                self._history.append(job)\n                logger.info(f\"Job removed from queue [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n            elif job.status == JobStatus.RUNNING:\n                # Signal cancel — worker calls mark_cancelled() after catching JobCancelledError\n                job.request_cancel()\n                logger.info(f\"Job cancel requested [{job.id}]: {job.job_type.value} for '{job.clip_name}'\")\n\n    def cancel_current(self) -> None:\n        \"\"\"Cancel the currently running job, if any.\"\"\"\n        with self._lock:\n            if self._current_job and self._current_job.status == JobStatus.RUNNING:\n                self._current_job.request_cancel()\n\n    def cancel_all(self) -> None:\n        \"\"\"Cancel current job and clear the queue.\"\"\"\n        with self._lock:\n            # Cancel current\n            if self._current_job and self._current_job.status == JobStatus.RUNNING:\n                self._current_job.request_cancel()\n            # Clear queue — preserve in history\n            for job in self._queue:\n                job.status = JobStatus.CANCELLED\n                self._history.append(job)\n            self._queue.clear()\n            logger.info(\"All jobs cancelled\")\n\n    def report_progress(self, clip_name: str, current: int, total: int) -> None:\n        \"\"\"Report progress for the current job. Called by processing code.\"\"\"\n        if self._current_job:\n            self._current_job.current_frame = current\n            self._current_job.total_frames = total\n        if self.on_progress:\n            self.on_progress(clip_name, current, total)\n\n    def report_warning(self, message: str) -> None:\n        \"\"\"Report a non-fatal warning. Called by processing code.\"\"\"\n        logger.warning(message)\n        if self.on_warning:\n            self.on_warning(message)\n\n    def find_job_by_id(self, job_id: str) -> GPUJob | None:\n        \"\"\"Find a job by ID in queue, current, or history.\"\"\"\n        with self._lock:\n            if self._current_job and self._current_job.id == job_id:\n                return self._current_job\n            for job in self._queue:\n                if job.id == job_id:\n                    return job\n            for job in self._history:\n                if job.id == job_id:\n                    return job\n        return None\n\n    def clear_history(self) -> None:\n        \"\"\"Clear job history (for UI reset).\"\"\"\n        with self._lock:\n            self._history.clear()\n\n    def remove_job(self, job_id: str) -> None:\n        \"\"\"Remove a single finished job from history.\"\"\"\n        with self._lock:\n            self._history = [j for j in self._history if j.id != job_id]\n\n    @property\n    def has_pending(self) -> bool:\n        with self._lock:\n            return len(self._queue) > 0\n\n    @property\n    def current_job(self) -> GPUJob | None:\n        with self._lock:\n            return self._current_job\n\n    @property\n    def pending_count(self) -> int:\n        with self._lock:\n            return len(self._queue)\n\n    @property\n    def queue_snapshot(self) -> list[GPUJob]:\n        \"\"\"Return a copy of the current queue for display purposes.\"\"\"\n        with self._lock:\n            return list(self._queue)\n\n    @property\n    def history_snapshot(self) -> list[GPUJob]:\n        \"\"\"Return a copy of job history for display purposes.\"\"\"\n        with self._lock:\n            return list(self._history)\n\n    @property\n    def all_jobs_snapshot(self) -> list[GPUJob]:\n        \"\"\"Return current + queued + history for full queue panel display.\"\"\"\n        with self._lock:\n            result = []\n            if self._current_job:\n                result.append(self._current_job)\n            result.extend(self._queue)\n            result.extend(self._history)\n            return result\n"
  },
  {
    "path": "backend/natural_sort.py",
    "content": "\"\"\"Natural sort key for frame filenames.\n\nHandles non-zero-padded frame numbers correctly:\n  frame_1, frame_2, frame_10  (not frame_1, frame_10, frame_2)\n\nNo external dependency — pure Python implementation.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport re\n\n_SPLIT_RE = re.compile(r\"(\\d+)\")\n\n\ndef natural_sort_key(text: str) -> list[str | int]:\n    \"\"\"Return a sort key that orders numeric substrings numerically.\n\n    >>> sorted(['f_1', 'f_10', 'f_2'], key=natural_sort_key)\n    ['f_1', 'f_2', 'f_10']\n    \"\"\"\n    parts: list[str | int] = []\n    for chunk in _SPLIT_RE.split(text):\n        if chunk.isdigit():\n            parts.append(int(chunk))\n        else:\n            parts.append(chunk.lower())\n    return parts\n\n\ndef natsorted(items: list[str]) -> list[str]:\n    \"\"\"Return a naturally sorted copy of a list of strings.\"\"\"\n    return sorted(items, key=natural_sort_key)\n"
  },
  {
    "path": "backend/project.py",
    "content": "\"\"\"Project folder management — creation, scanning, and metadata.\n\nA project is a timestamped container holding one or more clips:\n    Projects/\n        260301_093000_Woman_Jumps/\n            project.json                    (v2 — project-level metadata)\n            clips/\n                Woman_Jumps/                (ClipEntry.root_path → here)\n                    Source/\n                        Woman_Jumps_For_Joy.mp4\n                    Frames/\n                    AlphaHint/\n                    Output/FG/ Matte/ Comp/ Processed/\n                    clip.json               (per-clip metadata)\n                Man_Walks/\n                    Source/...\n\nLegacy v1 format (no clips/ dir) is still supported for backward compat.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport os\nimport re\nimport shutil\nimport sys\nfrom datetime import datetime\n\nlogger = logging.getLogger(__name__)\n\n_VIDEO_EXTS = frozenset({\".mp4\", \".mov\", \".avi\", \".mkv\", \".mxf\", \".webm\", \".m4v\"})\n_IMAGE_EXTS = frozenset({\".png\", \".jpg\", \".jpeg\", \".exr\", \".tif\", \".tiff\", \".bmp\", \".dpx\"})\nVIDEO_FILE_FILTER = \"Video Files (*.mp4 *.mov *.avi *.mkv *.mxf *.webm *.m4v);;All Files (*)\"\n\n_app_dir: str | None = None\n\n\ndef _dedupe_path(parent_dir: str, stem: str) -> tuple[str, str]:\n    \"\"\"Return a unique child path under *parent_dir* and its final stem.\n\n    If ``{parent_dir}/{stem}`` already exists, appends numeric suffixes\n    (``_2``, ``_3``, ...) until a free path is found.\n\n    Unlike fixed-range probes, this never silently falls back to an existing\n    path after enough collisions.\n    \"\"\"\n    path = os.path.join(parent_dir, stem)\n    if not os.path.exists(path):\n        return path, stem\n\n    index = 2\n    while True:\n        candidate_stem = f\"{stem}_{index}\"\n        candidate_path = os.path.join(parent_dir, candidate_stem)\n        if not os.path.exists(candidate_path):\n            return candidate_path, candidate_stem\n        index += 1\n\n\ndef set_app_dir(path: str) -> None:\n    \"\"\"Set the application directory. Called once at startup by main.py.\"\"\"\n    global _app_dir\n    _app_dir = path\n\n\ndef projects_root() -> str:\n    \"\"\"Return the Projects root directory, creating it if needed.\n\n    In dev mode: {repo_root}/Projects/\n    In frozen mode: {exe_dir}/Projects/\n    \"\"\"\n    if _app_dir:\n        root = os.path.join(_app_dir, \"Projects\")\n    elif getattr(sys, \"frozen\", False):\n        root = os.path.join(os.path.dirname(sys.executable), \"Projects\")\n    else:\n        # Fallback: two levels up from this file (backend/ -> repo root)\n        root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), \"Projects\")\n    os.makedirs(root, exist_ok=True)\n    return root\n\n\ndef sanitize_stem(filename: str, max_len: int = 60) -> str:\n    \"\"\"Clean a filename stem for use in folder names.\n\n    Strips extension, replaces non-alphanumeric chars with underscores,\n    collapses runs, and truncates.\n    \"\"\"\n    stem = os.path.splitext(filename)[0]\n    stem = re.sub(r\"[^\\w\\-]\", \"_\", stem)\n    stem = re.sub(r\"_+\", \"_\", stem).strip(\"_\")\n    return stem[:max_len]\n\n\ndef create_project(\n    source_video_paths: str | list[str],\n    *,\n    copy_source: bool = True,\n    display_name: str | None = None,\n) -> str:\n    \"\"\"Create a new project folder for one or more source videos.\n\n    Creates a v2 project with a ``clips/`` subdirectory.  Each video\n    gets its own clip subfolder inside ``clips/``.\n\n    When *copy_source* is True (default), video files are copied into\n    each clip's ``Source/`` directory.  When False, the clip stores a\n    reference to the original file path.\n\n    Creates: Projects/YYMMDD_HHMMSS_{stem}/clips/{clip_stem}/Source/...\n\n    Args:\n        source_video_paths: Single video path (str) or list of paths.\n        copy_source: Whether to copy video files into clip folders.\n        display_name: Optional project name. If provided, used for both\n            the folder name stem and display_name in project.json.\n            If None, derived from the first video filename.\n\n    Returns:\n        Absolute path to the new project folder.\n    \"\"\"\n    # Accept single path for backward compat\n    if isinstance(source_video_paths, str):\n        source_video_paths = [source_video_paths]\n    if not source_video_paths:\n        raise ValueError(\"At least one source video path is required\")\n\n    root = projects_root()\n\n    if display_name and display_name.strip():\n        clean = display_name.strip()\n        # Sanitize for folder name (no splitext — it's not a filename)\n        name_stem = re.sub(r\"[^\\w\\-]\", \"_\", clean)\n        name_stem = re.sub(r\"_+\", \"_\", name_stem).strip(\"_\")[:60]\n        project_display_name = clean\n    else:\n        first_filename = os.path.basename(source_video_paths[0])\n        name_stem = sanitize_stem(first_filename)\n        project_display_name = name_stem.replace(\"_\", \" \")\n\n    timestamp = datetime.now().strftime(\"%y%m%d_%H%M%S\")\n    folder_name = f\"{timestamp}_{name_stem}\"\n\n    # Deduplicate if folder already exists (e.g. rapid imports)\n    project_dir, _ = _dedupe_path(root, folder_name)\n\n    clips_dir = os.path.join(project_dir, \"clips\")\n    os.makedirs(clips_dir, exist_ok=True)\n\n    clip_names: list[str] = []\n    for video_path in source_video_paths:\n        clip_name = _create_clip_folder(\n            clips_dir,\n            video_path,\n            copy_source=copy_source,\n        )\n        clip_names.append(clip_name)\n\n    # Write project.json (v2 — project-level metadata only)\n    write_project_json(\n        project_dir,\n        {\n            \"version\": 2,\n            \"created\": datetime.now().isoformat(),\n            \"display_name\": project_display_name,\n            \"clips\": clip_names,\n        },\n    )\n\n    return project_dir\n\n\ndef add_clips_to_project(\n    project_dir: str,\n    source_video_paths: list[str],\n    *,\n    copy_source: bool = True,\n) -> list[str]:\n    \"\"\"Add new clips to an existing project.\n\n    Args:\n        project_dir: Absolute path to the project folder.\n        source_video_paths: List of video file paths to add.\n        copy_source: Whether to copy videos into clip folders.\n\n    Returns:\n        List of new clip subfolder paths (absolute).\n    \"\"\"\n    clips_dir = os.path.join(project_dir, \"clips\")\n    os.makedirs(clips_dir, exist_ok=True)\n\n    new_paths: list[str] = []\n    for video_path in source_video_paths:\n        clip_name = _create_clip_folder(\n            clips_dir,\n            video_path,\n            copy_source=copy_source,\n        )\n        new_paths.append(os.path.join(clips_dir, clip_name))\n\n    # Update project.json clips list\n    data = read_project_json(project_dir) or {}\n    existing = data.get(\"clips\", [])\n    for p in new_paths:\n        existing.append(os.path.basename(p))\n    data[\"clips\"] = existing\n    write_project_json(project_dir, data)\n\n    return new_paths\n\n\ndef _create_clip_folder(\n    clips_dir: str,\n    video_path: str,\n    *,\n    copy_source: bool = True,\n) -> str:\n    \"\"\"Create a single clip subfolder inside clips_dir.\n\n    Returns the clip folder name (not full path).\n    \"\"\"\n    filename = os.path.basename(video_path)\n    clip_name = sanitize_stem(filename)\n\n    # Deduplicate clip folder names within same project\n    clip_dir, clip_name = _dedupe_path(clips_dir, clip_name)\n\n    source_dir = os.path.join(clip_dir, \"Source\")\n    os.makedirs(source_dir, exist_ok=True)\n\n    if copy_source:\n        target = os.path.join(source_dir, filename)\n        if not os.path.isfile(target):\n            shutil.copy2(video_path, target)\n            logger.info(f\"Copied source video: {video_path} -> {target}\")\n    else:\n        logger.info(f\"Referencing source video in place: {video_path}\")\n\n    # Write clip.json (per-clip metadata)\n    write_clip_json(\n        clip_dir,\n        {\n            \"source\": {\n                \"original_path\": os.path.abspath(video_path),\n                \"filename\": filename,\n                \"copied\": copy_source,\n            },\n        },\n    )\n\n    return clip_name\n\n\ndef get_clip_dirs(project_dir: str) -> list[str]:\n    \"\"\"Return absolute paths to all clip subdirectories in a project.\n\n    For v2 projects (with clips/ dir), scans clips/ subdirectories.\n    For v1 projects (no clips/ dir), returns [project_dir] as a single clip.\n    \"\"\"\n    clips_dir = os.path.join(project_dir, \"clips\")\n    if os.path.isdir(clips_dir):\n        return sorted(\n            os.path.join(clips_dir, d)\n            for d in os.listdir(clips_dir)\n            if os.path.isdir(os.path.join(clips_dir, d)) and not d.startswith(\".\") and not d.startswith(\"_\")\n        )\n    # v1 fallback: project dir itself is the clip\n    return [project_dir]\n\n\ndef is_v2_project(project_dir: str) -> bool:\n    \"\"\"Check if a project uses the v2 nested clips structure.\"\"\"\n    return os.path.isdir(os.path.join(project_dir, \"clips\"))\n\n\ndef write_project_json(project_root: str, data: dict) -> None:\n    \"\"\"Atomic write of project.json.\"\"\"\n    path = os.path.join(project_root, \"project.json\")\n    tmp_path = path + \".tmp\"\n    with open(tmp_path, \"w\") as f:\n        json.dump(data, f, indent=2)\n    os.replace(tmp_path, path)\n\n\ndef read_project_json(project_root: str) -> dict | None:\n    \"\"\"Read project.json, returning None if missing or corrupt.\"\"\"\n    path = os.path.join(project_root, \"project.json\")\n    if not os.path.isfile(path):\n        return None\n    try:\n        with open(path) as f:\n            return json.load(f)\n    except (json.JSONDecodeError, OSError) as e:\n        logger.warning(f\"Failed to read project.json at {path}: {e}\")\n        return None\n\n\ndef write_clip_json(clip_root: str, data: dict) -> None:\n    \"\"\"Atomic write of clip.json (per-clip metadata).\"\"\"\n    path = os.path.join(clip_root, \"clip.json\")\n    tmp_path = path + \".tmp\"\n    with open(tmp_path, \"w\") as f:\n        json.dump(data, f, indent=2)\n    os.replace(tmp_path, path)\n\n\ndef read_clip_json(clip_root: str) -> dict | None:\n    \"\"\"Read clip.json, returning None if missing or corrupt.\"\"\"\n    path = os.path.join(clip_root, \"clip.json\")\n    if not os.path.isfile(path):\n        return None\n    try:\n        with open(path) as f:\n            return json.load(f)\n    except (json.JSONDecodeError, OSError) as e:\n        logger.warning(f\"Failed to read clip.json at {path}: {e}\")\n        return None\n\n\ndef _read_clip_or_project_json(root: str) -> dict | None:\n    \"\"\"Read clip.json first, falling back to project.json for v1 compat.\"\"\"\n    data = read_clip_json(root)\n    if data is not None:\n        return data\n    return read_project_json(root)\n\n\ndef get_display_name(root: str) -> str:\n    \"\"\"Get the user-visible name for a clip or project.\n\n    Checks clip.json first, then project.json, falling back to folder name.\n    \"\"\"\n    data = _read_clip_or_project_json(root)\n    if data and data.get(\"display_name\"):\n        return data[\"display_name\"]\n    return os.path.basename(root)\n\n\ndef set_display_name(root: str, name: str) -> None:\n    \"\"\"Update display_name. Writes to clip.json if it exists, else project.json.\"\"\"\n    if os.path.isfile(os.path.join(root, \"clip.json\")):\n        data = read_clip_json(root) or {}\n        data[\"display_name\"] = name\n        write_clip_json(root, data)\n    else:\n        data = read_project_json(root) or {}\n        data[\"display_name\"] = name\n        write_project_json(root, data)\n\n\ndef save_in_out_range(clip_root: str, in_out) -> None:\n    \"\"\"Persist in/out range to clip.json (v2) or project.json (v1).\n\n    Pass None to clear.\n    \"\"\"\n    if os.path.isfile(os.path.join(clip_root, \"clip.json\")):\n        data = read_clip_json(clip_root) or {}\n        if in_out is not None:\n            data[\"in_out_range\"] = in_out.to_dict()\n        else:\n            data.pop(\"in_out_range\", None)\n        write_clip_json(clip_root, data)\n    else:\n        data = read_project_json(clip_root) or {}\n        if in_out is not None:\n            data[\"in_out_range\"] = in_out.to_dict()\n        else:\n            data.pop(\"in_out_range\", None)\n        write_project_json(clip_root, data)\n\n\ndef load_in_out_range(clip_root: str):\n    \"\"\"Load in/out range from clip.json or project.json, or None if not set.\"\"\"\n    data = _read_clip_or_project_json(clip_root)\n    if data and \"in_out_range\" in data:\n        try:\n            from .clip_state import InOutRange\n\n            return InOutRange.from_dict(data[\"in_out_range\"])\n        except (KeyError, TypeError):\n            return None\n    return None\n\n\ndef is_video_file(filename: str) -> bool:\n    \"\"\"Check if a filename has a video extension.\"\"\"\n    return os.path.splitext(filename)[1].lower() in _VIDEO_EXTS\n\n\ndef is_image_file(filename: str) -> bool:\n    \"\"\"Check if a filename has an image extension.\"\"\"\n    return os.path.splitext(filename)[1].lower() in _IMAGE_EXTS\n"
  },
  {
    "path": "backend/service.py",
    "content": "\"\"\"CorridorKeyService — clean backend API for the UI and CLI.\n\nThis module wraps all processing logic from clip_manager.py into a\nservice layer. The UI never calls inference engines directly — it\ncalls methods here which handle validation, state transitions, and\nerror reporting.\n\nModel Residency Policy:\n    Only ONE heavy model is loaded at a time. Before loading a new\n    model type, the previous is unloaded and VRAM freed via\n    device_utils.clear_device_cache(). This prevents OOM on 24GB cards.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport json\nimport logging\nimport os\nimport sys\nimport threading\nimport time\nfrom dataclasses import asdict, dataclass\nfrom enum import Enum\nfrom typing import Any, Callable\n\nimport numpy as np\n\n# Enable OpenEXR support (must be before cv2 import)\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\nimport cv2\n\nfrom .clip_state import (\n    ClipAsset,\n    ClipEntry,\n    ClipState,\n    scan_clips_dir,\n)\nfrom .errors import (\n    CorridorKeyError,\n    FrameReadError,\n    JobCancelledError,\n    WriteFailureError,\n)\nfrom .frame_io import (\n    EXR_WRITE_FLAGS,\n    read_image_frame,\n    read_mask_frame,\n    read_video_frame_at,\n    read_video_frames,\n    read_video_mask_at,\n)\nfrom .job_queue import GPUJob, GPUJobQueue\nfrom .validators import (\n    ensure_output_dirs,\n    validate_frame_counts,\n    validate_frame_read,\n    validate_write,\n)\n\nlogger = logging.getLogger(__name__)\n\n# Project paths — frozen-build aware\nif getattr(sys, \"frozen\", False):\n    BASE_DIR = sys._MEIPASS\nelse:\n    BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))\n\n\nclass _ActiveModel(Enum):\n    \"\"\"Tracks which heavy model is currently loaded in VRAM.\"\"\"\n\n    NONE = \"none\"\n    INFERENCE = \"inference\"\n    GVM = \"gvm\"\n    VIDEOMAMA = \"videomama\"\n\n\n@dataclass\nclass InferenceParams:\n    \"\"\"Frozen parameters for a single inference job.\"\"\"\n\n    input_is_linear: bool = False\n    despill_strength: float = 1.0  # 0.0 to 1.0\n    auto_despeckle: bool = True\n    despeckle_size: int = 400\n    refiner_scale: float = 1.0\n\n    def to_dict(self) -> dict:\n        return asdict(self)\n\n    @classmethod\n    def from_dict(cls, d: dict) -> \"InferenceParams\":\n        known = {f.name for f in cls.__dataclass_fields__.values()}\n        return cls(**{k: v for k, v in d.items() if k in known})\n\n\n@dataclass\nclass OutputConfig:\n    \"\"\"Which output types to produce and their format.\"\"\"\n\n    fg_enabled: bool = True\n    fg_format: str = \"exr\"  # \"exr\" or \"png\"\n    matte_enabled: bool = True\n    matte_format: str = \"exr\"\n    comp_enabled: bool = True\n    comp_format: str = \"png\"\n    processed_enabled: bool = True\n    processed_format: str = \"exr\"\n\n    def to_dict(self) -> dict:\n        return asdict(self)\n\n    @classmethod\n    def from_dict(cls, d: dict) -> \"OutputConfig\":\n        known = {f.name for f in cls.__dataclass_fields__.values()}\n        return cls(**{k: v for k, v in d.items() if k in known})\n\n    @property\n    def enabled_outputs(self) -> list[str]:\n        \"\"\"Return list of enabled output names for manifest.\"\"\"\n        out = []\n        if self.fg_enabled:\n            out.append(\"fg\")\n        if self.matte_enabled:\n            out.append(\"matte\")\n        if self.comp_enabled:\n            out.append(\"comp\")\n        if self.processed_enabled:\n            out.append(\"processed\")\n        return out\n\n\n@dataclass\nclass FrameResult:\n    \"\"\"Result summary for a single processed frame (no numpy in this struct).\"\"\"\n\n    frame_index: int\n    input_stem: str\n    success: bool\n    warning: str | None = None\n\n\nclass CorridorKeyService:\n    \"\"\"Main backend service — scan, validate, process, write.\n\n    Usage:\n        service = CorridorKeyService()\n        clips = service.scan_clips(\"/path/to/ClipsForInference\")\n        ready = service.get_clips_by_state(clips, ClipState.READY)\n\n        for clip in ready:\n            params = InferenceParams(despill_strength=0.8)\n            service.run_inference(clip, params, on_progress=my_callback)\n    \"\"\"\n\n    def __init__(self):\n        self._engine = None\n        self._gvm_processor = None\n        self._videomama_pipeline = None\n        self._active_model = _ActiveModel.NONE\n        self._device: str = \"cpu\"\n        self._job_queue: GPUJobQueue | None = None\n        # GPU mutex — serializes ALL model operations (Codex: thread safety)\n        self._gpu_lock = threading.Lock()\n\n    @property\n    def job_queue(self) -> GPUJobQueue:\n        \"\"\"Lazy-init GPU job queue (only needed when UI is running).\"\"\"\n        if self._job_queue is None:\n            self._job_queue = GPUJobQueue()\n        return self._job_queue\n\n    # --- Device & Engine Management ---\n\n    def detect_device(self) -> str:\n        \"\"\"Detect best available compute device using centralized device_utils.\"\"\"\n        try:\n            from device_utils import resolve_device\n\n            self._device = resolve_device()\n        except ImportError:\n            self._device = \"cpu\"\n            logger.warning(\"device_utils not available — using CPU\")\n        logger.info(f\"Compute device: {self._device}\")\n        return self._device\n\n    def get_vram_info(self) -> dict[str, float]:\n        \"\"\"Get GPU VRAM info in GB. Returns empty dict if not CUDA.\"\"\"\n        try:\n            import torch\n\n            if not torch.cuda.is_available():\n                return {}\n            props = torch.cuda.get_device_properties(0)\n            total_bytes = props.total_memory\n            reserved = torch.cuda.memory_reserved(0)\n            return {\n                \"total\": total_bytes / (1024**3),\n                \"reserved\": reserved / (1024**3),\n                \"allocated\": torch.cuda.memory_allocated(0) / (1024**3),\n                \"free\": (total_bytes - reserved) / (1024**3),\n                \"name\": torch.cuda.get_device_name(0),\n            }\n        except Exception as e:\n            logger.debug(f\"VRAM query failed: {e}\")\n            return {}\n\n    @staticmethod\n    def _vram_allocated_mb() -> float:\n        \"\"\"Return current VRAM allocated in MB, or 0 if unavailable.\"\"\"\n        try:\n            import torch\n\n            if torch.cuda.is_available():\n                return torch.cuda.memory_allocated(0) / (1024**2)\n        except Exception:\n            pass\n        return 0.0\n\n    @staticmethod\n    def _safe_offload(obj: object) -> None:\n        \"\"\"Move a model's GPU tensors to CPU before dropping the reference.\n\n        Handles diffusers pipelines (.to('cpu')), plain nn.Modules (.cpu()),\n        and objects with an explicit unload() method.\n        \"\"\"\n        if obj is None:\n            return\n        logger.debug(f\"Offloading model: {type(obj).__name__}\")\n        try:\n            if hasattr(obj, \"unload\"):\n                obj.unload()\n            elif hasattr(obj, \"to\"):\n                obj.to(\"cpu\")\n            elif hasattr(obj, \"cpu\"):\n                obj.cpu()\n        except Exception as e:\n            logger.debug(f\"Model offload warning: {e}\")\n\n    def _ensure_model(self, needed: _ActiveModel) -> None:\n        \"\"\"Model residency manager — unload current model if switching types.\n\n        Only ONE heavy model stays in VRAM at a time. Before loading a\n        different model, the previous is moved to CPU and dereferenced.\n        \"\"\"\n        if self._active_model == needed:\n            return\n\n        # Unload whatever is currently loaded\n        if self._active_model != _ActiveModel.NONE:\n            # Snapshot VRAM before unload for leak diagnosis\n            vram_before_mb = self._vram_allocated_mb()\n            logger.info(\n                f\"Unloading {self._active_model.value} model for {needed.value} (VRAM before: {vram_before_mb:.0f}MB)\"\n            )\n\n            if self._active_model == _ActiveModel.INFERENCE:\n                self._safe_offload(self._engine)\n                self._engine = None\n            elif self._active_model == _ActiveModel.GVM:\n                self._safe_offload(self._gvm_processor)\n                self._gvm_processor = None\n            elif self._active_model == _ActiveModel.VIDEOMAMA:\n                self._safe_offload(self._videomama_pipeline)\n                self._videomama_pipeline = None\n\n            import gc\n\n            gc.collect()\n\n            try:\n                from device_utils import clear_device_cache\n\n                clear_device_cache(self._device)\n            except ImportError:\n                logger.debug(\"device_utils not available for cache clear during model switch\")\n\n            vram_after_mb = self._vram_allocated_mb()\n            freed = vram_before_mb - vram_after_mb\n            logger.info(f\"VRAM after unload: {vram_after_mb:.0f}MB (freed {freed:.0f}MB)\")\n\n        self._active_model = needed\n\n    def _get_engine(self):\n        \"\"\"Lazy-load the CorridorKey inference engine.\"\"\"\n        self._ensure_model(_ActiveModel.INFERENCE)\n\n        if self._engine is not None:\n            return self._engine\n\n        try:\n            from CorridorKeyModule.backend import TORCH_EXT, _discover_checkpoint\n            from CorridorKeyModule.inference_engine import CorridorKeyEngine\n        except ImportError as exc:\n            raise RuntimeError(\"CorridorKeyModule is not installed. Run: uv sync\") from exc\n\n        ckpt_path = _discover_checkpoint(TORCH_EXT)\n        logger.info(f\"Loading checkpoint: {os.path.basename(ckpt_path)}\")\n        t0 = time.monotonic()\n        self._engine = CorridorKeyEngine(\n            checkpoint_path=ckpt_path,\n            device=self._device,\n            img_size=2048,\n        )\n        logger.info(f\"Engine loaded in {time.monotonic() - t0:.1f}s\")\n        return self._engine\n\n    def _get_gvm(self):\n        \"\"\"Lazy-load the GVM processor.\"\"\"\n        self._ensure_model(_ActiveModel.GVM)\n\n        if self._gvm_processor is not None:\n            return self._gvm_processor\n\n        from gvm_core import GVMProcessor\n\n        logger.info(\"Loading GVM processor...\")\n        t0 = time.monotonic()\n        self._gvm_processor = GVMProcessor(device=self._device)\n        logger.info(f\"GVM loaded in {time.monotonic() - t0:.1f}s\")\n        return self._gvm_processor\n\n    def _get_videomama_pipeline(self):\n        \"\"\"Lazy-load the VideoMaMa inference pipeline.\"\"\"\n        self._ensure_model(_ActiveModel.VIDEOMAMA)\n\n        if self._videomama_pipeline is not None:\n            return self._videomama_pipeline\n\n        sys.path.insert(0, os.path.join(BASE_DIR, \"VideoMaMaInferenceModule\"))\n        from VideoMaMaInferenceModule.inference import load_videomama_model\n\n        logger.info(\"Loading VideoMaMa pipeline...\")\n        t0 = time.monotonic()\n        self._videomama_pipeline = load_videomama_model(device=self._device)\n        logger.info(f\"VideoMaMa loaded in {time.monotonic() - t0:.1f}s\")\n        return self._videomama_pipeline\n\n    def unload_engines(self) -> None:\n        \"\"\"Free GPU memory by unloading all engines.\"\"\"\n        self._safe_offload(self._engine)\n        self._safe_offload(self._gvm_processor)\n        self._safe_offload(self._videomama_pipeline)\n        self._engine = None\n        self._gvm_processor = None\n        self._videomama_pipeline = None\n        self._active_model = _ActiveModel.NONE\n        try:\n            from device_utils import clear_device_cache\n\n            clear_device_cache(self._device)\n        except ImportError:\n            logger.debug(\"device_utils not available for cache clear during unload\")\n        logger.info(\"All engines unloaded, VRAM freed\")\n\n    # --- Clip Scanning ---\n\n    def scan_clips(\n        self,\n        clips_dir: str,\n        allow_standalone_videos: bool = True,\n    ) -> list[ClipEntry]:\n        \"\"\"Scan a directory for clip folders.\"\"\"\n        return scan_clips_dir(clips_dir, allow_standalone_videos=allow_standalone_videos)\n\n    def get_clips_by_state(\n        self,\n        clips: list[ClipEntry],\n        state: ClipState,\n    ) -> list[ClipEntry]:\n        \"\"\"Filter clips by state.\"\"\"\n        return [c for c in clips if c.state == state]\n\n    # --- Frame I/O Helpers ---\n\n    def _read_input_frame(\n        self,\n        clip: ClipEntry,\n        frame_index: int,\n        input_files: list[str],\n        input_cap: Any | None,\n        input_is_linear: bool,\n    ) -> tuple[np.ndarray | None, str, bool]:\n        \"\"\"Read a single input frame.\n\n        Returns:\n            (image_float32, stem_name, is_linear_override)\n        \"\"\"\n        logger.debug(f\"Reading input frame {frame_index} for '{clip.name}'\")\n        input_stem = f\"{frame_index:05d}\"\n\n        if input_cap:\n            ret, frame = input_cap.read()\n            if not ret:\n                return None, input_stem, False\n            img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n            return img_rgb.astype(np.float32) / 255.0, input_stem, input_is_linear\n        else:\n            if frame_index >= len(input_files):\n                logger.warning(\n                    f\"Clip '{clip.name}': frame_index {frame_index} out of range (have {len(input_files)} frames)\"\n                )\n                return None, f\"{frame_index:05d}\", input_is_linear\n            fpath = os.path.join(clip.input_asset.path, input_files[frame_index])\n            input_stem = os.path.splitext(input_files[frame_index])[0]\n            img = read_image_frame(fpath)\n            validate_frame_read(img, clip.name, frame_index, fpath)\n            return img, input_stem, input_is_linear\n\n    def _read_alpha_frame(\n        self,\n        clip: ClipEntry,\n        frame_index: int,\n        alpha_files: list[str],\n        alpha_cap: Any | None,\n    ) -> np.ndarray | None:\n        \"\"\"Read a single alpha/mask frame and normalize to [H, W] float32.\"\"\"\n        if alpha_cap:\n            ret, frame = alpha_cap.read()\n            if not ret:\n                return None\n            return frame[:, :, 2].astype(np.float32) / 255.0\n        else:\n            fpath = os.path.join(clip.alpha_asset.path, alpha_files[frame_index])\n            mask = read_mask_frame(fpath, clip.name, frame_index)\n            validate_frame_read(mask, clip.name, frame_index, fpath)\n            return mask\n\n    def _write_image(\n        self,\n        img: np.ndarray,\n        path: str,\n        fmt: str,\n        clip_name: str,\n        frame_index: int,\n    ) -> None:\n        \"\"\"Write a single image in the requested format.\"\"\"\n        if fmt == \"exr\":\n            # EXR requires float32 — convert if uint8 (e.g. pre-converted comp)\n            if img.dtype == np.uint8:\n                img = img.astype(np.float32) / 255.0\n            elif img.dtype != np.float32:\n                img = img.astype(np.float32)\n            validate_write(cv2.imwrite(path, img, EXR_WRITE_FLAGS), clip_name, frame_index, path)\n        else:\n            # PNG 8-bit\n            if img.dtype != np.uint8:\n                img = (np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8)\n            validate_write(cv2.imwrite(path, img), clip_name, frame_index, path)\n\n    def _write_manifest(\n        self,\n        output_root: str,\n        output_config: OutputConfig,\n        params: InferenceParams,\n    ) -> None:\n        \"\"\"Write run manifest recording expected outputs/extensions per run.\n\n        Codex: base resume on manifest, not hardcoded FG/Matte intersection.\n        Uses atomic write (tmp + rename) to prevent corruption.\n        \"\"\"\n        manifest = {\n            \"version\": 1,\n            \"enabled_outputs\": output_config.enabled_outputs,\n            \"formats\": {\n                \"fg\": output_config.fg_format,\n                \"matte\": output_config.matte_format,\n                \"comp\": output_config.comp_format,\n                \"processed\": output_config.processed_format,\n            },\n            \"params\": params.to_dict(),\n        }\n        manifest_path = os.path.join(output_root, \".corridorkey_manifest.json\")\n        tmp_path = manifest_path + \".tmp\"\n        try:\n            with open(tmp_path, \"w\") as f:\n                json.dump(manifest, f, indent=2)\n            # Atomic replace (os.replace is atomic on both POSIX and Windows)\n            os.replace(tmp_path, manifest_path)\n        except Exception as e:\n            logger.warning(f\"Failed to write manifest: {e}\")\n\n    def _write_outputs(\n        self,\n        res: dict,\n        dirs: dict[str, str],\n        input_stem: str,\n        clip_name: str,\n        frame_index: int,\n        output_config: OutputConfig | None = None,\n    ) -> None:\n        \"\"\"Write output types for a single frame respecting OutputConfig.\"\"\"\n        cfg = output_config or OutputConfig()\n        logger.debug(f\"Writing outputs for '{clip_name}' frame {frame_index} stem='{input_stem}'\")\n\n        pred_fg = res[\"fg\"]\n        pred_alpha = res[\"alpha\"]\n\n        # FG\n        if cfg.fg_enabled:\n            fg_bgr = cv2.cvtColor(pred_fg, cv2.COLOR_RGB2BGR)\n            fg_path = os.path.join(dirs[\"fg\"], f\"{input_stem}.{cfg.fg_format}\")\n            self._write_image(fg_bgr, fg_path, cfg.fg_format, clip_name, frame_index)\n\n        # Matte\n        if cfg.matte_enabled:\n            alpha = pred_alpha\n            if alpha.ndim == 3:\n                alpha = alpha[:, :, 0]\n            matte_path = os.path.join(dirs[\"matte\"], f\"{input_stem}.{cfg.matte_format}\")\n            self._write_image(alpha, matte_path, cfg.matte_format, clip_name, frame_index)\n\n        # Comp\n        if cfg.comp_enabled:\n            comp_srgb = res[\"comp\"]\n            comp_bgr = cv2.cvtColor(\n                (np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8),\n                cv2.COLOR_RGB2BGR,\n            )\n            comp_path = os.path.join(dirs[\"comp\"], f\"{input_stem}.{cfg.comp_format}\")\n            self._write_image(comp_bgr, comp_path, cfg.comp_format, clip_name, frame_index)\n\n        # Processed (RGBA premultiplied)\n        if cfg.processed_enabled and \"processed\" in res:\n            proc_rgba = res[\"processed\"]\n            proc_bgra = cv2.cvtColor(proc_rgba, cv2.COLOR_RGBA2BGRA)\n            proc_path = os.path.join(dirs[\"processed\"], f\"{input_stem}.{cfg.processed_format}\")\n            self._write_image(proc_bgra, proc_path, cfg.processed_format, clip_name, frame_index)\n\n    # --- Processing ---\n\n    def run_inference(\n        self,\n        clip: ClipEntry,\n        params: InferenceParams,\n        job: GPUJob | None = None,\n        on_progress: Callable[[str, int, int], None] | None = None,\n        on_warning: Callable[[str], None] | None = None,\n        skip_stems: set[str] | None = None,\n        output_config: OutputConfig | None = None,\n        frame_range: tuple[int, int] | None = None,\n    ) -> list[FrameResult]:\n        \"\"\"Run CorridorKey inference on a single clip.\n\n        Args:\n            clip: Must be in READY or COMPLETE state with both input_asset and alpha_asset.\n            params: Frozen inference parameters.\n            job: Optional GPUJob for cancel checking.\n            on_progress: Called with (clip_name, current_frame, total_frames).\n            on_warning: Called with warning messages for non-fatal issues.\n            skip_stems: Set of frame stems to skip (for resume support).\n            output_config: Which outputs to write and their formats.\n\n        Returns:\n            List of FrameResult for each frame.\n\n        Raises:\n            JobCancelledError: If job.is_cancelled becomes True.\n            Various CorridorKeyError subclasses for fatal issues.\n        \"\"\"\n        if clip.input_asset is None or clip.alpha_asset is None:\n            raise CorridorKeyError(f\"Clip '{clip.name}' missing input or alpha asset\")\n\n        t_start = time.monotonic()\n\n        with self._gpu_lock:\n            engine = self._get_engine()\n        dirs = ensure_output_dirs(clip.root_path)\n        cfg = output_config or OutputConfig()\n\n        # Write run manifest (Codex: resume must know which outputs were enabled)\n        self._write_manifest(dirs[\"root\"], cfg, params)\n\n        num_frames = validate_frame_counts(\n            clip.name,\n            clip.input_asset.frame_count,\n            clip.alpha_asset.frame_count,\n        )\n\n        # Open video captures or get file lists\n        input_cap = None\n        alpha_cap = None\n        input_files: list[str] = []\n        alpha_files: list[str] = []\n\n        if clip.input_asset.asset_type == \"video\":\n            input_cap = cv2.VideoCapture(clip.input_asset.path)\n        else:\n            input_files = clip.input_asset.get_frame_files()\n\n        if clip.alpha_asset.asset_type == \"video\":\n            alpha_cap = cv2.VideoCapture(clip.alpha_asset.path)\n        else:\n            alpha_files = clip.alpha_asset.get_frame_files()\n\n        results: list[FrameResult] = []\n        skipped: list[int] = []\n        skip_stems = skip_stems or set()\n\n        # Determine frame range (in/out markers or full clip)\n        if frame_range is not None:\n            range_start = max(0, frame_range[0])\n            range_end = min(num_frames - 1, frame_range[1])\n            frame_indices = range(range_start, range_end + 1)\n            range_count = range_end - range_start + 1\n        else:\n            frame_indices = range(num_frames)\n            range_count = num_frames\n\n        try:\n            for progress_i, i in enumerate(frame_indices):\n                # Check cancellation between frames\n                if job and job.is_cancelled:\n                    raise JobCancelledError(clip.name, i)\n\n                # Report progress every frame (enables responsive cancel + timer)\n                if on_progress:\n                    on_progress(clip.name, progress_i, range_count)\n\n                try:\n                    # Read input\n                    img, input_stem, is_linear = self._read_input_frame(\n                        clip,\n                        i,\n                        input_files,\n                        input_cap,\n                        params.input_is_linear,\n                    )\n                    if img is None:\n                        skipped.append(i)\n                        results.append(FrameResult(i, f\"{i:05d}\", False, \"video read failed\"))\n                        continue\n\n                    # Resume: skip frames that already have outputs\n                    if input_stem in skip_stems:\n                        results.append(FrameResult(i, input_stem, True, \"resumed (skipped)\"))\n                        continue\n\n                    # Read alpha\n                    mask = self._read_alpha_frame(clip, i, alpha_files, alpha_cap)\n                    if mask is None:\n                        skipped.append(i)\n                        results.append(FrameResult(i, input_stem, False, \"alpha read failed\"))\n                        continue\n\n                    # Resize mask if dimensions don't match input\n                    if mask.shape[:2] != img.shape[:2]:\n                        mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)\n\n                    # Process (GPU-locked — process_frame mutates model hooks)\n                    t_frame = time.monotonic()\n                    with self._gpu_lock:\n                        res = engine.process_frame(\n                            img,\n                            mask,\n                            input_is_linear=is_linear,\n                            fg_is_straight=True,\n                            despill_strength=params.despill_strength,\n                            auto_despeckle=params.auto_despeckle,\n                            despeckle_size=params.despeckle_size,\n                            refiner_scale=params.refiner_scale,\n                        )\n                    logger.debug(f\"Clip '{clip.name}' frame {i}: process_frame {time.monotonic() - t_frame:.3f}s\")\n\n                    # Write outputs\n                    self._write_outputs(res, dirs, input_stem, clip.name, i, cfg)\n                    results.append(FrameResult(i, input_stem, True))\n\n                except FrameReadError as e:\n                    logger.warning(str(e))\n                    skipped.append(i)\n                    results.append(FrameResult(i, f\"{i:05d}\", False, str(e)))\n                    if on_warning:\n                        on_warning(str(e))\n\n                except WriteFailureError as e:\n                    logger.error(str(e))\n                    results.append(FrameResult(i, f\"{i:05d}\", False, str(e)))\n                    if on_warning:\n                        on_warning(str(e))\n\n            # Final progress\n            if on_progress:\n                on_progress(clip.name, range_count, range_count)\n\n        finally:\n            if input_cap:\n                input_cap.release()\n            if alpha_cap:\n                alpha_cap.release()\n\n        # Summary\n        processed = sum(1 for r in results if r.success)\n        if skipped:\n            msg = (\n                f\"Clip '{clip.name}': {len(skipped)} frame(s) skipped: \"\n                f\"{skipped[:20]}{'...' if len(skipped) > 20 else ''}\"\n            )\n            logger.warning(msg)\n            if on_warning:\n                on_warning(msg)\n\n        t_total = time.monotonic() - t_start\n        range_label = f\" (range {frame_range[0]}-{frame_range[1]})\" if frame_range else \"\"\n        logger.info(\n            f\"Clip '{clip.name}': inference complete{range_label}. {processed}/{range_count} frames \"\n            f\"in {t_total:.1f}s ({t_total / max(processed, 1):.2f}s/frame avg)\"\n        )\n\n        # State transition — only set COMPLETE if full clip was processed\n        is_full_clip = frame_range is None or (frame_range[0] == 0 and frame_range[1] >= num_frames - 1)\n        if processed == range_count and is_full_clip:\n            try:\n                clip.transition_to(ClipState.COMPLETE)\n            except Exception as e:\n                logger.warning(f\"Clip '{clip.name}': state transition to COMPLETE failed: {e}\")\n\n        return results\n\n    # --- Single-Frame Reprocess (Preview) ---\n\n    def is_engine_loaded(self) -> bool:\n        \"\"\"True if the inference engine is already loaded in VRAM.\"\"\"\n        return self._active_model == _ActiveModel.INFERENCE and self._engine is not None\n\n    def reprocess_single_frame(\n        self,\n        clip: ClipEntry,\n        params: InferenceParams,\n        frame_index: int,\n        job: GPUJob | None = None,\n    ) -> dict | None:\n        \"\"\"Reprocess a single frame with current params.\n\n        Returns the result dict (fg, alpha, comp, processed) or None.\n        This runs through the GPU lock for thread safety.\n        Does NOT write to disk — returns in-memory results for preview.\n        \"\"\"\n        t_start = time.monotonic()\n        if clip.input_asset is None or clip.alpha_asset is None:\n            return None\n\n        if job and job.is_cancelled:\n            return None\n\n        with self._gpu_lock:\n            engine = self._get_engine()\n\n        # Read the specific input frame\n        if clip.input_asset.asset_type == \"video\":\n            img = read_video_frame_at(clip.input_asset.path, frame_index)\n        else:\n            input_files = clip.input_asset.get_frame_files()\n            if frame_index >= len(input_files):\n                return None\n            img = read_image_frame(os.path.join(clip.input_asset.path, input_files[frame_index]))\n        if img is None:\n            return None\n\n        # Read the specific alpha frame\n        if clip.alpha_asset.asset_type == \"video\":\n            mask = read_video_mask_at(clip.alpha_asset.path, frame_index)\n        else:\n            alpha_files = clip.alpha_asset.get_frame_files()\n            if frame_index >= len(alpha_files):\n                return None\n            mask = read_mask_frame(\n                os.path.join(clip.alpha_asset.path, alpha_files[frame_index]),\n                clip.name,\n                frame_index,\n            )\n        if mask is None:\n            return None\n\n        if mask.shape[:2] != img.shape[:2]:\n            mask = cv2.resize(mask, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_LINEAR)\n\n        if job and job.is_cancelled:\n            return None\n\n        with self._gpu_lock:\n            res = engine.process_frame(\n                img,\n                mask,\n                input_is_linear=params.input_is_linear,\n                fg_is_straight=True,\n                despill_strength=params.despill_strength,\n                auto_despeckle=params.auto_despeckle,\n                despeckle_size=params.despeckle_size,\n                refiner_scale=params.refiner_scale,\n            )\n        logger.debug(f\"Clip '{clip.name}' frame {frame_index}: reprocess {time.monotonic() - t_start:.3f}s\")\n        return res\n\n    # --- GVM Alpha Generation ---\n\n    def run_gvm(\n        self,\n        clip: ClipEntry,\n        job: GPUJob | None = None,\n        on_progress: Callable[[str, int, int], None] | None = None,\n        on_warning: Callable[[str], None] | None = None,\n    ) -> None:\n        \"\"\"Run GVM auto alpha generation for a clip.\n\n        Transitions clip: RAW → READY (creates AlphaHint directory).\n\n        Args:\n            clip: Must be in RAW state with input_asset.\n            job: Optional GPUJob for cancel checking.\n            on_progress: Progress callback (GVM is monolithic, reports start/end).\n            on_warning: Warning callback.\n        \"\"\"\n        if clip.input_asset is None:\n            raise CorridorKeyError(f\"Clip '{clip.name}' missing input asset for GVM\")\n\n        t_start = time.monotonic()\n\n        with self._gpu_lock:\n            gvm = self._get_gvm()\n\n        alpha_dir = os.path.join(clip.root_path, \"AlphaHint\")\n        os.makedirs(alpha_dir, exist_ok=True)\n\n        if on_progress:\n            on_progress(clip.name, 0, 1)\n\n        # Check cancel before starting\n        if job and job.is_cancelled:\n            raise JobCancelledError(clip.name, 0)\n\n        # Per-batch progress callback — GVM iterates over frames internally\n        def _gvm_progress(batch_idx: int, total_batches: int) -> None:\n            if on_progress:\n                on_progress(clip.name, batch_idx, total_batches)\n            # Check cancel between batches\n            if job and job.is_cancelled:\n                raise JobCancelledError(clip.name, batch_idx)\n\n        try:\n            gvm.process_sequence(\n                input_path=clip.input_asset.path,\n                output_dir=clip.root_path,\n                num_frames_per_batch=1,\n                decode_chunk_size=1,\n                denoise_steps=1,\n                mode=\"matte\",\n                write_video=False,\n                direct_output_dir=alpha_dir,\n                progress_callback=_gvm_progress,\n            )\n        except JobCancelledError:\n            raise\n        except Exception as e:\n            if job and job.is_cancelled:\n                raise JobCancelledError(clip.name, 0) from None\n            raise CorridorKeyError(f\"GVM failed for '{clip.name}': {e}\") from e\n\n        # Refresh alpha asset\n        clip.alpha_asset = ClipAsset(alpha_dir, \"sequence\")\n\n        if on_progress:\n            on_progress(clip.name, 1, 1)\n\n        # Transition RAW → READY\n        try:\n            clip.transition_to(ClipState.READY)\n        except Exception as e:\n            if on_warning:\n                on_warning(f\"State transition after GVM: {e}\")\n\n        elapsed = time.monotonic() - t_start\n        logger.info(f\"GVM complete for '{clip.name}': {clip.alpha_asset.frame_count} alpha frames in {elapsed:.1f}s\")\n\n    # --- VideoMaMa Alpha Generation ---\n\n    def run_videomama(\n        self,\n        clip: ClipEntry,\n        job: GPUJob | None = None,\n        on_progress: Callable[[str, int, int], None] | None = None,\n        on_warning: Callable[[str], None] | None = None,\n        on_status: Callable[[str], None] | None = None,\n        chunk_size: int = 50,\n    ) -> None:\n        \"\"\"Run VideoMaMa guided alpha generation for a clip.\n\n        Transitions clip: MASKED → READY (creates AlphaHint directory).\n\n        Args:\n            clip: Must be in MASKED state with input_asset and mask_asset.\n            job: Optional GPUJob for cancel checking.\n            on_progress: Progress callback with per-chunk updates.\n            on_warning: Warning callback.\n            on_status: Phase status callback (e.g. \"Loading model...\").\n            chunk_size: Frames per chunk (lower = less RAM, default 50).\n        \"\"\"\n        if clip.input_asset is None:\n            raise CorridorKeyError(f\"Clip '{clip.name}' missing input asset for VideoMaMa\")\n        if clip.mask_asset is None:\n            raise CorridorKeyError(f\"Clip '{clip.name}' missing mask asset for VideoMaMa\")\n\n        def _status(msg: str) -> None:\n            logger.info(f\"VideoMaMa [{clip.name}]: {msg}\")\n            if on_status:\n                on_status(msg)\n\n        def _check_cancel(phase: str = \"\") -> None:\n            if job and job.is_cancelled:\n                raise JobCancelledError(clip.name, 0)\n\n        t_start = time.monotonic()\n\n        # ── Phase 1: Load model ──\n        _status(\"Loading model...\")\n        with self._gpu_lock:\n            pipeline = self._get_videomama_pipeline()\n        _check_cancel(\"model load\")\n\n        alpha_dir = os.path.join(clip.root_path, \"AlphaHint\")\n        os.makedirs(alpha_dir, exist_ok=True)\n\n        # Don't report progress yet — phase status is showing in the status bar.\n        # Sending on_progress(0, N) would switch the status bar to frame-counter\n        # mode and overwrite the phase text on every tick.\n\n        # ── Phase 2: Load input frames ──\n        _status(\"Loading frames...\")\n        input_frames = self._load_frames_for_videomama(\n            clip.input_asset,\n            clip.name,\n            job=job,\n            on_status=on_status,\n        )\n        _check_cancel(\"frame load\")\n\n        # ── Phase 3: Load + stem-match masks ──\n        _status(\"Loading masks...\")\n        mask_stems: dict[str, np.ndarray] = {}\n        if clip.mask_asset.asset_type == \"sequence\":\n            mask_files = clip.mask_asset.get_frame_files()\n            for _i, fname in enumerate(mask_files):\n                _check_cancel(\"mask load\")\n                fpath = os.path.join(clip.mask_asset.path, fname)\n                m = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)\n                if m is not None:\n                    _, binary = cv2.threshold(m, 10, 255, cv2.THRESH_BINARY)\n                    stem = os.path.splitext(fname)[0]\n                    mask_stems[stem] = binary\n        else:\n            raw_masks = self._load_mask_frames_for_videomama(clip.mask_asset, clip.name)\n            for i, m in enumerate(raw_masks):\n                mask_stems[f\"frame_{i:06d}\"] = m\n\n        # Build output filenames from input stems\n        if clip.input_asset and clip.input_asset.asset_type == \"sequence\":\n            input_names = clip.input_asset.get_frame_files()\n        else:\n            input_names = [f\"frame_{i:06d}.png\" for i in range(len(input_frames))]\n\n        # Align masks to input frames by stem, defaulting to all-black\n        num_frames = len(input_frames)\n        mask_frames = []\n        for fname in input_names:\n            stem = os.path.splitext(fname)[0]\n            if stem in mask_stems:\n                mask_frames.append(mask_stems[stem])\n            else:\n                h_m, w_m = input_frames[0].shape[:2] if input_frames else (4, 4)\n                mask_frames.append(np.zeros((h_m, w_m), dtype=np.uint8))\n\n        # ── Resume logic ──\n        existing_alpha = []\n        if os.path.isdir(alpha_dir):\n            existing_alpha = [f for f in os.listdir(alpha_dir) if f.lower().endswith((\".png\", \".jpg\", \".jpeg\"))]\n        n_existing = len(existing_alpha)\n        completed_chunks = n_existing // chunk_size\n        start_chunk = max(0, completed_chunks - 1)\n        start_frame = start_chunk * chunk_size\n        if start_frame > 0:\n            keep = set()\n            for i in range(start_frame):\n                if i < len(input_names):\n                    stem = os.path.splitext(input_names[i])[0]\n                    keep.add(f\"{stem}.png\")\n            for fname in existing_alpha:\n                if fname not in keep:\n                    os.remove(os.path.join(alpha_dir, fname))\n            logger.info(\n                f\"VideoMaMa resuming for '{clip.name}': {n_existing} alpha frames existed, \"\n                f\"rolling back to chunk {start_chunk} (frame {start_frame})\"\n            )\n\n        # ── Phase 4: Inference (per-chunk) ──\n        sys.path.insert(0, os.path.join(BASE_DIR, \"VideoMaMaInferenceModule\"))\n        from VideoMaMaInferenceModule.inference import run_inference\n\n        total_chunks = (num_frames + chunk_size - 1) // chunk_size\n        _status(f\"Running inference (chunk 1/{total_chunks})...\")\n        frames_written = start_frame\n        for chunk_idx, chunk_output in enumerate(\n            run_inference(pipeline, input_frames, mask_frames, chunk_size=chunk_size)\n        ):\n            _check_cancel(\"inference\")\n\n            # Skip already-completed chunks (resume)\n            if chunk_idx < start_chunk:\n                frames_written += len(chunk_output)\n                if on_progress:\n                    on_progress(clip.name, frames_written, num_frames)\n                continue\n\n            _status(f\"Processing chunk {chunk_idx + 1}/{total_chunks}...\")\n\n            # Write chunk frames\n            t_chunk = time.monotonic()\n            for frame in chunk_output:\n                out_bgr = cv2.cvtColor(\n                    (np.clip(frame, 0.0, 1.0) * 255.0).astype(np.uint8),\n                    cv2.COLOR_RGB2BGR,\n                )\n                if frames_written < len(input_names):\n                    stem = os.path.splitext(input_names[frames_written])[0]\n                    out_name = f\"{stem}.png\"\n                else:\n                    out_name = f\"frame_{frames_written:06d}.png\"\n                out_path = os.path.join(alpha_dir, out_name)\n                cv2.imwrite(out_path, out_bgr)\n                frames_written += 1\n            chunk_elapsed = time.monotonic() - t_chunk\n            logger.debug(f\"Clip '{clip.name}' chunk {chunk_idx}: {len(chunk_output)} frames in {chunk_elapsed:.3f}s\")\n\n            if on_progress:\n                on_progress(clip.name, frames_written, num_frames)\n\n        # Refresh alpha asset\n        clip.alpha_asset = ClipAsset(alpha_dir, \"sequence\")\n\n        # Transition MASKED → READY\n        try:\n            clip.transition_to(ClipState.READY)\n        except Exception as e:\n            if on_warning:\n                on_warning(f\"State transition after VideoMaMa: {e}\")\n\n        elapsed = time.monotonic() - t_start\n        logger.info(f\"VideoMaMa complete for '{clip.name}': {frames_written} alpha frames in {elapsed:.1f}s\")\n\n    def _load_frames_for_videomama(\n        self,\n        asset: ClipAsset,\n        clip_name: str,\n        job: GPUJob | None = None,\n        on_status: Callable[[str], None] | None = None,\n    ) -> list[np.ndarray]:\n        \"\"\"Load input frames for VideoMaMa as uint8 RGB [0, 255].\n\n        The VideoMaMa inference code expects uint8 arrays for PIL conversion.\n        Reports loading progress via on_status and checks cancel via job.\n        \"\"\"\n        if asset.asset_type == \"video\":\n            raw = read_video_frames(asset.path)\n            return [(np.clip(f, 0.0, 1.0) * 255.0).astype(np.uint8) for f in raw]\n        frames = []\n        files = asset.get_frame_files()\n        total = len(files)\n        for i, fname in enumerate(files):\n            if job and job.is_cancelled:\n                from .errors import JobCancelledError\n\n                raise JobCancelledError(clip_name, i)\n            fpath = os.path.join(asset.path, fname)\n            img = read_image_frame(fpath, gamma_correct_exr=True)\n            if img is not None:\n                frames.append((np.clip(img, 0.0, 1.0) * 255.0).astype(np.uint8))\n            if on_status and i % 20 == 0 and i > 0:\n                on_status(f\"Loading frames ({i}/{total})...\")\n        return frames\n\n    def _load_mask_frames_for_videomama(self, asset: ClipAsset, clip_name: str) -> list[np.ndarray]:\n        \"\"\"Load mask frames for VideoMaMa as uint8 grayscale [0, 255].\n\n        The VideoMaMa inference code expects uint8 arrays for PIL conversion.\n        Binary threshold at 10: anything above → 255 (foreground), else → 0.\n        \"\"\"\n\n        def _threshold_mask(bgr_frame: np.ndarray) -> np.ndarray:\n            gray = cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2GRAY)\n            _, binary = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)\n            return binary  # uint8\n\n        if asset.asset_type == \"video\":\n            return read_video_frames(asset.path, processor=_threshold_mask)\n        masks = []\n        for fname in asset.get_frame_files():\n            fpath = os.path.join(asset.path, fname)\n            mask = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)\n            if mask is None:\n                continue\n            _, binary = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)\n            masks.append(binary)  # uint8\n        return masks\n"
  },
  {
    "path": "backend/validators.py",
    "content": "\"\"\"Validation utilities for frame processing.\n\nAll validators either return cleaned data or raise typed exceptions from errors.py.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport logging\nimport os\n\nimport numpy as np\n\nfrom .errors import (\n    FrameMismatchError,\n    FrameReadError,\n    MaskChannelError,\n    WriteFailureError,\n)\n\nlogger = logging.getLogger(__name__)\n\n\ndef validate_frame_counts(\n    clip_name: str,\n    input_count: int,\n    alpha_count: int,\n    strict: bool = False,\n) -> int:\n    \"\"\"Validate that input and alpha frame counts are compatible.\n\n    Args:\n        clip_name: For error messages.\n        input_count: Number of input frames.\n        alpha_count: Number of alpha frames.\n        strict: If True, raises on mismatch. If False, logs warning and returns min.\n\n    Returns:\n        The number of frames to process (min of both).\n\n    Raises:\n        FrameMismatchError: If strict=True and counts differ.\n    \"\"\"\n    if input_count != alpha_count:\n        if strict:\n            raise FrameMismatchError(clip_name, input_count, alpha_count)\n        logger.warning(\n            f\"Clip '{clip_name}': frame count mismatch — \"\n            f\"input has {input_count}, alpha has {alpha_count}. \"\n            f\"Truncating to {min(input_count, alpha_count)}.\"\n        )\n    return min(input_count, alpha_count)\n\n\ndef normalize_mask_channels(\n    mask: np.ndarray,\n    clip_name: str = \"\",\n    frame_index: int = 0,\n) -> np.ndarray:\n    \"\"\"Reduce a mask to a single-channel 2D array.\n\n    Handles any channel count: extracts first channel from multi-channel masks.\n\n    Args:\n        mask: Input mask array, any shape [H, W] or [H, W, C].\n        clip_name: For error messages.\n        frame_index: For error messages.\n\n    Returns:\n        2D numpy array [H, W] with float32 values.\n    \"\"\"\n    if mask.ndim == 3:\n        if mask.shape[2] == 0:\n            raise MaskChannelError(clip_name, frame_index, 0)\n        # Always extract first channel regardless of channel count\n        mask = mask[:, :, 0]\n    elif mask.ndim != 2:\n        raise MaskChannelError(clip_name, frame_index, mask.ndim)\n\n    return mask.astype(np.float32) if mask.dtype != np.float32 else mask\n\n\ndef normalize_mask_dtype(mask: np.ndarray) -> np.ndarray:\n    \"\"\"Convert mask to float32 [0.0, 1.0] from any common dtype.\"\"\"\n    if mask.dtype == np.uint8:\n        return mask.astype(np.float32) / 255.0\n    elif mask.dtype == np.uint16:\n        return mask.astype(np.float32) / 65535.0\n    elif mask.dtype == np.float64:\n        return mask.astype(np.float32)\n    elif mask.dtype == np.float32:\n        return mask\n    else:\n        return mask.astype(np.float32)\n\n\ndef validate_frame_read(\n    frame: np.ndarray | None,\n    clip_name: str,\n    frame_index: int,\n    path: str,\n) -> np.ndarray:\n    \"\"\"Validate that a frame was read successfully.\n\n    Args:\n        frame: The result of cv2.imread() — None if read failed.\n        clip_name: For error messages.\n        frame_index: For error messages.\n        path: File path that was read.\n\n    Returns:\n        The frame array (unchanged).\n\n    Raises:\n        FrameReadError: If frame is None.\n    \"\"\"\n    if frame is None:\n        raise FrameReadError(clip_name, frame_index, path)\n    return frame\n\n\ndef validate_write(\n    success: bool,\n    clip_name: str,\n    frame_index: int,\n    path: str,\n) -> None:\n    \"\"\"Validate that a cv2.imwrite() call succeeded.\n\n    Args:\n        success: Return value of cv2.imwrite().\n        clip_name: For error messages.\n        frame_index: For error messages.\n        path: File path that was written.\n\n    Raises:\n        WriteFailureError: If success is False.\n    \"\"\"\n    if not success:\n        raise WriteFailureError(clip_name, frame_index, path)\n\n\ndef ensure_output_dirs(clip_root: str) -> dict[str, str]:\n    \"\"\"Create output subdirectories for a clip and return their paths.\n\n    Returns:\n        Dict with keys: 'root', 'fg', 'matte', 'comp', 'processed'\n    \"\"\"\n    out_root = os.path.join(clip_root, \"Output\")\n    dirs = {\n        \"root\": out_root,\n        \"fg\": os.path.join(out_root, \"FG\"),\n        \"matte\": os.path.join(out_root, \"Matte\"),\n        \"comp\": os.path.join(out_root, \"Comp\"),\n        \"processed\": os.path.join(out_root, \"Processed\"),\n    }\n    for d in dirs.values():\n        os.makedirs(d, exist_ok=True)\n    return dirs\n"
  },
  {
    "path": "clip_manager.py",
    "content": "from __future__ import annotations\n\nimport argparse\nimport glob\nimport logging\nimport os\nimport shutil\nimport sys\nfrom dataclasses import dataclass\nfrom typing import TYPE_CHECKING, Callable\n\n# Enable OpenEXR support in OpenCV — needed for EXR I/O throughout the pipeline.\n# Must be set before any cv2.imread/imwrite calls on .exr files.\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\n\nimport cv2\nimport numpy as np\n\nfrom backend.frame_io import EXR_WRITE_FLAGS, read_image_frame\nfrom device_utils import resolve_device\n\nif TYPE_CHECKING:\n    from gvm_core import GVMProcessor\nfrom BiRefNetModule.wrapper import BiRefNetHandler, usage_to_weights_file\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass InferenceSettings:\n    \"\"\"Settings for CorridorKey inference, extracted from interactive prompts.\n\n    Can be constructed directly for non-interactive use (Nuke, Houdini, batch scripts).\n    \"\"\"\n\n    input_is_linear: bool = False\n    despill_strength: float = 0.5  # 0.0–1.0\n    auto_despeckle: bool = True\n    despeckle_size: int = 400\n    refiner_scale: float = 1.0\n\n\n# Core Paths\nBASE_DIR = os.path.dirname(os.path.abspath(__file__))\nCLIPS_DIR = os.path.join(BASE_DIR, \"ClipsForInference\")\nOUTPUT_DIR = os.path.join(BASE_DIR, \"Output\")\n\n# Network Mapping\n# Windows Drive -> Linux Mount Point\nWIN_DRIVE_ROOT = \"V:\\\\\"\nLINUX_MOUNT_ROOT = \"/mnt/ssd-storage\"\n\n\n# --- Helpers ---\ndef is_image_file(filename: str) -> bool:\n    return filename.lower().endswith((\".png\", \".jpg\", \".jpeg\", \".exr\", \".tif\", \".tiff\", \".bmp\"))\n\n\ndef is_video_file(filename: str) -> bool:\n    return filename.lower().endswith((\".mp4\", \".mov\", \".avi\", \".mkv\"))\n\n\ndef map_path(win_path: str) -> str:\n    r\"\"\"\n    Converts a Windows path (example: V:\\Projects\\Shot1) to the local Linux path.\n    \"\"\"\n    # Normalize slashes\n    win_path = win_path.strip()\n\n    # Check if it starts with the drive letter\n    if win_path.upper().startswith(WIN_DRIVE_ROOT.upper()):\n        # Remove drive letter\n        rel_path = win_path[len(WIN_DRIVE_ROOT) :]\n        # Combine and flip slashes\n        linux_path = os.path.join(LINUX_MOUNT_ROOT, rel_path).replace(\"\\\\\", \"/\")\n        return linux_path\n\n    # If not on V:, maybe it's already a linux path or invalid?\n    return win_path\n\n\n# --- Classes ---\nclass ClipAsset:\n    def __init__(self, path: str, asset_type: str) -> None:\n        self.path = path\n        self.type = asset_type  # 'sequence' or 'video'\n        self.frame_count = 0\n        self._calculate_length()\n\n    def _calculate_length(self) -> None:\n        if self.type == \"sequence\":\n            files = sorted([f for f in os.listdir(self.path) if is_image_file(f)])\n            self.frame_count = len(files)\n        elif self.type == \"video\":\n            cap = cv2.VideoCapture(self.path)\n            if not cap.isOpened():\n                self.frame_count = 0\n            else:\n                self.frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n            cap.release()\n\n\nclass ClipEntry:\n    def __init__(self, name: str, root_path: str) -> None:\n        self.name = name\n        self.root_path = root_path\n        self.input_asset: ClipAsset | None = None\n        self.alpha_asset: ClipAsset | None = None\n\n    def find_assets(self) -> None:\n        # 1. Look for Input\n        input_dir = os.path.join(self.root_path, \"Input\")\n\n        # Check for directory first\n        if os.path.isdir(input_dir):\n            if not os.listdir(input_dir):\n                raise ValueError(f\"Clip '{self.name}': 'Input' directory is empty.\")\n            self.input_asset = ClipAsset(input_dir, \"sequence\")\n        else:\n            # Check for video file (Case-Insensitive)\n            candidates = glob.glob(os.path.join(self.root_path, \"[Ii]nput.*\"))\n            candidates = [c for c in candidates if is_video_file(c)]\n\n            if candidates:\n                self.input_asset = ClipAsset(candidates[0], \"video\")\n            else:\n                # Fallback: Look for ANY video file in the directory\n                all_files = glob.glob(os.path.join(self.root_path, \"*\"))\n                video_files = [f for f in all_files if is_video_file(f)]\n\n                if video_files:\n                    logger.info(f\"Clip '{self.name}': Using '{os.path.basename(video_files[0])}' as Input.\")\n                    self.input_asset = ClipAsset(video_files[0], \"video\")\n                else:\n                    raise ValueError(f\"Clip '{self.name}': No 'Input' directory or video file found.\")\n\n        if self.input_asset.frame_count == 0:\n            raise ValueError(f\"Clip '{self.name}': Input asset has 0 frames or could not be read.\")\n\n        # 2. Look for Alpha\n        # Check for 'AlphaHint' or 'alphahint' directory\n        alpha_dir_upper = os.path.join(self.root_path, \"AlphaHint\")\n        alpha_dir_lower = os.path.join(self.root_path, \"alphahint\")\n\n        target_alpha_dir = None\n        if os.path.isdir(alpha_dir_upper):\n            target_alpha_dir = alpha_dir_upper\n        elif os.path.isdir(alpha_dir_lower):\n            target_alpha_dir = alpha_dir_lower\n\n        if target_alpha_dir:\n            if not os.listdir(target_alpha_dir):\n                logger.warning(f\"Clip '{self.name}': AlphaHint directory exists but is empty. Marking for generation.\")\n                self.alpha_asset = None\n            else:\n                # Check for image sequence first\n                self.alpha_asset = ClipAsset(target_alpha_dir, \"sequence\")\n                if self.alpha_asset.frame_count == 0:\n                    # Fallback: check for video file inside the AlphaHint directory\n                    video_candidates = [f for f in os.listdir(target_alpha_dir) if is_video_file(f)]\n                    if video_candidates:\n                        self.alpha_asset = ClipAsset(os.path.join(target_alpha_dir, video_candidates[0]), \"video\")\n                    else:\n                        logger.warning(\n                            f\"Clip '{self.name}': AlphaHint directory has no valid image or video files.\"\n                            \" Marking for generation.\"\n                        )\n                        self.alpha_asset = None\n        else:\n            # Check for video file (Case-Insensitive)\n            # Match AlphaHint.* or alphahint.*\n            candidates = glob.glob(os.path.join(self.root_path, \"[Aa]lpha[Hh]int.*\"))\n            candidates = [c for c in candidates if is_video_file(c)]\n\n            if candidates:\n                self.alpha_asset = ClipAsset(candidates[0], \"video\")\n            else:\n                self.alpha_asset = None  # Missing, needs generation\n\n    def validate_pair(self) -> None:\n        if self.input_asset and self.alpha_asset:\n            if self.input_asset.frame_count != self.alpha_asset.frame_count:\n                raise ValueError(\n                    f\"Clip '{self.name}': Frame count mismatch! \"\n                    f\"Input: {self.input_asset.frame_count}, Alpha: {self.alpha_asset.frame_count}\"\n                )\n\n\n# --- Logic ---\n\n\ndef get_gvm_processor(device: str = \"cpu\") -> GVMProcessor:\n    try:\n        from gvm_core import GVMProcessor\n\n        return GVMProcessor(device=device)\n    except ImportError:\n        raise ImportError(\n            \"Could not import gvm_core. Please ensure 'gvm_core' is in the project root and requirements are installed.\"\n        ) from None\n    except Exception as e:\n        raise RuntimeError(f\"Failed to initialize GVM Processor: {e}\") from e\n\n\ndef generate_alphas(\n    clips,\n    device=None,\n    *,\n    on_clip_start: Callable[[str, int], None] | None = None,\n):\n    clips_to_process = [c for c in clips if c.alpha_asset is None]\n\n    if not clips_to_process:\n        logger.info(\"All clips have valid Alpha assets. No generation needed.\")\n        return\n\n    logger.info(f\"Found {len(clips_to_process)} clips missing Alpha.\")\n\n    if device is None:\n        device = resolve_device()\n\n    try:\n        processor = get_gvm_processor(device=device)\n    except ImportError as e:\n        logger.error(f\"GVM Import Error: {e}\")\n        logger.error(\"Skipping GVM generation. Please install GVM requirements if you wish to use this feature.\")\n        return\n    except Exception as e:\n        logger.error(f\"GVM Initialization Error: {e}\")\n        return\n\n    for clip in clips_to_process:\n        logger.info(f\"Generating Alpha for: {clip.name}\")\n        if on_clip_start:\n            on_clip_start(clip.name, len(clips_to_process))\n\n        alpha_output_dir = os.path.join(clip.root_path, \"AlphaHint\")\n        if os.path.exists(alpha_output_dir):\n            shutil.rmtree(alpha_output_dir)\n        os.makedirs(alpha_output_dir, exist_ok=True)\n\n        try:\n            processor.process_sequence(\n                input_path=clip.input_asset.path,\n                output_dir=None,\n                num_frames_per_batch=1,\n                decode_chunk_size=1,\n                denoise_steps=1,\n                mode=\"matte\",\n                write_video=False,\n                direct_output_dir=alpha_output_dir,\n            )\n\n            # Post-Process: Naming Convention\n            generated_files = sorted([f for f in os.listdir(alpha_output_dir) if f.endswith(\".png\")])\n\n            if not generated_files:\n                logger.error(f\"GVM finished but no PNGs found in {alpha_output_dir}\")\n                continue\n\n            if clip.input_asset.type == \"sequence\":\n                in_files = sorted([f for f in os.listdir(clip.input_asset.path) if is_image_file(f)])\n                stems = [os.path.splitext(f)[0] for f in in_files]\n            else:\n                base_name = os.path.splitext(os.path.basename(clip.input_asset.path))[0]\n                stems = [base_name] * len(generated_files)\n\n            for i, gvm_file in enumerate(generated_files):\n                if i >= len(stems):\n                    break\n\n                stem = stems[i]\n                new_name = f\"{stem}_alphaHint_{i:04d}.png\"\n\n                old_path = os.path.join(alpha_output_dir, gvm_file)\n                new_path = os.path.join(alpha_output_dir, new_name)\n\n                if old_path != new_path:\n                    os.rename(old_path, new_path)\n\n            logger.info(f\"Saved {len(generated_files)} alpha frames to {alpha_output_dir}\")\n\n        except Exception as e:\n            logger.error(f\"Error generating alpha for {clip.name}: {e}\")\n            import traceback\n\n            traceback.print_exc()\n\n\ndef get_birefnet_usage_options():\n    return list(usage_to_weights_file.keys())\n\n\ndef run_birefnet(\n    clips,\n    device=None,\n    usage=\"General\",\n    dilate_radius=0,\n    *,\n    on_clip_start: Callable[[str, int], None] | None = None,\n    on_frame_complete: Callable[[int, int], None] | None = None,\n):\n    clips_to_process = [c for c in clips if c.alpha_asset is None]\n\n    if not clips_to_process:\n        logger.info(\"All clips have valid Alpha assets. No BiRefNet generation needed.\")\n        return\n\n    if device is None:\n        device = resolve_device()\n\n    logger.info(f\"Found {len(clips_to_process)} clips missing Alpha.\")\n\n    logger.info(f\"Initializing BiRefNet ({usage}) on {device}...\")\n    # Initialize the handler once\n    try:\n        handler = BiRefNetHandler(device=device, usage=usage)\n    except ImportError as e:\n        logger.error(f\"BiRefNet Import Error: {e}\")\n        return\n    except Exception as e:\n        logger.error(f\"BiRefNet Initialization Error: {e}\")\n        return\n\n    try:\n        for clip in clips_to_process:\n            logger.info(f\"Generating BiRefNet Alpha for: {clip.name}\")\n            if on_clip_start:\n                on_clip_start(clip.name, clip.input_asset.frame_count)\n\n            alpha_output_dir = os.path.join(clip.root_path, \"AlphaHint\")\n            os.makedirs(alpha_output_dir, exist_ok=True)\n\n            try:\n                handler.process(\n                    input_path=clip.input_asset.path,\n                    alpha_output_dir=alpha_output_dir,\n                    dilate_radius=dilate_radius,\n                    on_frame_complete=on_frame_complete,\n                )\n                logger.info(f\"BiRefNet complete for {clip.name}\")\n            except Exception as e:\n                logger.error(f\"BiRefNet failed for {clip.name}: {e}\")\n                import traceback\n\n                traceback.print_exc()\n\n    finally:\n        handler.cleanup()\n\n\ndef run_videomama(\n    clips: list[ClipEntry],\n    chunk_size: int = 50,\n    device: str | None = None,\n    *,\n    on_clip_start: Callable[[str, int], None] | None = None,\n    on_frame_complete: Callable[[int, int], None] | None = None,\n) -> None:\n    \"\"\"\n    Runs VideoMaMa on clips that have VideoMamaMaskHint but NO AlphaHint.\n    \"\"\"\n    # Process if:\n    # 1. Has VideoMamaMaskHint (File or Folder, Case-Insensitive)\n    # 2. AND (Alpha is Missing OR Alpha is a Video File we want to upgrade)\n\n    clips_to_process = []\n    clip_mask_paths = {}  # Store the resolved mask path for each clip\n\n    for c in clips:\n        # Search for 'videomamamaskhint' asset (Strict: videomamamaskhint.ext or VideoMamaMaskHint/)\n        candidates = []\n        for f in os.listdir(c.root_path):\n            stem, _ = os.path.splitext(f)\n            if stem.lower() == \"videomamamaskhint\":\n                candidates.append(f)\n\n        mask_asset_path = None\n        has_mask = False\n\n        # Priority: Directory > Video File\n        # Check directories first\n        for cand in candidates:\n            path = os.path.join(c.root_path, cand)\n            if os.path.isdir(path) and len(os.listdir(path)) > 0:\n                has_mask = True\n                mask_asset_path = path\n                break\n\n        # If no directory, check files\n        if not has_mask:\n            for cand in candidates:\n                path = os.path.join(c.root_path, cand)\n                if os.path.isfile(path) and is_video_file(path):\n                    has_mask = True\n                    mask_asset_path = path\n                    break\n\n        if not has_mask:\n            continue\n\n        # Store for later\n        clip_mask_paths[c.name] = mask_asset_path\n\n        if c.alpha_asset is None:\n            clips_to_process.append(c)\n        elif c.alpha_asset.type == \"video\":\n            clips_to_process.append(c)\n\n    if not clips_to_process:\n        logger.info(\"No candidates for VideoMaMa (looking for VideoMamaMaskHint + [NoAlpha OR VideoAlpha]).\")\n        return\n\n    logger.info(f\"Found {len(clips_to_process)} clips for VideoMaMa processing.\")\n\n    # Import locally — sys.path mutation is needed because VideoMaMaInferenceModule\n    # uses intra-package imports that assume its directory is on the path.\n    try:\n        sys.path.append(os.path.join(BASE_DIR, \"VideoMaMaInferenceModule\"))\n        from VideoMaMaInferenceModule.inference import load_videomama_model\n        from VideoMaMaInferenceModule.inference import run_inference as run_videomama_frames\n    except ImportError as e:\n        logger.error(f\"Failed to import VideoMaMa: {e}\")\n        return\n\n    if device is None:\n        device = resolve_device()\n\n    logger.info(\"Loading VideoMaMa Pipeline...\")\n    pipeline = load_videomama_model(device=device)\n\n    for clip in clips_to_process:\n        logger.info(f\"Running VideoMaMa on: {clip.name}\")\n        if on_clip_start:\n            on_clip_start(clip.name, len(clips_to_process))\n\n        # Retrieve resolved path\n        mask_hint_path = clip_mask_paths[clip.name]\n        logger.info(f\"  Using VideoMamaMaskHint: {os.path.basename(mask_hint_path)}\")\n\n        alpha_output_dir = os.path.join(clip.root_path, \"AlphaHint\")\n\n        # Load Inputs\n        # 1. Input Frames (RGB)\n        input_frames = []\n        if clip.input_asset.type == \"video\":\n            cap = cv2.VideoCapture(clip.input_asset.path)\n            while True:\n                ret, frame = cap.read()\n                if not ret:\n                    break\n                input_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))\n            cap.release()\n        else:\n            files = sorted([f for f in os.listdir(clip.input_asset.path) if is_image_file(f)])\n            for f in files:\n                fpath = os.path.join(clip.input_asset.path, f)\n                # Handle EXR (Float 0-1) vs Standard (Int 0-255)\n                if f.lower().endswith(\".exr\"):\n                    img = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)\n                    if img is not None:\n                        # Normalize Float 0-1\n                        img = np.clip(img, 0.0, 1.0)\n                        # Linear -> sRGB (Gamma 2.2 Approximation) for VideoMaMa\n                        img = img ** (1.0 / 2.2)\n                        # 0-255 uint8\n                        img = (img * 255.0).astype(np.uint8)\n                        # Ensure 3 channels\n                        if img.ndim == 2:\n                            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)\n                        elif img.shape[2] == 4:\n                            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)\n                else:\n                    img = cv2.imread(fpath)\n\n                if img is not None:\n                    input_frames.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))\n\n        # 2. Mask Frames\n        mask_frames = []\n\n        # Check if VideoMamaMaskHint is a directory or a file (video)\n        if os.path.isdir(mask_hint_path):\n            # Directory of Images\n            mask_files = sorted([f for f in os.listdir(mask_hint_path) if is_image_file(f)])\n            for f in mask_files:\n                fpath = os.path.join(mask_hint_path, f)\n                m = None\n\n                # Handle EXR Masks\n                if f.lower().endswith(\".exr\"):\n                    m = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)\n                    if m is not None:\n                        if m.ndim == 3:\n                            m = m[:, :, 0]\n                        m = np.clip(m, 0.0, 1.0)\n                        m = (m * 255.0).astype(np.uint8)\n                else:\n                    # Standard Masks\n                    m = cv2.imread(fpath, cv2.IMREAD_GRAYSCALE)\n\n                if m is not None:\n                    # Force Binary Thresholding\n                    _, m = cv2.threshold(m, 10, 255, cv2.THRESH_BINARY)\n                    mask_frames.append(m)\n\n        elif os.path.isfile(mask_hint_path):\n            # Handle Video File\n            cap = cv2.VideoCapture(mask_hint_path)\n            while True:\n                ret, frame = cap.read()\n                if not ret:\n                    break\n                # Convert to Grayscale\n                m = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n                # Force Binary Thresholding\n                _, m = cv2.threshold(m, 10, 255, cv2.THRESH_BINARY)\n                mask_frames.append(m)\n            cap.release()\n\n        # Validate Lengths\n        num_frames = min(len(input_frames), len(mask_frames))\n        input_frames = input_frames[:num_frames]\n        mask_frames = mask_frames[:num_frames]\n\n        if num_frames == 0:\n            logger.error(f\"Skipping {clip.name}: No valid frame pairs found.\")\n            continue\n\n        # Run Inference\n        try:\n            # Prepare Output Directory First\n            # Logic: If it exists as a FILE (legacy/error), delete it.\n            if os.path.exists(alpha_output_dir) and not os.path.isdir(alpha_output_dir):\n                logger.warning(f\"Removing file '{alpha_output_dir}' to create directory.\")\n                os.remove(alpha_output_dir)\n\n            # If there was a Video Alpha Asset (e.g. AlphaHint.mp4), rename it to backup so it doesn't conflict\n            if clip.alpha_asset and clip.alpha_asset.type == \"video\":\n                old_path = clip.alpha_asset.path\n                if os.path.exists(old_path):\n                    dir_name = os.path.dirname(old_path)\n                    base, ext = os.path.splitext(os.path.basename(old_path))\n                    backup_path = os.path.join(dir_name, f\"{base}_backup{ext}\")\n                    logger.info(\n                        f\"Backing up existing Alpha Video: \"\n                        f\"{os.path.basename(old_path)} -> {os.path.basename(backup_path)}\"\n                    )\n                    os.rename(old_path, backup_path)\n                    # Clear it from memory so we rely on the new one\n                    clip.alpha_asset = None\n\n            os.makedirs(alpha_output_dir, exist_ok=True)\n\n            # Name setup\n            if clip.input_asset.type == \"sequence\":\n                in_names = sorted(\n                    [os.path.splitext(f)[0] for f in os.listdir(clip.input_asset.path) if is_image_file(f)]\n                )\n            else:\n                stem = os.path.splitext(os.path.basename(clip.input_asset.path))[0]\n                in_names = [f\"{stem}_{i:05d}\" for i in range(num_frames)]\n\n            total_saved = 0\n\n            # Iterate generator\n            for chunk_frames in run_videomama_frames(pipeline, input_frames, mask_frames, chunk_size=chunk_size):\n                for frame in chunk_frames:\n                    if total_saved >= len(in_names):\n                        break\n\n                    name = in_names[total_saved]\n                    out_path = os.path.join(alpha_output_dir, f\"{name}.png\")\n\n                    # Convert to BGR and Save\n                    cv2.imwrite(out_path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))\n                    total_saved += 1\n\n                    if on_frame_complete:\n                        on_frame_complete(total_saved, num_frames)\n\n                logger.info(f\"  Saved {total_saved}/{num_frames} frames...\")\n\n            logger.info(f\"VideoMaMa Complete: Saved {total_saved} frames to AlphaHint.\")\n\n            # Update clip state in memory (dummy) - re-scan will pick it up properly\n            clip.alpha_asset = ClipAsset(alpha_output_dir, \"sequence\")\n\n        except Exception as e:\n            logger.error(f\"VideoMaMa failed for {clip.name}: {e}\")\n            import traceback\n\n            traceback.print_exc()\n\n\ndef run_inference(\n    clips,\n    device=None,\n    backend=None,\n    max_frames=None,\n    skip_existing=False,\n    settings: InferenceSettings | None = None,\n    *,\n    on_clip_start: Callable[[str, int], None] | None = None,\n    on_frame_complete: Callable[[int, int], None] | None = None,\n):\n    ready_clips = [c for c in clips if c.input_asset and c.alpha_asset]\n\n    if not ready_clips:\n        logger.info(\"No clips found with both Input and Alpha assets. Run generate_coarse_alpha first?\")\n        return\n\n    logger.info(f\"Found {len(ready_clips)} clips ready for inference.\")\n\n    # Backward compat for callers that don't pass settings\n    if settings is None:\n        settings = InferenceSettings()\n\n    # Ensure Output Directory exists\n    if not os.path.exists(OUTPUT_DIR):\n        os.makedirs(OUTPUT_DIR, exist_ok=True)\n\n    import numpy as np\n\n    if device is None:\n        device = resolve_device()\n    from CorridorKeyModule.backend import create_engine\n\n    engine = create_engine(backend=backend, device=device)\n\n    for clip in ready_clips:\n        logger.info(f\"Running Inference on: {clip.name}\")\n\n        # Setup Outputs in ClipFolder/Output/...\n        clip_out_root = os.path.join(clip.root_path, \"Output\")\n        fg_dir = os.path.join(clip_out_root, \"FG\")\n        matte_dir = os.path.join(clip_out_root, \"Matte\")\n        comp_dir = os.path.join(clip_out_root, \"Comp\")\n        proc_dir = os.path.join(clip_out_root, \"Processed\")\n\n        for d in [fg_dir, matte_dir, comp_dir, proc_dir]:\n            os.makedirs(d, exist_ok=True)\n\n        num_frames = min(clip.input_asset.frame_count, clip.alpha_asset.frame_count)\n        if max_frames is not None:\n            num_frames = min(num_frames, max_frames)\n        logger.info(\n            f\"  Input frames: {clip.input_asset.frame_count},\"\n            f\" Alpha frames: {clip.alpha_asset.frame_count} -> Processing {num_frames} frames\"\n        )\n\n        if num_frames == 0:\n            logger.warning(f\"Clip '{clip.name}': 0 frames to process, skipping.\")\n            continue\n\n        input_cap = None\n        alpha_cap = None\n        input_files = []\n        alpha_files = []\n\n        if clip.input_asset.type == \"video\":\n            input_cap = cv2.VideoCapture(clip.input_asset.path)\n        else:\n            input_files = sorted([f for f in os.listdir(clip.input_asset.path) if is_image_file(f)])\n\n        if clip.alpha_asset.type == \"video\":\n            alpha_cap = cv2.VideoCapture(clip.alpha_asset.path)\n        else:\n            alpha_files = sorted([f for f in os.listdir(clip.alpha_asset.path) if is_image_file(f)])\n\n        if on_clip_start:\n            on_clip_start(clip.name, num_frames)\n\n        skipped_count = 0\n\n        for i in range(num_frames):\n            # Pre-compute output stem for skip-existing check (mirrors how input_stem\n            # is set later: video -> zero-padded index, sequence -> file stem)\n            if clip.input_asset.type == \"video\":\n                expected_stem = f\"{i:05d}\"\n            else:\n                expected_stem = os.path.splitext(input_files[i])[0]\n\n            if skip_existing and os.path.exists(os.path.join(comp_dir, f\"{expected_stem}.png\")):\n                logger.debug(\"Frame %d already rendered, skipping.\", i)\n                skipped_count += 1\n                if on_frame_complete:\n                    on_frame_complete(i, num_frames)\n                continue\n\n            # 1. Read Input\n            img_srgb = None\n            input_stem = f\"{i:05d}\"\n\n            # Use the settings-defined gamma\n            input_is_linear = settings.input_is_linear\n\n            if input_cap:\n                ret, frame = input_cap.read()\n                if not ret:\n                    break\n                img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n                img_srgb = img_rgb.astype(np.float32) / 255.0\n                input_stem = f\"{i:05d}\"\n            else:\n                fpath = os.path.join(clip.input_asset.path, input_files[i])\n                input_stem = os.path.splitext(input_files[i])[0]\n\n                is_exr = fpath.lower().endswith(\".exr\")\n                if is_exr:\n                    img_srgb = read_image_frame(fpath, gamma_correct_exr=not input_is_linear)\n                    if img_srgb is None:\n                        continue\n                else:\n                    img_bgr = cv2.imread(fpath)\n                    if img_bgr is None:\n                        continue\n                    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)\n                    img_srgb = img_rgb.astype(np.float32) / 255.0\n\n            # 2. Read Alpha (Mask)\n            mask_linear = None\n            if alpha_cap:\n                ret, frame = alpha_cap.read()\n                if not ret:\n                    break\n                mask_linear = frame[:, :, 2].astype(np.float32) / 255.0\n            else:\n                fpath = os.path.join(clip.alpha_asset.path, alpha_files[i])\n                mask_in = cv2.imread(fpath, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED)\n\n                if mask_in is None:\n                    continue\n\n                if mask_in.ndim == 3:\n                    if mask_in.shape[2] == 3:\n                        mask_linear = mask_in[:, :, 0]\n                    else:\n                        mask_linear = mask_in\n                else:\n                    mask_linear = mask_in\n\n                if mask_linear.dtype == np.uint8:\n                    mask_linear = mask_linear.astype(np.float32) / 255.0\n                elif mask_linear.dtype == np.uint16:\n                    mask_linear = mask_linear.astype(np.float32) / 65535.0\n                else:\n                    mask_linear = mask_linear.astype(np.float32)\n\n            if mask_linear.shape[:2] != img_srgb.shape[:2]:\n                mask_linear = cv2.resize(\n                    mask_linear, (img_srgb.shape[1], img_srgb.shape[0]), interpolation=cv2.INTER_LINEAR\n                )\n\n            # 3. Process\n            USE_STRAIGHT_MODEL = True\n            res = engine.process_frame(\n                img_srgb,\n                mask_linear,\n                input_is_linear=input_is_linear,\n                fg_is_straight=USE_STRAIGHT_MODEL,\n                despill_strength=settings.despill_strength,\n                auto_despeckle=settings.auto_despeckle,\n                despeckle_size=settings.despeckle_size,\n                refiner_scale=settings.refiner_scale,\n            )\n\n            pred_fg = res[\"fg\"]  # sRGB\n            pred_alpha = res[\"alpha\"]  # Linear\n\n            # 4. Save (EXR half-float, PXR24 compression — see backend/frame_io.py)\n\n            # Save FG\n            # pred_fg is RGB 0-1 float. Convert to BGR for OpenCV\n            fg_bgr = cv2.cvtColor(pred_fg, cv2.COLOR_RGB2BGR)\n            cv2.imwrite(os.path.join(fg_dir, f\"{input_stem}.exr\"), fg_bgr, EXR_WRITE_FLAGS)\n\n            # Save Matte\n            if pred_alpha.ndim == 3:\n                pred_alpha = pred_alpha[:, :, 0]\n            # Matte is single channel linear float\n            cv2.imwrite(os.path.join(matte_dir, f\"{input_stem}.exr\"), pred_alpha, EXR_WRITE_FLAGS)\n\n            # 5. Generate Reference Comp\n            comp_srgb = res[\"comp\"]\n            # Save Comp (PNG 8-bit)\n            comp_bgr = cv2.cvtColor((np.clip(comp_srgb, 0.0, 1.0) * 255.0).astype(np.uint8), cv2.COLOR_RGB2BGR)\n            cv2.imwrite(os.path.join(comp_dir, f\"{input_stem}.png\"), comp_bgr)\n\n            # 6. Save Processed (RGBA EXR)\n            if \"processed\" in res:\n                # Result is RGBA\n                proc_rgba = res[\"processed\"]\n                # Convert to BGRA for OpenCV\n                proc_bgra = cv2.cvtColor(proc_rgba, cv2.COLOR_RGBA2BGRA)\n                cv2.imwrite(os.path.join(proc_dir, f\"{input_stem}.exr\"), proc_bgra, EXR_WRITE_FLAGS)\n\n            if on_frame_complete:\n                on_frame_complete(i, num_frames)\n\n        if input_cap:\n            input_cap.release()\n        if alpha_cap:\n            alpha_cap.release()\n\n        if skip_existing and skipped_count > 0:\n            logger.info(\n                \"  Skipped %d of %d frames (outputs already exist).\",\n                skipped_count,\n                num_frames,\n            )\n\n        # 7. Stitch comp frames into MP4 (if input was video)\n        if clip.input_asset and clip.input_asset.type == \"video\":\n            try:\n                from backend.ffmpeg_tools import find_ffmpeg, probe_video, stitch_video\n\n                if find_ffmpeg():\n                    # Get source fps\n                    try:\n                        video_info = probe_video(clip.input_asset.path)\n                        fps = video_info.get(\"fps\", 24.0)\n                    except Exception:\n                        fps = 24.0\n\n                    comp_video_path = os.path.join(clip_out_root, f\"{clip.name}_comp.mp4\")\n\n                    # Detect frame pattern from saved files\n                    comp_files = sorted(f for f in os.listdir(comp_dir) if f.endswith(\".png\"))\n                    if comp_files:\n                        # Frames are named {input_stem}.png — e.g. 00000.png\n                        # Build ffmpeg pattern from first file\n                        first = comp_files[0]\n                        stem = os.path.splitext(first)[0]\n                        if stem.isdigit():\n                            pattern = f\"%0{len(stem)}d.png\"\n                        else:\n                            pattern = \"frame_%06d.png\"\n\n                        logger.info(f\"Stitching comp video: {comp_dir} -> {comp_video_path} @ {fps} fps\")\n                        stitch_video(comp_dir, comp_video_path, fps=fps, pattern=pattern)\n                    else:\n                        logger.warning(f\"No comp frames found in {comp_dir}, skipping video stitch.\")\n                else:\n                    logger.info(\"ffmpeg not found — skipping comp video stitch.\")\n            except Exception as e:\n                logger.warning(f\"Comp video stitch failed (non-fatal): {e}\")\n\n        logger.info(f\"Clip {clip.name} Complete.\")\n\n\ndef organize_target(target_dir: str) -> None:\n    \"\"\"\n    Organizes a specific folder.\n    1. If loose video -> Rename to Input.ext (if safe).\n    2. If sequence -> Move to Input/.\n    3. Ensure AlphaHint and VideoMamaMaskHint folders exist.\n    \"\"\"\n    logger.info(f\"Organizing Target: {target_dir}\")\n\n    if not os.path.exists(target_dir):\n        logger.error(f\"Target directory not found: {target_dir}\")\n        return\n\n    # Check for loose video\n    # Strategy: Find largest video file that ISN'T named Input.*\n    candidates = [f for f in os.listdir(target_dir) if is_video_file(f)]\n    candidates = [f for f in candidates if not os.path.splitext(f)[0].lower() == \"input\"]\n\n    if candidates and not os.path.exists(os.path.join(target_dir, \"Input\")):\n        # If multiple, pick largest (heuristic for 'Main Plate')\n        candidates.sort(key=lambda f: os.path.getsize(os.path.join(target_dir, f)), reverse=True)\n        main_clip = candidates[0]\n        ext = os.path.splitext(main_clip)[1]\n\n        try:\n            shutil.move(os.path.join(target_dir, main_clip), os.path.join(target_dir, f\"Input{ext}\"))\n            logger.info(f\"Renamed '{main_clip}' to 'Input{ext}'\")\n        except Exception as e:\n            logger.error(f\"Failed to rename '{main_clip}': {e}\")\n\n    # Check for Image Sequence (Flat)\n    # Only if Input folder doesn't exist and Input video doesn't exist\n    has_input_dir = os.path.isdir(os.path.join(target_dir, \"Input\"))\n    has_input_video = any(\n        is_video_file(f) and os.path.basename(f).lower().startswith(\"input\") for f in os.listdir(target_dir)\n    )\n\n    if not has_input_dir and not has_input_video:\n        all_files = sorted(glob.glob(os.path.join(target_dir, \"*\")))\n        image_files = [f for f in all_files if is_image_file(f)]\n\n        if len(image_files) > 0:\n            try:\n                input_subdir = os.path.join(target_dir, \"Input\")\n                os.makedirs(input_subdir)\n                for img in image_files:\n                    shutil.move(img, os.path.join(input_subdir, os.path.basename(img)))\n                logger.info(\n                    f\"Organized: Moved {len(image_files)} images in '{os.path.basename(target_dir)}' to 'Input/'\"\n                )\n            except Exception as e:\n                logger.error(f\"Failed to organize sequence in '{target_dir}': {e}\")\n\n    # Create Hints\n    for hint in [\"AlphaHint\", \"VideoMamaMaskHint\"]:\n        hint_path = os.path.join(target_dir, hint)\n        if not os.path.exists(hint_path):\n            os.makedirs(hint_path)\n\n\ndef organize_clips(clips_dir: str) -> None:\n    \"\"\"\n    Legacy wrapper for backward compatibility with 'ClipsForInference' folder.\n    Organizes all subfolders in the given directory using the new logic.\n    \"\"\"\n    if not os.path.exists(clips_dir):\n        logger.warning(f\"Clips directory not found: {clips_dir}\")\n        return\n\n    logger.info(f\"Organizing Clips Directory: {clips_dir}\")\n\n    # Check for loose videos in root\n    loose_videos = [f for f in os.listdir(clips_dir) if is_video_file(f) and os.path.isfile(os.path.join(clips_dir, f))]\n\n    # Organize loose videos first\n    for v in loose_videos:\n        clip_name = os.path.splitext(v)[0]\n        ext = os.path.splitext(v)[1]\n        target_folder = os.path.join(clips_dir, clip_name)\n\n        if os.path.exists(target_folder):\n            logger.warning(f\"Skipping loose video '{v}': Target folder '{clip_name}' already exists.\")\n            continue\n\n        try:\n            os.makedirs(target_folder)\n            target_file = os.path.join(target_folder, f\"Input{ext}\")\n            shutil.move(os.path.join(clips_dir, v), target_file)\n            logger.info(f\"Organized: Moved '{v}' to '{clip_name}/Input{ext}'\")\n\n            # Also initialize hints immediately\n            for hint in [\"AlphaHint\", \"VideoMamaMaskHint\"]:\n                os.makedirs(os.path.join(target_folder, hint), exist_ok=True)\n        except Exception as e:\n            logger.error(f\"Failed to organize video '{v}': {e}\")\n\n    # Now iterate all subdirectories and run organize_target\n    for entry in os.listdir(clips_dir):\n        full_path = os.path.join(clips_dir, entry)\n        if os.path.isdir(full_path) and entry not in [\"IgnoredClips\", \"Output\"]:\n            organize_target(full_path)\n\n\ndef scan_clips() -> list[ClipEntry]:\n    if not os.path.exists(CLIPS_DIR):\n        os.makedirs(CLIPS_DIR, exist_ok=True)\n        return []\n\n    # Auto-organize first\n    organize_clips(CLIPS_DIR)\n\n    clip_dirs = [d for d in os.listdir(CLIPS_DIR) if os.path.isdir(os.path.join(CLIPS_DIR, d))]\n\n    valid_clips = []\n    invalid_clips = []\n\n    for d in clip_dirs:\n        if d.startswith(\".\") or d.startswith(\"_\") or d == \"IgnoredClips\":\n            continue\n\n        full_path = os.path.join(CLIPS_DIR, d)\n\n        try:\n            entry = ClipEntry(d, full_path)\n            entry.find_assets()\n            entry.validate_pair()\n            valid_clips.append(entry)\n        except ValueError as ve:\n            invalid_clips.append((d, str(ve)))\n        except Exception as e:\n            invalid_clips.append((d, f\"Unexpected error: {e}\"))\n\n    if invalid_clips:\n        logger.warning(\"INVALID OR SKIPPED CLIPS:\")\n        for name, reason in invalid_clips:\n            logger.warning(\"  - %s: %s\", name, reason)\n    else:\n        logger.info(\"All clip folders appear valid.\")\n\n    return valid_clips\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"CorridorKey Clip Manager\")\n    parser.add_argument(\"--action\", choices=[\"generate_alphas\", \"run_inference\", \"list\", \"wizard\"], required=True)\n    parser.add_argument(\"--win_path\", help=r\"Windows Path (example: V:\\...) for Wizard Mode\", default=None)\n    parser.add_argument(\n        \"--device\",\n        choices=[\"auto\", \"cuda\", \"mps\", \"cpu\"],\n        default=\"auto\",\n        help=\"Compute device (default: auto-detect CUDA > MPS > CPU)\",\n    )\n    parser.add_argument(\n        \"--backend\",\n        choices=[\"auto\", \"torch\", \"mlx\"],\n        default=\"auto\",\n        help=\"Inference backend (default: auto-detect MLX on Apple Silicon, else Torch)\",\n    )\n    parser.add_argument(\n        \"--max-frames\",\n        type=int,\n        default=None,\n        help=\"Limit number of frames to process per clip (e.g. 1 for first frame only)\",\n    )\n\n    args = parser.parse_args()\n\n    device = resolve_device(args.device)\n    logger.info(f\"Using device: {device}\")\n\n    if args.action == \"list\":\n        scan_clips()\n    elif args.action == \"generate_alphas\":\n        clips = scan_clips()\n        generate_alphas(clips, device=device)\n    elif args.action == \"run_inference\":\n        clips = scan_clips()\n        run_inference(clips, device=device, backend=args.backend, max_frames=args.max_frames)\n    elif args.action == \"wizard\":\n        if not args.win_path:\n            print(\"Error: --win_path required for wizard.\")\n        else:\n            raise NotImplementedError(\"interactive_wizard is not yet implemented\")\n"
  },
  {
    "path": "corridorkey_cli.py",
    "content": "\"\"\"CorridorKey command-line interface and interactive wizard.\n\nThis module handles CLI subcommands, environment setup, and the\ninteractive wizard workflow. The pipeline logic lives in clip_manager.py,\nwhich can be imported independently as a library.\n\nUsage:\n    uv run corridorkey wizard \"V:\\\\...\"\n    uv run corridorkey run-inference\n    uv run corridorkey generate-alphas\n    uv run corridorkey list-clips\n\"\"\"\n\nfrom __future__ import annotations\n\nimport glob\nimport logging\nimport os\nimport shutil\nimport sys\nimport warnings\nfrom typing import Annotated, Optional\n\nimport typer\nfrom rich.console import Console\nfrom rich.logging import RichHandler\nfrom rich.panel import Panel\nfrom rich.progress import BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TaskID, TextColumn, TimeElapsedColumn\nfrom rich.prompt import Confirm, IntPrompt, Prompt\nfrom rich.table import Table\n\nfrom clip_manager import (\n    LINUX_MOUNT_ROOT,\n    ClipEntry,\n    InferenceSettings,\n    generate_alphas,\n    get_birefnet_usage_options,\n    is_video_file,\n    map_path,\n    organize_target,\n    run_birefnet,\n    run_inference,\n    run_videomama,\n    scan_clips,\n)\nfrom device_utils import resolve_device\n\nlogger = logging.getLogger(__name__)\nconsole = Console()\n\napp = typer.Typer(\n    name=\"corridorkey\",\n    help=\"Neural network green screen keying for professional VFX pipelines.\",\n    rich_markup_mode=\"rich\",\n    no_args_is_help=True,\n)\n\n\n# ---------------------------------------------------------------------------\n# Environment setup\n# ---------------------------------------------------------------------------\n\n\ndef _configure_environment() -> None:\n    \"\"\"Set up logging and warnings for interactive CLI use.\"\"\"\n    warnings.filterwarnings(\"ignore\", category=FutureWarning)\n    warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch\")\n    logging.basicConfig(\n        level=logging.INFO,\n        format=\"%(message)s\",\n        datefmt=\"[%X]\",\n        handlers=[RichHandler(console=console, rich_tracebacks=True)],\n    )\n\n\n# ---------------------------------------------------------------------------\n# Progress helpers (callback protocol → rich.progress)\n# ---------------------------------------------------------------------------\n\n\nclass ProgressContext:\n    \"\"\"Context manager bridging clip_manager callbacks to Rich progress bars.\n\n    clip_manager's callback protocol doesn't know about Rich, so this class\n    owns the Progress instance and exposes bound methods as callbacks.\n    ``__exit__`` always cleans up, even if inference raises.\n    \"\"\"\n\n    def __init__(self) -> None:\n        self._progress = Progress(\n            SpinnerColumn(),\n            TextColumn(\"[progress.description]{task.description}\"),\n            BarColumn(),\n            MofNCompleteColumn(),\n            TimeElapsedColumn(),\n            console=console,\n        )\n        self._frame_task_id: TaskID | None = None\n\n    def __enter__(self) -> \"ProgressContext\":\n        self._progress.__enter__()\n        return self\n\n    def __exit__(self, *exc: object) -> None:\n        self._progress.__exit__(*exc)\n\n    def on_clip_start(self, clip_name: str, num_frames: int) -> None:\n        \"\"\"Callback: reset the progress bar for a new clip.\"\"\"\n        if self._frame_task_id is not None:\n            self._progress.remove_task(self._frame_task_id)\n        self._frame_task_id = self._progress.add_task(f\"[cyan]{clip_name}\", total=num_frames)\n\n    def on_frame_complete(self, frame_idx: int, num_frames: int) -> None:\n        \"\"\"Callback: advance the progress bar by one frame.\"\"\"\n        if self._frame_task_id is not None:\n            self._progress.advance(self._frame_task_id)\n\n\ndef _on_clip_start_log_only(clip_name: str, total_clips: int) -> None:\n    \"\"\"Clip-level callback for generate-alphas.\n\n    Unlike ProgressContext.on_clip_start (frame-level granularity with a Rich\n    task per clip), GVM has no per-frame progress so we just log.\n    \"\"\"\n    console.print(f\"  Processing [bold]{clip_name}[/bold] ({total_clips} total)\")\n\n\n# ---------------------------------------------------------------------------\n# Inference settings prompt (rich.prompt — CLI layer only)\n# ---------------------------------------------------------------------------\n\n\ndef _prompt_inference_settings(\n    *,\n    default_linear: bool | None = None,\n    default_despill: int | None = None,\n    default_despeckle: bool | None = None,\n    default_despeckle_size: int | None = None,\n    default_refiner: float | None = None,\n) -> InferenceSettings:\n    \"\"\"Interactively prompt for inference settings, skipping any pre-filled values.\"\"\"\n    console.print(Panel(\"Inference Settings\", style=\"bold cyan\"))\n\n    if default_linear is not None:\n        input_is_linear = default_linear\n    else:\n        gamma_choice = Prompt.ask(\n            \"Input colorspace\",\n            choices=[\"linear\", \"srgb\"],\n            default=\"srgb\",\n        )\n        input_is_linear = gamma_choice == \"linear\"\n\n    if default_despill is not None:\n        despill_int = max(0, min(10, default_despill))\n    else:\n        despill_int = IntPrompt.ask(\n            \"Despill strength (0–10, 10 = max despill)\",\n            default=5,\n        )\n        despill_int = max(0, min(10, despill_int))\n    despill_strength = despill_int / 10.0\n\n    if default_despeckle is not None:\n        auto_despeckle = default_despeckle\n    else:\n        auto_despeckle = Confirm.ask(\n            \"Enable auto-despeckle (removes tracking dots)?\",\n            default=True,\n        )\n\n    despeckle_size = default_despeckle_size if default_despeckle_size is not None else 400\n    if auto_despeckle and default_despeckle_size is None and default_despeckle is None:\n        despeckle_size = IntPrompt.ask(\n            \"Despeckle size (min pixels for a spot)\",\n            default=400,\n        )\n        despeckle_size = max(0, despeckle_size)\n\n    if default_refiner is not None:\n        refiner_scale = default_refiner\n    else:\n        refiner_val = Prompt.ask(\n            \"Refiner strength multiplier [dim](experimental)[/dim]\",\n            default=\"1.0\",\n        )\n        try:\n            refiner_scale = float(refiner_val)\n        except ValueError:\n            refiner_scale = 1.0\n\n    return InferenceSettings(\n        input_is_linear=input_is_linear,\n        despill_strength=despill_strength,\n        auto_despeckle=auto_despeckle,\n        despeckle_size=despeckle_size,\n        refiner_scale=refiner_scale,\n    )\n\n\n# ---------------------------------------------------------------------------\n# Typer callback (shared options)\n# ---------------------------------------------------------------------------\n\n\n@app.callback()\ndef app_callback(\n    ctx: typer.Context,\n    device: Annotated[\n        str,\n        typer.Option(help=\"Compute device: auto, cuda, mps, cpu\"),\n    ] = \"auto\",\n) -> None:\n    \"\"\"Neural network green screen keying for professional VFX pipelines.\"\"\"\n    _configure_environment()\n    ctx.ensure_object(dict)\n    ctx.obj[\"device\"] = resolve_device(device)\n    logger.info(\"Using device: %s\", ctx.obj[\"device\"])\n\n\n# ---------------------------------------------------------------------------\n# Subcommands\n# ---------------------------------------------------------------------------\n\n\n@app.command(\"list-clips\")\ndef list_clips_cmd(ctx: typer.Context) -> None:\n    \"\"\"List all clips in ClipsForInference and their status.\"\"\"\n    scan_clips()\n\n\n@app.command(\"generate-alphas\")\ndef generate_alphas_cmd(ctx: typer.Context) -> None:\n    \"\"\"Generate coarse alpha hints via GVM for clips missing them.\"\"\"\n    clips = scan_clips()\n    with console.status(\"[bold green]Loading GVM model...\"):\n        generate_alphas(clips, device=ctx.obj[\"device\"], on_clip_start=_on_clip_start_log_only)\n    console.print(\"[bold green]Alpha generation complete.\")\n\n\n@app.command(\"run-inference\")\ndef run_inference_cmd(\n    ctx: typer.Context,\n    backend: Annotated[\n        str,\n        typer.Option(help=\"Inference backend: auto, torch, mlx\"),\n    ] = \"auto\",\n    max_frames: Annotated[\n        Optional[int],\n        typer.Option(\"--max-frames\", help=\"Limit frames per clip\"),\n    ] = None,\n    skip_existing: Annotated[\n        bool,\n        typer.Option(\"--skip-existing\", help=\"Skip frames whose output files already exist (resume a partial render)\"),\n    ] = False,\n    linear: Annotated[\n        Optional[bool],\n        typer.Option(\"--linear/--srgb\", help=\"Input colorspace (default: prompt)\"),\n    ] = None,\n    despill: Annotated[\n        Optional[int],\n        typer.Option(\"--despill\", help=\"Despill strength 0–10 (default: prompt)\"),\n    ] = None,\n    despeckle: Annotated[\n        Optional[bool],\n        typer.Option(\"--despeckle/--no-despeckle\", help=\"Auto-despeckle toggle (default: prompt)\"),\n    ] = None,\n    despeckle_size: Annotated[\n        Optional[int],\n        typer.Option(\"--despeckle-size\", help=\"Min pixel size for despeckle (default: prompt)\"),\n    ] = None,\n    refiner: Annotated[\n        Optional[float],\n        typer.Option(\"--refiner\", help=\"Refiner strength multiplier (default: prompt)\"),\n    ] = None,\n) -> None:\n    \"\"\"Run CorridorKey inference on clips with Input + AlphaHint.\n\n    Settings can be passed as flags for non-interactive use, or omitted to\n    prompt interactively.\n    \"\"\"\n    clips = scan_clips()\n\n    # despeckle_size excluded — sensible default even in headless mode\n    required_flags_set = all(v is not None for v in [linear, despill, despeckle, refiner])\n    if required_flags_set:\n        assert linear is not None and despill is not None and despeckle is not None and refiner is not None\n        despill_clamped = max(0, min(10, despill))\n        settings = InferenceSettings(\n            input_is_linear=linear,\n            despill_strength=despill_clamped / 10.0,\n            auto_despeckle=despeckle,\n            despeckle_size=despeckle_size if despeckle_size is not None else 400,\n            refiner_scale=refiner,\n        )\n    else:\n        settings = _prompt_inference_settings(\n            default_linear=linear,\n            default_despill=despill,\n            default_despeckle=despeckle,\n            default_despeckle_size=despeckle_size,\n            default_refiner=refiner,\n        )\n\n    with ProgressContext() as ctx_progress:\n        run_inference(\n            clips,\n            device=ctx.obj[\"device\"],\n            backend=backend,\n            max_frames=max_frames,\n            skip_existing=skip_existing,\n            settings=settings,\n            on_clip_start=ctx_progress.on_clip_start,\n            on_frame_complete=ctx_progress.on_frame_complete,\n        )\n\n    console.print(\"[bold green]Inference complete.\")\n\n\n@app.command()\ndef wizard(\n    ctx: typer.Context,\n    path: Annotated[str, typer.Argument(help=\"Target path (Windows or local)\")],\n) -> None:\n    \"\"\"Interactive wizard for organizing clips and running the pipeline.\"\"\"\n    interactive_wizard(path, device=ctx.obj[\"device\"])\n\n\n# ---------------------------------------------------------------------------\n# Wizard (rich-styled)\n# ---------------------------------------------------------------------------\n\n\ndef interactive_wizard(win_path: str, device: str | None = None) -> None:\n    console.print(Panel(\"[bold]CORRIDOR KEY — SMART WIZARD[/bold]\", style=\"cyan\"))\n\n    # 1. Resolve Path\n    console.print(f\"Windows Path: {win_path}\")\n\n    if os.path.exists(win_path):\n        process_path = win_path\n        console.print(f\"Running locally: [bold]{process_path}[/bold]\")\n    else:\n        process_path = map_path(win_path)\n        console.print(f\"Linux/Remote Path: [bold]{process_path}[/bold]\")\n\n        if not os.path.exists(process_path):\n            console.print(\n                f\"\\n[bold red]ERROR:[/bold red] Path does not exist locally OR on Linux mount!\\n\"\n                f\"Expected Linux Mount Root: {LINUX_MOUNT_ROOT}\"\n            )\n            raise typer.Exit(code=1)\n\n    # 2. Analyze — shot or project?\n    target_is_shot = False\n    if os.path.exists(os.path.join(process_path, \"Input\")) or glob.glob(os.path.join(process_path, \"Input.*\")):\n        target_is_shot = True\n\n    work_dirs: list[str] = []\n    # Pipeline output dirs, not clip sources\n    excluded_dirs = {\"Output\", \"AlphaHint\", \"VideoMamaMaskHint\", \".ipynb_checkpoints\"}\n    if target_is_shot:\n        work_dirs = [process_path]\n    else:\n        work_dirs = [\n            os.path.join(process_path, d)\n            for d in os.listdir(process_path)\n            if os.path.isdir(os.path.join(process_path, d)) and d not in excluded_dirs\n        ]\n\n    console.print(f\"\\nFound [bold]{len(work_dirs)}[/bold] potential clip folders.\")\n\n    # Files already named Input/AlphaHint/etc are organized, not \"loose\"\n    known_names = {\"input\", \"alphahint\", \"videomamamaskhint\"}\n    loose_videos = [\n        f\n        for f in os.listdir(process_path)\n        if is_video_file(f)\n        and os.path.isfile(os.path.join(process_path, f))\n        and os.path.splitext(f)[0].lower() not in known_names\n    ]\n\n    dirs_needing_org = []\n    for d in work_dirs:\n        has_input = os.path.exists(os.path.join(d, \"Input\")) or glob.glob(os.path.join(d, \"Input.*\"))\n        has_alpha = os.path.exists(os.path.join(d, \"AlphaHint\"))\n        has_mask = os.path.exists(os.path.join(d, \"VideoMamaMaskHint\"))\n        if not has_input or not has_alpha or not has_mask:\n            dirs_needing_org.append(d)\n\n    if loose_videos or dirs_needing_org:\n        if loose_videos:\n            console.print(f\"Found [yellow]{len(loose_videos)}[/yellow] loose video files:\")\n            for v in loose_videos:\n                console.print(f\"  • {v}\")\n\n        if dirs_needing_org:\n            console.print(f\"Found [yellow]{len(dirs_needing_org)}[/yellow] folders needing setup:\")\n            display_limit = 10\n            for d in dirs_needing_org[:display_limit]:\n                console.print(f\"  • {os.path.basename(d)}\")\n            if len(dirs_needing_org) > display_limit:\n                console.print(f\"  …and {len(dirs_needing_org) - display_limit} others.\")\n\n        # 3. Organize\n        if Confirm.ask(\"\\nOrganize clips & create hint folders?\", default=False):\n            for v in loose_videos:\n                clip_name = os.path.splitext(v)[0]\n                ext = os.path.splitext(v)[1]\n                target_folder = os.path.join(process_path, clip_name)\n\n                if os.path.exists(target_folder):\n                    logger.warning(f\"Skipping loose video '{v}': Target folder '{clip_name}' already exists.\")\n                    continue\n\n                try:\n                    os.makedirs(target_folder)\n                    target_file = os.path.join(target_folder, f\"Input{ext}\")\n                    shutil.move(os.path.join(process_path, v), target_file)\n                    logger.info(f\"Organized: Moved '{v}' to '{clip_name}/Input{ext}'\")\n                    for hint in [\"AlphaHint\", \"VideoMamaMaskHint\"]:\n                        os.makedirs(os.path.join(target_folder, hint), exist_ok=True)\n                except Exception as e:\n                    logger.error(f\"Failed to organize video '{v}': {e}\")\n\n            for d in work_dirs:\n                organize_target(d)\n            console.print(\"[green]Organization complete.[/green]\")\n\n            if not target_is_shot:\n                work_dirs = [\n                    os.path.join(process_path, d)\n                    for d in os.listdir(process_path)\n                    if os.path.isdir(os.path.join(process_path, d)) and d not in excluded_dirs\n                ]\n\n    # 4. Status Check Loop\n    while True:\n        ready: list[ClipEntry] = []\n        masked: list[ClipEntry] = []\n        raw: list[ClipEntry] = []\n\n        for d in work_dirs:\n            entry = ClipEntry(os.path.basename(d), d)\n            try:\n                entry.find_assets()\n            except (FileNotFoundError, ValueError, OSError):\n                pass\n\n            has_mask = False\n            mask_dir = os.path.join(d, \"VideoMamaMaskHint\")\n            if os.path.isdir(mask_dir) and len(os.listdir(mask_dir)) > 0:\n                has_mask = True\n            if not has_mask:\n                for f in os.listdir(d):\n                    stem, _ = os.path.splitext(f)\n                    if stem.lower() == \"videomamamaskhint\" and is_video_file(f):\n                        has_mask = True\n                        break\n\n            if entry.alpha_asset:\n                ready.append(entry)\n            elif has_mask:\n                masked.append(entry)\n            else:\n                raw.append(entry)\n\n        table = Table(title=\"Status Report\", show_lines=True)\n        table.add_column(\"Category\", style=\"bold\")\n        table.add_column(\"Count\", justify=\"right\")\n        table.add_column(\"Clips\")\n\n        table.add_row(\n            \"[green]Ready[/green] (AlphaHint)\",\n            str(len(ready)),\n            \", \".join(c.name for c in ready) or \"—\",\n        )\n        table.add_row(\n            \"[yellow]Masked[/yellow] (VideoMaMaMaskHint)\",\n            str(len(masked)),\n            \", \".join(c.name for c in masked) or \"—\",\n        )\n        table.add_row(\n            \"[red]Raw[/red] (Input only)\",\n            str(len(raw)),\n            \", \".join(c.name for c in raw) or \"—\",\n        )\n        console.print(table)\n\n        missing_alpha = masked + raw\n        actions: list[str] = []\n\n        if missing_alpha:\n            actions.append(f\"[bold]v[/bold] — Run VideoMaMa ({len(masked)} with masks)\")\n            actions.append(f\"[bold]g[/bold] — Run GVM (auto-matte {len(raw)} clips)\")\n            actions.append(f\"[bold]b[/bold] — Run BiRefNet (auto-matte {len(raw)} clips)\")\n        if ready:\n            actions.append(f\"[bold]i[/bold] — Run Inference ({len(ready)} ready clips)\")\n        actions.append(\"[bold]r[/bold] — Re-scan folders\")\n        actions.append(\"[bold]q[/bold] — Quit\")\n\n        console.print(Panel(\"\\n\".join(actions), title=\"Actions\", style=\"blue\"))\n\n        choice = Prompt.ask(\"Select action\", choices=[\"v\", \"g\", \"b\", \"i\", \"r\", \"q\"], default=\"q\")\n\n        if choice == \"v\":\n            console.print(Panel(\"VideoMaMa\", style=\"magenta\"))\n            run_videomama(missing_alpha, chunk_size=50, device=device)\n            Prompt.ask(\"VideoMaMa batch complete. Press Enter to re-scan\")\n\n        elif choice == \"g\":\n            console.print(Panel(\"GVM Auto-Matte\", style=\"magenta\"))\n            console.print(f\"Will generate alphas for {len(raw)} clips without mask hints.\")\n            if Confirm.ask(\"Proceed with GVM?\", default=False):\n                generate_alphas(raw, device=device)\n                Prompt.ask(\"GVM batch complete. Press Enter to re-scan\")\n\n        elif choice == \"b\":\n            console.print(Panel(\"BiRefNet Auto-Matte\", style=\"magenta\"))\n            usage_list = get_birefnet_usage_options()\n            for i, name in enumerate(usage_list, 1):\n                console.print(f\"[[bold]{i}[/bold]] {name}\")\n\n            idx = IntPrompt.ask(\"Select Model ID\", default=1)\n            try:\n                selected_usage = usage_list[idx - 1]\n                dilate = IntPrompt.ask(\"Enter dilation/erosion radius (-50 to 50, 0 to skip)\", default=0)\n\n                console.print(f\"Starting BiRefNet ({selected_usage}, Radius={dilate}) for {len(raw)} clips...\")\n                if Confirm.ask(f\"Proceed with {selected_usage}?\", default=True):\n                    with ProgressContext() as ctx_progress:\n                        run_birefnet(\n                            raw,\n                            device=device,\n                            usage=selected_usage,\n                            dilate_radius=dilate,\n                            on_clip_start=ctx_progress.on_clip_start,\n                            on_frame_complete=ctx_progress.on_frame_complete,\n                        )\n                    Prompt.ask(\"BiRefNet batch complete. Press Enter to re-scan\")\n            except IndexError:\n                console.print(\"[red]Invalid ID selected![/red]\")\n\n        elif choice == \"i\":\n            console.print(Panel(\"Corridor Key Inference\", style=\"magenta\"))\n            try:\n                settings = _prompt_inference_settings()\n                with ProgressContext() as ctx_progress:\n                    run_inference(\n                        ready,\n                        device=device,\n                        settings=settings,\n                        on_clip_start=ctx_progress.on_clip_start,\n                        on_frame_complete=ctx_progress.on_frame_complete,\n                    )\n            except (RuntimeError, FileNotFoundError) as e:\n                console.print(f\"[bold red]Inference failed:[/bold red] {e}\")\n            Prompt.ask(\"Inference batch complete. Press Enter to re-scan\")\n\n        elif choice == \"r\":\n            console.print(\"Re-scanning…\")\n\n        elif choice == \"q\":\n            break\n\n    console.print(\"[bold green]Wizard complete. Goodbye![/bold green]\")\n\n\n# ---------------------------------------------------------------------------\n# Entry point\n# ---------------------------------------------------------------------------\n\n\ndef main() -> None:\n    \"\"\"Entry point called by the `corridorkey` console script.\"\"\"\n    try:\n        app()\n    except KeyboardInterrupt:\n        console.print(\"\\n[yellow]Interrupted.[/yellow]\")\n        sys.exit(130)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "device_utils.py",
    "content": "\"\"\"Centralized cross-platform device selection for CorridorKey.\"\"\"\n\nimport logging\nimport os\n\nimport torch\n\nlogger = logging.getLogger(__name__)\n\nDEVICE_ENV_VAR = \"CORRIDORKEY_DEVICE\"\nVALID_DEVICES = (\"auto\", \"cuda\", \"mps\", \"cpu\")\n\n\ndef detect_best_device() -> str:\n    \"\"\"Auto-detect best available device: CUDA > MPS > CPU.\"\"\"\n    if torch.cuda.is_available():\n        device = \"cuda\"\n    elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n        device = \"mps\"\n    else:\n        device = \"cpu\"\n    logger.info(\"Auto-selected device: %s\", device)\n    return device\n\n\ndef resolve_device(requested: str | None = None) -> str:\n    \"\"\"Resolve device from explicit request > env var > auto-detect.\n\n    Args:\n        requested: Device string from CLI arg. None or \"auto\" triggers\n                   env var lookup then auto-detection.\n\n    Returns:\n        Validated device string (\"cuda\", \"mps\", or \"cpu\").\n\n    Raises:\n        RuntimeError: If the requested backend is unavailable.\n    \"\"\"\n    # CLI arg takes priority, then env var, then auto\n    device = requested\n    if device is None or device == \"auto\":\n        device = os.environ.get(DEVICE_ENV_VAR, \"auto\")\n\n    if device == \"auto\":\n        return detect_best_device()\n\n    device = device.lower()\n    if device not in VALID_DEVICES:\n        raise RuntimeError(f\"Unknown device '{device}'. Valid options: {', '.join(VALID_DEVICES)}\")\n\n    # Validate the explicit request\n    if device == \"cuda\":\n        if not torch.cuda.is_available():\n            raise RuntimeError(\n                \"CUDA requested but torch.cuda.is_available() is False. Install a CUDA-enabled PyTorch build.\"\n            )\n    elif device == \"mps\":\n        if not hasattr(torch.backends, \"mps\"):\n            raise RuntimeError(\n                \"MPS requested but this PyTorch build has no MPS support. Install PyTorch >= 1.12 with MPS backend.\"\n            )\n        if not torch.backends.mps.is_available():\n            raise RuntimeError(\n                \"MPS requested but not available on this machine. Requires Apple Silicon (M1+) with macOS 12.3+.\"\n            )\n\n    return device\n\n\ndef clear_device_cache(device: torch.device | str) -> None:\n    \"\"\"Clear GPU memory cache if applicable (no-op for CPU).\"\"\"\n    device_type = device.type if isinstance(device, torch.device) else device\n    if device_type == \"cuda\":\n        torch.cuda.empty_cache()\n    elif device_type == \"mps\":\n        torch.mps.empty_cache()\n"
  },
  {
    "path": "docker-compose.yml",
    "content": "services:\n  corridorkey:\n    profiles: [\"gpu\"]\n    build:\n      context: .\n      dockerfile: Dockerfile\n    image: corridorkey:latest\n    gpus: ${CK_GPUS:-all}\n    environment:\n      - OPENCV_IO_ENABLE_OPENEXR=1\n      # Use \"all\", \"0\", \"1\", or CSV like \"0,1\" for multi-GPU pinning.\n      - NVIDIA_VISIBLE_DEVICES=${NVIDIA_VISIBLE_DEVICES:-all}\n      # Add graphics if needed for GUI/OpenGL workloads.\n      - NVIDIA_DRIVER_CAPABILITIES=${NVIDIA_DRIVER_CAPABILITIES:-compute,utility,video}\n    volumes:\n      - ./ClipsForInference:/app/ClipsForInference\n      - ./Output:/app/Output\n      - ./CorridorKeyModule/checkpoints:/app/CorridorKeyModule/checkpoints\n      - ./gvm_core/weights:/app/gvm_core/weights\n      - ./VideoMaMaInferenceModule/checkpoints:/app/VideoMaMaInferenceModule/checkpoints\n    stdin_open: true\n    tty: true\n\n  corridorkey-cpu:\n    profiles: [\"cpu\"]\n    build:\n      context: .\n      dockerfile: Dockerfile\n    image: corridorkey:latest\n    volumes:\n      - ./ClipsForInference:/app/ClipsForInference\n      - ./Output:/app/Output\n      - ./CorridorKeyModule/checkpoints:/app/CorridorKeyModule/checkpoints\n      - ./gvm_core/weights:/app/gvm_core/weights\n      - ./VideoMaMaInferenceModule/checkpoints:/app/VideoMaMaInferenceModule/checkpoints\n    stdin_open: true\n    tty: true\n"
  },
  {
    "path": "docs/LLM_HANDOVER.md",
    "content": "# CorridorKey: LLM Handover Guide\n\nWelcome, fellow AI Assistant! You are picking up a highly specialized computer vision project called **CorridorKey**, an AI Chroma Keying engine designed for professional VFX pipelines. \n\nThis document is your technical entry point. It outlines the architecture, dataflow, design decisions, and common pitfalls of this codebase to help you assist the human user effectively.\n\n---\n\n## 1. Project Overview & Architecture\n\n**CorridorKey** is a neural-network-based green screen removal tool. It takes an RGB image and a \"Coarse Alpha Hint\" (generated by the user with a rough chroma key or AI roto, or by utilizing the GVM or VideoMaMa modules) and produces mathematically perfect, physically unmixed Alpha and Foreground Straight color, with the greenscreen unmixed from semi-transparent pixels.\n\n**Core Architecture (The `GreenFormer`):**\n*   **Backbone:** A `timm` implementation of `hiera_base_plus_224.mae_in1k_ft_in1k`.\n*   **Input Modification:** We patched the first layer to accept `4 channels` (RGB + Coarse Alpha Hint).\n*   **Decoders:** Multiscale feature fusion heads that predict \"Coarse\" Alpha (1ch) and Foreground (3ch) logits.\n*   **Refiner (`CNNRefinerModule`):** A custom CNN head (dilated residual blocks) that takes the original RGB input and the Coarse predictions, outputting purely additive \"Delta Logits\" that are applied directly to the backbone's outputs before final Sigmoid activation.\n\n**Key Files:**\n*   `CorridorKeyModule/core/model_transformer.py`: The PyTorch architecture described above.\n*   `CorridorKeyModule/inference_engine.py`: The `CorridorKeyEngine` class. It loads the `CorridorKey.pth` weights and handles the resizing API.\n*   `CorridorKeyModule/core/color_utils.py`: Pure math functions for digital compositing. **Crucial:** Pay attention to `srgb_to_linear()`, `premultiply()`, and luminance-preserving `despill()`.\n*   `clip_manager.py`: The user-facing Command Line Wizard. It handles scanning directories, prompting the user for inference settings, and piping data into the engine.\n\n---\n\n## 2. Critical Dataflow Properties (Do Not Break These)\n\nThe biggest challenge in this codebase revolves around **Color Space** and **Gamma Math**. When assisting the user with compositing bugs, check these rules first:\n\n1.  **Model Input/Output is strictly `[0.0, 1.0]` Float Tensors.**\n    *   The model assumes inputs are sRGB.\n    *   The predicted Output Foreground (`res['fg']`) is natively sRGB and the model is currently trained to predict the un-multiplied straight color fg element.\n    *   The predicted Output Alpha (`res['alpha']`) is inherently Linear.\n2.  **EXR Handling (`Processed` Output pass):**\n    *   EXRs are stored as Linear float data, premultiplied.\n    *   To build the `Processed` EXR, we take the sRGB foreground, pass it through `cu.srgb_to_linear()`, premultiply it by the Linear Alpha, pack them, and save them via OpenCV in `cv2.IMWRITE_EXR_TYPE_HALF`.\n    *   *Bug History:* Do not apply a pure mathematical `Gamma 2.2` curve; use the piecewise real sRGB transfer functions defined in `color_utils.py`.\n3.  **Inference Resizing (`img_size`):**\n    *   The engine is strictly trained on `2048x2048` crops.\n    *   In `inference_engine.py`, the `process_frame()` method uses OpenCV (Lanczos4) to upscale/downscale the user's arbitrary input resolution to 2048x2048, feeds the model, and then resizes the predictions *back* to the original resolution.\n\n---\n\n## 3. The Inference Pipeline (`clip_manager.py`)\n\nUsers generally run the system via local shell launcher scripts (`CorridorKey_DRAG_CLIPS_HERE_local.bat` or `CorridorKey_DRAG_CLIPS_HERE_local.sh`) which boot the `clip_manager.py` wizard.\n\nThe pipeline works as follows:\n1.  **Scan:** Looks for folders (or dragged-and-dropped paths) containing an `Input` sequence (RGB) and an `AlphaHint` sequence (BW).\n2.  **Config:** Prompts the user for settings (Gamma space, Despill strength, Auto-Despeckle threshold, Refiner Strength).\n3.  **Execution:** Loops frame-by-frame, passing `[H, W, 3]` Numpy arrays to `engine.process_frame()`.\n4.  **Export:**\n    *   `FG` directory: Half-float EXR, RGB (**sRGB** — the model predicts straight FG in sRGB; convert to linear before compositing).\n    *   `Matte` directory: Half-float EXR, Grayscale (Linear).\n    *   `Processed` directory: Half-float EXR, RGBA (Linear, Premultiplied).\n    *   `Comp` directory: 8-bit PNG (sRGB composite over a checkerboard, for quick preview).\n\n---\n\n## 4. Helpful Pointers for Future Work\n\n*   **Training Code:** We deliberately stripped training-specific logic (like returning coarse logits, `.detach()` gradients, gradient checkpointing) out of the inference tool. these are built out in a separate program. To keep `model_transformer.py` pristine for inference speed. If the user wants to resume training the Hiera backbone, utilize the Corridor Key trainer (coming soon, maybe?).\n*   **PointRend:** You may see \"PointRend\" mentioned in old commit messages. It was entirely replaced by the CNN Refiner.\n*   **GVM / VideoMaMa:** There are sub-modules for generating the Coarse Alpha Hints. `clip_manager.py`: `--action generate_alphas` handles piping footage into these external repos.\n\n## 5. Directives for the AI\n\n*   **Be Proactive:** The user is highly technical (a VFX professional/coder). Skip basic tutorials and dive straight into advanced implementation, but be sure to document math thoroughly.\n*   **Prioritize Performance:** This is video processing. Every `.numpy()` transfer or `cv2.resize` matters in a loop running on 4K footage.\n*   **Verify Gamma:** If the user complains about \"crushed shadows\" or \"dark fringes\", the problem is almost certainly an sRGB-to-Linear conversion step happening in the wrong order inside `color_utils.py`. \n\nGood luck, and build cool tools!\n"
  },
  {
    "path": "docs/index.md",
    "content": ""
  },
  {
    "path": "gvm_core/LICENSE.md",
    "content": "# GVM Licensing and Acknowledgements\n\nThis module (`gvm_core`) contains repackaged code and integrations from the **Generative Video Matting (GVM)** project hosted by the Advanced Intelligent Machines (AIM) research team at Zhejiang University (`aim-uofa`).\n\n## Original Repository\n*   **GitHub:** [https://github.com/aim-uofa/GVM](https://github.com/aim-uofa/GVM)\n\n## License\nThe GVM project is licensed under the **Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License (CC BY-NC-SA 4.0)**.\n\nTo view a copy of this license, visit: [http://creativecommons.org/licenses/by-nc-sa/4.0/](http://creativecommons.org/licenses/by-nc-sa/4.0/)\n\nBy utilizing this module and downloading the associated Stable Video Diffusion weights, you are subject to the terms of the CC BY-NC-SA 4.0 license, which strictly prohibits commercial use.\n"
  },
  {
    "path": "gvm_core/README.md",
    "content": "# GVM Core Module\n\nThis folder contains the core logic and pre-trained models for Generative Video Matting (GVM).\nIt is designed to be a self-contained, portable module that can be dropped into any Python project.\n\n## Directory Structure\n\n```\ngvm_core/\n├── __init__.py           # Exports the main GVMProcessor class\n├── wrapper.py            # High-level API for inference\n├── requirements.txt      # List of dependencies\n├── gvm/                  # The core library package\n│   ├── models/           # Spatio-temporal UNet definitions\n│   ├── pipelines/        # Diffusers-based pipeline logic\n│   └── utils/            # Video IO and processing utilities\n└── weights/              # Bundled Model Weights (Autoencoder, UNet)\n```\n\n## Installation\n\n1. **Install Dependencies**:\n   Ensure you have a Python environment set up (Python 3.10+ recommended).\n   Install the required packages:\n\n   ```bash\n   pip install -r requirements.txt\n   ```\n\n   *Note: You may need to install PyTorch separately first to match your CUDA version.*\n\n## Usage\n\nYou can use the `GVMProcessor` to run inference on videos or image sequences.\nThe processor automatically finds the bundled model weights in the `weights/` directory.\n\n### Basic Example\n\n```python\nfrom gvm_core import GVMProcessor\n\n# Initialize the processor\n# It will load models from ./weights automatically\nprocessor = GVMProcessor(device=\"cuda\")\n\n# Process a video file\nprocessor.process_sequence(\n    input_path=\"path/to/input_video.mp4\",\n    output_dir=\"path/to/output_folder\",\n    num_frames_per_batch=8,   # Adjust based on VRAM (try 4 if OOM)\n    denoise_steps=1           # 1-step inference is standard for this model\n)\n```\n\n### Advanced Usage\n\nYou can customize the inference parameters:\n\n```python\nprocessor.process_sequence(\n    input_path=\"path/to/sequence_folder\", # Can also be a folder of images\n    output_dir=\"output\",\n    num_frames_per_batch=8,\n    decode_chunk_size=4,      # Reduces VRAM usage during decoding\n    num_overlap_frames=1,     # Overlap between batches for temporal consistency\n    mode='matte'              # 'matte' is the default mode\n)\n```\n\n## Troubleshooting\n\n- **Out of Memory (OOM)**: Reduce `num_frames_per_batch` (e.g., to 4 or 2) and `decode_chunk_size`.\n- **Missing Weights**: Ensure the `weights/` folder exists inside `gvm_core/`. If you moved the code without the weights, you must download them or copy them separately.\n"
  },
  {
    "path": "gvm_core/__init__.py",
    "content": "from .wrapper import GVMProcessor\n"
  },
  {
    "path": "gvm_core/gvm/__init__.py",
    "content": ""
  },
  {
    "path": "gvm_core/gvm/models/__init__.py",
    "content": ""
  },
  {
    "path": "gvm_core/gvm/models/unet_spatio_temporal_condition.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.attention_processor import (\n    CROSS_ATTENTION_PROCESSORS,\n    AttentionProcessor,\n    AttnProcessor,\n)\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block\nfrom diffusers.loaders import PeftAdapterMixin\nfrom diffusers.models.unets.unet_spatio_temporal_condition import (\n    UNetSpatioTemporalConditionOutput,\n)\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass UNetSpatioTemporalConditionModel(\n    ModelMixin, \n    ConfigMixin, \n    UNet2DConditionLoadersMixin,\n    PeftAdapterMixin,\n    # LoraLoaderMixin,\n):\n    r\"\"\"\n    A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlockSpatioTemporal\", \"CrossAttnDownBlockSpatioTemporal\", \"CrossAttnDownBlockSpatioTemporal\", \"DownBlockSpatioTemporal\")`):\n            The tuple of downsample blocks to use.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\", \"CrossAttnUpBlockSpatioTemporal\")`):\n            The tuple of upsample blocks to use.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        addition_time_embed_dim: (`int`, defaults to 256):\n            Dimension to to encode the additional time ids.\n        projection_class_embeddings_input_dim (`int`, defaults to 768):\n            The dimension of the projection of encoded `added_time_ids`.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],\n            [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].\n        num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):\n            The number of attention heads.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 8,\n        out_channels: int = 4,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"CrossAttnDownBlockSpatioTemporal\",\n            \"DownBlockSpatioTemporal\",\n        ),\n        up_block_types: Tuple[str] = (\n            \"UpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n            \"CrossAttnUpBlockSpatioTemporal\",\n        ),\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        addition_time_embed_dim: int = 256,\n        projection_class_embeddings_input_dim: int = 768,\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        cross_attention_dim: Union[int, Tuple[int]] = 1024,\n        transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,\n        num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),\n        num_frames: int = 25,\n\n        class_embed_type: Optional[str] = None, # 'projection',\n        num_class_embeds: Optional[int] = None,\n        act_fn: str = \"silu\",\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(\n            down_block_types\n        ):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(\n            down_block_types\n        ):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(\n            down_block_types\n        ):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            padding=1,\n        )\n\n        # time\n        time_embed_dim = block_out_channels[0] * 4\n\n        self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)\n        timestep_input_dim = block_out_channels[0]\n\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n        \n        # self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)\n        # self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(\n                down_block_types\n            )\n\n        blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-5,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                resnet_act_fn=\"silu\",\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.mid_block = UNetMidBlockSpatioTemporal(\n            block_out_channels[-1],\n            temb_channels=blocks_time_embed_dim,\n            transformer_layers_per_block=transformer_layers_per_block[-1],\n            cross_attention_dim=cross_attention_dim[-1],\n            num_attention_heads=num_attention_heads[-1],\n        )\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = list(\n            reversed(transformer_layers_per_block)\n        )\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[\n                min(i + 1, len(block_out_channels) - 1)\n            ]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=1e-5,\n                resolution_idx=i,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                resnet_act_fn=\"silu\",\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(\n            num_channels=block_out_channels[0], num_groups=32, eps=1e-5\n        )\n        self.conv_act = nn.SiLU()\n\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0],\n            out_channels,\n            kernel_size=3,\n            padding=1,\n        )\n        # class embedding\n        if class_embed_type is not None:\n            self._set_class_embedding(\n                class_embed_type,\n                act_fn=act_fn,\n                num_class_embeds=num_class_embeds,\n                projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,\n                time_embed_dim=time_embed_dim,\n                timestep_input_dim=timestep_input_dim,\n            )\n\n    def _set_class_embedding(\n        self,\n        class_embed_type: Optional[str],\n        act_fn: str,\n        num_class_embeds: Optional[int],\n        projection_class_embeddings_input_dim: Optional[int],\n        time_embed_dim: int,\n        timestep_input_dim: int,\n    ):\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n    def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:\n        class_emb = None\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n        return class_emb\n\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n            name: str,\n            module: torch.nn.Module,\n            processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(\n                    # return_deprecated_lora=True\n                )\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(\n            proc.__class__ in CROSS_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking\n    def enable_forward_chunking(\n        self, chunk_size: Optional[int] = None, dim: int = 0\n    ) -> None:\n        \"\"\"\n        Sets the attention processor to use [feed forward\n        chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).\n\n        Parameters:\n            chunk_size (`int`, *optional*):\n                The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually\n                over each tensor of dim=`dim`.\n            dim (`int`, *optional*, defaults to `0`):\n                The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)\n                or dim=1 (sequence length).\n        \"\"\"\n        if dim not in [0, 1]:\n            raise ValueError(f\"Make sure to set `dim` to either 0 or 1, not {dim}\")\n\n        # By default chunk size is 1\n        chunk_size = chunk_size or 1\n\n        def fn_recursive_feed_forward(\n            module: torch.nn.Module, chunk_size: int, dim: int\n        ):\n            if hasattr(module, \"set_chunk_feed_forward\"):\n                module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)\n\n            for child in module.children():\n                fn_recursive_feed_forward(child, chunk_size, dim)\n\n        for module in self.children():\n            fn_recursive_feed_forward(module, chunk_size, dim)\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        return_dict: bool = True,\n        position_ids=None,\n        class_labels=None,\n    ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNetSpatioTemporalConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain\n                tuple.\n        Returns:\n            [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        # for dim in sample.shape[-2:]:\n        #     if dim % default_overall_up_factor != 0:\n        #         # Forward upsample size to force interpolation output size.\n        #         forward_upsample_size = True\n        #         break\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            # logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        batch_size, num_frames = sample.shape[:2]\n        timesteps = timesteps.expand(batch_size)\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n        emb = self.time_embedding(t_emb)\n\n\n        # time_embeds = self.add_time_proj(added_time_ids.flatten())\n        # time_embeds = time_embeds.reshape((batch_size, -1))\n        # time_embeds = time_embeds.to(emb.dtype)\n        # aug_emb = self.add_embedding(time_embeds)\n        # emb = emb + aug_emb\n\n        # if class_labels is not None:\n        #     class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)\n        #     emb = emb + class_emb\n\n        # else:\n        #     class_emb = None\n        # # import pdb;pdb.set_trace()\n        # if class_emb is not None:\n        #     emb = emb + class_emb\n\n        # Flatten the batch and frames dimensions\n        # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]\n        sample = sample.flatten(0, 1)\n        # Repeat the embeddings num_video_frames times\n        # emb: [batch, channels] -> [batch * frames, channels]\n        emb = emb.repeat_interleave(num_frames, dim=0)\n        # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]\n        encoder_hidden_states = encoder_hidden_states.repeat_interleave(\n            num_frames, dim=0\n        )\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        image_only_indicator = torch.zeros(\n            batch_size, num_frames, dtype=sample.dtype, device=sample.device\n        )\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if (\n                hasattr(downsample_block, \"has_cross_attention\")\n                and downsample_block.has_cross_attention\n            ):\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    image_only_indicator=image_only_indicator,\n                    # position_ids=position_ids,\n                )\n            else:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    image_only_indicator=image_only_indicator,\n                    # position_ids=position_ids,\n                )\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        sample = self.mid_block(\n            hidden_states=sample,\n            temb=emb,\n            encoder_hidden_states=encoder_hidden_states,\n            image_only_indicator=image_only_indicator,\n            # position_ids=position_ids,\n        )\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[\n                : -len(upsample_block.resnets)\n            ]\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if (\n                hasattr(upsample_block, \"has_cross_attention\")\n                and upsample_block.has_cross_attention\n            ):\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    upsample_size=upsample_size,\n                    image_only_indicator=image_only_indicator,\n                    # position_ids=position_ids,\n                )\n            else:\n                # print('unet 611 upsample_size:', upsample_size)\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    upsample_size=upsample_size,\n                    image_only_indicator=image_only_indicator,\n                    # position_ids=position_ids,\n                )\n\n        # 6. post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        # 7. Reshape back to original shape\n        sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])\n\n        if not return_dict:\n            return (sample,)\n\n        return UNetSpatioTemporalConditionOutput(sample=sample)\n"
  },
  {
    "path": "gvm_core/gvm/pipelines/pipeline_gvm.py",
    "content": "import torch\nimport tqdm\nimport numpy as np\nfrom diffusers import DiffusionPipeline\nfrom diffusers.utils import (\n    BaseOutput, \n    USE_PEFT_BACKEND,     \n    is_peft_available,\n    is_peft_version,\n    is_torch_version,\n    logging\n)\nfrom diffusers.loaders.lora_pipeline import (\n    _LOW_CPU_MEM_USAGE_DEFAULT_LORA,\n    StableDiffusionLoraLoaderMixin\n)\nfrom peft import LoraConfig, LoraModel, set_peft_model_state_dict\nimport os\n\nimport matplotlib\nfrom typing import Union, Dict\nlogger = logging.get_logger(__name__)\n\n\nclass GVMLoraLoader(StableDiffusionLoraLoaderMixin):\n    _lora_loadable_modules = [\"unet\"]\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def load_lora_weights(\n        self, \n        pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], \n        adapter_name=None, \n        hotswap: bool = False,\n        **kwargs\n    ):\n\n        unet_lora_config = LoraConfig.from_pretrained(pretrained_model_name_or_path_or_dict)\n        checkpoint = os.path.join(pretrained_model_name_or_path_or_dict, f\"pytorch_lora_weights.pt\")\n        unet_lora_ckpt = torch.load(checkpoint)\n        self.unet = LoraModel(self.unet, unet_lora_config, \"default\")\n        set_peft_model_state_dict(self.unet, unet_lora_ckpt)\n\n\nclass GVMOutput(BaseOutput):\n    r\"\"\"\n    Output class for zero-shot text-to-video pipeline.\n\n    Args:\n        frames (`[List[PIL.Image.Image]`, `np.ndarray`]):\n            List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,\n            num_channels)`.\n    \"\"\"\n    alpha: np.ndarray\n    image: np.ndarray\n\nclass GVMPipeline(DiffusionPipeline, GVMLoraLoader):\n    def __init__(self, vae, unet, scheduler):\n        super().__init__()\n        self.register_modules(\n            vae=vae, unet=unet, scheduler=scheduler\n        )\n\n    def encode(self, input):\n        num_frames = input.shape[1]\n        input = input.flatten(0, 1)\n        latent = self.vae.encode(input.to(self.vae.dtype)).latent_dist.mode()\n        latent = latent * self.vae.config.scaling_factor\n        latent = latent.reshape(-1, num_frames, *latent.shape[1:])\n        return latent\n\n    def decode(self, latents, decode_chunk_size=16):\n        # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]\n        num_frames = latents.shape[1]\n        latents = latents.flatten(0, 1)\n        latents = latents / self.vae.config.scaling_factor\n\n        # decode decode_chunk_size frames at a time to avoid OOM\n        frames = []\n        for i in range(0, latents.shape[0], decode_chunk_size):\n            num_frames_in = latents[i : i + decode_chunk_size].shape[0]\n            frame = self.vae.decode(\n                latents[i : i + decode_chunk_size].to(self.vae.dtype),\n                num_frames=num_frames_in,\n            ).sample\n            frames.append(frame)\n        frames = torch.cat(frames, dim=0)\n\n        # [batch, frames, channels, height, width]\n        frames = frames.reshape(-1, num_frames, *frames.shape[1:])\n        return frames.to(torch.float32)\n\n    \n    def single_infer(self, rgb, position_ids=None, num_inference_steps=None, class_labels=None, noise_type=\"gaussian\"):\n        rgb_latent = self.encode(rgb)\n\n        self.scheduler.set_timesteps(num_inference_steps, device=rgb.device)\n\n        if noise_type == \"gaussian\":\n            noise_latent = torch.randn_like(rgb_latent)\n            timesteps = self.scheduler.timesteps\n        elif noise_type == \"zeros\":\n            noise_latent = torch.zeros_like(rgb_latent)\n            timesteps = torch.ones_like(self.scheduler.timesteps) * (self.scheduler.config.num_train_timesteps - 1) # 999\n            timesteps = timesteps.long()\n        else:\n            raise NotImplementedError\n            \n        image_embeddings = torch.zeros((noise_latent.shape[0], 1, 1024)).to(\n            noise_latent\n        )\n\n        for i, t in enumerate(timesteps):\n            latent_model_input = noise_latent\n            latent_model_input = torch.cat([latent_model_input, rgb_latent], dim=2)\n            # [batch_size, num_frame, 4, h, w]\n            model_output = self.unet(\n                latent_model_input,\n                t,\n                encoder_hidden_states=image_embeddings,\n                position_ids=position_ids,\n                class_labels=class_labels,\n            ).sample\n\n            if noise_type == 'zeros':\n                noise_latent = model_output\n            else:\n                # compute the previous noisy sample x_t -> x_t-1\n                noise_latent = self.scheduler.step(\n                    model_output, t, noise_latent\n                ).prev_sample\n\n        return noise_latent\n\n    \n    def __call__(\n        self,\n        image,\n        num_frames,\n        num_overlap_frames,\n        num_interp_frames,\n        decode_chunk_size,\n        num_inference_steps,\n        use_clip_img_emb=False,\n        noise_type='zeros',\n        mode='matte',\n        ensemble_size: int = 3,\n    ):\n\n        assert ensemble_size >= 1\n        self.vae.to(dtype=torch.float16)\n        class_embedding = None\n        \n        # (1, N, 3, H, W)\n        image = image.unsqueeze(0)\n        B, N = image.shape[:2]\n        rgb_norm = image * 2 - 1  # [-1, 1]\n\n        rgb = rgb_norm.expand(ensemble_size, -1, -1, -1, -1)\n        if N <= num_frames:\n            position_ids = torch.arange(N).unsqueeze(0).repeat(B, 1).to(rgb.device)\n            position_ids = torch.zeros_like(position_ids)\n            position_ids = None\n\n            latent_all = self.single_infer(\n                rgb,\n                num_inference_steps=num_inference_steps,\n                class_labels=class_embedding,\n                position_ids=position_ids,\n                noise_type=noise_type\n            )\n        else:\n            # assert 2 <= num_overlap_frames <= (num_interp_frames + 2 + 1) // 2\n            assert num_frames % 2 == 0\n            # num_interp_frames = num_frames - 2\n            key_frame_indices = []\n            for i in range(0, N, num_frames - num_overlap_frames):\n                if (\n                    i + num_frames - 1 >= N\n                    or len(key_frame_indices) >= num_frames\n                ):\n\n                    # print(i)\n                    pass\n\n                key_frame_indices.append(i)\n                key_frame_indices.append(min(N - 1, i + num_frames - 1))\n\n            key_frame_indices = torch.tensor(key_frame_indices, device=rgb.device)\n            \n            latent_all = None\n            pre_latent = None\n\n            for i in tqdm.tqdm(range(0, len(key_frame_indices), 2)):\n                position_ids = torch.arange(0, key_frame_indices[i + 1] - key_frame_indices[i] + 1).to(rgb.device)\n                position_ids = position_ids.unsqueeze(0).repeat(B, 1)\n                position_ids = None\n                latent = self.single_infer(\n                    rgb[:, key_frame_indices[i] : key_frame_indices[i + 1] + 1],\n                    position_ids=position_ids,\n                    num_inference_steps=num_inference_steps,\n                    class_labels=class_embedding\n                )\n\n                if pre_latent is not None:\n                    ratio = (\n                        torch.linspace(0, 1, num_overlap_frames)\n                        .to(latent)\n                        .view(1, -1, 1, 1, 1)\n                    )\n                    try:\n                        latent_all[:, -num_overlap_frames:] = latent[:,:num_overlap_frames] * ratio + latent_all[:, -num_overlap_frames:] * (1 - ratio)\n                    except:\n                        num_overlap_frames = min(num_overlap_frames, latent.shape[1])\n                        ratio = (\n                                torch.linspace(0, 1, num_overlap_frames)\n                                .to(latent)\n                                .view(1, -1, 1, 1, 1)\n                        )\n                        latent_all[:, -num_overlap_frames:] = latent[:,:num_overlap_frames] * ratio + latent_all[:, -num_overlap_frames:] * (1 - ratio)\n                    latent_all = torch.cat([latent_all, latent[:,num_overlap_frames:]], dim=1)\n                else:\n                    latent_all = latent.clone()\n\n                pre_latent = latent\n                if rgb.device.type == \"cuda\":\n                    torch.cuda.empty_cache()\n\n            assert latent_all.shape[1] == image.shape[1]\n\n        alpha = self.decode(latent_all, decode_chunk_size=decode_chunk_size)\n\n        # (N_videos, num_frames, H, W, 3)\n        alpha = alpha.mean(dim=2, keepdim=True)\n        alpha, _ = torch.max(alpha, dim=0)\n        alpha = torch.clamp(alpha * 0.5 + 0.5, 0.0, 1.0)\n\n        if alpha.dim() == 5:\n            alpha = alpha.squeeze(0)\n        \n        # (N, H, W, 3)\n        image = image.squeeze(0)\n\n        return GVMOutput(\n            alpha=alpha,\n            image=image,\n        )"
  },
  {
    "path": "gvm_core/gvm/utils/__init__.py",
    "content": ""
  },
  {
    "path": "gvm_core/gvm/utils/inference_utils.py",
    "content": "import av\nimport os\nimport pims\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nfrom torchvision.transforms.functional import to_pil_image\nfrom PIL import Image\nfrom fractions import Fraction\n\n\nclass VideoReader(Dataset):\n    def __init__(self, path, max_frames=None, transform=None):\n        self.video = pims.PyAVVideoReader(path)\n        self.rate = self.video.frame_rate\n        self.transform = transform\n        self.max_frames = max_frames\n        \n    @property\n    def frame_rate(self):\n        return self.rate\n    \n    @property\n    def origin_shape(self):\n        return self.video[0].shape[:2]\n\n    def __len__(self):\n        if self.max_frames is not None and self.max_frames > 0:\n            return min(len(self.video), self.max_frames)\n        else:\n            return len(self.video)\n        \n    def __getitem__(self, idx):\n        frame = self.video[idx]\n        frame = Image.fromarray(np.asarray(frame))\n        if self.transform is not None:\n            frame = self.transform(frame)\n        return frame\n\n\nclass VideoWriter:\n    def __init__(self, path, frame_rate, bit_rate=1000000):\n        self.container = av.open(path, mode='w')\n        # self.container.add_stream('h264', rate=30)\n        self.stream = self.container.add_stream('h264', rate=Fraction(frame_rate).limit_denominator())\n        self.stream.pix_fmt = 'yuv420p'\n        self.stream.bit_rate = bit_rate\n    \n    def write(self, frames):\n\n        # frames: [T, C, H, W]\n        self.stream.width = frames.size(3)\n        self.stream.height = frames.size(2)\n        if frames.size(1) == 1:\n            frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB\n        frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()\n\n        for t in range(frames.shape[0]):\n            frame = frames[t]\n            frame = av.VideoFrame.from_ndarray(frame, format='rgb24')\n            self.container.mux(self.stream.encode(frame))\n\n    def write_numpy(self, frames):\n        \n        # frames: [T, H, W, C]\n        self.stream.height = frames.shape[1]\n        self.stream.width = frames.shape[2]\n\n        for t in range(frames.shape[0]):\n            frame = frames[t]\n            frame = av.VideoFrame.from_ndarray(frame, format='rgb24')\n            self.container.mux(self.stream.encode(frame))\n\n    def close(self):\n        self.container.mux(self.stream.encode())\n        self.container.close()\n\n\nclass ImageSequenceReader(Dataset):\n    def __init__(self, path, transform=None):\n        self.path = path\n        self.files = sorted(os.listdir(path))\n        self.transform = transform\n\n    @property\n    def origin_shape(self):\n        # Use cv2 for robustness\n        import cv2\n        img = cv2.imread(os.path.join(self.path, self.files[0]), cv2.IMREAD_UNCHANGED)\n        return img.shape[:2]\n\n    def __len__(self):\n        return len(self.files)\n    \n    def __getitem__(self, idx):\n        import cv2\n        fpath = os.path.join(self.path, self.files[idx])\n        is_exr = fpath.lower().endswith('.exr')\n        \n        if is_exr:\n            img = cv2.imread(fpath, cv2.IMREAD_UNCHANGED)\n            # Convert to RGB (OpenCV is BGR)\n            if img is None:\n                 raise ValueError(f\"Failed to read {fpath}\")\n            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n            \n            # Convert to PIL Image for Transforms (expected by torchvision)\n            # EXR is float32 (Linear). PIL usually handles 0-255 uint8 better for standard transforms?\n            # BUT: The pipeline might expect standard range. \n            # If I convert to float 0-1 range and then ToTensor() handles it?\n            # ToTensor handles np.ndarray. \n            # If float32, ToTensor() keeps it float32.\n            # If uint8, ToTensor() scales to 0-1 float32.\n            \n            # Simple tone map for passing to GVM if needed? \n            # GVM expects an RGB image. \n            # Let's assume simplest path: Clip linear to 0-1, sRGB gamma, then uint8 PIL?\n            # Or pass float tensor?\n            \n            # The transforms are: ToTensor(), Resize(). ToTensor accepts array.\n            # Resize accepts PIL or Tensor.\n            \n            # Let's normalize consistent with main.py: linear -> sRGB gamma\n            img = np.power(np.clip(img, 0.0, None), 1.0/2.2)\n            img = (np.clip(img, 0.0, 1.0) * 255).astype(np.uint8)\n            img = Image.fromarray(img)\n            \n        else:\n            # Fallback to PIL for non-exr for safety/compatibility\n            with Image.open(fpath) as img:\n                img.load()\n        \n        origin_shape = torch.from_numpy(np.asarray(np.array(img).shape[:2]))\n\n        if self.transform is not None:\n            img, filename = self.transform(img), self.files[idx]\n        else:\n            filename = self.files[idx]\n\n        return {\"image\": img, \"filename\": filename, \"origin_shape\": origin_shape}\n\n\nclass ImageSequenceWriter:\n    def __init__(self, path, extension='jpg'):\n        self.path = path\n        self.extension = extension\n        self.counter = 0\n        os.makedirs(path, exist_ok=True)\n    \n    def write(self, frames, filenames=None):\n        # frames: [T, C, H, W]\n        for t in range(frames.shape[0]):\n            if filenames is None:\n                filename = str(self.counter).zfill(4) + '.' + self.extension\n            else:\n                filename = filenames[t].split('.')[0] + '.' + self.extension\n\n            to_pil_image(frames[t]).save(os.path.join(\n                self.path, filename))\n            self.counter += 1\n            \n    def close(self):\n        pass\n        "
  },
  {
    "path": "gvm_core/weights/.gitkeep",
    "content": ""
  },
  {
    "path": "gvm_core/wrapper.py",
    "content": "import os\nimport os.path as osp\nimport cv2\nimport random\nimport logging\nimport time\nfrom pathlib import Path\n\nfrom easydict import EasyDict\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.transforms import ToTensor, Resize, Compose\nfrom diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler\nfrom tqdm import tqdm\n\n# Relative imports from the internal gvm package\n# Assuming this file is inside gvm_core/\nfrom .gvm.pipelines.pipeline_gvm import GVMPipeline\nfrom .gvm.utils.inference_utils import VideoReader, VideoWriter, ImageSequenceReader, ImageSequenceWriter\nfrom .gvm.models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel\n\n\ndef seed_all(seed: int = 0):\n    \"\"\"Seed all random number generators for reproducibility.\n\n    WARNING: This mutates global state — Python's random, numpy's RNG,\n    and all PyTorch CUDA RNGs. Called from GVMProcessor.__init__.\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n\ndef impad_multi(img, multiple=32):\n    # img: (N, C, H, W)\n    h, w = img.shape[2], img.shape[3]\n    \n    target_h = int(np.ceil(h / multiple) * multiple)\n    target_w = int(np.ceil(w / multiple) * multiple)\n\n    pad_top = (target_h - h) // 2\n    pad_bottom = target_h - h - pad_top\n    pad_left = (target_w - w) // 2\n    pad_right = target_w - w - pad_left\n\n    # F.pad expects (padding_left, padding_right, padding_top, padding_bottom)\n    padded = F.pad(img, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')\n\n    return padded, (pad_top, pad_left, pad_bottom, pad_right)\n\ndef sequence_collate_fn(examples):\n    rgb_values = torch.stack([example[\"image\"] for example in examples])\n    rgb_values = rgb_values.to(memory_format=torch.contiguous_format).float()\n    rgb_names = [example[\"filename\"] for example in examples]\n    return {'rgb_values': rgb_values, 'rgb_names': rgb_names}\n\nclass GVMProcessor:\n    def __init__(self, \n                 model_base=None,\n                 unet_base=None,\n                 lora_base=None,\n                 device=\"cpu\",\n                 seed=None):\n        self.device = torch.device(device)\n        \n        # Resolve default weights path relative to this file\n        if model_base is None:\n            model_base = osp.join(osp.dirname(__file__), \"weights\")\n            \n        self.model_base = model_base\n        self.unet_base = unet_base\n        self.lora_base = lora_base\n        \n        if seed is None:\n            seed = int(time.time())\n        seed_all(seed)\n        \n        logging.info(f\"Loading GVM models from {model_base}...\")\n        self.vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder=\"vae\", torch_dtype=torch.float16)\n        self.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_base, subfolder=\"scheduler\")\n        \n        unet_folder = unet_base if unet_base is not None else model_base\n        self.unet = UNetSpatioTemporalConditionModel.from_pretrained(\n            unet_folder, \n            subfolder=\"unet\", \n            class_embed_type=None,\n            torch_dtype=torch.float16\n        )\n\n        self.pipe = GVMPipeline(vae=self.vae, unet=self.unet, scheduler=self.scheduler)\n        if lora_base:\n            # Check if lora_base is None or points to valid path, otherwise try default\n            if lora_base is None and osp.exists(osp.join(model_base, \"unet\")):\n                 # Often lora weights are just the unet weights in this codebase based on demo.py usage\n                 pass \n            elif lora_base:\n                self.pipe.load_lora_weights(lora_base)\n                \n        self.pipe = self.pipe.to(self.device, dtype=torch.float16)\n        logging.info(\"Models loaded.\")\n\n    def process_sequence(self, input_path, output_dir, \n                         num_frames_per_batch=8,\n                         denoise_steps=1,\n                         max_frames=None,\n                         decode_chunk_size=8,\n                         num_interp_frames=1,\n                         num_overlap_frames=1,\n                         use_clip_img_emb=False,\n                         noise_type='zeros',\n                         mode='matte',\n                         write_video=True,\n                         direct_output_dir=None,\n                         progress_callback=None):\n        \"\"\"\n        Process a single video or directory of images.\n        \"\"\"\n        input_path = Path(input_path)\n        file_name = input_path.stem\n        is_video = input_path.suffix.lower() in ['.mp4', '.mkv', '.gif', '.mov', '.avi']\n        \n        # --- Determine Resolution & Upscaling ---\n        if is_video:\n            cap = cv2.VideoCapture(str(input_path))\n            orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n            orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n            cap.release()\n        else:\n            image_files = sorted([f for f in input_path.iterdir() if f.is_file() and f.suffix.lower() in ['.jpg', '.png', '.jpeg', '.exr']])\n            if not image_files:\n                logging.warning(f\"No images found in {input_path}\")\n                return\n            # Use cv2 for EXR support if needed\n            first_img_path = str(image_files[0])\n            if first_img_path.lower().endswith('.exr'):\n                 # import cv2 # Global import used\n                 if \"OPENCV_IO_ENABLE_OPENEXR\" not in os.environ:\n                     os.environ[\"OPENCV_IO_ENABLE_OPENEXR\"] = \"1\"\n                 img = cv2.imread(first_img_path, cv2.IMREAD_UNCHANGED)\n            else:\n                 img = cv2.imread(first_img_path)\n                 \n            if img is not None:\n                orig_h, orig_w = img.shape[:2]\n            else:\n                orig_h, orig_w = 1080, 1920 # Fallback\n\n        target_h = orig_h\n        if target_h < 1024:\n            scale_ratio = 1024 / target_h\n            target_h = 1024\n        \n        # Calculate max resolution / long edge\n        if orig_h < orig_w: # Landscape\n            ratio = orig_w / orig_h\n            new_long = int(1024 * ratio)\n        else:\n            ratio = orig_h / orig_w\n            new_long = int(1024 * ratio)\n            \n        scale_cap = 1920\n        if new_long > scale_cap:\n            new_long = scale_cap\n        \n        max_res_param = new_long \n\n        transform = Compose([\n            ToTensor(),\n            Resize(size=1024, max_size=max_res_param, antialias=True)\n        ])\n\n        if is_video:\n            reader = VideoReader(\n                str(input_path), \n                max_frames=max_frames,\n                transform=transform\n            )\n        else:\n            reader = ImageSequenceReader(\n                str(input_path), \n                transform=transform\n            )\n\n        # Get upscaled shape from first frame\n        first_frame = reader[0]\n        if isinstance(first_frame, dict):\n             first_frame = first_frame['image']\n        \n        current_upscaled_shape = list(first_frame.shape[1:]) # H, W\n        if current_upscaled_shape[0] % 2 != 0: current_upscaled_shape[0] -= 1\n        if current_upscaled_shape[1] % 2 != 0: current_upscaled_shape[1] -= 1\n        current_upscaled_shape = tuple(current_upscaled_shape)\n\n        # Output preparation\n        fps = reader.frame_rate if hasattr(reader, 'frame_rate') else 24.0\n        \n        if direct_output_dir:\n            # Write directly to this folder\n            os.makedirs(direct_output_dir, exist_ok=True)\n            writer_alpha_seq = ImageSequenceWriter(direct_output_dir, extension='png')\n            writer_alpha = None\n            if write_video:\n                 # Warning: direct mode might not support video naming nicely without logic\n                 # Let's write video into the directory with fixed name\n                 writer_alpha = VideoWriter(osp.join(direct_output_dir, f\"{file_name}_alpha.mp4\"), frame_rate=fps)\n        else:\n            # Create output directory for this specific file\n            file_output_dir = osp.join(output_dir, file_name)\n            os.makedirs(file_output_dir, exist_ok=True)\n            logging.info(f\"Processing {input_path} -> {file_output_dir}\")\n            \n            writer_alpha = VideoWriter(osp.join(file_output_dir, f\"{file_name}_alpha.mp4\"), frame_rate=fps) if write_video else None\n            writer_alpha_seq = ImageSequenceWriter(osp.join(file_output_dir, \"alpha_seq\"), extension='png')\n        \n        # Dataloader\n        if is_video:\n            dataloader = DataLoader(reader, batch_size=num_frames_per_batch)\n        else:\n            dataloader = DataLoader(reader, batch_size=num_frames_per_batch, collate_fn=sequence_collate_fn)\n\n        upper_bound = 240./255.\n        lower_bound = 25./ 255.\n\n        for batch_id, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc=f\"Inferencing {file_name}\"):\n            filenames = []\n            if is_video:\n                b, _, h, w = batch.shape\n                for i in range(b):\n                    file_id = batch_id * b + i\n                    filenames.append(f\"{file_id:05d}.jpg\")\n            else:\n                filenames = batch['rgb_names']\n                batch = batch['rgb_values']\n\n            # Pad (Reflective)\n            batch, pad_info = impad_multi(batch)\n\n            # Inference\n            with torch.no_grad():\n                pipe_out = self.pipe(\n                    batch.to(self.device, dtype=torch.float16),\n                    num_frames=num_frames_per_batch,\n                    num_overlap_frames=num_overlap_frames,\n                    num_interp_frames=num_interp_frames,\n                    decode_chunk_size=decode_chunk_size,\n                    num_inference_steps=denoise_steps,\n                    mode=mode,\n                    use_clip_img_emb=use_clip_img_emb,\n                    noise_type=noise_type,\n                    ensemble_size=1,\n                )\n            image = pipe_out.image\n            alpha = pipe_out.alpha\n\n            # Crop padding\n            out_h, out_w = image.shape[2:]\n            pad_t, pad_l, pad_b, pad_r = pad_info\n            \n            end_h = out_h - pad_b\n            end_w = out_w - pad_r\n            \n            image = image[:, :, pad_t:end_h, pad_l:end_w]\n            alpha = alpha[:, :, pad_t:end_h, pad_l:end_w]\n\n            # Resize to ensure exact match if there's any discrepancy\n            alpha = F.interpolate(alpha, current_upscaled_shape, mode='bilinear')\n            \n            # Threshold\n            alpha[alpha>=upper_bound] = 1.0\n            alpha[alpha<=lower_bound] = 0.0\n\n            if writer_alpha: writer_alpha.write(alpha)\n            writer_alpha_seq.write(alpha, filenames=filenames)\n\n            if progress_callback is not None:\n                progress_callback(batch_id + 1, len(dataloader))\n        \n        if writer_alpha: writer_alpha.close()\n        writer_alpha_seq.close()\n        logging.info(f\"Finished {file_name}\")\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"corridorkey\"\nversion = \"1.0.0\"\ndescription = \"Neural network green screen keying for professional VFX pipelines\"\nreadme = \"README.md\"\nrequires-python = \">=3.10, <3.14\"\nlicense = { text = \"CC-BY-NC-SA-4.0\" }\nauthors = [{ name = \"Corridor Digital\" }]\n\ndependencies = [\n    # Core inference\n    \"torch==2.8.0\",\n    \"torchvision==0.23.0\",\n    \"timm==1.0.24\",\n    \"numpy\",\n    \"opencv-python\",\n    \"tqdm\",\n    \"setuptools\",\n    # Triton fix for Windows\n    # There might still be issues though https://github.com/triton-lang/triton-windows?tab=readme-ov-file#windows-file-path-length-limit-260-causes-compilation-failure\n    \"triton-windows==3.4.0.post21 ; sys_platform == 'win32'\",\n    # GVM alpha hint generator\n    \"diffusers\",\n    \"transformers\",\n    \"accelerate\",\n    \"peft\",\n    \"av\",\n    \"Pillow\",\n    \"PIMS\",\n    \"easydict\",\n    \"imageio\",\n    \"matplotlib\",\n    # VideoMaMa alpha hint generator\n    \"einops\",\n    # BiRefNet alpha hint generator\n    \"kornia\",\n    # CLI tools (huggingface-hub is also a transitive dep, but must be direct\n    # so that uv installs the \"hf\" console-script entry point)\n    \"huggingface-hub\",\n    # CLI framework + terminal output\n    \"typer>=0.12\",\n    \"rich>=13\",\n]\n\n[project.optional-dependencies]\ncuda = [\n    \"torch==2.8.0\",\n    \"torchvision==0.23.0\",\n]\nmlx = [\n    \"corridorkey-mlx ; python_version >= '3.11'\",\n]\n\n[dependency-groups]\ndev = [\"pytest\", \"pytest-cov\", \"ruff\", \"hypothesis\"]\ndocs = [\"zensical>=0.0.24\"]\n\n[project.scripts]\ncorridorkey = \"corridorkey_cli:main\"\n\n[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[tool.hatch.build.targets.wheel]\npackages = [\"CorridorKeyModule\", \"gvm_core\", \"VideoMaMaInferenceModule\"]\n\n[tool.ruff]\ntarget-version = \"py311\"\nline-length = 120\nextend-exclude = [\"gvm_core/\", \"VideoMaMaInferenceModule/\"]\n\n[tool.ruff.lint]\nselect = [\"E\", \"F\", \"W\", \"I\", \"B\"]\n\n[tool.ruff.format]\n# Third-party code excluded via top-level exclude\n\n[tool.pytest.ini_options]\ntestpaths = [\"tests\"]\nmarkers = [\n    \"gpu: requires CUDA GPU (skipped in CI)\",\n    \"slow: long-running test\",\n    \"mlx: requires Apple Silicon with MLX installed\",\n]\nenv = [\"OPENCV_IO_ENABLE_OPENEXR=1\", ]\naddopts = \"--tb=short\"\n\n[tool.coverage.run]\nbranch = true\n# \".\" covers root-level modules (clip_manager.py, corridorkey_cli.py once PR #8 merges).\n# Third-party and non-library code is excluded via omit.\nsource = [\"CorridorKeyModule\", \".\"]\nomit = [\n    \"gvm_core/*\",\n    \"VideoMaMaInferenceModule/*\",\n    \"tests/*\",\n    \"test_vram.py\", # manual GPU smoke test, not part of the pytest suite\n]\n\n[tool.uv]\n# Guard against transitive deps (diffusers, imageio, PIMS) silently pulling in\n# opencv-python-headless, which conflicts with opencv-python at the file level\n# (both install into the same cv2/ directory). If any future dep requests\n# opencv-python-headless, uv resolution will fail explicitly rather than\n# corrupting the environment.\nconstraint-dependencies = [\n    \"opencv-python-headless==99999\",\n]\nconflicts = [\n    [\n        { extra = \"cuda\" },\n        { extra = \"mlx\" },\n    ],\n]\n\n[[tool.uv.index]]\nname = \"pytorch\"\nurl = \"https://download.pytorch.org/whl/cu128\" # CUDA 12.6 doesn't support RTX 5000 Series\nexplicit = true\nextra = \"cuda\"\n\n[tool.uv.sources]\n# Use Hiera fix in order to utilize the FlashAttention Kernel\ntimm = { git = \"https://github.com/Raiden129/pytorch-image-models-fix\", branch = \"fix/hiera-flash-attention-global-4d\" }\ntorch = { index = \"pytorch\", extra = \"cuda\" }\ntorchvision = { index = \"pytorch\", extra = \"cuda\" }\ncorridorkey-mlx = { git = \"https://github.com/nikopueringer/corridorkey-mlx.git\", extra = \"mlx\" }\n"
  },
  {
    "path": "renovate.json",
    "content": "{\n  \"$schema\": \"https://docs.renovatebot.com/renovate-schema.json\",\n  \"extends\": [\"config:recommended\"],\n  \"pinDigests\": true,\n  \"lockFileMaintenance\": {\n    \"enabled\": true\n  },\n  \"semanticCommits\": \"enabled\",\n  \"addLabels\": [\"renovate\", \"dependencies\"]\n}\n"
  },
  {
    "path": "test_vram.py",
    "content": "import timeit\n\nimport numpy as np\nimport torch\n\nfrom CorridorKeyModule.inference_engine import CorridorKeyEngine\n\n\ndef process_frame(engine):\n    img = np.random.randint(0, 255, (2160, 3840, 3), dtype=np.uint8)\n    mask = np.random.randint(0, 255, (2160, 3840), dtype=np.uint8)\n\n    engine.process_frame(img, mask)\n\n\ndef test_vram():\n    print(\"Loading engine...\")\n    engine = CorridorKeyEngine(\n        checkpoint_path=\"CorridorKeyModule/checkpoints/CorridorKey_v1.0.pth\",\n        img_size=2048,\n        device=\"cuda\",\n        model_precision=torch.float16,\n    )\n\n    # Reset stats\n    torch.cuda.reset_peak_memory_stats()\n\n    iterations = 24\n    print(f\"Running {iterations} inference passes...\")\n    time = timeit.timeit(lambda: process_frame(engine), number=iterations)\n    print(f\"Seconds per frame: {time / iterations}\")\n\n    peak_vram = torch.cuda.max_memory_allocated() / (1024**3)\n    print(f\"Peak VRAM used: {peak_vram:.2f} GB\")\n\n\nif __name__ == \"__main__\":\n    test_vram()\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/conftest.py",
    "content": "\"\"\"Shared pytest configuration and fixtures for CorridorKey tests.\"\"\"\n\nimport platform\nimport sys\nfrom types import ModuleType\nfrom unittest.mock import MagicMock\n\nimport numpy as np\nimport pytest\nimport torch\n\n\ndef _has_gpu():\n    \"\"\"Check if any GPU backend (CUDA or MPS) is available.\"\"\"\n    if torch.cuda.is_available():\n        return True\n    if hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n        return True\n\n    return False\n\n\ndef _has_mlx():\n    \"\"\"Check if MLX is available (Apple Silicon + corridorkey_mlx installed).\"\"\"\n    if sys.platform != \"darwin\" or platform.machine() != \"arm64\":\n        return False\n    try:\n        import corridorkey_mlx  # noqa: F401\n\n        return True\n    except ImportError:\n        return False\n\n\ndef pytest_collection_modifyitems(config, items):\n    \"\"\"Auto-skip GPU/MLX tests when hardware is unavailable.\"\"\"\n    if not _has_gpu():\n        skip_gpu = pytest.mark.skip(reason=\"No GPU available (neither CUDA nor MPS)\")\n        for item in items:\n            if \"gpu\" in item.keywords:\n                item.add_marker(skip_gpu)\n\n    if not _has_mlx():\n        skip_mlx = pytest.mark.skip(reason=\"MLX not available (requires Apple Silicon + corridorkey_mlx)\")\n        for item in items:\n            if \"mlx\" in item.keywords:\n                item.add_marker(skip_mlx)\n\n\n# ---------------------------------------------------------------------------\n# Basic frame/mask fixtures (used by color_utils and inference_engine tests)\n# ---------------------------------------------------------------------------\n\n\n@pytest.fixture\ndef sample_frame_rgb():\n    \"\"\"Small 64x64 RGB frame as float32 in [0, 1] (sRGB).\"\"\"\n    rng = np.random.default_rng(42)\n    return rng.random((64, 64, 3), dtype=np.float32)\n\n\n@pytest.fixture\ndef sample_mask():\n    \"\"\"Matching 64x64 single-channel alpha mask as float32 in [0, 1].\"\"\"\n    rng = np.random.default_rng(42)\n    mask = rng.random((64, 64), dtype=np.float32)\n    # Make it more mask-like: threshold to create distinct FG/BG regions\n    return (mask > 0.5).astype(np.float32)\n\n\n# ---------------------------------------------------------------------------\n# Clip directory structure fixtures (used by clip_manager tests)\n# ---------------------------------------------------------------------------\n\n\n@pytest.fixture\ndef tmp_clip_dir(tmp_path):\n    \"\"\"Creates a temporary directory with the expected clip structure.\n\n    Layout::\n\n        tmp_path/\n          shot_a/\n            Input/\n              frame_0000.png\n              frame_0001.png\n            AlphaHint/\n              frame_0000.png\n              frame_0001.png\n            VideoMamaMaskHint/\n          shot_b/\n            Input/\n              frame_0000.png\n            AlphaHint/         (empty — needs generation)\n            VideoMamaMaskHint/\n\n    PNG files are tiny 4x4 images so tests run fast.\n    \"\"\"\n    import cv2\n\n    tiny_img = np.zeros((4, 4, 3), dtype=np.uint8)\n    tiny_img[1:3, 1:3] = 255  # small white square\n    tiny_mask = np.zeros((4, 4), dtype=np.uint8)\n    tiny_mask[1:3, 1:3] = 255\n\n    # shot_a — fully ready (Input + AlphaHint populated)\n    shot_a = tmp_path / \"shot_a\"\n    for subdir in [\"Input\", \"AlphaHint\", \"VideoMamaMaskHint\"]:\n        (shot_a / subdir).mkdir(parents=True)\n\n    for i in range(2):\n        cv2.imwrite(str(shot_a / \"Input\" / f\"frame_{i:04d}.png\"), tiny_img)\n        cv2.imwrite(str(shot_a / \"AlphaHint\" / f\"frame_{i:04d}.png\"), tiny_mask)\n\n    # shot_b — Input only, empty AlphaHint (needs generation)\n    shot_b = tmp_path / \"shot_b\"\n    for subdir in [\"Input\", \"AlphaHint\", \"VideoMamaMaskHint\"]:\n        (shot_b / subdir).mkdir(parents=True)\n\n    cv2.imwrite(str(shot_b / \"Input\" / \"frame_0000.png\"), tiny_img)\n\n    # shot_c: Valid candidate (2 frames + Mask sequence)\n    shot_c = tmp_path / \"shot_c\"\n    for subdir in [\"Input\", \"VideoMamaMaskHint\"]:\n        (shot_c / subdir).mkdir(parents=True)\n    for i in range(2):\n        cv2.imwrite(str(shot_c / \"Input\" / f\"frame_{i:04d}.png\"), tiny_img)\n        cv2.imwrite(str(shot_c / \"VideoMamaMaskHint\" / f\"mask_{i:04d}.png\"), tiny_mask)\n\n    return tmp_path\n\n\n# ---------------------------------------------------------------------------\n# Mock inference engine fixture (used by inference_engine tests)\n# ---------------------------------------------------------------------------\n\n\n@pytest.fixture\ndef mock_greenformer():\n    \"\"\"A mock GreenFormer model that returns deterministic tensors.\n\n    Returns alpha=0.8 and fg=0.6 everywhere, sized to match the input.\n    No GPU or model weights needed.\n    \"\"\"\n\n    def fake_forward(x):\n        b, c, h, w = x.shape\n        return {\n            \"alpha\": torch.full((b, 1, h, w), 0.8, device=x.device),\n            \"fg\": torch.full((b, 3, h, w), 0.6, device=x.device),\n        }\n\n    model = MagicMock()\n    model.side_effect = fake_forward\n    model.refiner = None\n    model.use_refiner = False\n    return model\n\n\n# ---------------------------------------------------------------------------\n# VideoMaMa Backend & Staging Fixtures (used by clip_manager tests)\n# ---------------------------------------------------------------------------\n\n\n@pytest.fixture(autouse=True)\ndef silent_backend_injection(monkeypatch):\n    \"\"\"\n    Mock the inference module globally to prevent real AI loading.\n    \"\"\"\n\n    parent_mod = ModuleType(\"VideoMaMaInferenceModule\")\n    inference_mod = ModuleType(\"VideoMaMaInferenceModule.inference\")\n\n    def fake_load(device=None):\n        return \"fake_handle\"\n\n    def fake_run(pipeline, input_frames, mask_frames, **kwargs):\n        yield [np.full_like(f, 255) for f in input_frames]\n\n    inference_mod.load_videomama_model = fake_load\n    inference_mod.run_inference = fake_run\n    monkeypatch.setitem(sys.modules, \"VideoMaMaInferenceModule\", parent_mod)\n    monkeypatch.setitem(sys.modules, \"VideoMaMaInferenceModule.inference\", inference_mod)\n\n    yield\n\n\n@pytest.fixture\ndef stage_shot(tmp_path):\n    \"\"\"\n    STAGING: Physically builds the shot directory on demand.\n    Ensures ClipEntry finds its folders regardless of casing/regex.\n    \"\"\"\n    import cv2\n\n    def _stage(shot_name, create_alpha=False):\n        shot_path = tmp_path / shot_name\n        shot_path.mkdir(parents=True, exist_ok=True)\n\n        for folder_variant in [\"Input\", \"input\"]:\n            in_dir = shot_path / folder_variant\n            in_dir.mkdir(exist_ok=True)\n            for i in range(2):\n                img = np.zeros((4, 4, 3), dtype=np.uint8)\n                cv2.imwrite(str(in_dir / f\"frame_{i:04d}.png\"), img)\n\n        for mask_variant in [\"VideoMamaMaskHint\", \"videomamamaskhint\"]:\n            mask_dir = shot_path / mask_variant\n            mask_dir.mkdir(exist_ok=True)\n            cv2.imwrite(str(mask_dir / \"mask_0000.png\"), np.zeros((4, 4), np.uint8))\n            cv2.imwrite(str(mask_dir / \"mask_0001.png\"), np.zeros((4, 4), np.uint8))\n\n        if create_alpha:\n            a_dir = shot_path / \"AlphaHint\"\n            a_dir.mkdir(exist_ok=True)\n            cv2.imwrite(str(a_dir / \"frame_0000.png\"), np.zeros((4, 4), np.uint8))\n\n        return shot_path\n\n    return _stage\n\n\n@pytest.fixture(autouse=True)\ndef sandbox_clip_manager(tmp_path, monkeypatch):\n    \"\"\"\n    Forces all clip_manager operations into a temporary sandbox.\n    Prevents tests from touching real project files.\n    \"\"\"\n    import clip_manager\n\n    sandbox = tmp_path / \"Clips\"\n    sandbox.mkdir(parents=True, exist_ok=True)\n\n    monkeypatch.setattr(clip_manager, \"CLIPS_DIR\", str(sandbox))\n    monkeypatch.setattr(clip_manager, \"organize_clips\", MagicMock())\n\n    yield sandbox\n"
  },
  {
    "path": "tests/test_backend.py",
    "content": "\"\"\"Unit tests for CorridorKeyModule.backend — no GPU/MLX required.\"\"\"\n\nimport errno\nimport logging\nimport os\nfrom unittest import mock\n\nimport numpy as np\nimport pytest\n\nfrom CorridorKeyModule.backend import (\n    BACKEND_ENV_VAR,\n    HF_CHECKPOINT_FILENAME,\n    HF_REPO_ID,\n    MLX_EXT,\n    TORCH_EXT,\n    _discover_checkpoint,\n    _ensure_torch_checkpoint,\n    _wrap_mlx_output,\n    resolve_backend,\n)\n\n# --- resolve_backend ---\n\n\nclass TestResolveBackend:\n    def test_explicit_torch(self):\n        assert resolve_backend(\"torch\") == \"torch\"\n\n    def test_explicit_mlx_on_non_apple_raises(self):\n        with mock.patch(\"CorridorKeyModule.backend.sys\") as mock_sys:\n            mock_sys.platform = \"linux\"\n            with pytest.raises(RuntimeError, match=\"Apple Silicon\"):\n                resolve_backend(\"mlx\")\n\n    def test_env_var_torch(self):\n        with mock.patch.dict(os.environ, {BACKEND_ENV_VAR: \"torch\"}):\n            assert resolve_backend(None) == \"torch\"\n            assert resolve_backend(\"auto\") == \"torch\"\n\n    def test_auto_non_darwin(self):\n        with mock.patch(\"CorridorKeyModule.backend.sys\") as mock_sys:\n            mock_sys.platform = \"linux\"\n            assert resolve_backend(\"auto\") == \"torch\"\n\n    def test_auto_darwin_no_mlx_package(self):\n        with (\n            mock.patch(\"CorridorKeyModule.backend.sys\") as mock_sys,\n            mock.patch(\"CorridorKeyModule.backend.platform\") as mock_platform,\n        ):\n            mock_sys.platform = \"darwin\"\n            mock_platform.machine.return_value = \"arm64\"\n\n            # corridorkey_mlx not importable\n            import builtins\n\n            real_import = builtins.__import__\n\n            def fail_mlx(name, *args, **kwargs):\n                if name == \"corridorkey_mlx\":\n                    raise ImportError\n                return real_import(name, *args, **kwargs)\n\n            with mock.patch(\"builtins.__import__\", side_effect=fail_mlx):\n                assert resolve_backend(\"auto\") == \"torch\"\n\n    def test_unknown_backend_raises(self):\n        with pytest.raises(RuntimeError, match=\"Unknown backend\"):\n            resolve_backend(\"tensorrt\")\n\n\n# --- _discover_checkpoint ---\n\n\nclass TestDiscoverCheckpoint:\n    def test_exactly_one(self, tmp_path):\n        ckpt = tmp_path / \"model.pth\"\n        ckpt.touch()\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            result = _discover_checkpoint(TORCH_EXT)\n            assert result == ckpt\n\n    def test_zero_torch_triggers_auto_download(self, tmp_path):\n        \"\"\"Empty dir + TORCH_EXT now calls _ensure_torch_checkpoint (auto-download).\"\"\"\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\") as mock_dl:\n                # Simulate hf_hub_download returning a cached file\n                cached = tmp_path / \"hf_cache\" / \"CorridorKey.pth\"\n                cached.parent.mkdir()\n                cached.write_bytes(b\"fake-checkpoint\")\n                mock_dl.return_value = str(cached)\n\n                result = _discover_checkpoint(TORCH_EXT)\n                assert result.name == \"CorridorKey.pth\"\n                assert result.exists()\n                mock_dl.assert_called_once()\n\n    def test_zero_torch_download_failure_raises_runtime_error(self, tmp_path):\n        \"\"\"When auto-download fails, RuntimeError is raised with HF URL.\"\"\"\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\n                \"huggingface_hub.hf_hub_download\",\n                side_effect=ConnectionError(\"no network\"),\n            ):\n                with pytest.raises(RuntimeError, match=\"huggingface.co\"):\n                    _discover_checkpoint(TORCH_EXT)\n\n    def test_zero_safetensors_with_cross_reference(self, tmp_path):\n        \"\"\"MLX ext with no .safetensors but .pth present gives cross-reference hint.\"\"\"\n        (tmp_path / \"model.pth\").touch()\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with pytest.raises(FileNotFoundError, match=\"--backend=torch\"):\n                _discover_checkpoint(MLX_EXT)\n\n    def test_multiple_raises(self, tmp_path):\n        (tmp_path / \"a.pth\").touch()\n        (tmp_path / \"b.pth\").touch()\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with pytest.raises(ValueError, match=\"Multiple\"):\n                _discover_checkpoint(TORCH_EXT)\n\n    def test_safetensors(self, tmp_path):\n        ckpt = tmp_path / \"model.safetensors\"\n        ckpt.touch()\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            result = _discover_checkpoint(MLX_EXT)\n            assert result == ckpt\n\n    def test_ensure_torch_checkpoint_happy_path(self, tmp_path):\n        \"\"\"Mock hf_hub_download, verify copy to CHECKPOINT_DIR/CorridorKey.pth.\"\"\"\n        cached = tmp_path / \"hf_cache\" / \"CorridorKey.pth\"\n        cached.parent.mkdir()\n        cached.write_bytes(b\"fake-checkpoint-data\")\n\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\", return_value=str(cached)) as mock_dl:\n                result = _ensure_torch_checkpoint()\n\n                assert result == tmp_path / HF_CHECKPOINT_FILENAME\n                assert result.exists()\n                assert result.read_bytes() == b\"fake-checkpoint-data\"\n                mock_dl.assert_called_once_with(\n                    repo_id=HF_REPO_ID,\n                    filename=HF_CHECKPOINT_FILENAME,\n                )\n\n    def test_skip_when_present(self, tmp_path):\n        \"\"\"Existing .pth file means hf_hub_download is never called.\"\"\"\n        ckpt = tmp_path / \"model.pth\"\n        ckpt.write_bytes(b\"existing\")\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\") as mock_dl:\n                result = _discover_checkpoint(TORCH_EXT)\n                assert result == ckpt\n                mock_dl.assert_not_called()\n\n    def test_mlx_not_triggered(self, tmp_path):\n        \"\"\"MLX ext with empty dir raises FileNotFoundError, no download attempted.\"\"\"\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\") as mock_dl:\n                with pytest.raises(FileNotFoundError):\n                    _discover_checkpoint(MLX_EXT)\n                mock_dl.assert_not_called()\n\n    def test_network_error_wrapping(self, tmp_path):\n        \"\"\"ConnectionError from hf_hub_download becomes RuntimeError with HF URL.\"\"\"\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\n                \"huggingface_hub.hf_hub_download\",\n                side_effect=ConnectionError(\"connection refused\"),\n            ) as mock_dl:\n                with pytest.raises(RuntimeError, match=r\"huggingface\\.co/nikopueringer/CorridorKey_v1\\.0\"):\n                    _ensure_torch_checkpoint()\n                mock_dl.assert_called_once()\n\n    def test_disk_space_error(self, tmp_path):\n        \"\"\"OSError ENOSPC from copy2 produces message mentioning ~300 MB.\"\"\"\n        cached = tmp_path / \"hf_cache\" / \"CorridorKey.pth\"\n        cached.parent.mkdir()\n        cached.write_bytes(b\"data\")\n\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\", return_value=str(cached)):\n                with mock.patch(\n                    \"CorridorKeyModule.backend.shutil.copy2\",\n                    side_effect=OSError(errno.ENOSPC, \"No space left on device\"),\n                ):\n                    with pytest.raises(OSError, match=\"300 MB\"):\n                        _ensure_torch_checkpoint()\n\n    def test_logging_on_download(self, tmp_path, caplog):\n        \"\"\"Info-level log messages emitted at download start and completion.\"\"\"\n        cached = tmp_path / \"hf_cache\" / \"CorridorKey.pth\"\n        cached.parent.mkdir()\n        cached.write_bytes(b\"data\")\n\n        with mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(tmp_path)):\n            with mock.patch(\"huggingface_hub.hf_hub_download\", return_value=str(cached)):\n                with caplog.at_level(logging.INFO, logger=\"CorridorKeyModule.backend\"):\n                    _ensure_torch_checkpoint()\n\n        assert any(\"Downloading\" in msg for msg in caplog.messages)\n        assert any(\"saved\" in msg.lower() for msg in caplog.messages)\n\n\n# --- _wrap_mlx_output ---\n\n\nclass TestWrapMlxOutput:\n    @pytest.fixture\n    def mlx_raw_output(self):\n        \"\"\"Simulated MLX engine output: uint8.\"\"\"\n        h, w = 64, 64\n        rng = np.random.default_rng(42)\n        return {\n            \"alpha\": rng.integers(0, 256, (h, w), dtype=np.uint8),\n            \"fg\": rng.integers(0, 256, (h, w, 3), dtype=np.uint8),\n            \"comp\": rng.integers(0, 256, (h, w, 3), dtype=np.uint8),\n            \"processed\": rng.integers(0, 256, (h, w, 3), dtype=np.uint8),\n        }\n\n    def test_output_keys(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=1.0, auto_despeckle=True, despeckle_size=400)\n        assert set(result.keys()) == {\"alpha\", \"fg\", \"comp\", \"processed\"}\n\n    def test_alpha_shape_dtype(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=1.0, auto_despeckle=False, despeckle_size=400)\n        assert result[\"alpha\"].shape == (64, 64, 1)\n        assert result[\"alpha\"].dtype == np.float32\n        assert result[\"alpha\"].min() >= 0.0\n        assert result[\"alpha\"].max() <= 1.0\n\n    def test_fg_shape_dtype(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=0.0, auto_despeckle=False, despeckle_size=400)\n        assert result[\"fg\"].shape == (64, 64, 3)\n        assert result[\"fg\"].dtype == np.float32\n\n    def test_processed_shape_dtype(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=1.0, auto_despeckle=False, despeckle_size=400)\n        assert result[\"processed\"].shape == (64, 64, 4)\n        assert result[\"processed\"].dtype == np.float32\n\n    def test_comp_shape_dtype(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=1.0, auto_despeckle=False, despeckle_size=400)\n        assert result[\"comp\"].shape == (64, 64, 3)\n        assert result[\"comp\"].dtype == np.float32\n\n    def test_value_ranges(self, mlx_raw_output):\n        result = _wrap_mlx_output(mlx_raw_output, despill_strength=1.0, auto_despeckle=False, despeckle_size=400)\n        # alpha and fg come from uint8 / 255 so strictly 0-1\n        for key in (\"alpha\", \"fg\"):\n            assert result[key].min() >= 0.0, f\"{key} has negative values\"\n            assert result[key].max() <= 1.0, f\"{key} exceeds 1.0\"\n        # comp/processed can slightly exceed 1.0 due to sRGB conversion + despill redistribution\n        # (same behavior as Torch engine — linear_to_srgb doesn't clamp)\n        for key in (\"comp\", \"processed\"):\n            assert result[key].min() >= 0.0, f\"{key} has negative values\"\n"
  },
  {
    "path": "tests/test_cli.py",
    "content": "\"\"\"Tests for the typer-based CLI in corridorkey_cli.py.\"\"\"\n\nfrom __future__ import annotations\n\nimport re\nfrom unittest.mock import patch\n\nfrom typer.testing import CliRunner\n\nfrom clip_manager import InferenceSettings\nfrom corridorkey_cli import app\n\nrunner = CliRunner()\n\nANSI_ESCAPE = re.compile(r\"\\x1b\\[[0-9;]*m\")\n\n\n# ---------------------------------------------------------------------------\n# Help output\n# ---------------------------------------------------------------------------\n\n\nclass TestHelpOutput:\n    def test_main_help(self):\n        result = runner.invoke(app, [\"--help\"])\n        assert result.exit_code == 0\n        assert \"list-clips\" in result.output\n        assert \"generate-alphas\" in result.output\n        assert \"run-inference\" in result.output\n        assert \"wizard\" in result.output\n\n    def test_list_clips_help(self):\n        result = runner.invoke(app, [\"list-clips\", \"--help\"])\n        assert result.exit_code == 0\n\n    def test_generate_alphas_help(self):\n        result = runner.invoke(app, [\"generate-alphas\", \"--help\"])\n        assert result.exit_code == 0\n\n    def test_run_inference_help(self):\n        result = runner.invoke(app, [\"run-inference\", \"--help\"])\n        assert result.exit_code == 0\n\n    def test_wizard_help(self):\n        result = runner.invoke(app, [\"wizard\", \"--help\"])\n        assert result.exit_code == 0\n\n\n# ---------------------------------------------------------------------------\n# Invalid arguments\n# ---------------------------------------------------------------------------\n\n\nclass TestInvalidArgs:\n    def test_wizard_requires_path(self):\n        result = runner.invoke(app, [\"wizard\"])\n        assert result.exit_code != 0\n\n    def test_unknown_subcommand(self):\n        result = runner.invoke(app, [\"nonexistent\"])\n        assert result.exit_code != 0\n\n\n# ---------------------------------------------------------------------------\n# InferenceSettings defaults\n# ---------------------------------------------------------------------------\n\n\nclass TestInferenceSettings:\n    def test_defaults(self):\n        s = InferenceSettings()\n        assert s.input_is_linear is False\n        assert s.despill_strength == 0.5\n        assert s.auto_despeckle is True\n        assert s.despeckle_size == 400\n        assert s.refiner_scale == 1.0\n\n    def test_custom_values(self):\n        s = InferenceSettings(\n            input_is_linear=True,\n            despill_strength=0.8,\n            auto_despeckle=False,\n            despeckle_size=200,\n            refiner_scale=1.5,\n        )\n        assert s.input_is_linear is True\n        assert s.despill_strength == 0.8\n        assert s.auto_despeckle is False\n        assert s.despeckle_size == 200\n        assert s.refiner_scale == 1.5\n\n\n# ---------------------------------------------------------------------------\n# Callback protocol\n# ---------------------------------------------------------------------------\n\n\nclass TestCallbackProtocol:\n    @patch(\"corridorkey_cli.scan_clips\")\n    @patch(\"corridorkey_cli.run_inference\")\n    @patch(\"corridorkey_cli._prompt_inference_settings\")\n    def test_run_inference_passes_callbacks(self, mock_prompt, mock_run, mock_scan):\n        \"\"\"run-inference subcommand passes on_clip_start and on_frame_complete.\"\"\"\n        mock_scan.return_value = []\n        mock_prompt.return_value = InferenceSettings()\n\n        result = runner.invoke(app, [\"run-inference\"])\n        assert result.exit_code == 0\n\n        mock_run.assert_called_once()\n        _, kwargs = mock_run.call_args\n        assert \"on_clip_start\" in kwargs\n        assert \"on_frame_complete\" in kwargs\n        assert callable(kwargs[\"on_clip_start\"])\n        assert callable(kwargs[\"on_frame_complete\"])\n\n    def test_callback_signatures(self):\n        \"\"\"Callbacks accept the documented (name, count) / (idx, total) args.\"\"\"\n        from corridorkey_cli import ProgressContext\n\n        ctx = ProgressContext()\n        ctx.__enter__()\n        try:\n            # Should not raise\n            ctx.on_clip_start(\"test_clip\", 100)\n            ctx.on_frame_complete(0, 100)\n            ctx.on_frame_complete(99, 100)\n        finally:\n            ctx.__exit__(None, None, None)\n\n\n# ---------------------------------------------------------------------------\n# list-clips subcommand\n# ---------------------------------------------------------------------------\n\n\nclass TestListClips:\n    @patch(\"corridorkey_cli.scan_clips\")\n    def test_list_clips_calls_scan(self, mock_scan):\n        mock_scan.return_value = []\n        result = runner.invoke(app, [\"list-clips\"])\n        assert result.exit_code == 0\n        mock_scan.assert_called_once()\n\n\n# ---------------------------------------------------------------------------\n# Non-interactive flags for run-inference\n# ---------------------------------------------------------------------------\n\n\nclass TestNonInteractiveFlags:\n    @patch(\"corridorkey_cli.scan_clips\")\n    @patch(\"corridorkey_cli.run_inference\")\n    def test_all_flags_skips_prompts(self, mock_run, mock_scan):\n        \"\"\"When all settings flags are provided, no interactive prompts fire.\"\"\"\n        mock_scan.return_value = []\n\n        result = runner.invoke(\n            app,\n            [\n                \"run-inference\",\n                \"--linear\",\n                \"--despill\",\n                \"7\",\n                \"--despeckle\",\n                \"--despeckle-size\",\n                \"200\",\n                \"--refiner\",\n                \"1.5\",\n            ],\n        )\n        assert result.exit_code == 0\n\n        mock_run.assert_called_once()\n        _, kwargs = mock_run.call_args\n        settings = kwargs[\"settings\"]\n        assert settings.input_is_linear is True\n        assert settings.despill_strength == 0.7\n        assert settings.auto_despeckle is True\n        assert settings.despeckle_size == 200\n        assert settings.refiner_scale == 1.5\n\n    @patch(\"corridorkey_cli.scan_clips\")\n    @patch(\"corridorkey_cli.run_inference\")\n    def test_srgb_flag(self, mock_run, mock_scan):\n        \"\"\"--srgb sets input_is_linear=False.\"\"\"\n        mock_scan.return_value = []\n\n        result = runner.invoke(\n            app,\n            [\n                \"run-inference\",\n                \"--srgb\",\n                \"--despill\",\n                \"5\",\n                \"--no-despeckle\",\n                \"--refiner\",\n                \"1.0\",\n            ],\n        )\n        assert result.exit_code == 0\n\n        mock_run.assert_called_once()\n        _, kwargs = mock_run.call_args\n        settings = kwargs[\"settings\"]\n        assert settings.input_is_linear is False\n        assert settings.auto_despeckle is False\n\n    @patch(\"corridorkey_cli.scan_clips\")\n    @patch(\"corridorkey_cli.run_inference\")\n    def test_despill_clamped_to_range(self, mock_run, mock_scan):\n        \"\"\"Despill values outside 0-10 are clamped.\"\"\"\n        mock_scan.return_value = []\n\n        result = runner.invoke(\n            app,\n            [\n                \"run-inference\",\n                \"--srgb\",\n                \"--despill\",\n                \"15\",\n                \"--despeckle\",\n                \"--refiner\",\n                \"1.0\",\n            ],\n        )\n        assert result.exit_code == 0\n\n        mock_run.assert_called_once()\n        _, kwargs = mock_run.call_args\n        settings = kwargs[\"settings\"]\n        assert settings.despill_strength == 1.0  # clamped 15→10, then /10\n\n    def test_run_inference_help_shows_flags(self):\n        \"\"\"run-inference --help lists the settings flags.\"\"\"\n        result = runner.invoke(app, [\"run-inference\", \"--help\"])\n        assert result.exit_code == 0\n        plain = ANSI_ESCAPE.sub(\"\", result.output)\n        assert \"--despill\" in plain\n        assert \"--linear\" in plain\n        assert \"--refiner\" in plain\n        assert \"--despeckle-size\" in plain\n        assert \"--skip-existing\" in plain\n\n    @patch(\"corridorkey_cli.scan_clips\")\n    @patch(\"corridorkey_cli.run_inference\")\n    def test_skip_existing_passed_through(self, mock_run, mock_scan):\n        \"\"\"--skip-existing is forwarded to run_inference as skip_existing kwarg.\"\"\"\n        mock_scan.return_value = []\n        result = runner.invoke(\n            app,\n            [\"run-inference\", \"--skip-existing\", \"--srgb\", \"--despill\", \"5\", \"--despeckle\", \"--refiner\", \"1.0\"],\n        )\n        assert result.exit_code == 0\n        mock_run.assert_called_once()\n        _, kwargs = mock_run.call_args\n        assert kwargs[\"skip_existing\"] is True\n"
  },
  {
    "path": "tests/test_clip_manager.py",
    "content": "\"\"\"Tests for clip_manager.py utility functions and ClipEntry discovery.\n\nThese tests verify the non-interactive parts of clip_manager: file type\ndetection, Windows→Linux path mapping, and the ClipEntry asset discovery\nthat scans directory trees to find Input/AlphaHint pairs.\n\nNo GPU, model weights, or interactive input required.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom unittest.mock import MagicMock, patch\n\nimport cv2\nimport numpy as np\nimport pytest\n\nfrom clip_manager import (\n    ClipAsset,\n    ClipEntry,\n    generate_alphas,\n    is_image_file,\n    is_video_file,\n    map_path,\n    organize_clips,\n    organize_target,\n    run_videomama,\n    scan_clips,\n)\n\n# ---------------------------------------------------------------------------\n# is_image_file / is_video_file\n# ---------------------------------------------------------------------------\n\n\nclass TestFileTypeDetection:\n    \"\"\"Verify extension-based file type helpers.\n\n    These are used everywhere in clip_manager to decide how to read inputs.\n    A missed extension means a valid frame silently disappears from the batch.\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"filename\",\n        [\n            \"frame.png\",\n            \"SHOT_001.EXR\",\n            \"plate.jpg\",\n            \"ref.JPEG\",\n            \"scan.tif\",\n            \"deep.tiff\",\n            \"comp.bmp\",\n        ],\n    )\n    def test_image_extensions_recognized(self, filename):\n        assert is_image_file(filename)\n\n    @pytest.mark.parametrize(\n        \"filename\",\n        [\n            \"frame.mp4\",\n            \"CLIP.MOV\",\n            \"take.avi\",\n            \"rushes.mkv\",\n        ],\n    )\n    def test_video_extensions_recognized(self, filename):\n        assert is_video_file(filename)\n\n    @pytest.mark.parametrize(\n        \"filename\",\n        [\n            \"readme.txt\",\n            \"notes.pdf\",\n            \"project.nk\",\n            \"scene.blend\",\n            \".DS_Store\",\n        ],\n    )\n    def test_non_media_rejected(self, filename):\n        assert not is_image_file(filename)\n        assert not is_video_file(filename)\n\n    def test_image_is_not_video(self):\n        \"\"\"Image and video extensions must not overlap.\"\"\"\n        assert not is_video_file(\"frame.png\")\n        assert not is_video_file(\"plate.exr\")\n\n    def test_video_is_not_image(self):\n        assert not is_image_file(\"clip.mp4\")\n        assert not is_image_file(\"rushes.mov\")\n\n\n# ---------------------------------------------------------------------------\n# map_path\n# ---------------------------------------------------------------------------\n\n\nclass TestMapPath:\n    r\"\"\"Windows→Linux path mapping.\n\n    The tool is designed for studios running a Linux render farm with\n    Windows workstations.  V:\\ maps to /mnt/ssd-storage.\n    \"\"\"\n\n    def test_basic_mapping(self):\n        result = map_path(r\"V:\\Projects\\Shot1\")\n        assert result == \"/mnt/ssd-storage/Projects/Shot1\"\n\n    def test_case_insensitive_drive_letter(self):\n        result = map_path(r\"v:\\projects\\shot1\")\n        assert result == \"/mnt/ssd-storage/projects/shot1\"\n\n    def test_trailing_whitespace_stripped(self):\n        result = map_path(r\"  V:\\Projects\\Shot1  \")\n        assert result == \"/mnt/ssd-storage/Projects/Shot1\"\n\n    def test_backslashes_converted(self):\n        result = map_path(r\"V:\\Deep\\Nested\\Path\\Here\")\n        assert \"\\\\\" not in result\n\n    def test_non_v_drive_passthrough(self):\n        \"\"\"Paths not on V: are returned as-is (may already be Linux paths).\"\"\"\n        linux_path = \"/mnt/other/data\"\n        assert map_path(linux_path) == linux_path\n\n    def test_drive_root_only(self):\n        result = map_path(\"V:\\\\\")\n        assert result == \"/mnt/ssd-storage/\"\n\n\n# ---------------------------------------------------------------------------\n# ClipAsset\n# ---------------------------------------------------------------------------\n\n\nclass TestClipAsset:\n    \"\"\"ClipAsset wraps a directory of images or a video file and counts frames.\"\"\"\n\n    def test_sequence_frame_count(self, tmp_path):\n        \"\"\"Image sequence: frame count = number of image files in directory.\"\"\"\n        seq_dir = tmp_path / \"Input\"\n        seq_dir.mkdir()\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        for i in range(5):\n            cv2.imwrite(str(seq_dir / f\"frame_{i:04d}.png\"), tiny)\n\n        asset = ClipAsset(str(seq_dir), \"sequence\")\n        assert asset.frame_count == 5\n\n    def test_sequence_ignores_non_image_files(self, tmp_path):\n        \"\"\"Non-image files (thumbs.db, .nk, etc.) should not be counted.\"\"\"\n        seq_dir = tmp_path / \"Input\"\n        seq_dir.mkdir()\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        cv2.imwrite(str(seq_dir / \"frame_0000.png\"), tiny)\n        (seq_dir / \"thumbs.db\").write_text(\"junk\")\n        (seq_dir / \"notes.txt\").write_text(\"notes\")\n\n        asset = ClipAsset(str(seq_dir), \"sequence\")\n        assert asset.frame_count == 1\n\n    def test_empty_sequence(self, tmp_path):\n        \"\"\"Empty directory → 0 frames.\"\"\"\n        seq_dir = tmp_path / \"Input\"\n        seq_dir.mkdir()\n        asset = ClipAsset(str(seq_dir), \"sequence\")\n        assert asset.frame_count == 0\n\n\n# ---------------------------------------------------------------------------\n# ClipEntry.find_assets\n# ---------------------------------------------------------------------------\n\n\nclass TestClipEntryFindAssets:\n    \"\"\"ClipEntry.find_assets() discovers Input and AlphaHint from a shot directory.\n\n    This is the core discovery logic that decides what's ready for inference\n    vs. what still needs alpha generation.\n    \"\"\"\n\n    def test_finds_image_sequence_input(self, tmp_clip_dir):\n        \"\"\"shot_a has Input/ with 2 PNGs → input_asset is a sequence.\"\"\"\n        entry = ClipEntry(\"shot_a\", str(tmp_clip_dir / \"shot_a\"))\n        entry.find_assets()\n        assert entry.input_asset is not None\n        assert entry.input_asset.type == \"sequence\"\n        assert entry.input_asset.frame_count == 2\n\n    def test_finds_alpha_hint(self, tmp_clip_dir):\n        \"\"\"shot_a has AlphaHint/ with 2 PNGs → alpha_asset is populated.\"\"\"\n        entry = ClipEntry(\"shot_a\", str(tmp_clip_dir / \"shot_a\"))\n        entry.find_assets()\n        assert entry.alpha_asset is not None\n        assert entry.alpha_asset.type == \"sequence\"\n        assert entry.alpha_asset.frame_count == 2\n\n    def test_empty_alpha_hint_is_none(self, tmp_clip_dir):\n        \"\"\"shot_b has empty AlphaHint/ → alpha_asset is None (needs generation).\"\"\"\n        entry = ClipEntry(\"shot_b\", str(tmp_clip_dir / \"shot_b\"))\n        entry.find_assets()\n        assert entry.input_asset is not None\n        assert entry.alpha_asset is None\n\n    def test_missing_input_raises(self, tmp_path):\n        \"\"\"A shot with no Input directory or video raises ValueError.\"\"\"\n        empty_shot = tmp_path / \"empty_shot\"\n        empty_shot.mkdir()\n        entry = ClipEntry(\"empty_shot\", str(empty_shot))\n        with pytest.raises(ValueError, match=\"No 'Input' directory or video file found\"):\n            entry.find_assets()\n\n    def test_empty_input_dir_raises(self, tmp_path):\n        \"\"\"An empty Input/ directory raises ValueError.\"\"\"\n        shot = tmp_path / \"bad_shot\"\n        (shot / \"Input\").mkdir(parents=True)\n        entry = ClipEntry(\"bad_shot\", str(shot))\n        with pytest.raises(ValueError, match=\"'Input' directory is empty\"):\n            entry.find_assets()\n\n    def test_validate_pair_frame_count_mismatch(self, tmp_path):\n        \"\"\"Mismatched Input/AlphaHint frame counts raise ValueError.\"\"\"\n        shot = tmp_path / \"mismatch\"\n        (shot / \"Input\").mkdir(parents=True)\n        (shot / \"AlphaHint\").mkdir(parents=True)\n\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        tiny_mask = np.zeros((4, 4), dtype=np.uint8)\n\n        # 3 input frames, 2 alpha frames\n        for i in range(3):\n            cv2.imwrite(str(shot / \"Input\" / f\"frame_{i:04d}.png\"), tiny)\n        for i in range(2):\n            cv2.imwrite(str(shot / \"AlphaHint\" / f\"frame_{i:04d}.png\"), tiny_mask)\n\n        entry = ClipEntry(\"mismatch\", str(shot))\n        entry.find_assets()\n        with pytest.raises(ValueError, match=\"Frame count mismatch\"):\n            entry.validate_pair()\n\n    def test_validate_pair_matching_counts_ok(self, tmp_clip_dir):\n        \"\"\"Matching frame counts pass validation without error.\"\"\"\n        entry = ClipEntry(\"shot_a\", str(tmp_clip_dir / \"shot_a\"))\n        entry.find_assets()\n        entry.validate_pair()  # should not raise\n\n\n# ---------------------------------------------------------------------------\n# generate_alphas\n# ---------------------------------------------------------------------------\n\n\nclass TestGenerateAlphas:\n    \"\"\"\n    Tests for the generate_alphas orchestrator.\n    Focuses on GVM integration, directory cleanup, and filename remapping.\n    \"\"\"\n\n    def test_all_clips_valid_skips_generation(self, caplog):\n        \"\"\"\n        Scenario: Every provided clip already has a valid alpha_asset.\n        Expected: Logs that generation is unnecessary and returns without invoking GVM.\n        \"\"\"\n        caplog.set_level(\"INFO\")\n        clip = ClipEntry(\"shot_a\", \"/tmp/shot_a\")\n        clip.alpha_asset = MagicMock()\n\n        generate_alphas([clip])\n\n        assert \"All clips have valid Alpha assets\" in caplog.text\n\n    @patch(\"clip_manager.get_gvm_processor\")\n    def test_gvm_missing_exits_gracefully(self, mock_get_processor, caplog):\n        \"\"\"\n        Scenario: GVM requirements are missing (ImportError) during initialization.\n        Expected: Logs a specific GVM Import Error and exits early without a crash.\n        \"\"\"\n        mock_get_processor.side_effect = ImportError(\"No module named 'gvm'\")\n\n        clip = ClipEntry(\"shot_a\", \"/tmp/shot_a\")\n        clip.alpha_asset = None\n\n        generate_alphas([clip])\n\n        assert \"GVM Import Error\" in caplog.text\n        assert \"Skipping GVM generation\" in caplog.text\n\n    @patch(\"clip_manager.get_gvm_processor\")\n    def test_existing_alpha_dir_is_cleaned(self, _mock_gvm, tmp_path):\n        \"\"\"\n        Scenario: A legacy AlphaHint folder exists from a previous failed run.\n        Expected: Deletes the existing directory physically before creating a fresh one.\n        \"\"\"\n        shot_dir = tmp_path / \"shot_a\"\n        shot_dir.mkdir()\n        alpha_dir = shot_dir / \"AlphaHint\"\n        alpha_dir.mkdir()\n        (alpha_dir / \"old_file.png\").write_text(\"junk\")\n\n        clip = ClipEntry(\"shot_a\", str(shot_dir))\n        clip.alpha_asset = None\n        clip.input_asset = ClipAsset(str(shot_dir / \"in.mp4\"), \"video\")\n\n        generate_alphas([clip])\n\n        assert alpha_dir.exists()\n        assert not (alpha_dir / \"old_file.png\").exists()\n\n    @patch(\"clip_manager.get_gvm_processor\")\n    def test_naming_remap_sequence(self, mock_get_processor, tmp_path):\n        \"\"\"\n        Scenario: Input is a sequence; GVM 'processor' is called with Path objects.\n        Expected: Mock processor creates a file, and the renamer finds it in the AlphaHint dir.\n        \"\"\"\n        shot_dir = tmp_path / \"shot_01\"\n        shot_dir.mkdir()\n        input_dir = shot_dir / \"Input\"\n        input_dir.mkdir()\n        alpha_dir = shot_dir / \"AlphaHint\"\n\n        (input_dir / \"frame_A.png\").write_text(\"fake_png\")\n\n        clip = ClipEntry(\"shot_01\", str(shot_dir))\n        clip.input_asset = ClipAsset(path=str(input_dir), asset_type=\"sequence\")\n\n        mock_processor = MagicMock()\n        mock_get_processor.return_value = mock_processor\n\n        def side_effect_create_file(*args, **kwargs):\n            from pathlib import Path\n\n            target = Path(kwargs.get(\"direct_output_dir\"))\n            target.mkdir(parents=True, exist_ok=True)\n            (target / \"output_0.png\").write_text(\"mask\")\n\n        mock_processor.process_sequence.side_effect = side_effect_create_file\n\n        generate_alphas([clip])\n\n        expected_name = \"frame_A_alphaHint_0000.png\"\n        assert (alpha_dir / expected_name).exists()\n\n    @patch(\"clip_manager.get_gvm_processor\")\n    def test_naming_remap_video(self, mock_get_processor, tmp_path):\n        \"\"\"\n        Scenario: Input is a video file; stem 'my_clip' is used for remapping.\n        Expected: Generic GVM output is renamed to 'my_clip_alphaHint_0000.png'.\n        \"\"\"\n        shot_dir = tmp_path / \"shot_01\"\n        shot_dir.mkdir()\n        alpha_dir = shot_dir / \"AlphaHint\"\n\n        video_path = shot_dir / \"my_clip.mp4\"\n        video_path.write_text(\"headers\")\n\n        clip = ClipEntry(\"shot_01\", str(shot_dir))\n        clip.input_asset = ClipAsset(path=str(video_path), asset_type=\"video\")\n\n        mock_processor = MagicMock()\n        mock_get_processor.return_value = mock_processor\n\n        def side_effect_create_file(*args, **kwargs):\n            from pathlib import Path\n\n            target = Path(kwargs.get(\"direct_output_dir\"))\n            target.mkdir(parents=True, exist_ok=True)\n            (target / \"frame_0.png\").write_text(\"mask\")\n\n        mock_processor.process_sequence.side_effect = side_effect_create_file\n\n        generate_alphas([clip])\n\n        assert (alpha_dir / \"my_clip_alphaHint_0000.png\").exists()\n\n    @patch(\"clip_manager.get_gvm_processor\")\n    def test_empty_output_logs_error(self, mock_get_processor, tmp_path, caplog):\n        \"\"\"\n        Scenario: GVM finishes (mocked) but the output directory is physically empty.\n        Expected: The runner logs that no PNGs were found in AlphaHint.\n        \"\"\"\n        caplog.set_level(\"ERROR\")\n\n        shot_dir = tmp_path / \"shot_a\"\n        shot_dir.mkdir()\n        (shot_dir / \"AlphaHint\").mkdir()\n\n        clip = ClipEntry(\"shot_a\", str(shot_dir))\n        clip.input_asset = ClipAsset(str(shot_dir / \"in.mp4\"), \"video\")\n\n        mock_processor = MagicMock()\n        mock_get_processor.return_value = mock_processor\n\n        generate_alphas([clip])\n\n        assert \"no pngs found\" in caplog.text.lower()\n\n\n# ---------------------------------------------------------------------------\n# run_videomama\n# ---------------------------------------------------------------------------\n\n\nclass TestVideoMaMa:\n    def test_videomama_skips_if_sequence_exists(self, stage_shot, caplog):\n        \"\"\"\n        Scenario: A clip already has a populated AlphaHint directory.\n        Expected: run_videomama identifies no candidates and skips processing.\n        \"\"\"\n        caplog.set_level(\"INFO\")\n        path = stage_shot(\"shot_exists\", create_alpha=True)\n        mask_path = path / \"VideoMamaMaskHint\"\n        if mask_path.exists():\n            import shutil\n\n            shutil.rmtree(mask_path)\n\n        clip = ClipEntry(\"shot_exists\", str(path))\n        clip.find_assets()\n\n        run_videomama([clip])\n\n        assert \"No candidates for VideoMaMa\" in caplog.text\n\n    def test_videomama_processes_valid_candidate(self, stage_shot):\n        \"\"\"\n        Scenario: A clip has Input and VideoMamaMaskHint but no AlphaHint.\n        Expected: AlphaHint is created and populated with generated frames.\n        \"\"\"\n        path = stage_shot(\"shot_valid\")\n        clip = ClipEntry(\"shot_valid\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        alpha_dir = os.path.join(str(path), \"AlphaHint\")\n        assert os.path.isdir(alpha_dir)\n        assert len(os.listdir(alpha_dir)) > 0\n\n    def test_videomama_skips_if_input_missing(self, tmp_path):\n        \"\"\"\n        Scenario: A clip directory is missing the Input folder.\n        Expected: ClipEntry raises ValueError during discovery.\n        \"\"\"\n        path = tmp_path / \"shot_no_input\"\n        path.mkdir()\n        clip = ClipEntry(\"shot_no_input\", str(path))\n        with pytest.raises(ValueError, match=\"No 'Input' directory\"):\n            clip.find_assets()\n\n    def test_videomama_skips_if_mask_missing(self, stage_shot, caplog):\n        \"\"\"\n        Scenario: A clip is missing all valid VideoMamaMaskHint variants.\n        Expected: run_videomama skips the clip.\n        \"\"\"\n        caplog.set_level(\"INFO\")\n        path = stage_shot(\"shot_no_mask\")\n        for d in [\"VideoMamaMaskHint\", \"videomamamaskhint\", \"VIDEOMAMAMASKHINT\"]:\n            mask_path = path / d\n            if mask_path.exists():\n                import shutil\n\n                shutil.rmtree(mask_path)\n\n        clip = ClipEntry(\"shot_no_mask\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert \"No candidates for VideoMaMa\" in caplog.text\n\n    def test_videomama_mask_thresholding(self, stage_shot):\n        \"\"\"\n        Scenario: VideoMaMaMaskHint contains soft grayscale values.\n        Expected: Input masks are binarized before being passed to the model.\n        \"\"\"\n        path = stage_shot(\"shot_threshold\")\n        mask_path = path / \"VideoMamaMaskHint\" / \"mask_0000.png\"\n        soft_mask = np.full((4, 4), 128, dtype=np.uint8)\n        cv2.imwrite(str(mask_path), soft_mask)\n        clip = ClipEntry(\"shot_threshold\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert os.path.isdir(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_videomama_rgba_to_rgb_conversion(self, stage_shot):\n        \"\"\"\n        Scenario: Input directory contains 4-channel RGBA images.\n        Expected: Images are converted to 3-channel RGB without crashing.\n        \"\"\"\n        path = stage_shot(\"shot_rgba\")\n        in_file = path / \"Input\" / \"frame_0000.png\"\n        rgba = np.zeros((4, 4, 4), dtype=np.uint8)\n        cv2.imwrite(str(in_file), rgba)\n        clip = ClipEntry(\"shot_rgba\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert os.path.exists(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_videomama_exr_gamma_handling(self, stage_shot):\n        \"\"\"\n        Scenario: Input directory contains Linear EXR files.\n        Expected: Data is normalized and handled as linear float32.\n        \"\"\"\n        path = stage_shot(\"shot_exr\")\n        in_dir = path / \"Input\"\n        exr_file = str(in_dir / \"frame_0000.exr\")\n        img = np.zeros((4, 4, 3), dtype=np.float32)\n        cv2.imwrite(exr_file, img)\n        clip = ClipEntry(\"shot_exr\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert os.path.exists(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_safety_removes_file_blocking_dir(self, stage_shot):\n        \"\"\"\n        Scenario: A file exists where the AlphaHint directory needs to be created.\n        Expected: The blocking file is removed and replaced by a directory.\n        \"\"\"\n        path = stage_shot(\"shot_blocker\")\n        blocker = path / \"AlphaHint\"\n        blocker.write_text(\"i am a file\")\n        clip = ClipEntry(\"shot_blocker\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert blocker.is_dir()\n        assert len(os.listdir(blocker)) > 0\n\n    def test_videomama_multiple_clips_batch(self, stage_shot):\n        \"\"\"\n        Scenario: Multiple valid clip candidates are passed to the runner.\n        Expected: All candidates are processed and receive generated AlphaHints.\n        \"\"\"\n        path_1 = stage_shot(\"shot_1\")\n        path_2 = stage_shot(\"shot_2\")\n        c1 = ClipEntry(\"shot_1\", str(path_1))\n        c2 = ClipEntry(\"shot_2\", str(path_2))\n        c1.find_assets()\n        c2.find_assets()\n        run_videomama([c1, c2])\n        assert os.path.isdir(os.path.join(str(path_1), \"AlphaHint\"))\n        assert os.path.isdir(os.path.join(str(path_2), \"AlphaHint\"))\n\n    def test_videomama_upgrades_video_alpha(self, stage_shot):\n        \"\"\"\n        Scenario: A clip uses a video file as input rather than a sequence.\n        Expected: VideoMaMa processes the video and outputs an image sequence alpha.\n        \"\"\"\n        path = stage_shot(\"shot_video\")\n        clip = ClipEntry(\"shot_video\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert os.path.isdir(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_videomama_handles_invalid_image_load(self, stage_shot, caplog):\n        \"\"\"\n        Scenario: The runner attempts to load a non-image file.\n        Expected: The failure is logged.\n        \"\"\"\n        caplog.set_level(\"INFO\")\n        path = stage_shot(\"shot_corrupt\")\n        corrupt = path / \"Input\" / \"frame_0000.png\"\n        corrupt.write_text(\"not an image\")\n\n        clip = ClipEntry(\"shot_corrupt\", str(path))\n        clip.find_assets()\n\n        with patch(\"clip_manager.run_inference\") as mock_run:\n            mock_run.side_effect = Exception(\"corrupt image data\")\n            try:\n                run_videomama([clip])\n            except Exception:\n                pass\n\n        assert any(x in caplog.text.lower() for x in [\"error\", \"fail\", \"corrupt\"])\n\n    def test_videomama_priority_folder_over_video(self, stage_shot):\n        \"\"\"\n        Scenario: Both a video file and an Input directory exist in the shot.\n        Expected: The Input directory takes priority for processing.\n        \"\"\"\n        path = stage_shot(\"shot_priority\")\n        (path / \"input_video.mp4\").write_text(\"dummy\")\n        clip = ClipEntry(\"shot_priority\", str(path))\n        clip.find_assets()\n        assert clip.input_asset.type == \"sequence\"\n        run_videomama([clip])\n        assert os.path.isdir(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_loop_chunking_logic(self, tmp_path):\n        \"\"\"\n        Scenario: A 12-frame sequence is processed.\n        Expected: All 12 frames are saved to AlphaHint.\n        \"\"\"\n        path = tmp_path / \"shot_large\"\n        in_dir = path / \"Input\"\n        mask_dir = path / \"VideoMamaMaskHint\"\n        in_dir.mkdir(parents=True)\n        mask_dir.mkdir(parents=True)\n\n        for i in range(12):\n            cv2.imwrite(str(in_dir / f\"frame_{i:04d}.png\"), np.zeros((4, 4, 3), np.uint8))\n            cv2.imwrite(str(mask_dir / f\"mask_{i:04d}.png\"), np.zeros((4, 4), np.uint8))\n\n        clip = ClipEntry(\"shot_large\", str(path))\n        clip.find_assets()\n\n        run_videomama([clip])\n\n        alpha_dir = os.path.join(str(path), \"AlphaHint\")\n        files = [f for f in os.listdir(alpha_dir) if f.endswith(\".png\")]\n        assert len(files) == 12\n\n    def test_videomama_mask_from_video(self, stage_shot):\n        \"\"\"\n        Scenario: The mask hint is provided as a video file instead of a sequence.\n        Expected: The runner extracts frames from the mask video to guide inference.\n        \"\"\"\n        path = stage_shot(\"shot_mask_vid\")\n        clip = ClipEntry(\"shot_mask_vid\", str(path))\n        clip.find_assets()\n        run_videomama([clip])\n        assert os.path.isdir(os.path.join(str(path), \"AlphaHint\"))\n\n    def test_videomama_cleanup_on_failure(self, stage_shot, caplog):\n        \"\"\"\n        Scenario: An error occurs during the inference loop.\n        Expected: The error is caught by the runner's try/except and logged.\n        Note: Currently, load_videomama_model is outside the main loop's\n        try/except, so it raises directly to the caller. This should be fixed!\n        \"\"\"\n        import sys\n\n        caplog.set_level(\"ERROR\")\n        path = stage_shot(\"shot_fail\")\n        (path / \"Input\").mkdir(parents=True, exist_ok=True)\n        (path / \"VideoMamaMaskHint\").mkdir(parents=True, exist_ok=True)\n\n        clip = ClipEntry(\"shot_fail\", str(path))\n        clip.find_assets()\n\n        mock_inference_mod = MagicMock()\n        mock_inference_mod.load_videomama_model.side_effect = RuntimeError(\"GPU OOM\")\n\n        sys.modules[\"VideoMaMaInferenceModule.inference\"] = mock_inference_mod\n\n        try:\n            with pytest.raises(RuntimeError, match=\"GPU OOM\"):\n                run_videomama([clip])\n        finally:\n            if \"VideoMaMaInferenceModule.inference\" in sys.modules:\n                del sys.modules[\"VideoMaMaInferenceModule.inference\"]\n\n\n# ---------------------------------------------------------------------------\n# organize_target\n# ---------------------------------------------------------------------------\n\n\nclass TestOrganizeTarget:\n    \"\"\"organize_target() sets up the hint directory structure for a shot.\n\n    It creates AlphaHint/ and VideoMamaMaskHint/ directories if missing.\n    \"\"\"\n\n    def test_creates_hint_directories(self, tmp_path):\n        \"\"\"Missing hint directories should be created.\"\"\"\n        shot = tmp_path / \"shot_x\"\n        (shot / \"Input\").mkdir(parents=True)\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        cv2.imwrite(str(shot / \"Input\" / \"frame_0000.png\"), tiny)\n\n        organize_target(str(shot))\n\n        assert (shot / \"AlphaHint\").is_dir()\n        assert (shot / \"VideoMamaMaskHint\").is_dir()\n\n    def test_existing_hint_dirs_preserved(self, tmp_clip_dir):\n        \"\"\"Existing hint directories and their contents are not disturbed.\"\"\"\n        shot_a = tmp_clip_dir / \"shot_a\"\n        alpha_files_before = sorted(os.listdir(shot_a / \"AlphaHint\"))\n\n        organize_target(str(shot_a))\n\n        alpha_files_after = sorted(os.listdir(shot_a / \"AlphaHint\"))\n        assert alpha_files_before == alpha_files_after\n\n    def test_moves_loose_images_to_input(self, tmp_path):\n        \"\"\"Loose image files in a shot dir get moved into Input/.\"\"\"\n        shot = tmp_path / \"messy_shot\"\n        shot.mkdir()\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        cv2.imwrite(str(shot / \"frame_0000.png\"), tiny)\n        cv2.imwrite(str(shot / \"frame_0001.png\"), tiny)\n\n        organize_target(str(shot))\n\n        assert (shot / \"Input\").is_dir()\n        input_files = os.listdir(shot / \"Input\")\n        assert len(input_files) == 2\n        # Original loose files should be gone\n        assert not (shot / \"frame_0000.png\").exists()\n\n\n# ---------------------------------------------------------------------------\n# organize_clips\n# ---------------------------------------------------------------------------\n\n\nclass TestOrganizeClips:\n    \"\"\"\n    Tests for the legacy wrapper that organizes the main /Clips directory.\n    \"\"\"\n\n    def test_organize_loose_video_file(self, tmp_path):\n        \"\"\"\n        Tests that a loose .mp4 file is moved into its own folder.\n\n        Scenario: A directory contains a loose video file like 'shot_001.mp4'.\n        Expected: A new folder 'shot_001' is created, containing 'Input.mp4' and an empty 'AlphaHint' directory.\n        \"\"\"\n        clips_dir = tmp_path / \"ClipsForInference\"\n        clips_dir.mkdir()\n\n        video_file = clips_dir / \"shot_001.mp4\"\n        video_file.write_text(\"test_video_data\")\n\n        with patch(\"clip_manager.organize_target\") as mock_target:\n            organize_clips(str(clips_dir))\n\n        target_folder = clips_dir / \"shot_001\"\n        assert target_folder.is_dir(), f\"Folder {target_folder} was not created!\"\n        assert (target_folder / \"Input.mp4\").exists()\n        assert (target_folder / \"AlphaHint\").exists()\n\n        mock_target.assert_called_with(str(target_folder))\n\n    def test_skips_video_if_folder_exists(self, tmp_path, caplog):\n        \"\"\"\n        Tests that a video is skipped if a folder with its name already exists.\n\n        Scenario: Both 'shot_001.mp4' and a folder named 'shot_001' exist.\n        Expected: The original file is left alone, and a conflict warning is logged.\n        \"\"\"\n        clips_dir = tmp_path / \"ClipsForInference\"\n        clips_dir.mkdir()\n\n        video_path = clips_dir / \"shot_001.mp4\"\n        video_path.write_text(\"data\")\n\n        conflict_folder = clips_dir / \"shot_001\"\n        conflict_folder.mkdir()\n\n        organize_clips(str(clips_dir))\n        assert video_path.exists(), \"The video was moved even though a folder existed!\"\n        assert \"already exists\" in caplog.text\n\n    def test_ignores_protected_folders(self, tmp_path):\n        \"\"\"\n        Tests that 'Output' and 'IgnoredClips' folders are not processed.\n\n        Scenario: Directory contains a valid shot folder plus 'Output' and 'IgnoredClips'.\n        Expected: 'organize_target' is called exactly once (only for the valid shot).\n        \"\"\"\n        clips_dir = tmp_path / \"ClipsForInference\"\n        clips_dir.mkdir()\n\n        (clips_dir / \"shot_001\").mkdir()\n        (clips_dir / \"Output\").mkdir()\n        (clips_dir / \"IgnoredClips\").mkdir()\n\n        with patch(\"clip_manager.organize_target\") as mock_target:\n            organize_clips(str(clips_dir))\n\n        mock_target.assert_any_call(str(clips_dir / \"shot_001\"))\n\n        assert mock_target.call_count == 1, f\"Expected 1 call, but got {mock_target.call_count}\"\n\n    def test_handles_nonexistent_directory(self, caplog):\n        \"\"\"\n        Tests that the function exits gracefully if the directory is missing.\n\n        Scenario: The provided path does not exist on the filesystem.\n        Expected: Function logs a 'directory not found' warning and returns early.\n        \"\"\"\n        fake_path = \"/tmp/ghost_directory_12345\"\n\n        organize_clips(fake_path)\n\n        assert \"directory not found\" in caplog.text\n        assert fake_path in caplog.text\n\n    def test_batch_organization_mix(self, tmp_path):\n        \"\"\"\n        Tests that the function handles a mix of loose videos and folders at once.\n\n        Scenario: Directory contains one loose video and one already existing folder.\n        Expected: The video is migrated, and 'organize_target' is called for both.\n        \"\"\"\n        clips_dir = tmp_path / \"ClipsForInference\"\n        clips_dir.mkdir()\n\n        video_a = clips_dir / \"shot_A.mp4\"\n        video_a.write_text(\"video_data\")\n\n        folder_b = clips_dir / \"shot_B\"\n        folder_b.mkdir()\n\n        with patch(\"clip_manager.organize_target\") as mock_target:\n            organize_clips(str(clips_dir))\n\n        assert (clips_dir / \"shot_A\" / \"Input.mp4\").exists()\n\n        mock_target.assert_any_call(str(clips_dir / \"shot_A\"))\n        mock_target.assert_any_call(str(clips_dir / \"shot_B\"))\n        assert mock_target.call_count == 2\n\n\n# ---------------------------------------------------------------------------\n# scan_clips\n# ---------------------------------------------------------------------------\n\n\nclass TestScanClips:\n    \"\"\"\n    Tests for the scan_clips file orchestrator.\n    Ensures directory health, automatic organization, and validation reporting.\n    Added additions from #118\n    \"\"\"\n\n    def test_creates_clips_dir_and_returns_empty_if_missing(self, tmp_path, monkeypatch):\n        \"\"\"A missing CLIPS_DIR is created automatically and [] is returned.\"\"\"\n        import clip_manager\n\n        missing = str(tmp_path / \"ClipsForInference\")\n        monkeypatch.setattr(clip_manager, \"CLIPS_DIR\", missing)\n\n        result = scan_clips()\n\n        assert result == []\n        assert os.path.isdir(missing)\n\n    def test_returns_clips_with_valid_input(self, tmp_clip_dir, monkeypatch):\n        \"\"\"Clips whose Input directories exist are included in the result.\"\"\"\n        import clip_manager\n\n        monkeypatch.setattr(clip_manager, \"CLIPS_DIR\", str(tmp_clip_dir))\n        result = scan_clips()\n        names = {c.name for c in result}\n\n        assert \"shot_a\" in names\n        assert \"shot_b\" in names  # valid input even without alpha\n\n    def test_excludes_frame_count_mismatch(self, tmp_clip_dir, monkeypatch):\n        \"\"\"A clip with mismatched Input/AlphaHint frame counts is excluded.\"\"\"\n        import clip_manager\n\n        mismatch = tmp_clip_dir / \"mismatch_shot\"\n        (mismatch / \"Input\").mkdir(parents=True)\n        (mismatch / \"AlphaHint\").mkdir()\n        tiny = np.zeros((4, 4, 3), dtype=np.uint8)\n        tiny_mask = np.zeros((4, 4), dtype=np.uint8)\n        for i in range(3):\n            cv2.imwrite(str(mismatch / \"Input\" / f\"frame_{i:04d}.png\"), tiny)\n        cv2.imwrite(str(mismatch / \"AlphaHint\" / \"frame_0000.png\"), tiny_mask)\n\n        monkeypatch.setattr(clip_manager, \"CLIPS_DIR\", str(tmp_clip_dir))\n        result = scan_clips()\n        names = {c.name for c in result}\n\n        assert \"mismatch_shot\" not in names\n        assert \"shot_a\" in names  # valid shot still found\n\n    def test_skips_hidden_and_underscore_dirs(self, tmp_clip_dir, monkeypatch):\n        \"\"\"Directories starting with '.' or '_' are never returned.\"\"\"\n        import clip_manager\n\n        (tmp_clip_dir / \".hidden\").mkdir()\n        (tmp_clip_dir / \"_temp\").mkdir()\n        monkeypatch.setattr(clip_manager, \"CLIPS_DIR\", str(tmp_clip_dir))\n\n        result = scan_clips()\n        names = {c.name for c in result}\n\n        assert \".hidden\" not in names\n        assert \"_temp\" not in names\n\n    def test_noise_filter_skips_hidden_folders(self, sandbox_clip_manager):\n        \"\"\"\n        Scenario: A .git folder and a 'shot_01' folder exist.\n        Expected: .git is ignored; shot_01 is returned.\n        \"\"\"\n        (sandbox_clip_manager / \".git\").mkdir()\n        valid_shot = sandbox_clip_manager / \"shot_01\"\n        valid_shot.mkdir()\n\n        input_dir = valid_shot / \"Input\"\n        input_dir.mkdir()\n        (input_dir / \"frame_0000.png\").write_text(\"data\")\n\n        results = scan_clips()\n\n        assert len(results) == 1\n        assert results[0].name == \"shot_01\"\n\n    def test_scanner_handles_multiple_shots(self, sandbox_clip_manager):\n        \"\"\"\n        Scenario: Multiple valid shot folders.\n        Expected: 3 ClipEntry objects found, verified in alphabetical order.\n        \"\"\"\n        for name in [\"shot_C\", \"shot_B\", \"shot_A\"]:\n            d = sandbox_clip_manager / name\n            d.mkdir()\n            (d / \"Input\").mkdir()\n            (d / \"Input\" / \"f.png\").write_text(\"data\")\n\n        results = scan_clips()\n\n        assert len(results) == 3\n        names = sorted([r.name for r in results])\n        assert names == [\"shot_A\", \"shot_B\", \"shot_C\"]\n\n    def test_ideal_organization_loose_videos(self, sandbox_clip_manager):\n        \"\"\"\n        Scenario: A loose video file 'my_clip.mp4' exists.\n        Expected: Folder 'my_clip' created with 'Input.mp4' inside.\n        \"\"\"\n        video_file = sandbox_clip_manager / \"my_clip.mp4\"\n        video_file.write_text(\"content\")\n        expected_folder = sandbox_clip_manager / \"my_clip\"\n        organize_clips(str(sandbox_clip_manager))\n\n        assert expected_folder.is_dir()\n        assert (expected_folder / \"Input.mp4\").exists() or (expected_folder / \"Input\" / \"my_clip.mp4\").exists()\n        assert (expected_folder / \"AlphaHint\").exists()\n\n    def test_organization_skips_existing_folders(self, sandbox_clip_manager, caplog):\n        \"\"\"\n        Scenario: Both 'collision.mp4' and folder 'collision' exist.\n        Expected: Conflict warning logged, file not moved.\n        \"\"\"\n        (sandbox_clip_manager / \"collision\").mkdir()\n        video_file = sandbox_clip_manager / \"collision.mp4\"\n        video_file.write_text(\"data\")\n        organize_clips(str(sandbox_clip_manager))\n\n        assert video_file.exists()\n        assert \"already exists\" in caplog.text.lower()\n\n    def test_batch_processing_mix(self, sandbox_clip_manager):\n        \"\"\"\n        Scenario: Mix of loose files and existing folders.\n        Expected: Loose files migrated; existing folders left intact.\n        \"\"\"\n        (sandbox_clip_manager / \"existing\").mkdir()\n        video_file = sandbox_clip_manager / \"new_shot.mp4\"\n        video_file.write_text(\"data\")\n        organize_clips(str(sandbox_clip_manager))\n\n        assert (sandbox_clip_manager / \"new_shot\").is_dir()\n        assert (sandbox_clip_manager / \"existing\").is_dir()\n\n    def test_nonexistent_directory_logging(self, caplog):\n        \"\"\"\n        Scenario: Path doesn't exist.\n        Expected: 'not found' warning logged.\n        \"\"\"\n        fake_path = \"/tmp/missing_dir_999\"\n        organize_clips(fake_path)\n\n        assert \"not found\" in caplog.text.lower()\n"
  },
  {
    "path": "tests/test_color_utils.py",
    "content": "\"\"\"Unit tests for CorridorKeyModule.core.color_utils.\n\nThese tests verify the color math that underpins CorridorKey's compositing\npipeline.  Every function is tested with both numpy arrays and PyTorch tensors\nbecause color_utils supports both backends and bugs can hide in one path.\n\nNo GPU or model weights required — pure math.\n\"\"\"\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom CorridorKeyModule.core import color_utils as cu\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _to_np(x):\n    \"\"\"Ensure value is a numpy float32 array.\"\"\"\n    return np.asarray(x, dtype=np.float32)\n\n\ndef _to_torch(x):\n    \"\"\"Ensure value is a float32 torch tensor.\"\"\"\n    return torch.tensor(x, dtype=torch.float32)\n\n\n# ---------------------------------------------------------------------------\n# linear_to_srgb  /  srgb_to_linear\n# ---------------------------------------------------------------------------\n\n\nclass TestSrgbLinearConversion:\n    \"\"\"sRGB ↔ linear transfer function tests.\n\n    The piecewise sRGB spec uses exponent 2.4 (not \"gamma 2.2\").\n    Breakpoints: 0.0031308 (linear side), 0.04045 (sRGB side).\n    \"\"\"\n\n    # Known identity values: 0 → 0, 1 → 1\n    @pytest.mark.parametrize(\"value\", [0.0, 1.0])\n    def test_identity_values_numpy(self, value):\n        x = _to_np(value)\n        assert cu.linear_to_srgb(x) == pytest.approx(value, abs=1e-7)\n        assert cu.srgb_to_linear(x) == pytest.approx(value, abs=1e-7)\n\n    @pytest.mark.parametrize(\"value\", [0.0, 1.0])\n    def test_identity_values_torch(self, value):\n        x = _to_torch(value)\n        assert cu.linear_to_srgb(x).item() == pytest.approx(value, abs=1e-7)\n        assert cu.srgb_to_linear(x).item() == pytest.approx(value, abs=1e-7)\n\n    # Mid-gray: sRGB 0.5 ≈ linear 0.214\n    def test_mid_gray_numpy(self):\n        srgb_half = _to_np(0.5)\n        linear_val = cu.srgb_to_linear(srgb_half)\n        assert linear_val == pytest.approx(0.214, abs=0.001)\n\n    def test_mid_gray_torch(self):\n        srgb_half = _to_torch(0.5)\n        linear_val = cu.srgb_to_linear(srgb_half)\n        assert linear_val.item() == pytest.approx(0.214, abs=0.001)\n\n    # Roundtrip: linear → sRGB → linear ≈ original\n    @pytest.mark.parametrize(\"value\", [0.0, 0.001, 0.0031308, 0.05, 0.214, 0.5, 0.8, 1.0])\n    def test_roundtrip_numpy(self, value):\n        x = _to_np(value)\n        roundtripped = cu.srgb_to_linear(cu.linear_to_srgb(x))\n        assert roundtripped == pytest.approx(value, abs=1e-5)\n\n    @pytest.mark.parametrize(\"value\", [0.0, 0.001, 0.0031308, 0.05, 0.214, 0.5, 0.8, 1.0])\n    def test_roundtrip_torch(self, value):\n        x = _to_torch(value)\n        roundtripped = cu.srgb_to_linear(cu.linear_to_srgb(x))\n        assert roundtripped.item() == pytest.approx(value, abs=1e-5)\n\n    # Piecewise continuity: both branches must agree at the breakpoint\n    def test_breakpoint_continuity_linear_to_srgb(self):\n        # At linear = 0.0031308, the two branches should produce the same sRGB value\n        bp = 0.0031308\n        below = cu.linear_to_srgb(_to_np(bp - 1e-7))\n        above = cu.linear_to_srgb(_to_np(bp + 1e-7))\n        at = cu.linear_to_srgb(_to_np(bp))\n        # All three should be very close (no discontinuity)\n        assert below == pytest.approx(float(at), abs=1e-4)\n        assert above == pytest.approx(float(at), abs=1e-4)\n\n    def test_breakpoint_continuity_srgb_to_linear(self):\n        bp = 0.04045\n        below = cu.srgb_to_linear(_to_np(bp - 1e-7))\n        above = cu.srgb_to_linear(_to_np(bp + 1e-7))\n        at = cu.srgb_to_linear(_to_np(bp))\n        assert below == pytest.approx(float(at), abs=1e-4)\n        assert above == pytest.approx(float(at), abs=1e-4)\n\n    # Negative inputs should be clamped to 0\n    def test_negative_clamped_linear_to_srgb_numpy(self):\n        result = cu.linear_to_srgb(_to_np(-0.5))\n        assert float(result) == pytest.approx(0.0, abs=1e-7)\n\n    def test_negative_clamped_linear_to_srgb_torch(self):\n        result = cu.linear_to_srgb(_to_torch(-0.5))\n        assert result.item() == pytest.approx(0.0, abs=1e-7)\n\n    def test_negative_clamped_srgb_to_linear_numpy(self):\n        result = cu.srgb_to_linear(_to_np(-0.5))\n        assert float(result) == pytest.approx(0.0, abs=1e-7)\n\n    def test_negative_clamped_srgb_to_linear_torch(self):\n        result = cu.srgb_to_linear(_to_torch(-0.5))\n        assert result.item() == pytest.approx(0.0, abs=1e-7)\n\n    # Vectorized: works on arrays, not just scalars\n    def test_vectorized_numpy(self):\n        x = _to_np([0.0, 0.1, 0.5, 1.0])\n        result = cu.linear_to_srgb(x)\n        assert result.shape == (4,)\n        roundtripped = cu.srgb_to_linear(result)\n        np.testing.assert_allclose(roundtripped, x, atol=1e-5)\n\n    def test_vectorized_torch(self):\n        x = _to_torch([0.0, 0.1, 0.5, 1.0])\n        result = cu.linear_to_srgb(x)\n        assert result.shape == (4,)\n        roundtripped = cu.srgb_to_linear(result)\n        torch.testing.assert_close(roundtripped, x, atol=1e-5, rtol=1e-5)\n\n\n# ---------------------------------------------------------------------------\n# premultiply  /  unpremultiply\n# ---------------------------------------------------------------------------\n\n\nclass TestPremultiply:\n    \"\"\"Premultiply / unpremultiply tests.\n\n    The core compositing contract: premultiplied RGB = straight RGB * alpha.\n    \"\"\"\n\n    def test_roundtrip_numpy(self):\n        fg = _to_np([[0.8, 0.5, 0.2]])\n        alpha = _to_np([[0.6]])\n        premul = cu.premultiply(fg, alpha)\n        recovered = cu.unpremultiply(premul, alpha)\n        np.testing.assert_allclose(recovered, fg, atol=1e-5)\n\n    def test_roundtrip_torch(self):\n        fg = _to_torch([[0.8, 0.5, 0.2]])\n        alpha = _to_torch([[0.6]])\n        premul = cu.premultiply(fg, alpha)\n        recovered = cu.unpremultiply(premul, alpha)\n        torch.testing.assert_close(recovered, fg, atol=1e-5, rtol=1e-5)\n\n    def test_output_bounded_by_fg_numpy(self):\n        \"\"\"Premultiplied RGB must be <= straight RGB when 0 <= alpha <= 1.\"\"\"\n        fg = _to_np([[1.0, 0.5, 0.3]])\n        alpha = _to_np([[0.7]])\n        premul = cu.premultiply(fg, alpha)\n        assert np.all(premul <= fg + 1e-7)\n\n    def test_output_bounded_by_fg_torch(self):\n        fg = _to_torch([[1.0, 0.5, 0.3]])\n        alpha = _to_torch([[0.7]])\n        premul = cu.premultiply(fg, alpha)\n        assert torch.all(premul <= fg + 1e-7)\n\n    def test_zero_alpha_numpy(self):\n        \"\"\"Premultiply by zero alpha → zero RGB.\"\"\"\n        fg = _to_np([[0.8, 0.5, 0.2]])\n        alpha = _to_np([[0.0]])\n        premul = cu.premultiply(fg, alpha)\n        np.testing.assert_allclose(premul, 0.0, atol=1e-7)\n\n    def test_one_alpha_numpy(self):\n        \"\"\"Premultiply by alpha=1 → unchanged.\"\"\"\n        fg = _to_np([[0.8, 0.5, 0.2]])\n        alpha = _to_np([[1.0]])\n        premul = cu.premultiply(fg, alpha)\n        np.testing.assert_allclose(premul, fg, atol=1e-7)\n\n\n# ---------------------------------------------------------------------------\n# composite_straight  /  composite_premul\n# ---------------------------------------------------------------------------\n\n\nclass TestCompositing:\n    \"\"\"The Porter-Duff 'over' operator: A over B.\n\n    composite_straight and composite_premul must produce the same result\n    given equivalent inputs.\n    \"\"\"\n\n    def test_straight_vs_premul_equivalence_numpy(self):\n        fg = _to_np([0.9, 0.3, 0.1])\n        bg = _to_np([0.1, 0.2, 0.8])\n        alpha = _to_np(0.6)\n\n        result_straight = cu.composite_straight(fg, bg, alpha)\n        fg_premul = cu.premultiply(fg, alpha)\n        result_premul = cu.composite_premul(fg_premul, bg, alpha)\n\n        np.testing.assert_allclose(result_straight, result_premul, atol=1e-6)\n\n    def test_straight_vs_premul_equivalence_torch(self):\n        fg = _to_torch([0.9, 0.3, 0.1])\n        bg = _to_torch([0.1, 0.2, 0.8])\n        alpha = _to_torch(0.6)\n\n        result_straight = cu.composite_straight(fg, bg, alpha)\n        fg_premul = cu.premultiply(fg, alpha)\n        result_premul = cu.composite_premul(fg_premul, bg, alpha)\n\n        torch.testing.assert_close(result_straight, result_premul, atol=1e-6, rtol=1e-6)\n\n    def test_alpha_zero_shows_background(self):\n        fg = _to_np([1.0, 0.0, 0.0])\n        bg = _to_np([0.0, 0.0, 1.0])\n        alpha = _to_np(0.0)\n        result = cu.composite_straight(fg, bg, alpha)\n        np.testing.assert_allclose(result, bg, atol=1e-7)\n\n    def test_alpha_one_shows_foreground(self):\n        fg = _to_np([1.0, 0.0, 0.0])\n        bg = _to_np([0.0, 0.0, 1.0])\n        alpha = _to_np(1.0)\n        result = cu.composite_straight(fg, bg, alpha)\n        np.testing.assert_allclose(result, fg, atol=1e-7)\n\n\n# ---------------------------------------------------------------------------\n# despill\n# ---------------------------------------------------------------------------\n\n\nclass TestDespill:\n    \"\"\"Green spill removal.\n\n    The despill function clamps excess green based on red/blue, then\n    redistributes the removed energy to preserve luminance.\n    \"\"\"\n\n    def test_pure_green_reduced_average_mode_numpy(self):\n        \"\"\"A pure green pixel should have green clamped to (R+B)/2 = 0.\"\"\"\n        img = _to_np([[0.0, 1.0, 0.0]])\n        result = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        # Green should be 0 (clamped to avg of R=0, B=0)\n        assert result[0, 1] == pytest.approx(0.0, abs=1e-6)\n\n    def test_pure_green_reduced_max_mode_numpy(self):\n        \"\"\"With 'max' mode, green clamped to max(R, B) = 0 for pure green.\"\"\"\n        img = _to_np([[0.0, 1.0, 0.0]])\n        result = cu.despill(img, green_limit_mode=\"max\", strength=1.0)\n        assert result[0, 1] == pytest.approx(0.0, abs=1e-6)\n\n    def test_pure_red_unchanged_numpy(self):\n        \"\"\"A pixel with no green excess should not be modified.\"\"\"\n        img = _to_np([[1.0, 0.0, 0.0]])\n        result = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        np.testing.assert_allclose(result, img, atol=1e-6)\n\n    def test_strength_zero_is_noop_numpy(self):\n        \"\"\"strength=0 should return the input unchanged.\"\"\"\n        img = _to_np([[0.2, 0.9, 0.1]])\n        result = cu.despill(img, strength=0.0)\n        np.testing.assert_allclose(result, img, atol=1e-7)\n\n    def test_partial_green_average_mode_numpy(self):\n        \"\"\"Green slightly above (R+B)/2 should be reduced, not zeroed.\"\"\"\n        img = _to_np([[0.4, 0.8, 0.2]])\n        result = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        limit = (0.4 + 0.2) / 2.0  # 0.3\n        expected_green = limit  # green clamped to limit\n        assert result[0, 1] == pytest.approx(expected_green, abs=1e-5)\n\n    def test_max_mode_higher_limit_than_average(self):\n        \"\"\"'max' mode uses max(R,B) which is >= (R+B)/2, so less despill.\"\"\"\n        img = _to_np([[0.6, 0.8, 0.1]])\n        result_avg = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        result_max = cu.despill(img, green_limit_mode=\"max\", strength=1.0)\n        # max(R,B)=0.6 vs avg(R,B)=0.35, so max mode removes less green\n        assert result_max[0, 1] >= result_avg[0, 1]\n\n    def test_fractional_strength_interpolates(self):\n        \"\"\"strength=0.5 should produce a result between original and fully despilled.\"\"\"\n        img = _to_np([[0.2, 0.9, 0.1]])\n        full = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        half = cu.despill(img, green_limit_mode=\"average\", strength=0.5)\n        # Half-strength green should be between original green and fully despilled green\n        assert half[0, 1] < img[0, 1]  # less green than original\n        assert half[0, 1] > full[0, 1]  # more green than full despill\n        # Verify it's actually the midpoint: img * 0.5 + full * 0.5\n        expected = img * 0.5 + full * 0.5\n        np.testing.assert_allclose(half, expected, atol=1e-6)\n\n    def test_despill_torch(self):\n        \"\"\"Verify torch path matches numpy path.\"\"\"\n        img_np = _to_np([[0.3, 0.9, 0.2]])\n        img_t = _to_torch([[0.3, 0.9, 0.2]])\n        result_np = cu.despill(img_np, green_limit_mode=\"average\", strength=1.0)\n        result_t = cu.despill(img_t, green_limit_mode=\"average\", strength=1.0)\n        np.testing.assert_allclose(result_np, result_t.numpy(), atol=1e-5)\n\n    def test_green_below_limit_unchanged_numpy(self):\n        \"\"\"spill_amount is clamped to zero when G < (R+B)/2 — pixel is returned unchanged.\n\n        When a pixel has less green than the luminance limit ((R+B)/2) it\n        carries no green spill.  The max(..., 0) clamp on spill_amount ensures\n        the pixel is left untouched.  Without that clamp despill would\n        *increase* green and *decrease* red/blue, corrupting non-spill regions.\n        \"\"\"\n        # G=0.3 is well below the average limit (0.8+0.6)/2 = 0.7\n        # spill_amount = max(0.3 - 0.7, 0) = 0  →  output equals input\n        img = _to_np([[0.8, 0.3, 0.6]])\n        result = cu.despill(img, green_limit_mode=\"average\", strength=1.0)\n        np.testing.assert_allclose(result, img, atol=1e-6)\n\n\n# ---------------------------------------------------------------------------\n# clean_matte\n# ---------------------------------------------------------------------------\n\n\nclass TestCleanMatte:\n    \"\"\"Connected-component cleanup of alpha mattes.\n\n    Small disconnected blobs (tracking markers, noise) should be removed\n    while large foreground regions are preserved.\n    \"\"\"\n\n    def test_large_blob_preserved(self):\n        \"\"\"A single large opaque region should survive cleanup.\"\"\"\n        matte = np.zeros((100, 100), dtype=np.float32)\n        matte[20:80, 20:80] = 1.0  # 60x60 = 3600 pixels\n        result = cu.clean_matte(matte, area_threshold=300)\n        # Center of the blob should still be opaque\n        assert result[50, 50] > 0.9\n\n    def test_small_blob_removed(self):\n        \"\"\"A tiny blob below the threshold should be removed.\"\"\"\n        matte = np.zeros((100, 100), dtype=np.float32)\n        matte[5:8, 5:8] = 1.0  # 3x3 = 9 pixels\n        result = cu.clean_matte(matte, area_threshold=300)\n        assert result[6, 6] == pytest.approx(0.0, abs=1e-5)\n\n    def test_mixed_blobs(self):\n        \"\"\"Large blob kept, small blob removed.\"\"\"\n        matte = np.zeros((200, 200), dtype=np.float32)\n        # Large blob: 50x50 = 2500 px\n        matte[10:60, 10:60] = 1.0\n        # Small blob: 5x5 = 25 px\n        matte[150:155, 150:155] = 1.0\n\n        result = cu.clean_matte(matte, area_threshold=100)\n        assert result[35, 35] > 0.9  # large blob center preserved\n        assert result[152, 152] < 0.01  # small blob removed\n\n    def test_3d_input_preserved(self):\n        \"\"\"[H, W, 1] input should return [H, W, 1] output.\"\"\"\n        matte = np.zeros((50, 50, 1), dtype=np.float32)\n        matte[10:40, 10:40, 0] = 1.0\n        result = cu.clean_matte(matte, area_threshold=100)\n        assert result.ndim == 3\n        assert result.shape[2] == 1\n\n\n# ---------------------------------------------------------------------------\n# create_checkerboard\n# ---------------------------------------------------------------------------\n\n\nclass TestCheckerboard:\n    \"\"\"Checkerboard pattern generator used for QC composites.\"\"\"\n\n    def test_output_shape(self):\n        result = cu.create_checkerboard(640, 480)\n        assert result.shape == (480, 640, 3)\n\n    def test_output_range(self):\n        result = cu.create_checkerboard(100, 100, color1=0.2, color2=0.4)\n        assert result.min() >= 0.0\n        assert result.max() <= 1.0\n\n    def test_uses_specified_colors(self):\n        result = cu.create_checkerboard(128, 128, checker_size=64, color1=0.1, color2=0.9)\n        unique_vals = np.unique(result[:, :, 0])\n        np.testing.assert_allclose(sorted(unique_vals), [0.1, 0.9], atol=1e-6)\n\n\n# ---------------------------------------------------------------------------\n# rgb_to_yuv\n# ---------------------------------------------------------------------------\n\n\nclass TestRgbToYuv:\n    \"\"\"RGB to YUV (Rec. 601) conversion.\n\n    Three layout branches: BCHW (4D), CHW (3D channel-first), and\n    last-dim (3D/2D channel-last). Each independently indexes channels,\n    so a wrong index silently swaps color information.\n\n    Known Rec. 601 coefficients: Y = 0.299R + 0.587G + 0.114B\n    \"\"\"\n\n    def test_pure_white_bchw(self):\n        \"\"\"Pure white (1,1,1) → Y=1, U=0, V=0 in any colorspace.\"\"\"\n        img = torch.ones(1, 3, 2, 2)  # BCHW\n        result = cu.rgb_to_yuv(img)\n        assert result.shape == (1, 3, 2, 2)\n        # Y channel should be 1.0\n        torch.testing.assert_close(result[:, 0], torch.ones(1, 2, 2), atol=1e-5, rtol=1e-5)\n        # U and V should be ~0 for achromatic input\n        assert result[:, 1].abs().max() < 1e-5\n        assert result[:, 2].abs().max() < 1e-5\n\n    def test_pure_red_known_values(self):\n        \"\"\"Pure red (1,0,0) → known Y, U, V from Rec. 601 coefficients.\"\"\"\n        img = torch.zeros(1, 3, 1, 1)\n        img[0, 0, 0, 0] = 1.0  # R=1, G=0, B=0\n        result = cu.rgb_to_yuv(img)\n        expected_y = 0.299\n        expected_u = 0.492 * (0.0 - expected_y)  # 0.492 * (B - Y)\n        expected_v = 0.877 * (1.0 - expected_y)  # 0.877 * (R - Y)\n        assert result[0, 0, 0, 0].item() == pytest.approx(expected_y, abs=1e-5)\n        assert result[0, 1, 0, 0].item() == pytest.approx(expected_u, abs=1e-5)\n        assert result[0, 2, 0, 0].item() == pytest.approx(expected_v, abs=1e-5)\n\n    def test_chw_layout(self):\n        \"\"\"3D CHW input (channel-first) should produce CHW output.\"\"\"\n        img = torch.zeros(3, 4, 4)\n        img[1, :, :] = 1.0  # Pure green\n        result = cu.rgb_to_yuv(img)\n        assert result.shape == (3, 4, 4)\n        expected_y = 0.587  # 0.299*0 + 0.587*1 + 0.114*0\n        assert result[0, 0, 0].item() == pytest.approx(expected_y, abs=1e-5)\n\n    def test_last_dim_layout(self):\n        \"\"\"2D [N, 3] input (channel-last) should produce [N, 3] output.\"\"\"\n        img = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])\n        result = cu.rgb_to_yuv(img)\n        assert result.shape == (3, 3)\n        # Row 0 is pure red: Y = 0.299\n        assert result[0, 0].item() == pytest.approx(0.299, abs=1e-5)\n        # Row 1 is pure green: Y = 0.587\n        assert result[1, 0].item() == pytest.approx(0.587, abs=1e-5)\n        # Row 2 is pure blue: Y = 0.114\n        assert result[2, 0].item() == pytest.approx(0.114, abs=1e-5)\n\n    def test_rejects_numpy(self):\n        \"\"\"rgb_to_yuv is torch-only — numpy input should raise TypeError.\"\"\"\n        img = np.zeros((3, 4, 4), dtype=np.float32)\n        with pytest.raises(TypeError):\n            cu.rgb_to_yuv(img)\n\n\n# ---------------------------------------------------------------------------\n# dilate_mask\n# ---------------------------------------------------------------------------\n\n\nclass TestDilateMask:\n    \"\"\"Mask dilation via cv2 (numpy) or max_pool2d (torch).\n\n    Both backends should expand the mask outward. radius=0 is a no-op.\n    \"\"\"\n\n    def test_radius_zero_noop_numpy(self):\n        mask = np.zeros((50, 50), dtype=np.float32)\n        mask[20:30, 20:30] = 1.0\n        result = cu.dilate_mask(mask, radius=0)\n        np.testing.assert_array_equal(result, mask)\n\n    def test_radius_zero_noop_torch(self):\n        mask = torch.zeros(50, 50)\n        mask[20:30, 20:30] = 1.0\n        result = cu.dilate_mask(mask, radius=0)\n        torch.testing.assert_close(result, mask)\n\n    def test_dilation_expands_numpy(self):\n        \"\"\"Dilated mask should be >= original at every pixel.\"\"\"\n        mask = np.zeros((50, 50), dtype=np.float32)\n        mask[20:30, 20:30] = 1.0\n        result = cu.dilate_mask(mask, radius=3)\n        assert np.all(result >= mask)\n        # Pixels just outside the original region should now be 1\n        assert result[19, 25] > 0  # above the original box\n        assert result[25, 19] > 0  # left of the original box\n\n    def test_dilation_expands_torch(self):\n        \"\"\"Dilated mask should be >= original at every pixel (torch path).\"\"\"\n        mask = torch.zeros(50, 50)\n        mask[20:30, 20:30] = 1.0\n        result = cu.dilate_mask(mask, radius=3)\n        assert torch.all(result >= mask)\n        assert result[19, 25] > 0\n        assert result[25, 19] > 0\n\n    def test_preserves_2d_shape_numpy(self):\n        mask = np.zeros((40, 60), dtype=np.float32)\n        result = cu.dilate_mask(mask, radius=5)\n        assert result.shape == (40, 60)\n\n    def test_preserves_2d_shape_torch(self):\n        mask = torch.zeros(40, 60)\n        result = cu.dilate_mask(mask, radius=5)\n        assert result.shape == (40, 60)\n\n    def test_preserves_3d_shape_torch(self):\n        \"\"\"[C, H, W] input should return [C, H, W] output.\"\"\"\n        mask = torch.zeros(1, 40, 60)\n        result = cu.dilate_mask(mask, radius=5)\n        assert result.shape == (1, 40, 60)\n\n\n# ---------------------------------------------------------------------------\n# apply_garbage_matte\n# ---------------------------------------------------------------------------\n\n\nclass TestApplyGarbageMatte:\n    \"\"\"Garbage matte application: multiplies predicted matte by a dilated coarse mask.\n\n    Used to zero out regions outside the coarse matte (rigs, lights, etc.).\n    \"\"\"\n\n    def test_none_input_passthrough(self):\n        \"\"\"If no garbage matte is provided, the predicted matte is returned unchanged.\"\"\"\n        rng = np.random.default_rng(42)\n        matte = rng.random((100, 100)).astype(np.float32)\n        result = cu.apply_garbage_matte(matte, None)\n        np.testing.assert_array_equal(result, matte)\n\n    def test_zeros_outside_garbage_region(self):\n        \"\"\"Regions outside the garbage matte should be zeroed.\"\"\"\n        predicted = np.ones((50, 50), dtype=np.float32)\n        garbage = np.zeros((50, 50), dtype=np.float32)\n        garbage[10:40, 10:40] = 1.0  # only center is valid\n        result = cu.apply_garbage_matte(predicted, garbage, dilation=0)\n        # Outside the garbage matte region should be 0\n        assert result[0, 0] == pytest.approx(0.0, abs=1e-7)\n        # Inside should be preserved\n        assert result[25, 25] == pytest.approx(1.0, abs=1e-7)\n\n    def test_3d_matte_with_2d_garbage(self):\n        \"\"\"[H, W, 1] predicted matte with [H, W] garbage matte should broadcast.\"\"\"\n        predicted = np.ones((50, 50, 1), dtype=np.float32)\n        garbage = np.zeros((50, 50), dtype=np.float32)\n        garbage[10:40, 10:40] = 1.0\n        result = cu.apply_garbage_matte(predicted, garbage, dilation=0)\n        assert result.shape == (50, 50, 1)\n        assert result[0, 0, 0] == pytest.approx(0.0, abs=1e-7)\n        assert result[25, 25, 0] == pytest.approx(1.0, abs=1e-7)\n"
  },
  {
    "path": "tests/test_device_utils.py",
    "content": "\"\"\"Unit tests for device_utils — cross-platform device selection.\n\nTests cover all code paths in detect_best_device(), resolve_device(),\nand clear_device_cache() using monkeypatch to mock hardware availability.\nNo GPU required.\n\"\"\"\n\nfrom unittest.mock import MagicMock\n\nimport pytest\nimport torch\n\nfrom device_utils import (\n    DEVICE_ENV_VAR,\n    clear_device_cache,\n    detect_best_device,\n    resolve_device,\n)\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _patch_gpu(monkeypatch, *, cuda=False, mps=False):\n    \"\"\"Mock CUDA and MPS availability flags.\"\"\"\n    monkeypatch.setattr(torch.cuda, \"is_available\", lambda: cuda)\n    # MPS lives behind torch.backends.mps; ensure the attr path exists\n    mps_backend = MagicMock()\n    mps_backend.is_available = MagicMock(return_value=mps)\n    monkeypatch.setattr(torch.backends, \"mps\", mps_backend)\n\n\n# ---------------------------------------------------------------------------\n# detect_best_device\n# ---------------------------------------------------------------------------\n\n\nclass TestDetectBestDevice:\n    \"\"\"Priority chain: CUDA > MPS > CPU.\"\"\"\n\n    def test_returns_cuda_when_available(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=True, mps=True)\n        assert detect_best_device() == \"cuda\"\n\n    def test_returns_mps_when_no_cuda(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False, mps=True)\n        assert detect_best_device() == \"mps\"\n\n    def test_returns_cpu_when_nothing(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False, mps=False)\n        assert detect_best_device() == \"cpu\"\n\n\n# ---------------------------------------------------------------------------\n# resolve_device\n# ---------------------------------------------------------------------------\n\n\nclass TestResolveDevice:\n    \"\"\"Priority chain: CLI arg > env var > auto-detect.\"\"\"\n\n    # --- auto-detect path ---\n\n    def test_none_triggers_auto_detect(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False, mps=False)\n        monkeypatch.delenv(DEVICE_ENV_VAR, raising=False)\n        assert resolve_device(None) == \"cpu\"\n\n    def test_auto_string_triggers_auto_detect(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=True)\n        monkeypatch.delenv(DEVICE_ENV_VAR, raising=False)\n        assert resolve_device(\"auto\") == \"cuda\"\n\n    # --- env var fallback ---\n\n    def test_env_var_used_when_no_cli_arg(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=True, mps=True)\n        monkeypatch.setenv(DEVICE_ENV_VAR, \"cpu\")\n        assert resolve_device(None) == \"cpu\"\n\n    def test_env_var_auto_triggers_detect(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False, mps=True)\n        monkeypatch.setenv(DEVICE_ENV_VAR, \"auto\")\n        assert resolve_device(None) == \"mps\"\n\n    # --- CLI arg overrides env var ---\n\n    def test_cli_arg_overrides_env_var(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=True, mps=True)\n        monkeypatch.setenv(DEVICE_ENV_VAR, \"mps\")\n        assert resolve_device(\"cuda\") == \"cuda\"\n\n    # --- explicit valid devices ---\n\n    def test_explicit_cuda(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=True)\n        assert resolve_device(\"cuda\") == \"cuda\"\n\n    def test_explicit_mps(self, monkeypatch):\n        _patch_gpu(monkeypatch, mps=True)\n        assert resolve_device(\"mps\") == \"mps\"\n\n    def test_explicit_cpu(self, monkeypatch):\n        assert resolve_device(\"cpu\") == \"cpu\"\n\n    def test_case_insensitive(self, monkeypatch):\n        assert resolve_device(\"CPU\") == \"cpu\"\n\n    # --- unavailable backend errors ---\n\n    def test_cuda_unavailable_raises(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False)\n        with pytest.raises(RuntimeError, match=\"CUDA requested\"):\n            resolve_device(\"cuda\")\n\n    def test_mps_no_backend_raises(self, monkeypatch):\n        # Simulate PyTorch build without MPS module in torch.backends\n        _patch_gpu(monkeypatch, cuda=False, mps=False)\n        # Replace torch.backends with an object that lacks \"mps\" entirely\n        fake_backends = type(\"Backends\", (), {})()\n        monkeypatch.setattr(\"device_utils.torch.backends\", fake_backends)\n        with pytest.raises(RuntimeError, match=\"no MPS support\"):\n            resolve_device(\"mps\")\n\n    def test_mps_unavailable_raises(self, monkeypatch):\n        _patch_gpu(monkeypatch, cuda=False, mps=False)\n        with pytest.raises(RuntimeError, match=\"not available on this machine\"):\n            resolve_device(\"mps\")\n\n    # --- invalid device string ---\n\n    def test_invalid_device_raises(self, monkeypatch):\n        with pytest.raises(RuntimeError, match=\"Unknown device\"):\n            resolve_device(\"tpu\")\n\n\n# ---------------------------------------------------------------------------\n# clear_device_cache\n# ---------------------------------------------------------------------------\n\n\nclass TestClearDeviceCache:\n    \"\"\"Dispatches to correct backend cache clear.\"\"\"\n\n    def test_cuda_clears_cache(self, monkeypatch):\n        mock_empty = MagicMock()\n        monkeypatch.setattr(torch.cuda, \"empty_cache\", mock_empty)\n        clear_device_cache(\"cuda\")\n        mock_empty.assert_called_once()\n\n    def test_mps_clears_cache(self, monkeypatch):\n        mock_empty = MagicMock()\n        monkeypatch.setattr(torch.mps, \"empty_cache\", mock_empty)\n        clear_device_cache(\"mps\")\n        mock_empty.assert_called_once()\n\n    def test_cpu_is_noop(self):\n        # Should not raise\n        clear_device_cache(\"cpu\")\n\n    def test_accepts_torch_device_object(self, monkeypatch):\n        mock_empty = MagicMock()\n        monkeypatch.setattr(torch.cuda, \"empty_cache\", mock_empty)\n        clear_device_cache(torch.device(\"cuda\"))\n        mock_empty.assert_called_once()\n\n    def test_accepts_mps_device_object(self, monkeypatch):\n        mock_empty = MagicMock()\n        monkeypatch.setattr(torch.mps, \"empty_cache\", mock_empty)\n        clear_device_cache(torch.device(\"mps\"))\n        mock_empty.assert_called_once()\n"
  },
  {
    "path": "tests/test_e2e_workflow.py",
    "content": "\"\"\"End-to-end workflow integration tests for CorridorKey.\n\nThese tests exercise the full pipeline from ClipEntry asset discovery through\nrun_inference output file creation.  The neural network engine is mocked so\nno model weights or GPU are required.\n\nWhy integration-test run_inference?\n  Unit tests cover individual math functions.  This file verifies that the\n  orchestration layer (reading frames from disk, calling the engine, writing\n  output files to the right directories) works end-to-end on realistic\n  directory structures.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom unittest.mock import MagicMock, patch\n\nimport numpy as np\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _fake_result(h: int = 4, w: int = 4) -> dict:\n    \"\"\"Return a minimal but valid process_frame result dict sized to (h, w).\"\"\"\n    return {\n        \"alpha\": np.full((h, w, 1), 0.8, dtype=np.float32),\n        \"fg\": np.full((h, w, 3), 0.6, dtype=np.float32),\n        \"comp\": np.full((h, w, 3), 0.5, dtype=np.float32),\n        \"processed\": np.full((h, w, 4), 0.4, dtype=np.float32),\n    }\n\n\n# ---------------------------------------------------------------------------\n# End-to-end: ClipEntry discovery → run_inference → files on disk\n# ---------------------------------------------------------------------------\n\n\nclass TestE2EInferenceWorkflow:\n    \"\"\"End-to-end: ClipEntry discovery → run_inference → output files on disk.\n\n    Uses the ``tmp_clip_dir`` fixture (shot_a: 2 frames, shot_b: 1 frame /\n    no alpha) and a mocked engine.  Verifies directory creation, frame I/O,\n    and file writing without a real engine or checkpoint.\n    \"\"\"\n\n    def test_output_directories_created(self, tmp_clip_dir, monkeypatch):\n        \"\"\"run_inference creates Output/{FG,Matte,Comp,Processed} for each clip.\"\"\"\n        from clip_manager import ClipEntry, run_inference\n\n        entry = ClipEntry(\"shot_a\", str(tmp_clip_dir / \"shot_a\"))\n        entry.find_assets()\n\n        # Supply blank answers to all interactive prompts inside run_inference\n        monkeypatch.setattr(\"builtins.input\", lambda prompt=\"\": \"\")\n\n        mock_engine = MagicMock()\n        mock_engine.process_frame.return_value = _fake_result()\n\n        with patch(\"CorridorKeyModule.backend.create_engine\", return_value=mock_engine):\n            run_inference([entry], device=\"cpu\")\n\n        out_root = tmp_clip_dir / \"shot_a\" / \"Output\"\n        assert (out_root / \"FG\").is_dir()\n        assert (out_root / \"Matte\").is_dir()\n        assert (out_root / \"Comp\").is_dir()\n        assert (out_root / \"Processed\").is_dir()\n\n    def test_output_files_written_per_frame(self, tmp_clip_dir, monkeypatch):\n        \"\"\"run_inference writes exactly one output file per input frame.\n\n        shot_a has 2 input frames and 2 alpha frames, so each output\n        subdirectory should contain exactly 2 files after inference.\n        \"\"\"\n        from clip_manager import ClipEntry, run_inference\n\n        entry = ClipEntry(\"shot_a\", str(tmp_clip_dir / \"shot_a\"))\n        entry.find_assets()\n\n        monkeypatch.setattr(\"builtins.input\", lambda prompt=\"\": \"\")\n\n        mock_engine = MagicMock()\n        mock_engine.process_frame.return_value = _fake_result()\n\n        with patch(\"CorridorKeyModule.backend.create_engine\", return_value=mock_engine):\n            run_inference([entry], device=\"cpu\")\n\n        out_root = tmp_clip_dir / \"shot_a\" / \"Output\"\n        # shot_a has 2 frames → 2 files per output directory\n        assert len(list((out_root / \"FG\").glob(\"*.exr\"))) == 2\n        assert len(list((out_root / \"Matte\").glob(\"*.exr\"))) == 2\n        assert len(list((out_root / \"Comp\").glob(\"*.png\"))) == 2\n        assert len(list((out_root / \"Processed\").glob(\"*.exr\"))) == 2\n\n    def test_clip_without_alpha_skipped(self, tmp_clip_dir, monkeypatch):\n        \"\"\"Clips missing an alpha asset are silently skipped by run_inference.\n\n        shot_b has Input but an empty AlphaHint, so it has no alpha_asset.\n        run_inference should process zero frames and create no Output directory.\n        \"\"\"\n        from clip_manager import ClipEntry, run_inference\n\n        entry = ClipEntry(\"shot_b\", str(tmp_clip_dir / \"shot_b\"))\n        entry.find_assets()\n        assert entry.alpha_asset is None  # precondition\n\n        monkeypatch.setattr(\"builtins.input\", lambda prompt=\"\": \"\")\n\n        mock_engine = MagicMock()\n        mock_engine.process_frame.return_value = _fake_result()\n\n        with patch(\"CorridorKeyModule.backend.create_engine\", return_value=mock_engine):\n            run_inference([entry], device=\"cpu\")\n\n        # No engine calls — clip was filtered out before inference\n        mock_engine.process_frame.assert_not_called()\n        assert not (tmp_clip_dir / \"shot_b\" / \"Output\").exists()\n"
  },
  {
    "path": "tests/test_exr_gamma_bug_condition.py",
    "content": "\"\"\"Bug condition exploration tests for EXR gamma correction.\n\nThese tests encode the EXPECTED (correct) behavior for EXR+sRGB gamma handling.\nOn UNFIXED code they are expected to FAIL, confirming the bug exists.\n\n**Validates: Requirements 1.1, 1.2, 1.3**\n\nDefect 1 (1.1): run_inference EXR path ignores input_is_linear=False — no gamma correction applied.\nDefect 2 (1.2): process_frame receives raw linear data with input_is_linear=False semantics.\nDefect 3 (1.3): read_image_frame uses naive pow(1/2.2) instead of piecewise sRGB transfer function.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nfrom unittest.mock import MagicMock, patch\n\nimport cv2\nimport numpy as np\nfrom hypothesis import given, settings\nfrom hypothesis import strategies as st\nfrom hypothesis.extra.numpy import arrays\n\nfrom backend.frame_io import read_image_frame\nfrom CorridorKeyModule.core.color_utils import linear_to_srgb\n\n# ---------------------------------------------------------------------------\n# Strategies\n# ---------------------------------------------------------------------------\n\n# Float32 pixel values in [0, 1] — the valid linear range for EXR data\nlinear_pixel_values = st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False)\n\n# Small float32 RGB arrays representing EXR frame data\n# Using small sizes (4x4 to 16x16) to keep tests fast\nlinear_pixel_arrays = arrays(\n    dtype=np.float32,\n    shape=st.tuples(\n        st.integers(min_value=4, max_value=16),\n        st.integers(min_value=4, max_value=16),\n        st.just(3),\n    ),\n    elements=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False),\n)\n\n# Strategy that specifically targets the sRGB piecewise threshold region\n# Values near 0.0031308 where naive pow(1/2.2) diverges from piecewise sRGB\nthreshold_pixel_values = st.one_of(\n    # Values below the threshold (linear segment of sRGB)\n    st.floats(min_value=0.0, max_value=0.0031308, allow_nan=False, allow_infinity=False),\n    # Values right around the threshold\n    st.floats(min_value=0.002, max_value=0.005, allow_nan=False, allow_infinity=False),\n    # Values above the threshold (power segment of sRGB)\n    st.floats(min_value=0.0031308, max_value=1.0, allow_nan=False, allow_infinity=False),\n)\n\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _write_exr(path: str, rgb_data: np.ndarray) -> None:\n    \"\"\"Write a float32 RGB array as a BGR EXR file via OpenCV.\"\"\"\n    bgr = cv2.cvtColor(rgb_data.astype(np.float32), cv2.COLOR_RGB2BGR)\n    cv2.imwrite(path, bgr)\n\n\n# ---------------------------------------------------------------------------\n# Defect 3: naive pow(1/2.2) vs piecewise sRGB in read_image_frame\n# ---------------------------------------------------------------------------\n\n\nclass TestDefect3NaivePowVsPiecewiseSRGB:\n    \"\"\"**Validates: Requirements 1.3**\n\n    read_image_frame(exr_path, gamma_correct_exr=True) should produce output\n    matching linear_to_srgb() from color_utils.py, NOT np.power(data, 1/2.2).\n    \"\"\"\n\n    @given(data=linear_pixel_arrays)\n    @settings(max_examples=50, deadline=None)\n    def test_gamma_corrected_exr_matches_piecewise_srgb(self, data: np.ndarray) -> None:\n        \"\"\"When gamma_correct_exr=True, the result must match the piecewise\n        sRGB transfer function, not the naive pow(1/2.2) approximation.\n\n        **Validates: Requirements 1.3**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            exr_path = os.path.join(tmpdir, \"test.exr\")\n            _write_exr(exr_path, data)\n\n            result = read_image_frame(exr_path, gamma_correct_exr=True)\n            assert result is not None, \"read_image_frame returned None\"\n\n            # The expected output: piecewise sRGB transfer function\n            raw = cv2.imread(exr_path, cv2.IMREAD_UNCHANGED)\n            if raw.ndim == 3 and raw.shape[2] == 4:\n                raw = raw[:, :, :3]\n            raw_rgb = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)\n            linear_clamped = np.maximum(raw_rgb, 0.0).astype(np.float32)\n            expected = linear_to_srgb(linear_clamped)\n\n            np.testing.assert_allclose(\n                result,\n                expected,\n                atol=1e-6,\n                err_msg=(\n                    \"read_image_frame with gamma_correct_exr=True does not match \"\n                    \"linear_to_srgb(). It likely uses naive pow(1/2.2) instead of \"\n                    \"the piecewise sRGB transfer function.\"\n                ),\n            )\n\n    @given(\n        pixel_val=threshold_pixel_values,\n    )\n    @settings(max_examples=100, deadline=None)\n    def test_threshold_region_divergence(self, pixel_val: float) -> None:\n        \"\"\"Specifically test values near the 0.0031308 threshold where\n        naive pow(1/2.2) and piecewise sRGB diverge most.\n\n        **Validates: Requirements 1.3**\n        \"\"\"\n        # Create a small 2x2 image with the test pixel value\n        data = np.full((2, 2, 3), pixel_val, dtype=np.float32)\n\n        with tempfile.TemporaryDirectory() as tmpdir:\n            exr_path = os.path.join(tmpdir, \"threshold_test.exr\")\n            _write_exr(exr_path, data)\n\n            result = read_image_frame(exr_path, gamma_correct_exr=True)\n            assert result is not None\n\n            expected = linear_to_srgb(data)\n\n            np.testing.assert_allclose(\n                result,\n                expected,\n                atol=1e-6,\n                err_msg=(\n                    f\"At pixel value {pixel_val} (near sRGB threshold 0.0031308), \"\n                    f\"read_image_frame produces {result.flat[0]:.8f} but piecewise \"\n                    f\"sRGB expects {expected.flat[0]:.8f}. \"\n                    f\"Naive pow(1/2.2) would give {np.power(pixel_val, 1.0 / 2.2):.8f}.\"\n                ),\n            )\n\n\n# ---------------------------------------------------------------------------\n# Defect 1 & 2: run_inference EXR path ignores gamma + wrong semantics\n# ---------------------------------------------------------------------------\n\n\nclass TestDefect1And2RunInferenceEXRPath:\n    \"\"\"**Validates: Requirements 1.1, 1.2**\n\n    When input is an EXR image sequence with input_is_linear=False,\n    the frame passed to process_frame() must have sRGB gamma applied\n    and process_frame must receive input_is_linear=False with properly\n    gamma-corrected data.\n    \"\"\"\n\n    @given(data=linear_pixel_arrays)\n    @settings(max_examples=30, deadline=None)\n    def test_exr_srgb_frame_is_gamma_corrected(self, data: np.ndarray) -> None:\n        \"\"\"Given an EXR image sequence with input_is_linear=False, the frame\n        passed to process_frame() must NOT be raw linear data — it must have\n        sRGB gamma applied.\n\n        **Validates: Requirements 1.1, 1.2**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            # Set up a minimal clip structure with one EXR frame\n            input_dir = os.path.join(tmpdir, \"Input\")\n            alpha_dir = os.path.join(tmpdir, \"AlphaHint\")\n            os.makedirs(input_dir)\n            os.makedirs(alpha_dir)\n\n            h, w = data.shape[:2]\n\n            # Write EXR input frame\n            exr_path = os.path.join(input_dir, \"frame_00000.exr\")\n            _write_exr(exr_path, data)\n\n            # Write a matching alpha mask (simple grayscale PNG)\n            mask = np.full((h, w), 128, dtype=np.uint8)\n            cv2.imwrite(os.path.join(alpha_dir, \"frame_00000.png\"), mask)\n\n            # What we expect: the linear data should be gamma-corrected\n            # via piecewise sRGB before reaching process_frame\n            expected_srgb = linear_to_srgb(np.maximum(data, 0.0).astype(np.float32))\n\n            # We'll capture what process_frame actually receives by patching it\n            captured_args = {}\n\n            def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs):\n                captured_args[\"image\"] = image.copy()\n                captured_args[\"input_is_linear\"] = input_is_linear\n                # Return minimal valid result\n                h_img, w_img = image.shape[:2]\n                return {\n                    \"alpha\": np.zeros((h_img, w_img, 1), dtype=np.float32),\n                    \"fg\": np.zeros((h_img, w_img, 3), dtype=np.float32),\n                    \"comp\": np.zeros((h_img, w_img, 3), dtype=np.float32),\n                    \"processed\": np.zeros((h_img, w_img, 4), dtype=np.float32),\n                }\n\n            # Build a mock clip that looks like an EXR image sequence\n            mock_clip = MagicMock()\n            mock_clip.name = \"test_clip\"\n            mock_clip.root_path = tmpdir\n            mock_clip.input_asset.type = \"sequence\"\n            mock_clip.input_asset.path = input_dir\n            mock_clip.input_asset.frame_count = 1\n            mock_clip.alpha_asset.type = \"sequence\"\n            mock_clip.alpha_asset.path = alpha_dir\n            mock_clip.alpha_asset.frame_count = 1\n\n            # Create mock settings with input_is_linear=False (sRGB selected)\n            mock_settings = MagicMock()\n            mock_settings.input_is_linear = False\n            mock_settings.despill_strength = 1.0\n            mock_settings.auto_despeckle = False\n            mock_settings.despeckle_size = 400\n            mock_settings.refiner_scale = 1.0\n\n            # Mock the engine\n            mock_engine = MagicMock()\n            mock_engine.process_frame = mock_process_frame\n\n            # Patch create_engine where it's imported from inside run_inference\n            with patch(\"CorridorKeyModule.backend.create_engine\", return_value=mock_engine):\n                from clip_manager import run_inference\n\n                run_inference(\n                    [mock_clip],\n                    device=\"cpu\",\n                    max_frames=1,\n                    settings=mock_settings,\n                )\n\n            assert \"image\" in captured_args, \"process_frame was never called — clip setup may be wrong\"\n\n            actual_image = captured_args[\"image\"]\n            actual_is_linear = captured_args[\"input_is_linear\"]\n\n            # Defect 1: The frame should be gamma-corrected (sRGB), not raw linear\n            # On buggy code, img_srgb contains raw linear data\n            np.testing.assert_allclose(\n                actual_image,\n                expected_srgb,\n                atol=1e-5,\n                err_msg=(\n                    \"Frame passed to process_frame() is raw linear data, not \"\n                    \"sRGB gamma-corrected. The EXR branch in run_inference() \"\n                    \"does not apply gamma correction when input_is_linear=False.\"\n                ),\n            )\n\n            # Defect 2: input_is_linear should be False (sRGB semantics)\n            assert actual_is_linear is False, (\n                f\"process_frame received input_is_linear={actual_is_linear}, \"\n                f\"expected False. The EXR data has wrong semantics.\"\n            )\n"
  },
  {
    "path": "tests/test_exr_gamma_preservation.py",
    "content": "\"\"\"Preservation property tests for EXR gamma correction bugfix.\n\nThese tests verify that non-buggy code paths remain unchanged after the fix.\nThey MUST PASS on the current UNFIXED code — they establish the baseline\nbehavior that must be preserved.\n\n**Validates: Requirements 3.1, 3.2, 3.3, 3.4, 3.5**\n\n3.1: Linear EXR path (input_is_linear=True) passes data through unchanged.\n3.2: Standard image path (PNG/JPG/TIFF) reads uint8 normalized to [0,1] float32.\n3.3: Video path continues to decode frames as sRGB regardless of input_is_linear.\n3.4: read_image_frame() with gamma_correct_exr=False returns raw linear EXR data.\n3.5: Linear EXR inputs with input_is_linear=True produce byte-exact identical output.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nimport tempfile\nfrom unittest.mock import MagicMock, patch\n\nimport cv2\nimport numpy as np\nfrom hypothesis import given, settings\nfrom hypothesis import strategies as st\nfrom hypothesis.extra.numpy import arrays\n\nfrom backend.frame_io import read_image_frame\n\n# ---------------------------------------------------------------------------\n# Strategies\n# ---------------------------------------------------------------------------\n\n# Float32 pixel arrays in [0, 1] for EXR data\nlinear_pixel_arrays = arrays(\n    dtype=np.float32,\n    shape=st.tuples(\n        st.integers(min_value=4, max_value=16),\n        st.integers(min_value=4, max_value=16),\n        st.just(3),\n    ),\n    elements=st.floats(min_value=0.0, max_value=1.0, allow_nan=False, allow_infinity=False),\n)\n\n# uint8 pixel arrays for standard image data (PNG/JPG)\nuint8_pixel_arrays = arrays(\n    dtype=np.uint8,\n    shape=st.tuples(\n        st.integers(min_value=4, max_value=16),\n        st.integers(min_value=4, max_value=16),\n        st.just(3),\n    ),\n)\n\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _write_exr(path: str, rgb_data: np.ndarray) -> None:\n    \"\"\"Write a float32 RGB array as a BGR EXR file via OpenCV.\"\"\"\n    bgr = cv2.cvtColor(rgb_data.astype(np.float32), cv2.COLOR_RGB2BGR)\n    cv2.imwrite(path, bgr)\n\n\ndef _write_png(path: str, rgb_data: np.ndarray) -> None:\n    \"\"\"Write a uint8 RGB array as a BGR PNG file via OpenCV.\"\"\"\n    bgr = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2BGR)\n    cv2.imwrite(path, bgr)\n\n\n# ---------------------------------------------------------------------------\n# 3.4: gamma_correct_exr=False returns raw linear EXR data (no transformation)\n# ---------------------------------------------------------------------------\n\n\nclass TestPreservationLinearEXRRead:\n    \"\"\"**Validates: Requirements 3.1, 3.4**\n\n    read_image_frame() with gamma_correct_exr=False (the default) returns\n    raw linear EXR data with no gamma transformation applied. The result\n    must be identical to what cv2.imread produces (clamped to >= 0, as float32).\n    \"\"\"\n\n    @given(data=linear_pixel_arrays)\n    @settings(max_examples=50, deadline=None)\n    def test_exr_default_returns_raw_linear_data(self, data: np.ndarray) -> None:\n        \"\"\"For all float32 arrays in [0, 1], read_image_frame with default\n        params (gamma_correct_exr=False) returns data identical to raw\n        cv2.imread output — no transformation applied.\n\n        **Validates: Requirements 3.4**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            exr_path = os.path.join(tmpdir, \"test.exr\")\n            _write_exr(exr_path, data)\n\n            result = read_image_frame(exr_path)\n            assert result is not None, \"read_image_frame returned None\"\n\n            # Reconstruct what raw cv2.imread would produce\n            raw = cv2.imread(exr_path, cv2.IMREAD_UNCHANGED)\n            if raw.ndim == 3 and raw.shape[2] == 4:\n                raw = raw[:, :, :3]\n            raw_rgb = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)\n            expected = np.maximum(raw_rgb, 0.0).astype(np.float32)\n\n            np.testing.assert_array_equal(\n                result,\n                expected,\n                err_msg=(\n                    \"read_image_frame with gamma_correct_exr=False (default) \"\n                    \"should return raw linear EXR data identical to cv2.imread \"\n                    \"output. No transformation should be applied.\"\n                ),\n            )\n\n    @given(data=linear_pixel_arrays)\n    @settings(max_examples=50, deadline=None)\n    def test_exr_explicit_false_returns_raw_linear_data(self, data: np.ndarray) -> None:\n        \"\"\"Explicitly passing gamma_correct_exr=False produces the same\n        result as the default — raw linear data, no transformation.\n\n        **Validates: Requirements 3.1, 3.4**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            exr_path = os.path.join(tmpdir, \"test.exr\")\n            _write_exr(exr_path, data)\n\n            result_default = read_image_frame(exr_path)\n            result_explicit = read_image_frame(exr_path, gamma_correct_exr=False)\n\n            assert result_default is not None\n            assert result_explicit is not None\n\n            np.testing.assert_array_equal(\n                result_default,\n                result_explicit,\n                err_msg=(\n                    \"read_image_frame(path) and read_image_frame(path, \"\n                    \"gamma_correct_exr=False) should produce identical results.\"\n                ),\n            )\n\n\n# ---------------------------------------------------------------------------\n# 3.2: Standard image path reads uint8 normalized to [0,1] float32\n# ---------------------------------------------------------------------------\n\n\nclass TestPreservationStandardImageRead:\n    \"\"\"**Validates: Requirements 3.2**\n\n    For PNG/JPG/TIFF with input_is_linear=False, data is read as uint8,\n    divided by 255, and the result is float32 in [0, 1].\n    \"\"\"\n\n    @given(data=uint8_pixel_arrays)\n    @settings(max_examples=50, deadline=None)\n    def test_png_read_returns_uint8_normalized(self, data: np.ndarray) -> None:\n        \"\"\"For all uint8 arrays, standard image read returns\n        array.astype(float32) / 255.0.\n\n        **Validates: Requirements 3.2**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            png_path = os.path.join(tmpdir, \"test.png\")\n            _write_png(png_path, data)\n\n            result = read_image_frame(png_path)\n            assert result is not None, \"read_image_frame returned None for PNG\"\n\n            # Reconstruct expected: read as uint8 BGR, convert to RGB, / 255\n            raw = cv2.imread(png_path)\n            raw_rgb = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)\n            expected = raw_rgb.astype(np.float32) / 255.0\n\n            np.testing.assert_array_equal(\n                result,\n                expected,\n                err_msg=(\n                    \"Standard image read should return uint8 data normalized to [0,1] float32 via division by 255.\"\n                ),\n            )\n\n    @given(data=uint8_pixel_arrays)\n    @settings(max_examples=50, deadline=None)\n    def test_png_result_dtype_and_range(self, data: np.ndarray) -> None:\n        \"\"\"Standard image read always returns float32 in [0, 1].\n\n        **Validates: Requirements 3.2**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            png_path = os.path.join(tmpdir, \"test.png\")\n            _write_png(png_path, data)\n\n            result = read_image_frame(png_path)\n            assert result is not None\n\n            assert result.dtype == np.float32, f\"Expected float32 dtype, got {result.dtype}\"\n            assert result.min() >= 0.0, \"Result contains negative values\"\n            assert result.max() <= 1.0, \"Result contains values > 1.0\"\n\n\n# ---------------------------------------------------------------------------\n# 3.1, 3.5: Linear EXR with input_is_linear=True in run_inference\n# ---------------------------------------------------------------------------\n\n\nclass TestPreservationLinearEXRInference:\n    \"\"\"**Validates: Requirements 3.1, 3.5**\n\n    When input is an EXR image sequence with input_is_linear=True,\n    the data passes through without gamma correction and process_frame()\n    receives input_is_linear=True with raw linear data.\n    \"\"\"\n\n    @given(data=linear_pixel_arrays)\n    @settings(max_examples=30, deadline=None)\n    def test_linear_exr_passes_through_unchanged(self, data: np.ndarray) -> None:\n        \"\"\"For all linear EXR inputs with input_is_linear=True, the data\n        reaching process_frame() is byte-exact identical to the raw EXR\n        read — no gamma correction applied.\n\n        **Validates: Requirements 3.1, 3.5**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmpdir:\n            input_dir = os.path.join(tmpdir, \"Input\")\n            alpha_dir = os.path.join(tmpdir, \"AlphaHint\")\n            os.makedirs(input_dir)\n            os.makedirs(alpha_dir)\n\n            h, w = data.shape[:2]\n\n            # Write EXR input frame\n            exr_path = os.path.join(input_dir, \"frame_00000.exr\")\n            _write_exr(exr_path, data)\n\n            # Write a matching alpha mask\n            mask = np.full((h, w), 128, dtype=np.uint8)\n            cv2.imwrite(os.path.join(alpha_dir, \"frame_00000.png\"), mask)\n\n            # Read what the raw EXR data looks like after cv2 round-trip\n            raw = cv2.imread(exr_path, cv2.IMREAD_UNCHANGED)\n            if raw.ndim == 3 and raw.shape[2] == 4:\n                raw = raw[:, :, :3]\n            raw_rgb = cv2.cvtColor(raw, cv2.COLOR_BGR2RGB)\n            expected_linear = np.maximum(raw_rgb, 0.0).astype(np.float32)\n\n            # Capture what process_frame actually receives\n            captured_args = {}\n\n            def mock_process_frame(image, mask_linear, *, input_is_linear=False, **kwargs):\n                captured_args[\"image\"] = image.copy()\n                captured_args[\"input_is_linear\"] = input_is_linear\n                h_img, w_img = image.shape[:2]\n                return {\n                    \"alpha\": np.zeros((h_img, w_img, 1), dtype=np.float32),\n                    \"fg\": np.zeros((h_img, w_img, 3), dtype=np.float32),\n                    \"comp\": np.zeros((h_img, w_img, 3), dtype=np.float32),\n                    \"processed\": np.zeros((h_img, w_img, 4), dtype=np.float32),\n                }\n\n            mock_clip = MagicMock()\n            mock_clip.name = \"test_clip\"\n            mock_clip.root_path = tmpdir\n            mock_clip.input_asset.type = \"sequence\"\n            mock_clip.input_asset.path = input_dir\n            mock_clip.input_asset.frame_count = 1\n            mock_clip.alpha_asset.type = \"sequence\"\n            mock_clip.alpha_asset.path = alpha_dir\n            mock_clip.alpha_asset.frame_count = 1\n\n            # input_is_linear=True — user confirmed linear EXR\n            mock_settings = MagicMock()\n            mock_settings.input_is_linear = True\n            mock_settings.despill_strength = 1.0\n            mock_settings.auto_despeckle = False\n            mock_settings.despeckle_size = 400\n            mock_settings.refiner_scale = 1.0\n\n            mock_engine = MagicMock()\n            mock_engine.process_frame = mock_process_frame\n\n            with patch(\"CorridorKeyModule.backend.create_engine\", return_value=mock_engine):\n                from clip_manager import run_inference\n\n                run_inference(\n                    [mock_clip],\n                    device=\"cpu\",\n                    max_frames=1,\n                    settings=mock_settings,\n                )\n\n            assert \"image\" in captured_args, \"process_frame was never called\"\n\n            # The frame should be raw linear data — no gamma correction\n            np.testing.assert_allclose(\n                captured_args[\"image\"],\n                expected_linear,\n                atol=1e-6,\n                err_msg=(\n                    \"With input_is_linear=True, the EXR data reaching \"\n                    \"process_frame() should be raw linear — identical to \"\n                    \"cv2.imread output. No gamma correction should be applied.\"\n                ),\n            )\n\n            # input_is_linear should be True\n            assert captured_args[\"input_is_linear\"] is True, (\n                f\"process_frame received input_is_linear={captured_args['input_is_linear']}, expected True\"\n            )\n"
  },
  {
    "path": "tests/test_frame_io.py",
    "content": "\"\"\"Tests for backend.frame_io — frame reading utilities.\n\nFocuses on input validation and edge cases that don't require real video\nfiles or model weights.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom backend.frame_io import read_video_frame_at, read_video_mask_at\n\n\nclass TestReadVideoFrameAtNegativeIndex:\n    \"\"\"read_video_frame_at must return None for negative frame indices.\"\"\"\n\n    def test_negative_one_returns_none(self, tmp_path):\n        \"\"\"frame_index=-1 must return None without raising.\"\"\"\n        # A valid-looking path is enough — the guard fires before VideoCapture\n        result = read_video_frame_at(str(tmp_path / \"fake.mp4\"), frame_index=-1)\n        assert result is None\n\n    def test_large_negative_returns_none(self, tmp_path):\n        \"\"\"Large negative values must also return None.\"\"\"\n        result = read_video_frame_at(str(tmp_path / \"fake.mp4\"), frame_index=-999)\n        assert result is None\n\n    def test_zero_does_not_trigger_guard(self, tmp_path):\n        \"\"\"frame_index=0 is valid and must not be caught by the negative guard.\n\n        The file doesn't exist so cap.read() fails, returning None via the\n        existing 'ret' check — not the new guard. We just confirm no TypeError\n        or unexpected exception is raised.\n        \"\"\"\n        # Should return None (file not found path), not raise\n        result = read_video_frame_at(str(tmp_path / \"fake.mp4\"), frame_index=0)\n        assert result is None\n\n\nclass TestReadVideoMaskAtNegativeIndex:\n    \"\"\"read_video_mask_at must return None for negative frame indices.\"\"\"\n\n    def test_negative_one_returns_none(self, tmp_path):\n        \"\"\"frame_index=-1 must return None without raising.\"\"\"\n        result = read_video_mask_at(str(tmp_path / \"fake.mp4\"), frame_index=-1)\n        assert result is None\n\n    def test_large_negative_returns_none(self, tmp_path):\n        \"\"\"Large negative values must also return None.\"\"\"\n        result = read_video_mask_at(str(tmp_path / \"fake.mp4\"), frame_index=-999)\n        assert result is None\n\n    def test_zero_does_not_trigger_guard(self, tmp_path):\n        \"\"\"frame_index=0 is valid and must not be caught by the negative guard.\"\"\"\n        result = read_video_mask_at(str(tmp_path / \"fake.mp4\"), frame_index=0)\n        assert result is None\n"
  },
  {
    "path": "tests/test_gamma_consistency.py",
    "content": "\"\"\"Tests documenting the gamma 2.2 vs piecewise sRGB inconsistency.\n\nSTATUS: This test documents a KNOWN INCONSISTENCY, not desired behavior.\n\nThe codebase uses two different methods to convert between linear and sRGB:\n\n1. **Piecewise sRGB (correct)** — used by ``color_utils.linear_to_srgb()``\n   and ``color_utils.srgb_to_linear()``, called from ``inference_engine.py``.\n   This follows the IEC 61966-2-1 spec: exponent 2.4 with a linear segment\n   below the breakpoint.\n\n2. **Gamma 2.2 approximation** — used by ``clip_manager.py:383`` (VideoMaMa\n   frame loading) and ``gvm_core/gvm/utils/inference_utils.py:124`` (GVM\n   frame loading).  This uses a simple ``x ** (1/2.2)`` power curve.\n\nThe two methods produce visibly different results, especially in darks.\nAt linear 0.01, the difference is ~4.7% — enough to see in a waveform monitor.\n\n**Why this hasn't been fixed yet:**\nThe gamma 2.2 paths feed data into third-party models (VideoMaMa, GVM) that\nwere likely *trained* on gamma 2.2 converted data.  Switching to piecewise\nsRGB might degrade their output quality.  Verifying this requires running\ninference with model weights, which isn't feasible in automated tests.\n\n**If you fix one path, fix the other too** — or these tests will tell you\nsomething changed.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport numpy as np\nimport pytest\n\nfrom CorridorKeyModule.core import color_utils as cu\n\n# ---------------------------------------------------------------------------\n# Document the divergence\n# ---------------------------------------------------------------------------\n\n\nclass TestGammaInconsistency:\n    \"\"\"These tests assert that the two conversion methods produce DIFFERENT\n    results — documenting the inconsistency so it isn't accidentally\n    \"fixed\" in only one place.\n    \"\"\"\n\n    def test_linear_to_srgb_differs_from_gamma_22(self):\n        \"\"\"Piecewise sRGB and gamma 2.2 must produce different results.\n\n        If this test fails, someone unified the conversion — which is good!\n        Remove this test and update the docstring above.\n        \"\"\"\n        # Test at a value where the difference is significant\n        linear_val = np.float32(0.1)\n\n        piecewise_srgb = float(cu.linear_to_srgb(linear_val))\n        gamma_22 = float(linear_val ** (1.0 / 2.2))\n\n        # They should NOT be equal\n        assert piecewise_srgb != pytest.approx(gamma_22, abs=1e-4), (\n            \"Piecewise sRGB and gamma 2.2 now match — has the inconsistency been fixed? If so, remove this test.\"\n        )\n\n    def test_srgb_to_linear_differs_from_gamma_22(self):\n        \"\"\"Inverse direction: srgb_to_linear vs x**2.2.\"\"\"\n        srgb_val = np.float32(0.5)\n\n        piecewise_linear = float(cu.srgb_to_linear(srgb_val))\n        gamma_22_linear = float(srgb_val**2.2)\n\n        assert piecewise_linear != pytest.approx(gamma_22_linear, abs=1e-4), (\n            \"Piecewise sRGB and gamma 2.2 inverse now match — has the \"\n            \"inconsistency been fixed? If so, remove this test.\"\n        )\n\n\n# ---------------------------------------------------------------------------\n# Quantify the divergence\n# ---------------------------------------------------------------------------\n\n\nclass TestGammaDivergenceMagnitude:\n    \"\"\"Quantify how far apart the two methods are at various values.\n\n    These tests serve as documentation — if someone changes the color math,\n    these show exactly where the drift happens and by how much.\n    \"\"\"\n\n    @pytest.mark.parametrize(\n        \"linear_val,expected_min_diff\",\n        [\n            # In darks, the piecewise linear segment causes the biggest gap.\n            # At linear 0.001, piecewise sRGB uses the linear segment (x*12.92)\n            # while gamma 2.2 uses x^(1/2.2) — very different behavior.\n            (0.001, 0.005),\n            # Mid-darks: still a measurable gap\n            (0.01, 0.01),\n            # Mid-tones: smaller but still measurable\n            (0.1, 0.001),\n            # Highlights: converges as both curves approach 1.0\n            (0.5, 0.001),\n        ],\n    )\n    def test_divergence_at_known_values(self, linear_val, expected_min_diff):\n        \"\"\"The two methods should differ by at least expected_min_diff.\"\"\"\n        x = np.float32(linear_val)\n        piecewise = float(cu.linear_to_srgb(x))\n        gamma_22 = float(x ** (1.0 / 2.2))\n\n        diff = abs(piecewise - gamma_22)\n        assert diff >= expected_min_diff, (\n            f\"At linear={linear_val}: piecewise={piecewise:.6f}, \"\n            f\"gamma2.2={gamma_22:.6f}, diff={diff:.6f} \"\n            f\"(expected >= {expected_min_diff})\"\n        )\n\n    def test_both_methods_agree_at_zero_and_one(self):\n        \"\"\"At the endpoints 0.0 and 1.0, both methods agree exactly.\"\"\"\n        for val in [0.0, 1.0]:\n            x = np.float32(val)\n            piecewise = float(cu.linear_to_srgb(x))\n            gamma_22 = float(x ** (1.0 / 2.2))\n            assert piecewise == pytest.approx(gamma_22, abs=1e-6), f\"At {val}, both methods should agree\"\n\n    def test_worst_case_divergence_in_darks(self):\n        \"\"\"Document the worst-case divergence across the 0-1 range.\n\n        This is informational — the exact value may shift if someone tweaks\n        tolerances, but the magnitude should stay in the ballpark.\n        \"\"\"\n        values = np.linspace(0.0, 1.0, 1000, dtype=np.float32)\n        piecewise = cu.linear_to_srgb(values).astype(np.float64)\n        gamma_22 = (values.astype(np.float64)) ** (1.0 / 2.2)\n\n        max_diff = float(np.max(np.abs(piecewise - gamma_22)))\n\n        # The worst-case difference should be in the low-mid range (~0.01-0.04)\n        # and definitely not zero (that would mean the inconsistency is gone)\n        assert max_diff > 0.01, \"Expected significant divergence in darks\"\n        assert max_diff < 0.10, \"Divergence larger than expected — check math\"\n"
  },
  {
    "path": "tests/test_imports.py",
    "content": "\"\"\"Smoke tests: verify all CorridorKey packages import without error.\n\nThese catch missing __init__.py files, broken relative imports, and\nmissing dependencies before any real logic runs.\n\"\"\"\n\n\ndef test_import_corridorkey_module():\n    import CorridorKeyModule  # noqa: F401\n\n\ndef test_import_color_utils():\n    from CorridorKeyModule.core import color_utils  # noqa: F401\n\n\ndef test_import_inference_engine():\n    from CorridorKeyModule import inference_engine  # noqa: F401\n\n\ndef test_import_model_transformer():\n    from CorridorKeyModule.core import model_transformer  # noqa: F401\n\n\ndef test_import_gvm_core():\n    import gvm_core  # noqa: F401\n\n\ndef test_import_gvm_wrapper():\n    from gvm_core import wrapper  # noqa: F401\n\n\ndef test_import_videomama():\n    import VideoMaMaInferenceModule  # noqa: F401\n\n\ndef test_import_videomama_inference():\n    from VideoMaMaInferenceModule import inference  # noqa: F401\n"
  },
  {
    "path": "tests/test_inference_engine.py",
    "content": "\"\"\"Tests for CorridorKeyModule.inference_engine.CorridorKeyEngine.process_frame.\n\nThese tests mock the GreenFormer model so they run without GPU or model\nweights. They verify the pre-processing (resize, normalize, color space\nconversion) and post-processing (upscale, despill, premultiply, composite)\npipeline that wraps the neural network.\n\nWhy mock the model?\n  The model requires a ~500MB checkpoint and CUDA. The pre/post-processing\n  pipeline is where compositing bugs hide (wrong color space, premul errors,\n  alpha dimension mismatches). Mocking the model isolates that logic.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport numpy as np\nimport pytest\nimport torch\n\nfrom CorridorKeyModule.core import color_utils as cu\n\n# ---------------------------------------------------------------------------\n# Helpers\n# ---------------------------------------------------------------------------\n\n\ndef _make_engine_with_mock(mock_greenformer, img_size=64):\n    \"\"\"Create a CorridorKeyEngine with a mocked model, bypassing __init__.\n\n    Manually sets the attributes that __init__ would create, avoiding the\n    need for checkpoint files or GPU.\n    \"\"\"\n    from CorridorKeyModule.inference_engine import CorridorKeyEngine\n\n    engine = object.__new__(CorridorKeyEngine)\n    engine.device = torch.device(\"cpu\")\n    engine.img_size = img_size\n    engine.checkpoint_path = \"/fake/checkpoint.pth\"\n    engine.use_refiner = False\n    engine.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3)\n    engine.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3)\n    engine.model = mock_greenformer\n    engine.model_precision = torch.float32\n    engine.mixed_precision = True\n    return engine\n\n\n# ---------------------------------------------------------------------------\n# process_frame output structure\n# ---------------------------------------------------------------------------\n\n\nclass TestProcessFrameOutputs:\n    \"\"\"Verify shape, dtype, and key presence of process_frame outputs.\"\"\"\n\n    def test_output_keys(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"process_frame must return alpha, fg, comp, and processed.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n\n        assert \"alpha\" in result\n        assert \"fg\" in result\n        assert \"comp\" in result\n        assert \"processed\" in result\n\n    def test_output_shapes_match_input(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"All outputs should match the spatial dimensions of the input.\"\"\"\n        h, w = sample_frame_rgb.shape[:2]\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n\n        assert result[\"alpha\"].shape[:2] == (h, w)\n        assert result[\"fg\"].shape[:2] == (h, w)\n        assert result[\"comp\"].shape == (h, w, 3)\n        assert result[\"processed\"].shape == (h, w, 4)\n\n    def test_output_dtype_float32(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"All outputs should be float32 numpy arrays.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n\n        for key in (\"alpha\", \"fg\", \"comp\", \"processed\"):\n            assert result[key].dtype == np.float32, f\"{key} should be float32\"\n\n    def test_alpha_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"Alpha output must be in [0, 1] — values outside this range corrupt compositing.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n        alpha = result[\"alpha\"]\n        assert alpha.min() >= -0.01, f\"alpha min {alpha.min():.4f} is below 0\"\n        assert alpha.max() <= 1.01, f\"alpha max {alpha.max():.4f} is above 1\"\n\n    def test_fg_output_range_is_zero_to_one(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"FG output must be in [0, 1] — required for downstream sRGB conversion and EXR export.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n        fg = result[\"fg\"]\n        assert fg.min() >= -0.01, f\"fg min {fg.min():.4f} is below 0\"\n        assert fg.max() <= 1.01, f\"fg max {fg.max():.4f} is above 1\"\n\n\n# ---------------------------------------------------------------------------\n# Input color space handling\n# ---------------------------------------------------------------------------\n\n\nclass TestProcessFrameColorSpace:\n    \"\"\"Verify the sRGB vs linear input paths.\n\n    When input_is_linear=True, process_frame resizes in linear space then\n    converts to sRGB before feeding the model (preserving HDR highlight detail).\n    When False (default), it resizes in sRGB directly.\n    \"\"\"\n\n    def test_srgb_input_default(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"Default sRGB path should not crash and should return valid outputs.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=False)\n\n        np.testing.assert_allclose(result[\"comp\"], 0.545655, atol=1e-4)\n\n    def test_linear_input_path(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"Linear input path should convert to sRGB before model input.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask, input_is_linear=True)\n        assert result[\"comp\"].shape == sample_frame_rgb.shape\n\n    def test_uint8_input_normalized(self, sample_mask, mock_greenformer):\n        \"\"\"uint8 input should be auto-converted to float32 [0, 1].\"\"\"\n        img_uint8 = np.random.default_rng(42).integers(0, 256, (64, 64, 3), dtype=np.uint8)\n        engine = _make_engine_with_mock(mock_greenformer)\n        # Should not crash — uint8 is auto-normalized to float32\n        result = engine.process_frame(img_uint8, sample_mask)\n        assert result[\"alpha\"].dtype == np.float32\n\n    def test_model_called_exactly_once(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"The neural network model must be called exactly once per process_frame() call.\n\n        Double-inference would double latency and produce incorrect outputs.\n        \"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        engine.process_frame(sample_frame_rgb, sample_mask)\n        assert mock_greenformer.call_count == 1\n\n\n# ---------------------------------------------------------------------------\n# Post-processing pipeline\n# ---------------------------------------------------------------------------\n\n\nclass TestProcessFramePostProcessing:\n    \"\"\"Verify post-processing: despill, despeckle, premultiply, composite.\"\"\"\n\n    def test_despill_strength_reduces_green_in_spill_pixels(self, sample_frame_rgb, sample_mask):\n        \"\"\"despill_strength=1.0 must reduce green in spill pixels; strength=0.0 must leave it unchanged.\n\n        The default mock_greenformer returns uniform gray (R=G=B=0.6) which has no\n        green spill by definition: limit=(R+B)/2=0.6=G so spill_amount=0 always.\n        This test uses a green-heavy fg mock (R=0.2, G=0.8, B=0.2) to force\n        spill_amount > 0 and verify the despill path actually runs and reduces green.\n        \"\"\"\n        from unittest.mock import MagicMock\n\n        def green_heavy_forward(x):\n            b, c, h, w = x.shape\n            fg = torch.zeros(b, 3, h, w, dtype=torch.float32)\n            fg[:, 0, :, :] = 0.2  # R\n            fg[:, 1, :, :] = 0.8  # G — heavy green spill: G >> (R+B)/2\n            fg[:, 2, :, :] = 0.2  # B\n            return {\n                \"alpha\": torch.full((b, 1, h, w), 0.8, dtype=torch.float32),\n                \"fg\": fg,\n            }\n\n        green_mock = MagicMock()\n        green_mock.side_effect = green_heavy_forward\n        green_mock.refiner = None\n        green_mock.use_refiner = False\n\n        engine = _make_engine_with_mock(green_mock)\n        result_no_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=0.0)\n        result_full_despill = engine.process_frame(sample_frame_rgb, sample_mask, despill_strength=1.0)\n\n        rgb_none = result_no_despill[\"processed\"][:, :, :3]\n        rgb_full = result_full_despill[\"processed\"][:, :, :3]\n\n        # Both outputs must be valid shapes and in-range\n        assert rgb_none.shape == rgb_full.shape\n        assert rgb_none.min() >= 0.0\n        assert rgb_full.min() >= 0.0\n\n        # Green channel must be reduced by despill (spill_amount > 0 is guaranteed by construction)\n        assert rgb_full[:, :, 1].mean() < rgb_none[:, :, 1].mean(), (\n            \"despill_strength=1.0 should reduce the green channel relative to strength=0.0 when G > (R+B)/2\"\n        )\n\n    def test_auto_despeckle_toggle(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"auto_despeckle=False should skip clean_matte without crashing.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask, auto_despeckle=False)\n        assert result[\"alpha\"].shape[:2] == sample_frame_rgb.shape[:2]\n\n    def test_processed_is_linear_premul_rgba(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"The 'processed' output should be 4-channel RGBA (linear, premultiplied).\n\n        This is the EXR-ready output that compositors load into Nuke for\n        an Over operation. The RGB channels should be <= alpha (premultiplied\n        means color is already multiplied by alpha).\n        \"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n        processed = result[\"processed\"]\n        assert processed.shape[2] == 4\n\n        rgb = processed[:, :, :3]\n        alpha = processed[:, :, 3:4]\n        # Use srgb_to_linear rather than the gamma 2.2 approximation (x**2.2).\n        # LLM_HANDOVER.md Bug History: \"Do not apply a pure mathematical Gamma 2.2\n        # curve; use the piecewise real sRGB transfer functions defined in color_utils.py.\"\n        # The difference between the two at FG=0.6 is ~0.005, which the previous\n        # atol=1e-2 was too loose to catch — a gamma 2.2 regression would have passed.\n        expected_premul = cu.srgb_to_linear(np.float32(0.6)) * 0.8\n        np.testing.assert_allclose(alpha, 0.8, atol=1e-5)\n        np.testing.assert_allclose(rgb, expected_premul, atol=1e-4)\n\n    def test_mask_2d_vs_3d_input(self, sample_frame_rgb, mock_greenformer):\n        \"\"\"process_frame should accept both [H, W] and [H, W, 1] masks.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        mask_2d = np.ones((64, 64), dtype=np.float32) * 0.5\n        mask_3d = mask_2d[:, :, np.newaxis]\n\n        result_2d = engine.process_frame(sample_frame_rgb, mask_2d)\n        result_3d = engine.process_frame(sample_frame_rgb, mask_3d)\n\n        # Both should produce the same output\n        np.testing.assert_allclose(result_2d[\"alpha\"], result_3d[\"alpha\"], atol=1e-5)\n\n    def test_refiner_scale_parameter_accepted(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"Non-default refiner_scale must not raise — the parameter must be threaded through.\"\"\"\n        engine = _make_engine_with_mock(mock_greenformer)\n        result = engine.process_frame(sample_frame_rgb, sample_mask, refiner_scale=0.5)\n        assert result[\"alpha\"].shape[:2] == sample_frame_rgb.shape[:2]\n\n\n# ---------------------------------------------------------------------------\n# NVIDIA Specific GPU test\n# ---------------------------------------------------------------------------\n\n\nclass TestNvidiaGPUProcess:\n    @pytest.mark.gpu\n    def test_process_frame_on_gpu(self, sample_frame_rgb, sample_mask, mock_greenformer):\n        \"\"\"\n        Scenario: Process a frame using a CUDA-configured engine.\n        Expected: Input tensors are moved to CUDA before the model is called,\n        confirmed by asserting the device of the tensor the mock received.\n        \"\"\"\n        if not torch.cuda.is_available():\n            pytest.skip(\"CUDA not available\")\n\n        captured_device: list[torch.device] = []\n        original_side_effect = mock_greenformer.side_effect\n\n        def spy_forward(x):\n            captured_device.append(x.device)\n            return original_side_effect(x)\n\n        mock_greenformer.side_effect = spy_forward\n\n        engine = _make_engine_with_mock(mock_greenformer)\n        engine.device = torch.device(\"cuda\")\n\n        result = engine.process_frame(sample_frame_rgb, sample_mask)\n        assert result[\"alpha\"].dtype == np.float32\n        assert len(captured_device) == 1, \"Model should be called exactly once\"\n        assert captured_device[0].type == \"cuda\", f\"Expected model input on cuda, got {captured_device[0]}\"\n"
  },
  {
    "path": "tests/test_mlx_smoke.py",
    "content": "\"\"\"MLX integration smoke test — requires Apple Silicon + corridorkey_mlx.\"\"\"\n\nimport numpy as np\nimport pytest\n\npytestmark = [pytest.mark.mlx, pytest.mark.slow]\n\n\n@pytest.fixture\ndef mlx_engine():\n    \"\"\"Load MLX engine via create_engine at 2048.\"\"\"\n    from CorridorKeyModule.backend import create_engine\n\n    return create_engine(backend=\"mlx\", img_size=2048)\n\n\ndef test_mlx_smoke_2048(mlx_engine):\n    \"\"\"Process one synthetic frame and verify output contract.\"\"\"\n    h, w = 2048, 2048\n\n    # Solid green image + white mask\n    image = np.zeros((h, w, 3), dtype=np.float32)\n    image[:, :, 1] = 1.0  # green channel\n    mask = np.ones((h, w, 1), dtype=np.float32)\n\n    result = mlx_engine.process_frame(image, mask)\n\n    # Keys\n    assert set(result.keys()) == {\"alpha\", \"fg\", \"comp\", \"processed\"}\n\n    # Shapes\n    assert result[\"alpha\"].shape == (h, w, 1), f\"alpha shape: {result['alpha'].shape}\"\n    assert result[\"fg\"].shape == (h, w, 3), f\"fg shape: {result['fg'].shape}\"\n    assert result[\"comp\"].shape == (h, w, 3), f\"comp shape: {result['comp'].shape}\"\n    assert result[\"processed\"].shape == (h, w, 4), f\"processed shape: {result['processed'].shape}\"\n\n    # Dtypes\n    for key in (\"alpha\", \"fg\", \"comp\", \"processed\"):\n        assert result[key].dtype == np.float32, f\"{key} dtype: {result[key].dtype}\"\n\n    # Value ranges (0-1 for alpha/fg; comp/processed may slightly exceed due to sRGB conversion)\n    for key in (\"alpha\", \"fg\"):\n        assert result[key].min() >= 0.0, f\"{key} has negative values\"\n        assert result[key].max() <= 1.0, f\"{key} exceeds 1.0\"\n    for key in (\"comp\", \"processed\"):\n        assert result[key].min() >= 0.0, f\"{key} has negative values\"\n"
  },
  {
    "path": "tests/test_pbt_auto_download.py",
    "content": "\"\"\"Feature: auto-model-download — Property-based tests for auto checkpoint download.\n\nProperties tested:\n  1: Missing checkpoint triggers download and returns valid path.\n  2: Existing checkpoint skips download.\n  3: Auto-download is Torch-only.\n  4: Network errors produce actionable messages.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport tempfile\nfrom pathlib import Path\nfrom unittest import mock\n\nimport pytest\nfrom hypothesis import given, settings\nfrom hypothesis import strategies as st\n\nfrom CorridorKeyModule.backend import (\n    HF_CHECKPOINT_FILENAME,\n    HF_REPO_ID,\n    TORCH_EXT,\n    _discover_checkpoint,\n    _ensure_torch_checkpoint,\n)\n\n# ---------------------------------------------------------------------------\n# Strategies\n# ---------------------------------------------------------------------------\n\n# File extensions that are NOT .pth — used to populate \"non-empty but no .pth\" dirs\n_non_pth_extensions = st.sampled_from(\n    [\n        \".txt\",\n        \".json\",\n        \".safetensors\",\n        \".bin\",\n        \".onnx\",\n        \".csv\",\n        \".log\",\n        \".yaml\",\n    ]\n)\n\n# Strategy: list of non-.pth filenames to place in the checkpoint dir\n_junk_filenames = st.lists(\n    st.tuples(\n        st.text(\n            alphabet=st.characters(whitelist_categories=(\"L\", \"N\"), whitelist_characters=\"_-\"),\n            min_size=1,\n            max_size=12,\n        ),\n        _non_pth_extensions,\n    ).map(lambda t: f\"{t[0]}{t[1]}\"),\n    min_size=0,\n    max_size=5,\n)\n\n\n# ---------------------------------------------------------------------------\n# Property 1: Missing checkpoint triggers download and returns valid path\n# ---------------------------------------------------------------------------\n\n\nclass TestMissingCheckpointTriggersDownload:\n    \"\"\"Property 1: For any empty checkpoint directory (no .pth files),\n    calling _discover_checkpoint(TORCH_EXT) invokes hf_hub_download with\n    the correct repo ID and filename, copies the result to\n    CHECKPOINT_DIR/CorridorKey.pth, and returns a Path that exists on disk.\n\n    Feature: auto-model-download, Property 1: Missing checkpoint triggers download and returns valid path\n\n    **Validates: Requirements 1.1, 1.2, 4.1, 4.2**\n    \"\"\"\n\n    @settings(max_examples=100)\n    @given(junk_files=_junk_filenames)\n    def test_missing_pth_triggers_download_and_returns_valid_path(\n        self,\n        junk_files: list[str],\n    ) -> None:\n        \"\"\"Feature: auto-model-download, Property 1: Missing checkpoint triggers download and returns valid path\n\n        **Validates: Requirements 1.1, 1.2, 4.1, 4.2**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmp:\n            ckpt_dir = Path(tmp) / \"checkpoints\"\n            ckpt_dir.mkdir()\n\n            # Populate with non-.pth junk files (may be empty list)\n            for fname in junk_files:\n                (ckpt_dir / fname).touch()\n\n            # Prepare a fake cached file that hf_hub_download would return\n            cache_dir = Path(tmp) / \"hf_cache\"\n            cache_dir.mkdir()\n            cached_file = cache_dir / HF_CHECKPOINT_FILENAME\n            cached_file.write_bytes(b\"fake-checkpoint-bytes\")\n\n            with (\n                mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(ckpt_dir)),\n                mock.patch(\n                    \"huggingface_hub.hf_hub_download\",\n                    return_value=str(cached_file),\n                ) as mock_dl,\n            ):\n                result = _discover_checkpoint(TORCH_EXT)\n\n                # The returned path must point to CHECKPOINT_DIR/CorridorKey.pth\n                expected = ckpt_dir / HF_CHECKPOINT_FILENAME\n                assert result == expected, f\"Expected {expected}, got {result}\"\n\n                # The file must actually exist on disk\n                assert result.exists(), f\"Returned path does not exist: {result}\"\n\n                # hf_hub_download must have been called with correct args\n                mock_dl.assert_called_once_with(\n                    repo_id=HF_REPO_ID,\n                    filename=HF_CHECKPOINT_FILENAME,\n                )\n\n\n# ---------------------------------------------------------------------------\n# Strategies for Property 2\n# ---------------------------------------------------------------------------\n\n# Strategy: valid .pth filenames (alphanumeric + underscore/dash, non-empty)\n_pth_basenames = st.text(\n    alphabet=st.characters(whitelist_categories=(\"L\", \"N\"), whitelist_characters=\"_-\"),\n    min_size=1,\n    max_size=20,\n).map(lambda s: f\"{s}.pth\")\n\n\n# ---------------------------------------------------------------------------\n# Property 2: Existing checkpoint skips download\n# ---------------------------------------------------------------------------\n\n\nclass TestExistingCheckpointSkipsDownload:\n    \"\"\"Property 2: For any checkpoint directory that already contains a .pth\n    file, calling _discover_checkpoint(TORCH_EXT) returns the existing file's\n    path without invoking hf_hub_download.\n\n    Feature: auto-model-download, Property 2: Existing checkpoint skips download\n\n    **Validates: Requirements 1.3**\n    \"\"\"\n\n    @settings(max_examples=100)\n    @given(pth_name=_pth_basenames)\n    def test_existing_pth_skips_download(self, pth_name: str) -> None:\n        \"\"\"Feature: auto-model-download, Property 2: Existing checkpoint skips download\n\n        **Validates: Requirements 1.3**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmp:\n            ckpt_dir = Path(tmp) / \"checkpoints\"\n            ckpt_dir.mkdir()\n\n            # Place exactly one .pth file in the checkpoint directory\n            existing_file = ckpt_dir / pth_name\n            existing_file.write_bytes(b\"fake-checkpoint-data\")\n\n            with (\n                mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(ckpt_dir)),\n                mock.patch(\n                    \"huggingface_hub.hf_hub_download\",\n                ) as mock_dl,\n            ):\n                result = _discover_checkpoint(TORCH_EXT)\n\n                # hf_hub_download must NOT have been called\n                mock_dl.assert_not_called()\n\n                # The returned path must match the existing file\n                assert result == existing_file, f\"Expected {existing_file}, got {result}\"\n\n\n# ---------------------------------------------------------------------------\n# Property 3: Auto-download is Torch-only\n# ---------------------------------------------------------------------------\n\n\nclass TestAutoDownloadIsTorchOnly:\n    \"\"\"Property 3: For any extension that is not TORCH_EXT, calling\n    _discover_checkpoint(ext) with zero matches raises FileNotFoundError\n    without invoking hf_hub_download.\n\n    Feature: auto-model-download, Property 3: Auto-download is Torch-only\n\n    **Validates: Requirements 1.4, 4.3**\n    \"\"\"\n\n    @settings(max_examples=100)\n    @given(ext=_non_pth_extensions.map(lambda e: e if e.startswith(\".\") else f\".{e}\"))\n    def test_non_pth_extension_raises_without_download(self, ext: str) -> None:\n        \"\"\"Feature: auto-model-download, Property 3: Auto-download is Torch-only\n\n        **Validates: Requirements 1.4, 4.3**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmp:\n            ckpt_dir = Path(tmp) / \"checkpoints\"\n            ckpt_dir.mkdir()\n\n            with (\n                mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(ckpt_dir)),\n                mock.patch(\n                    \"huggingface_hub.hf_hub_download\",\n                ) as mock_dl,\n            ):\n                with pytest.raises(FileNotFoundError):\n                    _discover_checkpoint(ext)\n\n                # hf_hub_download must NOT have been called\n                mock_dl.assert_not_called()\n\n\n# ---------------------------------------------------------------------------\n# Strategies for Property 4\n# ---------------------------------------------------------------------------\n\n\ndef _make_hf_hub_http_error(message: str) -> Exception:\n    \"\"\"Create an HfHubHTTPError with a mock response object.\"\"\"\n    import requests\n    from huggingface_hub.utils import HfHubHTTPError\n\n    response = requests.Response()\n    response.status_code = 503\n    return HfHubHTTPError(message, response=response)\n\n\n# Network-related exception factories: each takes a message and returns an exception\n_network_exception_factories = [\n    lambda msg: ConnectionError(msg),\n    lambda msg: TimeoutError(msg),\n    lambda msg: _make_hf_hub_http_error(msg),\n]\n\n_network_exception_strategy = st.tuples(\n    st.sampled_from(_network_exception_factories),\n    st.text(min_size=1, max_size=50),\n).map(lambda t: t[0](t[1]))\n\n\n# ---------------------------------------------------------------------------\n# Property 4: Network errors produce actionable messages\n# ---------------------------------------------------------------------------\n\n\nclass TestNetworkErrorsProduceActionableMessages:\n    \"\"\"Property 4: For any network-related exception raised by\n    hf_hub_download, _ensure_torch_checkpoint() raises a RuntimeError\n    whose message contains both the HuggingFace repository URL and a\n    connectivity troubleshooting hint.\n\n    Feature: auto-model-download, Property 4: Network errors produce actionable messages\n\n    **Validates: Requirements 3.1**\n    \"\"\"\n\n    @settings(max_examples=100)\n    @given(exc=_network_exception_strategy)\n    def test_network_errors_produce_actionable_messages(\n        self,\n        exc: Exception,\n    ) -> None:\n        \"\"\"Feature: auto-model-download, Property 4: Network errors produce actionable messages\n\n        **Validates: Requirements 3.1**\n        \"\"\"\n        with tempfile.TemporaryDirectory() as tmp:\n            ckpt_dir = Path(tmp) / \"checkpoints\"\n            ckpt_dir.mkdir()\n\n            with (\n                mock.patch(\"CorridorKeyModule.backend.CHECKPOINT_DIR\", str(ckpt_dir)),\n                mock.patch(\n                    \"huggingface_hub.hf_hub_download\",\n                    side_effect=exc,\n                ),\n            ):\n                with pytest.raises(RuntimeError) as exc_info:\n                    _ensure_torch_checkpoint()\n\n                error_msg = str(exc_info.value)\n\n                # Must contain the HuggingFace repo URL\n                expected_url = f\"https://huggingface.co/{HF_REPO_ID}\"\n                assert expected_url in error_msg, (\n                    f\"Error message missing HF repo URL.\\nExpected URL: {expected_url}\\nGot message: {error_msg}\"\n                )\n\n                # Must contain the connectivity hint\n                expected_hint = \"Check your network connection and try again\"\n                assert expected_hint in error_msg, (\n                    f\"Error message missing connectivity hint.\\n\"\n                    f\"Expected hint: {expected_hint}\\n\"\n                    f\"Got message: {error_msg}\"\n                )\n"
  },
  {
    "path": "tests/test_pbt_backend_resolution.py",
    "content": "\"\"\"Property-based test for backend resolution priority chain.\n\n# Feature: uv-lock-drift-fix, Property 1: Backend resolution priority chain\n\nValidates: Requirements 5.1\n\"\"\"\n\nfrom __future__ import annotations\n\nimport os\nfrom unittest import mock\n\nfrom hypothesis import given, settings\nfrom hypothesis import strategies as st\n\nfrom CorridorKeyModule.backend import BACKEND_ENV_VAR, resolve_backend\n\n# --- Strategies ---\n\n# CLI --backend values: None means not provided, \"auto\" means explicit auto\ncli_backend = st.sampled_from([None, \"auto\", \"torch\", \"mlx\"])\n\n# Environment variable: None means unset, otherwise a string value\nenv_backend = st.sampled_from([None, \"auto\", \"torch\", \"mlx\"])\n\n# Auto-detection result (what _auto_detect_backend would return)\nauto_detect_result = st.sampled_from([\"torch\", \"mlx\"])\n\n\ndef _expected_backend(cli: str | None, env: str | None, auto: str) -> str:\n    \"\"\"Reference implementation of the priority chain.\n\n    1. CLI value if explicit and not \"auto\"\n    2. Env var if set and not \"auto\"\n    3. Auto-detection result\n    \"\"\"\n    if cli is not None and cli != \"auto\":\n        return cli\n    if env is not None and env != \"auto\":\n        return env\n    return auto\n\n\n@settings(max_examples=200)\n@given(cli=cli_backend, env=env_backend, auto=auto_detect_result)\ndef test_backend_resolution_priority_chain(cli: str | None, env: str | None, auto: str) -> None:\n    \"\"\"For any combination of CLI flag, env var, and auto-detection result,\n    resolve_backend() returns the CLI value if explicit and not \"auto\",\n    else the env var if set and not \"auto\", else the auto-detection result.\n\n    **Validates: Requirements 5.1**\n    \"\"\"\n    expected = _expected_backend(cli, env, auto)\n\n    # Build the environment: set or unset CORRIDORKEY_BACKEND\n    env_dict = {BACKEND_ENV_VAR: env} if env is not None else {}\n\n    # We need to mock:\n    # 1. The environment variable\n    # 2. _auto_detect_backend() to return our generated auto value\n    # 3. _validate_mlx_available() to avoid platform checks when mlx is selected\n    with (\n        mock.patch.dict(os.environ, env_dict, clear=False),\n        mock.patch(\"CorridorKeyModule.backend._auto_detect_backend\", return_value=auto),\n        mock.patch(\"CorridorKeyModule.backend._validate_mlx_available\"),\n    ):\n        # Ensure env var is unset when env is None\n        if env is None:\n            os.environ.pop(BACKEND_ENV_VAR, None)\n\n        result = resolve_backend(cli)\n\n    assert result == expected, f\"cli={cli!r}, env={env!r}, auto={auto!r} → expected {expected!r}, got {result!r}\"\n"
  },
  {
    "path": "tests/test_pbt_dep_preservation.py",
    "content": "\"\"\"Property-based test for non-torch dependency preservation.\n\n# Feature: uv-lock-drift-fix, Property 2: Non-torch dependency preservation\n\nValidates: Requirements 6.1, 6.2, 6.3\n\"\"\"\n\nfrom __future__ import annotations\n\nimport sys\nfrom pathlib import Path\n\nimport pytest\n\nif sys.version_info >= (3, 11):\n    import tomllib\nelse:\n    try:\n        import tomli as tomllib  # type: ignore[no-redef]\n    except ImportError:\n        pytest.skip(\"tomli required for Python < 3.11\", allow_module_level=True)\n\nfrom hypothesis import given, settings\nfrom hypothesis import strategies as st\n\nPYPROJECT_PATH = Path(__file__).resolve().parents[1] / \"pyproject.toml\"\n\n# Known original non-torch, non-torchvision base dependencies from the design doc.\n# These are the dependencies that MUST be preserved after the restructuring.\nKNOWN_ORIGINAL_NON_TORCH_DEPS: list[str] = [\n    \"timm==1.0.24\",\n    \"numpy\",\n    \"opencv-python\",\n    \"tqdm\",\n    \"setuptools\",\n    \"diffusers\",\n    \"transformers\",\n    \"accelerate\",\n    \"peft\",\n    \"av\",\n    \"Pillow\",\n    \"PIMS\",\n    \"easydict\",\n    \"imageio\",\n    \"matplotlib\",\n    \"einops\",\n    \"huggingface-hub\",\n    \"typer>=0.12\",\n    \"rich>=13\",\n]\n\n\ndef _parse_base_dependencies() -> list[str]:\n    \"\"\"Parse the current pyproject.toml and return the base dependencies list.\"\"\"\n    with open(PYPROJECT_PATH, \"rb\") as f:\n        data = tomllib.load(f)\n    return data[\"project\"][\"dependencies\"]\n\n\ndef _normalize_dep(dep: str) -> str:\n    \"\"\"Normalize a dependency string for comparison.\n\n    Lowercases the package name portion (before any version specifier or marker),\n    strips whitespace, and normalizes underscores to hyphens in the package name.\n    \"\"\"\n    return dep.strip().lower().replace(\"_\", \"-\")\n\n\ndef _dep_present(dep: str, dep_list: list[str]) -> bool:\n    \"\"\"Check if a dependency string is present in a list, using normalized comparison.\"\"\"\n    normalized = _normalize_dep(dep)\n    return any(_normalize_dep(d) == normalized for d in dep_list)\n\n\n# --- Property-based test ---\n\n\n@settings(max_examples=200)\n@given(dep=st.sampled_from(KNOWN_ORIGINAL_NON_TORCH_DEPS))\ndef test_non_torch_dependency_preserved(dep: str) -> None:\n    \"\"\"For any non-torch, non-torchvision dependency from the original base\n    dependencies list, that dependency (with identical name and version\n    constraint) shall appear in the updated pyproject.toml base dependencies.\n\n    **Validates: Requirements 6.1, 6.2, 6.3**\n    \"\"\"\n    current_deps = _parse_base_dependencies()\n    assert _dep_present(dep, current_deps), (\n        f\"Dependency {dep!r} from the original pyproject.toml is missing \"\n        f\"or has a different version constraint in the current base dependencies.\\n\"\n        f\"Current dependencies: {current_deps}\"\n    )\n"
  },
  {
    "path": "tests/test_pyproject_structure.py",
    "content": "\"\"\"Structural validation tests for pyproject.toml extras configuration.\n\nValidates that the pyproject.toml correctly defines CUDA/MLX extras,\nscoped index sources, and conflict groups to eliminate lockfile drift.\n\"\"\"\n\nfrom __future__ import annotations\n\nimport sys\nfrom pathlib import Path\n\nimport pytest\n\nif sys.version_info >= (3, 11):\n    import tomllib\nelse:\n    try:\n        import tomli as tomllib  # type: ignore[no-redef]\n    except ImportError:\n        pytest.skip(\"tomli required for Python < 3.11\", allow_module_level=True)\n\nPYPROJECT_PATH = Path(__file__).resolve().parents[1] / \"pyproject.toml\"\n\n\n@pytest.fixture(scope=\"module\")\ndef pyproject() -> dict:\n    \"\"\"Parse and return the pyproject.toml as a dict.\"\"\"\n    with open(PYPROJECT_PATH, \"rb\") as f:\n        return tomllib.load(f)\n\n\n# ---------------------------------------------------------------------------\n# Requirement 2.1, 2.2: CUDA optional extra\n# ---------------------------------------------------------------------------\n\n\nclass TestCudaExtra:\n    \"\"\"Validates: Requirements 1.2, 1.3, 2.1, 2.2, 2.3\"\"\"\n\n    def test_cuda_extra_contains_torch(self, pyproject: dict) -> None:\n        cuda_deps = pyproject[\"project\"][\"optional-dependencies\"][\"cuda\"]\n        assert \"torch==2.8.0\" in cuda_deps\n\n    def test_cuda_extra_contains_torchvision(self, pyproject: dict) -> None:\n        cuda_deps = pyproject[\"project\"][\"optional-dependencies\"][\"cuda\"]\n        assert \"torchvision==0.23.0\" in cuda_deps\n\n\n# ---------------------------------------------------------------------------\n# Requirement 3.1, 3.2, 3.3: MLX optional extra\n# ---------------------------------------------------------------------------\n\n\nclass TestMlxExtra:\n    \"\"\"Validates: Requirements 3.1, 3.2, 3.3\"\"\"\n\n    def test_mlx_extra_contains_corridorkey_mlx(self, pyproject: dict) -> None:\n        mlx_deps = pyproject[\"project\"][\"optional-dependencies\"][\"mlx\"]\n        mlx_dep_names = [d.split(\";\")[0].strip() for d in mlx_deps]\n        assert \"corridorkey-mlx\" in mlx_dep_names\n\n\n# ---------------------------------------------------------------------------\n# Requirement 2.3: pytorch index scoped to cuda extra\n# ---------------------------------------------------------------------------\n\n\nclass TestPytorchIndex:\n    \"\"\"Validates: Requirements 2.3\"\"\"\n\n    def test_pytorch_index_has_cuda_extra(self, pyproject: dict) -> None:\n        indexes = pyproject[\"tool\"][\"uv\"][\"index\"]\n        pytorch_entries = [idx for idx in indexes if idx.get(\"name\") == \"pytorch\"]\n        assert len(pytorch_entries) == 1, \"Expected exactly one pytorch index entry\"\n        assert pytorch_entries[0].get(\"extra\") == \"cuda\"\n\n\n# ---------------------------------------------------------------------------\n# Requirement 2.1, 2.3: torch/torchvision source overrides scoped to cuda\n# ---------------------------------------------------------------------------\n\n\nclass TestUvSources:\n    \"\"\"Validates: Requirements 1.3, 2.1, 2.3\"\"\"\n\n    def test_torch_source_has_cuda_extra(self, pyproject: dict) -> None:\n        sources = pyproject[\"tool\"][\"uv\"][\"sources\"]\n        torch_src = sources[\"torch\"]\n        assert torch_src.get(\"extra\") == \"cuda\"\n        assert \"marker\" not in torch_src, \"torch source should not have platform markers\"\n\n    def test_torchvision_source_has_cuda_extra(self, pyproject: dict) -> None:\n        sources = pyproject[\"tool\"][\"uv\"][\"sources\"]\n        tv_src = sources[\"torchvision\"]\n        assert tv_src.get(\"extra\") == \"cuda\"\n        assert \"marker\" not in tv_src, \"torchvision source should not have platform markers\"\n\n\n# ---------------------------------------------------------------------------\n# Requirement 4.1: Conflict group between cuda and mlx\n# ---------------------------------------------------------------------------\n\n\nclass TestConflicts:\n    \"\"\"Validates: Requirements 4.1\"\"\"\n\n    def test_cuda_mlx_conflict_declared(self, pyproject: dict) -> None:\n        conflicts = pyproject[\"tool\"][\"uv\"][\"conflicts\"]\n        # conflicts is a list of conflict groups; each group is a list of dicts\n        extras_in_groups = [\n            {entry[\"extra\"] for entry in group} for group in conflicts if all(\"extra\" in entry for entry in group)\n        ]\n        assert {\"cuda\", \"mlx\"} in extras_in_groups, \"Expected a conflict group containing both 'cuda' and 'mlx' extras\"\n\n\n# ---------------------------------------------------------------------------\n# Requirement 6.2: timm git source override preserved\n# ---------------------------------------------------------------------------\n\n\nclass TestTimmSourcePreserved:\n    \"\"\"Validates: Requirements 6.2\"\"\"\n\n    def test_timm_source_is_git(self, pyproject: dict) -> None:\n        timm_src = pyproject[\"tool\"][\"uv\"][\"sources\"][\"timm\"]\n        assert \"git\" in timm_src, \"timm source should be a git override\"\n\n    def test_timm_git_url(self, pyproject: dict) -> None:\n        timm_src = pyproject[\"tool\"][\"uv\"][\"sources\"][\"timm\"]\n        assert timm_src[\"git\"] == \"https://github.com/Raiden129/pytorch-image-models-fix\"\n\n    def test_timm_git_branch(self, pyproject: dict) -> None:\n        timm_src = pyproject[\"tool\"][\"uv\"][\"sources\"][\"timm\"]\n        assert timm_src[\"branch\"] == \"fix/hiera-flash-attention-global-4d\"\n\n\n# ---------------------------------------------------------------------------\n# Requirement 6.3: triton-windows platform-conditional dependency preserved\n# ---------------------------------------------------------------------------\n\n\nclass TestTritonWindowsPreserved:\n    \"\"\"Validates: Requirements 6.3\"\"\"\n\n    def test_triton_windows_in_base_deps(self, pyproject: dict) -> None:\n        deps = pyproject[\"project\"][\"dependencies\"]\n        triton_entries = [d for d in deps if \"triton-windows\" in d]\n        assert len(triton_entries) == 1, \"Expected exactly one triton-windows dependency\"\n\n    def test_triton_windows_has_win32_marker(self, pyproject: dict) -> None:\n        deps = pyproject[\"project\"][\"dependencies\"]\n        triton_entries = [d for d in deps if \"triton-windows\" in d]\n        assert \"sys_platform == 'win32'\" in triton_entries[0]\n\n\n# ---------------------------------------------------------------------------\n# Requirement 7.1: dev dependency group preserved\n# ---------------------------------------------------------------------------\n\n\nclass TestDevDependencyGroup:\n    \"\"\"Validates: Requirements 7.1\"\"\"\n\n    def test_dev_group_contains_pytest(self, pyproject: dict) -> None:\n        dev = pyproject[\"dependency-groups\"][\"dev\"]\n        assert \"pytest\" in dev\n\n    def test_dev_group_contains_pytest_cov(self, pyproject: dict) -> None:\n        dev = pyproject[\"dependency-groups\"][\"dev\"]\n        assert \"pytest-cov\" in dev\n\n    def test_dev_group_contains_ruff(self, pyproject: dict) -> None:\n        dev = pyproject[\"dependency-groups\"][\"dev\"]\n        assert \"ruff\" in dev\n\n\n# ---------------------------------------------------------------------------\n# Requirement 7.2: docs dependency group preserved\n# ---------------------------------------------------------------------------\n\n\nclass TestDocsDependencyGroup:\n    \"\"\"Validates: Requirements 7.2\"\"\"\n\n    def test_docs_group_contains_zensical(self, pyproject: dict) -> None:\n        docs = pyproject[\"dependency-groups\"][\"docs\"]\n        zensical_entries = [d for d in docs if \"zensical\" in d]\n        assert len(zensical_entries) == 1\n        assert \"zensical>=0.0.24\" in zensical_entries[0]\n"
  },
  {
    "path": "zensical.toml",
    "content": "[project]\nsite_name = \"CorridorKey\"\nsite_description = \"Perfect Green Screen Keys\"\nrepo_name = \"CorridorKey\"\nrepo_url = \"https://github.com/nikopueringer/CorridorKey\"\n\nnav = [{ Home = \"index.md\" }, { LLM_HANDOVER = \"LLM_HANDOVER.md\" }]\n\n[project.theme]\nlanguage = \"en\"\nfeatures = [\n    \"announce.dismiss\",\n    \"content.action.edit\",\n    \"content.action.view\",\n    \"content.code.annotate\",\n    \"content.code.copy\",\n    \"content.code.select\",\n    \"content.footnote.tooltips\",\n    \"content.tabs.link\",\n    \"content.tooltips\",\n    \"navigation.footer\",\n    \"navigation.indexes\",\n    \"navigation.instant\",\n    \"navigation.instant.prefetch\",\n    \"navigation.instant.progress\",\n    \"navigation.path\",\n    \"navigation.tabs\",\n    \"navigation.top\",\n    \"navigation.tracking\",\n    \"search.highlight\",\n]\n\n[[project.theme.palette]]\nscheme = \"default\"\ntoggle.icon = \"lucide/sun\"\ntoggle.name = \"Switch to dark mode\"\n\n[[project.theme.palette]]\nscheme = \"slate\"\ntoggle.icon = \"lucide/moon\"\ntoggle.name = \"Switch to light mode\"\n\n[project.markdown_extensions.admonition]\n[project.markdown_extensions.pymdownx.details]\n[project.markdown_extensions.pymdownx.superfences]\ncustom_fences = [\n    { name = \"mermaid\", class = \"mermaid\", format = \"pymdownx.superfences.fence_code_format\" },\n]\n\n[project.markdown_extensions.pymdownx.tabbed]\nalternate_style = true\n"
  }
]