Repository: robertvoy/ComfyUI-Distributed
Branch: main
Commit: a91f9fb081eb
Files: 130
Total size: 895.4 KB
Directory structure:
gitextract_vxcmpaik/
├── .github/
│ ├── FUNDING.yml
│ └── workflows/
│ └── publish_action.yml
├── .gitignore
├── .nvmrc
├── LICENSE
├── README.md
├── __init__.py
├── api/
│ ├── __init__.py
│ ├── config_routes.py
│ ├── job_routes.py
│ ├── orchestration/
│ │ ├── __init__.py
│ │ ├── dispatch.py
│ │ ├── media_sync.py
│ │ └── prompt_transform.py
│ ├── queue_orchestration.py
│ ├── queue_request.py
│ ├── schemas.py
│ ├── tunnel_routes.py
│ ├── usdu_routes.py
│ └── worker_routes.py
├── conftest.py
├── distributed.py
├── docs/
│ ├── comfyui-distributed-api.md
│ ├── model-download-script.md
│ ├── video-upscaler-runpod-preset.md
│ └── worker-setup-guides.md
├── nodes/
│ ├── __init__.py
│ ├── collector.py
│ ├── distributed_upscale.py
│ └── utilities.py
├── package.json
├── pyproject.toml
├── scripts/
│ └── test-web.sh
├── tests/
│ ├── api/
│ │ ├── test_config_routes.py
│ │ ├── test_distributed_queue.py
│ │ ├── test_media_sync.py
│ │ ├── test_usdu_routes.py
│ │ └── test_worker_routes.py
│ ├── conftest.py
│ ├── test_async_helpers.py
│ ├── test_batch_dividers.py
│ ├── test_config.py
│ ├── test_detection.py
│ ├── test_dispatch_selection.py
│ ├── test_distributed_value.py
│ ├── test_job_timeout.py
│ ├── test_network_helpers.py
│ ├── test_payload_parsers.py
│ ├── test_prompt_transform.py
│ ├── test_queue_request.py
│ ├── test_static_mode.py
│ └── test_worker_process_runtime.py
├── upscale/
│ ├── __init__.py
│ ├── conditioning.py
│ ├── job_models.py
│ ├── job_state.py
│ ├── job_store.py
│ ├── job_timeout.py
│ ├── modes/
│ │ ├── __init__.py
│ │ ├── dynamic.py
│ │ ├── single_gpu.py
│ │ └── static.py
│ ├── payload_parsers.py
│ ├── result_collector.py
│ ├── tile_ops.py
│ └── worker_comms.py
├── utils/
│ ├── __init__.py
│ ├── async_helpers.py
│ ├── audio_payload.py
│ ├── cloudflare/
│ │ ├── __init__.py
│ │ ├── binary.py
│ │ ├── process_reader.py
│ │ ├── state.py
│ │ └── tunnel.py
│ ├── config.py
│ ├── constants.py
│ ├── crop_model_patch.py
│ ├── exceptions.py
│ ├── image.py
│ ├── logging.py
│ ├── network.py
│ ├── process.py
│ ├── trace_logger.py
│ ├── usdu_managment.py
│ └── usdu_utils.py
├── vitest.config.js
├── web/
│ ├── apiClient.js
│ ├── constants.js
│ ├── distributed.css
│ ├── distributedValue.js
│ ├── executionUtils.js
│ ├── image_batch_divider.js
│ ├── main.js
│ ├── masterDetection.js
│ ├── sidebar/
│ │ ├── actionsSection.js
│ │ ├── settingsSection.js
│ │ └── workersSection.js
│ ├── sidebarRenderer.js
│ ├── stateManager.js
│ ├── tests/
│ │ ├── apiClient.test.js
│ │ ├── executionUtils.test.js
│ │ ├── urlUtils.test.js
│ │ ├── workerLifecycle.test.js
│ │ └── workerSettings.test.js
│ ├── tunnelManager.js
│ ├── ui/
│ │ ├── buttonHelpers.js
│ │ ├── cloudflareWarning.js
│ │ ├── entityCard.js
│ │ ├── logModal.js
│ │ └── settingsForm.js
│ ├── ui.js
│ ├── urlUtils.js
│ ├── workerLifecycle.js
│ ├── workerSettings.js
│ └── workerUtils.js
├── workers/
│ ├── __init__.py
│ ├── detection.py
│ ├── process/
│ │ ├── __init__.py
│ │ ├── launch_builder.py
│ │ ├── lifecycle.py
│ │ ├── persistence.py
│ │ └── root_discovery.py
│ ├── process_manager.py
│ ├── startup.py
│ └── worker_monitor.py
└── workflows/
├── distributed-txt2img.json
├── distributed-upscale-video.json
├── distributed-upscale.json
├── distributed-wan-2.2_14b_t2v.json
└── distributed-wan.json
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
# These are supported funding model platforms
github: robertvoy
================================================
FILE: .github/workflows/publish_action.yml
================================================
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- main
paths:
- "pyproject.toml"
permissions:
issues: write
jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'robertvoy' }}
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@main
with:
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
================================================
FILE: .gitignore
================================================
bin/
logs/
gpu_config.json
__pycache__/
**/__pycache__/
*.py[cod]
node_modules/
npm-debug.log*
AGENTS.md
================================================
FILE: .nvmrc
================================================
20
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
> **A powerful extension for ComfyUI that enables distributed and parallel processing across multiple GPUs and machines. Generate more images and videos and accelerate your upscaling workflows by leveraging all available GPU resources in your network and cloud.**

---
## Key Features
#### Parallel Workflow Processing
- Run your workflow on multiple GPUs simultaneously with varied seeds, collect results on the master
- Scale output with more workers
- Supports images and videos
#### Distributed Upscaling
- Accelerate Ultimate SD Upscale by distributing tiles across GPUs
- Intelligent distribution
- Handles single images and videos
#### Ease of Use
- Auto-setup local workers; easily add remote/cloud ones
- Convert any workflow to distributed with 2 nodes
- JSON configuration with UI controls
---
## Worker Types
ComfyUI Distributed supports three types of workers:
- **Local Workers** - Additional GPUs on the same machine (auto-configured on first launch)
- **Remote Workers** - GPUs on other computers in your network
- **Cloud Workers** - GPUs hosted on a cloud service like Runpod, accessible via secure tunnels
> For detailed setup instructions, see the [setup guide](/docs/worker-setup-guides.md)
---
## Requirements
- ComfyUI
- Multiple NVIDIA GPUs
> No additional GPUs? Use [Cloud Workers](https://github.com/robertvoy/ComfyUI-Distributed/blob/main/docs/worker-setup-guides.md#cloud-workers)
- That's it
---
## Installation
1. **Clone this repository** into your ComfyUI custom nodes directory:
```bash
git clone https://github.com/robertvoy/ComfyUI-Distributed.git
```
2. **Restart ComfyUI**
- If you'll be using remote/cloud workers, add `--enable-cors-header` to your launch arguments on the master
3. Read the [setup guide](/docs/worker-setup-guides.md) for adding workers
---
## Official Sponsor
[](https://get.runpod.io/0bw29uf3ug0p)
Join Runpod with [this link](https://get.runpod.io/0bw29uf3ug0p) and unlock a special bonus.
---
## Workflow Examples
### Basic Parallel Generation
Generate multiple images in the time it takes to generate one. Each worker uses a different seed.

> [Download workflow](/workflows/distributed-txt2img.json)
1. Open your ComfyUI workflow
2. Add **Distributed Seed** → connect to sampler's seed
3. Add **Distributed Collector** → after VAE Decode
4. Optional: enable `load_balance` on Distributed Collector to run on one least-busy participant
5. Enable workers in the UI
6. Run the workflow!
### Parallel WAN Generation
Generate multiple videos in the time it takes to generate one. Each worker uses a different seed.

> [Download workflow](/workflows/distributed-wan.json)
1. Open your WAN ComfyUI workflow
2. Add **Distributed Seed** → connect to sampler's seed
3. Add **Distributed Collector** → after VAE Decode
4. Add **Image Batch Divider** → after Distributed Collector
5. Set the `divide_by` to the number of GPUs you have available
> For example: if you have a master and 2x workers, set it to 3
7. Enable workers in the UI
8. Run the workflow!
### Distributed Image Upscaling
Accelerate Ultimate SD Upscaler by distributing tiles across multiple workers, with speed scaling as you add more GPUs.

> [Download workflow](/workflows/distributed-upscale.json)
1. Load your image
2. Upscale with ESRGAN or similar
3. Connect to **Ultimate SD Upscale Distributed**
4. Configure tile settings
5. Enable workers for faster processing
### Distributed Video Upscaling
Accelerate Ultimate SD Upscaler by distributing video tiles across multiple workers, with speed scaling as you add more GPUs.

> [Download workflow](/workflows/distributed-upscale-video.json)
1. Load your video
2. Optional: upscale with ESRGAN or similar
3. Connect to **Ultimate SD Upscale Distributed**
4. Configure tile settings
5. Use RES4LYF (bong/res2) to get better results
6. Enable workers for faster processing
> You can run this workflow entirely on Runpod with minimal setup. [Check out the guide here.](https://github.com/robertvoy/ComfyUI-Distributed/blob/main/docs/video-upscaler-runpod-preset.md)
---
## Developer API
Control your distributed cluster programmatically without opening the browser.
* **Endpoint:** `POST /distributed/queue`
* **Functionality:** Accepts a standard ComfyUI workflow JSON, automatically distributes it to available workers, and returns the execution ID.
* **Documentation:** [See API Examples & Scripts](https://github.com/robertvoy/ComfyUI-Distributed/blob/main/docs/comfyui-distributed-api.md)
> **⚠️ Security Warning:** Do not expose your ComfyUI port to the public internet. If you need remote access, run ComfyUI behind a secure proxy (like Cloudflare or a VPN).
---
## Distributed Value
Use **Distributed Value** when you want per-worker overrides (for example, different prompts/models/settings per worker).
- Output type adapts to the connected input where possible (`STRING`, `INT`, `FLOAT`, `COMBO`).
- The node shows only currently enabled workers.
- If worker enablement changes, worker fields update automatically.
- When disconnected, it resets to default string mode and clears per-worker overrides.
- On execution, master uses `default_value`; workers use their mapped override with typed coercion fallback to default.
---
## Nodes
| Node | Description |
|------|-------------|
| **Distributed Seed** | Generates unique seeds for each worker |
| **Distributed Collector** | Collects results (image/video frames and optionally audio) from workers back to the master; `load_balance` can route the run to one least-busy participant |
| **Distributed Value** | Outputs per-worker override values with fallback to default |
| **Ultimate SD Upscale Distributed** | Distributes upscale tiles across workers |
| **Image Batch Divider** | Splits image batches for multi-GPU output |
| **Audio Batch Divider** | Splits audio batches for multi-GPU output |
| **Distributed Model Name** | Passes model paths to workers, enabling workflows to use models not present on the master in orchestrator-only mode |
| **Distributed Empty Image** | Produces an empty IMAGE batch used when the master delegates all work |
---
## FAQ
Does it combine VRAM of multiple GPUs?
No, it does not combine VRAM of multiple GPUs.
Does it speed up the generation of a single image or video?
No, it does not speed up the generation of a single image or video. Instead, it enables the generation of more images or videos simultaneously. However, it can speed up the upscaling of a single image when using the Ultimate SD Upscale Distributed feature.
Does it work with the ComfyUI desktop app?
Yes, it does now.
Can I combine my RTX 5090 with a GTX 980 to get faster results?
Yes, you can combine different GPUs, but performance is optimized when using similar GPUs. A significant performance imbalance between GPUs may cause bottlenecks.
Does this work with cloud providers?
Yes, it is compatible with cloud providers. Refer to the setup guides for detailed instructions.
Can I use my main machine just to coordinate workers without rendering?
Yes. Open the Distributed panel and uncheck the master toggle to run in orchestrator-only mode. The master will distribute work to workers but won't render locally. If all workers become unavailable, the master automatically re-enables to ensure your workflow still runs.
Can I make this work with my Docker setup?
Yes, it is compatible with Docker setups, but you will need to configure your Docker environment yourself. Unfortunately, assistance with Docker configuration is not provided.
---
## Disclaimer
This software is provided "as is" without any warranties, express or implied, including merchantability, fitness for a particular purpose, or non-infringement. The developers and copyright holders are not liable for any claims, damages, or liabilities arising from the use, modification, or distribution of the software. Users are solely responsible for ensuring compliance with applicable laws and regulations and for securing their networks against unauthorized access, hacking, data breaches, or loss. The developers assume no liability for any damages or incidents resulting from misuse, improper configuration, or external threats.
---
## Support the Project
If my custom nodes have added value to your workflow, consider fueling future development with a coffee!
Your support helps keep this project thriving.
Buy me a coffee at: https://buymeacoffee.com/robertvoy
================================================
FILE: __init__.py
================================================
# Import everything needed from the main module
from .distributed import (
NODE_CLASS_MAPPINGS as DISTRIBUTED_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as DISTRIBUTED_DISPLAY_NAME_MAPPINGS
)
# Import utilities
from .utils.config import ensure_config_exists, CONFIG_FILE
from .utils.logging import debug_log
# Import distributed upscale nodes
from .nodes.distributed_upscale import (
NODE_CLASS_MAPPINGS as UPSCALE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS as UPSCALE_DISPLAY_NAME_MAPPINGS
)
WEB_DIRECTORY = "./web"
ensure_config_exists()
# Merge node mappings
NODE_CLASS_MAPPINGS = {**DISTRIBUTED_CLASS_MAPPINGS, **UPSCALE_CLASS_MAPPINGS}
NODE_DISPLAY_NAME_MAPPINGS = {**DISTRIBUTED_DISPLAY_NAME_MAPPINGS, **UPSCALE_DISPLAY_NAME_MAPPINGS}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
debug_log("Loaded Distributed nodes.")
debug_log(f"Config file: {CONFIG_FILE}")
debug_log(f"Available nodes: {list(NODE_CLASS_MAPPINGS.keys())}")
================================================
FILE: api/__init__.py
================================================
from . import config_routes # noqa: F401
from . import tunnel_routes # noqa: F401
from . import worker_routes # noqa: F401
from . import job_routes # noqa: F401
from . import usdu_routes # noqa: F401
================================================
FILE: api/config_routes.py
================================================
import json
from contextlib import asynccontextmanager
from aiohttp import web
import server
try:
from ..utils.config import config_transaction, load_config, save_config
except ImportError:
from ..utils.config import load_config
try:
from ..utils.config import save_config
except ImportError:
def save_config(_config):
return True
@asynccontextmanager
async def config_transaction():
config = load_config()
original_snapshot = json.dumps(config, sort_keys=True)
yield config
if json.dumps(config, sort_keys=True) != original_snapshot:
save_config(config)
from ..utils.logging import debug_log, log
from ..utils.network import handle_api_error, normalize_host
def _positive_int(value):
return value > 0
CONFIG_SCHEMA = {
"workers": (list, None),
"master": (dict, None),
"settings": (dict, None),
"tunnel": (dict, None),
"managed_processes": (dict, None),
"worker_timeout_seconds": (int, _positive_int),
"debug": (bool, None),
"auto_launch_workers": (bool, None),
"stop_workers_on_master_exit": (bool, None),
"master_delegate_only": (bool, None),
"websocket_orchestration": (bool, None),
"has_auto_populated_workers": (bool, None),
}
_SETTINGS_FIELDS = {
"worker_timeout_seconds",
"debug",
"auto_launch_workers",
"stop_workers_on_master_exit",
"master_delegate_only",
"websocket_orchestration",
"has_auto_populated_workers",
}
_WORKER_FIELDS = [
("enabled", None, False),
("name", None, False),
("port", None, False),
("host", normalize_host, True),
("cuda_device", None, True),
("extra_args", None, True),
("type", None, False),
]
_MASTER_FIELDS = [
("name", None, False),
("host", normalize_host, True),
("port", None, False),
("cuda_device", None, False),
("extra_args", None, False),
]
def _apply_field_patch(target: dict, data: dict, field_rules: list) -> None:
"""Apply a partial update to a target dict based on field rules."""
for key, normalizer, remove_on_none in field_rules:
if key not in data:
continue
value = data[key]
if value is None and remove_on_none:
target.pop(key, None)
else:
target[key] = normalizer(value) if (normalizer and value is not None) else value
@server.PromptServer.instance.routes.get("/distributed/config")
async def get_config_endpoint(request):
config = load_config()
return web.json_response(config)
@server.PromptServer.instance.routes.post("/distributed/config")
async def update_config_endpoint(request):
"""Bulk config update with schema validation."""
try:
data = await request.json()
except Exception as e:
return await handle_api_error(request, f"Invalid JSON payload: {e}", 400)
if not isinstance(data, dict):
return await handle_api_error(request, "Config payload must be an object", 400)
validated_settings = {}
validated_root = {}
errors = []
for key, value in data.items():
if key not in CONFIG_SCHEMA:
errors.append(f"Unknown field: {key}")
continue
expected_type, validator = CONFIG_SCHEMA[key]
if not isinstance(value, expected_type):
errors.append(f"{key}: expected {expected_type.__name__}")
continue
if validator and not validator(value):
errors.append(f"{key}: value {value!r} failed validation")
continue
if key in _SETTINGS_FIELDS:
validated_settings[key] = value
else:
validated_root[key] = value
if errors:
return web.json_response({
"status": "error",
"error": errors,
"message": "; ".join(errors),
}, status=400)
try:
async with config_transaction() as config:
settings = config.setdefault("settings", {})
settings.update(validated_settings)
for key, value in validated_root.items():
config[key] = value
return web.json_response({"status": "success", "config": config})
except Exception as e:
return await handle_api_error(request, e)
@server.PromptServer.instance.routes.get("/distributed/queue_status/{job_id}")
async def queue_status_endpoint(request):
"""Check if a job queue is initialized."""
try:
job_id = request.match_info['job_id']
# Import to ensure initialization
from ..upscale.job_store import ensure_tile_jobs_initialized
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
exists = job_id in prompt_server.distributed_pending_tile_jobs
debug_log(f"Queue status check for job {job_id}: {'exists' if exists else 'not found'}")
return web.json_response({"exists": exists, "job_id": job_id})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/config/update_worker")
async def update_worker_endpoint(request):
try:
data = await request.json()
worker_id = data.get("worker_id")
if worker_id is None:
return await handle_api_error(request, "Missing worker_id", 400)
async with config_transaction() as config:
worker_found = False
workers = config.setdefault("workers", [])
for worker in workers:
if worker["id"] == worker_id:
_apply_field_patch(worker, data, _WORKER_FIELDS)
worker_found = True
break
if not worker_found:
# If worker not found and all required fields are provided, create new worker
if all(key in data for key in ["name", "port", "cuda_device"]):
new_worker = {
"id": worker_id,
"name": data["name"],
"host": normalize_host(data.get("host", "localhost")),
"port": data["port"],
"cuda_device": data["cuda_device"],
"enabled": data.get("enabled", False),
"extra_args": data.get("extra_args", ""),
"type": data.get("type", "local")
}
workers.append(new_worker)
else:
return await handle_api_error(
request,
f"Worker {worker_id} not found and missing required fields for creation",
404,
)
return web.json_response({"status": "success"})
except Exception as e:
return await handle_api_error(request, e, 400)
@server.PromptServer.instance.routes.post("/distributed/config/delete_worker")
async def delete_worker_endpoint(request):
try:
data = await request.json()
worker_id = data.get("worker_id")
if worker_id is None:
return await handle_api_error(request, "Missing worker_id", 400)
async with config_transaction() as config:
workers = config.get("workers", [])
# Find and remove the worker
worker_index = -1
for i, worker in enumerate(workers):
if worker["id"] == worker_id:
worker_index = i
break
if worker_index == -1:
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
# Remove the worker
removed_worker = workers.pop(worker_index)
return web.json_response({
"status": "success",
"message": f"Worker {removed_worker.get('name', worker_id)} deleted"
})
except Exception as e:
return await handle_api_error(request, e, 400)
@server.PromptServer.instance.routes.post("/distributed/config/update_setting")
async def update_setting_endpoint(request):
"""Updates a specific key in the settings object."""
try:
data = await request.json()
key = data.get("key")
value = data.get("value")
if not key or value is None:
return await handle_api_error(request, "Missing 'key' or 'value' in request", 400)
if key not in _SETTINGS_FIELDS:
return await handle_api_error(request, f"Unknown setting: {key}", 400)
async with config_transaction() as config:
if 'settings' not in config:
config['settings'] = {}
config['settings'][key] = value
return web.json_response({"status": "success", "message": f"Setting '{key}' updated."})
except Exception as e:
return await handle_api_error(request, e, 400)
@server.PromptServer.instance.routes.post("/distributed/config/update_master")
async def update_master_endpoint(request):
"""Updates master configuration."""
try:
data = await request.json()
async with config_transaction() as config:
if 'master' not in config:
config['master'] = {}
_apply_field_patch(config['master'], data, _MASTER_FIELDS)
return web.json_response({"status": "success", "message": "Master configuration updated."})
except Exception as e:
return await handle_api_error(request, e, 400)
================================================
FILE: api/job_routes.py
================================================
import json
import asyncio
import io
import os
import base64
import binascii
import time
from aiohttp import web
import server
import torch
from PIL import Image
from ..utils.logging import debug_log
from ..utils.image import pil_to_tensor, ensure_contiguous
from ..utils.network import handle_api_error
from ..utils.constants import JOB_INIT_GRACE_PERIOD, MEMORY_CLEAR_DELAY
try:
from .queue_orchestration import ensure_distributed_state, orchestrate_distributed_execution
except ImportError:
from .queue_orchestration import orchestrate_distributed_execution
def ensure_distributed_state():
return None
from .queue_request import parse_queue_request_payload
prompt_server = server.PromptServer.instance
# Canonical worker result envelope accepted by POST /distributed/job_complete:
# { "job_id": str, "worker_id": str, "batch_idx": int, "image": , "is_last": bool }
def _decode_image_sync(image_path):
"""Decode image/video file and compute hash in a threadpool worker."""
import base64
import hashlib
import folder_paths
full_path = folder_paths.get_annotated_filepath(image_path)
if not os.path.exists(full_path):
raise FileNotFoundError(image_path)
hash_md5 = hashlib.md5()
with open(full_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
file_hash = hash_md5.hexdigest()
video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'}
file_ext = os.path.splitext(full_path)[1].lower()
if file_ext in video_extensions:
with open(full_path, 'rb') as f:
file_data = f.read()
mime_types = {
'.mp4': 'video/mp4',
'.avi': 'video/x-msvideo',
'.mov': 'video/quicktime',
'.mkv': 'video/x-matroska',
'.webm': 'video/webm'
}
mime_type = mime_types.get(file_ext, 'video/mp4')
image_data = f"data:{mime_type};base64,{base64.b64encode(file_data).decode('utf-8')}"
else:
with Image.open(full_path) as img:
if img.mode not in ('RGB', 'RGBA'):
img = img.convert('RGB')
buffer = io.BytesIO()
img.save(buffer, format='PNG', compress_level=1)
image_data = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
return {
"status": "success",
"image_data": image_data,
"hash": file_hash,
}
def _check_file_sync(filename, expected_hash):
"""Check file presence and hash in a threadpool worker."""
import hashlib
import folder_paths
full_path = folder_paths.get_annotated_filepath(filename)
if not os.path.exists(full_path):
return {
"status": "success",
"exists": False,
}
hash_md5 = hashlib.md5()
with open(full_path, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
file_hash = hash_md5.hexdigest()
return {
"status": "success",
"exists": True,
"hash_matches": file_hash == expected_hash,
}
def _decode_canonical_png_tensor(image_payload):
"""Decode canonical base64 PNG payload into a contiguous IMAGE tensor."""
if not isinstance(image_payload, str) or not image_payload.strip():
raise ValueError("Field 'image' must be a non-empty base64 PNG string.")
encoded = image_payload.strip()
if encoded.startswith("data:"):
header, sep, data_part = encoded.partition(",")
if not sep:
raise ValueError("Field 'image' data URL is malformed.")
if not header.lower().startswith("data:image/png;base64"):
raise ValueError("Field 'image' must be a PNG data URL when using data:* format.")
encoded = data_part
try:
png_bytes = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as exc:
raise ValueError("Field 'image' is not valid base64 PNG data.") from exc
if not png_bytes:
raise ValueError("Field 'image' decoded to empty PNG data.")
try:
with Image.open(io.BytesIO(png_bytes)) as img:
img = img.convert("RGB")
tensor = pil_to_tensor(img)
return ensure_contiguous(tensor)
except Exception as exc:
raise ValueError(f"Failed to decode PNG image payload: {exc}") from exc
def _decode_audio_payload(audio_payload):
"""Decode canonical audio payload into an AUDIO dict."""
from ..utils.audio_payload import decode_audio_payload
return decode_audio_payload(audio_payload)
@server.PromptServer.instance.routes.post("/distributed/prepare_job")
async def prepare_job_endpoint(request):
try:
data = await request.json()
multi_job_id = data.get('multi_job_id')
if not multi_job_id:
return await handle_api_error(request, "Missing multi_job_id", 400)
ensure_distributed_state()
async with prompt_server.distributed_jobs_lock:
if multi_job_id not in prompt_server.distributed_pending_jobs:
prompt_server.distributed_pending_jobs[multi_job_id] = asyncio.Queue()
debug_log(f"Prepared queue for job {multi_job_id}")
return web.json_response({"status": "success"})
except Exception as e:
return await handle_api_error(request, e)
@server.PromptServer.instance.routes.post("/distributed/clear_memory")
async def clear_memory_endpoint(request):
debug_log("Received request to clear VRAM.")
try:
# Use ComfyUI's prompt server queue system like the /free endpoint does
if hasattr(server.PromptServer.instance, 'prompt_queue'):
server.PromptServer.instance.prompt_queue.set_flag("unload_models", True)
server.PromptServer.instance.prompt_queue.set_flag("free_memory", True)
debug_log("Set queue flags for memory clearing.")
# Wait a bit for the queue to process
await asyncio.sleep(MEMORY_CLEAR_DELAY)
# Also do direct cleanup as backup, but with error handling
import gc
import comfy.model_management as mm
try:
mm.unload_all_models()
except AttributeError as e:
debug_log(f"Warning during model unload: {e}")
try:
mm.soft_empty_cache()
except Exception as e:
debug_log(f"Warning during cache clear: {e}")
for _ in range(3):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
debug_log("VRAM cleared successfully.")
return web.json_response({"status": "success", "message": "GPU memory cleared."})
except Exception as e:
# Even if there's an error, try to do basic cleanup
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
debug_log(f"Partial VRAM clear completed with warning: {e}")
return web.json_response({"status": "success", "message": "GPU memory cleared (with warnings)"})
@server.PromptServer.instance.routes.post("/distributed/queue")
async def distributed_queue_endpoint(request):
"""Queue a distributed workflow, mirroring the UI orchestration pipeline."""
try:
raw_payload = await request.json()
except Exception as exc:
return await handle_api_error(request, f"Invalid JSON payload: {exc}", 400)
try:
payload = parse_queue_request_payload(raw_payload)
except ValueError as exc:
return await handle_api_error(request, exc, 400)
try:
prompt_id, prompt_number, worker_count, node_errors = await orchestrate_distributed_execution(
payload.prompt,
payload.workflow_meta,
payload.client_id,
enabled_worker_ids=payload.enabled_worker_ids,
delegate_master=payload.delegate_master,
trace_execution_id=payload.trace_execution_id,
)
return web.json_response({
"prompt_id": prompt_id,
"number": prompt_number,
"node_errors": node_errors,
"worker_count": worker_count,
"auto_prepare_supported": True,
})
except Exception as exc:
return await handle_api_error(request, exc, 500)
@server.PromptServer.instance.routes.post("/distributed/load_image")
async def load_image_endpoint(request):
"""Load an image or video file and return it as base64 data with hash."""
try:
data = await request.json()
image_path = data.get("image_path")
if not image_path:
return await handle_api_error(request, "Missing image_path", 400)
loop = asyncio.get_running_loop()
payload = await loop.run_in_executor(None, _decode_image_sync, image_path)
return web.json_response(payload)
except FileNotFoundError:
return await handle_api_error(request, f"File not found: {image_path}", 404)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/check_file")
async def check_file_endpoint(request):
"""Check if a file exists and matches the given hash."""
try:
data = await request.json()
filename = data.get("filename")
expected_hash = data.get("hash")
if not filename or not expected_hash:
return await handle_api_error(request, "Missing filename or hash", 400)
loop = asyncio.get_running_loop()
payload = await loop.run_in_executor(None, _check_file_sync, filename, expected_hash)
return web.json_response(payload)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/job_complete")
async def job_complete_endpoint(request):
try:
data = await request.json()
except Exception as exc:
return await handle_api_error(request, f"Invalid JSON payload: {exc}", 400)
if not isinstance(data, dict):
return await handle_api_error(request, "Expected a JSON object body", 400)
try:
job_id = data.get("job_id")
worker_id = data.get("worker_id")
batch_idx = data.get("batch_idx")
image_payload = data.get("image")
audio_payload = data.get("audio")
is_last = data.get("is_last")
errors = []
if not isinstance(job_id, str) or not job_id.strip():
errors.append("job_id: expected non-empty string")
if not isinstance(worker_id, str) or not worker_id.strip():
errors.append("worker_id: expected non-empty string")
if not isinstance(batch_idx, int) or batch_idx < 0:
errors.append("batch_idx: expected non-negative integer")
if not isinstance(image_payload, str) or not image_payload.strip():
errors.append("image: expected non-empty base64 PNG string")
if audio_payload is not None and not isinstance(audio_payload, dict):
errors.append("audio: expected object when provided")
if not isinstance(is_last, bool):
errors.append("is_last: expected boolean")
if errors:
return await handle_api_error(request, errors, 400)
tensor = _decode_canonical_png_tensor(image_payload)
decoded_audio = _decode_audio_payload(audio_payload) if audio_payload is not None else None
multi_job_id = job_id.strip()
worker_id = worker_id.strip()
pending = None
queue_size = 0
deadline = time.monotonic() + float(JOB_INIT_GRACE_PERIOD)
while pending is None:
async with prompt_server.distributed_jobs_lock:
pending = prompt_server.distributed_pending_jobs.get(multi_job_id)
if pending is not None:
await pending.put(
{
"tensor": tensor,
"worker_id": worker_id,
"image_index": int(batch_idx),
"is_last": is_last,
"audio": decoded_audio,
}
)
queue_size = pending.qsize()
break
if time.monotonic() > deadline:
return await handle_api_error(request, "job not initialized", 404)
await asyncio.sleep(0.05)
debug_log(
f"job_complete received canonical envelope - job_id: {multi_job_id}, "
f"worker: {worker_id}, batch_idx: {batch_idx}, is_last: {is_last}, "
f"queue_size: {queue_size}"
)
return web.json_response({"status": "success"})
except Exception as e:
return await handle_api_error(request, e)
================================================
FILE: api/orchestration/__init__.py
================================================
# Orchestration helpers split out from queue_orchestration.py:
# - prompt_transform.py: graph pruning + hidden input overrides
# - media_sync.py: remote media/path normalization
# - dispatch.py: worker probe + prompt dispatch
================================================
FILE: api/orchestration/dispatch.py
================================================
import asyncio
import json
import uuid
import aiohttp
from ...utils.logging import debug_log, log
from ...utils.network import build_worker_url, get_client_session, probe_worker
try:
from ...utils.trace_logger import trace_debug, trace_info
except ImportError:
def trace_debug(*_args, **_kwargs):
return None
def trace_info(*_args, **_kwargs):
return None
try:
from ..schemas import parse_positive_int
except ImportError:
def parse_positive_int(value, default):
try:
parsed = int(value)
return parsed if parsed > 0 else default
except (TypeError, ValueError):
return default
_least_busy_rr_index = 0
async def worker_is_active(worker):
"""Ping worker's /prompt endpoint to confirm it's reachable."""
url = build_worker_url(worker)
return await probe_worker(url, timeout=3.0) is not None
async def worker_ws_is_active(worker):
"""Ping worker's websocket endpoint to confirm it's reachable."""
session = await get_client_session()
url = build_worker_url(worker, "/distributed/worker_ws")
try:
ws = await session.ws_connect(url, heartbeat=20, timeout=3)
await ws.close()
return True
except asyncio.TimeoutError:
debug_log(f"[Distributed] Worker WS probe timed out: {url}")
return False
except aiohttp.ClientConnectorError:
debug_log(f"[Distributed] Worker WS unreachable: {url}")
return False
except Exception as e:
debug_log(f"[Distributed] Worker WS probe unexpected error: {e}")
return False
async def _probe_worker_active(worker, use_websocket, semaphore):
async with semaphore:
is_active = await (worker_ws_is_active(worker) if use_websocket else worker_is_active(worker))
return worker, is_active
async def _dispatch_via_websocket(worker_url, payload, client_id, timeout=60.0):
"""Open a fresh worker websocket, dispatch one prompt, wait for ack, then close."""
request_id = uuid.uuid4().hex
ws_payload = {
"type": "dispatch_prompt",
"request_id": request_id,
"prompt": payload.get("prompt"),
"workflow": payload.get("workflow"),
"client_id": client_id,
}
ws_url = worker_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/distributed/worker_ws"
session = await get_client_session()
async with session.ws_connect(ws_url, heartbeat=20, timeout=timeout) as ws:
await ws.send_json(ws_payload)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data or "{}")
if data.get("type") == "dispatch_ack" and data.get("request_id") == request_id:
if data.get("ok"):
return
error_text = data.get("error") or "Worker rejected websocket dispatch."
validation_error = data.get("validation_error")
node_errors = data.get("node_errors")
if validation_error:
error_text = f"{error_text} | validation_error={validation_error}"
if node_errors:
error_text = f"{error_text} | node_errors={node_errors}"
raise RuntimeError(error_text)
elif msg.type in (aiohttp.WSMsgType.ERROR, aiohttp.WSMsgType.CLOSED):
raise RuntimeError(f"Worker websocket closed unexpectedly: {msg.type}")
raise RuntimeError("Worker websocket closed before dispatch_ack was received.")
async def dispatch_worker_prompt(
worker,
prompt_obj,
workflow_meta,
client_id=None,
use_websocket=False,
trace_execution_id=None,
):
"""Send the prepared prompt to a worker ComfyUI instance."""
worker_url = build_worker_url(worker)
url = build_worker_url(worker, "/prompt")
payload = {"prompt": prompt_obj}
extra_data = {}
if workflow_meta:
extra_data.setdefault("extra_pnginfo", {})["workflow"] = workflow_meta
if extra_data:
payload["extra_data"] = extra_data
if use_websocket:
try:
await _dispatch_via_websocket(
worker_url,
{
"prompt": prompt_obj,
"workflow": workflow_meta,
},
client_id,
)
return
except Exception as exc:
worker_id = worker.get("id")
if trace_execution_id:
trace_info(trace_execution_id, f"Websocket dispatch failed for worker {worker_id}: {exc}")
else:
log(f"[Distributed] Websocket dispatch failed for worker {worker_id}: {exc}")
raise
session = await get_client_session()
async with session.post(
url,
json=payload,
timeout=aiohttp.ClientTimeout(total=60),
) as resp:
resp.raise_for_status()
async def select_active_workers(
workers,
use_websocket,
delegate_master,
trace_execution_id=None,
probe_concurrency=8,
):
"""Probe workers and return (active_workers, updated_delegate_master)."""
probe_limit = parse_positive_int(probe_concurrency, 8)
probe_semaphore = asyncio.Semaphore(probe_limit)
if trace_execution_id and workers:
trace_debug(
trace_execution_id,
f"Probing {len(workers)} workers with probe_concurrency={probe_limit}",
)
probe_results = await asyncio.gather(
*[
_probe_worker_active(worker, use_websocket, probe_semaphore)
for worker in workers
]
)
active_workers = []
for worker, is_active in probe_results:
if is_active:
active_workers.append(worker)
else:
if trace_execution_id:
trace_info(trace_execution_id, f"Worker {worker['name']} ({worker['id']}) is offline, skipping.")
else:
log(f"[Distributed] Worker {worker['name']} ({worker['id']}) is offline, skipping.")
if trace_execution_id and workers:
trace_debug(
trace_execution_id,
f"Worker probe complete: active={len(active_workers)}/{len(workers)}",
)
if not active_workers and delegate_master:
if trace_execution_id:
trace_debug(trace_execution_id, "All workers offline while delegate-only requested; enabling master participation.")
else:
debug_log("All workers offline while delegate-only requested; enabling master participation.")
delegate_master = False
return active_workers, delegate_master
def _extract_queue_remaining(payload):
if not isinstance(payload, dict):
return 0
try:
queue_remaining = int(payload.get("exec_info", {}).get("queue_remaining", 0))
except (TypeError, ValueError):
queue_remaining = 0
return max(queue_remaining, 0)
async def _probe_worker_queue(worker, semaphore, probe_timeout):
async with semaphore:
worker_url = build_worker_url(worker)
payload = await probe_worker(worker_url, timeout=probe_timeout)
if payload is None:
return None
return {
"worker": worker,
"queue_remaining": _extract_queue_remaining(payload),
}
def _select_idle_round_robin(statuses):
global _least_busy_rr_index
if not statuses:
return None
index = _least_busy_rr_index % len(statuses)
_least_busy_rr_index += 1
return statuses[index]
async def select_least_busy_worker(
workers,
trace_execution_id=None,
probe_concurrency=8,
probe_timeout=3.0,
):
"""Select one worker by queue depth, round-robin among idle workers."""
if not workers:
return None
probe_limit = parse_positive_int(probe_concurrency, 8)
probe_semaphore = asyncio.Semaphore(probe_limit)
statuses = await asyncio.gather(
*[
_probe_worker_queue(worker, probe_semaphore, probe_timeout)
for worker in workers
]
)
statuses = [status for status in statuses if status is not None]
if not statuses:
if trace_execution_id:
trace_info(trace_execution_id, "Least-busy selection failed: no worker queue probes succeeded.")
else:
log("[Distributed] Least-busy selection failed: no worker queue probes succeeded.")
return None
idle_statuses = [status for status in statuses if status["queue_remaining"] == 0]
if idle_statuses:
selected = _select_idle_round_robin(idle_statuses)
else:
selected = min(statuses, key=lambda status: status["queue_remaining"])
worker = selected["worker"]
queue_remaining = selected["queue_remaining"]
if trace_execution_id:
trace_debug(
trace_execution_id,
f"Least-busy worker selected: {worker.get('name')} ({worker.get('id')}), queue_remaining={queue_remaining}",
)
else:
debug_log(
f"Least-busy worker selected: {worker.get('name')} ({worker.get('id')}), queue_remaining={queue_remaining}"
)
return worker
================================================
FILE: api/orchestration/media_sync.py
================================================
import asyncio
import hashlib
import mimetypes
import os
import re
import aiohttp
from ...utils.logging import debug_log, log
from ...utils.network import build_worker_url, get_client_session
from ...utils.trace_logger import trace_debug, trace_info
LIKELY_FILENAME_RE = re.compile(
r"\.(ckpt|safetensors|pt|pth|bin|yaml|json|png|jpg|jpeg|webp|gif|bmp|mp4|avi|mov|mkv|webm|"
r"wav|mp3|flac|m4a|aac|ogg|opus|aiff|aif|wma|latent|txt|vae|lora|embedding)"
r"(\s*\[\w+\])?$",
re.IGNORECASE,
)
MEDIA_FILE_RE = re.compile(
r"\.(png|jpg|jpeg|webp|gif|bmp|mp4|avi|mov|mkv|webm|wav|mp3|flac|m4a|aac|ogg|opus|aiff|aif|wma)(\s*\[\w+\])?$",
re.IGNORECASE,
)
def _normalize_media_reference(value):
"""Normalize one media string value to a path-like reference or None."""
if not isinstance(value, str):
return None
cleaned = re.sub(r"\s*\[\w+\]$", "", value).strip().replace("\\", "/")
if MEDIA_FILE_RE.search(cleaned):
return cleaned
return None
def convert_paths_for_platform(obj, target_separator):
"""Recursively normalize likely file paths for the worker platform separator."""
if target_separator not in ("/", "\\"):
return obj
def _convert(value):
if isinstance(value, str):
if ("/" in value or "\\" in value) and LIKELY_FILENAME_RE.search(value):
trimmed = value.strip()
has_drive = bool(re.match(r"^[A-Za-z]:(\\\\|/)", trimmed))
is_absolute = trimmed.startswith("/") or trimmed.startswith("\\\\")
has_protocol = bool(re.match(r"^\w+://", trimmed))
# URLs are not local paths and should never be separator-normalized.
if has_protocol:
return trimmed
# Keep relative media-style paths in forward-slash form (Comfy-style annotated paths).
if not has_drive and not is_absolute and not has_protocol and MEDIA_FILE_RE.search(trimmed):
return re.sub(r"[\\]+", "/", trimmed)
if target_separator == "\\":
return re.sub(r"[\\/]+", r"\\", trimmed)
return re.sub(r"[\\/]+", "/", trimmed)
return value
if isinstance(value, list):
return [_convert(item) for item in value]
if isinstance(value, dict):
return {key: _convert(item) for key, item in value.items()}
return value
return _convert(obj)
def _find_media_references(prompt_obj):
"""Find media file references in image/video/audio/file inputs used by worker prompts."""
media_refs = set()
for node in prompt_obj.values():
if not isinstance(node, dict):
continue
inputs = node.get("inputs", {})
for key in ("image", "video", "audio", "file"):
cleaned = _normalize_media_reference(inputs.get(key))
if cleaned:
media_refs.add(cleaned)
return sorted(media_refs)
def _rewrite_prompt_media_inputs(prompt_obj, worker_media_paths):
"""Rewrite media string inputs to worker-local uploaded paths."""
if not isinstance(worker_media_paths, dict) or not worker_media_paths:
return
for node in prompt_obj.values():
if not isinstance(node, dict):
continue
inputs = node.get("inputs", {})
if not isinstance(inputs, dict):
continue
for key in ("image", "video", "audio", "file"):
value = inputs.get(key)
cleaned = _normalize_media_reference(value)
if not cleaned:
continue
worker_path = worker_media_paths.get(cleaned)
if worker_path:
inputs[key] = worker_path
def _load_media_file_sync(filename):
"""Load local media bytes and hash for worker upload sync."""
import folder_paths
full_path = folder_paths.get_annotated_filepath(filename)
if not os.path.exists(full_path):
raise FileNotFoundError(filename)
with open(full_path, "rb") as f:
file_bytes = f.read()
file_hash = hashlib.md5(file_bytes).hexdigest()
mime_type = mimetypes.guess_type(full_path)[0]
if not mime_type:
ext = os.path.splitext(full_path)[1].lower()
if ext in {".mp4", ".avi", ".mov", ".mkv", ".webm"}:
mime_type = "video/mp4"
else:
mime_type = "image/png"
return file_bytes, file_hash, mime_type
async def fetch_worker_path_separator(worker, trace_execution_id=None):
"""Best-effort fetch of a worker's path separator from /distributed/system_info."""
url = build_worker_url(worker, "/distributed/system_info")
session = await get_client_session()
try:
async with session.get(url, timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status != 200:
return None
payload = await resp.json()
separator = ((payload or {}).get("platform") or {}).get("path_separator")
return separator if separator in ("/", "\\") else None
except Exception as exc:
if trace_execution_id:
trace_debug(trace_execution_id, f"Failed to fetch worker system info ({worker.get('id')}): {exc}")
else:
debug_log(f"[Distributed] Failed to fetch worker system info ({worker.get('id')}): {exc}")
return None
async def _upload_media_to_worker(worker, filename, file_bytes, file_hash, mime_type, trace_execution_id=None):
"""Upload one media file to worker iff missing or hash-mismatched."""
session = await get_client_session()
normalized = filename.replace("\\", "/")
check_url = build_worker_url(worker, "/distributed/check_file")
try:
async with session.post(
check_url,
json={"filename": normalized, "hash": file_hash},
timeout=aiohttp.ClientTimeout(total=6),
) as resp:
if resp.status == 200:
payload = await resp.json()
if payload.get("exists") and payload.get("hash_matches"):
return False, normalized
except Exception as exc:
if trace_execution_id:
trace_debug(trace_execution_id, f"Media check failed for '{normalized}' on worker {worker.get('id')}: {exc}")
else:
debug_log(f"[Distributed] Media check failed for '{normalized}' on worker {worker.get('id')}: {exc}")
parts = normalized.split("/")
clean_name = parts[-1]
subfolder = "/".join(parts[:-1])
form = aiohttp.FormData()
form.add_field("image", file_bytes, filename=clean_name, content_type=mime_type)
form.add_field("type", "input")
form.add_field("subfolder", subfolder)
form.add_field("overwrite", "true")
upload_url = build_worker_url(worker, "/upload/image")
async with session.post(
upload_url,
data=form,
timeout=aiohttp.ClientTimeout(total=30),
) as resp:
resp.raise_for_status()
try:
payload = await resp.json()
except Exception:
payload = {}
name = str((payload or {}).get("name") or clean_name).strip()
subfolder = str((payload or {}).get("subfolder") or "").strip().replace("\\", "/").strip("/")
worker_path = f"{subfolder}/{name}" if subfolder else name
return True, worker_path
async def sync_worker_media(worker, prompt_obj, trace_execution_id=None):
"""Sync referenced media files from master to a remote worker before dispatch."""
media_refs = _find_media_references(prompt_obj)
if not media_refs:
return
loop = asyncio.get_running_loop()
uploaded = 0
skipped = 0
missing = 0
worker_media_paths = {}
for filename in media_refs:
try:
file_bytes, file_hash, mime_type = await loop.run_in_executor(
None, _load_media_file_sync, filename
)
except FileNotFoundError:
missing += 1
if trace_execution_id:
trace_info(trace_execution_id, f"Media file '{filename}' not found on master; worker may fail to load it.")
else:
log(f"[Distributed] Media file '{filename}' not found on master; worker may fail to load it.")
continue
except Exception as exc:
if trace_execution_id:
trace_info(trace_execution_id, f"Failed to load media '{filename}' for worker sync: {exc}")
else:
log(f"[Distributed] Failed to load media '{filename}' for worker sync: {exc}")
continue
try:
changed, worker_path = await _upload_media_to_worker(
worker,
filename,
file_bytes,
file_hash,
mime_type,
trace_execution_id=trace_execution_id,
)
if worker_path:
worker_media_paths[filename] = worker_path
if changed:
uploaded += 1
else:
skipped += 1
except Exception as exc:
if trace_execution_id:
trace_info(trace_execution_id, f"Failed to upload media '{filename}' to worker {worker.get('id')}: {exc}")
else:
log(f"[Distributed] Failed to upload media '{filename}' to worker {worker.get('id')}: {exc}")
_rewrite_prompt_media_inputs(prompt_obj, worker_media_paths)
summary = (
f"Media sync for worker {worker.get('id')}: "
f"uploaded={uploaded}, skipped={skipped}, missing={missing}, referenced={len(media_refs)}"
)
if trace_execution_id:
trace_debug(trace_execution_id, summary)
else:
debug_log(f"[Distributed] {summary}")
================================================
FILE: api/orchestration/prompt_transform.py
================================================
import json
from collections import deque
from ...utils.logging import debug_log
class PromptIndex:
"""Cache prompt metadata for faster worker/master prompt preparation."""
def __init__(self, prompt_obj):
self._prompt_json = json.dumps(prompt_obj)
self.nodes_by_class = {}
self.class_by_node = {}
self.inputs_by_node = {}
for node_id, node in _iter_prompt_nodes(prompt_obj):
class_type = node.get("class_type")
node_id_str = str(node_id)
if class_type:
self.nodes_by_class.setdefault(class_type, []).append(node_id_str)
self.class_by_node[node_id_str] = class_type
self.inputs_by_node[node_id_str] = node.get("inputs", {})
self._upstream_cache = {}
def copy_prompt(self):
return json.loads(self._prompt_json)
def nodes_for_class(self, class_name):
return self.nodes_by_class.get(class_name, [])
def has_upstream(self, start_node_id, target_class):
cache_key = (str(start_node_id), target_class)
if cache_key in self._upstream_cache:
return self._upstream_cache[cache_key]
visited = set()
stack = [str(start_node_id)]
while stack:
node_id = stack.pop()
if node_id in visited:
continue
visited.add(node_id)
inputs = self.inputs_by_node.get(node_id, {})
for value in inputs.values():
if isinstance(value, list) and len(value) == 2:
upstream_id = str(value[0])
if self.class_by_node.get(upstream_id) == target_class:
self._upstream_cache[cache_key] = True
return True
if upstream_id in self.inputs_by_node:
stack.append(upstream_id)
self._upstream_cache[cache_key] = False
return False
def _iter_prompt_nodes(prompt_obj):
for node_id, node in prompt_obj.items():
if isinstance(node, dict):
yield str(node_id), node
def find_nodes_by_class(prompt_obj, class_name):
nodes = []
for node_id, node in _iter_prompt_nodes(prompt_obj):
if node.get("class_type") == class_name:
nodes.append(node_id)
return nodes
def _find_downstream_nodes(prompt_obj, start_ids):
"""Return all nodes reachable downstream from the provided IDs."""
adjacency = {}
for node_id, node in _iter_prompt_nodes(prompt_obj):
inputs = node.get("inputs", {})
for value in inputs.values():
if isinstance(value, list) and len(value) == 2:
source_id = str(value[0])
adjacency.setdefault(source_id, set()).add(str(node_id))
connected = set(start_ids)
queue = deque(start_ids)
while queue:
current = queue.popleft()
for dependent in adjacency.get(current, ()): # pragma: no branch - simple iteration
if dependent not in connected:
connected.add(dependent)
queue.append(dependent)
return connected
def _create_numeric_id_generator(prompt_obj):
"""Return a closure that yields new numeric string IDs."""
max_id = 0
for node_id in prompt_obj.keys():
try:
numeric = int(node_id)
except (TypeError, ValueError):
continue
max_id = max(max_id, numeric)
counter = max_id
def _next_id():
nonlocal counter
counter += 1
return str(counter)
return _next_id
def _find_upstream_nodes(prompt_obj, start_ids):
"""Return all nodes reachable upstream from start_ids, including start nodes."""
connected = set(str(node_id) for node_id in start_ids)
queue = deque(connected)
while queue:
node_id = queue.popleft()
node = prompt_obj.get(node_id) or {}
inputs = node.get("inputs", {})
for value in inputs.values():
if isinstance(value, list) and len(value) == 2:
source_id = str(value[0])
if source_id in prompt_obj and source_id not in connected:
connected.add(source_id)
queue.append(source_id)
return connected
def prune_prompt_for_worker(prompt_obj):
"""Prune worker prompt to distributed nodes and their upstream dependencies."""
collector_ids = find_nodes_by_class(prompt_obj, "DistributedCollector")
upscale_ids = find_nodes_by_class(prompt_obj, "UltimateSDUpscaleDistributed")
distributed_ids = collector_ids + upscale_ids
if not distributed_ids:
return prompt_obj
connected = _find_upstream_nodes(prompt_obj, distributed_ids)
pruned_prompt = {}
for node_id in connected:
node = prompt_obj.get(node_id)
if node is not None:
pruned_prompt[node_id] = json.loads(json.dumps(node))
# Generate IDs from the original prompt so we never reuse IDs from pruned downstream nodes.
next_id = _create_numeric_id_generator(prompt_obj)
for dist_id in distributed_ids:
if dist_id not in pruned_prompt:
continue
downstream = _find_downstream_nodes(prompt_obj, [dist_id])
has_removed_downstream = any(node_id != dist_id for node_id in downstream)
if has_removed_downstream:
preview_id = next_id()
pruned_prompt[preview_id] = {
"inputs": {
"images": [dist_id, 0],
},
"class_type": "PreviewImage",
"_meta": {
"title": "Preview Image (auto-added)",
},
}
return pruned_prompt
def prepare_delegate_master_prompt(prompt_obj, collector_ids):
"""Prune master prompt so it only executes post-collector nodes in delegate mode."""
downstream = _find_downstream_nodes(prompt_obj, collector_ids)
nodes_to_keep = set(collector_ids)
nodes_to_keep.update(downstream)
pruned_prompt = {}
for node_id in nodes_to_keep:
node = prompt_obj.get(node_id)
if node is not None:
pruned_prompt[node_id] = json.loads(json.dumps(node))
pruned_ids = set(pruned_prompt.keys())
for node_id, node in pruned_prompt.items():
inputs = node.get("inputs")
if not inputs:
continue
for input_name, input_value in list(inputs.items()):
if isinstance(input_value, list) and len(input_value) == 2:
source_id = str(input_value[0])
if source_id not in pruned_ids:
inputs.pop(input_name, None)
debug_log(
f"Removed upstream reference '{input_name}' from node {node_id} for delegate-only master prompt."
)
# Generate IDs from the original prompt to avoid ID collisions with pruned nodes.
next_id = _create_numeric_id_generator(prompt_obj)
for collector_id in collector_ids:
collector_entry = pruned_prompt.get(collector_id)
if not collector_entry:
continue
placeholder_id = next_id()
pruned_prompt[placeholder_id] = {
"class_type": "DistributedEmptyImage",
"inputs": {
"height": 64,
"width": 64,
"channels": 3,
},
"_meta": {
"title": "Distributed Empty Image (auto-added)",
},
}
collector_entry.setdefault("inputs", {})["images"] = [placeholder_id, 0]
debug_log(
f"Inserted placeholder node {placeholder_id} for collector {collector_id} in delegate-only master prompt."
)
return pruned_prompt
def generate_job_id_map(prompt_index, prefix):
"""Create stable per-node job IDs for distributed nodes."""
job_map = {}
distributed_nodes = prompt_index.nodes_for_class("DistributedCollector") + prompt_index.nodes_for_class(
"UltimateSDUpscaleDistributed"
)
for node_id in distributed_nodes:
job_map[node_id] = f"{prefix}_{node_id}"
return job_map
def _override_seed_nodes(prompt_copy, prompt_index, is_master, participant_id, worker_index_map):
"""Configure DistributedSeed nodes for master or worker role."""
for node_id in prompt_index.nodes_for_class("DistributedSeed"):
node = prompt_copy.get(node_id)
if not isinstance(node, dict):
continue
inputs = node.setdefault("inputs", {})
inputs["is_worker"] = not is_master
if is_master:
inputs["worker_id"] = ""
else:
inputs["worker_id"] = f"worker_{worker_index_map.get(participant_id, 0)}"
def _override_collector_nodes(
prompt_copy,
prompt_index,
is_master,
participant_id,
job_id_map,
master_url,
enabled_json,
delegate_master,
):
"""Configure DistributedCollector nodes for master or worker role."""
for node_id in prompt_index.nodes_for_class("DistributedCollector"):
node = prompt_copy.get(node_id)
if not isinstance(node, dict):
continue
if prompt_index.has_upstream(node_id, "UltimateSDUpscaleDistributed"):
node.setdefault("inputs", {})["pass_through"] = True
continue
inputs = node.setdefault("inputs", {})
inputs["multi_job_id"] = job_id_map.get(node_id, node_id)
inputs["is_worker"] = not is_master
inputs["enabled_worker_ids"] = enabled_json
if is_master:
inputs["delegate_only"] = bool(delegate_master)
inputs.pop("master_url", None)
inputs.pop("worker_id", None)
else:
inputs["master_url"] = master_url
inputs["worker_id"] = participant_id
inputs["delegate_only"] = False
def _override_upscale_nodes(
prompt_copy,
prompt_index,
is_master,
participant_id,
job_id_map,
master_url,
enabled_json,
):
"""Configure UltimateSDUpscaleDistributed nodes for master or worker role."""
for node_id in prompt_index.nodes_for_class("UltimateSDUpscaleDistributed"):
node = prompt_copy.get(node_id)
if not isinstance(node, dict):
continue
inputs = node.setdefault("inputs", {})
inputs["multi_job_id"] = job_id_map.get(node_id, node_id)
inputs["is_worker"] = not is_master
inputs["enabled_worker_ids"] = enabled_json
if is_master:
inputs.pop("master_url", None)
inputs.pop("worker_id", None)
else:
inputs["master_url"] = master_url
inputs["worker_id"] = participant_id
def _override_value_nodes(prompt_copy, prompt_index, is_master, participant_id, worker_index_map):
"""Configure DistributedValue nodes for master or worker role."""
for node_id in prompt_index.nodes_for_class("DistributedValue"):
node = prompt_copy.get(node_id)
if not isinstance(node, dict):
continue
inputs = node.setdefault("inputs", {})
inputs["is_worker"] = not is_master
if is_master:
inputs["worker_id"] = ""
else:
inputs["worker_id"] = f"worker_{worker_index_map.get(participant_id, 0)}"
def apply_participant_overrides(
prompt_copy,
participant_id,
enabled_worker_ids,
job_id_map,
master_url,
delegate_master,
prompt_index,
):
"""Return a prompt copy with hidden inputs configured for master/worker."""
is_master = participant_id == "master"
worker_index_map = {wid: idx for idx, wid in enumerate(enabled_worker_ids)}
enabled_json = json.dumps(enabled_worker_ids)
_override_seed_nodes(prompt_copy, prompt_index, is_master, participant_id, worker_index_map)
_override_value_nodes(prompt_copy, prompt_index, is_master, participant_id, worker_index_map)
_override_collector_nodes(
prompt_copy,
prompt_index,
is_master,
participant_id,
job_id_map,
master_url,
enabled_json,
delegate_master,
)
_override_upscale_nodes(
prompt_copy,
prompt_index,
is_master,
participant_id,
job_id_map,
master_url,
enabled_json,
)
return prompt_copy
================================================
FILE: api/queue_orchestration.py
================================================
import asyncio
import time
import uuid
import server
from ..utils.async_helpers import queue_prompt_payload
from ..utils.config import load_config
from ..utils.constants import (
ORCHESTRATION_MEDIA_SYNC_CONCURRENCY,
ORCHESTRATION_MEDIA_SYNC_TIMEOUT,
ORCHESTRATION_WORKER_PROBE_CONCURRENCY,
ORCHESTRATION_WORKER_PREP_CONCURRENCY,
)
from ..utils.logging import debug_log, log
from ..utils.network import build_master_url, build_master_callback_url
from ..utils.trace_logger import trace_debug
from .schemas import parse_positive_float, parse_positive_int
from .orchestration.dispatch import (
dispatch_worker_prompt,
select_active_workers,
select_least_busy_worker,
)
from .orchestration.media_sync import convert_paths_for_platform, fetch_worker_path_separator, sync_worker_media
from .orchestration.prompt_transform import (
PromptIndex,
apply_participant_overrides,
find_nodes_by_class,
generate_job_id_map,
prepare_delegate_master_prompt,
prune_prompt_for_worker,
)
prompt_server = server.PromptServer.instance
def _generate_execution_trace_id():
return f"exec_{int(time.time() * 1000)}_{uuid.uuid4().hex[:6]}"
def ensure_distributed_state(server_instance=None):
"""Ensure prompt_server has the state used by distributed queue orchestration."""
ps = server_instance or prompt_server
if not hasattr(ps, "distributed_pending_jobs"):
ps.distributed_pending_jobs = {}
if not hasattr(ps, "distributed_jobs_lock"):
ps.distributed_jobs_lock = asyncio.Lock()
# Initialize top-level distributed queue state at module import time.
ensure_distributed_state()
async def _ensure_distributed_queue(job_id):
"""Ensure a queue exists for the given distributed job ID."""
ensure_distributed_state()
async with prompt_server.distributed_jobs_lock:
if job_id not in prompt_server.distributed_pending_jobs:
prompt_server.distributed_pending_jobs[job_id] = asyncio.Queue()
def _resolve_enabled_workers(config, requested_ids=None):
"""Return a list of worker configs that should participate."""
workers = []
for worker in config.get("workers", []):
worker_id = str(worker.get("id") or "").strip()
if not worker_id:
continue
if requested_ids is not None:
if worker_id not in requested_ids:
continue
elif not worker.get("enabled", False):
continue
raw_port = worker.get("port", worker.get("listen_port", 8188))
try:
port = int(raw_port or 8188)
except (TypeError, ValueError):
log(f"[Distributed] Invalid port '{raw_port}' for worker {worker_id}; defaulting to 8188.")
port = 8188
workers.append(
{
"id": worker_id,
"name": worker.get("name", worker_id),
"host": worker.get("host"),
"port": port,
"type": worker.get("type", "local"),
}
)
return workers
def _resolve_orchestration_limits(config):
"""Resolve bounded concurrency/timeouts for worker preparation pipeline."""
settings = (config or {}).get("settings", {}) or {}
worker_probe_concurrency = parse_positive_int(
settings.get("worker_probe_concurrency"),
ORCHESTRATION_WORKER_PROBE_CONCURRENCY,
)
worker_prep_concurrency = parse_positive_int(
settings.get("worker_prep_concurrency"),
ORCHESTRATION_WORKER_PREP_CONCURRENCY,
)
media_sync_concurrency = parse_positive_int(
settings.get("media_sync_concurrency"),
ORCHESTRATION_MEDIA_SYNC_CONCURRENCY,
)
media_sync_timeout_seconds = parse_positive_float(
settings.get("media_sync_timeout_seconds"),
ORCHESTRATION_MEDIA_SYNC_TIMEOUT,
)
return (
worker_probe_concurrency,
worker_prep_concurrency,
media_sync_concurrency,
media_sync_timeout_seconds,
)
def _is_load_balance_enabled(value):
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "on"}
return False
def _prompt_requests_load_balance(prompt_index):
for node_id in prompt_index.nodes_for_class("DistributedCollector"):
inputs = prompt_index.inputs_by_node.get(node_id, {})
if _is_load_balance_enabled(inputs.get("load_balance", False)):
return True
return False
async def _prepare_worker_payload(
worker,
prompt_index,
enabled_ids,
job_id_map,
master_url,
config,
delegate_master,
trace_execution_id,
worker_prep_semaphore,
media_sync_semaphore,
media_sync_timeout_seconds,
):
"""Prepare one worker prompt payload with bounded concurrency and media-sync timeout."""
async with worker_prep_semaphore:
worker_prompt = prompt_index.copy_prompt()
worker_master_url = build_master_callback_url(
worker,
config=config,
prompt_server_instance=prompt_server,
)
worker_type = str(worker.get("type") or "local").strip().lower()
is_remote_like = bool(worker.get("host")) and worker_type != "local"
if is_remote_like:
path_separator = await fetch_worker_path_separator(worker, trace_execution_id=trace_execution_id)
if path_separator:
worker_prompt = convert_paths_for_platform(worker_prompt, path_separator)
worker_prompt = prune_prompt_for_worker(worker_prompt)
worker_prompt = apply_participant_overrides(
worker_prompt,
worker["id"],
enabled_ids,
job_id_map,
worker_master_url,
delegate_master,
prompt_index,
)
if is_remote_like:
async with media_sync_semaphore:
try:
await asyncio.wait_for(
sync_worker_media(worker, worker_prompt, trace_execution_id=trace_execution_id),
timeout=media_sync_timeout_seconds,
)
except asyncio.TimeoutError:
trace_debug(
trace_execution_id,
(
f"Media sync timed out after {media_sync_timeout_seconds:.1f}s "
f"for worker {worker.get('name')} ({worker.get('id')}); continuing dispatch."
),
)
return worker, worker_prompt
async def orchestrate_distributed_execution(
prompt_obj,
workflow_meta,
client_id,
enabled_worker_ids=None,
delegate_master=None,
trace_execution_id=None,
):
"""Core orchestration logic for the /distributed/queue endpoint.
Returns:
tuple[str, int, int, dict]: (prompt_id, number, worker_count, node_errors)
"""
ensure_distributed_state()
execution_trace_id = trace_execution_id or _generate_execution_trace_id()
config = load_config()
use_websocket = bool(config.get("settings", {}).get("websocket_orchestration", False))
master_url = build_master_url(config=config, prompt_server_instance=prompt_server)
(
worker_probe_concurrency,
worker_prep_concurrency,
media_sync_concurrency,
media_sync_timeout_seconds,
) = _resolve_orchestration_limits(config)
requested_ids = enabled_worker_ids if enabled_worker_ids is not None else None
workers = _resolve_enabled_workers(config, requested_ids)
prompt_index = PromptIndex(prompt_obj)
load_balance_requested = _prompt_requests_load_balance(prompt_index)
trace_debug(
execution_trace_id,
(
f"Orchestration start: requested_workers={len(workers)}, "
f"requested_ids={requested_ids if requested_ids is not None else 'enabled_only'}, "
f"websocket={use_websocket}, "
f"probe_concurrency={worker_probe_concurrency}, "
f"prep_concurrency={worker_prep_concurrency}, "
f"media_sync_concurrency={media_sync_concurrency}, "
f"media_sync_timeout={media_sync_timeout_seconds:.1f}s, "
f"load_balance={load_balance_requested}"
),
)
# Respect master delegate-only configuration
if delegate_master is None:
delegate_master = bool(config.get("settings", {}).get("master_delegate_only", False))
if not workers and delegate_master:
trace_debug(
execution_trace_id,
"Delegate-only requested but no workers are enabled. Falling back to master execution.",
)
delegate_master = False
active_workers, delegate_master = await select_active_workers(
workers,
use_websocket,
delegate_master,
trace_execution_id=execution_trace_id,
probe_concurrency=worker_probe_concurrency,
)
if load_balance_requested:
candidate_workers = list(active_workers)
if not delegate_master:
# Include master in load balancing only when master participation is enabled.
candidate_workers.append(
{
"id": "master",
"name": "Master",
"host": master_url,
"type": "local",
}
)
selected_worker = None
if candidate_workers:
selected_worker = await select_least_busy_worker(
candidate_workers,
trace_execution_id=execution_trace_id,
probe_concurrency=worker_probe_concurrency,
)
if selected_worker is None and candidate_workers:
trace_debug(
execution_trace_id,
"Load-balance selection probe failed; using first available candidate.",
)
selected_worker = candidate_workers[0]
if selected_worker is not None and str(selected_worker.get("id")) == "master":
# Master selected as least busy; run master workload only.
active_workers = []
delegate_master = False
trace_debug(
execution_trace_id,
"Load-balance selected master for execution (workers skipped).",
)
elif selected_worker is not None:
active_workers = [selected_worker]
# Worker selected as least busy; keep master orchestrator-only for this run.
delegate_master = True
trace_debug(
execution_trace_id,
f"Load-balance selected worker {selected_worker.get('id')} (master set to delegate-only).",
)
else:
trace_debug(
execution_trace_id,
"Load-balance requested but no execution candidates were available.",
)
active_workers = []
delegate_master = False
enabled_ids = [worker["id"] for worker in active_workers]
discovery_prefix = f"exec_{int(time.time() * 1000)}_{uuid.uuid4().hex[:6]}"
job_id_map = generate_job_id_map(prompt_index, discovery_prefix)
if not job_id_map:
trace_debug(execution_trace_id, "No distributed nodes detected; queueing prompt on master only.")
queue_result = await queue_prompt_payload(
prompt_obj,
workflow_meta,
client_id,
include_queue_metadata=True,
)
return (
queue_result["prompt_id"],
queue_result["number"],
0,
queue_result.get("node_errors", {}),
)
for job_id in job_id_map.values():
await _ensure_distributed_queue(job_id)
master_prompt = prompt_index.copy_prompt()
master_prompt = apply_participant_overrides(
master_prompt,
"master",
enabled_ids,
job_id_map,
master_url,
delegate_master,
prompt_index,
)
if delegate_master:
collector_ids = find_nodes_by_class(master_prompt, "DistributedCollector")
upscale_nodes = find_nodes_by_class(master_prompt, "UltimateSDUpscaleDistributed")
if upscale_nodes:
debug_log(
"Delegate-only master mode currently does not support UltimateSDUpscaleDistributed nodes; running full prompt on master."
)
elif not collector_ids:
debug_log(
"Delegate-only master mode requested but no collectors found in master prompt. Running full prompt on master."
)
else:
master_prompt = prepare_delegate_master_prompt(master_prompt, collector_ids)
if active_workers:
trace_debug(
execution_trace_id,
"Active distributed workers: "
+ ", ".join(f"{worker['name']} ({worker['id']})" for worker in active_workers),
)
worker_payloads = []
if active_workers:
worker_prep_semaphore = asyncio.Semaphore(worker_prep_concurrency)
media_sync_semaphore = asyncio.Semaphore(media_sync_concurrency)
worker_payloads = await asyncio.gather(
*[
_prepare_worker_payload(
worker,
prompt_index,
enabled_ids,
job_id_map,
master_url,
config,
delegate_master,
execution_trace_id,
worker_prep_semaphore,
media_sync_semaphore,
media_sync_timeout_seconds,
)
for worker in active_workers
]
)
if worker_payloads:
await asyncio.gather(
*[
dispatch_worker_prompt(
worker,
wprompt,
workflow_meta,
client_id,
use_websocket=use_websocket,
trace_execution_id=execution_trace_id,
)
for worker, wprompt in worker_payloads
]
)
queue_result = await queue_prompt_payload(
master_prompt,
workflow_meta,
client_id,
include_queue_metadata=True,
)
prompt_id = queue_result["prompt_id"]
prompt_number = queue_result["number"]
node_errors = queue_result.get("node_errors", {})
trace_debug(
execution_trace_id,
f"Orchestration complete: prompt_id={prompt_id}, dispatched_workers={len(worker_payloads)}, delegate_master={delegate_master}",
)
return prompt_id, prompt_number, len(worker_payloads), node_errors
================================================
FILE: api/queue_request.py
================================================
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@dataclass(frozen=True)
class QueueRequestPayload:
prompt: Dict[str, Any]
workflow_meta: Any
client_id: str
delegate_master: Optional[bool]
enabled_worker_ids: List[str]
auto_prepare: bool
trace_execution_id: Optional[str]
def parse_queue_request_payload(data: Any) -> QueueRequestPayload:
"""Parse and validate /distributed/queue payload into a normalized shape."""
if not isinstance(data, dict):
raise ValueError("Expected a JSON object body")
auto_prepare_raw = data.get("auto_prepare", True)
if not isinstance(auto_prepare_raw, bool):
raise ValueError("auto_prepare must be a boolean when provided")
auto_prepare = auto_prepare_raw
prompt = data.get("prompt")
# Auto-prepare is always on server-side; keep the field for wire compatibility.
if prompt is None:
workflow_payload = data.get("workflow")
if isinstance(workflow_payload, dict):
candidate_prompt = workflow_payload.get("prompt")
if isinstance(candidate_prompt, dict):
prompt = candidate_prompt
if not isinstance(prompt, dict):
raise ValueError("Field 'prompt' must be an object")
enabled_ids_raw = data.get("enabled_worker_ids")
workers_field = data.get("workers")
if enabled_ids_raw is None and workers_field is not None:
if not isinstance(workers_field, list):
raise ValueError("Field 'workers' must be a list when provided")
enabled_ids_raw = []
for entry in workers_field:
worker_id = entry.get("id") if isinstance(entry, dict) else entry
if worker_id is not None:
enabled_ids_raw.append(str(worker_id))
if enabled_ids_raw is None:
raise ValueError("enabled_worker_ids required")
else:
if not isinstance(enabled_ids_raw, list):
raise ValueError("enabled_worker_ids must be a list of worker IDs")
enabled_ids = [str(worker_id).strip() for worker_id in enabled_ids_raw if str(worker_id).strip()]
delegate_master = data.get("delegate_master")
if delegate_master is not None and not isinstance(delegate_master, bool):
raise ValueError("delegate_master must be a boolean when provided")
client_id = data.get("client_id")
if not isinstance(client_id, str) or not client_id.strip():
raise ValueError("client_id required")
client_id = client_id.strip()
trace_execution_id = data.get("trace_execution_id")
if trace_execution_id is not None:
if not isinstance(trace_execution_id, str):
raise ValueError("trace_execution_id must be a string when provided")
trace_execution_id = trace_execution_id.strip() or None
return QueueRequestPayload(
prompt=prompt,
workflow_meta=data.get("workflow"),
client_id=client_id,
delegate_master=delegate_master,
enabled_worker_ids=enabled_ids,
auto_prepare=auto_prepare,
trace_execution_id=trace_execution_id,
)
================================================
FILE: api/schemas.py
================================================
def require_fields(data: dict, *fields) -> list[str]:
"""Return field names that are missing or empty in a JSON object."""
if not isinstance(data, dict):
return list(fields)
missing = []
for field in fields:
if field not in data:
missing.append(field)
continue
value = data.get(field)
if value is None:
missing.append(field)
continue
if isinstance(value, str) and not value.strip():
missing.append(field)
return missing
def validate_worker_id(worker_id: str, config: dict) -> bool:
"""Return True when worker_id exists in config['workers']."""
worker_id_str = str(worker_id)
workers = (config or {}).get("workers", [])
return any(str(worker.get("id")) == worker_id_str for worker in workers)
def validate_positive_int(value, field_name: str) -> str | None:
"""Validate positive integers and return an error string when invalid."""
try:
parsed = int(value)
except (TypeError, ValueError):
return f"Field '{field_name}' must be a positive integer."
if parsed <= 0:
return f"Field '{field_name}' must be a positive integer."
return None
def parse_positive_int(value, default: int) -> int:
"""Parse value as positive int, returning default on failure."""
try:
parsed = int(value)
except (TypeError, ValueError):
return max(1, int(default))
return max(1, parsed)
def parse_positive_float(value, default: float) -> float:
"""Parse value as positive float, returning default on failure."""
try:
parsed = float(value)
except (TypeError, ValueError):
return max(0.0, float(default))
return max(0.0, parsed)
================================================
FILE: api/tunnel_routes.py
================================================
from aiohttp import web
import server
from ..utils.cloudflare import cloudflare_tunnel_manager
from ..utils.config import load_config
from ..utils.logging import debug_log, log
from ..utils.network import handle_api_error
@server.PromptServer.instance.routes.get("/distributed/tunnel/status")
async def tunnel_status_endpoint(request):
"""Return Cloudflare tunnel status and last known details."""
try:
status = cloudflare_tunnel_manager.get_status()
config = load_config()
master_host = (config.get("master") or {}).get("host")
return web.json_response({
"status": "success",
"tunnel": status,
"master_host": master_host
})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/tunnel/start")
async def tunnel_start_endpoint(request):
"""Start a Cloudflare tunnel pointing at the current ComfyUI server."""
try:
result = await cloudflare_tunnel_manager.start_tunnel()
config = load_config()
return web.json_response({
"status": "success",
"tunnel": result,
"master_host": (config.get("master") or {}).get("host")
})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/tunnel/stop")
async def tunnel_stop_endpoint(request):
"""Stop the managed Cloudflare tunnel if running."""
try:
result = await cloudflare_tunnel_manager.stop_tunnel()
config = load_config()
return web.json_response({
"status": "success",
"tunnel": result,
"master_host": (config.get("master") or {}).get("host")
})
except Exception as e:
return await handle_api_error(request, e, 500)
================================================
FILE: api/usdu_routes.py
================================================
import asyncio
import io
import time
from aiohttp import web
from PIL import Image
import server
from ..upscale.job_models import BaseJobState, ImageJobState, TileJobState
from ..upscale.job_store import MAX_PAYLOAD_SIZE, ensure_tile_jobs_initialized
from ..upscale.payload_parsers import _parse_tiles_from_form
from ..utils.logging import debug_log
from ..utils.network import handle_api_error
@server.PromptServer.instance.routes.post("/distributed/heartbeat")
async def heartbeat_endpoint(request):
try:
data = await request.json()
worker_id = data.get('worker_id')
multi_job_id = data.get('multi_job_id')
if not worker_id or not multi_job_id:
return await handle_api_error(request, "Missing worker_id or multi_job_id", 400)
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if isinstance(job_data, BaseJobState):
job_data.worker_status[worker_id] = time.time()
debug_log(f"Heartbeat from worker {worker_id}")
return web.json_response({"status": "success"})
return await handle_api_error(request, "Worker status tracking not available", 400)
return await handle_api_error(request, "Job not found", 404)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/submit_tiles")
async def submit_tiles_endpoint(request):
"""Endpoint for workers to submit processed tiles in static mode."""
try:
content_length = request.headers.get('content-length')
if content_length and int(content_length) > MAX_PAYLOAD_SIZE:
return await handle_api_error(request, f"Payload too large: {content_length} bytes", 413)
data = await request.post()
multi_job_id = data.get('multi_job_id')
worker_id = data.get('worker_id')
is_last = data.get('is_last', 'False').lower() == 'true'
if multi_job_id is None or worker_id is None:
return await handle_api_error(request, "Missing multi_job_id or worker_id", 400)
prompt_server = ensure_tile_jobs_initialized()
batch_size = int(data.get('batch_size', 0))
# Handle completion signal
if batch_size == 0 and is_last:
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if not isinstance(job_data, TileJobState):
return await handle_api_error(request, "Job not configured for tile submissions", 400)
await job_data.queue.put({
'worker_id': worker_id,
'is_last': True,
'tiles': [],
})
debug_log(f"Received completion signal from worker {worker_id}")
return web.json_response({"status": "success"})
try:
tiles = _parse_tiles_from_form(data)
except ValueError as e:
return await handle_api_error(request, str(e), 400)
# Submit tiles to queue
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if not isinstance(job_data, TileJobState):
return await handle_api_error(request, "Job not configured for tile submissions", 400)
q = job_data.queue
if batch_size > 0 or len(tiles) > 0:
await q.put({
'worker_id': worker_id,
'tiles': tiles,
'is_last': is_last,
})
debug_log(f"Received {len(tiles)} tiles from worker {worker_id} (is_last={is_last})")
else:
await q.put({
'worker_id': worker_id,
'is_last': True,
'tiles': [],
})
return web.json_response({"status": "success"})
return await handle_api_error(request, "Job not found", 404)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/submit_image")
async def submit_image_endpoint(request):
"""Endpoint for workers to submit processed images in dynamic mode."""
try:
content_length = request.headers.get('content-length')
if content_length and int(content_length) > MAX_PAYLOAD_SIZE:
return await handle_api_error(request, f"Payload too large: {content_length} bytes", 413)
data = await request.post()
multi_job_id = data.get('multi_job_id')
worker_id = data.get('worker_id')
is_last = data.get('is_last', 'False').lower() == 'true'
if multi_job_id is None or worker_id is None:
return await handle_api_error(request, "Missing multi_job_id or worker_id", 400)
prompt_server = ensure_tile_jobs_initialized()
# Handle image submission
if 'full_image' in data and 'image_idx' in data:
image_idx = int(data.get('image_idx'))
img_data = data['full_image'].file.read()
img = Image.open(io.BytesIO(img_data)).convert("RGB")
debug_log(f"Received full image {image_idx} from worker {worker_id}")
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if not isinstance(job_data, ImageJobState):
return await handle_api_error(request, "Job not configured for image submissions", 400)
await job_data.queue.put({
'worker_id': worker_id,
'image_idx': image_idx,
'image': img,
'is_last': is_last,
})
return web.json_response({"status": "success"})
# Handle completion signal (no image data)
elif is_last:
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if not isinstance(job_data, ImageJobState):
return await handle_api_error(request, "Job not configured for image submissions", 400)
await job_data.queue.put({
'worker_id': worker_id,
'is_last': True,
})
debug_log(f"Received completion signal from worker {worker_id}")
return web.json_response({"status": "success"})
else:
return await handle_api_error(request, "Missing image data or invalid request", 400)
return await handle_api_error(request, "Job not found", 404)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/request_image")
async def request_image_endpoint(request):
"""Endpoint for workers to request tasks (images in dynamic mode, tiles in static mode)."""
try:
data = await request.json()
worker_id = data.get('worker_id')
multi_job_id = data.get('multi_job_id')
if not worker_id or not multi_job_id:
return await handle_api_error(request, "Missing worker_id or multi_job_id", 400)
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if not isinstance(job_data, BaseJobState):
return await handle_api_error(request, "Invalid job data structure", 500)
mode = job_data.mode
if isinstance(job_data, ImageJobState):
pending_queue = job_data.pending_images
elif isinstance(job_data, TileJobState):
pending_queue = job_data.pending_tasks
else:
return await handle_api_error(request, "Invalid job configuration", 400)
try:
task_idx = await asyncio.wait_for(pending_queue.get(), timeout=0.1)
job_data.assigned_to_workers.setdefault(worker_id, []).append(task_idx)
job_data.worker_status[worker_id] = time.time()
remaining = pending_queue.qsize()
if mode == 'dynamic':
debug_log(f"UltimateSDUpscale API - Assigned image {task_idx} to worker {worker_id}")
return web.json_response({"image_idx": task_idx, "estimated_remaining": remaining})
debug_log(f"UltimateSDUpscale API - Assigned tile {task_idx} to worker {worker_id}")
return web.json_response({
"tile_idx": task_idx,
"estimated_remaining": remaining,
"batched_static": job_data.batched_static,
})
except asyncio.TimeoutError:
if mode == 'dynamic':
return web.json_response({"image_idx": None})
return web.json_response({"tile_idx": None})
return await handle_api_error(request, "Job not found", 404)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/job_status")
async def job_status_endpoint(request):
"""Endpoint to check if a job is ready."""
multi_job_id = request.query.get('multi_job_id')
if not multi_job_id:
return web.json_response({"ready": False})
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
ready = bool(isinstance(job_data, BaseJobState) and job_data.queue is not None)
return web.json_response({"ready": ready})
================================================
FILE: api/worker_routes.py
================================================
import json
import asyncio
import os
import time
import platform
import subprocess
import socket
import torch
import aiohttp
from aiohttp import web
import server
from ..utils.config import load_config
from ..utils.logging import debug_log, log
from ..utils.network import (
build_worker_url,
get_client_session,
handle_api_error,
normalize_host,
probe_worker,
)
from ..utils.constants import CHUNK_SIZE
from ..workers import get_worker_manager
from .schemas import require_fields, validate_worker_id
from ..workers.detection import (
get_machine_id,
is_docker_environment,
is_runpod_environment,
)
try:
from ..utils.async_helpers import PromptValidationError, queue_prompt_payload
except ImportError:
from ..utils.async_helpers import queue_prompt_payload
class PromptValidationError(RuntimeError):
def __init__(self, message, validation_error=None, node_errors=None):
super().__init__(str(message))
self.validation_error = validation_error if isinstance(validation_error, dict) else {}
self.node_errors = node_errors if isinstance(node_errors, dict) else {}
@server.PromptServer.instance.routes.get("/distributed/worker_ws")
async def worker_ws_endpoint(request):
"""WebSocket endpoint for worker prompt dispatch."""
ws = web.WebSocketResponse(heartbeat=30)
await ws.prepare(request)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT:
try:
data = json.loads(msg.data or "{}")
except json.JSONDecodeError:
await ws.send_json({
"type": "dispatch_ack",
"request_id": None,
"ok": False,
"error": "Invalid JSON payload.",
})
continue
if data.get("type") != "dispatch_prompt":
await ws.send_json({
"type": "dispatch_ack",
"request_id": data.get("request_id"),
"ok": False,
"error": "Unsupported websocket message type.",
})
continue
prompt = data.get("prompt")
if not isinstance(prompt, dict):
await ws.send_json({
"type": "dispatch_ack",
"request_id": data.get("request_id"),
"ok": False,
"error": "Field 'prompt' must be an object.",
})
continue
try:
prompt_id = await queue_prompt_payload(
prompt,
workflow_meta=data.get("workflow"),
client_id=data.get("client_id"),
)
await ws.send_json({
"type": "dispatch_ack",
"request_id": data.get("request_id"),
"ok": True,
"prompt_id": prompt_id,
})
except PromptValidationError as exc:
await ws.send_json({
"type": "dispatch_ack",
"request_id": data.get("request_id"),
"ok": False,
"error": str(exc),
"validation_error": exc.validation_error,
"node_errors": exc.node_errors,
})
except Exception as exc:
await ws.send_json({
"type": "dispatch_ack",
"request_id": data.get("request_id"),
"ok": False,
"error": str(exc),
})
elif msg.type == aiohttp.WSMsgType.ERROR:
log(f"[Distributed] Worker websocket error: {ws.exception()}")
return ws
@server.PromptServer.instance.routes.post("/distributed/worker/clear_launching")
async def clear_launching_state(request):
"""Clear the launching flag when worker is confirmed running."""
try:
wm = get_worker_manager()
data = await request.json()
missing = require_fields(data, "worker_id")
if missing:
return await handle_api_error(request, f"Missing required field(s): {', '.join(missing)}", 400)
worker_id = str(data.get("worker_id")).strip()
config = load_config()
if not validate_worker_id(worker_id, config):
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
# Clear launching flag in managed processes
if worker_id in wm.processes:
if 'launching' in wm.processes[worker_id]:
del wm.processes[worker_id]['launching']
wm.save_processes()
debug_log(f"Cleared launching state for worker {worker_id}")
return web.json_response({"status": "success"})
except Exception as e:
return await handle_api_error(request, e, 500)
def get_network_ips():
"""Get all network IPs, trying multiple methods."""
ips = []
hostname = socket.gethostname()
# Method 1: Try socket.getaddrinfo
try:
addr_info = socket.getaddrinfo(hostname, None)
for info in addr_info:
ip = info[4][0]
if ip and ip not in ips and not ip.startswith('::'): # Skip IPv6 for now
ips.append(ip)
except (socket.gaierror, OSError):
pass
# Method 2: Try to connect to external server and get local IP
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Google DNS
local_ip = s.getsockname()[0]
s.close()
if local_ip not in ips:
ips.append(local_ip)
except (OSError, socket.error):
pass
# Method 3: Platform-specific commands
try:
if platform.system() == "Windows":
# Windows ipconfig
result = subprocess.run(["ipconfig"], capture_output=True, text=True)
lines = result.stdout.split('\n')
for i, line in enumerate(lines):
if 'IPv4' in line and i + 1 < len(lines):
ip = lines[i].split(':')[-1].strip()
if ip and ip not in ips:
ips.append(ip)
else:
# Unix/Linux/Mac ifconfig or ip addr
try:
result = subprocess.run(["ip", "addr"], capture_output=True, text=True)
except (FileNotFoundError, OSError):
try:
result = subprocess.run(["ifconfig"], capture_output=True, text=True)
except (FileNotFoundError, OSError):
result = None
import re
ip_pattern = re.compile(r'inet\s+(\d+\.\d+\.\d+\.\d+)')
if result is not None:
for match in ip_pattern.finditer(result.stdout):
ip = match.group(1)
if ip and ip not in ips:
ips.append(ip)
except (OSError, subprocess.SubprocessError):
pass
return ips
def get_recommended_ip(ips):
"""Choose the best IP for master-worker communication."""
# Priority order:
# 1. Private network ranges (192.168.x.x, 10.x.x.x, 172.16-31.x.x)
# 2. Other non-localhost IPs
# 3. Localhost as last resort
private_ips = []
public_ips = []
for ip in ips:
if ip.startswith('127.') or ip == 'localhost':
continue
elif (ip.startswith('192.168.')
or ip.startswith('10.')
or (ip.startswith('172.') and 16 <= int(ip.split('.')[1]) <= 31)):
private_ips.append(ip)
else:
public_ips.append(ip)
# Prefer private IPs
if private_ips:
# Prefer 192.168 range as it's most common
for ip in private_ips:
if ip.startswith('192.168.'):
return ip
return private_ips[0]
elif public_ips:
return public_ips[0]
elif ips:
return ips[0]
else:
return None
def _get_cuda_info():
"""Detect CUDA device index and total physical GPU count.
Returns (cuda_device, cuda_device_count, physical_device_count).
All three are 0/None if CUDA is unavailable.
"""
if not torch.cuda.is_available():
return None, 0, 0
try:
cuda_device_count = torch.cuda.device_count()
cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cuda_visible and cuda_visible.strip():
visible_devices = [int(d.strip()) for d in cuda_visible.split(',') if d.strip().isdigit()]
if visible_devices:
cuda_device = visible_devices[0]
try:
result = subprocess.run(
['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
capture_output=True,
text=True,
timeout=5,
)
physical_device_count = (
len(result.stdout.strip().split('\n'))
if result.returncode == 0
else max(visible_devices) + 1
)
except (FileNotFoundError, OSError, subprocess.SubprocessError):
physical_device_count = max(visible_devices) + 1
return cuda_device, cuda_device_count, physical_device_count
else:
return 0, cuda_device_count, cuda_device_count
else:
cuda_device = torch.cuda.current_device()
return cuda_device, cuda_device_count, cuda_device_count
except Exception as e:
debug_log(f"CUDA detection error: {e}")
return None, 0, 0
def _collect_network_info_sync():
"""Collect network/cuda info in a worker thread to avoid blocking route handlers."""
cuda_device, cuda_device_count, physical_device_count = _get_cuda_info()
hostname = socket.gethostname()
all_ips = get_network_ips()
recommended_ip = get_recommended_ip(all_ips)
return {
"hostname": hostname,
"all_ips": all_ips,
"recommended_ip": recommended_ip,
"cuda_device": cuda_device,
"cuda_device_count": physical_device_count if physical_device_count > 0 else cuda_device_count,
}
def _read_worker_log_sync(log_file, lines_to_read):
"""Read worker log content from disk in a threadpool worker."""
file_size = os.path.getsize(log_file)
with open(log_file, 'r', encoding='utf-8', errors='replace') as f:
if lines_to_read > 0 and file_size > 1024 * 1024:
# Read last N lines efficiently from end of file.
lines = []
f.seek(0, 2)
file_length = f.tell()
chunk_size = min(CHUNK_SIZE, file_length)
while len(lines) < lines_to_read and f.tell() > 0:
current_pos = max(0, f.tell() - chunk_size)
f.seek(current_pos)
chunk = f.read(chunk_size)
chunk_lines = chunk.splitlines()
if current_pos > 0:
chunk_lines = chunk_lines[1:]
lines = chunk_lines + lines
f.seek(current_pos)
content = '\n'.join(lines[-lines_to_read:])
truncated = len(lines) > lines_to_read
else:
content = f.read()
truncated = False
return {
"content": content,
"file_size": file_size,
"truncated": truncated,
"lines_shown": lines_to_read if truncated else content.count('\n') + 1,
}
def _parse_positive_int_query(value, default, minimum=1, maximum=10000):
"""Parse bounded positive integer query params with sane fallback."""
try:
parsed = int(value)
except (TypeError, ValueError):
return default
parsed = max(minimum, parsed)
if maximum is not None:
parsed = min(maximum, parsed)
return parsed
def _find_worker_by_id(config, worker_id):
worker_id_str = str(worker_id).strip()
for worker in config.get("workers", []):
if str(worker.get("id")).strip() == worker_id_str:
return worker
return None
@server.PromptServer.instance.routes.get("/distributed/local_log")
async def get_local_log_endpoint(request):
"""Return this instance's in-memory ComfyUI log buffer."""
try:
from app.logger import get_logs
except Exception as e:
return await handle_api_error(request, f"Failed to import app.logger: {e}", 500)
try:
lines_to_read = _parse_positive_int_query(request.query.get("lines"), default=300, maximum=3000)
logs = get_logs()
if logs is None:
return web.json_response(
{
"status": "success",
"content": "",
"entries": 0,
"source": "memory",
"truncated": False,
"lines_shown": 0,
}
)
entries = list(logs)
selected_entries = entries[-lines_to_read:]
content = "".join(
entry.get("m", "") if isinstance(entry, dict) else str(entry)
for entry in selected_entries
)
lines_shown = content.count("\n") + (1 if content else 0)
return web.json_response(
{
"status": "success",
"content": content,
"entries": len(selected_entries),
"source": "memory",
"truncated": len(entries) > len(selected_entries),
"lines_shown": lines_shown,
}
)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/network_info")
async def get_network_info_endpoint(request):
"""Get network interfaces and recommend best IP for master."""
try:
loop = asyncio.get_running_loop()
info = await loop.run_in_executor(None, _collect_network_info_sync)
return web.json_response({
"status": "success",
**info,
"message": "Auto-detected network configuration"
})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/system_info")
async def get_system_info_endpoint(request):
"""Get system information including machine ID for local worker detection."""
try:
import socket
return web.json_response({
"status": "success",
"hostname": socket.gethostname(),
"machine_id": get_machine_id(),
"platform": {
"system": platform.system(),
"machine": platform.machine(),
"node": platform.node(),
"path_separator": os.sep, # Add path separator
"os_name": os.name # Add OS name (posix, nt, etc.)
},
"is_docker": is_docker_environment(),
"is_runpod": is_runpod_environment(),
"runpod_pod_id": os.environ.get('RUNPOD_POD_ID')
})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.post("/distributed/launch_worker")
async def launch_worker_endpoint(request):
"""Launch a worker process from the UI."""
try:
wm = get_worker_manager()
data = await request.json()
missing = require_fields(data, "worker_id")
if missing:
return await handle_api_error(request, f"Missing required field(s): {', '.join(missing)}", 400)
worker_id = str(data.get("worker_id")).strip()
# Find worker config
config = load_config()
if not validate_worker_id(worker_id, config):
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
worker = next((w for w in config.get("workers", []) if str(w.get("id")) == worker_id), None)
if not worker:
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
# Ensure consistent string ID
worker_id_str = worker_id
# Check if already running (managed by this instance)
if worker_id_str in wm.processes:
proc_info = wm.processes[worker_id_str]
process = proc_info.get('process')
# Check if still running
is_running = False
if process:
is_running = process.poll() is None
else:
# Restored process without subprocess object
is_running = wm._is_process_running(proc_info['pid'])
if is_running:
return await handle_api_error(request, "Worker already running (managed by UI)", 409)
else:
# Process is dead, remove it
del wm.processes[worker_id_str]
wm.save_processes()
# Launch the worker
try:
loop = asyncio.get_running_loop()
pid = await loop.run_in_executor(None, wm.launch_worker, worker)
log_file = wm.processes[worker_id_str].get('log_file')
return web.json_response({
"status": "success",
"pid": pid,
"message": f"Worker {worker['name']} launched",
"log_file": log_file
})
except Exception as e:
return await handle_api_error(request, f"Failed to launch worker: {str(e)}", 500)
except Exception as e:
return await handle_api_error(request, e, 400)
@server.PromptServer.instance.routes.post("/distributed/stop_worker")
async def stop_worker_endpoint(request):
"""Stop a worker process that was launched from the UI."""
try:
wm = get_worker_manager()
data = await request.json()
missing = require_fields(data, "worker_id")
if missing:
return await handle_api_error(request, f"Missing required field(s): {', '.join(missing)}", 400)
worker_id = str(data.get("worker_id")).strip()
config = load_config()
if not validate_worker_id(worker_id, config):
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
success, message = wm.stop_worker(worker_id)
if success:
return web.json_response({"status": "success", "message": message})
else:
return await handle_api_error(
request,
message,
404 if "not managed" in message else 409,
)
except Exception as e:
return await handle_api_error(request, e, 400)
@server.PromptServer.instance.routes.get("/distributed/managed_workers")
async def get_managed_workers_endpoint(request):
"""Get list of workers managed by this UI instance."""
try:
managed = get_worker_manager().get_managed_workers()
return web.json_response({
"status": "success",
"managed_workers": managed
})
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/local-worker-status")
async def get_local_worker_status_endpoint(request):
"""Check status of all local workers (localhost/no host specified)."""
try:
config = load_config()
worker_statuses = {}
for worker in config.get("workers", []):
# Only check local workers
host = normalize_host(worker.get("host")) or ""
if not host or host in ["localhost", "127.0.0.1"]:
worker_id = worker["id"]
port = worker["port"]
# Check if worker is enabled
if not worker.get("enabled", False):
worker_statuses[worker_id] = {
"online": False,
"enabled": False,
"processing": False,
"queue_count": 0
}
continue
# Try to connect to worker
try:
worker_url = build_worker_url(worker)
data = await probe_worker(worker_url, timeout=2.0)
if data is None:
worker_statuses[worker_id] = {
"online": False,
"enabled": True,
"processing": False,
"queue_count": 0,
"error": "Unavailable",
}
continue
queue_remaining = data.get("exec_info", {}).get("queue_remaining", 0)
worker_statuses[worker_id] = {
"online": True,
"enabled": True,
"processing": queue_remaining > 0,
"queue_count": queue_remaining
}
except asyncio.TimeoutError:
worker_statuses[worker_id] = {
"online": False,
"enabled": True,
"processing": False,
"queue_count": 0,
"error": "Timeout"
}
except Exception as e:
worker_statuses[worker_id] = {
"online": False,
"enabled": True,
"processing": False,
"queue_count": 0,
"error": str(e)
}
return web.json_response({
"status": "success",
"worker_statuses": worker_statuses
})
except Exception as e:
debug_log(f"Error checking local worker status: {e}")
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/worker_log/{worker_id}")
async def get_worker_log_endpoint(request):
"""Get log content for a specific worker."""
try:
wm = get_worker_manager()
worker_id = request.match_info['worker_id']
# Ensure worker_id is string
worker_id = str(worker_id)
# Check if we manage this worker
if worker_id not in wm.processes:
return await handle_api_error(request, f"Worker {worker_id} not managed by UI", 404)
proc_info = wm.processes[worker_id]
log_file = proc_info.get('log_file')
if not log_file or not os.path.exists(log_file):
return await handle_api_error(request, "Log file not found", 404)
# Read last N lines (or full file if small)
lines_to_read = _parse_positive_int_query(request.query.get('lines'), default=1000)
try:
loop = asyncio.get_running_loop()
payload = await loop.run_in_executor(None, _read_worker_log_sync, log_file, lines_to_read)
return web.json_response({
"status": "success",
"content": payload["content"],
"log_file": log_file,
"file_size": payload["file_size"],
"truncated": payload["truncated"],
"lines_shown": payload["lines_shown"],
})
except Exception as e:
return await handle_api_error(request, f"Error reading log file: {str(e)}", 500)
except Exception as e:
return await handle_api_error(request, e, 500)
@server.PromptServer.instance.routes.get("/distributed/remote_worker_log/{worker_id}")
async def get_remote_worker_log_endpoint(request):
"""Proxy a remote worker log request to the worker's local in-memory log endpoint."""
try:
worker_id = str(request.match_info["worker_id"]).strip()
config = load_config()
worker = _find_worker_by_id(config, worker_id)
if not worker:
return await handle_api_error(request, f"Worker {worker_id} not found", 404)
# Remote log proxy is only meaningful for remote/cloud workers.
host = normalize_host(worker.get("host")) or ""
if not host:
return await handle_api_error(
request,
f"Worker {worker_id} is local; use /distributed/worker_log/{worker_id} instead.",
400,
)
lines_to_read = _parse_positive_int_query(request.query.get("lines"), default=300, maximum=3000)
worker_url = build_worker_url(worker, "/distributed/local_log")
session = await get_client_session()
async with session.get(
worker_url,
params={"lines": str(lines_to_read)},
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status >= 400:
body = await resp.text()
return await handle_api_error(
request,
f"Remote worker {worker_id} returned HTTP {resp.status}: {body[:400]}",
resp.status,
)
try:
data = await resp.json()
except Exception as e:
return await handle_api_error(
request,
f"Remote worker {worker_id} returned invalid JSON: {e}",
502,
)
return web.json_response(data)
except Exception as e:
return await handle_api_error(request, e, 500)
================================================
FILE: conftest.py
================================================
# conftest.py — project-level pytest configuration.
#
# Problem: custom_nodes/ComfyUI-Distributed/__init__.py uses relative imports
# (from .distributed import ...) that fail when pytest tries to import it as a
# standalone module during Package.setup() for the root package node.
#
# Fix: patch Package.setup() to skip the root-package's __init__.py import.
# All actual package context is provided by each test module via
# importlib.util.spec_from_file_location with synthetic stub packages.
from _pytest.python import Package
_orig_pkg_setup = Package.setup
def _patched_pkg_setup(self) -> None:
# Skip the root package setup — its __init__.py uses relative imports
# that require a parent package (ComfyUI's plugin loader) which is not
# available in the test environment.
if self.path == self.config.rootpath:
return
_orig_pkg_setup(self)
Package.setup = _patched_pkg_setup
collect_ignore = [
"__init__.py",
"distributed.py",
"distributed_upscale.py",
]
================================================
FILE: distributed.py
================================================
"""
ComfyUI-Distributed: thin entry point.
All implementation lives in workers/, nodes/, api/.
"""
import atexit
import os
import server
from .utils.config import ensure_config_exists
from .utils.logging import debug_log
from .utils.network import cleanup_client_session
from .workers import get_worker_manager
from .workers.startup import delayed_auto_launch, register_async_signals, sync_cleanup
from .upscale.job_store import ensure_tile_jobs_initialized
from .nodes import (
NODE_CLASS_MAPPINGS,
NODE_DISPLAY_NAME_MAPPINGS,
ImageBatchDivider,
DistributedCollectorNode,
DistributedSeed,
DistributedModelName,
DistributedValue,
AudioBatchDivider,
DistributedEmptyImage,
AnyType,
ByPassTypeTuple,
any_type,
)
from . import api # noqa: F401 - triggers all @routes.* registrations
from .api.queue_orchestration import ensure_distributed_state
ensure_config_exists()
# Aiohttp session cleanup
async def _cleanup_session():
await cleanup_client_session()
atexit.register(lambda: None) # placeholder; real cleanup in sync_cleanup
# Initialize distributed job state on prompt_server
prompt_server = server.PromptServer.instance
ensure_distributed_state(prompt_server)
ensure_tile_jobs_initialized()
# Worker startup
if not os.environ.get('COMFYUI_IS_WORKER'):
atexit.register(sync_cleanup)
delayed_auto_launch()
register_async_signals()
================================================
FILE: docs/comfyui-distributed-api.md
================================================
# ComfyUI-Distributed API (Experimental)
This document describes the **public HTTP API** added to ComfyUI-Distributed to allow queueing *distributed* workflows from external tools (scripts, services, CI jobs, render farms, etc.) without using the ComfyUI web UI.
## Demo
- Video walkthrough: https://youtu.be/yiQlPd0MzLk
## Examples Repository
- Examples repo: https://github.com/umanets/ComfyUI-Distributed-API-examples.git
---
## Overview
### What this adds
- `POST /distributed/queue` — queues a workflow using the same distributed orchestration rules as the UI:
- Detects distributed nodes in the prompt (`DistributedCollector`, `UltimateSDUpscaleDistributed`).
- Resolves enabled/selected workers.
- Pings workers (`GET /prompt`) to include only reachable ones.
- Dispatches the workflow to workers (`POST /prompt`).
- Queues the master workflow in ComfyUI’s prompt queue.
- If any `DistributedCollector` has `load_balance=true`, selects one least-busy participant for this run.
### What it does *not* add
- Authentication/authorization.
- A separate “job status” API for distributed results (you still use ComfyUI’s normal prompt history / websocket flow, and the existing `/distributed/queue_status/{job_id}` behavior for collector queues).
---
## Endpoint: `POST /distributed/queue`
Queue a workflow for distributed execution.
### URL
- `http://:/distributed/queue`
### Headers
- `Content-Type: application/json`
### Request Body
```json
{
"prompt": { "": { "class_type": "...", "inputs": { } } },
"workflow": { },
"client_id": "external-client",
"delegate_master": false,
"enabled_worker_ids": ["1", "2"],
"workers": ["1", "2"],
"auto_prepare": true,
"trace_execution_id": "exec_1700000000_ab12cd"
}
```
#### Fields
- `prompt` (required unless `workflow.prompt` is present, object)
- The ComfyUI prompt/workflow graph, same shape as used by `POST /prompt`.
- `workflow` (optional, object)
- Workflow metadata that ComfyUI normally stores in `extra_pnginfo.workflow`.
- If you don’t care about UI metadata, you can omit it.
- `client_id` (required, string)
- Passed through as `extra_data.client_id` (useful if you consume ComfyUI websocket events).
- `delegate_master` (optional, boolean)
- If `true`, attempts “workers-only” execution for workflows based on `DistributedCollector`.
- Current limitation: delegate-only mode **does not support** `UltimateSDUpscaleDistributed` and will fall back to running the full prompt on master.
- `enabled_worker_ids` (required, array of strings)
- The explicit worker IDs to consider for this run.
- `workers` (optional, array of strings or objects with `id`)
- Transitional alias for `enabled_worker_ids` used by older clients.
- `auto_prepare` (optional, boolean)
- Kept for wire compatibility.
- Backend orchestration always runs with auto-prepare semantics.
- If top-level `prompt` is omitted, backend will attempt `workflow.prompt`.
- `trace_execution_id` (optional, string)
- Passed through to orchestration logs.
- Server log lines include the marker as `[exec:]`.
##### How to get `enabled_worker_ids`
Worker IDs come from the plugin config (`GET /distributed/config`) under `workers[].id`.
Example (bash + `jq`):
```bash
curl -s "http://127.0.0.1:8188/distributed/config" \
| jq -r '.workers[] | "id=\(.id)\tname=\(.name)\tenabled=\(.enabled)\thost=\(.host)\tport=\(.port)\ttype=\(.type)"'
```
Example (PowerShell):
```powershell
$cfg = Invoke-RestMethod "http://127.0.0.1:8188/distributed/config"
$cfg.workers | Select-Object id,name,enabled,host,port,type | Format-Table -AutoSize
```
### Response Body
```json
{
"prompt_id": "",
"worker_count": 2,
"auto_prepare_supported": true
}
```
- `prompt_id` — the master prompt id queued into ComfyUI.
- `worker_count` — number of workers that received a dispatched prompt (only those that passed the health check).
### Status Codes
- `200` — queued.
- `400` — invalid JSON or invalid body.
- `500` — orchestration/dispatch failure (see server logs for details).
---
## Worker requirements (important)
For a worker to participate, it must be reachable from the master:
- Health check: `GET /prompt` must return HTTP 200.
- Dispatch: `POST /prompt` must accept the workflow.
Also, for collector-based flows:
- Workers will send results back to the master via `POST /distributed/job_complete` (that route must be reachable from workers).
---
## Endpoint: `POST /distributed/job_complete`
Submit one completed worker image back to the master collector queue.
### URL
- `http://:/distributed/job_complete`
### Request Body
```json
{
"job_id": "exec_1234567890_17",
"worker_id": "worker-1",
"batch_idx": 0,
"image": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...",
"is_last": false
}
```
### Canonical envelope (required fields)
- `job_id` (string, required)
- `worker_id` (string, required)
- `batch_idx` (integer >= 0, required)
- `image` (string, required)
- PNG payload as either:
- data URL: `data:image/png;base64,...`
- raw base64 PNG bytes
- `is_last` (boolean, required)
Legacy multipart/tensor payload formats are no longer accepted on this endpoint.
### CORS note
If you call the API from a browser (not from a backend), ensure the master ComfyUI is started with `--enable-cors-header`.
---
## Log Endpoints
### `GET /distributed/worker_log/{worker_id}`
Read log files for workers launched locally by the master UI process manager.
- Intended for managed local workers.
- Query param: `lines` (optional, default `1000`).
### `GET /distributed/local_log`
Read this ComfyUI instance's in-memory runtime log buffer.
- Available on any ComfyUI-Distributed instance (master or worker).
- Query param: `lines` (optional, default `300`, max `3000`).
### `GET /distributed/remote_worker_log/{worker_id}`
Proxy endpoint on master that fetches logs from a configured remote/cloud worker's
`/distributed/local_log`.
- Intended for remote/cloud workers in master config.
- Query param: `lines` (optional, default `300`, max `3000`).
---
## Examples
### 1) Minimal `curl`
```bash
curl -X POST "http://127.0.0.1:8188/distributed/queue" \
-H "Content-Type: application/json" \
-d @payload.json
```
Where `payload.json` contains at least:
```json
{
"prompt": {
"1": {"class_type": "KSampler", "inputs": {} }
},
"enabled_worker_ids": [],
"client_id": "external-client"
}
```
### 2) Python (`requests`)
```python
import requests
url = "http://127.0.0.1:8188/distributed/queue"
payload = {
"prompt": {...},
"workflow": {...},
"client_id": "external-client",
"delegate_master": False,
"enabled_worker_ids": ["1", "2"],
}
r = requests.post(url, json=payload, timeout=60)
r.raise_for_status()
print(r.json())
```
### 3) JavaScript (`fetch`)
```js
const url = "http://127.0.0.1:8188/distributed/queue";
const payload = {
prompt: {/* ... */},
workflow: {/* ... */},
client_id: "external-client",
delegate_master: false,
enabled_worker_ids: ["1", "2"],
};
const resp = await fetch(url, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload),
});
if (!resp.ok) throw new Error(await resp.text());
console.log(await resp.json());
```
---
## Operational notes / gotchas
- If the workflow contains **no distributed nodes**, the endpoint falls back to normal master queueing and returns `worker_count: 0`.
- Worker selection is “best-effort”: offline workers are skipped.
- For public URLs/tunnels: prefer configuring `master.host` with an explicit scheme (`https://...`) to avoid ambiguity.
---
## Changelog (this feature)
- Added `POST /distributed/queue` endpoint.
- Added orchestration module used by the endpoint.
================================================
FILE: docs/model-download-script.md
================================================
## Automating ComfyUI Model Downloads
> This guide will walk you through creating a shell script to automatically download the necessary models for your ComfyUI workflow, leveraging an advanced Large Language Model (LLM).
1. In ComfyUI (on your local machine), export your workflow as an API workflow
2. Copy the below prompt and upload the API workflow to an LLM **that has access to the internet**
📋 Click to expand the full prompt
```
Create a sh script that will download the models from this workflow into the correct folders. For reference, these are the paths:
base_path: /workspace/ComfyUI
checkpoints: models/checkpoints/
clip: models/clip/
clip_vision: models/clip_vision/
controlnet: models/controlnet/
diffusion_models: models/diffusion_models/
embeddings: models/embeddings/
florence2: models/florence2/
ipadapter: models/ipadapter/
loras: models/loras/
style_models: models/style_models/
text_encoders: models/text_encoders/
unet: models/unet/
upscale_models: models/upscale_models/
vae: models/vae/
---
Important:
Make sure you find the correct URLs for the models online.
Use comfy cli to download the models: `comfy model download --url [--relative-path ] [--set-civitai-api-token ] [--set-hf-api-token ]`
Make sure you add `--set-civitai-api-token $CIVITAI_API_TOKEN` for CivitAI download and `--set-hf-api-token $HF_API_TOKEN` for Hugging Face downloads.
---
Example:
#!/bin/bash
# Download from CivitAI
comfy model download --url https://civitai.com/api/download/models/1759168 --relative-path /workspace/ComfyUI/models/checkpoints --set-civitai-api-token $CIVITAI_API_TOKEN
# Download model from Hugging Face
comfy model download --url https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/flux1-dev.safetensors --relative-path /workspace/ComfyUI/models/unet --set-hf-api-token $HF_API_TOKEN
# If a model in the workflow was in a subfolder
comfy model download --url https://civitai.com/api/download/models/1759168 --relative-path /workspace/ComfyUI/models/checkpoints/SDXL --set-civitai-api-token $CIVITAI_API_TOKEN
```
3. Review the LLMs output to make sure all download links are correct and save it as a .sh file, for example `download_models.sh`
4. Launch the [ComfyUI Distributed Pod](https://console.runpod.io/deploy?template=m21ynvo8yo&ref=ak218p52) with these Environment Variables:
- `CIVITAI_API_TOKEN`: [get your token here](https://civitai.com/user/account)
- `HF_API_TOKEN`: [get your token here](https://huggingface.co/settings/tokens)
5. Upload the .sh file to your Runpod instance, into `/workspace`
6. Then run these commands:
- `chmod 755 /workspace/download_models.sh`
- `/workspace/download_models.sh`
7. Confirm each model name (sometimes you might need to rename them to match the name on your local machine)
================================================
FILE: docs/video-upscaler-runpod-preset.md
================================================

**Accelerated Creative Video Upscaler On Runpod:**
1. Use the [ComfyUI Distributed Pod](https://console.runpod.io/deploy?template=m21ynvo8yo&ref=0bw29uf3ug0p) template.
2. Filter instances by CUDA 12.8 (add filter in Additional Filters at the top of the page).
3. Choose 4x 5090s
4. Press Edit Template to configure the pod's Environment Variables:
- CIVITAI_API_TOKEN: Not necessary for this workflow.
- HF_API_TOKEN: [get your token here](https://huggingface.co/settings/tokens)
- SAGE_ATTENTION: optional optimisation (set to true/false). Recommended for this workflow.
- PRESET_VIDEO_UPSCALER: set to true. This will download everything you need.
5. Deploy your pod.
6. Once pod setup is complete, connect to ComfyUI running on your pod.
7. In ComfyUI, open the GPU panel on the left.
> If you set SAGE_ATTENTION to true, add "--use-sage-attention" to Extra Args on the workers.
8. Launch the workers.
9. [Load the workflow.](https://github.com/robertvoy/ComfyUI-Distributed/blob/main/workflows/distributed-upscale-video.json)
10. Upload video, add prompt and run workflow.
11. Right-click the Video Combine node and click Save Preview to save the video.
================================================
FILE: docs/worker-setup-guides.md
================================================
## Worker Setup Guide
**Master**: The main ComfyUI instance that coordinates and distributes work. This is where you load workflows, manage the queue, and view results.
**Worker**: A ComfyUI instance that receives and processes tasks from the master. Workers handle just the GPU computation and send results back to the master. You can have multiple workers connected to a single master, each utilizing their own GPU.
### Master participation modes
The master can either contribute GPU work or stay in **orchestrator-only** mode:
- **Participating**: Master renders alongside workers, useful when you want every available GPU.
- **Orchestrator-only**: Master sends jobs to selected workers but skips local rendering. Enable this by opening the Distributed panel and unchecking the master toggle. The master card will display *“Master disabled: running as orchestrator only.”*
- **Fallback**: If orchestrator-only is enabled but no workers remain selected, the master automatically re-enables execution to guarantee the workflow still runs. The UI shows a green *“Master fallback execution active”* badge so you know work is executing locally again.
### Types of Workers
- **Local workers**: Additional GPUs on the same machine as the master
- **Remote workers**: GPUs on different computers within your network
- **Cloud workers**: GPUs hosted on cloud services like Runpod
## Local workers
> These are added automatically on first launch, but you can add them manually if you need to.
📺 [Watch Tutorial](https://youtu.be/p6eE3IlAbOs?si=K7Km0_flmPHwRQwz&t=43)
1. **Open** the Distributed GPU panel.
2. **Click** "Add Worker" in the UI.
3. **Configure** your local worker:
- **Name**: A descriptive name for the worker (e.g., "Studio PC 1")
- **Port**: A unique port number for this worker (e.g., 8189, 8190...).
- **CUDA Device**: The GPU index from `nvidia-smi` (e.g., 0, 1).
- **Extra Args**: Optional ComfyUI arguments for this specific worker.
4. **Save** and launch the local worker.
## Remote workers
> ComfyUI instances running on completely different computers on your network. These allow you to harness GPU power from other machines. Remote workers must be manually started on their respective computers and are connected via IP address.
📺 [Watch Tutorial](https://youtu.be/p6eE3IlAbOs?si=Oxj3EzPyf4jKDvfG&t=140)
**On the Remote Worker Machine:**
1. **Launch** ComfyUI with the `--listen --enable-cors-header` arguments. ⚠️ **Required!**
- This ComfyUI instance will serve as a worker for your main master.
2. *Optionally* add additional local workers on this machine if it has multiple GPUs:
- Access the Distributed GPU panel in this ComfyUI instance
- Add workers for any additional GPUs (if they haven't been added automatically)
- Make sure they have `--listen` set in `Extra Args`
- Launch them
3. **Open** the ComfyUI port (e.g., 8188) and any additional worker ports (e.g., 8189, 8190) in the firewall.
**On the Main Machine:**
1. **Launch** ComfyUI with `--enable-cors-header` launch argument.
2. **Open** the Distributed GPU panel (sidebar on the left).
3. **Click** "Add Worker."
4. **Choose** "Remote".
5. **Configure** your remote worker:
- **Name**: A descriptive name for the worker (e.g., "Server Rack GPU 0")
- **Host**: The remote worker's IP address.
- **Port**: The port number used when launching ComfyUI on the remote master/worker (e.g., 8188).
6. **Save** the remote worker configuration.
## Cloud workers
> ComfyUI instances running on a cloud service like Runpod.
### Deploy Cloud Worker on Runpod
📺 [Watch Tutorial](https://www.youtube.com/watch?v=wxKKWMQhYTk)
**On Runpod:**
> If using your own template, make sure you launch ComfyUI with the `--enable-cors-header` argument and you `git clone ComfyUI-Distributed` into custom_nodes. ⚠️ **Required!**
1. Register a [Runpod](https://get.runpod.io/0bw29uf3ug0p) account.
2. On Runpod, go to Storage > New Network Volume and create a volume that will store the models you need. Start with 40 GB, you can always add more later. Learn more [about Network Volumes](https://docs.runpod.io/pods/storage/create-network-volumes).
3. Use the [ComfyUI Distributed Pod](https://console.runpod.io/deploy?template=m21ynvo8yo&ref=0bw29uf3ug0p) template.
4. Make sure your Network Volume is mounted and choose a suitable GPU.
> ⚠️ To use the ComfyUI Distributed Pod template, you will need to filter instances by CUDA 12.8 (add filter in Additional Filters).
6. Press Edit Template to configure the pod's Environment Variables:
- CIVITAI_API_TOKEN: [get your token here](https://civitai.com/user/account)
- HF_API_TOKEN: [get your token here](https://huggingface.co/settings/tokens)
- SAGE_ATTENTION: optional optimisation (set to true/false)
5. Deploy your pod.
6. Connect to your pod using JupyterLabs. This gives us access to the pod's file system.
7. Download models into /workspaces/ComfyUI/models/ (these will remain on your network drive even after you terminate the pod). Example commands below:
```
# Download from CivitAI
comfy model download --url https://civitai.com/api/download/models/1759168 --relative-path /workspace/ComfyUI/models/checkpoints --set-civitai-api-token $CIVITAI_API_TOKEN
# Download model from Hugging Face
comfy model download --url https://huggingface.co/black-forest-labs/FLUX.1-dev/resolve/main/flux1-dev.safetensors --relative-path /workspace/ComfyUI/models/unet --set-hf-api-token $HF_API_TOKEN
```
> ℹ️ Use [this guide](model-download-script.md) to make this process easy. It will generate a shell script that automatically downloads the models for a given workflow.
9. Access ComfyUI through the Runpod URL.
10. Download any additional custom nodes you need using the ComfyUI Manager.
**On the Main Machine:**
1. **Launch** a Cloudflare tunnel.
- Download from here: [https://github.com/cloudflare/cloudflared/releases](https://github.com/cloudflare/cloudflared/releases)
- Then run, for example: `cloudflared-windows-amd64.exe tunnel --url http://localhost:8188`
> ℹ️ Cloudflare tunnels create secure connections without exposing ports directly to the internet and are required for Cloud Workers.
2. **Copy** the Cloudflare address
3. **Launch** ComfyUI with `--enable-cors-header` launch argument.
4. **Open** the Distributed GPU panel (sidebar on the left).
5. **Edit** the Master's settings to change the host address to the Cloudflare address.
6. **Click** "Add Worker."
7. **Choose** "Cloud".
8. **Configure** your cloud worker:
- **Host**: The ComfyUI Runpod address. For example: `wcegfo9tbbml9l-8188.proxy.runpod.net`
- **Port**: 443
9. **Save** the remote worker configuration.
---
### Deploy Cloud Worker on Other Platforms
**On the Cloud Worker machine:**
- Your cloud worker container needs to have the same models and custom nodes as the workflow you want to run on your local machine.
- If your cloud platform doesn't provide a secure connection, use Cloudflare to create a tunnel for the worker. Each GPU needs their own tunnel for their respective port.
- For example: `./cloudflared tunnel --url http://localhost:8188`
1. **Launch** ComfyUI with the `--listen --enable-cors-header` arguments. ⚠️ **Required!**
2. **Add** workers in the UI panel if the cloud machine has more than one GPU.
- Make sure that they also have `--listen` set in `Extra Args`.
- Then launch them.
**On the Main Machine:**
1. **Launch** a Cloudflare tunnel on your local machine.
- Download from here: [https://github.com/cloudflare/cloudflared/releases](https://github.com/cloudflare/cloudflared/releases)
- Then run, for example: `cloudflared-windows-amd64.exe tunnel --url http://localhost:8188`
2. **Copy** the Cloudflare address
3. **Launch** ComfyUI with `--enable-cors-header` launch argument.
4. **Open** the Distributed GPU panel (sidebar on the left).
5. **Edit** the Master's host address and replace it with the Cloudflare address.
6. **Click** "Add Worker."
7. **Choose** "Cloud".
8. **Configure** your cloud worker:
- **Host**: The remote worker's IP address/domain
- **Port**: 443
9. **Save** the remote worker configuration.
================================================
FILE: nodes/__init__.py
================================================
from .utilities import (
DistributedSeed,
DistributedModelName,
DistributedValue,
ImageBatchDivider,
AudioBatchDivider,
DistributedEmptyImage,
AnyType,
ByPassTypeTuple,
any_type,
)
from .collector import DistributedCollectorNode
NODE_CLASS_MAPPINGS = {
"DistributedCollector": DistributedCollectorNode,
"DistributedSeed": DistributedSeed,
"DistributedModelName": DistributedModelName,
"DistributedValue": DistributedValue,
"ImageBatchDivider": ImageBatchDivider,
"AudioBatchDivider": AudioBatchDivider,
"DistributedEmptyImage": DistributedEmptyImage,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DistributedCollector": "Distributed Collector",
"DistributedSeed": "Distributed Seed",
"DistributedModelName": "Distributed Model Name",
"DistributedValue": "Distributed Value",
"ImageBatchDivider": "Image Batch Divider",
"AudioBatchDivider": "Audio Batch Divider",
"DistributedEmptyImage": "Distributed Empty Image",
}
================================================
FILE: nodes/collector.py
================================================
import torch
import io
import json
import asyncio
import time
import base64
import aiohttp
import server as _server
import comfy.model_management
from comfy.utils import ProgressBar
from ..utils.logging import debug_log, log
from ..utils.config import get_worker_timeout_seconds, load_config, is_master_delegate_only
from ..utils.constants import HEARTBEAT_INTERVAL
from ..utils.image import tensor_to_pil, pil_to_tensor, ensure_contiguous
from ..utils.network import build_worker_url, get_client_session, probe_worker
from ..utils.audio_payload import encode_audio_payload
from ..utils.async_helpers import run_async_in_server_loop
prompt_server = _server.PromptServer.instance
class DistributedCollectorNode:
EMPTY_AUDIO = {"waveform": torch.zeros(1, 2, 1), "sample_rate": 44100}
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"load_balance": (
"BOOLEAN",
{
"default": False,
"tooltip": "Run this workflow on one least-busy participant (master included when participating).",
},
),
},
"optional": { "audio": ("AUDIO",) },
"hidden": {
"multi_job_id": ("STRING", {"default": ""}),
"is_worker": ("BOOLEAN", {"default": False}),
"master_url": ("STRING", {"default": ""}),
"enabled_worker_ids": ("STRING", {"default": "[]"}),
"worker_batch_size": ("INT", {"default": 1, "min": 1, "max": 1024}),
"worker_id": ("STRING", {"default": ""}),
"pass_through": ("BOOLEAN", {"default": False}),
"delegate_only": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE", "AUDIO")
RETURN_NAMES = ("images", "audio")
FUNCTION = "run"
CATEGORY = "image"
def run(self, images, load_balance=False, audio=None, multi_job_id="", is_worker=False, master_url="", enabled_worker_ids="[]", worker_batch_size=1, worker_id="", pass_through=False, delegate_only=False):
# Create empty audio if not provided
empty_audio = {"waveform": torch.zeros(1, 2, 1), "sample_rate": 44100}
if not multi_job_id or pass_through:
if pass_through:
debug_log("Collector: pass-through mode enabled, returning images unchanged")
return (images, audio if audio is not None else empty_audio)
# Use async helper to run in server loop
result = run_async_in_server_loop(
self.execute(
images,
audio,
load_balance,
multi_job_id,
is_worker,
master_url,
enabled_worker_ids,
worker_batch_size,
worker_id,
delegate_only,
)
)
return result
async def send_batch_to_master(self, image_batch, audio, multi_job_id, master_url, worker_id):
"""Send image batch to master via canonical JSON envelopes."""
batch_size = image_batch.shape[0]
if batch_size == 0:
return
encoded_audio = encode_audio_payload(audio)
session = await get_client_session()
url = f"{master_url}/distributed/job_complete"
for batch_idx in range(batch_size):
img = tensor_to_pil(image_batch[batch_idx:batch_idx+1], 0)
byte_io = io.BytesIO()
img.save(byte_io, format='PNG', compress_level=0)
encoded_image = base64.b64encode(byte_io.getvalue()).decode('utf-8')
payload = {
"job_id": str(multi_job_id),
"worker_id": str(worker_id),
"batch_idx": int(batch_idx),
"image": f"data:image/png;base64,{encoded_image}",
"is_last": bool(batch_idx == batch_size - 1),
}
if payload["is_last"] and encoded_audio is not None:
payload["audio"] = encoded_audio
try:
async with session.post(
url,
json=payload,
timeout=aiohttp.ClientTimeout(total=60),
) as response:
response.raise_for_status()
except Exception as e:
log(f"Worker - Failed to send canonical image envelope to master: {e}")
debug_log(f"Worker - Full error details: URL={url}")
raise # Re-raise to handle at caller level
def _combine_audio(self, master_audio, worker_audio, empty_audio, worker_order=None):
"""Combine audio from master and workers into a single audio output.
Ordering: master first, then workers in `worker_order` (if provided),
then any unexpected worker ids in sorted order.
"""
audio_pieces = []
sample_rate = 44100
# Add master audio first if present
if master_audio is not None:
waveform = master_audio.get("waveform")
if waveform is not None and waveform.numel() > 0:
audio_pieces.append(waveform)
sample_rate = master_audio.get("sample_rate", 44100)
# Add worker audio in configured enabled-worker order first.
ordered_worker_ids = [str(worker_id) for worker_id in (worker_order or [])]
seen = set()
for worker_id_str in ordered_worker_ids:
seen.add(worker_id_str)
w_audio = worker_audio.get(worker_id_str)
if w_audio is not None:
waveform = w_audio.get("waveform")
if waveform is not None and waveform.numel() > 0:
audio_pieces.append(waveform)
# Use first available sample rate
if sample_rate == 44100:
sample_rate = w_audio.get("sample_rate", 44100)
# Append any audio from unexpected worker ids deterministically.
for worker_id_str in sorted(worker_audio.keys()):
if worker_id_str in seen:
continue
w_audio = worker_audio[worker_id_str]
if w_audio is not None:
waveform = w_audio.get("waveform")
if waveform is not None and waveform.numel() > 0:
audio_pieces.append(waveform)
if sample_rate == 44100:
sample_rate = w_audio.get("sample_rate", 44100)
if not audio_pieces:
return empty_audio
try:
# Concatenate along the samples dimension (dim=-1)
# Ensure all pieces have same batch and channel dimensions
combined_waveform = torch.cat(audio_pieces, dim=-1)
debug_log(f"Master - Combined audio: {len(audio_pieces)} pieces, final shape={combined_waveform.shape}")
return {"waveform": combined_waveform, "sample_rate": sample_rate}
except Exception as e:
log(f"[Distributed] Master - Audio combination failed, returning silence: {e}")
return empty_audio
def _store_worker_result(self, worker_images: dict, item: dict) -> int:
"""Store one canonical queue item in worker_images in-place.
Canonical format:
- item has 'worker_id', 'image_index', and 'tensor'
Returns 1 when stored, otherwise 0.
"""
worker_id = item['worker_id']
tensor = item.get('tensor')
image_index = item.get('image_index')
if tensor is None or image_index is None:
return 0
worker_images.setdefault(worker_id, {})
worker_images[worker_id][image_index] = tensor
return 1
def _reorder_and_combine_tensors(
self,
worker_images: dict,
worker_order: list,
master_batch_size: int,
images_on_cpu,
delegate_mode: bool,
fallback_images,
) -> torch.Tensor:
"""Assemble final tensor: master first, then workers in enabled order."""
ordered_tensors = []
if not delegate_mode and images_on_cpu is not None:
for i in range(master_batch_size):
ordered_tensors.append(images_on_cpu[i:i+1])
ordered_worker_ids = [str(worker_id) for worker_id in (worker_order or [])]
seen = set()
for worker_id_str in ordered_worker_ids:
seen.add(worker_id_str)
if worker_id_str not in worker_images:
continue
for idx in sorted(worker_images[worker_id_str].keys()):
ordered_tensors.append(worker_images[worker_id_str][idx])
# Append any unexpected worker ids deterministically.
for worker_id_str in sorted(worker_images.keys()):
if worker_id_str in seen:
continue
for idx in sorted(worker_images[worker_id_str].keys()):
ordered_tensors.append(worker_images[worker_id_str][idx])
cpu_tensors = []
for t in ordered_tensors:
if t.is_cuda:
t = t.cpu()
t = ensure_contiguous(t)
cpu_tensors.append(t)
if cpu_tensors:
return ensure_contiguous(torch.cat(cpu_tensors, dim=0))
elif fallback_images is not None:
return ensure_contiguous(fallback_images)
else:
raise ValueError("No image data collected from master or workers")
async def execute(self, images, audio, load_balance=False, multi_job_id="", is_worker=False, master_url="", enabled_worker_ids="[]", worker_batch_size=1, worker_id="", delegate_only=False):
if is_worker:
# Worker mode: send images and audio to master in a single batch
debug_log(f"Worker - Job {multi_job_id} complete. Sending {images.shape[0]} image(s) to master")
await self.send_batch_to_master(images, audio, multi_job_id, master_url, worker_id)
return (images, audio if audio is not None else self.EMPTY_AUDIO)
else:
delegate_mode = delegate_only or is_master_delegate_only()
# Master mode: collect images and audio from workers
enabled_workers_raw = json.loads(enabled_worker_ids)
enabled_workers = []
seen_enabled = set()
for worker_id in enabled_workers_raw:
worker_id_str = str(worker_id)
if worker_id_str in seen_enabled:
continue
seen_enabled.add(worker_id_str)
enabled_workers.append(worker_id_str)
expected_workers = set(enabled_workers)
num_workers = len(expected_workers)
if num_workers == 0:
return (images, audio if audio is not None else self.EMPTY_AUDIO)
# Create the queue before any expensive local work to avoid job_complete race.
async with prompt_server.distributed_jobs_lock:
if multi_job_id not in prompt_server.distributed_pending_jobs:
prompt_server.distributed_pending_jobs[multi_job_id] = asyncio.Queue()
debug_log(f"Master - Initialized queue early for job {multi_job_id}")
else:
existing_size = prompt_server.distributed_pending_jobs[multi_job_id].qsize()
debug_log(f"Master - Using existing queue for job {multi_job_id} (current size: {existing_size})")
if delegate_mode:
master_batch_size = 0
images_on_cpu = None
master_audio = None
debug_log(f"Master - Job {multi_job_id}: Delegate-only mode enabled, collecting exclusively from {num_workers} workers")
else:
images_on_cpu = images.cpu()
master_batch_size = images.shape[0]
master_audio = audio # Keep master's audio for later
debug_log(f"Master - Job {multi_job_id}: Master has {master_batch_size} images, collecting from {num_workers} workers...")
# Ensure master images are contiguous
images_on_cpu = ensure_contiguous(images_on_cpu)
# Initialize storage for collected images and audio
worker_images = {} # Dict to store images by worker_id and index
worker_audio = {} # Dict to store audio by worker_id
# Collect images until all workers report they're done
collected_count = 0
workers_done = set()
# Use unified worker timeout from config/UI with simple sliced waits
base_timeout = float(get_worker_timeout_seconds())
slice_timeout = min(max(0.1, HEARTBEAT_INTERVAL / 20.0), base_timeout)
last_activity = time.time()
# Get queue size before starting
async with prompt_server.distributed_jobs_lock:
q = prompt_server.distributed_pending_jobs[multi_job_id]
initial_size = q.qsize()
# NEW: Initialize progress bar for workers (total = num_workers)
p = ProgressBar(num_workers)
def mark_worker_done(done_worker_id):
done_worker_id = str(done_worker_id)
if done_worker_id not in expected_workers:
debug_log(
f"Master - Ignoring completion from unexpected worker {done_worker_id} for job {multi_job_id}"
)
return
if done_worker_id in workers_done:
debug_log(
f"Master - Ignoring duplicate completion from worker {done_worker_id} for job {multi_job_id}"
)
return
workers_done.add(done_worker_id)
p.update(1) # +1 per completed expected worker
try:
while len(workers_done) < num_workers:
# Check for user interruption to abort collection promptly
comfy.model_management.throw_exception_if_processing_interrupted()
try:
# Get the queue again each time to ensure we have the right reference
async with prompt_server.distributed_jobs_lock:
q = prompt_server.distributed_pending_jobs[multi_job_id]
current_size = q.qsize()
result = await asyncio.wait_for(q.get(), timeout=slice_timeout)
worker_id = result['worker_id']
is_last = result.get('is_last', False)
count = self._store_worker_result(worker_images, result)
collected_count += count
debug_log(
f"Master - Got canonical result from worker {worker_id}, "
f"image {result.get('image_index', 0)}, is_last={is_last}"
)
# Collect audio data if present
result_audio = result.get('audio')
if result_audio is not None:
worker_audio[worker_id] = result_audio
debug_log(f"Master - Got audio from worker {worker_id}")
# Record activity and refresh timeout baseline
last_activity = time.time()
base_timeout = float(get_worker_timeout_seconds())
if is_last:
mark_worker_done(worker_id)
except asyncio.TimeoutError:
# If we still have time, continue polling; otherwise handle timeout
if (time.time() - last_activity) < base_timeout:
comfy.model_management.throw_exception_if_processing_interrupted()
continue
# Re-check for user interruption after timeout expiry
comfy.model_management.throw_exception_if_processing_interrupted()
missing_workers = set(str(w) for w in enabled_workers) - workers_done
elapsed = time.time() - last_activity
for missing_worker_id in sorted(missing_workers):
log(
"Master - Heartbeat timeout: "
f"worker={missing_worker_id}, elapsed={elapsed:.1f}s"
)
log(
f"Master - Heartbeat timeout. Still waiting for workers: {list(missing_workers)} "
f"(elapsed={elapsed:.1f}s)"
)
# Probe missing workers' /prompt endpoints to check if they are actively processing
any_busy = False
try:
cfg = load_config()
cfg_workers = cfg.get('workers', [])
for wid in list(missing_workers):
wrec = next((w for w in cfg_workers if str(w.get('id')) == str(wid)), None)
if not wrec:
debug_log(f"Collector probe: worker {wid} not found in config")
continue
worker_url = build_worker_url(wrec)
try:
payload = await probe_worker(worker_url, timeout=2.0)
queue_remaining = None
if payload is not None:
queue_remaining = int(payload.get('exec_info', {}).get('queue_remaining', 0))
debug_log(
"Collector probe: worker "
f"{wid} online={payload is not None} queue_remaining={queue_remaining}"
)
if payload is not None and queue_remaining and queue_remaining > 0:
any_busy = True
log(
f"Master - Probe grace: worker {wid} appears busy "
f"(queue_remaining={queue_remaining}). Continuing to wait."
)
break
except Exception as e:
debug_log(f"Collector probe failed for worker {wid}: {e}")
except Exception as e:
debug_log(f"Collector probe setup error: {e}")
if any_busy:
# Refresh last_activity and continue waiting
last_activity = time.time()
# Refresh base timeout in case the user changed it in UI
base_timeout = float(get_worker_timeout_seconds())
continue
# Check queue size again with lock
async with prompt_server.distributed_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_jobs:
final_q = prompt_server.distributed_pending_jobs[multi_job_id]
final_size = final_q.qsize()
# Try to drain any remaining items
remaining_items = []
while not final_q.empty():
try:
item = final_q.get_nowait()
remaining_items.append(item)
except asyncio.QueueEmpty:
break
if remaining_items:
# Process them
for item in remaining_items:
worker_id = item['worker_id']
is_last = item.get('is_last', False)
collected_count += self._store_worker_result(worker_images, item)
if is_last:
mark_worker_done(worker_id)
else:
log(f"Master - Queue {multi_job_id} no longer exists!")
break
except comfy.model_management.InterruptProcessingException:
# Cleanup queue on interruption and re-raise to abort prompt cleanly
async with prompt_server.distributed_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_jobs:
del prompt_server.distributed_pending_jobs[multi_job_id]
raise
total_collected = sum(len(imgs) for imgs in worker_images.values())
# Clean up job queue
async with prompt_server.distributed_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_jobs:
del prompt_server.distributed_pending_jobs[multi_job_id]
try:
combined = self._reorder_and_combine_tensors(
worker_images, enabled_workers, master_batch_size, images_on_cpu, delegate_mode, images
)
debug_log(f"Master - Job {multi_job_id} complete. Combined {combined.shape[0]} images total "
f"(master: {master_batch_size}, workers: {combined.shape[0] - master_batch_size})")
# Combine audio from master and workers
combined_audio = self._combine_audio(master_audio, worker_audio, self.EMPTY_AUDIO, enabled_workers)
return (combined, combined_audio)
except Exception as e:
log(f"Master - Error combining images: {e}")
# Return just the master images as fallback
return (images, audio if audio is not None else self.EMPTY_AUDIO)
================================================
FILE: nodes/distributed_upscale.py
================================================
import json
import math
from functools import wraps
import comfy.samplers
from ..utils.logging import debug_log, log
from ..utils.async_helpers import run_async_in_server_loop
from ..upscale.job_store import ensure_tile_jobs_initialized
from ..upscale.tile_ops import TileOpsMixin
from ..upscale.result_collector import ResultCollectorMixin
from ..upscale.worker_comms import WorkerCommsMixin
from ..upscale.job_state import JobStateMixin
from ..upscale.modes.single_gpu import SingleGpuModeMixin
from ..upscale.modes.static import StaticModeMixin
from ..upscale.modes.dynamic import DynamicModeMixin
def sync_wrapper(async_func):
"""Decorator to wrap async methods for synchronous execution."""
@wraps(async_func)
def sync_func(self, *args, **kwargs):
# Use run_async_in_server_loop for ComfyUI compatibility
return run_async_in_server_loop(
async_func(self, *args, **kwargs),
timeout=600.0 # 10 minute timeout for long operations
)
return sync_func
def _parse_enabled_worker_ids(enabled_worker_ids):
"""Parse enabled worker IDs from either JSON or list input."""
if isinstance(enabled_worker_ids, list):
return [str(worker_id) for worker_id in enabled_worker_ids]
if not enabled_worker_ids:
return []
if isinstance(enabled_worker_ids, str):
try:
parsed = json.loads(enabled_worker_ids)
except json.JSONDecodeError:
log("USDU Dist: Invalid enabled_worker_ids JSON; defaulting to no workers.")
return []
if isinstance(parsed, list):
return [str(wid) for wid in parsed]
return []
class UltimateSDUpscaleDistributed(
DynamicModeMixin,
StaticModeMixin,
SingleGpuModeMixin,
ResultCollectorMixin,
WorkerCommsMixin,
JobStateMixin,
TileOpsMixin,
):
"""
Distributed version of Ultimate SD Upscale (No Upscale).
Supports three processing modes:
1. Single GPU: No workers available, process everything locally
2. Static Mode: Small batches, distributes tiles across workers (flattened)
3. Dynamic Mode: Large batches, assigns whole images to workers dynamically
Features:
- Multi-mode batch handling for efficient video/image upscaling
- Tiled VAE support for memory efficiency
- Dynamic load balancing for large batches
- Backward compatible with single-image workflows
Environment Variables:
- COMFYUI_MAX_BATCH: Chunk size for tile sending (default 20)
- COMFYUI_MAX_PAYLOAD_SIZE: Max API payload bytes (default 50MB)
Threshold: dynamic_threshold input controls mode switch (default 8)
"""
def __init__(self):
"""Initialize the node and ensure persistent storage exists."""
# Pre-initialize the persistent storage on node creation
ensure_tile_jobs_initialized()
debug_log("UltimateSDUpscaleDistributed - Node initialized")
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"upscaled_image": ("IMAGE",),
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS,),
"denoise": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"tile_width": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 8}),
"tile_height": ("INT", {"default": 512, "min": 64, "max": 2048, "step": 8}),
"padding": ("INT", {"default": 32, "min": 0, "max": 256, "step": 8}),
"mask_blur": ("INT", {"default": 8, "min": 0, "max": 256}),
"force_uniform_tiles": ("BOOLEAN", {"default": True}),
"tiled_decode": ("BOOLEAN", {"default": False}),
},
"hidden": {
"multi_job_id": ("STRING", {"default": ""}),
"is_worker": ("BOOLEAN", {"default": False}),
"master_url": ("STRING", {"default": ""}),
"enabled_worker_ids": ("STRING", {"default": "[]"}),
"worker_id": ("STRING", {"default": ""}),
"tile_indices": ("STRING", {"default": ""}), # Unused - kept for compatibility
"dynamic_threshold": ("INT", {"default": 8, "min": 1, "max": 64}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "run"
CATEGORY = "image/upscaling"
@classmethod
def IS_CHANGED(cls, **kwargs):
"""Force re-execution."""
return float("nan") # Always re-execute
def run(self, upscaled_image, model, positive, negative, vae, seed, steps, cfg,
sampler_name, scheduler, denoise, tile_width, tile_height, padding,
mask_blur, force_uniform_tiles, tiled_decode,
multi_job_id="", is_worker=False, master_url="", enabled_worker_ids="[]",
worker_id="", tile_indices="", dynamic_threshold=8):
"""Entry point - runs SYNCHRONOUSLY like Ultimate SD Upscaler."""
# Strict WAN/FLOW batching: error if batch is not 4n+1 (except allow 1)
try:
batch_size = int(getattr(upscaled_image, 'shape', [1])[0])
except Exception:
batch_size = 1
# Enforce 4n+1 batches globally for any model when batch > 1 (master only)
if not is_worker and batch_size != 1 and (batch_size % 4 != 1):
raise ValueError(
f"Batch size {batch_size} is not of the form 4n+1. "
"This node requires batch sizes of 1 or 4n+1 (1, 5, 9, 13, ...). "
"Please adjust the batch size."
)
if not multi_job_id:
# No distributed processing, run single GPU version
return self.process_single_gpu(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur, force_uniform_tiles, tiled_decode)
if is_worker:
# Worker mode: process tiles synchronously
return self.process_worker(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_worker_ids, dynamic_threshold)
else:
# Master mode: distribute and collect synchronously
return self.process_master(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_worker_ids,
dynamic_threshold)
def process_worker(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_worker_ids, dynamic_threshold):
"""Unified worker processing - handles both static and dynamic modes."""
# Get batch size to determine mode
batch_size = upscaled_image.shape[0]
# Ensure mode consistency across master/workers via shared threshold
# Determine mode (must match master's logic)
enabled_workers = json.loads(enabled_worker_ids)
num_workers = len(enabled_workers)
# Compute number of tiles for this image to decide if tile distribution makes sense
_, height, width, _ = upscaled_image.shape
all_tiles = self.calculate_tiles(width, height, self.round_to_multiple(tile_width), self.round_to_multiple(tile_height), force_uniform_tiles)
num_tiles_per_image = len(all_tiles)
mode = self._determine_processing_mode(batch_size, num_workers, dynamic_threshold)
# For USDU-style processing, we want tile distribution whenever workers are available
# and there is more than one tile to process, even if batch == 1.
if num_workers > 0 and num_tiles_per_image > 1:
mode = "static"
debug_log(f"USDU Dist Worker - Batch size {batch_size}")
if mode == "dynamic":
return self.process_worker_dynamic(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_worker_ids, dynamic_threshold)
# Static mode - enhanced with health monitoring and retry logic
return self._process_worker_static_sync(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_workers)
def process_master(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_worker_ids,
dynamic_threshold):
"""Unified master processing with enhanced monitoring and failure handling."""
# Round tile dimensions
tile_width = self.round_to_multiple(tile_width)
tile_height = self.round_to_multiple(tile_height)
# Get image dimensions and batch size
batch_size, height, width, _ = upscaled_image.shape
# Calculate all tiles and grid
all_tiles = self.calculate_tiles(width, height, tile_width, tile_height, force_uniform_tiles)
num_tiles_per_image = len(all_tiles)
rows = math.ceil(height / tile_height)
cols = math.ceil(width / tile_width)
log(
f"USDU Dist: Canvas {width}x{height} | Tile {tile_width}x{tile_height} | Grid {rows}x{cols} ({num_tiles_per_image} tiles/image) | Batch {batch_size}"
)
# Parse enabled workers
enabled_workers = json.loads(enabled_worker_ids)
num_workers = len(enabled_workers)
# Determine processing mode
mode = self._determine_processing_mode(batch_size, num_workers, dynamic_threshold)
# Prefer tile-based static distribution when workers are available and there are multiple tiles,
# even for batch == 1, to spread tiles across GPUs like the legacy dynamic tile queue.
if num_workers > 0 and num_tiles_per_image > 1:
mode = "static"
log(f"USDU Dist: Workers {num_workers} | Mode {mode} | Threshold {dynamic_threshold}")
if mode == "single_gpu":
# No workers, process all tiles locally
return self.process_single_gpu(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur, force_uniform_tiles, tiled_decode)
elif mode == "dynamic":
# Dynamic mode for large batches
return self.process_master_dynamic(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_workers)
# Static mode - enhanced with unified job management
return self._process_master_static_sync(upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_workers,
all_tiles, num_tiles_per_image)
def _determine_processing_mode(self, batch_size: int, num_workers: int, dynamic_threshold: int) -> str:
"""Determines processing mode per requested policy:
- any workers => prefer static (tile-based) for USDU
- no workers => single_gpu
"""
if num_workers == 0:
return "single_gpu"
# Default to static when distributed; master/worker may still override if special cases arise
return "static"
# Ensure initialization before registering routes
ensure_tile_jobs_initialized()
# Node registration
NODE_CLASS_MAPPINGS = {
"UltimateSDUpscaleDistributed": UltimateSDUpscaleDistributed,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UltimateSDUpscaleDistributed": "Ultimate SD Upscale Distributed (No Upscale)",
}
================================================
FILE: nodes/utilities.py
================================================
import torch
import json
from ..utils.logging import debug_log, log
def _chunk_bounds(total_items: int, n_splits: int) -> list[tuple[int, int]]:
"""Return contiguous [start, end) bounds for n_splits chunks."""
split_count = max(1, int(n_splits))
total = max(0, int(total_items))
base, remainder = divmod(total, split_count)
bounds: list[tuple[int, int]] = []
start = 0
for idx in range(split_count):
size = base + (1 if idx < remainder else 0)
end = start + size
bounds.append((start, end))
start = end
return bounds
class DistributedSeed:
"""
Distributes seed values across multiple GPUs.
On master: passes through the original seed.
On workers: adds offset based on worker ID.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"seed": ("INT", {
"default": 1125899906842,
"min": 0,
"max": 1125899906842624,
"forceInput": False # Widget by default, can be converted to input
}),
},
"hidden": {
"is_worker": ("BOOLEAN", {"default": False}),
"worker_id": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("INT",)
RETURN_NAMES = ("seed",)
FUNCTION = "distribute"
CATEGORY = "utils"
def distribute(self, seed, is_worker=False, worker_id=""):
if not is_worker:
# Master node: pass through original values
debug_log(f"Distributor - Master: seed={seed}")
return (seed,)
else:
# Worker node: apply offset based on worker index
# Find worker index from enabled_worker_ids
try:
# Worker IDs are passed as "worker_0", "worker_1", etc.
if worker_id.startswith("worker_"):
worker_index = int(worker_id.split("_")[1])
else:
# Fallback: try to parse as direct index
worker_index = int(worker_id)
offset = worker_index + 1
new_seed = seed + offset
debug_log(f"Distributor - Worker {worker_index}: seed={seed} → {new_seed}")
return (new_seed,)
except (ValueError, IndexError) as e:
debug_log(f"Distributor - Error parsing worker_id '{worker_id}': {e}")
# Fallback: return original seed
return (seed,)
# Define ByPassTypeTuple for flexible return types
class AnyType(str):
def __ne__(self, __value: object) -> bool:
return False
any_type = AnyType("*")
class DistributedValue:
"""
Outputs a different value per worker.
On master: returns default_value.
On workers: looks up the worker-specific value from a JSON map,
falling back to default_value if not set.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"default_value": ("STRING", {"default": ""}),
"worker_values": ("STRING", {"default": "{}"}),
},
"hidden": {
"is_worker": ("BOOLEAN", {"default": False}),
"worker_id": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = (any_type,)
RETURN_NAMES = ("value",)
FUNCTION = "distribute"
CATEGORY = "utils"
@staticmethod
def _coerce(value, value_type):
"""Convert a string value to the requested type."""
if value_type == "INT":
return int(float(value))
if value_type == "FLOAT":
return float(value)
return value # STRING and COMBO stay as strings
@staticmethod
def _coerce_safe(value, value_type):
"""Best-effort coercion with graceful fallback to original value."""
try:
return DistributedValue._coerce(value, value_type)
except (TypeError, ValueError):
return value
def distribute(self, default_value, worker_values="{}", is_worker=False, worker_id=""):
values = {}
value_type = "STRING"
try:
values = json.loads(worker_values) if isinstance(worker_values, str) else worker_values
if not isinstance(values, dict):
values = {}
except json.JSONDecodeError as e:
debug_log(f"DistributedValue - Error parsing worker_values: {e}")
values = {}
value_type = values.get("_type", "STRING")
coerced_default = self._coerce_safe(default_value, value_type)
if not is_worker:
debug_log(f"DistributedValue - Master: returning default '{coerced_default}'")
return (coerced_default,)
try:
if worker_id.startswith("worker_"):
idx = int(worker_id.split("_")[1])
else:
idx = int(worker_id)
key = str(idx + 1) # worker_0 → key "1" (1-indexed)
raw = values.get(key, "")
if raw:
coerced = self._coerce(raw, value_type)
debug_log(f"DistributedValue - Worker {idx}: returning '{coerced}'")
return (coerced,)
except (ValueError, IndexError) as e:
debug_log(f"DistributedValue - Error: {e}")
debug_log(f"DistributedValue - Worker fallback: returning default '{coerced_default}'")
return (coerced_default,)
class DistributedModelName:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": ""}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
"extra_pnginfo": "EXTRA_PNGINFO",
},
}
RETURN_TYPES = (any_type,)
RETURN_NAMES = ("output",)
FUNCTION = "log_input"
OUTPUT_NODE = True
CATEGORY = "utils"
def _stringify(self, value):
if isinstance(value, str):
return value
if isinstance(value, (int, float, bool)):
return str(value)
try:
return json.dumps(value, indent=4)
except Exception:
return str(value)
def _update_workflow(self, extra_pnginfo, unique_id, values):
if not extra_pnginfo:
return
info = extra_pnginfo[0] if isinstance(extra_pnginfo, list) else extra_pnginfo
if not isinstance(info, dict) or "workflow" not in info:
return
node_id = None
if isinstance(unique_id, list) and unique_id:
node_id = str(unique_id[0])
elif unique_id is not None:
node_id = str(unique_id)
if not node_id:
return
workflow = info["workflow"]
node = next((x for x in workflow["nodes"] if str(x.get("id")) == node_id), None)
if node:
node["widgets_values"] = [values]
def log_input(self, text, unique_id=None, extra_pnginfo=None):
values = []
if isinstance(text, list):
for val in text:
values.append(self._stringify(val))
else:
values.append(self._stringify(text))
# Keep widget display in workflow metadata if available.
self._update_workflow(extra_pnginfo, unique_id, values)
if isinstance(values, list) and len(values) == 1:
return {"ui": {"text": values}, "result": (values[0],)}
return {"ui": {"text": values}, "result": (values,)}
class ByPassTypeTuple(tuple):
def __getitem__(self, index):
if index > 0:
index = 0
item = super().__getitem__(index)
if isinstance(item, str):
return any_type
return item
class ImageBatchDivider:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"images": ("IMAGE",),
"divide_by": ("INT", {
"default": 2,
"min": 1,
"max": 10,
"step": 1,
"display": "number",
"tooltip": "Number of parts to divide the batch into"
}),
}
}
RETURN_TYPES = ByPassTypeTuple(("IMAGE", )) # Flexible for variable outputs
RETURN_NAMES = ByPassTypeTuple(tuple([f"batch_{i+1}" for i in range(10)]))
FUNCTION = "divide_batch"
OUTPUT_NODE = True
CATEGORY = "image"
def divide_batch(self, images, divide_by):
total_splits = max(1, min(int(divide_by), 10))
total_frames = images.shape[0]
empty_tensor = images[:0]
bounds = _chunk_bounds(total_frames, total_splits)
outputs = [images[start:end] if end > start else empty_tensor for start, end in bounds]
while len(outputs) < 10:
outputs.append(empty_tensor)
return tuple(outputs[:10])
class AudioBatchDivider:
"""Divides an audio waveform into multiple parts along the time/samples dimension."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"audio": ("AUDIO",),
"divide_by": ("INT", {
"default": 2,
"min": 1,
"max": 10,
"step": 1,
"display": "number",
"tooltip": "Number of parts to divide the audio into"
}),
}
}
RETURN_TYPES = ByPassTypeTuple(("AUDIO",)) # Flexible for variable outputs
RETURN_NAMES = ByPassTypeTuple(tuple([f"audio_{i+1}" for i in range(10)]))
FUNCTION = "divide_audio"
OUTPUT_NODE = True
CATEGORY = "audio"
def divide_audio(self, audio, divide_by):
import torch
waveform = audio.get("waveform")
sample_rate = audio.get("sample_rate", 44100)
if waveform is None or waveform.numel() == 0:
# Return empty audio for all outputs
empty_audio = {"waveform": torch.zeros(1, 2, 1), "sample_rate": sample_rate}
return tuple([empty_audio] * 10)
total_splits = max(1, min(int(divide_by), 10))
total_samples = int(waveform.shape[-1])
bounds = _chunk_bounds(total_samples, total_splits)
outputs = []
empty_waveform = waveform[..., :0]
for start, end in bounds:
split_waveform = waveform[..., start:end] if end > start else empty_waveform
outputs.append({
"waveform": split_waveform,
"sample_rate": sample_rate
})
# Pad with empty audio up to max (10) to match RETURN_TYPES length
empty_audio = {
"waveform": empty_waveform,
"sample_rate": sample_rate
}
while len(outputs) < 10:
outputs.append(empty_audio)
return tuple(outputs)
class DistributedEmptyImage:
"""Produces an empty IMAGE batch used when the master delegates all work."""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"height": ("INT", {"default": 64, "min": 1, "max": 4096, "step": 1}),
"width": ("INT", {"default": 64, "min": 1, "max": 4096, "step": 1}),
"channels": ("INT", {"default": 3, "min": 1, "max": 4, "step": 1}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "create"
CATEGORY = "image"
def create(self, height, width, channels):
import torch
shape = (0, height, width, channels)
tensor = torch.zeros(shape, dtype=torch.float32)
return (tensor,)
================================================
FILE: package.json
================================================
{
"name": "comfyui-distributed-web-tests",
"private": true,
"type": "module",
"scripts": {
"test:web": "bash ./scripts/test-web.sh",
"test:web:watch": "bash ./scripts/test-web.sh --watch"
},
"devDependencies": {
"vitest": "^2.1.9"
}
}
================================================
FILE: pyproject.toml
================================================
[project]
name = "ComfyUI-Distributed"
description = "ComfyUI extension that enables multi-GPU processing locally, remotely and in the cloud"
version = "1.4.4"
license = {file = "LICENSE"}
dependencies = []
[project.urls]
Repository = "https://github.com/robertvoy/ComfyUI-Distributed"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "robertvoy"
DisplayName = "ComfyUI-Distributed"
Icon = "https://raw.githubusercontent.com/robertvoy/ComfyUI-Distributed/refs/heads/main/web/distributed-logo-icon.png"
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
================================================
FILE: scripts/test-web.sh
================================================
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd -- "${SCRIPT_DIR}/.." && pwd)"
export NVM_DIR="${NVM_DIR:-$HOME/.nvm}"
if [[ -s "${NVM_DIR}/nvm.sh" ]]; then
# shellcheck source=/dev/null
. "${NVM_DIR}/nvm.sh"
fi
if ! command -v node >/dev/null 2>&1 || ! command -v npm >/dev/null 2>&1; then
echo "[test-web] node/npm are not available." >&2
echo "[test-web] Install nvm and Node, or ensure node/npm are on PATH." >&2
exit 1
fi
if [[ -f "${REPO_ROOT}/.nvmrc" ]] && command -v nvm >/dev/null 2>&1; then
nvm use >/dev/null
fi
cd "${REPO_ROOT}"
if [[ "${1:-}" == "--watch" ]]; then
exec npx vitest web/tests
fi
exec npx vitest run web/tests
================================================
FILE: tests/api/test_config_routes.py
================================================
import copy
import importlib.util
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
class _FakeResponse:
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
class _FakeRequest:
def __init__(self, payload=None):
self._payload = payload
async def json(self):
return self._payload
def _load_config_routes_module():
module_path = Path(__file__).resolve().parents[2] / "api" / "config_routes.py"
package_name = "dist_api_config_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: _FakeResponse(payload, status=status)
)
sys.modules["aiohttp"] = aiohttp_module
class _Routes:
def get(self, _path):
def _decorator(fn):
return fn
return _decorator
def post(self, _path):
def _decorator(fn):
return fn
return _decorator
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(instance=types.SimpleNamespace(routes=_Routes()))
sys.modules["server"] = server_module
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
async def _handle_api_error(_request, error, status=500):
return _FakeResponse({"status": "error", "message": str(error)}, status=status)
network_module.handle_api_error = _handle_api_error
network_module.normalize_host = lambda value: value
sys.modules[f"{package_name}.utils.network"] = network_module
default_config = {
"workers": [],
"master": {"host": ""},
"settings": {"debug": False},
"tunnel": {},
}
config_module = types.ModuleType(f"{package_name}.utils.config")
config_module.load_config = lambda: copy.deepcopy(default_config)
config_module.save_config = lambda _cfg: True
sys.modules[f"{package_name}.utils.config"] = config_module
spec = importlib.util.spec_from_file_location(f"{package_name}.api.config_routes", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
return module
config_routes = _load_config_routes_module()
class ConfigRoutesTests(unittest.IsolatedAsyncioTestCase):
async def test_get_config_returns_core_sections(self):
cfg = {"workers": [], "master": {}, "settings": {}, "tunnel": {}}
with patch.object(config_routes, "load_config", return_value=cfg):
response = await config_routes.get_config_endpoint(_FakeRequest())
self.assertEqual(response.status, 200)
self.assertIn("workers", response.payload)
self.assertIn("master", response.payload)
self.assertIn("settings", response.payload)
async def test_update_config_valid_field_persists(self):
cfg = {"workers": [], "master": {}, "settings": {"debug": False}, "tunnel": {}}
with patch.object(config_routes, "load_config", return_value=cfg), patch.object(
config_routes, "save_config", return_value=True
):
response = await config_routes.update_config_endpoint(_FakeRequest({"debug": True}))
self.assertEqual(response.status, 200)
self.assertEqual(response.payload["status"], "success")
self.assertTrue(response.payload["config"]["settings"]["debug"])
async def test_update_config_unknown_field_returns_400(self):
cfg = {"workers": [], "master": {}, "settings": {"debug": False}, "tunnel": {}}
with patch.object(config_routes, "load_config", return_value=cfg):
response = await config_routes.update_config_endpoint(_FakeRequest({"unknown_field": 1}))
self.assertEqual(response.status, 400)
self.assertIn("unknown_field", " ".join(response.payload.get("error", [])).lower())
async def test_update_config_wrong_type_returns_400(self):
cfg = {"workers": [], "master": {}, "settings": {"debug": False}, "tunnel": {}}
with patch.object(config_routes, "load_config", return_value=cfg):
response = await config_routes.update_config_endpoint(_FakeRequest({"debug": "true"}))
self.assertEqual(response.status, 400)
self.assertIn("debug", " ".join(response.payload.get("error", [])).lower())
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/api/test_distributed_queue.py
================================================
import importlib.util
import sys
import types
import unittest
import asyncio
import base64
from dataclasses import dataclass
from pathlib import Path
from unittest.mock import AsyncMock, patch
import numpy as np
import torch
class _FakeResponse:
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
class _FakeRequest:
def __init__(self, payload):
self._payload = payload
async def json(self):
return self._payload
def _load_job_routes_module():
module_path = Path(__file__).resolve().parents[2] / "api" / "job_routes.py"
package_name = "dist_api_queue_testpkg"
# Reset package namespace to avoid stale module state across test runs.
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
# aiohttp.web stub
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: _FakeResponse(payload, status=status)
)
sys.modules["aiohttp"] = aiohttp_module
# server module stub with route decorators
class _Routes:
def get(self, _path):
def _decorator(fn):
return fn
return _decorator
def post(self, _path):
def _decorator(fn):
return fn
return _decorator
prompt_server_instance = types.SimpleNamespace(
routes=_Routes(),
distributed_jobs_lock=None,
distributed_pending_jobs={},
)
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(instance=prompt_server_instance)
sys.modules["server"] = server_module
# torch stub (only needed to satisfy import)
created_torch_stub = False
if "torch" not in sys.modules:
created_torch_stub = True
torch_module = types.ModuleType("torch")
torch_module.cuda = types.SimpleNamespace(
is_available=lambda: False,
empty_cache=lambda: None,
ipc_collect=lambda: None,
)
sys.modules["torch"] = torch_module
# PIL stub (only needed to satisfy import)
created_pil_stub = False
if "PIL" not in sys.modules:
created_pil_stub = True
pil_module = types.ModuleType("PIL")
image_module = types.ModuleType("PIL.Image")
pil_module.Image = image_module
sys.modules["PIL"] = pil_module
sys.modules["PIL.Image"] = image_module
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
image_module = types.ModuleType(f"{package_name}.utils.image")
image_module.pil_to_tensor = lambda *_args, **_kwargs: None
image_module.ensure_contiguous = lambda tensor: tensor
sys.modules[f"{package_name}.utils.image"] = image_module
audio_payload_module = types.ModuleType(f"{package_name}.utils.audio_payload")
def _decode_audio_payload(payload):
if payload is None:
return None
if not isinstance(payload, dict):
raise ValueError("Field 'audio' must be an object when provided.")
encoded = payload.get("data")
shape = payload.get("shape")
dtype = payload.get("dtype", "float32")
sample_rate = payload.get("sample_rate", 44100)
if not isinstance(encoded, str) or not encoded.strip():
raise ValueError("Field 'audio.data' must be a non-empty base64 string.")
if not isinstance(shape, list) or len(shape) != 3:
raise ValueError("Field 'audio.shape' must be a 3-item list.")
if dtype != "float32":
raise ValueError("Field 'audio.dtype' must be 'float32'.")
try:
shape_tuple = tuple(int(dim) for dim in shape)
except Exception as exc:
raise ValueError("Field 'audio.shape' must contain integers.") from exc
raw = base64.b64decode(encoded, validate=True)
expected_bytes = int(np.prod(shape_tuple, dtype=np.int64)) * 4
if len(raw) != expected_bytes:
raise ValueError("Field 'audio.data' byte size mismatch.")
waveform = torch.from_numpy(np.frombuffer(raw, dtype=np.float32).reshape(shape_tuple).copy())
return {"waveform": waveform, "sample_rate": int(sample_rate)}
audio_payload_module.decode_audio_payload = _decode_audio_payload
sys.modules[f"{package_name}.utils.audio_payload"] = audio_payload_module
network_module = types.ModuleType(f"{package_name}.utils.network")
async def _handle_api_error(_request, error, status=500):
return _FakeResponse({"status": "error", "message": str(error)}, status=status)
network_module.handle_api_error = _handle_api_error
sys.modules[f"{package_name}.utils.network"] = network_module
constants_module = types.ModuleType(f"{package_name}.utils.constants")
constants_module.MEMORY_CLEAR_DELAY = 0.0
constants_module.JOB_INIT_GRACE_PERIOD = 10.0
sys.modules[f"{package_name}.utils.constants"] = constants_module
async_helpers_module = types.ModuleType(f"{package_name}.utils.async_helpers")
async_helpers_module.queue_prompt_payload = AsyncMock(return_value="prompt_local")
sys.modules[f"{package_name}.utils.async_helpers"] = async_helpers_module
queue_orchestration_module = types.ModuleType(f"{package_name}.api.queue_orchestration")
queue_orchestration_module.orchestrate_distributed_execution = AsyncMock(return_value=("prompt_dist", 7, 1, {}))
sys.modules[f"{package_name}.api.queue_orchestration"] = queue_orchestration_module
@dataclass(frozen=True)
class _QueuePayload:
prompt: dict
workflow_meta: object
client_id: str
delegate_master: object
enabled_worker_ids: list
auto_prepare: bool
trace_execution_id: object
def _parse_queue_request_payload(data):
if not isinstance(data, dict):
raise ValueError("Expected a JSON object body")
prompt = data.get("prompt")
if not isinstance(prompt, dict):
raise ValueError("Field 'prompt' must be an object")
enabled = data.get("enabled_worker_ids")
if not isinstance(enabled, list):
raise ValueError("enabled_worker_ids required")
client_id = data.get("client_id")
if not isinstance(client_id, str) or not client_id.strip():
raise ValueError("client_id required")
return _QueuePayload(
prompt=prompt,
workflow_meta=data.get("workflow"),
client_id=client_id,
delegate_master=data.get("delegate_master"),
enabled_worker_ids=enabled,
auto_prepare=bool(data.get("auto_prepare", True)),
trace_execution_id=data.get("trace_execution_id"),
)
queue_request_module = types.ModuleType(f"{package_name}.api.queue_request")
queue_request_module.parse_queue_request_payload = _parse_queue_request_payload
sys.modules[f"{package_name}.api.queue_request"] = queue_request_module
spec = importlib.util.spec_from_file_location(f"{package_name}.api.job_routes", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
if created_torch_stub:
sys.modules.pop("torch", None)
if created_pil_stub:
sys.modules.pop("PIL.Image", None)
sys.modules.pop("PIL", None)
return module
job_routes = _load_job_routes_module()
class DistributedQueueEndpointTests(unittest.IsolatedAsyncioTestCase):
async def test_distributed_queue_happy_path_returns_prompt_metadata(self):
request = _FakeRequest(
{
"prompt": {"1": {"class_type": "Node"}},
"enabled_worker_ids": ["w1"],
"client_id": "client-1",
"auto_prepare": True,
}
)
with patch.object(
job_routes,
"orchestrate_distributed_execution",
new=AsyncMock(return_value=("prompt_123", 42, 2, {})),
):
response = await job_routes.distributed_queue_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("prompt_id"), "prompt_123")
self.assertEqual(response.payload.get("number"), 42)
self.assertEqual(response.payload.get("node_errors"), {})
self.assertTrue(response.payload.get("auto_prepare_supported"))
async def test_distributed_queue_missing_prompt_returns_400(self):
request = _FakeRequest(
{
"enabled_worker_ids": ["w1"],
"client_id": "client-1",
}
)
response = await job_routes.distributed_queue_endpoint(request)
self.assertEqual(response.status, 400)
self.assertIn("prompt", response.payload.get("message", "").lower())
async def test_distributed_queue_missing_enabled_worker_ids_returns_400(self):
request = _FakeRequest(
{
"prompt": {"1": {"class_type": "Node"}},
"client_id": "client-1",
}
)
response = await job_routes.distributed_queue_endpoint(request)
self.assertEqual(response.status, 400)
self.assertIn("enabled_worker_ids", response.payload.get("message", "").lower())
class JobCompleteAudioPayloadTests(unittest.IsolatedAsyncioTestCase):
def _encoded_audio_payload(self):
waveform = np.arange(8, dtype=np.float32).reshape(1, 2, 4)
return {
"sample_rate": 44100,
"shape": [1, 2, 4],
"dtype": "float32",
"data": base64.b64encode(waveform.tobytes()).decode("ascii"),
}
async def test_job_complete_accepts_audio_payload(self):
queue = asyncio.Queue()
job_routes.prompt_server.distributed_jobs_lock = asyncio.Lock()
job_routes.prompt_server.distributed_pending_jobs = {"job-1": queue}
request = _FakeRequest(
{
"job_id": "job-1",
"worker_id": "worker-1",
"batch_idx": 0,
"image": "data:image/png;base64,AAAA",
"audio": self._encoded_audio_payload(),
"is_last": True,
}
)
with patch.object(job_routes, "_decode_canonical_png_tensor", return_value="tensor-data"):
response = await job_routes.job_complete_endpoint(request)
self.assertEqual(response.status, 200)
queued = await queue.get()
self.assertEqual(queued["worker_id"], "worker-1")
self.assertTrue(queued["is_last"])
self.assertIsNotNone(queued["audio"])
self.assertEqual(queued["audio"]["sample_rate"], 44100)
self.assertEqual(tuple(queued["audio"]["waveform"].shape), (1, 2, 4))
def test_decode_audio_payload_rejects_bad_shape(self):
bad = {
"sample_rate": 44100,
"shape": [1, 2],
"dtype": "float32",
"data": base64.b64encode(b"\x00\x00\x00\x00").decode("ascii"),
}
with self.assertRaises(ValueError):
job_routes._decode_audio_payload(bad)
def test_decode_audio_payload_rejects_bad_dtype(self):
payload = {
"sample_rate": 44100,
"shape": [1, 2, 4],
"dtype": "float16",
"data": base64.b64encode((np.zeros((1, 2, 4), dtype=np.float32)).tobytes()).decode("ascii"),
}
with self.assertRaises(ValueError):
job_routes._decode_audio_payload(payload)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/api/test_media_sync.py
================================================
import importlib.util
import sys
import types
import unittest
from pathlib import Path
def _load_media_sync_module():
module_path = Path(__file__).resolve().parents[2] / "api" / "orchestration" / "media_sync.py"
package_name = "dist_ms_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
orch_pkg = types.ModuleType(f"{package_name}.api.orchestration")
orch_pkg.__path__ = []
sys.modules[f"{package_name}.api.orchestration"] = orch_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
network_module.build_worker_url = lambda worker, endpoint="": f"http://localhost{endpoint}"
async def _fake_session():
raise RuntimeError("network calls not used in pure-function tests")
network_module.get_client_session = _fake_session
sys.modules[f"{package_name}.utils.network"] = network_module
trace_module = types.ModuleType(f"{package_name}.utils.trace_logger")
trace_module.trace_debug = lambda *_args, **_kwargs: None
trace_module.trace_info = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.trace_logger"] = trace_module
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
class _ClientTimeout:
def __init__(self, total=None):
pass
class _FormData:
def add_field(self, *args, **kwargs):
pass
aiohttp_module.ClientTimeout = _ClientTimeout
aiohttp_module.FormData = _FormData
sys.modules["aiohttp"] = aiohttp_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.api.orchestration.media_sync",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
return module
ms = _load_media_sync_module()
# ---------------------------------------------------------------------------
# convert_paths_for_platform
# ---------------------------------------------------------------------------
class ConvertPathsForPlatformTests(unittest.TestCase):
def test_forward_slash_target_normalises_backslashes(self):
obj = {"ckpt_name": "C:\\Models\\model.safetensors"}
result = ms.convert_paths_for_platform(obj, "/")
self.assertEqual(result["ckpt_name"], "C:/Models/model.safetensors")
def test_backslash_target_normalises_forward_slashes(self):
obj = {"ckpt_name": "/models/checkpoints/model.safetensors"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertIn("\\", result["ckpt_name"])
self.assertNotIn("/", result["ckpt_name"])
def test_relative_media_paths_always_stay_forward_slash(self):
"""Relative image/video/audio paths (Comfy annotated style) must not be backslash-ified."""
obj = {"image": "subfolder/my_photo.png"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertEqual(result["image"], "subfolder/my_photo.png")
def test_relative_audio_paths_stay_forward_slash(self):
obj = {"audio": "subfolder/my_track.wav"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertEqual(result["audio"], "subfolder/my_track.wav")
def test_annotated_relative_media_path_stays_forward_slash(self):
obj = {"image": "input/frame.jpg [abc123]"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertIn("/", result["image"])
self.assertNotIn("\\", result["image"].split("[")[0])
def test_non_filename_strings_are_untouched(self):
obj = {"prompt": "a beautiful cat", "count": 5}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertEqual(result["prompt"], "a beautiful cat")
self.assertEqual(result["count"], 5)
def test_url_strings_are_untouched(self):
obj = {"url": "https://example.com/model.safetensors"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertEqual(result["url"], "https://example.com/model.safetensors")
def test_invalid_separator_returns_obj_unchanged(self):
obj = {"ckpt_name": "/models/model.safetensors"}
result = ms.convert_paths_for_platform(obj, "|")
self.assertEqual(result, obj)
def test_nested_dict_is_processed_recursively(self):
obj = {"node": {"ckpt_name": "C:\\Models\\model.safetensors"}}
result = ms.convert_paths_for_platform(obj, "/")
self.assertEqual(result["node"]["ckpt_name"], "C:/Models/model.safetensors")
def test_list_items_are_processed_recursively(self):
obj = [{"ckpt_name": "C:\\Models\\model.safetensors"}, "plain string"]
result = ms.convert_paths_for_platform(obj, "/")
self.assertEqual(result[0]["ckpt_name"], "C:/Models/model.safetensors")
self.assertEqual(result[1], "plain string")
def test_non_string_scalar_values_are_untouched(self):
obj = {"seed": 42, "enabled": True, "ratio": 1.5}
result = ms.convert_paths_for_platform(obj, "/")
self.assertEqual(result["seed"], 42)
self.assertTrue(result["enabled"])
def test_absolute_unix_path_to_windows(self):
obj = {"lora": "/home/user/loras/my_lora.safetensors"}
result = ms.convert_paths_for_platform(obj, "\\")
self.assertNotIn("/", result["lora"])
def test_already_normalised_path_is_idempotent(self):
obj = {"ckpt": "C:/Models/model.safetensors"}
result = ms.convert_paths_for_platform(obj, "/")
self.assertEqual(result["ckpt"], "C:/Models/model.safetensors")
# ---------------------------------------------------------------------------
# _find_media_references
# ---------------------------------------------------------------------------
class FindMediaReferencesTests(unittest.TestCase):
def test_finds_image_input(self):
prompt = {"1": {"class_type": "LoadImage", "inputs": {"image": "photo.png"}}}
refs = ms._find_media_references(prompt)
self.assertIn("photo.png", refs)
def test_finds_video_input(self):
prompt = {"1": {"class_type": "LoadVideo", "inputs": {"video": "clip.mp4"}}}
refs = ms._find_media_references(prompt)
self.assertIn("clip.mp4", refs)
def test_finds_file_input_for_load_video(self):
prompt = {"1": {"class_type": "LoadVideo", "inputs": {"file": "1 - Copy.mp4"}}}
refs = ms._find_media_references(prompt)
self.assertIn("1 - Copy.mp4", refs)
def test_finds_audio_input(self):
prompt = {"1": {"class_type": "LoadAudio", "inputs": {"audio": "track.wav"}}}
refs = ms._find_media_references(prompt)
self.assertIn("track.wav", refs)
def test_strips_annotation_suffix(self):
prompt = {"1": {"class_type": "LoadImage", "inputs": {"image": "photo.jpg [abc123]"}}}
refs = ms._find_media_references(prompt)
self.assertIn("photo.jpg", refs)
self.assertFalse(any("[" in r for r in refs))
def test_normalises_backslashes_in_path(self):
prompt = {"1": {"class_type": "LoadImage", "inputs": {"image": "sub\\img.png"}}}
refs = ms._find_media_references(prompt)
self.assertIn("sub/img.png", refs)
def test_ignores_non_media_text_inputs(self):
prompt = {"1": {"class_type": "CLIPTextEncode", "inputs": {"text": "a cat"}}}
refs = ms._find_media_references(prompt)
self.assertEqual(refs, [])
def test_ignores_node_link_values(self):
"""Inputs that are [node_id, slot] lists should be ignored."""
prompt = {"1": {"class_type": "Anything", "inputs": {"image": ["2", 0]}}}
refs = ms._find_media_references(prompt)
self.assertEqual(refs, [])
def test_deduplicates_same_file_across_nodes(self):
prompt = {
"1": {"class_type": "LoadImage", "inputs": {"image": "cat.png"}},
"2": {"class_type": "LoadImage", "inputs": {"image": "cat.png"}},
}
refs = ms._find_media_references(prompt)
self.assertEqual(len(refs), 1)
def test_returns_sorted_list(self):
prompt = {
"1": {"class_type": "LoadImage", "inputs": {"image": "z_image.png"}},
"2": {"class_type": "LoadImage", "inputs": {"image": "a_image.jpg"}},
}
refs = ms._find_media_references(prompt)
self.assertEqual(refs, sorted(refs))
def test_ignores_non_dict_nodes(self):
prompt = {"1": "not a node dict", "2": {"class_type": "LoadImage", "inputs": {"image": "img.png"}}}
refs = ms._find_media_references(prompt)
self.assertIn("img.png", refs)
def test_empty_prompt_returns_empty_list(self):
self.assertEqual(ms._find_media_references({}), [])
def test_multiple_media_types_all_found(self):
prompt = {
"1": {"class_type": "LoadImage", "inputs": {"image": "frame.png"}},
"2": {"class_type": "LoadVideo", "inputs": {"video": "clip.mp4"}},
"3": {"class_type": "LoadAudio", "inputs": {"audio": "track.wav"}},
}
refs = ms._find_media_references(prompt)
self.assertIn("frame.png", refs)
self.assertIn("clip.mp4", refs)
self.assertIn("track.wav", refs)
class RewritePromptMediaInputsTests(unittest.TestCase):
def test_rewrites_video_file_input_to_worker_path(self):
prompt = {
"79": {"class_type": "LoadVideo", "inputs": {"file": "1 - Copy.mp4"}},
}
ms._rewrite_prompt_media_inputs(prompt, {"1 - Copy.mp4": "videos/1 - Copy.mp4"})
self.assertEqual(prompt["79"]["inputs"]["file"], "videos/1 - Copy.mp4")
def test_rewrites_audio_input_and_strips_annotation_when_matching(self):
prompt = {
"1": {"class_type": "LoadAudio", "inputs": {"audio": "song.wav [input]"}},
}
ms._rewrite_prompt_media_inputs(prompt, {"song.wav": "song.wav"})
self.assertEqual(prompt["1"]["inputs"]["audio"], "song.wav")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/api/test_usdu_routes.py
================================================
import asyncio
import importlib.util
import io
import sys
import types
import unittest
from dataclasses import dataclass, field
from pathlib import Path
from PIL import Image
class _FakeResponse:
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
class _FakeRequest:
def __init__(self, json_payload=None, post_payload=None, headers=None, query=None):
self._json_payload = json_payload
self._post_payload = post_payload or {}
self.headers = headers or {}
self.query = query or {}
async def json(self):
return self._json_payload
async def post(self):
return self._post_payload
class _Routes:
def post(self, _path):
def _decorator(fn):
return fn
return _decorator
def get(self, _path):
def _decorator(fn):
return fn
return _decorator
def _load_usdu_routes_module():
module_path = Path(__file__).resolve().parents[2] / "api" / "usdu_routes.py"
package_name = "dist_api_usdu_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
upscale_pkg = types.ModuleType(f"{package_name}.upscale")
upscale_pkg.__path__ = []
sys.modules[f"{package_name}.upscale"] = upscale_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
prompt_server_holder = {
"value": types.SimpleNamespace(
distributed_tile_jobs_lock=asyncio.Lock(),
distributed_pending_tile_jobs={},
)
}
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: _FakeResponse(payload, status=status)
)
sys.modules["aiohttp"] = aiohttp_module
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(instance=types.SimpleNamespace(routes=_Routes()))
sys.modules["server"] = server_module
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
async def _handle_api_error(_request, error, status=500):
return _FakeResponse({"status": "error", "message": str(error)}, status=status)
network_module.handle_api_error = _handle_api_error
sys.modules[f"{package_name}.utils.network"] = network_module
job_store_module = types.ModuleType(f"{package_name}.upscale.job_store")
job_store_module.MAX_PAYLOAD_SIZE = 1024
job_store_module.ensure_tile_jobs_initialized = lambda: prompt_server_holder["value"]
sys.modules[f"{package_name}.upscale.job_store"] = job_store_module
job_models_module = types.ModuleType(f"{package_name}.upscale.job_models")
class BaseJobState:
pass
@dataclass
class TileJobState(BaseJobState):
multi_job_id: str
mode: str = field(default="static", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_tasks: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_tasks: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
@dataclass
class ImageJobState(BaseJobState):
multi_job_id: str
mode: str = field(default="dynamic", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_images: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_images: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
@property
def pending_tasks(self):
return self.pending_images
@property
def completed_tasks(self):
return self.completed_images
job_models_module.BaseJobState = BaseJobState
job_models_module.TileJobState = TileJobState
job_models_module.ImageJobState = ImageJobState
sys.modules[f"{package_name}.upscale.job_models"] = job_models_module
parsers_module = types.ModuleType(f"{package_name}.upscale.payload_parsers")
parsers_module._parse_tiles_from_form = lambda _data: []
sys.modules[f"{package_name}.upscale.payload_parsers"] = parsers_module
spec = importlib.util.spec_from_file_location(f"{package_name}.api.usdu_routes", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: _FakeResponse(payload, status=status)
)
module._prompt_server_holder = prompt_server_holder
module._TileJobState = TileJobState
module._ImageJobState = ImageJobState
return module
usdu_routes = _load_usdu_routes_module()
class _UploadField:
def __init__(self, data):
self.file = io.BytesIO(data)
def _tiny_png_bytes():
image = Image.new("RGB", (1, 1), (255, 0, 0))
buf = io.BytesIO()
image.save(buf, format="PNG")
return buf.getvalue()
class USDURoutesTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
usdu_routes._prompt_server_holder["value"] = types.SimpleNamespace(
distributed_tile_jobs_lock=asyncio.Lock(),
distributed_pending_tile_jobs={},
)
async def test_heartbeat_updates_worker_status(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
job_data = usdu_routes._TileJobState("job-1")
prompt_server.distributed_pending_tile_jobs["job-1"] = job_data
request = _FakeRequest(json_payload={"worker_id": "worker-a", "multi_job_id": "job-1"})
response = await usdu_routes.heartbeat_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("status"), "success")
self.assertIn("worker-a", job_data.worker_status)
async def test_heartbeat_missing_fields_returns_400(self):
request = _FakeRequest(json_payload={"worker_id": "worker-a"})
response = await usdu_routes.heartbeat_endpoint(request)
self.assertEqual(response.status, 400)
self.assertIn("missing", response.payload.get("message", "").lower())
async def test_request_image_dynamic_assigns_next_image(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
job_data = usdu_routes._ImageJobState("job-2")
await job_data.pending_images.put(7)
prompt_server.distributed_pending_tile_jobs["job-2"] = job_data
request = _FakeRequest(json_payload={"worker_id": "worker-a", "multi_job_id": "job-2"})
response = await usdu_routes.request_image_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("image_idx"), 7)
self.assertEqual(response.payload.get("estimated_remaining"), 0)
self.assertEqual(job_data.assigned_to_workers["worker-a"], [7])
self.assertIn("worker-a", job_data.worker_status)
async def test_request_image_static_assigns_tile_and_batched_flag(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
job_data = usdu_routes._TileJobState("job-3")
job_data.batched_static = True
await job_data.pending_tasks.put(4)
prompt_server.distributed_pending_tile_jobs["job-3"] = job_data
request = _FakeRequest(json_payload={"worker_id": "worker-a", "multi_job_id": "job-3"})
response = await usdu_routes.request_image_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("tile_idx"), 4)
self.assertTrue(response.payload.get("batched_static"))
self.assertEqual(job_data.assigned_to_workers["worker-a"], [4])
async def test_submit_tiles_completion_signal_enqueues_last_marker(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
job_data = usdu_routes._TileJobState("job-4")
prompt_server.distributed_pending_tile_jobs["job-4"] = job_data
request = _FakeRequest(
post_payload={
"multi_job_id": "job-4",
"worker_id": "worker-a",
"batch_size": "0",
"is_last": "true",
},
headers={"content-length": "128"},
)
response = await usdu_routes.submit_tiles_endpoint(request)
self.assertEqual(response.status, 200)
queued = await job_data.queue.get()
self.assertEqual(queued["worker_id"], "worker-a")
self.assertTrue(queued["is_last"])
self.assertEqual(queued["tiles"], [])
async def test_submit_image_enqueues_processed_image_payload(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
job_data = usdu_routes._ImageJobState("job-5")
prompt_server.distributed_pending_tile_jobs["job-5"] = job_data
request = _FakeRequest(
post_payload={
"multi_job_id": "job-5",
"worker_id": "worker-a",
"image_idx": "2",
"full_image": _UploadField(_tiny_png_bytes()),
"is_last": "false",
},
headers={"content-length": "256"},
)
response = await usdu_routes.submit_image_endpoint(request)
self.assertEqual(response.status, 200)
queued = await job_data.queue.get()
self.assertEqual(queued["worker_id"], "worker-a")
self.assertEqual(queued["image_idx"], 2)
self.assertIn("image", queued)
async def test_job_status_endpoint_reports_ready(self):
prompt_server = usdu_routes._prompt_server_holder["value"]
prompt_server.distributed_pending_tile_jobs["job-6"] = usdu_routes._TileJobState("job-6")
request = _FakeRequest(query={"multi_job_id": "job-6"})
response = await usdu_routes.job_status_endpoint(request)
self.assertEqual(response.status, 200)
self.assertTrue(response.payload.get("ready"))
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/api/test_worker_routes.py
================================================
import importlib.util
import os
import sys
import tempfile
import types
import unittest
from collections import deque
from pathlib import Path
from unittest.mock import patch
class _FakeResponse:
def __init__(self, payload, status=200):
self.payload = payload
self.status = status
class _FakeRequest:
def __init__(self, payload=None, match_info=None, query=None):
self._payload = payload
self.match_info = match_info or {}
self.query = query or {}
async def json(self):
return self._payload
class _FakeHTTPClientResponse:
def __init__(self, payload, status=200):
self._payload = payload
self.status = status
async def __aenter__(self):
return self
async def __aexit__(self, _exc_type, _exc, _tb):
return False
async def json(self):
return self._payload
async def text(self):
return str(self._payload)
class _FakeHTTPClientSession:
def __init__(self, payload, status=200):
self._payload = payload
self._status = status
self.calls = []
def get(self, url, params=None, timeout=None):
self.calls.append({"url": url, "params": params, "timeout": timeout})
return _FakeHTTPClientResponse(self._payload, status=self._status)
class _DummyWorkerManager:
def __init__(self):
self.processes = {}
def launch_worker(self, worker):
worker_id = str(worker["id"])
self.processes[worker_id] = {
"pid": 12345,
"log_file": f"/tmp/distributed_worker_{worker_id}.log",
"process": None,
}
return 12345
def _is_process_running(self, _pid):
return False
def save_processes(self):
return None
def stop_worker(self, _worker_id):
return True, "Stopped"
def get_managed_workers(self):
return []
class _ImmediateLoop:
async def run_in_executor(self, _executor, func, *args):
return func(*args)
def _load_worker_routes_module():
module_path = Path(__file__).resolve().parents[2] / "api" / "worker_routes.py"
package_name = "dist_api_worker_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
workers_pkg = types.ModuleType(f"{package_name}.workers")
workers_pkg.__path__ = []
workers_pkg.get_worker_manager = lambda: _DummyWorkerManager()
sys.modules[f"{package_name}.workers"] = workers_pkg
detection_module = types.ModuleType(f"{package_name}.workers.detection")
detection_module.is_local_worker = lambda *_args, **_kwargs: True
detection_module.is_same_physical_host = lambda *_args, **_kwargs: True
detection_module.get_machine_id = lambda: "machine-id"
detection_module.is_docker_environment = lambda: False
detection_module.is_runpod_environment = lambda: False
detection_module.get_comms_channel = lambda *_args, **_kwargs: "lan"
sys.modules[f"{package_name}.workers.detection"] = detection_module
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
class _ClientTimeout:
def __init__(self, total=None):
self.total = total
class _WSMsgType:
TEXT = "TEXT"
ERROR = "ERROR"
CLOSED = "CLOSED"
class _WebSocketResponse:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
async def prepare(self, _request):
return None
async def send_json(self, _payload):
return None
def __aiter__(self):
async def _empty():
if False:
yield None
return _empty()
aiohttp_module.ClientTimeout = _ClientTimeout
aiohttp_module.WSMsgType = _WSMsgType
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: _FakeResponse(payload, status=status),
WebSocketResponse=_WebSocketResponse,
)
sys.modules["aiohttp"] = aiohttp_module
class _Routes:
def get(self, _path):
def _decorator(fn):
return fn
return _decorator
def post(self, _path):
def _decorator(fn):
return fn
return _decorator
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(instance=types.SimpleNamespace(routes=_Routes()))
sys.modules["server"] = server_module
created_torch_stub = False
if "torch" not in sys.modules:
created_torch_stub = True
torch_module = types.ModuleType("torch")
torch_module.cuda = types.SimpleNamespace(
is_available=lambda: False,
empty_cache=lambda: None,
ipc_collect=lambda: None,
current_device=lambda: 0,
device_count=lambda: 0,
)
sys.modules["torch"] = torch_module
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
config_module = types.ModuleType(f"{package_name}.utils.config")
config_module.load_config = lambda: {"workers": []}
sys.modules[f"{package_name}.utils.config"] = config_module
network_module = types.ModuleType(f"{package_name}.utils.network")
async def _handle_api_error(_request, error, status=500):
return _FakeResponse({"status": "error", "message": str(error)}, status=status)
network_module.handle_api_error = _handle_api_error
network_module.normalize_host = lambda value: value
network_module.build_worker_url = lambda worker, endpoint="": f"http://localhost:{worker.get('port', 8188)}{endpoint}"
async def _probe_worker(*_args, **_kwargs):
return None
network_module.probe_worker = _probe_worker
async def _get_client_session():
raise RuntimeError("not used in these tests")
network_module.get_client_session = _get_client_session
sys.modules[f"{package_name}.utils.network"] = network_module
constants_module = types.ModuleType(f"{package_name}.utils.constants")
constants_module.CHUNK_SIZE = 8192
sys.modules[f"{package_name}.utils.constants"] = constants_module
async_helpers_module = types.ModuleType(f"{package_name}.utils.async_helpers")
async def _queue_prompt_payload(*_args, **_kwargs):
return "prompt-id"
async_helpers_module.queue_prompt_payload = _queue_prompt_payload
class _PromptValidationError(RuntimeError):
def __init__(self, message="invalid prompt", validation_error=None, node_errors=None):
super().__init__(message)
self.validation_error = validation_error if isinstance(validation_error, dict) else {}
self.node_errors = node_errors if isinstance(node_errors, dict) else {}
async_helpers_module.PromptValidationError = _PromptValidationError
sys.modules[f"{package_name}.utils.async_helpers"] = async_helpers_module
schemas_module = types.ModuleType(f"{package_name}.api.schemas")
def _require_fields(data, *fields):
missing = []
for field in fields:
value = data.get(field) if isinstance(data, dict) else None
if value is None or (isinstance(value, str) and not value.strip()):
missing.append(field)
return missing
def _validate_worker_id(worker_id, config):
return any(str(worker.get("id")) == str(worker_id) for worker in config.get("workers", []))
schemas_module.require_fields = _require_fields
schemas_module.validate_worker_id = _validate_worker_id
sys.modules[f"{package_name}.api.schemas"] = schemas_module
spec = importlib.util.spec_from_file_location(f"{package_name}.api.worker_routes", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
if created_torch_stub:
sys.modules.pop("torch", None)
return module
worker_routes = _load_worker_routes_module()
class WorkerRoutesTests(unittest.IsolatedAsyncioTestCase):
async def test_launch_worker_valid_id_returns_200(self):
manager = _DummyWorkerManager()
config = {"workers": [{"id": "worker-a", "name": "Worker A", "port": 8188}]}
request = _FakeRequest({"worker_id": "worker-a"})
with patch.object(worker_routes, "get_worker_manager", return_value=manager), patch.object(
worker_routes, "load_config", return_value=config
), patch.object(
worker_routes.asyncio, "get_running_loop", return_value=_ImmediateLoop()
):
response = await worker_routes.launch_worker_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("status"), "success")
self.assertEqual(response.payload.get("pid"), 12345)
async def test_launch_worker_unknown_id_returns_404(self):
manager = _DummyWorkerManager()
config = {"workers": [{"id": "worker-a", "name": "Worker A", "port": 8188}]}
request = _FakeRequest({"worker_id": "missing-worker"})
with patch.object(worker_routes, "get_worker_manager", return_value=manager), patch.object(
worker_routes, "load_config", return_value=config
):
response = await worker_routes.launch_worker_endpoint(request)
self.assertEqual(response.status, 404)
self.assertIn("not found", response.payload.get("message", "").lower())
async def test_worker_log_returns_content_json(self):
manager = _DummyWorkerManager()
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as handle:
handle.write("line-1\nline-2\nline-3\n")
log_path = handle.name
manager.processes["worker-a"] = {
"pid": 9999,
"log_file": log_path,
"process": None,
}
request = _FakeRequest(match_info={"worker_id": "worker-a"}, query={"lines": "2"})
try:
with patch.object(worker_routes, "get_worker_manager", return_value=manager), patch.object(
worker_routes.asyncio, "get_running_loop", return_value=_ImmediateLoop()
):
response = await worker_routes.get_worker_log_endpoint(request)
finally:
if os.path.exists(log_path):
os.remove(log_path)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("status"), "success")
self.assertIn("content", response.payload)
self.assertIn("line-3", response.payload["content"])
async def test_local_log_reads_memory_buffer(self):
request = _FakeRequest(query={"lines": "2"})
fake_logs = deque(
[
{"m": "line-1\n"},
{"m": "line-2\n"},
{"m": "line-3\n"},
],
maxlen=300,
)
app_module = types.ModuleType("app")
app_module.__path__ = []
logger_module = types.ModuleType("app.logger")
logger_module.get_logs = lambda: fake_logs
app_module.logger = logger_module
with patch.dict(sys.modules, {"app": app_module, "app.logger": logger_module}):
response = await worker_routes.get_local_log_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("status"), "success")
self.assertEqual(response.payload.get("source"), "memory")
self.assertEqual(response.payload.get("entries"), 2)
self.assertIn("line-3", response.payload.get("content", ""))
async def test_remote_worker_log_proxies_to_worker_local_log_endpoint(self):
config = {
"workers": [
{
"id": "worker-remote",
"name": "Remote Worker",
"host": "worker.example.com",
"port": 8188,
"type": "remote",
}
]
}
request = _FakeRequest(match_info={"worker_id": "worker-remote"}, query={"lines": "120"})
proxied_payload = {
"status": "success",
"content": "remote-log-content\n",
"entries": 1,
"source": "memory",
"truncated": False,
"lines_shown": 1,
}
fake_session = _FakeHTTPClientSession(proxied_payload)
async def _fake_get_client_session():
return fake_session
with patch.object(worker_routes, "load_config", return_value=config), patch.object(
worker_routes, "get_client_session", side_effect=_fake_get_client_session
):
response = await worker_routes.get_remote_worker_log_endpoint(request)
self.assertEqual(response.status, 200)
self.assertEqual(response.payload.get("content"), "remote-log-content\n")
self.assertEqual(len(fake_session.calls), 1)
self.assertEqual(fake_session.calls[0]["params"], {"lines": "120"})
self.assertTrue(fake_session.calls[0]["url"].endswith("/distributed/local_log"))
async def test_remote_worker_log_rejects_local_workers(self):
config = {"workers": [{"id": "worker-local", "name": "Local Worker", "port": 8188}]}
request = _FakeRequest(match_info={"worker_id": "worker-local"})
with patch.object(worker_routes, "load_config", return_value=config):
response = await worker_routes.get_remote_worker_log_endpoint(request)
self.assertEqual(response.status, 400)
self.assertIn("local", response.payload.get("message", "").lower())
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/conftest.py
================================================
# This conftest.py marks the tests/ directory as the pytest collection root,
# preventing pytest from traversing into the parent package's __init__.py.
================================================
FILE: tests/test_async_helpers.py
================================================
import importlib.util
import sys
import types
import unittest
from pathlib import Path
class _PromptQueue:
def __init__(self):
self.items = []
def put(self, item):
self.items.append(item)
def _load_async_helpers_module():
module_path = Path(__file__).resolve().parents[1] / "utils" / "async_helpers.py"
package_name = "dist_async_helpers_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
execution_module = types.ModuleType("execution")
async def _validate_prompt(prompt_id, prompt, partial_execution_targets):
return (True, None, ["9"], {})
execution_module.validate_prompt = _validate_prompt
execution_module.SENSITIVE_EXTRA_DATA_KEYS = []
sys.modules["execution"] = execution_module
prompt_server = types.SimpleNamespace(
trigger_on_prompt=lambda payload: payload,
number=12,
prompt_queue=_PromptQueue(),
)
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(instance=prompt_server)
sys.modules["server"] = server_module
network_module = types.ModuleType(f"{package_name}.utils.network")
network_module.get_server_loop = lambda: None
sys.modules[f"{package_name}.utils.network"] = network_module
spec = importlib.util.spec_from_file_location(f"{package_name}.utils.async_helpers", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module, prompt_server
async_helpers, prompt_server = _load_async_helpers_module()
class QueuePromptPayloadTests(unittest.IsolatedAsyncioTestCase):
async def test_queue_prompt_payload_includes_create_time_and_client_metadata(self):
result = await async_helpers.queue_prompt_payload(
{"1": {"class_type": "Node"}},
workflow_meta={"id": "workflow-1"},
client_id="client-1",
include_queue_metadata=True,
)
self.assertIsInstance(result["prompt_id"], str)
self.assertTrue(result["prompt_id"])
self.assertEqual(result["number"], 12)
self.assertEqual(result["node_errors"], {})
self.assertEqual(prompt_server.number, 13)
self.assertEqual(len(prompt_server.prompt_queue.items), 1)
queued_item = prompt_server.prompt_queue.items[0]
self.assertEqual(queued_item[0], 12)
extra_data = queued_item[3]
self.assertEqual(extra_data["client_id"], "client-1")
self.assertIn("create_time", extra_data)
self.assertIsInstance(extra_data["create_time"], int)
self.assertGreater(extra_data["create_time"], 0)
self.assertEqual(extra_data["extra_pnginfo"]["workflow"], {"id": "workflow-1"})
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_batch_dividers.py
================================================
import importlib.util
import sys
import types
import unittest
from pathlib import Path
import torch
def _load_utilities_module():
module_path = Path(__file__).resolve().parents[1] / "nodes" / "utilities.py"
package_name = "dist_divider_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
nodes_pkg = types.ModuleType(f"{package_name}.nodes")
nodes_pkg.__path__ = []
sys.modules[f"{package_name}.nodes"] = nodes_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.nodes.utilities",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
utils = _load_utilities_module()
class ImageBatchDividerTests(unittest.TestCase):
def test_divides_images_into_contiguous_chunks(self):
divider = utils.ImageBatchDivider()
images = torch.arange(10, dtype=torch.float32).reshape(10, 1, 1, 1)
outputs = divider.divide_batch(images, 3)
self.assertEqual(outputs[0].shape[0], 4)
self.assertEqual(outputs[1].shape[0], 3)
self.assertEqual(outputs[2].shape[0], 3)
self.assertEqual(outputs[0][:, 0, 0, 0].tolist(), [0.0, 1.0, 2.0, 3.0])
self.assertEqual(outputs[1][:, 0, 0, 0].tolist(), [4.0, 5.0, 6.0])
self.assertEqual(outputs[2][:, 0, 0, 0].tolist(), [7.0, 8.0, 9.0])
def test_unused_image_outputs_are_empty(self):
divider = utils.ImageBatchDivider()
images = torch.arange(4, dtype=torch.float32).reshape(4, 1, 1, 1)
outputs = divider.divide_batch(images, 2)
self.assertEqual(len(outputs), 10)
for idx in range(2, 10):
self.assertEqual(outputs[idx].shape[0], 0)
class AudioBatchDividerTests(unittest.TestCase):
def test_divides_audio_samples_into_contiguous_chunks(self):
divider = utils.AudioBatchDivider()
audio = {
"waveform": torch.arange(10, dtype=torch.float32).reshape(1, 1, 10),
"sample_rate": 24000,
}
outputs = divider.divide_audio(audio, 3)
self.assertEqual(outputs[0]["waveform"][0, 0].tolist(), [0.0, 1.0, 2.0, 3.0])
self.assertEqual(outputs[1]["waveform"][0, 0].tolist(), [4.0, 5.0, 6.0])
self.assertEqual(outputs[2]["waveform"][0, 0].tolist(), [7.0, 8.0, 9.0])
def test_unused_audio_outputs_are_empty(self):
divider = utils.AudioBatchDivider()
audio = {
"waveform": torch.arange(8, dtype=torch.float32).reshape(1, 1, 8),
"sample_rate": 24000,
}
outputs = divider.divide_audio(audio, 2)
self.assertEqual(len(outputs), 10)
for idx in range(2, 10):
self.assertEqual(outputs[idx]["waveform"].shape[-1], 0)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_config.py
================================================
import importlib.util
import json
import os
import sys
import tempfile
import types
import unittest
from pathlib import Path
from unittest.mock import patch
def _load_config_module():
module_path = Path(__file__).resolve().parents[1] / "utils" / "config.py"
package_name = "dist_cfg_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
logging_module = types.ModuleType(f"{package_name}.logging")
logging_module.log = lambda *_args, **_kwargs: None
logging_module.debug_log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.logging"] = logging_module
constants_module = types.ModuleType(f"{package_name}.constants")
constants_module.HEARTBEAT_TIMEOUT = 30
sys.modules[f"{package_name}.constants"] = constants_module
spec = importlib.util.spec_from_file_location(f"{package_name}.config", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
config = _load_config_module()
# ---------------------------------------------------------------------------
# _merge_with_defaults
# ---------------------------------------------------------------------------
class MergeWithDefaultsTests(unittest.TestCase):
def test_non_dict_input_returns_defaults(self):
result = config._merge_with_defaults("not a dict", {"key": "default"})
self.assertEqual(result, {"key": "default"})
def test_fills_missing_keys_with_defaults(self):
result = config._merge_with_defaults({}, {"a": 1, "b": 2})
self.assertEqual(result, {"a": 1, "b": 2})
def test_loaded_value_overrides_default(self):
result = config._merge_with_defaults({"a": 99}, {"a": 1, "b": 2})
self.assertEqual(result["a"], 99)
self.assertEqual(result["b"], 2)
def test_nested_dict_merges_recursively(self):
defaults = {"settings": {"debug": False, "count": 5}}
loaded = {"settings": {"debug": True}}
result = config._merge_with_defaults(loaded, defaults)
self.assertTrue(result["settings"]["debug"])
self.assertEqual(result["settings"]["count"], 5)
def test_preserves_unknown_keys_for_forward_compatibility(self):
result = config._merge_with_defaults({"extra_key": "extra"}, {"a": 1})
self.assertEqual(result["extra_key"], "extra")
def test_none_loaded_value_overrides_default(self):
"""Explicitly set None in config should override non-None default."""
result = config._merge_with_defaults({"a": None}, {"a": "default"})
self.assertIsNone(result["a"])
def test_non_dict_nested_loaded_value_replaces_dict_default(self):
"""If loaded has a scalar where default has a dict, use the scalar."""
defaults = {"settings": {"debug": False}}
loaded = {"settings": "flat_string"}
result = config._merge_with_defaults(loaded, defaults)
self.assertEqual(result["settings"], "flat_string")
# ---------------------------------------------------------------------------
# load_config
# ---------------------------------------------------------------------------
class LoadConfigTests(unittest.TestCase):
def setUp(self):
config.invalidate_config_cache()
def tearDown(self):
config.invalidate_config_cache()
def test_returns_defaults_when_file_missing(self):
with patch.object(config, "CONFIG_FILE", "/nonexistent/path/config.json"):
cfg = config.load_config()
defaults = config.get_default_config()
self.assertEqual(cfg["settings"]["debug"], defaults["settings"]["debug"])
self.assertIn("workers", cfg)
def test_loads_valid_json_file(self):
data = {
"workers": [{"id": "w1"}],
"master": {"host": "test.host"},
"settings": {},
"tunnel": {},
}
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as f:
json.dump(data, f)
tmp_path = f.name
try:
with patch.object(config, "CONFIG_FILE", tmp_path):
cfg = config.load_config()
self.assertEqual(cfg["master"]["host"], "test.host")
self.assertEqual(len(cfg["workers"]), 1)
finally:
os.unlink(tmp_path)
def test_merges_loaded_file_with_defaults(self):
"""Loaded file with partial settings should be filled in from defaults."""
data = {"master": {"host": "h"}, "workers": [], "settings": {"debug": True}, "tunnel": {}}
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as f:
json.dump(data, f)
tmp_path = f.name
try:
with patch.object(config, "CONFIG_FILE", tmp_path):
cfg = config.load_config()
# debug was set to True
self.assertTrue(cfg["settings"]["debug"])
# auto_launch_workers is a default key and should be present
self.assertIn("auto_launch_workers", cfg["settings"])
finally:
os.unlink(tmp_path)
def test_falls_back_to_defaults_on_invalid_json(self):
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as f:
f.write("{invalid json{{")
tmp_path = f.name
try:
with patch.object(config, "CONFIG_FILE", tmp_path):
cfg = config.load_config()
self.assertIn("settings", cfg)
self.assertIn("workers", cfg)
finally:
os.unlink(tmp_path)
def test_second_call_returns_cached_object(self):
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as f:
json.dump(config.get_default_config(), f)
tmp_path = f.name
try:
with patch.object(config, "CONFIG_FILE", tmp_path):
cfg1 = config.load_config()
cfg2 = config.load_config()
self.assertIs(cfg1, cfg2)
finally:
os.unlink(tmp_path)
def test_invalidate_cache_forces_reload(self):
data = config.get_default_config()
data["master"]["host"] = "first"
with tempfile.NamedTemporaryFile("w", suffix=".json", delete=False, encoding="utf-8") as f:
json.dump(data, f)
tmp_path = f.name
try:
with patch.object(config, "CONFIG_FILE", tmp_path):
cfg1 = config.load_config()
config.invalidate_config_cache()
data["master"]["host"] = "second"
with open(tmp_path, "w", encoding="utf-8") as fh:
json.dump(data, fh)
cfg2 = config.load_config()
self.assertEqual(cfg1["master"]["host"], "first")
self.assertEqual(cfg2["master"]["host"], "second")
finally:
os.unlink(tmp_path)
# ---------------------------------------------------------------------------
# save_config
# ---------------------------------------------------------------------------
class SaveConfigTests(unittest.TestCase):
def setUp(self):
config.invalidate_config_cache()
def tearDown(self):
config.invalidate_config_cache()
def test_saves_and_reloads_correctly(self):
data = config.get_default_config()
data["master"]["host"] = "saved.host"
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = os.path.join(tmpdir, "config.json")
with patch.object(config, "CONFIG_FILE", tmp_path):
result = config.save_config(data)
self.assertTrue(result)
loaded = config.load_config()
self.assertEqual(loaded["master"]["host"], "saved.host")
def test_returns_false_when_path_unwritable(self):
with patch.object(config, "CONFIG_FILE", "/nonexistent_dir/config.json"):
result = config.save_config({})
self.assertFalse(result)
def test_save_invalidates_cache(self):
"""After saving, the cache should be cleared so next load re-reads."""
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = os.path.join(tmpdir, "config.json")
with patch.object(config, "CONFIG_FILE", tmp_path):
data = config.get_default_config()
config.save_config(data)
# Cache is now None; load_config should re-read
self.assertIsNone(config._config_cache)
def test_written_file_is_valid_json(self):
data = config.get_default_config()
with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = os.path.join(tmpdir, "config.json")
with patch.object(config, "CONFIG_FILE", tmp_path):
config.save_config(data)
with open(tmp_path, encoding="utf-8") as fh:
parsed = json.load(fh)
self.assertEqual(parsed["master"], data["master"])
# ---------------------------------------------------------------------------
# get_worker_timeout_seconds
# ---------------------------------------------------------------------------
class GetWorkerTimeoutSecondsTests(unittest.TestCase):
def test_returns_configured_value(self):
cfg = config.get_default_config()
cfg["settings"]["worker_timeout_seconds"] = 120
with patch.object(config, "load_config", return_value=cfg):
self.assertEqual(config.get_worker_timeout_seconds(), 120)
def test_clamps_zero_to_one(self):
cfg = config.get_default_config()
cfg["settings"]["worker_timeout_seconds"] = 0
with patch.object(config, "load_config", return_value=cfg):
self.assertEqual(config.get_worker_timeout_seconds(), 1)
def test_clamps_negative_to_one(self):
cfg = config.get_default_config()
cfg["settings"]["worker_timeout_seconds"] = -10
with patch.object(config, "load_config", return_value=cfg):
self.assertEqual(config.get_worker_timeout_seconds(), 1)
def test_falls_back_to_provided_default_when_key_missing(self):
cfg = config.get_default_config()
# worker_timeout_seconds is not present in default config
cfg["settings"].pop("worker_timeout_seconds", None)
with patch.object(config, "load_config", return_value=cfg):
result = config.get_worker_timeout_seconds(default=45)
self.assertEqual(result, 45)
def test_fallback_also_clamped_to_one(self):
cfg = config.get_default_config()
cfg["settings"].pop("worker_timeout_seconds", None)
with patch.object(config, "load_config", return_value=cfg):
result = config.get_worker_timeout_seconds(default=0)
self.assertEqual(result, 1)
# ---------------------------------------------------------------------------
# is_master_delegate_only
# ---------------------------------------------------------------------------
class IsMasterDelegateOnlyTests(unittest.TestCase):
def test_returns_false_by_default(self):
cfg = config.get_default_config()
with patch.object(config, "load_config", return_value=cfg):
self.assertFalse(config.is_master_delegate_only())
def test_returns_true_when_enabled(self):
cfg = config.get_default_config()
cfg["settings"]["master_delegate_only"] = True
with patch.object(config, "load_config", return_value=cfg):
self.assertTrue(config.is_master_delegate_only())
def test_returns_false_on_exception(self):
def _raise():
raise RuntimeError("config exploded")
with patch.object(config, "load_config", side_effect=RuntimeError("boom")):
self.assertFalse(config.is_master_delegate_only())
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_detection.py
================================================
import importlib.util
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
def _load_detection_module():
module_path = Path(__file__).resolve().parents[1] / "workers" / "detection.py"
package_name = "dist_det_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
workers_pkg = types.ModuleType(f"{package_name}.workers")
workers_pkg.__path__ = []
sys.modules[f"{package_name}.workers"] = workers_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
network_module.normalize_host = lambda value: value
async def _fake_session():
raise RuntimeError("network calls not used in these tests")
network_module.get_client_session = _fake_session
sys.modules[f"{package_name}.utils.network"] = network_module
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
class _ClientTimeout:
def __init__(self, total=None):
pass
aiohttp_module.ClientTimeout = _ClientTimeout
sys.modules["aiohttp"] = aiohttp_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.workers.detection",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
return module
detection = _load_detection_module()
# ---------------------------------------------------------------------------
# is_docker_environment
# ---------------------------------------------------------------------------
class IsDockerEnvironmentTests(unittest.TestCase):
def test_true_when_dockerenv_file_exists(self):
with patch.object(detection.os.path, "exists", return_value=True), \
patch.dict(detection.os.environ, {}, clear=True), \
patch.object(detection.platform, "node", return_value="my-laptop"):
self.assertTrue(detection.is_docker_environment())
def test_true_when_docker_container_env_var_is_set(self):
with patch.object(detection.os.path, "exists", return_value=False), \
patch.dict(detection.os.environ, {"DOCKER_CONTAINER": "1"}, clear=True), \
patch.object(detection.platform, "node", return_value="my-laptop"):
self.assertTrue(detection.is_docker_environment())
def test_true_when_platform_node_contains_docker(self):
with patch.object(detection.os.path, "exists", return_value=False), \
patch.dict(detection.os.environ, {}, clear=True), \
patch.object(detection.platform, "node", return_value="my-docker-host"):
self.assertTrue(detection.is_docker_environment())
def test_false_when_none_of_the_signals_are_present(self):
with patch.object(detection.os.path, "exists", return_value=False), \
patch.dict(detection.os.environ, {}, clear=True), \
patch.object(detection.platform, "node", return_value="my-laptop"):
self.assertFalse(detection.is_docker_environment())
def test_docker_node_name_is_case_insensitive(self):
with patch.object(detection.os.path, "exists", return_value=False), \
patch.dict(detection.os.environ, {}, clear=True), \
patch.object(detection.platform, "node", return_value="My-Docker-Box"):
self.assertTrue(detection.is_docker_environment())
def test_docker_env_var_empty_string_is_falsy(self):
"""An empty DOCKER_CONTAINER env var should NOT trigger docker detection."""
with patch.object(detection.os.path, "exists", return_value=False), \
patch.dict(detection.os.environ, {"DOCKER_CONTAINER": ""}, clear=True), \
patch.object(detection.platform, "node", return_value="my-laptop"):
self.assertFalse(detection.is_docker_environment())
# ---------------------------------------------------------------------------
# is_runpod_environment
# ---------------------------------------------------------------------------
class IsRunpodEnvironmentTests(unittest.TestCase):
def test_true_when_runpod_pod_id_is_set(self):
with patch.dict(detection.os.environ, {"RUNPOD_POD_ID": "pod-abc"}, clear=True):
self.assertTrue(detection.is_runpod_environment())
def test_true_when_runpod_api_key_is_set(self):
with patch.dict(detection.os.environ, {"RUNPOD_API_KEY": "key-xyz"}, clear=True):
self.assertTrue(detection.is_runpod_environment())
def test_true_when_both_vars_are_set(self):
with patch.dict(
detection.os.environ,
{"RUNPOD_POD_ID": "pod-abc", "RUNPOD_API_KEY": "key-xyz"},
clear=True,
):
self.assertTrue(detection.is_runpod_environment())
def test_false_when_neither_var_is_set(self):
with patch.dict(detection.os.environ, {}, clear=True):
self.assertFalse(detection.is_runpod_environment())
def test_true_when_pod_id_is_empty_string(self):
"""is not None check means even empty string counts as detected."""
with patch.dict(detection.os.environ, {"RUNPOD_POD_ID": ""}, clear=True):
self.assertTrue(detection.is_runpod_environment())
# ---------------------------------------------------------------------------
# is_local_worker (synchronous paths only)
# ---------------------------------------------------------------------------
class IsLocalWorkerTests(unittest.IsolatedAsyncioTestCase):
async def test_true_for_localhost_host(self):
result = await detection.is_local_worker({"host": "localhost", "port": 8188})
self.assertTrue(result)
async def test_true_for_127_0_0_1(self):
result = await detection.is_local_worker({"host": "127.0.0.1", "port": 8188})
self.assertTrue(result)
async def test_true_for_0_0_0_0(self):
result = await detection.is_local_worker({"host": "0.0.0.0", "port": 8188})
self.assertTrue(result)
async def test_true_when_type_is_local(self):
result = await detection.is_local_worker({"type": "local", "host": "remote.example.com"})
self.assertTrue(result)
async def test_false_for_remote_host(self):
result = await detection.is_local_worker({"host": "remote.example.com", "port": 8188})
self.assertFalse(result)
async def test_true_when_no_host_key(self):
"""Missing host defaults to 'localhost'."""
result = await detection.is_local_worker({"port": 8188})
self.assertTrue(result)
# ---------------------------------------------------------------------------
# get_machine_id
# ---------------------------------------------------------------------------
class GetMachineIdTests(unittest.TestCase):
def test_returns_a_string(self):
result = detection.get_machine_id()
self.assertIsInstance(result, str)
def test_returns_non_empty_string(self):
result = detection.get_machine_id()
self.assertTrue(len(result) > 0)
def test_stable_across_calls(self):
r1 = detection.get_machine_id()
r2 = detection.get_machine_id()
self.assertEqual(r1, r2)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_dispatch_selection.py
================================================
import asyncio
import importlib.util
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
def _load_dispatch_module():
module_path = Path(__file__).resolve().parents[1] / "api" / "orchestration" / "dispatch.py"
package_name = "dist_dispatch_testpkg"
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
orch_pkg = types.ModuleType(f"{package_name}.api.orchestration")
orch_pkg.__path__ = []
sys.modules[f"{package_name}.api.orchestration"] = orch_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
network_module.build_worker_url = lambda *_args, **_kwargs: "http://example.invalid"
async def _probe_worker(*_args, **_kwargs):
return None
network_module.probe_worker = _probe_worker
async def _fake_session():
raise RuntimeError("get_client_session should be mocked in these tests")
network_module.get_client_session = _fake_session
sys.modules[f"{package_name}.utils.network"] = network_module
created_aiohttp_stub = False
if "aiohttp" not in sys.modules:
created_aiohttp_stub = True
aiohttp_module = types.ModuleType("aiohttp")
class _ClientTimeout:
def __init__(self, total=None):
self.total = total
class _ClientConnectorError(Exception):
pass
class _WSMsgType:
TEXT = "TEXT"
ERROR = "ERROR"
CLOSED = "CLOSED"
class _TCPConnector:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _ClientSession:
def __init__(self, *args, **kwargs):
self.closed = False
async def close(self):
self.closed = True
aiohttp_module.ClientTimeout = _ClientTimeout
aiohttp_module.ClientConnectorError = _ClientConnectorError
aiohttp_module.WSMsgType = _WSMsgType
aiohttp_module.TCPConnector = _TCPConnector
aiohttp_module.ClientSession = _ClientSession
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: {"payload": payload, "status": status}
)
sys.modules["aiohttp"] = aiohttp_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.api.orchestration.dispatch",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_aiohttp_stub:
sys.modules.pop("aiohttp", None)
return module
dispatch = _load_dispatch_module()
class DispatchSelectionTests(unittest.IsolatedAsyncioTestCase):
async def test_select_active_workers_filters_offline(self):
workers = [
{"id": "w1", "name": "Worker 1"},
{"id": "w2", "name": "Worker 2"},
{"id": "w3", "name": "Worker 3"},
]
async def fake_probe(worker):
return worker["id"] != "w2"
with patch.object(dispatch, "worker_is_active", side_effect=fake_probe):
active_workers, delegate_master = await dispatch.select_active_workers(
workers,
use_websocket=False,
delegate_master=False,
probe_concurrency=3,
)
self.assertEqual([w["id"] for w in active_workers], ["w1", "w3"])
self.assertFalse(delegate_master)
async def test_select_active_workers_disables_delegate_when_all_offline(self):
workers = [{"id": "w1", "name": "Worker 1"}]
async def fake_probe(_worker):
return False
with patch.object(dispatch, "worker_is_active", side_effect=fake_probe):
active_workers, delegate_master = await dispatch.select_active_workers(
workers,
use_websocket=False,
delegate_master=True,
probe_concurrency=1,
)
self.assertEqual(active_workers, [])
self.assertFalse(delegate_master)
async def test_select_active_workers_uses_websocket_probe_when_enabled(self):
workers = [{"id": "w1", "name": "Worker 1"}, {"id": "w2", "name": "Worker 2"}]
async def fake_http_probe(_worker):
return False
async def fake_ws_probe(_worker):
return True
with patch.object(dispatch, "worker_is_active", side_effect=fake_http_probe) as http_probe, patch.object(
dispatch,
"worker_ws_is_active",
side_effect=fake_ws_probe,
) as ws_probe:
active_workers, _ = await dispatch.select_active_workers(
workers,
use_websocket=True,
delegate_master=False,
probe_concurrency=4,
)
self.assertEqual([w["id"] for w in active_workers], ["w1", "w2"])
self.assertEqual(ws_probe.call_count, 2)
self.assertEqual(http_probe.call_count, 0)
async def test_probe_concurrency_is_bounded(self):
workers = [{"id": f"w{i}", "name": f"Worker {i}"} for i in range(6)]
state = {"in_flight": 0, "max_in_flight": 0}
async def fake_probe(_worker):
state["in_flight"] += 1
state["max_in_flight"] = max(state["max_in_flight"], state["in_flight"])
await asyncio.sleep(0.01)
state["in_flight"] -= 1
return True
with patch.object(dispatch, "worker_is_active", side_effect=fake_probe):
active_workers, _ = await dispatch.select_active_workers(
workers,
use_websocket=False,
delegate_master=False,
probe_concurrency=2,
)
self.assertEqual(len(active_workers), len(workers))
self.assertLessEqual(state["max_in_flight"], 2)
self.assertGreaterEqual(state["max_in_flight"], 2)
async def test_select_least_busy_worker_round_robins_idle_workers(self):
workers = [
{"id": "w1", "name": "Worker 1"},
{"id": "w2", "name": "Worker 2"},
{"id": "w3", "name": "Worker 3"},
]
queue_map = {"w1": 0, "w2": 0, "w3": 2}
async def fake_probe(worker_url, timeout=3.0):
worker_id = worker_url.rsplit("/", 1)[-1]
return {"exec_info": {"queue_remaining": queue_map[worker_id]}}
with patch.object(dispatch, "build_worker_url", side_effect=lambda worker: f"http://host/{worker['id']}"), patch.object(
dispatch,
"probe_worker",
side_effect=fake_probe,
):
dispatch._least_busy_rr_index = 0
selected1 = await dispatch.select_least_busy_worker(workers, probe_concurrency=3)
selected2 = await dispatch.select_least_busy_worker(workers, probe_concurrency=3)
selected3 = await dispatch.select_least_busy_worker(workers, probe_concurrency=3)
self.assertEqual(selected1["id"], "w1")
self.assertEqual(selected2["id"], "w2")
self.assertEqual(selected3["id"], "w1")
async def test_select_least_busy_worker_chooses_smallest_queue_when_all_busy(self):
workers = [
{"id": "w1", "name": "Worker 1"},
{"id": "w2", "name": "Worker 2"},
{"id": "w3", "name": "Worker 3"},
]
queue_map = {"w1": 5, "w2": 2, "w3": 4}
async def fake_probe(worker_url, timeout=3.0):
worker_id = worker_url.rsplit("/", 1)[-1]
return {"exec_info": {"queue_remaining": queue_map[worker_id]}}
with patch.object(dispatch, "build_worker_url", side_effect=lambda worker: f"http://host/{worker['id']}"), patch.object(
dispatch,
"probe_worker",
side_effect=fake_probe,
):
selected = await dispatch.select_least_busy_worker(workers, probe_concurrency=2)
self.assertEqual(selected["id"], "w2")
async def test_select_least_busy_worker_returns_none_when_all_probes_fail(self):
workers = [{"id": "w1", "name": "Worker 1"}]
async def fake_probe(_worker_url, timeout=3.0):
return None
with patch.object(dispatch, "build_worker_url", side_effect=lambda worker: f"http://host/{worker['id']}"), patch.object(
dispatch,
"probe_worker",
side_effect=fake_probe,
):
selected = await dispatch.select_least_busy_worker(workers, probe_concurrency=1)
self.assertIsNone(selected)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_distributed_value.py
================================================
import json
import unittest
class DistributedValueTests(unittest.TestCase):
"""Unit tests for the DistributedValue node's distribute() method."""
def _make_node(self):
# Import inline to avoid plugin-level imports
import importlib.util
import sys
import types
from pathlib import Path
from unittest.mock import MagicMock
module_path = Path(__file__).resolve().parents[1] / "nodes" / "utilities.py"
pkg_name = "dv_test_pkg"
for mod_name in list(sys.modules):
if mod_name == pkg_name or mod_name.startswith(f"{pkg_name}."):
del sys.modules[mod_name]
# Mock torch if not available
if "torch" not in sys.modules:
sys.modules["torch"] = MagicMock()
root_pkg = types.ModuleType(pkg_name)
root_pkg.__path__ = []
sys.modules[pkg_name] = root_pkg
utils_pkg = types.ModuleType(f"{pkg_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{pkg_name}.utils"] = utils_pkg
logging_mod = types.ModuleType(f"{pkg_name}.utils.logging")
logging_mod.debug_log = lambda *_a, **_k: None
logging_mod.log = lambda *_a, **_k: None
sys.modules[f"{pkg_name}.utils.logging"] = logging_mod
spec = importlib.util.spec_from_file_location(
f"{pkg_name}.nodes.utilities", module_path
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod.DistributedValue()
def setUp(self):
self.node = self._make_node()
def test_master_returns_default(self):
result = self.node.distribute(
default_value="model_a",
worker_values="{}",
is_worker=False,
worker_id="",
)
self.assertEqual(result, ("model_a",))
def test_master_coerces_default_int(self):
values = json.dumps({"_type": "INT"})
result = self.node.distribute(
default_value="42",
worker_values=values,
is_worker=False,
worker_id="",
)
self.assertEqual(result, (42,))
self.assertIsInstance(result[0], int)
def test_master_coerces_default_float(self):
values = json.dumps({"_type": "FLOAT"})
result = self.node.distribute(
default_value="2.5",
worker_values=values,
is_worker=False,
worker_id="",
)
self.assertEqual(result, (2.5,))
self.assertIsInstance(result[0], float)
def test_worker_returns_specific_value(self):
values = json.dumps({"1": "model_x", "2": "model_y"})
result = self.node.distribute(
default_value="default",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, ("model_x",))
def test_worker_second_index(self):
values = json.dumps({"1": "model_x", "2": "model_y"})
result = self.node.distribute(
default_value="default",
worker_values=values,
is_worker=True,
worker_id="worker_1",
)
self.assertEqual(result, ("model_y",))
def test_worker_falls_back_to_default_when_key_missing(self):
values = json.dumps({"_type": "INT", "1": "3"})
result = self.node.distribute(
default_value="9",
worker_values=values,
is_worker=True,
worker_id="worker_5",
)
self.assertEqual(result, (9,))
self.assertIsInstance(result[0], int)
def test_worker_falls_back_to_default_on_empty_value(self):
values = json.dumps({"1": ""})
result = self.node.distribute(
default_value="fallback",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, ("fallback",))
def test_worker_falls_back_on_invalid_json(self):
result = self.node.distribute(
default_value="safe",
worker_values="not-json",
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, ("safe",))
def test_worker_falls_back_on_invalid_worker_id(self):
values = json.dumps({"1": "model_x"})
result = self.node.distribute(
default_value="safe",
worker_values=values,
is_worker=True,
worker_id="bad_id",
)
self.assertEqual(result, ("safe",))
def test_worker_id_as_direct_integer(self):
values = json.dumps({"1": "model_x"})
result = self.node.distribute(
default_value="default",
worker_values=values,
is_worker=True,
worker_id="0",
)
self.assertEqual(result, ("model_x",))
def test_type_int_coerces_value(self):
values = json.dumps({"_type": "INT", "1": "42"})
result = self.node.distribute(
default_value="0",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, (42,))
self.assertIsInstance(result[0], int)
def test_type_float_coerces_value(self):
values = json.dumps({"_type": "FLOAT", "1": "3.14"})
result = self.node.distribute(
default_value="0",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertAlmostEqual(result[0], 3.14)
self.assertIsInstance(result[0], float)
def test_type_combo_stays_string(self):
values = json.dumps({"_type": "COMBO", "1": "model_v2"})
result = self.node.distribute(
default_value="model_v1",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, ("model_v2",))
self.assertIsInstance(result[0], str)
def test_type_string_default_stays_string(self):
values = json.dumps({"1": "hello"})
result = self.node.distribute(
default_value="default",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, ("hello",))
self.assertIsInstance(result[0], str)
def test_int_coerce_handles_float_string(self):
"""INT coercion of '3.7' should truncate to 3."""
values = json.dumps({"_type": "INT", "1": "3.7"})
result = self.node.distribute(
default_value="0",
worker_values=values,
is_worker=True,
worker_id="worker_0",
)
self.assertEqual(result, (3,))
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_job_timeout.py
================================================
import asyncio
import importlib.util
import sys
import time
import types
import unittest
from dataclasses import dataclass, field
from pathlib import Path
def _load_job_timeout_module():
module_path = Path(__file__).resolve().parents[1] / "upscale" / "job_timeout.py"
package_name = "dist_job_timeout_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
upscale_pkg = types.ModuleType(f"{package_name}.upscale")
upscale_pkg.__path__ = []
sys.modules[f"{package_name}.upscale"] = upscale_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
config_holder = {"value": {"settings": {}, "workers": []}}
probe_holder = {"fn": None}
prompt_server_holder = {
"value": types.SimpleNamespace(
distributed_tile_jobs_lock=asyncio.Lock(),
distributed_pending_tile_jobs={},
)
}
config_module = types.ModuleType(f"{package_name}.utils.config")
config_module.load_config = lambda: config_holder["value"]
sys.modules[f"{package_name}.utils.config"] = config_module
constants_module = types.ModuleType(f"{package_name}.utils.constants")
constants_module.HEARTBEAT_TIMEOUT = 60
sys.modules[f"{package_name}.utils.constants"] = constants_module
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
network_module = types.ModuleType(f"{package_name}.utils.network")
network_module.build_worker_url = lambda worker: f"http://{worker.get('host', '127.0.0.1')}:{worker.get('port', 8188)}"
async def _probe_worker(url, timeout=2.0):
fn = probe_holder["fn"]
if fn is None:
return None
return await fn(url, timeout)
network_module.probe_worker = _probe_worker
sys.modules[f"{package_name}.utils.network"] = network_module
job_store_module = types.ModuleType(f"{package_name}.upscale.job_store")
job_store_module.ensure_tile_jobs_initialized = lambda: prompt_server_holder["value"]
sys.modules[f"{package_name}.upscale.job_store"] = job_store_module
job_models_module = types.ModuleType(f"{package_name}.upscale.job_models")
class BaseJobState:
pass
@dataclass
class ImageJobState(BaseJobState):
multi_job_id: str
mode: str = field(default="dynamic", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_images: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_images: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
@property
def pending_tasks(self):
return self.pending_images
@property
def completed_tasks(self):
return self.completed_images
@dataclass
class TileJobState(BaseJobState):
multi_job_id: str
mode: str = field(default="static", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_tasks: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_tasks: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
job_models_module.BaseJobState = BaseJobState
job_models_module.ImageJobState = ImageJobState
job_models_module.TileJobState = TileJobState
sys.modules[f"{package_name}.upscale.job_models"] = job_models_module
spec = importlib.util.spec_from_file_location(f"{package_name}.upscale.job_timeout", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
module._config_holder = config_holder
module._probe_holder = probe_holder
module._prompt_server_holder = prompt_server_holder
module._ImageJobState = ImageJobState
module._TileJobState = TileJobState
return module
jt = _load_job_timeout_module()
class JobTimeoutRequeueTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
jt._prompt_server_holder["value"] = types.SimpleNamespace(
distributed_tile_jobs_lock=asyncio.Lock(),
distributed_pending_tile_jobs={},
)
jt._config_holder["value"] = {
"settings": {"worker_timeout_seconds": 5},
"workers": [{"id": "worker-1", "host": "worker.local", "port": 8188}],
}
async def test_requeues_only_incomplete_dynamic_tasks_for_timed_out_worker(self):
async def _offline_probe(_url, _timeout):
return None
jt._probe_holder["fn"] = _offline_probe
prompt_server = jt._prompt_server_holder["value"]
job_data = jt._ImageJobState("job-1")
job_data.worker_status["worker-1"] = time.time() - 60.0
job_data.assigned_to_workers["worker-1"] = [0, 1]
job_data.completed_images[1] = "done"
prompt_server.distributed_pending_tile_jobs["job-1"] = job_data
requeued = await jt._check_and_requeue_timed_out_workers("job-1", total_tasks=2)
self.assertEqual(requeued, 1)
self.assertEqual(await job_data.pending_images.get(), 0)
self.assertNotIn("worker-1", job_data.worker_status)
self.assertEqual(job_data.assigned_to_workers["worker-1"], [])
async def test_busy_probe_graces_worker_and_skips_requeue(self):
async def _busy_probe(_url, _timeout):
return {"exec_info": {"queue_remaining": 3}}
jt._probe_holder["fn"] = _busy_probe
prompt_server = jt._prompt_server_holder["value"]
job_data = jt._ImageJobState("job-2")
old_heartbeat = time.time() - 60.0
job_data.worker_status["worker-1"] = old_heartbeat
job_data.assigned_to_workers["worker-1"] = [0]
prompt_server.distributed_pending_tile_jobs["job-2"] = job_data
requeued = await jt._check_and_requeue_timed_out_workers("job-2", total_tasks=1)
self.assertEqual(requeued, 0)
self.assertIn("worker-1", job_data.worker_status)
self.assertGreaterEqual(job_data.worker_status["worker-1"], old_heartbeat)
self.assertTrue(job_data.pending_images.empty())
async def test_completed_dynamic_task_is_not_requeued(self):
async def _offline_probe(_url, _timeout):
return None
jt._probe_holder["fn"] = _offline_probe
prompt_server = jt._prompt_server_holder["value"]
job_data = jt._ImageJobState("job-3")
job_data.worker_status["worker-1"] = time.time() - 60.0
job_data.assigned_to_workers["worker-1"] = [7]
job_data.completed_images[7] = "complete"
prompt_server.distributed_pending_tile_jobs["job-3"] = job_data
requeued = await jt._check_and_requeue_timed_out_workers("job-3", total_tasks=1)
self.assertEqual(requeued, 0)
self.assertTrue(job_data.pending_images.empty())
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_network_helpers.py
================================================
import importlib.util
import sys
import types
import unittest
from pathlib import Path
def _load_network_module():
module_path = Path(__file__).resolve().parents[1] / "utils" / "network.py"
package_name = "dist_utils_testpkg"
package_module = types.ModuleType(package_name)
package_module.__path__ = [] # mark as package
sys.modules[package_name] = package_module
logging_module = types.ModuleType(f"{package_name}.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.logging"] = logging_module
server_module = types.ModuleType("server")
server_module.PromptServer = types.SimpleNamespace(
instance=types.SimpleNamespace(address="127.0.0.1", port=8188, loop=None)
)
sys.modules["server"] = server_module
if "aiohttp" not in sys.modules:
aiohttp_module = types.ModuleType("aiohttp")
class _TCPConnector:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _ClientSession:
def __init__(self, *args, **kwargs):
self.closed = False
async def close(self):
self.closed = True
aiohttp_module.TCPConnector = _TCPConnector
aiohttp_module.ClientSession = _ClientSession
aiohttp_module.web = types.SimpleNamespace(
json_response=lambda payload, status=200: {"payload": payload, "status": status}
)
sys.modules["aiohttp"] = aiohttp_module
spec = importlib.util.spec_from_file_location(f"{package_name}.network", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
network = _load_network_module()
class NetworkHelpersTests(unittest.TestCase):
def test_normalize_host_strips_protocol_and_path(self):
self.assertEqual(network.normalize_host(" https://example.com/a/b "), "example.com")
def test_normalize_host_keeps_none(self):
self.assertIsNone(network.normalize_host(None))
def test_build_worker_url_defaults_to_server_address(self):
worker = {"id": "w1", "port": 8189}
self.assertEqual(network.build_worker_url(worker, "/prompt"), "http://127.0.0.1:8189/prompt")
def test_build_worker_url_cloud_defaults_to_https(self):
worker = {"id": "w2", "host": "foo.proxy.runpod.net", "port": 443}
self.assertEqual(network.build_worker_url(worker), "https://foo.proxy.runpod.net")
def test_build_worker_url_keeps_explicit_scheme(self):
worker = {"id": "w3", "host": "https://worker.example.com", "port": 1234}
self.assertEqual(network.build_worker_url(worker, "/prompt"), "https://worker.example.com/prompt")
def test_build_master_url_uses_https_for_cloud_host(self):
cfg = {"master": {"host": "demo.proxy.runpod.net"}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8188)
self.assertEqual(
network.build_master_url(config=cfg, prompt_server_instance=prompt_server),
"https://demo.proxy.runpod.net",
)
def test_build_master_url_keeps_explicit_scheme(self):
cfg = {"master": {"host": "https://master.example.com/"}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8188)
self.assertEqual(
network.build_master_url(config=cfg, prompt_server_instance=prompt_server),
"https://master.example.com",
)
def test_build_master_url_ignores_stale_saved_port_and_uses_runtime_port(self):
cfg = {"master": {"host": "192.168.68.56", "port": 8001}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8188)
self.assertEqual(
network.build_master_url(config=cfg, prompt_server_instance=prompt_server),
"http://192.168.68.56:8188",
)
def test_build_master_url_keeps_explicit_port_in_host(self):
cfg = {"master": {"host": "192.168.68.56:8001"}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8188)
self.assertEqual(
network.build_master_url(config=cfg, prompt_server_instance=prompt_server),
"http://192.168.68.56:8001",
)
def test_build_master_url_falls_back_to_server_address(self):
cfg = {"master": {"host": "", "port": 8001}}
prompt_server = types.SimpleNamespace(address="0.0.0.0", port=8190)
self.assertEqual(
network.build_master_url(config=cfg, prompt_server_instance=prompt_server),
"http://127.0.0.1:8190",
)
def test_build_master_callback_url_uses_loopback_for_local_worker(self):
cfg = {"master": {"host": "192.168.68.56"}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8001)
worker = {"id": "w1", "type": "local", "host": "localhost", "port": 8189}
self.assertEqual(
network.build_master_callback_url(worker, config=cfg, prompt_server_instance=prompt_server),
"http://127.0.0.1:8001",
)
def test_build_master_callback_url_keeps_public_master_url_for_remote_worker(self):
cfg = {"master": {"host": "192.168.68.56"}}
prompt_server = types.SimpleNamespace(address="127.0.0.1", port=8001)
worker = {"id": "w2", "type": "remote", "host": "192.168.68.99", "port": 8189}
self.assertEqual(
network.build_master_callback_url(worker, config=cfg, prompt_server_instance=prompt_server),
"http://192.168.68.56:8001",
)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_payload_parsers.py
================================================
import importlib.util
import io
import json
import sys
import types
import unittest
from pathlib import Path
try:
from PIL import Image as PILImage
PIL_AVAILABLE = True
except ImportError:
PIL_AVAILABLE = False
def _load_payload_parsers_module():
# payload_parsers.py has no relative imports; only stdlib + PIL
module_path = Path(__file__).resolve().parents[1] / "upscale" / "payload_parsers.py"
spec = importlib.util.spec_from_file_location("upscale_payload_parsers", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
if PIL_AVAILABLE:
pp = _load_payload_parsers_module()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_png_bytes(width=64, height=64, color=(128, 64, 32)):
"""Return raw PNG bytes for a solid-colour image."""
img = PILImage.new("RGB", (width, height), color=color)
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
class _MockFileField:
"""Minimal multipart file-field stub."""
class _MockFile:
def __init__(self, data: bytes):
self._buf = io.BytesIO(data)
def read(self) -> bytes:
return self._buf.read()
def __init__(self, data: bytes):
self.file = self._MockFile(data)
def _make_form(n_tiles, *, padding=None, extra_meta=None, image_color=(128, 64, 32)):
"""Build a minimal form-data dict with `n_tiles` tile entries."""
image_bytes = _make_png_bytes(color=image_color)
metadata = []
for i in range(n_tiles):
entry = {
"tile_idx": i,
"x": i * 64,
"y": 0,
"extracted_width": 64,
"extracted_height": 64,
}
if extra_meta and i < len(extra_meta):
entry.update(extra_meta[i])
metadata.append(entry)
form = {"tiles_metadata": json.dumps(metadata)}
if padding is not None:
form["padding"] = str(padding)
for i in range(n_tiles):
form[f"tile_{i}"] = _MockFileField(image_bytes)
return form
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@unittest.skipUnless(PIL_AVAILABLE, "PIL not installed")
class ParseTilesFromFormTests(unittest.TestCase):
# --- happy paths ---
def test_single_tile_returns_one_entry(self):
tiles = pp._parse_tiles_from_form(_make_form(1))
self.assertEqual(len(tiles), 1)
def test_multiple_tiles_all_returned(self):
tiles = pp._parse_tiles_from_form(_make_form(3))
self.assertEqual(len(tiles), 3)
def test_tile_image_is_pil_image(self):
tiles = pp._parse_tiles_from_form(_make_form(1))
self.assertIsInstance(tiles[0]["image"], PILImage.Image)
def test_tile_metadata_fields_are_parsed(self):
tiles = pp._parse_tiles_from_form(_make_form(1))
tile = tiles[0]
self.assertEqual(tile["tile_idx"], 0)
self.assertEqual(tile["x"], 0)
self.assertEqual(tile["y"], 0)
self.assertEqual(tile["extracted_width"], 64)
self.assertEqual(tile["extracted_height"], 64)
def test_padding_is_parsed_from_form(self):
tiles = pp._parse_tiles_from_form(_make_form(1, padding=16))
self.assertEqual(tiles[0]["padding"], 16)
def test_default_padding_is_zero(self):
form = _make_form(1)
form.pop("padding", None)
tiles = pp._parse_tiles_from_form(form)
self.assertEqual(tiles[0]["padding"], 0)
def test_invalid_padding_string_falls_back_to_zero(self):
form = _make_form(1)
form["padding"] = "not_a_number"
tiles = pp._parse_tiles_from_form(form)
self.assertEqual(tiles[0]["padding"], 0)
def test_optional_batch_idx_included_when_present(self):
extra = [{"batch_idx": 2}]
tiles = pp._parse_tiles_from_form(_make_form(1, extra_meta=extra))
self.assertEqual(tiles[0]["batch_idx"], 2)
def test_optional_global_idx_included_when_present(self):
extra = [{"global_idx": 5}]
tiles = pp._parse_tiles_from_form(_make_form(1, extra_meta=extra))
self.assertEqual(tiles[0]["global_idx"], 5)
def test_batch_idx_and_global_idx_absent_when_not_in_metadata(self):
tiles = pp._parse_tiles_from_form(_make_form(1))
self.assertNotIn("batch_idx", tiles[0])
self.assertNotIn("global_idx", tiles[0])
def test_tile_indices_match_metadata_order(self):
tiles = pp._parse_tiles_from_form(_make_form(3))
for i, tile in enumerate(tiles):
self.assertEqual(tile["tile_idx"], i)
def test_x_coordinates_reflect_metadata(self):
tiles = pp._parse_tiles_from_form(_make_form(3))
self.assertEqual(tiles[1]["x"], 64)
self.assertEqual(tiles[2]["x"], 128)
# --- error cases ---
def test_missing_tiles_metadata_raises_value_error(self):
with self.assertRaises(ValueError, msg="Missing tiles_metadata"):
pp._parse_tiles_from_form({})
def test_invalid_json_metadata_raises_value_error(self):
form = {"tiles_metadata": "{not valid json}"}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
def test_non_list_metadata_raises_value_error(self):
form = {"tiles_metadata": json.dumps({"not": "a list"})}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
def test_missing_tile_file_field_raises_value_error(self):
form = {
"tiles_metadata": json.dumps([{"tile_idx": 0, "x": 0, "y": 0}]),
# tile_0 intentionally omitted
}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
def test_tile_field_without_file_attr_raises_value_error(self):
form = {
"tiles_metadata": json.dumps([{"tile_idx": 0, "x": 0, "y": 0}]),
"tile_0": "plain string without .file",
}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
def test_non_image_bytes_raises_value_error(self):
class _BadFileField:
class _BadFile:
def read(self):
return b"this is definitely not image data"
file = _BadFile()
form = {
"tiles_metadata": json.dumps([{"tile_idx": 0, "x": 0, "y": 0}]),
"tile_0": _BadFileField(),
}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
def test_invalid_metadata_value_type_raises_value_error(self):
"""Non-integer metadata fields (x, y, etc.) should raise ValueError."""
form = {
"tiles_metadata": json.dumps([{"tile_idx": 0, "x": "not_int", "y": 0}]),
"tile_0": _MockFileField(_make_png_bytes()),
}
with self.assertRaises(ValueError):
pp._parse_tiles_from_form(form)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_prompt_transform.py
================================================
import importlib.util
import json
import sys
import types
import unittest
from pathlib import Path
def _load_prompt_transform_module():
module_path = Path(__file__).resolve().parents[1] / "api" / "orchestration" / "prompt_transform.py"
package_name = "dist_pt_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
api_pkg = types.ModuleType(f"{package_name}.api")
api_pkg.__path__ = []
sys.modules[f"{package_name}.api"] = api_pkg
orch_pkg = types.ModuleType(f"{package_name}.api.orchestration")
orch_pkg.__path__ = []
sys.modules[f"{package_name}.api.orchestration"] = orch_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.api.orchestration.prompt_transform",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
pt = _load_prompt_transform_module()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _linear_prompt():
"""1 → 2 → 3 → 4(DistributedCollector) → 5(SaveImage)"""
return {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {}},
"2": {"class_type": "CLIPTextEncode", "inputs": {"clip": ["1", 1]}},
"3": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["2", 0]}},
"4": {"class_type": "DistributedCollector", "inputs": {"images": ["3", 0]}},
"5": {"class_type": "SaveImage", "inputs": {"images": ["4", 0]}},
}
def _collector_only_prompt():
"""1(Checkpoint) → 2(DistributedCollector) [no downstream from 2]"""
return {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {}},
"2": {"class_type": "DistributedCollector", "inputs": {"images": ["1", 0]}},
}
def _delegate_prompt():
"""1 → 2 → 3(DistributedCollector) → 4(SaveImage)"""
return {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {}},
"2": {"class_type": "KSampler", "inputs": {"model": ["1", 0]}},
"3": {"class_type": "DistributedCollector", "inputs": {"images": ["2", 0]}},
"4": {"class_type": "SaveImage", "inputs": {"images": ["3", 0]}},
}
def _apply(prompt, participant_id, enabled_worker_ids=None, delegate_master=False):
if enabled_worker_ids is None:
enabled_worker_ids = ["worker-a", "worker-b"]
idx = pt.PromptIndex(prompt)
job_id_map = pt.generate_job_id_map(idx, "run")
return pt.apply_participant_overrides(
prompt,
participant_id=participant_id,
enabled_worker_ids=enabled_worker_ids,
job_id_map=job_id_map,
master_url="http://master.example.com",
delegate_master=delegate_master,
prompt_index=idx,
)
# ---------------------------------------------------------------------------
# PromptIndex
# ---------------------------------------------------------------------------
class PromptIndexTests(unittest.TestCase):
def test_nodes_by_class_groups_correctly(self):
prompt = {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {}},
"2": {"class_type": "DistributedCollector", "inputs": {}},
"3": {"class_type": "DistributedCollector", "inputs": {}},
}
idx = pt.PromptIndex(prompt)
self.assertCountEqual(idx.nodes_for_class("DistributedCollector"), ["2", "3"])
self.assertEqual(idx.nodes_for_class("CheckpointLoaderSimple"), ["1"])
def test_nodes_for_class_unknown_returns_empty(self):
idx = pt.PromptIndex({"1": {"class_type": "KSampler", "inputs": {}}})
self.assertEqual(idx.nodes_for_class("Nonexistent"), [])
def test_nodes_without_class_type_are_indexed_under_none(self):
prompt = {"1": {"inputs": {}}}
idx = pt.PromptIndex(prompt)
# Should not raise; nodes_for_class with None key or missing class_type
self.assertEqual(idx.nodes_for_class("KSampler"), [])
def test_copy_prompt_is_a_deep_copy(self):
prompt = {"1": {"class_type": "KSampler", "inputs": {"seed": 42}}}
idx = pt.PromptIndex(prompt)
copy = idx.copy_prompt()
copy["1"]["inputs"]["seed"] = 999
self.assertEqual(prompt["1"]["inputs"]["seed"], 42)
def test_has_upstream_direct_connection(self):
"""Node 4 reads directly from node 3 (KSampler)."""
idx = pt.PromptIndex(_linear_prompt())
self.assertTrue(idx.has_upstream("4", "KSampler"))
def test_has_upstream_transitive_connection(self):
"""Node 4 → 3 → 2 → 1 (CheckpointLoaderSimple)."""
idx = pt.PromptIndex(_linear_prompt())
self.assertTrue(idx.has_upstream("4", "CheckpointLoaderSimple"))
def test_has_upstream_returns_false_when_no_path(self):
idx = pt.PromptIndex(_linear_prompt())
# CheckpointLoaderSimple has no upstream nodes
self.assertFalse(idx.has_upstream("1", "DistributedCollector"))
def test_has_upstream_result_is_cached(self):
idx = pt.PromptIndex(_linear_prompt())
r1 = idx.has_upstream("4", "KSampler")
r2 = idx.has_upstream("4", "KSampler")
self.assertEqual(r1, r2)
self.assertIn(("4", "KSampler"), idx._upstream_cache)
def test_has_upstream_does_not_infinite_loop_on_cycle(self):
"""Cyclic references in inputs should not cause infinite recursion."""
prompt = {
"1": {"class_type": "A", "inputs": {"x": ["2", 0]}},
"2": {"class_type": "B", "inputs": {"x": ["1", 0]}},
}
idx = pt.PromptIndex(prompt)
# Should terminate without error
result = idx.has_upstream("1", "NonExistent")
self.assertFalse(result)
# ---------------------------------------------------------------------------
# find_nodes_by_class
# ---------------------------------------------------------------------------
class FindNodesByClassTests(unittest.TestCase):
def test_finds_matching_nodes(self):
prompt = {
"1": {"class_type": "KSampler", "inputs": {}},
"2": {"class_type": "DistributedCollector", "inputs": {}},
}
result = pt.find_nodes_by_class(prompt, "KSampler")
self.assertEqual(result, ["1"])
def test_returns_empty_when_no_match(self):
prompt = {"1": {"class_type": "KSampler", "inputs": {}}}
self.assertEqual(pt.find_nodes_by_class(prompt, "DistributedCollector"), [])
def test_skips_non_dict_nodes(self):
prompt = {"1": "not a dict", "2": {"class_type": "KSampler", "inputs": {}}}
result = pt.find_nodes_by_class(prompt, "KSampler")
self.assertEqual(result, ["2"])
# ---------------------------------------------------------------------------
# prune_prompt_for_worker
# ---------------------------------------------------------------------------
class PrunePromptForWorkerTests(unittest.TestCase):
def test_no_distributed_nodes_returns_prompt_unchanged(self):
prompt = {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {}},
"2": {"class_type": "SaveImage", "inputs": {"images": ["1", 0]}},
}
result = pt.prune_prompt_for_worker(prompt)
self.assertCountEqual(result.keys(), ["1", "2"])
def test_keeps_collector_and_upstream(self):
prompt = _linear_prompt()
result = pt.prune_prompt_for_worker(prompt)
for node_id in ("1", "2", "3", "4"):
self.assertIn(node_id, result)
def test_removes_downstream_of_collector(self):
prompt = _linear_prompt()
result = pt.prune_prompt_for_worker(prompt)
self.assertNotIn("5", result)
def test_injects_preview_image_when_downstream_exists(self):
prompt = _linear_prompt()
result = pt.prune_prompt_for_worker(prompt)
preview_nodes = [n for n in result.values() if n.get("class_type") == "PreviewImage"]
self.assertEqual(len(preview_nodes), 1)
self.assertEqual(preview_nodes[0]["inputs"]["images"], ["4", 0])
def test_no_preview_image_when_no_downstream(self):
result = pt.prune_prompt_for_worker(_collector_only_prompt())
preview_nodes = [n for n in result.values() if n.get("class_type") == "PreviewImage"]
self.assertEqual(len(preview_nodes), 0)
def test_unrelated_nodes_are_pruned(self):
prompt = {
"1": {"class_type": "DistributedCollector", "inputs": {}},
"2": {"class_type": "UnrelatedNode", "inputs": {}}, # no connection to 1
}
result = pt.prune_prompt_for_worker(prompt)
self.assertIn("1", result)
self.assertNotIn("2", result)
def test_result_is_a_copy_not_same_object(self):
prompt = _linear_prompt()
result = pt.prune_prompt_for_worker(prompt)
# Mutating the result should not affect the original
original_keys = set(prompt.keys())
result["NEW"] = {"class_type": "Test", "inputs": {}}
self.assertEqual(set(prompt.keys()), original_keys)
def test_upscale_node_is_treated_as_distributed(self):
prompt = {
"1": {"class_type": "KSampler", "inputs": {}},
"2": {"class_type": "UltimateSDUpscaleDistributed", "inputs": {"image": ["1", 0]}},
"3": {"class_type": "SaveImage", "inputs": {"images": ["2", 0]}},
}
result = pt.prune_prompt_for_worker(prompt)
self.assertIn("1", result)
self.assertIn("2", result)
self.assertNotIn("3", result)
# ---------------------------------------------------------------------------
# prepare_delegate_master_prompt
# ---------------------------------------------------------------------------
class PrepareDelegateMasterPromptTests(unittest.TestCase):
def test_keeps_collector_and_downstream(self):
prompt = _delegate_prompt()
result = pt.prepare_delegate_master_prompt(prompt, ["3"])
self.assertIn("3", result)
self.assertIn("4", result)
self.assertNotIn("1", result)
self.assertNotIn("2", result)
def test_removes_dangling_upstream_refs(self):
"""Collector must not retain dangling refs to pruned upstream nodes."""
prompt = _delegate_prompt()
result = pt.prepare_delegate_master_prompt(prompt, ["3"])
collector_inputs = result["3"].get("inputs", {})
# Original "images" pointed at node 2, which is pruned.
# It should now point at a newly injected placeholder node.
self.assertIn("images", collector_inputs)
source_id = str(collector_inputs["images"][0])
self.assertNotEqual(source_id, "2")
self.assertIn(source_id, result)
self.assertEqual(result[source_id].get("class_type"), "DistributedEmptyImage")
def test_injects_empty_image_placeholder(self):
prompt = _delegate_prompt()
result = pt.prepare_delegate_master_prompt(prompt, ["3"])
empty_nodes = [(nid, n) for nid, n in result.items() if n.get("class_type") == "DistributedEmptyImage"]
self.assertEqual(len(empty_nodes), 1)
placeholder_id = empty_nodes[0][0]
self.assertEqual(result["3"]["inputs"]["images"], [placeholder_id, 0])
def test_one_placeholder_per_collector(self):
"""Two collectors → two placeholders."""
prompt = {
"1": {"class_type": "DistributedCollector", "inputs": {}},
"2": {"class_type": "DistributedCollector", "inputs": {}},
"3": {"class_type": "SaveImage", "inputs": {"images": ["1", 0]}},
}
result = pt.prepare_delegate_master_prompt(prompt, ["1", "2"])
empty_nodes = [n for n in result.values() if n.get("class_type") == "DistributedEmptyImage"]
self.assertEqual(len(empty_nodes), 2)
def test_result_is_independent_copy(self):
prompt = _delegate_prompt()
result = pt.prepare_delegate_master_prompt(prompt, ["3"])
result["3"]["inputs"]["NEW"] = "injected"
# Original should be untouched
self.assertNotIn("NEW", prompt["3"].get("inputs", {}))
# ---------------------------------------------------------------------------
# generate_job_id_map
# ---------------------------------------------------------------------------
class GenerateJobIdMapTests(unittest.TestCase):
def test_maps_collector_nodes(self):
prompt = {
"1": {"class_type": "DistributedCollector", "inputs": {}},
"2": {"class_type": "KSampler", "inputs": {}},
}
idx = pt.PromptIndex(prompt)
job_map = pt.generate_job_id_map(idx, "prefix")
self.assertEqual(job_map["1"], "prefix_1")
self.assertNotIn("2", job_map)
def test_maps_upscale_nodes(self):
prompt = {
"5": {"class_type": "UltimateSDUpscaleDistributed", "inputs": {}},
}
idx = pt.PromptIndex(prompt)
job_map = pt.generate_job_id_map(idx, "run")
self.assertEqual(job_map["5"], "run_5")
def test_empty_prompt_returns_empty_map(self):
idx = pt.PromptIndex({})
self.assertEqual(pt.generate_job_id_map(idx, "prefix"), {})
def test_stable_ids_across_calls(self):
prompt = {"1": {"class_type": "DistributedCollector", "inputs": {}}}
idx = pt.PromptIndex(prompt)
m1 = pt.generate_job_id_map(idx, "run")
m2 = pt.generate_job_id_map(idx, "run")
self.assertEqual(m1, m2)
# ---------------------------------------------------------------------------
# apply_participant_overrides – DistributedCollector
# ---------------------------------------------------------------------------
class ApplyOverridesCollectorTests(unittest.TestCase):
def _collector_prompt(self):
return {"1": {"class_type": "DistributedCollector", "inputs": {}}}
def test_worker_sets_is_worker_true(self):
result = _apply(self._collector_prompt(), "worker-a")
self.assertTrue(result["1"]["inputs"]["is_worker"])
def test_worker_sets_master_url(self):
result = _apply(self._collector_prompt(), "worker-a")
self.assertEqual(result["1"]["inputs"]["master_url"], "http://master.example.com")
def test_worker_sets_worker_id(self):
result = _apply(self._collector_prompt(), "worker-a")
self.assertEqual(result["1"]["inputs"]["worker_id"], "worker-a")
def test_worker_sets_delegate_only_false(self):
result = _apply(self._collector_prompt(), "worker-a")
self.assertFalse(result["1"]["inputs"]["delegate_only"])
def test_master_sets_is_worker_false(self):
result = _apply(self._collector_prompt(), "master")
self.assertFalse(result["1"]["inputs"]["is_worker"])
def test_master_clears_stale_master_url(self):
prompt = {"1": {"class_type": "DistributedCollector", "inputs": {"master_url": "stale"}}}
result = _apply(prompt, "master")
self.assertNotIn("master_url", result["1"]["inputs"])
def test_master_clears_stale_worker_id(self):
prompt = {"1": {"class_type": "DistributedCollector", "inputs": {"worker_id": "stale"}}}
result = _apply(prompt, "master")
self.assertNotIn("worker_id", result["1"]["inputs"])
def test_master_with_delegate_master_sets_delegate_only_true(self):
result = _apply(self._collector_prompt(), "master", delegate_master=True)
self.assertTrue(result["1"]["inputs"]["delegate_only"])
def test_master_without_delegate_master_sets_delegate_only_false(self):
result = _apply(self._collector_prompt(), "master", delegate_master=False)
self.assertFalse(result["1"]["inputs"]["delegate_only"])
def test_enabled_worker_ids_serialized_as_json(self):
enabled = ["worker-a", "worker-b"]
result = _apply(self._collector_prompt(), "master", enabled_worker_ids=enabled)
self.assertEqual(result["1"]["inputs"]["enabled_worker_ids"], json.dumps(enabled))
def test_multi_job_id_is_set_from_job_map(self):
prompt = {"1": {"class_type": "DistributedCollector", "inputs": {}}}
idx = pt.PromptIndex(prompt)
job_id_map = {"1": "run_abc_1"}
result = pt.apply_participant_overrides(
prompt,
participant_id="worker-a",
enabled_worker_ids=["worker-a"],
job_id_map=job_id_map,
master_url="http://master",
delegate_master=False,
prompt_index=idx,
)
self.assertEqual(result["1"]["inputs"]["multi_job_id"], "run_abc_1")
# ---------------------------------------------------------------------------
# apply_participant_overrides – DistributedSeed
# ---------------------------------------------------------------------------
class ApplyOverridesSeedTests(unittest.TestCase):
def _seed_prompt(self):
return {"1": {"class_type": "DistributedSeed", "inputs": {}}}
def test_worker_sets_is_worker_true(self):
result = _apply(self._seed_prompt(), "worker-a")
self.assertTrue(result["1"]["inputs"]["is_worker"])
def test_worker_id_reflects_index_in_enabled_list(self):
result = _apply(self._seed_prompt(), "worker-b", enabled_worker_ids=["worker-a", "worker-b"])
self.assertEqual(result["1"]["inputs"]["worker_id"], "worker_1")
def test_master_sets_is_worker_false(self):
result = _apply(self._seed_prompt(), "master")
self.assertFalse(result["1"]["inputs"]["is_worker"])
def test_master_sets_empty_worker_id(self):
result = _apply(self._seed_prompt(), "master")
self.assertEqual(result["1"]["inputs"]["worker_id"], "")
# ---------------------------------------------------------------------------
# apply_participant_overrides – UltimateSDUpscaleDistributed
# ---------------------------------------------------------------------------
class ApplyOverridesUpscaleTests(unittest.TestCase):
def _upscale_prompt(self):
return {"1": {"class_type": "UltimateSDUpscaleDistributed", "inputs": {}}}
def test_worker_sets_is_worker_true(self):
result = _apply(self._upscale_prompt(), "worker-a")
self.assertTrue(result["1"]["inputs"]["is_worker"])
def test_worker_sets_master_url_and_worker_id(self):
result = _apply(self._upscale_prompt(), "worker-a")
self.assertEqual(result["1"]["inputs"]["master_url"], "http://master.example.com")
self.assertEqual(result["1"]["inputs"]["worker_id"], "worker-a")
def test_master_clears_master_url_and_worker_id(self):
prompt = {"1": {"class_type": "UltimateSDUpscaleDistributed", "inputs": {"master_url": "x", "worker_id": "y"}}}
result = _apply(prompt, "master")
self.assertNotIn("master_url", result["1"]["inputs"])
self.assertNotIn("worker_id", result["1"]["inputs"])
def test_collector_downstream_of_upscale_gets_pass_through(self):
"""A DistributedCollector that is downstream of UltimateSDUpscaleDistributed → pass_through=True."""
prompt = {
"1": {"class_type": "UltimateSDUpscaleDistributed", "inputs": {}},
"2": {"class_type": "DistributedCollector", "inputs": {"images": ["1", 0]}},
}
result = _apply(prompt, "worker-a", enabled_worker_ids=["worker-a"])
self.assertTrue(result["2"]["inputs"].get("pass_through"))
# ---------------------------------------------------------------------------
# apply_participant_overrides – DistributedValue
# ---------------------------------------------------------------------------
class ApplyOverridesValueTests(unittest.TestCase):
def _value_prompt(self):
return {"1": {"class_type": "DistributedValue", "inputs": {}}}
def test_worker_sets_is_worker_true(self):
result = _apply(self._value_prompt(), "worker-a")
self.assertTrue(result["1"]["inputs"]["is_worker"])
def test_worker_id_reflects_index_in_enabled_list(self):
result = _apply(self._value_prompt(), "worker-b", enabled_worker_ids=["worker-a", "worker-b"])
self.assertEqual(result["1"]["inputs"]["worker_id"], "worker_1")
def test_master_sets_is_worker_false(self):
result = _apply(self._value_prompt(), "master")
self.assertFalse(result["1"]["inputs"]["is_worker"])
def test_master_sets_empty_worker_id(self):
result = _apply(self._value_prompt(), "master")
self.assertEqual(result["1"]["inputs"]["worker_id"], "")
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_queue_request.py
================================================
import importlib.util
import unittest
from pathlib import Path
def _load_queue_request_module():
module_path = Path(__file__).resolve().parents[1] / "api" / "queue_request.py"
spec = importlib.util.spec_from_file_location("queue_request", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
queue_request = _load_queue_request_module()
parse_queue_request_payload = queue_request.parse_queue_request_payload
class QueueRequestPayloadTests(unittest.TestCase):
def _base_payload(self):
return {
"prompt": {"1": {"class_type": "Anything"}},
"enabled_worker_ids": ["worker-1"],
"client_id": "client-1",
}
def test_normalizes_enabled_worker_ids(self):
payload_data = self._base_payload()
payload_data["enabled_worker_ids"] = ["a", 2, 3]
payload_data["delegate_master"] = True
payload = parse_queue_request_payload(
payload_data
)
self.assertEqual(payload.enabled_worker_ids, ["a", "2", "3"])
self.assertTrue(payload.delegate_master)
def test_supports_legacy_workers_field(self):
payload_data = self._base_payload()
payload_data.pop("enabled_worker_ids", None)
payload_data["workers"] = [{"id": "w1"}, "w2", {"id": 3}, {"name": "no-id"}]
payload = parse_queue_request_payload(
payload_data
)
self.assertEqual(payload.enabled_worker_ids, ["w1", "w2", "3"])
def test_supports_auto_prepare_prompt_fallback(self):
payload_data = self._base_payload()
payload_data.pop("prompt", None)
payload_data["auto_prepare"] = True
payload_data["workflow"] = {
"prompt": {
"10": {"class_type": "DistributedCollector"},
}
}
payload = parse_queue_request_payload(
payload_data
)
self.assertIn("10", payload.prompt)
self.assertTrue(payload.auto_prepare)
def test_normalizes_trace_execution_id(self):
payload_data = self._base_payload()
payload_data["trace_execution_id"] = " exec_123 "
payload = parse_queue_request_payload(
payload_data
)
self.assertEqual(payload.trace_execution_id, "exec_123")
def test_blank_trace_execution_id_normalizes_to_none(self):
payload_data = self._base_payload()
payload_data["trace_execution_id"] = " "
payload = parse_queue_request_payload(
payload_data
)
self.assertIsNone(payload.trace_execution_id)
def test_auto_prepare_defaults_true(self):
payload = parse_queue_request_payload(self._base_payload())
self.assertTrue(payload.auto_prepare)
def test_workers_field_must_be_list(self):
payload_data = self._base_payload()
payload_data.pop("enabled_worker_ids", None)
payload_data["workers"] = "worker-a"
with self.assertRaisesRegex(ValueError, "Field 'workers' must be a list"):
parse_queue_request_payload(payload_data)
def test_trace_execution_id_must_be_string(self):
payload_data = self._base_payload()
payload_data["trace_execution_id"] = 123
with self.assertRaisesRegex(ValueError, "trace_execution_id must be a string"):
parse_queue_request_payload(payload_data)
def test_auto_prepare_false_still_falls_back_to_workflow_prompt(self):
payload_data = self._base_payload()
payload_data.pop("prompt", None)
payload_data["auto_prepare"] = False
payload_data["workflow"] = {
"prompt": {"10": {"class_type": "DistributedCollector"}},
}
payload = parse_queue_request_payload(payload_data)
self.assertIn("10", payload.prompt)
self.assertFalse(payload.auto_prepare)
def test_auto_prepare_must_be_boolean(self):
payload_data = self._base_payload()
payload_data["auto_prepare"] = "true"
with self.assertRaisesRegex(ValueError, "auto_prepare must be a boolean"):
parse_queue_request_payload(payload_data)
def test_invalid_delegate_master_type_raises(self):
payload_data = self._base_payload()
payload_data["delegate_master"] = "yes"
with self.assertRaisesRegex(ValueError, "delegate_master must be a boolean"):
parse_queue_request_payload(payload_data)
def test_invalid_enabled_worker_ids_type_raises(self):
payload_data = self._base_payload()
payload_data["enabled_worker_ids"] = "worker-a"
with self.assertRaisesRegex(ValueError, "enabled_worker_ids must be a list"):
parse_queue_request_payload(payload_data)
def test_invalid_top_level_payload_raises(self):
with self.assertRaisesRegex(ValueError, "Expected a JSON object body"):
parse_queue_request_payload(["not", "an", "object"])
def test_missing_prompt_raises(self):
with self.assertRaisesRegex(ValueError, "Field 'prompt' must be an object"):
parse_queue_request_payload(
{
"workflow": {},
"enabled_worker_ids": ["worker-1"],
"client_id": "client-1",
}
)
def test_missing_enabled_worker_ids_raises(self):
payload_data = self._base_payload()
payload_data.pop("enabled_worker_ids", None)
with self.assertRaisesRegex(ValueError, "enabled_worker_ids required"):
parse_queue_request_payload(payload_data)
def test_missing_client_id_raises(self):
payload_data = self._base_payload()
payload_data.pop("client_id", None)
with self.assertRaisesRegex(ValueError, "client_id required"):
parse_queue_request_payload(payload_data)
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_static_mode.py
================================================
import asyncio
import importlib.util
import sys
import types
import unittest
from pathlib import Path
import torch
def _load_static_mode_module():
module_path = Path(__file__).resolve().parents[1] / "upscale" / "modes" / "static.py"
package_name = "dist_static_mode_testpkg"
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
upscale_pkg = types.ModuleType(f"{package_name}.upscale")
upscale_pkg.__path__ = []
sys.modules[f"{package_name}.upscale"] = upscale_pkg
modes_pkg = types.ModuleType(f"{package_name}.upscale.modes")
modes_pkg.__path__ = []
sys.modules[f"{package_name}.upscale.modes"] = modes_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
created_comfy_stub = False
if "comfy" not in sys.modules:
created_comfy_stub = True
comfy_module = types.ModuleType("comfy")
model_mgmt = types.ModuleType("comfy.model_management")
class _InterruptProcessingException(Exception):
pass
model_mgmt.processing_interrupted = lambda: False
model_mgmt.throw_exception_if_processing_interrupted = lambda: None
model_mgmt.InterruptProcessingException = _InterruptProcessingException
comfy_module.model_management = model_mgmt
sys.modules["comfy"] = comfy_module
sys.modules["comfy.model_management"] = model_mgmt
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
image_module = types.ModuleType(f"{package_name}.utils.image")
from PIL import Image as PILImage
import numpy as np
def _tensor_to_pil(img_tensor, batch_index=0):
return PILImage.fromarray((255 * img_tensor[batch_index].cpu().numpy()).astype(np.uint8))
def _pil_to_tensor(image):
arr = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0)
image_module.tensor_to_pil = _tensor_to_pil
image_module.pil_to_tensor = _pil_to_tensor
sys.modules[f"{package_name}.utils.image"] = image_module
async_helpers_module = types.ModuleType(f"{package_name}.utils.async_helpers")
def _run_async_in_server_loop(coro, timeout=None):
if timeout is not None:
return asyncio.run(asyncio.wait_for(coro, timeout=timeout))
return asyncio.run(coro)
async_helpers_module.run_async_in_server_loop = _run_async_in_server_loop
sys.modules[f"{package_name}.utils.async_helpers"] = async_helpers_module
config_module = types.ModuleType(f"{package_name}.utils.config")
config_module.get_worker_timeout_seconds = lambda: 60
sys.modules[f"{package_name}.utils.config"] = config_module
constants_module = types.ModuleType(f"{package_name}.utils.constants")
constants_module.HEARTBEAT_INTERVAL = 10.0
constants_module.JOB_POLL_INTERVAL = 0.0
constants_module.JOB_POLL_MAX_ATTEMPTS = 3
constants_module.MAX_BATCH = 20
constants_module.TILE_SEND_TIMEOUT = 1.0
constants_module.TILE_WAIT_TIMEOUT = 1.0
sys.modules[f"{package_name}.utils.constants"] = constants_module
job_store_module = types.ModuleType(f"{package_name}.upscale.job_store")
async def _noop(*_args, **_kwargs):
return None
job_store_module.ensure_tile_jobs_initialized = lambda: types.SimpleNamespace(
distributed_tile_jobs_lock=asyncio.Lock(),
distributed_pending_tile_jobs={},
)
job_store_module.init_static_job_batched = _noop
job_store_module._mark_task_completed = _noop
job_store_module._cleanup_job = _noop
job_store_module._drain_results_queue = _noop
job_store_module._get_completed_count = _noop
sys.modules[f"{package_name}.upscale.job_store"] = job_store_module
job_models_module = types.ModuleType(f"{package_name}.upscale.job_models")
class _TileJobState:
pass
job_models_module.TileJobState = _TileJobState
sys.modules[f"{package_name}.upscale.job_models"] = job_models_module
spec = importlib.util.spec_from_file_location(f"{package_name}.upscale.modes.static", module_path)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
if created_comfy_stub:
sys.modules.pop("comfy.model_management", None)
sys.modules.pop("comfy", None)
return module
static_mode = _load_static_mode_module()
class _FakeStaticWorker(static_mode.StaticModeMixin):
def __init__(self):
self.sent_batches = []
self.request_calls = 0
self.heartbeat_calls = 0
self.job_ready = True
self.tile_sequence = [(0, 0, True), (None, 0, True)]
def round_to_multiple(self, value):
return value
def calculate_tiles(self, _width, _height, _tile_width, _tile_height, _force_uniform_tiles):
return [(0, 0)]
def _poll_job_ready(self, *_args, **_kwargs):
return self.job_ready
async def _request_tile_from_master(self, *_args, **_kwargs):
self.request_calls += 1
return self.tile_sequence.pop(0)
async def _send_heartbeat_to_master(self, *_args, **_kwargs):
self.heartbeat_calls += 1
async def send_tiles_batch_to_master(
self,
processed_tiles,
_multi_job_id,
_master_url,
_padding,
_worker_id,
is_final_flush=False,
):
self.sent_batches.append(
{
"tiles": list(processed_tiles),
"is_final_flush": bool(is_final_flush),
}
)
def _extract_and_process_tile(self, upscaled_image, *_args, **_kwargs):
batch_size = upscaled_image.shape[0]
processed_batch = torch.zeros((batch_size, 2, 2, 3), dtype=torch.float32)
return processed_batch, 0, 0, 2, 2
def create_tile_mask(self, *_args, **_kwargs):
from PIL import Image
return Image.new("L", (4, 4), 255)
def blend_tile(self, base_image, *_args, **_kwargs):
return base_image
def _call_worker_static(fake_worker):
image = torch.zeros((1, 4, 4, 3), dtype=torch.float32)
return fake_worker._process_worker_static_sync(
image,
model=None,
positive=None,
negative=None,
vae=None,
seed=1,
steps=1,
cfg=1.0,
sampler_name="euler",
scheduler="normal",
denoise=0.5,
tile_width=4,
tile_height=4,
padding=8,
mask_blur=4,
force_uniform_tiles=True,
tiled_decode=False,
multi_job_id="job-1",
master_url="http://master:8188",
worker_id="worker-1",
enabled_workers=["worker-1"],
)
class StaticModeWorkerFlowTests(unittest.TestCase):
def test_worker_static_aborts_when_job_not_ready(self):
worker = _FakeStaticWorker()
worker.job_ready = False
result = _call_worker_static(worker)
self.assertEqual(result[0].shape[0], 1)
self.assertEqual(worker.request_calls, 0)
self.assertEqual(worker.heartbeat_calls, 0)
self.assertEqual(worker.sent_batches, [])
def test_worker_static_requests_tiles_and_flushes_final_batch(self):
worker = _FakeStaticWorker()
_call_worker_static(worker)
self.assertEqual(worker.request_calls, 2) # one tile, then sentinel
self.assertEqual(worker.heartbeat_calls, 1)
self.assertEqual(len(worker.sent_batches), 1)
self.assertTrue(worker.sent_batches[0]["is_final_flush"])
tiles = worker.sent_batches[0]["tiles"]
self.assertEqual(len(tiles), 1)
self.assertEqual(tiles[0]["tile_idx"], 0)
self.assertEqual(tiles[0]["global_idx"], 0)
self.assertEqual(tiles[0]["batch_idx"], 0)
def test_flush_empty_final_still_sends_completion_signal(self):
worker = _FakeStaticWorker()
returned = worker._flush_tiles_to_master(
[],
"job-1",
"http://master:8188",
8,
"worker-1",
is_final_flush=True,
)
self.assertEqual(returned, [])
self.assertEqual(len(worker.sent_batches), 1)
self.assertEqual(worker.sent_batches[0]["tiles"], [])
self.assertTrue(worker.sent_batches[0]["is_final_flush"])
if __name__ == "__main__":
unittest.main()
================================================
FILE: tests/test_worker_process_runtime.py
================================================
import importlib.util
import sys
import types
import unittest
from argparse import Namespace
from pathlib import Path
from unittest.mock import patch
def _load_process_module(module_filename: str):
module_path = Path(__file__).resolve().parents[1] / "workers" / "process" / module_filename
package_name = "dist_proc_testpkg"
module_name = module_filename[:-3]
for mod_name in list(sys.modules):
if mod_name == package_name or mod_name.startswith(f"{package_name}."):
del sys.modules[mod_name]
root_pkg = types.ModuleType(package_name)
root_pkg.__path__ = []
sys.modules[package_name] = root_pkg
workers_pkg = types.ModuleType(f"{package_name}.workers")
workers_pkg.__path__ = []
sys.modules[f"{package_name}.workers"] = workers_pkg
process_pkg = types.ModuleType(f"{package_name}.workers.process")
process_pkg.__path__ = []
sys.modules[f"{package_name}.workers.process"] = process_pkg
utils_pkg = types.ModuleType(f"{package_name}.utils")
utils_pkg.__path__ = []
sys.modules[f"{package_name}.utils"] = utils_pkg
logging_module = types.ModuleType(f"{package_name}.utils.logging")
logging_module.debug_log = lambda *_args, **_kwargs: None
logging_module.log = lambda *_args, **_kwargs: None
sys.modules[f"{package_name}.utils.logging"] = logging_module
process_module = types.ModuleType(f"{package_name}.utils.process")
process_module.get_python_executable = lambda: "/usr/bin/test-python"
sys.modules[f"{package_name}.utils.process"] = process_module
spec = importlib.util.spec_from_file_location(
f"{package_name}.workers.process.{module_name}",
module_path,
)
module = importlib.util.module_from_spec(spec)
assert spec is not None and spec.loader is not None
spec.loader.exec_module(module)
return module
root_discovery_module = _load_process_module("root_discovery.py")
launch_builder_module = _load_process_module("launch_builder.py")
class ComfyRootDiscoveryTests(unittest.TestCase):
def test_prefers_loaded_comfyui_module_path(self):
discovery = root_discovery_module.ComfyRootDiscovery()
server_module = types.SimpleNamespace(__file__="/opt/ComfyUI/server.py")
def fake_exists(path):
return path == "/opt/ComfyUI/main.py"
with patch.dict(sys.modules, {"server": server_module}, clear=False), \
patch.object(root_discovery_module.os.path, "exists", side_effect=fake_exists), \
patch.dict(root_discovery_module.os.environ, {}, clear=True):
self.assertEqual(discovery.find_comfy_root(), "/opt/ComfyUI")
class LaunchCommandBuilderTests(unittest.TestCase):
def test_inherits_runtime_layout_args_for_desktop(self):
builder = launch_builder_module.LaunchCommandBuilder()
runtime_args = Namespace(
listen="127.0.0.1",
base_directory="C:/Users/test/ComfyUI",
temp_directory=None,
input_directory="C:/Users/test/ComfyUI/input",
output_directory="C:/Users/test/ComfyUI/output",
user_directory="C:/Users/test/ComfyUI/user",
front_end_root="C:/Program Files/ComfyUI/web_custom_versions/desktop_app",
extra_model_paths_config=[["C:/Users/test/AppData/Roaming/ComfyUI/extra_models_config.yaml"]],
enable_manager=True,
disable_manager_ui=False,
enable_manager_legacy_ui=False,
windows_standalone_build=True,
log_stdout=True,
verbose="INFO",
enable_cors_header="*",
)
comfy_module = types.ModuleType("comfy")
comfy_cli_args = types.ModuleType("comfy.cli_args")
comfy_cli_args.args = runtime_args
worker_config = {
"port": 9001,
"extra_args": "--preview-method auto",
}
def fake_exists(path):
return path == "/desktop/ComfyUI/main.py"
with patch.dict(
sys.modules,
{"comfy": comfy_module, "comfy.cli_args": comfy_cli_args},
clear=False,
), patch.object(launch_builder_module.os.path, "exists", side_effect=fake_exists):
cmd = builder.build_launch_command(worker_config, "/desktop/ComfyUI")
self.assertEqual(cmd[:2], ["/usr/bin/test-python", "/desktop/ComfyUI/main.py"])
self.assertIn("--listen", cmd)
self.assertIn("127.0.0.1", cmd)
self.assertIn("--base-directory", cmd)
self.assertIn("C:/Users/test/ComfyUI", cmd)
self.assertIn("--input-directory", cmd)
self.assertIn("--output-directory", cmd)
self.assertIn("--user-directory", cmd)
self.assertIn("--front-end-root", cmd)
self.assertIn("--extra-model-paths-config", cmd)
self.assertIn("C:/Users/test/AppData/Roaming/ComfyUI/extra_models_config.yaml", cmd)
self.assertIn("--enable-manager", cmd)
self.assertIn("--windows-standalone-build", cmd)
self.assertIn("--log-stdout", cmd)
self.assertIn("--disable-auto-launch", cmd)
self.assertIn("--enable-cors-header", cmd)
self.assertIn("*", cmd)
self.assertIn("--port", cmd)
self.assertIn("9001", cmd)
self.assertNotIn("--auto-launch", cmd)
if __name__ == "__main__":
unittest.main()
================================================
FILE: upscale/__init__.py
================================================
================================================
FILE: upscale/conditioning.py
================================================
import copy
def clone_control_chain(control, clone_hint=True):
"""Shallow copy the ControlNet chain, optionally cloning hints but sharing models."""
if control is None:
return None
new_control = copy.copy(control)
if clone_hint and hasattr(control, 'cond_hint_original'):
hint = getattr(control, 'cond_hint_original', None)
new_control.cond_hint_original = hint.clone() if hint is not None else None
if hasattr(control, 'previous_controlnet'):
new_control.previous_controlnet = clone_control_chain(control.previous_controlnet, clone_hint)
return new_control
def clone_conditioning(cond_list, clone_hints=True):
"""Clone conditioning without duplicating ControlNet models."""
new_cond = []
for emb, cond_dict in cond_list:
new_emb = emb.clone() if emb is not None else None
new_dict = cond_dict.copy()
if 'control' in new_dict:
new_dict['control'] = clone_control_chain(new_dict['control'], clone_hints)
if 'mask' in new_dict and new_dict['mask'] is not None:
new_dict['mask'] = new_dict['mask'].clone()
if 'pooled_output' in new_dict and new_dict['pooled_output'] is not None:
new_dict['pooled_output'] = new_dict['pooled_output'].clone()
if 'area' in new_dict:
new_dict['area'] = new_dict['area'][:]
new_cond.append([new_emb, new_dict])
return new_cond
================================================
FILE: upscale/job_models.py
================================================
from dataclasses import dataclass, field
import asyncio
import time
class BaseJobState:
"""Marker base class for typed USDU job state containers."""
@dataclass
class TileJobState(BaseJobState):
"""Typed state container for static (tile) USDU jobs."""
multi_job_id: str
mode: str = field(default="static", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_tasks: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_tasks: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
created_at: float = field(default_factory=time.monotonic)
@dataclass
class ImageJobState(BaseJobState):
"""Typed state container for dynamic (per-image) USDU jobs."""
multi_job_id: str
mode: str = field(default="dynamic", init=False)
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
pending_images: asyncio.Queue = field(default_factory=asyncio.Queue)
completed_images: dict = field(default_factory=dict)
worker_status: dict = field(default_factory=dict)
assigned_to_workers: dict = field(default_factory=dict)
batch_size: int = 0
num_tiles_per_image: int = 0
batched_static: bool = False
created_at: float = field(default_factory=time.monotonic)
@property
def pending_tasks(self):
return self.pending_images
@property
def completed_tasks(self):
return self.completed_images
================================================
FILE: upscale/job_state.py
================================================
import asyncio
from ..utils.logging import debug_log
from .job_store import ensure_tile_jobs_initialized
from .job_timeout import _check_and_requeue_timed_out_workers as _requeue_usdu
from .job_models import ImageJobState, TileJobState
class JobStateMixin:
async def _get_job_data(self, multi_job_id):
"""Return current job data reference while holding lock briefly."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
return prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
async def _get_all_completed_tasks(self, multi_job_id):
"""Helper to retrieve all completed tasks from the job data."""
job_data = await self._get_job_data(multi_job_id)
if isinstance(job_data, TileJobState):
return dict(job_data.completed_tasks)
if isinstance(job_data, ImageJobState):
return dict(job_data.completed_images)
return {}
async def _get_next_image_index(self, multi_job_id):
"""Get next image index from pending queue for master."""
prompt_server = ensure_tile_jobs_initialized()
pending_queue = None
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
pending_queue = job_data.pending_images
if pending_queue is None:
return None
try:
return await asyncio.wait_for(pending_queue.get(), timeout=1.0)
except asyncio.TimeoutError:
return None
async def _get_next_tile_index(self, multi_job_id):
"""Get next tile index from pending queue for master in static mode."""
prompt_server = ensure_tile_jobs_initialized()
pending_queue = None
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, TileJobState):
pending_queue = job_data.pending_tasks
if pending_queue is None:
return None
try:
return await asyncio.wait_for(pending_queue.get(), timeout=0.1)
except asyncio.TimeoutError:
return None
async def _get_total_completed_count(self, multi_job_id):
"""Get total count of all completed images (master + workers)."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
return len(job_data.completed_images)
if isinstance(job_data, TileJobState):
return len(job_data.completed_tasks)
return 0
async def _get_all_completed_images(self, multi_job_id):
"""Get all completed images."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
return job_data.completed_images.copy()
return {}
async def _get_pending_count(self, multi_job_id):
"""Get count of pending images in the queue."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
return job_data.pending_images.qsize()
if isinstance(job_data, TileJobState):
return job_data.pending_tasks.qsize()
return 0
async def _drain_worker_results_queue(self, multi_job_id):
"""Drain pending worker results from queue and update completed images."""
prompt_server = ensure_tile_jobs_initialized()
worker_queue = None
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
worker_queue = job_data.queue
if worker_queue is None:
return 0
drained_results = []
while True:
try:
drained_results.append(worker_queue.get_nowait())
except asyncio.QueueEmpty:
break
if not drained_results:
return 0
collected = 0
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if not isinstance(job_data, ImageJobState):
return 0
for result in drained_results:
worker_id = result.get("worker_id")
if "image_idx" in result and "image" in result:
image_idx = result["image_idx"]
image_pil = result["image"]
if image_idx not in job_data.completed_images:
job_data.completed_images[image_idx] = image_pil
collected += 1
debug_log(f"Drained image {image_idx} from worker {worker_id}")
if collected > 0:
debug_log(f"Drained {collected} worker images during retry")
return collected
async def _check_and_requeue_timed_out_workers(self, multi_job_id, batch_size):
"""Check for timed out workers and requeue their assigned images."""
return await _requeue_usdu(multi_job_id, batch_size)
================================================
FILE: upscale/job_store.py
================================================
import asyncio
import os
import time
from typing import List, Optional
import server
from ..utils.logging import debug_log
from .job_models import BaseJobState, ImageJobState, TileJobState
# Configure maximum payload size (50MB default, configurable via environment variable)
MAX_PAYLOAD_SIZE = int(os.environ.get('COMFYUI_MAX_PAYLOAD_SIZE', str(50 * 1024 * 1024)))
def ensure_tile_jobs_initialized():
"""Ensure tile job storage is initialized on the server instance."""
prompt_server = server.PromptServer.instance
if not hasattr(prompt_server, 'distributed_pending_tile_jobs'):
debug_log("Initializing persistent tile job queue on server instance.")
prompt_server.distributed_pending_tile_jobs = {}
prompt_server.distributed_tile_jobs_lock = asyncio.Lock()
else:
invalid_job_ids = [
job_id
for job_id, job_data in prompt_server.distributed_pending_tile_jobs.items()
if not isinstance(job_data, BaseJobState)
]
for job_id in invalid_job_ids:
debug_log(f"Removing invalid job state for {job_id}")
del prompt_server.distributed_pending_tile_jobs[job_id]
return prompt_server
async def _init_job_queue(
multi_job_id,
mode,
batch_size=None,
num_tiles_per_image=None,
all_indices=None,
enabled_workers=None,
batched_static: bool = False,
):
"""Unified initialization for job queues in static and dynamic modes."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
debug_log(f"Queue already exists for {multi_job_id}")
return
if mode == 'dynamic':
job_data = ImageJobState(multi_job_id=multi_job_id)
elif mode == 'static':
job_data = TileJobState(multi_job_id=multi_job_id)
else:
raise ValueError(f"Unknown mode: {mode}")
job_data.worker_status = {w: time.time() for w in enabled_workers or []}
job_data.assigned_to_workers = {w: [] for w in enabled_workers or []}
if mode == 'dynamic':
job_data.batch_size = int(batch_size or 0)
pending_queue = job_data.pending_images
for i in (all_indices or range(int(batch_size or 0))):
await pending_queue.put(i)
debug_log(f"Initialized image queue with {batch_size} pending items")
elif mode == 'static':
job_data.num_tiles_per_image = int(num_tiles_per_image or 0)
job_data.batch_size = int(batch_size or 0)
job_data.batched_static = bool(batched_static)
# For batched static distribution, populate only tile ids [0..num_tiles_per_image-1]
pending_queue = job_data.pending_tasks
if batched_static and num_tiles_per_image is not None:
for i in range(num_tiles_per_image):
await pending_queue.put(i)
else:
total_tiles = int(batch_size or 0) * int(num_tiles_per_image or 0)
for i in range(total_tiles):
await pending_queue.put(i)
prompt_server.distributed_pending_tile_jobs[multi_job_id] = job_data
async def init_dynamic_job(
multi_job_id: str,
batch_size: int,
enabled_workers: List[str],
all_indices: Optional[List[int]] = None,
):
"""Initialize queue for dynamic mode (per-image), with collector fields."""
await _init_job_queue(
multi_job_id,
'dynamic',
batch_size=batch_size,
all_indices=all_indices or list(range(batch_size)),
enabled_workers=enabled_workers,
)
debug_log(f"Job {multi_job_id} initialized with {batch_size} images")
async def init_static_job_batched(
multi_job_id: str,
batch_size: int,
num_tiles_per_image: int,
enabled_workers: List[str],
):
"""Initialize queue for static mode (batched-per-tile)."""
await _init_job_queue(
multi_job_id,
'static',
batch_size=batch_size,
num_tiles_per_image=num_tiles_per_image,
enabled_workers=enabled_workers,
batched_static=True,
)
async def _drain_results_queue(multi_job_id):
"""Drain pending results from queue and update completed_tasks. Returns count drained."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if not isinstance(job_data, BaseJobState):
return 0
q = job_data.queue
completed_tasks = job_data.completed_tasks
collected = 0
while True:
try:
result = q.get_nowait()
except asyncio.QueueEmpty:
break
worker_id = result['worker_id']
is_last = result.get('is_last', False)
if 'image_idx' in result and 'image' in result:
task_id = result['image_idx']
if task_id not in completed_tasks:
completed_tasks[task_id] = result['image']
collected += 1
elif 'tiles' in result:
for tile_data in result['tiles']:
task_id = tile_data.get('global_idx', tile_data['tile_idx'])
if task_id not in completed_tasks:
completed_tasks[task_id] = tile_data
collected += 1
if is_last:
if worker_id in job_data.worker_status:
del job_data.worker_status[worker_id]
return collected
async def _get_completed_count(multi_job_id):
"""Get count of completed tasks."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, BaseJobState):
return len(job_data.completed_tasks)
return 0
async def _mark_task_completed(multi_job_id, task_id, result):
"""Mark a task as completed."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, BaseJobState):
job_data.completed_tasks[task_id] = result
async def _cleanup_job(multi_job_id):
"""Cleanup the job data."""
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
del prompt_server.distributed_pending_tile_jobs[multi_job_id]
debug_log(f"Cleaned up job {multi_job_id}")
================================================
FILE: upscale/job_timeout.py
================================================
import time
from ..utils.config import load_config
from ..utils.constants import HEARTBEAT_TIMEOUT
from ..utils.logging import debug_log, log
from ..utils.network import build_worker_url, probe_worker
from .job_models import BaseJobState
from .job_store import ensure_tile_jobs_initialized
def _find_worker_record(worker_id):
"""Return worker config entry by id, or None when missing."""
workers = load_config().get("workers", [])
return next((w for w in workers if str(w.get("id")) == str(worker_id)), None)
async def _check_and_requeue_timed_out_workers(multi_job_id, total_tasks):
"""Check timed out workers and requeue their tasks. Returns requeued count."""
prompt_server = ensure_tile_jobs_initialized()
current_time = time.time()
# Allow override via config setting 'worker_timeout_seconds'
cfg = load_config()
hb_timeout = int(cfg.get("settings", {}).get("worker_timeout_seconds", HEARTBEAT_TIMEOUT))
# Snapshot timed-out workers and job details under lock.
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if not isinstance(job_data, BaseJobState):
return 0
completed_tasks_snapshot = set(job_data.completed_tasks.keys())
batched_static_snapshot = bool(job_data.batched_static)
num_tiles_per_image_snapshot = int(job_data.num_tiles_per_image or 1)
batch_size_snapshot = int(job_data.batch_size or 1)
timed_out_workers = []
for worker, last_heartbeat in list(job_data.worker_status.items()):
age = current_time - float(last_heartbeat)
debug_log(f"Timeout check: worker={worker} age={age:.1f}s threshold={hb_timeout}s")
if age > hb_timeout:
timed_out_workers.append(
{
"worker_id": worker,
"last_heartbeat": float(last_heartbeat),
"assigned_tasks": list(job_data.assigned_to_workers.get(worker, [])),
}
)
if not timed_out_workers:
return 0
# Probe outside lock to avoid lock contention on network latency.
workers_to_requeue = []
workers_graced = []
for worker_info in timed_out_workers:
worker = worker_info["worker_id"]
assigned = worker_info["assigned_tasks"]
age = current_time - worker_info["last_heartbeat"]
incomplete_assigned = 0
try:
if assigned:
if batched_static_snapshot:
for task_id in assigned:
for b in range(batch_size_snapshot):
gidx = b * num_tiles_per_image_snapshot + task_id
if gidx not in completed_tasks_snapshot:
incomplete_assigned += 1
break
else:
for task_id in assigned:
if task_id not in completed_tasks_snapshot:
incomplete_assigned += 1
debug_log(
f"Assigned diagnostics: total_assigned={len(assigned)} "
f"incomplete_assigned={incomplete_assigned}"
)
except Exception as e:
debug_log(f"Assigned diagnostics failed for worker {worker}: {e}")
busy = False
probe_queue = None
try:
worker_record = _find_worker_record(worker)
if worker_record:
worker_url = build_worker_url(worker_record)
debug_log(f"Probing worker {worker} at {worker_url}/prompt")
payload = await probe_worker(worker_url, timeout=2.0)
if payload is not None:
probe_queue = int(payload.get("exec_info", {}).get("queue_remaining", 0))
busy = probe_queue is not None and probe_queue > 0
else:
debug_log(f"Probe skipped; worker {worker} not found in config")
except Exception as e:
debug_log(f"Probe failed for worker {worker}: {e}")
finally:
debug_log(
f"Probe diagnostics: online={probe_queue is not None} queue_remaining={probe_queue}"
)
if busy:
workers_graced.append(worker)
debug_log(f"Heartbeat grace: worker {worker} busy via probe; skipping requeue")
continue
log(f"Worker {worker} heartbeat timed out after {age:.1f}s")
workers_to_requeue.append((worker, assigned))
# Re-acquire lock and apply requeue/cleanup decisions.
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if not isinstance(job_data, BaseJobState):
return 0
# Refresh heartbeat for workers that we proved are still busy.
for worker in workers_graced:
if worker in job_data.worker_status:
job_data.worker_status[worker] = current_time
requeued_count = 0
completed_tasks = job_data.completed_tasks
batched_static = bool(job_data.batched_static)
num_tiles_per_image = int(job_data.num_tiles_per_image or 1)
batch_size = int(job_data.batch_size or 1)
for worker, assigned_snapshot in workers_to_requeue:
# Use current assignments if present, falling back to the snapshot.
assigned_tasks = list(job_data.assigned_to_workers.get(worker, assigned_snapshot))
for task_id in assigned_tasks:
# If batched_static, task_id is a tile_idx; consider it complete only if
# all corresponding global_idx entries are present in completed_tasks.
if batched_static:
all_done = True
for b in range(batch_size):
gidx = b * num_tiles_per_image + task_id
if gidx not in completed_tasks:
all_done = False
break
if not all_done:
await job_data.pending_tasks.put(task_id)
requeued_count += 1
else:
if task_id not in completed_tasks:
await job_data.pending_tasks.put(task_id)
requeued_count += 1
job_data.worker_status.pop(worker, None)
if worker in job_data.assigned_to_workers:
job_data.assigned_to_workers[worker] = []
return requeued_count
================================================
FILE: upscale/modes/__init__.py
================================================
================================================
FILE: upscale/modes/dynamic.py
================================================
import asyncio, torch
from PIL import Image
import comfy.model_management
from ...utils.logging import debug_log, log
from ...utils.image import tensor_to_pil, pil_to_tensor
from ...utils.async_helpers import run_async_in_server_loop
from ...utils.config import get_worker_timeout_seconds
from ...utils.constants import TILE_WAIT_TIMEOUT, TILE_SEND_TIMEOUT
from ..job_store import ensure_tile_jobs_initialized, init_dynamic_job
class DynamicModeMixin:
"""
Dynamic (per-image queue) USDU mode behaviors for master and worker roles.
Expected co-mixins on `self`:
- TileOpsMixin (`calculate_tiles`, `_slice_conditioning`, `_process_and_blend_tile`).
- JobStateMixin (image queue/task completion helpers).
- WorkerCommsMixin (`_request_image_from_master`, `_send_full_image_to_master`, `_send_heartbeat_to_master`).
"""
def process_master_dynamic(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_workers):
"""Dynamic mode for large batches - assigns whole images to workers dynamically, including master."""
# Get batch size and dimensions
batch_size, height, width, _ = upscaled_image.shape
num_workers = len(enabled_workers)
log(f"USDU Dist: Image queue distribution | Batch {batch_size} | Workers {num_workers} | Canvas {width}x{height} | Tile {tile_width}x{tile_height}")
# No fixed share - all images are dynamic
all_indices = list(range(batch_size))
debug_log(f"Processing {batch_size} images dynamically across master + {num_workers} workers.")
# Calculate tiles for processing
all_tiles = self.calculate_tiles(width, height, tile_width, tile_height, force_uniform_tiles)
# Initialize job queue for communication
try:
run_async_in_server_loop(
init_dynamic_job(multi_job_id, batch_size, enabled_workers, all_indices),
timeout=2.0
)
except Exception as e:
debug_log(f"UltimateSDUpscale Master - Queue initialization error: {e}")
raise RuntimeError(f"Failed to initialize dynamic mode queue: {e}")
# Convert batch to PIL list
result_images = [tensor_to_pil(upscaled_image[b:b+1], 0).convert('RGB').copy() for b in range(batch_size)]
# Process images dynamically with master participating
prompt_server = ensure_tile_jobs_initialized()
processed_count = 0
consecutive_retries = 0
max_consecutive_retries = 10
# Process loop - master pulls from queue and processes synchronously
while processed_count < batch_size:
# Try to get an image to process
image_idx = run_async_in_server_loop(
self._get_next_image_index(multi_job_id),
timeout=5.0 # Short timeout to allow frequent checks
)
if image_idx is not None:
# Reset retry counter and process locally
consecutive_retries = 0
debug_log(f"Master processing image {image_idx} dynamically")
processed_count += 1
# Process locally
single_tensor = upscaled_image[image_idx:image_idx+1]
local_image = result_images[image_idx]
image_seed = seed
# Pre-slice conditioning once per image (not per tile)
positive_sliced, negative_sliced = self._slice_conditioning(positive, negative, image_idx)
for tile_idx, pos in enumerate(all_tiles):
source_tensor = pil_to_tensor(local_image)
if single_tensor.is_cuda:
source_tensor = source_tensor.cuda()
local_image = self._process_and_blend_tile(
tile_idx, pos, source_tensor, local_image,
model, positive_sliced, negative_sliced, vae, image_seed, steps, cfg,
sampler_name, scheduler, denoise, tile_width, tile_height,
padding, mask_blur, width, height, force_uniform_tiles,
tiled_decode, batch_idx=image_idx
)
# Yield after each tile to minimize worker downtime
run_async_in_server_loop(self._async_yield(), timeout=0.1)
# Note: No per-tile drain here – that's what makes this "per-image"
result_images[image_idx] = local_image
# Mark as completed
run_async_in_server_loop(
self._mark_image_completed(multi_job_id, image_idx, local_image),
timeout=5.0
)
# NEW: Drain after the full image is marked complete (catches workers who finished during master's processing)
drained_count = run_async_in_server_loop(
self._drain_worker_results_queue(multi_job_id),
timeout=5.0
)
if drained_count > 0:
debug_log(f"Drained {drained_count} worker images after master's image {image_idx}")
# NEW: Log overall progress (includes master's image + any drained workers)
completed_now = run_async_in_server_loop(
self._get_total_completed_count(multi_job_id),
timeout=1.0
)
log(f"USDU Dist: Images progress {completed_now}/{batch_size}")
# Yield to allow workers to get new images after completing one
run_async_in_server_loop(self._async_yield(), timeout=0.1)
else:
# Queue empty: collect any queued worker results to update progress
drained_count = run_async_in_server_loop(
self._drain_worker_results_queue(multi_job_id),
timeout=5.0
)
run_async_in_server_loop(self._async_yield(), timeout=0.1) # Yield after drain
# Check for timed out workers and requeue their images
requeued_count = run_async_in_server_loop(
self._check_and_requeue_timed_out_workers(multi_job_id, batch_size),
timeout=5.0
)
run_async_in_server_loop(self._async_yield(), timeout=0.1) # Yield after requeue
if requeued_count > 0:
log(f"Requeued {requeued_count} images from timed out workers")
consecutive_retries = 0 # Reset since we have work to do
continue
# Now check total completed (includes newly collected)
completed_now = run_async_in_server_loop(
self._get_total_completed_count(multi_job_id),
timeout=1.0
)
log(f"USDU Dist: Images progress {completed_now}/{batch_size}")
if completed_now >= batch_size:
break
run_async_in_server_loop(self._async_yield(), timeout=0.1) # Yield before pending check
# Check if there are pending images in the queue (could be requeued)
pending_count = run_async_in_server_loop(
self._get_pending_count(multi_job_id),
timeout=1.0
)
if pending_count > 0:
consecutive_retries = 0 # Reset retries since there's work to do
continue
consecutive_retries += 1
if consecutive_retries >= max_consecutive_retries:
log(f"Max retries ({max_consecutive_retries}) reached. Forcing collection of remaining results.")
break # Force exit to collection phase
debug_log("Waiting for workers")
# Use async sleep to allow event loop to process worker requests
run_async_in_server_loop(asyncio.sleep(2), timeout=3.0)
debug_log(f"Master processed {processed_count} images locally")
# Get all completed images to check what needs to be collected
all_completed = run_async_in_server_loop(
self._get_all_completed_images(multi_job_id),
timeout=5.0
)
# Calculate how many we still need to collect
remaining_to_collect = batch_size - len(all_completed)
if remaining_to_collect > 0:
debug_log(f"Waiting for {remaining_to_collect} more images from workers")
# Use the unified worker timeout for the collection phase
collection_timeout = float(get_worker_timeout_seconds())
collected_images = run_async_in_server_loop(
self._async_collect_dynamic_images(multi_job_id, remaining_to_collect, num_workers, batch_size, processed_count),
timeout=collection_timeout
)
# Merge collected with already completed
all_completed.update(collected_images)
# Update result images with all completed images
for idx, processed_img in all_completed.items():
if idx < batch_size:
result_images[idx] = processed_img
# Convert back to tensor
result_tensor = torch.cat([pil_to_tensor(img) for img in result_images], dim=0) if batch_size > 1 else pil_to_tensor(result_images[0])
if upscaled_image.is_cuda:
result_tensor = result_tensor.cuda()
debug_log(f"UltimateSDUpscale Master - Job {multi_job_id} complete")
log(f"Completed processing all {batch_size} images")
return (result_tensor,)
def process_worker_dynamic(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_worker_ids, dynamic_threshold):
"""Worker processing in dynamic mode - processes whole images."""
# Round tile dimensions
tile_width = self.round_to_multiple(tile_width)
tile_height = self.round_to_multiple(tile_height)
# Get dimensions and tile grid
batch_size, height, width, _ = upscaled_image.shape
all_tiles = self.calculate_tiles(width, height, tile_width, tile_height, force_uniform_tiles)
log(f"USDU Dist Worker[{worker_id[:8]}]: Processing image queue | Batch {batch_size}")
# Keep track of processed images for is_last detection
processed_count = 0
# Poll for job readiness to avoid races during master init
max_poll_attempts = 20 # ~20s at 1s sleep
if not self._poll_job_ready(multi_job_id, master_url, worker_id=worker_id, max_attempts=max_poll_attempts):
log(f"Job {multi_job_id} not ready after {max_poll_attempts} attempts, aborting")
return (upscaled_image,)
# Loop to request and process images
while True:
# Request an image to process
image_idx, estimated_remaining = run_async_in_server_loop(
self._request_image_from_master(multi_job_id, master_url, worker_id),
timeout=TILE_WAIT_TIMEOUT
)
if image_idx is None:
debug_log(f"USDU Dist Worker - No more images to process")
break
debug_log(f"Worker[{worker_id[:8]}] - Assigned image {image_idx}")
processed_count += 1
# Determine if this should be marked as last for this worker
is_last_for_worker = (estimated_remaining == 0)
# Extract single image tensor
single_tensor = upscaled_image[image_idx:image_idx+1]
# Convert to PIL for processing
local_image = tensor_to_pil(single_tensor, 0).copy()
# Process all tiles for this image
image_seed = seed
# Pre-slice conditioning once per image (not per tile)
positive_sliced, negative_sliced = self._slice_conditioning(positive, negative, image_idx)
for tile_idx, pos in enumerate(all_tiles):
source_tensor = pil_to_tensor(local_image)
if single_tensor.is_cuda:
source_tensor = source_tensor.cuda()
local_image = self._process_and_blend_tile(
tile_idx, pos, source_tensor, local_image,
model, positive_sliced, negative_sliced, vae, image_seed, steps, cfg,
sampler_name, scheduler, denoise, tile_width, tile_height,
padding, mask_blur, width, height, force_uniform_tiles,
tiled_decode, batch_idx=image_idx
)
run_async_in_server_loop(
self._send_heartbeat_to_master(multi_job_id, master_url, worker_id),
timeout=5.0
)
# Send processed image back to master
try:
# Use the estimated remaining to determine if this is the last image
is_last = is_last_for_worker
run_async_in_server_loop(
self._send_full_image_to_master(local_image, image_idx, multi_job_id,
master_url, worker_id, is_last),
timeout=TILE_SEND_TIMEOUT
)
# Send heartbeat after processing
run_async_in_server_loop(
self._send_heartbeat_to_master(multi_job_id, master_url, worker_id),
timeout=5.0
)
if is_last:
break
except Exception as e:
log(f"USDU Dist Worker[{worker_id[:8]}] - Error sending image {image_idx}: {e}")
# Continue processing other images
# Send final is_last signal
debug_log(f"Worker[{worker_id[:8]}] processed {processed_count} images, sending completion signal")
try:
run_async_in_server_loop(
self._send_worker_complete_signal(multi_job_id, master_url, worker_id),
timeout=TILE_SEND_TIMEOUT
)
except Exception as e:
log(f"USDU Dist Worker[{worker_id[:8]}] - Error sending completion signal: {e}")
return (upscaled_image,)
================================================
FILE: upscale/modes/single_gpu.py
================================================
import math, torch
from PIL import Image
from ...utils.logging import debug_log, log
from ...utils.image import tensor_to_pil, pil_to_tensor
class SingleGpuModeMixin:
def process_single_gpu(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur, force_uniform_tiles, tiled_decode):
"""Process all tiles on a single GPU (no distribution), batching per tile like USDU."""
# Round tile dimensions
tile_width = self.round_to_multiple(tile_width)
tile_height = self.round_to_multiple(tile_height)
# Get image dimensions and batch size
batch_size, height, width, _ = upscaled_image.shape
# Calculate all tiles
all_tiles = self.calculate_tiles(width, height, tile_width, tile_height, force_uniform_tiles)
rows = math.ceil(height / tile_height)
cols = math.ceil(width / tile_width)
log(
f"USDU Dist: Single GPU | Canvas {width}x{height} | Tile {tile_width}x{tile_height} | Grid {rows}x{cols} ({len(all_tiles)} tiles/image) | Batch {batch_size}"
)
# Prepare result images list
result_images = []
for b in range(batch_size):
image_pil = tensor_to_pil(upscaled_image[b:b+1], 0).convert('RGB')
result_images.append(image_pil.copy())
# Precompute tile masks once
tile_masks = []
for tx, ty in all_tiles:
tile_masks.append(self.create_tile_mask(width, height, tx, ty, tile_width, tile_height, mask_blur))
# Process tiles batched across images
for tile_idx, (tx, ty) in enumerate(all_tiles):
# Progressive state parity: extract each tile from the current updated image batch.
source_batch = torch.cat([pil_to_tensor(img) for img in result_images], dim=0)
if upscaled_image.is_cuda:
source_batch = source_batch.cuda()
# Extract batched tile
tile_batch, x1, y1, ew, eh = self.extract_batch_tile_with_padding(
source_batch, tx, ty, tile_width, tile_height, padding, force_uniform_tiles
)
# Process batch
region = (x1, y1, x1 + ew, y1 + eh)
processed_batch = self.process_tiles_batch(tile_batch, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tiled_decode, region, (width, height))
# Blend results back into each image using cached mask
tile_mask = tile_masks[tile_idx]
for b in range(batch_size):
tile_pil = tensor_to_pil(processed_batch, b)
# Resize back to extracted size
if tile_pil.size != (ew, eh):
tile_pil = tile_pil.resize((ew, eh), Image.LANCZOS)
result_images[b] = self.blend_tile(result_images[b], tile_pil, x1, y1, (ew, eh), tile_mask, padding)
# Convert back to tensor
result_tensors = [pil_to_tensor(img) for img in result_images]
result_tensor = torch.cat(result_tensors, dim=0)
if upscaled_image.is_cuda:
result_tensor = result_tensor.cuda()
return (result_tensor,)
================================================
FILE: upscale/modes/static.py
================================================
import asyncio, time, torch
from PIL import Image
import comfy.model_management
from ...utils.logging import debug_log, log
from ...utils.image import tensor_to_pil, pil_to_tensor
from ...utils.async_helpers import run_async_in_server_loop
from ...utils.config import get_worker_timeout_seconds
from ...utils.constants import (
HEARTBEAT_INTERVAL,
JOB_POLL_INTERVAL,
JOB_POLL_MAX_ATTEMPTS,
MAX_BATCH,
TILE_SEND_TIMEOUT,
TILE_WAIT_TIMEOUT,
)
from ..job_store import (
ensure_tile_jobs_initialized, init_static_job_batched,
_mark_task_completed, _cleanup_job, _drain_results_queue, _get_completed_count,
)
from ..job_models import TileJobState
class StaticModeMixin:
"""
Static (tile-queue) USDU mode behaviors for master and worker roles.
Expected co-mixins on `self`:
- TileOpsMixin (`calculate_tiles`, tile extract/blend helpers).
- JobStateMixin (`_get_next_tile_index`, `_get_all_completed_tasks`, requeue checks).
- WorkerCommsMixin (`send_tiles_batch_to_master`, `_request_tile_from_master`, `_send_heartbeat_to_master`).
"""
def _poll_job_ready(self, multi_job_id, master_url, worker_id=None, max_attempts=JOB_POLL_MAX_ATTEMPTS):
"""Poll master for job readiness to avoid worker/master initialization race."""
for attempt in range(max_attempts):
ready = run_async_in_server_loop(
self._check_job_status(multi_job_id, master_url),
timeout=5.0
)
if ready:
if worker_id:
debug_log(f"Worker[{worker_id[:8]}] job {multi_job_id} ready after {attempt} attempts")
else:
debug_log(f"Job {multi_job_id} ready after {attempt} attempts")
return True
time.sleep(JOB_POLL_INTERVAL)
return False
def _extract_and_process_tile(
self,
upscaled_image,
tile_id,
all_tiles,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
):
"""Extract one tile position for the whole batch and process it."""
tx, ty = all_tiles[tile_id]
tile_batch, x1, y1, ew, eh = self.extract_batch_tile_with_padding(
upscaled_image, tx, ty, tile_width, tile_height, padding, force_uniform_tiles
)
region = (x1, y1, x1 + ew, y1 + eh)
processed_batch = self.process_tiles_batch(
tile_batch, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise, tiled_decode,
region, (width, height)
)
return processed_batch, x1, y1, ew, eh
def _flush_tiles_to_master(
self,
processed_tiles,
multi_job_id,
master_url,
padding,
worker_id,
is_final_flush=False,
):
"""Send accumulated tile payloads to master and return a fresh accumulator."""
if not processed_tiles:
if is_final_flush:
run_async_in_server_loop(
self.send_tiles_batch_to_master(
[],
multi_job_id,
master_url,
padding,
worker_id,
is_final_flush=True,
),
timeout=TILE_SEND_TIMEOUT,
)
return processed_tiles
run_async_in_server_loop(
self.send_tiles_batch_to_master(
processed_tiles,
multi_job_id,
master_url,
padding,
worker_id,
is_final_flush=is_final_flush,
),
timeout=TILE_SEND_TIMEOUT
)
return []
def _master_process_one_tile(
self,
tile_id,
all_tiles,
upscaled_image,
result_images,
tile_masks,
multi_job_id,
batch_size,
num_tiles_per_image,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
):
"""Process one tile_id across the batch and blend into result_images."""
source_batch = torch.cat([pil_to_tensor(img) for img in result_images], dim=0)
if upscaled_image.is_cuda:
source_batch = source_batch.cuda()
processed_batch, x1, y1, ew, eh = self._extract_and_process_tile(
source_batch,
tile_id,
all_tiles,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
)
tile_mask = tile_masks[tile_id]
out_bs = processed_batch.shape[0] if hasattr(processed_batch, "shape") else batch_size
processed_items = min(batch_size, out_bs)
for b in range(processed_items):
tile_pil = tensor_to_pil(processed_batch, b)
if tile_pil.size != (ew, eh):
tile_pil = tile_pil.resize((ew, eh), Image.LANCZOS)
result_images[b] = self.blend_tile(result_images[b], tile_pil, x1, y1, (ew, eh), tile_mask, padding)
global_idx = b * num_tiles_per_image + tile_id
run_async_in_server_loop(
_mark_task_completed(multi_job_id, global_idx, {'batch_idx': b, 'tile_idx': tile_id}),
timeout=5.0
)
return processed_items
def _process_worker_static_sync(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, master_url,
worker_id, enabled_workers):
"""Worker static mode processing with optional dynamic queue pulling."""
# Round tile dimensions
tile_width = self.round_to_multiple(tile_width)
tile_height = self.round_to_multiple(tile_height)
# Get dimensions and calculate tiles
_, height, width, _ = upscaled_image.shape
all_tiles = self.calculate_tiles(width, height, tile_width, tile_height, force_uniform_tiles)
num_tiles_per_image = len(all_tiles)
batch_size = upscaled_image.shape[0]
total_tiles = batch_size * num_tiles_per_image
processed_tiles = []
working_images = []
for b in range(batch_size):
image_pil = tensor_to_pil(upscaled_image[b:b+1], 0)
working_images.append(image_pil.copy())
tile_masks = []
for tx, ty in all_tiles:
tile_masks.append(self.create_tile_mask(width, height, tx, ty, tile_width, tile_height, mask_blur))
# Dynamic queue mode (static processing): process batched-per-tile
log(f"USDU Dist Worker[{worker_id[:8]}]: Canvas {width}x{height} | Tile {tile_width}x{tile_height} | Tiles/image {num_tiles_per_image} | Batch {batch_size}")
processed_count = 0
max_poll_attempts = JOB_POLL_MAX_ATTEMPTS
if not self._poll_job_ready(multi_job_id, master_url, worker_id=worker_id, max_attempts=max_poll_attempts):
log(f"Job {multi_job_id} not ready after {max_poll_attempts} attempts, aborting")
return (upscaled_image,)
# Main processing loop - pull tile ids from queue
while True:
# Request a tile to process
tile_idx, estimated_remaining, batched_static = run_async_in_server_loop(
self._request_tile_from_master(multi_job_id, master_url, worker_id),
timeout=TILE_WAIT_TIMEOUT
)
if tile_idx is None:
debug_log(f"Worker[{worker_id[:8]}] - No more tiles to process")
break
# Always batched-per-tile in static mode
debug_log(f"Worker[{worker_id[:8]}] - Assigned tile_id {tile_idx}")
processed_count += batch_size
tile_id = tile_idx
source_batch = torch.cat([pil_to_tensor(img) for img in working_images], dim=0)
if upscaled_image.is_cuda:
source_batch = source_batch.cuda()
processed_batch, x1, y1, ew, eh = self._extract_and_process_tile(
source_batch,
tile_id,
all_tiles,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
)
# Queue results
for b in range(batch_size):
tile_pil = tensor_to_pil(processed_batch, b)
if tile_pil.size != (ew, eh):
tile_pil = tile_pil.resize((ew, eh), Image.LANCZOS)
working_images[b] = self.blend_tile(
working_images[b],
tile_pil,
x1,
y1,
(ew, eh),
tile_masks[tile_id],
padding,
)
processed_tiles.append({
'tile': processed_batch[b:b+1],
'tile_idx': tile_id,
'x': x1,
'y': y1,
'extracted_width': ew,
'extracted_height': eh,
'padding': padding,
'batch_idx': b,
'global_idx': b * num_tiles_per_image + tile_id
})
# Send heartbeat
try:
run_async_in_server_loop(
self._send_heartbeat_to_master(multi_job_id, master_url, worker_id),
timeout=5.0
)
except Exception as e:
debug_log(f"Worker[{worker_id[:8]}] heartbeat failed: {e}")
# Send tiles in batches within loop
if len(processed_tiles) >= MAX_BATCH:
processed_tiles = self._flush_tiles_to_master(
processed_tiles, multi_job_id, master_url, padding, worker_id, is_final_flush=False
)
# Send any remaining tiles
processed_tiles = self._flush_tiles_to_master(
processed_tiles, multi_job_id, master_url, padding, worker_id, is_final_flush=True
)
debug_log(f"Worker {worker_id} completed all assigned and requeued tiles")
return (upscaled_image,)
async def _async_collect_and_monitor_static(self, multi_job_id, total_tiles, expected_total):
"""Async helper for collection and monitoring in static mode.
Returns collected tasks dict. Caller should check if all tasks are complete."""
last_progress_log = time.time()
progress_interval = 5.0
last_heartbeat_check = time.time()
last_completed_count = 0
while True:
# Check for user interruption
if comfy.model_management.processing_interrupted():
log("Processing interrupted by user")
raise comfy.model_management.InterruptProcessingException()
# Drain any pending results
collected_count = await _drain_results_queue(multi_job_id)
# Check and requeue timed-out workers periodically
current_time = time.time()
if current_time - last_heartbeat_check >= HEARTBEAT_INTERVAL:
requeued_count = await self._check_and_requeue_timed_out_workers(multi_job_id, expected_total)
if requeued_count > 0:
log(f"Requeued {requeued_count} tasks from timed-out workers")
last_heartbeat_check = current_time
# Get current completion count
completed_count = await _get_completed_count(multi_job_id)
# Progress logging
if current_time - last_progress_log >= progress_interval:
log(f"Progress: {completed_count}/{expected_total} tasks completed")
last_progress_log = current_time
# Check if all tasks are completed
if completed_count >= expected_total:
debug_log(f"All {expected_total} tasks completed")
break
# If no active workers remain and there are pending tasks, return for local processing
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, TileJobState):
pending_queue = job_data.pending_tasks
active_workers = list(job_data.worker_status.keys())
if pending_queue and not pending_queue.empty() and len(active_workers) == 0:
log(f"No active workers remaining with {expected_total - completed_count} tasks pending. Returning for local processing.")
break
# Wait a bit before next check
await asyncio.sleep(0.1)
# Get all completed tasks for return
return await self._get_all_completed_tasks(multi_job_id)
def _process_master_static_sync(self, upscaled_image, model, positive, negative, vae,
seed, steps, cfg, sampler_name, scheduler, denoise,
tile_width, tile_height, padding, mask_blur,
force_uniform_tiles, tiled_decode, multi_job_id, enabled_workers,
all_tiles, num_tiles_per_image):
"""Static mode master processing with optional dynamic queue pulling."""
batch_size = upscaled_image.shape[0]
_, height, width, _ = upscaled_image.shape
total_tiles = batch_size * num_tiles_per_image
# Convert batch to PIL list for processing
result_images = []
for b in range(batch_size):
image_pil = tensor_to_pil(upscaled_image[b:b+1], 0)
result_images.append(image_pil.copy())
# Initialize queue: pending queue holds tile ids (batched per tile)
log("USDU Dist: Using tile queue distribution")
run_async_in_server_loop(
init_static_job_batched(multi_job_id, batch_size, num_tiles_per_image, enabled_workers),
timeout=10.0
)
debug_log(
f"Initialized tile-id queue with {num_tiles_per_image} ids for batch {batch_size}"
)
# Precompute masks for all tile positions to avoid repeated Gaussian blur work during blending
tile_masks = []
for idx, (tx, ty) in enumerate(all_tiles):
tile_masks.append(self.create_tile_mask(width, height, tx, ty, tile_width, tile_height, mask_blur))
processed_count = 0
consecutive_no_tile = 0
max_consecutive_no_tile = 2
while processed_count < total_tiles:
comfy.model_management.throw_exception_if_processing_interrupted()
tile_idx = run_async_in_server_loop(
self._get_next_tile_index(multi_job_id),
timeout=5.0
)
if tile_idx is not None:
consecutive_no_tile = 0
tile_id = tile_idx
processed_count += self._master_process_one_tile(
tile_id,
all_tiles,
upscaled_image,
result_images,
tile_masks,
multi_job_id,
batch_size,
num_tiles_per_image,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
)
log(f"USDU Dist: Tiles progress {processed_count}/{total_tiles} (tile {tile_id})")
else:
consecutive_no_tile += 1
if consecutive_no_tile >= max_consecutive_no_tile:
debug_log(f"Master processed {processed_count} tiles, moving to collection phase")
break
time.sleep(0.1)
master_processed_count = processed_count
# Continue processing any remaining tiles while collecting worker results
remaining_tiles = total_tiles - master_processed_count
if remaining_tiles > 0:
debug_log(f"Master waiting for {remaining_tiles} tiles from workers")
# Collect worker results using async operations
try:
# Wait until either all tasks are collected or there are no active workers left
collected_tasks = run_async_in_server_loop(
self._async_collect_and_monitor_static(multi_job_id, total_tiles, expected_total=total_tiles),
timeout=None
)
except comfy.model_management.InterruptProcessingException:
# Clean up job on interruption
run_async_in_server_loop(_cleanup_job(multi_job_id), timeout=5.0)
raise
# Check if we need to process any remaining tasks locally after collection
completed_count = len(collected_tasks)
if completed_count < total_tiles:
log(f"Processing remaining {total_tiles - completed_count} tasks locally after worker failures")
# Process any remaining pending tasks (batched-per-tile)
while True:
# Check for user interruption
comfy.model_management.throw_exception_if_processing_interrupted()
# Get next tile_id from pending queue
tile_id = run_async_in_server_loop(
self._get_next_tile_index(multi_job_id),
timeout=5.0
)
if tile_id is None:
break
self._master_process_one_tile(
tile_id,
all_tiles,
upscaled_image,
result_images,
tile_masks,
multi_job_id,
batch_size,
num_tiles_per_image,
tile_width,
tile_height,
padding,
force_uniform_tiles,
model,
positive,
negative,
vae,
seed,
steps,
cfg,
sampler_name,
scheduler,
denoise,
tiled_decode,
width,
height,
)
else:
# Master processed all tiles
collected_tasks = run_async_in_server_loop(
self._get_all_completed_tasks(multi_job_id),
timeout=5.0
)
# Blend worker tiles synchronously in deterministic tile order.
def _sort_key(item):
global_idx, tile_data = item
batch_idx = tile_data.get('batch_idx', global_idx // num_tiles_per_image)
tile_idx = tile_data.get('tile_idx', global_idx % num_tiles_per_image)
return (tile_idx, batch_idx, global_idx)
for global_idx, tile_data in sorted(collected_tasks.items(), key=_sort_key):
# Skip tiles that don't have tensor data (already processed)
if 'tensor' not in tile_data and 'image' not in tile_data:
continue
batch_idx = tile_data.get('batch_idx', global_idx // num_tiles_per_image)
tile_idx = tile_data.get('tile_idx', global_idx % num_tiles_per_image)
if batch_idx >= batch_size:
continue
# Blend tile synchronously
x = tile_data.get('x', 0)
y = tile_data.get('y', 0)
# Prefer PIL image if present to avoid reconversion
if 'image' in tile_data:
tile_pil = tile_data['image']
else:
tile_tensor = tile_data['tensor']
tile_pil = tensor_to_pil(tile_tensor, 0)
orig_x, orig_y = all_tiles[tile_idx]
tile_mask = tile_masks[tile_idx]
extracted_width = tile_data.get('extracted_width', tile_width + 2 * padding)
extracted_height = tile_data.get('extracted_height', tile_height + 2 * padding)
result_images[batch_idx] = self.blend_tile(result_images[batch_idx], tile_pil,
x, y, (extracted_width, extracted_height), tile_mask, padding)
try:
# Convert back to tensor
if batch_size == 1:
result_tensor = pil_to_tensor(result_images[0])
else:
result_tensors = [pil_to_tensor(img) for img in result_images]
result_tensor = torch.cat(result_tensors, dim=0)
if upscaled_image.is_cuda:
result_tensor = result_tensor.cuda()
log(f"UltimateSDUpscale Master - Job {multi_job_id} complete")
return (result_tensor,)
finally:
# Cleanup (async operation) - always execute
run_async_in_server_loop(_cleanup_job(multi_job_id), timeout=5.0)
================================================
FILE: upscale/payload_parsers.py
================================================
import io
import json
from PIL import Image
def _parse_tiles_from_form(data):
"""Parse tiles submitted via multipart/form-data into a list of tile dicts."""
try:
padding = int(data.get('padding', 0)) if data.get('padding') is not None else 0
except Exception:
padding = 0
meta_raw = data.get('tiles_metadata')
if meta_raw is None:
raise ValueError("Missing tiles_metadata")
try:
metadata = json.loads(meta_raw)
except Exception as e:
raise ValueError(f"Invalid tiles_metadata JSON: {e}")
if not isinstance(metadata, list):
raise ValueError("tiles_metadata must be a list")
tiles = []
for i, meta in enumerate(metadata):
file_field = data.get(f'tile_{i}')
if file_field is None or not hasattr(file_field, 'file'):
raise ValueError(f"Missing tile data for index {i}")
raw = file_field.file.read()
try:
img = Image.open(io.BytesIO(raw)).convert("RGB")
except Exception as e:
raise ValueError(f"Invalid image data for tile {i}: {e}")
try:
tile_info = {
'image': img,
'tile_idx': int(meta.get('tile_idx', i)),
'x': int(meta.get('x', 0)),
'y': int(meta.get('y', 0)),
'extracted_width': int(meta.get('extracted_width', img.width)),
'extracted_height': int(meta.get('extracted_height', img.height)),
'padding': int(padding),
}
except Exception as e:
raise ValueError(f"Invalid metadata values for tile {i}: {e}")
if 'batch_idx' in meta:
try:
tile_info['batch_idx'] = int(meta['batch_idx'])
except Exception:
pass
if 'global_idx' in meta:
try:
tile_info['global_idx'] = int(meta['global_idx'])
except Exception:
pass
tiles.append(tile_info)
return tiles
================================================
FILE: upscale/result_collector.py
================================================
import asyncio, time
import comfy.model_management
import server
from ..utils.constants import DYNAMIC_MODE_MAX_POLL_TIMEOUT, HEARTBEAT_INTERVAL
from ..utils.logging import debug_log, log
from ..utils.config import get_worker_timeout_seconds
from .job_store import ensure_tile_jobs_initialized, _mark_task_completed
from .job_timeout import _check_and_requeue_timed_out_workers
from .job_models import BaseJobState, ImageJobState, TileJobState
class ResultCollectorMixin:
"""
Mixin for master-side result collection in USDU distributed jobs.
Expected co-mixins/attributes:
- JobStateMixin methods for queue/task access.
- `self._check_and_requeue_timed_out_workers(...)` coroutine.
- `self._async_yield(...)` optional helper from WorkerCommsMixin.
"""
def _log_worker_timeout_status(self, job_data, current_time: float, multi_job_id: str) -> list[str]:
"""Log timeout elapsed seconds for each tracked worker and return worker ids."""
if not isinstance(job_data, BaseJobState):
return []
worker_status = dict(job_data.worker_status)
for worker_id, last_seen in worker_status.items():
elapsed = max(0.0, current_time - float(last_seen))
log(
"UltimateSDUpscale Master - Heartbeat timeout: "
f"job={multi_job_id}, worker={worker_id}, elapsed={elapsed:.1f}s"
)
return list(worker_status.keys())
async def _async_collect_results(self, multi_job_id, num_workers, mode='static',
remaining_to_collect=None, batch_size=None):
"""Unified async helper to collect results from workers (tiles or images)."""
# Get the already initialized queue
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id not in prompt_server.distributed_pending_tile_jobs:
raise RuntimeError(f"Job queue not initialized for {multi_job_id}")
job_data = prompt_server.distributed_pending_tile_jobs[multi_job_id]
if mode == 'dynamic':
if not isinstance(job_data, ImageJobState):
raise RuntimeError(
f"Mode mismatch: expected dynamic, got {getattr(job_data, 'mode', 'unknown')}"
)
q = job_data.queue
completed_images = job_data.completed_images
expected_count = remaining_to_collect or batch_size
elif mode == 'static':
if not isinstance(job_data, TileJobState):
raise RuntimeError(
f"Mode mismatch: expected static, got {getattr(job_data, 'mode', 'unknown')}"
)
q = job_data.queue
expected_count = len(job_data.completed_tasks) + job_data.pending_tasks.qsize()
else:
raise RuntimeError(f"Unsupported mode: {mode}")
item_type = "images" if mode == 'dynamic' else "tiles"
debug_log(f"UltimateSDUpscale Master - Starting collection, expecting {expected_count} {item_type} from {num_workers} workers")
collected_results = {}
workers_done = set()
# Unify collector/upscaler wait behavior with the UI worker timeout
timeout = float(get_worker_timeout_seconds())
last_heartbeat_check = time.time()
wait_started_at = time.time()
collected_count = 0
while len(workers_done) < num_workers:
# Check for user interruption
if comfy.model_management.processing_interrupted():
log("Processing interrupted by user")
raise comfy.model_management.InterruptProcessingException()
# For dynamic mode with remaining_to_collect, check if we've collected enough
if mode == 'dynamic' and remaining_to_collect and collected_count >= remaining_to_collect:
break
job_data_snapshot = None
async with prompt_server.distributed_tile_jobs_lock:
current_job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(current_job_data, BaseJobState):
job_data_snapshot = current_job_data
try:
# Shorter poll for dynamic mode, but never exceed the configured timeout
wait_timeout = (min(DYNAMIC_MODE_MAX_POLL_TIMEOUT, timeout) if mode == 'dynamic' else timeout)
result = await asyncio.wait_for(q.get(), timeout=wait_timeout)
worker_id = result['worker_id']
is_last = result.get('is_last', False)
if mode == 'static':
# Handle tiles
tiles = result.get('tiles', [])
debug_log(
f"UltimateSDUpscale Master - Received batch of {len(tiles)} tiles from worker "
f"'{worker_id}' (is_last={is_last})"
)
for tile_data in tiles:
if 'batch_idx' not in tile_data:
log("UltimateSDUpscale Master - Missing batch_idx in tile data, skipping")
continue
tile_idx = tile_data['tile_idx']
key = tile_data.get('global_idx', tile_idx)
entry = {
'tile_idx': tile_idx,
'x': tile_data['x'],
'y': tile_data['y'],
'extracted_width': tile_data['extracted_width'],
'extracted_height': tile_data['extracted_height'],
'padding': tile_data['padding'],
'worker_id': worker_id,
'batch_idx': tile_data.get('batch_idx', 0),
'global_idx': tile_data.get('global_idx', tile_idx),
}
if 'image' in tile_data:
entry['image'] = tile_data['image']
elif 'tensor' in tile_data:
entry['tensor'] = tile_data['tensor']
collected_results[key] = entry
elif mode == 'dynamic':
# Handle full images
if 'image_idx' in result and 'image' in result:
image_idx = result['image_idx']
image_pil = result['image']
completed_images[image_idx] = image_pil
collected_results[image_idx] = image_pil
collected_count += 1
debug_log(f"UltimateSDUpscale Master - Received image {image_idx} from worker {worker_id}")
if is_last:
workers_done.add(worker_id)
debug_log(f"UltimateSDUpscale Master - Worker {worker_id} completed")
except asyncio.TimeoutError:
current_time = time.time()
waiting_workers = self._log_worker_timeout_status(job_data_snapshot, current_time, multi_job_id)
if mode == 'dynamic':
# Check for worker timeouts periodically
if current_time - last_heartbeat_check >= HEARTBEAT_INTERVAL:
# Use the class method to check and requeue
requeued = await self._check_and_requeue_timed_out_workers(multi_job_id, batch_size)
if requeued > 0:
log(f"UltimateSDUpscale Master - Requeued {requeued} images from timed out workers")
last_heartbeat_check = current_time
# Check if we've been waiting too long overall
if current_time - wait_started_at > timeout:
elapsed = current_time - wait_started_at
log(
"UltimateSDUpscale Master - Heartbeat timeout while waiting for images; "
f"workers={waiting_workers}, elapsed={elapsed:.1f}s"
)
break
else:
elapsed = current_time - wait_started_at
log(
f"UltimateSDUpscale Master - Heartbeat timeout waiting for {item_type}; "
f"workers={waiting_workers}, elapsed={elapsed:.1f}s"
)
break
debug_log(f"UltimateSDUpscale Master - Collection complete. Got {len(collected_results)} {item_type} from {len(workers_done)} workers")
# Clean up job queue
async with prompt_server.distributed_tile_jobs_lock:
if multi_job_id in prompt_server.distributed_pending_tile_jobs:
del prompt_server.distributed_pending_tile_jobs[multi_job_id]
return collected_results if mode == 'static' else completed_images
async def _async_collect_worker_tiles(self, multi_job_id, num_workers):
"""Async helper to collect tiles from workers."""
return await self._async_collect_results(multi_job_id, num_workers, mode='static')
async def _mark_image_completed(self, multi_job_id, image_idx, image_pil):
"""Mark an image as completed in the job data."""
# Mark the image as completed with the image data
await _mark_task_completed(multi_job_id, image_idx, {'image': image_pil})
prompt_server = ensure_tile_jobs_initialized()
async with prompt_server.distributed_tile_jobs_lock:
job_data = prompt_server.distributed_pending_tile_jobs.get(multi_job_id)
if isinstance(job_data, ImageJobState):
job_data.completed_images[image_idx] = image_pil
async def _async_collect_dynamic_images(self, multi_job_id, remaining_to_collect, num_workers, batch_size, master_processed_count):
"""Collect remaining processed images from workers."""
return await self._async_collect_results(multi_job_id, num_workers, mode='dynamic',
remaining_to_collect=remaining_to_collect,
batch_size=batch_size)
================================================
FILE: upscale/tile_ops.py
================================================
import math, torch
from contextlib import nullcontext
from PIL import Image, ImageFilter, ImageDraw
from typing import List, Tuple
import comfy.samplers, comfy.model_management
from ..utils.logging import debug_log, log
from ..utils.image import tensor_to_pil, pil_to_tensor
from ..utils.usdu_utils import crop_cond, get_crop_region, expand_crop
from ..utils.crop_model_patch import crop_model_cond
from .conditioning import clone_conditioning
class TileOpsMixin:
def round_to_multiple(self, value: int, multiple: int = 8) -> int:
"""Round value to nearest multiple."""
return round(value / multiple) * multiple
def calculate_tiles(self, image_width: int, image_height: int,
tile_width: int, tile_height: int, force_uniform_tiles: bool = True) -> List[Tuple[int, int]]:
"""Calculate tile positions to match Ultimate SD Upscale.
Positions are a simple grid starting at (0,0) with steps of
`tile_width` and `tile_height`, using ceil(rows/cols) to cover edges.
Uniform vs non-uniform affects only crop/resize, not positions.
"""
rows = math.ceil(image_height / tile_height)
cols = math.ceil(image_width / tile_width)
tiles: List[Tuple[int, int]] = []
for yi in range(rows):
for xi in range(cols):
tiles.append((xi * tile_width, yi * tile_height))
return tiles
def extract_tile_with_padding(self, image: torch.Tensor, x: int, y: int,
tile_width: int, tile_height: int, padding: int,
force_uniform_tiles: bool) -> Tuple[torch.Tensor, int, int, int, int]:
"""Extract a tile region and resize to match USDU cropping logic.
Mirrors ComfyUI_UltimateSDUpscale processing:
- Build a mask with a white rectangle at the tile rect
- Compute crop_region via get_crop_region(mask, padding)
- If force_uniform_tiles: expand by crop/aspect ratio, then resize to
fixed processing size of round_to_multiple(tile + padding)
- Else: target is ceil(crop_size/8)*8 per dimension
- Extract the crop and resize to target tile_size
Returns the resized tensor and crop origin/size for blending.
"""
_, h, w, _ = image.shape
# Create mask and compute initial padded crop region
mask = Image.new('L', (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x, y, x + tile_width, y + tile_height], fill=255)
x1, y1, x2, y2 = get_crop_region(mask, padding)
# Determine crop + processing size
if force_uniform_tiles:
process_w = self.round_to_multiple(tile_width + padding, 8)
process_h = self.round_to_multiple(tile_height + padding, 8)
crop_w = x2 - x1
crop_h = y2 - y1
crop_ratio = crop_w / crop_h if crop_h != 0 else 1.0
process_ratio = process_w / process_h if process_h != 0 else 1.0
if crop_ratio > process_ratio:
target_w = crop_w
target_h = round(crop_w / process_ratio) if process_ratio != 0 else crop_h
else:
target_w = round(crop_h * process_ratio)
target_h = crop_h
(x1, y1, x2, y2), _ = expand_crop((x1, y1, x2, y2), w, h, target_w, target_h)
target_w = process_w
target_h = process_h
else:
crop_w = x2 - x1
crop_h = y2 - y1
target_w = max(8, math.ceil(crop_w / 8) * 8)
target_h = max(8, math.ceil(crop_h / 8) * 8)
(x1, y1, x2, y2), (target_w, target_h) = expand_crop((x1, y1, x2, y2), w, h, target_w, target_h)
# Actual extracted size before resizing (for blending)
extracted_width = x2 - x1
extracted_height = y2 - y1
# Extract tile and resize to processing size
tile = image[:, y1:y2, x1:x2, :]
tile_pil = tensor_to_pil(tile, 0)
if tile_pil.size != (target_w, target_h):
tile_pil = tile_pil.resize((target_w, target_h), Image.LANCZOS)
tile_tensor = pil_to_tensor(tile_pil)
if image.is_cuda:
tile_tensor = tile_tensor.cuda()
return tile_tensor, x1, y1, extracted_width, extracted_height
def extract_batch_tile_with_padding(self, images: torch.Tensor, x: int, y: int,
tile_width: int, tile_height: int, padding: int,
force_uniform_tiles: bool) -> Tuple[torch.Tensor, int, int, int, int]:
"""Extract a tile region for the entire batch and resize to USDU logic.
- Computes a single crop region from a mask at (x,y,w,h) with padding
- force_uniform_tiles controls target processing size logic
- Returns a batched tensor [B,H',W',C] and crop origin/size for blending
"""
batch, h, w, _ = images.shape
# Create mask and compute initial padded crop region (same for all images)
mask = Image.new('L', (w, h), 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x, y, x + tile_width, y + tile_height], fill=255)
x1, y1, x2, y2 = get_crop_region(mask, padding)
# Determine crop + processing size
if force_uniform_tiles:
process_w = self.round_to_multiple(tile_width + padding, 8)
process_h = self.round_to_multiple(tile_height + padding, 8)
crop_w = x2 - x1
crop_h = y2 - y1
crop_ratio = crop_w / crop_h if crop_h != 0 else 1.0
process_ratio = process_w / process_h if process_h != 0 else 1.0
if crop_ratio > process_ratio:
target_w = crop_w
target_h = round(crop_w / process_ratio) if process_ratio != 0 else crop_h
else:
target_w = round(crop_h * process_ratio)
target_h = crop_h
(x1, y1, x2, y2), _ = expand_crop((x1, y1, x2, y2), w, h, target_w, target_h)
target_w = process_w
target_h = process_h
else:
crop_w = x2 - x1
crop_h = y2 - y1
target_w = max(8, math.ceil(crop_w / 8) * 8)
target_h = max(8, math.ceil(crop_h / 8) * 8)
(x1, y1, x2, y2), (target_w, target_h) = expand_crop((x1, y1, x2, y2), w, h, target_w, target_h)
extracted_width = x2 - x1
extracted_height = y2 - y1
# Slice batch region
tiles = images[:, y1:y2, x1:x2, :]
# Resize each tile to target size
resized_tiles = []
for i in range(batch):
tile_pil = tensor_to_pil(tiles, i)
if tile_pil.size != (target_w, target_h):
tile_pil = tile_pil.resize((target_w, target_h), Image.LANCZOS)
resized_tiles.append(pil_to_tensor(tile_pil))
tile_batch = torch.cat(resized_tiles, dim=0)
if images.is_cuda:
tile_batch = tile_batch.cuda()
return tile_batch, x1, y1, extracted_width, extracted_height
def process_tile(self, tile_tensor: torch.Tensor, model, positive, negative, vae,
seed: int, steps: int, cfg: float, sampler_name: str,
scheduler: str, denoise: float, tiled_decode: bool = False,
batch_idx: int = 0, region: Tuple[int, int, int, int] = None,
image_size: Tuple[int, int] = None) -> torch.Tensor:
"""Process a single tile through SD sampling.
Note: positive and negative should already be pre-sliced for the current batch_idx."""
debug_log(f"[process_tile] Processing tile for batch_idx={batch_idx}, seed={seed}, region={region}")
# Import here to avoid circular dependencies
from nodes import common_ksampler, VAEEncode, VAEDecode
# Try to import tiled VAE nodes if available
try:
from nodes import VAEEncodeTiled, VAEDecodeTiled
tiled_vae_available = True
except ImportError:
tiled_vae_available = False
if tiled_decode:
debug_log("Tiled VAE nodes not available, falling back to standard VAE")
# Convert to PIL and back to ensure clean tensor without gradient tracking
tile_pil = tensor_to_pil(tile_tensor, 0)
clean_tensor = pil_to_tensor(tile_pil)
# Ensure tensor is detached and doesn't require gradients
clean_tensor = clean_tensor.detach()
if hasattr(clean_tensor, 'requires_grad_'):
clean_tensor.requires_grad_(False)
# Move to correct device
if tile_tensor.is_cuda:
clean_tensor = clean_tensor.cuda()
clean_tensor = clean_tensor.detach() # Detach again after device transfer
# Clone conditioning per tile (shares models, clones hints for cropping)
positive_tile = clone_conditioning(positive, clone_hints=True)
negative_tile = clone_conditioning(negative, clone_hints=True)
# Crop conditioning to tile region if provided (assumes hints at image resolution)
if region is not None and image_size is not None:
init_size = image_size # (width, height) of full image
canvas_size = image_size
tile_size = (tile_tensor.shape[2], tile_tensor.shape[1]) # (width, height)
w_pad = 0 # No extra pad needed; region already includes padding
h_pad = 0
positive_cropped = crop_cond(positive_tile, region, init_size, canvas_size, tile_size, w_pad, h_pad)
negative_cropped = crop_cond(negative_tile, region, init_size, canvas_size, tile_size, w_pad, h_pad)
else:
# No region cropping needed, use cloned conditioning as-is
positive_cropped = positive_tile
negative_cropped = negative_tile
# Encode to latent (always non-tiled, matching original node)
latent = VAEEncode().encode(vae, clean_tensor)[0]
# Sample with model patch cropping parity (ControlNet patch hints)
if region is not None and image_size is not None:
model_ctx = crop_model_cond(
model,
region,
image_size,
image_size,
(clean_tensor.shape[2], clean_tensor.shape[1]),
)
else:
model_ctx = nullcontext(model)
with model_ctx as model_for_sampling:
samples = common_ksampler(
model_for_sampling, seed, steps, cfg, sampler_name, scheduler,
positive_cropped, negative_cropped, latent, denoise=denoise
)[0]
# Decode back to image
if tiled_decode and tiled_vae_available:
image = VAEDecodeTiled().decode(vae, samples, tile_size=512)[0]
else:
image = VAEDecode().decode(vae, samples)[0]
return image
def process_tiles_batch(self, tile_batch: torch.Tensor, model, positive, negative, vae,
seed: int, steps: int, cfg: float, sampler_name: str,
scheduler: str, denoise: float, tiled_decode: bool,
region: Tuple[int, int, int, int], image_size: Tuple[int, int]) -> torch.Tensor:
"""Process a batch of tiles together (USDU behavior).
tile_batch: [B, H, W, C]
Returns image batch tensor [B, H, W, C]
"""
# Import locally to avoid circular deps
from nodes import common_ksampler, VAEEncode, VAEDecode
try:
from nodes import VAEEncodeTiled, VAEDecodeTiled
tiled_vae_available = True
except ImportError:
tiled_vae_available = False
# Detach and move device
clean = tile_batch.detach()
if hasattr(clean, 'requires_grad_'):
clean.requires_grad_(False)
if tile_batch.is_cuda:
clean = clean.cuda().detach()
# Clone/crop conditioning once for the region
positive_tile = clone_conditioning(positive, clone_hints=True)
negative_tile = clone_conditioning(negative, clone_hints=True)
init_size = image_size
canvas_size = image_size
tile_size = (clean.shape[2], clean.shape[1]) # (W,H)
w_pad = 0
h_pad = 0
positive_cropped = crop_cond(positive_tile, region, init_size, canvas_size, tile_size, w_pad, h_pad)
negative_cropped = crop_cond(negative_tile, region, init_size, canvas_size, tile_size, w_pad, h_pad)
# Encode -> Sample -> Decode
latent = VAEEncode().encode(vae, clean)[0]
with crop_model_cond(model, region, image_size, image_size, tile_size) as model_for_sampling:
samples = common_ksampler(
model_for_sampling, seed, steps, cfg, sampler_name, scheduler,
positive_cropped, negative_cropped, latent, denoise=denoise
)[0]
if tiled_decode and tiled_vae_available:
image = VAEDecodeTiled().decode(vae, samples, tile_size=512)[0]
else:
image = VAEDecode().decode(vae, samples)[0]
return image
def create_tile_mask(self, image_width: int, image_height: int,
x: int, y: int, tile_width: int, tile_height: int,
mask_blur: int) -> Image.Image:
"""Create a mask for blending tiles - matches Ultimate SD Upscale approach.
Creates a black image with a white rectangle at the tile position,
then applies blur to create soft edges.
"""
# Create a full-size mask matching the image dimensions
mask = Image.new('L', (image_width, image_height), 0) # Black background
# Draw white rectangle at tile position
draw = ImageDraw.Draw(mask)
draw.rectangle([x, y, x + tile_width, y + tile_height], fill=255)
# Apply blur to soften edges
if mask_blur > 0:
mask = mask.filter(ImageFilter.GaussianBlur(mask_blur))
return mask
def blend_tile(self, base_image: Image.Image, tile_image: Image.Image,
x: int, y: int, extracted_size: Tuple[int, int],
mask: Image.Image, padding: int) -> Image.Image:
"""Blend a processed tile back into the base image using Ultimate SD Upscale's exact approach.
This follows the exact method from ComfyUI_UltimateSDUpscale/modules/processing.py
"""
extracted_width, extracted_height = extracted_size
# Debug logging (uncomment if needed)
# debug_log(f"[Blend] Placing tile at ({x}, {y}), size: {extracted_width}x{extracted_height}")
# Calculate the crop region that was used for extraction
crop_region = (x, y, x + extracted_width, y + extracted_height)
# The mask is already full-size, no need to crop
# Resize the processed tile back to the extracted size
if tile_image.size != (extracted_width, extracted_height):
tile_resized = tile_image.resize((extracted_width, extracted_height), Image.LANCZOS)
else:
tile_resized = tile_image
# Follow Ultimate SD Upscale blending approach:
# Put the tile into position
image_tile_only = Image.new('RGBA', base_image.size)
image_tile_only.paste(tile_resized, crop_region[:2])
# Add the mask as an alpha channel
# Must make a copy due to the possibility of an edge becoming black
temp = image_tile_only.copy()
temp.putalpha(mask) # Use the full image mask
image_tile_only.paste(temp, image_tile_only)
# Add back the tile to the initial image according to the mask in the alpha channel
result = base_image.convert('RGBA')
result.alpha_composite(image_tile_only)
# Convert back to RGB
return result.convert('RGB')
def _slice_conditioning(self, positive, negative, batch_idx):
"""Helper to slice conditioning for a specific batch index."""
# Clone and slice conditioning properly, including ControlNet hints
positive_sliced = clone_conditioning(positive)
negative_sliced = clone_conditioning(negative)
for cond_list in [positive_sliced, negative_sliced]:
for i in range(len(cond_list)):
emb, cond_dict = cond_list[i]
if emb.shape[0] > 1:
cond_list[i][0] = emb[batch_idx:batch_idx+1]
if 'control' in cond_dict:
control = cond_dict['control']
while control is not None:
hint = control.cond_hint_original
if hint.shape[0] > 1:
control.cond_hint_original = hint[batch_idx:batch_idx+1]
control = control.previous_controlnet
if 'mask' in cond_dict and cond_dict['mask'].shape[0] > 1:
cond_dict['mask'] = cond_dict['mask'][batch_idx:batch_idx+1]
return positive_sliced, negative_sliced
def _process_and_blend_tile(self, tile_idx, tile_pos, upscaled_image, result_image,
model, positive, negative, vae, seed, steps, cfg,
sampler_name, scheduler, denoise, tile_width, tile_height,
padding, mask_blur, image_width, image_height, force_uniform_tiles,
tiled_decode, batch_idx: int = 0):
"""Process a single tile and blend it into the result image."""
x, y = tile_pos
# Extract and process tile
tile_tensor, x1, y1, ew, eh = self.extract_tile_with_padding(
upscaled_image, x, y, tile_width, tile_height, padding, force_uniform_tiles
)
processed_tile = self.process_tile(tile_tensor, model, positive, negative, vae,
seed, steps, cfg, sampler_name,
scheduler, denoise, tiled_decode, batch_idx=batch_idx,
region=(x1, y1, x1 + ew, y1 + eh), image_size=(image_width, image_height))
# Convert and blend
processed_pil = tensor_to_pil(processed_tile, 0)
# Create mask for this specific tile (no cache here; only used in single-tile path)
tile_mask = self.create_tile_mask(image_width, image_height, x, y, tile_width, tile_height, mask_blur)
# Use extraction position and size for blending
result_image = self.blend_tile(result_image, processed_pil,
x1, y1, (ew, eh), tile_mask, padding)
return result_image
def _process_single_tile(self, global_idx, num_tiles_per_image, upscaled_image, all_tiles,
model, positive, negative, vae, seed, steps, cfg, sampler_name,
scheduler, denoise, tiled_decode, tile_width, tile_height, padding,
width, height, force_uniform_tiles, sliced_conditioning_cache):
"""Process a single tile."""
# Calculate which image and tile this corresponds to
batch_idx = global_idx // num_tiles_per_image
tile_idx = global_idx % num_tiles_per_image
# Skip if batch_idx is out of range
if batch_idx >= upscaled_image.shape[0]:
debug_log(f"Warning: Calculated batch_idx {batch_idx} exceeds batch size {upscaled_image.shape[0]}")
return None
# Get or create sliced conditioning for this batch index
if batch_idx not in sliced_conditioning_cache:
positive_sliced, negative_sliced = self._slice_conditioning(positive, negative, batch_idx)
sliced_conditioning_cache[batch_idx] = (positive_sliced, negative_sliced)
else:
positive_sliced, negative_sliced = sliced_conditioning_cache[batch_idx]
x, y = all_tiles[tile_idx]
# Extract tile from the specific image in the batch
tile_tensor, x1, y1, ew, eh = self.extract_tile_with_padding(
upscaled_image[batch_idx:batch_idx+1], x, y, tile_width, tile_height, padding, force_uniform_tiles
)
# Process tile through SD with the exact seed (USDU parity)
image_seed = seed
processed_tile = self.process_tile(tile_tensor, model, positive_sliced, negative_sliced, vae,
image_seed, steps, cfg, sampler_name,
scheduler, denoise, tiled_decode, batch_idx=batch_idx,
region=(x1, y1, x1 + ew, y1 + eh), image_size=(width, height))
return {
'tile': processed_tile,
'global_idx': global_idx,
'batch_idx': batch_idx,
'tile_idx': tile_idx,
'x': x1,
'y': y1,
'extracted_width': ew,
'extracted_height': eh
}
================================================
FILE: upscale/worker_comms.py
================================================
import asyncio, io, json, time
import aiohttp
from PIL import Image
from ..utils.logging import debug_log, log
from ..utils.network import get_client_session
from ..utils.constants import TILE_SEND_TIMEOUT
from ..utils.usdu_managment import MAX_PAYLOAD_SIZE, _send_heartbeat_to_master
from ..utils.image import tensor_to_pil
class WorkerCommsMixin:
async def _send_heartbeat_to_master(self, multi_job_id, master_url, worker_id):
"""Proxy heartbeat helper used by worker processing mixins."""
await _send_heartbeat_to_master(multi_job_id, master_url, worker_id)
async def send_tiles_batch_to_master(self, processed_tiles, multi_job_id, master_url,
padding, worker_id, is_final_flush=False):
"""Send all processed tiles to master, chunked if large."""
if not processed_tiles:
if is_final_flush:
await self._send_tiles_completion_signal(multi_job_id, master_url, worker_id)
return # Early exit if empty
total_tiles = len(processed_tiles)
debug_log(f"Worker[{worker_id[:8]}] - Preparing to send {total_tiles} tiles (size-aware chunks)")
# Prepare encoded images and sizes to enable size-aware chunking
encoded = []
for idx, tile_data in enumerate(processed_tiles):
img = tensor_to_pil(tile_data['tile'], 0)
bio = io.BytesIO()
# Keep compression low to balance speed and size; adjust if needed
img.save(bio, format='PNG', compress_level=0)
raw = bio.getvalue()
encoded.append({
'bytes': raw,
'meta': {
'tile_idx': tile_data['tile_idx'],
'x': tile_data['x'],
'y': tile_data['y'],
'extracted_width': tile_data['extracted_width'],
'extracted_height': tile_data['extracted_height'],
**({'batch_idx': tile_data['batch_idx']} if 'batch_idx' in tile_data else {}),
**({'global_idx': tile_data['global_idx']} if 'global_idx' in tile_data else {}),
}
})
# Size-aware chunking
max_bytes = int(MAX_PAYLOAD_SIZE) - (1024 * 1024) # 1MB headroom
i = 0
chunk_index = 0
while i < total_tiles:
data = aiohttp.FormData()
data.add_field('multi_job_id', multi_job_id)
data.add_field('worker_id', str(worker_id))
data.add_field('padding', str(padding))
metadata = []
used = 0
j = i
while j < total_tiles:
img_bytes = encoded[j]['bytes']
meta = encoded[j]['meta']
# Rough overhead for fields + JSON
overhead = 1024
if used + len(img_bytes) + overhead > max_bytes and j > i:
break
# Accept this tile in this chunk
metadata.append(meta)
data.add_field(f'tile_{j - i}', io.BytesIO(img_bytes), filename=f'tile_{j}.png', content_type='image/png')
used += len(img_bytes) + overhead
j += 1
# Ensure at least one tile per chunk
if j == i:
# Single oversized tile, send anyway
meta = encoded[j]['meta']
metadata.append(meta)
data.add_field('tile_0', io.BytesIO(encoded[j]['bytes']), filename=f'tile_{j}.png', content_type='image/png')
j += 1
chunk_size = j - i
is_chunk_last = (j >= total_tiles)
data.add_field('is_last', str(bool(is_final_flush and is_chunk_last)))
data.add_field('batch_size', str(chunk_size))
data.add_field('tiles_metadata', json.dumps(metadata), content_type='application/json')
# Retry logic with exponential backoff
max_retries = 5
retry_delay = 0.5
for attempt in range(max_retries):
try:
session = await get_client_session()
url = f"{master_url}/distributed/submit_tiles"
async with session.post(url, data=data) as response:
response.raise_for_status()
break
except Exception as e:
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, 5.0)
else:
log(f"UltimateSDUpscale Worker - Failed to send chunk {chunk_index} after {max_retries} attempts: {e}")
raise
debug_log(f"Worker[{worker_id[:8]}] - Sent chunk {chunk_index} ({chunk_size} tiles, ~{used/1e6:.2f} MB)")
chunk_index += 1
i = j
async def _send_tiles_completion_signal(self, multi_job_id, master_url, worker_id):
"""Send completion signal to master in static mode when no tiles are left."""
data = aiohttp.FormData()
data.add_field('multi_job_id', multi_job_id)
data.add_field('worker_id', str(worker_id))
data.add_field('is_last', 'true')
data.add_field('batch_size', '0')
session = await get_client_session()
url = f"{master_url}/distributed/submit_tiles"
async with session.post(url, data=data) as response:
response.raise_for_status()
debug_log(f"Worker {worker_id} sent static completion signal")
async def _request_work_item_from_master(
self,
multi_job_id,
master_url,
worker_id,
endpoint="/distributed/request_image",
):
"""Request one work item from master with retry/backoff and total timeout."""
max_retries = 10
retry_delay = 0.5
start_time = time.monotonic()
url = f"{master_url}{endpoint}"
for attempt in range(max_retries):
if time.monotonic() - start_time > 30:
log(f"Total request timeout after 30s for worker {worker_id}")
return None
try:
session = await get_client_session()
async with session.post(url, json={
'worker_id': str(worker_id),
'multi_job_id': multi_job_id
}) as response:
if response.status == 200:
return await response.json()
if response.status == 404:
text = await response.text()
debug_log(f"Job not found (404), will retry: {text}")
await asyncio.sleep(1.0)
else:
text = await response.text()
debug_log(
f"Request work item failed ({response.status}) for worker {worker_id}: {text}"
)
except Exception as exc:
if attempt < max_retries - 1:
debug_log(f"Retry {attempt + 1}/{max_retries} after error: {exc}")
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, 5.0)
else:
log(f"Failed to request work item after {max_retries} attempts: {exc}")
raise
return None
async def _request_image_from_master(self, multi_job_id, master_url, worker_id):
"""Request an image index to process from master in dynamic mode."""
data = await self._request_work_item_from_master(multi_job_id, master_url, worker_id)
if not data:
return None, 0
image_idx = data.get('image_idx')
estimated_remaining = data.get('estimated_remaining', 0)
return image_idx, estimated_remaining
async def _request_tile_from_master(self, multi_job_id, master_url, worker_id):
"""Request a tile index to process from master in static mode (reusing dynamic infrastructure)."""
data = await self._request_work_item_from_master(multi_job_id, master_url, worker_id)
if not data:
return None, 0, False
tile_idx = data.get('tile_idx')
estimated_remaining = data.get('estimated_remaining', 0)
batched_static = data.get('batched_static', False)
return tile_idx, estimated_remaining, batched_static
async def _send_full_image_to_master(self, image_pil, image_idx, multi_job_id,
master_url, worker_id, is_last):
"""Send a processed full image back to master in dynamic mode."""
# Serialize image to PNG
byte_io = io.BytesIO()
image_pil.save(byte_io, format='PNG', compress_level=0)
byte_io.seek(0)
# Prepare form data
data = aiohttp.FormData()
data.add_field('multi_job_id', multi_job_id)
data.add_field('worker_id', str(worker_id))
data.add_field('image_idx', str(image_idx))
data.add_field('is_last', str(is_last))
data.add_field('full_image', byte_io, filename=f'image_{image_idx}.png',
content_type='image/png')
# Retry logic
max_retries = 5
retry_delay = 0.5
for attempt in range(max_retries):
try:
session = await get_client_session()
url = f"{master_url}/distributed/submit_image"
async with session.post(url, data=data) as response:
response.raise_for_status()
debug_log(f"Successfully sent image {image_idx} to master")
return
except Exception as e:
if attempt < max_retries - 1:
debug_log(f"Retry {attempt + 1}/{max_retries} after error: {e}")
await asyncio.sleep(retry_delay)
retry_delay *= 2
else:
log(f"Failed to send image {image_idx} after {max_retries} attempts: {e}")
raise
async def _send_worker_complete_signal(self, multi_job_id, master_url, worker_id):
"""Send completion signal to master in dynamic mode."""
# Send a dummy request with is_last=True
data = aiohttp.FormData()
data.add_field('multi_job_id', multi_job_id)
data.add_field('worker_id', str(worker_id))
data.add_field('is_last', 'true')
# No image data - just completion signal
session = await get_client_session()
url = f"{master_url}/distributed/submit_image"
async with session.post(url, data=data) as response:
response.raise_for_status()
debug_log(f"Worker {worker_id} sent completion signal")
async def _check_job_status(self, multi_job_id, master_url):
"""Check if job is ready on the master."""
try:
session = await get_client_session()
url = f"{master_url}/distributed/job_status?multi_job_id={multi_job_id}"
async with session.get(url) as response:
if response.status == 200:
data = await response.json()
return data.get('ready', False)
return False
except Exception as e:
debug_log(f"Job status check failed: {e}")
return False
async def _async_yield(self):
"""Simple async yield to allow event loop processing."""
await asyncio.sleep(0)
================================================
FILE: utils/__init__.py
================================================
"""
Utility modules for ComfyUI-Distributed extension.
"""
# Make utils importable as a package
================================================
FILE: utils/async_helpers.py
================================================
"""
Async helper utilities for ComfyUI-Distributed.
"""
import asyncio
import threading
import time
import uuid
import execution
import server
from typing import Optional, Any, Coroutine
from .network import get_server_loop
def run_async_in_server_loop(coro: Coroutine, timeout: Optional[float] = None) -> Any:
"""
Run async coroutine in server's event loop and wait for result.
This is useful when you need to run async code from a synchronous context
but want to use the server's existing event loop instead of creating a new one.
Args:
coro: The coroutine to run
timeout: Optional timeout in seconds
Returns:
The result of the coroutine
Raises:
TimeoutError: If the operation times out
Exception: Any exception raised by the coroutine
"""
event = threading.Event()
result = None
error = None
async def wrapper():
nonlocal result, error
try:
result = await coro
except Exception as e:
error = e
finally:
event.set()
# Schedule on server's event loop
loop = get_server_loop()
asyncio.run_coroutine_threadsafe(wrapper(), loop)
# Wait for completion
if not event.wait(timeout):
raise TimeoutError(f"Async operation timed out after {timeout} seconds")
if error:
raise error
return result
prompt_server = server.PromptServer.instance
def _summarize_node_errors(node_errors: dict) -> str:
if not isinstance(node_errors, dict) or not node_errors:
return ""
parts = []
for node_id, entry in node_errors.items():
if not isinstance(entry, dict):
continue
class_type = str(entry.get("class_type") or "UnknownNode")
for err in entry.get("errors", []):
if not isinstance(err, dict):
continue
message = str(err.get("message") or "validation error")
details = str(err.get("details") or "").strip()
parts.append(
f"{class_type}#{node_id}: {message}{f' ({details})' if details else ''}"
)
if len(parts) >= 5:
return " | ".join(parts)
return " | ".join(parts)
class PromptValidationError(RuntimeError):
"""Raised when a prompt fails ComfyUI validation with structured context."""
def __init__(self, error_payload, node_errors=None):
payload = error_payload if isinstance(error_payload, dict) else {
"type": "prompt_validation_failed",
"message": str(error_payload),
"details": "",
"extra_info": {},
}
self.validation_error = dict(payload)
self.node_errors = node_errors if isinstance(node_errors, dict) else {}
if self.node_errors:
details = str(self.validation_error.get("details") or "").strip()
if not details:
summary = _summarize_node_errors(self.node_errors)
if summary:
self.validation_error["details"] = summary
merged = dict(self.validation_error)
if self.node_errors:
merged["node_errors"] = self.node_errors
super().__init__(f"Invalid prompt: {merged}")
async def queue_prompt_payload(
prompt_obj,
workflow_meta=None,
client_id=None,
include_queue_metadata=False,
):
"""Validate and queue a prompt via ComfyUI's prompt queue."""
payload = {"prompt": prompt_obj}
payload = prompt_server.trigger_on_prompt(payload)
prompt = payload["prompt"]
prompt_id = str(uuid.uuid4())
valid = await execution.validate_prompt(prompt_id, prompt, None)
if not valid[0]:
error_payload = valid[1] if len(valid) > 1 else "Prompt outputs failed validation"
node_errors = valid[3] if len(valid) > 3 else {}
raise PromptValidationError(error_payload, node_errors)
extra_data = {"create_time": int(time.time() * 1000)}
if workflow_meta:
extra_data.setdefault("extra_pnginfo", {})["workflow"] = workflow_meta
if client_id:
extra_data["client_id"] = client_id
sensitive = {}
for key in getattr(execution, "SENSITIVE_EXTRA_DATA_KEYS", []):
if key in extra_data:
sensitive[key] = extra_data.pop(key)
number = getattr(prompt_server, "number", 0)
prompt_server.number = number + 1
prompt_queue_item = (number, prompt_id, prompt, extra_data, valid[2], sensitive)
prompt_server.prompt_queue.put(prompt_queue_item)
if include_queue_metadata:
return {
"prompt_id": prompt_id,
"number": number,
"node_errors": {},
}
return prompt_id
================================================
FILE: utils/audio_payload.py
================================================
import base64
import binascii
import os
import numpy as np
import torch
from .image import ensure_contiguous
MAX_AUDIO_PAYLOAD_BYTES = int(
os.environ.get("COMFYUI_MAX_AUDIO_PAYLOAD_BYTES", str(256 * 1024 * 1024))
)
def encode_audio_payload(audio_payload):
"""Serialize an AUDIO dict into JSON-safe canonical envelope payload."""
if not isinstance(audio_payload, dict):
return None
waveform = audio_payload.get("waveform")
if waveform is None or not isinstance(waveform, torch.Tensor) or waveform.numel() == 0:
return None
sample_rate = audio_payload.get("sample_rate", 44100)
try:
sample_rate = int(sample_rate)
except (TypeError, ValueError):
sample_rate = 44100
waveform_cpu = waveform.detach().to(device="cpu", dtype=torch.float32).contiguous()
data_bytes = waveform_cpu.numpy().tobytes()
if len(data_bytes) > MAX_AUDIO_PAYLOAD_BYTES:
raise ValueError(
f"Audio payload too large: {len(data_bytes)} bytes exceeds {MAX_AUDIO_PAYLOAD_BYTES}."
)
return {
"sample_rate": sample_rate,
"shape": [int(dim) for dim in waveform_cpu.shape],
"dtype": "float32",
"data": base64.b64encode(data_bytes).decode("ascii"),
}
def decode_audio_payload(audio_payload):
"""Decode canonical envelope audio payload into an AUDIO dict."""
if audio_payload is None:
return None
if not isinstance(audio_payload, dict):
raise ValueError("Field 'audio' must be an object when provided.")
encoded = audio_payload.get("data")
shape = audio_payload.get("shape")
sample_rate = audio_payload.get("sample_rate", 44100)
dtype = audio_payload.get("dtype", "float32")
if not isinstance(encoded, str) or not encoded.strip():
raise ValueError("Field 'audio.data' must be a non-empty base64 string.")
if not isinstance(shape, list) or len(shape) != 3:
raise ValueError("Field 'audio.shape' must be a 3-item list [batch, channels, samples].")
if dtype != "float32":
raise ValueError("Field 'audio.dtype' must be 'float32'.")
try:
shape_tuple = tuple(int(dim) for dim in shape)
except (TypeError, ValueError) as exc:
raise ValueError("Field 'audio.shape' must contain integers.") from exc
if shape_tuple[0] <= 0 or shape_tuple[1] <= 0 or shape_tuple[2] < 0:
raise ValueError(
"Field 'audio.shape' must be [batch>0, channels>0, samples>=0]."
)
try:
sample_rate = int(sample_rate)
except (TypeError, ValueError) as exc:
raise ValueError("Field 'audio.sample_rate' must be an integer.") from exc
if sample_rate <= 0:
raise ValueError("Field 'audio.sample_rate' must be positive.")
try:
raw = base64.b64decode(encoded, validate=True)
except (binascii.Error, ValueError) as exc:
raise ValueError("Field 'audio.data' is not valid base64.") from exc
if len(raw) > MAX_AUDIO_PAYLOAD_BYTES:
raise ValueError(
f"Field 'audio.data' too large: {len(raw)} bytes exceeds {MAX_AUDIO_PAYLOAD_BYTES}."
)
expected_bytes = int(np.prod(shape_tuple, dtype=np.int64)) * 4
if len(raw) != expected_bytes:
raise ValueError(
f"Field 'audio.data' byte size mismatch: expected {expected_bytes}, got {len(raw)}."
)
array = np.frombuffer(raw, dtype=np.float32).reshape(shape_tuple)
waveform = torch.from_numpy(array.copy())
return {
"waveform": ensure_contiguous(waveform),
"sample_rate": sample_rate,
}
================================================
FILE: utils/cloudflare/__init__.py
================================================
from .tunnel import CloudflareTunnelManager
cloudflare_tunnel_manager = CloudflareTunnelManager()
__all__ = ["CloudflareTunnelManager", "cloudflare_tunnel_manager"]
================================================
FILE: utils/cloudflare/binary.py
================================================
"""Cloudflared binary discovery and download helpers."""
import os
import platform
import shutil
import stat
from urllib import error as urlerror
from urllib import request
from ..logging import debug_log
def _get_project_root():
return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
def _get_cloudflared_dir():
return os.path.join(_get_project_root(), "bin")
def _get_platform_binary_name():
system = platform.system().lower()
machine = platform.machine().lower()
if system == "windows":
if "arm" in machine:
return "cloudflared-windows-arm64.exe"
return "cloudflared-windows-amd64.exe"
if system == "darwin":
if machine in ("arm64", "aarch64"):
return "cloudflared-darwin-arm64"
return "cloudflared-darwin-amd64"
if system == "linux":
if machine in ("arm64", "aarch64"):
return "cloudflared-linux-arm64"
return "cloudflared-linux-amd64"
raise RuntimeError(f"Unsupported platform for cloudflared: {system}/{machine}")
def _get_binary_path(bin_dir=None):
bin_dir = bin_dir or _get_cloudflared_dir()
binary_name = "cloudflared.exe" if platform.system().lower() == "windows" else "cloudflared"
return os.path.join(bin_dir, binary_name)
def _download_cloudflared():
asset = _get_platform_binary_name()
url = f"https://github.com/cloudflare/cloudflared/releases/latest/download/{asset}"
bin_dir = _get_cloudflared_dir()
os.makedirs(bin_dir, exist_ok=True)
target_path = _get_binary_path(bin_dir)
debug_log(f"Downloading cloudflared from {url}")
try:
with request.urlopen(url, timeout=30) as resp:
with open(target_path, "wb") as f:
shutil.copyfileobj(resp, f)
except urlerror.URLError as exc:
raise RuntimeError(f"Failed to download cloudflared: {exc}") from exc
st = os.stat(target_path)
os.chmod(target_path, st.st_mode | stat.S_IEXEC)
debug_log(f"Downloaded cloudflared to {target_path}")
return target_path
def ensure_binary() -> str:
"""Return a usable cloudflared binary path, downloading if necessary."""
env_path = os.environ.get("CLOUDFLARED_PATH")
if env_path and os.path.exists(env_path):
return env_path
local_candidate = _get_binary_path()
if os.path.exists(local_candidate):
return local_candidate
path_binary = shutil.which("cloudflared")
if path_binary:
return path_binary
return _download_cloudflared()
================================================
FILE: utils/cloudflare/process_reader.py
================================================
"""Background cloudflared process output reader."""
import asyncio
import re
import threading
from ..constants import CLOUDFLARE_LOG_BUFFER_SIZE
from ..logging import debug_log
PUBLIC_URL_PATTERN = re.compile(
r"(https?://[\w.-]+\.(?:trycloudflare\.com|cloudflare\.dev))",
re.IGNORECASE,
)
class ProcessReader:
def __init__(self, log_file=None):
self._process = None
self._thread = None
self._loop = None
self._url_event = None
self._public_url = None
self._last_error = None
self._recent_logs = []
self._log_file = log_file
def set_log_file(self, log_file):
self._log_file = log_file
def _append_log(self, line):
if self._log_file:
try:
with open(self._log_file, "a", encoding="utf-8", errors="replace") as f:
f.write(line + "\n")
except Exception as exc: # pragma: no cover
debug_log(f"Failed to write tunnel log: {exc}")
self._recent_logs.append(line)
if len(self._recent_logs) > CLOUDFLARE_LOG_BUFFER_SIZE:
self._recent_logs = self._recent_logs[-CLOUDFLARE_LOG_BUFFER_SIZE:]
def _reader(self):
process = self._process
if process is None:
return
loop = self._loop
for raw_line in iter(process.stdout.readline, ""):
line = raw_line.strip()
if not line:
continue
self._append_log(line)
match = PUBLIC_URL_PATTERN.search(line)
if match and not self._public_url:
self._public_url = match.group(1).rstrip("/")
if self._url_event and loop:
loop.call_soon_threadsafe(self._url_event.set)
if "error" in line.lower() and not self._last_error:
self._last_error = line
if self._url_event and loop:
if not self._last_error and not self._public_url:
self._last_error = "Cloudflare tunnel exited before becoming ready"
loop.call_soon_threadsafe(self._url_event.set)
def start(self, process, loop):
self._process = process
self._loop = loop
self._url_event = asyncio.Event()
self._public_url = None
self._last_error = None
self._recent_logs = []
self._thread = threading.Thread(target=self._reader, daemon=True)
self._thread.start()
async def wait_for_url(self, timeout):
if not self._url_event:
return None
await asyncio.wait_for(self._url_event.wait(), timeout=timeout)
return self._public_url
def stop(self):
if self._thread and self._thread.is_alive():
self._thread.join(timeout=1)
self._thread = None
self._process = None
self._loop = None
self._url_event = None
def get_url(self):
return self._public_url
def get_last_error(self):
return self._last_error
def get_recent_logs(self):
return list(self._recent_logs)
================================================
FILE: utils/cloudflare/state.py
================================================
"""Cloudflare tunnel state persistence helpers."""
from ..config import load_config, save_config
from ..network import normalize_host
def _get_tunnel_config(cfg):
tunnel_cfg = cfg.get("tunnel", {})
if isinstance(tunnel_cfg, dict):
return tunnel_cfg
return {}
def load_tunnel_state():
cfg = load_config()
tunnel_cfg = _get_tunnel_config(cfg)
master_cfg = cfg.get("master", {}) if isinstance(cfg.get("master", {}), dict) else {}
return {
"status": tunnel_cfg.get("status", "stopped"),
"public_url": tunnel_cfg.get("public_url") or None,
"pid": tunnel_cfg.get("pid"),
"log_file": tunnel_cfg.get("log_file"),
"previous_master_host": tunnel_cfg.get("previous_master_host"),
"master_host": master_cfg.get("host"),
}
def persist_tunnel_state(
status=None,
public_url=None,
pid=None,
log_file=None,
previous_host=None,
master_host=None,
):
cfg = load_config()
tunnel_cfg = _get_tunnel_config(cfg)
if status is not None:
tunnel_cfg["status"] = status
if public_url is not None:
tunnel_cfg["public_url"] = public_url
if pid is not None:
tunnel_cfg["pid"] = pid
if log_file is not None:
tunnel_cfg["log_file"] = log_file
if previous_host is not None:
tunnel_cfg["previous_master_host"] = previous_host
if master_host is not None:
cfg.setdefault("master", {})["host"] = master_host
cfg["tunnel"] = tunnel_cfg
save_config(cfg)
def clear_tunnel_state(log_file=None, previous_host=None, master_host=None):
persist_tunnel_state(
status="stopped",
public_url="",
pid=None,
log_file=log_file,
previous_host=previous_host,
master_host=master_host,
)
def resolve_restore_master_host(previous_master_host):
"""Determine whether master host should be restored after tunnel stop."""
cfg = load_config()
tunnel_cfg = _get_tunnel_config(cfg)
active_url = tunnel_cfg.get("public_url")
current_master_host = (cfg.get("master") or {}).get("host")
if not active_url:
return None
active_host = normalize_host(active_url)
current_host = normalize_host(current_master_host)
if current_host == active_host:
return previous_master_host or ""
return None
================================================
FILE: utils/cloudflare/tunnel.py
================================================
"""Cloudflare tunnel lifecycle manager."""
import asyncio
import os
import shutil
import signal
import subprocess
import time
from ..constants import TUNNEL_START_TIMEOUT
from ..logging import debug_log
from ..network import get_server_port, normalize_host
from ..process import is_process_alive, terminate_process
from .binary import ensure_binary
from .process_reader import ProcessReader
from .state import clear_tunnel_state, load_tunnel_state, persist_tunnel_state, resolve_restore_master_host
class CloudflareTunnelManager:
def __init__(self):
self.process = None
self.pid = None
self.public_url = None
self.last_error = None
self.log_file = None
self.status = "stopped"
self.previous_master_host = None
self._lock = asyncio.Lock()
self._reader = ProcessReader()
self.binary_path = None
self._restore_state()
@property
def base_dir(self):
return os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
def _restore_state(self):
state = load_tunnel_state()
self.public_url = state.get("public_url") or None
self.previous_master_host = state.get("previous_master_host")
self.log_file = state.get("log_file")
pid = state.get("pid")
if pid and is_process_alive(pid):
self.pid = pid
self.status = state.get("status") or "running"
debug_log(f"Detected existing cloudflared process (pid={pid})")
else:
clear_tunnel_state(log_file=self.log_file, previous_host=self.previous_master_host)
self.status = "stopped"
self.pid = None
async def start_tunnel(self):
async with self._lock:
if self.process and self.process.poll() is None:
return {
"status": self.status,
"public_url": self.public_url,
"pid": self.process.pid,
"log_file": self.log_file,
}
if self.pid and is_process_alive(self.pid):
debug_log(f"Stopping stale cloudflared pid {self.pid} before starting a new one")
await self.stop_tunnel()
binary = await asyncio.to_thread(ensure_binary)
self.binary_path = binary
port = get_server_port()
self.status = "starting"
self.last_error = None
self.public_url = None
state = load_tunnel_state()
master_host = state.get("master_host") or ""
if state.get("previous_master_host"):
self.previous_master_host = state.get("previous_master_host")
else:
self.previous_master_host = master_host
os.makedirs(os.path.join(self.base_dir, "logs"), exist_ok=True)
timestamp = time.strftime("%Y%m%d-%H%M%S")
self.log_file = os.path.join(self.base_dir, "logs", f"cloudflare-{timestamp}.log")
cmd = [
binary,
"tunnel",
"--no-autoupdate",
"--url",
f"http://127.0.0.1:{port}",
]
debug_log(f"Starting cloudflared: {' '.join(cmd)}")
try:
self.process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
)
except FileNotFoundError:
self.status = "error"
raise RuntimeError("cloudflared binary not found")
except Exception as exc:
self.status = "error"
raise RuntimeError(f"Failed to start cloudflared: {exc}") from exc
self.pid = self.process.pid
persist_tunnel_state(
status="starting",
pid=self.pid,
log_file=self.log_file,
previous_host=self.previous_master_host,
)
loop = asyncio.get_running_loop()
self._reader.set_log_file(self.log_file)
self._reader.start(self.process, loop)
try:
await self._reader.wait_for_url(timeout=TUNNEL_START_TIMEOUT)
except asyncio.TimeoutError:
self.last_error = "Timed out waiting for Cloudflare to assign a URL"
await self.stop_tunnel()
raise RuntimeError(self.last_error)
public_url = self._reader.get_url()
if not public_url:
self.last_error = self._reader.get_last_error() or "Cloudflare tunnel failed to start"
await self.stop_tunnel()
raise RuntimeError(self.last_error)
self.public_url = public_url
self.status = "running"
debug_log(f"Cloudflare tunnel ready at {self.public_url}")
persist_tunnel_state(
status="running",
public_url=self.public_url,
pid=self.pid,
log_file=self.log_file,
previous_host=self.previous_master_host or "",
master_host=normalize_host(self.public_url),
)
return {
"status": self.status,
"public_url": self.public_url,
"pid": self.pid,
"log_file": self.log_file,
}
async def stop_tunnel(self):
async with self._lock:
pid = self.process.pid if self.process else self.pid
if not pid:
clear_tunnel_state(log_file=self.log_file, previous_host=self.previous_master_host)
self.status = "stopped"
return {"status": "stopped"}
debug_log(f"Stopping cloudflared (pid={pid})")
if self.process:
terminate_process(self.process, timeout=5)
else:
try:
os.kill(pid, signal.SIGTERM)
time.sleep(0.5)
except Exception as exc: # pragma: no cover
debug_log(f"Error stopping cloudflared pid {pid}: {exc}")
restore_host = resolve_restore_master_host(self.previous_master_host)
self.status = "stopped"
self.public_url = None
self.pid = None
self.process = None
self.last_error = None
self._reader.stop()
clear_tunnel_state(
log_file=self.log_file,
previous_host=self.previous_master_host,
master_host=restore_host,
)
return {"status": "stopped"}
def get_status(self):
alive = False
pid = self.process.pid if self.process else self.pid
if pid:
alive = is_process_alive(pid)
if not alive and self.status == "running":
self.status = "stopped"
return {
"status": self.status,
"public_url": self.public_url,
"pid": pid,
"log_file": self.log_file,
"last_error": self.last_error or self._reader.get_last_error(),
"binary_path": self.binary_path or shutil.which("cloudflared"),
"recent_logs": self._reader.get_recent_logs()[-20:],
"previous_master_host": self.previous_master_host,
}
================================================
FILE: utils/config.py
================================================
"""
Configuration management for ComfyUI-Distributed.
"""
import asyncio
import os
import json
from contextlib import asynccontextmanager
from .logging import log
# Import defaults for timeout fallbacks
from .constants import HEARTBEAT_TIMEOUT
CONFIG_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), "gpu_config.json")
_config_cache = None
_config_mtime = 0.0
_config_lock = asyncio.Lock()
def _config_path():
return CONFIG_FILE
def get_default_config():
"""Returns the default configuration dictionary. Single source of truth."""
return {
"master": {"host": ""},
"workers": [],
"settings": {
"debug": False,
"auto_launch_workers": False,
"stop_workers_on_master_exit": True,
"master_delegate_only": False,
"websocket_orchestration": True,
"worker_probe_concurrency": 8,
"worker_prep_concurrency": 4,
"media_sync_concurrency": 2,
"media_sync_timeout_seconds": 120
},
"tunnel": {
"status": "stopped",
"public_url": "",
"pid": None,
"log_file": "",
"previous_master_host": ""
}
}
def _merge_with_defaults(data, defaults):
"""Recursively merge loaded config data with default keys."""
if not isinstance(data, dict):
return defaults
merged = {}
for key, default_value in defaults.items():
loaded_value = data.get(key, default_value)
if isinstance(default_value, dict) and isinstance(loaded_value, dict):
merged[key] = _merge_with_defaults(loaded_value, default_value)
else:
merged[key] = loaded_value
# Preserve unknown keys for forward compatibility.
for key, value in data.items():
if key not in merged:
merged[key] = value
return merged
def invalidate_config_cache():
"""Invalidate in-memory config cache so next load reads from disk."""
global _config_cache, _config_mtime
_config_cache = None
_config_mtime = 0.0
def load_config():
"""Loads the config, falling back to defaults if the file is missing or invalid."""
global _config_cache, _config_mtime
path = _config_path()
try:
mtime = os.path.getmtime(path)
except OSError:
if _config_cache is None:
_config_cache = get_default_config()
return _config_cache
if _config_cache is None or mtime != _config_mtime:
try:
with open(path, 'r', encoding='utf-8') as f:
loaded = json.load(f)
_config_cache = _merge_with_defaults(loaded, get_default_config())
except Exception as e:
log(f"Error loading config, using defaults: {e}")
_config_cache = get_default_config()
_config_mtime = mtime
return _config_cache
def save_config(config):
"""Saves the configuration to file."""
tmp_path = f"{_config_path()}.tmp"
try:
with open(tmp_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, _config_path())
invalidate_config_cache()
return True
except Exception as e:
try:
os.unlink(tmp_path)
except OSError:
pass
log(f"Error saving config: {e}")
return False
@asynccontextmanager
async def config_transaction():
"""Acquire config lock, yield loaded config, and save if changed."""
async with _config_lock:
config = load_config()
original_snapshot = json.dumps(config, sort_keys=True)
yield config
updated_snapshot = json.dumps(config, sort_keys=True)
if updated_snapshot != original_snapshot:
if not save_config(config):
raise RuntimeError("Failed to save config")
def ensure_config_exists():
"""Creates default config file if it doesn't exist. Used by __init__.py"""
if not os.path.exists(_config_path()):
default_config = get_default_config()
if save_config(default_config):
from .logging import debug_log
debug_log("Created default config file")
else:
log("Could not create default config file")
def get_worker_timeout_seconds(default: int = HEARTBEAT_TIMEOUT) -> int:
"""Return the unified worker timeout (seconds).
Priority:
1) UI-configured setting `settings.worker_timeout_seconds`
2) Fallback to provided `default` (defaults to HEARTBEAT_TIMEOUT which itself
can be overridden via the COMFYUI_HEARTBEAT_TIMEOUT env var)
This value should be used anywhere we consider a worker "timed out" from the
master's perspective (e.g., collector waits, upscaler result collection).
"""
try:
cfg = load_config()
val = int(cfg.get('settings', {}).get('worker_timeout_seconds', default))
return max(1, val)
except Exception:
return max(1, int(default))
def is_master_delegate_only() -> bool:
"""Returns True when master should skip local workload and act as orchestrator only."""
try:
cfg = load_config()
return bool(cfg.get('settings', {}).get('master_delegate_only', False))
except Exception:
return False
================================================
FILE: utils/constants.py
================================================
"""
Shared constants for ComfyUI-Distributed.
"""
import os
# Timeouts (in seconds)
WORKER_JOB_TIMEOUT = 30.0
TILE_COLLECTION_TIMEOUT = 30.0
TILE_WAIT_TIMEOUT = 30.0
PROCESS_TERMINATION_TIMEOUT = 5.0
# Process monitoring
WORKER_CHECK_INTERVAL = 2.0
STATUS_CHECK_INTERVAL = 5.0
# Cloudflare tunnel
TUNNEL_START_TIMEOUT = float(os.environ.get("TUNNEL_START_TIMEOUT", "25"))
CLOUDFLARE_LOG_BUFFER_SIZE = 200
# Network
CHUNK_SIZE = 8192
LOG_TAIL_BYTES = 65536 # 64KB
# File paths
WORKER_LOG_PATTERN = "distributed_worker_*.log"
# Worker management
WORKER_STARTUP_DELAY = 2.0
# Tile transfer
TILE_TRANSFER_TIMEOUT = 30.0
# Process cleanup
PROCESS_WAIT_TIMEOUT = 3.0
QUEUE_INIT_TIMEOUT = 5.0
TILE_SEND_TIMEOUT = 60.0
JOB_INIT_GRACE_PERIOD = 10.0
# Memory operations
MEMORY_CLEAR_DELAY = 0.5
# Batch processing
MAX_BATCH = int(os.environ.get('COMFYUI_MAX_BATCH', '20')) # Maximum items per batch to prevent timeouts/OOM (~100MB chunks for 512x512 PNGs)
# Heartbeat monitoring
HEARTBEAT_INTERVAL = float(os.environ.get('COMFYUI_HEARTBEAT_INTERVAL', '10')) # Heartbeat/check interval in seconds
HEARTBEAT_TIMEOUT = int(os.environ.get('COMFYUI_HEARTBEAT_TIMEOUT', '60')) # Worker heartbeat timeout in seconds (default 60s)
# USDU result collection
DYNAMIC_MODE_MAX_POLL_TIMEOUT = 10.0
# Static mode job poll loop
JOB_POLL_INTERVAL = 1.0
JOB_POLL_MAX_ATTEMPTS = 20
# Orchestration pipeline
ORCHESTRATION_WORKER_PROBE_CONCURRENCY = int(
os.environ.get('COMFYUI_ORCHESTRATION_WORKER_PROBE_CONCURRENCY', '8')
)
ORCHESTRATION_WORKER_PREP_CONCURRENCY = int(
os.environ.get('COMFYUI_ORCHESTRATION_WORKER_PREP_CONCURRENCY', '4')
)
ORCHESTRATION_MEDIA_SYNC_CONCURRENCY = int(
os.environ.get('COMFYUI_ORCHESTRATION_MEDIA_SYNC_CONCURRENCY', '2')
)
ORCHESTRATION_MEDIA_SYNC_TIMEOUT = float(
os.environ.get('COMFYUI_ORCHESTRATION_MEDIA_SYNC_TIMEOUT', '120')
)
================================================
FILE: utils/crop_model_patch.py
================================================
from contextlib import contextmanager
import torch
from .logging import debug_log
from .usdu_utils import resize_region
@contextmanager
def crop_model_cond(model, crop_regions, init_size, canvas_size, tile_size, latent_crop=False):
"""Clone model and crop compatible model patches for tile-local sampling."""
try:
patched_model = model.clone()
except Exception:
# Fallback to original model when clone/patch access is unavailable.
yield model
return
patches = (
patched_model
.model_options
.get("transformer_options", {})
.get("patches", {})
)
applied_croppers = {}
for _module, module_patches in patches.items():
for patch in module_patches:
if id(patch) in applied_croppers:
continue
if type(patch).__name__ not in ("DiffSynthCnetPatch", "ZImageControlPatch"):
continue
try:
cropper = ModelPatchCropper(patch).crop(crop_regions, canvas_size, latent_crop)
applied_croppers[id(patch)] = cropper
except Exception as exc:
debug_log(f"crop_model_cond: patch crop skipped for {type(patch).__name__}: {exc}")
try:
yield patched_model
finally:
for cropper in applied_croppers.values():
del cropper
class ModelPatchCropper:
"""Stateful crop helper that restores model patch tensors on cleanup."""
def __init__(self, patch):
self.patch = patch
self.original_state = {
"image": patch.image.clone() if isinstance(patch.image, torch.Tensor) else patch.image,
"encoded_image": patch.encoded_image.clone() if isinstance(patch.encoded_image, torch.Tensor) else patch.encoded_image,
"encoded_image_size": patch.encoded_image_size,
}
self.patch_class = type(patch).__name__
required_attrs = (
"image",
"model_patch",
"vae",
"strength",
"encoded_image",
"encoded_image_size",
)
missing = [attr for attr in required_attrs if not hasattr(patch, attr)]
if missing:
raise AttributeError(
f"{self.patch_class} missing required attrs: {', '.join(missing)}"
)
def __del__(self):
self.patch.image = self.original_state["image"]
self.patch.encoded_image = self.original_state["encoded_image"]
self.patch.encoded_image_size = self.original_state["encoded_image_size"]
def crop(self, crop_regions, canvas_size, latent_crop=True):
patch = self.patch
if not isinstance(crop_regions, list):
crop_regions = [crop_regions]
image_size = (patch.image.shape[2], patch.image.shape[1]) # (W,H)
cropped_images = []
for crop_region in crop_regions:
resized_crop = resize_region(crop_region, canvas_size, image_size)
x1, y1, x2, y2 = resized_crop
cropped_image = patch.image[:, y1:y2, x1:x2, :]
cropped_images.append(cropped_image)
concatenated_image = torch.cat(cropped_images, dim=0)
patch.image = concatenated_image
patch.encoded_image_size = (
concatenated_image.shape[1],
concatenated_image.shape[2],
)
if latent_crop:
downscale_ratio = patch.vae.spacial_compression_encode()
cropped_latents = []
for crop_region in crop_regions:
resized_crop = resize_region(crop_region, canvas_size, image_size)
x1, y1, x2, y2 = tuple(x // downscale_ratio for x in resized_crop)
cropped_latent = patch.encoded_image[:, :, y1:y2, x1:x2]
cropped_latents.append(cropped_latent)
patch.encoded_image = torch.cat(cropped_latents, dim=0)
else:
patch.__init__(
patch.model_patch,
patch.vae,
concatenated_image,
patch.strength,
inpaint_image=patch.inpaint_image,
mask=patch.mask,
)
return self
================================================
FILE: utils/exceptions.py
================================================
"""Custom exceptions for ComfyUI-Distributed."""
class DistributedError(Exception):
"""Base exception for all ComfyUI-Distributed errors."""
class WorkerError(DistributedError):
"""Error related to a specific distributed worker."""
def __init__(self, message, worker_id=None):
super().__init__(message)
self.worker_id = worker_id
class WorkerTimeoutError(WorkerError):
"""Worker did not respond within the expected timeout."""
class WorkerNotAvailableError(WorkerError):
"""Worker is unreachable or not running."""
class JobQueueError(DistributedError):
"""Error in distributed job queue management."""
class TileCollectionError(DistributedError):
"""Error collecting processed tiles from workers."""
class ProcessError(DistributedError):
"""Error managing a worker subprocess."""
def __init__(self, message, pid=None, worker_id=None):
super().__init__(message)
self.pid = pid
self.worker_id = worker_id
class TunnelError(DistributedError):
"""Error managing the Cloudflare tunnel."""
================================================
FILE: utils/image.py
================================================
"""
Image and tensor conversion utilities for ComfyUI-Distributed.
"""
import torch
import numpy as np
from PIL import Image
def tensor_to_pil(img_tensor, batch_index=0):
"""Takes a batch of images in tensor form [B, H, W, C] and returns an RGB PIL Image."""
return Image.fromarray((255 * img_tensor[batch_index].cpu().numpy()).astype(np.uint8))
def pil_to_tensor(image):
"""Takes a PIL image and returns a tensor of shape [1, H, W, C]."""
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0)
if len(image.shape) == 3: # If grayscale, add channel dimension
image = image.unsqueeze(-1)
return image
def ensure_contiguous(tensor):
"""Ensure tensor is contiguous in memory."""
if not tensor.is_contiguous():
return tensor.contiguous()
return tensor
================================================
FILE: utils/logging.py
================================================
"""
Shared logging utilities for ComfyUI-Distributed.
"""
import os
import json
import time
# Config file is in parent directory
CONFIG_FILE = os.path.join(os.path.dirname(os.path.dirname(__file__)), "gpu_config.json")
_debug_cache: bool | None = None
_debug_cache_time: float = 0.0
_DEBUG_TTL: float = 5.0
def is_debug_enabled():
"""Check if debug is enabled."""
global _debug_cache, _debug_cache_time
now = time.monotonic()
if _debug_cache is not None and (now - _debug_cache_time) < _DEBUG_TTL:
return _debug_cache
enabled = False
if os.path.exists(CONFIG_FILE):
try:
with open(CONFIG_FILE, 'r') as f:
config = json.load(f)
enabled = config.get("settings", {}).get("debug", False)
except (OSError, json.JSONDecodeError, ValueError):
pass
_debug_cache = enabled
_debug_cache_time = now
return enabled
def debug_log(message):
"""Log debug messages only if debug is enabled in config."""
if is_debug_enabled():
print(f"[Distributed] {message}")
def log(message):
"""Always log important messages."""
print(f"[Distributed] {message}")
================================================
FILE: utils/network.py
================================================
"""
Network and API utilities for ComfyUI-Distributed.
"""
import asyncio
import aiohttp
import re
import server
from aiohttp import web
from .logging import debug_log
# Shared session for connection pooling
_client_session = None
async def get_client_session():
"""Get or create a shared aiohttp client session."""
global _client_session
try:
asyncio.get_running_loop()
except RuntimeError as exc:
raise RuntimeError("get_client_session() requires an active asyncio event loop.") from exc
if _client_session is None or _client_session.closed:
connector = aiohttp.TCPConnector(limit=100, limit_per_host=30)
# Don't set timeout here - set it per request
_client_session = aiohttp.ClientSession(connector=connector)
return _client_session
async def cleanup_client_session():
"""Clean up the shared client session."""
global _client_session
if _client_session and not _client_session.closed:
await _client_session.close()
_client_session = None
async def handle_api_error(request, error, status=500):
"""Standardized error response handler."""
if isinstance(error, list):
messages = [str(item) for item in error]
debug_log(f"API Error [{status}]: {messages}")
return web.json_response({"errors": messages}, status=status)
message = str(error)
debug_log(f"API Error [{status}]: {message}")
return web.json_response({"error": message}, status=status)
def get_server_port():
"""Get the ComfyUI server port."""
import server
return server.PromptServer.instance.port
def get_server_loop():
"""Get the ComfyUI server event loop."""
import server
return server.PromptServer.instance.loop
def normalize_host(value):
if value is None:
return None
if not isinstance(value, str):
return value
host = value.strip()
if not host:
return host
host = re.sub(r"^https?://", "", host, flags=re.IGNORECASE)
return host.split("/")[0]
def _split_host_and_port(host):
if not host:
return host, None
if host.startswith("["):
match = re.match(r"^(\[[^\]]+\])(?::(\d+))?$", host)
if match:
parsed_port = int(match.group(2)) if match.group(2) else None
return match.group(1), parsed_port
return host, None
if host.count(":") == 1:
candidate_host, candidate_port = host.rsplit(":", 1)
if candidate_port.isdigit():
return candidate_host, int(candidate_port)
return host, None
def build_worker_url(worker, endpoint=""):
"""Construct the worker base URL with optional endpoint."""
host = (worker.get("host") or "").strip()
port = int(worker.get("port", worker.get("listen_port", 8188)) or 8188)
if not host:
host = getattr(server.PromptServer.instance, "address", "127.0.0.1") or "127.0.0.1"
if host.startswith(("http://", "https://")):
base = host.rstrip("/")
else:
is_cloud = worker.get("type") == "cloud" or host.endswith(".proxy.runpod.net") or port == 443
scheme = "https" if is_cloud else "http"
default_port = 443 if scheme == "https" else 80
port_part = "" if port == default_port else f":{port}"
base = f"{scheme}://{host}{port_part}"
return f"{base}{endpoint}"
async def probe_worker(worker_url: str, timeout: float = 5.0) -> dict | None:
"""GET {worker_url}/prompt. Returns parsed JSON or None on any failure."""
base_url = (worker_url or "").strip().rstrip("/")
if not base_url:
return None
probe_url = base_url if base_url.endswith("/prompt") else f"{base_url}/prompt"
session = await get_client_session()
try:
async with session.get(
probe_url,
timeout=aiohttp.ClientTimeout(total=float(timeout)),
) as response:
if response.status != 200:
debug_log(f"[Distributed] Worker probe non-200 status: {response.status} ({probe_url})")
return None
payload = await response.json()
if isinstance(payload, dict):
return payload
debug_log(f"[Distributed] Worker probe returned non-object JSON: {probe_url}")
return None
except asyncio.TimeoutError:
debug_log(f"[Distributed] Worker probe timed out: {probe_url}")
return None
except aiohttp.ClientConnectorError:
debug_log(f"[Distributed] Worker unreachable: {probe_url}")
return None
except Exception as exc:
debug_log(f"[Distributed] Worker probe error ({probe_url}): {exc}")
return None
def build_master_url(config=None, prompt_server_instance=None):
"""Build the best public URL workers should use to reach the master."""
if config is None:
from .config import load_config
config = load_config()
prompt_server_instance = prompt_server_instance or server.PromptServer.instance
master_cfg = (config or {}).get("master", {}) or {}
configured_host = (master_cfg.get("host") or "").strip()
runtime_port = getattr(prompt_server_instance, "port", 8188) or 8188
def _needs_https(hostname):
hostname = hostname.lower()
https_domains = (
".proxy.runpod.net",
".ngrok-free.app",
".ngrok-free.dev",
".ngrok.io",
".trycloudflare.com",
".cloudflare.dev",
)
return any(hostname.endswith(suffix) for suffix in https_domains)
if configured_host:
if configured_host.startswith(("http://", "https://")):
return configured_host.rstrip("/")
host, explicit_port = _split_host_and_port(configured_host)
port = explicit_port if explicit_port is not None else int(runtime_port)
scheme = "https" if _needs_https(host) or port == 443 else "http"
default_port_for_scheme = 443 if scheme == "https" else 80
if explicit_port is None and scheme == "https" and _needs_https(host):
port = default_port_for_scheme
port_part = "" if port == default_port_for_scheme else f":{port}"
return f"{scheme}://{host}{port_part}"
address = getattr(prompt_server_instance, "address", "127.0.0.1") or "127.0.0.1"
if address in ("0.0.0.0", "::"):
address = "127.0.0.1"
port = int(runtime_port)
scheme = "https" if port == 443 else "http"
default_port_for_scheme = 443 if scheme == "https" else 80
port_part = "" if port == default_port_for_scheme else f":{port}"
return f"{scheme}://{address}{port_part}"
def build_master_callback_url(worker, config=None, prompt_server_instance=None):
"""Build the callback URL a specific worker should use to reach the master."""
prompt_server_instance = prompt_server_instance or server.PromptServer.instance
worker_type = str((worker or {}).get("type") or "").strip().lower()
worker_host = normalize_host((worker or {}).get("host"))
local_hosts = {"", "localhost", "127.0.0.1", "::1", "[::1]", "0.0.0.0"}
is_local_worker = worker_type == "local" or worker_host in local_hosts
if is_local_worker:
port = int(getattr(prompt_server_instance, "port", 8188) or 8188)
scheme = "https" if port == 443 else "http"
default_port_for_scheme = 443 if scheme == "https" else 80
port_part = "" if port == default_port_for_scheme else f":{port}"
return f"{scheme}://127.0.0.1{port_part}"
return build_master_url(config=config, prompt_server_instance=prompt_server_instance)
================================================
FILE: utils/process.py
================================================
"""
Process management utilities for ComfyUI-Distributed.
"""
import os
import subprocess
import platform
import signal
def is_process_alive(pid):
"""Check if a process with given PID is still alive."""
try:
if platform.system() == "Windows":
# Windows: use tasklist
result = subprocess.run(['tasklist', '/FI', f'PID eq {pid}'],
capture_output=True, text=True)
return str(pid) in result.stdout
else:
# Unix: send signal 0
os.kill(pid, 0)
return True
except (OSError, subprocess.SubprocessError):
return False
def terminate_process(process, timeout=5):
"""Gracefully terminate a process with timeout."""
if process.poll() is None: # Still running
process.terminate()
try:
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
def get_python_executable():
"""Get the Python executable path."""
import sys
return sys.executable
================================================
FILE: utils/trace_logger.py
================================================
from .logging import debug_log, log
def trace_prefix(trace_execution_id: str) -> str:
return f"[Distributed][exec:{trace_execution_id}]"
def trace_debug(trace_execution_id: str, message: str) -> None:
debug_log(f"{trace_prefix(trace_execution_id)} {message}")
def trace_info(trace_execution_id: str, message: str) -> None:
log(f"{trace_prefix(trace_execution_id)} {message}")
================================================
FILE: utils/usdu_managment.py
================================================
"""Backward-compatibility shim for USDU helpers.
Route handlers and job logic now live in:
- upscale.job_store
- upscale.job_timeout
- upscale.payload_parsers
- api.usdu_routes
"""
from ..upscale.conditioning import clone_conditioning, clone_control_chain
from ..upscale.job_store import (
MAX_PAYLOAD_SIZE,
_cleanup_job,
_drain_results_queue,
_get_completed_count,
_init_job_queue,
_mark_task_completed,
ensure_tile_jobs_initialized,
init_dynamic_job,
init_static_job_batched,
)
from ..upscale.job_timeout import _check_and_requeue_timed_out_workers
from ..upscale.payload_parsers import _parse_tiles_from_form
from .logging import debug_log
from .network import get_client_session
async def _send_heartbeat_to_master(multi_job_id, master_url, worker_id):
"""Send heartbeat to master from worker-side processing loops."""
try:
data = {'multi_job_id': multi_job_id, 'worker_id': str(worker_id)}
session = await get_client_session()
url = f"{master_url}/distributed/heartbeat"
async with session.post(url, json=data) as response:
response.raise_for_status()
except Exception as e:
debug_log(f"Heartbeat failed: {e}")
__all__ = [
"MAX_PAYLOAD_SIZE",
"_check_and_requeue_timed_out_workers",
"_cleanup_job",
"_drain_results_queue",
"_get_completed_count",
"_init_job_queue",
"_mark_task_completed",
"_parse_tiles_from_form",
"_send_heartbeat_to_master",
"clone_conditioning",
"clone_control_chain",
"ensure_tile_jobs_initialized",
"init_dynamic_job",
"init_static_job_batched",
]
================================================
FILE: utils/usdu_utils.py
================================================
import numpy as np
from PIL import Image, ImageFilter
import torch
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
import math
if (not hasattr(Image, 'Resampling')): # For older versions of Pillow
Image.Resampling = Image
BLUR_KERNEL_SIZE = 15
def tensor_to_pil(img_tensor, batch_index=0):
# Takes a batch of images in the form of a tensor of shape [batch_size, height, width, channels]
# and returns an RGB PIL Image. Assumes channels=3
return Image.fromarray((255 * img_tensor[batch_index].cpu().numpy()).astype(np.uint8))
def pil_to_tensor(image):
# Takes a PIL image and returns a tensor of shape [1, height, width, channels]
image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image).unsqueeze(0)
if len(image.shape) == 3: # If the image is grayscale, add a channel dimension
image = image.unsqueeze(-1)
return image
def controlnet_hint_to_pil(tensor, batch_index=0):
return tensor_to_pil(tensor.movedim(1, -1), batch_index)
def pil_to_controlnet_hint(img):
return pil_to_tensor(img).movedim(-1, 1)
def crop_tensor(tensor, region):
# Takes a tensor of shape [batch_size, height, width, channels] and crops it to the given region
x1, y1, x2, y2 = region
return tensor[:, y1:y2, x1:x2, :]
def resize_tensor(tensor, size, mode="nearest-exact"):
# Takes a tensor of shape [B, C, H, W] and resizes
# it to a shape of [B, C, size[0], size[1]] using the given mode
return torch.nn.functional.interpolate(tensor, size=size, mode=mode)
def get_crop_region(mask, pad=0):
# Takes a black and white PIL image in 'L' mode and returns the coordinates of the white rectangular mask region
# Should be equivalent to the get_crop_region function from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/masking.py
coordinates = mask.getbbox()
if coordinates is not None:
x1, y1, x2, y2 = coordinates
else:
x1, y1, x2, y2 = mask.width, mask.height, 0, 0
# Apply padding
x1 = max(x1 - pad, 0)
y1 = max(y1 - pad, 0)
x2 = min(x2 + pad, mask.width)
y2 = min(y2 + pad, mask.height)
return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))
def fix_crop_region(region, image_size):
# Remove the extra pixel added by the get_crop_region function
image_width, image_height = image_size
x1, y1, x2, y2 = region
if x2 < image_width:
x2 -= 1
if y2 < image_height:
y2 -= 1
return x1, y1, x2, y2
def expand_crop(region, width, height, target_width, target_height):
'''
Expands a crop region to a specified target size.
:param region: A tuple of the form (x1, y1, x2, y2) denoting the upper left and the lower right points
of the rectangular region. Expected to have x2 > x1 and y2 > y1.
:param width: The width of the image the crop region is from.
:param height: The height of the image the crop region is from.
:param target_width: The desired width of the crop region.
:param target_height: The desired height of the crop region.
'''
x1, y1, x2, y2 = region
actual_width = x2 - x1
actual_height = y2 - y1
# target_width = math.ceil(actual_width / 8) * 8
# target_height = math.ceil(actual_height / 8) * 8
# Try to expand region to the right of half the difference
width_diff = target_width - actual_width
x2 = min(x2 + width_diff // 2, width)
# Expand region to the left of the difference including the pixels that could not be expanded to the right
width_diff = target_width - (x2 - x1)
x1 = max(x1 - width_diff, 0)
# Try the right again
width_diff = target_width - (x2 - x1)
x2 = min(x2 + width_diff, width)
# Try to expand region to the bottom of half the difference
height_diff = target_height - actual_height
y2 = min(y2 + height_diff // 2, height)
# Expand region to the top of the difference including the pixels that could not be expanded to the bottom
height_diff = target_height - (y2 - y1)
y1 = max(y1 - height_diff, 0)
# Try the bottom again
height_diff = target_height - (y2 - y1)
y2 = min(y2 + height_diff, height)
return (x1, y1, x2, y2), (target_width, target_height)
def resize_region(region, init_size, resize_size):
# Resize a crop so that it fits an image that was resized to the given width and height
x1, y1, x2, y2 = region
init_width, init_height = init_size
resize_width, resize_height = resize_size
x1 = math.floor(x1 * resize_width / init_width)
x2 = math.ceil(x2 * resize_width / init_width)
y1 = math.floor(y1 * resize_height / init_height)
y2 = math.ceil(y2 * resize_height / init_height)
return (x1, y1, x2, y2)
def pad_image(image, left_pad, right_pad, top_pad, bottom_pad, fill=False, blur=False):
'''
Pads an image with the given number of pixels on each side and fills the padding with data from the edges.
:param image: A PIL image
:param left_pad: The number of pixels to pad on the left side
:param right_pad: The number of pixels to pad on the right side
:param top_pad: The number of pixels to pad on the top side
:param bottom_pad: The number of pixels to pad on the bottom side
:param blur: Whether to blur the padded edges
:return: A PIL image with size (image.width + left_pad + right_pad, image.height + top_pad + bottom_pad)
'''
left_edge = image.crop((0, 1, 1, image.height - 1))
right_edge = image.crop((image.width - 1, 1, image.width, image.height - 1))
top_edge = image.crop((1, 0, image.width - 1, 1))
bottom_edge = image.crop((1, image.height - 1, image.width - 1, image.height))
new_width = image.width + left_pad + right_pad
new_height = image.height + top_pad + bottom_pad
padded_image = Image.new(image.mode, (new_width, new_height))
padded_image.paste(image, (left_pad, top_pad))
if fill:
for i in range(left_pad):
edge = left_edge.resize(
(1, new_height - i * (top_pad + bottom_pad) // left_pad), resample=Image.Resampling.NEAREST)
padded_image.paste(edge, (i, i * top_pad // left_pad))
for i in range(right_pad):
edge = right_edge.resize(
(1, new_height - i * (top_pad + bottom_pad) // right_pad), resample=Image.Resampling.NEAREST)
padded_image.paste(edge, (new_width - 1 - i, i * top_pad // right_pad))
for i in range(top_pad):
edge = top_edge.resize(
(new_width - i * (left_pad + right_pad) // top_pad, 1), resample=Image.Resampling.NEAREST)
padded_image.paste(edge, (i * left_pad // top_pad, i))
for i in range(bottom_pad):
edge = bottom_edge.resize(
(new_width - i * (left_pad + right_pad) // bottom_pad, 1), resample=Image.Resampling.NEAREST)
padded_image.paste(edge, (i * left_pad // bottom_pad, new_height - 1 - i))
if blur and not (left_pad == right_pad == top_pad == bottom_pad == 0):
padded_image = padded_image.filter(ImageFilter.GaussianBlur(BLUR_KERNEL_SIZE))
padded_image.paste(image, (left_pad, top_pad))
return padded_image
def pad_image2(image, left_pad, right_pad, top_pad, bottom_pad, fill=False, blur=False):
'''
Pads an image with the given number of pixels on each side and fills the padding with data from the edges.
Faster than pad_image, but only pads with edge data in straight lines.
:param image: A PIL image
:param left_pad: The number of pixels to pad on the left side
:param right_pad: The number of pixels to pad on the right side
:param top_pad: The number of pixels to pad on the top side
:param bottom_pad: The number of pixels to pad on the bottom side
:param blur: Whether to blur the padded edges
:return: A PIL image with size (image.width + left_pad + right_pad, image.height + top_pad + bottom_pad)
'''
left_edge = image.crop((0, 1, 1, image.height - 1))
right_edge = image.crop((image.width - 1, 1, image.width, image.height - 1))
top_edge = image.crop((1, 0, image.width - 1, 1))
bottom_edge = image.crop((1, image.height - 1, image.width - 1, image.height))
new_width = image.width + left_pad + right_pad
new_height = image.height + top_pad + bottom_pad
padded_image = Image.new(image.mode, (new_width, new_height))
padded_image.paste(image, (left_pad, top_pad))
if fill:
if left_pad > 0:
padded_image.paste(left_edge.resize((left_pad, new_height), resample=Image.Resampling.NEAREST), (0, 0))
if right_pad > 0:
padded_image.paste(right_edge.resize((right_pad, new_height),
resample=Image.Resampling.NEAREST), (new_width - right_pad, 0))
if top_pad > 0:
padded_image.paste(top_edge.resize((new_width, top_pad), resample=Image.Resampling.NEAREST), (0, 0))
if bottom_pad > 0:
padded_image.paste(bottom_edge.resize((new_width, bottom_pad),
resample=Image.Resampling.NEAREST), (0, new_height - bottom_pad))
if blur and not (left_pad == right_pad == top_pad == bottom_pad == 0):
padded_image = padded_image.filter(ImageFilter.GaussianBlur(BLUR_KERNEL_SIZE))
padded_image.paste(image, (left_pad, top_pad))
return padded_image
def pad_tensor(tensor, left_pad, right_pad, top_pad, bottom_pad, fill=False, blur=False):
'''
Pads an image tensor with the given number of pixels on each side and fills the padding with data from the edges.
:param tensor: A tensor of shape [B, H, W, C]
:param left_pad: The number of pixels to pad on the left side
:param right_pad: The number of pixels to pad on the right side
:param top_pad: The number of pixels to pad on the top side
:param bottom_pad: The number of pixels to pad on the bottom side
:param blur: Whether to blur the padded edges
:return: A tensor of shape [B, H + top_pad + bottom_pad, W + left_pad + right_pad, C]
'''
batch_size, channels, height, width = tensor.shape
h_pad = left_pad + right_pad
v_pad = top_pad + bottom_pad
new_width = width + h_pad
new_height = height + v_pad
# Create empty image
padded = torch.zeros((batch_size, channels, new_height, new_width), dtype=tensor.dtype)
# Copy the original image into the centor of the padded tensor
padded[:, :, top_pad:top_pad + height, left_pad:left_pad + width] = tensor
# Duplicate the edges of the original image into the padding
if top_pad > 0:
padded[:, :, :top_pad, :] = padded[:, :, top_pad:top_pad + 1, :] # Top edge
if bottom_pad > 0:
padded[:, :, -bottom_pad:, :] = padded[:, :, -bottom_pad - 1:-bottom_pad, :] # Bottom edge
if left_pad > 0:
padded[:, :, :, :left_pad] = padded[:, :, :, left_pad:left_pad + 1] # Left edge
if right_pad > 0:
padded[:, :, :, -right_pad:] = padded[:, :, :, -right_pad - 1:-right_pad] # Right edge
return padded
def resize_and_pad_image(image, width, height, fill=False, blur=False):
'''
Resizes an image to the given width and height and pads it to the given width and height.
:param image: A PIL image
:param width: The width of the resized image
:param height: The height of the resized image
:param fill: Whether to fill the padding with data from the edges
:param blur: Whether to blur the padded edges
:return: A PIL image of size (width, height)
'''
width_ratio = width / image.width
height_ratio = height / image.height
if height_ratio > width_ratio:
resize_ratio = width_ratio
else:
resize_ratio = height_ratio
resize_width = round(image.width * resize_ratio)
resize_height = round(image.height * resize_ratio)
resized = image.resize((resize_width, resize_height), resample=Image.Resampling.LANCZOS)
# Pad the sides of the image to get the image to the desired size that wasn't covered by the resize
horizontal_pad = (width - resize_width) // 2
vertical_pad = (height - resize_height) // 2
result = pad_image2(resized, horizontal_pad, horizontal_pad, vertical_pad, vertical_pad, fill, blur)
result = result.resize((width, height), resample=Image.Resampling.LANCZOS)
return result, (horizontal_pad, vertical_pad)
def resize_and_pad_tensor(tensor, width, height, fill=False, blur=False):
'''
Resizes an image tensor to the given width and height and pads it to the given width and height.
:param tensor: A tensor of shape [B, H, W, C]
:param width: The width of the resized image
:param height: The height of the resized image
:param fill: Whether to fill the padding with data from the edges
:param blur: Whether to blur the padded edges
:return: A tensor of shape [B, height, width, C]
'''
# Resize the image to the closest size that maintains the aspect ratio
width_ratio = width / tensor.shape[3]
height_ratio = height / tensor.shape[2]
if height_ratio > width_ratio:
resize_ratio = width_ratio
else:
resize_ratio = height_ratio
resize_width = round(tensor.shape[3] * resize_ratio)
resize_height = round(tensor.shape[2] * resize_ratio)
resized = F.interpolate(tensor, size=(resize_height, resize_width), mode='nearest-exact')
# Pad the sides of the image to get the image to the desired size that wasn't covered by the resize
horizontal_pad = (width - resize_width) // 2
vertical_pad = (height - resize_height) // 2
result = pad_tensor(resized, horizontal_pad, horizontal_pad, vertical_pad, vertical_pad, fill, blur)
result = F.interpolate(result, size=(height, width), mode='nearest-exact')
return result
def crop_controlnet(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad):
if "control" not in cond_dict:
return
c = cond_dict["control"]
controlnet = c.copy()
cond_dict["control"] = controlnet
while c is not None:
# hint is shape (B, C, H, W)
hint = controlnet.cond_hint_original
resized_crop = resize_region(region, canvas_size, hint.shape[:-3:-1])
hint = crop_tensor(hint.movedim(1, -1), resized_crop).movedim(-1, 1)
hint = resize_tensor(hint, tile_size[::-1])
controlnet.cond_hint_original = hint
c = c.previous_controlnet
controlnet.set_previous_controlnet(c.copy() if c is not None else None)
controlnet = controlnet.previous_controlnet
def region_intersection(region1, region2):
"""
Returns the coordinates of the intersection of two rectangular regions.
:param region1: A tuple of the form (x1, y1, x2, y2) denoting the upper left and the lower right points
of the first rectangular region. Expected to have x2 > x1 and y2 > y1.
:param region2: The second rectangular region with the same format as the first.
:return: A tuple of the form (x1, y1, x2, y2) denoting the rectangular intersection.
None if there is no intersection.
"""
x1, y1, x2, y2 = region1
x1_, y1_, x2_, y2_ = region2
x1 = max(x1, x1_)
y1 = max(y1, y1_)
x2 = min(x2, x2_)
y2 = min(y2, y2_)
if x1 >= x2 or y1 >= y2:
return None
return (x1, y1, x2, y2)
def crop_gligen(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad):
if "gligen" not in cond_dict:
return
type, model, cond = cond_dict["gligen"]
if type != "position":
from warnings import warn
warn(f"Unknown gligen type {type}")
return
cropped = []
for c in cond:
emb, h, w, y, x = c
# Get the coordinates of the box in the upscaled image
x1 = x * 8
y1 = y * 8
x2 = x1 + w * 8
y2 = y1 + h * 8
gligen_upscaled_box = resize_region((x1, y1, x2, y2), init_size, canvas_size)
# Calculate the intersection of the gligen box and the region
intersection = region_intersection(gligen_upscaled_box, region)
if intersection is None:
continue
x1, y1, x2, y2 = intersection
# Offset the gligen box so that the origin is at the top left of the tile region
x1 -= region[0]
y1 -= region[1]
x2 -= region[0]
y2 -= region[1]
# Add the padding
x1 += w_pad
y1 += h_pad
x2 += w_pad
y2 += h_pad
# Set the new position params
h = (y2 - y1) // 8
w = (x2 - x1) // 8
x = x1 // 8
y = y1 // 8
cropped.append((emb, h, w, y, x))
cond_dict["gligen"] = (type, model, cropped)
def crop_area(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad):
if "area" not in cond_dict:
return
# Resize the area conditioning to the canvas size and confine it to the tile region
h, w, y, x = cond_dict["area"]
w, h, x, y = 8 * w, 8 * h, 8 * x, 8 * y
x1, y1, x2, y2 = resize_region((x, y, x + w, y + h), init_size, canvas_size)
intersection = region_intersection((x1, y1, x2, y2), region)
if intersection is None:
del cond_dict["area"]
del cond_dict["strength"]
return
x1, y1, x2, y2 = intersection
# Offset origin to the top left of the tile
x1 -= region[0]
y1 -= region[1]
x2 -= region[0]
y2 -= region[1]
# Add the padding
x1 += w_pad
y1 += h_pad
x2 += w_pad
y2 += h_pad
# Set the params for tile
w, h = (x2 - x1) // 8, (y2 - y1) // 8
x, y = x1 // 8, y1 // 8
cond_dict["area"] = (h, w, y, x)
def crop_mask(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad):
if "mask" not in cond_dict:
return
mask_tensor = cond_dict["mask"] # (B, H, W)
masks = []
for i in range(mask_tensor.shape[0]):
# Convert to PIL image
mask = tensor_to_pil(mask_tensor, i) # W x H
# Resize the mask to the canvas size
mask = mask.resize(canvas_size, Image.Resampling.BICUBIC)
# Crop the mask to the region
mask = mask.crop(region)
# Add padding
mask, _ = resize_and_pad_image(mask, tile_size[0], tile_size[1], fill=True)
# Resize the mask to the tile size
if tile_size != mask.size:
mask = mask.resize(tile_size, Image.Resampling.BICUBIC)
# Convert back to tensor
mask = pil_to_tensor(mask) # (1, H, W, 1)
mask = mask.squeeze(-1) # (1, H, W)
masks.append(mask)
cond_dict["mask"] = torch.cat(masks, dim=0) # (B, H, W)
# Added Flux-Kontext Support crop_reference_latents by TBG ETUR
def crop_reference_latents(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad):
"""
1. Resize each latent to `canvas_size` in latent units.
2. Crop the rectangle `region` (pixel coordinates).
3. Down-sample the crop to latent-space `tile_size`.
Expects a list of BCHW tensors under "reference_latents".
"""
latents = cond_dict.get("reference_latents")
if not isinstance(latents, list):
return # nothing to do
k = 8 # down-sample factor from pixel space → latent space (SD-type models)
W_can_px, H_can_px = canvas_size
# canvas size expressed in latent units
W_can_lat, H_can_lat = W_can_px // k, H_can_px // k
W_tile_px, H_tile_px = tile_size
W_tile_lat, H_tile_lat = max(1, W_tile_px // k), max(1, H_tile_px // k)
x1_px, y1_px, x2_px, y2_px = region
new_latents = []
for t in latents: # (B,C,H_lat_in,W_lat_in) or (B,C,1,H_lat_in,W_lat_in)
has_5d = False
if t.ndim == 5:
has_5d = True
t = t.squeeze(2)
if t.ndim != 4:
raise ValueError(f"expected BCHW or BC1HW, got {t.shape}")
# 1. Resize to canvas resolution in latent units only if needed
if t.shape[-2:] != (H_can_lat, W_can_lat):
t = F.interpolate(t,
size=(H_can_lat, W_can_lat),
mode="bilinear",
align_corners=False)
# 2. Convert pixel crop → latent slice
w0_lat = int(round(x1_px / k))
w1_lat = int(round(x2_px / k))
h0_lat = int(round(y1_px / k))
h1_lat = int(round(y2_px / k))
cropped = t[:, :, h0_lat:h1_lat, w0_lat:w1_lat] # view
# 3. Down-sample to latent-tile size
cropped = F.interpolate(cropped,
size=(H_tile_lat, W_tile_lat),
mode="bilinear",
align_corners=False)
if has_5d:
cropped = cropped.unsqueeze(2)
new_latents.append(cropped)
cond_dict["reference_latents"] = new_latents
def crop_cond(cond, region, init_size, canvas_size, tile_size, w_pad=0, h_pad=0):
cropped = []
for emb, x in cond:
cond_dict = x.copy()
n = [emb, cond_dict]
crop_controlnet(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad)
crop_gligen(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad)
crop_area(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad)
crop_mask(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad)
crop_reference_latents(cond_dict, region, init_size, canvas_size, tile_size, w_pad, h_pad)
cropped.append(n)
return cropped
================================================
FILE: vitest.config.js
================================================
import { defineConfig } from "vitest/config";
export default defineConfig({
test: {
include: ["web/tests/**/*.test.js"],
environment: "node",
},
});
================================================
FILE: web/apiClient.js
================================================
import { TIMEOUTS } from './constants.js';
import { normalizeWorkerUrl } from './urlUtils.js';
export function createApiClient(baseUrl) {
const normalizedBaseUrl = normalizeWorkerUrl(baseUrl);
const request = async (
endpoint,
options = {},
{ retries = TIMEOUTS.MAX_RETRIES, retry = true } = {},
) => {
const maxAttempts = retry ? retries : 1;
let lastError;
let delay = TIMEOUTS.RETRY_DELAY; // Initial delay for exponential backoff
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const headers = {
'Content-Type': 'application/json',
...(options.headers || {}),
};
const response = await fetch(`${normalizedBaseUrl}${endpoint}`, {
...options,
headers,
});
if (!response.ok) {
const error = await response.json().catch(() => ({}));
const message = error.message
|| error.error
|| (Array.isArray(error.errors) ? error.errors.join('; ') : null)
|| `HTTP ${response.status}`;
throw new Error(message);
}
return await response.json();
} catch (error) {
lastError = error;
console.log(`API Error (attempt ${attempt + 1}/${maxAttempts}): ${endpoint} - ${error.message}`);
if (attempt < maxAttempts - 1) {
await new Promise(resolve => setTimeout(resolve, delay));
delay *= 2; // Exponential backoff
}
}
}
throw lastError;
};
const requestUrl = async (
url,
options = {},
{ retries = TIMEOUTS.MAX_RETRIES, retry = true } = {},
) => {
const maxAttempts = retry ? retries : 1;
let lastError;
let delay = TIMEOUTS.RETRY_DELAY;
for (let attempt = 0; attempt < maxAttempts; attempt++) {
try {
const response = await fetch(url, options);
if (!response.ok) {
const error = await response.json().catch(() => ({}));
const message = error.message
|| error.error
|| (Array.isArray(error.errors) ? error.errors.join('; ') : null)
|| `HTTP ${response.status}`;
throw new Error(message);
}
return await response.json();
} catch (error) {
lastError = error;
console.log(`API Error (attempt ${attempt + 1}/${maxAttempts}): ${url} - ${error.message}`);
if (attempt < maxAttempts - 1) {
await new Promise(resolve => setTimeout(resolve, delay));
delay *= 2;
}
}
}
throw lastError;
};
return {
// Config endpoints
async getConfig() {
return request('/distributed/config');
},
async updateWorker(workerId, data) {
return request('/distributed/config/update_worker', {
method: 'POST',
body: JSON.stringify({ worker_id: workerId, ...data })
}, { retry: false });
},
async deleteWorker(workerId) {
return request('/distributed/config/delete_worker', {
method: 'POST',
body: JSON.stringify({ worker_id: workerId })
}, { retry: false });
},
async updateSetting(key, value) {
return request('/distributed/config/update_setting', {
method: 'POST',
body: JSON.stringify({ key, value })
}, { retry: false });
},
async updateMaster(data) {
return request('/distributed/config/update_master', {
method: 'POST',
body: JSON.stringify(data)
}, { retry: false });
},
// Worker management endpoints
async launchWorker(workerId) {
return request('/distributed/launch_worker', {
method: 'POST',
body: JSON.stringify({ worker_id: workerId })
}, { retry: false });
},
async stopWorker(workerId) {
return request('/distributed/stop_worker', {
method: 'POST',
body: JSON.stringify({ worker_id: workerId })
}, { retry: false });
},
async getManagedWorkers() {
return request('/distributed/managed_workers');
},
async getWorkerLog(workerId, lines = 1000) {
return request(`/distributed/worker_log/${workerId}?lines=${lines}`);
},
async getRemoteWorkerLog(workerId, lines = 300) {
return request(`/distributed/remote_worker_log/${workerId}?lines=${lines}`);
},
async clearLaunchingFlag(workerId) {
return request('/distributed/worker/clear_launching', {
method: 'POST',
body: JSON.stringify({ worker_id: workerId })
}, { retry: false });
},
async queueDistributed(payload) {
return request('/distributed/queue', {
method: 'POST',
headers: {
...(payload?.trace_execution_id
? { 'X-Idempotency-Key': payload.trace_execution_id }
: {}),
},
body: JSON.stringify(payload)
}, { retry: false });
},
async probeWorker(workerUrl, timeoutMs = TIMEOUTS.STATUS_CHECK, signal = null) {
const normalizedWorkerUrl = normalizeWorkerUrl(workerUrl);
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeoutMs);
const effectiveSignal = signal
? AbortSignal.any([controller.signal, signal])
: controller.signal;
try {
const response = await fetch(`${normalizedWorkerUrl}/prompt`, {
method: 'GET',
mode: 'cors',
cache: 'no-store',
signal: effectiveSignal,
});
if (!response.ok) {
return { ok: false, status: response.status, queueRemaining: null };
}
let data;
try {
data = await response.json();
} catch {
return { ok: false, status: response.status, queueRemaining: null };
}
if (!data || typeof data !== "object" || Array.isArray(data)) {
return { ok: false, status: response.status, queueRemaining: null };
}
const execInfo = data.exec_info;
if (!execInfo || typeof execInfo !== "object" || Array.isArray(execInfo)) {
return { ok: false, status: response.status, queueRemaining: null };
}
const rawQueueRemaining = execInfo.queue_remaining;
const queueRemaining = Number(rawQueueRemaining);
if (!Number.isFinite(queueRemaining)) {
return { ok: false, status: response.status, queueRemaining: null };
}
return {
ok: true,
status: response.status,
queueRemaining: Math.max(0, queueRemaining),
};
} finally {
clearTimeout(timeoutId);
}
},
async dispatchToWorker(workerUrl, promptPayload) {
const normalizedWorkerUrl = normalizeWorkerUrl(workerUrl);
return requestUrl(`${normalizedWorkerUrl}/prompt`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
mode: 'cors',
body: JSON.stringify(promptPayload),
}, { retry: false });
},
// Network info
async getNetworkInfo() {
return request('/distributed/network_info');
},
// Status checking (with timeout)
async checkStatus(url, timeout = TIMEOUTS.DEFAULT_FETCH) {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), timeout);
try {
const response = await fetch(url, {
method: 'GET',
mode: 'cors',
signal: controller.signal
});
clearTimeout(timeoutId);
if (!response.ok) throw new Error(`HTTP ${response.status}`);
return await response.json();
} catch (error) {
clearTimeout(timeoutId);
throw error;
}
},
// Batch status checking
async checkMultipleStatuses(urls) {
return Promise.allSettled(
urls.map(url => this.checkStatus(url))
);
},
// Cloudflare tunnel management
async startTunnel() {
return request('/distributed/tunnel/start', {
method: 'POST',
body: JSON.stringify({})
}, { retry: false });
},
async stopTunnel() {
return request('/distributed/tunnel/stop', {
method: 'POST',
body: JSON.stringify({})
}, { retry: false });
},
async getTunnelStatus() {
return request('/distributed/tunnel/status');
}
};
}
================================================
FILE: web/constants.js
================================================
export const BUTTON_STYLES = {
// Base styles with unified padding
base: "width: 100%; padding: 4px 14px; color: white; border: none; border-radius: 4px; cursor: pointer; transition: all 0.2s; font-size: 12px; font-weight: 500;",
// Context-specific combined styles
workerControl: "flex: 1; font-size: 11px;",
// Layout modifiers
hidden: "display: none;",
marginLeftAuto: "margin-left: auto;",
// Color variants
cancel: "background-color: #555;",
info: "background-color: #333;",
success: "background-color: #4a7c4a;",
error: "background-color: #7c4a4a;",
launch: "background-color: #4a7c4a;",
stop: "background-color: #7c4a4a;",
log: "background-color: #685434;",
working: "background-color: #666;",
clearMemory: "background-color: #555; padding: 6px 14px;",
interrupt: "background-color: #555; padding: 6px 14px;",
};
export const STATUS_COLORS = {
DISABLED_GRAY: "#666",
OFFLINE_RED: "#c04c4c",
ONLINE_GREEN: "#3ca03c",
PROCESSING_YELLOW: "#f0ad4e"
};
export const UI_COLORS = {
MUTED_TEXT: "#888",
SECONDARY_TEXT: "#ccc",
BORDER_LIGHT: "#555",
BORDER_DARK: "#444",
BORDER_DARKER: "#3a3a3a",
BACKGROUND_DARK: "#2a2a2a",
BACKGROUND_DARKER: "#1e1e1e",
ICON_COLOR: "#666",
ACCENT_COLOR: "#777"
};
export const PULSE_ANIMATION_CSS = `
@keyframes pulse {
0% {
opacity: 1;
transform: scale(0.8);
box-shadow: 0 0 0 0 rgba(240, 173, 78, 0.7);
}
50% {
opacity: 0.3;
transform: scale(1.1);
box-shadow: 0 0 0 6px rgba(240, 173, 78, 0);
}
100% {
opacity: 1;
transform: scale(0.8);
box-shadow: 0 0 0 0 rgba(240, 173, 78, 0);
}
}
.status-pulsing {
animation: pulse 1.2s ease-in-out infinite;
transform-origin: center;
}
.worker-status--online {
background: var(--status-online, #3ca03c) !important;
}
.worker-status--offline {
background: var(--status-offline, #c04c4c) !important;
}
.worker-status--unknown {
background: var(--status-unknown, #888) !important;
}
.worker-status--processing {
background: var(--status-processing, #f0ad4e) !important;
}
/* Button hover effects */
.distributed-button:hover:not(:disabled) {
filter: brightness(1.2);
transition: filter 0.2s ease;
}
.distributed-button:disabled {
opacity: 0.6;
cursor: not-allowed;
}
/* Settings button animation */
.settings-btn {
transition: transform 0.2s ease;
}
/* Expanded settings panel */
.worker-settings {
max-height: 0;
overflow: hidden;
opacity: 0;
transition: max-height 0.3s ease, opacity 0.3s ease, padding 0.3s ease, margin 0.3s ease;
}
.worker-settings.expanded {
max-height: 500px;
opacity: 1;
padding: 12px;
margin-top: 8px;
margin-bottom: 8px;
}
/* Cloudflare tunnel spinner */
@keyframes tunnel-spin {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
.tunnel-spinner {
width: 14px;
height: 14px;
border: 2px solid rgba(255, 255, 255, 0.35);
border-top-color: #fff;
border-radius: 50%;
display: inline-block;
animation: tunnel-spin 0.9s linear infinite;
margin-right: 8px;
vertical-align: middle;
}
`;
export const UI_STYLES = {
statusDot: "display: inline-block; width: 10px; height: 10px; border-radius: 50%; margin-right: 10px;",
controlsDiv: "padding: 0 12px 12px 12px; display: flex; gap: 6px;",
formGroup: "display: flex; flex-direction: column; gap: 5px;",
formLabel: "font-size: 12px; color: var(--dist-label-text, #ccc); font-weight: 500;",
formInput:
"padding: 6px 10px; color: var(--dist-input-text, white); background: var(--dist-input-bg, transparent); font-size: 12px; transition: border-color 0.2s;",
// Card styles
cardBase: "margin-bottom: 12px; overflow: hidden; display: flex;",
workerCard: "margin-bottom: 12px; overflow: hidden; display: flex;",
cardBlueprint: "cursor: pointer; transition: all 0.2s ease;",
cardAdd: "cursor: pointer; transition: all 0.2s ease;",
// Column styles
columnBase: "display: flex; align-items: center; justify-content: center;",
checkboxColumn: "flex: 0 0 44px; display: flex; align-items: center; justify-content: center; cursor: default;",
contentColumn: "flex: 1; display: flex; flex-direction: column; transition: background-color 0.2s ease;",
iconColumn: "width: 44px; flex-shrink: 0; font-size: 20px; color: var(--dist-placeholder-add-color, #666);",
// Row and content styles
infoRow: "display: flex; align-items: center; padding: 12px; cursor: pointer; min-height: 64px;",
workerContent: "display: flex; align-items: center; gap: 10px; flex: 1;",
// Form and controls styles
buttonGroup: "display: flex; gap: 4px; margin-top: 10px;",
settingsForm: "display: flex; flex-direction: column; gap: 10px;",
checkboxGroup: "display: flex; align-items: center; gap: 8px; margin: 5px 0;",
formLabelClickable: "font-size: 12px; color: var(--dist-label-text, #ccc); cursor: pointer;",
settingsToggle: "display: flex; align-items: center; gap: 6px; padding: 4px 0; cursor: pointer; user-select: none;",
controlsWrapper: "display: flex; gap: 6px; align-items: stretch; width: 100%;",
// Existing styles
settingsArrow:
"font-size: 12px; color: var(--dist-settings-arrow, #888); transition: all 0.2s ease; margin-left: auto; padding: 4px;",
infoBox:
"color: var(--dist-info-box-text, #999); padding: 5px 14px; font-size: 11px; text-align: center; flex: 1; font-weight: 500;",
workerSettings: "margin: 0 12px; padding: 0 12px;"
};
export const TIMEOUTS = {
DEFAULT_FETCH: 5000, // ms for general API calls
STATUS_CHECK: 1200, // ms for status checks
LAUNCH: 90000, // ms for worker launch (longer for model loading)
RETRY_DELAY: 1000, // initial delay for exponential backoff
MAX_RETRIES: 3, // max retry attempts
// UI feedback delays
BUTTON_RESET: 3000, // button text/state reset after actions
FLASH_SHORT: 1000, // brief success feedback
FLASH_MEDIUM: 1500, // medium error feedback
FLASH_LONG: 2000, // longer error feedback
// Operational delays
POST_ACTION_DELAY: 500, // delay after operations before status checks
STATUS_CHECK_DELAY: 100, // brief delay before status checks
// Background tasks
LOG_REFRESH: 2000, // log auto-refresh interval
IMAGE_CACHE_CLEAR: 30000 // delay before clearing image cache
};
export const ENDPOINTS = {
// ComfyUI core
PROMPT: '/prompt',
INTERRUPT: '/interrupt',
UPLOAD_IMAGE: '/upload/image',
SYSTEM_INFO: '/system_stats',
// Distributed API
CONFIG: '/distributed/config',
UPDATE_WORKER: '/distributed/config/update_worker',
DELETE_WORKER: '/distributed/config/delete_worker',
UPDATE_SETTING: '/distributed/config/update_setting',
UPDATE_MASTER: '/distributed/config/update_master',
LAUNCH_WORKER: '/distributed/launch_worker',
STOP_WORKER: '/distributed/stop_worker',
MANAGED_WORKERS: '/distributed/managed_workers',
WORKER_LOG: '/distributed/worker_log',
REMOTE_WORKER_LOG: '/distributed/remote_worker_log',
LOCAL_LOG: '/distributed/local_log',
CLEAR_LAUNCHING: '/distributed/worker/clear_launching',
PREPARE_JOB: '/distributed/prepare_job',
LOAD_IMAGE: '/distributed/load_image',
NETWORK_INFO: '/distributed/network_info',
CHECK_FILE: '/distributed/check_file',
CLEAR_MEMORY: '/distributed/clear_memory',
SYSTEM_INFO_DIST: '/distributed/system_info',
TUNNEL_START: '/distributed/tunnel/start',
TUNNEL_STOP: '/distributed/tunnel/stop',
TUNNEL_STATUS: '/distributed/tunnel/status',
};
export const NODE_CLASSES = {
DISTRIBUTED_COLLECTOR: 'DistributedCollector',
DISTRIBUTED_SEED: 'DistributedSeed',
DISTRIBUTED_EMPTY_IMAGE: 'DistributedEmptyImage',
UPSCALE_DISTRIBUTED: 'UltimateSDUpscaleDistributed',
PREVIEW_IMAGE: 'PreviewImage',
};
export function generateUUID() {
if (crypto.randomUUID) return crypto.randomUUID();
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, c => {
const r = Math.random() * 16 | 0;
return (c === 'x' ? r : (r & 0x3 | 0x8)).toString(16);
});
}
================================================
FILE: web/distributed.css
================================================
:root {
--btn-stop: #7c4a4a;
--btn-launch: #4a7c4a;
--btn-log: #685434;
--btn-working: #666;
--btn-success: #3a6a3a;
--btn-error: #6a3a3a;
--tunnel-enable: #665533;
--tunnel-disable: #7c4a4a;
--master-badge-fallback-bg: #243024;
--master-badge-fallback-text: #6bd06b;
--master-badge-fallback-border: #335533;
--master-badge-delegate-bg: #3a3a3a;
--master-badge-delegate-text: #ffcc66;
--dist-divider: #444;
--dist-muted-text: #888;
--dist-label-text: #ccc;
--dist-settings-arrow: #888;
--dist-settings-arrow-hover: #fff;
--dist-card-bg: #2a2a2a;
--dist-card-title: #f4f5f7;
--dist-card-subtext: #a9afb9;
--dist-card-placeholder-title: #aaa;
--dist-left-col-border: #3a3a3a;
--dist-left-col-bg: rgba(0, 0, 0, 0.1);
--dist-info-box-bg: #333;
--dist-info-box-text: #999;
--dist-input-bg: #2a2a2a;
--dist-input-border: #444;
--dist-input-text: #fff;
--dist-settings-bg: #1e1e1e;
--dist-settings-border: #2a2a2a;
--dist-hover-bg: #333;
--dist-placeholder-blueprint-border: #555;
--dist-placeholder-blueprint-hover-border: #777;
--dist-placeholder-add-border: #444;
--dist-placeholder-add-hover-border: #666;
--dist-placeholder-blueprint-bg: rgba(255, 255, 255, 0.02);
--dist-placeholder-blueprint-hover-bg: rgba(255, 255, 255, 0.05);
--dist-placeholder-add-hover-bg: rgba(255, 255, 255, 0.02);
--dist-placeholder-blueprint-color: #777;
--dist-placeholder-blueprint-hover-color: #999;
--dist-placeholder-add-color: #555;
--dist-placeholder-add-hover-color: #888;
--dist-log-modal-bg: #1e1e1e;
--dist-log-modal-border: #444;
--dist-log-modal-header-border: #444;
--dist-log-modal-title: #fff;
--dist-log-modal-label: #ccc;
--dist-log-modal-body-bg: #0d0d0d;
--dist-log-modal-body-text: #ddd;
--dist-log-modal-status: #888;
}
.distributed-panel--light {
--dist-divider: #c7ced9;
--dist-muted-text: #5b6472;
--dist-label-text: #2f3a4a;
--dist-settings-arrow: #6b7483;
--dist-settings-arrow-hover: #253040;
--dist-card-bg: #f6f8fb;
--dist-card-title: #1e293b;
--dist-card-subtext: #4b5565;
--dist-card-placeholder-title: #3f4a5a;
--dist-left-col-border: #d8deea;
--dist-left-col-bg: #eef2f8;
--dist-info-box-bg: #e8edf5;
--dist-info-box-text: #4f5b6c;
--dist-input-bg: #ffffff;
--dist-input-border: #b8c2d3;
--dist-input-text: #1f2937;
--dist-settings-bg: #eef2f8;
--dist-settings-border: #d5dcea;
--dist-hover-bg: #e6ebf3;
--dist-placeholder-blueprint-border: #aeb8c9;
--dist-placeholder-blueprint-hover-border: #8e9ab0;
--dist-placeholder-add-border: #b7c1d3;
--dist-placeholder-add-hover-border: #919eb4;
--dist-placeholder-blueprint-bg: rgba(80, 100, 140, 0.05);
--dist-placeholder-blueprint-hover-bg: rgba(80, 100, 140, 0.08);
--dist-placeholder-add-hover-bg: rgba(80, 100, 140, 0.05);
--dist-placeholder-blueprint-color: #5f6c82;
--dist-placeholder-blueprint-hover-color: #47556e;
--dist-placeholder-add-color: #6a7588;
--dist-placeholder-add-hover-color: #4c586f;
--master-badge-delegate-bg: #efe7ce;
--master-badge-delegate-text: #6a4f00;
--master-badge-fallback-bg: #dff1df;
--master-badge-fallback-text: #1f5f1f;
--master-badge-fallback-border: #9fc79f;
--dist-log-modal-bg: #f9fbff;
--dist-log-modal-border: #bcc7db;
--dist-log-modal-header-border: #c7d1e3;
--dist-log-modal-title: #1f2937;
--dist-log-modal-label: #4a5568;
--dist-log-modal-body-bg: #f2f5fb;
--dist-log-modal-body-text: #1f2937;
--dist-log-modal-status: #5c6779;
}
.btn--stop {
background-color: var(--btn-stop) !important;
}
.btn--launch {
background-color: var(--btn-launch) !important;
}
.btn--log {
background-color: var(--btn-log) !important;
}
.btn--working {
background-color: var(--btn-working) !important;
}
.btn--success {
background-color: var(--btn-success) !important;
}
.btn--error {
background-color: var(--btn-error) !important;
}
.master-info-badge--fallback {
background-color: var(--master-badge-fallback-bg) !important;
color: var(--master-badge-fallback-text) !important;
border: 1px solid var(--master-badge-fallback-border) !important;
}
.master-info-badge--delegate {
background-color: var(--master-badge-delegate-bg) !important;
color: var(--master-badge-delegate-text) !important;
}
.entity-card-content--hovered {
background-color: var(--dist-hover-bg) !important;
}
.placeholder-card--blueprint {
border-color: var(--dist-placeholder-blueprint-border) !important;
background-color: var(--dist-placeholder-blueprint-bg) !important;
}
.placeholder-card--blueprint.is-hovered {
border-color: var(--dist-placeholder-blueprint-hover-border) !important;
background-color: var(--dist-placeholder-blueprint-hover-bg) !important;
}
.placeholder-card--add {
border-color: var(--dist-placeholder-add-border) !important;
background-color: transparent !important;
}
.placeholder-card--add.is-hovered {
border-color: var(--dist-placeholder-add-hover-border) !important;
background-color: var(--dist-placeholder-add-hover-bg) !important;
}
.placeholder-column--blueprint {
color: var(--dist-placeholder-blueprint-color) !important;
border-right-color: var(--dist-placeholder-blueprint-border) !important;
}
.placeholder-column--blueprint.is-hovered {
color: var(--dist-placeholder-blueprint-hover-color) !important;
}
.placeholder-column--add {
color: var(--dist-placeholder-add-color) !important;
border-color: var(--dist-placeholder-add-border) !important;
border-right-color: var(--dist-placeholder-add-border) !important;
}
.placeholder-column--add.is-hovered {
color: var(--dist-placeholder-add-hover-color) !important;
border-color: var(--dist-placeholder-add-hover-border) !important;
}
.tunnel-button--enable {
background-color: var(--tunnel-enable) !important;
}
.tunnel-button--disable {
background-color: var(--tunnel-disable) !important;
}
.tunnel-status--enable {
background-color: var(--tunnel-enable) !important;
}
.tunnel-status--disable {
background-color: var(--tunnel-disable) !important;
}
/* ---- Themeable card/column/input defaults (Classic) ---- */
.dist-card {
background: var(--dist-card-bg);
color: var(--dist-card-title);
border-radius: 6px;
}
.dist-card--blueprint {
border: 2px dashed var(--dist-placeholder-blueprint-border);
background: var(--dist-placeholder-blueprint-bg);
}
.dist-card--add {
border: 1px dashed var(--dist-placeholder-add-border);
background: transparent;
}
.dist-card__left-col {
border-right: 1px solid var(--dist-left-col-border);
background: var(--dist-left-col-bg);
}
.dist-info-box {
background-color: var(--dist-info-box-bg);
color: var(--dist-info-box-text);
border-radius: 4px;
}
.dist-form-input {
background: var(--dist-input-bg);
border: 1px solid var(--dist-input-border);
color: var(--dist-input-text);
border-radius: 4px;
}
.worker-settings {
background: var(--dist-settings-bg);
border: 1px solid var(--dist-settings-border);
border-radius: 4px;
}
.dist-worker-info__title {
color: var(--dist-card-title) !important;
font-size: 1.03em;
font-weight: 700;
}
.dist-worker-info__meta {
color: var(--dist-card-subtext) !important;
font-size: 0.88em;
}
.dist-worker-info__fallback {
color: var(--master-badge-fallback-text) !important;
font-size: 0.86em;
font-weight: 600;
}
/* ---- Nodes 2.0 Theme ---- */
.distributed-panel--nodes2 .dist-card {
background: var(--p-surface-800, #1e1e1e);
border: 1px solid var(--p-surface-700, #2c2c2c);
border-radius: 8px;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
}
.distributed-panel--nodes2 .dist-card--blueprint {
border: 2px dashed var(--p-surface-600, #404040);
background: rgba(255, 255, 255, 0.015);
box-shadow: none;
}
.distributed-panel--nodes2 .dist-card--add {
border: 1px dashed var(--p-surface-600, #404040);
background: transparent;
box-shadow: none;
}
.distributed-panel--nodes2 .dist-card__left-col {
border-right-color: var(--p-surface-700, #2c2c2c);
background: var(--p-surface-900, #131313);
}
.distributed-panel--nodes2 .dist-info-box {
background-color: var(--p-surface-700, #2c2c2c);
color: var(--p-text-muted-color, #9e9e9e);
border-radius: 6px;
}
.distributed-panel--nodes2 .dist-form-input {
background: var(--p-surface-800, #1e1e1e);
border-color: var(--p-surface-600, #404040);
border-radius: 6px;
color: var(--p-text-color, #ffffff);
}
.distributed-panel--nodes2 .worker-settings {
background: var(--p-surface-900, #131313);
border-color: var(--p-surface-700, #2c2c2c);
border-radius: 6px;
}
.distributed-panel--nodes2 .entity-card-content--hovered {
background-color: var(--p-surface-700, #2c2c2c) !important;
}
.distributed-panel--nodes2 .placeholder-card--blueprint {
border-color: var(--p-surface-600, #404040) !important;
}
.distributed-panel--nodes2 .placeholder-card--blueprint.is-hovered {
border-color: var(--p-surface-500, #555555) !important;
}
.distributed-panel--nodes2 .placeholder-card--add {
border-color: var(--p-surface-600, #404040) !important;
}
.distributed-panel--nodes2 .placeholder-card--add.is-hovered {
border-color: var(--p-surface-500, #555555) !important;
}
/* ---- End Nodes 2.0 Theme ---- */
.is-hidden {
display: none !important;
}
.settings-arrow--expanded {
transform: rotate(90deg) !important;
}
.log-modal {
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background: rgba(0, 0, 0, 0.8);
display: flex;
align-items: center;
justify-content: center;
z-index: 10000;
}
.log-modal__content {
background: var(--dist-log-modal-bg);
border-radius: 8px;
width: 90%;
max-width: 1200px;
height: 80%;
display: flex;
flex-direction: column;
border: 1px solid var(--dist-log-modal-border);
}
.log-modal__header {
padding: 15px 20px;
border-bottom: 1px solid var(--dist-log-modal-header-border);
display: flex;
justify-content: space-between;
align-items: center;
}
.log-modal__title {
margin: 0;
color: var(--dist-log-modal-title);
}
.log-modal__header-buttons {
display: flex;
gap: 20px;
align-items: center;
}
.log-modal__refresh {
display: flex;
align-items: center;
gap: 4px;
}
.log-modal__refresh-input {
cursor: pointer;
}
.log-modal__refresh-label {
font-size: 12px;
color: var(--dist-log-modal-label);
cursor: pointer;
white-space: nowrap;
}
.log-modal__close {
background-color: #c04c4c !important;
padding: 5px 10px !important;
font-size: 14px !important;
font-weight: bold !important;
border-radius: 6px !important;
}
.log-modal__body {
flex: 1;
overflow: auto;
padding: 15px;
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
font-size: 12px;
line-height: 1.4;
color: var(--dist-log-modal-body-text);
background: var(--dist-log-modal-body-bg);
white-space: pre-wrap;
word-wrap: break-word;
}
.log-modal__status {
padding: 10px 20px;
border-top: 1px solid var(--dist-log-modal-header-border);
font-size: 11px;
color: var(--dist-log-modal-status);
}
================================================
FILE: web/distributedValue.js
================================================
import { app } from "/scripts/app.js";
import { ENDPOINTS } from "./constants.js";
const NODE_CLASS = "DistributedValue";
const CONVERTED_WIDGET = "converted-widget";
const DYNAMIC_DEFAULT_WIDGET = "_dv_default";
const DYNAMIC_WORKER_WIDGET_PREFIX = "_dv_worker_";
const WORKERS_CHANGED_EVENT = "distributed:workers-changed";
const trackedNodes = new Set();
let workersChangedListenerAttached = false;
function filterEnabledWorkers(workers) {
if (!Array.isArray(workers)) return [];
return workers.filter((worker) => Boolean(worker?.enabled));
}
async function fetchWorkers() {
try {
const resp = await fetch(ENDPOINTS.CONFIG);
if (!resp.ok) return [];
const config = await resp.json();
return filterEnabledWorkers(config.workers);
} catch {
return [];
}
}
function getRawDefaultWidget(node) {
return node.widgets?.find((w) => w.name === "default_value");
}
function getRawWorkerValuesWidget(node) {
return node.widgets?.find((w) => w.name === "worker_values");
}
function getDynamicDefaultWidget(node) {
return node.widgets?.find((w) => w.name === DYNAMIC_DEFAULT_WIDGET);
}
function getDynamicWorkerWidgets(node) {
return (node.widgets || []).filter((w) => w.name.startsWith(DYNAMIC_WORKER_WIDGET_PREFIX));
}
function hideWidgetForGood(node, widget, suffix = "") {
if (!widget) return;
if (typeof widget.type === "string" && widget.type.startsWith(CONVERTED_WIDGET)) return;
widget.origType = widget.type;
widget.origComputeSize = widget.computeSize;
widget.origSerializeValue = widget.serializeValue;
widget.computeSize = () => [0, -4];
widget.type = `${CONVERTED_WIDGET}${suffix}`;
// Hide any attached DOM element (multiline widgets).
if (widget.element) widget.element.style.display = "none";
if (widget.inputEl) widget.inputEl.style.display = "none";
if (widget.linkedWidgets) {
for (const linked of widget.linkedWidgets) {
hideWidgetForGood(node, linked, `:${widget.name}`);
}
}
}
function hideRawWidgets(node) {
hideWidgetForGood(node, getRawDefaultWidget(node), ":default_value");
hideWidgetForGood(node, getRawWorkerValuesWidget(node), ":worker_values");
}
function removeDynamicDefaultWidget(node) {
const idx = node.widgets?.findIndex((w) => w.name === DYNAMIC_DEFAULT_WIDGET);
if (idx != null && idx >= 0) {
node.widgets.splice(idx, 1);
}
}
function removeDynamicWorkerWidgets(node) {
if (!node.widgets) return;
for (let i = node.widgets.length - 1; i >= 0; i--) {
if (node.widgets[i].name.startsWith(DYNAMIC_WORKER_WIDGET_PREFIX)) {
node.widgets.splice(i, 1);
}
}
}
function readWorkerStore(node) {
const raw = getRawWorkerValuesWidget(node);
if (!raw) return {};
try {
const parsed = JSON.parse(raw.value || "{}");
return typeof parsed === "object" && parsed !== null ? parsed : {};
} catch {
return {};
}
}
function writeWorkerStore(node, store) {
const raw = getRawWorkerValuesWidget(node);
if (!raw) return;
raw.value = JSON.stringify(store);
}
function normalizeComboOptions(options) {
if (!options) return null;
if (Array.isArray(options)) return options;
if (Array.isArray(options.values)) return options.values;
return null;
}
function resolveGraphLink(graph, linkId) {
const links = graph.links || graph._links;
if (!links) return null;
const link = links[linkId] ?? (typeof links.get === "function" ? links.get(linkId) : null);
if (!link) return null;
if (Array.isArray(link)) {
return {
target_id: link[2],
target_slot: link[3],
};
}
return link;
}
function detectTargetType(node) {
const out = node.outputs?.[0];
const linkIds = out?.links || [];
if (!linkIds.length) {
return { connected: false, type: "STRING", options: null };
}
const graph = node.graph || app.graph;
if (!graph) {
return { connected: false, type: "STRING", options: null };
}
const link = resolveGraphLink(graph, linkIds[0]);
if (!link) {
return { connected: false, type: "STRING", options: null };
}
const targetNode = graph.getNodeById(link.target_id);
if (!targetNode) {
return { connected: false, type: "STRING", options: null };
}
const targetInputName = targetNode.inputs?.[link.target_slot]?.name;
if (!targetInputName) {
return { connected: false, type: "STRING", options: null };
}
const targetWidget = targetNode.widgets?.find((w) => w.name === targetInputName);
if (targetWidget) {
if (targetWidget.type === "combo") {
const comboOptions = normalizeComboOptions(targetWidget.options);
return { connected: true, type: "COMBO", options: comboOptions };
}
if (targetWidget.type === "number") {
const step = targetWidget.options?.step;
const precision = targetWidget.options?.precision;
const isInt = Number.isInteger(step) && (precision === 0 || precision == null);
return { connected: true, type: isInt ? "INT" : "FLOAT", options: null };
}
}
const nodeDef = targetNode.constructor?.nodeData;
const inputDef = nodeDef?.input?.required?.[targetInputName] || nodeDef?.input?.optional?.[targetInputName];
if (inputDef) {
const defType = inputDef[0];
if (Array.isArray(defType)) {
return { connected: true, type: "COMBO", options: defType };
}
if (defType === "INT") return { connected: true, type: "INT", options: null };
if (defType === "FLOAT") return { connected: true, type: "FLOAT", options: null };
}
return { connected: true, type: "STRING", options: null };
}
function normalizeNumber(value, fallback) {
const parsed = Number(value);
return Number.isFinite(parsed) ? parsed : fallback;
}
function getDefaultInitialValue(node, inputType, comboOptions) {
const rawDefault = getRawDefaultWidget(node);
const current = rawDefault?.value;
if (inputType === "INT") {
return Math.trunc(normalizeNumber(current, 0));
}
if (inputType === "FLOAT") {
return normalizeNumber(current, 0);
}
if (inputType === "COMBO" && Array.isArray(comboOptions) && comboOptions.length) {
const currentText = current == null ? "" : String(current);
return comboOptions.includes(currentText) ? currentText : comboOptions[0];
}
return current == null ? "" : String(current);
}
function setRawDefaultValue(node, value) {
const rawDefault = getRawDefaultWidget(node);
if (!rawDefault) return;
rawDefault.value = value;
}
function serializeWorkerStoreFromWidgets(node, inputType, comboOptions) {
const nextStore = { _type: inputType };
if (inputType === "COMBO" && Array.isArray(comboOptions)) {
nextStore._options = comboOptions;
}
const valuesByWorkerId = {};
for (const widget of getDynamicWorkerWidgets(node)) {
const key = widget.name.slice(DYNAMIC_WORKER_WIDGET_PREFIX.length);
if (widget.value !== "" && widget.value !== null && widget.value !== undefined) {
const value = String(widget.value);
nextStore[key] = value;
if (widget._dvWorkerId) {
valuesByWorkerId[widget._dvWorkerId] = value;
}
}
}
if (Object.keys(valuesByWorkerId).length) {
nextStore._by_worker_id = valuesByWorkerId;
}
writeWorkerStore(node, nextStore);
}
function updateWorkerStoreTypeMetadata(node, inputType, comboOptions) {
const store = readWorkerStore(node);
store._type = inputType;
if (inputType === "COMBO" && Array.isArray(comboOptions)) {
store._options = comboOptions;
} else {
delete store._options;
}
writeWorkerStore(node, store);
}
function createDynamicDefaultWidget(node, inputType, comboOptions) {
removeDynamicDefaultWidget(node);
const initial = getDefaultInitialValue(node, inputType, comboOptions);
let widget;
if (inputType === "COMBO" && Array.isArray(comboOptions) && comboOptions.length) {
widget = node.addWidget(
"combo",
DYNAMIC_DEFAULT_WIDGET,
initial,
(value) => {
widget.value = value;
setRawDefaultValue(node, String(value));
},
{ values: comboOptions }
);
} else if (inputType === "INT") {
widget = node.addWidget(
"number",
DYNAMIC_DEFAULT_WIDGET,
initial,
(value) => {
widget.value = Math.trunc(normalizeNumber(value, 0));
setRawDefaultValue(node, widget.value);
},
{ min: -Infinity, max: Infinity, step: 1, precision: 0 }
);
} else if (inputType === "FLOAT") {
widget = node.addWidget(
"number",
DYNAMIC_DEFAULT_WIDGET,
initial,
(value) => {
widget.value = normalizeNumber(value, 0);
setRawDefaultValue(node, widget.value);
},
{ min: -Infinity, max: Infinity, step: 0.1, precision: 3 }
);
} else {
widget = node.addWidget(
"string",
DYNAMIC_DEFAULT_WIDGET,
initial,
(value) => {
widget.value = value ?? "";
setRawDefaultValue(node, widget.value);
},
{}
);
}
widget.label = "default_value";
}
function getWorkerInitialValue(store, key, workerId, inputType, comboOptions) {
const byWorkerId = store?._by_worker_id;
const saved = (byWorkerId && workerId && byWorkerId[workerId] != null)
? byWorkerId[workerId]
: store[key];
if (saved == null) {
if (inputType === "INT" || inputType === "FLOAT") return 0;
if (inputType === "COMBO" && Array.isArray(comboOptions) && comboOptions.length) {
return comboOptions[0];
}
return "";
}
if (inputType === "INT") return Math.trunc(normalizeNumber(saved, 0));
if (inputType === "FLOAT") return normalizeNumber(saved, 0);
if (inputType === "COMBO" && Array.isArray(comboOptions) && comboOptions.length) {
const savedText = String(saved);
return comboOptions.includes(savedText) ? savedText : comboOptions[0];
}
return String(saved);
}
function createWorkerWidgets(node, workers, inputType, comboOptions) {
removeDynamicWorkerWidgets(node);
const store = readWorkerStore(node);
for (let i = 0; i < workers.length; i++) {
const key = String(i + 1);
const worker = workers[i];
const label = worker.name || worker.id || `Worker ${key}`;
const widgetName = `${DYNAMIC_WORKER_WIDGET_PREFIX}${key}`;
const initial = getWorkerInitialValue(store, key, worker.id, inputType, comboOptions);
let widget;
if (inputType === "COMBO" && Array.isArray(comboOptions) && comboOptions.length) {
widget = node.addWidget(
"combo",
widgetName,
initial,
(value) => {
widget.value = value;
serializeWorkerStoreFromWidgets(node, inputType, comboOptions);
},
{ values: comboOptions }
);
} else if (inputType === "INT") {
widget = node.addWidget(
"number",
widgetName,
initial,
(value) => {
widget.value = Math.trunc(normalizeNumber(value, 0));
serializeWorkerStoreFromWidgets(node, inputType, comboOptions);
},
{ min: -Infinity, max: Infinity, step: 1, precision: 0 }
);
} else if (inputType === "FLOAT") {
widget = node.addWidget(
"number",
widgetName,
initial,
(value) => {
widget.value = normalizeNumber(value, 0);
serializeWorkerStoreFromWidgets(node, inputType, comboOptions);
},
{ min: -Infinity, max: Infinity, step: 0.1, precision: 3 }
);
} else {
widget = node.addWidget(
"string",
widgetName,
initial,
(value) => {
widget.value = value ?? "";
serializeWorkerStoreFromWidgets(node, inputType, comboOptions);
},
{}
);
}
widget.label = label;
widget._dvWorkerId = worker.id;
}
serializeWorkerStoreFromWidgets(node, inputType, comboOptions);
}
function rebuildWidgets(node) {
hideRawWidgets(node);
const workers = node._dvWorkers || [];
const store = readWorkerStore(node);
const detected = detectTargetType(node);
const disconnected = !detected.connected;
const inputType = disconnected ? "STRING" : detected.type;
const comboOptions = disconnected ? null : detected.options;
if (disconnected) {
// Reset disconnected node back to the neutral default state.
setRawDefaultValue(node, "");
writeWorkerStore(node, { _type: "STRING" });
}
createDynamicDefaultWidget(node, inputType, comboOptions);
if (workers.length > 0) {
createWorkerWidgets(node, workers, inputType, comboOptions);
} else {
removeDynamicWorkerWidgets(node);
updateWorkerStoreTypeMetadata(node, inputType, comboOptions);
}
const size = node.computeSize();
size[0] = Math.max(size[0], node.size?.[0] || 0);
node.setSize(size);
if (node.setDirtyCanvas) node.setDirtyCanvas(true, true);
}
function refreshNodeWorkers(node, workers) {
if (!node || !node.graph) return;
node._dvWorkers = workers;
rebuildWidgets(node);
}
async function refreshTrackedNodes(workers = null) {
const nextWorkers = workers || (await fetchWorkers());
for (const node of trackedNodes) {
refreshNodeWorkers(node, nextWorkers);
}
}
function attachWorkersChangedListener() {
if (workersChangedListenerAttached) return;
if (typeof window === "undefined" || typeof window.addEventListener !== "function") return;
window.addEventListener(WORKERS_CHANGED_EVENT, (event) => {
const changedWorkers = filterEnabledWorkers(event?.detail?.workers);
if (changedWorkers.length > 0 || Array.isArray(event?.detail?.workers)) {
void refreshTrackedNodes(changedWorkers);
return;
}
void refreshTrackedNodes();
});
workersChangedListenerAttached = true;
}
app.registerExtension({
name: "Distributed.DistributedValue",
async nodeCreated(node) {
if (node.comfyClass !== NODE_CLASS) return;
try {
attachWorkersChangedListener();
trackedNodes.add(node);
node._dvWorkers = await fetchWorkers();
rebuildWidgets(node);
const originalOnConnectionsChange = node.onConnectionsChange;
node.onConnectionsChange = function (type, index, connected, linkInfo, ioSlot) {
if (originalOnConnectionsChange) {
originalOnConnectionsChange.call(this, type, index, connected, linkInfo, ioSlot);
}
if (type === 2 && index === 0) {
setTimeout(() => rebuildWidgets(this), 20);
}
};
const originalConfigure = node.configure;
node.configure = function (data) {
const result = originalConfigure ? originalConfigure.call(this, data) : undefined;
setTimeout(() => rebuildWidgets(this), 20);
return result;
};
const originalOnRemoved = node.onRemoved;
node.onRemoved = function () {
trackedNodes.delete(this);
if (originalOnRemoved) {
return originalOnRemoved.call(this);
}
};
} catch (error) {
console.error("Error in DistributedValue extension:", error);
}
},
});
================================================
FILE: web/executionUtils.js
================================================
import { api } from "../../scripts/api.js";
import { applyProbeResultToWorkerDot, findNodesByClass } from './workerUtils.js';
import { TIMEOUTS, NODE_CLASSES, generateUUID } from './constants.js';
import { checkAllWorkerStatuses, getWorkerUrl } from './workerLifecycle.js';
export function setupInterceptor(extension) {
api.queuePrompt = async (number, prompt, ...rest) => {
if (extension.isEnabled) {
const hasCollector = findNodesByClass(prompt.output, NODE_CLASSES.DISTRIBUTED_COLLECTOR).length > 0;
const hasDistUpscale = findNodesByClass(prompt.output, NODE_CLASSES.UPSCALE_DISTRIBUTED).length > 0;
if (hasCollector || hasDistUpscale) {
const result = await executeParallelDistributed(extension, prompt);
// Immediate status check for instant feedback
checkAllWorkerStatuses(extension);
// Another check after a short delay to catch state changes
setTimeout(() => checkAllWorkerStatuses(extension), TIMEOUTS.POST_ACTION_DELAY);
return result;
}
}
return extension.originalQueuePrompt(number, prompt, ...rest);
};
}
export async function executeParallelDistributed(extension, promptWrapper) {
const traceExecutionId = `exec_${Date.now()}_${generateUUID().slice(0, 6)}`;
try {
const enabledWorkers = extension.enabledWorkers;
extension.log(`[exec:${traceExecutionId}] Starting distributed execution`, "debug");
// Pre-flight health check on all enabled workers
const activeWorkers = await performPreflightCheck(extension, enabledWorkers);
// Case: Enabled workers but all offline
if (activeWorkers.length === 0 && enabledWorkers.length > 0) {
extension.log("No active workers found. All enabled workers are offline.");
if (extension.ui?.showToast) {
extension.ui.showToast(extension.app, "error", "All Workers Offline",
`${enabledWorkers.length} worker(s) enabled but all are offline or unreachable. Check worker connections and try again.`, 5000);
}
// Fall back to master-only execution
return extension.originalQueuePrompt(0, promptWrapper);
}
extension.log(`Pre-flight check: ${activeWorkers.length} of ${enabledWorkers.length} workers are active`, "debug");
// Check if master host might be unreachable by workers (cloudflare tunnel down)
const masterHost = extension.config?.master?.host || '';
const isCloudflareHost = /\.(trycloudflare\.com|cloudflare\.dev)$/i.test(masterHost);
if (isCloudflareHost && activeWorkers.length > 0) {
// Try to verify if the cloudflare tunnel is actually up
try {
const testUrl = `${window.location.protocol}//${masterHost}/prompt`;
const response = await fetch(testUrl, {
method: 'GET',
mode: 'cors',
cache: 'no-cache',
signal: AbortSignal.timeout(3000) // 3 second timeout
});
if (!response.ok) {
throw new Error('Master not reachable');
}
} catch (error) {
// Cloudflare tunnel appears to be down
extension.log(`Master host ${masterHost} is not reachable - cloudflare tunnel may be down`, "error");
if (extension.ui?.showCloudflareWarning) {
extension.ui.showCloudflareWarning(extension, masterHost);
}
// Stop execution - workers won't be able to send results back
extension.log("Blocking execution - workers cannot reach master at cloudflare domain", "error");
return null; // This will prevent the workflow from running
}
}
const queueResponse = await extension.api.queueDistributed({
prompt: promptWrapper.output,
workflow: promptWrapper.workflow,
enabled_worker_ids: activeWorkers.map((worker) => worker.id),
workers: activeWorkers.map((worker) => ({ id: worker.id })),
client_id: api.clientId,
delegate_master: Boolean(extension.config?.settings?.master_delegate_only),
auto_prepare: true,
trace_execution_id: traceExecutionId,
});
if (queueResponse?.prompt_id) {
extension.log(
`[exec:${traceExecutionId}] Distributed queue accepted by backend (prompt_id=${queueResponse.prompt_id}, workers=${queueResponse.worker_count ?? activeWorkers.length})`,
"debug"
);
return queueResponse;
}
throw new Error(
`[exec:${traceExecutionId}] Backend did not return a prompt_id for distributed queue.`
);
} catch (error) {
extension.log(`[exec:${traceExecutionId}] Distributed execution failed: ${error.message}`, "error");
if (extension.ui?.showToast) {
extension.ui.showToast(extension.app, "error", "Distributed Failed", error.message, 5000);
}
return null;
}
}
export async function performPreflightCheck(extension, workers) {
if (workers.length === 0) return [];
extension.log(`Performing pre-flight health check on ${workers.length} workers...`, "debug");
const startTime = Date.now();
const checkPromises = workers.map(async (worker) => {
const workerUrl = getWorkerUrl(extension, worker);
extension.log(`Pre-flight checking ${worker.name} at: ${workerUrl}`, "debug");
try {
const probeResult = await extension.api.probeWorker(workerUrl, TIMEOUTS.STATUS_CHECK);
if (probeResult.ok) {
extension.log(`Worker ${worker.name} is active`, "debug");
return { worker, active: true };
} else {
extension.log(`Worker ${worker.name} returned ${probeResult.status}`, "debug");
return { worker, active: false };
}
} catch (error) {
if (error?.name === 'AbortError') {
extension.log(`Worker ${worker.name} pre-flight check timed out; assuming active`, "debug");
return { worker, active: true, uncertain: true };
}
extension.log(`Worker ${worker.name} is offline or unreachable: ${error.message}`, "debug");
return { worker, active: false };
}
});
const results = await Promise.all(checkPromises);
const activeWorkers = results.filter(r => r.active).map(r => r.worker);
const elapsed = Date.now() - startTime;
extension.log(`Pre-flight check completed in ${elapsed}ms. Active workers: ${activeWorkers.length}/${workers.length}`, "debug");
// Update UI status indicators for inactive workers
results.filter(r => !r.active).forEach(r => {
applyProbeResultToWorkerDot(r.worker.id, { ok: false });
});
return activeWorkers;
}
================================================
FILE: web/image_batch_divider.js
================================================
import { app } from "/scripts/app.js";
// Configuration for each batch divider node type
const BATCH_DIVIDER_NODES = {
"ImageBatchDivider": { outputPrefix: "batch_", outputType: "IMAGE" },
"AudioBatchDivider": { outputPrefix: "audio_", outputType: "AUDIO" }
};
app.registerExtension({
name: "Distributed.BatchDividers",
async nodeCreated(node) {
const config = BATCH_DIVIDER_NODES[node.comfyClass];
if (!config) return;
try {
const updateOutputs = () => {
if (!node.widgets) return;
const divideByWidget = node.widgets.find(w => w.name === "divide_by");
if (!divideByWidget) return;
const divideBy = parseInt(divideByWidget.value, 10) || 1;
const totalOutputs = divideBy;
// Ensure outputs array exists
if (!node.outputs) node.outputs = [];
// Remove excess outputs
while (node.outputs.length > totalOutputs) {
node.removeOutput(node.outputs.length - 1);
}
// Add missing outputs
while (node.outputs.length < totalOutputs) {
const outputIndex = node.outputs.length + 1;
node.addOutput(`${config.outputPrefix}${outputIndex}`, config.outputType);
}
if (node.setDirty) node.setDirty(true);
};
// Initial update with delay to allow workflow loading
setTimeout(updateOutputs, 200);
// Find the widget and set up responsive handlers
const divideByWidget = node.widgets.find(w => w.name === "divide_by");
if (divideByWidget) {
const originalCallback = divideByWidget.callback;
divideByWidget.callback = (value) => {
updateOutputs();
if (originalCallback) originalCallback.call(divideByWidget, value);
};
if (divideByWidget.inputEl) {
divideByWidget.inputEl.addEventListener('input', updateOutputs);
}
const observer = new MutationObserver(updateOutputs);
if (divideByWidget.element) {
observer.observe(divideByWidget.element, { attributes: true, childList: true, subtree: true });
}
node._batchDividerCleanup = () => {
observer.disconnect();
if (divideByWidget.inputEl) {
divideByWidget.inputEl.removeEventListener('input', updateOutputs);
}
divideByWidget.callback = originalCallback;
};
}
const originalConfigure = node.configure;
node.configure = function(data) {
const result = originalConfigure ? originalConfigure.call(this, data) : undefined;
updateOutputs();
return result;
};
} catch (error) {
console.error(`Error in ${node.comfyClass} extension:`, error);
}
},
nodeBeforeRemove(node) {
if (BATCH_DIVIDER_NODES[node.comfyClass] && node._batchDividerCleanup) {
node._batchDividerCleanup();
}
}
});
================================================
FILE: web/main.js
================================================
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
import { DistributedUI } from './ui.js';
import { createStateManager } from './stateManager.js';
import { createApiClient } from './apiClient.js';
import { renderSidebarContent, updateWorkerCard } from './sidebarRenderer.js';
import { handleInterruptWorkers, handleClearMemory } from './workerUtils.js';
import { setupInterceptor } from './executionUtils.js';
import { PULSE_ANIMATION_CSS, TIMEOUTS, STATUS_COLORS } from './constants.js';
import { updateTunnelUIElements, refreshTunnelStatus, handleTunnelToggle } from './tunnelManager.js';
import { checkAllWorkerStatuses, checkWorkerStatus, loadManagedWorkers } from './workerLifecycle.js';
import { detectMasterIP } from './masterDetection.js';
import { parseHostInput, getMasterUrl as buildMasterUrl } from './urlUtils.js';
const WORKERS_CHANGED_EVENT = "distributed:workers-changed";
class DistributedExtension {
constructor() {
this.config = null;
this.originalQueuePrompt = api.queuePrompt.bind(api);
this.logAutoRefreshInterval = null;
this.masterSettingsExpanded = false;
this.app = app; // Store app reference for toast notifications
this.tunnelStatus = { status: "unknown" };
this.tunnelElements = {};
// Initialize centralized state
this.state = createStateManager();
// Initialize UI component factory
this.ui = new DistributedUI();
// Initialize API client
this.api = createApiClient(window.location.origin);
// Initialize status check timeout reference
this.statusCheckTimeout = null;
// Initialize abort controller for status checks
this.statusCheckAbortController = null;
this.themeMutationObserver = null;
// Inject CSS for pulsing animation
this.injectStyles();
this.loadConfig().then(async () => {
this.registerSidebarTab();
this.setupInterceptor();
// Don't start polling until panel opens
// this.startStatusChecking();
loadManagedWorkers(this);
// Detect master IP after everything is set up
this.detectMasterIP();
// Listen for Nodes 2.0 setting changes (once, for the lifetime of the extension)
this._setupNodes2Listener();
});
}
// Debug logging helpers
log(message, level = "info") {
if (level === "debug" && !this.config?.settings?.debug) return;
if (level === "error") {
console.error(`[Distributed] ${message}`);
} else {
console.log(`[Distributed] ${message}`);
}
}
injectStyles() {
const styleId = 'distributed-styles';
if (!document.getElementById(styleId)) {
const style = document.createElement('style');
style.id = styleId;
style.textContent = PULSE_ANIMATION_CSS;
document.head.appendChild(style);
}
const fileStyleId = 'distributed-file-styles';
if (!document.getElementById(fileStyleId)) {
const style = document.createElement('style');
style.id = fileStyleId;
fetch(new URL('./distributed.css', import.meta.url))
.then((response) => response.text())
.then((cssText) => {
style.textContent = cssText;
})
.catch((error) => {
this.log(`Failed to load distributed.css: ${error.message}`, "error");
});
document.head.appendChild(style);
}
}
// --- State & Config Management (Single Source of Truth) ---
get enabledWorkers() {
return this.config?.workers?.filter(w => w.enabled) || [];
}
get isEnabled() {
return this.enabledWorkers.length > 0;
}
isMasterParticipationEnabled() {
return !Boolean(this.config?.settings?.master_delegate_only);
}
isMasterFallbackActive() {
return Boolean(this.config?.settings?.master_delegate_only) && this.enabledWorkers.length === 0;
}
isMasterParticipating() {
return this.isMasterParticipationEnabled() || this.isMasterFallbackActive();
}
async updateMasterParticipation(enabled) {
if (!this.config?.settings) {
this.config.settings = {};
}
const delegateOnly = !enabled;
if (this.config.settings.master_delegate_only === delegateOnly) {
return;
}
await this._updateSetting('master_delegate_only', delegateOnly);
if (this.panelElement) {
renderSidebarContent(this, this.panelElement);
}
}
async loadConfig() {
try {
this.config = await this.api.getConfig();
this.log("Loaded config: " + JSON.stringify(this.config), "debug");
// Ensure default flag values
if (!this.config.settings) {
this.config.settings = {};
}
if (this.config.settings.has_auto_populated_workers === undefined) {
this.config.settings.has_auto_populated_workers = false;
}
// Load stored master CUDA device
this.masterCudaDevice = this.config?.master?.cuda_device ?? undefined;
// Sync to state
if (this.config.workers) {
this.config.workers.forEach(w => {
this.state.updateWorker(w.id, { enabled: w.enabled });
});
}
this._emitWorkersChanged();
} catch (error) {
this.log("Failed to load config: " + error.message, "error");
this.config = { workers: [], settings: { has_auto_populated_workers: false } };
}
}
_emitWorkersChanged() {
if (typeof window === "undefined" || typeof window.dispatchEvent !== "function") {
return;
}
window.dispatchEvent(new CustomEvent(WORKERS_CHANGED_EVENT, {
detail: { workers: this.config?.workers || [] },
}));
}
_applyMasterHost(host) {
if (!host || !this.config) return;
if (!this.config.master) this.config.master = {};
this.config.master.host = host;
const hostInput = document.getElementById('master-host');
if (hostInput) {
hostInput.value = host;
}
}
_parseHostInput(value) {
return parseHostInput(value);
}
updateTunnelUIElements(isRunning, isStarting) {
return updateTunnelUIElements(this, isRunning, isStarting);
}
async refreshTunnelStatus() {
return refreshTunnelStatus(this);
}
async handleTunnelToggle(button) {
return handleTunnelToggle(this, button);
}
async updateWorkerEnabled(workerId, enabled) {
const worker = this.config.workers.find(w => w.id === workerId);
if (worker) {
worker.enabled = enabled;
this.state.updateWorker(workerId, { enabled });
this._emitWorkersChanged();
// Immediately update status dot based on enabled state
const statusDot = document.getElementById(`status-${workerId}`);
if (statusDot) {
if (enabled) {
// Enabled: Start with checking state and trigger check
this.ui.updateStatusDot(workerId, STATUS_COLORS.OFFLINE_RED, "Checking status...", false);
setTimeout(() => checkWorkerStatus(this, worker), TIMEOUTS.STATUS_CHECK_DELAY);
} else {
// Disabled: Set to gray
this.ui.updateStatusDot(workerId, STATUS_COLORS.DISABLED_GRAY, "Disabled", false);
}
}
}
try {
await this.api.updateWorker(workerId, { enabled });
} catch (error) {
this.log("Error updating worker: " + error.message, "error");
}
if (this.panelElement) {
await renderSidebarContent(this, this.panelElement);
}
}
async _updateSetting(key, value) {
// Update local config
if (!this.config.settings) {
this.config.settings = {};
}
this.config.settings[key] = value;
try {
await this.api.updateSetting(key, value);
const prettyKey = key.replace(/_/g, ' ').replace(/\b\w/g, l => l.toUpperCase());
let detail;
if (key === 'worker_timeout_seconds') {
const secs = parseInt(value, 10);
detail = `Worker Timeout set to ${Number.isFinite(secs) ? secs : value}s`;
} else if (typeof value === 'boolean') {
detail = `${prettyKey} ${value ? 'enabled' : 'disabled'}`;
} else {
detail = `${prettyKey} set to ${value}`;
}
app.extensionManager.toast.add({
severity: "success",
summary: "Setting Updated",
detail,
life: 2000
});
} catch (error) {
this.log(`Error updating setting '${key}': ${error.message}`, "error");
app.extensionManager.toast.add({
severity: "error",
summary: "Setting Update Failed",
detail: error.message,
life: 3000
});
}
}
// --- UI Rendering ---
registerSidebarTab() {
app.extensionManager.registerSidebarTab({
id: "distributed",
icon: "pi pi-server",
title: "Distributed",
tooltip: "Distributed Control Panel",
type: "custom",
render: (el) => {
this.panelElement = el;
this.onPanelOpen();
renderSidebarContent(this, el);
this._applyNodes2Style();
this._applyThemeToneClass();
},
destroy: () => {
this.onPanelClose();
}
});
}
onPanelOpen() {
this.log("Panel opened - starting status polling", "debug");
if (!this.statusCheckTimeout) {
checkAllWorkerStatuses(this);
}
this._startThemeObserver();
this._applyThemeToneClass();
}
onPanelClose() {
this.log("Panel closed - stopping status polling", "debug");
// Cancel any pending status checks
if (this.statusCheckAbortController) {
this.statusCheckAbortController.abort();
this.statusCheckAbortController = null;
}
// Clear the timeout
if (this.statusCheckTimeout) {
clearTimeout(this.statusCheckTimeout);
this.statusCheckTimeout = null;
}
this._stopThemeObserver();
this.panelElement = null;
}
_applyNodes2Style() {
if (!this.panelElement) return;
const enabled = app.ui.settings.getSettingValue("Comfy.VueNodes.Enabled") ?? false;
this.panelElement.classList.toggle('distributed-panel--nodes2', Boolean(enabled));
}
_parseColorToRgba(colorValue) {
if (!colorValue || typeof colorValue !== "string") {
return null;
}
const color = colorValue.trim().toLowerCase();
if (!color || color === "transparent") {
return null;
}
const rgbMatch = color.match(/^rgba?\(([^)]+)\)$/);
if (rgbMatch) {
const parts = rgbMatch[1].split(",").map((part) => Number(part.trim()));
if (parts.length < 3 || parts.slice(0, 3).some((part) => Number.isNaN(part))) {
return null;
}
const alpha = parts.length >= 4 && Number.isFinite(parts[3]) ? parts[3] : 1;
return {
r: Math.max(0, Math.min(255, parts[0])),
g: Math.max(0, Math.min(255, parts[1])),
b: Math.max(0, Math.min(255, parts[2])),
a: Math.max(0, Math.min(1, alpha)),
};
}
const hexMatch = color.match(/^#([0-9a-f]{3}|[0-9a-f]{6})$/i);
if (hexMatch) {
const value = hexMatch[1];
const expanded = value.length === 3
? value.split("").map((c) => `${c}${c}`).join("")
: value;
const r = parseInt(expanded.slice(0, 2), 16);
const g = parseInt(expanded.slice(2, 4), 16);
const b = parseInt(expanded.slice(4, 6), 16);
return { r, g, b, a: 1 };
}
return null;
}
_isPanelLightTheme() {
const fallbackLight = window.matchMedia?.("(prefers-color-scheme: light)")?.matches || false;
if (!this.panelElement) {
return fallbackLight;
}
let current = this.panelElement;
while (current) {
const bg = getComputedStyle(current).backgroundColor;
const rgba = this._parseColorToRgba(bg);
if (rgba && rgba.a > 0.02) {
// Relative luminance approximation (0..1)
const luminance = (0.2126 * rgba.r + 0.7152 * rgba.g + 0.0722 * rgba.b) / 255;
return luminance > 0.58;
}
current = current.parentElement;
}
return fallbackLight;
}
_applyThemeToneClass() {
if (!this.panelElement) {
return;
}
this.panelElement.classList.toggle("distributed-panel--light", this._isPanelLightTheme());
}
_startThemeObserver() {
if (this.themeMutationObserver) {
return;
}
this.themeMutationObserver = new MutationObserver(() => {
this._applyThemeToneClass();
});
this.themeMutationObserver.observe(document.documentElement, {
attributes: true,
attributeFilter: ["class", "style"],
});
if (document.body) {
this.themeMutationObserver.observe(document.body, {
attributes: true,
attributeFilter: ["class", "style"],
});
}
}
_stopThemeObserver() {
if (!this.themeMutationObserver) {
return;
}
this.themeMutationObserver.disconnect();
this.themeMutationObserver = null;
}
_setupNodes2Listener() {
app.ui.settings.addEventListener("Comfy.VueNodes.Enabled.change", (e) => {
const enabled = e.detail?.value ?? false;
if (this.panelElement) {
this.panelElement.classList.toggle('distributed-panel--nodes2', Boolean(enabled));
this._applyThemeToneClass();
}
});
}
// --- Core Logic & Execution ---
setupInterceptor() {
setupInterceptor(this);
}
updateWorkerCard(workerId, newStatus) {
return updateWorkerCard(this, workerId, newStatus);
}
/**
* Cleanup method to stop intervals and listeners
*/
cleanup() {
if (this.logAutoRefreshInterval) {
clearInterval(this.logAutoRefreshInterval);
this.logAutoRefreshInterval = null;
}
if (this.statusCheckTimeout) {
clearTimeout(this.statusCheckTimeout);
this.statusCheckTimeout = null;
}
this.log("Cleaned up intervals", "debug");
}
getMasterUrl() {
return buildMasterUrl(this.config, window.location, (message, level) => this.log(message, level));
}
async detectMasterIP() {
return detectMasterIP(this);
}
_handleInterruptWorkers(button) {
return handleInterruptWorkers(this, button);
}
_handleClearMemory(button) {
return handleClearMemory(this, button);
}
}
app.registerExtension({
name: "Distributed.Panel",
async setup() {
new DistributedExtension();
}
});
================================================
FILE: web/masterDetection.js
================================================
import { generateUUID } from './constants.js';
export async function detectMasterIP(extension) {
try {
const isRunpod = window.location.hostname.endsWith('.proxy.runpod.net');
if (isRunpod) {
extension.log("Detected Runpod environment", "info");
}
const data = await extension.api.getNetworkInfo();
extension.log("Network info: " + JSON.stringify(data), "debug");
if (data.cuda_device !== null && data.cuda_device !== undefined) {
extension.masterCudaDevice = data.cuda_device;
if (!extension.config.master) {
extension.config.master = {};
}
if (extension.config.master.cuda_device === undefined || extension.config.master.cuda_device !== data.cuda_device) {
extension.config.master.cuda_device = data.cuda_device;
try {
await extension.api.updateMaster({ cuda_device: data.cuda_device });
extension.log(`Stored master CUDA device: ${data.cuda_device}`, "debug");
} catch (error) {
extension.log(`Error storing master CUDA device: ${error.message}`, "error");
}
}
extension.ui.updateMasterDisplay(extension);
}
if (data.cuda_device_count > 0) {
extension.cudaDeviceCount = data.cuda_device_count;
extension.log(`Detected ${extension.cudaDeviceCount} CUDA devices`, "info");
const shouldAutoPopulate =
!extension.config.settings.has_auto_populated_workers &&
(!extension.config.workers || extension.config.workers.length === 0);
extension.log(`Auto-population check: has_populated=${extension.config.settings.has_auto_populated_workers}, workers=${extension.config.workers ? extension.config.workers.length : 'null'}, should_populate=${shouldAutoPopulate}`, "debug");
if (shouldAutoPopulate) {
extension.log(`Auto-populating workers based on ${extension.cudaDeviceCount} CUDA devices (excluding master on CUDA ${extension.masterCudaDevice})`, "info");
const newWorkers = [];
let workerNum = 1;
let portOffset = 0;
for (let i = 0; i < extension.cudaDeviceCount; i++) {
if (i === extension.masterCudaDevice) {
extension.log(`Skipping CUDA ${i} (used by master)`, "debug");
continue;
}
const worker = {
id: generateUUID(),
name: `Worker ${workerNum}`,
host: isRunpod ? null : "localhost",
port: 8189 + portOffset,
cuda_device: i,
enabled: true,
extra_args: isRunpod ? "--listen" : "",
};
newWorkers.push(worker);
workerNum += 1;
portOffset += 1;
}
if (newWorkers.length > 0) {
extension.log(`Auto-populating ${newWorkers.length} workers`, "info");
extension.config.workers = newWorkers;
extension.config.settings.has_auto_populated_workers = true;
for (const worker of newWorkers) {
try {
await extension.api.updateWorker(worker.id, worker);
} catch (error) {
extension.log(`Error saving worker ${worker.name}: ${error.message}`, "error");
}
}
try {
await extension.api.updateSetting('has_auto_populated_workers', true);
} catch (error) {
extension.log(`Error saving auto-population flag: ${error.message}`, "error");
}
extension.log(`Auto-populated ${newWorkers.length} workers and saved config`, "info");
if (extension.app.extensionManager?.toast) {
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Workers Auto-populated",
detail: `Automatically created ${newWorkers.length} workers based on detected CUDA devices`,
life: 5000,
});
}
await extension.loadConfig();
} else {
extension.log("No additional CUDA devices available for workers (all used by master)", "debug");
}
}
}
if (extension.config?.master?.host) {
extension.log(`Master host already configured: ${extension.config.master.host}`, "debug");
return;
}
if (isRunpod) {
const runpodHost = window.location.hostname;
extension.log(`Setting Runpod master host: ${runpodHost}`, "info");
await extension.api.updateMaster({ host: runpodHost });
if (!extension.config.master) {
extension.config.master = {};
}
extension.config.master.host = runpodHost;
if (extension.app.extensionManager?.toast) {
extension.app.extensionManager.toast.add({
severity: "info",
summary: "Runpod Auto-Configuration",
detail: `Master host set to ${runpodHost} with --listen flag for workers`,
life: 5000,
});
}
return;
}
if (data.recommended_ip && data.recommended_ip !== '127.0.0.1') {
extension.log(`Auto-detected master IP: ${data.recommended_ip}`, "info");
await extension.api.updateMaster({ host: data.recommended_ip });
if (!extension.config.master) {
extension.config.master = {};
}
extension.config.master.host = data.recommended_ip;
}
} catch (error) {
extension.log("Error detecting master IP: " + error.message, "error");
}
}
================================================
FILE: web/sidebar/actionsSection.js
================================================
import { BUTTON_STYLES } from "../constants.js";
export function renderActionsSection(extension) {
const actionsSection = document.createElement("div");
actionsSection.style.cssText =
"padding-top: 10px; margin-bottom: 15px; border-top: 1px solid var(--dist-divider, #444);";
const buttonRow = document.createElement("div");
buttonRow.style.cssText = "display: flex; gap: 8px;";
const clearMemButton = extension.ui.createButtonHelper(
"Clear Worker VRAM",
(event) => extension._handleClearMemory(event.target),
BUTTON_STYLES.clearMemory
);
clearMemButton.title = "Clear VRAM on all enabled worker GPUs (not master)";
clearMemButton.style.cssText = BUTTON_STYLES.base + " flex: 1;" + BUTTON_STYLES.clearMemory;
const interruptButton = extension.ui.createButtonHelper(
"Interrupt Workers",
(event) => extension._handleInterruptWorkers(event.target),
BUTTON_STYLES.interrupt
);
interruptButton.title = "Cancel/interrupt execution on all enabled worker GPUs";
interruptButton.style.cssText = BUTTON_STYLES.base + " flex: 1;" + BUTTON_STYLES.interrupt;
buttonRow.appendChild(clearMemButton);
buttonRow.appendChild(interruptButton);
actionsSection.appendChild(buttonRow);
return actionsSection;
}
================================================
FILE: web/sidebar/settingsSection.js
================================================
import { createCheckboxSetting, createNumberSetting } from "../ui/buttonHelpers.js";
export function renderSettingsSection(extension) {
const settingsSection = document.createElement("div");
settingsSection.style.cssText = "border-top: 1px solid var(--dist-divider, #444); margin-bottom: 10px;";
const settingsToggleArea = document.createElement("div");
settingsToggleArea.style.cssText = "padding: 16.5px 0; cursor: pointer; user-select: none;";
const settingsHeader = document.createElement("div");
settingsHeader.style.cssText = "display: flex; align-items: center; justify-content: space-between;";
const workerSettingsTitle = document.createElement("h4");
workerSettingsTitle.textContent = "Settings";
workerSettingsTitle.style.cssText = "margin: 0; font-size: 14px;";
const workerSettingsToggle = document.createElement("span");
workerSettingsToggle.textContent = "▶";
workerSettingsToggle.style.cssText =
"font-size: 12px; color: var(--dist-settings-arrow, #888); transition: all 0.2s ease;";
settingsHeader.appendChild(workerSettingsTitle);
settingsHeader.appendChild(workerSettingsToggle);
settingsToggleArea.appendChild(settingsHeader);
settingsToggleArea.onmouseover = () => {
workerSettingsToggle.style.color = "var(--dist-settings-arrow-hover, #fff)";
};
settingsToggleArea.onmouseout = () => {
workerSettingsToggle.style.color = "var(--dist-settings-arrow, #888)";
};
const settingsSeparator = document.createElement("div");
settingsSeparator.style.cssText = "border-bottom: 1px solid var(--dist-divider, #444); margin: 0;";
const settingsContent = document.createElement("div");
settingsContent.style.cssText =
"max-height: 0; overflow: hidden; opacity: 0; transition: max-height 0.3s ease, opacity 0.3s ease;";
const settingsDiv = document.createElement("div");
settingsDiv.style.cssText =
"display: grid; grid-template-columns: 1fr auto; row-gap: 10px; column-gap: 10px; padding-top: 10px; align-items: center;";
let settingsExpanded = false;
settingsToggleArea.onclick = () => {
settingsExpanded = !settingsExpanded;
if (settingsExpanded) {
settingsContent.style.maxHeight = "200px";
settingsContent.style.opacity = "1";
workerSettingsToggle.style.transform = "rotate(90deg)";
settingsSeparator.style.display = "none";
} else {
settingsContent.style.maxHeight = "0";
settingsContent.style.opacity = "0";
workerSettingsToggle.style.transform = "rotate(0deg)";
settingsSeparator.style.display = "block";
}
};
const generalLabel = document.createElement("div");
generalLabel.textContent = "GENERAL";
generalLabel.style.cssText =
"grid-column: 1 / span 2; font-size: 11px; color: var(--dist-muted-text, #888); letter-spacing: 0.06em; padding-top: 2px;";
const timeoutsLabel = document.createElement("div");
timeoutsLabel.textContent = "TIMEOUTS";
timeoutsLabel.style.cssText =
"grid-column: 1 / span 2; font-size: 11px; color: var(--dist-muted-text, #888); letter-spacing: 0.06em; padding-top: 4px;";
settingsDiv.appendChild(generalLabel);
settingsDiv.appendChild(
createCheckboxSetting(
"setting-debug",
"Debug Mode",
"Enable verbose logging in the browser console.",
extension.config?.settings?.debug || false,
(event) => extension._updateSetting("debug", event.target.checked)
)
);
settingsDiv.appendChild(
createCheckboxSetting(
"setting-auto-launch",
"Auto-launch Local Workers on Startup",
"Start local worker processes automatically when the master starts.",
extension.config?.settings?.auto_launch_workers || false,
(event) => extension._updateSetting("auto_launch_workers", event.target.checked)
)
);
settingsDiv.appendChild(
createCheckboxSetting(
"setting-stop-on-exit",
"Stop Local Workers on Master Exit",
"Stop local worker processes automatically when the master exits.",
extension.config?.settings?.stop_workers_on_master_exit !== false,
(event) => extension._updateSetting("stop_workers_on_master_exit", event.target.checked)
)
);
settingsDiv.appendChild(timeoutsLabel);
settingsDiv.appendChild(
createNumberSetting(
"setting-worker-timeout",
"Worker Timeout",
"Seconds without a heartbeat before a worker is considered timed out. Default 60.",
extension.config?.settings?.worker_timeout_seconds ?? 60,
10,
1,
(event) => {
const value = parseInt(event.target.value, 10);
if (!Number.isFinite(value) || value <= 0) {
return;
}
extension._updateSetting("worker_timeout_seconds", value);
}
)
);
settingsContent.appendChild(settingsDiv);
settingsSection.appendChild(settingsToggleArea);
settingsSection.appendChild(settingsSeparator);
settingsSection.appendChild(settingsContent);
return settingsSection;
}
================================================
FILE: web/sidebar/workersSection.js
================================================
import { addNewWorker } from "../workerSettings.js";
export function renderWorkersSection(extension) {
const workersSection = document.createElement("div");
workersSection.style.cssText = "flex: 1; overflow-y: auto; margin-bottom: 15px;";
const workersList = document.createElement("div");
const workers = extension.config?.workers || [];
if (workers.length === 0) {
const blueprintDiv = extension.ui.renderEntityCard(
"blueprint",
{ onClick: () => addNewWorker(extension) },
extension
);
workersList.appendChild(blueprintDiv);
}
workers.forEach((worker) => {
const workerCard = extension.ui.renderEntityCard("worker", worker, extension);
workersList.appendChild(workerCard);
});
workersSection.appendChild(workersList);
if (workers.length > 0) {
const addWorkerDiv = extension.ui.renderEntityCard(
"add",
{ onClick: () => addNewWorker(extension) },
extension
);
workersSection.appendChild(addWorkerDiv);
}
return workersSection;
}
================================================
FILE: web/sidebarRenderer.js
================================================
import { STATUS_COLORS } from './constants.js';
import { checkAllWorkerStatuses, loadManagedWorkers, updateWorkerControls } from './workerLifecycle.js';
import { renderActionsSection } from './sidebar/actionsSection.js';
import { renderSettingsSection } from './sidebar/settingsSection.js';
import { renderWorkersSection } from './sidebar/workersSection.js';
export function updateWorkerCard(extension, workerId, newStatus = {}) {
const card = document.querySelector(`[data-worker-id="${workerId}"]`);
if (!card) {
return false;
}
const worker = extension.config?.workers?.find((w) => w.id === workerId);
if (!worker) {
return false;
}
const workerState = extension.state.getWorker(workerId);
const isLaunching = Boolean(workerState?.launching);
if (isLaunching && !newStatus.online) {
extension.ui.updateStatusDot(workerId, STATUS_COLORS.PROCESSING_YELLOW, "Launching...", true);
} else if (newStatus.online && newStatus.processing) {
const queue = newStatus.queueCount || 0;
extension.ui.updateStatusDot(workerId, STATUS_COLORS.PROCESSING_YELLOW, `Online - Processing (${queue} in queue)`, false);
} else if (newStatus.online) {
extension.ui.updateStatusDot(workerId, STATUS_COLORS.ONLINE_GREEN, "Online - Idle", false);
} else if (worker.enabled) {
extension.ui.updateStatusDot(workerId, STATUS_COLORS.OFFLINE_RED, "Offline - Cannot connect", false);
}
updateWorkerControls(extension, workerId);
return true;
}
export async function renderSidebarContent(extension, el) {
// Panel is being opened/rendered
extension.log("Panel opened", "debug");
if (!el) {
extension.log("No element provided to renderSidebarContent", "debug");
return;
}
// Prevent infinite recursion
if (extension._isRendering) {
extension.log("Already rendering, skipping", "debug");
return;
}
extension._isRendering = true;
try {
// Store reference to the panel element
extension.panelElement = el;
// Show loading indicator
el.innerHTML = '';
const loadingDiv = document.createElement("div");
loadingDiv.style.cssText =
"display: flex; align-items: center; justify-content: center; height: calc(100vh - 100px); color: var(--dist-muted-text, #888);";
loadingDiv.innerHTML = ``;
el.appendChild(loadingDiv);
// Add rotation animation
const style = document.createElement('style');
style.textContent = `
@keyframes rotate {
from { transform: rotate(0deg); }
to { transform: rotate(360deg); }
}
`;
document.head.appendChild(style);
loadingDiv.querySelector('svg').style.animation = 'rotate 1s linear infinite';
// Preload data outside render
await Promise.all([extension.loadConfig(), loadManagedWorkers(extension), extension.refreshTunnelStatus()]);
extension.tunnelElements = {};
el.innerHTML = '';
// Create toolbar header to match ComfyUI style
const toolbar = document.createElement("div");
toolbar.className = "p-toolbar p-component border-x-0 border-t-0 rounded-none px-2 py-1 min-h-8";
toolbar.style.cssText =
"border-bottom: 1px solid var(--dist-divider, #444); background: transparent; display: flex; align-items: center;";
const toolbarStart = document.createElement("div");
toolbarStart.className = "p-toolbar-start";
toolbarStart.style.cssText = "display: flex; align-items: center;";
const titleSpan = document.createElement("span");
titleSpan.className = "text-xs 2xl:text-sm truncate";
titleSpan.textContent = "COMFYUI DISTRIBUTED";
titleSpan.title = "ComfyUI Distributed";
toolbarStart.appendChild(titleSpan);
toolbar.appendChild(toolbarStart);
const toolbarCenter = document.createElement("div");
toolbarCenter.className = "p-toolbar-center";
toolbar.appendChild(toolbarCenter);
const toolbarEnd = document.createElement("div");
toolbarEnd.className = "p-toolbar-end";
toolbar.appendChild(toolbarEnd);
el.appendChild(toolbar);
// Main container with adjusted padding
const container = document.createElement("div");
container.style.cssText = "padding: 15px; display: flex; flex-direction: column; height: calc(100% - 32px);";
// Detect master info on panel open (in case CUDA info wasn't available at startup)
extension.log(`Panel opened. CUDA device count: ${extension.cudaDeviceCount}, Workers: ${extension.config?.workers?.length || 0}`, "debug");
if (!extension.cudaDeviceCount) {
await extension.detectMasterIP();
}
// Now render with guaranteed up-to-date config
// Master Node Section
const masterDiv = extension.ui.renderEntityCard('master', extension.config?.master, extension);
container.appendChild(masterDiv);
container.appendChild(renderWorkersSection(extension));
container.appendChild(renderActionsSection(extension));
container.appendChild(renderSettingsSection(extension));
el.appendChild(container);
extension._applyThemeToneClass?.();
// Start checking worker statuses immediately in parallel
setTimeout(() => checkAllWorkerStatuses(extension), 0);
} finally {
// Always reset the rendering flag
extension._isRendering = false;
}
}
================================================
FILE: web/stateManager.js
================================================
export function createStateManager() {
const state = {
workers: new Map(), // Unified worker state: { status, managed, launching, expanded, ... }
masterStatus: 'online',
};
return {
// Worker state management
getWorker(workerId) {
return state.workers.get(String(workerId)) || {};
},
updateWorker(workerId, updates) {
const id = String(workerId);
const current = state.workers.get(id) || {};
state.workers.set(id, { ...current, ...updates });
return state.workers.get(id);
},
setWorkerStatus(workerId, status) {
return this.updateWorker(workerId, { status });
},
setWorkerManaged(workerId, info) {
return this.updateWorker(workerId, { managed: info });
},
setWorkerLaunching(workerId, launching) {
return this.updateWorker(workerId, { launching });
},
setWorkerExpanded(workerId, expanded) {
return this.updateWorker(workerId, { expanded });
},
isWorkerLaunching(workerId) {
return this.getWorker(workerId).launching || false;
},
isWorkerExpanded(workerId) {
return this.getWorker(workerId).expanded || false;
},
isWorkerManaged(workerId) {
return !!this.getWorker(workerId).managed;
},
getWorkerStatus(workerId) {
return this.getWorker(workerId).status || {};
},
// Master state
setMasterStatus(status) {
state.masterStatus = status;
},
getMasterStatus() {
return state.masterStatus;
}
};
}
================================================
FILE: web/tests/apiClient.test.js
================================================
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { createApiClient } from "../apiClient.js";
describe("apiClient probeWorker", () => {
let originalFetch;
beforeEach(() => {
originalFetch = globalThis.fetch;
globalThis.fetch = vi.fn();
});
afterEach(() => {
globalThis.fetch = originalFetch;
vi.restoreAllMocks();
});
it("returns ok=true when /prompt returns valid exec_info payload", async () => {
globalThis.fetch.mockResolvedValue({
ok: true,
status: 200,
json: vi.fn().mockResolvedValue({ exec_info: { queue_remaining: 2 } }),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: true, status: 200, queueRemaining: 2 });
});
it("returns ok=false on non-200 responses", async () => {
globalThis.fetch.mockResolvedValue({
ok: false,
status: 503,
json: vi.fn(),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: false, status: 503, queueRemaining: null });
});
it("returns ok=false when response JSON is invalid", async () => {
globalThis.fetch.mockResolvedValue({
ok: true,
status: 200,
json: vi.fn().mockRejectedValue(new Error("invalid json")),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: false, status: 200, queueRemaining: null });
});
it("returns ok=false when exec_info is missing", async () => {
globalThis.fetch.mockResolvedValue({
ok: true,
status: 200,
json: vi.fn().mockResolvedValue({}),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: false, status: 200, queueRemaining: null });
});
it("returns ok=false when queue_remaining is not numeric", async () => {
globalThis.fetch.mockResolvedValue({
ok: true,
status: 200,
json: vi.fn().mockResolvedValue({ exec_info: { queue_remaining: "n/a" } }),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: false, status: 200, queueRemaining: null });
});
it("clamps negative queue_remaining to zero", async () => {
globalThis.fetch.mockResolvedValue({
ok: true,
status: 200,
json: vi.fn().mockResolvedValue({ exec_info: { queue_remaining: -5 } }),
});
const client = createApiClient("http://127.0.0.1:8188");
const result = await client.probeWorker("http://worker.local:8190", 1000);
expect(result).toEqual({ ok: true, status: 200, queueRemaining: 0 });
});
});
================================================
FILE: web/tests/executionUtils.test.js
================================================
import { describe, expect, it } from "vitest";
import { buildWorkerWebSocketUrl } from "../urlUtils.js";
describe("execution decision helpers", () => {
it("buildWorkerWebSocketUrl converts http/https to ws/wss", () => {
expect(buildWorkerWebSocketUrl("http://worker.local:8188")).toBe(
"ws://worker.local:8188/distributed/worker_ws"
);
expect(buildWorkerWebSocketUrl("https://worker.example.com")).toBe(
"wss://worker.example.com/distributed/worker_ws"
);
});
});
================================================
FILE: web/tests/urlUtils.test.js
================================================
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import {
buildWorkerUrl,
buildWorkerWebSocketUrl,
getMasterUrl,
normalizeWorkerUrl,
parseHostInput,
} from "../urlUtils.js";
// ---------------------------------------------------------------------------
// normalizeWorkerUrl
// ---------------------------------------------------------------------------
describe("normalizeWorkerUrl", () => {
it("returns empty string for empty input", () => {
expect(normalizeWorkerUrl("")).toBe("");
});
it("returns empty string for null input", () => {
expect(normalizeWorkerUrl(null)).toBe("");
});
it("returns empty string for non-string input", () => {
expect(normalizeWorkerUrl(42)).toBe("");
});
it("preserves https protocol", () => {
expect(normalizeWorkerUrl("https://example.com")).toBe("https://example.com");
});
it("strips trailing slash", () => {
expect(normalizeWorkerUrl("http://example.com/")).toBe("http://example.com");
});
it("prepends http when protocol is missing", () => {
const result = normalizeWorkerUrl("worker.local:8188");
expect(result).toMatch(/^http:\/\//);
});
it("trims leading and trailing whitespace", () => {
expect(normalizeWorkerUrl(" http://example.com ")).toBe("http://example.com");
});
});
// ---------------------------------------------------------------------------
// parseHostInput
// ---------------------------------------------------------------------------
describe("parseHostInput", () => {
it("returns empty host and null port for null", () => {
expect(parseHostInput(null)).toEqual({ host: "", port: null });
});
it("returns empty host and null port for empty string", () => {
expect(parseHostInput("")).toEqual({ host: "", port: null });
});
it("parses hostname without port", () => {
const result = parseHostInput("worker.example.com");
expect(result.host).toBe("worker.example.com");
expect(result.port).toBeNull();
});
it("parses hostname with port", () => {
const result = parseHostInput("worker.example.com:9000");
expect(result.host).toBe("worker.example.com");
expect(result.port).toBe(9000);
});
it("strips http:// protocol prefix", () => {
const result = parseHostInput("http://worker.example.com:8188");
expect(result.host).toBe("worker.example.com");
expect(result.port).toBe(8188);
});
it("strips https:// protocol prefix", () => {
const result = parseHostInput("https://worker.example.com");
expect(result.host).toBe("worker.example.com");
expect(result.port).toBeNull();
});
it("ignores path after host:port", () => {
const result = parseHostInput("worker.example.com:8188/some/path");
expect(result.host).toBe("worker.example.com");
expect(result.port).toBe(8188);
});
});
// ---------------------------------------------------------------------------
// buildWorkerWebSocketUrl
// ---------------------------------------------------------------------------
describe("buildWorkerWebSocketUrl", () => {
it("converts http to ws", () => {
expect(buildWorkerWebSocketUrl("http://worker.local:8188")).toBe(
"ws://worker.local:8188/distributed/worker_ws"
);
});
it("converts https to wss", () => {
expect(buildWorkerWebSocketUrl("https://worker.example.com")).toBe(
"wss://worker.example.com/distributed/worker_ws"
);
});
it("always appends /distributed/worker_ws", () => {
const url = buildWorkerWebSocketUrl("http://worker.local:8188");
expect(url.endsWith("/distributed/worker_ws")).toBe(true);
});
});
// ---------------------------------------------------------------------------
// buildWorkerUrl (requires window.location stub)
// ---------------------------------------------------------------------------
describe("buildWorkerUrl", () => {
let originalWindow;
beforeEach(() => {
originalWindow = globalThis.window;
globalThis.window = {
location: {
hostname: "127.0.0.1",
protocol: "http:",
port: "8188",
origin: "http://127.0.0.1:8188",
},
};
});
afterEach(() => {
globalThis.window = originalWindow;
});
it("builds local worker URL using window hostname when no host set", () => {
const worker = { id: "w1", port: 8189 };
expect(buildWorkerUrl(worker, "/prompt")).toBe("http://127.0.0.1:8189/prompt");
});
it("builds remote worker URL using explicit host", () => {
const worker = { id: "w2", host: "worker.example.com", port: 9000 };
expect(buildWorkerUrl(worker, "/prompt")).toBe("http://worker.example.com:9000/prompt");
});
it("builds cloud worker URL with https when type=cloud", () => {
const worker = { id: "w3", host: "cloud.example.com", port: 443, type: "cloud" };
expect(buildWorkerUrl(worker, "/prompt")).toBe("https://cloud.example.com/prompt");
});
it("uses https for port 443 even without type=cloud", () => {
const worker = { id: "w4", host: "worker.example.com", port: 443 };
const result = buildWorkerUrl(worker, "");
expect(result.startsWith("https://")).toBe(true);
});
it("rewrites runpod proxy hostname for local port", () => {
globalThis.window = {
location: {
hostname: "podabc.proxy.runpod.net",
protocol: "https:",
port: "",
origin: "https://podabc.proxy.runpod.net",
},
};
const worker = { id: "w5", port: 8189 };
expect(buildWorkerUrl(worker, "/prompt")).toBe(
"https://podabc-8189.proxy.runpod.net/prompt"
);
});
it("returns URL without trailing slash when no endpoint given", () => {
const worker = { id: "w6", host: "worker.example.com", port: 8188 };
const result = buildWorkerUrl(worker, "");
expect(result.endsWith("/")).toBe(false);
});
});
// ---------------------------------------------------------------------------
// getMasterUrl
// ---------------------------------------------------------------------------
describe("getMasterUrl", () => {
const _loc = (hostname, protocol = "http:", port = "8188") => ({
hostname,
protocol,
port,
origin: `${protocol}//${hostname}${port ? `:${port}` : ""}`,
});
it("returns origin when master host not configured and hostname is non-localhost", () => {
const loc = _loc("192.168.1.10");
const result = getMasterUrl({}, loc);
expect(result).toBe(loc.origin);
});
it("returns origin for localhost when master host not configured", () => {
const loc = _loc("127.0.0.1");
const result = getMasterUrl({}, loc);
expect(result).toBe(loc.origin);
});
it("uses configured master host as-is when it includes http://", () => {
const config = { master: { host: "http://master.example.com" } };
const result = getMasterUrl(config, _loc("127.0.0.1"));
expect(result).toBe("http://master.example.com");
});
it("uses configured master host as-is when it includes https://", () => {
const config = { master: { host: "https://secure.master.com" } };
const result = getMasterUrl(config, _loc("127.0.0.1"));
expect(result).toBe("https://secure.master.com");
});
it("defaults to https for domain-name master hosts", () => {
const config = { master: { host: "master.example.com" } };
const result = getMasterUrl(config, _loc("127.0.0.1"));
expect(result).toBe("https://master.example.com");
});
it("does not force https for IP-address master hosts", () => {
const config = { master: { host: "192.168.1.100" } };
const result = getMasterUrl(config, _loc("127.0.0.1"));
expect(result.startsWith("https://")).toBe(false);
});
it("does not force https for localhost master host", () => {
const config = { master: { host: "localhost" } };
const result = getMasterUrl(config, _loc("127.0.0.1"));
expect(result.startsWith("https://")).toBe(false);
});
it("accepts null log parameter without throwing", () => {
const loc = _loc("127.0.0.1");
expect(() => getMasterUrl({}, loc, null)).not.toThrow();
});
});
================================================
FILE: web/tests/workerLifecycle.test.js
================================================
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { getWorkerUrl } from "../workerLifecycle.js";
describe("workerLifecycle URL construction", () => {
let originalWindow;
beforeEach(() => {
originalWindow = globalThis.window;
globalThis.window = {
location: {
hostname: "127.0.0.1",
protocol: "http:",
port: "8190",
origin: "http://127.0.0.1:8190",
},
};
});
afterEach(() => {
globalThis.window = originalWindow;
});
it("builds local worker URL with explicit local port", () => {
const worker = { id: "w1", port: 8189, type: "local" };
expect(getWorkerUrl({}, worker, "/prompt")).toBe("http://127.0.0.1:8189/prompt");
});
it("builds remote worker URL with host:port", () => {
const worker = { id: "w2", host: "worker.example.com", port: 9000, type: "remote" };
expect(getWorkerUrl({}, worker, "/prompt")).toBe("http://worker.example.com:9000/prompt");
});
it("builds cloud worker URL as https", () => {
const worker = { id: "w3", host: "cloud.example.com", port: 443, type: "cloud" };
expect(getWorkerUrl({}, worker, "/prompt")).toBe("https://cloud.example.com/prompt");
});
it("rewrites runpod proxy hostname for local worker ports", () => {
globalThis.window = {
location: {
hostname: "podabc.proxy.runpod.net",
protocol: "https:",
port: "",
origin: "https://podabc.proxy.runpod.net",
},
};
const worker = { id: "w4", port: 8189, type: "local" };
expect(getWorkerUrl({}, worker, "/prompt")).toBe("https://podabc-8189.proxy.runpod.net/prompt");
});
});
================================================
FILE: web/tests/workerSettings.test.js
================================================
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { addNewWorker, isRemoteWorker } from "../workerSettings.js";
describe("workerSettings remote classification", () => {
let originalWindow;
beforeEach(() => {
originalWindow = globalThis.window;
globalThis.window = {
location: {
hostname: "127.0.0.1",
},
};
});
afterEach(() => {
globalThis.window = originalWindow;
});
it("treats explicit local worker type as local even with non-local host", () => {
const worker = { type: "local", host: "192.168.1.50" };
expect(isRemoteWorker({}, worker)).toBe(false);
});
it("treats explicit remote worker type as remote", () => {
const worker = { type: "remote", host: "127.0.0.1" };
expect(isRemoteWorker({}, worker)).toBe(true);
});
it("treats cloud worker type as remote", () => {
const worker = { type: "cloud", host: "worker.example.com" };
expect(isRemoteWorker({}, worker)).toBe(true);
});
it("falls back to host heuristic for legacy workers", () => {
expect(isRemoteWorker({}, { host: "127.0.0.1" })).toBe(false);
expect(isRemoteWorker({}, { host: "worker.example.com" })).toBe(true);
});
});
describe("addNewWorker GPU availability guard", () => {
it("falls back to a disabled remote worker when no local CUDA device is available", async () => {
const toastAdd = vi.fn();
const updateWorker = vi.fn().mockResolvedValue({});
const stateUpdateWorker = vi.fn();
const extension = {
cudaDeviceCount: 1,
masterCudaDevice: 0,
panelElement: null,
config: {
workers: [],
master: { cuda_device: 0 },
},
api: { updateWorker },
state: { updateWorker: stateUpdateWorker, setWorkerExpanded: vi.fn() },
app: { extensionManager: { toast: { add: toastAdd } } },
};
await addNewWorker(extension);
expect(updateWorker).toHaveBeenCalledTimes(1);
expect(updateWorker.mock.calls[0][1]).toEqual(
expect.objectContaining({
type: "remote",
enabled: false,
cuda_device: null,
host: "",
})
);
expect(extension.config.workers).toHaveLength(1);
expect(extension.config.workers[0]).toEqual(
expect.objectContaining({
type: "remote",
enabled: false,
cuda_device: null,
host: "",
})
);
expect(stateUpdateWorker).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({ enabled: false })
);
expect(toastAdd).toHaveBeenCalledWith(
expect.objectContaining({
severity: "warn",
summary: "Remote Worker Added",
})
);
});
it("assigns the first free local CUDA device when adding a worker", async () => {
const toastAdd = vi.fn();
const updateWorker = vi.fn().mockResolvedValue({});
const stateUpdateWorker = vi.fn();
const setWorkerExpanded = vi.fn();
const extension = {
cudaDeviceCount: 3,
masterCudaDevice: 0,
panelElement: null,
config: {
workers: [
{ id: "w-existing", type: "local", port: 8189, cuda_device: 1, enabled: true },
],
master: { cuda_device: 0 },
},
api: { updateWorker },
state: { updateWorker: stateUpdateWorker, setWorkerExpanded },
app: { extensionManager: { toast: { add: toastAdd } } },
};
await addNewWorker(extension);
expect(updateWorker).toHaveBeenCalledTimes(1);
expect(updateWorker.mock.calls[0][1]).toEqual(
expect.objectContaining({
cuda_device: 2,
})
);
});
});
================================================
FILE: web/tunnelManager.js
================================================
export function updateTunnelUIElements(extension, isRunning, isStarting) {
void isRunning;
void isStarting;
const elements = extension.tunnelElements || {};
const status = (extension.tunnelStatus?.status || "stopped").toLowerCase();
const tunnelButtonColorClasses = ["tunnel-button--enable", "tunnel-button--disable"];
const tunnelStatusColorClasses = ["tunnel-status--enable", "tunnel-status--disable"];
if (elements.button) {
elements.button.disabled = status === "starting" || status === "stopping";
elements.button.classList.remove(...tunnelButtonColorClasses);
if (status === "starting") {
elements.button.innerHTML = ` Starting...`;
elements.button.classList.add("tunnel-button--enable");
} else if (status === "stopping") {
elements.button.innerHTML = ` Stopping...`;
elements.button.classList.add("tunnel-button--disable");
} else if (status === "running") {
elements.button.textContent = "Disable Cloudflare Tunnel";
elements.button.classList.add("tunnel-button--disable");
} else if (status === "error") {
elements.button.textContent = "Retry Cloudflare Tunnel";
elements.button.classList.add("tunnel-button--disable");
} else {
elements.button.textContent = "Enable Cloudflare Tunnel";
elements.button.classList.add("tunnel-button--enable");
}
}
if (elements.status) {
elements.status.textContent = status.toUpperCase();
elements.status.classList.remove(...tunnelStatusColorClasses);
if (status === "running" || status === "error" || status === "stopping") {
elements.status.classList.add("tunnel-status--disable");
} else {
elements.status.classList.add("tunnel-status--enable");
}
}
if (elements.url) {
const url = extension.tunnelStatus?.public_url;
if (url) {
elements.url.innerHTML = `${url}`;
} else {
elements.url.textContent = status === "starting" ? "Requesting public URL..." : "No tunnel active";
}
}
if (elements.copyBtn) {
const hasUrl = Boolean(extension.tunnelStatus?.public_url);
elements.copyBtn.disabled = !hasUrl;
elements.copyBtn.style.opacity = hasUrl ? "1" : "0.5";
}
}
export async function refreshTunnelStatus(extension) {
try {
const data = await extension.api.getTunnelStatus();
extension.tunnelStatus = data.tunnel || { status: "stopped" };
if (data.master_host !== undefined) {
extension._applyMasterHost(data.master_host);
}
return extension.tunnelStatus;
} catch (error) {
extension.tunnelStatus = { status: "error", last_error: error.message };
extension.log("Failed to fetch tunnel status: " + error.message, "error");
return extension.tunnelStatus;
} finally {
updateTunnelUIElements(extension);
}
}
export async function handleTunnelToggle(extension, button) {
const currentStatus = (extension.tunnelStatus?.status || "stopped").toLowerCase();
if (currentStatus === "starting" || currentStatus === "stopping") {
return;
}
const setStatus = (status) => {
extension.tunnelStatus = { ...(extension.tunnelStatus || {}), status };
updateTunnelUIElements(extension);
};
if (currentStatus === "running") {
setStatus("stopping");
try {
if (button) {
button.innerHTML = ` Stopping...`;
button.disabled = true;
}
const data = await extension.api.stopTunnel();
extension.tunnelStatus = data.tunnel || { status: "stopped" };
if (data.master_host !== undefined) {
extension._applyMasterHost(data.master_host);
}
updateTunnelUIElements(extension);
extension.ui.showToast(extension.app, "info", "Cloudflare Tunnel Disabled", "Master address restored", 4000);
} catch (error) {
extension.tunnelStatus = { status: "error", last_error: error.message };
updateTunnelUIElements(extension);
extension.ui.showToast(extension.app, "error", "Failed to stop tunnel", error.message, 5000);
} finally {
if (button) {
button.disabled = false;
}
}
return;
}
// Start tunnel
setStatus("starting");
if (button) {
button.innerHTML = ` Starting...`;
button.disabled = true;
}
try {
const data = await extension.api.startTunnel();
extension.tunnelStatus = data.tunnel || { status: "running" };
if (data.master_host !== undefined) {
extension._applyMasterHost(data.master_host);
}
updateTunnelUIElements(extension);
const url = data.tunnel?.public_url || data.master_host;
extension.ui.showToast(extension.app, "success", "Cloudflare Tunnel Ready", url || "Public URL created", 4500);
} catch (error) {
extension.tunnelStatus = { status: "error", last_error: error.message };
updateTunnelUIElements(extension);
extension.ui.showToast(extension.app, "error", "Failed to start tunnel", error.message, 5000);
} finally {
if (button) {
button.disabled = false;
}
}
}
================================================
FILE: web/ui/buttonHelpers.js
================================================
export function createButtonHelper(ui, text, onClick, style) {
return ui.createButton(text, onClick, style);
}
export function createCheckboxSetting(id, label, tooltip, checked, onChange) {
const group = document.createElement("div");
group.style.cssText = "grid-column: 1 / span 2; display: flex; align-items: center; gap: 8px;";
const checkbox = document.createElement("input");
checkbox.type = "checkbox";
checkbox.id = id;
checkbox.checked = checked;
checkbox.onchange = onChange;
const lbl = document.createElement("label");
lbl.htmlFor = id;
lbl.textContent = label;
lbl.style.cssText = "font-size: 12px; color: var(--dist-label-text, #ccc); cursor: pointer;";
if (tooltip) {
lbl.title = tooltip;
}
group.appendChild(checkbox);
group.appendChild(lbl);
return group;
}
export function createNumberSetting(id, label, tooltip, value, min, step, onChange) {
const group = document.createElement("div");
group.style.cssText = "grid-column: 1 / span 2; display: flex; align-items: center; gap: 6px;";
const lbl = document.createElement("label");
lbl.htmlFor = id;
lbl.textContent = label;
lbl.style.cssText = "font-size: 12px; color: var(--dist-label-text, #ccc);";
if (tooltip) {
lbl.title = tooltip;
}
const input = document.createElement("input");
input.type = "number";
input.id = id;
input.min = String(min);
input.step = String(step);
input.style.cssText =
"width: 80px; padding: 2px 6px; background: var(--dist-input-bg, #222); color: var(--dist-input-text, #ddd); border: 1px solid var(--dist-input-border, #333); border-radius: 3px;";
input.value = value;
input.onchange = onChange;
group.appendChild(lbl);
group.appendChild(input);
return group;
}
================================================
FILE: web/ui/cloudflareWarning.js
================================================
export function showCloudflareWarning(extension, masterHost) {
const existingBanner = document.getElementById('cloudflare-warning-banner');
if (existingBanner) {
existingBanner.remove();
}
const banner = document.createElement('div');
banner.id = 'cloudflare-warning-banner';
banner.style.cssText = `
position: fixed;
top: 0;
left: 0;
right: 0;
background: #ff9800;
color: #333;
padding: 8px 16px;
text-align: center;
z-index: 10000;
display: flex;
align-items: center;
justify-content: center;
gap: 16px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
`;
const messageSpan = document.createElement('span');
messageSpan.innerHTML = `Connection issue: Master address ${masterHost} is not reachable. The cloudflare tunnel may be offline.`;
messageSpan.style.fontSize = '13px';
const resetButton = document.createElement('button');
resetButton.textContent = 'Reset Master Address';
resetButton.style.cssText = `
background: #333;
color: white;
border: none;
padding: 6px 14px;
border-radius: 4px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
transition: background 0.2s;
`;
resetButton.onmouseover = () => {
resetButton.style.background = '#555';
};
resetButton.onmouseout = () => {
resetButton.style.background = '#333';
};
const dismissButton = document.createElement('button');
dismissButton.textContent = 'Dismiss';
dismissButton.style.cssText = `
background: transparent;
color: #333;
border: 1px solid #333;
padding: 6px 14px;
border-radius: 4px;
cursor: pointer;
font-size: 13px;
transition: opacity 0.2s;
`;
dismissButton.onmouseover = () => {
dismissButton.style.opacity = '0.7';
};
dismissButton.onmouseout = () => {
dismissButton.style.opacity = '1';
};
resetButton.onclick = async () => {
resetButton.disabled = true;
resetButton.textContent = 'Resetting...';
try {
await extension.api.updateMaster({
name: extension.config?.master?.name || "Master",
host: "",
});
if (extension.config?.master) {
extension.config.master.host = "";
}
await extension.detectMasterIP();
await extension.loadConfig();
const newMasterUrl = extension.getMasterUrl();
extension.log(`Master host reset. New URL: ${newMasterUrl}`, "info");
if (extension.panelElement) {
const hostInput = document.getElementById('master-host');
if (hostInput) {
hostInput.value = extension.config?.master?.host || "";
}
}
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Master Host Reset",
detail: `New address: ${newMasterUrl}`,
life: 4000,
});
banner.remove();
} catch (error) {
resetButton.disabled = false;
resetButton.textContent = 'Reset Master Host';
extension.log(`Failed to reset master host: ${error.message}`, "error");
}
};
dismissButton.onclick = () => {
banner.remove();
};
banner.appendChild(messageSpan);
banner.appendChild(resetButton);
banner.appendChild(dismissButton);
document.body.prepend(banner);
setTimeout(() => {
if (document.getElementById('cloudflare-warning-banner')) {
banner.style.transition = 'opacity 0.5s';
banner.style.opacity = '0';
setTimeout(() => {
banner.remove();
}, 500);
}
}, 30000);
}
================================================
FILE: web/ui/entityCard.js
================================================
import { updateWorkerControls, toggleWorkerExpanded } from "../workerLifecycle.js";
import { isRemoteWorker } from "../workerSettings.js";
export function renderEntityCard(ui, cardConfigs, entityType, data, extension) {
const config = cardConfigs[entityType] || {};
const isPlaceholder = entityType === 'blueprint' || entityType === 'add';
const isWorker = entityType === 'worker';
const isMaster = entityType === 'master';
const isRemote = isWorker && isRemoteWorker(extension, data);
const cardOptions = {
onClick: isPlaceholder ? data?.onClick : null,
};
if (isPlaceholder) {
cardOptions.title = entityType === 'blueprint' ? "Click to add your first worker" : "Click to add a new worker";
}
const card = ui.createCard(entityType, cardOptions);
if (isWorker && data?.id) {
card.dataset.workerId = String(data.id);
}
const leftColumn = ui.createCheckboxOrIconColumn(config.checkbox, data, extension);
card.appendChild(leftColumn);
const rightColumn = ui.createCardColumn('content');
rightColumn.classList.add("entity-card-content");
const infoRow = ui.createInfoRow();
if (config.infoRowPadding) {
infoRow.style.padding = config.infoRowPadding;
}
if (config.minHeight === 'auto') {
infoRow.style.minHeight = 'auto';
} else if (config.minHeight) {
infoRow.style.minHeight = config.minHeight;
}
if (config.expand) {
infoRow.title = "Click to expand settings";
infoRow.onclick = () => {
if (isMaster) {
const masterSettingsExpanded = !extension.masterSettingsExpanded;
extension.masterSettingsExpanded = masterSettingsExpanded;
const masterSettingsDiv = document.getElementById("master-settings");
const arrow = infoRow.querySelector('.settings-arrow');
if (masterSettingsExpanded) {
masterSettingsDiv.classList.add("expanded");
masterSettingsDiv.style.padding = "12px";
masterSettingsDiv.style.marginTop = "8px";
masterSettingsDiv.style.marginBottom = "8px";
arrow.style.transform = "rotate(90deg)";
} else {
masterSettingsDiv.classList.remove("expanded");
masterSettingsDiv.style.padding = "0 12px";
masterSettingsDiv.style.marginTop = "0";
masterSettingsDiv.style.marginBottom = "0";
arrow.style.transform = "rotate(0deg)";
}
} else {
toggleWorkerExpanded(extension, data.id);
}
};
}
const workerContent = ui.createWorkerContent();
if (entityType === 'add') {
workerContent.style.alignItems = "center";
}
const statusDot = ui.createStatusDotHelper(config.statusDot, data, extension);
workerContent.appendChild(statusDot);
const infoSpan = document.createElement("span");
infoSpan.classList.add("dist-worker-info");
infoSpan.innerHTML = config.infoText(data, extension);
workerContent.appendChild(infoSpan);
infoRow.appendChild(workerContent);
let settingsArrow;
if (config.expand) {
const expandedId = config.settings?.expandedId || (isMaster ? 'master' : data?.id);
settingsArrow = ui.createSettingsToggleHelper(expandedId, extension);
if (isMaster && !extension.masterSettingsExpanded) {
settingsArrow.style.transform = "rotate(0deg)";
}
infoRow.appendChild(settingsArrow);
}
rightColumn.appendChild(infoRow);
if (config.hover === true) {
rightColumn.classList.add("entity-card-content--hoverable");
rightColumn.onmouseover = () => {
rightColumn.classList.add("entity-card-content--hovered");
if (settingsArrow) {
settingsArrow.style.color = "var(--dist-settings-arrow-hover, #fff)";
}
};
rightColumn.onmouseout = () => {
rightColumn.classList.remove("entity-card-content--hovered");
if (settingsArrow) {
settingsArrow.style.color = "var(--dist-settings-arrow, #888)";
}
};
}
const controlsDiv = ui.createControlsSection(config.controls, data, extension, isRemote);
if (controlsDiv) {
rightColumn.appendChild(controlsDiv);
}
if (config.settings) {
const settingsDiv = ui.createSettingsSection(config.settings, data, extension);
rightColumn.appendChild(settingsDiv);
}
card.appendChild(rightColumn);
if (config.hover === 'placeholder') {
ui.addPlaceholderHover(card, leftColumn, entityType);
}
if (isWorker && !isRemote) {
updateWorkerControls(extension, data.id);
}
return card;
}
================================================
FILE: web/ui/logModal.js
================================================
import { TIMEOUTS } from '../constants.js';
function formatFileSize(bytes) {
if (bytes < 1024) return bytes + ' B';
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
}
export function createLogModal() {
let _modalEl = null;
let _keydownHandler = null;
let _refreshTimer = null;
let _fetching = false;
let _fetchLog = null;
let _onClose = null;
let _logContentEl = null;
let _statusBarEl = null;
let _refreshCheckbox = null;
const updateLogView = (logData) => {
if (!_logContentEl || !_statusBarEl || !logData) {
return;
}
const shouldAutoScroll =
_logContentEl.scrollTop + _logContentEl.clientHeight >= _logContentEl.scrollHeight - 50;
_logContentEl.textContent = logData.content || '';
if (shouldAutoScroll) {
_logContentEl.scrollTop = _logContentEl.scrollHeight;
}
let statusText;
if (logData.source === "memory") {
statusText = "Remote worker log (in-memory buffer)";
if (logData.truncated) {
statusText += ` (showing last ${logData.lines_shown || 0} lines)`;
}
} else {
statusText = `Log file: ${logData.log_file || 'unknown'}`;
if (logData.truncated) {
statusText += ` (showing last ${logData.lines_shown} lines of ${formatFileSize(logData.file_size || 0)})`;
}
}
_statusBarEl.textContent = statusText;
};
const refreshLog = async () => {
if (_fetching || !_fetchLog || !_refreshCheckbox?.checked) {
return;
}
_fetching = true;
try {
const data = await _fetchLog();
if (data) {
updateLogView(data);
}
} catch (_error) {
// Keep the modal open and continue retrying on next interval.
} finally {
_fetching = false;
}
};
const stopRefresh = () => {
if (_refreshTimer) {
clearInterval(_refreshTimer);
_refreshTimer = null;
}
};
const startRefresh = () => {
stopRefresh();
_refreshTimer = setInterval(() => {
refreshLog();
}, TIMEOUTS.LOG_REFRESH);
};
const unmount = () => {
stopRefresh();
if (_keydownHandler) {
document.removeEventListener('keydown', _keydownHandler);
_keydownHandler = null;
}
if (_modalEl) {
_modalEl.remove();
_modalEl = null;
}
const onClose = _onClose;
_onClose = null;
if (onClose) {
onClose();
}
};
const mount = (container, { workerName, logData, onClose, fetchLog, themeClass = "" }) => {
_onClose = onClose;
_fetchLog = fetchLog;
const modal = document.createElement('div');
modal.id = 'distributed-log-modal';
modal.className = 'log-modal';
if (themeClass) {
modal.classList.add(themeClass);
}
const content = document.createElement('div');
content.className = 'log-modal__content';
const header = document.createElement('div');
header.className = 'log-modal__header';
const title = document.createElement('h3');
title.className = 'log-modal__title';
title.textContent = `${workerName} - Log Viewer`;
const headerButtons = document.createElement('div');
headerButtons.className = 'log-modal__header-buttons';
const refreshContainer = document.createElement('div');
refreshContainer.className = 'log-modal__refresh';
const refreshCheckbox = document.createElement('input');
refreshCheckbox.type = 'checkbox';
refreshCheckbox.id = 'log-auto-refresh';
refreshCheckbox.className = 'log-modal__refresh-input';
refreshCheckbox.checked = true;
const refreshLabel = document.createElement('label');
refreshLabel.htmlFor = 'log-auto-refresh';
refreshLabel.className = 'log-modal__refresh-label';
refreshLabel.textContent = 'Auto-refresh';
refreshContainer.appendChild(refreshCheckbox);
refreshContainer.appendChild(refreshLabel);
const closeBtn = document.createElement('button');
closeBtn.className = 'distributed-button log-modal__close';
closeBtn.textContent = '✕';
headerButtons.appendChild(refreshContainer);
headerButtons.appendChild(closeBtn);
header.appendChild(title);
header.appendChild(headerButtons);
const logContainer = document.createElement('div');
logContainer.className = 'log-modal__body';
logContainer.id = 'distributed-log-content';
const statusBar = document.createElement('div');
statusBar.className = 'log-modal__status';
content.appendChild(header);
content.appendChild(logContainer);
content.appendChild(statusBar);
modal.appendChild(content);
closeBtn.addEventListener('click', unmount);
modal.addEventListener('click', (event) => {
if (event.target === modal) {
unmount();
}
});
_keydownHandler = (event) => {
if (event.key === 'Escape') {
unmount();
}
};
document.addEventListener('keydown', _keydownHandler);
_modalEl = modal;
_logContentEl = logContainer;
_statusBarEl = statusBar;
_refreshCheckbox = refreshCheckbox;
refreshCheckbox.addEventListener('change', () => {
if (refreshCheckbox.checked) {
refreshLog();
}
});
container.appendChild(modal);
updateLogView(logData);
requestAnimationFrame(() => {
if (_logContentEl) {
_logContentEl.scrollTop = _logContentEl.scrollHeight;
}
});
startRefresh();
};
return {
mount,
unmount,
update: updateLogView,
};
}
================================================
FILE: web/ui/settingsForm.js
================================================
import { BUTTON_STYLES } from '../constants.js';
import { cancelWorkerSettings, deleteWorker, isRemoteWorker, saveWorkerSettings } from '../workerSettings.js';
export function createWorkerSettingsForm(ui, extension, worker) {
const form = document.createElement("div");
form.style.cssText = "display: flex; flex-direction: column; gap: 8px;";
const nameGroup = ui.createFormGroup("Name:", worker.name, `name-${worker.id}`);
form.appendChild(nameGroup.group);
const typeGroup = document.createElement("div");
typeGroup.style.cssText = "display: flex; flex-direction: column; gap: 4px; margin: 5px 0;";
const typeLabel = document.createElement("label");
typeLabel.htmlFor = `worker-type-${worker.id}`;
typeLabel.textContent = "Worker Type:";
typeLabel.style.cssText = "font-size: 12px; color: var(--dist-label-text, #ccc);";
const typeSelect = document.createElement("select");
typeSelect.id = `worker-type-${worker.id}`;
typeSelect.style.cssText =
"padding: 4px 8px; background: var(--dist-input-bg, #333); color: var(--dist-input-text, #fff); border: 1px solid var(--dist-input-border, #555); border-radius: 4px; font-size: 12px;";
const localOption = document.createElement("option");
localOption.value = "local";
localOption.textContent = "Local";
const remoteOption = document.createElement("option");
remoteOption.value = "remote";
remoteOption.textContent = "Remote";
const cloudOption = document.createElement("option");
cloudOption.value = "cloud";
cloudOption.textContent = "Cloud";
typeSelect.appendChild(localOption);
typeSelect.appendChild(remoteOption);
typeSelect.appendChild(cloudOption);
const runpodText = document.createElement("a");
runpodText.id = `runpod-text-${worker.id}`;
runpodText.href = "https://github.com/robertvoy/ComfyUI-Distributed/blob/main/docs/worker-setup-guides.md#cloud-workers";
runpodText.target = "_blank";
runpodText.textContent = "Deploy Cloud Worker with Runpod";
runpodText.style.cssText = "font-size: 12px; color: #4a90e2; text-decoration: none; margin-top: 4px; display: none; cursor: pointer;";
const createOnChangeHandler = () => {
return (e) => {
const workerType = e.target.value;
const hostGroup = document.getElementById(`host-group-${worker.id}`);
const hostInput = document.getElementById(`host-${worker.id}`);
const portGroup = document.getElementById(`port-group-${worker.id}`);
const portInput = document.getElementById(`port-${worker.id}`);
const cudaGroup = document.getElementById(`cuda-group-${worker.id}`);
const argsGroup = document.getElementById(`args-group-${worker.id}`);
const runpodTextElem = document.getElementById(`runpod-text-${worker.id}`);
if (!hostGroup || !portGroup || !cudaGroup || !argsGroup || !runpodTextElem || !hostInput || !portInput) {
return;
}
if (workerType === "local") {
hostGroup.style.display = "none";
portGroup.style.display = "flex";
cudaGroup.style.display = "flex";
argsGroup.style.display = "flex";
runpodTextElem.style.display = "none";
} else if (workerType === "remote") {
hostGroup.style.display = "flex";
portGroup.style.display = "flex";
cudaGroup.style.display = "none";
argsGroup.style.display = "none";
runpodTextElem.style.display = "none";
hostInput.placeholder = "e.g., 192.168.1.100";
if (hostInput.value === "localhost" || hostInput.value === "127.0.0.1") {
hostInput.value = "";
}
} else if (workerType === "cloud") {
hostGroup.style.display = "flex";
portGroup.style.display = "flex";
cudaGroup.style.display = "none";
argsGroup.style.display = "none";
runpodTextElem.style.display = "block";
hostInput.placeholder = "e.g., your-cloud-worker.trycloudflare.com";
portInput.value = "443";
if (hostInput.value === "localhost" || hostInput.value === "127.0.0.1") {
hostInput.value = "";
}
}
};
};
typeGroup.appendChild(typeLabel);
typeGroup.appendChild(typeSelect);
typeGroup.appendChild(runpodText);
form.appendChild(typeGroup);
const hostGroup = ui.createFormGroup("Host:", worker.host || "", `host-${worker.id}`, "text", "e.g., 192.168.1.100");
hostGroup.group.id = `host-group-${worker.id}`;
hostGroup.group.style.display = (isRemoteWorker(extension, worker) || worker.type === "cloud") ? "flex" : "none";
form.appendChild(hostGroup.group);
const portGroup = ui.createFormGroup("Port:", worker.port, `port-${worker.id}`, "number");
portGroup.group.id = `port-group-${worker.id}`;
form.appendChild(portGroup.group);
const cudaGroup = ui.createFormGroup("CUDA Device:", worker.cuda_device || 0, `cuda-${worker.id}`, "number");
cudaGroup.group.id = `cuda-group-${worker.id}`;
cudaGroup.group.style.display = (isRemoteWorker(extension, worker) || worker.type === "cloud") ? "none" : "flex";
form.appendChild(cudaGroup.group);
const argsGroup = ui.createFormGroup("Extra Args:", worker.extra_args || "", `args-${worker.id}`);
argsGroup.group.id = `args-group-${worker.id}`;
argsGroup.group.style.display = (isRemoteWorker(extension, worker) || worker.type === "cloud") ? "none" : "flex";
form.appendChild(argsGroup.group);
const saveBtn = ui.createButton("Save", () => saveWorkerSettings(extension, worker.id), "background-color: #4a7c4a;");
saveBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.success;
const cancelBtn = ui.createButton("Cancel", () => cancelWorkerSettings(extension, worker.id), "background-color: #555;");
cancelBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.cancel;
const deleteBtn = ui.createButton("Delete", () => deleteWorker(extension, worker.id), "background-color: #7c4a4a;");
deleteBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.error + BUTTON_STYLES.marginLeftAuto;
const buttonGroup = ui.createButtonGroup([saveBtn, cancelBtn, deleteBtn], " margin-top: 8px;");
form.appendChild(buttonGroup);
typeSelect.onchange = createOnChangeHandler();
if (worker.type === "cloud") {
typeSelect.value = "cloud";
runpodText.style.display = "block";
} else if (isRemoteWorker(extension, worker)) {
typeSelect.value = "remote";
} else {
typeSelect.value = "local";
}
typeSelect.dispatchEvent(new Event('change'));
return form;
}
================================================
FILE: web/ui.js
================================================
import { BUTTON_STYLES, UI_STYLES, STATUS_COLORS, UI_COLORS, TIMEOUTS } from './constants.js';
import { createButtonHelper as createButtonHelperFn } from './ui/buttonHelpers.js';
import { showCloudflareWarning as showCloudflareWarningFn } from './ui/cloudflareWarning.js';
import { createWorkerSettingsForm as createWorkerSettingsFormFn } from './ui/settingsForm.js';
import { renderEntityCard as renderEntityCardFn } from './ui/entityCard.js';
import { createLogModal } from './ui/logModal.js';
import { launchWorker, stopWorker, updateWorkerControls, viewWorkerLog } from './workerLifecycle.js';
import { isRemoteWorker } from './workerSettings.js';
const cardConfigs = {
master: {
checkbox: {
enabled: true,
masterToggle: true,
title: "Toggle master participation in workloads"
},
statusDot: {
id: 'master-status',
initialColor: (_, extension) => extension.isMasterParticipating() ? STATUS_COLORS.ONLINE_GREEN : STATUS_COLORS.DISABLED_GRAY,
initialTitle: (_, extension) => extension.isMasterParticipating() ? 'Master participating' : 'Master orchestrator only',
dynamic: true
},
infoText: (data, extension) => {
const cudaDevice = extension.config?.master?.cuda_device ?? extension.masterCudaDevice;
const cudaInfo = cudaDevice !== undefined ? `CUDA ${cudaDevice} • ` : '';
const port = window.location.port || (window.location.protocol === 'https:' ? '443' : '80');
const participationEnabled = extension.isMasterParticipationEnabled();
const fallbackActive = extension.isMasterFallbackActive();
let delegateBadge = '';
if (!participationEnabled && fallbackActive) {
delegateBadge = ` Fallback active • Master executing`;
}
return `${data?.name || extension.config?.master?.name || "Master"} ${cudaInfo}Port ${port}${delegateBadge}`;
},
controls: {
type: 'master'
},
settings: {
formType: 'master',
id: 'master-settings',
expandedTracker: 'masterSettingsExpanded'
},
hover: true,
expand: true,
border: 'solid'
},
worker: {
checkbox: {
enabled: true,
title: "Enable/disable this worker"
},
statusDot: {
dynamic: true,
initialColor: (data) => data.enabled ? STATUS_COLORS.OFFLINE_RED : STATUS_COLORS.DISABLED_GRAY,
initialTitle: (data) => data.enabled ? "Checking status..." : "Disabled",
id: (data) => `status-${data.id}`
},
infoText: (data, extension) => {
const isRemote = isRemoteWorker(extension, data);
const isCloud = data.type === 'cloud';
if (isCloud) {
// For cloud workers, don't show port (it's always 443)
return `${data.name} ${data.host}`;
} else if (isRemote) {
const hostLabel = data.host
? `${data.host}:${data.port}`
: `Unconfigured remote worker • Port ${data.port}`;
return `${data.name} ${hostLabel}`;
} else {
const cudaInfo = data.cuda_device !== undefined ? `CUDA ${data.cuda_device} • ` : '';
return `${data.name} ${cudaInfo}Port ${data.port}`;
}
},
controls: {
dynamic: true
},
settings: {
formType: 'worker',
id: (data) => `settings-${data.id}`,
expandedId: (data) => data?.id
},
hover: true,
expand: true,
border: 'solid'
},
blueprint: {
checkbox: {
type: 'icon',
content: '+',
width: 42,
style: `border-right: 2px dashed ${UI_COLORS.BORDER_LIGHT}; color: ${UI_COLORS.ACCENT_COLOR}; font-size: 24px; font-weight: 500;`
},
statusDot: {
color: 'transparent',
border: `1px solid ${UI_COLORS.BORDER_LIGHT}`
},
infoText: () => `Add New Worker [CUDA] • [Port]`,
controls: {
type: 'ghost',
text: 'Configure',
style: `border: 1px solid ${UI_COLORS.BORDER_DARK}; background: transparent; color: ${UI_COLORS.BORDER_LIGHT};`
},
hover: 'placeholder',
expand: false,
border: 'dashed'
},
add: {
checkbox: {
type: 'icon',
content: '+',
width: 43,
style: `border-right: 1px dashed ${UI_COLORS.BORDER_DARK}; color: ${UI_COLORS.BORDER_LIGHT}; font-size: 18px;`
},
statusDot: {
color: 'transparent',
border: `1px solid ${UI_COLORS.BORDER_LIGHT}`
},
infoText: () => `Add New Worker`,
controls: null,
hover: 'placeholder',
expand: false,
border: 'dashed',
minHeight: '48px'
}
};
export class DistributedUI {
constructor() {
// UI element styles
this.styles = UI_STYLES;
}
createStatusDot(id, color = "#666", title = "Status") {
const dot = document.createElement("span");
if (id) dot.id = id;
dot.style.cssText = this.styles.statusDot + ` background-color: ${color};`;
dot.title = title;
return dot;
}
createButton(text, onClick, customStyle = "") {
const button = document.createElement("button");
button.textContent = text;
button.className = "distributed-button";
button.style.cssText = BUTTON_STYLES.base + customStyle;
if (onClick) button.onclick = onClick;
return button;
}
createButtonGroup(buttons, style = "") {
const group = document.createElement("div");
group.style.cssText = this.styles.buttonGroup + style;
buttons.forEach(button => group.appendChild(button));
return group;
}
createWorkerControls(workerId, handlers = {}) {
const controlsDiv = document.createElement("div");
controlsDiv.id = `controls-${workerId}`;
controlsDiv.style.cssText = this.styles.controlsDiv;
const buttons = [];
if (handlers.launch) {
const launchBtn = this.createButton('Launch', handlers.launch);
launchBtn.id = `launch-${workerId}`;
launchBtn.title = "Launch this worker instance";
buttons.push(launchBtn);
}
if (handlers.stop) {
const stopBtn = this.createButton('Stop', handlers.stop);
stopBtn.id = `stop-${workerId}`;
stopBtn.title = "Stop this worker instance";
buttons.push(stopBtn);
}
if (handlers.viewLog) {
const logBtn = this.createButton('View Log', handlers.viewLog);
logBtn.id = `log-${workerId}`;
logBtn.title = "View worker log file";
buttons.push(logBtn);
}
buttons.forEach(btn => controlsDiv.appendChild(btn));
return controlsDiv;
}
createFormGroup(label, value, id, type = "text", placeholder = "") {
const group = document.createElement("div");
group.style.cssText = this.styles.formGroup;
const labelEl = document.createElement("label");
labelEl.textContent = label;
labelEl.htmlFor = id;
labelEl.style.cssText = this.styles.formLabel;
const input = document.createElement("input");
input.type = type;
input.id = id;
input.value = value;
input.placeholder = placeholder;
input.classList.add('dist-form-input');
input.style.cssText = this.styles.formInput;
group.appendChild(labelEl);
group.appendChild(input);
return { group, input };
}
createInfoBox(text) {
const box = document.createElement("div");
box.classList.add('dist-info-box');
box.style.cssText = this.styles.infoBox;
box.textContent = text;
return box;
}
addHoverEffect(element, onHover, onLeave) {
element.onmouseover = onHover;
element.onmouseout = onLeave;
}
createCard(type = 'worker', options = {}) {
const card = document.createElement("div");
card.classList.add('dist-card');
switch(type) {
case 'master':
case 'worker':
card.style.cssText = this.styles.workerCard;
break;
case 'blueprint':
card.classList.add('dist-card--blueprint');
card.style.cssText = this.styles.cardBase + this.styles.cardBlueprint;
if (options.onClick) card.onclick = options.onClick;
if (options.title) card.title = options.title;
break;
case 'add':
card.classList.add('dist-card--add');
card.style.cssText = this.styles.cardBase + this.styles.cardAdd;
if (options.onClick) card.onclick = options.onClick;
if (options.title) card.title = options.title;
break;
}
if (options.onMouseEnter) {
card.addEventListener('mouseenter', options.onMouseEnter);
}
if (options.onMouseLeave) {
card.addEventListener('mouseleave', options.onMouseLeave);
}
return card;
}
createCardColumn(type = 'checkbox', options = {}) {
const column = document.createElement("div");
switch(type) {
case 'checkbox':
column.classList.add('dist-card__left-col');
column.style.cssText = this.styles.checkboxColumn;
if (options.title) column.title = options.title;
break;
case 'icon':
column.style.cssText = this.styles.columnBase + this.styles.iconColumn;
break;
case 'content':
column.style.cssText = this.styles.contentColumn;
break;
}
return column;
}
createInfoRow(options = {}) {
const row = document.createElement("div");
row.style.cssText = this.styles.infoRow;
if (options.onClick) row.onclick = options.onClick;
return row;
}
createWorkerContent() {
const content = document.createElement("div");
content.style.cssText = this.styles.workerContent;
return content;
}
createSettingsForm(fields = [], options = {}) {
const form = document.createElement("div");
form.style.cssText = this.styles.settingsForm;
fields.forEach(field => {
if (field.type === 'checkbox') {
const group = document.createElement("div");
group.style.cssText = this.styles.checkboxGroup;
const checkbox = document.createElement("input");
checkbox.type = "checkbox";
checkbox.id = field.id;
checkbox.checked = field.checked || false;
if (field.onChange) checkbox.onchange = field.onChange;
const label = document.createElement("label");
label.htmlFor = field.id;
label.textContent = field.label;
label.style.cssText = this.styles.formLabelClickable;
group.appendChild(checkbox);
group.appendChild(label);
form.appendChild(group);
} else {
const result = this.createFormGroup(field.label, field.value, field.id, field.type, field.placeholder);
if (field.groupId) result.group.id = field.groupId;
if (field.display) result.group.style.display = field.display;
form.appendChild(result.group);
}
});
if (options.buttons) {
const buttonGroup = this.createButtonGroup(options.buttons, options.buttonStyle || " margin-top: 8px;");
form.appendChild(buttonGroup);
}
return form;
}
createButtonHelper(text, onClick, style) {
return createButtonHelperFn(this, text, onClick, style);
}
updateMasterDisplay(extension) {
// Use persistent config value as fallback
const cudaDevice = extension?.config?.master?.cuda_device ?? extension?.masterCudaDevice;
// Update CUDA info if element exists
const cudaInfo = document.getElementById('master-cuda-info');
if (cudaInfo) {
const port = window.location.port || (window.location.protocol === 'https:' ? '443' : '80');
if (cudaDevice !== undefined && cudaDevice !== null) {
cudaInfo.textContent = `CUDA ${cudaDevice} • Port ${port}`;
} else {
cudaInfo.textContent = `Port ${port}`;
}
}
// Update name if changed
const nameDisplay = document.getElementById('master-name-display');
if (nameDisplay && extension?.config?.master?.name) {
nameDisplay.textContent = extension.config.master.name;
}
}
showToast(app, severity, summary, detail, life = 3000) {
if (app.extensionManager?.toast?.add) {
app.extensionManager.toast.add({ severity, summary, detail, life });
}
}
showCloudflareWarning(extension, masterHost) {
return showCloudflareWarningFn(extension, masterHost);
}
updateStatusDot(workerId, color, title, pulsing = false) {
const statusDot = document.getElementById(`status-${workerId}`);
if (!statusDot) return;
const statusClasses = [
"worker-status--online",
"worker-status--offline",
"worker-status--unknown",
"worker-status--processing",
];
statusDot.classList.remove(...statusClasses);
const colorClassMap = {
[STATUS_COLORS.ONLINE_GREEN]: "worker-status--online",
[STATUS_COLORS.OFFLINE_RED]: "worker-status--offline",
[STATUS_COLORS.DISABLED_GRAY]: "worker-status--unknown",
[STATUS_COLORS.PROCESSING_YELLOW]: "worker-status--processing",
};
const statusClass = colorClassMap[color] || "worker-status--unknown";
statusDot.classList.add(statusClass);
statusDot.style.backgroundColor = "";
statusDot.title = title;
statusDot.classList.toggle('status-pulsing', pulsing);
}
showLogModal(extension, workerId, logData, fetchLog = null) {
if (this._logModal) {
this._logModal.unmount();
this._logModal = null;
}
const worker = extension.config.workers.find(w => w.id === workerId);
const workerName = worker?.name || `Worker ${workerId}`;
const modal = createLogModal();
this._logModal = modal;
const themeClass =
extension.panelElement?.classList.contains("distributed-panel--light")
? "distributed-panel--light"
: "";
modal.mount(document.body, {
workerName,
logData,
fetchLog: fetchLog || (async () => extension.api.getWorkerLog(workerId, 1000)),
themeClass,
onClose: () => {
if (this._logModal === modal) {
this._logModal = null;
}
},
});
}
createWorkerSettingsForm(extension, worker) {
return createWorkerSettingsFormFn(this, extension, worker);
}
createSettingsToggle() {
const settingsRow = document.createElement("div");
settingsRow.style.cssText = this.styles.settingsToggle;
const settingsTitle = document.createElement("h4");
settingsTitle.textContent = "Settings";
settingsTitle.style.cssText = "margin: 0; font-size: 14px;";
const settingsToggle = document.createElement("span");
settingsToggle.textContent = "▶"; // Right arrow when collapsed
settingsToggle.style.cssText =
"font-size: 12px; color: var(--dist-settings-arrow, #888); transition: all 0.2s ease;";
settingsRow.appendChild(settingsToggle);
settingsRow.appendChild(settingsTitle);
return { settingsRow, settingsToggle };
}
createCheckboxOrIconColumn(config, data, extension) {
const column = this.createCardColumn('checkbox');
if (config?.type === 'icon') {
column.style.flex = `0 0 ${config.width || 44}px`;
column.innerHTML = config.content || '+';
if (config.style) {
const styles = config.style.split(';').filter(s => s.trim());
styles.forEach(style => {
const [prop, value] = style.split(':').map(s => s.trim());
if (prop && value) {
column.style[prop.replace(/-([a-z])/g, (g) => g[1].toUpperCase())] = value;
}
});
}
} else {
const checkbox = document.createElement("input");
checkbox.type = "checkbox";
checkbox.id = `gpu-${data?.id || 'master'}`;
checkbox.checked = config?.checked !== undefined ? config.checked : data?.enabled;
checkbox.disabled = config?.disabled || false;
checkbox.style.cssText = `cursor: ${config?.disabled ? 'default' : 'pointer'}; width: 16px; height: 16px;`;
if (config?.opacity) checkbox.style.opacity = config.opacity;
if (config?.title) column.title = config.title;
const isMasterToggle = config?.masterToggle && typeof extension.isMasterParticipating === 'function';
if (isMasterToggle) {
const participationEnabled = extension.isMasterParticipationEnabled();
const fallbackActive = extension.isMasterFallbackActive();
const buildTitle = (enabled, fallback) => {
if (enabled) {
return "Master participating • Click to switch to orchestrator-only";
}
if (fallback) {
return "No workers selected • Master fallback execution active";
}
return "Master orchestrator-only • Click to re-enable participation";
};
checkbox.checked = participationEnabled;
checkbox.style.pointerEvents = "none";
column.style.cursor = "pointer";
column.title = buildTitle(participationEnabled, fallbackActive);
column.onclick = async (event) => {
if (event) {
event.stopPropagation();
event.preventDefault();
}
const nextState = !extension.isMasterParticipationEnabled();
const nextFallback = !nextState && extension.enabledWorkers.length === 0;
checkbox.checked = nextState;
column.title = buildTitle(nextState, nextFallback);
await extension.updateMasterParticipation(nextState);
};
} else if (config?.enabled && !config?.disabled && data?.id) {
checkbox.style.pointerEvents = "none";
column.style.cursor = "pointer";
column.onclick = async () => {
checkbox.checked = !checkbox.checked;
await extension.updateWorkerEnabled(data.id, checkbox.checked);
};
}
column.appendChild(checkbox);
}
return column;
}
createStatusDotHelper(config, data, extension) {
let color = config.color || "#666";
let title = config.title || "Status";
let id = config.id;
if (typeof config.initialColor === 'function') {
color = config.initialColor(data, extension);
}
if (typeof config.initialTitle === 'function') {
title = config.initialTitle(data, extension);
}
if (typeof config.id === 'function') {
id = config.id(data);
}
const dot = this.createStatusDot(id, color, title);
if (config.border) {
dot.style.border = config.border;
}
if (config.pulsing && (typeof config.pulsing !== 'function' || config.pulsing(data))) {
dot.classList.add('status-pulsing');
}
return dot;
}
createSettingsToggleHelper(expandedId, extension) {
const arrow = document.createElement("span");
arrow.className = "settings-arrow";
arrow.innerHTML = "▶";
arrow.style.cssText = this.styles.settingsArrow;
const isExpanded = typeof expandedId === 'function' ?
extension.state.isWorkerExpanded(expandedId(extension)) :
(expandedId === 'master' ? false : extension.state.isWorkerExpanded(expandedId));
if (isExpanded) {
arrow.style.transform = "rotate(90deg)";
}
return arrow;
}
createControlsSection(config, data, extension, isRemote) {
if (!config) return null;
const controlsDiv = document.createElement("div");
controlsDiv.id = `controls-${data?.id || 'master'}`;
controlsDiv.style.cssText = this.styles.controlsDiv;
// Always create a wrapper div for consistent layout
const controlsWrapper = document.createElement("div");
controlsWrapper.style.cssText = this.styles.controlsWrapper;
if (config.type === 'master') {
const participationEnabled = extension.isMasterParticipationEnabled();
const fallbackActive = extension.isMasterFallbackActive();
let message;
const badge = document.createElement("div");
badge.classList.add("dist-info-box", "master-info-badge");
badge.style.cssText = this.styles.infoBox;
if (fallbackActive) {
message = "No workers selected. Master fallback execution active.";
badge.textContent = message;
badge.classList.add("master-info-badge--fallback");
} else if (!participationEnabled) {
message = "Master disabled: running as orchestrator only";
badge.textContent = message;
badge.classList.add("master-info-badge--delegate");
} else {
message = "Master participating in workflows";
badge.textContent = message;
}
controlsWrapper.appendChild(badge);
} else if (config.dynamic && data) {
if (isRemote) {
const isCloud = data.type === 'cloud';
const workerTypeText = isCloud ? "Cloud worker" : "Remote worker";
const workerTypeBadge = this.createInfoBox(workerTypeText);
workerTypeBadge.title = "Worker is externally hosted";
controlsWrapper.appendChild(workerTypeBadge);
const logBtn = this.createButton('View Log', () => viewWorkerLog(extension, data.id, true));
logBtn.id = `log-${data.id}`;
logBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.workerControl;
logBtn.classList.add("btn--log");
logBtn.title = "View remote worker log";
controlsWrapper.appendChild(logBtn);
} else {
const controls = this.createWorkerControls(data.id, {
launch: () => launchWorker(extension, data.id),
stop: () => stopWorker(extension, data.id),
viewLog: () => viewWorkerLog(extension, data.id)
});
const launchBtn = controls.querySelector(`#launch-${data.id}`);
const stopBtn = controls.querySelector(`#stop-${data.id}`);
const logBtn = controls.querySelector(`#log-${data.id}`);
launchBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.workerControl;
launchBtn.classList.add("btn--launch");
launchBtn.title = "Launch worker (runs in background with logging)";
stopBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.workerControl + BUTTON_STYLES.hidden;
stopBtn.classList.add("btn--stop");
stopBtn.title = "Stop worker";
logBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.workerControl + BUTTON_STYLES.hidden;
logBtn.classList.add("btn--log");
while (controls.firstChild) {
controlsWrapper.appendChild(controls.firstChild);
}
}
} else if (config.type === 'info') {
const infoBtn = this.createButton(config.text, null, config.style || "");
infoBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.workerControl + (config.style || BUTTON_STYLES.info) + " cursor: default;";
infoBtn.disabled = true;
controlsWrapper.appendChild(infoBtn);
} else if (config.type === 'ghost') {
const ghostBtn = document.createElement("button");
ghostBtn.style.cssText = `flex: 1; padding: 5px 14px; font-size: 11px; font-weight: 500; border-radius: 4px; cursor: default; ${config.style || ""}`;
ghostBtn.textContent = config.text;
ghostBtn.disabled = true;
controlsWrapper.appendChild(ghostBtn);
}
controlsDiv.appendChild(controlsWrapper);
return controlsDiv;
}
createSettingsSection(config, data, extension) {
const settingsDiv = document.createElement("div");
const settingsId = typeof config.id === 'function' ? config.id(data) : config.id;
settingsDiv.id = settingsId;
settingsDiv.className = "worker-settings";
const expandedId = typeof config.expandedId === 'function' ? config.expandedId(data) : config.expandedId;
const isExpanded = expandedId === 'master-settings' ? false : extension.state.isWorkerExpanded(expandedId);
settingsDiv.style.cssText = this.styles.workerSettings;
if (isExpanded) {
settingsDiv.classList.add("expanded");
settingsDiv.style.padding = "12px";
settingsDiv.style.marginTop = "8px";
settingsDiv.style.marginBottom = "8px";
}
let settingsForm;
if (config.formType === 'master') {
settingsForm = this.createMasterSettingsForm(extension, data);
} else if (config.formType === 'worker') {
settingsForm = this.createWorkerSettingsForm(extension, data);
}
if (settingsForm) {
settingsDiv.appendChild(settingsForm);
}
return settingsDiv;
}
createMasterSettingsForm(extension, data) {
const settingsForm = document.createElement("div");
settingsForm.style.cssText = "display: flex; flex-direction: column; gap: 8px;";
const nameResult = this.createFormGroup("Name:", extension.config?.master?.name || "Master", "master-name");
settingsForm.appendChild(nameResult.group);
const hostResult = this.createFormGroup("Host:", extension.config?.master?.host || "", "master-host", "text", "Auto-detect if empty");
settingsForm.appendChild(hostResult.group);
// Cloudflare tunnel toggle (simple button inside master settings)
const tunnelBtn = this.createButton("Enable Cloudflare Tunnel", (e) => extension.handleTunnelToggle(e.target));
tunnelBtn.id = "cloudflare-tunnel-button";
tunnelBtn.style.cssText = BUTTON_STYLES.base + " margin: 4px 0 -5px 0;";
tunnelBtn.classList.add("tunnel-button", "tunnel-button--enable");
settingsForm.appendChild(tunnelBtn);
extension.tunnelElements = { button: tunnelBtn };
extension.updateTunnelUIElements();
const saveBtn = this.createButton("Save", async () => {
const nameInput = document.getElementById('master-name');
const hostInput = document.getElementById('master-host');
if (!extension.config.master) extension.config.master = {};
extension.config.master.name = nameInput.value.trim() || "Master";
const hostValue = hostInput.value.trim();
await extension.api.updateMaster({
host: hostValue,
name: extension.config.master.name
});
// Reload config to refresh any updated values
await extension.loadConfig();
// If host was emptied, trigger auto-detection
if (!hostValue) {
extension.log("Host field cleared, triggering IP auto-detection", "debug");
await extension.detectMasterIP();
// Reload config again to get the auto-detected IP
await extension.loadConfig();
// Update the input field with the detected IP
document.getElementById('master-host').value = extension.config?.master?.host || "";
}
document.getElementById('master-name-display').textContent = extension.config.master.name;
this.updateMasterDisplay(extension);
// Show toast notification
if (extension.app?.extensionManager?.toast) {
const message = !hostValue ?
"Master settings saved and IP auto-detected" :
"Master settings saved successfully";
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Master Updated",
detail: message,
life: 3000
});
}
saveBtn.textContent = "Saved!";
setTimeout(() => { saveBtn.textContent = "Save"; }, TIMEOUTS.FLASH_LONG);
}, "background-color: #4a7c4a;");
saveBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.success;
const cancelBtn = this.createButton("Cancel", () => {
document.getElementById('master-name').value = extension.config?.master?.name || "Master";
document.getElementById('master-host').value = extension.config?.master?.host || "";
}, "background-color: #555;");
cancelBtn.style.cssText = BUTTON_STYLES.base + BUTTON_STYLES.cancel;
const buttonGroup = this.createButtonGroup([saveBtn, cancelBtn], " margin-top: 8px;");
settingsForm.appendChild(buttonGroup);
return settingsForm;
}
addPlaceholderHover(card, leftColumn, entityType) {
const cardTypeClass = entityType === 'blueprint' ? 'placeholder-card--blueprint' : 'placeholder-card--add';
const columnTypeClass = entityType === 'blueprint' ? 'placeholder-column--blueprint' : 'placeholder-column--add';
card.classList.add('placeholder-card', cardTypeClass);
leftColumn.classList.add('placeholder-column', columnTypeClass);
card.onmouseover = () => {
card.classList.add('is-hovered');
leftColumn.classList.add('is-hovered');
};
card.onmouseout = () => {
card.classList.remove('is-hovered');
leftColumn.classList.remove('is-hovered');
};
}
renderEntityCard(entityType, data, extension) {
return renderEntityCardFn(this, cardConfigs, entityType, data, extension);
}
}
================================================
FILE: web/urlUtils.js
================================================
export function normalizeWorkerUrl(rawUrl) {
if (!rawUrl || typeof rawUrl !== "string") {
return "";
}
const trimmed = rawUrl.trim();
if (!trimmed) {
return "";
}
const withProtocol = /^https?:\/\//i.test(trimmed) ? trimmed : `http://${trimmed}`;
try {
const parsed = new URL(withProtocol);
if (parsed.pathname === "/") {
parsed.pathname = "";
}
return parsed.toString().replace(/\/$/, "");
} catch (error) {
return withProtocol.replace(/\/$/, "");
}
}
export function parseHostInput(value) {
if (!value) {
return { host: "", port: null };
}
let cleaned = value.trim().replace(/^https?:\/\//i, "");
cleaned = cleaned.split("/")[0];
try {
const url = new URL(`http://${cleaned}`);
const port = url.port ? parseInt(url.port, 10) : null;
return {
host: url.hostname || cleaned,
port: Number.isFinite(port) ? port : null,
};
} catch (error) {
return { host: cleaned, port: null };
}
}
export function buildWorkerUrl(worker, endpoint = "", windowLocation = window.location) {
const parsed = parseHostInput(worker?.host || windowLocation.hostname);
const host = parsed.host || windowLocation.hostname;
const resolvedPort = parsed.port || worker?.port || 8188;
const isCloud = worker?.type === "cloud";
const isRunpodProxy = host.endsWith(".proxy.runpod.net");
let finalHost = host;
if (!worker?.host && isRunpodProxy) {
const match = host.match(/^(.*)\.proxy\.runpod\.net$/);
if (match) {
finalHost = `${match[1]}-${resolvedPort}.proxy.runpod.net`;
} else {
console.error(`[Distributed] Failed to parse Runpod proxy host: ${host}`);
}
}
const useHttps = isCloud || isRunpodProxy || resolvedPort === 443;
const protocol = useHttps ? "https" : "http";
const defaultPort = useHttps ? 443 : 80;
const needsPort = !isRunpodProxy && resolvedPort !== defaultPort;
const portPart = needsPort ? `:${resolvedPort}` : "";
return normalizeWorkerUrl(`${protocol}://${finalHost}${portPart}${endpoint}`);
}
export function buildWorkerWebSocketUrl(workerUrl) {
const normalized = normalizeWorkerUrl(workerUrl);
const wsBase = normalized.replace(/^http:\/\//i, "ws://").replace(/^https:\/\//i, "wss://");
return `${wsBase}/distributed/worker_ws`;
}
export function getMasterUrl(config, windowLocation = window.location, log = null) {
const masterHost = config?.master?.host;
if (masterHost) {
const configuredHost = masterHost;
// If the configured host already includes protocol, use as-is.
if (configuredHost.startsWith("http://") || configuredHost.startsWith("https://")) {
return configuredHost;
}
// For domain names (not IPs), default to HTTPS.
const isIP = /^(\d{1,3}\.){3}\d{1,3}$/.test(configuredHost);
const isLocalhost = configuredHost === "localhost" || configuredHost === "127.0.0.1";
if (!isIP && !isLocalhost && configuredHost.includes(".")) {
return `https://${configuredHost}`;
}
const protocol = windowLocation.protocol || "http:";
const port = windowLocation.port || (protocol === "https:" ? "443" : "80");
if ((protocol === "https:" && port === "443") || (protocol === "http:" && port === "80")) {
return `${protocol}//${configuredHost}`;
}
return `${protocol}//${configuredHost}:${port}`;
}
const hostname = windowLocation.hostname;
if (hostname !== "localhost" && hostname !== "127.0.0.1") {
return windowLocation.origin;
}
if (typeof log === "function") {
log(
"No master host configured - remote workers won't be able to connect. Master host should be auto-detected on startup.",
"debug",
);
}
return windowLocation.origin;
}
================================================
FILE: web/workerLifecycle.js
================================================
import { TIMEOUTS, STATUS_COLORS } from './constants.js';
import { buildWorkerUrl, normalizeWorkerUrl } from './urlUtils.js';
import { isRemoteWorker } from './workerSettings.js';
import { applyProbeResultToWorkerDot } from './workerUtils.js';
export { normalizeWorkerUrl };
let _statusCheckRunning = false;
function setStatusDotClass(dot, statusClass) {
if (!dot) {
return;
}
const classes = [
"worker-status--online",
"worker-status--offline",
"worker-status--unknown",
"worker-status--processing",
];
dot.classList.remove(...classes);
if (statusClass) {
dot.classList.add(statusClass);
}
}
function setButtonClass(button, className) {
if (!button) {
return;
}
button.classList.remove("btn--stop", "btn--launch", "btn--log", "btn--working", "btn--success", "btn--error");
if (className) {
button.classList.add(className);
}
}
function setButtonVisibility(button, visible) {
if (!button) {
return;
}
button.classList.toggle("is-hidden", !visible);
button.style.display = visible ? "" : "none";
}
export async function checkAllWorkerStatuses(extension) {
if (_statusCheckRunning || !extension.panelElement) {
return;
}
_statusCheckRunning = true;
let nextInterval = 5000;
try {
// Create a fresh AbortController for this poll cycle.
extension.statusCheckAbortController = new AbortController();
await checkMasterStatus(extension);
if (extension.config?.workers) {
await Promise.all(
extension.config.workers.map(async (worker) => {
if (worker.enabled || extension.state.isWorkerLaunching(worker.id)) {
await checkWorkerStatus(extension, worker);
}
})
);
}
let isActive = extension.state.getMasterStatus() === "processing";
extension.config?.workers?.forEach((worker) => {
const workerState = extension.state.getWorker(worker.id);
if (workerState.launching || workerState.status?.processing) {
isActive = true;
}
});
nextInterval = isActive ? 1000 : 5000;
} finally {
_statusCheckRunning = false;
if (extension.panelElement) {
extension.statusCheckTimeout = setTimeout(() => checkAllWorkerStatuses(extension), nextInterval);
}
}
}
export async function checkMasterStatus(extension) {
try {
const signal = extension.statusCheckAbortController?.signal || null;
const probeResult = await extension.api.probeWorker(
window.location.origin,
TIMEOUTS.STATUS_CHECK,
signal,
);
if (!probeResult.ok) {
throw new Error(`HTTP ${probeResult.status}`);
}
const queueRemaining = probeResult.queueRemaining || 0;
const isProcessing = queueRemaining > 0;
// Update master status in state
extension.state.setMasterStatus(isProcessing ? "processing" : "online");
// Update master status dot
const statusDot = document.getElementById("master-status");
if (statusDot) {
if (!extension.isMasterParticipating()) {
if (isProcessing) {
setStatusDotClass(statusDot, "worker-status--processing");
statusDot.title = `Orchestrating (${queueRemaining} in queue)`;
} else {
setStatusDotClass(statusDot, "worker-status--unknown");
statusDot.title = "Master orchestrator only";
}
} else if (isProcessing) {
setStatusDotClass(statusDot, "worker-status--processing");
statusDot.title = `Processing (${queueRemaining} in queue)`;
} else {
setStatusDotClass(statusDot, "worker-status--online");
statusDot.title = "Online";
}
}
} catch (error) {
if (error?.name === "AbortError") {
return;
}
// Master is always online (we're running on it), so keep it green
const statusDot = document.getElementById("master-status");
if (statusDot) {
setStatusDotClass(
statusDot,
extension.isMasterParticipating() ? "worker-status--online" : "worker-status--unknown"
);
statusDot.title = extension.isMasterParticipating() ? "Online" : "Master orchestrator only";
}
}
}
// Helper to build worker URL
export function getWorkerUrl(extension, worker, endpoint = "") {
return buildWorkerUrl(worker, endpoint, window.location);
}
export async function checkWorkerStatus(extension, worker) {
// Assume caller ensured enabled; proceed with check
const workerUrl = getWorkerUrl(extension, worker);
try {
const signal = extension.statusCheckAbortController?.signal || null;
const probeResult = await extension.api.probeWorker(
workerUrl,
TIMEOUTS.STATUS_CHECK,
signal,
);
if (!probeResult.ok) {
throw new Error(`HTTP ${probeResult.status}`);
}
const queueRemaining = probeResult.queueRemaining || 0;
const isProcessing = queueRemaining > 0;
// Update status
extension.state.setWorkerStatus(worker.id, {
online: true,
processing: isProcessing,
queueCount: queueRemaining,
});
// Update status dot based on probe result
applyProbeResultToWorkerDot(worker.id, probeResult);
// Clear launching state since worker is now online
if (extension.state.isWorkerLaunching(worker.id)) {
extension.state.setWorkerLaunching(worker.id, false);
clearLaunchingFlag(extension, worker.id);
}
} catch (error) {
// Don't process aborted requests
if (error.name === "AbortError") {
return;
}
// Worker is offline or unreachable
extension.state.setWorkerStatus(worker.id, {
online: false,
processing: false,
queueCount: 0,
});
// Check if worker is launching
if (extension.state.isWorkerLaunching(worker.id)) {
extension.ui.updateStatusDot(worker.id, STATUS_COLORS.PROCESSING_YELLOW, "Launching...", true);
} else if (worker.enabled) {
// Only update to red if not currently launching AND still enabled.
applyProbeResultToWorkerDot(worker.id, { ok: false });
}
// If disabled, don't update the dot (leave it gray)
extension.log(`Worker ${worker.id} status check failed: ${error.message}`, "debug");
}
// Update control buttons based on new status
const updatedInPlace = extension.updateWorkerCard?.(worker.id, extension.state.getWorkerStatus(worker.id));
if (!updatedInPlace) {
updateWorkerControls(extension, worker.id);
}
}
export async function launchWorker(extension, workerId) {
const worker = extension.config.workers.find((w) => w.id === workerId);
// If worker is disabled, enable it first
if (!worker.enabled) {
await extension.updateWorkerEnabled(workerId, true);
// Update the checkbox UI
const checkbox = document.getElementById(`gpu-${workerId}`);
if (checkbox) {
checkbox.checked = true;
}
}
// Re-query button AFTER updateWorkerEnabled (which may re-render sidebar)
const launchBtn = document.querySelector(`#controls-${workerId} button`);
extension.ui.updateStatusDot(workerId, STATUS_COLORS.PROCESSING_YELLOW, "Launching...", true);
extension.state.setWorkerLaunching(workerId, true);
// Allow 90 seconds for worker to launch (model loading can take time)
setTimeout(() => {
extension.state.setWorkerLaunching(workerId, false);
}, TIMEOUTS.LAUNCH);
if (!launchBtn) {
return;
}
try {
// Disable button immediately
launchBtn.disabled = true;
const result = await extension.api.launchWorker(workerId);
if (result) {
extension.log(`Launched ${worker.name} (PID: ${result.pid})`, "info");
if (result.log_file) {
extension.log(`Log file: ${result.log_file}`, "debug");
}
extension.state.setWorkerManaged(workerId, {
pid: result.pid,
log_file: result.log_file,
started_at: Date.now(),
});
// Update controls immediately to hide launch button and show stop/log buttons
updateWorkerControls(extension, workerId);
setTimeout(() => checkWorkerStatus(extension, worker), TIMEOUTS.STATUS_CHECK);
}
} catch (error) {
// Check if worker was already running
if (error.message && error.message.includes("already running")) {
extension.log(`Worker ${worker.name} is already running`, "info");
updateWorkerControls(extension, workerId);
setTimeout(() => checkWorkerStatus(extension, worker), TIMEOUTS.STATUS_CHECK_DELAY);
} else {
extension.log(`Error launching worker: ${error.message || error}`, "error");
// Re-enable button on error
if (launchBtn) {
launchBtn.disabled = false;
}
}
}
}
export async function stopWorker(extension, workerId) {
const worker = extension.config.workers.find((w) => w.id === workerId);
const stopBtn = document.querySelectorAll(`#controls-${workerId} button`)[1];
// Provide immediate feedback
if (stopBtn) {
stopBtn.disabled = true;
stopBtn.textContent = "Stopping...";
setButtonClass(stopBtn, "btn--working");
}
try {
const result = await extension.api.stopWorker(workerId);
if (result) {
extension.log(`Stopped worker: ${result.message}`, "info");
extension.state.setWorkerManaged(workerId, null);
// Immediately update status to offline
extension.ui.updateStatusDot(workerId, STATUS_COLORS.OFFLINE_RED, "Offline");
extension.state.setWorkerStatus(workerId, { online: false });
// Flash success feedback
if (stopBtn) {
setButtonClass(stopBtn, "btn--success");
stopBtn.textContent = "Stopped!";
setTimeout(() => {
updateWorkerControls(extension, workerId);
}, TIMEOUTS.FLASH_SHORT);
}
// Verify status after a short delay
setTimeout(() => checkWorkerStatus(extension, worker), TIMEOUTS.STATUS_CHECK);
} else {
extension.log(`Failed to stop worker: ${result.message}`, "error");
// Flash error feedback
if (stopBtn) {
setButtonClass(stopBtn, "btn--error");
stopBtn.textContent = result.message.includes("already stopped") ? "Not Running" : "Failed";
// If already stopped, update status immediately
if (result.message.includes("already stopped")) {
extension.ui.updateStatusDot(workerId, STATUS_COLORS.OFFLINE_RED, "Offline");
extension.state.setWorkerStatus(workerId, { online: false });
}
setTimeout(() => {
updateWorkerControls(extension, workerId);
}, TIMEOUTS.FLASH_MEDIUM);
}
}
} catch (error) {
extension.log(`Error stopping worker: ${error}`, "error");
// Reset button on error
if (stopBtn) {
setButtonClass(stopBtn, "btn--error");
stopBtn.textContent = "Error";
setTimeout(() => {
updateWorkerControls(extension, workerId);
}, TIMEOUTS.FLASH_MEDIUM);
}
}
}
export async function clearLaunchingFlag(extension, workerId) {
try {
await extension.api.clearLaunchingFlag(workerId);
extension.log(`Cleared launching flag for worker ${workerId}`, "debug");
} catch (error) {
extension.log(`Error clearing launching flag: ${error.message || error}`, "error");
}
}
export async function loadManagedWorkers(extension) {
try {
const result = await extension.api.getManagedWorkers();
// Check for launching workers
for (const [workerId, info] of Object.entries(result.managed_workers)) {
extension.state.setWorkerManaged(workerId, info);
// If worker is marked as launching, add to launchingWorkers set
if (info.launching) {
extension.state.setWorkerLaunching(workerId, true);
extension.log(`Worker ${workerId} is in launching state`, "debug");
}
}
// Update UI for all workers
if (extension.config?.workers) {
extension.config.workers.forEach((w) => updateWorkerControls(extension, w.id));
}
} catch (error) {
extension.log(`Error loading managed workers: ${error}`, "error");
}
}
export function updateWorkerControls(extension, workerId) {
const controlsDiv = document.getElementById(`controls-${workerId}`);
if (!controlsDiv) {
return;
}
const worker = extension.config.workers.find((w) => w.id === workerId);
if (!worker) {
return;
}
// Update button states - buttons are now inside a wrapper div
const launchBtn = document.getElementById(`launch-${workerId}`);
const stopBtn = document.getElementById(`stop-${workerId}`);
const logBtn = document.getElementById(`log-${workerId}`);
if (isRemoteWorker(extension, worker)) {
setButtonVisibility(launchBtn, false);
setButtonVisibility(stopBtn, false);
if (logBtn) {
setButtonVisibility(logBtn, true);
logBtn.disabled = false;
logBtn.textContent = "View Log";
setButtonClass(logBtn, "btn--log");
}
return;
}
// Ensure we check for string ID
const managedInfo = extension.state.getWorker(workerId).managed;
const status = extension.state.getWorkerStatus(workerId);
// Show log button immediately if we have log file info (even if worker is still starting)
if (logBtn) {
const showLog = Boolean(managedInfo?.log_file);
setButtonVisibility(logBtn, showLog);
if (showLog) {
setButtonClass(logBtn, "btn--log");
}
}
if (status?.online || managedInfo) {
// Worker is running or we just launched it
setButtonVisibility(launchBtn, false);
if (managedInfo) {
// Only show stop button if we manage this worker
setButtonVisibility(stopBtn, true);
stopBtn.disabled = false;
stopBtn.textContent = "Stop";
setButtonClass(stopBtn, "btn--stop");
} else {
// Hide stop button for workers launched outside UI
setButtonVisibility(stopBtn, false);
}
} else {
// Worker is not running
setButtonVisibility(launchBtn, true);
launchBtn.disabled = false;
launchBtn.textContent = "Launch";
setButtonClass(launchBtn, "btn--launch");
setButtonVisibility(stopBtn, false);
}
}
export async function viewWorkerLog(extension, workerId, isRemote = false) {
const worker = extension.config.workers.find((w) => w.id === workerId);
const isRemoteLog = isRemote || (worker ? isRemoteWorker(extension, worker) : false);
const managedInfo = extension.state.getWorker(workerId).managed;
if (!isRemoteLog && !managedInfo?.log_file) {
return;
}
const logBtn = document.getElementById(`log-${workerId}`);
// Provide immediate feedback
if (logBtn) {
logBtn.disabled = true;
logBtn.textContent = "Loading...";
setButtonClass(logBtn, "btn--working");
}
try {
const fetchLog = isRemoteLog
? async () => extension.api.getRemoteWorkerLog(workerId, 300)
: async () => extension.api.getWorkerLog(workerId, 1000);
const data = await fetchLog();
// Create modal dialog
extension.ui.showLogModal(extension, workerId, data, fetchLog);
// Restore button
if (logBtn) {
logBtn.disabled = false;
logBtn.textContent = "View Log";
setButtonClass(logBtn, "btn--log");
}
} catch (error) {
extension.log("Error viewing log: " + error.message, "error");
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Error",
detail: `Failed to load log: ${error.message}`,
life: 5000,
});
// Flash error and restore button
if (logBtn) {
setButtonClass(logBtn, "btn--error");
logBtn.textContent = "Error";
setTimeout(() => {
logBtn.disabled = false;
logBtn.textContent = "View Log";
setButtonClass(logBtn, "btn--log");
}, TIMEOUTS.FLASH_LONG);
}
}
}
export async function refreshLog(extension, workerId, silent = false) {
const logContent = document.getElementById("distributed-log-content");
if (!logContent) {
return;
}
try {
const worker = extension.config.workers.find((w) => w.id === workerId);
const isRemoteLog = worker ? isRemoteWorker(extension, worker) : false;
const data = isRemoteLog
? await extension.api.getRemoteWorkerLog(workerId, 300)
: await extension.api.getWorkerLog(workerId, 1000);
// Update content
const shouldAutoScroll = logContent.scrollTop + logContent.clientHeight >= logContent.scrollHeight - 50;
logContent.textContent = data.content;
// Auto-scroll if was at bottom
if (shouldAutoScroll) {
logContent.scrollTop = logContent.scrollHeight;
}
// Only show toast if not in silent mode (manual refresh)
if (!silent) {
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Log Refreshed",
detail: "Log content updated",
life: 2000,
});
}
} catch (error) {
// Only show error toast if not in silent mode
if (!silent) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Refresh Failed",
detail: error.message,
life: 3000,
});
}
}
}
export function startLogAutoRefresh(extension, workerId) {
// Stop any existing auto-refresh
stopLogAutoRefresh(extension);
// Refresh every 2 seconds
extension.logAutoRefreshInterval = setInterval(() => {
refreshLog(extension, workerId, true); // silent mode
}, TIMEOUTS.LOG_REFRESH);
}
export function stopLogAutoRefresh(extension) {
if (extension.logAutoRefreshInterval) {
clearInterval(extension.logAutoRefreshInterval);
extension.logAutoRefreshInterval = null;
}
}
export function toggleWorkerExpanded(extension, workerId) {
const gpuDiv = document.querySelector(`[data-worker-id="${workerId}"]`);
const settingsDiv = gpuDiv?.querySelector(`#settings-${workerId}`) || document.getElementById(`settings-${workerId}`);
const settingsArrow = gpuDiv?.querySelector(".settings-arrow");
if (!settingsDiv) {
return;
}
if (extension.state.isWorkerExpanded(workerId)) {
extension.state.setWorkerExpanded(workerId, false);
settingsDiv.classList.remove("expanded");
settingsDiv.style.padding = "0 12px";
settingsDiv.style.marginTop = "0";
settingsDiv.style.marginBottom = "0";
if (settingsArrow) {
settingsArrow.classList.remove("settings-arrow--expanded");
}
} else {
extension.state.setWorkerExpanded(workerId, true);
settingsDiv.classList.add("expanded");
settingsDiv.style.padding = "12px";
settingsDiv.style.marginTop = "8px";
settingsDiv.style.marginBottom = "8px";
if (settingsArrow) {
settingsArrow.classList.add("settings-arrow--expanded");
}
}
}
================================================
FILE: web/workerSettings.js
================================================
import { renderSidebarContent } from './sidebarRenderer.js';
import { generateUUID } from './constants.js';
import { parseHostInput } from './urlUtils.js';
import { toggleWorkerExpanded } from './workerLifecycle.js';
const WORKERS_CHANGED_EVENT = "distributed:workers-changed";
function emitWorkersChanged(extension) {
if (typeof window === "undefined" || typeof window.dispatchEvent !== "function") {
return;
}
window.dispatchEvent(new CustomEvent(WORKERS_CHANGED_EVENT, {
detail: { workers: extension.config?.workers || [] },
}));
}
export function isRemoteWorker(extension, worker) {
const workerType = String(worker?.type || "").toLowerCase();
// Explicit type always wins over host heuristics.
if (workerType === "cloud" || workerType === "remote") {
return true;
}
if (workerType === "local") {
return false;
}
// Otherwise check by host (backward compatibility)
const parsed = parseHostInput(worker?.host || window.location.hostname);
const host = String(parsed.host || window.location.hostname || "").toLowerCase();
const currentHost = String(parseHostInput(window.location.hostname).host || window.location.hostname || "").toLowerCase();
const localHosts = new Set(["", "localhost", "127.0.0.1", "::1", "[::1]", "0.0.0.0"]);
return !(localHosts.has(host) || host === currentHost);
}
export function isCloudWorker(extension, worker) {
return worker.type === "cloud";
}
export async function saveWorkerSettings(extension, workerId) {
const worker = extension.config.workers.find((w) => w.id === workerId);
if (!worker) {
return;
}
// Get form values
const name = document.getElementById(`name-${workerId}`).value;
const workerType = document.getElementById(`worker-type-${workerId}`).value;
const isRemote = workerType === "remote" || workerType === "cloud";
const isCloud = workerType === "cloud";
const rawHost = isRemote ? document.getElementById(`host-${workerId}`).value : window.location.hostname;
const parsedHost = isRemote ? parseHostInput(rawHost) : { host: window.location.hostname, port: null };
const host = isRemote ? parsedHost.host : window.location.hostname;
const hostTrimmed = (host || "").trim();
let port = parseInt(document.getElementById(`port-${workerId}`).value);
const cudaDevice = isRemote ? undefined : parseInt(document.getElementById(`cuda-${workerId}`).value);
const extraArgs = isRemote ? undefined : document.getElementById(`args-${workerId}`).value;
if (isRemote && Number.isFinite(parsedHost.port)) {
port = parsedHost.port;
}
// Validate
if (!name.trim()) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Validation Error",
detail: "Worker name is required",
life: 3000,
});
return;
}
if ((workerType === "remote" || workerType === "cloud") && !hostTrimmed) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Validation Error",
detail: "Host is required for remote workers",
life: 3000,
});
return;
}
if (!isCloud && (isNaN(port) || port < 1 || port > 65535)) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Validation Error",
detail: "Port must be between 1 and 65535",
life: 3000,
});
return;
}
// Check for port conflicts
// Remote workers can reuse ports, but local workers cannot share ports with each other or master
if (!isRemote) {
// Check if port conflicts with master
const masterPort = parseInt(window.location.port) || (window.location.protocol === "https:" ? 443 : 80);
if (port === masterPort) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Port Conflict",
detail: `Port ${port} is already in use by the master server`,
life: 3000,
});
return;
}
// Check if port conflicts with other local workers
const localPortConflict = extension.config.workers.some(
(w) => w.id !== workerId && w.port === port && !w.host // local workers have no host or host is null
);
if (localPortConflict) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Port Conflict",
detail: `Port ${port} is already in use by another local worker`,
life: 3000,
});
return;
}
} else {
// For remote workers, only check conflicts with other workers on the same host
const sameHostConflict = extension.config.workers.some((w) => w.id !== workerId && w.port === port && w.host === hostTrimmed);
if (sameHostConflict) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Port Conflict",
detail: `Port ${port} is already in use by another worker on ${host}`,
life: 3000,
});
return;
}
}
const wasUnconfiguredRemote =
(worker.type === "remote" || worker.type === "cloud") &&
(!String(worker.host || "").trim()) &&
!worker.enabled;
const nextEnabled = isRemote && hostTrimmed && wasUnconfiguredRemote ? true : worker.enabled;
try {
await extension.api.updateWorker(workerId, {
name: name.trim(),
type: workerType,
host: isRemote ? hostTrimmed : null,
port,
cuda_device: isRemote ? null : cudaDevice,
extra_args: isRemote ? null : extraArgs ? extraArgs.trim() : "",
enabled: nextEnabled,
});
// Update local config
worker.name = name.trim();
worker.type = workerType;
if (isRemote) {
worker.host = hostTrimmed;
delete worker.cuda_device;
delete worker.extra_args;
} else {
delete worker.host;
worker.cuda_device = cudaDevice;
worker.extra_args = extraArgs ? extraArgs.trim() : "";
}
worker.port = port;
worker.enabled = nextEnabled;
// Sync to state
extension.state.updateWorker(workerId, { enabled: nextEnabled });
emitWorkersChanged(extension);
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Settings Saved",
detail: nextEnabled && wasUnconfiguredRemote
? `Worker ${name} configured and enabled`
: `Worker ${name} settings updated`,
life: 3000,
});
// Keep post-save card height consistent with the default collapsed layout.
extension.state.setWorkerExpanded(workerId, false);
// Refresh the UI
if (extension.panelElement) {
renderSidebarContent(extension, extension.panelElement);
}
} catch (error) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Save Failed",
detail: error.message,
life: 5000,
});
}
}
export function cancelWorkerSettings(extension, workerId) {
// Collapse the settings panel
toggleWorkerExpanded(extension, workerId);
// Reset form values to original
const worker = extension.config.workers.find((w) => w.id === workerId);
if (worker) {
document.getElementById(`name-${workerId}`).value = worker.name;
document.getElementById(`host-${workerId}`).value = worker.host || "";
document.getElementById(`port-${workerId}`).value = worker.port;
document.getElementById(`cuda-${workerId}`).value = worker.cuda_device || 0;
document.getElementById(`args-${workerId}`).value = worker.extra_args || "";
// Reset remote checkbox
const remoteCheckbox = document.getElementById(`remote-${workerId}`);
if (remoteCheckbox) {
remoteCheckbox.checked = isRemoteWorker(extension, worker);
}
}
}
export async function deleteWorker(extension, workerId) {
const worker = extension.config.workers.find((w) => w.id === workerId);
if (!worker) {
return;
}
// Confirm deletion
if (!confirm(`Are you sure you want to delete worker "${worker.name}"?`)) {
return;
}
try {
await extension.api.deleteWorker(workerId);
// Remove from local config
const index = extension.config.workers.findIndex((w) => w.id === workerId);
if (index !== -1) {
extension.config.workers.splice(index, 1);
}
emitWorkersChanged(extension);
extension.app.extensionManager.toast.add({
severity: "success",
summary: "Worker Deleted",
detail: `Worker ${worker.name} has been removed`,
life: 3000,
});
// Refresh the UI
if (extension.panelElement) {
renderSidebarContent(extension, extension.panelElement);
}
} catch (error) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Delete Failed",
detail: error.message,
life: 5000,
});
}
}
export async function addNewWorker(extension) {
const toInt = (value) => {
const parsed = Number.parseInt(value, 10);
return Number.isFinite(parsed) ? parsed : null;
};
const totalCudaDevices = toInt(extension.cudaDeviceCount);
const masterCudaDevice = toInt(extension.masterCudaDevice ?? extension.config?.master?.cuda_device);
const localWorkers = (extension.config?.workers || []).filter((w) => !isRemoteWorker(extension, w));
let selectedCudaDevice = extension.config.workers.length;
let fallbackToRemote = false;
if (totalCudaDevices !== null && totalCudaDevices > 0) {
const usedCudaDevices = new Set();
for (const worker of localWorkers) {
const cudaIdx = toInt(worker.cuda_device);
if (cudaIdx !== null) {
usedCudaDevices.add(cudaIdx);
}
}
if (masterCudaDevice !== null) {
usedCudaDevices.add(masterCudaDevice);
}
const availableCudaDevices = [];
for (let i = 0; i < totalCudaDevices; i++) {
if (!usedCudaDevices.has(i)) {
availableCudaDevices.push(i);
}
}
if (availableCudaDevices.length === 0) {
fallbackToRemote = true;
selectedCudaDevice = null;
} else {
selectedCudaDevice = availableCudaDevices[0];
}
}
// Generate new worker ID using UUID (fallback for non-secure contexts)
const newId = generateUUID();
// Find next available port
const usedPorts = extension.config.workers.map((w) => w.port);
let nextPort = 8189;
while (usedPorts.includes(nextPort)) {
nextPort++;
}
// Create new worker object
const newWorker = {
id: newId,
name: `Worker ${extension.config.workers.length + 1}`,
port: nextPort,
type: fallbackToRemote ? "remote" : "local",
host: fallbackToRemote ? "" : null,
cuda_device: selectedCudaDevice,
enabled: fallbackToRemote ? false : true, // Remote fallback starts disabled until configured
extra_args: "",
};
// Add to config
extension.config.workers.push(newWorker);
// Save immediately
try {
await extension.api.updateWorker(newId, {
name: newWorker.name,
port: newWorker.port,
cuda_device: newWorker.cuda_device,
extra_args: newWorker.extra_args,
enabled: newWorker.enabled,
host: newWorker.host,
type: newWorker.type,
});
// Sync to state
extension.state.updateWorker(newId, { enabled: newWorker.enabled });
emitWorkersChanged(extension);
extension.app.extensionManager.toast.add({
severity: fallbackToRemote ? "warn" : "success",
summary: fallbackToRemote ? "Remote Worker Added" : "Worker Added",
detail: fallbackToRemote
? `No local GPU available, so a disabled remote worker was added on port ${nextPort}. Configure host and enable it.`
: `New worker created on port ${nextPort}`,
life: fallbackToRemote ? 5000 : 3000,
});
// Refresh UI and expand the new worker
extension.state.setWorkerExpanded(newId, true);
if (extension.panelElement) {
renderSidebarContent(extension, extension.panelElement);
}
} catch (error) {
extension.app.extensionManager.toast.add({
severity: "error",
summary: "Failed to Add Worker",
detail: error.message,
life: 5000,
});
}
}
================================================
FILE: web/workerUtils.js
================================================
import { TIMEOUTS, ENDPOINTS } from './constants.js';
import { checkAllWorkerStatuses, getWorkerUrl } from './workerLifecycle.js';
export async function handleWorkerOperation(extension, button, operation, successText, errorText) {
const originalText = button.textContent;
const originalStyle = button.style.cssText;
const originalClasses = Array.from(button.classList);
const stateClasses = ["btn--working", "btn--success", "btn--error"];
const setButtonStateClass = (className) => {
button.classList.remove(...stateClasses);
if (className) {
button.classList.add(className);
}
};
button.textContent = operation.loadingText;
button.disabled = true;
setButtonStateClass("btn--working");
try {
const urlsToProcess = extension.enabledWorkers.map(w => ({
name: w.name,
url: getWorkerUrl(extension, w)
}));
if (urlsToProcess.length === 0) {
button.textContent = "No Workers";
setButtonStateClass("btn--error");
setTimeout(() => {
button.textContent = originalText;
button.style.cssText = originalStyle;
button.classList.remove(...stateClasses);
button.classList.add(...originalClasses);
button.disabled = false;
}, TIMEOUTS.BUTTON_RESET);
return;
}
const promises = urlsToProcess.map(target =>
fetch(`${target.url}${operation.endpoint}`, {
method: 'POST',
mode: 'cors'
})
.then(response => ({ ok: response.ok, name: target.name }))
.catch(() => ({ ok: false, name: target.name }))
);
const results = await Promise.all(promises);
const failures = results.filter(r => !r.ok);
if (failures.length === 0) {
button.textContent = successText;
setButtonStateClass("btn--success");
if (operation.onSuccess) operation.onSuccess();
} else {
button.textContent = errorText;
setButtonStateClass("btn--error");
extension.log(`${operation.name} failed on: ${failures.map(f => f.name).join(", ")}`, "error");
}
setTimeout(() => {
button.textContent = originalText;
button.style.cssText = originalStyle;
button.classList.remove(...stateClasses);
button.classList.add(...originalClasses);
}, TIMEOUTS.BUTTON_RESET);
} finally {
button.disabled = false;
}
}
export async function handleInterruptWorkers(extension, button) {
return handleWorkerOperation(extension, button, {
name: "Interrupt",
endpoint: ENDPOINTS.INTERRUPT,
loadingText: "Interrupting...",
onSuccess: () => setTimeout(() => checkAllWorkerStatuses(extension), TIMEOUTS.POST_ACTION_DELAY)
}, "Interrupted!", "Error! See Console");
}
export async function handleClearMemory(extension, button) {
return handleWorkerOperation(extension, button, {
name: "Clear memory",
endpoint: ENDPOINTS.CLEAR_MEMORY,
loadingText: "Clearing..."
}, "Success!", "Error! See Console");
}
export function findNodesByClass(apiPrompt, className) {
return Object.entries(apiPrompt)
.filter(([, nodeData]) => nodeData.class_type === className)
.map(([nodeId, nodeData]) => ({ id: nodeId, data: nodeData }));
}
export function applyProbeResultToWorkerDot(workerId, probeResult) {
const dot = document.getElementById(`status-${workerId}`);
if (!dot) {
return;
}
dot.classList.remove(
'worker-status--online',
'worker-status--offline',
'worker-status--processing',
'worker-status--unknown',
'status-pulsing',
);
if (!probeResult || !probeResult.ok) {
dot.classList.add('worker-status--offline');
dot.title = 'Offline - Cannot connect';
return;
}
if ((probeResult.queueRemaining || 0) > 0) {
dot.classList.add('worker-status--processing');
dot.title = `Processing (${probeResult.queueRemaining} queued)`;
return;
}
dot.classList.add('worker-status--online');
dot.title = 'Online - Idle';
}
================================================
FILE: workers/__init__.py
================================================
from .process_manager import WorkerProcessManager
_worker_manager: WorkerProcessManager | None = None
def get_worker_manager() -> WorkerProcessManager:
global _worker_manager
if _worker_manager is None:
_worker_manager = WorkerProcessManager()
_worker_manager.queues = {}
return _worker_manager
================================================
FILE: workers/detection.py
================================================
import os
import platform
import uuid
import aiohttp
from ..utils.network import normalize_host, get_client_session
from ..utils.logging import debug_log
async def is_local_worker(worker_config):
"""Check if a worker is running on the same machine as the master."""
host = normalize_host(worker_config.get('host', 'localhost')) or 'localhost'
if host in ['localhost', '127.0.0.1', '0.0.0.0', ''] or worker_config.get('type') == 'local':
return True
# For cloud workers, check if on same physical host
if worker_config.get('type') == 'cloud':
return await is_same_physical_host(worker_config)
return False
async def is_same_physical_host(worker_config):
"""Compare machine IDs to determine if worker is on same physical host."""
try:
# Get master machine ID
master_machine_id = get_machine_id()
# Fetch worker's machine ID via API
host = normalize_host(worker_config.get('host', 'localhost')) or 'localhost'
port = worker_config.get('port', 8188)
session = await get_client_session()
async with session.get(
f"http://{host}:{port}/distributed/system_info",
timeout=aiohttp.ClientTimeout(total=5)
) as resp:
if resp.status == 200:
data = await resp.json()
worker_machine_id = data.get('machine_id')
return worker_machine_id == master_machine_id
else:
debug_log(f"Failed to get system info from worker: HTTP {resp.status}")
return False
except Exception as e:
debug_log(f"Error checking same physical host: {e}")
return False
def get_machine_id():
"""Get a unique identifier for this machine."""
# Try multiple methods to get a stable machine ID
try:
# Method 1: MAC address-based UUID
return str(uuid.getnode())
except Exception:
try:
# Method 2: Platform + hostname
import socket
return f"{platform.machine()}_{socket.gethostname()}"
except Exception:
# Fallback
return platform.machine()
def is_docker_environment():
"""Check if running inside Docker container."""
return (os.path.exists('/.dockerenv') or
os.environ.get('DOCKER_CONTAINER', False) or
'docker' in platform.node().lower())
def is_runpod_environment():
"""Check if running in Runpod environment."""
return (os.environ.get('RUNPOD_POD_ID') is not None or
os.environ.get('RUNPOD_API_KEY') is not None)
================================================
FILE: workers/process/__init__.py
================================================
from .launch_builder import LaunchCommandBuilder
from .lifecycle import ProcessLifecycle
from .persistence import ProcessPersistence
from .root_discovery import ComfyRootDiscovery
__all__ = [
"ComfyRootDiscovery",
"LaunchCommandBuilder",
"ProcessLifecycle",
"ProcessPersistence",
]
================================================
FILE: workers/process/launch_builder.py
================================================
import glob
import os
import shlex
import shutil
from ...utils.logging import debug_log
from ...utils.process import get_python_executable
class LaunchCommandBuilder:
"""Build command-lines for launching worker ComfyUI processes."""
def _extend_arg(self, cmd, flag, value):
if value in (None, "", [], ()):
return
cmd.extend([flag, str(value)])
def _extend_grouped_args(self, cmd, flag, values):
for group in values or []:
flattened = [str(item) for item in group if item]
if flattened:
cmd.append(flag)
cmd.extend(flattened)
def _get_runtime_args(self):
try:
from comfy.cli_args import args
return args
except Exception as exc:
debug_log(f"Could not read current ComfyUI CLI args for worker launch: {exc}")
return None
def _build_runtime_launch_args(self):
args = self._get_runtime_args()
if args is None:
return []
inherited = []
self._extend_arg(inherited, "--listen", getattr(args, "listen", None))
self._extend_arg(inherited, "--base-directory", getattr(args, "base_directory", None))
self._extend_arg(inherited, "--temp-directory", getattr(args, "temp_directory", None))
self._extend_arg(inherited, "--input-directory", getattr(args, "input_directory", None))
self._extend_arg(inherited, "--output-directory", getattr(args, "output_directory", None))
self._extend_arg(inherited, "--user-directory", getattr(args, "user_directory", None))
self._extend_arg(inherited, "--front-end-root", getattr(args, "front_end_root", None))
self._extend_grouped_args(
inherited,
"--extra-model-paths-config",
getattr(args, "extra_model_paths_config", None),
)
if getattr(args, "enable_manager", False):
inherited.append("--enable-manager")
if getattr(args, "disable_manager_ui", False):
inherited.append("--disable-manager-ui")
if getattr(args, "enable_manager_legacy_ui", False):
inherited.append("--enable-manager-legacy-ui")
if getattr(args, "windows_standalone_build", False):
inherited.append("--windows-standalone-build")
if getattr(args, "log_stdout", False):
inherited.append("--log-stdout")
verbose = getattr(args, "verbose", None)
if verbose and verbose != "INFO":
inherited.extend(["--verbose", str(verbose)])
return inherited
def _find_windows_terminal(self):
"""Find Windows Terminal executable."""
possible_paths = [
os.path.expandvars(r"%LOCALAPPDATA%\Microsoft\WindowsApps\wt.exe"),
os.path.expandvars(r"%PROGRAMFILES%\WindowsApps\Microsoft.WindowsTerminal_*\wt.exe"),
"wt.exe",
]
for path in possible_paths:
if os.path.exists(path):
return path
if "*" in path:
matches = glob.glob(path)
if matches:
return matches[0]
wt_path = shutil.which("wt")
if wt_path:
return wt_path
return None
def build_launch_command(self, worker_config, comfy_root):
"""Build the command to launch a worker."""
main_py = os.path.join(comfy_root, "main.py")
if os.path.exists(main_py):
cmd = [
get_python_executable(),
main_py,
]
cmd.extend(self._build_runtime_launch_args())
cmd.extend(["--port", str(worker_config["port"])])
current_args = self._get_runtime_args()
current_cors = getattr(current_args, "enable_cors_header", None) if current_args else None
cmd.append("--enable-cors-header")
if current_cors is not None:
cmd.append(str(current_cors))
if "--disable-auto-launch" not in cmd:
cmd.append("--disable-auto-launch")
debug_log(f"Using main.py: {main_py}")
else:
error_msg = f"Could not find main.py in {comfy_root}\n"
error_msg += f"Searched for: {main_py}\n"
error_msg += f"Directory contents of {comfy_root}:\n"
try:
if os.path.exists(comfy_root):
files = os.listdir(comfy_root)[:20]
error_msg += " " + "\n ".join(files)
if len(os.listdir(comfy_root)) > 20:
error_msg += f"\n ... and {len(os.listdir(comfy_root)) - 20} more files"
else:
error_msg += f" Directory {comfy_root} does not exist!"
except Exception as exc:
error_msg += f" Error listing directory: {exc}"
error_msg += "\n\nPossible solutions:\n"
error_msg += "1. Check if ComfyUI is installed in a different location\n"
error_msg += "2. For Docker: ComfyUI might be in /ComfyUI or /app\n"
error_msg += "3. Ensure the custom node is installed in the correct location\n"
raise RuntimeError(error_msg)
if worker_config.get("extra_args"):
raw_args = worker_config["extra_args"].strip()
if raw_args:
extra_args_list = shlex.split(raw_args)
forbidden_chars = set(";|>&<`$()[]{}*!?")
for arg in extra_args_list:
if any(char in forbidden_chars for char in arg):
forbidden = "".join(forbidden_chars)
raise ValueError(f"Invalid characters in extra_args: {arg}. Forbidden: {forbidden}")
cmd.extend(extra_args_list)
return cmd
================================================
FILE: workers/process/lifecycle.py
================================================
import os
import platform
import signal
import subprocess
import time
from ...utils.config import load_config, save_config
from ...utils.constants import PROCESS_TERMINATION_TIMEOUT, PROCESS_WAIT_TIMEOUT, WORKER_CHECK_INTERVAL
from ...utils.logging import debug_log, log
from ...utils.process import get_python_executable, is_process_alive, terminate_process
try:
import psutil
PSUTIL_AVAILABLE = True
except ImportError:
log("psutil not available, using fallback process management")
PSUTIL_AVAILABLE = False
class ProcessLifecycle:
"""Worker process lifecycle operations operating on manager-owned state."""
def __init__(self, manager):
self._manager = manager
def launch_worker(self, worker_config, show_window=False):
"""Launch a worker process with logging."""
_ = show_window # Kept for API compatibility.
comfy_root = self._manager.find_comfy_root()
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(worker_config.get("cuda_device", 0))
env["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
env["COMFYUI_MASTER_PID"] = str(os.getpid())
env["COMFYUI_IS_WORKER"] = "1"
cmd = self._manager.build_launch_command(worker_config, comfy_root)
cwd = comfy_root
log_dir = os.path.join(comfy_root, "logs", "workers")
os.makedirs(log_dir, exist_ok=True)
date_stamp = time.strftime("%Y%m%d")
worker_name = worker_config.get("name", f"Worker{worker_config['id']}")
safe_name = "".join(char if char.isalnum() or char in ("-", "_") else "_" for char in worker_name)
log_file = os.path.join(log_dir, f"{safe_name}_{date_stamp}.log")
with open(log_file, "a", encoding="utf-8") as log_handle:
log_handle.write(f"\n\n{'=' * 50}\n")
log_handle.write("=== ComfyUI Worker Session Started ===\n")
log_handle.write(f"Worker: {worker_name}\n")
log_handle.write(f"Port: {worker_config['port']}\n")
log_handle.write(f"CUDA Device: {worker_config.get('cuda_device', 0)}\n")
log_handle.write(f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
log_handle.write(f"Command: {' '.join(cmd)}\n")
config = load_config()
stop_on_master_exit = config.get("settings", {}).get("stop_workers_on_master_exit", True)
if stop_on_master_exit:
log_handle.write("Note: Worker will stop when master shuts down\n")
else:
log_handle.write("Note: Worker will continue running after master shuts down\n")
log_handle.write("=" * 30 + "\n\n")
log_handle.flush()
if stop_on_master_exit and env.get("COMFYUI_MASTER_PID"):
monitor_script = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"worker_monitor.py",
)
monitored_cmd = [get_python_executable(), monitor_script] + cmd
log_handle.write(f"[Worker Monitor] Monitoring master PID: {env['COMFYUI_MASTER_PID']}\n")
log_handle.flush()
else:
monitored_cmd = cmd
if platform.system() == "Windows":
create_no_window = 0x08000000
process = subprocess.Popen(
monitored_cmd,
env=env,
cwd=cwd,
stdout=log_handle,
stderr=subprocess.STDOUT,
creationflags=create_no_window,
)
else:
process = subprocess.Popen(
monitored_cmd,
env=env,
cwd=cwd,
stdout=log_handle,
stderr=subprocess.STDOUT,
start_new_session=True,
)
worker_id = str(worker_config["id"])
self._manager.processes[worker_id] = {
"pid": process.pid,
"process": process,
"started_at": time.time(),
"config": worker_config,
"log_file": log_file,
"is_monitor": stop_on_master_exit and env.get("COMFYUI_MASTER_PID"),
"launching": True,
}
self._manager.save_processes()
if stop_on_master_exit and env.get("COMFYUI_MASTER_PID"):
debug_log(f"Launched worker {worker_name} via monitor (Monitor PID: {process.pid})")
else:
log(f"Launched worker {worker_name} directly (PID: {process.pid})")
debug_log(f"Log file: {log_file}")
return process.pid
def stop_worker(self, worker_id):
"""Stop a worker process."""
worker_id = str(worker_id)
if worker_id not in self._manager.processes:
return False, "Worker not managed by UI"
proc_info = self._manager.processes[worker_id]
process = proc_info.get("process")
pid = proc_info["pid"]
debug_log(f"Attempting to stop worker {worker_id} (PID: {pid})")
if not process:
try:
debug_log("[Distributed] Stopping restored process (no subprocess object)")
if self._kill_process_tree(pid):
del self._manager.processes[worker_id]
self._manager.save_processes()
debug_log(f"Successfully stopped worker {worker_id} and all child processes")
return True, "Worker stopped"
return False, "Failed to stop worker process"
except Exception as exc:
log(f"[Distributed] Exception during stop: {exc}")
return False, f"Error stopping worker: {str(exc)}"
if process.poll() is not None:
log(f"[Distributed] Worker {worker_id} already stopped")
del self._manager.processes[worker_id]
self._manager.save_processes()
return False, "Worker already stopped"
try:
debug_log(f"Using process tree kill for worker {worker_id}")
if self._kill_process_tree(pid):
del self._manager.processes[worker_id]
self._manager.save_processes()
debug_log(f"Successfully stopped worker {worker_id} and all child processes")
return True, "Worker stopped"
log("[Distributed] Process tree kill failed, trying normal termination")
terminate_process(process, timeout=PROCESS_TERMINATION_TIMEOUT)
del self._manager.processes[worker_id]
self._manager.save_processes()
return True, "Worker stopped (fallback)"
except Exception as exc:
log(f"[Distributed] Exception during stop: {exc}")
return False, f"Error stopping worker: {str(exc)}"
def get_managed_workers(self):
"""Get list of workers managed by this process."""
managed = {}
for worker_id, proc_info in list(self._manager.processes.items()):
is_running, _ = self._check_worker_process(worker_id, proc_info)
if is_running:
managed[worker_id] = {
"pid": proc_info["pid"],
"started_at": proc_info["started_at"],
"log_file": proc_info.get("log_file"),
"launching": proc_info.get("launching", False),
}
else:
del self._manager.processes[worker_id]
return managed
def cleanup_all(self):
"""Stop all managed workers (called on shutdown)."""
for worker_id in list(self._manager.processes.keys()):
try:
self.stop_worker(worker_id)
except Exception as exc:
log(f"[Distributed] Error stopping worker {worker_id}: {exc}")
config = load_config()
config["managed_processes"] = {}
save_config(config)
def _is_process_running(self, pid):
"""Check if a process with given PID is running."""
return is_process_alive(pid)
def _check_worker_process(self, worker_id, proc_info):
"""Check if a worker process is still running and return status."""
_ = worker_id # Signature retained for compatibility with existing callers.
process = proc_info.get("process")
pid = proc_info.get("pid")
if process:
return process.poll() is None, True
if pid:
return self._is_process_running(pid), False
return False, False
def _kill_process_tree(self, pid):
"""Kill a process and all its children."""
if PSUTIL_AVAILABLE:
try:
parent = psutil.Process(pid)
children = parent.children(recursive=True)
debug_log(f"Killing process tree for PID {pid} ({parent.name()})")
for child in children:
debug_log(f" - Child PID {child.pid} ({child.name()})")
for child in children:
try:
debug_log(f"Terminating child {child.pid}")
child.terminate()
except psutil.NoSuchProcess:
pass
_, alive = psutil.wait_procs(children, timeout=PROCESS_WAIT_TIMEOUT)
for child in alive:
try:
debug_log(f"Force killing child {child.pid}")
child.kill()
except psutil.NoSuchProcess:
pass
try:
debug_log(f"Terminating parent {pid}")
parent.terminate()
parent.wait(timeout=PROCESS_WAIT_TIMEOUT)
except psutil.TimeoutExpired:
debug_log(f"Force killing parent {pid}")
parent.kill()
except psutil.NoSuchProcess:
debug_log(f"Parent process {pid} already gone")
return True
except psutil.NoSuchProcess:
debug_log(f"Process {pid} does not exist")
return False
except Exception as exc:
debug_log(f"Error killing process tree: {exc}")
debug_log("[Distributed] Using OS commands to kill process tree")
if platform.system() == "Windows":
try:
result = subprocess.run(
["wmic", "process", "where", f"ParentProcessId={pid}", "get", "ProcessId"],
capture_output=True,
text=True,
)
if result.returncode == 0:
lines = result.stdout.strip().split("\n")[1:]
child_pids = [line.strip() for line in lines if line.strip().isdigit()]
debug_log(f"[Distributed] Found child processes: {child_pids}")
for child_pid in child_pids:
try:
subprocess.run(
["taskkill", "/F", "/PID", child_pid],
capture_output=True,
check=False,
)
except (FileNotFoundError, OSError) as exc:
debug_log(f"[Distributed] Warning: taskkill failed for PID {child_pid}: {exc}")
result = subprocess.run(
["taskkill", "/F", "/PID", str(pid), "/T"],
capture_output=True,
text=True,
)
debug_log(f"[Distributed] Taskkill result: {result.stdout.strip()}")
return result.returncode == 0
except Exception as exc:
log(f"[Distributed] Error with taskkill: {exc}")
return False
try:
subprocess.run(["pkill", "-TERM", "-P", str(pid)], check=False)
time.sleep(WORKER_CHECK_INTERVAL)
subprocess.run(["pkill", "-KILL", "-P", str(pid)], check=False)
os.kill(pid, signal.SIGKILL)
return True
except Exception as exc:
log(f"[Distributed] Error killing process tree for PID {pid}: {exc}")
return False
================================================
FILE: workers/process/persistence.py
================================================
from ...utils.config import load_config, save_config
from ...utils.logging import debug_log
class ProcessPersistence:
"""Persist and restore manager-owned worker process metadata."""
def __init__(self, manager):
self._manager = manager
def load_processes(self):
"""Load persisted process information from config."""
config = load_config()
managed_processes = config.get("managed_processes", {})
for worker_id, proc_info in managed_processes.items():
pid = proc_info.get("pid")
if pid and self._manager._is_process_running(pid):
self._manager.processes[worker_id] = {
"pid": pid,
"process": None,
"started_at": proc_info.get("started_at"),
"config": proc_info.get("config"),
"log_file": proc_info.get("log_file"),
}
debug_log(f"[Distributed] Restored worker {worker_id} (PID: {pid})")
elif pid:
debug_log(f"[Distributed] Worker {worker_id} (PID: {pid}) is no longer running")
def save_processes(self):
"""Save process information to config."""
config = load_config()
managed_processes = {}
for worker_id, proc_info in self._manager.processes.items():
is_running, _ = self._manager._check_worker_process(worker_id, proc_info)
if not is_running:
continue
managed_processes[worker_id] = {
"pid": proc_info["pid"],
"started_at": proc_info["started_at"],
"config": proc_info["config"],
"log_file": proc_info.get("log_file"),
"launching": proc_info.get("launching", False),
}
config["managed_processes"] = managed_processes
save_config(config)
================================================
FILE: workers/process/root_discovery.py
================================================
import os
import sys
from ...utils.logging import debug_log, log
class ComfyRootDiscovery:
"""Resolve the ComfyUI root directory across local and container layouts."""
def _find_root_from_loaded_modules(self):
"""Use already-imported ComfyUI modules to locate the runtime root."""
for module_name in ("server", "folder_paths", "main"):
module = sys.modules.get(module_name)
module_file = getattr(module, "__file__", None)
if not module_file:
continue
candidate = os.path.dirname(os.path.abspath(module_file))
if os.path.exists(os.path.join(candidate, "main.py")):
debug_log(f"Found ComfyUI root via loaded module {module_name}: {candidate}")
return candidate
return None
def find_comfy_root(self):
# Start from current file location.
current_dir = os.path.dirname(os.path.abspath(__file__))
potential_root = os.path.dirname(os.path.dirname(current_dir))
# Method 1: Check for environment variable override.
env_root = os.environ.get("COMFYUI_ROOT")
if env_root and os.path.exists(os.path.join(env_root, "main.py")):
debug_log(f"Found ComfyUI root via COMFYUI_ROOT environment variable: {env_root}")
return env_root
# Method 2: Inspect the already-loaded ComfyUI runtime modules.
runtime_root = self._find_root_from_loaded_modules()
if runtime_root:
return runtime_root
# Method 3: Try going up from custom_nodes directory.
if os.path.exists(os.path.join(potential_root, "main.py")):
debug_log(f"Found ComfyUI root via directory traversal: {potential_root}")
return potential_root
# Method 4: Look for common Docker paths.
docker_paths = [
"/basedir",
"/ComfyUI",
"/app",
"/workspace/ComfyUI",
"/comfyui",
"/opt/ComfyUI",
"/workspace",
]
for path in docker_paths:
if os.path.exists(path) and os.path.exists(os.path.join(path, "main.py")):
debug_log(f"Found ComfyUI root in Docker path: {path}")
return path
# Method 5: Search upwards for main.py.
search_dir = current_dir
for _ in range(5):
if os.path.exists(os.path.join(search_dir, "main.py")):
debug_log(f"Found ComfyUI root via upward search: {search_dir}")
return search_dir
parent = os.path.dirname(search_dir)
if parent == search_dir:
break
search_dir = parent
# Method 6: Try to import and use folder_paths.
try:
import folder_paths
if hasattr(folder_paths, "base_path") and os.path.exists(
os.path.join(folder_paths.base_path, "main.py")
):
debug_log(f"Found ComfyUI root via folder_paths: {folder_paths.base_path}")
return folder_paths.base_path
except Exception as exc:
debug_log(f"folder_paths root detection failed: {exc}")
log("Warning: Could not reliably determine ComfyUI root directory")
log(f"Current directory: {current_dir}")
log(f"Initial guess was: {potential_root}")
return potential_root
================================================
FILE: workers/process_manager.py
================================================
from .process import ComfyRootDiscovery, LaunchCommandBuilder, ProcessLifecycle, ProcessPersistence
class WorkerProcessManager:
"""Thin composition wrapper around worker process subsystems."""
def __init__(self):
self.processes = {}
self._root_discovery = ComfyRootDiscovery()
self._launch_builder = LaunchCommandBuilder()
self._lifecycle = ProcessLifecycle(self)
self._persistence = ProcessPersistence(self)
self.load_processes()
def find_comfy_root(self):
return self._root_discovery.find_comfy_root()
def _find_windows_terminal(self):
return self._launch_builder._find_windows_terminal()
def build_launch_command(self, worker_config, comfy_root):
return self._launch_builder.build_launch_command(worker_config, comfy_root)
def launch_worker(self, worker_config, show_window=False):
return self._lifecycle.launch_worker(worker_config, show_window=show_window)
def stop_worker(self, worker_id):
return self._lifecycle.stop_worker(worker_id)
def get_managed_workers(self):
return self._lifecycle.get_managed_workers()
def cleanup_all(self):
return self._lifecycle.cleanup_all()
def load_processes(self):
return self._persistence.load_processes()
def save_processes(self):
return self._persistence.save_processes()
def _is_process_running(self, pid):
return self._lifecycle._is_process_running(pid)
def _check_worker_process(self, worker_id, proc_info):
return self._lifecycle._check_worker_process(worker_id, proc_info)
def _kill_process_tree(self, pid):
return self._lifecycle._kill_process_tree(pid)
================================================
FILE: workers/startup.py
================================================
import asyncio
import threading
import time
import atexit
import signal
import sys
import platform
import server
from ..utils.config import load_config, save_config
from ..utils.logging import debug_log, log
from ..utils.network import normalize_host
from ..utils.cloudflare import cloudflare_tunnel_manager
from ..utils.constants import WORKER_STARTUP_DELAY
from . import get_worker_manager
def auto_launch_workers():
"""Launch enabled workers if auto_launch_workers is set to true."""
wm = get_worker_manager()
try:
config = load_config()
if config.get('settings', {}).get('auto_launch_workers', False):
log("Auto-launch workers is enabled, checking for workers to start...")
# Clear managed_processes before launching new workers
# This handles cases where the master was killed without proper cleanup
if config.get('managed_processes'):
log("Clearing old managed_processes before auto-launch...")
config['managed_processes'] = {}
save_config(config)
workers = config.get('workers', [])
launched_count = 0
for worker in workers:
if worker.get('enabled', False):
worker_id = worker.get('id')
worker_name = worker.get('name', f'Worker {worker_id}')
# Skip remote workers
host = (normalize_host(worker.get('host', 'localhost')) or 'localhost').lower()
if host not in ['localhost', '127.0.0.1', '', None]:
debug_log(f"Skipping remote worker {worker_name} (host: {host})")
continue
# Check if already running
if str(worker_id) in wm.processes:
proc_info = wm.processes[str(worker_id)]
if wm._is_process_running(proc_info['pid']):
debug_log(f"Worker {worker_name} already running, skipping")
continue
# Launch the worker
try:
pid = wm.launch_worker(worker)
log(f"Auto-launched worker {worker_name} (PID: {pid})")
# Mark as launching in managed processes
if str(worker_id) in wm.processes:
wm.processes[str(worker_id)]['launching'] = True
wm.save_processes()
launched_count += 1
except Exception as e:
log(f"Failed to auto-launch worker {worker_name}: {e}")
if launched_count > 0:
log(f"Auto-launched {launched_count} worker(s)")
else:
debug_log("No workers to auto-launch")
else:
debug_log("Auto-launch workers is disabled")
except Exception as e:
log(f"Error during auto-launch: {e}")
# Schedule auto-launch after a short delay to ensure server is ready
def delayed_auto_launch():
"""Delay auto-launch to ensure server is fully initialized."""
import threading
timer = threading.Timer(WORKER_STARTUP_DELAY, auto_launch_workers)
timer.daemon = True
timer.start()
# Async cleanup function for proper shutdown
async def async_cleanup_and_exit(signum=None):
"""Async-friendly cleanup and exit."""
wm = get_worker_manager()
try:
config = load_config()
if config.get('settings', {}).get('stop_workers_on_master_exit', True):
print("\n[Distributed] Master shutting down, stopping all managed workers...")
wm.cleanup_all()
else:
print("\n[Distributed] Master shutting down, workers will continue running")
wm.save_processes()
try:
await cloudflare_tunnel_manager.stop_tunnel()
except Exception as tunnel_error:
log(f"[Distributed] Warning: Cloudflare tunnel did not stop cleanly during shutdown: {tunnel_error}")
except Exception as e:
print(f"[Distributed] Error during cleanup: {e}")
# On Windows, we need to exit differently
if platform.system() == "Windows":
# Force exit on Windows
sys.exit(0)
else:
# On Unix, stop the event loop gracefully
loop = asyncio.get_running_loop()
loop.stop()
def register_async_signals():
"""Register async signal handlers for graceful shutdown."""
wm = get_worker_manager()
if platform.system() == "Windows":
# Windows doesn't support add_signal_handler, use traditional signal handling
def signal_handler(signum, frame):
# Schedule the async cleanup in the event loop
loop = server.PromptServer.instance.loop
if loop and loop.is_running():
asyncio.run_coroutine_threadsafe(async_cleanup_and_exit(signum), loop)
else:
# Fallback to sync cleanup if loop isn't running
try:
config = load_config()
if config.get('settings', {}).get('stop_workers_on_master_exit', True):
print("\n[Distributed] Master shutting down, stopping all managed workers...")
wm.cleanup_all()
else:
print("\n[Distributed] Master shutting down, workers will continue running")
wm.save_processes()
except Exception as e:
print(f"[Distributed] Error during cleanup: {e}")
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
else:
# Unix-like systems support add_signal_handler
loop = server.PromptServer.instance.loop
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, lambda s=sig: asyncio.create_task(async_cleanup_and_exit(s)))
# SIGHUP is Unix-only
loop.add_signal_handler(signal.SIGHUP, lambda: asyncio.create_task(async_cleanup_and_exit(signal.SIGHUP)))
def sync_cleanup():
"""Synchronous wrapper for atexit."""
wm = get_worker_manager()
try:
# For atexit, we don't want to stop the loop or exit
config = load_config()
if config.get('settings', {}).get('stop_workers_on_master_exit', True):
print("\n[Distributed] Master shutting down, stopping all managed workers...")
wm.cleanup_all()
else:
print("\n[Distributed] Master shutting down, workers will continue running")
wm.save_processes()
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(cloudflare_tunnel_manager.stop_tunnel())
else:
loop.run_until_complete(cloudflare_tunnel_manager.stop_tunnel())
except RuntimeError:
# No running loop; create a temporary one
asyncio.run(cloudflare_tunnel_manager.stop_tunnel())
except Exception as tunnel_error:
log(f"[Distributed] Warning: Cloudflare tunnel did not stop cleanly during sync cleanup: {tunnel_error}")
except Exception as e:
print(f"[Distributed] Error during cleanup: {e}")
================================================
FILE: workers/worker_monitor.py
================================================
#!/usr/bin/env python3
"""
Worker process monitor - monitors if the master process is still alive
and terminates the worker if the master dies.
"""
import os
import sys
import time
import subprocess
import platform
import signal
# Add package root to path so this script works when launched by file path.
NODE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if NODE_ROOT not in sys.path:
sys.path.insert(0, NODE_ROOT)
try:
from utils.process import is_process_alive, terminate_process
from utils.constants import WORKER_CHECK_INTERVAL, PROCESS_TERMINATION_TIMEOUT
except ImportError:
# Fallback if running from different context
def is_process_alive(pid):
"""Check if a process with given PID is still alive."""
try:
if platform.system() == "Windows":
# Windows: use tasklist
result = subprocess.run(['tasklist', '/FI', f'PID eq {pid}'],
capture_output=True, text=True)
return str(pid) in result.stdout
else:
# Unix: send signal 0
os.kill(pid, 0)
return True
except (OSError, subprocess.SubprocessError):
return False
WORKER_CHECK_INTERVAL = 2.0
PROCESS_TERMINATION_TIMEOUT = 5.0
def monitor_and_run(master_pid, command):
"""Run command and monitor master process."""
# Start the actual worker process
print(f"[Distributed] Launching worker command: {' '.join(command)}")
worker_process = subprocess.Popen(command)
print(f"[Distributed] Started worker PID: {worker_process.pid}")
print(f"[Distributed] Monitoring master PID: {master_pid}")
# Write worker PID to a file so parent can track it
monitor_pid = os.getpid()
pid_info_file = os.environ.get('WORKER_PID_FILE')
if pid_info_file:
try:
with open(pid_info_file, 'w') as f:
f.write(f"{monitor_pid},{worker_process.pid}")
print(f"[Distributed] Wrote PID info to {pid_info_file}")
except Exception as e:
print(f"[Distributed] Could not write PID file: {e}")
# Define cleanup function
def cleanup_worker(signum=None, frame=None):
"""Clean up worker process when monitor is terminated."""
if signum:
print(f"\n[Distributed] Received signal {signum}, terminating worker...")
else:
print("\n[Distributed] Terminating worker...")
if worker_process.poll() is None: # Still running
try:
terminate_process(worker_process, timeout=PROCESS_TERMINATION_TIMEOUT)
except NameError:
# Fallback if terminate_process wasn't imported
worker_process.terminate()
try:
worker_process.wait(timeout=PROCESS_TERMINATION_TIMEOUT)
except subprocess.TimeoutExpired:
print("[Distributed] Worker didn't terminate gracefully, forcing kill...")
worker_process.kill()
worker_process.wait()
print("[Distributed] Worker terminated.")
sys.exit(0)
# Register signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, cleanup_worker)
signal.signal(signal.SIGINT, cleanup_worker)
if platform.system() != "Windows":
signal.signal(signal.SIGHUP, cleanup_worker)
# Monitor loop
check_interval = WORKER_CHECK_INTERVAL
try:
while True:
# Check if worker is still running
if worker_process.poll() is not None:
print(f"[Distributed] Worker process exited with code: {worker_process.returncode}")
sys.exit(worker_process.returncode)
# Check if master is still alive
if not is_process_alive(master_pid):
print(f"[Distributed] Master process {master_pid} is no longer running. Terminating worker...")
cleanup_worker()
time.sleep(check_interval)
except KeyboardInterrupt:
cleanup_worker()
if __name__ == "__main__":
# Get master PID from environment
master_pid = os.environ.get('COMFYUI_MASTER_PID')
if not master_pid:
print("[Distributed] Error: COMFYUI_MASTER_PID not set")
sys.exit(1)
try:
master_pid = int(master_pid)
except ValueError:
print(f"[Distributed] Error: Invalid master PID: {master_pid}")
sys.exit(1)
# Get the actual command to run (all remaining arguments)
if len(sys.argv) < 2:
print("[Distributed] Error: No command specified")
sys.exit(1)
command = sys.argv[1:]
# Start monitoring
monitor_and_run(master_pid, command)
================================================
FILE: workflows/distributed-txt2img.json
================================================
{
"id": "c9a4d248-9b83-408f-b45e-3ef61dd56ef5",
"revision": 0,
"last_node_id": 13,
"last_link_id": 19,
"nodes": [
{
"id": 8,
"type": "KSampler",
"pos": [
2190,
770
],
"size": [
315,
262
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 7
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 8
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 9
},
{
"name": "latent_image",
"type": "LATENT",
"link": 10
},
{
"name": "seed",
"type": "INT",
"widget": {
"name": "seed"
},
"link": 11
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"slot_index": 0,
"links": [
1
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "KSampler"
},
"widgets_values": [
361252850620022,
"randomize",
20,
6,
"euler",
"normal",
1
]
},
{
"id": 9,
"type": "EmptyLatentImage",
"pos": [
2220,
1080
],
"size": [
270,
106
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
10
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "EmptyLatentImage"
},
"widgets_values": [
512,
512,
1
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
1846.3876953125,
968.6343994140625
],
"size": [
310,
180
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 6
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
9
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"text, watermark"
]
},
{
"id": 5,
"type": "CLIPTextEncode",
"pos": [
1856.3876953125,
768.6343994140625
],
"size": [
300,
160
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 5
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
8
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"beautiful scenery nature glass bottle landscape, , purple galaxy bottle,"
]
},
{
"id": 7,
"type": "CheckpointLoaderSimple",
"pos": [
1500,
780
],
"size": [
315,
98
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
7
]
},
{
"name": "CLIP",
"type": "CLIP",
"slot_index": 1,
"links": [
5,
6
]
},
{
"name": "VAE",
"type": "VAE",
"slot_index": 2,
"links": [
2
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CheckpointLoaderSimple",
"models": [
{
"name": "v1-5-pruned-emaonly-fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors?download=true",
"directory": "checkpoints"
}
]
},
"widgets_values": [
"SDXL\\juggernautXL_ragnarokBy.safetensors"
]
},
{
"id": 1,
"type": "VAEDecode",
"pos": [
2530,
790
],
"size": [
210,
46
],
"flags": {
"collapsed": true
},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 1
},
{
"name": "vae",
"type": "VAE",
"link": 2
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
3
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "VAEDecode"
}
},
{
"id": 2,
"type": "DistributedCollector",
"pos": [
2690,
770
],
"size": [
166.50416564941406,
26
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 3
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
4
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"aux_id": "robertvoy/ComfyUI-Distributed",
"ver": "99021363d65cc2b2f0f3a0f12a76a358f0fb330f",
"Node name for S&R": "DistributedCollector"
}
},
{
"id": 3,
"type": "PreviewImage",
"pos": [
2880,
760
],
"size": [
410,
480
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 4
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 4,
"type": "DistributedSeed",
"pos": [
1890,
1220
],
"size": [
270,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "seed",
"type": "INT",
"links": [
11
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"aux_id": "robertvoy/ComfyUI-Distributed",
"ver": "99021363d65cc2b2f0f3a0f12a76a358f0fb330f",
"Node name for S&R": "DistributedSeed"
},
"widgets_values": [
504373561407102,
"randomize"
]
}
],
"links": [
[
1,
8,
0,
1,
0,
"LATENT"
],
[
2,
7,
2,
1,
1,
"VAE"
],
[
3,
1,
0,
2,
0,
"IMAGE"
],
[
4,
2,
0,
3,
0,
"IMAGE"
],
[
5,
7,
1,
5,
0,
"CLIP"
],
[
6,
7,
1,
6,
0,
"CLIP"
],
[
7,
7,
0,
8,
0,
"MODEL"
],
[
8,
5,
0,
8,
1,
"CONDITIONING"
],
[
9,
6,
0,
8,
2,
"CONDITIONING"
],
[
10,
9,
0,
8,
3,
"LATENT"
],
[
11,
4,
0,
8,
4,
"INT"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.6649272177973091,
"offset": [
-903.8525468054443,
-478.58804363769354
]
},
"frontendVersion": "1.23.4"
},
"version": 0.4
}
================================================
FILE: workflows/distributed-upscale-video.json
================================================
{
"id": "707da2be-c7d6-481f-b3b0-3ec8207924a1",
"revision": 0,
"last_node_id": 76,
"last_link_id": 122,
"nodes": [
{
"id": 18,
"type": "UltimateSDUpscaleDistributed",
"pos": [
3800,
1400
],
"size": [
380,
430
],
"flags": {},
"order": 15,
"mode": 0,
"inputs": [
{
"name": "upscaled_image",
"type": "IMAGE",
"link": 118
},
{
"name": "model",
"type": "MODEL",
"link": 32
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 33
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 104
},
{
"name": "vae",
"type": "VAE",
"link": 35
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
95
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "e076cf3455df72383d17f1d5f5b5aa5e709f2e4a",
"Node name for S&R": "UltimateSDUpscaleDistributed"
},
"widgets_values": [
157027921504581,
"fixed",
5,
1,
"res_2s",
"bong_tangent",
0.35,
1024,
1024,
32,
16,
true,
false
]
},
{
"id": 57,
"type": "VHS_VideoCombine",
"pos": [
4230,
1400
],
"size": [
550,
639.5
],
"flags": {},
"order": 16,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 95
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-videohelpersuite",
"ver": "a7ce59e381934733bfae03b1be029756d6ce936d",
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"loop_count": 0,
"filename_prefix": "WAN",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 20,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "WAN_00065.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 16,
"workflow": "WAN_00065.png",
"fullpath": "C:\\venvs\\ComfyUI\\ComfyUI\\output\\WAN_00065.mp4"
}
}
}
},
{
"id": 12,
"type": "CLIPTextEncode",
"pos": [
2980,
1610
],
"size": [
370,
160
],
"flags": {
"collapsed": false
},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 115
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
33,
56
]
}
],
"title": "Positive Prompt",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPTextEncode",
"ue_properties": {
"version": "7.0.1",
"widget_ue_connectable": {}
}
},
"widgets_values": [
"A sea turtle swims amidst vibrant coral reefs, two sergeant major fish nearby. The water is clear and blue, showcasing the intricate details of the coral and turtle's shell. Photorealistic, underwater scene, 8k resolution."
],
"color": "#232",
"bgcolor": "#353"
},
{
"id": 11,
"type": "CLIPLoader",
"pos": [
2520,
1610
],
"size": [
387.5943603515625,
106
],
"flags": {
"collapsed": false
},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
115
]
}
],
"title": "CLIP",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPLoader",
"ue_properties": {
"version": "7.0.1",
"widget_ue_connectable": {}
}
},
"widgets_values": [
"umt5_xxl_fp16.safetensors",
"wan",
"default"
]
},
{
"id": 74,
"type": "ImageFromBatch",
"pos": [
2280,
1970
],
"size": [
270,
82
],
"flags": {},
"order": 6,
"mode": 4,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 119
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
120
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.59",
"Node name for S&R": "ImageFromBatch"
},
"widgets_values": [
0,
1
]
},
{
"id": 63,
"type": "easy imageInterrogator",
"pos": [
2570,
1970
],
"size": [
280,
82
],
"flags": {},
"order": 10,
"mode": 4,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 120
}
],
"outputs": [
{
"name": "prompt",
"shape": 6,
"type": "STRING",
"links": [
108
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-easy-use",
"ver": "1.3.3",
"Node name for S&R": "easy imageInterrogator"
},
"widgets_values": [
"fast",
true
]
},
{
"id": 8,
"type": "ModelSamplingSD3",
"pos": [
3410,
1020
],
"size": [
221.14166259765625,
88.32342529296875
],
"flags": {
"collapsed": false
},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 121
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
32
]
}
],
"title": "Shift",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.36",
"Node name for S&R": "ModelSamplingSD3",
"ue_properties": {
"version": "7.0.1",
"widget_ue_connectable": {}
}
},
"widgets_values": [
8.000000000000002
],
"color": "#223",
"bgcolor": "#335"
},
{
"id": 26,
"type": "ConditioningZeroOut",
"pos": [
3440,
1710
],
"size": [
198.16665649414062,
26
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 56
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
104
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.51",
"Node name for S&R": "ConditioningZeroOut"
},
"widgets_values": []
},
{
"id": 70,
"type": "ImageUpscaleWithModel",
"pos": [
2590,
1280
],
"size": [
310,
46
],
"flags": {},
"order": 7,
"mode": 4,
"inputs": [
{
"name": "upscale_model",
"type": "UPSCALE_MODEL",
"link": 111
},
{
"name": "image",
"type": "IMAGE",
"link": 112
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
113
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.59",
"Node name for S&R": "ImageUpscaleWithModel"
},
"widgets_values": []
},
{
"id": 54,
"type": "VHS_LoadVideo",
"pos": [
1480,
1300
],
"size": [
620,
654
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
112,
119
]
},
{
"name": "frame_count",
"type": "INT",
"links": []
},
{
"name": "audio",
"type": "AUDIO",
"links": null
},
{
"name": "video_info",
"type": "VHS_VIDEOINFO",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-videohelpersuite",
"ver": "a7ce59e381934733bfae03b1be029756d6ce936d",
"Node name for S&R": "VHS_LoadVideo"
},
"widgets_values": {
"video": "ComfyUI_00005_.mp4",
"force_rate": 0,
"custom_width": 0,
"custom_height": 0,
"frame_load_cap": 5,
"skip_first_frames": 0,
"select_every_nth": 1,
"format": "Wan",
"choose video to upload": "image",
"videopreview": {
"hidden": false,
"paused": true,
"params": {
"filename": "ComfyUI_00005_.mp4",
"type": "input",
"format": "video/mp4",
"force_rate": 0,
"custom_width": 0,
"custom_height": 0,
"frame_load_cap": 5,
"skip_first_frames": 0,
"select_every_nth": 1
}
}
}
},
{
"id": 71,
"type": "ImageResize+",
"pos": [
3030,
1270
],
"size": [
270,
218
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 113
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
117
]
},
{
"name": "width",
"type": "INT",
"links": null
},
{
"name": "height",
"type": "INT",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui_essentials",
"ver": "9d9f4bedfc9f0321c19faf71855e228c93bd0dc9",
"Node name for S&R": "ImageResize+"
},
"widgets_values": [
1920,
1080,
"lanczos",
"keep proportion",
"always",
0
]
},
{
"id": 14,
"type": "VAELoader",
"pos": [
3410,
1870
],
"size": [
270,
58
],
"flags": {
"collapsed": false
},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
35
]
}
],
"title": "VAE",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "VAELoader",
"ue_properties": {
"version": "7.0.1",
"widget_ue_connectable": {}
}
},
"widgets_values": [
"wan_2.1_vae.safetensors"
]
},
{
"id": 67,
"type": "DisplayAny",
"pos": [
2880,
1970
],
"size": [
360,
160
],
"flags": {},
"order": 13,
"mode": 4,
"inputs": [
{
"name": "input",
"type": "*",
"link": 108
}
],
"outputs": [
{
"name": "STRING",
"type": "STRING",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui_essentials",
"ver": "9d9f4bedfc9f0321c19faf71855e228c93bd0dc9",
"Node name for S&R": "DisplayAny"
},
"widgets_values": [
"raw value"
]
},
{
"id": 73,
"type": "Reroute",
"pos": [
3500,
1270
],
"size": [
75,
26
],
"flags": {},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 117
}
],
"outputs": [
{
"name": "",
"type": "IMAGE",
"links": [
118
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 72,
"type": "LoraLoaderModelOnly",
"pos": [
2940,
1020
],
"size": [
420,
82
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 122
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
121
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.59",
"Node name for S&R": "LoraLoaderModelOnly"
},
"widgets_values": [
"Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors",
1
]
},
{
"id": 69,
"type": "UpscaleModelLoader",
"pos": [
2250,
1280
],
"size": [
300,
60
],
"flags": {},
"order": 3,
"mode": 4,
"inputs": [],
"outputs": [
{
"name": "UPSCALE_MODEL",
"type": "UPSCALE_MODEL",
"links": [
111
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.59",
"Node name for S&R": "UpscaleModelLoader"
},
"widgets_values": [
"RealESRGAN_x2.pth"
]
},
{
"id": 76,
"type": "UNETLoader",
"pos": [
2520,
1020
],
"size": [
380,
82
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
122
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.59",
"Node name for S&R": "UNETLoader"
},
"widgets_values": [
"wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors",
"default"
]
}
],
"links": [
[
32,
8,
0,
18,
1,
"MODEL"
],
[
33,
12,
0,
18,
2,
"CONDITIONING"
],
[
35,
14,
0,
18,
4,
"VAE"
],
[
56,
12,
0,
26,
0,
"CONDITIONING"
],
[
95,
18,
0,
57,
0,
"IMAGE"
],
[
104,
26,
0,
18,
3,
"CONDITIONING"
],
[
108,
63,
0,
67,
0,
"*"
],
[
111,
69,
0,
70,
0,
"UPSCALE_MODEL"
],
[
112,
54,
0,
70,
1,
"IMAGE"
],
[
113,
70,
0,
71,
0,
"IMAGE"
],
[
115,
11,
0,
12,
0,
"CLIP"
],
[
117,
71,
0,
73,
0,
"*"
],
[
118,
73,
0,
18,
0,
"IMAGE"
],
[
119,
54,
0,
74,
0,
"IMAGE"
],
[
120,
74,
0,
63,
0,
"IMAGE"
],
[
121,
72,
0,
8,
0,
"MODEL"
],
[
122,
76,
0,
72,
0,
"MODEL"
]
],
"groups": [
{
"id": 1,
"title": "Optional Upscale",
"bounding": [
2240,
1200,
690,
180
],
"color": "#3f789e",
"font_size": 24,
"flags": {}
},
{
"id": 2,
"title": "Prompt Generator",
"bounding": [
2260,
1880,
1000,
280
],
"color": "#3f789e",
"font_size": 24,
"flags": {}
}
],
"config": {},
"extra": {
"ds": {
"scale": 0.7360065561459117,
"offset": [
-1580.220925243604,
-519.0718505701544
]
},
"frontendVersion": "1.25.11",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}
================================================
FILE: workflows/distributed-upscale.json
================================================
{
"id": "817bbfe2-06b8-44c8-8c14-b82b63b335d5",
"revision": 0,
"last_node_id": 137,
"last_link_id": 211,
"nodes": [
{
"id": 86,
"type": "Reroute",
"pos": [
2130,
1050
],
"size": [
75,
26
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 121
}
],
"outputs": [
{
"name": "",
"type": "VAE",
"links": [
127
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 88,
"type": "Reroute",
"pos": [
2130,
1000
],
"size": [
75,
26
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 124
}
],
"outputs": [
{
"name": "",
"type": "MODEL",
"links": [
137
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 107,
"type": "Reroute",
"pos": [
2810,
1050
],
"size": [
75,
26
],
"flags": {},
"order": 18,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 171
}
],
"outputs": [
{
"name": "",
"type": "VAE",
"links": [
172
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 89,
"type": "Reroute",
"pos": [
2810,
1000
],
"size": [
75,
26
],
"flags": {},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 137
}
],
"outputs": [
{
"name": "",
"type": "MODEL",
"links": [
170
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 110,
"type": "Reroute",
"pos": [
2810,
970
],
"size": [
75,
26
],
"flags": {},
"order": 21,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 211
}
],
"outputs": [
{
"name": "",
"type": "IMAGE",
"links": [
176
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 7,
"type": "CheckpointLoaderSimple",
"pos": [
1700,
1270
],
"size": [
370,
98
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
124
]
},
{
"name": "CLIP",
"type": "CLIP",
"slot_index": 1,
"links": [
5,
6
]
},
{
"name": "VAE",
"type": "VAE",
"slot_index": 2,
"links": [
121
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CheckpointLoaderSimple",
"models": [
{
"name": "v1-5-pruned-emaonly-fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors?download=true",
"directory": "checkpoints"
}
]
},
"widgets_values": [
"SDXL\\juggernautXL_ragnarokBy.safetensors"
]
},
{
"id": 90,
"type": "Reroute",
"pos": [
2460,
1050
],
"size": [
75,
26
],
"flags": {},
"order": 15,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 127
}
],
"outputs": [
{
"name": "",
"type": "VAE",
"links": [
129,
171
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 44,
"type": "ControlNetLoader",
"pos": [
1910,
1510
],
"size": [
390,
58
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CONTROL_NET",
"type": "CONTROL_NET",
"links": [
81
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "ControlNetLoader"
},
"widgets_values": [
"SDXL\\xinsir-controlnet-union-sdxl-1.0-promax.safetensors"
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
2130,
1370
],
"size": [
420,
88
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 6
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
64
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"text, watermark"
]
},
{
"id": 52,
"type": "SetUnionControlNetType",
"pos": [
2340,
1510
],
"size": [
210,
58
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "control_net",
"type": "CONTROL_NET",
"link": 81
}
],
"outputs": [
{
"name": "CONTROL_NET",
"type": "CONTROL_NET",
"links": [
192
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "SetUnionControlNetType"
},
"widgets_values": [
"tile"
]
},
{
"id": 125,
"type": "Reroute",
"pos": [
2450,
880
],
"size": [
75,
26
],
"flags": {},
"order": 16,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 200
}
],
"outputs": [
{
"name": "",
"type": "IMAGE",
"links": [
206
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 34,
"type": "ImageResize+",
"pos": [
1700,
970
],
"size": [
270,
218
],
"flags": {},
"order": 20,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 197
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
211
]
},
{
"name": "width",
"type": "INT",
"links": null
},
{
"name": "height",
"type": "INT",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui_essentials",
"ver": "9d9f4bedfc9f0321c19faf71855e228c93bd0dc9",
"Node name for S&R": "ImageResize+"
},
"widgets_values": [
2048,
2048,
"lanczos",
"keep proportion",
"always",
8
]
},
{
"id": 123,
"type": "ImageUpscaleWithModel",
"pos": [
1410,
970
],
"size": [
222.75416564941406,
46
],
"flags": {},
"order": 17,
"mode": 4,
"inputs": [
{
"name": "upscale_model",
"type": "UPSCALE_MODEL",
"link": 195
},
{
"name": "image",
"type": "IMAGE",
"link": 210
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
197
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "ImageUpscaleWithModel"
},
"widgets_values": []
},
{
"id": 134,
"type": "MarkdownNote",
"pos": [
1920,
1610
],
"size": [
370,
88
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"[https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main)"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 135,
"type": "LoadImage",
"pos": [
780,
880
],
"size": [
274.375,
314.00006103515625
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
207
]
},
{
"name": "MASK",
"type": "MASK",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"example.png",
"image"
]
},
{
"id": 132,
"type": "Note",
"pos": [
1460,
1280
],
"size": [
210,
88
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"Choose an SDXL model"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 5,
"type": "CLIPTextEncode",
"pos": [
2130,
1160
],
"size": [
420,
160
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 5
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
63
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.41",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"beautiful scenery nature glass bottle landscape, , purple galaxy bottle,"
]
},
{
"id": 124,
"type": "Reroute",
"pos": [
1230,
880
],
"size": [
75,
26
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "",
"type": "*",
"link": 207
}
],
"outputs": [
{
"name": "",
"type": "IMAGE",
"links": [
200,
210
]
}
],
"properties": {
"showOutputText": false,
"horizontal": false
}
},
{
"id": 122,
"type": "UpscaleModelLoader",
"pos": [
1130,
1070
],
"size": [
270,
58
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "UPSCALE_MODEL",
"type": "UPSCALE_MODEL",
"links": [
195
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "UpscaleModelLoader"
},
"widgets_values": [
"4xNomos8kDAT.pth"
]
},
{
"id": 131,
"type": "Note",
"pos": [
1420,
1070
],
"size": [
210,
88
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"Optional"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 137,
"type": "Note",
"pos": [
3000,
1600
],
"size": [
330,
90
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"If all your GPUs are the same/similar, set static_distribution to true\n"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 136,
"type": "PreviewImage",
"pos": [
3380,
1100
],
"size": [
490,
550
],
"flags": {},
"order": 23,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 209
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "PreviewImage"
},
"widgets_values": []
},
{
"id": 43,
"type": "ControlNetApplyAdvanced",
"pos": [
2640,
1150
],
"size": [
270,
186
],
"flags": {},
"order": 19,
"mode": 0,
"inputs": [
{
"name": "positive",
"type": "CONDITIONING",
"link": 63
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 64
},
{
"name": "control_net",
"type": "CONTROL_NET",
"link": 192
},
{
"name": "image",
"type": "IMAGE",
"link": 206
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": 129
}
],
"outputs": [
{
"name": "positive",
"type": "CONDITIONING",
"links": [
190
]
},
{
"name": "negative",
"type": "CONDITIONING",
"links": [
191
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "ControlNetApplyAdvanced"
},
"widgets_values": [
1.0000000000000002,
0,
0.8000000000000002
]
},
{
"id": 30,
"type": "UltimateSDUpscaleDistributed",
"pos": [
3000,
1110
],
"size": [
326.691650390625,
450
],
"flags": {},
"order": 22,
"mode": 0,
"inputs": [
{
"name": "upscaled_image",
"type": "IMAGE",
"link": 176
},
{
"name": "model",
"type": "MODEL",
"link": 170
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 190
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 191
},
{
"name": "vae",
"type": "VAE",
"link": 172
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
209
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "dd23503883fdf319e8beb6e7a190445ecf89973c",
"Node name for S&R": "UltimateSDUpscaleDistributed"
},
"widgets_values": [
269777990474642,
"randomize",
20,
7,
"dpmpp_2m_sde",
"karras",
0.6000000000000001,
1024,
1024,
32,
16,
true,
false,
true
]
}
],
"links": [
[
5,
7,
1,
5,
0,
"CLIP"
],
[
6,
7,
1,
6,
0,
"CLIP"
],
[
63,
5,
0,
43,
0,
"CONDITIONING"
],
[
64,
6,
0,
43,
1,
"CONDITIONING"
],
[
81,
44,
0,
52,
0,
"CONTROL_NET"
],
[
121,
7,
2,
86,
0,
"*"
],
[
124,
7,
0,
88,
0,
"*"
],
[
127,
86,
0,
90,
0,
"*"
],
[
129,
90,
0,
43,
4,
"VAE"
],
[
137,
88,
0,
89,
0,
"*"
],
[
170,
89,
0,
30,
1,
"MODEL"
],
[
171,
90,
0,
107,
0,
"*"
],
[
172,
107,
0,
30,
4,
"VAE"
],
[
176,
110,
0,
30,
0,
"IMAGE"
],
[
190,
43,
0,
30,
2,
"CONDITIONING"
],
[
191,
43,
1,
30,
3,
"CONDITIONING"
],
[
192,
52,
0,
43,
2,
"CONTROL_NET"
],
[
195,
122,
0,
123,
0,
"UPSCALE_MODEL"
],
[
197,
123,
0,
34,
0,
"IMAGE"
],
[
200,
124,
0,
125,
0,
"*"
],
[
206,
125,
0,
43,
3,
"IMAGE"
],
[
207,
135,
0,
124,
0,
"*"
],
[
209,
30,
0,
136,
0,
"IMAGE"
],
[
210,
124,
0,
123,
1,
"IMAGE"
],
[
211,
34,
0,
110,
0,
"*"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1.0152559799477252,
"offset": [
-2260.53316345765,
-499.7179536588252
]
},
"frontendVersion": "1.23.4",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}
================================================
FILE: workflows/distributed-wan-2.2_14b_t2v.json
================================================
{
"id": "8968d33f-abd1-4e8a-8e55-5d87a104afb8",
"revision": 0,
"last_node_id": 92,
"last_link_id": 187,
"nodes": [
{
"id": 82,
"type": "CreateVideo",
"pos": [
640,
1460
],
"size": [
270,
78
],
"flags": {
"collapsed": true
},
"order": 16,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 172
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "VIDEO",
"type": "VIDEO",
"links": [
187
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CreateVideo"
},
"widgets_values": [
16
]
},
{
"id": 80,
"type": "CreateVideo",
"pos": [
20,
1450
],
"size": [
270,
78
],
"flags": {
"collapsed": true
},
"order": 15,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 170
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "VIDEO",
"type": "VIDEO",
"links": [
186
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CreateVideo"
},
"widgets_values": [
16
]
},
{
"id": 78,
"type": "CreateVideo",
"pos": [
970,
480
],
"size": [
270,
78
],
"flags": {
"collapsed": true
},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 168
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "VIDEO",
"type": "VIDEO",
"links": [
184
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CreateVideo"
},
"widgets_values": [
16
]
},
{
"id": 60,
"type": "CreateVideo",
"pos": [
80,
610
],
"size": [
270,
78
],
"flags": {
"collapsed": true
},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 166
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
}
],
"outputs": [
{
"name": "VIDEO",
"type": "VIDEO",
"links": [
185
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CreateVideo"
},
"widgets_values": [
16
]
},
{
"id": 77,
"type": "ImageBatchDivider",
"pos": [
930,
380
],
"size": [
270,
118
],
"flags": {
"collapsed": true
},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 165
}
],
"outputs": [
{
"name": "batch_1",
"type": "IMAGE",
"links": [
166
]
},
{
"name": "batch_2",
"type": "IMAGE",
"links": [
168
]
},
{
"name": "batch_3",
"type": "IMAGE",
"links": [
170
]
},
{
"name": "batch_4",
"type": "IMAGE",
"links": [
172
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "ae3201accb0161040bfd7c5705b08874726b1853",
"Node name for S&R": "ImageBatchDivider"
},
"widgets_values": [
4
]
},
{
"id": 67,
"type": "DistributedCollector",
"pos": [
700,
380
],
"size": [
166.50416564941406,
26
],
"flags": {
"collapsed": true
},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 141
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
165
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "b54f2be19cab29543068d9c4355d9c5b773bee0d",
"Node name for S&R": "DistributedCollector"
},
"widgets_values": []
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
520,
380
],
"size": [
210,
46
],
"flags": {
"collapsed": true
},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 178
},
{
"name": "vae",
"type": "VAE",
"link": 76
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
141
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
580,
340
],
"size": [
430,
180
],
"flags": {
"collapsed": true
},
"order": 5,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 148
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
176
]
}
],
"title": "CLIP Text Encode (Negative Prompt)",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
],
"color": "#322",
"bgcolor": "#533"
},
{
"id": 57,
"type": "KSamplerAdvanced",
"pos": [
890,
10
],
"size": [
304.748046875,
334
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 139
},
{
"name": "positive",
"type": "CONDITIONING",
"link": 174
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 176
},
{
"name": "latent_image",
"type": "LATENT",
"link": 179
},
{
"name": "noise_seed",
"type": "INT",
"widget": {
"name": "noise_seed"
},
"link": 143
}
],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"links": [
178
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "KSamplerAdvanced"
},
"widgets_values": [
"enable",
1070591872081175,
"randomize",
4,
1,
"euler",
"simple",
0,
1000,
"disable"
]
},
{
"id": 68,
"type": "DistributedSeed",
"pos": [
560,
100
],
"size": [
270,
82
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "seed",
"type": "INT",
"links": [
143
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "b54f2be19cab29543068d9c4355d9c5b773bee0d",
"Node name for S&R": "DistributedSeed"
},
"widgets_values": [
349924686792776,
"randomize"
]
},
{
"id": 66,
"type": "LoraLoaderModelOnly",
"pos": [
320,
0
],
"size": [
240,
82
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 134
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
135
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "LoraLoaderModelOnly"
},
"widgets_values": [
"WAN\\Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors",
0.8000000000000002
]
},
{
"id": 71,
"type": "CLIPLoaderGGUF",
"pos": [
30,
70
],
"size": [
270,
82
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
148,
149
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-gguf",
"ver": "1.1.1",
"Node name for S&R": "CLIPLoaderGGUF"
},
"widgets_values": [
"umt5-xxl-encoder-Q8_0.gguf",
"wan"
]
},
{
"id": 63,
"type": "UnetLoaderGGUF",
"pos": [
30,
0
],
"size": [
270,
58
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
134
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-gguf",
"ver": "1.1.1",
"Node name for S&R": "UnetLoaderGGUF"
},
"widgets_values": [
"Wan2.2-T2V-A14B-LowNoise-Q8_0.gguf"
]
},
{
"id": 54,
"type": "ModelSamplingSD3",
"pos": [
610,
0
],
"size": [
210,
60
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 135
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
139
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "ModelSamplingSD3"
},
"widgets_values": [
8.000000000000002
]
},
{
"id": 90,
"type": "SaveVideo",
"pos": [
20,
420
],
"size": [
580,
678
],
"flags": {},
"order": 17,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": 185
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.47",
"Node name for S&R": "SaveVideo"
},
"widgets_values": [
"video/ComfyUI",
"auto",
"auto"
]
},
{
"id": 91,
"type": "SaveVideo",
"pos": [
20,
1130
],
"size": [
580,
678
],
"flags": {},
"order": 19,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": 186
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.47",
"Node name for S&R": "SaveVideo"
},
"widgets_values": [
"video/ComfyUI",
"auto",
"auto"
]
},
{
"id": 89,
"type": "SaveVideo",
"pos": [
610,
420
],
"size": [
580,
678
],
"flags": {},
"order": 18,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": 184
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.47",
"Node name for S&R": "SaveVideo"
},
"widgets_values": [
"video/ComfyUI",
"auto",
"auto"
]
},
{
"id": 92,
"type": "SaveVideo",
"pos": [
610,
1130
],
"size": [
580,
678
],
"flags": {},
"order": 20,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": 187
}
],
"outputs": [],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.47",
"Node name for S&R": "SaveVideo"
},
"widgets_values": [
"video/ComfyUI",
"auto",
"auto"
]
},
{
"id": 39,
"type": "VAELoader",
"pos": [
30,
340
],
"size": [
320,
58
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"slot_index": 0,
"links": [
76
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "VAELoader",
"models": [
{
"name": "wan_2.1_vae.safetensors",
"url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors",
"directory": "vae"
}
]
},
"widgets_values": [
"wan_2.1_vae.safetensors"
]
},
{
"id": 59,
"type": "EmptyHunyuanLatentVideo",
"pos": [
40,
180
],
"size": [
315,
130
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "LATENT",
"type": "LATENT",
"slot_index": 0,
"links": [
179
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "EmptyHunyuanLatentVideo"
},
"widgets_values": [
704,
704,
33,
1
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
380,
210
],
"size": [
460,
88
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 149
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"slot_index": 0,
"links": [
174
]
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.45",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"A beautiful woman holding a GPU at a photoshoot"
],
"color": "#232",
"bgcolor": "#353"
}
],
"links": [
[
76,
39,
0,
8,
1,
"VAE"
],
[
134,
63,
0,
66,
0,
"MODEL"
],
[
135,
66,
0,
54,
0,
"MODEL"
],
[
139,
54,
0,
57,
0,
"MODEL"
],
[
141,
8,
0,
67,
0,
"IMAGE"
],
[
143,
68,
0,
57,
4,
"INT"
],
[
148,
71,
0,
7,
0,
"CLIP"
],
[
149,
71,
0,
6,
0,
"CLIP"
],
[
165,
67,
0,
77,
0,
"IMAGE"
],
[
166,
77,
0,
60,
0,
"IMAGE"
],
[
168,
77,
1,
78,
0,
"IMAGE"
],
[
170,
77,
2,
80,
0,
"IMAGE"
],
[
172,
77,
3,
82,
0,
"IMAGE"
],
[
174,
6,
0,
57,
1,
"CONDITIONING"
],
[
176,
7,
0,
57,
2,
"CONDITIONING"
],
[
178,
57,
0,
8,
0,
"LATENT"
],
[
179,
59,
0,
57,
3,
"LATENT"
],
[
184,
78,
0,
89,
0,
"VIDEO"
],
[
185,
60,
0,
90,
0,
"VIDEO"
],
[
186,
80,
0,
91,
0,
"VIDEO"
],
[
187,
82,
0,
92,
0,
"VIDEO"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.693433494944177,
"offset": [
630.9170235538304,
115.03441263315318
]
},
"frontendVersion": "1.23.4",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}
================================================
FILE: workflows/distributed-wan.json
================================================
{
"id": "00000000-0000-0000-0000-000000000000",
"revision": 0,
"last_node_id": 234,
"last_link_id": 79,
"nodes": [
{
"id": 67,
"type": "ModelSamplingSD3",
"pos": [
1700,
550
],
"size": [
270,
58
],
"flags": {},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 23
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
17,
26
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "ModelSamplingSD3"
},
"widgets_values": [
8.000000000000002
]
},
{
"id": 50,
"type": "WanImageToVideo",
"pos": [
1590,
230
],
"size": [
270,
210
],
"flags": {},
"order": 18,
"mode": 0,
"inputs": [
{
"name": "positive",
"type": "CONDITIONING",
"link": 7
},
{
"name": "negative",
"type": "CONDITIONING",
"link": 8
},
{
"name": "vae",
"type": "VAE",
"link": 9
},
{
"name": "clip_vision_output",
"shape": 7,
"type": "CLIP_VISION_OUTPUT",
"link": 10
},
{
"name": "start_image",
"shape": 7,
"type": "IMAGE",
"link": 11
},
{
"name": "width",
"type": "INT",
"widget": {
"name": "width"
},
"link": 5
},
{
"name": "height",
"type": "INT",
"widget": {
"name": "height"
},
"link": 6
}
],
"outputs": [
{
"name": "positive",
"type": "CONDITIONING",
"links": [
27
]
},
{
"name": "negative",
"type": "CONDITIONING",
"links": null
},
{
"name": "latent",
"type": "LATENT",
"links": [
22
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "WanImageToVideo"
},
"widgets_values": [
832,
480,
17,
1
]
},
{
"id": 60,
"type": "RandomNoise",
"pos": [
1850,
10
],
"size": [
270,
82
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "noise_seed",
"type": "INT",
"widget": {
"name": "noise_seed"
},
"link": 16
}
],
"outputs": [
{
"name": "NOISE",
"type": "NOISE",
"links": [
18
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "RandomNoise"
},
"widgets_values": [
200526686850175,
"randomize"
]
},
{
"id": 150,
"type": "DistributedSeed",
"pos": [
1540,
0
],
"size": [
270,
82
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "seed",
"type": "INT",
"links": [
16
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "9650280c7f50898720bf9773e8874d75e7c20846",
"Node name for S&R": "DistributedSeed"
},
"widgets_values": [
177896884005433,
"randomize"
]
},
{
"id": 61,
"type": "KSamplerSelect",
"pos": [
2140,
350
],
"size": [
260,
60
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "SAMPLER",
"type": "SAMPLER",
"links": [
20
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "KSamplerSelect"
},
"widgets_values": [
"lcm"
]
},
{
"id": 221,
"type": "DistributedCollector",
"pos": [
2660,
200
],
"size": [
166.50416564941406,
26
],
"flags": {
"collapsed": false
},
"order": 22,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 67
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
78
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "b4ed92a363fab14d944b87c610aa9a0b4c87c085",
"Node name for S&R": "DistributedCollector"
},
"widgets_values": []
},
{
"id": 231,
"type": "ImageBatchDivider",
"pos": [
2850,
200
],
"size": [
210,
118
],
"flags": {},
"order": 23,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 78
}
],
"outputs": [
{
"name": "batch_1",
"type": "IMAGE",
"links": [
74
]
},
{
"name": "batch_2",
"type": "IMAGE",
"links": [
75
]
},
{
"name": "batch_3",
"type": "IMAGE",
"links": [
76
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "ComfyUI-Distributed",
"ver": "fa09a2da10dfc63cecfbec1f8d4a0516e923b911",
"Node name for S&R": "ImageBatchDivider"
},
"widgets_values": [
3
]
},
{
"id": 139,
"type": "CLIPLoaderGGUF",
"pos": [
490,
60
],
"size": [
270,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP",
"type": "CLIP",
"links": [
1,
2
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-gguf",
"ver": "1.1.1",
"Node name for S&R": "CLIPLoaderGGUF"
},
"widgets_values": [
"umt5-xxl-encoder-Q8_0.gguf",
"wan"
]
},
{
"id": 142,
"type": "UnetLoaderGGUF",
"pos": [
930,
850
],
"size": [
340,
60
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
24
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-gguf",
"ver": "1.1.1",
"Node name for S&R": "UnetLoaderGGUF"
},
"widgets_values": [
"wan2.1-i2v-14b-720p-Q8_0.gguf"
]
},
{
"id": 118,
"type": "LoraLoaderModelOnly",
"pos": [
1290,
840
],
"size": [
310,
82
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 24
}
],
"outputs": [
{
"name": "MODEL",
"type": "MODEL",
"links": [
23
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "LoraLoaderModelOnly"
},
"widgets_values": [
"Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors",
0.8100000000000002
]
},
{
"id": 135,
"type": "ImageNoiseAugmentation",
"pos": [
1250,
310
],
"size": [
270,
106
],
"flags": {},
"order": 16,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 25
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
11
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-kjnodes",
"ver": "1.1.2",
"Node name for S&R": "ImageNoiseAugmentation"
},
"widgets_values": [
0.10000000000000002,
1100164865582526,
"randomize"
]
},
{
"id": 49,
"type": "CLIPVisionLoader",
"pos": [
1260,
610
],
"size": [
290,
60
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "CLIP_VISION",
"type": "CLIP_VISION",
"links": [
12
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPVisionLoader"
},
"widgets_values": [
"clip_vision_h.safetensors"
]
},
{
"id": 7,
"type": "CLIPTextEncode",
"pos": [
790,
120
],
"size": [
380,
110
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 2
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
8
]
}
],
"title": "CLIP Text Encode (Negative Prompt)",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
]
},
{
"id": 64,
"type": "SamplerCustomAdvanced",
"pos": [
2140,
190
],
"size": [
260,
110
],
"flags": {},
"order": 20,
"mode": 0,
"inputs": [
{
"name": "noise",
"type": "NOISE",
"link": 18
},
{
"name": "guider",
"type": "GUIDER",
"link": 19
},
{
"name": "sampler",
"type": "SAMPLER",
"link": 20
},
{
"name": "sigmas",
"type": "SIGMAS",
"link": 21
},
{
"name": "latent_image",
"type": "LATENT",
"link": 22
}
],
"outputs": [
{
"name": "output",
"type": "LATENT",
"links": [
3
]
},
{
"name": "denoised_output",
"type": "LATENT",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "SamplerCustomAdvanced"
},
"widgets_values": []
},
{
"id": 233,
"type": "Note",
"pos": [
2840,
40
],
"size": [
220,
90
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"Use this node to set how many videos to output.\n\nExample: if you have 1x master and 2x workers, set it to 3."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 234,
"type": "Note",
"pos": [
1540,
-140
],
"size": [
270,
90
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"This node will give you a different variation from each worker. Delete this node if you want them to be all the same."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 227,
"type": "VHS_VideoCombine",
"pos": [
3090,
200
],
"size": [
380,
555.6923217773438
],
"flags": {},
"order": 24,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 74
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-videohelpersuite",
"ver": "a7ce59e381934733bfae03b1be029756d6ce936d",
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"loop_count": 0,
"filename_prefix": "WAN",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "AnimateDiff_00163.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 8,
"workflow": "AnimateDiff_00163.png",
"fullpath": "C:\\venvs\\ComfyUI\\ComfyUI\\output\\AnimateDiff_00163.mp4"
}
}
}
},
{
"id": 228,
"type": "VHS_VideoCombine",
"pos": [
3490,
190
],
"size": [
380,
334
],
"flags": {},
"order": 25,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 75
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-videohelpersuite",
"ver": "a7ce59e381934733bfae03b1be029756d6ce936d",
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"loop_count": 0,
"filename_prefix": "WAN",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "AnimateDiff_00165.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 8,
"workflow": "AnimateDiff_00165.png",
"fullpath": "C:\\venvs\\ComfyUI\\ComfyUI\\output\\AnimateDiff_00165.mp4"
}
}
}
},
{
"id": 229,
"type": "VHS_VideoCombine",
"pos": [
3880,
190
],
"size": [
390,
334
],
"flags": {},
"order": 26,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 76
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui-videohelpersuite",
"ver": "a7ce59e381934733bfae03b1be029756d6ce936d",
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"loop_count": 0,
"filename_prefix": "WAN",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "AnimateDiff_00164.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 8,
"workflow": "AnimateDiff_00164.png",
"fullpath": "C:\\venvs\\ComfyUI\\ComfyUI\\output\\AnimateDiff_00164.mp4"
}
}
}
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
2450,
190
],
"size": [
140,
46
],
"flags": {},
"order": 21,
"mode": 0,
"inputs": [
{
"name": "samples",
"type": "LATENT",
"link": 3
},
{
"name": "vae",
"type": "VAE",
"link": 4
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
67
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "VAEDecode"
},
"widgets_values": []
},
{
"id": 51,
"type": "CLIPVisionEncode",
"pos": [
1260,
490
],
"size": [
290.97918701171875,
78
],
"flags": {},
"order": 15,
"mode": 0,
"inputs": [
{
"name": "clip_vision",
"type": "CLIP_VISION",
"link": 12
},
{
"name": "image",
"type": "IMAGE",
"link": 13
}
],
"outputs": [
{
"name": "CLIP_VISION_OUTPUT",
"type": "CLIP_VISION_OUTPUT",
"links": [
10
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPVisionEncode"
},
"widgets_values": [
"none"
]
},
{
"id": 192,
"type": "LoadImage",
"pos": [
530,
310
],
"size": [
310,
370
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
79
]
},
{
"name": "MASK",
"type": "MASK",
"links": null
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"example.png",
"image"
]
},
{
"id": 55,
"type": "ImageResize+",
"pos": [
890,
320
],
"size": [
270,
218
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 79
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
13,
25
]
},
{
"name": "width",
"type": "INT",
"links": [
5
]
},
{
"name": "height",
"type": "INT",
"links": [
6
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfyui_essentials",
"ver": "9d9f4bedfc9f0321c19faf71855e228c93bd0dc9",
"Node name for S&R": "ImageResize+"
},
"widgets_values": [
832,
480,
"lanczos",
"fill / crop",
"always",
0
]
},
{
"id": 6,
"type": "CLIPTextEncode",
"pos": [
790,
-10
],
"size": [
380,
88
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "clip",
"type": "CLIP",
"link": 1
}
],
"outputs": [
{
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
7
]
}
],
"title": "CLIP Text Encode (Positive Prompt)",
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "CLIPTextEncode"
},
"widgets_values": [
"beautiful scenery nature glass bottle landscape, , purple galaxy bottle,"
]
},
{
"id": 39,
"type": "VAELoader",
"pos": [
1250,
200
],
"size": [
270,
58
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "VAE",
"type": "VAE",
"links": [
4,
9
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "VAELoader"
},
"widgets_values": [
"wan_2.1_vae.safetensors"
]
},
{
"id": 148,
"type": "BasicGuider",
"pos": [
1980,
260
],
"size": [
156.0208282470703,
46
],
"flags": {
"collapsed": true
},
"order": 19,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 26
},
{
"name": "conditioning",
"type": "CONDITIONING",
"link": 27
}
],
"outputs": [
{
"name": "GUIDER",
"type": "GUIDER",
"links": [
19
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "BasicGuider"
},
"widgets_values": []
},
{
"id": 62,
"type": "BasicScheduler",
"pos": [
2130,
460
],
"size": [
270,
106
],
"flags": {},
"order": 17,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "MODEL",
"link": 17
}
],
"outputs": [
{
"name": "SIGMAS",
"type": "SIGMAS",
"links": [
21
]
}
],
"properties": {
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.3.43",
"Node name for S&R": "BasicScheduler"
},
"widgets_values": [
"simple",
4,
1
]
}
],
"links": [
[
1,
139,
0,
6,
0,
"CLIP"
],
[
2,
139,
0,
7,
0,
"CLIP"
],
[
3,
64,
0,
8,
0,
"LATENT"
],
[
4,
39,
0,
8,
1,
"VAE"
],
[
5,
55,
1,
50,
5,
"INT"
],
[
6,
55,
2,
50,
6,
"INT"
],
[
7,
6,
0,
50,
0,
"CONDITIONING"
],
[
8,
7,
0,
50,
1,
"CONDITIONING"
],
[
9,
39,
0,
50,
2,
"VAE"
],
[
10,
51,
0,
50,
3,
"CLIP_VISION_OUTPUT"
],
[
11,
135,
0,
50,
4,
"IMAGE"
],
[
12,
49,
0,
51,
0,
"CLIP_VISION"
],
[
13,
55,
0,
51,
1,
"IMAGE"
],
[
16,
150,
0,
60,
0,
"INT"
],
[
17,
67,
0,
62,
0,
"MODEL"
],
[
18,
60,
0,
64,
0,
"NOISE"
],
[
19,
148,
0,
64,
1,
"GUIDER"
],
[
20,
61,
0,
64,
2,
"SAMPLER"
],
[
21,
62,
0,
64,
3,
"SIGMAS"
],
[
22,
50,
2,
64,
4,
"LATENT"
],
[
23,
118,
0,
67,
0,
"MODEL"
],
[
24,
142,
0,
118,
0,
"MODEL"
],
[
25,
55,
0,
135,
0,
"IMAGE"
],
[
26,
67,
0,
148,
0,
"MODEL"
],
[
27,
50,
0,
148,
1,
"CONDITIONING"
],
[
67,
8,
0,
221,
0,
"IMAGE"
],
[
74,
231,
0,
227,
0,
"IMAGE"
],
[
75,
231,
1,
228,
0,
"IMAGE"
],
[
76,
231,
2,
229,
0,
"IMAGE"
],
[
78,
221,
0,
231,
0,
"IMAGE"
],
[
79,
192,
0,
55,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.9090909090909091,
"offset": [
-1587.6131275248988,
280.3225523752761
]
},
"frontendVersion": "1.23.4",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}