main bbb3fb004546 cached
10 files
104.9 KB
29.6k tokens
206 symbols
1 requests
Download .txt
Repository: spacepxl/ComfyUI-Image-Filters
Branch: main
Commit: bbb3fb004546
Files: 10
Total size: 104.9 KB

Directory structure:
gitextract_qxw4_7mm/

├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── download_all_models.py
├── import_error_install.bat
├── install.bat
├── nodes.py
├── raft.py
└── requirements.txt

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# models
models/

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 spacepxl

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
## ComfyUI-Image-Filters

Started as just some image processing nodes, but now more of a kitchen sink nodepack

Two install batch files are provided, `install.bat` which only installs requirements, and `import_error_install.bat`, which uninstalls all versions of opencv then reinstalls all 4 variants with matching version (use this if you get import errors relating to opencv or cv2, which are caused by manager or other node packs installing different variants and/or versions.)

Or if you want to manage requirements manually, the only opencv variant you actually need is `opencv-contrib-python`, it covers all opencv requirements.

## Nodes

<details><summary>Latent</summary>

### AdaIN Latent

Normalizes latents to the mean and std dev of a reference input. Useful for getting rid of color shift from high denoise strength, or matching color to a reference in general.

### AdaIN Filter Latent

Same as AdaIN Latent, but with a spatial filter instead of the full frame, works like a latent color match.

### Batch Normalize Latent

Normalizes each frame in a batch to the overall mean and std dev, good for removing overall brightness flickering.

### Clamp Outliers

Clamps latents that are more than n standard deviations away from the mean. Could help with fireflies or stray noise that disrupt the VAE decode.

### Upscale Hunyuan3Dv2 Latent By

Nearest Neighbor upscaling for Hy3D latents, might be useful for hires fix.

### Latent Normalize/Shuffle

Can help break up residual image information in inversion noise.

### RandnLikeLatent

Create random noise in the same shape as the input latent, works with any latent. Useful for noise injection or other times when you just want to control noise manually.

### Offset Latent Image

Create an empty latent image with custom values, for offset noise with per-channel control. Can be combined with Latent Stats to get channel values.

### Sharpen Filter (Latent)

Increases local contrast between latent "pixels" with an image sharpening filter.

</details>

<details><summary>Image</summary>

### AdaIN Image

Normalizes images to the mean and std dev of a reference input. Useful for getting rid of color shift from high denoise strength, or matching color to a reference in general.

### Batch Align (RAFT)

Use RAFT motion vectors to warp align images

### Batch Average Image

Returns the single average image of a batch.

### Batch Normalize Image

Normalizes each frame in a batch to the overall mean and std dev, good for removing overall brightness flickering.

### BetterFilmGrain

Yet another film grain node, but this one looks better (realistic grain structure, no pixel-perfect RGB glitter, natural luminance/intensity response) and is 10x faster than the next best option (ProPostFilmGrain).

### Bilateral Filter Image

Applies a bilateral filter, can be used to remove noise or high frequency details while preserving edges

### Blur Image (Fast)

Blurs images using opencv gaussian blur, which is >100x faster than comfy image blur. Supports larger blur radius, and separate x/y controls.

### Clamp Image

Clamps image values outside of blackpoint/whitepoint range

### Color Match Image

Match image color to reference image, using overall mean or blurred image (frequency separation)

### Convert Normals

Translate between different normal map color spaces, with optional normalization fix and black region fix.

### Depth to Normals

Converts depthmap to normal map

### Difference Checker

Absolute value of the difference between inputs, with a multiplier to boost dark values for easier viewing. Alternative to the vanilla merge difference node, which is only subtraction without the abs()

### Enhance Detail

Increase or decrease details in an image or batch of images using a guided filter (as opposed to the typical gaussian blur used by most sharpening filters)

### Exposure Adjust

Linear exposure adjustment in f-stops, with optional tonemap.

### Frequency Separate/Combine

For manual frequency separation workflows

### Game of Life

Runs the Game of Life simulation with optional mask input for starting condition

### Guided Filter Image

Use a guided filter to blur an image or mask based on RGB color similarity. Works best with a strong color separation between FG and BG.

### Image Constant Color (RGB/HSV)

Create images of any solid color, from either RGB or HSV values

### Image Matting

Takes an image and trimap/mask, and refines the matte edges with [closed-form matting](https://github.com/pymatting/pymatting). Optionally extracts the foreground and background colors as well. Good for cleaning up SAM segments or hand drawn masks.

### Keyer

Basic image keyer with luma/sat/channel/greenscreen/etc options

### Median Filter Image

Applies a median filter to remove high frequency information from images, useful for frequency separation workflows

### Normal Map (Simple)

Simple high-frequency normal map from Scharr operator

### Relight (Simple)

Basic dot product (Lambertian) relighting from a normal map

### Remap Range

Fits the color range of an image to a new blackpoint and whitepoint (clamped)

### Restore Detail

Transfers details from one image to another using frequency separation. Useful for restoring the lost details from IC-Light or other img2img workflows. Has options for add/subtract method (fewer artifacts, but mostly ignores highlights) or divide/multiply (more natural but can create artifacts in areas that go from dark to bright), and either gaussian blur or guided filter (prevents oversharpened edges)

### Shuffle Channels

Move channels around at will.

### Tonemap / UnTonemap

Apply or remove a log + contrast curve tonemap

Apply tonemap:
```
power = 1.7
SLog3R = clamp((log10((r + 0.01)/0.19) * 261.5 + 420) / 1023, 0, 1)
SLog3G = clamp((log10((g + 0.01)/0.19) * 261.5 + 420) / 1023, 0, 1)
SLog3B = clamp((log10((b + 0.01)/0.19) * 261.5 + 420) / 1023, 0, 1)

r = r > 0.06 ? pow(1 / (1 + (1 / pow(SLog3R / (1 - SLog3R), power))), power) : r
g = g > 0.06 ? pow(1 / (1 + (1 / pow(SLog3G / (1 - SLog3G), power))), power) : g
b = b > 0.06 ? pow(1 / (1 + (1 / pow(SLog3B / (1 - SLog3B), power))), power) : b
```

Remove tonemap:
```
power = 1.7
SR = 1 / (1 + pow((-1/pow(r, 1/power)) * (pow(r, 1/power) - 1), 1/power))
SG = 1 / (1 + pow((-1/pow(g, 1/power)) * (pow(g, 1/power) - 1), 1/power))
SB = 1 / (1 + pow((-1/pow(b, 1/power)) * (pow(b, 1/power) - 1), 1/power))

r = r > 0.06 ? pow(10, (SR * 1023 - 420)/261.5) * 0.19 - 0.01 : r
g = g > 0.06 ? pow(10, (SG * 1023 - 420)/261.5) * 0.19 - 0.01 : g
b = b > 0.06 ? pow(10, (SB * 1023 - 420)/261.5) * 0.19 - 0.01 : b
```

### JitterImage, UnJitterImage, BatchAverageUnJittered

For supersampling/antialiasing workflows.

### Extract N Frames, Merge Frames By Index

For processing a smaller number of frames evenly distributed across a larger batch/video, then merging them back into the full batch

</details>

<details><summary>Mask</summary>

### Blur Mask (Fast)

Same as Blur Image (Fast) but for masks instead of images.

### Dilate/Erode Mask

Dilate or erode masks, with either a box or circle filter.

### Mask Clean

Clean up holes and near-solid areas in a matte.

### Pack Video Mask

Compresses the frames of a video mask to match video VAE latent frames, to work around comfyui's naive temporal resizing of masks.

</details>

<details><summary>Conditioning</summary>

### Conditioning Subtract

Takes the difference of two text conditions, can have interesting effects that are different from negative prompts.

### Inpaint Condition Encode/Apply

Separates the VAE encode from the conditioning so you don't have to re-encode latents every time you change a prompt.

### IP2P Conditioning Advanced

Separates the VAE encode from the conditioning so you don't have to re-encode latents every time you change a prompt.

</details>

<details><summary>Sampling</summary>

### Custom Noise

Use any latent as the noise for SamplerCustomAdvanced.

</details>

<details><summary>Utils</summary>

### Latent Stats

Get/print some stats about the latents (dimensions, and per-channel mean, std dev, min, and max)

### Model Test

Debugging node for examining model structure

### Print Sigmas

Prints the noise schedule sigma values to see what a scheduler is doing

### Visualize Latents

Shows the latent channels as a grid image

</details>


================================================
FILE: __init__.py
================================================
# from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from .nodes import COMBINED_MAPPINGS

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
for k, v in COMBINED_MAPPINGS.items():
    NODE_CLASS_MAPPINGS[k] = v[0]
    NODE_DISPLAY_NAME_MAPPINGS[k] = v[1]

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']

================================================
FILE: download_all_models.py
================================================
from raft import load_raft

load_raft()

================================================
FILE: import_error_install.bat
================================================
@echo off

set "requirements_txt=%~dp0\requirements.txt"
set "python_exec=..\..\..\python_embeded\python.exe"

echo installing requirements...

if exist "%python_exec%" (
    echo Installing with ComfyUI Portable
	%python_exec% -s -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless opencv-contrib-python-headless
    for /f "delims=" %%i in (%requirements_txt%) do (
        %python_exec% -s -m pip install "%%i"
    )
) else (
    echo Installing with system Python
	pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless opencv-contrib-python-headless
    for /f "delims=" %%i in (%requirements_txt%) do (
        pip install "%%i"
    )
)

pause

================================================
FILE: install.bat
================================================
@echo off

set "requirements_txt=%~dp0\requirements.txt"
set "python_exec=..\..\..\python_embeded\python.exe"

echo installing requirements...

if exist "%python_exec%" (
    echo Installing with ComfyUI Portable
    for /f "delims=" %%i in (%requirements_txt%) do (
        %python_exec% -s -m pip install "%%i"
    )
) else (
    echo Installing with system Python
    for /f "delims=" %%i in (%requirements_txt%) do (
        pip install "%%i"
    )
)

pause

================================================
FILE: nodes.py
================================================
import math
import copy
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from pymatting import estimate_alpha_cf, estimate_foreground_ml, fix_trimap
from tqdm import trange

try:
    from cv2.ximgproc import guidedFilter
except ImportError:
    print("\033[33mUnable to import guidedFilter, make sure you have only opencv-contrib-python or run the import_error_install.bat script\033[m")

import comfy.model_management
import node_helpers
from server import PromptServer
from comfy.utils import ProgressBar
from comfy_extras.nodes_post_processing import gaussian_kernel
from .raft import *

MAX_RESOLUTION=8192

# gaussian blur a tensor image batch in format [B x H x W x C] on H/W (spatial, per-image, per-channel)
def cv_blur_tensor(images, dx, dy):
    if min(dx, dy) > 100:
        np_img = F.interpolate(images.detach().clone().movedim(-1,1), scale_factor=0.1, mode='bilinear').movedim(1,-1).cpu().numpy()
        for index, image in enumerate(np_img):
            np_img[index] = cv2.GaussianBlur(image, (dx // 20 * 2 + 1, dy // 20 * 2 + 1), 0)
        return F.interpolate(torch.from_numpy(np_img).movedim(-1,1), size=(images.shape[1], images.shape[2]), mode='bilinear').movedim(1,-1)
    else:
        np_img = images.detach().clone().cpu().numpy()
        for index, image in enumerate(np_img):
            np_img[index] = cv2.GaussianBlur(image, (dx, dy), 0)
        return torch.from_numpy(np_img)

# guided filter a tensor image batch in format [B x H x W x C] on H/W (spatial, per-image, per-channel)
def guided_filter_tensor(ref, images, d, s):
    if d > 100:
        np_img = F.interpolate(images.detach().clone().movedim(-1,1), scale_factor=0.1, mode='bilinear').movedim(1,-1).cpu().numpy()
        np_ref = F.interpolate(ref.detach().clone().movedim(-1,1), scale_factor=0.1, mode='bilinear').movedim(1,-1).cpu().numpy()
        for index, image in enumerate(np_img):
            np_img[index] = guidedFilter(np_ref[index], image, d // 20 * 2 + 1, s)
        return F.interpolate(torch.from_numpy(np_img).movedim(-1,1), size=(images.shape[1], images.shape[2]), mode='bilinear').movedim(1,-1)
    else:
        np_img = images.detach().clone().cpu().numpy()
        np_ref = ref.cpu().numpy()
        for index, image in enumerate(np_img):
            np_img[index] = guidedFilter(np_ref[index], image, d, s)
        return torch.from_numpy(np_img)

# std_dev and mean of tensor t within local spatial filter size d, per-image, per-channel [B x H x W x C]
def std_mean_filter(t, d):
    t_mean = cv_blur_tensor(t, d, d)
    t_diff_squared = (t - t_mean) ** 2
    t_std = torch.sqrt(cv_blur_tensor(t_diff_squared, d, d))
    return t_std, t_mean

def RGB2YCbCr(t):
    YCbCr = t.detach().clone()
    YCbCr[:,:,:,0] = 0.2123 * t[:,:,:,0] + 0.7152 * t[:,:,:,1] + 0.0722 * t[:,:,:,2]
    YCbCr[:,:,:,1] = 0 - 0.1146 * t[:,:,:,0] - 0.3854 * t[:,:,:,1] + 0.5 * t[:,:,:,2]
    YCbCr[:,:,:,2] = 0.5 * t[:,:,:,0] - 0.4542 * t[:,:,:,1] - 0.0458 * t[:,:,:,2]
    return YCbCr

def YCbCr2RGB(t):
    RGB = t.detach().clone()
    RGB[:,:,:,0] = t[:,:,:,0] + 1.5748 * t[:,:,:,2]
    RGB[:,:,:,1] = t[:,:,:,0] - 0.1873 * t[:,:,:,1] - 0.4681 * t[:,:,:,2]
    RGB[:,:,:,2] = t[:,:,:,0] + 1.8556 * t[:,:,:,1]
    return RGB

def hsv_to_rgb(h, s, v):
    if s:
        if h == 1.0: h = 0.0
        i = int(h*6.0)
        f = h*6.0 - i
        
        w = v * (1.0 - s)
        q = v * (1.0 - s * f)
        t = v * (1.0 - s * (1.0 - f))
        
        if i==0: return (v, t, w)
        if i==1: return (q, v, w)
        if i==2: return (w, v, t)
        if i==3: return (w, q, v)
        if i==4: return (t, w, v)
        if i==5: return (v, w, q)
    else: return (v, v, v)

def sRGBtoLinear(npArray):
    less = npArray <= 0.0404482362771082
    npArray[less] = npArray[less] / 12.92
    npArray[~less] = np.power((npArray[~less] + 0.055) / 1.055, 2.4)

def linearToSRGB(npArray):
    less = npArray <= 0.0031308
    npArray[less] = npArray[less] * 12.92
    npArray[~less] = np.power(npArray[~less], 1/2.4) * 1.055 - 0.055

def sRGBtoLinear_pt(t: torch.Tensor):
    less = t <= 0.0404482362771082
    t[less] = t[less] / 12.92
    t[~less] = torch.pow((t[~less] + 0.055) / 1.055, 2.4)
    return t

def linearToSRGB_pt(t: torch.Tensor):
    less = t <= 0.0031308
    t[less] = t[less] * 12.92
    t[~less] = torch.pow(t[~less], 1 / 2.4) * 1.055 - 0.055
    return t

def linearToTonemap(npArray, tonemap_scale):
    npArray /= tonemap_scale
    more = npArray > 0.06
    SLog3 = np.clip((np.log10((npArray + 0.01)/0.19) * 261.5 + 420) / 1023, 0, 1)
    npArray[more] = np.power(1 / (1 + (1 / np.power(SLog3[more] / (1 - SLog3[more]), 1.7))), 1.7)
    npArray *= tonemap_scale

def tonemapToLinear(npArray, tonemap_scale):
    npArray /= tonemap_scale
    more = npArray > 0.06
    x = np.power(np.clip(npArray, 0.000001, 1), 1/1.7)
    ut = 1 / (1 + np.power((-1 / x) * (x - 1), 1/1.7))
    npArray[more] = np.power(10, (ut[more] * 1023 - 420)/261.5) * 0.19 - 0.01
    npArray *= tonemap_scale

def exposure(npArray, stops):
    more = npArray > 0
    npArray[more] *= pow(2, stops)

def randn_like_g(x, generator=None):
    device = generator.device if generator is not None else x.device
    r = torch.randn(x.size(), generator=generator, dtype=x.dtype, layout=x.layout, device=device)
    return r.to(x.device)


class AlphaClean:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "radius": ("INT", {"default": 8, "min": 1, "max": 64, "step": 1}),
                "fill_holes": ("INT", {"default": 1, "min": 0, "max": 16, "step": 1}),
                "white_threshold": ("FLOAT", {"default": 0.9, "min": 0.01, "max": 1.0, "step": 0.01}),
                "extra_clip": ("FLOAT", {"default": 0.98, "min": 0.01, "max": 1.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "alpha_clean"
    CATEGORY = "Image-Filters/image"
    DEPRECATED = True

    def alpha_clean(self, images: torch.Tensor, radius: int, fill_holes: int, white_threshold: float, extra_clip: float):
        d = radius * 2 + 1
        i_dup = copy.deepcopy(images.cpu().numpy())
        
        for index, image in enumerate(i_dup):
            
            cleaned = cv2.bilateralFilter(image, 9, 0.05, 8)
            
            alpha = np.clip((image - white_threshold) / (1 - white_threshold), 0, 1)
            rgb = image * alpha
            
            alpha = cv2.GaussianBlur(alpha, (d,d), 0) * 0.99 + np.average(alpha) * 0.01
            rgb = cv2.GaussianBlur(rgb, (d,d), 0) * 0.99 + np.average(rgb) * 0.01
            
            rgb = rgb / np.clip(alpha, 0.00001, 1)
            rgb = rgb * extra_clip
            
            cleaned = np.clip(cleaned / rgb, 0, 1)
            
            if fill_holes > 0:
                fD = fill_holes * 2 + 1
                gamma = cleaned * cleaned
                kD = np.ones((fD, fD), np.uint8)
                kE = np.ones((fD + 2, fD + 2), np.uint8)
                gamma = cv2.dilate(gamma, kD, iterations=1)
                gamma = cv2.erode(gamma, kE, iterations=1)
                gamma = cv2.GaussianBlur(gamma, (fD, fD), 0)
                cleaned = np.maximum(cleaned, gamma)

            i_dup[index] = cleaned
        
        return (torch.from_numpy(i_dup),)


class MaskClean:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "mask": ("MASK",),
                "radius": ("INT", {"default": 8, "min": 1, "max": 64, "step": 1}),
                "fill_holes": ("INT", {"default": 1, "min": 0, "max": 16, "step": 1}),
                "white_threshold": ("FLOAT", {"default": 0.9, "min": 0.001, "max": 1.0, "step": 0.001}),
                "extra_clip": ("FLOAT", {"default": 0.98, "min": 0.001, "max": 1.0, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "alpha_clean"
    CATEGORY = "Image-Filters/mask"

    def alpha_clean(self, mask, radius, fill_holes, white_threshold, extra_clip):
        d = radius * 2 + 1
        i_dup = mask.cpu().numpy()
        
        for index, image in enumerate(i_dup):
            cleaned = cv2.bilateralFilter(image, 9, 0.05, 8)
            
            alpha = np.clip((image - white_threshold) / (1 - white_threshold), 0, 1)
            rgb = image * alpha
            
            alpha = cv2.GaussianBlur(alpha, (d,d), 0) * 0.99 + np.average(alpha) * 0.01
            rgb = cv2.GaussianBlur(rgb, (d,d), 0) * 0.99 + np.average(rgb) * 0.01
            
            rgb = rgb / np.clip(alpha, 0.00001, 1)
            rgb = rgb * extra_clip
            
            cleaned = np.clip(cleaned / rgb, 0, 1)
            
            if fill_holes > 0:
                fD = fill_holes * 2 + 1
                gamma = cleaned * cleaned
                kD = np.ones((fD, fD), np.uint8)
                kE = np.ones((fD + 2, fD + 2), np.uint8)
                gamma = cv2.dilate(gamma, kD, iterations=1)
                gamma = cv2.erode(gamma, kE, iterations=1)
                gamma = cv2.GaussianBlur(gamma, (fD, fD), 0)
                cleaned = np.maximum(cleaned, gamma)

            i_dup[index] = cleaned
        
        return (torch.from_numpy(i_dup),)


class AlphaMatte:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "alpha_trimap": ("IMAGE",),
                "preblur": ("INT", {"default": 8, "min": 0, "max": 256, "step": 1}),
                "blackpoint": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 0.99, "step": 0.01}),
                "whitepoint": ("FLOAT", {"default": 0.99, "min": 0.01, "max": 1.0, "step": 0.01}),
                "max_iterations": ("INT", {"default": 1000, "min": 100, "max": 10000, "step": 100}),
                "estimate_fg": (["true", "false"],),
            },
        }

    RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE",)
    RETURN_NAMES = ("alpha", "fg", "bg",)
    FUNCTION = "alpha_matte"
    CATEGORY = "Image-Filters/image"
    DEPRECATED = True

    def alpha_matte(self, images, alpha_trimap, preblur, blackpoint, whitepoint, max_iterations, estimate_fg):
        d = preblur * 2 + 1
        
        i_dup = images.cpu().numpy().astype(np.float64)
        a_dup = alpha_trimap.cpu().numpy().astype(np.float64)
        fg = images.cpu().numpy().astype(np.float64)
        bg = images.cpu().numpy().astype(np.float64)
        
        
        for index, image in enumerate(i_dup):
            trimap = a_dup[index][:,:,0] # convert to single channel
            if preblur > 0:
                trimap = cv2.GaussianBlur(trimap, (d, d), 0)
            trimap = fix_trimap(trimap, blackpoint, whitepoint)
            
            alpha = estimate_alpha_cf(image, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter":max_iterations})
            
            if estimate_fg == "true":
                fg[index], bg[index] = estimate_foreground_ml(image, alpha, return_background=True)
            
            a_dup[index] = np.stack([alpha, alpha, alpha], axis = -1) # convert back to rgb
        
        return (
            torch.from_numpy(a_dup.astype(np.float32)), # alpha
            torch.from_numpy(fg.astype(np.float32)), # fg
            torch.from_numpy(bg.astype(np.float32)), # bg
            )


class ImageMatting:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "trimap": ("MASK",),
                "preblur": ("INT", {"default": 8, "min": 0, "max": 256, "step": 1}),
                "blackpoint": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 0.99, "step": 0.01}),
                "whitepoint": ("FLOAT", {"default": 0.99, "min": 0.01, "max": 1.0, "step": 0.01}),
                "max_iterations": ("INT", {"default": 1000, "min": 10, "max": 10000, "step": 10}),
                "estimate_fg": ("BOOLEAN", {"default": True}),
            },
        }

    RETURN_TYPES = ("MASK", "IMAGE", "IMAGE",)
    RETURN_NAMES = ("matte", "fg", "bg",)
    FUNCTION = "alpha_matte"
    CATEGORY = "Image-Filters/image"

    def alpha_matte(self, images, trimap, preblur, blackpoint, whitepoint, max_iterations, estimate_fg):
        d = preblur * 2 + 1
        
        i_dup = images.cpu().numpy().astype(np.float64)
        a_dup = trimap.cpu().numpy().astype(np.float64)
        fg = copy.deepcopy(i_dup)
        bg = copy.deepcopy(i_dup)
        
        
        for index, image in enumerate(i_dup):
            trimap = a_dup[index]
            if preblur > 0:
                trimap = cv2.GaussianBlur(trimap, (d, d), 0)
            trimap = fix_trimap(trimap, blackpoint, whitepoint)
            
            alpha = estimate_alpha_cf(image, trimap, laplacian_kwargs={"epsilon": 1e-6}, cg_kwargs={"maxiter":max_iterations})
            
            if estimate_fg:
                fg[index], bg[index] = estimate_foreground_ml(image, alpha, return_background=True)
            
            a_dup[index] = alpha
        
        return (
            torch.from_numpy(a_dup.astype(np.float32)), # matte
            torch.from_numpy(fg.astype(np.float32)), # fg
            torch.from_numpy(bg.astype(np.float32)), # bg
            )


class BetterFilmGrain:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "scale": ("FLOAT", {"default": 0.5, "min": 0.25, "max": 2.0, "step": 0.05}),
                "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step": 0.01}),
                "saturation": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.01}),
                "toe": ("FLOAT", {"default": 0.0, "min": -0.2, "max": 0.5, "step": 0.001}),
                "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "grain"
    CATEGORY = "Image-Filters/image"

    def grain(self, image, scale, strength, saturation, toe, seed):
        t = image.detach().clone()
        torch.manual_seed(seed)
        grain = torch.rand(t.shape[0], int(t.shape[1] // scale), int(t.shape[2] // scale), 3)
        
        YCbCr = RGB2YCbCr(grain)
        YCbCr[:,:,:,0] = cv_blur_tensor(YCbCr[:,:,:,0], 3, 3)
        YCbCr[:,:,:,1] = cv_blur_tensor(YCbCr[:,:,:,1], 15, 15)
        YCbCr[:,:,:,2] = cv_blur_tensor(YCbCr[:,:,:,2], 11, 11)
        
        grain = (YCbCr2RGB(YCbCr) - 0.5) * strength
        grain[:,:,:,0] *= 2
        grain[:,:,:,2] *= 3
        grain += 1
        grain = grain * saturation + grain[:,:,:,1].unsqueeze(3).repeat(1,1,1,3) * (1 - saturation)
        
        grain = F.interpolate(grain.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
        t[:,:,:,:3] = torch.clip((1 - (1 - t[:,:,:,:3]) * grain) * (1 - toe) + toe, 0, 1)
        return(t,)


class BlurImageFast:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "radius_x": ("INT", {"default": 1, "min": 0, "max": 1023, "step": 1}),
                "radius_y": ("INT", {"default": 1, "min": 0, "max": 1023, "step": 1}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "blur_image"
    CATEGORY = "Image-Filters/image"

    def blur_image(self, images, radius_x, radius_y):
        if radius_x + radius_y == 0:
            return (images,)
        
        dx = radius_x * 2 + 1
        dy = radius_y * 2 + 1
        
        dup = copy.deepcopy(images.cpu().numpy())
        
        for index, image in enumerate(dup):
            dup[index] = cv2.GaussianBlur(image, (dx, dy), 0)
        
        return (torch.from_numpy(dup),)


class BlurMaskFast:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "masks": ("MASK",),
                "radius_x": ("INT", {"default": 1, "min": 0, "max": 1023, "step": 1}),
                "radius_y": ("INT", {"default": 1, "min": 0, "max": 1023, "step": 1}),
            },
        }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "blur_mask"
    CATEGORY = "Image-Filters/mask"

    def blur_mask(self, masks, radius_x, radius_y):
        if radius_x + radius_y == 0:
            return (masks,)
        
        dx = radius_x * 2 + 1
        dy = radius_y * 2 + 1
        
        dup = copy.deepcopy(masks.cpu().numpy())
        
        for index, mask in enumerate(dup):
            dup[index] = cv2.GaussianBlur(mask, (dx, dy), 0)
        
        return (torch.from_numpy(dup),)


class ColorMatchImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "reference": ("IMAGE", ),
                "blur_type": (["blur", "guidedFilter"],),
                "blur_size": ("INT", {"default": 0, "min": 0, "max": 1023}),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/image"

    def batch_normalize(self, images, reference, blur_type, blur_size, factor):
        t = images.detach().clone() + 0.1
        ref = reference.detach().clone() + 0.1
        
        if ref.shape[0] < t.shape[0]:
            ref = ref[0].unsqueeze(0).repeat(t.shape[0], 1, 1, 1)
        
        if blur_size == 0:
            mean = torch.mean(t, (1,2), keepdim=True)
            mean_ref = torch.mean(ref, (1,2), keepdim=True)
            
            for i in range(t.shape[0]):
                for c in range(3):
                    t[i,:,:,c] /= mean[i,0,0,c]
                    t[i,:,:,c] *= mean_ref[i,0,0,c]
        else:
            d = blur_size * 2 + 1
            
            if blur_type == "blur":
                blurred = cv_blur_tensor(t, d, d)
                blurred_ref = cv_blur_tensor(ref, d, d)
            elif blur_type == "guidedFilter":
                blurred = guided_filter_tensor(t, t, d, 0.01)
                blurred_ref = guided_filter_tensor(ref, ref, d, 0.01)
            
            for i in range(t.shape[0]):
                for c in range(3):
                    t[i,:,:,c] /= blurred[i,:,:,c]
                    t[i,:,:,c] *= blurred_ref[i,:,:,c]
        
        t = t - 0.1
        torch.clamp(torch.lerp(images, t, factor), 0, 1)
        return (t,)


class RestoreDetail:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "detail": ("IMAGE", ),
                "mode": (["add", "multiply"],),
                "blur_type": (["blur", "guidedFilter"],),
                "blur_size": ("INT", {"default": 1, "min": 1, "max": 1023}),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/image"

    def batch_normalize(self, images, detail, mode, blur_type, blur_size, factor):
        t = images.detach().clone() + 0.1
        ref = detail.detach().clone() + 0.1
        
        if ref.shape[0] < t.shape[0]:
            ref = ref[0].unsqueeze(0).repeat(t.shape[0], 1, 1, 1)
        
        d = blur_size * 2 + 1
        
        if blur_type == "blur":
            blurred = cv_blur_tensor(t, d, d)
            blurred_ref = cv_blur_tensor(ref, d, d)
        elif blur_type == "guidedFilter":
            blurred = guided_filter_tensor(t, t, d, 0.01)
            blurred_ref = guided_filter_tensor(ref, ref, d, 0.01)
        
        if mode == "multiply":
            t = (ref / blurred_ref) * blurred
        else:
            t = (ref - blurred_ref) + blurred
        
        t = t - 0.1
        t = torch.clamp(torch.lerp(images, t, factor), 0, 1)
        return (t,)


class DilateErodeMask:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "masks": ("MASK",),
                "radius": ("INT", {"default": 0, "min": -1023, "max": 1023, "step": 1}),
                "shape": (["box", "circle"],),
            },
        }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "dilate_mask"
    CATEGORY = "Image-Filters/mask"

    def dilate_mask(self, masks, radius, shape):
        if radius == 0:
            return (masks,)
        
        s = abs(radius)
        d = s * 2 + 1
        k = np.zeros((d, d), np.uint8)
        if shape == "circle":
            k = cv2.circle(k, (s,s), s, 1, -1)
        else:
            k += 1
        
        dup = copy.deepcopy(masks.cpu().numpy())
        
        for index, mask in enumerate(dup):
            if radius > 0:
                dup[index] = cv2.dilate(mask, k, iterations=1)
            else:
                dup[index] = cv2.erode(mask, k, iterations=1)
        
        return (torch.from_numpy(dup),)


class EnhanceDetail:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "filter_radius": ("INT", {"default": 2, "min": 1, "max": 64, "step": 1}),
                "sigma": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 100.0, "step": 0.01}),
                "denoise": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 10.0, "step": 0.01}),
                "detail_mult": ("FLOAT", {"default": 2.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "enhance"
    CATEGORY = "Image-Filters/image"

    def enhance(self, images: torch.Tensor, filter_radius: int, sigma: float, denoise: float, detail_mult: float):
        if filter_radius == 0:
            return (images,)
        
        d = filter_radius * 2 + 1
        s = sigma / 10
        n = denoise / 10
        
        dup = copy.deepcopy(images.cpu().numpy())
        
        for index, image in enumerate(dup):
            imgB = image
            if denoise > 0.0:
                imgB = cv2.bilateralFilter(image, d, n, d)
            
            imgG = np.clip(guidedFilter(image, image, d, s), 0.001, 1)
            
            details = (imgB/imgG - 1) * detail_mult + 1
            dup[index] = np.clip(details*imgG - imgB + image, 0, 1)
        
        return (torch.from_numpy(dup),)


class GuidedFilterImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "guide": ("IMAGE", ),
                "size": ("INT", {"default": 4, "min": 0, "max": 1023}),
                "sigma": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 100.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "filter_image"
    CATEGORY = "Image-Filters/image"

    def filter_image(self, images, guide, size, sigma):
        d = size * 2 + 1
        s = sigma / 10
        filtered = guided_filter_tensor(guide, images, d, s)
        return (filtered,)


class MedianFilterImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "size": ("INT", {"default": 1, "min": 1, "max": 1023}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "filter_image"
    CATEGORY = "Image-Filters/image"

    def filter_image(self, images, size):
        np_images = images.detach().clone().cpu().numpy()
        d = size * 2 + 1
        for index, image in enumerate(np_images):
            if d > 5:
                work_image = image * 255
                work_image = cv2.medianBlur(work_image.astype(np.uint8), d)
                np_images[index] = work_image.astype(np.float32) / 255
            else:
                np_images[index] = cv2.medianBlur(image, d)
        return (torch.from_numpy(np_images),)


class BilateralFilterImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "size": ("INT", {"default": 8, "min": 1, "max": 64}),
                "sigma_color": ("FLOAT", {"default": 0.5, "min": 0.01, "max": 1000.0, "step": 0.01}),
                "sigma_space": ("FLOAT", {"default": 100.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "filter_image"
    CATEGORY = "Image-Filters/image"

    def filter_image(self, images, size, sigma_color, sigma_space):
        np_images = images.detach().clone().cpu().numpy()
        d = size * 2 + 1
        for index, image in enumerate(np_images):
            np_images[index] = cv2.bilateralFilter(image, d, sigma_color, sigma_space)
        return (torch.from_numpy(np_images),)


class FrequencyCombine:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "high_frequency": ("IMAGE", ),
                "low_frequency": ("IMAGE", ),
                "mode": (["subtract", "divide"],),
                "eps": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 0.99, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "filter_image"
    CATEGORY = "Image-Filters/image"

    def filter_image(self, high_frequency, low_frequency, mode, eps):
        t = low_frequency.detach().clone()
        if mode == "subtract":
            t = t + high_frequency - 0.5
        else:
            t = (high_frequency * 2) * (t + eps) - eps
        return (torch.clamp(t, 0, 1),)


class FrequencySeparate:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "original": ("IMAGE", ),
                "low_frequency": ("IMAGE", ),
                "mode": (["subtract", "divide"],),
                "eps": ("FLOAT", {"default": 0.1, "min": 0.01, "max": 0.99, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("high_frequency",)
    FUNCTION = "filter_image"
    CATEGORY = "Image-Filters/image"

    def filter_image(self, original, low_frequency, mode, eps):
        t = original.detach().clone()
        if mode == "subtract":
            t = t - low_frequency + 0.5
        else:
            t = ((t + eps) / (low_frequency + eps)) * 0.5
        return (t,)


class RemapRange:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "blackpoint": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
                "whitepoint": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 1.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "remap"
    CATEGORY = "Image-Filters/image"

    def remap(self, image: torch.Tensor, blackpoint: float, whitepoint: float):
        bp = min(blackpoint, whitepoint - 0.001)
        scale = 1 / (whitepoint - bp)
        
        i_dup = copy.deepcopy(image.cpu().numpy())
        i_dup = np.clip((i_dup - bp) * scale, 0.0, 1.0)
        
        return (torch.from_numpy(i_dup),)


class ClampImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "blackpoint": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "whitepoint": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "clamp_image"
    CATEGORY = "Image-Filters/image"

    def clamp_image(self, image: torch.Tensor, blackpoint: float, whitepoint: float):
        clamped_image = torch.clamp(torch.nan_to_num(image.detach().clone()), min=blackpoint, max=whitepoint)
        return (clamped_image,)


Channel_List = ["red", "green", "blue", "alpha", "white", "black"]
Alpha_List = ["red", "green", "blue", "alpha", "white", "black", "none"]

class ShuffleChannels:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "red": (Channel_List, {"default": "red"}),
                "green": (Channel_List, {"default": "green"}),
                "blue": (Channel_List, {"default": "blue"}),
                "alpha": (Alpha_List, {"default": "none"}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "shuffle"
    CATEGORY = "Image-Filters/image"

    def shuffle(self, image, red, green, blue, alpha):
        ch = 3 if alpha == "none" else 4
        t = torch.zeros((image.shape[0], image.shape[1], image.shape[2], ch), dtype=image.dtype, device=image.device)
        image_copy = image.detach().clone()
        
        ch_key = [red, green, blue, alpha]
        for i in range(ch):
            if ch_key[i] == "white":
                t[:,:,:,i] = 1
            elif ch_key[i] == "red":
                t[:,:,:,i] = image_copy[:,:,:,0]
            elif ch_key[i] == "green":
                t[:,:,:,i] = image_copy[:,:,:,1]
            elif ch_key[i] == "blue":
                t[:,:,:,i] = image_copy[:,:,:,2]
            elif ch_key[i] == "alpha":
                if image.shape[3] > 3:
                    t[:,:,:,i] = image_copy[:,:,:,3]
                else:
                    t[:,:,:,i] = 1
        
        return(t,)


class ClampOutliers:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "std_dev": ("FLOAT", {"default": 3.0, "min": 0.1, "max": 100.0, "step": 0.1,  "round": 0.1}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "clamp_outliers"
    CATEGORY = "Image-Filters/latent"

    def clamp_outliers(self, latents, std_dev):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"]
        
        for i, latent in enumerate(t):
            for j, channel in enumerate(latent):
                sd, mean = torch.std_mean(channel, dim=None)
                t[i,j] = torch.clamp(channel, min = -sd * std_dev + mean, max = sd * std_dev + mean)
        
        latents_copy["samples"] = t
        return (latents_copy,)


class AdainLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "reference": ("LATENT", ),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/latent"

    def batch_normalize(self, latents, reference, factor):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"]
        
        t_std, t_mean = torch.std_mean(t, dim=(-2, -1), keepdim=True)
        ref_std, ref_mean = torch.std_mean(reference["samples"], dim=(-2, -1), keepdim=True)
        t = (t - t_mean) / t_std
        t = t * ref_std + ref_mean
        
        latents_copy["samples"] = torch.lerp(latents["samples"], t, factor)
        return (latents_copy,)


class AdainFilterLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "reference": ("LATENT", ),
                "filter_size": ("INT", {"default": 1, "min": 1, "max": 128}),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/latent"

    def batch_normalize(self, latents, reference, filter_size, factor):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"].movedim(1, -1) # BCHW -> BHWC or BCFHW -> BFHWC
        ref = reference["samples"].movedim(1, -1)
        d = filter_size * 2 + 1
        
        if t.dim() == 5:
            t_std, t_mean, ref_std, ref_mean = [], [], [], []
            for b in range(t.shape[0]):
                tb_std, tb_mean = std_mean_filter(t[b], d)
                rb_std, rb_mean = std_mean_filter(ref[b], d)
                t_std.append(tb_std)
                t_mean.append(tb_mean)
                ref_std.append(rb_std)
                ref_mean.append(rb_mean)
            t_std = torch.stack(t_std, dim=0)
            t_mean = torch.stack(t_mean, dim=0)
            ref_std = torch.stack(ref_std, dim=0)
            ref_mean = torch.stack(ref_mean, dim=0)
        else:
            t_std, t_mean = std_mean_filter(t, d)
            ref_std, ref_mean = std_mean_filter(ref, d)
        
        t = (t - t_mean) / t_std
        t = t * ref_std + ref_mean
        t = t.movedim(-1, 1) # BHWC -> BCHW or BFHWC -> BCFHW
        
        latents_copy["samples"] = torch.lerp(latents["samples"], t, factor)
        return (latents_copy,)


class SharpenFilterLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "filter_size": ("INT", {"default": 1, "min": 1, "max": 128}),
                "factor": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "filter_latent"
    CATEGORY = "Image-Filters/latent"

    def filter_latent(self, latents, filter_size, factor):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"].movedim(1, -1) # BCHW -> BHWC or BCFHW -> BFHWC
        d = filter_size * 2 + 1
        
        if t.dim() == 5:
            t_blurred = []
            for b in range(t.shape[0]):
                t_blurred.append(cv_blur_tensor(t[b], d, d))
            t_blurred = torch.stack(t_blurred, dim=0)
        else:
            t_blurred = cv_blur_tensor(t, d, d)
        
        t = t - t_blurred
        t = t * factor
        t = t + t_blurred
        
        t = t.movedim(-1, 1) # BHWC -> BCHW or BFHWC -> BCFHW
        latents_copy["samples"] = t
        return (latents_copy,)


class AdainImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "reference": ("IMAGE", ),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/image"

    def batch_normalize(self, images, reference, factor):
        t = copy.deepcopy(images) # [B x H x W x C]
        
        t = t.movedim(-1,0) # [C x B x H x W]
        for c in range(t.size(0)):
            for i in range(t.size(1)):
                r_sd, r_mean = torch.std_mean(reference[i, :, :, c], dim=None) # index by original dim order
                i_sd, i_mean = torch.std_mean(t[c, i], dim=None)
                
                t[c, i] = ((t[c, i] - i_mean) / i_sd) * r_sd + r_mean
        
        t = torch.lerp(images, t.movedim(0,-1), factor) # [B x H x W x C]
        return (t,)


class BatchNormalizeLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/latent"

    def batch_normalize(self, latents, factor):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"] # [B x C x H x W]
        
        t = t.movedim(0,1) # [C x B x H x W]
        for c in range(t.size(0)):
            c_sd, c_mean = torch.std_mean(t[c], dim=None)
            
            for i in range(t.size(1)):
                i_sd, i_mean = torch.std_mean(t[c, i], dim=None)
                t[c, i] = (t[c, i] - i_mean) / i_sd
            
            t[c] = t[c] * c_sd + c_mean
        
        latents_copy["samples"] = torch.lerp(latents["samples"], t.movedim(1,0), factor) # [B x C x H x W]
        return (latents_copy,)


class BatchNormalizeImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE", ),
                "factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01,  "round": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/image"

    def batch_normalize(self, images, factor):
        t = copy.deepcopy(images) # [B x H x W x C]
        
        t = t.movedim(-1,0) # [C x B x H x W]
        for c in range(t.size(0)):
            c_sd, c_mean = torch.std_mean(t[c], dim=None)
            
            for i in range(t.size(1)):
                i_sd, i_mean = torch.std_mean(t[c, i], dim=None)
                
                t[c, i] = (t[c, i] - i_mean) / i_sd
            
            t[c] = t[c] * c_sd + c_mean
        
        t = torch.lerp(images, t.movedim(0,-1), factor) # [B x H x W x C]
        return (t,)


class DifferenceChecker:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images1": ("IMAGE", ),
                "images2": ("IMAGE", ),
                "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 1000.0, "step": 0.01,  "round": 0.01}),
                "print_MAE": ("BOOLEAN", {"default": False}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "difference_checker"
    OUTPUT_NODE = True
    CATEGORY = "Image-Filters/image"

    def difference_checker(self, images1, images2, multiplier, print_MAE):
        t = copy.deepcopy(images1)
        t = torch.abs(images1 - images2)
        if print_MAE:
            print(f"MAE = {torch.mean(t)}")
        return (torch.clamp(t * multiplier, min=0, max=1),)


class ImageConstant:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
                "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
                "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
                "red": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "green": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "blue": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "generate"
    CATEGORY = "Image-Filters/image"

    def generate(self, width, height, batch_size, red, green, blue):
        r = torch.full([batch_size, height, width, 1], red)
        g = torch.full([batch_size, height, width, 1], green)
        b = torch.full([batch_size, height, width, 1], blue)
        return (torch.cat((r, g, b), dim=-1), )


class ImageConstantHSV:
    def __init__(self, device="cpu"):
        self.device = device

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
                "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
                "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
                "hue": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "saturation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "value": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "generate"
    CATEGORY = "Image-Filters/image"

    def generate(self, width, height, batch_size, hue, saturation, value):
        red, green, blue = hsv_to_rgb(hue, saturation, value)
        
        r = torch.full([batch_size, height, width, 1], red)
        g = torch.full([batch_size, height, width, 1], green)
        b = torch.full([batch_size, height, width, 1], blue)
        return (torch.cat((r, g, b), dim=-1), )


class OffsetLatentImage:
    def __init__(self):
        self.device = comfy.model_management.intermediate_device()

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "width": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                "height": ("INT", {"default": 512, "min": 16, "max": MAX_RESOLUTION, "step": 8}),
                "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
                "offset_0": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1,  "round": 0.1}),
                "offset_1": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1,  "round": 0.1}),
                "offset_2": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1,  "round": 0.1}),
                "offset_3": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.1,  "round": 0.1}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "generate"
    CATEGORY = "Image-Filters/latent"

    def generate(self, width, height, batch_size, offset_0, offset_1, offset_2, offset_3):
        latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
        latent[:,0,:,:] = offset_0
        latent[:,1,:,:] = offset_1
        latent[:,2,:,:] = offset_2
        latent[:,3,:,:] = offset_3
        return ({"samples":latent}, )


class RelightSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "normals": ("IMAGE",),
                "x": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
                "y": ("FLOAT", {"default": 0.0, "min": -1.0, "max": 1.0, "step": 0.001}),
                "z": ("FLOAT", {"default": 1.0, "min": -1.0, "max": 1.0, "step": 0.001}),
                "brightness": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "relight"
    CATEGORY = "Image-Filters/image"

    def relight(self, image, normals, x, y, z, brightness):
        if image.shape[0] != normals.shape[0]:
            raise Exception("Batch size for image and normals must match")
        norm = normals.detach().clone() * 2 - 1
        norm = F.interpolate(norm.movedim(-1,1), size=(image.shape[1], image.shape[2]), mode='bilinear').movedim(1,-1)
        light = torch.tensor([x, y, z])
        light = F.normalize(light, dim=0)
        
        diffuse = norm[:,:,:,0] * light[0] + norm[:,:,:,1] * light[1] + norm[:,:,:,2] * light[2]
        diffuse = torch.clip(diffuse.unsqueeze(3).repeat(1,1,1,3), 0, 1)
        
        relit = image.detach().clone()
        relit[:,:,:,:3] = torch.clip(relit[:,:,:,:3] * diffuse * brightness, 0, 1)
        return (relit,)


class LatentStats:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"latent": ("LATENT", ),}}

    RETURN_TYPES = ("STRING", "FLOAT", "FLOAT", "FLOAT", "FLOAT")
    RETURN_NAMES = ("stats", "c0_mean", "c1_mean", "c2_mean", "c3_mean")
    FUNCTION = "notify"
    OUTPUT_NODE = True
    CATEGORY = "Image-Filters/utils"

    def notify(self, latent):
        latents = latent["samples"]
        channels = latents.size(1)
        width, height = latents.size(3), latents.size(2)
        
        text = ["",]
        text[0] = f"batch size: {latents.size(0)}"
        text.append(f"channels: {channels}")
        text.append(f"width: {width} ({width * 8})")
        text.append(f"height: {height} ({height * 8})")
        
        cmean = [0,0,0,0]
        for i in range(channels):
            minimum = torch.min(latents[:,i,:,:]).item()
            maximum = torch.max(latents[:,i,:,:]).item()
            std_dev, mean = torch.std_mean(latents[:,i,:,:], dim=None)
            if i < 4:
                cmean[i] = mean
            
            text.append(f"c{i} mean: {mean:.1f} std_dev: {std_dev:.1f} min: {minimum:.1f} max: {maximum:.1f}")
        
        
        printtext = "\033[36mLatent Stats:\033[m"
        for t in text:
            printtext += "\n    " + t
        
        returntext = ""
        for i in range(len(text)):
            if i > 0:
                returntext += "\n"
            returntext += text[i]
        
        print(printtext)
        return (returntext, cmean[0], cmean[1], cmean[2], cmean[3])


class Tonemap:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "input_mode": (["linear", "sRGB"],),
                "output_mode": (["sRGB", "linear"],),
                "tonemap_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "apply"
    CATEGORY = "Image-Filters/image"

    def apply(self, images, input_mode, output_mode, tonemap_scale):
        t = images.detach().clone().cpu().numpy().astype(np.float32)
        
        if input_mode == "sRGB":
            sRGBtoLinear(t[:,:,:,:3])
        
        linearToTonemap(t[:,:,:,:3], tonemap_scale)
        
        if output_mode == "sRGB":
            linearToSRGB(t[:,:,:,:3])
            t = np.clip(t, 0, 1)
        
        t = torch.from_numpy(t)
        return (t,)


class UnTonemap:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "input_mode": (["sRGB", "linear"],),
                "output_mode": (["linear", "sRGB"],),
                "tonemap_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "apply"
    CATEGORY = "Image-Filters/image"

    def apply(self, images, input_mode, output_mode, tonemap_scale):
        t = images.detach().clone().cpu().numpy().astype(np.float32)
        
        if input_mode == "sRGB":
            sRGBtoLinear(t[:,:,:,:3])
        
        tonemapToLinear(t[:,:,:,:3], tonemap_scale)
        
        if output_mode == "sRGB":
            linearToSRGB(t[:,:,:,:3])
            t = np.clip(t, 0, 1)
        
        t = torch.from_numpy(t)
        return (t,)


class ExposureAdjust:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "stops": ("FLOAT", {"default": 0.0, "min": -100, "max": 100, "step": 0.01}),
                "input_mode": (["sRGB", "linear"],),
                "output_mode": (["sRGB", "linear"],),
                "tonemap": (["linear", "Reinhard", "linlog"], {"default": "Reinhard"}),
                "tonemap_scale": ("FLOAT", {"default": 1, "min": 0.1, "max": 10, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "adjust_exposure"
    CATEGORY = "Image-Filters/image"

    def adjust_exposure(self, images, stops, input_mode, output_mode, tonemap, tonemap_scale):
        t = images.detach().clone().cpu().numpy().astype(np.float32)
        
        if input_mode == "sRGB":
            sRGBtoLinear(t[...,:3])
        
        if tonemap == "linlog":
            tonemapToLinear(t[...,:3], tonemap_scale)
        elif tonemap == "Reinhard":
            t = np.clip(t, 0, 0.999)
            t[...,:3] = -t[...,:3] / (t[...,:3] - 1)
        
        exposure(t[...,:3], stops)
        
        if tonemap == "linlog":
            linearToTonemap(t[...,:3], tonemap_scale)
        elif tonemap == "Reinhard":
            t[...,:3] = t[...,:3] / (t[...,:3] + 1)
        
        if output_mode == "sRGB":
            linearToSRGB(t[...,:3])
            t = np.clip(t, 0, 1)
        
        t = torch.from_numpy(t)
        return (t,)


# Normal map standard coordinates: +r:+x:right, +g:+y:up, +b:+z:in
class ConvertNormals:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "normals": ("IMAGE",),
                "input_mode": (["BAE", "MiDaS", "Standard"],),
                "output_mode": (["BAE", "MiDaS", "Standard"],),
                "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
                "normalize": ("BOOLEAN", {"default": True}),
                "fix_black": ("BOOLEAN", {"default": True}),
            },
            "optional": {
                "optional_fill": ("IMAGE",),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "convert_normals"
    CATEGORY = "Image-Filters/image"

    def convert_normals(self, normals, input_mode, output_mode, scale_XY, normalize, fix_black, optional_fill=None):
        t = normals.detach().clone()
        
        if input_mode == "BAE":
            t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
        elif input_mode == "MiDaS":
            t[:,:,:,:3] = torch.stack([1 - t[:,:,:,2], t[:,:,:,1], t[:,:,:,0]], dim=3) # BGR -> RGB and invert R
        
        if fix_black:
            key = torch.clamp(1 - t[:,:,:,2] * 2, min=0, max=1)
            if optional_fill == None:
                t[:,:,:,0] += key * 0.5
                t[:,:,:,1] += key * 0.5
                t[:,:,:,2] += key
            else:
                fill = optional_fill.detach().clone()
                if fill.shape[1:3] != t.shape[1:3]:
                    fill = F.interpolate(fill.movedim(-1,1), size=(t.shape[1], t.shape[2]), mode='bilinear').movedim(1,-1)
                if fill.shape[0] != t.shape[0]:
                    fill = fill[0].unsqueeze(0).expand(t.shape[0], -1, -1, -1)
                t[:,:,:,:3] += fill[:,:,:,:3] * key.unsqueeze(3).expand(-1, -1, -1, 3)
        
        t[:,:,:,:2] = (t[:,:,:,:2] - 0.5) * scale_XY + 0.5
        
        if normalize:
            t[:,:,:,:3] = F.normalize(t[:,:,:,:3] * 2 - 1, dim=3) / 2 + 0.5
        
        if output_mode == "BAE":
            t[:,:,:,0] = 1 - t[:,:,:,0] # invert R
        elif output_mode == "MiDaS":
            t[:,:,:,:3] = torch.stack([t[:,:,:,2], t[:,:,:,1], 1 - t[:,:,:,0]], dim=3) # invert R and BGR -> RGB
        
        return (t,)


class BatchAverageImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "operation": (["mean", "median"],),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "apply"
    CATEGORY = "Image-Filters/image"

    def apply(self, images, operation):
        t = images.detach().clone()
        if operation == "mean":
            return (torch.mean(t, dim=0, keepdim=True),)
        elif operation == "median":
            return (torch.median(t, dim=0, keepdim=True)[0],)
        return(t,)


class NormalMapSimple:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "scale_XY": ("FLOAT",{"default": 1, "min": 0, "max": 100, "step": 0.001}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "normal_map"
    CATEGORY = "Image-Filters/image"

    def normal_map(self, images, scale_XY):
        t = images.detach().clone().cpu().numpy().astype(np.float32)
        L = np.mean(t[:,:,:,:3], axis=3)
        for i in range(t.shape[0]):
            t[i,:,:,0] = cv2.Scharr(L[i], -1, 1, 0, cv2.BORDER_REFLECT) * -1
            t[i,:,:,1] = cv2.Scharr(L[i], -1, 0, 1, cv2.BORDER_REFLECT)
        t[:,:,:,2] = 1
        t = torch.from_numpy(t)
        t[:,:,:,:2] *= scale_XY
        t[:,:,:,:3] = F.normalize(t[:,:,:,:3], dim=3) / 2 + 0.5
        return (t,)


class DepthToNormals:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "depth": ("IMAGE",),
                "scale": ("FLOAT",{"default": 1, "min": 0.001, "max": 1000, "step": 0.001}),
                "output_mode": (["Standard", "BAE", "MiDaS"],),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    RETURN_NAMES = ("normals",)
    FUNCTION = "normal_map"
    CATEGORY = "Image-Filters/image"

    def normal_map(self, depth, scale, output_mode):
        kernel_x = torch.Tensor([[0,0,0],[1,0,-1],[0,0,0]]).unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1)
        kernel_y = torch.Tensor([[0,1,0],[0,0,0],[0,-1,0]]).unsqueeze(0).unsqueeze(0).repeat(3, 1, 1, 1)
        conv2d = F.conv2d
        pad = F.pad
        
        size_x = depth.size(2)
        size_y = depth.size(1)
        max_dim = max(size_x, size_y)
        position_map = depth.detach().clone() * scale
        xs = torch.linspace(-1 * size_x / max_dim, 1 * size_x / max_dim, steps=size_x)
        ys = torch.linspace(-1 * size_y / max_dim, 1 * size_y / max_dim, steps=size_y)
        grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy')
        position_map[..., 0] = grid_x.unsqueeze(0)
        position_map[..., 1] = grid_y.unsqueeze(0)
        
        position_map = position_map.movedim(-1, 1) # BCHW
        grad_x = conv2d(pad(position_map, (1,1,1,1), mode='replicate'), kernel_x, padding='valid', groups=3)
        grad_y = conv2d(pad(position_map, (1,1,1,1), mode='replicate'), kernel_y, padding='valid', groups=3)
        
        cross_product = torch.cross(grad_x, grad_y, dim=1)
        normals = F.normalize(cross_product)
        normals[:, 1] *= -1
        
        if output_mode != "Standard":
            normals[:, 0] *= -1
        
        if output_mode == "MiDaS":
            normals = torch.flip(normals, dims=[1,])
        
        normals = normals.movedim(1, -1) * 0.5 + 0.5 # BHWC
        return (normals,)


class Keyer:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "operation": (["luminance", "saturation", "max", "min", "red", "green", "blue", "redscreen", "greenscreen", "bluescreen"],),
                "low": ("FLOAT",{"default": 0, "step": 0.001}),
                "high": ("FLOAT",{"default": 1, "step": 0.001}),
                "gamma": ("FLOAT",{"default": 1.0, "min": 0.001, "step": 0.001}),
                "premult": ("BOOLEAN", {"default": True}),
            },
        }

    RETURN_TYPES = ("IMAGE", "IMAGE", "MASK")
    RETURN_NAMES = ("image", "alpha", "mask")
    FUNCTION = "keyer"
    CATEGORY = "Image-Filters/image"

    def keyer(self, images, operation, low, high, gamma, premult):
        t = images[:,:,:,:3].detach().clone()
        
        if operation == "luminance":
            alpha = 0.2126 * t[:,:,:,0] + 0.7152 * t[:,:,:,1] + 0.0722 * t[:,:,:,2]
        elif operation == "saturation":
            minV = torch.min(t, 3)[0]
            maxV = torch.max(t, 3)[0]
            mask = maxV != 0
            alpha = maxV
            alpha[mask] = (maxV[mask] - minV[mask]) / maxV[mask]
        elif operation == "max":
            alpha = torch.max(t, 3)[0]
        elif operation == "min":
            alpha = torch.min(t, 3)[0]
        elif operation == "red":
            alpha = t[:,:,:,0]
        elif operation == "green":
            alpha = t[:,:,:,1]
        elif operation == "blue":
            alpha = t[:,:,:,2]
        elif operation == "redscreen":
            alpha = 0.7 * (t[:,:,:,1] + t[:,:,:,2]) - t[:,:,:,0] + 1
        elif operation == "greenscreen":
            alpha = 0.7 * (t[:,:,:,0] + t[:,:,:,2]) - t[:,:,:,1] + 1
        elif operation == "bluescreen":
            alpha = 0.7 * (t[:,:,:,0] + t[:,:,:,1]) - t[:,:,:,2] + 1
        else: # should never be reached
            alpha = t[:,:,:,0] * 0
        
        if low == high:
            alpha = (alpha > high).to(t.dtype)
        else:
            alpha = (alpha - low) / (high - low)
        
        if gamma != 1.0:
            alpha = torch.pow(alpha, 1/gamma)
        alpha = torch.clamp(alpha, min=0, max=1).unsqueeze(3).repeat(1,1,1,3)
        if premult:
            t *= alpha
        return (t, alpha, alpha[:,:,:,0])


jitter_matrix = torch.Tensor([[[1, 0, 0], [0, 1, 0]], [[1, 0, 1], [0, 1, 0]], [[1, 0, 1], [0, 1, 1]],
                              [[1, 0, 0], [0, 1, 1]], [[1, 0,-1], [0, 1, 1]], [[1, 0,-1], [0, 1, 0]],
                              [[1, 0,-1], [0, 1,-1]], [[1, 0, 0], [0, 1,-1]], [[1, 0, 1], [0, 1,-1]]])

class JitterImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "jitter_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "step": 0.1}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "jitter"
    CATEGORY = "Image-Filters/image/jitter"

    def jitter(self, images, jitter_scale):
        t = images.detach().clone().movedim(-1,1) # [B x C x H x W]
        
        theta = jitter_matrix.detach().clone().to(t.device)
        theta[:,0,2] *= jitter_scale * 2 / t.shape[3]
        theta[:,1,2] *= jitter_scale * 2 / t.shape[2]
        affine = F.affine_grid(theta, torch.Size([9, t.shape[1], t.shape[2], t.shape[3]]))
        
        batch = []
        for i in range(t.shape[0]):
            jb = t[i].repeat(9,1,1,1)
            jb = F.grid_sample(jb, affine, mode='bilinear', padding_mode='border', align_corners=None)
            batch.append(jb)
        
        t = torch.cat(batch, dim=0).movedim(1,-1) # [B x H x W x C]
        return (t,)


class UnJitterImage:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "jitter_scale": ("FLOAT", {"default": 1.0, "min": 0.1, "step": 0.1}),
                "oflow_align": ("BOOLEAN", {"default": False}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "jitter"
    CATEGORY = "Image-Filters/image/jitter"

    def jitter(self, images, jitter_scale, oflow_align):
        t = images.detach().clone().movedim(-1,1) # [B x C x H x W]
        
        if oflow_align:
            pbar = ProgressBar(t.shape[0] // 9)
            raft_model, raft_device = load_raft()
            batch = []
            for i in trange(t.shape[0] // 9):
                batch1 = t[i*9].unsqueeze(0).repeat(9,1,1,1)
                batch2 = t[i*9:i*9+9]
                flows = raft_flow(raft_model, raft_device, batch1, batch2)
                batch.append(flows)
                pbar.update(1)
            flows = torch.cat(batch, dim=0)
            t = flow_warp(t, flows)
        else:
            theta = jitter_matrix.detach().clone().to(t.device)
            theta[:,0,2] *= jitter_scale * -2 / t.shape[3]
            theta[:,1,2] *= jitter_scale * -2 / t.shape[2]
            affine = F.affine_grid(theta, torch.Size([9, t.shape[1], t.shape[2], t.shape[3]]))
            batch = []
            for i in range(t.shape[0] // 9):
                jb = t[i*9:i*9+9]
                jb = F.grid_sample(jb, affine, mode='bicubic', padding_mode='border', align_corners=None)
                batch.append(jb)
            t = torch.cat(batch, dim=0)
        
        t = t.movedim(1,-1) # [B x H x W x C]
        return (t,)


class BatchAverageUnJittered:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "operation": (["mean", "median"],),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "apply"
    CATEGORY = "Image-Filters/image/jitter"

    def apply(self, images, operation):
        t = images.detach().clone()
        
        batch = []
        for i in range(t.shape[0] // 9):
            if operation == "mean":
                batch.append(torch.mean(t[i*9:i*9+9], dim=0, keepdim=True))
            elif operation == "median":
                batch.append(torch.median(t[i*9:i*9+9], dim=0, keepdim=True)[0])
        
        return (torch.cat(batch, dim=0),)


class BatchAlign:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "ref_frame": ("INT", {"default": 0, "min": 0}),
                "direction": (["forward", "backward"],),
                "blur": ("INT", {"default": 0, "min": 0}),
            },
        }

    RETURN_TYPES = ("IMAGE", "IMAGE")
    RETURN_NAMES = ("aligned", "flow")
    FUNCTION = "apply"
    CATEGORY = "Image-Filters/image"

    def apply(self, images, ref_frame, direction, blur):
        t = images.detach().clone().movedim(-1,1) # [B x C x H x W]
        rf = min(ref_frame, t.shape[0] - 1)
        
        raft_model, raft_device = load_raft()
        ref = t[rf].unsqueeze(0).repeat(t.shape[0],1,1,1)
        if direction == "forward":
            flows = raft_flow(raft_model, raft_device, ref, t)
        else:
            flows = raft_flow(raft_model, raft_device, t, ref) * -1
        
        if blur > 0:
            d = blur * 2 + 1
            dup = flows.movedim(1,-1).detach().clone().cpu().numpy()
            blurred = []
            for img in dup:
                blurred.append(torch.from_numpy(cv2.GaussianBlur(img, (d,d), 0)).unsqueeze(0).movedim(-1,1))
            flows = torch.cat(blurred).to(flows.device)
        
        t = flow_warp(t, flows)
        
        t = t.movedim(1,-1) # [B x H x W x C]
        f = images.detach().clone() * 0
        f[:,:,:,:2] = flows.movedim(1,-1)
        return (t,f)


class InstructPixToPixConditioningAdvanced:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "positive": ("CONDITIONING", ),
                "negative": ("CONDITIONING", ),
                "new": ("LATENT", ),
                "new_scale": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 100.0, "step": 0.01}),
                "original": ("LATENT", ),
                "original_scale": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 100.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("CONDITIONING","CONDITIONING","CONDITIONING","LATENT")
    RETURN_NAMES = ("cond1", "cond2", "negative", "latent")
    FUNCTION = "encode"
    CATEGORY = "Image-Filters/conditioning"

    def encode(self, positive, negative, new, new_scale, original, original_scale):
        new_shape, orig_shape = new["samples"].shape, original["samples"].shape
        if new_shape != orig_shape:
            raise Exception(f"Latent shape mismatch: {new_shape} and {orig_shape}")
        
        out_latent = {}
        out_latent["samples"] = new["samples"] * new_scale
        out = []
        for conditioning in [positive, negative]:
            c = []
            for t in conditioning:
                d = t[1].copy()
                d["concat_latent_image"] = original["samples"] * original_scale
                n = [t[0], d]
                c.append(n)
            out.append(c)
        return (out[0], out[1], negative, out_latent)


class InpaintConditionEncode:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "vae": ("VAE", ),
                "pixels": ("IMAGE", ),
                "mask": ("MASK", ),
            },}

    RETURN_TYPES = ("INPAINT_CONDITION",)
    RETURN_NAMES = ("inpaint_condition",)
    FUNCTION = "encode"
    CATEGORY = "Image-Filters/conditioning"

    def encode(self, vae, pixels, mask):
        x = (pixels.shape[1] // 8) * 8
        y = (pixels.shape[2] // 8) * 8
        mask = F.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")

        pixels = pixels.clone()
        if pixels.shape[1] != x or pixels.shape[2] != y:
            x_offset = (pixels.shape[1] % 8) // 2
            y_offset = (pixels.shape[2] % 8) // 2
            pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
            mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]

        m = (1.0 - mask.round()).squeeze(1)
        for i in range(3):
            pixels[:,:,:,i] -= 0.5
            pixels[:,:,:,i] *= m
            pixels[:,:,:,i] += 0.5
        concat_latent = vae.encode(pixels)
        
        return ({"concat_latent_image": concat_latent, "concat_mask": mask},)


class InpaintConditionApply:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "positive": ("CONDITIONING", ),
                "negative": ("CONDITIONING", ),
                "inpaint_condition": ("INPAINT_CONDITION", ),
                "noise_mask": ("BOOLEAN", {"default": False, "tooltip": "Add a noise mask to the latent so sampling will only happen within the mask. Might improve results or completely break things depending on the model."}),
                },
            "optional": {
                "latents_optional": ("LATENT",),
            },}

    RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT")
    RETURN_NAMES = ("positive", "negative", "latent")
    FUNCTION = "encode"
    CATEGORY = "Image-Filters/conditioning"

    def encode(self, positive, negative, inpaint_condition, noise_mask=True, latents_optional=None):
        concat_latent = inpaint_condition["concat_latent_image"]
        concat_mask = inpaint_condition["concat_mask"]
        
        if latents_optional is not None:
            out_latent = latents_optional.copy()
        else:
            out_latent = {}
            out_latent["samples"] = torch.zeros_like(concat_latent)
        
        if noise_mask:
            out_latent["noise_mask"] = concat_mask

        out = []
        for conditioning in [positive, negative]:
            c = node_helpers.conditioning_set_values(conditioning, {"concat_latent_image": concat_latent,
                                                                    "concat_mask": concat_mask})
            out.append(c)
        return (out[0], out[1], out_latent)


class LatentNormalizeShuffle:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "flatten": ("INT", {"default": 0, "min": 0, "max": 16}),
                "normalize": ("BOOLEAN", {"default": True}),
                "shuffle": ("BOOLEAN", {"default": True}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "batch_normalize"
    CATEGORY = "Image-Filters/latent"

    def batch_normalize(self, latents, flatten, normalize, shuffle):
        latents_copy = copy.deepcopy(latents)
        t = latents_copy["samples"] # [B x C x H x W]
        
        if flatten > 0:
            d = flatten * 2 + 1
            channels = t.shape[1]
            kernel = gaussian_kernel(d, 1, device=t.device).repeat(channels, 1, 1).unsqueeze(1)
            t_blurred = F.conv2d(t, kernel, padding='same', groups=channels)
            t = t - t_blurred
        
        if normalize:
            for b in range(t.shape[0]):
                for c in range(4):
                    t_sd, t_mean = torch.std_mean(t[b,c])
                    t[b,c] = (t[b,c] - t_mean) / t_sd
        
        if shuffle:
            t_shuffle = []
            for i in (1,2,3,0):
                t_shuffle.append(t[:,i])
            t = torch.stack(t_shuffle, dim=1)
        
        latents_copy["samples"] = t
        return (latents_copy,)


class RandnLikeLatent:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "latents": ("LATENT", ),
                "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
            },
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "generate"
    CATEGORY = "Image-Filters/latent"

    def generate(self, latents, seed):
        latents_copy = copy.deepcopy(latents)
        gen_cpu = torch.Generator(device="cpu").manual_seed(seed)
        latents_copy["samples"] = randn_like_g(latents_copy["samples"], generator=gen_cpu)
        return (latents_copy,)


class PrintSigmas:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {"sigmas": ("SIGMAS",)}
        }

    RETURN_TYPES = ("SIGMAS",)
    FUNCTION = "notify"
    OUTPUT_NODE = True
    CATEGORY = "Image-Filters/utils"
    
    def notify(self, sigmas):
        print(sigmas)
        return (sigmas,)


class ShowSigmas:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {"sigmas": ("SIGMAS",)},
            "hidden": {"unique_id": "UNIQUE_ID",},
        }

    RETURN_TYPES = ("SIGMAS",)
    FUNCTION = "show_sigmas"
    OUTPUT_NODE = True
    CATEGORY = "Image-Filters/utils"
    
    def show_sigmas(self, sigmas, unique_id=None):
        if unique_id:
            PromptServer.instance.send_progress_text(f"{sigmas}", unique_id)
        return (sigmas,)


class VisualizeLatents:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {"latent": ("LATENT", ),}
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "visualize"
    CATEGORY = "Image-Filters/utils"

    def visualize(self, latent):
        latents = latent["samples"]
        batch, channels, height, width = latents.size()
        
        latents = latents - latents.mean()
        latents = latents / latents.std()
        latents = latents / 10 + 0.5
        
        scale = int(channels ** 0.5)
        vis = torch.zeros(batch, height * scale, width * scale)
        
        for i in range(channels):
            start_h  = (i % scale) * height
            end_h    = start_h + height
            start_w  = (i // scale) * width
            end_w    = start_w + width
            
            vis[:, start_h:end_h, start_w:end_w] = latents[:, i]
        
        return (vis.unsqueeze(-1).repeat(1, 1, 1, 3),)


class GameOfLife:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "width": ("INT", { "default": 32, "min": 8, "max": 1024, "step": 1}),
                "height": ("INT", { "default": 32, "min": 8, "max": 1024, "step": 1}),
                "cell_size": ("INT", { "default": 16, "min": 8, "max": 1024, "step": 8}),
                "seed": ("INT", { "default": 0, "min": 0, "max": 0xffffffffffffffff, "step": 1}),
                "threshold": ("FLOAT", { "default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}),
                "steps": ("INT", { "default": 64, "min": 1, "max": 999999, "step": 1}),
            },
            "optional": {
                "optional_start": ("MASK", ),
            },
        }

    RETURN_TYPES = ("IMAGE", "MASK", "MASK", "MASK")
    RETURN_NAMES = ("image", "mask", "off", "on")
    FUNCTION = "game"
    CATEGORY = "Image-Filters/image"

    def game(self, width, height, cell_size, seed, threshold, steps, optional_start=None):
        if optional_start is None:
            # base random initialization
            torch.manual_seed(seed)
            grid = torch.rand(1, 1, height, width)
        else:
            grid = optional_start[0].unsqueeze(0).unsqueeze(0)
            grid = F.interpolate(grid, size=(height, width))
        
        grid = (grid > threshold).type(torch.uint8)
        empty = torch.zeros(1, 1, height, width, dtype=torch.uint8)
        
        # neighbor convolution kernel
        kernel = torch.ones(1, 1, 3, 3, dtype=torch.uint8)
        kernel[0, 0, 1, 1] = 0
        
        game_states = [[], [], []] # grid, turn_off, turn_on
        game_states[0].append(grid.detach().clone())
        game_states[1].append(empty.detach().clone())
        game_states[2].append(empty.detach().clone())
        for step in range(steps - 1):
            new_state = grid.detach().clone()
            neighbors = F.conv2d(F.pad(new_state, pad=(1, 1, 1, 1), mode="circular"), kernel) #, padding="same")
            
            # If a cell is ON and has fewer than two neighbors that are ON, it turns OFF
            new_state[(new_state == 1) == (neighbors < 2)] = 0
            
            # If a cell is ON and has either two or three neighbors that are ON, it remains ON.
            
            # If a cell is ON and has more than three neighbors that are ON, it turns OFF.
            new_state[(new_state == 1) == (neighbors > 3)] = 0
            
            # If a cell is OFF and has exactly three neighbors that are ON, it turns ON.
            new_state[(new_state == 0) == (neighbors == 3)] = 1
            
            turn_off = ((grid - new_state) == 1).type(torch.uint8)
            turn_on = ((new_state - grid) == 1).type(torch.uint8)
            
            game_states[0].append(new_state.detach().clone())
            game_states[1].append(turn_off.detach().clone())
            game_states[2].append(turn_on.detach().clone())
            
            grid = new_state
        
        def postprocess(tensorlist, to_image=False):
            game_anim = torch.cat(tensorlist, dim=0).type(torch.float32)
            game_anim = F.interpolate(game_anim, size=(height * cell_size, width * cell_size))
            game_anim = torch.squeeze(game_anim, dim=1) # BCHW -> BHW
            if to_image:
                game_anim = game_anim.unsqueeze(-1).repeat(1,1,1,3) # BHWC
            return game_anim
        
        image = postprocess(game_states[0], to_image=True)
        mask = postprocess(game_states[0])
        off = postprocess(game_states[1])
        on = postprocess(game_states[2])
        
        return (image, mask, off, on)


modeltest_code_default = """d = model.model.model_config.unet_config
for k in d.keys():
    print(k, d[k])"""

class ModelTest:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "code": ("STRING", {"multiline": True, "default": modeltest_code_default}),
            },
        }

    RETURN_TYPES = ()
    FUNCTION = "test"
    OUTPUT_NODE = True
    CATEGORY = "Image-Filters/utils"
    
    def test(self, model, code):
        exec(code)
        return ()


class ConditioningSubtract:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "cond_orig": ("CONDITIONING", ),
                "cond_subtract": ("CONDITIONING", ),
                "subtract_strength": ("FLOAT", {"default": 1.0, "step": 0.01}),
            },
        }

    RETURN_TYPES = ("CONDITIONING",)
    FUNCTION = "addWeighted"
    CATEGORY = "Image-Filters/conditioning"

    def addWeighted(self, cond_orig, cond_subtract, subtract_strength):
        out = []

        if len(cond_subtract) > 1:
            logging.warning("Warning: ConditioningSubtract cond_subtract contains more than 1 cond, only the first one will actually be applied to cond_orig.")

        cond_from = cond_subtract[0][0]
        pooled_output_from = cond_subtract[0][1].get("pooled_output", None)

        for i in range(len(cond_orig)):
            t1 = cond_orig[i][0]
            pooled_output_to = cond_orig[i][1].get("pooled_output", pooled_output_from)
            t0 = cond_from[:,:t1.shape[1]]
            if t0.shape[1] < t1.shape[1]:
                t0 = torch.cat([t0] + [torch.zeros((1, (t1.shape[1] - t0.shape[1]), t1.shape[2]))], dim=1)

            tw = t1 - torch.mul(t0, subtract_strength)
            t_to = cond_orig[i][1].copy()
            if pooled_output_from is not None and pooled_output_to is not None:
                t_to["pooled_output"] = pooled_output_to - torch.mul(pooled_output_from, subtract_strength)
            elif pooled_output_from is not None:
                t_to["pooled_output"] = pooled_output_from

            n = [tw, t_to]
            out.append(n)
        return (out, )


class Noise_CustomNoise:
    def __init__(self, noise_latent):
        self.seed = 0
        self.noise_latent = noise_latent

    def generate_noise(self, input_latent):
        return self.noise_latent.detach().clone().cpu()


class CustomNoise:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required":{"noise": ("LATENT",),}
        }

    RETURN_TYPES = ("NOISE",)
    FUNCTION = "get_noise"
    CATEGORY = "Image-Filters/sampling"

    def get_noise(self, noise):
        noise_latent = noise["samples"].detach().clone()
        std, mean = torch.std_mean(noise_latent, dim=(-2, -1), keepdim=True)
        noise_latent = (noise_latent - mean) / std
        return (Noise_CustomNoise(noise_latent),)


class ExtractNFrames:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "frames": ("INT", {"default": 16, "min": 2}),
            },
            "optional": {
                "images": ("IMAGE",),
                "masks": ("MASK",),
            },
        }

    RETURN_TYPES = ("LIST", "IMAGE", "MASK")
    RETURN_NAMES = ("index_list", "images", "masks")
    FUNCTION = "extract"
    CATEGORY = "Image-Filters/image/frames"
    
    def extract(self, frames, images=None, masks=None):
        original_length = 2
        if images is not None: original_length = max(original_length, len(images))
        if masks is not None: original_length = max(original_length, len(masks))
        
        n = min(original_length, frames)
        step = step = (original_length - 1) / (n - 1)
        ids = [round(i * step) for i in range(n)]
        while len(ids) < frames:
            ids.append(ids[-1])
        
        new_images = []
        new_masks = []
        for i in ids:
            if images is not None:
                new_images.append(images[min(i, len(images) - 1)].detach().clone())
            else:
                new_images.append(torch.zeros(512, 512, 3))
            
            if masks is not None:
                new_masks.append(masks[min(i, len(masks) - 1)].detach().clone())
            else:
                new_masks.append(torch.zeros(512, 512))
        
        return (ids, torch.stack(new_images, dim=0), torch.stack(new_masks, dim=0))


class MergeFramesByIndex:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "index_list": ("LIST",),
                "orig_images": ("IMAGE",),
                "images": ("IMAGE",),
            },
            "optional": {
                "orig_masks": ("MASK",),
                "masks": ("MASK",),
            },
        }

    RETURN_TYPES = ("IMAGE", "MASK")
    RETURN_NAMES = ("images", "masks")
    FUNCTION = "merge"
    CATEGORY = "Image-Filters/image/frames"
    
    def merge(self, index_list, orig_images, images, orig_masks=None, masks=None):
        new_images = orig_images.detach().clone()
        new_masks = torch.ones_like(new_images[..., 0]) # BHW
        if orig_masks is not None:
            for i in range(len(new_masks)):
                new_masks[i] = orig_masks[min(i, len(orig_masks) - 1)].detach().clone()
        
        for i, frame in enumerate(index_list):
            frame_mask = masks[i] if masks is not None else torch.ones_like(new_masks[i])
            new_images[frame] *= (1 - frame_mask[..., None])
            new_images[frame] += images[i].detach().clone() * frame_mask[..., None]
            new_masks[frame] *= 0
        
        return (new_images, new_masks)


class Hunyuan3Dv2LatentUpscaleBy:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT",),
                "scale_by": ("FLOAT", {"default": 2.0, "min": 0.01, "max": 8.0, "step": 0.01}),
            },
        }
    
    RETURN_TYPES = ("LATENT",)
    FUNCTION = "upscale"
    CATEGORY = "Image-Filters/latent"

    def upscale(self, samples, scale_by):
        s = samples.copy()
        size = round(samples["samples"].shape[-1] * scale_by)
        s["samples"] = F.interpolate(samples["samples"], size=(size,), mode="nearest-exact")
        return (s,)


class PackVideoMask:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "mask": ("MASK",),
                "blend_mode": (["max", "min", "average"], {"default": "max"}),
                "causal": ("BOOLEAN", {"default": True, "tooltip": "First latent frame is a single frame"}),
                "stride": ("INT", {"default": 4, "min": 1, "tooltip": "downsampling factor to match VAE, ie 4 for Wan, 8 for LTXV"}),
            },
        }

    RETURN_TYPES = ("MASK",)
    FUNCTION = "pack_mask"
    CATEGORY = "Image-Filters/mask"

    def pack_mask(self, mask, blend_mode, causal, stride):
        packed_mask = mask.detach().clone()
        
        # repeat first frame to match stride
        if causal:
            dup_first_frame = packed_mask[0].unsqueeze(0).repeat(stride - 1, 1, 1)
            packed_mask = torch.cat([dup_first_frame, packed_mask], dim=0)
        
        # repeat last frame to match stride
        remainder = packed_mask.shape[0] % stride
        if remainder > 0:
            dup_last_frame = packed_mask[-1].unsqueeze(0).repeat(stride - remainder, 1, 1)
            packed_mask = torch.cat([packed_mask, dup_last_frame], dim=0)
        
        # shuffle every n frame chunk to channels
        B, H, W = packed_mask.shape
        packed_mask = packed_mask.reshape(B // stride, stride, H, W)
        
        # squash channels
        if blend_mode == "max":
            squashed_mask = packed_mask.max(dim=1).values
        elif blend_mode == "min":
            squashed_mask = packed_mask.min(dim=1).values
        else: # average
            squashed_mask = packed_mask.mean(dim=1)
        
        return (squashed_mask,)


class PoissonNoise:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "image": ("IMAGE",),
                "gain": ("FLOAT", {"default": 1000, "min": 0.001, "max": 1_000_000, "step": 0.001}),
                "gain_r": ("FLOAT", {"default": 1.0, "min": 0, "max": 1_000_000, "step": 0.001}),
                "gain_g": ("FLOAT", {"default": 2.0, "min": 0, "max": 1_000_000, "step": 0.001}),
                "gain_b": ("FLOAT", {"default": 0.5, "min": 0, "max": 1_000_000, "step": 0.001}),
                "clamp": ("BOOLEAN", {"default": True}),
                "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
            },
        }

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "poissson_noise"
    CATEGORY = "Image-Filters/image"

    def poissson_noise(self, image, gain, gain_r, gain_g, gain_b, clamp, seed):
        linear = sRGBtoLinear_pt(image.cpu().clone())
        
        linear[..., 0] *= gain_r
        linear[..., 1] *= gain_g
        linear[..., 2] *= gain_b
        
        generator = torch.Generator("cpu").manual_seed(seed)
        noise = torch.poisson(linear * gain, generator) * (1 / gain)
        
        noise[..., 0] *= 1 / gain_r
        noise[..., 1] *= 1 / gain_g
        noise[..., 2] *= 1 / gain_b
        
        output = linearToSRGB_pt(noise)
        if clamp: output = torch.clamp(output, min=0, max=1)
        return(output,)


COMBINED_MAPPINGS = {
    "AdainFilterLatent":          (AdainFilterLatent,          "AdaIN Filter (Latent)"),
    "AdainImage":                 (AdainImage,                 "AdaIN (Image)"),
    "AdainLatent":                (AdainLatent,                "AdaIN (Latent)"),
    "AlphaClean":                 (AlphaClean,                 "Alpha Clean (DEPRECATED, use MaskClean)"),
    "AlphaMatte":                 (AlphaMatte,                 "Alpha Matte (DEPRECATED, use ImageMatting)"),
    "BatchAlign":                 (BatchAlign,                 "Batch Align (RAFT)"),
    "BatchAverageImage":          (BatchAverageImage,          "Batch Average Image"),
    "BatchAverageUnJittered":     (BatchAverageUnJittered,     "Batch Average Un-Jittered"),
    "BatchNormalizeImage":        (BatchNormalizeImage,        "Batch Normalize (Image)"),
    "BatchNormalizeLatent":       (BatchNormalizeLatent,       "Batch Normalize (Latent)"),
    "BetterFilmGrain":            (BetterFilmGrain,            "Better Film Grain"),
    "BilateralFilterImage":       (BilateralFilterImage,       "Bilateral Filter Image"),
    "BlurImageFast":              (BlurImageFast,              "Blur Image (Fast)"),
    "BlurMaskFast":               (BlurMaskFast,               "Blur Mask (Fast)"),
    "ClampImage":                 (ClampImage,                 "Clamp Image"),
    "ClampOutliers":              (ClampOutliers,              "Clamp Outliers"),
    "ColorMatchImage":            (ColorMatchImage,            "Color Match Image"),
    "ConditioningSubtract":       (ConditioningSubtract,       "ConditioningSubtract"),
    "ConvertNormals":             (ConvertNormals,             "Convert Normals"),
    "CustomNoise":                (CustomNoise,                "CustomNoise"),
    "DepthToNormals":             (DepthToNormals,             "Depth To Normals"),
    "DifferenceChecker":          (DifferenceChecker,          "Difference Checker"),
    "DilateErodeMask":            (DilateErodeMask,            "Dilate/Erode Mask"),
    "EnhanceDetail":              (EnhanceDetail,              "Enhance Detail"),
    "ExposureAdjust":             (ExposureAdjust,             "Exposure Adjust"),
    "ExtractNFrames":             (ExtractNFrames,             "Extract N Frames"),
    "FrequencyCombine":           (FrequencyCombine,           "Frequency Combine"),
    "FrequencySeparate":          (FrequencySeparate,          "Frequency Separate"),
    "GameOfLife":                 (GameOfLife,                 "Game Of Life"),
    "GuidedFilterImage":          (GuidedFilterImage,          "Guided Filter Image"),
    "Hunyuan3Dv2LatentUpscaleBy": (Hunyuan3Dv2LatentUpscaleBy, "Upscale Hunyuan3Dv2 Latent By"),
    "ImageConstant":              (ImageConstant,              "Image Constant Color (RGB)"),
    "ImageConstantHSV":           (ImageConstantHSV,           "Image Constant Color (HSV)"),
    "ImageMatting":               (ImageMatting,               "Image Matting"),
    "InpaintConditionApply":      (InpaintConditionApply,      "Inpaint Condition Apply"),
    "InpaintConditionEncode":     (InpaintConditionEncode,     "Inpaint Condition Encode"),
    "InstructPixToPixConditioningAdvanced": (InstructPixToPixConditioningAdvanced, "IP2P Conditioning Advanced"),
    "JitterImage":                (JitterImage,                "Jitter Image"),
    "Keyer":                      (Keyer,                      "Keyer"),
    "LatentNormalizeShuffle":     (LatentNormalizeShuffle,     "LatentNormalizeShuffle"),
    "RandnLikeLatent":            (RandnLikeLatent,            "RandnLikeLatent"),
    "LatentStats":                (LatentStats,                "Latent Stats"),
    "MaskClean":                  (MaskClean,                  "Mask (Alpha) Clean"),
    "MedianFilterImage":          (MedianFilterImage,          "Median Filter Image"),
    "MergeFramesByIndex":         (MergeFramesByIndex,         "Merge Frames By Index"),
    "ModelTest":                  (ModelTest,                  "Model Test"),
    "NormalMapSimple":            (NormalMapSimple,            "Normal Map (Simple)"),
    "OffsetLatentImage":          (OffsetLatentImage,          "Offset Latent Image"),
    "PackVideoMask":              (PackVideoMask,              "Pack Video Mask"),
    "PoissonNoise":               (PoissonNoise,               "Poisson Noise Image"),
    "PrintSigmas":                (PrintSigmas,                "Print Sigmas"),
    "RelightSimple":              (RelightSimple,              "Relight (Simple)"),
    "RemapRange":                 (RemapRange,                 "Remap Range"),
    "RestoreDetail":              (RestoreDetail,              "Restore Detail"),
    "SharpenFilterLatent":        (SharpenFilterLatent,        "Sharpen Filter (Latent)"),
    "ShowSigmas":                 (ShowSigmas,                 "Show Sigmas"),
    "ShuffleChannels":            (ShuffleChannels,            "Shuffle Channels"),
    "Tonemap":                    (Tonemap,                    "Tonemap"),
    "UnJitterImage":              (UnJitterImage,              "Un-Jitter Image"),
    "UnTonemap":                  (UnTonemap,                  "UnTonemap"),
    "VisualizeLatents":           (VisualizeLatents,           "Visualize Latents"),
}

================================================
FILE: raft.py
================================================
import os
import torch
import torch.nn.functional as F
from torchvision.models.optical_flow import Raft_Large_Weights, raft_large


def load_raft():
    model_dir = os.path.join(os.path.split(__file__)[0], "models")
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    
    raft_weights = Raft_Large_Weights.DEFAULT
    raft_path = os.path.join(model_dir, str(raft_weights) + ".pth")
    
    if os.path.exists(raft_path):
        model = raft_large()
        model.load_state_dict(torch.load(raft_path))
    else:
        model = raft_large(weights=raft_weights, progress=True)
        torch.save(model.state_dict(), raft_path)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()
    return (model, device)

def raft_flow(model, device, batch1, batch2):
    orig_H = batch1.shape[2]
    orig_W = batch1.shape[3]
    scale_factor = max(orig_H, orig_W) / 512
    new_H = int(((orig_H / scale_factor) // 8) * 8)
    new_W = int(((orig_W / scale_factor) // 8) * 8)
    
    if scale_factor > 1 or max(orig_H % 8, orig_W % 8) > 0:
        batch1_scaled = F.interpolate(batch1, size=(new_H, new_W), mode='bilinear')
        batch2_scaled = F.interpolate(batch2, size=(new_H, new_W), mode='bilinear')
        
        with torch.no_grad():
            flow = model(batch1_scaled.to(device), batch2_scaled.to(device))[-1]
        flow = F.interpolate(flow, size=(orig_H, orig_W), mode='bilinear')
        flow[:,0,:,:] *= orig_W / new_W
        flow[:,1,:,:] *= orig_H / new_H
    else:
        with torch.no_grad():
            flow = model(batch1.to(device), batch2.to(device))[-1]
    
    return flow.to(batch1.device)

def flow_warp(image, flow):
    B, C, H, W = image.size()
    # mesh grid
    xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
    yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
    xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
    yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
    grid = torch.cat((xx, yy), 1).float()
    
    grid = grid.to(image.device)
    vgrid = grid + flow
    
    # scale grid to [-1,1] for grid_sample
    vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
    vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
    vgrid = vgrid.permute(0, 2, 3, 1)
    output = F.grid_sample(image, vgrid, mode='bicubic', padding_mode='border', align_corners=True)
    return output

================================================
FILE: requirements.txt
================================================
opencv-contrib-python>=4.7.0.72
opencv-contrib-python-headless>=4.7.0.72
opencv-python>=4.7.0.72
opencv-python-headless>=4.7.0.72
pymatting
Download .txt
gitextract_qxw4_7mm/

├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── download_all_models.py
├── import_error_install.bat
├── install.bat
├── nodes.py
├── raft.py
└── requirements.txt
Download .txt
SYMBOL INDEX (206 symbols across 2 files)

FILE: nodes.py
  function cv_blur_tensor (line 25) | def cv_blur_tensor(images, dx, dy):
  function guided_filter_tensor (line 38) | def guided_filter_tensor(ref, images, d, s):
  function std_mean_filter (line 53) | def std_mean_filter(t, d):
  function RGB2YCbCr (line 59) | def RGB2YCbCr(t):
  function YCbCr2RGB (line 66) | def YCbCr2RGB(t):
  function hsv_to_rgb (line 73) | def hsv_to_rgb(h, s, v):
  function sRGBtoLinear (line 91) | def sRGBtoLinear(npArray):
  function linearToSRGB (line 96) | def linearToSRGB(npArray):
  function sRGBtoLinear_pt (line 101) | def sRGBtoLinear_pt(t: torch.Tensor):
  function linearToSRGB_pt (line 107) | def linearToSRGB_pt(t: torch.Tensor):
  function linearToTonemap (line 113) | def linearToTonemap(npArray, tonemap_scale):
  function tonemapToLinear (line 120) | def tonemapToLinear(npArray, tonemap_scale):
  function exposure (line 128) | def exposure(npArray, stops):
  function randn_like_g (line 132) | def randn_like_g(x, generator=None):
  class AlphaClean (line 138) | class AlphaClean:
    method INPUT_TYPES (line 140) | def INPUT_TYPES(s):
    method alpha_clean (line 156) | def alpha_clean(self, images: torch.Tensor, radius: int, fill_holes: i...
  class MaskClean (line 190) | class MaskClean:
    method INPUT_TYPES (line 192) | def INPUT_TYPES(s):
    method alpha_clean (line 207) | def alpha_clean(self, mask, radius, fill_holes, white_threshold, extra...
  class AlphaMatte (line 240) | class AlphaMatte:
    method INPUT_TYPES (line 242) | def INPUT_TYPES(s):
    method alpha_matte (line 261) | def alpha_matte(self, images, alpha_trimap, preblur, blackpoint, white...
  class ImageMatting (line 290) | class ImageMatting:
    method INPUT_TYPES (line 292) | def INPUT_TYPES(s):
    method alpha_matte (line 310) | def alpha_matte(self, images, trimap, preblur, blackpoint, whitepoint,...
  class BetterFilmGrain (line 339) | class BetterFilmGrain:
    method INPUT_TYPES (line 341) | def INPUT_TYPES(s):
    method grain (line 357) | def grain(self, image, scale, strength, saturation, toe, seed):
  class BlurImageFast (line 378) | class BlurImageFast:
    method INPUT_TYPES (line 380) | def INPUT_TYPES(s):
    method blur_image (line 393) | def blur_image(self, images, radius_x, radius_y):
  class BlurMaskFast (line 408) | class BlurMaskFast:
    method INPUT_TYPES (line 410) | def INPUT_TYPES(s):
    method blur_mask (line 423) | def blur_mask(self, masks, radius_x, radius_y):
  class ColorMatchImage (line 438) | class ColorMatchImage:
    method INPUT_TYPES (line 440) | def INPUT_TYPES(s):
    method batch_normalize (line 455) | def batch_normalize(self, images, reference, blur_type, blur_size, fac...
  class RestoreDetail (line 490) | class RestoreDetail:
    method INPUT_TYPES (line 492) | def INPUT_TYPES(s):
    method batch_normalize (line 508) | def batch_normalize(self, images, detail, mode, blur_type, blur_size, ...
  class DilateErodeMask (line 534) | class DilateErodeMask:
    method INPUT_TYPES (line 536) | def INPUT_TYPES(s):
    method dilate_mask (line 549) | def dilate_mask(self, masks, radius, shape):
  class EnhanceDetail (line 572) | class EnhanceDetail:
    method INPUT_TYPES (line 574) | def INPUT_TYPES(s):
    method enhance (line 589) | def enhance(self, images: torch.Tensor, filter_radius: int, sigma: flo...
  class GuidedFilterImage (line 612) | class GuidedFilterImage:
    method INPUT_TYPES (line 614) | def INPUT_TYPES(s):
    method filter_image (line 628) | def filter_image(self, images, guide, size, sigma):
  class MedianFilterImage (line 635) | class MedianFilterImage:
    method INPUT_TYPES (line 637) | def INPUT_TYPES(s):
    method filter_image (line 649) | def filter_image(self, images, size):
  class BilateralFilterImage (line 662) | class BilateralFilterImage:
    method INPUT_TYPES (line 664) | def INPUT_TYPES(s):
    method filter_image (line 678) | def filter_image(self, images, size, sigma_color, sigma_space):
  class FrequencyCombine (line 686) | class FrequencyCombine:
    method INPUT_TYPES (line 688) | def INPUT_TYPES(s):
    method filter_image (line 702) | def filter_image(self, high_frequency, low_frequency, mode, eps):
  class FrequencySeparate (line 711) | class FrequencySeparate:
    method INPUT_TYPES (line 713) | def INPUT_TYPES(s):
    method filter_image (line 728) | def filter_image(self, original, low_frequency, mode, eps):
  class RemapRange (line 737) | class RemapRange:
    method INPUT_TYPES (line 739) | def INPUT_TYPES(s):
    method remap (line 752) | def remap(self, image: torch.Tensor, blackpoint: float, whitepoint: fl...
  class ClampImage (line 762) | class ClampImage:
    method INPUT_TYPES (line 764) | def INPUT_TYPES(s):
    method clamp_image (line 777) | def clamp_image(self, image: torch.Tensor, blackpoint: float, whitepoi...
  class ShuffleChannels (line 785) | class ShuffleChannels:
    method INPUT_TYPES (line 787) | def INPUT_TYPES(s):
    method shuffle (line 802) | def shuffle(self, image, red, green, blue, alpha):
  class ClampOutliers (line 826) | class ClampOutliers:
    method INPUT_TYPES (line 828) | def INPUT_TYPES(s):
    method clamp_outliers (line 840) | def clamp_outliers(self, latents, std_dev):
  class AdainLatent (line 853) | class AdainLatent:
    method INPUT_TYPES (line 855) | def INPUT_TYPES(s):
    method batch_normalize (line 868) | def batch_normalize(self, latents, reference, factor):
  class AdainFilterLatent (line 881) | class AdainFilterLatent:
    method INPUT_TYPES (line 883) | def INPUT_TYPES(s):
    method batch_normalize (line 897) | def batch_normalize(self, latents, reference, filter_size, factor):
  class SharpenFilterLatent (line 928) | class SharpenFilterLatent:
    method INPUT_TYPES (line 930) | def INPUT_TYPES(s):
    method filter_latent (line 943) | def filter_latent(self, latents, filter_size, factor):
  class AdainImage (line 965) | class AdainImage:
    method INPUT_TYPES (line 967) | def INPUT_TYPES(s):
    method batch_normalize (line 980) | def batch_normalize(self, images, reference, factor):
  class BatchNormalizeLatent (line 995) | class BatchNormalizeLatent:
    method INPUT_TYPES (line 997) | def INPUT_TYPES(s):
    method batch_normalize (line 1009) | def batch_normalize(self, latents, factor):
  class BatchNormalizeImage (line 1027) | class BatchNormalizeImage:
    method INPUT_TYPES (line 1029) | def INPUT_TYPES(s):
    method batch_normalize (line 1041) | def batch_normalize(self, images, factor):
  class DifferenceChecker (line 1059) | class DifferenceChecker:
    method INPUT_TYPES (line 1061) | def INPUT_TYPES(s):
    method difference_checker (line 1076) | def difference_checker(self, images1, images2, multiplier, print_MAE):
  class ImageConstant (line 1084) | class ImageConstant:
    method __init__ (line 1085) | def __init__(self, device="cpu"):
    method INPUT_TYPES (line 1089) | def INPUT_TYPES(s):
    method generate (line 1105) | def generate(self, width, height, batch_size, red, green, blue):
  class ImageConstantHSV (line 1112) | class ImageConstantHSV:
    method __init__ (line 1113) | def __init__(self, device="cpu"):
    method INPUT_TYPES (line 1117) | def INPUT_TYPES(s):
    method generate (line 1133) | def generate(self, width, height, batch_size, hue, saturation, value):
  class OffsetLatentImage (line 1142) | class OffsetLatentImage:
    method __init__ (line 1143) | def __init__(self):
    method INPUT_TYPES (line 1147) | def INPUT_TYPES(s):
    method generate (line 1164) | def generate(self, width, height, batch_size, offset_0, offset_1, offs...
  class RelightSimple (line 1173) | class RelightSimple:
    method INPUT_TYPES (line 1175) | def INPUT_TYPES(s):
    method relight (line 1191) | def relight(self, image, normals, x, y, z, brightness):
  class LatentStats (line 1207) | class LatentStats:
    method INPUT_TYPES (line 1209) | def INPUT_TYPES(s):
    method notify (line 1218) | def notify(self, latent):
  class Tonemap (line 1254) | class Tonemap:
    method INPUT_TYPES (line 1256) | def INPUT_TYPES(s):
    method apply (line 1270) | def apply(self, images, input_mode, output_mode, tonemap_scale):
  class UnTonemap (line 1286) | class UnTonemap:
    method INPUT_TYPES (line 1288) | def INPUT_TYPES(s):
    method apply (line 1302) | def apply(self, images, input_mode, output_mode, tonemap_scale):
  class ExposureAdjust (line 1318) | class ExposureAdjust:
    method INPUT_TYPES (line 1320) | def INPUT_TYPES(s):
    method adjust_exposure (line 1336) | def adjust_exposure(self, images, stops, input_mode, output_mode, tone...
  class ConvertNormals (line 1364) | class ConvertNormals:
    method INPUT_TYPES (line 1366) | def INPUT_TYPES(s):
    method convert_normals (line 1385) | def convert_normals(self, normals, input_mode, output_mode, scale_XY, ...
  class BatchAverageImage (line 1420) | class BatchAverageImage:
    method INPUT_TYPES (line 1422) | def INPUT_TYPES(s):
    method apply (line 1434) | def apply(self, images, operation):
  class NormalMapSimple (line 1443) | class NormalMapSimple:
    method INPUT_TYPES (line 1445) | def INPUT_TYPES(s):
    method normal_map (line 1457) | def normal_map(self, images, scale_XY):
  class DepthToNormals (line 1470) | class DepthToNormals:
    method INPUT_TYPES (line 1472) | def INPUT_TYPES(s):
    method normal_map (line 1486) | def normal_map(self, depth, scale, output_mode):
  class Keyer (line 1520) | class Keyer:
    method INPUT_TYPES (line 1522) | def INPUT_TYPES(s):
    method keyer (line 1539) | def keyer(self, images, operation, low, high, gamma, premult):
  class JitterImage (line 1586) | class JitterImage:
    method INPUT_TYPES (line 1588) | def INPUT_TYPES(s):
    method jitter (line 1600) | def jitter(self, images, jitter_scale):
  class UnJitterImage (line 1618) | class UnJitterImage:
    method INPUT_TYPES (line 1620) | def INPUT_TYPES(s):
    method jitter (line 1633) | def jitter(self, images, jitter_scale, oflow_align):
  class BatchAverageUnJittered (line 1664) | class BatchAverageUnJittered:
    method INPUT_TYPES (line 1666) | def INPUT_TYPES(s):
    method apply (line 1678) | def apply(self, images, operation):
  class BatchAlign (line 1691) | class BatchAlign:
    method INPUT_TYPES (line 1693) | def INPUT_TYPES(s):
    method apply (line 1708) | def apply(self, images, ref_frame, direction, blur):
  class InstructPixToPixConditioningAdvanced (line 1735) | class InstructPixToPixConditioningAdvanced:
    method INPUT_TYPES (line 1737) | def INPUT_TYPES(s):
    method encode (line 1754) | def encode(self, positive, negative, new, new_scale, original, origina...
  class InpaintConditionEncode (line 1773) | class InpaintConditionEncode:
    method INPUT_TYPES (line 1775) | def INPUT_TYPES(s):
    method encode (line 1788) | def encode(self, vae, pixels, mask):
  class InpaintConditionApply (line 1810) | class InpaintConditionApply:
    method INPUT_TYPES (line 1812) | def INPUT_TYPES(s):
    method encode (line 1829) | def encode(self, positive, negative, inpaint_condition, noise_mask=Tru...
  class LatentNormalizeShuffle (line 1850) | class LatentNormalizeShuffle:
    method INPUT_TYPES (line 1852) | def INPUT_TYPES(s):
    method batch_normalize (line 1866) | def batch_normalize(self, latents, flatten, normalize, shuffle):
  class RandnLikeLatent (line 1893) | class RandnLikeLatent:
    method INPUT_TYPES (line 1895) | def INPUT_TYPES(s):
    method generate (line 1907) | def generate(self, latents, seed):
  class PrintSigmas (line 1914) | class PrintSigmas:
    method INPUT_TYPES (line 1916) | def INPUT_TYPES(s):
    method notify (line 1926) | def notify(self, sigmas):
  class ShowSigmas (line 1931) | class ShowSigmas:
    method INPUT_TYPES (line 1933) | def INPUT_TYPES(s):
    method show_sigmas (line 1944) | def show_sigmas(self, sigmas, unique_id=None):
  class VisualizeLatents (line 1950) | class VisualizeLatents:
    method INPUT_TYPES (line 1952) | def INPUT_TYPES(s):
    method visualize (line 1961) | def visualize(self, latent):
  class GameOfLife (line 1983) | class GameOfLife:
    method INPUT_TYPES (line 1985) | def INPUT_TYPES(s):
    method game (line 2005) | def game(self, width, height, cell_size, seed, threshold, steps, optio...
  class ModelTest (line 2069) | class ModelTest:
    method INPUT_TYPES (line 2071) | def INPUT_TYPES(s):
    method test (line 2084) | def test(self, model, code):
  class ConditioningSubtract (line 2089) | class ConditioningSubtract:
    method INPUT_TYPES (line 2091) | def INPUT_TYPES(s):
    method addWeighted (line 2104) | def addWeighted(self, cond_orig, cond_subtract, subtract_strength):
  class Noise_CustomNoise (line 2132) | class Noise_CustomNoise:
    method __init__ (line 2133) | def __init__(self, noise_latent):
    method generate_noise (line 2137) | def generate_noise(self, input_latent):
  class CustomNoise (line 2141) | class CustomNoise:
    method INPUT_TYPES (line 2143) | def INPUT_TYPES(s):
    method get_noise (line 2152) | def get_noise(self, noise):
  class ExtractNFrames (line 2159) | class ExtractNFrames:
    method INPUT_TYPES (line 2161) | def INPUT_TYPES(s):
    method extract (line 2177) | def extract(self, frames, images=None, masks=None):
  class MergeFramesByIndex (line 2204) | class MergeFramesByIndex:
    method INPUT_TYPES (line 2206) | def INPUT_TYPES(s):
    method merge (line 2224) | def merge(self, index_list, orig_images, images, orig_masks=None, mask...
  class Hunyuan3Dv2LatentUpscaleBy (line 2240) | class Hunyuan3Dv2LatentUpscaleBy:
    method INPUT_TYPES (line 2242) | def INPUT_TYPES(s):
    method upscale (line 2254) | def upscale(self, samples, scale_by):
  class PackVideoMask (line 2261) | class PackVideoMask:
    method INPUT_TYPES (line 2263) | def INPUT_TYPES(s):
    method pack_mask (line 2277) | def pack_mask(self, mask, blend_mode, causal, stride):
  class PoissonNoise (line 2306) | class PoissonNoise:
    method INPUT_TYPES (line 2308) | def INPUT_TYPES(s):
    method poissson_noise (line 2325) | def poissson_noise(self, image, gain, gain_r, gain_g, gain_b, clamp, s...

FILE: raft.py
  function load_raft (line 7) | def load_raft():
  function raft_flow (line 26) | def raft_flow(model, device, batch1, batch2):
  function flow_warp (line 48) | def flow_warp(image, flow):
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (116K chars).
[
  {
    "path": ".gitignore",
    "chars": 3096,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# models\nmodels/\n\n# C extensions\n*.so\n\n# Dist"
  },
  {
    "path": "LICENSE",
    "chars": 1065,
    "preview": "MIT License\n\nCopyright (c) 2023 spacepxl\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
  },
  {
    "path": "README.md",
    "chars": 8349,
    "preview": "## ComfyUI-Image-Filters\n\nStarted as just some image processing nodes, but now more of a kitchen sink nodepack\n\nTwo inst"
  },
  {
    "path": "__init__.py",
    "chars": 351,
    "preview": "# from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS\r\nfrom .nodes import COMBINED_MAPPINGS\r\n\r\nNODE_CLASS"
  },
  {
    "path": "download_all_models.py",
    "chars": 39,
    "preview": "from raft import load_raft\n\nload_raft()"
  },
  {
    "path": "import_error_install.bat",
    "chars": 718,
    "preview": "@echo off\r\n\r\nset \"requirements_txt=%~dp0\\requirements.txt\"\r\nset \"python_exec=..\\..\\..\\python_embeded\\python.exe\"\r\n\r\necho"
  },
  {
    "path": "install.bat",
    "chars": 480,
    "preview": "@echo off\r\n\r\nset \"requirements_txt=%~dp0\\requirements.txt\"\r\nset \"python_exec=..\\..\\..\\python_embeded\\python.exe\"\r\n\r\necho"
  },
  {
    "path": "nodes.py",
    "chars": 90750,
    "preview": "import math\r\nimport copy\r\nimport torch\r\nimport torch.nn.functional as F\r\nimport numpy as np\r\nimport cv2\r\nfrom pymatting "
  },
  {
    "path": "raft.py",
    "chars": 2422,
    "preview": "import os\nimport torch\nimport torch.nn.functional as F\nfrom torchvision.models.optical_flow import Raft_Large_Weights, r"
  },
  {
    "path": "requirements.txt",
    "chars": 139,
    "preview": "opencv-contrib-python>=4.7.0.72\nopencv-contrib-python-headless>=4.7.0.72\nopencv-python>=4.7.0.72\nopencv-python-headless>"
  }
]

About this extraction

This page contains the full source code of the spacepxl/ComfyUI-Image-Filters GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (104.9 KB), approximately 29.6k tokens, and a symbol index with 206 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!