Repository: bio-mlhui/LGRNet Branch: main Commit: e97ab7391f2b Files: 117 Total size: 13.2 MB Directory structure: gitextract__wc8fr5u/ ├── .gitignore ├── README.md ├── assets/ │ ├── DATA.md │ ├── INSTALL.md │ └── MODEL_ZOO.md ├── data_schedule/ │ ├── __init__.py │ ├── registry.py │ ├── utils/ │ │ ├── box_ops.py │ │ ├── sampler.py │ │ └── segmentation.py │ └── vis/ │ ├── __init__.py │ ├── apis.py │ ├── evaluator_fast.py │ ├── evaluator_utils.py │ ├── fibroid/ │ │ ├── __init__.py │ │ ├── evals.py │ │ ├── fibroid_dataset.py │ │ ├── fibroid_utils.py │ │ └── metrics.py │ ├── mapper.py │ ├── mapper_utils.py │ ├── polyp/ │ │ ├── __init__.py │ │ ├── evals.py │ │ ├── polyp_dataset.py │ │ └── polyp_utils.py │ ├── vis_aug_eval.py │ ├── vis_aug_train.py │ ├── vis_aug_utils.py │ └── vis_frame_sampler.py ├── handle_vps.py ├── main.py ├── models/ │ ├── VIS/ │ │ ├── BackboneEncoderDecoder_WithScaleConsistency.py │ │ ├── __init__.py │ │ └── aux_mapper.py │ ├── __init__.py │ ├── backbone/ │ │ ├── __init__.py │ │ ├── pvtv2.py │ │ ├── res2net.py │ │ └── utils.py │ ├── decoder/ │ │ ├── __init__.py │ │ └── mask2former_video.py │ ├── encoder/ │ │ ├── __init__.py │ │ ├── input_projs.py │ │ ├── localGlobal.py │ │ ├── neighborhood_qk.py │ │ └── ops/ │ │ ├── MultiScaleDeformableAttention.egg-info/ │ │ │ └── PKG-INFO │ │ ├── attention.py │ │ ├── build/ │ │ │ ├── lib.linux-x86_64-cpython-311/ │ │ │ │ ├── functions/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ └── modules/ │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── lib.linux-x86_64-cpython-38/ │ │ │ │ ├── functions/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── ms_deform_attn_func.py │ │ │ │ └── modules/ │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── temp.linux-x86_64-cpython-311/ │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ └── home/ │ │ │ │ └── xhh/ │ │ │ │ └── workspace/ │ │ │ │ └── rvos_encoder/ │ │ │ │ └── models/ │ │ │ │ └── ops/ │ │ │ │ └── src/ │ │ │ │ ├── cpu/ │ │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ │ ├── cuda/ │ │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ │ └── vision.o │ │ │ └── temp.linux-x86_64-cpython-38/ │ │ │ └── home/ │ │ │ └── xhh/ │ │ │ └── workspace/ │ │ │ └── ReferFormer/ │ │ │ └── models/ │ │ │ └── ops/ │ │ │ └── src/ │ │ │ ├── cpu/ │ │ │ │ └── ms_deform_attn_cpu.o │ │ │ ├── cuda/ │ │ │ │ └── ms_deform_attn_cuda.o │ │ │ └── vision.o │ │ ├── dist/ │ │ │ ├── MultiScaleDeformableAttention-1.0-py3.11-linux-x86_64.egg │ │ │ └── MultiScaleDeformableAttention-1.0-py3.8-linux-x86_64.egg │ │ ├── functions/ │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── make.sh │ │ ├── modules/ │ │ │ ├── __init__.py │ │ │ ├── frame_query_ss2d.py │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ ├── src/ │ │ │ ├── cpu/ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda/ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── test.py │ ├── layers/ │ │ ├── anyc_trans.py │ │ ├── decoder_layers.py │ │ ├── gilbert/ │ │ │ ├── demo/ │ │ │ │ ├── index.html │ │ │ │ ├── normalize.css │ │ │ │ ├── script.js │ │ │ │ ├── skeleton.css │ │ │ │ └── two.js │ │ │ ├── gilbert2d.py │ │ │ ├── gilbert3d.py │ │ │ ├── gilbert_d2xy.py │ │ │ ├── gilbert_d2xyz.py │ │ │ ├── gilbert_xy2d.py │ │ │ ├── gilbert_xyz2d.py │ │ │ ├── plotpath.m │ │ │ ├── ports/ │ │ │ │ ├── Makefile │ │ │ │ ├── gilbert.c │ │ │ │ └── gilbert.js │ │ │ ├── test.py │ │ │ └── tests/ │ │ │ └── runtests.sh │ │ ├── matching.py │ │ ├── position_encoding.py │ │ └── utils.py │ ├── modality_input_mappers/ │ │ ├── __init__.py │ │ └── hilbert_curve.py │ ├── optimization/ │ │ ├── optimizer.py │ │ └── scheduler.py │ └── registry.py ├── output/ │ └── VIS/ │ ├── cvc/ │ │ └── pvt.py │ ├── fibroid/ │ │ └── pvt.py │ └── sunseg/ │ ├── pvt/ │ │ └── pvt.py │ └── res/ │ └── res.py ├── reorganize_sunseg.py ├── trainers/ │ ├── Trainer.py │ └── __init__.py └── utils/ ├── __init__.py └── misc.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ **/__pycache__/** **/wandb/** **/.vscode/** *.out *.err *.zip *.tar *.pth stdout_train.txt stdout_eval.txt stdout_visualize.txt hostfile **/olds/** **/old/** *.pt *.txt miccai_generate_Dir generate.py visualzie.py ================================================ FILE: README.md ================================================ ## LGRNet: Local-Global Reciprocal Network for Video Polyp Segmentation [`Paper`](https://arxiv.org/abs/2407.05703) | [`BibTeX`](#citing) | [`Huggingface(UFUV Dataset)`](https://huggingface.co/datasets/huihuixu/uterine_fibroid_ultrasound_video_segmentation) Huihui Xu, Yijun Yang, Angelica Aviles-Rivero, Guang Yang, Jing Qin, and Lei Zhu [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lgrnet-local-global-reciprocal-network-for/video-polyp-segmentation-on-sun-seg-hard)](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-hard?p=lgrnet-local-global-reciprocal-network-for) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lgrnet-local-global-reciprocal-network-for/video-polyp-segmentation-on-sun-seg-easy)](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-easy?p=lgrnet-local-global-reciprocal-network-for) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lgrnet-local-global-reciprocal-network-for/video-polyp-segmentation-on-sun-seg-easy-1)](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-easy-1?p=lgrnet-local-global-reciprocal-network-for) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/lgrnet-local-global-reciprocal-network-for/video-polyp-segmentation-on-sun-seg-hard-1)](https://paperswithcode.com/sota/video-polyp-segmentation-on-sun-seg-hard-1?p=lgrnet-local-global-reciprocal-network-for) This is the official implmentation of LGRNet (MICCAI'24 Early Accept), which incorporates local **[Cyclic Neighborhoold Propagation](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/neighborhood_qk.py#L57)** and global **[Hilbert Selective Scan](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/ops/modules/frame_query_ss2d.py#L531)**. Together with the notion of **[Frame Bottleneck Queries](https://github.com/bio-mlhui/LGRNet/blob/main/models/encoder/localGlobal.py#L185)**, LGRNet can both efficiently and effectively aggregate the local-global temporal context, which achieves *state-of-the-art* on the public [Video Polyp Segmentation(VPS)](https://paperswithcode.com/task/video-polyp-segmentation) benchmark.
As an example for ultrasound video, a single frame is too noisy and insufficient for accurate lesion diagnosis. In practice, doctors need to check neighboring frames(local) and collect all visual clues (global) in the video to predict possible lesion region and filter out irrelevent surrounding issues.


In CNP, each token takes the neighborhood tokens (defined by a kernel) in the cyclic frame as attention keys. CNP enables aggregating the local(cyclic) temporal information into one token. In Hilbert Selective Scan, a set of frame bottleneck queries are used to aggreate spatial information from each frame. Then, we use Hilbert Selective Scan to efficiently parse the global temporal context based on these bottleneck queries. The global temporal context is then propagated back to the feature maps by a Distribute layer. Based on Mask2Former, the decoder can output a set of different mask predictions with corresponding confidence score, which also facilitates comprehesive diagnosis.
## Items 1. Installation: Please refer to [INSTALL.md](assets/INSTALL.md) for more details. 2. Data preparation: Please refer to [DATA.md](assets/DATA.md) for more details. 3. Training: Change PORT_NUM for DDP and make sure the $CURRENT_TASK is 'VIS': ``` export CURRENT_TASK=VIS export MASTER_ADDR=127.0.0.1 export MASTER_PORT=PORT_NUM ``` Make sure the $PT_PATH and $DATASET_PATH are correctly set during installation and preparing data. The training on SUN-SEG is conducted using 2 4090-24GB GPUs: ``` CUDA_VISIBLE_DEVICES=0,1 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode train_attmpt ``` 4. logs, checkpoints, predictions | Backbone| Dataset | Dice | mIou | log | ckpt | predictions | | :----: | :----: | :----: | :----: | :----: | :----: |:----: | | PVTv2-B2 | SUN-SEG-Train | -- | -- | [log](https://drive.google.com/file/d/17MTOYW73RLbvZS3BLFBZEphY_0JzN6er/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | -- | PVTv2-B2 | SUN-SEG-Hard-Testing | 0.876 | 0.805 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing) | PVTv2-B2 | SUN-SEG-Easy-Testing | 0.875 | 0.810 | [log](https://drive.google.com/file/d/1wdVMWMknSlURaBROWbMax4iS9V1Tbn9-/view?usp=sharing) |[ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing) | PVTv2-B2 | SUN-SEG-Hard-Unseen-Testing | 0.865 | 0.792 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing) | [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing) | PVTv2-B2 | SUN-SEG-Easy-Unseen-Testing | 0.853 | 0.783 | [log](https://drive.google.com/file/d/1obt_qvWCvslhRY-e4SrTJNS0r6Diad4e/view?usp=sharing) | [ckpt](https://drive.google.com/file/d/1D4YAIfFCCQIsDfKgSCr9tCw7vDAgqf76/view?usp=sharing)| [mask predictions](https://drive.google.com/file/d/1V8CDMC87o7t4eyts4BVEwflDUrFpAOVX/view?usp=sharing) | Res2Net-50 | SUN-SEG-Hard-Testing | 0.841 | 0.765 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) | | Res2Net-50 | SUN-SEG-Easy-Testing | 0.843 | 0.774 | [log](https://drive.google.com/file/d/17pUxFMuHpPD_In5RVrJUsPFZGOgNFzb6/view?usp=sharing) | | PVTv2-B2 | CVC612V | 0.933 | 0.877 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) | | PVTv2-B2 | CVC300TV | 0.916 | 0.852 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) | | PVTv2-B2 | CVC612T | 0.875 | 0.814 | [log](https://drive.google.com/file/d/1m36mJL0Fu3T9F73TqFGnFsWaCGh3JDeJ/view?usp=drive_link) | 5. Evaluate: Evaluating on SUN-SEG-Easy AND SUN-SEG-Hard using 1 4090-24GPU GPUS (**modify the ckpt_path to the absolute path**): ``` CUDA_VISIBLE_DEVICES=0 TORCH_NUM_WORKERS=8 python main.py --config_file output/VIS/sunseg/pvt/pvt.py --trainer_mode eval --eval_path ckpt_path ``` ## citing ``` @article{xu2024lgrnet, title={LGRNet: Local-Global Reciprocal Network for Uterine Fibroid Segmentation in Ultrasound Videos}, author={Xu, Huihui and Yang, Yijun and Aviles-Rivero, Angelica I and Yang, Guang and Qin, Jing and Zhu, Lei}, journal={arXiv preprint arXiv:2407.05703}, year={2024} } ``` ## Acknowledgments - Thanks [Gilbert](https://github.com/jakubcerveny/gilbert) for the implementation of Hilbert curve generation. - Thanks GPT4 for helping me constructing idea of Hilbert Filling Curve v.s. Zigzag curve ================================================ FILE: assets/DATA.md ================================================ # Data Preparation ## UFUV (Private): please email the second author for UFUV dataset if you want, I have no absolute power for UFUV ## VPS (Public) ### CVC/Kvasir/Mayo We follow [PNS-Net](https://github.com/GewelsJI/PNS-Net) to download the CVC/Kvasir/Mayo dataset. The download link is same as [link](https://drive.google.com/file/d/1TyaRy4c4nHFDa3o2bOl4dP5Z7wes7HV2/view?usp=sharing) Put MICCAI-VPS-dataset.zip in $DATASET_PATH, then run following script to change the directory structure: ``` cd $DATASET_PATH unzip -qq MICCAI-VPS-dataset.zip # cd LGRNet directory # normalize the VPS data structure python handle_vps.py ``` Now the structure should be like: ``` ${DATASET_PATH} -- MICCAI-VPS-dataset -- Kvasir-SEG -- * -- VPS-TestSet -- CVC-ColonDB-300 -- * -- CVC-ClinicDB-612-Valid -- * -- CVC-ClinicDB-612-Test -- * -- VPS-TrainSet -- ASU-Mayo_Clinic -- Train -- * -- CVC-ClinicDB-612 -- Train -- * -- CVC-ColonDB-300 -- Train -- * ``` where * means the following structure: ``` -- Frame -- vid1 -- img file -- GT -- vid1 -- mask file ``` ### SUN-SEG Please follow https://github.com/GewelsJI/VPS/blob/main/docs/DATA_PREPARATION.md to email the author for SUN-SEG. Put part1, part2, annotation in $DATASET_PATH/SUN-SEG ``` # normalize the directory unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part1.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive unzip -qq $DATASET_PATH/SUN-SEG/sundatabase_positive_part2.zip -d $DATASET_PATH/SUN-SEG/SUN-Positive tar -xf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation.tar -C $DATASET_PATH/SUN-SEG/ rm -rf $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Unseen/Frame find $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation -name "._.DS_Store" -type f -delete python reorganize_sunseg.py ``` Now the structure should be like: ``` ${DATASET_PATH} -- SUN-SEG -- SUN-SEG-Annotation -- TrainDataset -- * -- TestEasyDataset -- combine -- * -- TestHardDataset -- combine -- * ``` ================================================ FILE: assets/INSTALL.md ================================================ # Install ## Requirements We test the codes in the following environments - CUDA 12.1 - Python 3.10.13 - Pytorch 2.1.1 - Torchvison 0.16.1 - detectron 0.6 - mamba_ssm 1.2.0.post1 - natten 0.15.1 - timm 0.9.12 ## Install environment for LGRNet ``` conda create --name lgrnet python=3.10 conda activate lgrnet # make sure CUDA-12.1 is installed and activated in env var. # install torch pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu121 # install detectron2, building may take much time. python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' # install mamba, see https://github.com/state-spaces/mamba cd .. git clone https://github.com/state-spaces/mamba.git cd mamba pip install . --no-build-isolation cd ../LGRNet # install natten, see https://github.com/SHI-Labs/NATTEN/blob/main/docs/install.md pip install natten==0.15.1+torch210cu121 -f https://shi-labs.com/natten/wheels/cu121/torch2.1.0/natten-0.15.1%2Btorch210cu121-cp310-cp310-linux_x86_64.whl # misc pip install albumentations==1.3.1 pip install Pygments pip install imgaug pip install timm==0.9.12 # compile deform attention cd models/encoder/ops/ python setup.py build install --user # download resnet/pvtv2 ckpt, our model uses the same backbone with WeakPoly(phttps://github.com/weijun88/WeakPolyp) wget -P $PT_PATH/pvt_v2/pvt_v2_b2.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/pvt_v2_b2.pth wget -P $PT_PATH/res2net/res2net50_v1b_26w_4s-3cf99910.pth https://huggingface.co/huihuixu/lgrnet_ckpts/blob/main/res2net50_v1b_26w_4s-3cf99910.pth ``` ================================================ FILE: assets/MODEL_ZOO.md ================================================ ================================================ FILE: data_schedule/__init__.py ================================================ import os if os.getenv('CURRENT_TASK') == 'VIS': from . import vis else: raise ValueError() def build_schedule(configs, model_input_mapper, model_input_collate_fn): import logging from functools import partial import detectron2.utils.comm as comm from torch.utils.data import DataLoader, ConcatDataset from .registry import MAPPER_REGISTRY, EVALUATOR_REGISTRY from detectron2.data import DatasetCatalog, DatasetFromList, MapDataset, MetadataCatalog from data_schedule.utils.sampler import Evaluate_ExactSampler_Distributed, Train_InfiniteSampler_Distributed datasets = {'train': [], 'evaluate': []} meta_idx_shift = 0 for mode in ['train', 'evaluate']: for dataset_name in configs['data'][mode].keys(): dataset_assume_mode = MetadataCatalog.get(dataset_name).get('mode') if dataset_assume_mode != mode: logging.warning(f'default mode of {dataset_name} is {dataset_assume_mode} not {mode}') dataset_dicts = DatasetFromList(DatasetCatalog.get(dataset_name), copy=True, serialize=True) mapper = MAPPER_REGISTRY.get(configs['data'][mode][dataset_name]['mapper']['name'])(mode=mode, dataset_name=dataset_name, configs=configs, meta_idx_shift=meta_idx_shift if mode == 'train' else 0) meta_idx_shift += len(dataset_dicts) dataset = MapDataset(dataset_dicts, partial(composition, mappers=[mapper, partial(model_input_mapper, mode=mode)])) if mode == 'train': datasets[mode].append(dataset) else: datasets[mode].append((dataset_name, dataset)) train_dataset = ConcatDataset(datasets['train']) logging.debug(f'Total number of training meta: {len(train_dataset)}') train_loader_splits = configs['optim']['splits'] batch_sizes = configs['optim']['batch_sizes'] splits = list(zip(train_loader_splits[:-1], train_loader_splits[1:])) assert len(splits) == (len(batch_sizes)) inf_stream_fn = partial(infinite_indices, seed=configs['stream_idx_seed'], batch_sizes=configs['optim']['batch_sizes'], splits=configs['optim']['splits'], one_batch_two_epoch=configs['optim']['one_batch_two_epoch'], dataset_length=len(train_dataset), shuffle=True) train_samplers = [] train_loaders = [] for btch_size, (range_start, range_end) in zip(batch_sizes, splits): if range_end is not None: assert (range_end - range_start) % btch_size == 0, '' assert btch_size % comm.get_world_size() == 0, '' each_process_batch_size = int(btch_size / comm.get_world_size()) loader_sampler = Train_InfiniteSampler_Distributed(inf_stream_fn=inf_stream_fn, start_idx=range_start, end_idx=range_end,) train_samplers.append(loader_sampler) train_loaders.append(DataLoader(train_dataset, batch_size=each_process_batch_size, sampler=loader_sampler, collate_fn=partial(model_input_collate_fn, mode='train'), num_workers=int(os.getenv('TORCH_NUM_WORKERS')), pin_memory=True, persistent_workers=True)) evaluators = [] for eval_dataset_name, eval_dataset in datasets['evaluate']: logging.debug(f'Number of evaluate meta in {eval_dataset_name}: {len(eval_dataset)}') loader = DataLoader(eval_dataset, batch_size=1, sampler=Evaluate_ExactSampler_Distributed(eval_dataset), collate_fn=partial(model_input_collate_fn, mode='evaluate'), num_workers=int(os.getenv('TORCH_NUM_WORKERS')), pin_memory=True, persistent_workers=True) evaluator = EVALUATOR_REGISTRY.get(configs['data']['evaluate'][eval_dataset_name]['evaluator']['name'])(configs=configs, dataset_name=eval_dataset_name, data_loader=loader) evaluators.append((eval_dataset_name, evaluator)) return train_samplers, train_loaders, partial(evaluate_call, evaluators=evaluators) def composition(data_dict, mappers): for mappper in mappers: data_dict = mappper(data_dict) if data_dict is None: return None return data_dict def evaluate_call(evaluators, model, output_dir): import detectron2.utils.comm as comm ret = {} for eval_dataset_name, evaluator in evaluators: metric_dict = evaluator(model=model,output_dir=output_dir) if comm.is_main_process(): for key, value in metric_dict.items(): assert f'{key}_{eval_dataset_name}' not in ret ret[f'{key}_{eval_dataset_name}'] = value comm.synchronize() return ret def _infinite_indices(seed, dataset_length, shuffle=True,): import torch g = torch.Generator() g.manual_seed(seed) while True: if shuffle: yield from torch.randperm(dataset_length, generator=g).tolist() else: yield from torch.arange(dataset_length).tolist() def infinite_indices(seed, dataset_length, batch_sizes, splits, one_batch_two_epoch='just_use', shuffle=True): # 'abandon', 'just_use', 'pad' import torch import math g = torch.Generator() g.manual_seed(seed) split_ranges = list(zip(splits[:-1], splits[1:])) assert len(split_ranges) == (len(batch_sizes)) stream = _infinite_indices(seed, dataset_length=dataset_length, shuffle=shuffle) stream_throw_cnt = 0 cnt = 0 for (range_start, range_end), btch_size in zip(split_ranges, batch_sizes): assert cnt == range_start if range_end == None: range_end = math.inf while cnt < range_end: epoch_milestone = ((stream_throw_cnt // dataset_length) + 1 ) * dataset_length if (stream_throw_cnt < epoch_milestone) and (stream_throw_cnt + btch_size > epoch_milestone) and (one_batch_two_epoch != 'just_use'): if one_batch_two_epoch == 'abandon': for _ in range(epoch_milestone - stream_throw_cnt): abandon = next(stream) stream_throw_cnt += 1 elif one_batch_two_epoch == 'pad': diff = stream_throw_cnt + btch_size - epoch_milestone num_throw = btch_size - diff rand_idxs = torch.randperm(dataset_length, generator=g)[:diff].tolist() for _ in range(num_throw): cnt += 1 stream_throw_cnt += 1 yield next(stream) for idx in rand_idxs: cnt += 1 yield idx else: raise ValueError() else: for _ in range(btch_size): cnt += 1 stream_throw_cnt += 1 yield next(stream) assert cnt == range_end ================================================ FILE: data_schedule/registry.py ================================================ from detectron2.utils.registry import Registry EVALUATOR_REGISTRY = Registry('EVALUATOR') MAPPER_REGISTRY = Registry('MAPPER') class Mapper: def __init__(self, meta_idx_shift, dataset_meta,) -> None: self.meta_idx_shift = meta_idx_shift self.visualized_meta_idxs = dataset_meta.get('visualize_meta_idxs') def _call(self, data_dict): pass def __call__(self, data_dict): meta_idx = data_dict['meta_idx'] ret = self._call(data_dict) if ret is None: return None ret['meta_idx'] = meta_idx + self.meta_idx_shift if meta_idx in self.visualized_meta_idxs: ret['visualize'] = True else: ret['visualize'] = False return ret ================================================ FILE: data_schedule/utils/box_ops.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Utilities for bounding box manipulation and GIoU. """ import torch from torchvision.ops.boxes import box_area def box_cxcywh_to_xyxy(x): x_c, y_c, w, h = x.unbind(-1) b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] return torch.stack(b, dim=-1) def box_xyxy_to_cxcywh(x): x0, y0, x1, y1 = x.unbind(-1) assert ((x1 - x0) >= 0).all() assert ((y1 - y0) >= 0).all() b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] return torch.stack(b, dim=-1) # modified from torchvision to also return the union def box_iou(boxes1, boxes2): area1 = box_area(boxes1) area2 = box_area(boxes2) lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] wh = (rb - lt).clamp(min=0) # [N,M,2] inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] union = area1[:, None] + area2 - inter iou = inter / union return iou, union def generalized_box_iou(boxes1, boxes2): """ Generalized IoU from https://giou.stanford.edu/ The boxes should be in [x0, y0, x1, y1] format Returns a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) """ # degenerate boxes gives inf / nan results # so do an early check assert (boxes1[:, 2:] >= boxes1[:, :2]).all() assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) wh = (rb - lt).clamp(min=0) # [N,M,2] area = wh[:, :, 0] * wh[:, :, 1] return iou - (area - union) / area def masks_to_boxes(masks): """Compute the bounding boxes around the provided masks The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. Returns a [N, 4] tensors, with the boxes in xyxy format """ if masks.numel() == 0: return torch.zeros((0, 4), device=masks.device) h, w = masks.shape[-2:] y = torch.arange(0, h, dtype=torch.float) x = torch.arange(0, w, dtype=torch.float) y, x = torch.meshgrid(y, x) x_mask = (masks * x.unsqueeze(0)) x_max = x_mask.flatten(1).max(-1)[0] x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] y_mask = (masks * y.unsqueeze(0)) y_max = y_mask.flatten(1).max(-1)[0] y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] return torch.stack([x_min, y_min, x_max, y_max], 1) ================================================ FILE: data_schedule/utils/sampler.py ================================================ import math import torch.distributed as dist from typing import TypeVar, Optional, Iterator T_co = TypeVar('T_co', covariant=True) from torch.utils.data.distributed import DistributedSampler import detectron2.utils.comm as comm from utils.misc import all_gather from torch.utils.data import Sampler import torch import logging import itertools class TrainRandomSampler_ByEpoch(Sampler[int]): def __init__(self, data_source, seed, ) -> None: self.data_source = data_source self.num_samples = len(self.data_source) self.seed = seed self.epoch = None def __iter__(self): seed = self.seed + self.epoch print(f'generating a new indices permutations for this epoch using seed {seed}') n = len(self.data_source) g = torch.Generator() g.manual_seed(seed) for _ in range(self.num_samples // n): yield from torch.randperm(n, generator=g).tolist() yield from torch.randperm(n, generator=g).tolist()[:self.num_samples % n] def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Args: epoch (int): Epoch number. """ self.epoch = epoch class Train_InfiniteSampler_Distributed(Sampler[T_co]): def __init__(self, inf_stream_fn, start_idx: int = 0, end_idx = None, ): self.rank = comm.get_rank() self.num_replicas = comm.get_world_size() self.start_idx = start_idx self.end_idx = end_idx self.inf_stream_fn = inf_stream_fn def set_iter_first_sample_idx(self, idx): self.start_idx = idx def set_iter_last_sample_idx(self, idx): self.end_idx = idx def __iter__(self) -> Iterator[T_co]: logging.debug(f'在 infinite stream 上定位到{self.start_idx} 为开头') yield from itertools.islice(self.inf_stream_fn(), self.start_idx + self.rank, self.end_idx, self.num_replicas) class Evaluate_ExactSampler_Distributed(Sampler[T_co]): def __init__(self, dataset) -> None: self.dataset = dataset self.rank = comm.get_rank() self.num_replicas = comm.get_world_size() indices = list(range(len(self.dataset))) self.indices = indices[self.rank:len(self.dataset):self.num_replicas] def __iter__(self): yield from self.indices def __len__(self): return len(self.indices) class TrainRandomSampler_ByEpoch_Distributed(Sampler[T_co]): def __init__(self, dataset, num_replicas, rank, seed: int = 0) -> None: if rank >= num_replicas or rank < 0: raise ValueError("Invalid rank {}, rank should be in the interval"" [0, {}]".format(rank, num_replicas - 1)) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = None self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.seed = seed def __iter__(self) -> Iterator[T_co]: seed = self.seed + self.epoch logging.debug(f'generating a new indices permutations for this epoch using seed {seed}') g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] assert len(indices) == self.num_samples self.epoch = None return iter(indices) def __len__(self) -> int: return self.num_samples def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. Args: epoch (int): Epoch number. """ self.epoch = epoch class InferenceSampler(Sampler): """ Produce indices for inference across all workers. Inference needs to run on the __exact__ set of samples, therefore when the total number of samples is not divisible by the number of workers, this sampler produces different number of samples on different workers. """ def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() self._local_indices = self._get_local_indices(size, self._world_size, self._rank) @staticmethod def _get_local_indices(total_size, world_size, rank): shard_size = total_size // world_size left = total_size % world_size shard_sizes = [shard_size + int(r < left) for r in range(world_size)] begin = sum(shard_sizes[:rank]) end = min(sum(shard_sizes[: rank + 1]), total_size) return range(begin, end) def __iter__(self): yield from self._local_indices def __len__(self): return len(self._local_indices) ================================================ FILE: data_schedule/utils/segmentation.py ================================================ import torch def bounding_box_from_mask(mask): if not mask.any(): return torch.zeros([4]).float() rows = torch.any(mask, dim=1) # h cols = torch.any(mask, dim=0) # w row_indexs = torch.where(rows)[0] rmin, rmax = row_indexs.min(), row_indexs.max() col_indexs = torch.where(cols)[0] cmin, cmax = col_indexs.min(), col_indexs.max() return torch.tensor([cmin, rmin, cmax, rmax]).float() # x1y1x2y2 ================================================ FILE: data_schedule/vis/__init__.py ================================================ from . import polyp from . import mapper from . import evaluator_fast from . import vis_aug_eval from . import vis_aug_train from . import vis_frame_sampler ================================================ FILE: data_schedule/vis/apis.py ================================================ class VIS_Dataset: """ """ class VIS_Aug_CallbackAPI: """ """ class VIS_Evaluator_OutAPI_EvalFn_API: """ """ class VIS_TrainAPI_clipped_video: """ """ class VIS_EvalAPI_clipped_video_request_ann: """ """ class VIS_FrameSampler_InputOutput_API: """ """ class GetFrames: """ """ ================================================ FILE: data_schedule/vis/evaluator_fast.py ================================================ import os from tqdm import tqdm from functools import partial import torch import detectron2.utils.comm as comm from utils.misc import to_device from detectron2.data import MetadataCatalog from data_schedule.registry import EVALUATOR_REGISTRY from .evaluator_utils import vis_metric_entrypoint from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann, VIS_Aug_CallbackAPI, VIS_Evaluator_OutAPI_EvalFn_API from collections import defaultdict @EVALUATOR_REGISTRY.register() class VIS_Evaluator_FrameFast: def __init__(self, dataset_name, data_loader, configs) -> None: self.dataset_name = dataset_name self.loader = data_loader frame_metrics = configs['data']['evaluate'][dataset_name]['evaluator']['frame_metrics'] dataset_meta = MetadataCatalog.get(dataset_name) self.frame_metric_fns = [] for metric_name, metric_config in frame_metrics: metric_fn = vis_metric_entrypoint(metric_name) metric_fn = partial(metric_fn, dataset_meta=dataset_meta, **metric_config) self.frame_metric_fns.append(metric_fn) self.eval_meta_keys = dataset_meta.get('eval_meta_keys') metrics_aggregator = configs['data']['evaluate'][dataset_name]['evaluator']['metrics_aggregator'] self.eval_meta_keys = dataset_meta.get('eval_meta_keys') # { video_id: list[fnames] } self.metrics_aggregator = partial(vis_metric_entrypoint(metrics_aggregator[0]), dataset_meta=dataset_meta, eval_meta_keys=self.eval_meta_keys, **metrics_aggregator[1]) def visualize_path(self, meta_idxs, visualize, evaluator_path): return [os.path.join(evaluator_path, f'meta_{meta_idx}') if vis else None for (meta_idx, vis) in zip(meta_idxs, visualize)] @torch.no_grad() def __call__(self, model, output_dir): evaluator_path = os.path.join(output_dir, f'eval_{self.dataset_name}') os.makedirs(evaluator_path, exist_ok=True) macs, params = None, None metrics_by_video_id_frame = defaultdict(dict) for batch_dict in tqdm(self.loader): VIS_EvalAPI_clipped_video_request_ann eval_metas = batch_dict.pop('metas') request_anns = eval_metas['request_ann'][0] # t, bool tensor frame_strs = eval_metas['frames'][0] # t', list[str] video_id = eval_metas['video_id'][0] # str assert request_anns.int().sum() == len(frame_strs) callback_fns = eval_metas['callback_fns'][0] # list[fn] visualize_path = self.visualize_path(meta_idxs=batch_dict['meta_idxs'], visualize=batch_dict['visualize'], evaluator_path=os.path.join(evaluator_path, 'visualize_model')) batch_dict['visualize_paths'] = visualize_path batch_dict = to_device(batch_dict, device=model.device) VIS_Aug_CallbackAPI # if macs is None: # from detectron2.utils.analysis import ( # FlopCountAnalysis, # ) # flops = FlopCountAnalysis(model, batch_dict, inference_func=lambda model, *inputs: model.sample(*inputs)) # total_flops = flops.total() # # counts = flops.by_operator() # logging.debug(f'macs: {total_flops/ (10**9) / len(request_anns)}') model_outputs = model.sample(batch_dict) predictions = { 'video': model_outputs['video'][0], # t 3 h w 'pred_masks': [haosen for idx, haosen in enumerate(model_outputs['pred_masks'][0]) if request_anns[idx]], # list[nt h w], t' 'pred_class': [haosen for idx, haosen in enumerate(model_outputs['pred_class'][0]) if request_anns[idx]], # list[nt c], t', } if 'pred_boxes' in model_outputs: predictions.update({'pred_boxes': [haosen for idx, haosen in enumerate(model_outputs['pred_boxes'][0]) if request_anns[idx]]}) # # list[nt 4], t, for cardib in callback_fns: predictions = cardib(predictions) pred_masks = predictions['pred_masks'] pred_class = predictions['pred_class'] assert len(frame_strs) == len(pred_masks) for idx, (fname, fmk, fclass) in enumerate(zip(frame_strs, pred_masks, pred_class)): VIS_Evaluator_OutAPI_EvalFn_API frame_pred = {'masks': fmk, 'classes': fclass.tolist(), 'video_id': video_id, 'frame_name': fname} if 'pred_boxes' in predictions: frame_pred.update({'boxes': predictions['pred_boxes'][idx]}) meta_key_metrics = {} for metric_fn in self.frame_metric_fns: metric_values = metric_fn(frame_pred=frame_pred, output_dir=evaluator_path) for key, value in metric_values.items(): assert key not in meta_key_metrics meta_key_metrics[key] = value assert fname not in metrics_by_video_id_frame[video_id] metrics_by_video_id_frame[video_id][fname] = meta_key_metrics metrics_by_video_id_frame = comm.gather(dict(metrics_by_video_id_frame), dst=0) eval_metrics = {} if comm.is_main_process(): metrics_by_video = {} for video_id in tqdm(self.eval_meta_keys.keys(), desc='gathering different processes'): video_id_metrics = [haosen[video_id] for haosen in metrics_by_video_id_frame if video_id in haosen] video_id_frame_names = [list(haosen.keys()) for haosen in video_id_metrics] merged_video_id_frame_names = [item for sublist in video_id_frame_names for item in sublist] assert len(set(merged_video_id_frame_names)) == len(merged_video_id_frame_names),'' assert set(merged_video_id_frame_names).issubset(set(self.eval_meta_keys[video_id])) assert set(self.eval_meta_keys[video_id]).issubset(set(merged_video_id_frame_names)) # perframe metrics frame: predictions vid_frame_metrics = video_id_metrics[0] for haosen in video_id_metrics[1:]: vid_frame_metrics.update(haosen) metrics_by_video[video_id] = vid_frame_metrics eval_metrics = self.metrics_aggregator(metrics_by_video) comm.synchronize() return eval_metrics ================================================ FILE: data_schedule/vis/evaluator_utils.py ================================================ _vis_metric_entrypoints = {} def register_vis_metric(fn): vis_metric_name = fn.__name__ if vis_metric_name in _vis_metric_entrypoints: raise ValueError(f'vis_metric name {vis_metric_name} has been registered') _vis_metric_entrypoints[vis_metric_name] = fn return fn def vis_metric_entrypoint(vis_metric_name): try: return _vis_metric_entrypoints[vis_metric_name] except KeyError as e: print(f'vis_metric Name {vis_metric_name} not found') import numpy as np _EPS = np.spacing(1) _TYPE = np.float64 def _prepare_data(pred: np.ndarray, gt: np.ndarray) -> tuple: gt = gt > 128 pred = pred / 255 if pred.max() != pred.min(): pred = (pred - pred.min()) / (pred.max() - pred.min()) return pred, gt class Smeasure(object): def __init__(self, length, alpha: float = 0.5): self.sms = [] self.alpha = alpha def step(self, pred: np.ndarray, gt: np.ndarray, idx): pred, gt = _prepare_data(pred=pred, gt=gt) sm = self.cal_sm(pred, gt) self.sms.append(sm) def cal_sm(self, pred: np.ndarray, gt: np.ndarray) -> float: y = np.mean(gt) if y == 0: sm = 1 - np.mean(pred) elif y == 1: sm = np.mean(pred) else: sm = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) sm = max(0, sm) return sm def object(self, pred: np.ndarray, gt: np.ndarray) -> float: fg = pred * gt bg = (1 - pred) * (1 - gt) u = np.mean(gt) object_score = u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, 1 - gt) return object_score def s_object(self, pred: np.ndarray, gt: np.ndarray) -> float: x = np.mean(pred[gt == 1]) sigma_x = np.std(pred[gt == 1], ddof=1) score = 2 * x / (np.power(x, 2) + 1 + sigma_x + _EPS) return score def region(self, pred: np.ndarray, gt: np.ndarray) -> float: x, y = self.centroid(gt) part_info = self.divide_with_xy(pred, gt, x, y) w1, w2, w3, w4 = part_info['weight'] pred1, pred2, pred3, pred4 = part_info['pred'] gt1, gt2, gt3, gt4 = part_info['gt'] score1 = self.ssim(pred1, gt1) score2 = self.ssim(pred2, gt2) score3 = self.ssim(pred3, gt3) score4 = self.ssim(pred4, gt4) return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 def centroid(self, matrix: np.ndarray) -> tuple: """ To ensure consistency with the matlab code, one is added to the centroid coordinate, so there is no need to use the redundant addition operation when dividing the region later, because the sequence generated by ``1:X`` in matlab will contain ``X``. :param matrix: a bool data array :return: the centroid coordinate """ h, w = matrix.shape area_object = np.count_nonzero(matrix) if area_object == 0: x = np.round(w / 2) y = np.round(h / 2) else: # More details can be found at: https://www.yuque.com/lart/blog/gpbigm y, x = np.argwhere(matrix).mean(axis=0).round() return int(x) + 1, int(y) + 1 def divide_with_xy(self, pred: np.ndarray, gt: np.ndarray, x, y) -> dict: h, w = gt.shape area = h * w gt_LT = gt[0:y, 0:x] gt_RT = gt[0:y, x:w] gt_LB = gt[y:h, 0:x] gt_RB = gt[y:h, x:w] pred_LT = pred[0:y, 0:x] pred_RT = pred[0:y, x:w] pred_LB = pred[y:h, 0:x] pred_RB = pred[y:h, x:w] w1 = x * y / area w2 = y * (w - x) / area w3 = (h - y) * x / area w4 = 1 - w1 - w2 - w3 return dict(gt=(gt_LT, gt_RT, gt_LB, gt_RB), pred=(pred_LT, pred_RT, pred_LB, pred_RB), weight=(w1, w2, w3, w4)) def ssim(self, pred: np.ndarray, gt: np.ndarray) -> float: h, w = pred.shape N = h * w x = np.mean(pred) y = np.mean(gt) sigma_x = np.sum((pred - x) ** 2) / (N - 1) sigma_y = np.sum((gt - y) ** 2) / (N - 1) sigma_xy = np.sum((pred - x) * (gt - y)) / (N - 1) alpha = 4 * x * y * sigma_xy beta = (x ** 2 + y ** 2) * (sigma_x + sigma_y) if alpha != 0: score = alpha / (beta + _EPS) elif alpha == 0 and beta == 0: score = 1 else: score = 0 return score def get_results(self): sm = np.mean(np.array(self.sms, dtype=_TYPE)) return dict(Smeasure=sm) import torch import os from PIL import Image @register_vis_metric def mask_dice_iou(frame_pred, dataset_meta, **kwargs): video_id = frame_pred['video_id'] frame_name = frame_pred['frame_name'] masks = frame_pred['masks'] # nq h w get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn') scores = torch.tensor(frame_pred['classes']) # nq c foreground_scores = scores[:, :-1].sum(-1) # nq max_idx = foreground_scores.argmax() pred_mask = masks[max_idx].int() # h w gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w gt_mask = gt_mask[0].int() # h w inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum() dice = (2*inter+1)/(union+1) iou = (inter+1)/(union-inter+1) return {'dice': dice, 'iou': iou} @register_vis_metric def mask_dice_iou_sen_mae_smeasure(frame_pred, dataset_meta, **kwargs): video_id = frame_pred['video_id'] frame_name = frame_pred['frame_name'] masks = frame_pred['masks'] # nq h w get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn') scores = torch.tensor(frame_pred['classes']) # nq c foreground_scores = scores[:, :-1].sum(-1) # nq max_idx = foreground_scores.argmax() pred_mask = masks[max_idx].int() # h w gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w gt_mask = gt_mask[0].int() # h w # tp, tp*2 + fp + fn inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum() dice = (2*inter+1)/(union+1) # 2*tp / tp + tp + fp + fn iou = (inter+1)/(union-inter+1) # tp / tp + fp + fn tp = (pred_mask * gt_mask).sum().float() fp = (pred_mask.sum() - tp).float() fn = (gt_mask.sum() - tp).float() tn = (pred_mask.shape[0] * pred_mask.shape[1] - (tp + fp + fn)).float() their_dice = tp * 2 / (tp + fp + fn + tp) their_iou = tp / (tp + fp + fn) # their_spe = tn / (tn + fp) their_sen = tp / (tp + fn) their_mae = (pred_mask.float() - gt_mask.float()).abs().mean() Np = gt_mask.sum() Nn = gt_mask.shape[0] * gt_mask.shape[1] - Np null = Smeasure(length=1, alpha=0.5) null.step(pred=(pred_mask.float() * 255 ).numpy(), gt=(gt_mask.float() * 255).numpy(), idx=None) their_smeasure = torch.tensor(null.get_results()['Smeasure']).float() return {'dice': dice, 'iou': iou, 'their_dice': their_dice, 'their_iou': their_iou, 'their_sen': their_sen, 'their_mae_abs': their_mae, 'their_smeasure': their_smeasure, 'tp': tp, # true positive 'fp': fp, # false positive 'fn': fn, # false negative 'tn': tn, # true negative 'Np': Np, # positive accumulation 'Nn': Nn} # negative accumulation @register_vis_metric def web(frame_pred, output_dir, **kwargs): os.makedirs(os.path.join(output_dir, 'web'), exist_ok=True) video_id = frame_pred['video_id'] frame_name = frame_pred['frame_name'] masks = frame_pred['masks'] # nq h w scores = torch.tensor(frame_pred['classes']) # nq c foreground_scores = scores[:, :-1].sum(-1) # nq max_idx = foreground_scores.argmax() pred_mask = masks[max_idx].int() # h w mask = Image.fromarray(255 * pred_mask.int().numpy()).convert('L') save_path = os.path.join(output_dir, 'web', video_id) os.makedirs(save_path, exist_ok=True) png_path = os.path.join(save_path, f'{frame_name}.png') if os.path.exists(png_path): os.remove(png_path) mask.save(png_path) return {} ================================================ FILE: data_schedule/vis/fibroid/__init__.py ================================================ # 注册fibrois数据集 from . import fibroid_dataset # 注册fibroid评估标准 from . import evals ================================================ FILE: data_schedule/vis/fibroid/evals.py ================================================ from data_schedule.vis.evaluator_utils import register_vis_metric import os from glob import glob from tqdm import tqdm import shutil from functools import partial from PIL import Image import numpy as np import torch import detectron2.utils.comm as comm import logging import pycocotools.mask as mask_util from pycocotools.mask import decode as decode_rle import data_schedule.vis.fibroid.metrics as metrics @register_vis_metric def fibroid_other_medi(model_preds, dataset_meta, **kwargs): assert comm.is_main_process() iou_by_test_sample = [] dice_by_test_sample = [] preds_by_test_sample = [] gt_by_test_sample = [] get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn') for pred in model_preds: video_id = pred['video_id'] # str frame_name = pred['frame_name'] # list[str], t' masks = pred['masks']# list[rle], nq scores = pred['scores'] # nq max_idx = torch.tensor(scores).argmax() pred_mask = masks[max_idx] # rle pred_mask = decode_rle(pred_mask) pred_mask = torch.as_tensor(pred_mask, dtype=torch.uint8).contiguous() # h w gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w gt_mask = gt_mask[0].int() # 0/1 preds_by_test_sample.append(pred_mask) gt_by_test_sample.append(gt_mask) tp, fp, fn, tn = metrics.get_stats(pred_mask[None, None, ...], gt_mask[None, None, ...], mode='binary') iou_score = metrics.iou_score(tp, fp, fn, tn, reduction='micro') dice = metrics.dice(tp, fp, fn, tn, reduction='micro') iou_by_test_sample.append(iou_score) dice_by_test_sample.append(dice) mean_iou = torch.tensor(iou_by_test_sample).mean() mean_dice = torch.tensor(dice_by_test_sample).mean() preds_by_test_sample = torch.stack(preds_by_test_sample, dim=0).unsqueeze(1) # N 1 h w gt_by_test_sample = torch.stack(gt_by_test_sample, dim=0).unsqueeze(1) # N 1 h w tp, fp, fn, tn = metrics.get_stats(preds_by_test_sample, gt_by_test_sample, mode='binary') overall_iou = metrics.iou_score(tp, fp, fn, tn, reduction='micro') recall = metrics.recall(tp, fp, fn, tn, reduction='micro-imagewise') precision = metrics.precision(tp, fp, fn, tn, reduction='micro-imagewise') all_medi = { 'mean_iou': mean_iou, 'dice': mean_dice, 'overall_iou': overall_iou, # J/overallIoU 'recall': recall, 'precision': precision, 'F': 2 * precision * recall / (precision + recall) } return all_medi from collections import defaultdict # by_vid, by_frame iou_dict = defaultdict(dict) @register_vis_metric def fibroid_mask_dice_iou(frame_pred, dataset_meta, **kwargs): video_id = frame_pred['video_id'] frame_name = frame_pred['frame_name'] masks = frame_pred['masks'] # nq h w get_frames_gt_mask_fn = dataset_meta.get('get_frames_gt_mask_fn') scores = torch.tensor(frame_pred['classes']) # nq c, 保证c是2 foreground_scores = scores[:, :-1].sum(-1) # nq max_idx = foreground_scores.argmax() pred_mask = masks[max_idx].int() # h w gt_mask, _ = get_frames_gt_mask_fn(video_id=video_id, frames=[frame_name]) # 1 h w gt_mask = gt_mask[0].int() # h w inter, union = (pred_mask*gt_mask).sum(), (pred_mask+gt_mask).sum() dice = (2*inter+1)/(union+1) iou = (inter+1)/(union-inter+1) iou_dict[video_id][frame_name] = iou if iou > 0.6: print(f'video_id: {video_id}, frame: {frame_name}: dice {dice}, iou {iou}') return {'dice': dice, 'iou': iou} @register_vis_metric def fibroid_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs): # output: eval_metrics # video: frame_name: metric/ vid_metrics eval_metrics = {} # video, frame_name # perframe metrics metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]] for taylor_swift in metric_names: eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean() # metrics by each video mean_iou_by_each_video = {} mean_dice_by_each_video = {} for video in eval_meta_keys: mean_iou_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['iou'] for fname in eval_meta_keys[video]]).mean() mean_dice_by_each_video[video] = torch.tensor([metrics_by_vid_frame[video][fname]['dice'] for fname in eval_meta_keys[video]]).mean() mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1])) mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1])) logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}') logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}') return eval_metrics ================================================ FILE: data_schedule/vis/fibroid/fibroid_dataset.py ================================================ from typing import Optional, Union import json import os from functools import partial import numpy as np import torch import logging from tqdm import tqdm import copy from detectron2.data import DatasetCatalog, MetadataCatalog from collections import defaultdict from data_schedule.vis.apis import VIS_Dataset from .fibroid_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR,\ SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX, SET_NAME_TO_GT_TYPE def fibroid_train(step_size, # none / int; 0, 6, 13, 19 ... split_dataset_name, video_ids, video_to_frames): logging.debug(f'{split_dataset_name} Generating metas...') metas = [] for vid_id in tqdm(video_ids): all_frames = sorted(video_to_frames[vid_id]) if step_size is None: metas.append({ 'video_id': vid_id, 'all_frames' : all_frames, 'meta_idx': len(metas), 'all_objs': {1: {'class_label': 0,}} # 语义分割 }) else: for frame_idx in range(0, len(all_frames), step_size): metas.append({ 'video_id': vid_id, 'frame_idx': frame_idx, 'all_frames': all_frames, 'all_objs': {1: {'class_label': 0,}}, 'meta_idx': len(metas) }) logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]') return metas def fibroid_evaluate(eval_video_ids, split_dataset_name, step_size, video_to_frames): if (step_size is not None) and (step_size > 1): logging.warning('为什么 evaluate的时候step size大于1呢') raise ValueError() metas = [] for video_id in eval_video_ids: VIS_Dataset all_frames = sorted(video_to_frames[video_id]) if step_size == None: metas.append({ 'video_id': video_id, 'all_frames': all_frames, 'meta_idx': len(metas) }) else: for frame_idx in range(0, len(all_frames), step_size): metas.append({ 'video_id': video_id, 'frame_idx': frame_idx, 'all_frames': all_frames, 'meta_idx': len(metas) }) logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]') return metas _root = os.getenv('DATASET_PATH') root = os.path.join(_root, 'uterus_myoma/Dataset') visualize_meta_idxs = defaultdict(list) visualize_meta_idxs['fibroid_train_step[6]'] = [] visualize_meta_idxs['fibroid_train'] = [] visualize_meta_idxs['fibroid_train_ste[1]'] = [] visualize_meta_idxs['fibroid_validate'] = [] visualize_meta_idxs['fibroid_validate_step[1]'] = [] visualize_meta_idxs['weakPolyP_fibroid_validate_step[1]'] = [] fibroid_meta = { 'thing_classes': ['rumor', 'not rumor'], 'thing_colors': [(255., 140., 0.), (0., 255., 0.)], } for name in SET_NAME: set_dir = SET_NAME_TO_DIR[name] set_dir = os.path.join(root, set_dir) num_videos = SET_NAME_TO_NUM_VIDEOS[name] video_ids = os.listdir(os.path.join(set_dir, 'Frame')) assert len(video_ids) == num_videos video_to_frames = { vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.png')])\ for vid in video_ids } mode = SET_NAME_TO_MODE[name] prefix = SET_NAME_TO_PREFIX[name] if mode == 'train': train_meta = copy.deepcopy(fibroid_meta) gt_type = SET_NAME_TO_GT_TYPE[name] train_meta.update({ 'mode': 'train', 'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')), 'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, gt_type),), 'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),), }) # train for step_size in [1, 6, None]: step_identifer = '' if step_size is None else f'_step[{step_size}]' split_name = f'{prefix}{step_identifer}' train_meta.update({'name': split_name}) DatasetCatalog.register(split_name, partial(fibroid_train, video_ids=video_ids, split_dataset_name=split_name, step_size=step_size, video_to_frames=video_to_frames,)) MetadataCatalog.get(split_name).set(**train_meta, step_size=step_size, visualize_meta_idxs=visualize_meta_idxs[split_name]) elif mode == 'evaluate': prefix = SET_NAME_TO_PREFIX[name] validate_meta = copy.deepcopy(fibroid_meta) validate_meta.update({ 'mode': 'evaluate', 'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))), 'eval_set_name': SET_NAME_TO_DIR[name], 'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),), 'eval_meta_keys': video_to_frames }) # validate for step_size in [1, None,]: step_identifer = '' if step_size is None else f'_step[{step_size}]' split_name = f'{prefix}{step_identifer}' validate_meta.update({'name': split_name}) DatasetCatalog.register(split_name, partial(fibroid_evaluate, eval_video_ids=video_ids, split_dataset_name=split_name, step_size=step_size, video_to_frames=video_to_frames)) MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size, visualize_meta_idxs=visualize_meta_idxs[split_name]) ================================================ FILE: data_schedule/vis/fibroid/fibroid_utils.py ================================================ import wandb import plotly.express as px import logging import os import numpy as np import torch import json from joblib import Parallel, delayed import multiprocessing import torch.distributed as dist import detectron2.utils.comm as comm import pycocotools.mask as mask_util from pycocotools.mask import encode, area from data_schedule.utils.segmentation import bounding_box_from_mask from data_schedule.utils.video_clips import generate_windows_of_video from glob import glob from PIL import Image def get_frames(frames_path, video_id, frames): return [Image.open(os.path.join(frames_path, video_id, f'{f}.png'),).convert('RGB') for f in frames] # t' h w, int, obj_ids ; has_ann t def get_frames_mask(mask_path, video_id, frames): masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames] masks = [np.array(mk) for mk in masks] masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w masks = (masks > 0).int() return masks, torch.ones(len(frames)).bool() SET_NAME = [ 'fibroid_train', 'fibroid_validate', 'weakpolyp_train', 'fibroid_validate_temp7', 'fibroid_train_temp7', # 'weakpolyp_fibroid_train_temp7', 'fibroid_validate_temp8', 'fibroid_train_temp8', 'weakpolyp_fibroid_train_temp8' ] SET_NAME_TO_DIR = { 'fibroid_train': 'temp/train', 'fibroid_validate': 'temp/test', 'weakpolyp_train': 'temp/uterus_myoma_WeakPolyP_temp/train', 'fibroid_validate_temp7': 'temp7/test', 'fibroid_train_temp7': 'temp7/train', 'weakpolyp_fibroid_train_temp7': 'temp7/uterus_myoma_WeakPolyP_temp7/train', 'fibroid_validate_temp8': 'temp8/test', 'fibroid_train_temp8': 'temp8/train', 'weakpolyp_fibroid_train_temp8': 'temp8/uterus_myoma_WeakPolyP_temp8/train', } SET_NAME_TO_NUM_VIDEOS = { 'fibroid_train': 80, 'fibroid_validate': 20, 'weakpolyp_train': 80, 'fibroid_train_temp7': 85, 'fibroid_validate_temp7': 15, 'weakpolyp_fibroid_train_temp7': 85 , 'fibroid_train_temp8': 83, 'fibroid_validate_temp8': 17, 'weakpolyp_fibroid_train_temp8': 83 } SET_NAME_TO_MODE = { 'fibroid_train': 'train', 'fibroid_validate': 'evaluate', 'weakpolyp_train': 'train', 'fibroid_train_temp7': 'train', 'fibroid_validate_temp7': 'evaluate', 'weakpolyp_fibroid_train_temp7': 'train', 'fibroid_train_temp8': 'train', 'fibroid_validate_temp8': 'evaluate', 'weakpolyp_fibroid_train_temp8': 'train' } SET_NAME_TO_PREFIX = { 'fibroid_train': 'fibroid_train', 'fibroid_validate': 'fibroid_validate', 'weakpolyp_train': 'weakpolyp_fibroid_train', 'fibroid_train_temp7': 'fibroid_train_temp7', 'fibroid_validate_temp7': 'fibroid_validate_temp7', 'weakpolyp_fibroid_train_temp7': 'weakpolyp_fibroid_train_temp7' , 'fibroid_train_temp8': 'fibroid_train_temp8', 'fibroid_validate_temp8': 'fibroid_validate_temp8', 'weakpolyp_fibroid_train_temp8': 'weakpolyp_fibroid_train_temp8' } SET_NAME_TO_GT_TYPE = { 'fibroid_train': 'GT', 'fibroid_validate': 'GT', 'weakpolyp_train': 'Box', 'fibroid_train_temp7': 'GT', 'fibroid_validate_temp7': 'GT', 'weakpolyp_fibroid_train_temp7': 'Box', 'fibroid_train_temp8': 'GT', 'fibroid_validate_temp8': 'GT', 'weakpolyp_fibroid_train_temp8': 'Box' } ================================================ FILE: data_schedule/vis/fibroid/metrics.py ================================================ import warnings from typing import Optional, List, Tuple, Union import torch """Various metrics based on Type I and Type II errors. References: https://en.wikipedia.org/wiki/Confusion_matrix Example: .. code-block:: python import segmentation_models_pytorch as smp # lets assume we have multilabel prediction for 3 classes output = torch.rand([10, 3, 256, 256]) target = torch.rand([10, 3, 256, 256]).round().long() # first compute statistics for true positives, false positives, false negative and # true negative "pixels" tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5) # then compute metrics with required reduction (see metric docs) iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro") f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro") f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro") accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro") recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise") """ __all__ = [ "get_stats", "fbeta_score", "f1_score", "iou_score", "accuracy", "precision", "recall", "sensitivity", "specificity", "balanced_accuracy", "positive_predictive_value", "negative_predictive_value", "false_negative_rate", "false_positive_rate", "false_discovery_rate", "false_omission_rate", "positive_likelihood_ratio", "negative_likelihood_ratio", ] ################################################################################################### # Statistics computation (true positives, false positives, false negatives, false positives) ################################################################################################### def get_stats( output: Union[torch.LongTensor, torch.FloatTensor], target: torch.LongTensor, mode: str, ignore_index: Optional[int] = None, threshold: Optional[Union[float, List[float]]] = None, num_classes: Optional[int] = None, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: """Compute true positive, false positive, false negative, true negative 'pixels' for each image and each class. Args: output (Union[torch.LongTensor, torch.FloatTensor]): Model output with following shapes and types depending on the specified ``mode``: 'binary' shape (N, 1, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` 'multilabel' shape (N, C, ...) and ``torch.LongTensor`` or ``torch.FloatTensor`` 'multiclass' shape (N, ...) and ``torch.LongTensor`` target (torch.LongTensor): Targets with following shapes depending on the specified ``mode``: 'binary' shape (N, 1, ...) 'multilabel' shape (N, C, ...) 'multiclass' shape (N, ...) mode (str): One of ``'binary'`` | ``'multilabel'`` | ``'multiclass'`` ignore_index (Optional[int]): Label to ignore on for metric computation. **Not** supproted for ``'binary'`` and ``'multilabel'`` modes. Defaults to None. threshold (Optional[float, List[float]]): Binarization threshold for ``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None. num_classes (Optional[int]): Number of classes, necessary attribute only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1). If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or ``255``. Raises: ValueError: in case of misconfiguration. Returns: Tuple[torch.LongTensor]: true_positive, false_positive, false_negative, true_negative tensors (N, C) shape each. """ if torch.is_floating_point(target): raise ValueError(f"Target should be one of the integer types, got {target.dtype}.") if torch.is_floating_point(output) and threshold is None: raise ValueError( f"Output should be one of the integer types if ``threshold`` is not None, got {output.dtype}." ) if torch.is_floating_point(output) and mode == "multiclass": raise ValueError(f"For ``multiclass`` mode ``output`` should be one of the integer types, got {output.dtype}.") if mode not in {"binary", "multiclass", "multilabel"}: raise ValueError(f"``mode`` should be in ['binary', 'multiclass', 'multilabel'], got mode={mode}.") if mode == "multiclass" and threshold is not None: raise ValueError("``threshold`` parameter does not supported for this 'multiclass' mode") if output.shape != target.shape: raise ValueError( "Dimensions should match, but ``output`` shape is not equal to ``target`` " + f"shape, {output.shape} != {target.shape}" ) if mode != "multiclass" and ignore_index is not None: raise ValueError(f"``ignore_index`` parameter is not supproted for '{mode}' mode") if mode == "multiclass" and num_classes is None: raise ValueError("``num_classes`` attribute should be not ``None`` for 'multiclass' mode.") if ignore_index is not None and 0 <= ignore_index <= num_classes - 1: raise ValueError( f"``ignore_index`` should be outside the class values range, but got class values in range " f"0..{num_classes - 1} and ``ignore_index={ignore_index}``. Hint: if you have ``ignore_index = 0``" f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``" f"and relevant class values started from ``0``." ) if mode == "multiclass": tp, fp, fn, tn = _get_stats_multiclass(output, target, num_classes, ignore_index) else: if threshold is not None: output = torch.where(output >= threshold, 1, 0) target = torch.where(target >= threshold, 1, 0) tp, fp, fn, tn = _get_stats_multilabel(output, target) return tp, fp, fn, tn @torch.no_grad() def _get_stats_multiclass( output: torch.LongTensor, target: torch.LongTensor, num_classes: int, ignore_index: Optional[int], ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: batch_size, *dims = output.shape num_elements = torch.prod(torch.tensor(dims)).long() if ignore_index is not None: ignore = target == ignore_index output = torch.where(ignore, -1, output) target = torch.where(ignore, -1, target) ignore_per_sample = ignore.view(batch_size, -1).sum(1) tp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) fp_count = torch.zeros(batch_size, num_classes, dtype=torch.long) fn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) tn_count = torch.zeros(batch_size, num_classes, dtype=torch.long) for i in range(batch_size): target_i = target[i] output_i = output[i] mask = output_i == target_i matched = torch.where(mask, target_i, -1) tp = torch.histc(matched.float(), bins=num_classes, min=0, max=num_classes - 1) fp = torch.histc(output_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp fn = torch.histc(target_i.float(), bins=num_classes, min=0, max=num_classes - 1) - tp tn = num_elements - tp - fp - fn if ignore_index is not None: tn = tn - ignore_per_sample[i] tp_count[i] = tp.long() fp_count[i] = fp.long() fn_count[i] = fn.long() tn_count[i] = tn.long() return tp_count, fp_count, fn_count, tn_count @torch.no_grad() def _get_stats_multilabel( output: torch.LongTensor, target: torch.LongTensor, ) -> Tuple[torch.LongTensor, torch.LongTensor, torch.LongTensor, torch.LongTensor]: batch_size, num_classes, *dims = target.shape output = output.view(batch_size, num_classes, -1) target = target.view(batch_size, num_classes, -1) tp = (output * target).sum(2) fp = output.sum(2) - tp fn = target.sum(2) - tp tn = torch.prod(torch.tensor(dims)) - (tp + fp + fn) return tp, fp, fn, tn ################################################################################################### # Metrics computation ################################################################################################### def _handle_zero_division(x, zero_division): nans = torch.isnan(x) if torch.any(nans) and zero_division == "warn": warnings.warn("Zero division in metric calculation!") value = zero_division if zero_division != "warn" else 0 value = torch.tensor(value, dtype=x.dtype).to(x.device) x = torch.where(nans, value, x) return x def _compute_metric( metric_fn, tp, fp, fn, tn, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division="warn", **metric_kwargs, ) -> float: if class_weights is None and reduction is not None and "weighted" in reduction: raise ValueError(f"Class weights should be provided for `{reduction}` reduction") class_weights = class_weights if class_weights is not None else 1.0 class_weights = torch.tensor(class_weights).to(tp.device) class_weights = class_weights / class_weights.sum() if reduction == "micro": tp = tp.sum() fp = fp.sum() fn = fn.sum() tn = tn.sum() score = metric_fn(tp, fp, fn, tn, **metric_kwargs) elif reduction == "macro": tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) tn = tn.sum(0) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score * class_weights).mean() elif reduction == "weighted": tp = tp.sum(0) fp = fp.sum(0) fn = fn.sum(0) tn = tn.sum(0) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score * class_weights).sum() elif reduction == "micro-imagewise": tp = tp.sum(1) fp = fp.sum(1) fn = fn.sum(1) tn = tn.sum(1) score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = score.mean() elif reduction == "macro-imagewise" or reduction == "weighted-imagewise": score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) score = (score.mean(0) * class_weights).mean() elif reduction == "none" or reduction is None: score = metric_fn(tp, fp, fn, tn, **metric_kwargs) score = _handle_zero_division(score, zero_division) else: raise ValueError( "`reduction` should be in [micro, macro, weighted, micro-imagewise," + "macro-imagesize, weighted-imagewise, none, None]" ) return score # Logic for metric computation, all metrics are with the same interface def _fbeta_score(tp, fp, fn, tn, beta=1): beta_tp = (1 + beta**2) * tp beta_fn = (beta**2) * fn score = beta_tp / (beta_tp + beta_fn + fp) return score def _iou_score(tp, fp, fn, tn): return tp / (tp + fp + fn) def _accuracy(tp, fp, fn, tn): return (tp + tn) / (tp + fp + fn + tn) def _sensitivity(tp, fp, fn, tn): return tp / (tp + fn) def _specificity(tp, fp, fn, tn): return tn / (tn + fp) def _balanced_accuracy(tp, fp, fn, tn): return (_sensitivity(tp, fp, fn, tn) + _specificity(tp, fp, fn, tn)) / 2 def _dice(tp, fp, fn, tn): return tp * 2 / (tp + fp + fn + tp) def _positive_predictive_value(tp, fp, fn, tn): return tp / (tp + fp) def _negative_predictive_value(tp, fp, fn, tn): return tn / (tn + fn) def _false_negative_rate(tp, fp, fn, tn): return fn / (fn + tp) def _false_positive_rate(tp, fp, fn, tn): return fp / (fp + tn) def _false_discovery_rate(tp, fp, fn, tn): return 1 - _positive_predictive_value(tp, fp, fn, tn) def _false_omission_rate(tp, fp, fn, tn): return 1 - _negative_predictive_value(tp, fp, fn, tn) def _positive_likelihood_ratio(tp, fp, fn, tn): return _sensitivity(tp, fp, fn, tn) / _false_positive_rate(tp, fp, fn, tn) def _negative_likelihood_ratio(tp, fp, fn, tn): return _false_negative_rate(tp, fp, fn, tn) / _specificity(tp, fp, fn, tn) def fbeta_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, beta: float = 1.0, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """F beta score""" return _compute_metric( _fbeta_score, tp, fp, fn, tn, beta=beta, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def f1_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """F1 score""" return _compute_metric( _fbeta_score, tp, fp, fn, tn, beta=1.0, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def dice( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """IoU score or Jaccard index""" # noqa return _compute_metric( _dice, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def iou_score( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """IoU score or Jaccard index""" # noqa return _compute_metric( _iou_score, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def accuracy( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Accuracy""" return _compute_metric( _accuracy, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def sensitivity( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Sensitivity, recall, hit rate, or true positive rate (TPR)""" return _compute_metric( _sensitivity, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def specificity( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Specificity, selectivity or true negative rate (TNR)""" return _compute_metric( _specificity, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def balanced_accuracy( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Balanced accuracy""" return _compute_metric( _balanced_accuracy, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def positive_predictive_value( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Precision or positive predictive value (PPV)""" return _compute_metric( _positive_predictive_value, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def negative_predictive_value( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Negative predictive value (NPV)""" return _compute_metric( _negative_predictive_value, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def false_negative_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Miss rate or false negative rate (FNR)""" return _compute_metric( _false_negative_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def false_positive_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Fall-out or false positive rate (FPR)""" return _compute_metric( _false_positive_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def false_discovery_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """False discovery rate (FDR)""" # noqa return _compute_metric( _false_discovery_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def false_omission_rate( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """False omission rate (FOR)""" # noqa return _compute_metric( _false_omission_rate, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def positive_likelihood_ratio( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Positive likelihood ratio (LR+)""" return _compute_metric( _positive_likelihood_ratio, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) def negative_likelihood_ratio( tp: torch.LongTensor, fp: torch.LongTensor, fn: torch.LongTensor, tn: torch.LongTensor, reduction: Optional[str] = None, class_weights: Optional[List[float]] = None, zero_division: Union[str, float] = 1.0, ) -> torch.Tensor: """Negative likelihood ratio (LR-)""" return _compute_metric( _negative_likelihood_ratio, tp, fp, fn, tn, reduction=reduction, class_weights=class_weights, zero_division=zero_division, ) _doc = """ Args: tp (torch.LongTensor): tensor of shape (N, C), true positive cases fp (torch.LongTensor): tensor of shape (N, C), false positive cases fn (torch.LongTensor): tensor of shape (N, C), false negative cases tn (torch.LongTensor): tensor of shape (N, C), true negative cases reduction (Optional[str]): Define how to aggregate metric between classes and images: - 'micro' Sum true positive, false positive, false negative and true negative pixels over all images and all classes and then compute score. - 'macro' Sum true positive, false positive, false negative and true negative pixels over all images for each label, then compute score for each label separately and average labels scores. This does not take label imbalance into account. - 'weighted' Sum true positive, false positive, false negative and true negative pixels over all images for each label, then compute score for each label separately and average weighted labels scores. - 'micro-imagewise' Sum true positive, false positive, false negative and true negative pixels for **each image**, then compute score for **each image** and average scores over dataset. All images contribute equally to final score, however takes into accout class imbalance for each image. - 'macro-imagewise' Compute score for each image and for each class on that image separately, then compute average score on each image over labels and average image scores over dataset. Does not take into account label imbalance on each image. - 'weighted-imagewise' Compute score for each image and for each class on that image separately, then compute weighted average score on each image over labels and average image scores over dataset. - 'none' or ``None`` Same as ``'macro-imagewise'``, but without any reduction. For ``'binary'`` case ``'micro' = 'macro' = 'weighted'`` and ``'micro-imagewise' = 'macro-imagewise' = 'weighted-imagewise'``. Prefixes ``'micro'``, ``'macro'`` and ``'weighted'`` define how the scores for classes will be aggregated, while postfix ``'imagewise'`` defines how scores between the images will be aggregated. class_weights (Optional[List[float]]): list of class weights for metric aggregation, in case of `weighted*` reduction is chosen. Defaults to None. zero_division (Union[str, float]): Sets the value to return when there is a zero division, i.e. when all predictions and labels are negative. If set to “warn”, this acts as 0, but warnings are also raised. Defaults to 1. Returns: torch.Tensor: if ``'reduction'`` is not ``None`` or ``'none'`` returns scalar metric, else returns tensor of shape (N, C) References: https://en.wikipedia.org/wiki/Confusion_matrix """ fbeta_score.__doc__ += _doc f1_score.__doc__ += _doc iou_score.__doc__ += _doc accuracy.__doc__ += _doc sensitivity.__doc__ += _doc specificity.__doc__ += _doc balanced_accuracy.__doc__ += _doc positive_predictive_value.__doc__ += _doc negative_predictive_value.__doc__ += _doc false_negative_rate.__doc__ += _doc false_positive_rate.__doc__ += _doc false_discovery_rate.__doc__ += _doc false_omission_rate.__doc__ += _doc positive_likelihood_ratio.__doc__ += _doc negative_likelihood_ratio.__doc__ += _doc precision = positive_predictive_value recall = sensitivity ================================================ FILE: data_schedule/vis/mapper.py ================================================ import json import os from typing import List import copy from functools import partial import random import numpy as np import torch import logging from einops import rearrange from detectron2.data import MetadataCatalog from data_schedule.registry import MAPPER_REGISTRY from .mapper_utils import VIS_TrainMapper, VIS_EvalMapper from .vis_frame_sampler import VIS_FRAMES_SAMPLER_REGISTRY from data_schedule.vis.apis import VIS_Dataset, VIS_Aug_CallbackAPI,\ VIS_TrainAPI_clipped_video, VIS_EvalAPI_clipped_video_request_ann @MAPPER_REGISTRY.register() class VIS_Video_EvalMapper(VIS_EvalMapper): def __init__(self, configs, dataset_name, mode, meta_idx_shift, ): assert mode == 'evaluate' dataset_meta = MetadataCatalog.get(dataset_name) assert dataset_meta.get('step_size') == None mapper_config = configs['data'][mode][dataset_name]['mapper'] super().__init__(meta_idx_shift=meta_idx_shift, dataset_meta=dataset_meta, mapper_config=mapper_config) def _call(self, data_dict): VIS_Dataset video_id, all_frames = data_dict['video_id'], data_dict['all_frames'] video_frames = self.get_frames_fn(video_id=video_id, frames=all_frames) aug_ret = { 'video': video_frames, 'callback_fns': [] } VIS_Aug_CallbackAPI aug_ret = self.augmentation(aug_ret) video = aug_ret.pop('video') callback_fns = aug_ret.pop('callback_fns')[::-1] VIS_EvalAPI_clipped_video_request_ann return { 'video_dict': {'video': video}, 'meta': { 'video_id': video_id, 'frames': all_frames, 'request_ann': torch.ones(len(all_frames)).bool(), 'callback_fns': callback_fns } } @MAPPER_REGISTRY.register() class VIS_Video_or_Step_To_Clip_TrainMapper(VIS_TrainMapper): def __init__(self, dataset_name, configs, mode, meta_idx_shift, ): assert mode == 'train' dataset_meta = MetadataCatalog.get(dataset_name) assert dataset_meta.get('name') == dataset_name mapper_config = configs['data'][mode][dataset_name]['mapper'] super().__init__(meta_idx_shift=meta_idx_shift, dataset_meta=dataset_meta, mapper_config=mapper_config) self.frames_sampler = VIS_FRAMES_SAMPLER_REGISTRY.get(\ mapper_config['frames_sampler']['name'])(sampler_configs=mapper_config['frames_sampler'], dataset_meta=dataset_meta) def _call(self, data_dict): VIS_Dataset video_id, all_frames, all_objs = data_dict['video_id'], data_dict['all_frames'], data_dict['all_objs'] frame_idx = data_dict['frame_idx'] if 'frame_idx' in data_dict else None all_obj_ids = list(all_objs.keys()) # [1, 2, 5, 4] assert len(list(set(all_obj_ids))) == len(all_obj_ids) class_labels = torch.tensor([all_objs[key]['class_label'] for key in all_obj_ids]) # [8, 10, 20 34] re_sample = True sampled_counts = 0 while re_sample: sampled_frames = self.frames_sampler(all_frames=all_frames, frame_idx=frame_idx, video_id=video_id) # t' h w, int, obj_ids ; has_ann t frames_mask, has_ann = self.get_frames_mask_fn(video_id=video_id, frames=sampled_frames) appear_objs = frames_mask.unique() # [0, 1, 2] assert set(appear_objs.tolist()).issubset(set([0] + all_obj_ids)) re_sample = (len(list(set(appear_objs.tolist()) & set(all_obj_ids))) == 0) # 只要出现某些个物体就行 sampled_counts += 1 if sampled_counts > 2: logging.error('sampled two much times') raise RuntimeError() frames_mask = torch.stack([frames_mask == obj_id for obj_id in all_obj_ids], dim=0) # N t' h w, bool video_frames = self.get_frames_fn(video_id=video_id, frames=sampled_frames) width, height = video_frames[0].size aug_ret = { 'video': video_frames, 'masks': frames_mask, # N t' h w 'has_ann': has_ann, # t 'classes': class_labels, # N } VIS_Aug_CallbackAPI aug_ret = self.augmentation(aug_ret) video = aug_ret.pop('video') frame_targets = self.map_to_frame_targets(aug_ret) if self.clip_global_targets_map_to_local_targets: aug_ret = self.map_global_targets_to_local_targets(aug_ret) VIS_TrainAPI_clipped_video ret = {} ret['video_dict'] = {'video': video} ret['targets'] = aug_ret ret['frame_targets'] = frame_targets return ret ================================================ FILE: data_schedule/vis/mapper_utils.py ================================================ from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY, VIS_TRAIN_AUG_REGISTRY import torch from copy import deepcopy as dcopy from data_schedule.registry import Mapper import copy from data_schedule.vis.apis import VIS_TrainAPI_clipped_video class VIS_Mapper(Mapper): def __init__(self, meta_idx_shift, dataset_meta,) -> None: super().__init__(meta_idx_shift=meta_idx_shift, dataset_meta=dataset_meta) self.get_frames_fn = dataset_meta.get('get_frames_fn') class VIS_TrainMapper(VIS_Mapper): def __init__(self, meta_idx_shift, dataset_meta, mapper_config) -> None: super().__init__(meta_idx_shift, dataset_meta) self.get_frames_mask_fn = dataset_meta.get('get_frames_mask_fn') self.clip_global_targets_map_to_local_targets = mapper_config['clip_global_targets_map_to_local_targets'] self.augmentation = VIS_TRAIN_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation']) def map_to_frame_targets(self, clip_targets): VIS_TrainAPI_clipped_video clip_rets = copy.deepcopy(clip_targets) masks = clip_rets['masks'].transpose(0, 1).contiguous() # t' N h w class_labels = clip_rets['classes'] # [10, 32, 10, 4] has_box = 'boxes' in clip_rets if has_box: boxes = clip_rets['boxes'].transpose(0, 1).contiguous() # t' N 4 assert len(masks) == len(boxes) ret = [] for idx, frame_mk in enumerate(masks): frame_targets = { 'masks': frame_mk.unsqueeze(1), # N 1 h w 'classes': class_labels, # N } if has_box: frame_targets.update({'boxes': boxes[idx].unsqueeze(1)}) # N 1 4 if self.clip_global_targets_map_to_local_targets: frame_targets = self.map_global_targets_to_local_targets(frame_targets) frame_targets['masks'] = frame_targets['masks'].squeeze(1) if has_box: frame_targets['boxes'] = frame_targets['boxes'].squeeze(1) ret.append(frame_targets) return ret def map_global_targets_to_local_targets(self, ret): VIS_TrainAPI_clipped_video masks = ret['masks'] # N t' h w global_obj_appear = masks.flatten(1).any(-1) # N [True, False, True, False, False, False, True] ret['masks'] = ret['masks'][global_obj_appear] ret['classes'] = ret['classes'][global_obj_appear] if 'boxes' in ret: ret['boxes'] = ret['boxes'][global_obj_appear] # n t' 4 return ret class VIS_EvalMapper(VIS_Mapper): def __init__(self, meta_idx_shift, dataset_meta, mapper_config) -> None: super().__init__(meta_idx_shift, dataset_meta) assert mapper_config['augmentation']['name'] in ['WeakPolyP_EvalAug', 'Visha_EvalAug'] self.augmentation = VIS_EVAL_AUG_REGISTRY.get(mapper_config['augmentation']['name'])(mapper_config['augmentation']) ================================================ FILE: data_schedule/vis/polyp/__init__.py ================================================ from . import polyp_dataset from . import evals ================================================ FILE: data_schedule/vis/polyp/evals.py ================================================ from data_schedule.vis.evaluator_utils import register_vis_metric import os import torch import detectron2.utils.comm as comm import logging import subprocess @register_vis_metric def polyp_metric_aggregator(metrics_by_vid_frame, dataset_meta, eval_meta_keys, **kwargs): # output: eval_metrics # video: frame_name: metric/ vid_metrics eval_metrics = {} # video, frame_name # perframe metrics metric_names = metrics_by_vid_frame[list(eval_meta_keys.keys())[0]][eval_meta_keys[list(eval_meta_keys.keys())[0]][0]] for taylor_swift in metric_names: eval_metrics[taylor_swift] = torch.tensor([metrics_by_vid_frame[video][frame][taylor_swift] for video in eval_meta_keys.keys() for frame in eval_meta_keys[video]]).mean() # metrics by each video mean_iou_by_each_video = {} mean_dice_by_each_video = {} for billie_eilish in eval_meta_keys: mean_iou_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['iou'] for fname in eval_meta_keys[billie_eilish]]).mean() mean_dice_by_each_video[billie_eilish] = torch.tensor([metrics_by_vid_frame[billie_eilish][fname]['dice'] for fname in eval_meta_keys[billie_eilish]]).mean() mean_iou_by_each_video = dict(sorted(mean_iou_by_each_video.items(), key=lambda x: x[1])) mean_dice_by_each_video = dict(sorted(mean_dice_by_each_video.items(), key=lambda x: x[1])) logging.debug(f'mean_iou_by_each_video: {mean_iou_by_each_video}') logging.debug(f'mean_dice_by_each_video: {mean_dice_by_each_video}') return eval_metrics ================================================ FILE: data_schedule/vis/polyp/polyp_dataset.py ================================================ from typing import Optional, Union import json import os from functools import partial import numpy as np import torch import logging from tqdm import tqdm import copy from detectron2.data import DatasetCatalog, MetadataCatalog from collections import defaultdict from .polyp_utils import get_frames, get_frames_mask, SET_NAME_TO_DIR, SET_NAME, SET_NAME_TO_NUM_VIDEOS, SET_NAME_TO_MODE, SET_NAME_TO_PREFIX def polyp_train(step_size, split_dataset_name, video_ids, video_to_frames, root_path): logging.debug(f'{split_dataset_name} Generating metas...') metas = [] for vid_id in tqdm(video_ids): all_frames = sorted(video_to_frames[vid_id]) poly_class = 0 if step_size is None: metas.append({ 'video_id': vid_id, 'all_frames' : all_frames, 'all_objs': { 1: {'class_label': poly_class} }, 'meta_idx': len(metas) }) else: for frame_idx in range(0, len(all_frames), step_size): metas.append({ 'video_id': vid_id, 'frame_idx': frame_idx, 'all_frames': all_frames, 'all_objs': { 1: {'class_label': poly_class} }, 'meta_idx': len(metas) }) logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]') return metas def polyp_evaluate(eval_video_ids, split_dataset_name, step_size, video_to_frames): if (step_size is not None) and (step_size > 1): logging.warning('why?') raise ValueError() metas = [] for video_id in eval_video_ids: all_frames = sorted(video_to_frames[video_id]) if step_size == None: metas.append({ 'video_id': video_id, 'all_frames': all_frames, 'meta_idx': len(metas) }) else: for frame_idx in range(0, len(all_frames), step_size): metas.append({ 'video_id': video_id, 'frame_idx': frame_idx, 'all_frames': all_frames, 'meta_idx': len(metas) }) logging.debug(f'{split_dataset_name} Total metas: [{len(metas)}]') return metas _root = os.getenv('DATASET_PATH') root = os.path.join(_root, 'SUN/SUN-SEG2') visualize_meta_idxs = defaultdict(list) visualize_meta_idxs['polyp_train_step[6]'] = [] visualize_meta_idxs['polyp_train'] = [] visualize_meta_idxs['polyp_hard_unseen'] = [] visualize_meta_idxs['polyp_hard_seen'] = [] visualize_meta_idxs['polyp_easy_unseen'] = [] visualize_meta_idxs['polyp_easy_seen'] = [] polyp_meta = { 'thing_classes': ['polyp', 'not polyp'], 'thing_colors': [(255., 140., 0.), (0., 255., 0.)], 'root': root } for name in SET_NAME: set_dir = SET_NAME_TO_DIR[name] set_dir = os.path.join(root, set_dir) num_videos = SET_NAME_TO_NUM_VIDEOS[name] video_ids = os.listdir(os.path.join(set_dir, 'Frame')) assert len(video_ids) == num_videos video_to_frames = { vid: sorted([png[:-4] for png in os.listdir(os.path.join(set_dir, 'Frame', vid)) if png.endswith('.jpg')])\ for vid in video_ids } mode = SET_NAME_TO_MODE[name] if mode == 'train': prefix = SET_NAME_TO_PREFIX[name] train_meta = copy.deepcopy(polyp_meta) train_meta.update({ 'mode': 'train', 'get_frames_fn': partial(get_frames, frames_path=os.path.join(set_dir, 'Frame')), 'get_frames_mask_fn': partial(get_frames_mask, mask_path=os.path.join(set_dir, 'GT'),), }) # train for step_size in [1, 3, 6, 9, 12, None]: step_identifer = '' if step_size is None else f'_step[{step_size}]' split_name = f'{prefix}{step_identifer}' train_meta.update({'name': split_name}) DatasetCatalog.register(split_name, partial(polyp_train, video_ids=video_ids, split_dataset_name=split_name, step_size=step_size, video_to_frames=video_to_frames, root_path=set_dir)) MetadataCatalog.get(split_name).set(**train_meta, step_size=step_size, visualize_meta_idxs=visualize_meta_idxs[split_name]) elif mode == 'evaluate': prefix = SET_NAME_TO_PREFIX[name] validate_meta = copy.deepcopy(polyp_meta) validate_meta.update({ 'mode': 'evaluate', 'get_frames_fn': partial(get_frames, frames_path=os.path.join(root, os.path.join(set_dir, 'Frame'))), 'eval_set_name': SET_NAME_TO_DIR[name], 'get_frames_gt_mask_fn': partial(get_frames_mask, mask_path=os.path.join(root, os.path.join(set_dir, 'GT')),), 'eval_meta_keys': video_to_frames }) for step_size in [1, None,]: step_identifer = '' if step_size is None else f'_step[{step_size}]' split_name = f'{prefix}{step_identifer}' validate_meta.update({'name': split_name}) DatasetCatalog.register(split_name, partial(polyp_evaluate, eval_video_ids=video_ids, split_dataset_name=split_name, step_size=step_size, video_to_frames=video_to_frames)) MetadataCatalog.get(split_name).set(**validate_meta, step_size=step_size, visualize_meta_idxs=visualize_meta_idxs[split_name]) ================================================ FILE: data_schedule/vis/polyp/polyp_utils.py ================================================ import os import numpy as np import torch from PIL import Image SET_NAME = ['polyp_train', 'polyp_hard_seen_validate', 'polyp_hard_unseen_validate', 'polyp_easy_seen_validate', 'polyp_easy_unseen_validate', 'polyp_hard_validate', 'polyp_easy_validate', 'Kvasir-train', 'Mayo-train', '300-train', '612-train', '300-tv', '612-test', '612-val' ] SET_NAME_TO_DIR = { 'polyp_train': 'TrainDataset', 'polyp_hard_seen_validate': 'TestHardDataset/Seen', 'polyp_hard_unseen_validate': 'TestHardDataset/Unseen', 'polyp_easy_seen_validate': 'TestEasyDataset/Seen', 'polyp_easy_unseen_validate': 'TestEasyDataset/Unseen', 'polyp_hard_validate': 'TestHardDataset/Combine', 'polyp_easy_validate': 'TestEasyDataset/Combine', 'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG', 'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train', '300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train', '612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train', '300-tv': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ColonDB-300', '612-test': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Test', '612-val': 'MICCAI-VPS-dataset/VPS-TestSet/CVC-ClinicDB-612-Valid' } SET_NAME_TO_NUM_VIDEOS = { 'polyp_train': 112, 'polyp_hard_seen_validate': 17, 'polyp_hard_unseen_validate': 37, 'polyp_easy_seen_validate': 33, 'polyp_easy_unseen_validate': 86, 'polyp_hard_validate': 54, 'polyp_easy_validate': 119, 'Kvasir-train': 1, 'Mayo-train': 10, '300-train': 6, '612-train': 18, '300-tv': 6, '612-test': 5, '612-val': 5 } SET_NAME_TO_MODE = { 'polyp_train': 'train', 'polyp_hard_seen_validate': 'evaluate', 'polyp_hard_unseen_validate': 'evaluate', 'polyp_easy_seen_validate': 'evaluate', 'polyp_easy_unseen_validate': 'evaluate', 'polyp_hard_validate': 'evaluate', 'polyp_easy_validate': 'evaluate', 'Kvasir-train': 'train', 'Mayo-train': 'train', '300-train': 'train', '612-train': 'train', '300-tv': 'evaluate', '612-test': 'evaluate', '612-val': 'evaluate' } SET_NAME_TO_PREFIX = { 'polyp_train': 'polyp_train', 'polyp_hard_seen_validate': 'polyp_hard_seen_validate', 'polyp_hard_unseen_validate': 'polyp_hard_unseen_validate', 'polyp_easy_seen_validate': 'polyp_easy_seen_validate', 'polyp_easy_unseen_validate': 'polyp_easy_unseen_validate', 'polyp_hard_validate': 'polyp_hard_validate', 'polyp_easy_validate': 'polyp_easy_validate', 'Kvasir-train': 'Kvasir-train', 'Mayo-train': 'Mayo-train', '300-train': '300-train', '612-train': '612-train', '300-tv': '300-tv', '612-test': '612-test', '612-val': '612-val' } CLASS_TO_ID = { 'high_grade_adenoma':0, 'hyperplastic_polyp':1, 'invasive_cancer':2, 'low_grade_adenoma':3, 'sessile_serrated_lesion':4, 'traditional_serrated_adenoma':5 } def get_frames(frames_path, video_id, frames): return [Image.open(os.path.join(frames_path, video_id, f'{f}.jpg')).convert('RGB') for f in frames] def get_frames_mask(mask_path, video_id, frames): # masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames] if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')): masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames] elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')): masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames] else: raise ValueError() masks = [np.array(mk) for mk in masks] masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w # assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}' masks = (masks > 0).int() return masks, torch.ones(len(frames)).bool() ================================================ FILE: data_schedule/vis/vis_aug_eval.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import random import torch import torchvision.transforms.functional as F from data_schedule.vis.apis import VIS_Aug_CallbackAPI from .vis_aug_utils import get_tgt_size from .vis_aug_utils import VIS_EVAL_AUG_REGISTRY class RandomResize: def __init__(self, sizes, max_size=None): assert isinstance(sizes, (list, tuple)) self.sizes = sizes self.max_size = max_size def __call__(self, ret): video = ret['video'] orig_size = video[0].size # w h tgt_size = get_tgt_size(video[0].size, random.choice(self.sizes), self.max_size) # h w resized_video = [F.resize(frame, tgt_size) for frame in video] ratio_width, ratio_height = tuple(float(s) / float(s_orig) for s, s_orig in zip(tgt_size[::-1], orig_size)) ret['video'] = resized_video if 'callback_fns' in ret: VIS_Aug_CallbackAPI ret['callback_fns'].append(RandomResize(sizes=[orig_size], max_size=None)) if "pred_masks" in ret: assert (len(self.sizes) == 1) and (self.max_size == None) VIS_Aug_CallbackAPI pred_masks = ret['pred_masks'] # list[nt h w], t pred_masks = [torch.nn.functional.interpolate(mk.unsqueeze(0).float(), tgt_size, mode='nearest')[0].bool() for mk in pred_masks] ret['pred_masks'] = pred_masks # list[nt h w], t if "pred_boxes" in ret: VIS_Aug_CallbackAPI pred_boxes = ret["pred_boxes"] # list[nt 4], t scaled_boxes = [bx * (torch.tensor([ratio_width, ratio_height, ratio_width, ratio_height])[None, :]) for bx in pred_boxes] ret["pred_boxes"] = scaled_boxes return ret class VideoToPIL: def __call__(self, ret): video = ret['video'] # t 3 h w -> assert video.dtype == torch.float and (video.max() <= 1) and (video.min() >=0) pil_video = [F.to_pil_image(frame, mode='RGB') for frame in video] # 3 h w, float, 0-1 ret['video'] = pil_video assert 'callback_fns' not in ret return ret class VideoToTensor: def __call__(self, ret): video = ret['video'] tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1 ret['video'] = tensor_video if 'callback_fns' in ret: VIS_Aug_CallbackAPI ret['callback_fns'].append(VideoToPIL()) return ret @VIS_EVAL_AUG_REGISTRY.register() class WeakPolyP_EvalAug: def __init__(self, configs) -> None: self.resize = RandomResize( sizes=[[352, 352]], ) self.tensor_video = VideoToTensor() def __call__(self, ret): VIS_Aug_CallbackAPI ret = self.resize(ret) ret = self.tensor_video(ret) return ret ================================================ FILE: data_schedule/vis/vis_aug_train.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import random from PIL import Image import torch import torchvision.transforms.functional as F from einops import rearrange from copy import deepcopy as dcopy from data_schedule.vis.apis import VIS_Aug_CallbackAPI import albumentations as A import numpy as np from data_schedule.utils.segmentation import bounding_box_from_mask from .vis_aug_utils import VIS_TRAIN_AUG_REGISTRY, pil_torch_to_numpy, numpy_to_pil_torch import copy import imgaug.augmenters as iaa from imgaug.augmentables.segmaps import SegmentationMapsOnImage from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage import imgaug from datetime import datetime class RandomRotate90: def __init__(self) -> None: self.album_aug = A.ReplayCompose( [A.RandomRotate90(0.5)] ) def __call__(self, ret): video = ret['video'] masks = ret['masks'] has_ann = ret['has_ann'] # list[PIL], n t' h w -> # list[h w 3, 255rgb], t # list[list[h w, 01uint8]] t video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann) replay = self.album_aug(image=video[0], mask=[masks[0][0]])['replay'] auged_video = [] auged_mask = [] for vid, mk in zip(video, masks): ret = self.album_aug.replay(replay, image=vid, mask=mk) auged_video.append(ret['image']) auged_mask.append(ret['mask']) auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, auged_mask=auged_mask, has_ann=has_ann) ret['video'] = auged_video ret['mask'] = auged_mask return ret class ComputeBox: def __call__(self, ret): W, H = ret['video'][0].size N, T = ret['masks'].shape[:2] # n t' h w boxes = torch.stack([bounding_box_from_mask(mask) for mask in copy.deepcopy(ret['masks']).flatten(0, 1)], dim=0) # Nt' 4 boxes = rearrange(boxes, '(N T) c -> N T c', N=N, T=T) boxes[:, :, 0::2].clamp_(min=0, max=W) boxes[:, :, 1::2].clamp_(min=0, max=H) ret['boxes'] = boxes return ret class VideoToTensor: def __call__(self, ret): video = ret['video'] tensor_video = torch.stack([F.to_tensor(frame) for frame in video], dim=0) # t 3 h w, float, 0-1 ret['video'] = tensor_video return ret class Compose: def __init__(self, transforms): self.transforms = transforms def __call__(self, ret): for t in self.transforms: ret = t(ret) return ret def __repr__(self): format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += " {0}".format(t) format_string += "\n)" return format_string @VIS_TRAIN_AUG_REGISTRY.register() class WeakPolyP_TrainAug: def __init__(self, configs) -> None: self.transform = A.ReplayCompose([ A.Resize(352, 352), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), ]) self.tensor_video = VideoToTensor() self.add_box = ComputeBox() def __call__(self, ret): VIS_Aug_CallbackAPI video = ret['video'] masks = ret['masks'] # n t' h w has_ann = ret['has_ann'] # t # list[PIL] -> list[h w 3, 0-1float], t # n t' h w -> list[list[h w, 01uint8], 没有annotation的帧box是空] t video, masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann) replay = self.transform(image=video[0], masks=[masks[0][0]])['replay'] auged_video = [] auged_mask = [] for vid, mk in zip(video, masks): auged_each_frame = self.transform.replay(replay, image=vid, masks=mk) auged_video.append(auged_each_frame['image']) auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8] auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann) ret['video'] = auged_video ret['masks'] = auged_mask ret = self.add_box(ret) ret = self.tensor_video(ret) return ret @VIS_TRAIN_AUG_REGISTRY.register() class WeakPolyP_TrainAug_RotateImageToClip: def __init__(self, configs) -> None: self.ImageToSeqAugmenter = ImageToSeqAugmenter(perspective=True, affine=True, motion_blur=True, rotation_range=(-20, 20), perspective_magnitude=0.08, hue_saturation_range=(-5, 5), brightness_range=(-40, 40), motion_blur_prob=0.25, motion_blur_kernel_sizes=(9, 11), translate_range=(-0.1, 0.1)) self.num_frames = configs['num_frames'] self.transform = A.ReplayCompose([ A.Resize(352, 352), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.RandomRotate90(p=0.5), ]) self.tensor_video = VideoToTensor() self.add_box = ComputeBox() def apply_random_sequence_shuffle(self, images, instance_masks): perm = list(range(self.num_frames)) random.shuffle(perm) images = [images[i] for i in perm] instance_masks = [instance_masks[i] for i in perm] return images, instance_masks def __call__(self, ret): VIS_Aug_CallbackAPI video = ret['video'] # list[pil], t masks = ret['masks'] # n t' h w has_ann = ret['has_ann'] # t # list[PIL] -> list[h w 3, uint8], t # n t' h w -> list[list[h w], n, uint8], t seq_images, seq_instance_masks = pil_torch_to_numpy(video=video, masks=masks, has_ann=has_ann, float_image=False) assert len(seq_images) == 1 and len(seq_instance_masks) == 1 static_img, static_mask = seq_images[0], seq_instance_masks[0] for t in range(self.num_frames - 1): im_trafo, instance_masks_trafo = self.ImageToSeqAugmenter(static_img, static_mask) # h w 3, uint8; list[h w], n, uint8 seq_images.append(np.uint8(im_trafo)) seq_instance_masks.append(instance_masks_trafo) # list[h w 3], t ; # list[list[h w, 01uint8]] t seq_images, seq_instance_masks = self.apply_random_sequence_shuffle(seq_images, seq_instance_masks) has_ann = torch.ones(self.num_frames).bool() # T seq_images = [np.float32(haosen) / 255.0 for haosen in seq_images] # list[h w 3, 0-1float], t replay = self.transform(image=seq_images[0], masks=[seq_instance_masks[0][0]])['replay'] auged_video = [] auged_mask = [] for vid, mk in zip(seq_images, seq_instance_masks): auged_each_frame = self.transform.replay(replay, image=vid, masks=mk) auged_video.append(auged_each_frame['image']) auged_mask.append(auged_each_frame['masks']) # list[h w, 01uint8] auged_video, auged_mask = numpy_to_pil_torch(video=auged_video, masks=auged_mask, has_ann=has_ann) # n t h w # [haosen.save(f'./test{idx}.png') for idx, haosen in enumerate(auged_video)] # import matplotlib.pyplot as plt # [plt.imsave( f'./mask{idx}.png', auged_mask[0][idx].float().numpy()) for idx in range(len(auged_mask[0]))] ret['video'] = auged_video ret['masks'] = auged_mask ret['has_ann'] = has_ann ret = self.add_box(ret) ret = self.tensor_video(ret) return ret class ImageToSeqAugmenter(object): def __init__(self, perspective=True, affine=True, motion_blur=True, brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12, scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20), motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5, seed=2024): self.basic_augmenter = iaa.SomeOf((1, None), [ iaa.Add(brightness_range), iaa.AddToHueAndSaturation(hue_saturation_range) ] ) transforms = [] if perspective: transforms.append(iaa.PerspectiveTransform(perspective_magnitude)) if affine: transforms.append(iaa.Affine(scale=scale_range, translate_percent=translate_range, rotate=rotation_range, order=1, # cv2.INTER_LINEAR backend='auto')) transforms = iaa.Sequential(transforms) transforms = [transforms] if motion_blur: blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf( [ iaa.MotionBlur(ksize) for ksize in motion_blur_kernel_sizes ] )) transforms.append(blur) self.frame_shift_augmenter = iaa.Sequential(transforms) self.seed = seed @staticmethod def condense_masks(instance_masks): condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8) for instance_id, mask in enumerate(instance_masks, 1): condensed_mask = np.where(mask, instance_id, condensed_mask) return condensed_mask @staticmethod def expand_masks(condensed_mask, num_instances): return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)] def __call__(self, image, masks=None, boxes=None): # n h w det_augmenter = self.frame_shift_augmenter.to_deterministic() if masks is not None: masks_np, is_binary_mask = [], [] boxs_np = [] for mask in masks: if isinstance(mask, np.ndarray): masks_np.append(mask.astype(np.bool_)) is_binary_mask.append(False) else: raise ValueError("Invalid mask type: {}".format(type(mask))) num_instances = len(masks_np) masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2]) # boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2]) seed = int(datetime.now().strftime('%M%S%f')[-8:]) imgaug.seed(seed) aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np) imgaug.seed(seed) invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2) aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances) # aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image() aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)] return aug_image, aug_masks #, aug_boxes.to_xyxy_array() else: masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])] aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks) return aug_image, invalid_pts_mask.get_arr() == 0 ================================================ FILE: data_schedule/vis/vis_aug_utils.py ================================================ from detectron2.utils.registry import Registry import torch import numpy as np import torchvision.transforms.functional as F from PIL import Image from einops import rearrange, reduce, repeat VIS_EVAL_AUG_REGISTRY = Registry('VIS_EVAL_AUG') VIS_TRAIN_AUG_REGISTRY = Registry('VIS_TRAIN_AUG') def get_size_with_aspect_ratio(image_size, size, max_size=None): w, h = image_size if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: size = int(round(max_size * min_original_size / max_original_size)) if (w <= h and w == size) or (h <= w and h == size): return (h, w) if w < h: ow = size oh = int(size * h / w) else: oh = size ow = int(size * w / h) return (oh, ow) def get_tgt_size(image_size, size, max_size=None): if isinstance(size, (list, tuple)): return size[::-1] else: return get_size_with_aspect_ratio(image_size, size, max_size) def pil_torch_to_numpy(video, masks, has_ann, float_image=True): # n t' h w # list[pil_image, rgb], t # t N, T = masks.shape[:2] has_ann_idx = torch.where(has_ann)[0] # time_idx # list[Image], t -> list[h w 3, 255uint8], t masks = masks.permute(1, 0, 2, 3).contiguous().unbind(0) # list[n h w] t' numpy_masks = [[]] * len(has_ann) # list[list[h w, 01_uint8], n], t assert len(has_ann_idx) == len(masks) for fmask, taylor in zip(masks, has_ann_idx): # n h w fnumpy_masks = [] for mk in fmask.unbind(0): # h w fnumpy_masks.append(mk.numpy().astype(np.uint8)) numpy_masks[taylor] = fnumpy_masks if float_image: # list[h w 3, 0-1float], t video = [F.to_tensor(frame).permute(1,2,0).numpy() for frame in video] else: # uint8 video = [np.array(frame) for frame in video] return video, numpy_masks def numpy_to_pil_torch(video, masks, has_ann): # numpy, numpy -> torch, torch # list[h w 3, 0-1float], t H, W = video[0].shape[:2] T = has_ann.int().sum() video = [Image.fromarray(np.uint8(aug_vid * 255), mode="RGB") for aug_vid in video] # t'n h w torch_masks = torch.stack([torch.from_numpy(obj_mk).bool() for frame_mk in masks for obj_mk in frame_mk], dim=0) torch_masks = rearrange(torch_masks, '(T N) h w -> N T h w', T=T) # n t' h w return video, torch_masks ================================================ FILE: data_schedule/vis/vis_frame_sampler.py ================================================ from detectron2.utils.registry import Registry import random import numpy as np import torch import logging from detectron2.utils import comm VIS_FRAMES_SAMPLER_REGISTRY = Registry('VIS_FRAMES_SAMPLER') import random @VIS_FRAMES_SAMPLER_REGISTRY.register() class Naive_ReferenceFrame_FrameSampler: def __init__(self, sampler_configs, dataset_meta, **kwargs): self.reference_frame_step_size = dataset_meta.get('step_size') self.clip_sizes = list(sampler_configs['clip_sizes']) # list[int] self.clip_distribute = sampler_configs['clip_distribute'] # dense, sparse, local_global self.clip_position = sampler_configs['clip_position'] # former, center, latter if max(self.clip_sizes) > self.reference_frame_step_size: if comm.is_main_process(): logging.warning('') def __call__(self, frame_idx=None, all_frames=None, # list[str] **kwargs): random_clip_size = random.choice(self.clip_sizes) video_len = len(all_frames) sample_indx = [frame_idx] if (self.clip_position == 'center') and (self.clip_distribute == 'local_global'): if random_clip_size != 1: sample_id_before = random.randint(1, 3) sample_id_after = random.randint(1, 3) local_indx = [max(0, frame_idx - sample_id_before), min(video_len - 1, frame_idx + sample_id_after)] sample_indx.extend(local_indx) if random_clip_size > 3: all_inds = list(range(video_len)) global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):] global_n = random_clip_size - len(sample_indx) if len(global_inds) > global_n: select_id = random.sample(range(len(global_inds)), global_n) for s_id in select_id: sample_indx.append(global_inds[s_id]) elif video_len >= global_n: select_id = random.sample(range(video_len), global_n) for s_id in select_id: sample_indx.append(all_inds[s_id]) else: select_id = random.sample(range(video_len), global_n - video_len) + list(range(video_len)) for s_id in select_id: sample_indx.append(all_inds[s_id]) elif (self.clip_position == 'center') and (self.clip_distribute == 'dense'): half_size = (random_clip_size - 1) // 2 sample_indx += list(range(frame_idx - half_size, frame_idx)) sample_indx += list(range(frame_idx+1, half_size + frame_idx + 1)) if len(sample_indx) < random_clip_size: sample_indx = [min(sample_indx)] + sample_indx assert len(sample_indx) == random_clip_size sample_indx = torch.tensor(sample_indx) sample_indx = sample_indx.clamp_(min=0, max=video_len-1) sample_indx = sample_indx.tolist() else: raise ValueError() sample_indx.sort() sampled_frames = [all_frames[idx] for idx in sample_indx] return sampled_frames ================================================ FILE: handle_vps.py ================================================ import cv2 import numpy as np import os import shutil from PIL import Image import torch from tqdm import tqdm dataset_root = os.getenv('DATASET_PATH') # the original IVPS is the union of Kvasir and per-frame Mayo/CVC all_images = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame') ka_images = [b for b in all_images if b.startswith('K')] assert len(ka_images) == 1000 all_gts = os.listdir(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT') ka_gts = [b for b in all_gts if b.startswith('K')] assert len(ka_gts) == 1000 os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1'),exist_ok=True) os.makedirs(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1'),exist_ok=True) for image_id in tqdm(ka_images): shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/Frame', f'{image_id}'), os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/Frame/1', f'{image_id}'),) for image_id in tqdm(ka_gts): shutil.copy(os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/IVPS-TrainSet/GT', f'{image_id}'), os.path.join(f'{dataset_root}/MICCAI-VPS-dataset/Kvasir-SEG/GT/1', f'{image_id}'),) # normalize train directory for base_path in [f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train', f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train', f'{dataset_root}/MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train']: video_ids = os.listdir(base_path) frame_path = os.path.join(base_path, 'Frame') gt_path = os.path.join(base_path, 'GT') os.makedirs(frame_path, exist_ok=True) os.makedirs(gt_path, exist_ok=True) # Iterate through each video ID directory for vid in video_ids: shutil.copytree(os.path.join(base_path, vid, 'Frame'), os.path.join(frame_path, vid)) shutil.copytree(os.path.join(base_path, vid, 'GT'), os.path.join(gt_path, vid)) # TODO: dangerous: remove if you want # remove non-mask frames of each training set SET_NAME = [ 'Kvasir-train', 'Mayo-train', '300-train', '612-train', ] SET_NAME_TO_DIR = { 'Kvasir-train': 'MICCAI-VPS-dataset/Kvasir-SEG', 'Mayo-train': 'MICCAI-VPS-dataset/VPS-TrainSet/ASU-Mayo_Clinic/Train', '300-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ColonDB-300/Train', '612-train': 'MICCAI-VPS-dataset/VPS-TrainSet/CVC-ClinicDB-612/Train', } SET_NAME_TO_NUM_VIDEOS = { 'Kvasir-train': 1, 'Mayo-train': 10, '300-train': 6, '612-train': 18, '300-tv': 6, '612-test': 5, '612-val': 5 } SET_NAME_TO_PREFIX = { 'Kvasir-train': 'Kvasir-train', 'Mayo-train': 'Mayo-train', '300-train': '300-train', '612-train': '612-train', } root = os.getenv('DATASET_PATH') def get_frames_mask(mask_path, video_id, frames): # masks = [cv2.imread(os.path.join(mask_path, video_id, f'{f}.jpg')) for f in frames] if os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.png')): masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.png')).convert('L') for f in frames] elif os.path.exists(os.path.join(mask_path, video_id, f'{frames[0]}.jpg')): masks = [Image.open(os.path.join(mask_path, video_id, f'{f}.jpg')).convert('L') for f in frames] else: raise ValueError() masks = [np.array(mk) for mk in masks] masks = torch.stack([torch.from_numpy(mk) for mk in masks], dim=0) # t h w # assert set(masks.unique().tolist()) == set([0, 255]), f'{masks.unique().tolist()}' masks = (masks > 0).int() return masks, torch.ones(len(frames)).bool() num_delted_frames = 0 for train_set_name in SET_NAME: set_dir = SET_NAME_TO_DIR[train_set_name] frames_dir = os.path.join(root, set_dir, 'Frame') mask_dir = os.path.join(root, set_dir, 'GT') video_ids = os.listdir(frames_dir) for vid in tqdm(video_ids): frames = [haosen[:-4] for haosen in os.listdir(os.path.join(frames_dir, vid))] frame_has_fore = [get_frames_mask(mask_dir, vid, [haosen])[0].any() for haosen in tqdm(frames)] # list[t] assert len(frame_has_fore) == len(frames) num_delted_frames += (~ torch.tensor(frame_has_fore)).int().sum() for haosen, frame_name in tqdm(zip(frame_has_fore, frames)): if not haosen: os.remove(os.path.join(frames_dir, vid, f'{frame_name}.jpg')) if os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.jpg')): os.remove(os.path.join(mask_dir, vid, f'{frame_name}.jpg')) elif os.path.exists(os.path.join(mask_dir, vid, f'{frame_name}.png')): os.remove(os.path.join(mask_dir, vid, f'{frame_name}.png')) else: raise ValueError() print(f'should be {num_delted_frames}/1546.') # should be 1546 ================================================ FILE: main.py ================================================ import os import argparse import logging import importlib from trainers import task_to_trainer import detectron2.utils.comm as comm from termcolor import colored import logging import yaml import torch from utils.misc import setup_for_distributed def _highlight(code, filename): try: import pygments except ImportError: return code from pygments.lexers import Python3Lexer, YamlLexer from pygments.formatters import Terminal256Formatter lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) return code class _ColorfulFormatter(logging.Formatter): def __init__(self, *args, **kwargs): self._root_name = kwargs.pop("root_name") + "." self._abbrev_name = kwargs.pop("abbrev_name", "") if len(self._abbrev_name): self._abbrev_name = self._abbrev_name + "." super(_ColorfulFormatter, self).__init__(*args, **kwargs) def formatMessage(self, record): record.name = record.name.replace(self._root_name, self._abbrev_name) message = record.message # message, asctime, name, filename = record.message, record.asctime, record.name, record.filename log = super(_ColorfulFormatter, self).formatMessage(record) if (record.levelno == logging.WARNING) or (record.levelno == logging.ERROR) or (record.levelno == logging.CRITICAL): colored_message = colored(message, "red", attrs=["blink", "underline"]) elif record.levelno == logging.DEBUG: colored_message = colored(message, "yellow", attrs=["blink", "underline"]) else: # INFO/NOTSET colored_message = colored(message, "white") return log + colored_message def set_logging_file(output_dir, file_name, mode='a'): handler1 = logging.StreamHandler() handler2 = logging.FileHandler(os.path.join(output_dir, file_name), mode=mode) formatter = _ColorfulFormatter( colored("[%(asctime)s %(name)s %(filename)s]: ", "green"), datefmt="%m/%d %H:%M:%S", root_name=os.path.join(output_dir, file_name), abbrev_name=str('grey'), ) handler1.setFormatter(formatter) handler2.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(handler1) logger.addHandler(handler2) logger.setLevel(logging.DEBUG) def init_process_group_and_set_device(world_size, process_id, device_id): """ This function needs to be called on each spawned process to initiate learning using DistributedDataParallel. The function initiates the process' process group and assigns it a single GPU to use during training. """ torch.cuda.set_device(device_id) device = torch.device(f'cuda:{device_id}') if world_size > 1: torch.distributed.init_process_group( torch.distributed.Backend.NCCL, world_size=world_size, rank=process_id ) comm.create_local_process_group(world_size) torch.distributed.barrier(device_ids=[device_id]) setup_for_distributed(process_id == 0) return device def run(rank, configs, world_size): os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ['PYDEVD_WARN_SLOW_RESOLVE_TIMEOUT'] = "4" os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = "1" os.environ["DGLBACKEND"] = "pytorch" logging.getLogger('penman').setLevel(logging.WARNING) logging.getLogger('PIL').setLevel(logging.WARNING) logging.getLogger('PIL.PngImagePlugin').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('urllib3').setLevel(logging.WARNING) logging.getLogger('h5py').setLevel(logging.WARNING) init_process_group_and_set_device(world_size, process_id=rank, device_id=rank) if comm.is_main_process(): mode = configs['trainer_mode'] out_dir = configs['out_dir'] if mode == 'eval': num_of_eval_times = len([eval_txt for eval_txt in os.listdir(out_dir) if eval_txt.endswith('eval.txt')]) set_logging_file(out_dir, f"eval.txt", mode='a') path = os.path.join(out_dir, f"config_eval.yaml") else: num_of_train_times = len([train_txt for train_txt in os.listdir(out_dir) if train_txt.endswith('train.txt')]) if 'resume' in mode: set_logging_file(out_dir, f"train.txt", mode='a') else: set_logging_file(out_dir, f"train.txt", mode='w') path = os.path.join(out_dir, f"config_train.yaml") logging.debug("Running with full config:\n{}".format(_highlight(yaml.dump(configs, default_flow_style=False), ".yaml"))) with open(path, "w") as f: f.write(yaml.dump(configs, default_flow_style=False)) logging.debug("Full config saved to {}".format(path)) comm.synchronize() trainer = task_to_trainer[configs['task']](configs=configs) comm.synchronize() if configs['trainer_mode'] == 'eval': eval_ckpts = configs['eval_ckpts'] for lunch in eval_ckpts: trainer.load_ckpt(lunch, load_model=True, load_schedule=True, load_random=False, load_optimize=False) trainer.evaluate() else: if configs['trainer_mode'] == 'train_resume': ckpt_dirs = os.listdir(configs['out_dir']) ckpt_dirs = sorted([a for a in ckpt_dirs if a.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1])) trainer_ckpt = '/'.join([configs['out_dir'], ckpt_dirs[-1], 'ckpt.pth.tar']) trainer.load_ckpt(trainer_ckpt, load_model=True, load_schedule=True, load_random=True, load_optimize=True) trainer.train() if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument('--config_file', type=str, required=True) parser.add_argument('--trainer_mode', type=str, default='train_attmpt') parser.add_argument('--eval_path', type=str, default='') args = parser.parse_args() task, group, config, config2 = args.config_file.split('/')[-4:] assert config == config2[:-3] config_file = '.'.join(['output', task, group, config, config]) configs = importlib.import_module(config_file).trainer_configs configs['task'], configs['group'], configs['config'] = task, group, config configs['out_dir'] = os.path.join('./', 'output', task, group, config) configs['trainer_mode'] = args.trainer_mode if configs['trainer_mode'] == 'eval': eval_ckpts = [] eval_path = args.eval_path assert eval_path != '', f'eval path is none' if os.path.isfile(eval_path): eval_ckpts.append(eval_path) elif os.path.isdir(eval_path): ckpt_dirs = os.listdir(eval_path) ckpt_dirs = [taylor for taylor in ckpt_dirs if os.path.isdir(os.path.join(eval_path, taylor))] # epc[1]_iter[5000]_sap[60009] ckpt_dirs = sorted([billie for billie in ckpt_dirs if billie.startswith('epc')], key=lambda x:int(x.split('sap[')[-1][:-1])) eval_ckpts = [os.path.join(eval_path, cd, f'ckpt.pth.tar') for cd in ckpt_dirs] eval_ckpts = [eval_c for eval_c in eval_ckpts if os.path.exists(eval_c)] else: raise ValueError() configs['eval_ckpts'] = eval_ckpts else: pass gpu_ids = list(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) assert len(set(gpu_ids)) == len(gpu_ids) gpu_ids = list(range(len(gpu_ids))) if len(gpu_ids) > 1: torch.multiprocessing.spawn(run, nprocs=len(gpu_ids), args=(configs, len(gpu_ids))) elif len(gpu_ids) == 1: run(rank=0, configs=configs, world_size=len(gpu_ids)) ================================================ FILE: models/VIS/BackboneEncoderDecoder_WithScaleConsistency.py ================================================ import matplotlib.pyplot as plt from typing import Any, Optional, List, Dict, Set, Callable import torch import torch.nn as nn import torch.nn.functional as F from data_schedule import build_schedule from torch import Tensor from einops import repeat, rearrange, reduce from functools import partial from einops.layers.torch import Rearrange from torch import einsum import numpy as np import logging from data_schedule.vis.apis import VIS_TrainAPI_clipped_video, VIS_Aug_CallbackAPI from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann import torchvision.transforms.functional as Trans_F import copy from models.registry import register_model from models.optimization.optimizer import get_optimizer from models.optimization.scheduler import build_scheduler from models.backbone.utils import VideoMultiscale_Shape from detectron2.modeling import BACKBONE_REGISTRY, META_ARCH_REGISTRY class BackboneEncoderDecoder_WithScaleConsistency(nn.Module): def __init__( self, configs, pixel_mean = [0.485, 0.456, 0.406], pixel_std = [0.229, 0.224, 0.225],): super().__init__() self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) # 3 1 1 self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) self.loss_weight = configs['model']['loss_weight'] video_backbone_configs = configs['model']['video_backbone'] video_backbone_cls = BACKBONE_REGISTRY.get(video_backbone_configs['name']) self.video_backbone = video_backbone_cls(video_backbone_configs) self.max_stride = self.video_backbone.max_stride self.fusion_encoder = META_ARCH_REGISTRY.get(configs['model']['fusion']['name'])(configs['model']['fusion'], multiscale_shapes=self.video_backbone.multiscale_shapes) same_dim_multiscale_shapes = VideoMultiscale_Shape.set_multiscale_same_dim(shape_by_dim=self.video_backbone.multiscale_shapes, same_dim=configs['model']['fusion']['d_model']) self.decoder = META_ARCH_REGISTRY.get(configs['model']['decoder']['name'])(configs['model']['decoder'], multiscale_shapes=same_dim_multiscale_shapes) if configs['model']['fusion']['name'] == 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_v2': self.fusion_encoder.hack_ref(query_norm=self.decoder.temporal_query_norm, mask_mlp=self.decoder.query_mask) self.test_clip_size = configs['model']['test_clip_size'] @property def device(self): return self.pixel_mean.device def model_preds(self, videos, video_aux_dict,): if (not self.training) and (self.test_clip_size is not None): nf = videos.shape[2] clip_outputs = [] # list[dict] for start_idx in range(0, nf, self.test_clip_size): multiscales = self.video_backbone(x=videos[:, :, start_idx:(start_idx + self.test_clip_size)]) # b c t h w multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict) clip_outputs.append(self.decoder(multiscales, video_aux_dict=video_aux_dict)[-1]) # b t nq h w return [{ 'pred_masks': torch.cat([haosen['pred_masks'] for haosen in clip_outputs], dim=1), # b t n h w 'pred_class': torch.cat([haosen['pred_class'] for haosen in clip_outputs], dim=1), }] # b 3 t h w -> b 3 t h w multiscales = self.video_backbone(x=videos) # b c t h w multiscales = self.fusion_encoder(multiscales, video_aux_dict=video_aux_dict) return self.decoder(multiscales, video_aux_dict=video_aux_dict) def forward(self, batch_dict): assert self.training VIS_TrainAPI_clipped_video videos = batch_dict['video_dict']['videos'] targets = batch_dict['targets'] batch_size, nf = videos.shape[:2] videos = (videos - self.pixel_mean) / self.pixel_std size1 = np.random.choice([256, 288, 320, 352, 384, 416, 448]) vid_1 = F.interpolate(videos.flatten(0, 1), size=size1, mode='bilinear') vid_1 = rearrange(vid_1, '(b T) c h w -> b c T h w',b=batch_size, T=nf) pred1 = self.model_preds(vid_1, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w} pred1_loss = self.decoder.compute_loss(pred1, targets=targets, frame_targets=batch_dict['frame_targets'], video_aux_dict=batch_dict['video_dict']) loss_value_dict = {key: pred1_loss[key] for key in list(self.loss_weight.keys())} return loss_value_dict, self.loss_weight @torch.no_grad() def sample(self, batch_dict): assert not self.training VIS_EvalAPI_clipped_video_request_ann videos = batch_dict['video_dict']['videos'] # b t 3 h w, 0-1 orig_t, _, orig_h, orig_w = batch_dict['video_dict']['orig_sizes'][0] videos = (videos - self.pixel_mean) / self.pixel_std assert videos.shape[0] == 1 batch_size, T, _, H, W = videos.shape videos = videos.permute(0, 2, 1,3,4) # b c t h w decoder_output = self.model_preds(videos, video_aux_dict=batch_dict['video_dict']) # {pred_masks: b 1 t h w} if isinstance(decoder_output, list): decoder_output = decoder_output[-1] pred_masks = decoder_output['pred_masks'][0] # T n h w pred_masks = F.interpolate(pred_masks, size=(H, W), mode='bilinear') > 0 # T n h w pred_masks = pred_masks[:orig_t, :, :orig_h, :orig_w] # T n h w # pred_classes = decoder_output['pred_class'][0][:orig_t, :,:] # T n c, probability pred_classes = pred_classes.cpu().unbind(0) # list[n c], T pred_masks = pred_masks.cpu().unbind(0) # list[n h w], T VIS_Aug_CallbackAPI orig_video = videos[0][:, :orig_t, :orig_h, :orig_w].permute(1,0,2,3) # T 3 h w orig_video = Trans_F.normalize(orig_video, [0, 0, 0], 1 / self.pixel_std) orig_video = Trans_F.normalize(orig_video, -self.pixel_mean, [1, 1, 1]).cpu() return { 'video': [orig_video], # [t 3 h w], 1 'pred_masks': [pred_masks], # [list[n h w], t, bool], 1 'pred_class': [pred_classes], # [list[n c], t, probability], 1 } @staticmethod def get_optim_params_group(model, configs): weight_decay_norm = configs['optim']['weight_decay_norm'] weight_decay_embed = configs['optim']['weight_decay_embed'] defaults = {} defaults['lr'] = configs['optim']['base_lr'] defaults['weight_decay'] = configs['optim']['weight_decay'] norm_module_types = ( torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm, # NaiveSyncBatchNorm inherits from BatchNorm2d torch.nn.GroupNorm, torch.nn.InstanceNorm1d, torch.nn.InstanceNorm2d, torch.nn.InstanceNorm3d, torch.nn.LayerNorm, torch.nn.LocalResponseNorm, ) params: List[Dict[str, Any]] = [] memo: Set[torch.nn.parameter.Parameter] = set() log_lr_group_idx = {'backbone':None, 'base':None} for module_name, module in model.named_modules(): for module_param_name, value in module.named_parameters(recurse=False): if not value.requires_grad: continue # Avoid duplicating parameters if value in memo: continue memo.add(value) hyperparams = copy.copy(defaults) if "video_backbone" in module_name: hyperparams["lr"] = hyperparams["lr"] * configs['optim']['backbone_lr_multiplier'] if log_lr_group_idx['backbone'] is None: log_lr_group_idx['backbone'] = len(params) else: if log_lr_group_idx['base'] is None: log_lr_group_idx['base'] = len(params) # pos_embed, norm, embedding的weight decay特殊对待 if ( "relative_position_bias_table" in module_param_name or "absolute_pos_embed" in module_param_name ): logging.debug(f'setting weight decay of {module_name}.{module_param_name} to zero') hyperparams["weight_decay"] = 0.0 if isinstance(module, norm_module_types): hyperparams["weight_decay"] = weight_decay_norm if isinstance(module, torch.nn.Embedding): hyperparams["weight_decay"] = weight_decay_embed params.append({"params": [value], **hyperparams}) return params, log_lr_group_idx @register_model def backbone_encoder_decoder_withScaleConsistency(configs, device): from .aux_mapper import AUXMapper_v1 model = BackboneEncoderDecoder_WithScaleConsistency(configs) model.to(device) params_group, log_lr_group_idx = BackboneEncoderDecoder_WithScaleConsistency.get_optim_params_group(model=model, configs=configs) to_train_num_parameters = len([n for n, p in model.named_parameters() if p.requires_grad]) assert len(params_group) == to_train_num_parameters, f'' optimizer = get_optimizer(params_group, configs) scheduler = build_scheduler(configs=configs, optimizer=optimizer) model_input_mapper = AUXMapper_v1(configs['model']['input_aux']) train_samplers, train_loaders, eval_function = build_schedule(configs, model_input_mapper.mapper, partial(model_input_mapper.collate, max_stride=model.max_stride)) return model, optimizer, scheduler, train_samplers, train_loaders, log_lr_group_idx, eval_function ================================================ FILE: models/VIS/__init__.py ================================================ from . import BackboneEncoderDecoder_WithScaleConsistency from .. import modality_input_mappers from .. import backbone from .. import decoder from .. import encoder ================================================ FILE: models/VIS/aux_mapper.py ================================================ import torch from torch.nn import functional as F from models.registry import register_model from data_schedule.utils.box_ops import box_xyxy_to_cxcywh from models.registry import MODELITY_INPUT_MAPPER_REGISTRY from data_schedule.vis.apis import VIS_TrainAPI_clipped_video from data_schedule.vis.apis import VIS_EvalAPI_clipped_video_request_ann from utils.misc import nested_tensor_from_videos_list_with_stride class AUXMapper_v1: def __init__(self, aux_configs): video_auxes = aux_configs['video_auxes'] video_auxes_names = [config['name'] for config in video_auxes] assert len(list(set(video_auxes_names))) == len(video_auxes_names), '每个aux的名字必须不一样' self.video_auxes_names = video_auxes_names self.video_auxes = [MODELITY_INPUT_MAPPER_REGISTRY.get(config['name'])(config) for config in video_auxes] self.targets_auxes = None def mapper(self, data_dict, mode,): if mode == 'train': VIS_TrainAPI_clipped_video video = data_dict['video_dict']['video'] for aux, aux_name in zip(self.video_auxes, self.video_auxes_names): data_dict['video_dict'][aux_name] = aux.mapper(video) elif mode == 'evaluate': VIS_EvalAPI_clipped_video_request_ann video = data_dict['video_dict']['video'] for aux, aux_name in zip(self.video_auxes, self.video_auxes_names): data_dict['video_dict'][aux_name] = aux.mapper(video) else: raise ValueError() return data_dict def collate(self, batch_dict, mode, max_stride): if mode == 'train': VIS_TrainAPI_clipped_video video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride) targets = [sample['targets'] for sample in batch_dict] frame_has_ann = [clip_tgt['has_ann'] for clip_tgt in targets] # list[t], b frame_targets = [sample['frame_targets'] for sample in batch_dict] _, pad_T, _, pad_H, pad_W = video_dict['videos'].shape targets = self.collate_targets(targets=targets, pad_H=pad_H, pad_W=pad_W, pad_T=pad_T) frame_targets = self.collate_frame_targets(frame_targets=frame_targets, frame_has_ann=frame_has_ann, pad_H=pad_H, pad_W=pad_W, pad_T=pad_T) ret = { 'video_dict': video_dict, 'targets': targets, 'frame_targets': frame_targets, 'meta_idxs': [sample['meta_idx'] for sample in batch_dict], 'visualize': [sample['visualize'] for sample in batch_dict], } elif mode == 'evaluate': VIS_EvalAPI_clipped_video_request_ann assert len(batch_dict) == 1 video_dict = self.collate_video_dict(batch_dict, max_stride=max_stride) # 不pad metas = [sample['meta'] for sample in batch_dict] collated_metas = {} for key in metas[0].keys(): collated_metas[key] = [mt[key] for mt in metas] ret = { 'video_dict': video_dict, 'metas': collated_metas, 'meta_idxs': [sample['meta_idx'] for sample in batch_dict], 'visualize': [sample['visualize'] for sample in batch_dict], } debug_data = False if debug_data: self.visualize_input_target_for_debug_data(ret) # ./test.png return ret def collate_video_dict(self, batch_dict, max_stride): videos = [sample['video_dict']['video'] for sample in batch_dict] # list[ti 3 hi wi] -> b T 3 H W orig_sizes = [list(vid.shape) for vid in videos] # t 3 h w if type(max_stride) == int: # temporal max stride 为1, spatial max stride pad_stride = [1, max_stride] if (type(max_stride) == list) and (len(max_stride) == 2): pad_stride = max_stride videos = nested_tensor_from_videos_list_with_stride(videos, max_stride=pad_stride).tensors # b t c h w video_dicts = {'videos': videos, 'orig_sizes': orig_sizes} for aux_name, aux in zip(self.video_auxes_names, self.video_auxes): auxes = [sample['video_dict'][aux_name] for sample in batch_dict] # list[dict] / list[tensor] collated_auxes = aux.collate(auxes, batch_videos=videos) # list[dict] / tensor if isinstance(auxes[0], dict): keys = collated_auxes.keys() for key in keys: assert key not in video_dicts video_dicts[key] = collated_auxes[key] else: video_dicts[aux_name] = collated_auxes return video_dicts def collate_frame_targets(self, frame_targets, frame_has_ann, pad_H, pad_W, pad_T): # VIS_TrainAPI_clipped_video ret = {} has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in frame_has_ann], dim=0).flatten() # bT ret['has_ann'] = has_ann masks = [ftarget['masks'] for sample in frame_targets for ftarget in sample] # list[ni h w], bt' masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2])).bool() for m in masks] # list[ni H W], bt' ret['masks'] = masks # list[ni h w], bt' classes = [ftarget['classes'] for sample in frame_targets for ftarget in sample] # list[ni], bt' ret['classes'] = classes if 'boxes' in frame_targets[0][0]: boxes = [ftarget['boxes'] for sample in frame_targets for ftarget in sample] # list[ni 4], x1y1x2y2, bt' boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes] boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=bx.dtype) for bx in boxes] # 0-1 ret['boxes'] = boxes # list[ni 4], bt' return ret def collate_targets(self, targets, pad_H, pad_W, pad_T): VIS_TrainAPI_clipped_video has_ann = [sample['has_ann'] for sample in targets] # list[t], bool has_ann = torch.stack([F.pad(ha.float(), pad=(0, pad_T - len(ha)), value=0.).bool() for ha in has_ann], dim=0) # b T masks = [sample['masks'] for sample in targets] masks = [F.pad(m.float(), pad=(0, pad_W-m.shape[-1], 0, pad_H-m.shape[-2]), value=0.).bool() \ for m in masks] # list[ni T' H W] classes = [sample['classes'] for sample in targets] ret = { 'masks': masks, # list[ni T' h w] 'has_ann': has_ann, # b T 'classes': classes, # list[ni], b } if 'boxes' in targets[0]: boxes = [sample['boxes'] for sample in targets] # list[ni T' 4], x1y1x2y2 boxes = [box_xyxy_to_cxcywh(bx) for bx in boxes] boxes = [bx / torch.tensor([pad_W, pad_H, pad_W, pad_H], dtype=torch.float) for bx in boxes] # 0-1 ret.update({'boxes': boxes,}) return ret def visualize_input_target_for_debug_data(self, ret): videos = ret['video_dict']['videos'] # b T 3 H W pass ================================================ FILE: models/__init__.py ================================================ import os from .registry import model_entrypoint if os.getenv('CURRENT_TASK') == 'VIS': from . import VIS else: raise ValueError() ================================================ FILE: models/backbone/__init__.py ================================================ from . import res2net, pvtv2 ================================================ FILE: models/backbone/pvtv2.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath from timm.models.registry import register_model class DWConv(nn.Module): def __init__(self, dim): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, groups=dim) def forward(self, x, H, W): B,N,C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) return x class Mlp(nn.Module): def __init__(self, in_features, hidden_features): super().__init__() self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv = DWConv(hidden_features) self.fc2 = nn.Linear(hidden_features, in_features) def forward(self, x, H, W): x = self.fc1(x) x = F.gelu(self.dwconv(x, H, W)) x = self.fc2(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads, sr_ratio): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.num_heads = num_heads self.scale = (dim//num_heads)**(-0.5) self.q = nn.Linear(dim, dim) self.kv = nn.Linear(dim, dim*2) self.proj = nn.Linear(dim, dim) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C//self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ = self.norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio, drop_path, sr_ratio): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.LayerNorm(dim, eps=1e-6) self.mlp = Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio)) def forward(self, x, H, W): x = x + self.drop_path(self.attn(self.norm1(x), H, W)) x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) return x class OverlapPatchEmbed(nn.Module): def __init__(self, patch_size, stride, in_chans, embed_dim): super().__init__() self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size//2, patch_size//2)) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): x = self.proj(x) B,C,H,W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, H, W class PVT(nn.Module): def __init__(self, embed_dims, mlp_ratios, depths, snapshot, sr_ratios=[8, 4, 2, 1]): super().__init__() self.depths = depths self.snapshot = snapshot # patch_embed self.patch_embed1 = OverlapPatchEmbed(patch_size=7, stride=4, in_chans=3, embed_dim=embed_dims[0]) self.patch_embed2 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed3 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.patch_embed4 = OverlapPatchEmbed(patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # transformer encoder dpr = [x.item() for x in torch.linspace(0, 0.1, sum(depths))] # stochastic depth decay rule cur = 0 self.block1 = nn.ModuleList([Block(dim=embed_dims[0], num_heads=1, mlp_ratio=mlp_ratios[0], drop_path=dpr[cur + i], sr_ratio=sr_ratios[0]) for i in range(depths[0])]) self.norm1 = nn.LayerNorm(embed_dims[0], eps=1e-6) cur += depths[0] self.block2 = nn.ModuleList([Block(dim=embed_dims[1], num_heads=2, mlp_ratio=mlp_ratios[1], drop_path=dpr[cur + i], sr_ratio=sr_ratios[1]) for i in range(depths[1])]) self.norm2 = nn.LayerNorm(embed_dims[1], eps=1e-6) cur += depths[1] self.block3 = nn.ModuleList([Block(dim=embed_dims[2], num_heads=5, mlp_ratio=mlp_ratios[2], drop_path=dpr[cur + i], sr_ratio=sr_ratios[2]) for i in range(depths[2])]) self.norm3 = nn.LayerNorm(embed_dims[2], eps=1e-6) cur += depths[2] self.block4 = nn.ModuleList([Block(dim=embed_dims[3], num_heads=8, mlp_ratio=mlp_ratios[3], drop_path=dpr[cur + i], sr_ratio=sr_ratios[3]) for i in range(depths[3])]) self.norm4 = nn.LayerNorm(embed_dims[3], eps=1e-6) state_dict:dict = torch.load(self.snapshot, map_location='cpu') state_dict.pop("head.weight") state_dict.pop("head.bias") self.load_state_dict(state_dict, strict=True) del state_dict @torch.jit.ignore def no_weight_decay(self): return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better def forward(self, x): B = x.shape[0] # stage 1 out1, H, W = self.patch_embed1(x) for i, blk in enumerate(self.block1): out1 = blk(out1, H, W) out1 = self.norm1(out1).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # stage 2 out2, H, W = self.patch_embed2(out1) for i, blk in enumerate(self.block2): out2 = blk(out2, H, W) out2 = self.norm2(out2).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # stage 3 out3, H, W = self.patch_embed3(out2) for i, blk in enumerate(self.block3): out3 = blk(out3, H, W) out3 = self.norm3(out3).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() # stage 4 out4, H, W = self.patch_embed4(out3) for i, blk in enumerate(self.block4): out4 = blk(out4, H, W) out4 = self.norm4(out4).reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() return out1, out2, out3, out4 from detectron2.modeling import BACKBONE_REGISTRY from einops import rearrange, reduce, repeat from .utils import VideoMultiscale_Shape, ImageMultiscale_Shape import os import time @BACKBONE_REGISTRY.register() class PVT_V2(nn.Module): def __init__(self, configs) -> None: super().__init__() pt_path = os.getenv('PT_PATH') pvt_v2 = PVT(embed_dims=[64, 128, 320, 512], mlp_ratios=[8, 8, 4, 4], depths=[3, 4, 6, 3], snapshot=os.path.join(pt_path, 'pvt_v2/pvt_v2_b2.pth')) self.pvt_v2 = pvt_v2 freeze = configs['freeze'] if freeze: for p in self.parameters(): p.requires_grad_(False) self.multiscale_shapes = {} for name, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'], [4, 8, 16, 32], [64, 128, 320, 512]): self.multiscale_shapes[name] = ImageMultiscale_Shape(spatial_stride=spatial_stride, dim=dim) self.max_stride = 32 def forward(self, x): if not self.training: batch_feats = [] for haosen in x: feats = self.pvt_v2(haosen.unsqueeze(0)) batch_feats.append(feats) batch_feats = list(zip(*batch_feats)) # 4 batch_feats = [torch.cat(haosen, dim=0) for haosen in batch_feats] # list[bt c h w] ret = {} names = ['res2', 'res3', 'res4', 'res5'] for name, feat in zip(names, batch_feats): ret[name] = feat return ret else: layer_outputs = self.pvt_v2(x) ret = {} names = ['res2', 'res3', 'res4', 'res5'] for name, feat in zip(names, layer_outputs): ret[name] = feat return ret def num_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) @BACKBONE_REGISTRY.register() class Video2D_PVT_V2(nn.Module): def __init__(self, configs) -> None: super().__init__() self.image_homo = PVT_V2(configs=configs) self.multiscale_shapes = {} for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'], [1, 1, 1, 1], [4, 8, 16, 32], [64, 128, 320, 512]): self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride, spatial_stride=spatial_stride, dim=dim) self.max_stride = [1, 32] def forward(self, x): batch_size, _, T = x.shape[:3] x = rearrange(x, 'b c t h w -> (b t) c h w').contiguous() layer_outputs = self.image_homo(x) layer_outputs = {key: rearrange(value.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T).contiguous() \ for key, value in layer_outputs.items()} return layer_outputs def num_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) ================================================ FILE: models/backbone/res2net.py ================================================ import math import torch import torch.nn as nn import os import torch.nn.functional as F from detectron2.modeling import BACKBONE_REGISTRY from einops import rearrange, reduce, repeat from .utils import VideoMultiscale_Shape class Bottle2neck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): super(Bottle2neck, self).__init__() width = int(math.floor(planes*(baseWidth/64.0))) self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(width*scale) self.nums = 1 if scale == 1 else scale - 1 if stype == 'stage': self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) convs, bns = [], [] for i in range(self.nums): convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) bns.append(nn.BatchNorm2d(width)) self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList(bns) self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.downsample = downsample self.stype = stype self.scale = scale self.width = width def forward(self, x): out = F.relu(self.bn1(self.conv1(x)), inplace=True) spx = torch.split(out, self.width, 1) for i in range(self.nums): sp = spx[i] if i == 0 or self.stype == 'stage' else sp + spx[i] sp = self.convs[i](sp) sp = F.relu(self.bns[i](sp), inplace=True) out = sp if i == 0 else torch.cat((out, sp), 1) if self.scale != 1 and self.stype == 'normal': out = torch.cat((out, spx[self.nums]), 1) elif self.scale != 1 and self.stype == 'stage': out = torch.cat((out, self.pool(spx[self.nums])), 1) out = self.bn3(self.conv3(out)) if self.downsample is not None: x = self.downsample(x) return F.relu(out+x, inplace=True) class Res2Net(nn.Module): def __init__(self, layers, snapshot, baseWidth=26, scale=4): super(Res2Net, self).__init__() self.inplanes = 64 self.snapshot = snapshot self.baseWidth = baseWidth self.scale = scale self.conv1 = nn.Sequential( nn.Conv2d(3, 32, 3, 2, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 32, 3, 1, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 64, 3, 1, 1, bias=False) ) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(Bottle2neck, 64, layers[0]) self.layer2 = self._make_layer(Bottle2neck, 128, layers[1], stride=2) self.layer3 = self._make_layer(Bottle2neck, 256, layers[2], stride=2) self.layer4 = self._make_layer(Bottle2neck, 512, layers[3], stride=2) state_dict:dict = torch.load(self.snapshot, map_location='cpu') self.load_state_dict(state_dict, strict=False) del state_dict def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample=downsample, stype='stage', baseWidth=self.baseWidth, scale=self.scale)] self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) return nn.Sequential(*layers) def forward(self, x): out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) out2 = self.layer1(out1) out3 = self.layer2(out2) out4 = self.layer3(out3) out5 = self.layer4(out4) return out2, out3, out4, out5 def initialize(self): self.load_state_dict(torch.load(self.snapshot), strict=False) @BACKBONE_REGISTRY.register() class Res2Net_50_EachFrame(nn.Module): def __init__(self, configs) -> None: super().__init__() pt_path = os.getenv('PT_PATH') res2net = Res2Net([3, 4, 6, 3], os.path.join(pt_path, 'res2net/res2net50_v1b_26w_4s-3cf99910.pth')) self.res2net = res2net freeze = configs['freeze'] if freeze: for p in self.parameters(): p.requires_grad_(False) self.multiscale_shapes = {} for name, temporal_stride, spatial_stride, dim in zip(['res2', 'res3', 'res4', 'res5'], [1, 1, 1, 1], [4, 8, 16, 32], [256, 512, 1024, 2048]): self.multiscale_shapes[name] = VideoMultiscale_Shape(temporal_stride=temporal_stride, spatial_stride=spatial_stride, dim=dim) self.max_stride = [1, 32] def forward(self, x): batch_size, _, T = x.shape[:3] x = rearrange(x, 'b c t h w -> (b t) c h w') layer_outputs = self.res2net(x) ret = {} names = ['res2', 'res3', 'res4', 'res5'] for name, feat in zip(names, layer_outputs): ret[name] = rearrange(feat.contiguous(), '(b t) c h w -> b c t h w',b=batch_size, t=T) return ret def num_parameters(self): return sum(p.numel() for p in self.parameters() if p.requires_grad) ================================================ FILE: models/backbone/utils.py ================================================ class VideoMultiscale_Shape: def __init__(self, temporal_stride, spatial_stride, dim) -> None: self.temporal_stride = temporal_stride self.spatial_stride = spatial_stride self.dim = dim @staticmethod def set_multiscale_same_dim(shape_by_dim, same_dim): return { key: VideoMultiscale_Shape(temporal_stride=value.temporal_stride, spatial_stride=value.spatial_stride, dim=same_dim) for key,value in shape_by_dim.items() } class ImageMultiscale_Shape: def __init__(self, spatial_stride, dim) -> None: self.spatial_stride = spatial_stride self.dim = dim ================================================ FILE: models/decoder/__init__.py ================================================ from . import mask2former_video ================================================ FILE: models/decoder/mask2former_video.py ================================================ # multi-scale features, b c h w -> module -> obj queries, predictions, b nq c import torch.nn as nn from models.layers.decoder_layers import CrossAttentionLayer, SelfAttentionLayer, FFNLayer from models.layers.anyc_trans import MLP import torch.nn.functional as F import torch import copy from models.layers.utils import zero_module, _get_clones from models.layers.position_encoding import build_position_encoding from einops import rearrange, reduce, repeat from scipy.optimize import linear_sum_assignment from models.layers.matching import batch_dice_loss, batch_sigmoid_ce_loss, batch_sigmoid_focal_loss, dice_loss, ce_mask_loss from detectron2.modeling import META_ARCH_REGISTRY import detectron2.utils.comm as comm import data_schedule.utils.box_ops as box_ops from models.layers.utils import zero_module from utils.misc import is_dist_avail_and_initialized from collections import defaultdict from detectron2.projects.point_rend.point_features import point_sample from torch.cuda.amp import autocast from detectron2.projects.point_rend.point_features import ( get_uncertain_point_coords_with_randomness, point_sample, ) def calculate_uncertainty(logits): """ We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the foreground class in `classes`. Args: logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is the number of foreground classes. The values are logits. Returns: scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most uncertain locations having the highest uncertainty score. """ assert logits.shape[1] == 1 gt_class_logits = logits.clone() return -(torch.abs(gt_class_logits)) class Video_SetMatchingLoss(nn.Module): def __init__(self, loss_config, num_classes,) -> None: super().__init__() self.num_classes = num_classes # n=1 / n=0 / n>1 self.matching_metrics = loss_config['matching_metrics'] # mask: mask/dice; point_sample_mask: .. self.losses = loss_config['losses'] self.aux_layer_weights = loss_config['aux_layer_weights'] # int/list empty_weight = torch.ones(self.num_classes + 1) empty_weight[-1] = loss_config['background_cls_eos'] self.register_buffer('empty_weight', empty_weight) # self.register_buffer('small_obj_weight', torch.tensor(loss_config['small_obj_weight']).float()) self._warmup_iters = 2000 self.register_buffer("_iter", torch.zeros([1])) @property def device(self,): return self.empty_weight.device def compute_loss(self, model_outs, targets, video_aux_dict, **kwargs): # list[n t' h w], batch if 'masks' in targets: num_objs = sum([haosen.flatten(1).any(-1).int().sum().item() for haosen in targets['masks']]) # list[n t' 4], batch elif 'boxes' in targets: # n t' 2 -> n t -> n num_objs = sum([(haosen[:, :, 2:] > 0).all(-1).any(-1).int().sum().item() for haosen in targets['boxes']]) else: raise ValueError('targets里没有boxes/masks, 需要确定数量') num_objs = torch.as_tensor([num_objs], dtype=torch.float, device=self.device) if is_dist_avail_and_initialized(): torch.distributed.all_reduce(num_objs) num_objs = torch.clamp(num_objs / comm.get_world_size(), min=1).item() if isinstance(self.aux_layer_weights, list): assert len(self.aux_layer_weights) == (len(model_outs) - 1) else: self.aux_layer_weights = [self.aux_layer_weights] * (len(model_outs) - 1) layer_weights = self.aux_layer_weights + [1.] loss_values = { 'mask_dice':0., 'mask_ce':0., 'box_l1': 0., 'box_giou': 0., 'class_ce':0., 'mask_dice_smobj':0., 'mask_ce_smobj':0., 'boxMask_dice':0., 'boxMask_ce':0., } if ('mask_ce_dice' in self.matching_metrics) or ('mask_ce_dice' in self.losses): # mask interpolate tgt_mask_shape = targets['masks'][0].shape[-2:] # list[n t H W], b for layer_idx in range(len(model_outs)): # b t nq h w batch_size, nf = model_outs[layer_idx]['pred_masks'].shape[:2] model_outs[layer_idx]['pred_masks'] = rearrange(F.interpolate(model_outs[layer_idx]['pred_masks'].flatten(0, 1), size=tgt_mask_shape, mode='bilinear', align_corners=False), '(b t) n h w -> b t n h w',b=batch_size, t=nf) for taylor, layer_out in zip(layer_weights, model_outs): if taylor != 0: matching_indices = self.matching(layer_out, targets) for loss in self.losses: loss_extra_param = self.losses[loss] if loss == 'mask_dice_ce' : loss_dict = self.loss_mask_dice_ce(layer_out, targets, matching_indices, num_objs, loss_extra_param=loss_extra_param) elif loss == 'class_ce': loss_dict = self.loss_class_ce(layer_out, targets, matching_indices, num_objs, loss_extra_param=loss_extra_param) elif loss == 'point_mask_dice_ce': loss_dict = self.loss_point_mask_dice_ce(layer_out, targets, matching_indices, num_objs, loss_extra_param=loss_extra_param) else: raise ValueError() for key, value in loss_dict.items(): loss_values[key] = loss_values[key] + value return loss_values @torch.no_grad() def matching(self, layer_out, targets): batch_size = len(targets['masks']) if 'masks' in targets else len(targets['boxes']) indices = [] has_ann = targets['has_ann'] for i in range(batch_size): C = 0. if 'class_prob' in self.matching_metrics: out_cls = layer_out['pred_class'][i].softmax(-1) # nq c tgt_cls = targets['classes'][i] # n cost_class = - out_cls[:, tgt_cls] # nq n C += self.matching_metrics['class_prob']['prob'] * cost_class if 'mask_dice_ce' in self.matching_metrics: out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w tgt_mask = targets['masks'][i].to(out_mask) # ni t' H W cost_mask = batch_sigmoid_ce_loss(out_mask.flatten(1), tgt_mask.flatten(1)) cost_dice = batch_dice_loss(out_mask.flatten(1), tgt_mask.flatten(1)) C += self.matching_metrics['mask_dice_ce']['ce'] * cost_mask + \ self.matching_metrics['mask_dice_ce']['dice'] * cost_dice if 'point_mask_dice_ce' in self.matching_metrics: out_mask = layer_out['pred_masks'][i][has_ann[i]].permute(1, 0, 2, 3).contiguous() # nq t' h w tgt_mask = targets['masks'][i].to(out_mask)# ni t' H W nf = out_mask.shape[1] out_mask = out_mask.flatten(0, 1)[:, None] tgt_mask = tgt_mask.flatten(0, 1)[:, None] # all masks share the same set of points for efficient matching! point_coords = torch.rand(1, self.matching_metrics['point_mask_dice_ce']['num_points'], 2, device=self.device) # get gt labels tgt_mask = point_sample( tgt_mask, point_coords.repeat(tgt_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) # nqt s tgt_mask = rearrange(tgt_mask, '(nq t) s -> nq t s',t=nf) out_mask = point_sample( out_mask, point_coords.repeat(out_mask.shape[0], 1, 1), align_corners=False, ).squeeze(1) # nit s out_mask = rearrange(out_mask, '(nq t) s -> nq t s',t=nf) with autocast(enabled=False): out_mask = out_mask.float().flatten(1) # nq num_points tgt_mask = tgt_mask.float().flatten(1) cost_mask = batch_sigmoid_ce_loss(out_mask, tgt_mask) cost_dice = batch_dice_loss(out_mask, tgt_mask) C += self.matching_metrics['point_mask_dice_ce']['ce'] * cost_mask + \ self.matching_metrics['point_mask_dice_ce']['dice'] * cost_dice C = C.cpu() indices.append(linear_sum_assignment(C)) return [ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices ] def loss_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param): has_ann = targets['has_ann'] # b t src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w tgt_masks = targets['masks'] # list[n t' h w] # list[nq t' h w] -> n_sigma t' h w src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0) tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0) tgt_masks = tgt_masks.to(src_masks) losses = { "mask_ce": ce_mask_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs), "mask_dice": dice_loss(src_masks.flatten(0, 1).flatten(1), tgt_masks.flatten(0, 1).flatten(1), num_boxes=num_objs), } return losses def loss_point_mask_dice_ce(self, outputs, targets, indices, num_objs, loss_extra_param): has_ann = targets['has_ann'] # b t src_masks = outputs['pred_masks'].permute(0, 2, 1, 3, 4).contiguous() # b nq t h w tgt_masks = targets['masks'] # list[n t' h w] # list[nq t' h w] -> n_sigma t' h w src_masks = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_masks, indices, has_ann)],dim=0) tgt_masks = torch.cat([t[J] for t, (_, J) in zip(tgt_masks, indices)], dim=0) tgt_masks = tgt_masks.to(src_masks) nf = src_masks.shape[1] # No need to upsample predictions as we are using normalized coordinates :) # NT x 1 x H x W src_masks = src_masks.flatten(0, 1).unsqueeze(1).contiguous() # nt' 1 h w target_masks = tgt_masks.flatten(0, 1).unsqueeze(1).contiguous() with torch.no_grad(): # sample point_coords point_coords = get_uncertain_point_coords_with_randomness( src_masks, lambda logits: calculate_uncertainty(logits), loss_extra_param['num_points'], loss_extra_param['oversample_ratio'], loss_extra_param['importance_sample_ratio'], ) # get gt labels point_labels = point_sample( target_masks, point_coords, align_corners=False, ).squeeze(1) # nt' s point_logits = point_sample( src_masks, point_coords, align_corners=False, ).squeeze(1) # nt' s # point_logits = rearrange(point_logits, '(n t) s -> n (t s)',t=nf) # point_labels = rearrange(point_labels, '(n t) s -> n (t s)',t=nf) losses = { "mask_dice": ce_mask_loss(point_logits, point_labels, num_objs), "mask_ce": dice_loss(point_logits, point_labels, num_objs), } del src_masks del target_masks return losses def loss_class_ce(self, outputs, targets, indices, num_objs, loss_extra_param): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ src_logits = outputs["pred_class"].float() # b nq c idx = self._get_src_permutation_idx(indices) # list[n], b -> bn target_classes_o = torch.cat([t[J] for t, (_, J) in zip(targets['classes'], indices)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=self.device ) target_classes[idx] = target_classes_o loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) losses = {"class_ce": loss_ce} return losses def loss_box_l1_giou(self, outputs, targets, indices, num_objs, loss_extra_param): tgt_boxes = targets['boxes'] # list[n tl 4], b has_ann = targets['has_ann'] # b t src_boxes = outputs['pred_boxes'].sigmoid().permute(0, 2, 1, 3).contiguous() # b nq t 4 src_boxes = torch.cat([t[J][:, haosen] for t, (J, _), haosen in zip(src_boxes, indices, has_ann)], dim=0) # n_sum t' 4 tgt_boxes = torch.cat([t[J] for t, (_, J) in zip(tgt_boxes, indices)], dim=0) # n_sum t' 4 nf = tgt_boxes.shape[1] loss_l1 = F.l1_loss(src_boxes, tgt_boxes, reduction='none').flatten(1) # n_sum t'4 loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( box_ops.box_cxcywh_to_xyxy(src_boxes.flatten(0, 1)), box_ops.box_cxcywh_to_xyxy(tgt_boxes.flatten(0, 1)))) # n_sumt' loss_giou = loss_giou.view(-1, nf).contiguous() return { 'box_l1': loss_l1.sum(-1).sum() / num_objs, 'box_giou': loss_giou.sum(-1).sum() / num_objs } def _get_src_permutation_idx(self, indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def _get_tgt_permutation_idx(self, indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx @META_ARCH_REGISTRY.register() class Video_MaskedAttn_MultiscaleMaskDecoder_v3(nn.Module): def __init__(self, configs, multiscale_shapes): super().__init__() d_model = configs['d_model'] attn_configs = configs['attn'] self.video_nqueries = configs['video_nqueries'] self.nlayers = configs['nlayers'] self.memory_scales = configs['memory_scales'] self.mask_scale = configs['mask_scale'] self.mask_spatial_stride = multiscale_shapes[self.mask_scale].spatial_stride num_classes = configs['num_classes'] inputs_projs = configs['inputs_projs'] self.inputs_projs = nn.Sequential() if inputs_projs is not None: self.inputs_projs = META_ARCH_REGISTRY.get(inputs_projs['name'])(inputs_projs, multiscale_shapes=multiscale_shapes, out_dim=d_model) self.level_embeds = nn.Embedding(len(self.memory_scales), d_model) assert self.nlayers % len(self.memory_scales) == 0 self.cross_layers = _get_clones(CrossAttentionLayer(d_model=d_model, nhead=attn_configs['nheads'], dropout=0.0, normalize_before=attn_configs['normalize_before']), self.nlayers) self.self_layers = _get_clones(SelfAttentionLayer(d_model=d_model, nhead=attn_configs['nheads'], dropout=0.0, normalize_before=attn_configs['normalize_before']), self.nlayers) self.ffn_layers = _get_clones(FFNLayer(d_model=d_model, dim_feedforward=attn_configs['dim_feedforward'], dropout=0.0, normalize_before=attn_configs['normalize_before']), self.nlayers) self.nheads = attn_configs['nheads'] self.temporal_query_poses = nn.Embedding(self.video_nqueries, d_model) self.temporal_query_feats = nn.Embedding(self.video_nqueries, d_model) self.temporal_query_norm = nn.LayerNorm(d_model) self.pos_3d = build_position_encoding(hidden_dim=d_model, position_embedding_name='3d') # b t c h w self.head_outputs = configs['head_outputs'] assert 'mask' in self.head_outputs self.query_mask = MLP(d_model, d_model, d_model, 3) if 'class' in self.head_outputs: self.query_class = nn.Linear(d_model, num_classes + 1) self.loss_module = Video_SetMatchingLoss(loss_config=configs['loss'], num_classes=num_classes) @property def device(self,): return self.temporal_query_feats.weight.device def get_memories_and_mask_features(self, multiscales): # b c t h w memories = [multiscales[scale] for scale in self.memory_scales] size_list = [mem_feat.shape[-2:] for mem_feat in memories] memories_poses = [self.pos_3d(mem.permute(0, 2, 1,3, 4)).permute(0, 2, 1, 3, 4) for mem in memories] # b c t h w memories = [rearrange(mem, 'b c t h w -> (t h w) b c').contiguous() for mem in memories] memories_poses = [rearrange(mem_pos, 'b c t h w -> (t h w) b c').contiguous() for mem_pos in memories_poses] mask_features = multiscales[self.mask_scale] # b c t h w return memories, memories_poses, mask_features, size_list def forward(self, multiscales, # b c t h w video_aux_dict=None ): multiscales = self.inputs_projs(multiscales[0]) # thw b c; b c t h w memories, memories_poses, mask_features, size_list = self.get_memories_and_mask_features(multiscales) memories = [mem_feat + self.level_embeds.weight[i][None, None, :] for i, mem_feat in enumerate(memories)] batch_size, _, nf, *_ = mask_features.shape # nq b c temporal_query_poses = self.temporal_query_poses.weight.unsqueeze(1).repeat(1, batch_size, 1) temporal_query_feats = self.temporal_query_feats.weight.unsqueeze(1).repeat(1, batch_size, 1) vid_ret = [] # b nq class, b nq t h w; b*head nq thw vid_class, vid_mask, attn_mask = \ self.forward_heads(temporal_query_feats=temporal_query_feats, mask_features=mask_features, attn_mask_target_size=size_list[0]) # first sight you re not human vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask}) for i in range(self.nlayers): level_index = i % len(self.memory_scales) attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 全masked掉的 全注意, 比如有padding temporal_query_feats = self.cross_layers[i]( tgt=temporal_query_feats, # nq b c memory=memories[level_index], # thw b c memory_mask=attn_mask, # b*h nq thw memory_key_padding_mask=None, pos=memories_poses[level_index], # thw b c query_pos=temporal_query_poses, # nq b c ) temporal_query_feats = self.self_layers[i]( temporal_query_feats, tgt_mask=None, tgt_key_padding_mask=None, query_pos=temporal_query_poses, ) temporal_query_feats = self.ffn_layers[i]( temporal_query_feats ) # b nq class, b nq t h w vid_class, vid_mask, attn_mask = \ self.forward_heads(temporal_query_feats=temporal_query_feats, mask_features=mask_features, attn_mask_target_size=size_list[(i + 1) % len(self.memory_scales)]) # first sight you re not human vid_ret.append({'pred_class':vid_class, 'pred_masks': vid_mask}) return vid_ret def forward_heads(self, temporal_query_feats, mask_features, attn_mask_target_size): # nq b c; b c t h w batch_size, _, nf, *_ = mask_features.shape temporal_query_feats = self.temporal_query_norm(temporal_query_feats) # nq b c temporal_query_feats = temporal_query_feats.transpose(0, 1).contiguous() # b nq c class_logits = self.query_class(temporal_query_feats) if 'class' in self.head_outputs else None # b n class+1 mask_embeds = self.query_mask(temporal_query_feats) # b n c mask_logits = torch.einsum("bqc,bcthw->bqthw", mask_embeds, mask_features) batch_size, nq, nf = mask_logits.shape[:3] mask_logits = F.interpolate(mask_logits.flatten(0, 1), scale_factor=self.mask_spatial_stride, mode='bilinear', align_corners=False) mask_logits = rearrange(mask_logits, '(b n) t h w -> b t n h w',b=batch_size, n=nq) # bt nq h w attn_mask = mask_logits.detach().clone().flatten(0, 1) attn_mask = (F.interpolate(attn_mask, size=attn_mask_target_size, mode="bilinear", align_corners=False) < 0.5).bool() attn_mask = repeat(attn_mask, '(b t) nq h w -> (b head) nq (t h w)', b=batch_size, t=nf, head=self.nheads) if self.training: return class_logits, mask_logits, attn_mask else: return class_logits.softmax(-1).unsqueeze(1).repeat(1, nf, 1, 1) if class_logits is not None else None, mask_logits, attn_mask def compute_loss(self, outputs, targets, video_aux_dict, **kwargs): assert self.training return self.loss_module.compute_loss(model_outs=outputs, targets=targets, video_aux_dict=video_aux_dict) ================================================ FILE: models/encoder/__init__.py ================================================ from . import localGlobal from . import input_projs from . import neighborhood_qk ================================================ FILE: models/encoder/input_projs.py ================================================ import torch.nn as nn from detectron2.modeling import META_ARCH_REGISTRY from models.layers.anyc_trans import Linear_NormAct from models.layers.anyc_trans import Conv3d_NormAct, Conv2d_NormAct from einops import rearrange @META_ARCH_REGISTRY.register() class VideoConv3d_TextLinear(nn.Module): """ 如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim 如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim """ def __init__(self, configs, out_dim, text_dim=None, # 如果是none的话, 那么假设等于out_dim multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() text_dim = out_dim if text_dim is None else text_dim multiscale_projs_config = configs['video_multiscale_projs'] proj_names = multiscale_projs_config.keys() # list[str] in_dims = {} if multiscale_shapes is not None: assert set(proj_names).issubset(set(list(multiscale_shapes.keys()))) for name in proj_names: in_dims[name] = multiscale_shapes[name].dim else: for name in proj_names: in_dims[name] = out_dim projs = {} for name, config in multiscale_projs_config.items(): projs[name] = Conv3d_NormAct(in_channels=in_dims[name], out_channels=out_dim, **config) self.video_multiscale_projs = nn.ModuleDict(projs) text_proj_config = configs['text_proj'] if text_proj_config is None: self.text_proj = nn.Identity() else: self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config) def forward(self, multiscales, text_dict): ret = {} for scale_name, scale_feat in multiscales.items(): # b c t h w if scale_name in self.video_multiscale_projs: scale_feat = self.video_multiscale_projs[scale_name](scale_feat) ret[scale_name] = scale_feat else: ret[scale_name] = scale_feat if isinstance(text_dict, AMRData): text_dict.amr_feats = self.text_proj(text_dict.amr_feats) text_dict.text_feats = self.text_proj(text_dict.text_feats) else: raise ValueError() return ret, text_dict @META_ARCH_REGISTRY.register() class VideoConv2d_TextLinear(nn.Module): """ 如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim 如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim """ def __init__(self, configs, out_dim, text_dim=None, # 如果是none的话, 那么假设等于out_dim multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() text_dim = out_dim if text_dim is None else text_dim multiscale_projs_config = configs['video_multiscale_projs'] proj_names = multiscale_projs_config.keys() # list[str] in_dims = {} if multiscale_shapes is not None: assert set(proj_names).issubset(set(list(multiscale_shapes.keys()))) for name in proj_names: in_dims[name] = multiscale_shapes[name].dim else: for name in proj_names: in_dims[name] = out_dim projs = {} for name, config in multiscale_projs_config.items(): projs[name] = Conv2d_NormAct(in_channels=in_dims[name], out_channels=out_dim, **config) self.video_multiscale_projs = nn.ModuleDict(projs) text_proj_config = configs['text_proj'] if text_proj_config is None: self.text_proj = nn.Identity() else: self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config) def forward(self, multiscales, text_dict): ret = {} for scale_name, scale_feat in multiscales.items(): # b c t h w if scale_name in self.video_multiscale_projs: batch_size, _, nf = scale_feat.shape[:3] scale_feat = rearrange(scale_feat, 'b c t h w -> (b t) c h w') scale_feat = self.video_multiscale_projs[scale_name](scale_feat) scale_feat = rearrange(scale_feat, '(b t) c h w -> b c t h w', b=batch_size, t=nf) ret[scale_name] = scale_feat else: ret[scale_name] = scale_feat if isinstance(text_dict, AMRData): text_dict.amr_feats = self.text_proj(text_dict.amr_feats) text_dict.text_feats = self.text_proj(text_dict.text_feats) else: raise ValueError() return ret, text_dict @META_ARCH_REGISTRY.register() class ImageConv_MultiscaleProj(nn.Module): def __init__(self, configs, out_dim, multiscale_shapes=None, ) -> None: """ 如果multiscale_shape是空, 那么输入的dim = out_dim """ super().__init__() projs_configs = configs['projs'] proj_names = list(projs_configs.keys()) # list[str] in_dims = {} if multiscale_shapes is not None: assert set(proj_names).issubset(set(list(multiscale_shapes.keys()))) for name in proj_names: in_dims[name] = multiscale_shapes[name].dim else: for name in proj_names: in_dims[name] = out_dim projs = {} for name, config in projs_configs.items(): projs[name] = Conv2d_NormAct(in_channels=in_dims[name], out_channels=out_dim, **config) self.multiscale_projs = nn.ModuleDict(projs) def forward(self, multiscales): ret = {} for scale_name, scale_feat in multiscales.items(): if scale_name in self.multiscale_projs: scale_feat = self.multiscale_projs[scale_name](scale_feat) ret[scale_name] = scale_feat else: ret[scale_name] = scale_feat return ret @META_ARCH_REGISTRY.register() class Video2D_ImageConv_MultiscaleProj(nn.Module): def __init__(self, configs, out_dim, multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() self.image_homo = ImageConv_MultiscaleProj(configs=configs, out_dim=out_dim, multiscale_shapes=multiscale_shapes) def forward(self, multiscales): batch_sisze, _, nf = multiscales[list(multiscales.keys())[0]].shape[:3] # b c t h w -> bt c h w multiscales = {key: value.permute(0, 2, 1, 3, 4).flatten(0, 1).contiguous() for key,value in multiscales.items()} multiscales = self.image_homo(multiscales) multiscales = {key: rearrange(value, '(b t) c h w -> b c t h w',b=batch_sisze, t=nf).contiguous()\ for key,value in multiscales.items()} return multiscales @META_ARCH_REGISTRY.register() class VideoConv_MultiscaleProj(nn.Module): def __init__(self, configs, out_dim, multiscale_shapes=None, ) -> None: """ 如果multiscale_shape是空, 那么输入的dim = out_dim """ super().__init__() projs_configs = configs['projs'] proj_names = list(projs_configs.keys()) # list[str] in_dims = {} if multiscale_shapes is not None: assert set(proj_names).issubset(set(list(multiscale_shapes.keys()))) for name in proj_names: in_dims[name] = multiscale_shapes[name].dim else: for name in proj_names: in_dims[name] = out_dim projs = {} for name, config in projs_configs.items(): projs[name] = Conv3d_NormAct(in_channels=in_dims[name], out_channels=out_dim, **config) self.multiscale_projs = nn.ModuleDict(projs) def forward(self, multiscales): ret = {} for scale_name, scale_feat in multiscales.items(): if scale_name in self.multiscale_projs: scale_feat = self.multiscale_projs[scale_name](scale_feat) ret[scale_name] = scale_feat else: ret[scale_name] = scale_feat return ret @META_ARCH_REGISTRY.register() class FrameQueryLinear_TextLinear(nn.Module): def __init__(self, configs, out_dim, text_dim=None, # int query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() query_proj_config = configs['query_proj'] query_dim = out_dim if query_dim is None else query_dim text_dim = out_dim if text_dim is None else text_dim self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config) text_proj_config = configs['text_proj'] if text_proj_config is None: self.text_proj = nn.Identity() else: self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config) def forward(self, frame_query, text_dict): # b T nqf c # text_dict frame_query = self.query_proj(frame_query) if isinstance(text_dict, AMRData): text_dict.amr_feats = self.text_proj(text_dict.amr_feats) text_dict.text_feats = self.text_proj(text_dict.text_feats) else: raise ValueError() return frame_query, text_dict @META_ARCH_REGISTRY.register() class VideoConv3d_FrameQueryLinear_TextLinear(nn.Module): def __init__(self, configs, out_dim, feat_dim=None, text_dim=None, # int query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() query_proj_config = configs['query_proj'] feat_proj_config = configs['feat_proj'] text_proj_config = configs['text_proj'] feat_dim = out_dim if feat_dim is None else feat_dim query_dim = out_dim if query_dim is None else query_dim text_dim = out_dim if text_dim is None else text_dim self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config) if query_proj_config is not None else nn.Identity() self.text_proj = Linear_NormAct(in_features=text_dim, out_features=out_dim, **text_proj_config) if text_proj_config is not None else nn.Identity() self.feat_proj = Conv3d_NormAct(in_channels=feat_dim, out_channels=out_dim, **feat_proj_config) def forward(self, mask_feat, frame_query, text_dict): mask_feat = self.feat_proj(mask_feat) frame_query = self.query_proj(frame_query) if isinstance(text_dict, AMRData): text_dict.amr_feats = self.text_proj(text_dict.amr_feats) text_dict.text_feats = self.text_proj(text_dict.text_feats) else: raise ValueError() return mask_feat, frame_query, text_dict # 每一个module应该都把input进行一边proj, proj到自己的空间里 @META_ARCH_REGISTRY.register() class VideoConv3d_FrameQueryLinear(nn.Module): """ 如果multiscale_shapes是None, 那么每个multiscale_shape的input_dim都是out_dim 如果multiscale_shapes给出了, 那么按照multiscale shapes里的dim """ def __init__(self, configs, out_dim, query_dim=None, # 如果是none的话, 那么假设等于out_dim multiscale_shapes=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() query_dim = out_dim if query_dim is None else query_dim multiscale_projs_config = configs['video_multiscale_projs'] proj_names = multiscale_projs_config.keys() # list[str] in_dims = {} if multiscale_shapes is not None: assert set(proj_names).issubset(set(list(multiscale_shapes.keys()))) for name in proj_names: in_dims[name] = multiscale_shapes[name].dim else: for name in proj_names: in_dims[name] = out_dim projs = {} for name, config in multiscale_projs_config.items(): projs[name] = Conv3d_NormAct(in_channels=in_dims[name], out_channels=out_dim, **config) self.video_multiscale_projs = nn.ModuleDict(projs) query_proj_config = configs['query_proj'] if query_proj_config is None: self.query_proj = nn.Identity() else: self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config) def forward(self, multiscales, frame_queries): # b t nq c ret = {} for scale_name, scale_feat in multiscales.items(): # b c t h w if scale_name in self.video_multiscale_projs: scale_feat = self.video_multiscale_projs[scale_name](scale_feat) ret[scale_name] = scale_feat else: ret[scale_name] = scale_feat frame_queries = self.query_proj(frame_queries) return ret, frame_queries @META_ARCH_REGISTRY.register() class FrameQueryLinear(nn.Module): def __init__(self, configs, out_dim, query_dim=None, # scale_name: (dim, [temporal_scale, spatial_scale]) ) -> None: super().__init__() query_proj_config = configs['query_proj'] query_dim = out_dim if query_dim is None else query_dim self.query_proj = Linear_NormAct(in_features=query_dim, out_features=out_dim, **query_proj_config) def forward(self, frame_query): # b T nqf c frame_query = self.query_proj(frame_query) return frame_query ================================================ FILE: models/encoder/localGlobal.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. import logging import numpy as np from typing import Callable, Dict, List, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ from torch.cuda.amp import autocast from detectron2.config import configurable from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.modeling import META_ARCH_REGISTRY from models.layers.position_encoding import PositionEmbeddingSine from models.layers.utils import _get_clones, _get_activation_fn from .ops.modules import MSDeformAttn # MSDeformAttn Transformer encoder in deformable detr class MSDeformAttnTransformerEncoderOnly(nn.Module): def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", num_feature_levels=4, enc_n_points=4, add_local = False, add_global=False, local_configs=None, global_configs=None, frame_nqueries=None, ): super().__init__() self.d_model = d_model self.nhead = nhead encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model = d_model, d_ffn = dim_feedforward, dropout = dropout, activation = activation, n_levels = num_feature_levels, n_heads = nhead, n_points = enc_n_points, add_local = add_local, add_global = add_global, local_configs = local_configs, global_configs = global_configs ) self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers, d_model=d_model, frame_nqueries=frame_nqueries) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape # b h w valid_H = torch.sum(~mask[:, :, 0], 1) # b valid_W = torch.sum(~mask[:, 0, :], 1) # b valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) # b 2 return valid_ratio def forward(self, srcs=None, pos_embeds=None, video_aux_dict=None, **kwargs): masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # b #scale 2 # encoder memory, frame_feats, frame_poses = self.encoder(src=src_flatten, # bt hw_sigma c spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, pos=lvl_pos_embed_flatten, padding_mask=mask_flatten, video_aux_dict=video_aux_dict) return memory, spatial_shapes, level_start_index, frame_feats, frame_poses class MSDeformAttnTransformerEncoderLayer(nn.Module): def __init__(self, d_model=256, d_ffn=1024, dropout=0.1, activation="relu", n_levels=4, n_heads=8, n_points=4, add_local=False, add_global=False, local_configs=None, global_configs=None): super().__init__() # deform2d self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) self.dropout1 = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) self.add_local = add_local if self.add_local: from .neighborhood_qk import NA_qk_Layer # self self.local_cnp = NA_qk_Layer(d_model=d_model, configs=local_configs) self.add_global = add_global if self.add_global: from models.layers.decoder_layers import CrossAttentionLayer # cross self.frame_query_cross_multiscale = CrossAttentionLayer(d_model=d_model, nhead=8, dropout=0.0, activation="relu", normalize_before=False) self.cross_num_heads = 8 self.global_add_attn_mask = global_configs['add_attn_mask'] if 'add_attn_mask' in global_configs else False # self+ffn from models.encoder.ops.modules.frame_query_ss2d import FrameQuery_SS2DLayer_hilbert self.global_hiss = FrameQuery_SS2DLayer_hilbert(global_configs) self.multiscale_cross_query = CrossAttentionLayer(d_model=d_model, nhead=8, dropout=0.0, activation="relu", normalize_before=False) # ffn self.linear1 = nn.Linear(d_model, d_ffn) self.activation = _get_activation_fn(activation) self.dropout2 = nn.Dropout(dropout) self.linear2 = nn.Linear(d_ffn, d_model) self.dropout3 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) @staticmethod def with_pos_embed(tensor, pos): return tensor if pos is None else tensor + pos def forward_ffn(self, src): src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) src = src + self.dropout3(src2) src = self.norm2(src) return src @torch.no_grad() def get_attn_mask(self, frame_query_feats, src, spatial_shapes, level_start_index,): # nq bt c # bt hw_sigma c assert len(spatial_shapes) == 3 frame_query_feats = frame_query_feats.permute(1, 0, 2) # bt nq c feat = src[:, level_start_index[-1]: (level_start_index[-1] + spatial_shapes[-1][0] * spatial_shapes[-1][1])] feat = rearrange(feat, 'b (h w) c -> b c h w',h=spatial_shapes[-1][0],w=spatial_shapes[-1][1]) mask = torch.einsum('bnc, bchw -> b n h w',frame_query_feats, feat) mask_2 = F.interpolate(mask, size=spatial_shapes[0].tolist(), mode='bilinear',align_corners=False) mask_3 = F.interpolate(mask, size=spatial_shapes[1].tolist(), mode='bilinear', align_corners=False) attn_mask = torch.cat([mask_2.flatten(2), mask_3.flatten(2), mask.flatten(2)], dim=-1) #bt n hw_sigma attn_mask = (attn_mask.unsqueeze(1).repeat(1, self.cross_num_heads, 1, 1).flatten(0, 1).sigmoid() < 0.5).bool() return attn_mask def forward(self, src=None, pos=None, reference_points=None, spatial_shapes=None, level_start_index=None, padding_mask=None, video_aux_dict=None, frame_query_feats=None, # nq bt c frame_query_poses=None): if self.add_local: # local_self src = self.local_cnp(tgt=src, scale_shapes=spatial_shapes, level_start_idxs=level_start_index, nf=video_aux_dict['nf']) if self.add_global: if self.global_add_attn_mask: attn_mask = self.get_attn_mask(frame_query_feats, src, spatial_shapes, level_start_index,) # bthead nq hw attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False # 全masked掉的 全注意, 比如有padding else: attn_mask = None # cross frame_query_feats = self.frame_query_cross_multiscale( tgt=frame_query_feats, # nq bt c memory=src.permute(1, 0, 2), # hw_sigma bt c memory_mask=attn_mask, memory_key_padding_mask=None, pos= pos.permute(1,0,2), query_pos=frame_query_poses, ) # self+ffn frame_query_feats = self.global_hiss(frame_query_feats=frame_query_feats, frame_query_poses=frame_query_poses, video_aux_dict=video_aux_dict) # self src = self.multiscale_cross_query( tgt=src.permute(1, 0, 2), # hw_sigma bt c memory=frame_query_feats, # nq bt c memory_mask=None, memory_key_padding_mask=None, pos= frame_query_poses, query_pos=pos.permute(1,0,2), ).permute(1, 0, 2) # self attention src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)[0] src = src + self.dropout1(src2) src = self.norm1(src) # ffn src = self.forward_ffn(src) return src, frame_query_feats class MSDeformAttnTransformerEncoder(nn.Module): def __init__(self, encoder_layer=None, num_layers=None, d_model=None, frame_nqueries=None): super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.frame_nqueries = frame_nqueries # 10 self.frame_query_feats = nn.Embedding(self.frame_nqueries, d_model) self.frame_query_poses = nn.Embedding(self.frame_nqueries, d_model) @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device): # b #scale 2, valid_w(0-1), valid_h(0-1), 整个feature map有多少是非padding的 # list[h w] #scale reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) # 1 hw / b 1 -> b hw(0-1), y的绝对坐标 ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) # 1 hw / b 1 -> b hw(0-1), x的绝对坐标 ref = torch.stack((ref_x, ref_y), -1) # b hw 2 reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) # b hw_sigma 2, 每个点的相对坐标(0-1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] # b hw_sigma 1 2 * b 1 #scale 2 return reference_points # b hw_sigma #scale 2 def forward(self, src, # bt hw_sigma c spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None, video_aux_dict=None): output = src # bt hw_sigma c batch_size_nf = output.shape[0] frame_query_feats = self.frame_query_feats.weight.unsqueeze(1).repeat(1,batch_size_nf, 1) frame_query_poses = self.frame_query_poses.weight.unsqueeze(1).repeat(1,batch_size_nf,1) # n bt c reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) frame_feats = [] for _, layer in enumerate(self.layers): output, frame_query_feats = layer(src=output, pos=pos, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, padding_mask=padding_mask, video_aux_dict=video_aux_dict, frame_query_feats=frame_query_feats, frame_query_poses=frame_query_poses) frame_feats.append(frame_query_feats) return output, frame_feats, frame_query_poses import copy from einops import rearrange from models.layers.utils import _get_clones from models.layers.position_encoding import build_position_encoding # video multiscale, text_dict @META_ARCH_REGISTRY.register() class Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal(nn.Module): def __init__( self, configs, multiscale_shapes, # {'res2': .temporal_stride, .spatial_stride, .dim} ): super().__init__() d_model = configs['d_model'] fpn_norm = configs['fpn_norm'] # fpn的norm nlayers = configs['nlayers'] # 4, 8, 16, 32 self.multiscale_shapes = dict(sorted(copy.deepcopy(multiscale_shapes).items(), key=lambda x: x[1].spatial_stride)) self.encoded_scales = sorted(configs['encoded_scales'], key=lambda x:self.multiscale_shapes[x].spatial_stride) # res3, res4, res5 # 4 -> 8 -> 16 -> 32 self.scale_dims = [val.dim for val in multiscale_shapes.values()] self.video_projs = META_ARCH_REGISTRY.get(configs['video_projs']['name'])(configs=configs['video_projs'], multiscale_shapes=multiscale_shapes, out_dim=d_model) self.pos_2d = build_position_encoding(position_embedding_name='2d') deform_attn = configs['deform_attn'] self.transformer = MSDeformAttnTransformerEncoderOnly( d_model=d_model, dropout=deform_attn['dropout'], nhead=deform_attn['nheads'], dim_feedforward=deform_attn['dim_feedforward'], activation=deform_attn['activation'], num_encoder_layers=nlayers, num_feature_levels=len(self.encoded_scales), enc_n_points=deform_attn['enc_n_points'], add_local = configs['add_local'], add_global = configs['add_global'], local_configs = configs['local_configs'], global_configs = configs['global_configs'], frame_nqueries=configs['frame_nqueries'] ) min_encode_stride = self.multiscale_shapes[self.encoded_scales[0]].spatial_stride # 8 min_stride = list(self.multiscale_shapes.values())[0].spatial_stride # 4 self.num_fpn_levels = int(np.log2(min_encode_stride) - np.log2(min_stride)) lateral_convs = [] output_convs = [] use_bias = fpn_norm == "" for idx, in_channels in enumerate(self.scale_dims[:self.num_fpn_levels]): lateral_norm = get_norm(fpn_norm, d_model) output_norm = get_norm(fpn_norm, d_model) lateral_conv = Conv2d(in_channels, d_model, kernel_size=1, bias=use_bias, norm=lateral_norm) output_conv = Conv2d(d_model, d_model, kernel_size=3, padding=1, bias=use_bias, norm=output_norm, activation=F.relu) self.add_module("adapter_{}".format(idx + 1), lateral_conv) self.add_module("layer_{}".format(idx + 1), output_conv) lateral_convs.append(lateral_conv) output_convs.append(output_conv) # Place convs into top-down order (from low to high resolution) # to make the top-down computation in forward clearer. self.lateral_convs = lateral_convs[::-1] # 8 4 self.output_convs = output_convs[::-1] # 8 4 def forward(self, multiscales=None, # b c t h w video_aux_dict=None, # dict{} **kwargs): batch_size, _, nf = multiscales[list(multiscales.keys())[0]].shape[:3] video_aux_dict['nf'] = nf multiscales = self.video_projs(multiscales) assert set(list(multiscales.keys())).issubset(set(list(self.multiscale_shapes.keys()))) assert set(list(self.multiscale_shapes.keys())).issubset(set(list(multiscales.keys()))) srcs = [] poses = [] # 32, 16, 8 for idx, scale_name in enumerate(self.encoded_scales[::-1]): x = multiscales[scale_name].permute(0, 2, 1, 3, 4).flatten(0,1).contiguous() # bt c h w srcs.append(x) poses.append(self.pos_2d(torch.zeros_like(x)[:, 0, :, :].bool(), hidden_dim=x.shape[1])) memory, spatial_shapes, level_start_index, frame_feats, frame_poses = self.transformer(srcs=srcs, pos_embeds=poses, video_aux_dict=video_aux_dict) bs = memory.shape[0] spatial_index = 0 memory_features = [] # 32 16 8 for lvl in range(len(self.encoded_scales)): h, w = spatial_shapes[lvl] memory_lvl = memory[:, spatial_index : spatial_index + h * w, :].reshape(bs, h, w, -1).permute(0, 3, 1, 2).contiguous() memory_features.append(memory_lvl) spatial_index += h * w for idx, f in enumerate(list(self.multiscale_shapes.keys())[:self.num_fpn_levels][::-1]): x = multiscales[f].permute(0, 2, 1, 3, 4).flatten(0,1).contiguous() # bt c h w cur_fpn = self.lateral_convs[idx](x) y = cur_fpn + F.interpolate(memory_features[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) y = self.output_convs[idx](y) memory_features.append(y) assert len(memory_features) == len(list(self.multiscale_shapes.keys())) ret = {} for key, out_feat in zip(list(self.multiscale_shapes.keys()), memory_features[::-1]): ret[key] = rearrange(out_feat, '(b t) c h w -> b c t h w', b=batch_size, t=nf) return ret, frame_feats[::-1], frame_poses # 32, 16, 8 ================================================ FILE: models/encoder/neighborhood_qk.py ================================================ from typing import Optional import torch from torch import nn, Tensor from torch.nn.functional import pad from torch.nn.init import trunc_normal_ from natten.functional import na2d_av, na2d_qk_with_bias from einops import rearrange from natten import NeighborhoodAttention2D from detectron2.modeling import META_ARCH_REGISTRY class NeighborhoodAttention2D_qk(nn.Module): """ Neighborhood Attention 2D Module """ def __init__( self, dim: int, num_heads: int, kernel_size: int, dilation: int = 1, bias: bool = True, qkv_bias: bool = True, qk_scale: Optional[float] = None, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() self.num_heads = num_heads self.head_dim = dim // self.num_heads self.scale = qk_scale or self.head_dim**-0.5 assert ( kernel_size > 1 and kernel_size % 2 == 1 ), f"Kernel size must be an odd number greater than 1, got {kernel_size}." self.kernel_size = kernel_size assert ( dilation is None or dilation >= 1 ), f"Dilation must be greater than or equal to 1, got {dilation}." self.dilation = dilation or 1 self.window_size = self.kernel_size * self.dilation self.q_linear = nn.Linear(dim, dim, bias=qkv_bias) self.kv_linear = nn.Linear(dim, dim * 2, bias=qkv_bias) if bias: self.rpb = nn.Parameter( torch.zeros(num_heads, (2 * kernel_size - 1), (2 * kernel_size - 1)) ) trunc_normal_(self.rpb, std=0.02, mean=0.0, a=-2.0, b=2.0) else: self.register_parameter("rpb", None) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x_q: Tensor, x_k: Tensor) -> Tensor: # bt h w c; bt h w c, 前一帧 if x_q.dim() != 4: raise ValueError( f"NeighborhoodAttention2D expected a rank-4 input tensor; got {x.dim()=}." ) B, H, W, C = x_q.shape # Pad if the input is small than the minimum supported size H_padded, W_padded = H, W padding_h = padding_w = 0 if H < self.window_size or W < self.window_size: padding_h = max(0, self.window_size - H_padded) padding_w = max(0, self.window_size - W_padded) x_q = pad(x_q, (0, 0, 0, padding_w, 0, padding_h)) x_k = pad(x_k, (0, 0, 0, padding_w, 0, padding_h)) _, H_padded, W_padded, _ = x_q.shape # b h w c -> b h w h c_h q = self.q_linear(x_q).reshape(B, H_padded, W_padded, self.num_heads, self.head_dim) q = q.permute(0, 3, 1, 2, 4) # b head h w c_h kv = ( self.kv_linear(x_k) .reshape(B, H_padded, W_padded, 2, self.num_heads, self.head_dim) .permute(3, 0, 4, 1, 2, 5) ) # b k, v = kv[0], kv[1] q = q * self.scale attn = na2d_qk_with_bias(q, k, self.rpb, self.kernel_size, self.dilation) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_q = na2d_av(attn, v, self.kernel_size, self.dilation) # b head h w c_h x_q = x_q.permute(0, 2, 3, 1, 4).reshape(B, H_padded, W_padded, C) # b h w head c_h # Remove padding, if added any if padding_h or padding_w: x_q = x_q[:, :H, :W, :].contiguous() return self.proj_drop(self.proj(x_q)) def extra_repr(self) -> str: return ( f"head_dim={self.head_dim}, num_heads={self.num_heads}, " + f"kernel_size={self.kernel_size}, dilation={self.dilation}, " + f"has_bias={self.rpb is not None}" ) from models.layers.utils import _get_clones class NA_qk_Layer(nn.Module): def __init__(self, d_model, configs): super().__init__() self.self_attn = NeighborhoodAttention2D_qk(dim=configs['d_model'], num_heads=configs['num_heads'], kernel_size=configs['kernel_size'], dilation=configs['dilation'], bias=False, qkv_bias=False,) self.num_steps = configs['num_steps'] if 'num_steps' in configs else 1 self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(configs['dropout']) def forward(self, tgt=None, scale_shapes=None, level_start_idxs=None, nf=None): # bt hw_sigma c -> list[b t h w c], 3 video_feats = [tgt[:, start_idx:(start_idx + haosen[0]*haosen[1])].contiguous() for start_idx, haosen in zip(level_start_idxs, scale_shapes)] video_feats = [rearrange(haosen, '(b t) (h w) c -> b t h w c', t=nf, h=scale_shapes[idx][0], w=scale_shapes[idx][1]).contiguous() for idx, haosen in enumerate(video_feats)] video_key_feats = [] for haosen in video_feats: scale_feats = torch.stack([torch.roll(haosen, shifts=k, dims=1) for k in range(1, self.num_steps+1)], dim=0) # s b t h w c video_key_feats.append(scale_feats.flatten(0, 2)) #sbt h w c # sbt h w c video_feats = [haosen.unsqueeze(0).repeat(self.num_steps, 1,1,1,1,1).flatten(0, 2) for haosen in video_feats] local_feats = [] # list[sbt h w c] for idx, (q_feat, k_feat) in enumerate(zip(video_feats, video_key_feats)): local_feats.append(self.self_attn(q_feat, k_feat)) local_feats = [rearrange(haosen, '(s bt) h w c -> s bt h w c',s=self.num_steps) for haosen in local_feats] local_feats = [haosen.sum(dim=0) for haosen in local_feats] # bt h w c local_feats = torch.cat([haosen.flatten(1, 2) for haosen in local_feats], dim=1) # bt hw_sigma c tgt = tgt + self.dropout(local_feats) tgt = self.norm(tgt) return tgt @META_ARCH_REGISTRY.register() class NA_qk_Layer_v2(nn.Module): def __init__(self, configs): super().__init__() self.self_attn = NeighborhoodAttention2D_qk(dim=configs['d_model'], num_heads=configs['num_heads'], kernel_size=configs['kernel_size'], dilation=configs['dilation'], bias=False, qkv_bias=False,) def forward(self, query=None, spatial_shapes=None, level_start_index=None, video_aux_dict=None,): # bt hw_sigma c -> list[b t h w c], 3 video_feat = [query[:, start_idx:(start_idx + haosen[0]*haosen[1])].contiguous() for start_idx, haosen in zip(level_start_index, spatial_shapes)] video_feat = [rearrange(haosen, '(b t) (h w) c -> b t h w c',t=video_aux_dict['nf'], h=spatial_shapes[idx][0], w=spatial_shapes[idx][1]).contiguous() for idx, haosen in enumerate(video_feat)] video_key_feats = [torch.roll(haosen, shifts=1, dims=1).contiguous() for haosen in video_feat] local_feats = [] # list[bt h w c] for idx, (q_feat, k_feat) in enumerate(zip(video_feat, video_key_feats)): local_feats.append(self.self_attn(q_feat.flatten(0, 1), k_feat.flatten(0, 1))) local_feats = torch.cat([haosen.flatten(1, 2) for haosen in local_feats], dim=1) # bt hw_sigma c return local_feats, None ================================================ FILE: models/encoder/ops/MultiScaleDeformableAttention.egg-info/PKG-INFO ================================================ Metadata-Version: 2.1 Name: MultiScaleDeformableAttention Version: 1.0 Summary: PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention Home-page: https://github.com/fundamentalvision/Deformable-DETR Author: Weijie Su ================================================ FILE: models/encoder/ops/attention.py ================================================ from inspect import isfunction import math import torch from torch.nn.init import xavier_uniform_, constant_ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from .functions import MSDeformAttnFunction import warnings def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) return self.to_out(out) class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = rearrange(v, 'b c h w -> b c (h w)') w_ = rearrange(w_, 'b i j -> b j i') h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) return x+h_ class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True): super().__init__() self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) def forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class DeformAttn(nn.Module): def __init__(self, d_model=256, nheads=8, npoints=4, nlevels=4, key_dim=None): super().__init__() query_dim = d_model key_dim = d_model head_dim = d_model // nheads if d_model % nheads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, nheads)) # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(head_dim): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = nheads * head_dim key_dim = default(key_dim, query_dim) self.value_proj = nn.Linear(key_dim, nheads * head_dim) self.sampling_offsets = nn.Linear(query_dim, nheads * nlevels * npoints * 2) self.attention_weights = nn.Linear(query_dim, nheads * nlevels * npoints) self.output_proj = nn.Linear(nheads * head_dim, query_dim) self.n_heads = nheads self.n_levels = nlevels self.head_dim = head_dim self.n_points = npoints self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index,input_padding_mask=None): """ multi-scale deformable attention, self attention if query == input_flatten Input: - query: T(b n c) - reference_points: center or reference boxes, normalized, [0, 1], including padding area, add additional (w, h) to form reference boxes T(b n level 2) or T(b n level 4) - input_flatten: multi-scale特征 T(b (h_\sigma w_\sigma) c) - input_spatial_shapes: 每个level的大小 T(level 2) - input_level_start_index: [0, level1_start, level2_start] - input_padding_mask: True/False T(b), (h_\sigma w_\sigma)) Output: - query results: T(b, n c) - sampling_locations: normalized T(b, n, m*l*k, 2) - attention_weights: after softmax T(b, n, m*l*k) """ batch_size, Nq, _ = query.shape _, Nk, _ = input_flatten.shape assert Nk == (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() # B (h w) M * V value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[...,None], float(0)) value = value.view(batch_size, Nk, self.n_heads, self.head_dim) sampling_offesets = self.sampling_offsets(query).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points, 2) attention_weights= self.attention_weights(query).view(batch_size, Nq, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, dim=-1).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points) # b, n ,head, level, point, 2 if reference_points.shape[-1] == 2: # T(2 level) offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] + \ sampling_offesets / offset_normalizer[None, None, None, :, None,:] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] + \ sampling_offesets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise NotImplementedError output = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output, sampling_locations, attention_weights class ContextuallSelfAttention(nn.Module): def __init__(self, d_model, n_points, n_heads, context_dim=None): super().__init__() context_dim = default(context_dim, d_model) query_dim = key_dim = d_model if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) head_dim = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(head_dim): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = d_model self.nheads = n_heads self.head_dim = head_dim self.npoints = n_points self.nlevels = 1 self.value_proj = nn.Linear(key_dim, n_heads * head_dim) self.sampling_offsets = nn.Linear(query_dim, n_heads * n_points * 2) self.attention_weights = nn.Linear(query_dim, n_heads * n_points) self.output_proj = nn.Linear(n_heads * head_dim, query_dim) def forward(self, context, context_mask, query, reference_points, query_padding_mask = None,): """ contextual deformable attention Input: - context: T(b n c) - context_mask: T(b n) - query: T(b (h w) c) - reference_points: center or reference boxes, normalized, [0, 1], including padding area, T(b (h w) 2/4) - query_padding_mask: T(b (h w)) Output: - query results: T(b, n c) - sampling_locations: normalized T(b, n, m*l*k, 2) - attention_weights: after softmax T(b, n, m*l*k) """ key = query key_padding_mask = query_padding_mask batch_size, Nq, _ = query.shape Nk = Nq input_spatial_shapes = torch.tensor(query.shape[-2:]).unsqueeze(0) # T(1, 2) input_level_start_index = [0, ] # B (h w) M * V value = self.value_proj(key) if key_padding_mask is not None: value = value.masked_fill(key_padding_mask[..., None], float(0)) value = value.view(batch_size, Nk, self.nheads, self.head_dim) sampling_offesets = self.sampling_offsets(query).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points, 2) attention_weights= self.attention_weights(query).view(batch_size, Nq, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, dim=-1).view(batch_size, Nq, self.n_heads, self.n_levels, self.n_points) # b, n ,head, level, point, 2 if reference_points.shape[-1] == 2: # T(2 level) offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] + \ sampling_offesets / offset_normalizer[None, None, None, :, None,:] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] + \ sampling_offesets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise NotImplementedError output = MSDeformAttnFunction.apply(value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output, sampling_locations, attention_weights class BasicTransformerBlock_v2(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True): super().__init__() self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) def forward(self, x, context=None): x = self.norm1(self.attn1(x, context=context) + x) x = self.norm2(self.ff(x) + x) return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention b, c, h, w = x.shape x_in = x x = self.norm(x) x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c') for block in self.transformer_blocks: x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable import MultiScaleDeformableAttention as MSDA class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn import MSDeformAttn ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-311/modules/ms_deform_attn.py ================================================ # Modify for sample points visualization # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output, sampling_locations, attention_weights ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable import MultiScaleDeformableAttention as MSDA class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn import MSDeformAttn ================================================ FILE: models/encoder/ops/build/lib.linux-x86_64-cpython-38/modules/ms_deform_attn.py ================================================ # Modify for sample points visualization # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 64 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output, sampling_locations, attention_weights ================================================ FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-311/.ninja_log ================================================ # ninja log v5 0 5344 1685604027 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.o 1eaabdd4515aceab 1 20910 1685604042 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/vision.o b8641c4a4f7766f9 0 21063 1685604042 /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.o d77fcd8ae1c377bb ================================================ FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-311/build.ninja ================================================ ninja_required_version = 1.3 cxx = c++ nvcc = /usr/local/cuda/bin/nvcc cflags = -pthread -B /home/xhh/anaconda3/envs/natten/compiler_compat -DNDEBUG -fwrapv -O2 -Wall -fPIC -O2 -isystem /home/xhh/anaconda3/envs/natten/include -fPIC -O2 -isystem /home/xhh/anaconda3/envs/natten/include -fPIC -DWITH_CUDA -I/home/xhh/workspace/rvos_encoder/models/ops/src -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/TH -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/xhh/anaconda3/envs/natten/include/python3.11 -c post_cflags = -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 cuda_cflags = -DWITH_CUDA -I/home/xhh/workspace/rvos_encoder/models/ops/src -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/TH -I/home/xhh/anaconda3/envs/natten/lib/python3.11/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/xhh/anaconda3/envs/natten/include/python3.11 -c cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=MultiScaleDeformableAttention -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 cuda_dlink_post_cflags = ldflags = rule compile command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags depfile = $out.d deps = gcc rule cuda_compile depfile = $out.d deps = gcc command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.o: compile /home/xhh/workspace/rvos_encoder/models/ops/src/cpu/ms_deform_attn_cpu.cpp build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.o: cuda_compile /home/xhh/workspace/rvos_encoder/models/ops/src/cuda/ms_deform_attn_cuda.cu build /home/xhh/workspace/rvos_encoder/models/ops/build/temp.linux-x86_64-cpython-311/home/xhh/workspace/rvos_encoder/models/ops/src/vision.o: compile /home/xhh/workspace/rvos_encoder/models/ops/src/vision.cpp ================================================ FILE: models/encoder/ops/build/temp.linux-x86_64-cpython-38/home/xhh/workspace/ReferFormer/models/ops/src/vision.o ================================================ [File too large to display: 12.3 MB] ================================================ FILE: models/encoder/ops/functions/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn_func import MSDeformAttnFunction ================================================ FILE: models/encoder/ops/functions/ms_deform_attn_func.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import torch import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable import MultiScaleDeformableAttention as MSDA class MSDeformAttnFunction(Function): @staticmethod def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): ctx.im2col_step = im2col_step output = MSDA.ms_deform_attn_forward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) return output @staticmethod @once_differentiable def backward(ctx, grad_output): value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors grad_value, grad_sampling_loc, grad_attn_weight = \ MSDA.ms_deform_attn_backward( value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): # for debug and test only, # need to use cuda version instead N_, S_, M_, D_ = value.shape _, Lq_, M_, L_, P_, _ = sampling_locations.shape value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) sampling_grids = 2 * sampling_locations - 1 sampling_value_list = [] for lid_, (H_, W_) in enumerate(value_spatial_shapes): # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) # N_*M_, D_, Lq_, P_ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, mode='bilinear', padding_mode='zeros', align_corners=False) sampling_value_list.append(sampling_value_l_) # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) return output.transpose(1, 2).contiguous() ================================================ FILE: models/encoder/ops/make.sh ================================================ #!/usr/bin/env bash # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ python setup.py build install ================================================ FILE: models/encoder/ops/modules/__init__.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from .ms_deform_attn import MSDeformAttn from . import frame_query_ss2d ================================================ FILE: models/encoder/ops/modules/frame_query_ss2d.py ================================================ from models.layers.position_encoding import build_position_encoding from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref import math import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction from mamba_ssm import Mamba from einops import rearrange, reduce, repeat from detectron2.modeling import META_ARCH_REGISTRY # v1 class SS2D(nn.Module): def __init__( self, d_model, d_state=16, # d_state="auto", # 20240109 d_conv=3, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, dropout=0., conv_bias=True, bias=False, device=None, dtype=None, **kwargs, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv2d = nn.Conv2d( in_channels=self.d_inner, out_channels=self.d_inner, groups=self.d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) self.act = nn.SiLU() self.x_proj = ( nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), ) self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner) del self.x_proj self.dt_projs = ( self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), ) self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner) del self.dt_projs self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N) self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) # self.selective_scan = selective_scan_fn self.forward_core = self.forward_corev0 self.out_norm = nn.LayerNorm(self.d_inner) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else None @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 1: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 1: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def forward_corev0(self, x: torch.Tensor): self.selective_scan = selective_scan_fn B, C, H, W = x.shape L = H * W K = 4 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=dt_projs_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, L) assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y def forward(self, x: torch.Tensor, **kwargs): B, H, W, C = x.shape xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # (b, h, w, d) x = x.permute(0, 3, 1, 2).contiguous() x = self.act(self.conv2d(x)) # (b, d, h, w) y1, y2, y3, y4 = self.forward_core(x) # B C hw assert y1.dtype == torch.float32 y = y1 + y2 + y3 + y4 y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) y = self.out_norm(y) y = y * F.silu(z) out = self.out_proj(y) if self.dropout is not None: out = self.dropout(out) return out class SS2D_FrameQuery(nn.Module): def __init__(self, configs,): super().__init__() d_model = configs['d_model'] self.homo = SS2D(d_model=configs['d_model'], d_state=configs['d_state'] if 'd_state' in configs else 16, d_conv=configs['d_conv'] if 'd_conv' in configs else 3, expand=configs['expand'] if 'expand' in configs else 2, dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto', dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001, dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1, dt_init=configs['dt_init'] if 'dt_init' in configs else 'random', dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0, dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4, dropout=configs['dropout'] if 'dropout' in configs else 0, conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True, bias=configs['bias'] if 'bias' in configs else False, ) self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding def forward(self, frame_query_feats=None, # n bt c frame_query_poses=None, # n bt c # nq上的Position embedding nf=None, **kwargs ): batch_size = frame_query_feats.shape[1] // nf # b frame_query_feats += frame_query_poses frame_query_feats = rearrange(frame_query_feats, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous() sin_poses = self.pos_1d(torch.zeros_like(frame_query_feats[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(), hidden_dim=frame_query_feats.shape[-1]) # bn c t sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size) frame_query_feats += sin_poses frame_query_feats = self.homo(frame_query_feats) # b t n c frame_query_feats = frame_query_feats.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c return frame_query_feats, None @META_ARCH_REGISTRY.register() class FrameQuery_SS2DLayer(nn.Module): def __init__(self, configs, dropout=0.0, activation="relu", normalize_before=False): super().__init__() d_model = configs['d_model'] dropout = configs['dropout'] self.self_attn = SS2D_FrameQuery(configs) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) from models.layers.decoder_layers import FFNLayer self.ffn = FFNLayer(d_model=d_model, dim_feedforward=configs['dim_feedforward'], dropout=dropout,) def forward(self, frame_query_feats, # n bt c frame_query_poses, # n bt c # nq上的Position embedding nf=None, **kwargs): tgt2 = self.self_attn(frame_query_feats=frame_query_feats, # n bt c frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding nf=nf,)[0] frame_query_feats = frame_query_feats + self.dropout(tgt2) frame_query_feats = self.norm(frame_query_feats) frame_query_feats = self.ffn(frame_query_feats) return frame_query_feats from models.layers.decoder_layers import CrossAttentionLayer, SelfAttentionLayer, FFNLayer @META_ARCH_REGISTRY.register() class TemporalQuery_CrossSelf(nn.Module): def __init__(self, configs) -> None: super().__init__() d_model = configs['d_model'] attn_configs = configs['attn'] self.cross_layers = CrossAttentionLayer(d_model=d_model, nhead=attn_configs['nheads'], dropout=0.0, normalize_before=attn_configs['normalize_before']) self.self_layers = SelfAttentionLayer(d_model=d_model, nhead=attn_configs['nheads'], dropout=0.0, normalize_before=attn_configs['normalize_before']) self.ffn_layers = FFNLayer(d_model=d_model, dim_feedforward=attn_configs['dim_feedforward'], dropout=0.0, normalize_before=attn_configs['normalize_before']) def forward(self, temporal_query_feats, temporal_query_poses, frame_query_feats, frame_query_poses, video_aux_dict=None, **kwargs): # nq b c; nq bt c nq, batch_size, _ = temporal_query_feats.shape nf = frame_query_feats.shape[1] // batch_size nqf = frame_query_feats.shape[0] frame_query_feats = rearrange(frame_query_feats, 'nq (b t) c -> (t nq) b c',b=batch_size, t=nf) frame_query_poses = rearrange(frame_query_poses, 'nq (b t) c -> (t nq) b c',b=batch_size, t=nf) temporal_query_feats = self.cross_layers( tgt=temporal_query_feats, # n b c memory=frame_query_feats, # t nqf b c pos=frame_query_poses, query_pos=temporal_query_poses, ) temporal_query_feats = self.self_layers( temporal_query_feats, query_pos=temporal_query_poses, ) temporal_query_feats = self.ffn_layers( temporal_query_feats ) return temporal_query_feats # v2 多层 class SS2D_FrameQuery_v2(nn.Module): def __init__(self, configs,): super().__init__() d_model = configs['d_model'] self.homo = SS2D(d_model=configs['d_model'], d_state=configs['d_state'] if 'd_state' in configs else 16, d_conv=configs['d_conv'] if 'd_conv' in configs else 3, expand=configs['expand'] if 'expand' in configs else 2, dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto', dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001, dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1, dt_init=configs['dt_init'] if 'dt_init' in configs else 'random', dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0, dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4, dropout=configs['dropout'] if 'dropout' in configs else 0, conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True, bias=configs['bias'] if 'bias' in configs else False, ) self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(configs['dropout']) def forward(self, frame_query_feats=None, # n bt c frame_query_poses=None, # n bt c # nq上的Position embedding nf=None, **kwargs ): batch_size = frame_query_feats.shape[1] // nf # b tgt2 = frame_query_feats + frame_query_poses tgt2 = rearrange(tgt2, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous() sin_poses = self.pos_1d(torch.zeros_like(tgt2[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(), hidden_dim=tgt2.shape[-1]) # bn c t sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size) tgt2 += sin_poses tgt2 = self.homo(tgt2) # b t n c tgt2 = tgt2.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c frame_query_feats = frame_query_feats + self.dropout(tgt2) frame_query_feats = self.norm(frame_query_feats) return frame_query_feats, None from models.layers.utils import _get_clones @META_ARCH_REGISTRY.register() class FrameQuery_SS2DLayer_v2(nn.Module): def __init__(self, configs, dropout=0.0): super().__init__() d_model = configs['d_model'] n_layers = configs['nlayers'] if 'nlayers' in configs else 1 self.nlayers = n_layers self.self_attn = _get_clones(SS2D_FrameQuery_v2(configs), n_layers) from models.layers.decoder_layers import FFNLayer self.ffn = FFNLayer(d_model=d_model, dim_feedforward=configs['dim_feedforward'], dropout=configs['dropout'],) def forward(self, frame_query_feats, # n bt c frame_query_poses, # n bt c # nq上的Position embedding nf=None, **kwargs): for i in range(self.nlayers): frame_query_feats = self.self_attn[i](frame_query_feats=frame_query_feats, # n bt c frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding nf=nf,)[0] frame_query_feats = self.ffn(frame_query_feats) return frame_query_feats class Hilbert_2DSelectiveScan(nn.Module): def __init__( self, d_model, d_state=16, # d_state="auto", # 20240109 d_conv=3, expand=2, dt_rank="auto", dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, dropout=0., conv_bias=True, bias=False, device=None, dtype=None, scan_order=None, **kwargs, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.d_model = d_model self.d_state = d_state # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109 self.d_conv = d_conv self.expand = expand self.d_inner = int(self.expand * self.d_model) self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs) self.conv2d = nn.Conv2d( in_channels=self.d_inner, out_channels=self.d_inner, groups=self.d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) self.act = nn.SiLU() self.x_proj = ( nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), ) self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=2, N, inner) del self.x_proj self.dt_projs = ( self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs), ) self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=2, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=2, inner) del self.dt_projs self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=2, merge=True) # (K=2, D, N) self.Ds = self.D_init(self.d_inner, copies=2, merge=True) # (K=2, D, N) # self.selective_scan = selective_scan_fn self.forward_core = self.forward_corev0 self.out_norm = nn.LayerNorm(self.d_inner) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else None self.scan_order = scan_order @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 1: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 1: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def forward_corev0(self, x: torch.Tensor, hilbert_curve): # LongTensor[int] 按照hw进行flatten之后的hilbert排序 self.selective_scan = selective_scan_fn B, C, H, W = x.shape L = H * W K = 2 if self.scan_order == 'zigzag': x_hw = x.view(B, -1, L).contiguous() # b c hw xs = torch.stack([x_hw, torch.flip(x_hw, dims=[-1])], dim=1) # (b, k, d, l) elif self.scan_order == 'hilbert': x_hw = x.flatten(2).contiguous() # b c hw x_hil = x_hw.index_select(dim=-1, index=hilbert_curve) xs = torch.stack([x_hil, torch.flip(x_hil, dims=[-1])], dim=1) # (b, k, d, l) else: raise ValueError() x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) Ds = self.Ds.float().view(-1) # (k * d) As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) out_y = self.selective_scan( xs, dts, As, Bs, Cs, Ds, z=None, delta_bias=dt_projs_bias, delta_softplus=True, return_last_state=False, ).view(B, K, -1, L) assert out_y.dtype == torch.float if self.scan_order == 'zigzag': hw_order = out_y[:, 0].contiguous().view(B, -1, H, W).contiguous() rhw_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() rhw_order = rhw_order.view(B, -1, H, W,).contiguous() return hw_order + rhw_order elif self.scan_order == 'hilbert': hil_order = out_y[:, 0].contiguous() # b c hw rhil_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() # b c hw sum_out = torch.zeros_like(hil_order) hilbert_curve = repeat(hilbert_curve, 'hw -> b c hw', b=hil_order.shape[0], c=hil_order.shape[1]) assert hil_order.shape == hilbert_curve.shape sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=hil_order) sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=rhil_order) sum_out = sum_out.view(B, -1, H, W).contiguous() return sum_out # def forward_corev0(self, x: torch.Tensor, hilbert_curve): # # LongTensor[int] 按照hw进行flatten之后的hilbert排序 # self.selective_scan = selective_scan_fn # B, C, H, W, T = x.shape # L = H * W * T # K = 2 # if self.scan_order == 'zigzag': # x_hw = x.view(B, -1, L).contiguous() # b c hwt # xs = torch.stack([x_hw, torch.flip(x_hw, dims=[-1])], dim=1) # (b, k, d, l) # elif self.scan_order == 'hilbert': # x_hw = x.flatten(2).contiguous() # b c hwt # x_hil = x_hw.index_select(dim=-1, index=hilbert_curve) # xs = torch.stack([x_hil, torch.flip(x_hil, dims=[-1])], dim=1) # (b, k, d, l) # else: # raise ValueError() # x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight) # # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) # dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2) # dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight) # # dts = dts + self.dt_projs_bias.view(1, K, -1, 1) # xs = xs.float().view(B, -1, L) # (b, k * d, l) # dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) # Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l) # Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) # Ds = self.Ds.float().view(-1) # (k * d) # As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) # (k * d, d_state) # dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) # out_y = self.selective_scan( # xs, dts, # As, Bs, Cs, Ds, z=None, # delta_bias=dt_projs_bias, # delta_softplus=True, # return_last_state=False, # ).view(B, K, -1, L) # assert out_y.dtype == torch.float # if self.scan_order == 'zigzag': # hw_order = out_y[:, 0].contiguous().view(B, -1, H, W).contiguous() # rhw_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() # rhw_order = rhw_order.view(B, -1, H, W,).contiguous() # return hw_order + rhw_order # elif self.scan_order == 'hilbert': # hil_order = out_y[:, 0].contiguous() # b c hw # rhil_order = torch.flip(out_y[:, 1].contiguous(), dims=[-1]).contiguous() # b c hw # sum_out = torch.zeros_like(hil_order) # hilbert_curve = repeat(hilbert_curve, 'hwt -> b c hwt', b=hil_order.shape[0], c=hil_order.shape[1]) # assert hil_order.shape == hilbert_curve.shape # sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=hil_order) # sum_out.scatter_add_(dim=-1, index=hilbert_curve, src=rhil_order) # sum_out = sum_out.view(B, -1, H, W).contiguous() # return sum_out def forward(self, x: torch.Tensor, hilbert_curve, **kwargs): B, H, W, C = x.shape xz = self.in_proj(x) x, z = xz.chunk(2, dim=-1) # (b, h, w, d) x = x.permute(0, 3, 1, 2).contiguous() x = self.act(self.conv2d(x)) # (b, d, h, w) y = self.forward_core(x, hilbert_curve=hilbert_curve) # B C h w y = y.permute(0, 2, 3, 1).contiguous() # b h w c y = self.out_norm(y) y = y * F.silu(z) out = self.out_proj(y) if self.dropout is not None: out = self.dropout(out) return out class SS2D_FrameQuery_hilbert(nn.Module): def __init__(self, configs,): super().__init__() d_model = configs['d_model'] self.homo = Hilbert_2DSelectiveScan(d_model=configs['d_model'], d_state=configs['d_state'] if 'd_state' in configs else 16, d_conv=configs['d_conv'] if 'd_conv' in configs else 3, expand=configs['expand'] if 'expand' in configs else 2, dt_rank=configs['dt_rank'] if 'dt_rank' in configs else 'auto', dt_min=configs['dt_min'] if 'dt_min' in configs else 0.001, dt_max=configs['dt_max'] if 'dt_max' in configs else 0.1, dt_init=configs['dt_init'] if 'dt_init' in configs else 'random', dt_scale=configs['dt_scale'] if 'dt_scale' in configs else 1.0, dt_init_floor=configs['dt_init_floor'] if 'dt_init_floor' in configs else 1e-4, dropout=configs['dropout'] if 'dropout' in configs else 0, conv_bias=configs['conv_bias'] if 'conv_bias' in configs else True, bias=configs['bias'] if 'bias' in configs else False, scan_order=configs['scan_order'] ) self.pos_1d = build_position_encoding(position_embedding_name='1d') # t上的position embedding self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(configs['dropout']) def forward(self, frame_query_feats=None, # n bt c frame_query_poses=None, # n bt c # nq上的Position embedding hilbert_curve=None, nf=None, **kwargs ): batch_size = frame_query_feats.shape[1] // nf # b tgt2 = frame_query_feats + frame_query_poses tgt2 = rearrange(tgt2, 'n (b t) c -> b t n c',b=batch_size,t=nf).contiguous() sin_poses = self.pos_1d(torch.zeros_like(tgt2[..., 0].permute(0, 2, 1).flatten(0, 1)).bool(), hidden_dim=tgt2.shape[-1]) # bn c t sin_poses = rearrange(sin_poses, '(b n) c t -> b t n c', b=batch_size) tgt2 += sin_poses tgt2 = self.homo(tgt2, hilbert_curve=hilbert_curve) # b t n c tgt2 = tgt2.permute(2, 0, 1, 3).flatten(1, 2).contiguous() # n bt c frame_query_feats = frame_query_feats + self.dropout(tgt2) frame_query_feats = self.norm(frame_query_feats) return frame_query_feats, None from models.layers.utils import _get_clones @META_ARCH_REGISTRY.register() class FrameQuery_SS2DLayer_hilbert(nn.Module): def __init__(self, configs, dropout=0.0): super().__init__() d_model = configs['d_model'] n_layers = configs['nlayers'] if 'nlayers' in configs else 1 self.nlayers = n_layers self.self_attn = _get_clones(SS2D_FrameQuery_hilbert(configs), n_layers) from models.layers.decoder_layers import FFNLayer self.ffn = FFNLayer(d_model=d_model, dim_feedforward=configs['dim_feedforward'], dropout=configs['dropout'],) def forward(self, frame_query_feats, # n bt c frame_query_poses, # n bt c # nq上的Position embedding video_aux_dict=None, **kwargs): for i in range(self.nlayers): frame_query_feats = self.self_attn[i](frame_query_feats=frame_query_feats, # n bt c frame_query_poses=frame_query_poses, # n bt c # nq上的Position embedding nf=video_aux_dict['nf'], hilbert_curve=video_aux_dict['hilbert_curve'])[0] frame_query_feats = self.ffn(frame_query_feats) return frame_query_feats ================================================ FILE: models/encoder/ops/modules/ms_deform_attn.py ================================================ # Modify for sample points visualization # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import warnings import math import torch from torch import nn import torch.nn.functional as F from torch.nn.init import xavier_uniform_, constant_ from ..functions import MSDeformAttnFunction def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) return (n & (n-1) == 0) and n != 0 class MSDeformAttn(nn.Module): def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): """ Multi-Scale Deformable Attention Module :param d_model hidden dimension :param n_levels number of feature levels :param n_heads number of attention heads :param n_points number of sampling points per attention head per feature level """ super().__init__() if d_model % n_heads != 0: raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " "which is more efficient in our CUDA implementation.") self.im2col_step = 512 self.d_model = d_model self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) self.value_proj = nn.Linear(d_model, d_model) self.output_proj = nn.Linear(d_model, d_model) self._reset_parameters() def _reset_parameters(self): constant_(self.sampling_offsets.weight.data, 0.) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) for i in range(self.n_points): grid_init[:, :, i, :] *= i + 1 with torch.no_grad(): self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) constant_(self.attention_weights.weight.data, 0.) constant_(self.attention_weights.bias.data, 0.) xavier_uniform_(self.value_proj.weight.data) constant_(self.value_proj.bias.data, 0.) xavier_uniform_(self.output_proj.weight.data) constant_(self.output_proj.bias.data, 0.) def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): """ :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements :return output (N, Length_{query}, C) """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # N, Len_q, n_heads, n_levels, n_points, 2 if reference_points.shape[-1] == 2: offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 else: raise ValueError( 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) output = MSDeformAttnFunction.apply( value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) output = self.output_proj(output) return output, sampling_locations, attention_weights ================================================ FILE: models/encoder/ops/setup.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ import os import glob import torch from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CppExtension from torch.utils.cpp_extension import CUDAExtension from setuptools import find_packages from setuptools import setup requirements = ["torch", "torchvision"] def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "src") main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) sources = main_file + source_cpu extension = CppExtension extra_compile_args = {"cxx": []} define_macros = [] if torch.cuda.is_available() and CUDA_HOME is not None: extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] extra_compile_args["nvcc"] = [ "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] else: raise NotImplementedError('Cuda is not availabel') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] ext_modules = [ extension( "MultiScaleDeformableAttention", sources, include_dirs=include_dirs, define_macros=define_macros, extra_compile_args=extra_compile_args, ) ] return ext_modules setup( name="MultiScaleDeformableAttention", version="1.0", author="Weijie Su", url="https://github.com/fundamentalvision/Deformable-DETR", description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", packages=find_packages(exclude=("configs", "tests",)), ext_modules=get_extensions(), cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, ) ================================================ FILE: models/encoder/ops/src/cpu/ms_deform_attn_cpu.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include #include #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ERROR("Not implement on cpu"); } std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ERROR("Not implement on cpu"); } ================================================ FILE: models/encoder/ops/src/cpu/ms_deform_attn_cpu.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include at::Tensor ms_deform_attn_cpu_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cpu_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: models/encoder/ops/src/cuda/ms_deform_attn_cuda.cu ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include #include "cuda/ms_deform_im2col_cuda.cuh" #include #include #include #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); const int batch_n = im2col_step_; auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, columns.data()); })); } output = output.view({batch, num_query, num_heads*channels}); return output; } std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); const int batch = value.size(0); const int spatial_size = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); const int num_levels = spatial_shapes.size(0); const int num_query = sampling_loc.size(1); const int num_point = sampling_loc.size(4); const int im2col_step_ = std::min(batch, im2col_step); AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); auto grad_value = at::zeros_like(value); auto grad_sampling_loc = at::zeros_like(sampling_loc); auto grad_attn_weight = at::zeros_like(attn_weight); const int batch_n = im2col_step_; auto per_value_size = spatial_size * num_heads * channels; auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), grad_output_g.data(), value.data() + n * im2col_step_ * per_value_size, spatial_shapes.data(), level_start_index.data(), sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, attn_weight.data() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value.data() + n * im2col_step_ * per_value_size, grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); })); } return { grad_value, grad_sampling_loc, grad_attn_weight }; } ================================================ FILE: models/encoder/ops/src/cuda/ms_deform_attn_cuda.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include at::Tensor ms_deform_attn_cuda_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step); std::vector ms_deform_attn_cuda_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step); ================================================ FILE: models/encoder/ops/src/cuda/ms_deform_im2col_cuda.cuh ================================================ /*! ************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************** * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) * Copyright (c) 2018 Microsoft ************************************************************************** */ #include #include #include #include #include #include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ i += blockDim.x * gridDim.x) const int CUDA_NUM_THREADS = 1024; inline int GET_BLOCKS(const int N, const int num_threads) { return (N + num_threads - 1) / num_threads; } template __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; } const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); return val; } template __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); *grad_attn_weight = top_grad * val; *grad_sampling_loc = width * grad_w_weight * top_grad_value; *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; } template __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, const int &height, const int &width, const int &nheads, const int &channels, const scalar_t &h, const scalar_t &w, const int &m, const int &c, const scalar_t &top_grad, const scalar_t &attn_weight, scalar_t* &grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int h_low = floor(h); const int w_low = floor(w); const int h_high = h_low + 1; const int w_high = w_low + 1; const scalar_t lh = h - h_low; const scalar_t lw = w - w_low; const scalar_t hh = 1 - lh, hw = 1 - lw; const int w_stride = nheads * channels; const int h_stride = width * w_stride; const int h_low_ptr_offset = h_low * h_stride; const int h_high_ptr_offset = h_low_ptr_offset + h_stride; const int w_low_ptr_offset = w_low * w_stride; const int w_high_ptr_offset = w_low_ptr_offset + w_stride; const int base_ptr = m * channels + c; const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; const scalar_t top_grad_value = top_grad * attn_weight; scalar_t grad_h_weight = 0, grad_w_weight = 0; scalar_t v1 = 0; if (h_low >= 0 && w_low >= 0) { const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; v1 = bottom_data[ptr1]; grad_h_weight -= hw * v1; grad_w_weight -= hh * v1; atomicAdd(grad_value+ptr1, w1*top_grad_value); } scalar_t v2 = 0; if (h_low >= 0 && w_high <= width - 1) { const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; v2 = bottom_data[ptr2]; grad_h_weight -= lw * v2; grad_w_weight += hh * v2; atomicAdd(grad_value+ptr2, w2*top_grad_value); } scalar_t v3 = 0; if (h_high <= height - 1 && w_low >= 0) { const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; v3 = bottom_data[ptr3]; grad_h_weight += hw * v3; grad_w_weight -= lh * v3; atomicAdd(grad_value+ptr3, w3*top_grad_value); } scalar_t v4 = 0; if (h_high <= height - 1 && w_high <= width - 1) { const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; v4 = bottom_data[ptr4]; grad_h_weight += lw * v4; grad_w_weight += lh * v4; atomicAdd(grad_value+ptr4, w4*top_grad_value); } const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); atomicAdd(grad_attn_weight, top_grad * val); atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); } template __global__ void ms_deformable_im2col_gpu_kernel(const int n, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *data_col) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; scalar_t *data_col_ptr = data_col + index; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; scalar_t col = 0; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; } data_weight_ptr += 1; data_loc_w_ptr += 2; } } *data_col_ptr = col; } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockSize; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; __shared__ scalar_t cache_grad_attn_weight[blockSize]; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockSize/2; s>0; s>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); if (tid == 0) { scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; int sid=2; for (unsigned int tid = 1; tid < blockDim.x; ++tid) { _grad_w += cache_grad_sampling_loc[sid]; _grad_h += cache_grad_sampling_loc[sid + 1]; _grad_a += cache_grad_attn_weight[tid]; sid += 2; } *grad_sampling_loc = _grad_w; *(grad_sampling_loc + 1) = _grad_h; *grad_attn_weight = _grad_a; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { *grad_sampling_loc = cache_grad_sampling_loc[0]; *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; *grad_attn_weight = cache_grad_attn_weight[0]; } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { extern __shared__ int _s[]; scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; unsigned int tid = threadIdx.x; int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; *(cache_grad_attn_weight+threadIdx.x)=0; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); } __syncthreads(); for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) { if (tid < s) { const unsigned int xid1 = tid << 1; const unsigned int xid2 = (tid + s) << 1; cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; if (tid + (s << 1) < spre) { cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; } } __syncthreads(); } if (tid == 0) { atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); } __syncthreads(); data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, const scalar_t *grad_col, const scalar_t *data_value, const int64_t *data_spatial_shapes, const int64_t *data_level_start_index, const scalar_t *data_sampling_loc, const scalar_t *data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc, scalar_t *grad_attn_weight) { CUDA_KERNEL_LOOP(index, n) { int _temp = index; const int c_col = _temp % channels; _temp /= channels; const int sampling_index = _temp; const int m_col = _temp % num_heads; _temp /= num_heads; const int q_col = _temp % num_query; _temp /= num_query; const int b_col = _temp; const scalar_t top_grad = grad_col[index]; int data_weight_ptr = sampling_index * num_levels * num_point; int data_loc_w_ptr = data_weight_ptr << 1; const int grad_sampling_ptr = data_weight_ptr; grad_sampling_loc += grad_sampling_ptr << 1; grad_attn_weight += grad_sampling_ptr; const int grad_weight_stride = 1; const int grad_loc_stride = 2; const int qid_stride = num_heads * channels; const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; for (int l_col=0; l_col < num_levels; ++l_col) { const int level_start_id = data_level_start_index[l_col]; const int spatial_h_ptr = l_col << 1; const int spatial_h = data_spatial_shapes[spatial_h_ptr]; const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; const scalar_t *data_value_ptr = data_value + value_ptr_offset; scalar_t *grad_value_ptr = grad_value + value_ptr_offset; for (int p_col=0; p_col < num_point; ++p_col) { const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; const scalar_t weight = data_attn_weight[data_weight_ptr]; const scalar_t h_im = loc_h * spatial_h - 0.5; const scalar_t w_im = loc_w * spatial_w - 0.5; if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { ms_deform_attn_col2im_bilinear_gm( data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, top_grad, weight, grad_value_ptr, grad_sampling_loc, grad_attn_weight); } data_weight_ptr += 1; data_loc_w_ptr += 2; grad_attn_weight += grad_weight_stride; grad_sampling_loc += grad_loc_stride; } } } } template void ms_deformable_im2col_cuda(cudaStream_t stream, const scalar_t* data_value, const int64_t* data_spatial_shapes, const int64_t* data_level_start_index, const scalar_t* data_sampling_loc, const scalar_t* data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* data_col) { const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; const int num_threads = CUDA_NUM_THREADS; ms_deformable_im2col_gpu_kernel <<>>( num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); } } template void ms_deformable_col2im_cuda(cudaStream_t stream, const scalar_t* grad_col, const scalar_t* data_value, const int64_t * data_spatial_shapes, const int64_t * data_level_start_index, const scalar_t * data_sampling_loc, const scalar_t * data_attn_weight, const int batch_size, const int spatial_size, const int num_heads, const int channels, const int num_levels, const int num_query, const int num_point, scalar_t* grad_value, scalar_t* grad_sampling_loc, scalar_t* grad_attn_weight) { const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; const int num_kernels = batch_size * num_query * num_heads * channels; const int num_actual_kernels = batch_size * num_query * num_heads * channels; if (channels > 1024) { if ((channels & 1023) == 0) { ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_gm <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } else{ switch(channels) { case 1: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 2: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 4: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 8: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 16: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 32: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 64: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 128: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 256: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 512: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; case 1024: ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); break; default: if (channels < 64) { ms_deformable_col2im_gpu_kernel_shm_reduce_v1 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } else { ms_deformable_col2im_gpu_kernel_shm_reduce_v2 <<>>( num_kernels, grad_col, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, grad_value, grad_sampling_loc, grad_attn_weight); } } } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); } } ================================================ FILE: models/encoder/ops/src/ms_deform_attn.h ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #pragma once #include "cpu/ms_deform_attn_cpu.h" #ifdef WITH_CUDA #include "cuda/ms_deform_attn_cuda.h" #endif at::Tensor ms_deform_attn_forward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } std::vector ms_deform_attn_backward( const at::Tensor &value, const at::Tensor &spatial_shapes, const at::Tensor &level_start_index, const at::Tensor &sampling_loc, const at::Tensor &attn_weight, const at::Tensor &grad_output, const int im2col_step) { if (value.type().is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); #else AT_ERROR("Not compiled with GPU support"); #endif } AT_ERROR("Not implemented on the CPU"); } ================================================ FILE: models/encoder/ops/src/vision.cpp ================================================ /*! ************************************************************************************************** * Deformable DETR * Copyright (c) 2020 SenseTime. All Rights Reserved. * Licensed under the Apache License, Version 2.0 [see LICENSE for details] ************************************************************************************************** * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 ************************************************************************************************** */ #include "ms_deform_attn.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); } ================================================ FILE: models/encoder/ops/test.py ================================================ # ------------------------------------------------------------------------------------------------ # Deformable DETR # Copyright (c) 2020 SenseTime. All Rights Reserved. # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------------------------------ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ from __future__ import absolute_import from __future__ import print_function from __future__ import division import time import torch import torch.nn as nn from torch.autograd import gradcheck from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) S = sum([(H*W).item() for H, W in shapes]) torch.manual_seed(3) @torch.no_grad() def check_forward_equal_with_pytorch_double(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') @torch.no_grad() def check_forward_equal_with_pytorch_float(): value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) max_abs_err = (output_cuda - output_pytorch).abs().max() max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) im2col_step = 2 func = MSDeformAttnFunction.apply value.requires_grad = grad_value sampling_locations.requires_grad = grad_sampling_loc attention_weights.requires_grad = grad_attn_weight gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) print(f'* {gradok} check_gradient_numerical(D={channels})') if __name__ == '__main__': check_forward_equal_with_pytorch_double() check_forward_equal_with_pytorch_float() for channels in [30, 32, 64, 71, 1025, 2048, 3096]: check_gradient_numerical(channels, True, True, True) ================================================ FILE: models/layers/anyc_trans.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat, rearrange, reduce from typing import Any, Optional from torch import Tensor from .utils import _get_activation_fn class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class Linear_NormAct(nn.Linear): def __init__(self, *args, **kwargs): norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) out_features = kwargs['out_features'] if norm == None: self.norm = None elif norm == 'ln': self.norm = nn.LayerNorm(out_features) else: raise ValueError() self.activation = _get_activation_fn(activation) def forward(self, x): x = F.linear(x, self.weight, self.bias) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class Conv2d_NormAct(torch.nn.Conv2d): def __init__(self, *args, **kwargs): norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) out_dim = kwargs['out_channels'] if norm is None: self.norm = None elif norm == 'bn2d': # b c h w self.norm = nn.BatchNorm2d(out_dim) elif 'gn' in norm: # b c .. n_groups = int(norm.split('_')[-1]) self.norm = nn.GroupNorm(n_groups, out_dim) else: raise ValueError() self.activation = _get_activation_fn(activation) def forward(self, x): # b c h w x = self._conv_forward(x, self.weight, self.bias) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x class Conv3d_NormAct(torch.nn.Conv3d): def __init__(self, *args, **kwargs): norm = kwargs.pop("norm", None) activation = kwargs.pop("activation", None) super().__init__(*args, **kwargs) out_dim = kwargs['out_channels'] if norm == None: self.norm = None elif 'gn' in norm: n_groups = int(norm.split('_')[-1]) self.norm = nn.GroupNorm(n_groups, out_dim) else: raise ValueError() self.activation = _get_activation_fn(activation) def forward(self, x): x = self._conv_forward(x, self.weight, self.bias) if self.norm is not None: x = self.norm(x) if self.activation is not None: x = self.activation(x) return x ================================================ FILE: models/layers/decoder_layers.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat, rearrange, reduce from typing import Any, Optional from torch import Tensor from .utils import _get_activation_fn class SelfAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): q = k = self.with_pos_embed(tgt, query_pos) tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm(tgt) q = k = self.with_pos_embed(tgt2, query_pos) tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, tgt_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, query_pos: Optional[Tensor] = None,): if self.normalize_before: return self.forward_pre(tgt, tgt_mask, tgt_key_padding_mask, query_pos) return self.forward_post(tgt, tgt_mask, tgt_key_padding_mask, query_pos) class CrossAttentionLayer(nn.Module): def __init__(self, d_model, nhead, dropout=0.0, activation="relu", normalize_before=False): super().__init__() self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.norm = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) # n b d return tgt def forward_pre(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): tgt2 = self.norm(tgt) tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt, memory, memory_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None, ): if self.normalize_before: return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) return self.forward_post(tgt, memory, memory_mask, memory_key_padding_mask, pos, query_pos) class FFNLayer(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.0, activation="relu", normalize_before=False): super().__init__() # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm = nn.LayerNorm(d_model) self.activation = _get_activation_fn(activation) self.normalize_before = normalize_before self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos def forward_post(self, tgt): tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) tgt = tgt + self.dropout(tgt2) tgt = self.norm(tgt) return tgt def forward_pre(self, tgt): tgt2 = self.norm(tgt) tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) tgt = tgt + self.dropout(tgt2) return tgt def forward(self, tgt): if self.normalize_before: return self.forward_pre(tgt) return self.forward_post(tgt) ================================================ FILE: models/layers/gilbert/demo/index.html ================================================ Gilbert Curve

Gilbert Curve

W
H
================================================ FILE: models/layers/gilbert/demo/normalize.css ================================================ /*! normalize.css v3.0.2 | MIT License | git.io/normalize */ /** * 1. Set default font family to sans-serif. * 2. Prevent iOS text size adjust after orientation change, without disabling * user zoom. */ html { font-family: sans-serif; /* 1 */ -ms-text-size-adjust: 100%; /* 2 */ -webkit-text-size-adjust: 100%; /* 2 */ } /** * Remove default margin. */ body { margin: 0; } /* HTML5 display definitions ========================================================================== */ /** * Correct `block` display not defined for any HTML5 element in IE 8/9. * Correct `block` display not defined for `details` or `summary` in IE 10/11 * and Firefox. * Correct `block` display not defined for `main` in IE 11. */ article, aside, details, figcaption, figure, footer, header, hgroup, main, menu, nav, section, summary { display: block; } /** * 1. Correct `inline-block` display not defined in IE 8/9. * 2. Normalize vertical alignment of `progress` in Chrome, Firefox, and Opera. */ audio, canvas, progress, video { display: inline-block; /* 1 */ vertical-align: baseline; /* 2 */ } /** * Prevent modern browsers from displaying `audio` without controls. * Remove excess height in iOS 5 devices. */ audio:not([controls]) { display: none; height: 0; } /** * Address `[hidden]` styling not present in IE 8/9/10. * Hide the `template` element in IE 8/9/11, Safari, and Firefox < 22. */ [hidden], template { display: none; } /* Links ========================================================================== */ /** * Remove the gray background color from active links in IE 10. */ a { background-color: transparent; } /** * Improve readability when focused and also mouse hovered in all browsers. */ a:active, a:hover { outline: 0; } /* Text-level semantics ========================================================================== */ /** * Address styling not present in IE 8/9/10/11, Safari, and Chrome. */ abbr[title] { border-bottom: 1px dotted; } /** * Address style set to `bolder` in Firefox 4+, Safari, and Chrome. */ b, strong { font-weight: bold; } /** * Address styling not present in Safari and Chrome. */ dfn { font-style: italic; } /** * Address variable `h1` font-size and margin within `section` and `article` * contexts in Firefox 4+, Safari, and Chrome. */ h1 { font-size: 2em; margin: 0.67em 0; } /** * Address styling not present in IE 8/9. */ mark { background: #ff0; color: #000; } /** * Address inconsistent and variable font size in all browsers. */ small { font-size: 80%; } /** * Prevent `sub` and `sup` affecting `line-height` in all browsers. */ sub, sup { font-size: 75%; line-height: 0; position: relative; vertical-align: baseline; } sup { top: -0.5em; } sub { bottom: -0.25em; } /* Embedded content ========================================================================== */ /** * Remove border when inside `a` element in IE 8/9/10. */ img { border: 0; } /** * Correct overflow not hidden in IE 9/10/11. */ svg:not(:root) { overflow: hidden; } /* Grouping content ========================================================================== */ /** * Address margin not present in IE 8/9 and Safari. */ figure { margin: 1em 40px; } /** * Address differences between Firefox and other browsers. */ hr { -moz-box-sizing: content-box; box-sizing: content-box; height: 0; } /** * Contain overflow in all browsers. */ pre { overflow: auto; } /** * Address odd `em`-unit font size rendering in all browsers. */ code, kbd, pre, samp { font-family: monospace, monospace; font-size: 1em; } /* Forms ========================================================================== */ /** * Known limitation: by default, Chrome and Safari on OS X allow very limited * styling of `select`, unless a `border` property is set. */ /** * 1. Correct color not being inherited. * Known issue: affects color of disabled elements. * 2. Correct font properties not being inherited. * 3. Address margins set differently in Firefox 4+, Safari, and Chrome. */ button, input, optgroup, select, textarea { color: inherit; /* 1 */ font: inherit; /* 2 */ margin: 0; /* 3 */ } /** * Address `overflow` set to `hidden` in IE 8/9/10/11. */ button { overflow: visible; } /** * Address inconsistent `text-transform` inheritance for `button` and `select`. * All other form control elements do not inherit `text-transform` values. * Correct `button` style inheritance in Firefox, IE 8/9/10/11, and Opera. * Correct `select` style inheritance in Firefox. */ button, select { text-transform: none; } /** * 1. Avoid the WebKit bug in Android 4.0.* where (2) destroys native `audio` * and `video` controls. * 2. Correct inability to style clickable `input` types in iOS. * 3. Improve usability and consistency of cursor style between image-type * `input` and others. */ button, html input[type="button"], /* 1 */ input[type="reset"], input[type="submit"] { -webkit-appearance: button; /* 2 */ cursor: pointer; /* 3 */ } /** * Re-set default cursor for disabled elements. */ button[disabled], html input[disabled] { cursor: default; } /** * Remove inner padding and border in Firefox 4+. */ button::-moz-focus-inner, input::-moz-focus-inner { border: 0; padding: 0; } /** * Address Firefox 4+ setting `line-height` on `input` using `!important` in * the UA stylesheet. */ input { line-height: normal; } /** * It's recommended that you don't attempt to style these elements. * Firefox's implementation doesn't respect box-sizing, padding, or width. * * 1. Address box sizing set to `content-box` in IE 8/9/10. * 2. Remove excess padding in IE 8/9/10. */ input[type="checkbox"], input[type="radio"] { box-sizing: border-box; /* 1 */ padding: 0; /* 2 */ } /** * Fix the cursor style for Chrome's increment/decrement buttons. For certain * `font-size` values of the `input`, it causes the cursor style of the * decrement button to change from `default` to `text`. */ input[type="number"]::-webkit-inner-spin-button, input[type="number"]::-webkit-outer-spin-button { height: auto; } /** * 1. Address `appearance` set to `searchfield` in Safari and Chrome. * 2. Address `box-sizing` set to `border-box` in Safari and Chrome * (include `-moz` to future-proof). */ input[type="search"] { -webkit-appearance: textfield; /* 1 */ -moz-box-sizing: content-box; -webkit-box-sizing: content-box; /* 2 */ box-sizing: content-box; } /** * Remove inner padding and search cancel button in Safari and Chrome on OS X. * Safari (but not Chrome) clips the cancel button when the search input has * padding (and `textfield` appearance). */ input[type="search"]::-webkit-search-cancel-button, input[type="search"]::-webkit-search-decoration { -webkit-appearance: none; } /** * Define consistent border, margin, and padding. */ fieldset { border: 1px solid #c0c0c0; margin: 0 2px; padding: 0.35em 0.625em 0.75em; } /** * 1. Correct `color` not being inherited in IE 8/9/10/11. * 2. Remove padding so people aren't caught out if they zero out fieldsets. */ legend { border: 0; /* 1 */ padding: 0; /* 2 */ } /** * Remove default vertical scrollbar in IE 8/9/10/11. */ textarea { overflow: auto; } /** * Don't inherit the `font-weight` (applied by a rule above). * NOTE: the default cannot safely be changed in Chrome and Safari on OS X. */ optgroup { font-weight: bold; } /* Tables ========================================================================== */ /** * Remove most spacing between table cells. */ table { border-collapse: collapse; border-spacing: 0; } td, th { padding: 0; } ================================================ FILE: models/layers/gilbert/demo/script.js ================================================ // SPDX-License-Identifier: BSD-2-Clause // Copyright (c) 2024 abetusk var info = { "W" : -1, "H": -1, "default": { "w": 28, "h": 18 }, "T": { "x": 10, "y": 10 }, "S": { "x": 10, "y": 10 }, "color": "color", "reverse_y": true, "two": null, "line": [], "two": null, "ctx": null, "container" : null, "canvas": null }; //--- // https://stackoverflow.com/a/18197341 CC-BY-SA // Matěj Pokorný (https://stackoverflow.com/users/2438165/mat%c4%9bj-pokorn%c3%bd) // function download(filename, text) { let element = document.createElement('a'); element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); element.setAttribute('download', filename); element.style.display = 'none'; document.body.appendChild(element); element.click(); document.body.removeChild(element); } // //--- function dl_svg() { info.two.render(); info.two.update(); info.svg = info.container.innerHTML; let txt = info.svg.toString(); let wxh_str = info.W.toString() + "x" + info.H.toString(); download("gilbert" + wxh_str + ".svg", txt); } function update_color() { if (info.color == "color") { info.color = "bw"; } else { info.color = "color"; } draw_curve(); } function update_wh(w,h) { let ui_width = document.getElementById("ui_width"); let ui_height = document.getElementById("ui_height"); ui_width.value = w; ui_height.value = h info.W = w; info.H = h; draw_curve(); } function update_preset() { let ele = document.getElementById("ui_preset"); let v = ele.value; if ( v.match( /^\d+x\d+$/ ) ){ let tok = v.split("x"); let _w = parseInt(tok[0]); let _h = parseInt(tok[1]); if (isNaN(_w)) { _w = 1; } if (isNaN(_h)) { _h = 1; } if (_w < 1) { _w = 1; } if (_h < 1) { _h = 1; } update_wh(_w,_h); return; } let ui_w = document.getElementById("ui_width"); let ui_h = document.getElementById("ui_height"); let _w = parseInt(ui_w.value); let _h = parseInt(ui_h.value); if (isNaN(_w)) { _w = 1; } if (isNaN(_h)) { _h = 1; } if (_w < 1) { _w = 1; } if (_h < 1) { _h = 1; } update_wh(_w,_h); } function update_num() { let ui_w = document.getElementById("ui_width"); let ui_h = document.getElementById("ui_height"); let _w = parseInt(ui_w.value); let _h = parseInt(ui_h.value); if (isNaN(_w)) { _w = 1; } if (isNaN(_h)) { _h = 1; } if (_w < 1) { _w = 1; } if (_h < 1) { _h = 1; } update_wh(_w,_h); } function draw_curve() { let two = info.two; let W = info.W; let H = info.H; let S = info.S; let T = info.T; let flip = info.reverse_y; //two.clear(); let N = (W*H); let n = N-1; if (n==0) { return; } if (n < info.line.length) { let m = info.line.length; for (let ii=n; ii info.line.length) { let m = info.line.length; for (let ii=m; ii .label-body { display: inline-block; margin-left: .5rem; font-weight: normal; } /* Lists –––––––––––––––––––––––––––––––––––––––––––––––––– */ ul { list-style: circle inside; } ol { list-style: decimal inside; } ol, ul { padding-left: 0; margin-top: 0; } ul ul, ul ol, ol ol, ol ul { margin: 1.5rem 0 1.5rem 3rem; font-size: 90%; } li { margin-bottom: 1rem; } /* Code –––––––––––––––––––––––––––––––––––––––––––––––––– */ code { padding: .2rem .5rem; margin: 0 .2rem; font-size: 90%; white-space: nowrap; background: #F1F1F1; border: 1px solid #E1E1E1; border-radius: 4px; } pre > code { display: block; padding: 1rem 1.5rem; white-space: pre; } /* Tables –––––––––––––––––––––––––––––––––––––––––––––––––– */ th, td { padding: 12px 15px; text-align: left; border-bottom: 1px solid #E1E1E1; } th:first-child, td:first-child { padding-left: 0; } th:last-child, td:last-child { padding-right: 0; } /* Spacing –––––––––––––––––––––––––––––––––––––––––––––––––– */ button, .button { margin-bottom: 1rem; } input, textarea, select, fieldset { margin-bottom: 1.5rem; } pre, blockquote, dl, figure, table, p, ul, ol, form { margin-bottom: 2.5rem; } /* Utilities –––––––––––––––––––––––––––––––––––––––––––––––––– */ .u-full-width { width: 100%; box-sizing: border-box; } .u-max-full-width { max-width: 100%; box-sizing: border-box; } .u-pull-right { float: right; } .u-pull-left { float: left; } /* Misc –––––––––––––––––––––––––––––––––––––––––––––––––– */ hr { margin-top: 3rem; margin-bottom: 3.5rem; border-width: 0; border-top: 1px solid #E1E1E1; } /* Clearing –––––––––––––––––––––––––––––––––––––––––––––––––– */ /* Self Clearing Goodness */ .container:after, .row:after, .u-cf { content: ""; display: table; clear: both; } /* Media Queries –––––––––––––––––––––––––––––––––––––––––––––––––– */ /* Note: The best way to structure the use of media queries is to create the queries near the relevant code. For example, if you wanted to change the styles for buttons on small devices, paste the mobile query code up in the buttons section and style it there. */ /* Larger than mobile */ @media (min-width: 400px) {} /* Larger than phablet (also point when grid becomes active) */ @media (min-width: 550px) {} /* Larger than tablet */ @media (min-width: 750px) {} /* Larger than desktop */ @media (min-width: 1000px) {} /* Larger than Desktop HD */ @media (min-width: 1200px) {} ================================================ FILE: models/layers/gilbert/demo/two.js ================================================ /* MIT License Copyright (c) 2012 - 2021 @jonobr1 / http://jono.fyi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ var Two = (() => { var __defProp = Object.defineProperty; var __getOwnPropDesc = Object.getOwnPropertyDescriptor; var __getOwnPropNames = Object.getOwnPropertyNames; var __hasOwnProp = Object.prototype.hasOwnProperty; var __defNormalProp = (obj, key, value) => key in obj ? __defProp(obj, key, { enumerable: true, configurable: true, writable: true, value }) : obj[key] = value; var __export = (target, all) => { for (var name in all) __defProp(target, name, { get: all[name], enumerable: true }); }; var __copyProps = (to, from, except, desc) => { if (from && typeof from === "object" || typeof from === "function") { for (let key of __getOwnPropNames(from)) if (!__hasOwnProp.call(to, key) && key !== except) __defProp(to, key, { get: () => from[key], enumerable: !(desc = __getOwnPropDesc(from, key)) || desc.enumerable }); } return to; }; var __toCommonJS = (mod2) => __copyProps(__defProp({}, "__esModule", { value: true }), mod2); var __publicField = (obj, key, value) => { __defNormalProp(obj, typeof key !== "symbol" ? key + "" : key, value); return value; }; // src/two.js var two_exports = {}; __export(two_exports, { default: () => Two }); // src/utils/path-commands.js var Commands = { move: "M", line: "L", curve: "C", arc: "A", close: "Z" }; // src/utils/math.js var math_exports = {}; __export(math_exports, { HALF_PI: () => HALF_PI, NumArray: () => NumArray, TWO_PI: () => TWO_PI, decomposeMatrix: () => decomposeMatrix, getComputedMatrix: () => getComputedMatrix, getPoT: () => getPoT, lerp: () => lerp, mod: () => mod, setMatrix: () => setMatrix, toFixed: () => toFixed }); // src/utils/root.js var root; if (typeof window !== "undefined") { root = window; } else if (typeof global !== "undefined") { root = global; } else if (typeof self !== "undefined") { root = self; } // src/utils/math.js var Matrix; var TWO_PI = Math.PI * 2; var HALF_PI = Math.PI * 0.5; function decomposeMatrix(matrix, b, c, d, e, f) { let a; if (arguments.length <= 1) { a = matrix.a; b = matrix.b; c = matrix.c; d = matrix.d; e = matrix.e; f = matrix.f; } else { a = matrix; } return { translateX: e, translateY: f, scaleX: Math.sqrt(a * a + b * b), scaleY: Math.sqrt(c * c + d * d), rotation: 180 * Math.atan2(b, a) / Math.PI }; } function setMatrix(matrix) { Matrix = matrix; } function getComputedMatrix(object, matrix) { matrix = matrix && matrix.identity() || new Matrix(); let parent = object; const matrices = []; while (parent && parent._matrix) { matrices.push(parent._matrix); parent = parent.parent; } matrices.reverse(); for (let i = 0; i < matrices.length; i++) { const m = matrices[i]; const e = m.elements; matrix.multiply( e[0], e[1], e[2], e[3], e[4], e[5], e[6], e[7], e[8], e[9] ); } return matrix; } function lerp(a, b, t) { return t * (b - a) + a; } var pots = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]; function getPoT(value) { let i = 0; while (pots[i] && pots[i] < value) { i++; } return pots[i]; } function mod(v, l) { while (v < 0) { v += l; } return v % l; } var NumArray = root.Float32Array || Array; var floor = Math.floor; function toFixed(v) { return floor(v * 1e6) / 1e6; } // src/utils/curves.js var curves_exports = {}; __export(curves_exports, { Curve: () => Curve, getAnchorsFromArcData: () => getAnchorsFromArcData, getComponentOnCubicBezier: () => getComponentOnCubicBezier, getControlPoints: () => getControlPoints, getCurveBoundingBox: () => getCurveBoundingBox, getCurveFromPoints: () => getCurveFromPoints, getCurveLength: () => getCurveLength, getReflection: () => getReflection, integrate: () => integrate, subdivide: () => subdivide }); // src/events.js var Events = class { _events = {}; _bound = false; constructor() { } addEventListener(name, handler) { const list = this._events[name] || (this._events[name] = []); list.push(handler); this._bound = true; return this; } on() { return this.addEventListener.apply(this, arguments); } bind() { return this.addEventListener.apply(this, arguments); } removeEventListener(name, handler) { if (!this._events) { return this; } if (!name && !handler) { this._events = {}; this._bound = false; return this; } const names = name ? [name] : Object.keys(this._events); for (let i = 0, l = names.length; i < l; i++) { name = names[i]; let list = this._events[name]; if (list) { let events = []; if (handler) { for (let j = 0, k = list.length; j < k; j++) { let e = list[j]; e = e.handler ? e.handler : e; if (handler !== e) { events.push(e); } } } this._events[name] = events; } } return this; } off() { return this.removeEventListener.apply(this, arguments); } unbind() { return this.removeEventListener.apply(this, arguments); } dispatchEvent(name) { if (!this._events) { return this; } const args = Array.prototype.slice.call(arguments, 1); const events = this._events[name]; if (events) { for (let i = 0; i < events.length; i++) { events[i].call(this, ...args); } } return this; } trigger() { return this.dispatchEvent.apply(this, arguments); } listen(obj, name, handler) { const scope = this; if (obj) { e.obj = obj; e.name = name; e.handler = handler; obj.on(name, e); } function e() { handler.apply(scope, arguments); } return scope; } ignore(obj, name, handler) { obj.off(name, handler); return this; } }; __publicField(Events, "Types", { play: "play", pause: "pause", update: "update", render: "render", resize: "resize", change: "change", remove: "remove", insert: "insert", order: "order", load: "load" }); __publicField(Events, "Methods", [ "addEventListener", "on", "removeEventListener", "off", "unbind", "dispatchEvent", "trigger", "listen", "ignore" ]); // src/vector.js var proto = { x: { enumerable: true, get: function() { return this._x; }, set: function(v) { if (this._x !== v) { this._x = v; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, y: { enumerable: true, get: function() { return this._y; }, set: function(v) { if (this._y !== v) { this._y = v; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } } }; var _Vector = class extends Events { _x = 0; _y = 0; constructor(x = 0, y = 0) { super(); for (let prop in proto) { Object.defineProperty(this, prop, proto[prop]); } this.x = x; this.y = y; } static add(v1, v2) { return new _Vector(v1.x + v2.x, v1.y + v2.y); } static sub(v1, v2) { return new _Vector(v1.x - v2.x, v1.y - v2.y); } static subtract(v1, v2) { return _Vector.sub(v1, v2); } static ratioBetween(v1, v2) { return (v1.x * v2.x + v1.y * v2.y) / (v1.length() * v2.length()); } static angleBetween(v1, v2) { if (arguments.length >= 4) { const dx2 = arguments[0] - arguments[2]; const dy2 = arguments[1] - arguments[3]; return Math.atan2(dy2, dx2); } const dx = v1.x - v2.x; const dy = v1.y - v2.y; return Math.atan2(dy, dx); } static distanceBetween(v1, v2) { return Math.sqrt(_Vector.distanceBetweenSquared(v1, v2)); } static distanceBetweenSquared(v1, v2) { const dx = v1.x - v2.x; const dy = v1.y - v2.y; return dx * dx + dy * dy; } set(x, y) { this.x = x; this.y = y; return this; } copy(v) { this.x = v.x; this.y = v.y; return this; } clear() { this.x = 0; this.y = 0; return this; } clone() { return new _Vector(this.x, this.y); } add(x, y) { if (arguments.length <= 0) { return this; } else if (arguments.length <= 1) { if (typeof x === "number") { this.x += x; this.y += x; } else if (x && typeof x.x === "number" && typeof x.y === "number") { this.x += x.x; this.y += x.y; } } else { this.x += x; this.y += y; } return this; } addSelf(v) { return this.add.apply(this, arguments); } sub(x, y) { if (arguments.length <= 0) { return this; } else if (arguments.length <= 1) { if (typeof x === "number") { this.x -= x; this.y -= x; } else if (x && typeof x.x === "number" && typeof x.y === "number") { this.x -= x.x; this.y -= x.y; } } else { this.x -= x; this.y -= y; } return this; } subtract() { return this.sub.apply(this, arguments); } subSelf(v) { return this.sub.apply(this, arguments); } subtractSelf(v) { return this.sub.apply(this, arguments); } multiply(x, y) { if (arguments.length <= 0) { return this; } else if (arguments.length <= 1) { if (typeof x === "number") { this.x *= x; this.y *= x; } else if (x && typeof x.x === "number" && typeof x.y === "number") { this.x *= x.x; this.y *= x.y; } } else { this.x *= x; this.y *= y; } return this; } multiplySelf(v) { return this.multiply.apply(this, arguments); } multiplyScalar(s) { return this.multiply(s); } divide(x, y) { if (arguments.length <= 0) { return this; } else if (arguments.length <= 1) { if (typeof x === "number") { this.x /= x; this.y /= x; } else if (x && typeof x.x === "number" && typeof x.y === "number") { this.x /= x.x; this.y /= x.y; } } else { this.x /= x; this.y /= y; } if (isNaN(this.x)) { this.x = 0; } if (isNaN(this.y)) { this.y = 0; } return this; } divideSelf(v) { return this.divide.apply(this, arguments); } divideScalar(s) { return this.divide(s); } negate() { return this.multiply(-1); } dot(v) { return this.x * v.x + this.y * v.y; } length() { return Math.sqrt(this.lengthSquared()); } lengthSquared() { return this.x * this.x + this.y * this.y; } normalize() { return this.divideScalar(this.length()); } distanceTo(v) { return Math.sqrt(this.distanceToSquared(v)); } distanceToSquared(v) { const dx = this.x - v.x; const dy = this.y - v.y; return dx * dx + dy * dy; } setLength(l) { return this.normalize().multiplyScalar(l); } equals(v, eps) { eps = typeof eps === "undefined" ? 1e-4 : eps; return this.distanceTo(v) < eps; } lerp(v, t) { const x = (v.x - this.x) * t + this.x; const y = (v.y - this.y) * t + this.y; return this.set(x, y); } isZero(eps) { eps = typeof eps === "undefined" ? 1e-4 : eps; return this.length() < eps; } toString() { return this.x + ", " + this.y; } toObject() { return { x: this.x, y: this.y }; } rotate(radians) { const x = this.x; const y = this.y; const cos7 = Math.cos(radians); const sin7 = Math.sin(radians); this.x = x * cos7 - y * sin7; this.y = x * sin7 + y * cos7; return this; } }; var Vector = _Vector; __publicField(Vector, "zero", new _Vector()); __publicField(Vector, "left", new _Vector(-1, 0)); __publicField(Vector, "right", new _Vector(1, 0)); __publicField(Vector, "up", new _Vector(0, -1)); __publicField(Vector, "down", new _Vector(0, 1)); // src/anchor.js var Anchor = class extends Vector { controls = { left: new Vector(), right: new Vector() }; _command = Commands.move; _relative = true; _rx = 0; _ry = 0; _xAxisRotation = 0; _largeArcFlag = 0; _sweepFlag = 1; constructor(x = 0, y = 0, ax = 0, ay = 0, bx = 0, by = 0, command = Commands.move) { super(x, y); for (let prop in proto2) { Object.defineProperty(this, prop, proto2[prop]); } this.command = command; this.relative = true; const broadcast = Anchor.makeBroadcast(this); this.controls.left.set(ax, ay).addEventListener(Events.Types.change, broadcast); this.controls.right.set(bx, by).addEventListener(Events.Types.change, broadcast); } static makeBroadcast(scope) { return broadcast; function broadcast() { if (scope._bound) { scope.dispatchEvent(Events.Types.change); } } } copy(v) { this.x = v.x; this.y = v.y; if (typeof v.command === "string") { this.command = v.command; } if (v.controls) { if (v.controls.left) { this.controls.left.copy(v.controls.left); } if (v.controls.right) { this.controls.right.copy(v.controls.right); } } if (typeof v.relative === "boolean") { this.relative = v.relative; } if (typeof v.rx === "number") { this.rx = v.rx; } if (typeof v.ry === "number") { this.ry = v.ry; } if (typeof v.xAxisRotation === "number") { this.xAxisRotation = v.xAxisRotation; } if (typeof v.largeArcFlag === "number") { this.largeArcFlag = v.largeArcFlag; } if (typeof v.sweepFlag === "number") { this.sweepFlag = v.sweepFlag; } return this; } clone() { return new Anchor().copy(this); } toObject() { return { x: this.x, y: this.y, command: this.command, relative: this.relative, controls: { left: this.controls.left.toObject(), right: this.controls.right.toObject() }, rx: this.rx, ry: this.ry, xAxisRotation: this.xAxisRotation, largeArcFlag: this.largeArcFlag, sweepFlag: this.sweepFlag }; } toString() { return JSON.stringify(this.toObject()); } }; var proto2 = { command: { enumerable: true, get: function() { return this._command; }, set: function(command) { if (this._command !== command) { this._command = command; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, relative: { enumerable: true, get: function() { return this._relative; }, set: function(relative) { if (this._relative !== !!relative) { this._relative = !!relative; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, rx: { enumerable: true, get: function() { return this._rx; }, set: function(rx) { if (this._rx !== rx) { this._rx = rx; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, ry: { enumerable: true, get: function() { return this._ry; }, set: function(ry) { if (this._ry !== ry) { this._ry = ry; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, xAxisRotation: { enumerable: true, get: function() { return this._xAxisRotation; }, set: function(xAxisRotation) { if (this._xAxisRotation !== xAxisRotation) { this._xAxisRotation = xAxisRotation; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, largeArcFlag: { enumerable: true, get: function() { return this._largeArcFlag; }, set: function(largeArcFlag) { if (this._largeArcFlag !== largeArcFlag) { this._largeArcFlag = largeArcFlag; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } }, sweepFlag: { get: function() { return this._sweepFlag; }, set: function(sweepFlag) { if (this._sweepFlag !== sweepFlag) { this._sweepFlag = sweepFlag; if (this._bound) { this.dispatchEvent(Events.Types.change); } } } } }; // src/constants.js var count = 0; var Constants = { nextFrameID: null, Types: { webgl: "WebGLRenderer", svg: "SVGRenderer", canvas: "CanvasRenderer" }, Version: "v0.8.12", PublishDate: "2023-10-16T17:55:26.551Z", Identifier: "two-", Resolution: 12, AutoCalculateImportedMatrices: true, Instances: [], uniqueId: function() { return count++; } }; // src/utils/curves.js var Curve = { CollinearityEpsilon: Math.pow(10, -30), RecursionLimit: 16, CuspLimit: 0, Tolerance: { distance: 0.25, angle: 0, epsilon: Number.EPSILON }, abscissas: [ [0.5773502691896257], [0, 0.7745966692414834], [0.33998104358485626, 0.8611363115940526], [0, 0.5384693101056831, 0.906179845938664], [0.2386191860831969, 0.6612093864662645, 0.932469514203152], [0, 0.4058451513773972, 0.7415311855993945, 0.9491079123427585], [0.1834346424956498, 0.525532409916329, 0.7966664774136267, 0.9602898564975363], [0, 0.3242534234038089, 0.6133714327005904, 0.8360311073266358, 0.9681602395076261], [0.14887433898163122, 0.4333953941292472, 0.6794095682990244, 0.8650633666889845, 0.9739065285171717], [0, 0.26954315595234496, 0.5190961292068118, 0.7301520055740494, 0.8870625997680953, 0.978228658146057], [0.1252334085114689, 0.3678314989981802, 0.5873179542866175, 0.7699026741943047, 0.9041172563704749, 0.9815606342467192], [0, 0.2304583159551348, 0.44849275103644687, 0.6423493394403402, 0.8015780907333099, 0.9175983992229779, 0.9841830547185881], [0.10805494870734367, 0.31911236892788974, 0.5152486363581541, 0.6872929048116855, 0.827201315069765, 0.9284348836635735, 0.9862838086968123], [0, 0.20119409399743451, 0.3941513470775634, 0.5709721726085388, 0.7244177313601701, 0.8482065834104272, 0.937273392400706, 0.9879925180204854], [0.09501250983763744, 0.2816035507792589, 0.45801677765722737, 0.6178762444026438, 0.755404408355003, 0.8656312023878318, 0.9445750230732326, 0.9894009349916499] ], weights: [ [1], [0.8888888888888888, 0.5555555555555556], [0.6521451548625461, 0.34785484513745385], [0.5688888888888889, 0.47862867049936647, 0.23692688505618908], [0.46791393457269104, 0.3607615730481386, 0.17132449237917036], [0.4179591836734694, 0.3818300505051189, 0.27970539148927664, 0.1294849661688697], [0.362683783378362, 0.31370664587788727, 0.22238103445337448, 0.10122853629037626], [0.3302393550012598, 0.31234707704000286, 0.26061069640293544, 0.1806481606948574, 0.08127438836157441], [0.29552422471475287, 0.26926671930999635, 0.21908636251598204, 0.1494513491505806, 0.06667134430868814], [0.2729250867779006, 0.26280454451024665, 0.23319376459199048, 0.18629021092773426, 0.1255803694649046, 0.05566856711617366], [0.24914704581340277, 0.2334925365383548, 0.20316742672306592, 0.16007832854334622, 0.10693932599531843, 0.04717533638651183], [0.2325515532308739, 0.22628318026289723, 0.2078160475368885, 0.17814598076194574, 0.13887351021978725, 0.09212149983772845, 0.04048400476531588], [0.2152638534631578, 0.2051984637212956, 0.18553839747793782, 0.15720316715819355, 0.12151857068790319, 0.08015808715976021, 0.03511946033175186], [0.2025782419255613, 0.19843148532711158, 0.1861610000155622, 0.16626920581699392, 0.13957067792615432, 0.10715922046717194, 0.07036604748810812, 0.03075324199611727], [0.1894506104550685, 0.18260341504492358, 0.16915651939500254, 0.14959598881657674, 0.12462897125553388, 0.09515851168249279, 0.062253523938647894, 0.027152459411754096] ] }; function getComponentOnCubicBezier(t, a, b, c, d) { const k = 1 - t; return k * k * k * a + 3 * k * k * t * b + 3 * k * t * t * c + t * t * t * d; } function subdivide(x1, y1, x2, y2, x3, y3, x4, y4, limit) { limit = limit || Curve.RecursionLimit; const amount = limit + 1; if (Math.abs(x1 - x4) < 1e-3 && Math.abs(y1 - y4) < 1e-3) { return [new Anchor(x4, y4)]; } const result = []; for (let i = 0; i < amount; i++) { const t = i / amount; const x = getComponentOnCubicBezier(t, x1, x2, x3, x4); const y = getComponentOnCubicBezier(t, y1, y2, y3, y4); result.push(new Anchor(x, y)); } return result; } function getCurveLength(x1, y1, x2, y2, x3, y3, x4, y4, limit) { if (x1 === x2 && y1 === y2 && x3 === x4 && y3 === y4) { const dx = x4 - x1; const dy = y4 - y1; return Math.sqrt(dx * dx + dy * dy); } const ax = 9 * (x2 - x3) + 3 * (x4 - x1), bx = 6 * (x1 + x3) - 12 * x2, cx = 3 * (x2 - x1), ay = 9 * (y2 - y3) + 3 * (y4 - y1), by = 6 * (y1 + y3) - 12 * y2, cy = 3 * (y2 - y1); function integrand(t) { const dx = (ax * t + bx) * t + cx, dy = (ay * t + by) * t + cy; return Math.sqrt(dx * dx + dy * dy); } return integrate( integrand, 0, 1, limit || Curve.RecursionLimit ); } function getCurveBoundingBox(x1, y1, x2, y2, x3, y3, x4, y4) { const tvalues = []; const bounds = [[], []]; let a, b, c, t, t1, t2, b2ac, sqrtb2ac; for (let i = 0; i < 2; ++i) { if (i == 0) { b = 6 * x1 - 12 * x2 + 6 * x3; a = -3 * x1 + 9 * x2 - 9 * x3 + 3 * x4; c = 3 * x2 - 3 * x1; } else { b = 6 * y1 - 12 * y2 + 6 * y3; a = -3 * y1 + 9 * y2 - 9 * y3 + 3 * y4; c = 3 * y2 - 3 * y1; } if (Math.abs(a) < 1e-12) { if (Math.abs(b) < 1e-12) { continue; } t = -c / b; if (0 < t && t < 1) { tvalues.push(t); } continue; } b2ac = b * b - 4 * c * a; sqrtb2ac = Math.sqrt(b2ac); if (b2ac < 0) { continue; } t1 = (-b + sqrtb2ac) / (2 * a); if (0 < t1 && t1 < 1) { tvalues.push(t1); } t2 = (-b - sqrtb2ac) / (2 * a); if (0 < t2 && t2 < 1) { tvalues.push(t2); } } let j = tvalues.length; let jlen = j; let mt; while (j--) { t = tvalues[j]; mt = 1 - t; bounds[0][j] = mt * mt * mt * x1 + 3 * mt * mt * t * x2 + 3 * mt * t * t * x3 + t * t * t * x4; bounds[1][j] = mt * mt * mt * y1 + 3 * mt * mt * t * y2 + 3 * mt * t * t * y3 + t * t * t * y4; } bounds[0][jlen] = x1; bounds[1][jlen] = y1; bounds[0][jlen + 1] = x4; bounds[1][jlen + 1] = y4; bounds[0].length = bounds[1].length = jlen + 2; return { min: { x: Math.min.apply(0, bounds[0]), y: Math.min.apply(0, bounds[1]) }, max: { x: Math.max.apply(0, bounds[0]), y: Math.max.apply(0, bounds[1]) } }; } function integrate(f, a, b, n) { let x = Curve.abscissas[n - 2], w = Curve.weights[n - 2], A = 0.5 * (b - a), B = A + a, i = 0, m = n + 1 >> 1, sum = n & 1 ? w[i++] * f(B) : 0; while (i < m) { const Ax = A * x[i]; sum += w[i++] * (f(B + Ax) + f(B - Ax)); } return A * sum; } function getCurveFromPoints(points, closed2) { const l = points.length, last = l - 1; for (let i = 0; i < l; i++) { const point = points[i]; const prev = closed2 ? mod(i - 1, l) : Math.max(i - 1, 0); const next = closed2 ? mod(i + 1, l) : Math.min(i + 1, last); const a = points[prev]; const b = point; const c = points[next]; getControlPoints(a, b, c); b.command = i === 0 ? Commands.move : Commands.curve; } } function getControlPoints(a, b, c) { const a1 = Vector.angleBetween(a, b); const a2 = Vector.angleBetween(c, b); let d1 = Vector.distanceBetween(a, b); let d2 = Vector.distanceBetween(c, b); let mid = (a1 + a2) / 2; if (d1 < 1e-4 || d2 < 1e-4) { if (typeof b.relative === "boolean" && !b.relative) { b.controls.left.copy(b); b.controls.right.copy(b); } return b; } d1 *= 0.33; d2 *= 0.33; if (a2 < a1) { mid += HALF_PI; } else { mid -= HALF_PI; } b.controls.left.x = Math.cos(mid) * d1; b.controls.left.y = Math.sin(mid) * d1; mid -= Math.PI; b.controls.right.x = Math.cos(mid) * d2; b.controls.right.y = Math.sin(mid) * d2; if (typeof b.relative === "boolean" && !b.relative) { b.controls.left.x += b.x; b.controls.left.y += b.y; b.controls.right.x += b.x; b.controls.right.y += b.y; } return b; } function getReflection(a, b, relative) { return new Vector( 2 * a.x - (b.x + a.x) - (relative ? a.x : 0), 2 * a.y - (b.y + a.y) - (relative ? a.y : 0) ); } function getAnchorsFromArcData(center, xAxisRotation, rx, ry, ts, td, ccw) { const resolution = Constants.Resolution; const anchors = []; for (let i = 0; i < resolution; i++) { let pct = (i + 1) / resolution; if (ccw) { pct = 1 - pct; } const theta = pct * td + ts; const x = rx * Math.cos(theta); const y = ry * Math.sin(theta); const anchor2 = new Anchor(x, y); anchor2.command = Commands.line; anchors.push(anchor2); } } // src/utils/device-pixel-ratio.js var devicePixelRatio = root.devicePixelRatio || 1; function getBackingStoreRatio(ctx) { return ctx.webkitBackingStorePixelRatio || ctx.mozBackingStorePixelRatio || ctx.msBackingStorePixelRatio || ctx.oBackingStorePixelRatio || ctx.backingStorePixelRatio || 1; } function getRatio(ctx) { return devicePixelRatio / getBackingStoreRatio(ctx); } // src/utils/underscore.js var slice = Array.prototype.slice; function isArrayLike(collection) { if (collection === null || collection === void 0) return false; const length = collection.length; return typeof length == "number" && length >= 0 && length < 4294967296; } var _ = { isNaN: function(obj) { return typeof obj === "number" && obj !== +obj; }, isElement: function(obj) { return !!(obj && obj.nodeType === 1); }, isObject: function(obj) { const type = typeof obj; return type === "function" || type === "object" && !!obj; }, extend: function(base) { const sources = slice.call(arguments, 1); for (let i = 0; i < sources.length; i++) { const obj = sources[i]; for (let k in obj) { base[k] = obj[k]; } } return base; }, defaults: function(base) { const sources = slice.call(arguments, 1); for (let i = 0; i < sources.length; i++) { const obj = sources[i]; for (let k in obj) { if (base[k] === void 0) { base[k] = obj[k]; } } } return base; }, each: function(obj, iteratee, context) { const ctx = context || this; const keys = !isArrayLike(obj) && Object.keys(obj); const length = (keys || obj).length; for (let i = 0; i < length; i++) { const k = keys ? keys[i] : i; iteratee.call(ctx, obj[k], k, obj); } return obj; }, performance: root.performance && root.performance.now ? root.performance : Date }; // src/element.js var Element = class extends Events { _flagId = false; _flagClassName = false; _renderer = {}; _id = ""; _className = ""; classList = []; constructor() { super(); for (let prop in proto3) { Object.defineProperty(this, prop, proto3[prop]); } } flagReset() { this._flagId = this._flagClassName = false; } }; var proto3 = { renderer: { enumerable: false, get: function() { return this._renderer; } }, id: { enumerable: true, get: function() { return this._id; }, set: function(v) { const id = this._id; if (v === this._id) { return; } this._id = v; this._flagId = true; if (this.parent) { delete this.parent.children.ids[id]; this.parent.children.ids[this._id] = this; } } }, className: { enumerable: true, get: function() { return this._className; }, set: function(v) { if (this._className !== v) { this._flagClassName = true; this.classList = v.split(/\s+?/); this._className = v; } } } }; // src/matrix.js var cos = Math.cos; var sin = Math.sin; var tan = Math.tan; var array = []; var _Matrix = class extends Events { elements = new NumArray(9); manual = false; constructor(a, b, c, d, e, f) { super(); let elements = a; if (!Array.isArray(elements)) { elements = Array.prototype.slice.call(arguments); } this.identity(); if (elements.length > 0) { this.set(elements); } } static Multiply(A, B, C) { if (B.length <= 3) { const e = A; let x, y, z; const a = B[0] || 0, b = B[1] || 0, c = B[2] || 0; x = e[0] * a + e[1] * b + e[2] * c; y = e[3] * a + e[4] * b + e[5] * c; z = e[6] * a + e[7] * b + e[8] * c; return [x, y, z]; } const A0 = A[0], A1 = A[1], A2 = A[2]; const A3 = A[3], A4 = A[4], A5 = A[5]; const A6 = A[6], A7 = A[7], A8 = A[8]; const B0 = B[0], B1 = B[1], B2 = B[2]; const B3 = B[3], B4 = B[4], B5 = B[5]; const B6 = B[6], B7 = B[7], B8 = B[8]; C = C || new NumArray(9); C[0] = A0 * B0 + A1 * B3 + A2 * B6; C[1] = A0 * B1 + A1 * B4 + A2 * B7; C[2] = A0 * B2 + A1 * B5 + A2 * B8; C[3] = A3 * B0 + A4 * B3 + A5 * B6; C[4] = A3 * B1 + A4 * B4 + A5 * B7; C[5] = A3 * B2 + A4 * B5 + A5 * B8; C[6] = A6 * B0 + A7 * B3 + A8 * B6; C[7] = A6 * B1 + A7 * B4 + A8 * B7; C[8] = A6 * B2 + A7 * B5 + A8 * B8; return C; } set(a, b, c, d, e, f, g, h, i) { if (typeof b === "undefined") { const elements = a; a = elements[0]; b = elements[1]; c = elements[2]; d = elements[3]; e = elements[4]; f = elements[5]; g = elements[6]; h = elements[7]; i = elements[8]; } this.elements[0] = a; this.elements[1] = b; this.elements[2] = c; this.elements[3] = d; this.elements[4] = e; this.elements[5] = f; this.elements[6] = g; this.elements[7] = h; this.elements[8] = i; return this.trigger(Events.Types.change); } copy(m) { this.elements[0] = m.elements[0]; this.elements[1] = m.elements[1]; this.elements[2] = m.elements[2]; this.elements[3] = m.elements[3]; this.elements[4] = m.elements[4]; this.elements[5] = m.elements[5]; this.elements[6] = m.elements[6]; this.elements[7] = m.elements[7]; this.elements[8] = m.elements[8]; this.manual = m.manual; return this.trigger(Events.Types.change); } identity() { this.elements[0] = _Matrix.Identity[0]; this.elements[1] = _Matrix.Identity[1]; this.elements[2] = _Matrix.Identity[2]; this.elements[3] = _Matrix.Identity[3]; this.elements[4] = _Matrix.Identity[4]; this.elements[5] = _Matrix.Identity[5]; this.elements[6] = _Matrix.Identity[6]; this.elements[7] = _Matrix.Identity[7]; this.elements[8] = _Matrix.Identity[8]; return this.trigger(Events.Types.change); } multiply(a, b, c, d, e, f, g, h, i) { if (typeof b === "undefined") { this.elements[0] *= a; this.elements[1] *= a; this.elements[2] *= a; this.elements[3] *= a; this.elements[4] *= a; this.elements[5] *= a; this.elements[6] *= a; this.elements[7] *= a; this.elements[8] *= a; return this.trigger(Events.Types.change); } if (typeof c === "undefined") { c = 1; } if (typeof d === "undefined") { a = a || 0; b = b || 0; c = c || 0; e = this.elements; const x = e[0] * a + e[1] * b + e[2] * c; const y = e[3] * a + e[4] * b + e[5] * c; const z = e[6] * a + e[7] * b + e[8] * c; return [x, y, z]; } const A = this.elements; const B = [a, b, c, d, e, f, g, h, i]; const A0 = A[0], A1 = A[1], A2 = A[2]; const A3 = A[3], A4 = A[4], A5 = A[5]; const A6 = A[6], A7 = A[7], A8 = A[8]; const B0 = B[0], B1 = B[1], B2 = B[2]; const B3 = B[3], B4 = B[4], B5 = B[5]; const B6 = B[6], B7 = B[7], B8 = B[8]; this.elements[0] = A0 * B0 + A1 * B3 + A2 * B6; this.elements[1] = A0 * B1 + A1 * B4 + A2 * B7; this.elements[2] = A0 * B2 + A1 * B5 + A2 * B8; this.elements[3] = A3 * B0 + A4 * B3 + A5 * B6; this.elements[4] = A3 * B1 + A4 * B4 + A5 * B7; this.elements[5] = A3 * B2 + A4 * B5 + A5 * B8; this.elements[6] = A6 * B0 + A7 * B3 + A8 * B6; this.elements[7] = A6 * B1 + A7 * B4 + A8 * B7; this.elements[8] = A6 * B2 + A7 * B5 + A8 * B8; return this.trigger(Events.Types.change); } inverse(out) { const a = this.elements; out = out || new _Matrix(); const a00 = a[0], a01 = a[1], a02 = a[2]; const a10 = a[3], a11 = a[4], a12 = a[5]; const a20 = a[6], a21 = a[7], a22 = a[8]; const b01 = a22 * a11 - a12 * a21; const b11 = -a22 * a10 + a12 * a20; const b21 = a21 * a10 - a11 * a20; let det = a00 * b01 + a01 * b11 + a02 * b21; if (!det) { return null; } det = 1 / det; out.elements[0] = b01 * det; out.elements[1] = (-a22 * a01 + a02 * a21) * det; out.elements[2] = (a12 * a01 - a02 * a11) * det; out.elements[3] = b11 * det; out.elements[4] = (a22 * a00 - a02 * a20) * det; out.elements[5] = (-a12 * a00 + a02 * a10) * det; out.elements[6] = b21 * det; out.elements[7] = (-a21 * a00 + a01 * a20) * det; out.elements[8] = (a11 * a00 - a01 * a10) * det; return out; } scale(sx, sy) { const l = arguments.length; if (l <= 1) { sy = sx; } return this.multiply(sx, 0, 0, 0, sy, 0, 0, 0, 1); } rotate(Number2) { const c = cos(Number2); const s = sin(Number2); return this.multiply(c, -s, 0, s, c, 0, 0, 0, 1); } translate(x, y) { return this.multiply(1, 0, x, 0, 1, y, 0, 0, 1); } skewX(Number2) { const a = tan(Number2); return this.multiply(1, a, 0, 0, 1, 0, 0, 0, 1); } skewY(Number2) { const a = tan(Number2); return this.multiply(1, 0, 0, a, 1, 0, 0, 0, 1); } toString(fullMatrix) { array.length = 0; this.toTransformArray(fullMatrix, array); return array.map(toFixed).join(" "); } toTransformArray(fullMatrix, output) { const elements = this.elements; const hasOutput = !!output; const a = elements[0]; const b = elements[1]; const c = elements[2]; const d = elements[3]; const e = elements[4]; const f = elements[5]; if (fullMatrix) { const g = elements[6]; const h = elements[7]; const i = elements[8]; if (hasOutput) { output[0] = a; output[1] = d; output[2] = g; output[3] = b; output[4] = e; output[5] = h; output[6] = c; output[7] = f; output[8] = i; return; } return [ a, d, g, b, e, h, c, f, i ]; } if (hasOutput) { output[0] = a; output[1] = d; output[2] = b; output[3] = e; output[4] = c; output[5] = f; return; } return [ a, d, b, e, c, f ]; } toArray(fullMatrix, output) { const elements = this.elements; const hasOutput = !!output; const a = elements[0]; const b = elements[1]; const c = elements[2]; const d = elements[3]; const e = elements[4]; const f = elements[5]; if (fullMatrix) { const g = elements[6]; const h = elements[7]; const i = elements[8]; if (hasOutput) { output[0] = a; output[1] = b; output[2] = c; output[3] = d; output[4] = e; output[5] = f; output[6] = g; output[7] = h; output[8] = i; return; } return [ a, b, c, d, e, f, g, h, i ]; } if (hasOutput) { output[0] = a; output[1] = b; output[2] = c; output[3] = d; output[4] = e; output[5] = f; return; } return [ a, b, c, d, e, f ]; } toObject() { return { elements: this.toArray(true), manual: !!this.manual }; } clone() { return new _Matrix().copy(this); } }; var Matrix2 = _Matrix; __publicField(Matrix2, "Identity", [ 1, 0, 0, 0, 1, 0, 0, 0, 1 ]); setMatrix(Matrix2); // src/shape.js var Shape = class extends Element { _flagMatrix = true; _flagScale = false; _matrix = null; _worldMatrix = null; _position = null; _rotation = 0; _scale = 1; _skewX = 0; _skewY = 0; constructor() { super(); for (let prop in proto4) { Object.defineProperty(this, prop, proto4[prop]); } this._renderer.flagMatrix = FlagMatrix.bind(this); this.isShape = true; this.id = Constants.Identifier + Constants.uniqueId(); this.matrix = new Matrix2(); this.worldMatrix = new Matrix2(); this.position = new Vector(); this.rotation = 0; this.scale = 1; this.skewX = 0; this.skewY = 0; } get renderer() { return this._renderer; } set renderer(v) { this._renderer = v; } get translation() { return proto4.position.get.apply(this, arguments); } set translation(v) { proto4.position.set.apply(this, arguments); } addTo(group) { group.add(this); return this; } remove() { if (!this.parent) { return this; } this.parent.remove(this); return this; } clone(parent) { const clone = new Shape(); clone.position.copy(this.position); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } if (parent) { parent.add(clone); } return clone._update(); } _update(bubbles) { if (!this._matrix.manual && this._flagMatrix) { this._matrix.identity().translate(this.position.x, this.position.y); if (this._scale instanceof Vector) { this._matrix.scale(this._scale.x, this._scale.y); } else { this._matrix.scale(this._scale); } this._matrix.rotate(this.rotation); this._matrix.skewX(this.skewX); this._matrix.skewY(this.skewY); } if (bubbles) { if (this.parent && this.parent._update) { this.parent._update(); } } return this; } flagReset() { this._flagMatrix = this._flagScale = false; super.flagReset.call(this); return this; } }; var proto4 = { position: { enumerable: true, get: function() { return this._position; }, set: function(v) { if (this._position) { this._position.unbind(Events.Types.change, this._renderer.flagMatrix); } this._position = v; this._position.bind(Events.Types.change, this._renderer.flagMatrix); FlagMatrix.call(this); } }, rotation: { enumerable: true, get: function() { return this._rotation; }, set: function(v) { this._rotation = v; this._flagMatrix = true; } }, scale: { enumerable: true, get: function() { return this._scale; }, set: function(v) { if (this._scale instanceof Vector) { this._scale.unbind(Events.Types.change, this._renderer.flagMatrix); } this._scale = v; if (this._scale instanceof Vector) { this._scale.bind(Events.Types.change, this._renderer.flagMatrix); } this._flagMatrix = true; this._flagScale = true; } }, skewX: { enumerable: true, get: function() { return this._skewX; }, set: function(v) { this._skewX = v; this._flagMatrix = true; } }, skewY: { enumerable: true, get: function() { return this._skewY; }, set: function(v) { this._skewY = v; this._flagMatrix = true; } }, matrix: { enumerable: true, get: function() { return this._matrix; }, set: function(v) { this._matrix = v; this._flagMatrix = true; } }, worldMatrix: { enumerable: true, get: function() { getComputedMatrix(this, this._worldMatrix); return this._worldMatrix; }, set: function(v) { this._worldMatrix = v; } } }; function FlagMatrix() { this._flagMatrix = true; } // src/collection.js var Collection = class extends Array { _events = new Events(); get _bound() { return this._events._bound; } set _bound(v) { this._events._bound = v; } addEventListener() { return this._events.addEventListener.apply(this, arguments); } on() { return this._events.on.apply(this, arguments); } bind() { return this._events.bind.apply(this, arguments); } removeEventListener() { return this._events.removeEventListener.apply(this, arguments); } off() { return this._events.off.apply(this, arguments); } unbind() { return this._events.unbind.apply(this, arguments); } dispatchEvent() { return this._events.dispatchEvent.apply(this, arguments); } trigger() { return this._events.trigger.apply(this, arguments); } listen() { return this._events.listen.apply(this, arguments); } ignore() { return this._events.ignore.apply(this, arguments); } constructor() { super(); if (arguments[0] && Array.isArray(arguments[0])) { if (arguments[0].length > 0) { this.push.apply(this, arguments[0]); } } else if (arguments.length > 0) { this.push.apply(this, arguments); } } pop() { const popped = super.pop.apply(this, arguments); this.trigger(Events.Types.remove, [popped]); return popped; } shift() { const shifted = super.shift.apply(this, arguments); this.trigger(Events.Types.remove, [shifted]); return shifted; } push() { const pushed = super.push.apply(this, arguments); this.trigger(Events.Types.insert, arguments); return pushed; } unshift() { const unshifted = super.unshift.apply(this, arguments); this.trigger(Events.Types.insert, arguments); return unshifted; } splice() { const spliced = super.splice.apply(this, arguments); this.trigger(Events.Types.remove, spliced); if (arguments.length > 2) { const inserted = this.slice(arguments[0], arguments[0] + arguments.length - 2); this.trigger(Events.Types.insert, inserted); this.trigger(Events.Types.order); } return spliced; } sort() { super.sort.apply(this, arguments); this.trigger(Events.Types.order); return this; } reverse() { super.reverse.apply(this, arguments); this.trigger(Events.Types.order); return this; } indexOf() { return super.indexOf.apply(this, arguments); } map(func, scope) { const results = []; for (let key = 0; key < this.length; key++) { const value = this[key]; let result; if (scope) { result = func.call(scope, value, key); } else { result = func(value, key); } results.push(result); } return results; } }; // src/children.js var Children = class extends Collection { ids = {}; constructor(children) { children = Array.isArray(children) ? children : Array.prototype.slice.call(arguments); super(children); this.attach(children); this.on(Events.Types.insert, this.attach); this.on(Events.Types.remove, this.detach); } attach(children) { for (let i = 0; i < children.length; i++) { const child = children[i]; if (child && child.id) { this.ids[child.id] = child; } } return this; } detach(children) { for (let i = 0; i < children.length; i++) { delete this.ids[children[i].id]; } return this; } }; // src/group.js var min = Math.min; var max = Math.max; var _Group = class extends Shape { _flagAdditions = false; _flagSubtractions = false; _flagOrder = false; _flagOpacity = true; _flagBeginning = false; _flagEnding = false; _flagLength = false; _flagMask = false; _fill = "#fff"; _stroke = "#000"; _linewidth = 1; _opacity = 1; _visible = true; _cap = "round"; _join = "round"; _miter = 4; _closed = true; _curved = false; _automatic = true; _beginning = 0; _ending = 1; _length = 0; _mask = null; constructor(children) { super(); for (let prop in proto5) { Object.defineProperty(this, prop, proto5[prop]); } this._renderer.type = "group"; this.additions = []; this.subtractions = []; this.children = Array.isArray(children) ? children : Array.prototype.slice.call(arguments); } static InsertChildren(children) { for (let i = 0; i < children.length; i++) { replaceParent.call(this, children[i], this); } } static RemoveChildren(children) { for (let i = 0; i < children.length; i++) { replaceParent.call(this, children[i]); } } static OrderChildren(children) { this._flagOrder = true; } clone(parent) { const clone = new _Group(); const children = this.children.map(function(child) { return child.clone(); }); clone.add(children); clone.opacity = this.opacity; if (this.mask) { clone.mask = this.mask; } clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.className = this.className; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } if (parent) { parent.add(clone); } return clone._update(); } toObject() { const result = { children: [], translation: this.translation.toObject(), rotation: this.rotation, scale: this.scale instanceof Vector ? this.scale.toObject() : this.scale, opacity: this.opacity, className: this.className, mask: this.mask ? this.mask.toObject() : null }; if (this.matrix.manual) { result.matrix = this.matrix.toObject(); } _.each(this.children, function(child, i) { result.children[i] = child.toObject(); }, this); return result; } corner() { const rect = this.getBoundingClientRect(true); for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.translation.x -= rect.left; child.translation.y -= rect.top; } if (this.mask) { this.mask.translation.x -= rect.left; this.mask.translation.y -= rect.top; } return this; } center() { const rect = this.getBoundingClientRect(true); const cx = rect.left + rect.width / 2 - this.translation.x; const cy = rect.top + rect.height / 2 - this.translation.y; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; if (child.isShape) { child.translation.x -= cx; child.translation.y -= cy; } } if (this.mask) { this.mask.translation.x -= cx; this.mask.translation.y -= cy; } return this; } getById(id) { let found = null; function search(node) { if (node.id === id) { return node; } else if (node.children) { for (let i = 0; i < node.children.length; i++) { found = search(node.children[i]); if (found) { return found; } } } return null; } return search(this); } getByClassName(className) { const found = []; function search(node) { if (Array.prototype.indexOf.call(node.classList, className) >= 0) { found.push(node); } if (node.children) { for (let i = 0; i < node.children.length; i++) { const child = node.children[i]; search(child); } } return found; } return search(this); } getByType(type) { const found = []; function search(node) { if (node instanceof type) { found.push(node); } if (node.children) { for (let i = 0; i < node.children.length; i++) { const child = node.children[i]; search(child); } } return found; } return search(this); } add(objects) { if (!(objects instanceof Array)) { objects = Array.prototype.slice.call(arguments); } else { objects = objects.slice(); } for (let i = 0; i < objects.length; i++) { const child = objects[i]; if (!(child && child.id)) { continue; } const index = Array.prototype.indexOf.call(this.children, child); if (index >= 0) { this.children.splice(index, 1); } this.children.push(child); } return this; } remove(objects) { const l = arguments.length, grandparent = this.parent; if (l <= 0 && grandparent) { grandparent.remove(this); return this; } if (!(objects instanceof Array)) { objects = Array.prototype.slice.call(arguments); } else { objects = objects.slice(); } for (let i = 0; i < objects.length; i++) { const object = objects[i]; if (!object || !this.children.ids[object.id]) { continue; } const index = this.children.indexOf(object); if (index >= 0) { this.children.splice(index, 1); } } return this; } getBoundingClientRect(shallow) { let rect, matrix, tc, lc, rc, bc; this._update(true); let left = Infinity, right = -Infinity, top = Infinity, bottom = -Infinity; const regex3 = /texture|gradient/i; matrix = shallow ? this.matrix : this.worldMatrix; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; if (!child.visible || regex3.test(child._renderer.type)) { continue; } rect = child.getBoundingClientRect(shallow); tc = typeof rect.top !== "number" || _.isNaN(rect.top) || !isFinite(rect.top); lc = typeof rect.left !== "number" || _.isNaN(rect.left) || !isFinite(rect.left); rc = typeof rect.right !== "number" || _.isNaN(rect.right) || !isFinite(rect.right); bc = typeof rect.bottom !== "number" || _.isNaN(rect.bottom) || !isFinite(rect.bottom); if (tc || lc || rc || bc) { continue; } if (shallow) { const [ax, ay] = matrix.multiply(rect.left, rect.top); const [bx, by] = matrix.multiply(rect.right, rect.top); const [cx, cy] = matrix.multiply(rect.left, rect.bottom); const [dx, dy] = matrix.multiply(rect.right, rect.bottom); top = min(ay, by, cy, dy); left = min(ax, bx, cx, dx); right = max(ax, bx, cx, dx); bottom = max(ay, by, cy, dy); } else { top = min(rect.top, top); left = min(rect.left, left); right = max(rect.right, right); bottom = max(rect.bottom, bottom); } } return { top, left, right, bottom, width: right - left, height: bottom - top }; } noFill() { this.children.forEach(function(child) { child.noFill(); }); return this; } noStroke() { this.children.forEach(function(child) { child.noStroke(); }); return this; } subdivide() { const args = arguments; this.children.forEach(function(child) { child.subdivide.apply(child, args); }); return this; } _update() { let i, l, child; if (this._flagBeginning || this._flagEnding) { const beginning = Math.min(this._beginning, this._ending); const ending = Math.max(this._beginning, this._ending); const length = this.length; let sum = 0; const bd = beginning * length; const ed = ending * length; for (i = 0; i < this.children.length; i++) { child = this.children[i]; l = child.length; if (bd > sum + l) { child.beginning = 1; child.ending = 1; } else if (ed < sum) { child.beginning = 0; child.ending = 0; } else if (bd > sum && bd < sum + l) { child.beginning = (bd - sum) / l; child.ending = 1; } else if (ed > sum && ed < sum + l) { child.beginning = 0; child.ending = (ed - sum) / l; } else { child.beginning = 0; child.ending = 1; } sum += l; } } return super._update.apply(this, arguments); } flagReset() { if (this._flagAdditions) { this.additions.length = 0; this._flagAdditions = false; } if (this._flagSubtractions) { this.subtractions.length = 0; this._flagSubtractions = false; } this._flagOrder = this._flagMask = this._flagOpacity = this._flagBeginning = this._flagEnding = false; super.flagReset.call(this); return this; } }; var Group = _Group; __publicField(Group, "Children", Children); __publicField(Group, "Properties", [ "fill", "stroke", "linewidth", "cap", "join", "miter", "closed", "curved", "automatic" ]); var proto5 = { visible: { enumerable: true, get: function() { return this._visible; }, set: function(v) { this._flagVisible = this._visible !== v || this._flagVisible; this._visible = v; } }, opacity: { enumerable: true, get: function() { return this._opacity; }, set: function(v) { this._flagOpacity = this._opacity !== v || this._flagOpacity; this._opacity = v; } }, beginning: { enumerable: true, get: function() { return this._beginning; }, set: function(v) { this._flagBeginning = this._beginning !== v || this._flagBeginning; this._beginning = v; } }, ending: { enumerable: true, get: function() { return this._ending; }, set: function(v) { this._flagEnding = this._ending !== v || this._flagEnding; this._ending = v; } }, length: { enumerable: true, get: function() { if (this._flagLength || this._length <= 0) { this._length = 0; if (!this.children) { return this._length; } for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; this._length += child.length; } } return this._length; } }, fill: { enumerable: true, get: function() { return this._fill; }, set: function(v) { this._fill = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.fill = v; } } }, stroke: { enumerable: true, get: function() { return this._stroke; }, set: function(v) { this._stroke = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.stroke = v; } } }, linewidth: { enumerable: true, get: function() { return this._linewidth; }, set: function(v) { this._linewidth = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.linewidth = v; } } }, join: { enumerable: true, get: function() { return this._join; }, set: function(v) { this._join = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.join = v; } } }, miter: { enumerable: true, get: function() { return this._miter; }, set: function(v) { this._miter = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.miter = v; } } }, cap: { enumerable: true, get: function() { return this._cap; }, set: function(v) { this._cap = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.cap = v; } } }, closed: { enumerable: true, get: function() { return this._closed; }, set: function(v) { this._closed = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.closed = v; } } }, curved: { enumerable: true, get: function() { return this._curved; }, set: function(v) { this._curved = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.curved = v; } } }, automatic: { enumerable: true, get: function() { return this._automatic; }, set: function(v) { this._automatic = v; for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; child.automatic = v; } } }, children: { enumerable: true, get: function() { return this._children; }, set: function(children) { const insertChildren = Group.InsertChildren.bind(this); const removeChildren = Group.RemoveChildren.bind(this); const orderChildren = Group.OrderChildren.bind(this); if (this._children) { this._children.unbind(); if (this._children.length > 0) { removeChildren(this._children); } } this._children = new Children(children); this._children.bind(Events.Types.insert, insertChildren); this._children.bind(Events.Types.remove, removeChildren); this._children.bind(Events.Types.order, orderChildren); if (children.length > 0) { insertChildren(children); } } }, mask: { enumerable: true, get: function() { return this._mask; }, set: function(v) { this._mask = v; this._flagMask = true; if (_.isObject(v) && !v.clip) { v.clip = true; } } } }; function replaceParent(child, newParent) { const parent = child.parent; let index; if (parent === newParent) { add(); return; } if (parent && parent.children.ids[child.id]) { index = Array.prototype.indexOf.call(parent.children, child); parent.children.splice(index, 1); splice(); } if (newParent) { add(); return; } splice(); if (parent._flagAdditions && parent.additions.length === 0) { parent._flagAdditions = false; } if (parent._flagSubtractions && parent.subtractions.length === 0) { parent._flagSubtractions = false; } delete child.parent; function add() { if (newParent.subtractions.length > 0) { index = Array.prototype.indexOf.call(newParent.subtractions, child); if (index >= 0) { newParent.subtractions.splice(index, 1); } } if (newParent.additions.length > 0) { index = Array.prototype.indexOf.call(newParent.additions, child); if (index >= 0) { newParent.additions.splice(index, 1); } } child.parent = newParent; newParent.additions.push(child); newParent._flagAdditions = true; } function splice() { index = Array.prototype.indexOf.call(parent.additions, child); if (index >= 0) { parent.additions.splice(index, 1); } index = Array.prototype.indexOf.call(parent.subtractions, child); if (index < 0) { parent.subtractions.push(child); parent._flagSubtractions = true; } } } // src/renderers/canvas.js var emptyArray = []; var max2 = Math.max; var min2 = Math.min; var abs = Math.abs; var sin2 = Math.sin; var cos2 = Math.cos; var acos = Math.acos; var sqrt = Math.sqrt; var canvas = { isHidden: /(undefined|none|transparent)/i, alignments: { left: "start", middle: "center", right: "end" }, shim: function(elem, name) { elem.tagName = elem.nodeName = name || "canvas"; elem.nodeType = 1; elem.getAttribute = function(prop) { return this[prop]; }; elem.setAttribute = function(prop, val) { this[prop] = val; return this; }; return elem; }, group: { renderChild: function(child) { canvas[child._renderer.type].render.call(child, this.ctx, true, this.clip); }, render: function(ctx) { if (!this._visible) { return this; } this._update(); const matrix = this._matrix.elements; const parent = this.parent; this._renderer.opacity = this._opacity * (parent && parent._renderer ? parent._renderer.opacity : 1); const mask = this._mask; const defaultMatrix = isDefaultMatrix(matrix); const shouldIsolate = !defaultMatrix || !!mask; if (!this._renderer.context) { this._renderer.context = {}; } this._renderer.context.ctx = ctx; if (shouldIsolate) { ctx.save(); if (!defaultMatrix) { ctx.transform( matrix[0], matrix[3], matrix[1], matrix[4], matrix[2], matrix[5] ); } } if (mask) { canvas[mask._renderer.type].render.call(mask, ctx, true); } if (this._opacity > 0 && this._scale !== 0) { for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; canvas[child._renderer.type].render.call(child, ctx); } } if (shouldIsolate) { ctx.restore(); } return this.flagReset(); } }, path: { render: function(ctx, forced, parentClipped) { let matrix, stroke, linewidth, fill, opacity, visible, cap, join, miter, closed2, commands, length, last, prev, a, b, c, d, ux, uy, vx, vy, ar, bl, br, cl, x, y, mask, clip, defaultMatrix, isOffset, dashes, po; po = this.parent && this.parent._renderer ? this.parent._renderer.opacity : 1; mask = this._mask; clip = this._clip; opacity = this._opacity * (po || 1); visible = this._visible; if (!forced && (!visible || clip || opacity === 0)) { return this; } this._update(); matrix = this._matrix.elements; stroke = this._stroke; linewidth = this._linewidth; fill = this._fill; cap = this._cap; join = this._join; miter = this._miter; closed2 = this._closed; commands = this._renderer.vertices; length = commands.length; last = length - 1; defaultMatrix = isDefaultMatrix(matrix); dashes = this.dashes; if (!defaultMatrix) { ctx.save(); ctx.transform(matrix[0], matrix[3], matrix[1], matrix[4], matrix[2], matrix[5]); } if (mask) { canvas[mask._renderer.type].render.call(mask, ctx, true); } if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { canvas[fill._renderer.type].render.call(fill, ctx, this); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { canvas[stroke._renderer.type].render.call(stroke, ctx, this); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth; } if (miter) { ctx.miterLimit = miter; } if (join) { ctx.lineJoin = join; } if (!closed2 && cap) { ctx.lineCap = cap; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } ctx.beginPath(); let rx, ry, xAxisRotation, largeArcFlag, sweepFlag, ax, ay; for (let i = 0; i < length; i++) { b = commands[i]; x = b.x; y = b.y; switch (b.command) { case Commands.close: ctx.closePath(); break; case Commands.arc: rx = b.rx; ry = b.ry; xAxisRotation = b.xAxisRotation; largeArcFlag = b.largeArcFlag; sweepFlag = b.sweepFlag; prev = closed2 ? mod(i - 1, length) : max2(i - 1, 0); a = commands[prev]; ax = a.x; ay = a.y; canvas.renderSvgArcCommand( ctx, ax, ay, rx, ry, largeArcFlag, sweepFlag, xAxisRotation, x, y ); break; case Commands.curve: prev = closed2 ? mod(i - 1, length) : Math.max(i - 1, 0); a = commands[prev]; ar = a.controls && a.controls.right || Vector.zero; bl = b.controls && b.controls.left || Vector.zero; if (a._relative) { vx = ar.x + a.x; vy = ar.y + a.y; } else { vx = ar.x; vy = ar.y; } if (b._relative) { ux = bl.x + b.x; uy = bl.y + b.y; } else { ux = bl.x; uy = bl.y; } ctx.bezierCurveTo(vx, vy, ux, uy, x, y); if (i >= last && closed2) { c = d; br = b.controls && b.controls.right || Vector.zero; cl = c.controls && c.controls.left || Vector.zero; if (b._relative) { vx = br.x + b.x; vy = br.y + b.y; } else { vx = br.x; vy = br.y; } if (c._relative) { ux = cl.x + c.x; uy = cl.y + c.y; } else { ux = cl.x; uy = cl.y; } x = c.x; y = c.y; ctx.bezierCurveTo(vx, vy, ux, uy, x, y); } break; case Commands.line: ctx.lineTo(x, y); break; case Commands.move: d = b; ctx.moveTo(x, y); break; } } if (closed2) { ctx.closePath(); } if (!clip && !parentClipped) { if (!canvas.isHidden.test(fill)) { isOffset = fill._renderer && fill._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(fill._renderer.scale.x, fill._renderer.scale.y); } ctx.fill(); if (isOffset) { ctx.restore(); } } if (!canvas.isHidden.test(stroke)) { isOffset = stroke._renderer && stroke._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(stroke._renderer.scale.x, stroke._renderer.scale.y); ctx.lineWidth = linewidth / stroke._renderer.scale.x; } ctx.stroke(); if (isOffset) { ctx.restore(); } } } if (!defaultMatrix) { ctx.restore(); } if (clip && !parentClipped) { ctx.clip(); } if (dashes && dashes.length > 0) { ctx.setLineDash(emptyArray); } return this.flagReset(); } }, points: { render: function(ctx, forced, parentClipped) { let me, stroke, linewidth, fill, opacity, visible, size, commands, length, b, x, y, defaultMatrix, isOffset, dashes, po; po = this.parent && this.parent._renderer ? this.parent._renderer.opacity : 1; opacity = this._opacity * (po || 1); visible = this._visible; if (!forced && (!visible || opacity === 0)) { return this; } this._update(); me = this._matrix.elements; stroke = this._stroke; linewidth = this._linewidth; fill = this._fill; commands = this._renderer.collection; length = commands.length; defaultMatrix = isDefaultMatrix(me); dashes = this.dashes; size = this._size; if (!defaultMatrix) { ctx.save(); ctx.transform(me[0], me[3], me[1], me[4], me[2], me[5]); } if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { canvas[fill._renderer.type].render.call(fill, ctx, this); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { canvas[stroke._renderer.type].render.call(stroke, ctx, this); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } ctx.beginPath(); let radius = size * 0.5, m; if (!this._sizeAttenuation) { m = this.worldMatrix.elements; m = decomposeMatrix(m[0], m[3], m[1], m[4], m[2], m[5]); radius /= Math.max(m.scaleX, m.scaleY); } for (let i = 0; i < length; i++) { b = commands[i]; x = b.x; y = b.y; ctx.moveTo(x + radius, y); ctx.arc(x, y, radius, 0, TWO_PI); } if (!parentClipped) { if (!canvas.isHidden.test(fill)) { isOffset = fill._renderer && fill._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(fill._renderer.scale.x, fill._renderer.scale.y); } ctx.fill(); if (isOffset) { ctx.restore(); } } if (!canvas.isHidden.test(stroke)) { isOffset = stroke._renderer && stroke._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(stroke._renderer.scale.x, stroke._renderer.scale.y); ctx.lineWidth = linewidth / stroke._renderer.scale.x; } ctx.stroke(); if (isOffset) { ctx.restore(); } } } if (!defaultMatrix) { ctx.restore(); } if (dashes && dashes.length > 0) { ctx.setLineDash(emptyArray); } return this.flagReset(); } }, text: { render: function(ctx, forced, parentClipped) { const po = this.parent && this.parent._renderer ? this.parent._renderer.opacity : 1; const opacity = this._opacity * po; const visible = this._visible; const mask = this._mask; const clip = this._clip; if (!forced && (!visible || clip || opacity === 0)) { return this; } this._update(); const matrix = this._matrix.elements; const stroke = this._stroke; const linewidth = this._linewidth; const fill = this._fill; const decoration = this._decoration; const defaultMatrix = isDefaultMatrix(matrix); const isOffset = fill._renderer && fill._renderer.offset && stroke._renderer && stroke._renderer.offset; const dashes = this.dashes; const alignment = canvas.alignments[this._alignment] || this._alignment; const baseline = this._baseline; let a, b, c, d, e, sx, sy, x1, y1, x2, y2; if (!defaultMatrix) { ctx.save(); ctx.transform(matrix[0], matrix[3], matrix[1], matrix[4], matrix[2], matrix[5]); } if (mask) { canvas[mask._renderer.type].render.call(mask, ctx, true); } if (!isOffset) { ctx.font = [this._style, this._weight, this._size + "px/" + this._leading + "px", this._family].join(" "); } ctx.textAlign = alignment; ctx.textBaseline = baseline; if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { canvas[fill._renderer.type].render.call(fill, ctx, this); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { canvas[stroke._renderer.type].render.call(stroke, ctx, this); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } if (!clip && !parentClipped) { if (!canvas.isHidden.test(fill)) { if (fill._renderer && fill._renderer.offset) { sx = fill._renderer.scale.x; sy = fill._renderer.scale.y; ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(sx, sy); a = this._size / fill._renderer.scale.y; b = this._leading / fill._renderer.scale.y; ctx.font = [ this._style, this._weight, a + "px/", b + "px", this._family ].join(" "); c = fill._renderer.offset.x / fill._renderer.scale.x; d = fill._renderer.offset.y / fill._renderer.scale.y; ctx.fillText(this.value, c, d); ctx.restore(); } else { ctx.fillText(this.value, 0, 0); } } if (!canvas.isHidden.test(stroke)) { if (stroke._renderer && stroke._renderer.offset) { sx = stroke._renderer.scale.x; sy = stroke._renderer.scale.y; ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(sx, sy); a = this._size / stroke._renderer.scale.y; b = this._leading / stroke._renderer.scale.y; ctx.font = [ this._style, this._weight, a + "px/", b + "px", this._family ].join(" "); c = stroke._renderer.offset.x / stroke._renderer.scale.x; d = stroke._renderer.offset.y / stroke._renderer.scale.y; e = linewidth / stroke._renderer.scale.x; ctx.lineWidth = e; ctx.strokeText(this.value, c, d); ctx.restore(); } else { ctx.strokeText(this.value, 0, 0); } } } if (/(underline|strikethrough)/i.test(decoration)) { const metrics = ctx.measureText(this.value); let scalar = 1; switch (decoration) { case "underline": y1 = metrics.actualBoundingBoxAscent; y2 = metrics.actualBoundingBoxAscent; break; case "strikethrough": y1 = 0; y2 = 0; scalar = 0.5; break; } switch (baseline) { case "top": y1 += this._size * scalar; y2 += this._size * scalar; break; case "baseline": case "bottom": y1 -= this._size * scalar; y2 -= this._size * scalar; break; } switch (alignment) { case "left": case "start": x1 = 0; x2 = metrics.width; break; case "right": case "end": x1 = -metrics.width; x2 = 0; break; default: x1 = -metrics.width / 2; x2 = metrics.width / 2; } ctx.lineWidth = Math.max(Math.floor(this._size / 15), 1); ctx.strokeStyle = ctx.fillStyle; ctx.beginPath(); ctx.moveTo(x1, y1); ctx.lineTo(x2, y2); ctx.stroke(); } if (!defaultMatrix) { ctx.restore(); } if (clip && !parentClipped) { ctx.clip(); } if (dashes && dashes.length > 0) { ctx.setLineDash(emptyArray); } return this.flagReset(); } }, "linear-gradient": { render: function(ctx, parent) { if (!parent) { return; } this._update(); if (!this._renderer.effect || this._flagEndPoints || this._flagStops || this._flagUnits) { let rect; let lx = this.left._x; let ly = this.left._y; let rx = this.right._x; let ry = this.right._y; if (/objectBoundingBox/i.test(this._units)) { rect = parent.getBoundingClientRect(true); lx = (lx - 0.5) * rect.width; ly = (ly - 0.5) * rect.height; rx = (rx - 0.5) * rect.width; ry = (ry - 0.5) * rect.height; } this._renderer.effect = ctx.createLinearGradient(lx, ly, rx, ry); for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; this._renderer.effect.addColorStop(stop._offset, stop._color); } } return this.flagReset(); } }, "radial-gradient": { render: function(ctx, parent) { if (!parent) { return; } this._update(); if (!this._renderer.effect || this._flagCenter || this._flagFocal || this._flagRadius || this._flagStops || this._flagUnits) { let rect; let cx = this.center._x; let cy = this.center._y; let fx = this.focal._x; let fy = this.focal._y; let radius = this._radius; if (/objectBoundingBox/i.test(this._units)) { rect = parent.getBoundingClientRect(true); cx = cx * rect.width * 0.5; cy = cy * rect.height * 0.5; fx = fx * rect.width * 0.5; fy = fy * rect.height * 0.5; radius *= Math.min(rect.width, rect.height) * 0.5; } this._renderer.effect = ctx.createRadialGradient( cx, cy, 0, fx, fy, radius ); for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; this._renderer.effect.addColorStop(stop._offset, stop._color); } } return this.flagReset(); } }, texture: { render: function(ctx) { this._update(); const image = this.image; if (!this._renderer.effect || (this._flagLoaded || this._flagImage || this._flagVideo || this._flagRepeat) && this.loaded) { this._renderer.effect = ctx.createPattern(this.image, this._repeat); } if (this._flagOffset || this._flagLoaded || this._flagScale) { if (!(this._renderer.offset instanceof Vector)) { this._renderer.offset = new Vector(); } this._renderer.offset.x = -this._offset.x; this._renderer.offset.y = -this._offset.y; if (image) { this._renderer.offset.x += image.width / 2; this._renderer.offset.y += image.height / 2; if (this._scale instanceof Vector) { this._renderer.offset.x *= this._scale.x; this._renderer.offset.y *= this._scale.y; } else { this._renderer.offset.x *= this._scale; this._renderer.offset.y *= this._scale; } } } if (this._flagScale || this._flagLoaded) { if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.copy(this._scale); } else { this._renderer.scale.set(this._scale, this._scale); } } return this.flagReset(); } }, renderSvgArcCommand: function(ctx, ax, ay, rx, ry, largeArcFlag, sweepFlag, xAxisRotation, x, y) { xAxisRotation = xAxisRotation * Math.PI / 180; rx = abs(rx); ry = abs(ry); const dx2 = (ax - x) / 2; const dy2 = (ay - y) / 2; const x1p = cos2(xAxisRotation) * dx2 + sin2(xAxisRotation) * dy2; const y1p = -sin2(xAxisRotation) * dx2 + cos2(xAxisRotation) * dy2; const x1ps = x1p * x1p; const y1ps = y1p * y1p; let rxs = rx * rx; let rys = ry * ry; const cr = x1ps / rxs + y1ps / rys; if (cr > 1) { const s = sqrt(cr); rx = s * rx; ry = s * ry; rxs = rx * rx; rys = ry * ry; } const dq = rxs * y1ps + rys * x1ps; const pq = (rxs * rys - dq) / dq; let q = sqrt(max2(0, pq)); if (largeArcFlag === sweepFlag) q = -q; const cxp = q * rx * y1p / ry; const cyp = -q * ry * x1p / rx; const cx = cos2(xAxisRotation) * cxp - sin2(xAxisRotation) * cyp + (ax + x) / 2; const cy = sin2(xAxisRotation) * cxp + cos2(xAxisRotation) * cyp + (ay + y) / 2; const startAngle = svgAngle(1, 0, (x1p - cxp) / rx, (y1p - cyp) / ry); const delta = svgAngle( (x1p - cxp) / rx, (y1p - cyp) / ry, (-x1p - cxp) / rx, (-y1p - cyp) / ry ) % TWO_PI; const endAngle = startAngle + delta; const clockwise = sweepFlag === 0; renderArcEstimate( ctx, cx, cy, rx, ry, startAngle, endAngle, clockwise, xAxisRotation ); } }; var Renderer = class extends Events { constructor(params) { super(); const smoothing = params.smoothing !== false; this.domElement = params.domElement || document.createElement("canvas"); this.ctx = this.domElement.getContext("2d"); this.overdraw = params.overdraw || false; if (typeof this.ctx.imageSmoothingEnabled !== "undefined") { this.ctx.imageSmoothingEnabled = smoothing; } this.scene = new Group(); this.scene.parent = this; } setSize(width, height, ratio) { this.width = width; this.height = height; this.ratio = typeof ratio === "undefined" ? getRatio(this.ctx) : ratio; this.domElement.width = width * this.ratio; this.domElement.height = height * this.ratio; if (this.domElement.style) { _.extend(this.domElement.style, { width: width + "px", height: height + "px" }); } return this.trigger(Events.Types.resize, width, height, ratio); } render() { const isOne = this.ratio === 1; if (!isOne) { this.ctx.save(); this.ctx.scale(this.ratio, this.ratio); } if (!this.overdraw) { this.ctx.clearRect(0, 0, this.width, this.height); } canvas.group.render.call(this.scene, this.ctx); if (!isOne) { this.ctx.restore(); } return this; } }; __publicField(Renderer, "Utils", canvas); function renderArcEstimate(ctx, ox, oy, rx, ry, startAngle, endAngle, clockwise, xAxisRotation) { const delta = endAngle - startAngle; const epsilon = Curve.Tolerance.epsilon; const samePoints = Math.abs(delta) < epsilon; let deltaAngle = mod(delta, TWO_PI); if (deltaAngle < epsilon) { if (samePoints) { deltaAngle = 0; } else { deltaAngle = TWO_PI; } } if (clockwise === true && !samePoints) { if (deltaAngle === TWO_PI) { deltaAngle = -TWO_PI; } else { deltaAngle = deltaAngle - TWO_PI; } } for (let i = 0; i < Constants.Resolution; i++) { const t = i / (Constants.Resolution - 1); const angle = startAngle + t * deltaAngle; let x = ox + rx * Math.cos(angle); let y = oy + ry * Math.sin(angle); if (xAxisRotation !== 0) { const cos7 = Math.cos(xAxisRotation); const sin7 = Math.sin(xAxisRotation); const tx = x - ox; const ty = y - oy; x = tx * cos7 - ty * sin7 + ox; y = tx * sin7 + ty * cos7 + oy; } ctx.lineTo(x, y); } } function svgAngle(ux, uy, vx, vy) { const dot = ux * vx + uy * vy; const len = sqrt(ux * ux + uy * uy) * sqrt(vx * vx + vy * vy); let ang = acos(max2(-1, min2(1, dot / len))); if (ux * vy - uy * vx < 0) { ang = -ang; } return ang; } function isDefaultMatrix(m) { return m[0] == 1 && m[3] == 0 && m[1] == 0 && m[4] == 1 && m[2] == 0 && m[5] == 0; } // src/utils/canvas-shim.js var CanvasShim = { Image: null, isHeadless: false, shim: function(canvas3, Image) { Renderer.Utils.shim(canvas3); if (typeof Image !== "undefined") { CanvasShim.Image = Image; } CanvasShim.isHeadless = true; return canvas3; } }; // src/utils/dom.js var dom = { hasEventListeners: typeof root.addEventListener === "function", bind: function(elem, event, func, bool) { if (this.hasEventListeners) { elem.addEventListener(event, func, !!bool); } else { elem.attachEvent("on" + event, func); } return dom; }, unbind: function(elem, event, func, bool) { if (dom.hasEventListeners) { elem.removeEventListeners(event, func, !!bool); } else { elem.detachEvent("on" + event, func); } return dom; }, getRequestAnimationFrame: function() { const vendors = ["ms", "moz", "webkit", "o"]; let lastTime = 0; let request = root.requestAnimationFrame; if (!request) { for (let i = 0; i < vendors.length; i++) { request = root[vendors[i] + "RequestAnimationFrame"] || request; } request = request || fallbackRequest; } function fallbackRequest(callback, element) { const currTime = new Date().getTime(); const timeToCall = Math.max(0, 16 - (currTime - lastTime)); const id = root.setTimeout(nextRequest, timeToCall); lastTime = currTime + timeToCall; function nextRequest() { callback(currTime + timeToCall); } return id; } return request; } }; var temp = root.document ? root.document.createElement("div") : {}; temp.id = "help-two-load"; Object.defineProperty(dom, "temp", { enumerable: true, get: function() { if (_.isElement(temp) && !root.document.head.contains(temp)) { temp.style.display = "none"; root.document.head.appendChild(temp); } return temp; } }); // src/utils/error.js var TwoError = class extends Error { name = "Two.js"; message; constructor(message) { super(); this.message = message; } }; // src/registry.js var Registry = class { map = {}; constructor() { } add(id, obj) { this.map[id] = obj; return this; } remove(id) { delete this.map[id]; return this; } get(id) { return this.map[id]; } contains(id) { return id in this.map; } }; // src/utils/shape.js function contains(path, t) { if (t === 0 || t === 1) { return true; } const length = path._length; const target = length * t; let elapsed = 0; for (let i = 0; i < path._lengths.length; i++) { const dist = path._lengths[i]; if (elapsed >= target) { return target - elapsed >= 0; } elapsed += dist; } return false; } function getIdByLength(path, target) { const total = path._length; if (target <= 0) { return 0; } else if (target >= total) { return path._lengths.length - 1; } for (let i = 0, sum = 0; i < path._lengths.length; i++) { if (sum + path._lengths[i] >= target) { target -= sum; return Math.max(i - 1, 0) + target / path._lengths[i]; } sum += path._lengths[i]; } return -1; } function getCurveLength2(a, b, limit) { let x1, x2, x3, x4, y1, y2, y3, y4; const right = b.controls && b.controls.right; const left = a.controls && a.controls.left; x1 = b.x; y1 = b.y; x2 = (right || b).x; y2 = (right || b).y; x3 = (left || a).x; y3 = (left || a).y; x4 = a.x; y4 = a.y; if (right && b._relative) { x2 += b.x; y2 += b.y; } if (left && a._relative) { x3 += a.x; y3 += a.y; } return getCurveLength(x1, y1, x2, y2, x3, y3, x4, y4, limit); } function getSubdivisions(a, b, limit) { let x1, x2, x3, x4, y1, y2, y3, y4; const right = b.controls && b.controls.right; const left = a.controls && a.controls.left; x1 = b.x; y1 = b.y; x2 = (right || b).x; y2 = (right || b).y; x3 = (left || a).x; y3 = (left || a).y; x4 = a.x; y4 = a.y; if (right && b._relative) { x2 += b.x; y2 += b.y; } if (left && a._relative) { x3 += a.x; y3 += a.y; } return subdivide(x1, y1, x2, y2, x3, y3, x4, y4, limit); } // src/effects/stop.js var _Stop = class extends Element { _flagOffset = true; _flagOpacity = true; _flagColor = true; _offset = 0; _opacity = 1; _color = "#fff"; constructor(offset, color, opacity) { super(); for (let prop in proto6) { Object.defineProperty(this, prop, proto6[prop]); } this._renderer.type = "stop"; this.offset = typeof offset === "number" ? offset : _Stop.Index <= 0 ? 0 : 1; this.opacity = typeof opacity === "number" ? opacity : 1; this.color = typeof color === "string" ? color : _Stop.Index <= 0 ? "#fff" : "#000"; _Stop.Index = (_Stop.Index + 1) % 2; } clone(parent) { const clone = new _Stop(); _.each(_Stop.Properties, function(property) { clone[property] = this[property]; }, this); if (parent && parent.stops) { parent.stops.push(clone); } return clone; } toObject() { const result = {}; _.each(_Stop.Properties, function(k) { result[k] = this[k]; }, this); return result; } flagReset() { this._flagOffset = this._flagColor = this._flagOpacity = false; super.flagReset.call(this); return this; } }; var Stop = _Stop; __publicField(Stop, "Index", 0); __publicField(Stop, "Properties", ["offset", "opacity", "color"]); var proto6 = { offset: { enumerable: true, get: function() { return this._offset; }, set: function(v) { this._offset = v; this._flagOffset = true; if (this.parent) { this.parent._flagStops = true; } } }, opacity: { enumerable: true, get: function() { return this._opacity; }, set: function(v) { this._opacity = v; this._flagOpacity = true; if (this.parent) { this.parent._flagStops = true; } } }, color: { enumerable: true, get: function() { return this._color; }, set: function(v) { this._color = v; this._flagColor = true; if (this.parent) { this.parent._flagStops = true; } } } }; // src/effects/gradient.js var _Gradient = class extends Element { _flagStops = false; _flagSpread = false; _flagUnits = false; _spread = ""; _units = ""; constructor(stops) { super(); for (let prop in proto7) { Object.defineProperty(this, prop, proto7[prop]); } this._renderer.type = "gradient"; this.id = Constants.Identifier + Constants.uniqueId(); this.classList = []; this._renderer.flagStops = FlagStops.bind(this); this._renderer.bindStops = BindStops.bind(this); this._renderer.unbindStops = UnbindStops.bind(this); this.spread = "pad"; this.units = "objectBoundingBox"; if (stops) { this.stops = stops; } } clone(parent) { const stops = this.stops.map(function(s) { return s.clone(); }); const clone = new _Gradient(stops); _.each(_Gradient.Properties, function(k) { clone[k] = this[k]; }, this); if (parent) { parent.add(clone); } return clone; } toObject() { const result = { stops: this.stops.map(function(s) { return s.toObject(); }) }; _.each(_Gradient.Properties, function(k) { result[k] = this[k]; }, this); return result; } _update() { if (this._flagSpread || this._flagStops) { this.trigger(Events.Types.change); } return this; } flagReset() { this._flagSpread = this._flagUnits = this._flagStops = false; super.flagReset.call(this); return this; } }; var Gradient = _Gradient; __publicField(Gradient, "Stop", Stop); __publicField(Gradient, "Properties", ["spread", "stops", "renderer", "units"]); var proto7 = { spread: { enumerable: true, get: function() { return this._spread; }, set: function(v) { this._spread = v; this._flagSpread = true; } }, units: { enumerable: true, get: function() { return this._units; }, set: function(v) { this._units = v; this._flagUnits = true; } }, stops: { enumerable: true, get: function() { return this._stops; }, set: function(stops) { const bindStops = this._renderer.bindStops; const unbindStops = this._renderer.unbindStops; if (this._stops) { this._stops.unbind(Events.Types.insert, bindStops).unbind(Events.Types.remove, unbindStops); } this._stops = new Collection((stops || []).slice(0)); this._stops.bind(Events.Types.insert, bindStops).bind(Events.Types.remove, unbindStops); bindStops(this._stops); } } }; function FlagStops() { this._flagStops = true; } function BindStops(items) { let i = items.length; while (i--) { items[i].bind(Events.Types.change, this._renderer.flagStops); items[i].parent = this; } this._renderer.flagStops(); } function UnbindStops(items) { let i = items.length; while (i--) { items[i].unbind(Events.Types.change, this._renderer.flagStops); delete items[i].parent; } this._renderer.flagStops(); } // src/effects/linear-gradient.js var _LinearGradient = class extends Gradient { _flagEndPoints = false; _left = null; _right = null; constructor(x1, y1, x2, y2, stops) { super(stops); for (let prop in proto8) { Object.defineProperty(this, prop, proto8[prop]); } this._renderer.type = "linear-gradient"; this._renderer.flagEndPoints = FlagEndPoints.bind(this); this.left = new Vector(); this.right = new Vector(); if (typeof x1 === "number") { this.left.x = x1; } if (typeof y1 === "number") { this.left.y = y1; } if (typeof x2 === "number") { this.right.x = x2; } if (typeof y2 === "number") { this.right.y = y2; } } clone(parent) { const stops = this.stops.map(function(stop) { return stop.clone(); }); const clone = new _LinearGradient( this.left._x, this.left._y, this.right._x, this.right._y, stops ); _.each(Gradient.Properties, function(k) { clone[k] = this[k]; }, this); if (parent) { parent.add(clone); } return clone; } toObject() { const result = super.toObject.call(this); result.left = this.left.toObject(); result.right = this.right.toObject(); return result; } _update() { if (this._flagEndPoints || this._flagSpread || this._flagStops) { this.trigger(Events.Types.change); } return this; } flagReset() { this._flagEndPoints = false; super.flagReset.call(this); return this; } }; var LinearGradient = _LinearGradient; __publicField(LinearGradient, "Properties", ["left", "right"]); __publicField(LinearGradient, "Stop", Stop); var proto8 = { left: { enumerable: true, get: function() { return this._left; }, set: function(v) { if (this._left instanceof Vector) { this._left.unbind(Events.Types.change, this._renderer.flagEndPoints); } this._left = v; this._left.bind(Events.Types.change, this._renderer.flagEndPoints); this._flagEndPoints = true; } }, right: { enumerable: true, get: function() { return this._right; }, set: function(v) { if (this._right instanceof Vector) { this._right.unbind(Events.Types.change, this._renderer.flagEndPoints); } this._right = v; this._right.bind(Events.Types.change, this._renderer.flagEndPoints); this._flagEndPoints = true; } } }; function FlagEndPoints() { this._flagEndPoints = true; } // src/effects/radial-gradient.js var _RadialGradient = class extends Gradient { _flagRadius = false; _flagCenter = false; _flagFocal = false; _radius = 0; _center = null; _focal = null; constructor(cx, cy, r, stops, fx, fy) { super(stops); for (let prop in proto9) { Object.defineProperty(this, prop, proto9[prop]); } this._renderer.type = "radial-gradient"; this._renderer.flagCenter = FlagCenter.bind(this); this._renderer.flagFocal = FlagFocal.bind(this); this.center = new Vector(); this.radius = typeof r === "number" ? r : 1; this.focal = new Vector(); if (typeof cx === "number") { this.center.x = cx; } if (typeof cy === "number") { this.center.y = cy; } this.focal.copy(this.center); if (typeof fx === "number") { this.focal.x = fx; } if (typeof fy === "number") { this.focal.y = fy; } } clone(parent) { const stops = this.stops.map(function(stop) { return stop.clone(); }); const clone = new _RadialGradient( this.center._x, this.center._y, this._radius, stops, this.focal._x, this.focal._y ); _.each(Gradient.Properties.concat(_RadialGradient.Properties), function(k) { clone[k] = this[k]; }, this); if (parent) { parent.add(clone); } return clone; } toObject() { const result = super.toObject.call(this); _.each(_RadialGradient.Properties, function(k) { result[k] = this[k]; }, this); result.center = this.center.toObject(); result.focal = this.focal.toObject(); return result; } _update() { if (this._flagRadius || this._flatCenter || this._flagFocal || this._flagSpread || this._flagStops) { this.trigger(Events.Types.change); } return this; } flagReset() { this._flagRadius = this._flagCenter = this._flagFocal = false; super.flagReset.call(this); return this; } }; var RadialGradient = _RadialGradient; __publicField(RadialGradient, "Stop", Stop); __publicField(RadialGradient, "Properties", ["center", "radius", "focal"]); var proto9 = { radius: { enumerable: true, get: function() { return this._radius; }, set: function(v) { this._radius = v; this._flagRadius = true; } }, center: { enumerable: true, get: function() { return this._center; }, set: function(v) { if (this._center) { this._center.unbind(Events.Types.change, this._renderer.flagCenter); } this._center = v; this._center.bind(Events.Types.change, this._renderer.flagCenter); this._flagCenter = true; } }, focal: { enumerable: true, get: function() { return this._focal; }, set: function(v) { if (this._focal) { this._focal.unbind(Events.Types.change, this._renderer.flagFocal); } this._focal = v; this._focal.bind(Events.Types.change, this._renderer.flagFocal); this._flagFocal = true; } } }; function FlagCenter() { this._flagCenter = true; } function FlagFocal() { this._flagFocal = true; } // src/effects/texture.js var anchor; var regex = { video: /\.(mp4|webm|ogg)$/i, image: /\.(jpe?g|png|gif|tiff|webp)$/i, effect: /texture|gradient/i }; if (root.document) { anchor = document.createElement("a"); } var _Texture = class extends Element { _flagSrc = false; _flagImage = false; _flagVideo = false; _flagLoaded = false; _flagRepeat = false; _flagOffset = false; _flagScale = false; _src = ""; _image = null; _loaded = false; _repeat = "no-repeat"; _scale = 1; _offset = null; constructor(src, callback) { super(); this._renderer = {}; for (let prop in proto10) { Object.defineProperty(this, prop, proto10[prop]); } this._renderer.type = "texture"; this._renderer.flagOffset = FlagOffset.bind(this); this._renderer.flagScale = FlagScale.bind(this); this.id = Constants.Identifier + Constants.uniqueId(); this.classList = []; this.loaded = false; this.repeat = "no-repeat"; this.offset = new Vector(); if (typeof callback === "function") { const loaded = function() { this.unbind(Events.Types.load, loaded); if (typeof callback === "function") { callback(); } }.bind(this); this.bind(Events.Types.load, loaded); } if (typeof src === "string") { this.src = src; } else if (typeof src === "object") { const elemString = Object.prototype.toString.call(src); if (elemString === "[object HTMLImageElement]" || elemString === "[object HTMLCanvasElement]" || elemString === "[object HTMLVideoElement]" || elemString === "[object Image]") { this.image = src; } } this._update(); } static getAbsoluteURL(path) { if (!anchor) { return path; } anchor.href = path; return anchor.href; } static loadHeadlessBuffer(texture, loaded) { texture.image.onload = loaded; texture.image.src = texture.src; } static getTag(image) { return image && image.nodeName && image.nodeName.toLowerCase() || "img"; } static getImage(src) { const absoluteSrc = _Texture.getAbsoluteURL(src); if (_Texture.ImageRegistry.contains(absoluteSrc)) { return _Texture.ImageRegistry.get(absoluteSrc); } let image; if (CanvasShim.Image) { image = new CanvasShim.Image(); Renderer.Utils.shim(image, "img"); } else if (root.document) { if (regex.video.test(absoluteSrc)) { image = document.createElement("video"); } else { image = document.createElement("img"); } } else { console.warn("Two.js: no prototypical image defined for Two.Texture"); } image.crossOrigin = "anonymous"; image.referrerPolicy = "no-referrer"; return image; } static load(texture, callback) { let image = texture.image; let tag = _Texture.getTag(image); if (texture._flagImage) { if (/canvas/i.test(tag)) { _Texture.Register.canvas(texture, callback); } else { texture._src = !CanvasShim.isHeadless && image.getAttribute("two-src") || image.src; _Texture.Register[tag](texture, callback); } } if (texture._flagSrc) { if (!image) { image = _Texture.getImage(texture.src); texture.image = image; } tag = _Texture.getTag(image); _Texture.Register[tag](texture, callback); } } clone() { const clone = new _Texture(this.src); clone.repeat = this.repeat; clone.offset.copy(this.origin); clone.scale = this.scale; return clone; } toObject() { return { src: this.src, repeat: this.repeat, origin: this.origin.toObject(), scale: typeof this.scale === "number" ? this.scale : this.scale.toObject() }; } _update() { if (this._flagSrc || this._flagImage) { this.trigger(Events.Types.change); if (this._flagSrc || this._flagImage) { this.loaded = false; _Texture.load(this, function() { this.loaded = true; this.trigger(Events.Types.change).trigger(Events.Types.load); }.bind(this)); } } if (this._image && this._image.readyState >= 4) { this._flagVideo = true; } return this; } flagReset() { this._flagSrc = this._flagImage = this._flagLoaded = this._flagRepeat = this._flagVideo = this._flagScale = this._flagOffset = false; super.flagReset.call(this); return this; } }; var Texture = _Texture; __publicField(Texture, "Properties", [ "src", "loaded", "repeat", "scale", "offset", "image" ]); __publicField(Texture, "RegularExpressions", regex); __publicField(Texture, "ImageRegistry", new Registry()); __publicField(Texture, "Register", { canvas: function(texture, callback) { texture._src = "#" + texture.id; _Texture.ImageRegistry.add(texture.src, texture.image); if (typeof callback === "function") { callback(); } }, img: function(texture, callback) { const image = texture.image; const loaded = function(e) { if (!CanvasShim.isHeadless && image.removeEventListener && typeof image.removeEventListener === "function") { image.removeEventListener("load", loaded, false); image.removeEventListener("error", error, false); } if (typeof callback === "function") { callback(); } }; const error = function(e) { if (!CanvasShim.isHeadless && typeof image.removeEventListener === "function") { image.removeEventListener("load", loaded, false); image.removeEventListener("error", error, false); } throw new TwoError("unable to load " + texture.src); }; if (typeof image.width === "number" && image.width > 0 && typeof image.height === "number" && image.height > 0) { loaded(); } else if (!CanvasShim.isHeadless && typeof image.addEventListener === "function") { image.addEventListener("load", loaded, false); image.addEventListener("error", error, false); } texture._src = _Texture.getAbsoluteURL(texture._src); if (!CanvasShim.isHeadless && image && image.getAttribute("two-src")) { return; } if (!CanvasShim.isHeadless) { image.setAttribute("two-src", texture.src); } _Texture.ImageRegistry.add(texture.src, image); if (CanvasShim.isHeadless) { _Texture.loadHeadlessBuffer(texture, loaded); } else { texture.image.src = texture.src; } }, video: function(texture, callback) { if (CanvasShim.isHeadless) { throw new TwoError("video textures are not implemented in headless environments."); } const loaded = function(e) { texture.image.removeEventListener("canplaythrough", loaded, false); texture.image.removeEventListener("error", error, false); texture.image.width = texture.image.videoWidth; texture.image.height = texture.image.videoHeight; if (typeof callback === "function") { callback(); } }; const error = function(e) { texture.image.removeEventListener("canplaythrough", loaded, false); texture.image.removeEventListener("error", error, false); throw new TwoError("unable to load " + texture.src); }; texture._src = _Texture.getAbsoluteURL(texture._src); if (!texture.image.getAttribute("two-src")) { texture.image.setAttribute("two-src", texture.src); _Texture.ImageRegistry.add(texture.src, texture.image); } if (texture.image.readyState >= 4) { loaded(); } else { texture.image.addEventListener("canplaythrough", loaded, false); texture.image.addEventListener("error", error, false); texture.image.src = texture.src; texture.image.load(); } } }); var proto10 = { src: { enumerable: true, get: function() { return this._src; }, set: function(v) { this._src = v; this._flagSrc = true; } }, loaded: { enumerable: true, get: function() { return this._loaded; }, set: function(v) { this._loaded = v; this._flagLoaded = true; } }, repeat: { enumerable: true, get: function() { return this._repeat; }, set: function(v) { this._repeat = v; this._flagRepeat = true; } }, image: { enumerable: true, get: function() { return this._image; }, set: function(image) { const tag = Texture.getTag(image); let index; switch (tag) { case "canvas": index = "#" + image.id; break; default: index = image.src; } if (Texture.ImageRegistry.contains(index)) { this._image = Texture.ImageRegistry.get(image.src); } else { this._image = image; } this._flagImage = true; } }, offset: { enumerable: true, get: function() { return this._offset; }, set: function(v) { if (this._offset) { this._offset.unbind(Events.Types.change, this._renderer.flagOffset); } this._offset = v; this._offset.bind(Events.Types.change, this._renderer.flagOffset); this._flagOffset = true; } }, scale: { enumerable: true, get: function() { return this._scale; }, set: function(v) { if (this._scale instanceof Vector) { this._scale.unbind(Events.Types.change, this._renderer.flagScale); } this._scale = v; if (this._scale instanceof Vector) { this._scale.bind(Events.Types.change, this._renderer.flagScale); } this._flagScale = true; } } }; function FlagOffset() { this._flagOffset = true; } function FlagScale() { this._flagScale = true; } // src/path.js var min3 = Math.min; var max3 = Math.max; var ceil = Math.ceil; var floor2 = Math.floor; var vector = new Vector(); var _Path = class extends Shape { _flagVertices = true; _flagLength = true; _flagFill = true; _flagStroke = true; _flagLinewidth = true; _flagOpacity = true; _flagVisible = true; _flagCap = true; _flagJoin = true; _flagMiter = true; _flagMask = false; _flagClip = false; _length = 0; _fill = "#fff"; _stroke = "#000"; _linewidth = 1; _opacity = 1; _visible = true; _cap = "round"; _join = "round"; _miter = 4; _closed = true; _curved = false; _automatic = true; _beginning = 0; _ending = 1; _mask = null; _clip = false; _dashes = null; constructor(vertices, closed2, curved, manual) { super(); for (let prop in proto11) { Object.defineProperty(this, prop, proto11[prop]); } this._renderer.type = "path"; this._renderer.flagVertices = FlagVertices.bind(this); this._renderer.bindVertices = BindVertices.bind(this); this._renderer.unbindVertices = UnbindVertices.bind(this); this._renderer.flagFill = FlagFill.bind(this); this._renderer.flagStroke = FlagStroke.bind(this); this._renderer.vertices = []; this._renderer.collection = []; this.closed = !!closed2; this.curved = !!curved; this.beginning = 0; this.ending = 1; this.fill = "#fff"; this.stroke = "#000"; this.linewidth = 1; this.opacity = 1; this.className = ""; this.visible = true; this.cap = "butt"; this.join = "miter"; this.miter = 4; this.vertices = vertices; this.automatic = !manual; this.dashes = []; this.dashes.offset = 0; } clone(parent) { const clone = new _Path(); for (let j = 0; j < this.vertices.length; j++) { clone.vertices.push(this.vertices[j].clone()); } for (let i = 0; i < _Path.Properties.length; i++) { const k = _Path.Properties[i]; clone[k] = this[k]; } clone.className = this.className; clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } if (parent) { parent.add(clone); } return clone._update(); } toObject() { const result = { vertices: this.vertices.map(function(v) { return v.toObject(); }) }; _.each(_Path.Properties, function(k) { if (typeof this[k] !== "undefined") { if (this[k].toObject) { result[k] = this[k].toObject(); } else { result[k] = this[k]; } } }, this); result.className = this.className; result.translation = this.translation.toObject(); result.rotation = this.rotation; result.scale = this.scale instanceof Vector ? this.scale.toObject() : this.scale; result.skewX = this.skewX; result.skewY = this.skewY; if (this.matrix.manual) { result.matrix = this.matrix.toObject(); } return result; } noFill() { this.fill = "none"; return this; } noStroke() { this.stroke = "none"; return this; } corner() { const rect = this.getBoundingClientRect(true); const hw = rect.width / 2; const hh = rect.height / 2; const cx = rect.left + rect.width / 2; const cy = rect.top + rect.height / 2; for (let i = 0; i < this.vertices.length; i++) { const v = this.vertices[i]; v.x -= cx; v.y -= cy; v.x += hw; v.y += hh; } if (this.mask) { this.mask.translation.x -= cx; this.mask.translation.x += hw; this.mask.translation.y -= cy; this.mask.translation.y += hh; } return this; } center() { const rect = this.getBoundingClientRect(true); const cx = rect.left + rect.width / 2 - this.translation.x; const cy = rect.top + rect.height / 2 - this.translation.y; for (let i = 0; i < this.vertices.length; i++) { const v = this.vertices[i]; v.x -= cx; v.y -= cy; } if (this.mask) { this.mask.translation.x -= cx; this.mask.translation.y -= cy; } return this; } getBoundingClientRect(shallow) { let matrix, border, l, i, v0, v1; let left = Infinity, right = -Infinity, top = Infinity, bottom = -Infinity; this._update(true); matrix = shallow ? this.matrix : this.worldMatrix; border = (this.linewidth || 0) / 2; l = this._renderer.vertices.length; if (this.linewidth > 0 || this.stroke && !/(transparent|none)/i.test(this.stroke)) { if (this.matrix.manual) { const { scaleX, scaleY } = decomposeMatrix( matrix.elements[0], matrix.elements[3], matrix.elements[1], matrix.elements[4], matrix.elements[2], matrix.elements[5] ); if (typeof scaleX === "number" && typeof scaleY === "number") { border = Math.max(scaleX, scaleY) * (this.linewidth || 0) / 2; } } else { border *= typeof this.scale === "number" ? this.scale : Math.max(this.scale.x, this.scale.y); } } if (l <= 0) { return { width: 0, height: 0 }; } for (i = 0; i < l; i++) { v1 = this._renderer.vertices[i]; v0 = this._renderer.vertices[(i + l - 1) % l]; const [v0x, v0y] = matrix.multiply(v0.x, v0.y); const [v1x, v1y] = matrix.multiply(v1.x, v1.y); if (v0.controls && v1.controls) { let rx = v0.controls.right.x; let ry = v0.controls.right.y; if (v0.relative) { rx += v0.x; ry += v0.y; } let [c0x, c0y] = matrix.multiply(rx, ry); let lx = v1.controls.left.x; let ly = v1.controls.left.y; if (v1.relative) { lx += v1.x; ly += v1.y; } let [c1x, c1y] = matrix.multiply(lx, ly); const bb = getCurveBoundingBox( v0x, v0y, c0x, c0y, c1x, c1y, v1x, v1y ); top = min3(bb.min.y - border, top); left = min3(bb.min.x - border, left); right = max3(bb.max.x + border, right); bottom = max3(bb.max.y + border, bottom); } else { if (i <= 1) { top = min3(v0y - border, top); left = min3(v0x - border, left); right = max3(v0x + border, right); bottom = max3(v0y + border, bottom); } top = min3(v1y - border, top); left = min3(v1x - border, left); right = max3(v1x + border, right); bottom = max3(v1y + border, bottom); } } return { top, left, right, bottom, width: right - left, height: bottom - top }; } getPointAt(t, obj) { let ia, ib, result; let x, x1, x2, x3, x4, y, y1, y2, y3, y4, left, right; let target = this.length * Math.min(Math.max(t, 0), 1); const length = this.vertices.length; const last = length - 1; let a = null; let b = null; for (let i = 0, l = this._lengths.length, sum = 0; i < l; i++) { if (sum + this._lengths[i] >= target) { if (this._closed) { ia = mod(i, length); ib = mod(i - 1, length); if (i === 0) { ia = ib; ib = i; } } else { ia = i; ib = Math.min(Math.max(i - 1, 0), last); } a = this.vertices[ia]; b = this.vertices[ib]; target -= sum; if (this._lengths[i] !== 0) { t = target / this._lengths[i]; } else { t = 0; } break; } sum += this._lengths[i]; } if (a === null || b === null) { return null; } if (!a) { return b; } else if (!b) { return a; } right = b.controls && b.controls.right; left = a.controls && a.controls.left; x1 = b.x; y1 = b.y; x2 = (right || b).x; y2 = (right || b).y; x3 = (left || a).x; y3 = (left || a).y; x4 = a.x; y4 = a.y; if (right && b.relative) { x2 += b.x; y2 += b.y; } if (left && a.relative) { x3 += a.x; y3 += a.y; } x = getComponentOnCubicBezier(t, x1, x2, x3, x4); y = getComponentOnCubicBezier(t, y1, y2, y3, y4); const t1x = lerp(x1, x2, t); const t1y = lerp(y1, y2, t); const t2x = lerp(x2, x3, t); const t2y = lerp(y2, y3, t); const t3x = lerp(x3, x4, t); const t3y = lerp(y3, y4, t); const brx = lerp(t1x, t2x, t); const bry = lerp(t1y, t2y, t); const alx = lerp(t2x, t3x, t); const aly = lerp(t2y, t3y, t); if (_.isObject(obj)) { obj.x = x; obj.y = y; if (obj instanceof Anchor) { obj.controls.left.x = brx; obj.controls.left.y = bry; obj.controls.right.x = alx; obj.controls.right.y = aly; if (!(typeof obj.relative === "boolean") || obj.relative) { obj.controls.left.x -= x; obj.controls.left.y -= y; obj.controls.right.x -= x; obj.controls.right.y -= y; } } obj.t = t; return obj; } result = new Anchor( x, y, brx - x, bry - y, alx - x, aly - y, this._curved ? Commands.curve : Commands.line ); result.t = t; return result; } plot() { if (this.curved) { getCurveFromPoints(this._collection, this.closed); return this; } for (let i = 0; i < this._collection.length; i++) { this._collection[i].command = i === 0 ? Commands.move : Commands.line; } return this; } subdivide(limit) { this._update(); const last = this.vertices.length - 1; const closed2 = this._closed || this.vertices[last]._command === Commands.close; let b = this.vertices[last]; let points = [], verts; _.each(this.vertices, function(a, i) { if (i <= 0 && !closed2) { b = a; return; } if (a.command === Commands.move) { points.push(new Anchor(b.x, b.y)); if (i > 0) { points[points.length - 1].command = Commands.line; } b = a; return; } verts = getSubdivisions(a, b, limit); points = points.concat(verts); _.each(verts, function(v, i2) { if (i2 <= 0 && b.command === Commands.move) { v.command = Commands.move; } else { v.command = Commands.line; } }); if (i >= last) { if (this._closed && this._automatic) { b = a; verts = getSubdivisions(a, b, limit); points = points.concat(verts); _.each(verts, function(v, i2) { if (i2 <= 0 && b.command === Commands.move) { v.command = Commands.move; } else { v.command = Commands.line; } }); } else if (closed2) { points.push(new Anchor(a.x, a.y)); } points[points.length - 1].command = closed2 ? Commands.close : Commands.line; } b = a; }, this); this._automatic = false; this._curved = false; this.vertices = points; return this; } _updateLength(limit, silent) { if (!silent) { this._update(); } const length = this.vertices.length; const last = length - 1; const closed2 = false; let b = this.vertices[last]; let sum = 0; if (typeof this._lengths === "undefined") { this._lengths = []; } _.each(this.vertices, function(a, i) { if (i <= 0 && !closed2 || a.command === Commands.move) { b = a; this._lengths[i] = 0; return; } this._lengths[i] = getCurveLength2(a, b, limit); sum += this._lengths[i]; if (i >= last && closed2) { b = this.vertices[(i + 1) % length]; this._lengths[i + 1] = getCurveLength2(a, b, limit); sum += this._lengths[i + 1]; } b = a; }, this); this._length = sum; this._flagLength = false; return this; } _update() { if (this._flagVertices) { if (this._automatic) { this.plot(); } if (this._flagLength) { this._updateLength(void 0, true); } const l = this._collection.length; const closed2 = this._closed; const beginning = Math.min(this._beginning, this._ending); const ending = Math.max(this._beginning, this._ending); const bid = getIdByLength(this, beginning * this._length); const eid = getIdByLength(this, ending * this._length); const low = ceil(bid); const high = floor2(eid); let left, right, prev, next, v, i; this._renderer.vertices.length = 0; for (i = 0; i < l; i++) { if (this._renderer.collection.length <= i) { this._renderer.collection.push(new Anchor()); } if (i > high && !right) { v = this._renderer.collection[i].copy(this._collection[i]); this.getPointAt(ending, v); v.command = this._renderer.collection[i].command; this._renderer.vertices.push(v); right = v; prev = this._collection[i - 1]; if (prev && prev.controls) { if (v.relative) { v.controls.right.clear(); } else { v.controls.right.copy(v); } if (prev.relative) { this._renderer.collection[i - 1].controls.right.copy(prev.controls.right).lerp(Vector.zero, 1 - v.t); } else { this._renderer.collection[i - 1].controls.right.copy(prev.controls.right).lerp(prev, 1 - v.t); } } } else if (i >= low && i <= high) { v = this._renderer.collection[i].copy(this._collection[i]); this._renderer.vertices.push(v); if (i === high && contains(this, ending)) { right = v; if (!closed2 && right.controls) { if (right.relative) { right.controls.right.clear(); } else { right.controls.right.copy(right); } } } else if (i === low && contains(this, beginning)) { left = v; left.command = Commands.move; if (!closed2 && left.controls) { if (left.relative) { left.controls.left.clear(); } else { left.controls.left.copy(left); } } } } } if (low > 0 && !left) { i = low - 1; v = this._renderer.collection[i].copy(this._collection[i]); this.getPointAt(beginning, v); v.command = Commands.move; this._renderer.vertices.unshift(v); next = this._collection[i + 1]; if (next && next.controls) { v.controls.left.clear(); if (next.relative) { this._renderer.collection[i + 1].controls.left.copy(next.controls.left).lerp(Vector.zero, v.t); } else { vector.copy(next); this._renderer.collection[i + 1].controls.left.copy(next.controls.left).lerp(next, v.t); } } } } Shape.prototype._update.apply(this, arguments); return this; } flagReset() { this._flagVertices = this._flagLength = this._flagFill = this._flagStroke = this._flagLinewidth = this._flagOpacity = this._flagVisible = this._flagCap = this._flagJoin = this._flagMiter = this._flagClip = false; Shape.prototype.flagReset.call(this); return this; } }; var Path = _Path; __publicField(Path, "Properties", [ "fill", "stroke", "linewidth", "opacity", "visible", "cap", "join", "miter", "closed", "curved", "automatic", "beginning", "ending" ]); __publicField(Path, "Utils", { getCurveLength: getCurveLength2 }); var proto11 = { linewidth: { enumerable: true, get: function() { return this._linewidth; }, set: function(v) { this._linewidth = v; this._flagLinewidth = true; } }, opacity: { enumerable: true, get: function() { return this._opacity; }, set: function(v) { this._opacity = v; this._flagOpacity = true; } }, visible: { enumerable: true, get: function() { return this._visible; }, set: function(v) { this._visible = v; this._flagVisible = true; } }, cap: { enumerable: true, get: function() { return this._cap; }, set: function(v) { this._cap = v; this._flagCap = true; } }, join: { enumerable: true, get: function() { return this._join; }, set: function(v) { this._join = v; this._flagJoin = true; } }, miter: { enumerable: true, get: function() { return this._miter; }, set: function(v) { this._miter = v; this._flagMiter = true; } }, fill: { enumerable: true, get: function() { return this._fill; }, set: function(f) { if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.unbind(Events.Types.change, this._renderer.flagFill); } this._fill = f; this._flagFill = true; if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.bind(Events.Types.change, this._renderer.flagFill); } } }, stroke: { enumerable: true, get: function() { return this._stroke; }, set: function(f) { if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.unbind(Events.Types.change, this._renderer.flagStroke); } this._stroke = f; this._flagStroke = true; if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.bind(Events.Types.change, this._renderer.flagStroke); } } }, length: { get: function() { if (this._flagLength) { this._updateLength(); } return this._length; } }, closed: { enumerable: true, get: function() { return this._closed; }, set: function(v) { this._closed = !!v; this._flagVertices = true; } }, curved: { enumerable: true, get: function() { return this._curved; }, set: function(v) { this._curved = !!v; this._flagVertices = true; } }, automatic: { enumerable: true, get: function() { return this._automatic; }, set: function(v) { if (v === this._automatic) { return; } this._automatic = !!v; const method = this._automatic ? "ignore" : "listen"; _.each(this.vertices, function(v2) { v2[method](); }); } }, beginning: { enumerable: true, get: function() { return this._beginning; }, set: function(v) { this._beginning = v; this._flagVertices = true; } }, ending: { enumerable: true, get: function() { return this._ending; }, set: function(v) { this._ending = v; this._flagVertices = true; } }, vertices: { enumerable: true, get: function() { return this._collection; }, set: function(vertices) { const bindVertices = this._renderer.bindVertices; const unbindVertices = this._renderer.unbindVertices; if (this._collection) { this._collection.unbind(Events.Types.insert, bindVertices).unbind(Events.Types.remove, unbindVertices); } if (vertices instanceof Collection) { this._collection = vertices; } else { this._collection = new Collection(vertices || []); } this._collection.bind(Events.Types.insert, bindVertices).bind(Events.Types.remove, unbindVertices); bindVertices(this._collection); } }, mask: { enumerable: true, get: function() { return this._mask; }, set: function(v) { this._mask = v; this._flagMask = true; if (_.isObject(v) && !v.clip) { v.clip = true; } } }, clip: { enumerable: true, get: function() { return this._clip; }, set: function(v) { this._clip = v; this._flagClip = true; } }, dashes: { enumerable: true, get: function() { return this._dashes; }, set: function(v) { if (typeof v.offset !== "number") { v.offset = this.dashes && this._dashes.offset || 0; } this._dashes = v; } } }; function FlagVertices() { this._flagVertices = true; this._flagLength = true; if (this.parent) { this.parent._flagLength = true; } } function BindVertices(items) { let i = items.length; while (i--) { items[i].bind(Events.Types.change, this._renderer.flagVertices); } this._renderer.flagVertices(); } function UnbindVertices(items) { let i = items.length; while (i--) { items[i].unbind(Events.Types.change, this._renderer.flagVertices); } this._renderer.flagVertices(); } function FlagFill() { this._flagFill = true; } function FlagStroke() { this._flagStroke = true; } // src/shapes/rectangle.js var _Rectangle = class extends Path { constructor(x, y, width, height) { const points = [ new Anchor(), new Anchor(), new Anchor(), new Anchor() ]; super(points, true, false, true); for (let prop in proto12) { Object.defineProperty(this, prop, proto12[prop]); } this.width = typeof width === "number" ? width : 1; this.height = typeof height === "number" ? height : 1; this.origin = new Vector(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } this._update(); } _flagWidth = 0; _flagHeight = 0; _width = 0; _height = 0; _origin = null; _update() { if (this._flagVertices || this._flagWidth || this._flagHeight) { const xr = this._width / 2; const yr = this._height / 2; if (!this._closed && this.vertices.length === 4) { this.vertices.push(new Anchor()); } this.vertices[0].set(-xr, -yr).sub(this._origin).command = Commands.move; this.vertices[1].set(xr, -yr).sub(this._origin).command = Commands.line; this.vertices[2].set(xr, yr).sub(this._origin).command = Commands.line; this.vertices[3].set(-xr, yr).sub(this._origin).command = Commands.line; if (this.vertices[4]) { this.vertices[4].set(-xr, -yr).sub(this._origin).command = Commands.line; } } super._update.call(this); return this; } flagReset() { this._flagWidth = this._flagHeight = false; super.flagReset.call(this); return this; } clone(parent) { const clone = new _Rectangle(0, 0, this.width, this.height); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); object.width = this.width; object.height = this.height; object.origin = this.origin.toObject(); return object; } }; var Rectangle = _Rectangle; __publicField(Rectangle, "Properties", ["width", "height"]); var proto12 = { width: { enumerable: true, get: function() { return this._width; }, set: function(v) { this._width = v; this._flagWidth = true; } }, height: { enumerable: true, get: function() { return this._height; }, set: function(v) { this._height = v; this._flagHeight = true; } }, origin: { enumerable: true, get: function() { return this._origin; }, set: function(v) { if (this._origin) { this._origin.unbind(Events.Types.change, this._renderer.flagVertices); } this._origin = v; this._origin.bind(Events.Types.change, this._renderer.flagVertices); this._renderer.flagVertices(); } } }; // src/effects/sprite.js var _Sprite = class extends Rectangle { _flagTexture = false; _flagColumns = false; _flagRows = false; _flagFrameRate = false; _flagIndex = false; _amount = 1; _duration = 0; _startTime = 0; _playing = false; _firstFrame = 0; _lastFrame = 0; _loop = true; _texture = null; _columns = 1; _rows = 1; _frameRate = 0; _index = 0; _origin = null; constructor(path, ox, oy, cols, rows, frameRate) { super(ox, oy, 0, 0); for (let prop in proto13) { Object.defineProperty(this, prop, proto13[prop]); } this.noStroke(); this.noFill(); if (path instanceof Texture) { this.texture = path; } else if (typeof path === "string") { this.texture = new Texture(path); } this.origin = new Vector(); this._update(); if (typeof cols === "number") { this.columns = cols; } if (typeof rows === "number") { this.rows = rows; } if (typeof frameRate === "number") { this.frameRate = frameRate; } this.index = 0; } play(firstFrame, lastFrame, onLastFrame) { this._playing = true; this._firstFrame = 0; this._lastFrame = this.amount - 1; this._startTime = _.performance.now(); if (typeof firstFrame === "number") { this._firstFrame = firstFrame; } if (typeof lastFrame === "number") { this._lastFrame = lastFrame; } if (typeof onLastFrame === "function") { this._onLastFrame = onLastFrame; } else { delete this._onLastFrame; } if (this._index !== this._firstFrame) { this._startTime -= 1e3 * Math.abs(this._index - this._firstFrame) / this._frameRate; } return this; } pause() { this._playing = false; return this; } stop() { this._playing = false; this._index = 0; return this; } clone(parent) { const clone = new _Sprite( this.texture, this.translation.x, this.translation.y, this.columns, this.rows, this.frameRate ); if (this.playing) { clone.play(this._firstFrame, this._lastFrame); clone._loop = this._loop; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); object.texture = this.texture.toObject(); object.columns = this.columns; object.rows = this.rows; object.frameRate = this.frameRate; object.index = this.index; object._firstFrame = this._firstFrame; object._lastFrame = this._lastFrame; object._loop = this._loop; return object; } _update() { const effect = this._texture; const cols = this._columns; const rows = this._rows; let width, height, elapsed, amount, duration; let index, iw, ih, frames; if (effect) { if (this._flagColumns || this._flagRows) { this._amount = this._columns * this._rows; } if (this._flagFrameRate) { this._duration = 1e3 * this._amount / this._frameRate; } if (this._flagTexture) { this.fill = effect; } if (effect.loaded) { iw = effect.image.width; ih = effect.image.height; width = iw / cols; height = ih / rows; amount = this._amount; if (this.width !== width) { this.width = width; } if (this.height !== height) { this.height = height; } if (this._playing && this._frameRate > 0) { if (_.isNaN(this._lastFrame)) { this._lastFrame = amount - 1; } elapsed = _.performance.now() - this._startTime; frames = this._lastFrame + 1; duration = 1e3 * (frames - this._firstFrame) / this._frameRate; if (this._loop) { elapsed = elapsed % duration; } else { elapsed = Math.min(elapsed, duration); } index = lerp(this._firstFrame, frames, elapsed / duration); index = Math.floor(index); if (index !== this._index) { this._index = index; if (index >= this._lastFrame - 1 && this._onLastFrame) { this._onLastFrame(); } } } const col = this._index % cols; const row = Math.floor(this._index / cols); const ox = -width * col + (iw - width) / 2; const oy = -height * row + (ih - height) / 2; if (ox !== effect.offset.x) { effect.offset.x = ox; } if (oy !== effect.offset.y) { effect.offset.y = oy; } } } super._update.call(this); return this; } flagReset() { this._flagTexture = this._flagColumns = this._flagRows = this._flagFrameRate = false; super.flagReset.call(this); return this; } }; var Sprite = _Sprite; __publicField(Sprite, "Properties", [ "texture", "columns", "rows", "frameRate", "index" ]); var proto13 = { texture: { enumerable: true, get: function() { return this._texture; }, set: function(v) { this._texture = v; this._flagTexture = true; } }, columns: { enumerable: true, get: function() { return this._columns; }, set: function(v) { this._columns = v; this._flagColumns = true; } }, rows: { enumerable: true, get: function() { return this._rows; }, set: function(v) { this._rows = v; this._flagRows = true; } }, frameRate: { enumerable: true, get: function() { return this._frameRate; }, set: function(v) { this._frameRate = v; this._flagFrameRate = true; } }, index: { enumerable: true, get: function() { return this._index; }, set: function(v) { this._index = v; this._flagIndex = true; } } }; // src/shapes/circle.js var cos3 = Math.cos; var sin3 = Math.sin; var _Circle = class extends Path { _flagRadius = false; _radius = 0; constructor(ox, oy, r, resolution) { const amount = resolution ? Math.max(resolution, 2) : 4; const points = []; for (let i = 0; i < amount; i++) { points.push(new Anchor(0, 0, 0, 0, 0, 0)); } super(points, true, true, true); for (let prop in proto14) { Object.defineProperty(this, prop, proto14[prop]); } if (typeof r === "number") { this.radius = r; } this._update(); if (typeof ox === "number") { this.translation.x = ox; } if (typeof oy === "number") { this.translation.y = oy; } } _update() { if (this._flagVertices || this._flagRadius) { let length = this.vertices.length; if (!this._closed && length > 2) { length -= 1; } const c = 4 / 3 * Math.tan(Math.PI / (length * 2)); const radius = this._radius; const rc = radius * c; for (let i = 0; i < this.vertices.length; i++) { const pct = i / length; const theta = pct * TWO_PI; const x = radius * cos3(theta); const y = radius * sin3(theta); const lx = rc * cos3(theta - HALF_PI); const ly = rc * sin3(theta - HALF_PI); const rx = rc * cos3(theta + HALF_PI); const ry = rc * sin3(theta + HALF_PI); const v = this.vertices[i]; v.command = i === 0 ? Commands.move : Commands.curve; v.set(x, y); v.controls.left.set(lx, ly); v.controls.right.set(rx, ry); } } super._update.call(this); return this; } flagReset() { this._flagRadius = false; super.flagReset.call(this); return this; } clone(parent) { const clone = new _Circle(0, 0, this.radius, this.vertices.length); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _Circle.Properties.length; i++) { const k = _Circle.Properties[i]; object[k] = this[k]; } return object; } }; var Circle = _Circle; __publicField(Circle, "Properties", ["radius"]); var proto14 = { radius: { enumerable: true, get: function() { return this._radius; }, set: function(v) { this._radius = v; this._flagRadius = true; } } }; // src/shapes/ellipse.js var cos4 = Math.cos; var sin4 = Math.sin; var _Ellipse = class extends Path { _flagWidth = false; _flagHeight = false; _width = 0; _height = 0; constructor(x, y, rx, ry, resolution) { if (typeof ry !== "number" && typeof rx === "number") { ry = rx; } const amount = resolution ? Math.max(resolution, 2) : 4; const points = []; for (let i = 0; i < amount; i++) { points.push(new Anchor()); } super(points, true, true, true); for (let prop in proto15) { Object.defineProperty(this, prop, proto15[prop]); } if (typeof rx === "number") { this.width = rx * 2; } if (typeof ry === "number") { this.height = ry * 2; } this._update(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } } _update() { if (this._flagVertices || this._flagWidth || this._flagHeight) { let length = this.vertices.length; if (!this._closed && length > 2) { length -= 1; } const c = 4 / 3 * Math.tan(Math.PI / (this.vertices.length * 2)); const radiusX = this._width / 2; const radiusY = this._height / 2; for (let i = 0; i < this.vertices.length; i++) { const pct = i / length; const theta = pct * TWO_PI; const x = radiusX * cos4(theta); const y = radiusY * sin4(theta); const lx = radiusX * c * cos4(theta - HALF_PI); const ly = radiusY * c * sin4(theta - HALF_PI); const rx = radiusX * c * cos4(theta + HALF_PI); const ry = radiusY * c * sin4(theta + HALF_PI); const v = this.vertices[i]; v.command = i === 0 ? Commands.move : Commands.curve; v.set(x, y); v.controls.left.set(lx, ly); v.controls.right.set(rx, ry); } } super._update.call(this); return this; } flagReset() { this._flagWidth = this._flagHeight = false; super.flagReset.call(this); return this; } clone(parent) { const rx = this.width / 2; const ry = this.height / 2; const resolution = this.vertices.length; const clone = new _Ellipse(0, 0, rx, ry, resolution); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _Ellipse.Properties.length; i++) { const k = _Ellipse.Properties[i]; object[k] = this[k]; } return object; } }; var Ellipse = _Ellipse; __publicField(Ellipse, "Properties", ["width", "height"]); var proto15 = { width: { enumerable: true, get: function() { return this._width; }, set: function(v) { this._width = v; this._flagWidth = true; } }, height: { enumerable: true, get: function() { return this._height; }, set: function(v) { this._height = v; this._flagHeight = true; } } }; // src/shapes/line.js var Line = class extends Path { constructor(x1, y1, x2, y2) { const points = [ new Anchor(x1, y1), new Anchor(x2, y2) ]; super(points); for (let prop in proto16) { Object.defineProperty(this, prop, proto16[prop]); } this.vertices[0].command = Commands.move; this.vertices[1].command = Commands.line; this.automatic = false; } }; var proto16 = { left: { enumerable: true, get: function() { return this.vertices[0]; }, set: function(v) { if (_.isObject(v)) { this.vertices.splice(0, 1, v); } else { const error = new TwoError("Two.Line.x argument is not an object."); console.warn(error.name, error.message); } } }, right: { enumerable: true, get: function() { return this.vertices[1]; }, set: function(v) { if (_.isObject(v)) { this.vertices.splice(1, 1, v); } else { const error = new TwoError("Two.Line.y argument is not an object."); console.warn(error.name, error.message); } } } }; // src/shapes/rounded-rectangle.js var _RoundedRectangle = class extends Path { _flagWidth = false; _flagHeight = false; _flagRadius = false; _width = 0; _height = 0; _radius = 12; constructor(x, y, width, height, radius) { if (typeof radius === "undefined" && typeof width === "number" && typeof height === "number") { radius = Math.floor(Math.min(width, height) / 12); } const points = []; for (let i = 0; i < 10; i++) { points.push( new Anchor( 0, 0, 0, 0, 0, 0, i === 0 ? Commands.move : Commands.curve ) ); } super(points); for (let prop in proto17) { Object.defineProperty(this, prop, proto17[prop]); } this.closed = true; this.automatic = false; this._renderer.flagRadius = FlagRadius.bind(this); if (typeof width === "number") { this.width = width; } if (typeof height === "number") { this.height = height; } if (typeof radius === "number") { this.radius = radius; } this._update(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } } _update() { if (this._flagVertices || this._flagWidth || this._flagHeight || this._flagRadius) { const width = this._width; const height = this._height; let rx, ry; if (this._radius instanceof Vector) { rx = this._radius.x; ry = this._radius.y; } else { rx = this._radius; ry = this._radius; } let v; let w = width / 2; let h = height / 2; v = this.vertices[0]; v.x = -(w - rx); v.y = -h; v = this.vertices[1]; v.x = w - rx; v.y = -h; v.controls.left.clear(); v.controls.right.x = rx; v.controls.right.y = 0; v = this.vertices[2]; v.x = w; v.y = -(h - ry); v.controls.right.clear(); v.controls.left.clear(); v = this.vertices[3]; v.x = w; v.y = h - ry; v.controls.left.clear(); v.controls.right.x = 0; v.controls.right.y = ry; v = this.vertices[4]; v.x = w - rx; v.y = h; v.controls.right.clear(); v.controls.left.clear(); v = this.vertices[5]; v.x = -(w - rx); v.y = h; v.controls.left.clear(); v.controls.right.x = -rx; v.controls.right.y = 0; v = this.vertices[6]; v.x = -w; v.y = h - ry; v.controls.left.clear(); v.controls.right.clear(); v = this.vertices[7]; v.x = -w; v.y = -(h - ry); v.controls.left.clear(); v.controls.right.x = 0; v.controls.right.y = -ry; v = this.vertices[8]; v.x = -(w - rx); v.y = -h; v.controls.left.clear(); v.controls.right.clear(); v = this.vertices[9]; v.copy(this.vertices[8]); } super._update.call(this); return this; } flagReset() { this._flagWidth = this._flagHeight = this._flagRadius = false; super.flagReset.call(this); return this; } clone(parent) { const width = this.width; const height = this.height; const radius = this.radius; const clone = new _RoundedRectangle(0, 0, width, height, radius); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _RoundedRectangle.Properties.length; i++) { const k = _RoundedRectangle.Properties[i]; object[k] = this[k]; } object.radius = typeof this.radius === "number" ? this.radius : this.radius.toObject(); return object; } }; var RoundedRectangle = _RoundedRectangle; __publicField(RoundedRectangle, "Properties", ["width", "height", "radius"]); var proto17 = { width: { enumerable: true, get: function() { return this._width; }, set: function(v) { this._width = v; this._flagWidth = true; } }, height: { enumerable: true, get: function() { return this._height; }, set: function(v) { this._height = v; this._flagHeight = true; } }, radius: { enumerable: true, get: function() { return this._radius; }, set: function(v) { if (this._radius instanceof Vector) { this._radius.unbind(Events.Types.change, this._renderer.flagRadius); } this._radius = v; if (this._radius instanceof Vector) { this._radius.bind(Events.Types.change, this._renderer.flagRadius); } this._flagRadius = true; } } }; function FlagRadius() { this._flagRadius = true; } // src/text.js var canvas2; var min4 = Math.min; var max4 = Math.max; if (root.document) { canvas2 = document.createElement("canvas"); } var _Text = class extends Shape { _flagValue = true; _flagFamily = true; _flagSize = true; _flagLeading = true; _flagAlignment = true; _flagBaseline = true; _flagStyle = true; _flagWeight = true; _flagDecoration = true; _flagFill = true; _flagStroke = true; _flagLinewidth = true; _flagOpacity = true; _flagVisible = true; _flagMask = false; _flagClip = false; _value = ""; _family = "sans-serif"; _size = 13; _leading = 17; _alignment = "center"; _baseline = "middle"; _style = "normal"; _weight = 500; _decoration = "none"; _fill = "#000"; _stroke = "none"; _linewidth = 1; _opacity = 1; _visible = true; _mask = null; _clip = false; _dashes = null; constructor(message, x, y, styles) { super(); for (let prop in proto18) { Object.defineProperty(this, prop, proto18[prop]); } this._renderer.type = "text"; this._renderer.flagFill = FlagFill2.bind(this); this._renderer.flagStroke = FlagStroke2.bind(this); this.value = message; if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } this.dashes = []; this.dashes.offset = 0; if (!_.isObject(styles)) { return this; } for (let i = 0; i < _Text.Properties.length; i++) { const property = _Text.Properties[i]; if (property in styles) { this[property] = styles[property]; } } } static Measure(text) { if (canvas2) { const ctx = canvas2.getContext("2d"); ctx.font = [ text._style, text._weight, `${text._size}px/${text._leading}px`, text._family ].join(" "); const metrics = ctx.measureText(text.value, 0, 0); const height = metrics.actualBoundingBoxDescent + metrics.actualBoundingBoxAscent; return { width: metrics.width, height }; } else { const width = this.value.length * this.size * _Text.Ratio; const height = this.leading; console.warn("Two.Text: unable to accurately measure text, so using an approximation."); return { width, height }; } } clone(parent) { const clone = new _Text(this.value); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; for (let i = 0; i < _Text.Properties.length; i++) { const prop = _Text.Properties[i]; clone[prop] = this[prop]; } if (this.matrix.manual) { clone.matrix.copy(this.matrix); } if (parent) { parent.add(clone); } return clone._update(); } toObject() { const result = { translation: this.translation.toObject(), rotation: this.rotation, scale: this.scale }; if (this.matrix.manual) { result.matrix = this.matrix.toObject(); } for (let i = 0; i < _Text.Properties.length; i++) { const prop = _Text.Properties[i]; result[prop] = this[prop]; } return result; } noFill() { this.fill = "none"; return this; } noStroke() { this.stroke = "none"; this.linewidth = 0; return this; } getBoundingClientRect(shallow) { let matrix; let left, right, top, bottom; this._update(true); matrix = shallow ? this.matrix : this.worldMatrix; const { width, height } = _Text.Measure(this); const border = (this._linewidth || 0) / 2; switch (this.alignment) { case "left": left = -border; right = width + border; break; case "right": left = -(width + border); right = border; break; default: left = -(width / 2 + border); right = width / 2 + border; } switch (this.baseline) { case "middle": top = -(height / 2 + border); bottom = height / 2 + border; break; default: top = -(height + border); bottom = border; } const [ax, ay] = matrix.multiply(left, top); const [bx, by] = matrix.multiply(left, bottom); const [cx, cy] = matrix.multiply(right, top); const [dx, dy] = matrix.multiply(right, bottom); top = min4(ay, by, cy, dy); left = min4(ax, bx, cx, dx); right = max4(ax, bx, cx, dx); bottom = max4(ay, by, cy, dy); return { top, left, right, bottom, width: right - left, height: bottom - top }; } flagReset() { super.flagReset.call(this); this._flagValue = this._flagFamily = this._flagSize = this._flagLeading = this._flagAlignment = this._flagFill = this._flagStroke = this._flagLinewidth = this._flagOpacity = this._flagVisible = this._flagClip = this._flagDecoration = this._flagClassName = this._flagBaseline = this._flagWeight = this._flagStyle = false; return this; } }; var Text = _Text; __publicField(Text, "Ratio", 0.6); __publicField(Text, "Properties", [ "value", "family", "size", "leading", "alignment", "linewidth", "style", "weight", "decoration", "baseline", "opacity", "visible", "fill", "stroke" ]); var proto18 = { value: { enumerable: true, get: function() { return this._value; }, set: function(v) { this._value = v; this._flagValue = true; } }, family: { enumerable: true, get: function() { return this._family; }, set: function(v) { this._family = v; this._flagFamily = true; } }, size: { enumerable: true, get: function() { return this._size; }, set: function(v) { this._size = v; this._flagSize = true; } }, leading: { enumerable: true, get: function() { return this._leading; }, set: function(v) { this._leading = v; this._flagLeading = true; } }, alignment: { enumerable: true, get: function() { return this._alignment; }, set: function(v) { this._alignment = v; this._flagAlignment = true; } }, linewidth: { enumerable: true, get: function() { return this._linewidth; }, set: function(v) { this._linewidth = v; this._flagLinewidth = true; } }, style: { enumerable: true, get: function() { return this._style; }, set: function(v) { this._style = v; this._flagStyle = true; } }, weight: { enumerable: true, get: function() { return this._weight; }, set: function(v) { this._weight = v; this._flagWeight = true; } }, decoration: { enumerable: true, get: function() { return this._decoration; }, set: function(v) { this._decoration = v; this._flagDecoration = true; } }, baseline: { enumerable: true, get: function() { return this._baseline; }, set: function(v) { this._baseline = v; this._flagBaseline = true; } }, opacity: { enumerable: true, get: function() { return this._opacity; }, set: function(v) { this._opacity = v; this._flagOpacity = true; } }, visible: { enumerable: true, get: function() { return this._visible; }, set: function(v) { this._visible = v; this._flagVisible = true; } }, fill: { enumerable: true, get: function() { return this._fill; }, set: function(f) { if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.unbind(Events.Types.change, this._renderer.flagFill); } this._fill = f; this._flagFill = true; if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.bind(Events.Types.change, this._renderer.flagFill); } } }, stroke: { enumerable: true, get: function() { return this._stroke; }, set: function(f) { if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.unbind(Events.Types.change, this._renderer.flagStroke); } this._stroke = f; this._flagStroke = true; if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.bind(Events.Types.change, this._renderer.flagStroke); } } }, mask: { enumerable: true, get: function() { return this._mask; }, set: function(v) { this._mask = v; this._flagMask = true; if (_.isObject(v) && !v.clip) { v.clip = true; } } }, clip: { enumerable: true, get: function() { return this._clip; }, set: function(v) { this._clip = v; this._flagClip = true; } }, dashes: { enumerable: true, get: function() { return this._dashes; }, set: function(v) { if (typeof v.offset !== "number") { v.offset = this.dashes && this._dashes.offset || 0; } this._dashes = v; } } }; function FlagFill2() { this._flagFill = true; } function FlagStroke2() { this._flagStroke = true; } // src/utils/interpret-svg.js var regex2 = { path: /[+-]?(?:\d*\.\d+|\d+)(?:[eE][+-]\d+)?/g, cssBackgroundImage: /url\(['"]?#([\w\d-_]*)['"]?\)/i, unitSuffix: /[a-zA-Z%]*/i }; var alignments = { start: "left", middle: "center", end: "right" }; var reservedAttributesToRemove = ["id", "class", "transform", "xmlns", "viewBox"]; var overwriteAttrs = ["x", "y", "width", "height", "href", "xlink:href"]; function getAlignment(anchor2) { return alignments[anchor2]; } function getBaseline(node) { const a = node.getAttribute("dominant-baseline"); const b = node.getAttribute("alignment-baseline"); return a || b; } function getTagName(tag) { return tag.replace(/svg:/ig, "").toLowerCase(); } function applyTransformsToVector(transforms, vector2) { vector2.x += transforms.translateX; vector2.y += transforms.translateY; vector2.x *= transforms.scaleX; vector2.y *= transforms.scaleY; if (transforms.rotation !== 0) { const l = vector2.length(); vector2.x = l * Math.cos(transforms.rotation); vector2.y = l * Math.sin(transforms.rotation); } } function extractCSSText(text, styles) { if (!styles) { styles = {}; } const commands = text.split(";"); for (let i = 0; i < commands.length; i++) { const command = commands[i].split(":"); const name = command[0]; const value = command[1]; if (typeof name === "undefined" || typeof value === "undefined") { continue; } styles[name] = value.replace(/\s/, ""); } return styles; } function getSvgStyles(node) { const styles = {}; const attributes = getSvgAttributes(node); const length = Math.max(attributes.length, node.style.length); for (let i = 0; i < length; i++) { const command = node.style[i]; const attribute = attributes[i]; if (command) { styles[command] = node.style[command]; } if (attribute) { styles[attribute] = node.getAttribute(attribute); } } return styles; } function getSvgAttributes(node) { const attributes = node.getAttributeNames(); for (let i = 0; i < reservedAttributesToRemove.length; i++) { const keyword = reservedAttributesToRemove[i]; const index = Array.prototype.indexOf.call(attributes, keyword); if (index >= 0) { attributes.splice(index, 1); } } return attributes; } function applySvgViewBox(node, value) { const elements = value.split(/[\s,]/); const x = -parseFloat(elements[0]); const y = -parseFloat(elements[1]); const width = parseFloat(elements[2]); const height = parseFloat(elements[3]); if (x && y) { for (let i = 0; i < node.children.length; i++) { const child = node.children[i]; if ("translation" in child) { child.translation.add(x, y); } else if ("x" in child) { child.x = x; } else if ("y" in child) { child.y = y; } } } const xExists = typeof node.x === "number"; const yExists = typeof node.y === "number"; const widthExists = typeof node.width === "number"; const heightExists = typeof node.height === "number"; if (xExists) { node.translation.x += node.x; } if (yExists) { node.translation.y += node.y; } if (widthExists || heightExists) { node.scale = new Vector(1, 1); } if (widthExists) { node.scale.x = node.width / width; } if (heightExists) { node.scale.y = node.height / height; } node.mask = new Rectangle(0, 0, width, height); node.mask.origin.set(-width / 2, -height / 2); return node; } function applySvgAttributes(node, elem, parentStyles) { const styles = {}, attributes = {}, extracted = {}; let i, m, key, value, prop, attr; let transforms, x, y; let id, scene, ref, tagName; let ca, cb, cc, error; if (node === null) { return styles; } if (root.getComputedStyle) { const computedStyles = root.getComputedStyle(node); i = computedStyles.length; while (i--) { key = computedStyles[i]; value = computedStyles[key]; if (typeof value !== "undefined") { styles[key] = value; } } } for (i = 0; i < node.attributes.length; i++) { attr = node.attributes[i]; if (/style/i.test(attr.nodeName)) { extractCSSText(attr.value, extracted); } else { attributes[attr.nodeName] = attr.value; } } if (typeof styles.opacity !== "undefined") { styles["stroke-opacity"] = styles.opacity; styles["fill-opacity"] = styles.opacity; delete styles.opacity; } if (parentStyles) { _.defaults(styles, parentStyles); } _.extend(styles, extracted, attributes); styles.visible = !(typeof styles.display === "undefined" && /none/i.test(styles.display)) || typeof styles.visibility === "undefined" && /hidden/i.test(styles.visibility); for (key in styles) { value = styles[key]; switch (key) { case "gradientTransform": if (/none/i.test(value)) break; m = node.gradientTransform && node.gradientTransform.baseVal && node.gradientTransform.baseVal.length > 0 ? node.gradientTransform.baseVal[0].matrix : node.getCTM ? node.getCTM() : null; if (m === null) break; transforms = decomposeMatrix(m); switch (elem._renderer.type) { case "linear-gradient": applyTransformsToVector(transforms, elem.left); applyTransformsToVector(transforms, elem.right); break; case "radial-gradient": elem.center.x += transforms.translateX; elem.center.y += transforms.translateY; elem.focal.x += transforms.translateX; elem.focal.y += transforms.translateY; elem.radius *= Math.max(transforms.scaleX, transforms.scaleY); break; } break; case "transform": if (/none/i.test(value)) break; m = node.transform && node.transform.baseVal && node.transform.baseVal.length > 0 ? node.transform.baseVal[0].matrix : node.getCTM ? node.getCTM() : null; if (m === null) break; if (Constants.AutoCalculateImportedMatrices) { transforms = decomposeMatrix(m); elem.translation.set(transforms.translateX, transforms.translateY); elem.rotation = Math.PI * (transforms.rotation / 180); elem.scale = new Vector(transforms.scaleX, transforms.scaleY); x = parseFloat((styles.x + "").replace("px")); y = parseFloat((styles.y + "").replace("px")); if (x) { elem.translation.x = x; } if (y) { elem.translation.y = y; } } else { m = node.getCTM(); elem._matrix.manual = true; elem._matrix.set(m.a, m.b, m.c, m.d, m.e, m.f); } break; case "visible": if (elem instanceof Group) { elem._visible = value; break; } elem.visible = value; break; case "stroke-linecap": if (elem instanceof Group) { elem._cap = value; break; } elem.cap = value; break; case "stroke-linejoin": if (elem instanceof Group) { elem._join = value; break; } elem.join = value; break; case "stroke-miterlimit": if (elem instanceof Group) { elem._miter = value; break; } elem.miter = value; break; case "stroke-width": if (elem instanceof Group) { elem._linewidth = parseFloat(value); break; } elem.linewidth = parseFloat(value); break; case "opacity": case "stroke-opacity": case "fill-opacity": if (elem instanceof Group) { elem._opacity = parseFloat(value); break; } elem.opacity = parseFloat(value); break; case "clip-path": if (regex2.cssBackgroundImage.test(value)) { id = value.replace(regex2.cssBackgroundImage, "$1"); if (read.defs.current && read.defs.current.contains(id)) { ref = read.defs.current.get(id); if (ref && ref.childNodes.length > 0) { ref = ref.childNodes[0]; tagName = getTagName(ref.nodeName); elem.mask = read[tagName].call(this, ref, {}); switch (elem._renderer.type) { case "text": case "path": elem.position.add(elem.mask.position); elem.mask.position.clear(); break; } } } } break; case "fill": case "stroke": prop = (elem instanceof Group ? "_" : "") + key; if (regex2.cssBackgroundImage.test(value)) { id = value.replace(regex2.cssBackgroundImage, "$1"); if (read.defs.current && read.defs.current.contains(id)) { ref = read.defs.current.get(id); if (!ref.object) { tagName = getTagName(ref.nodeName); ref.object = read[tagName].call(this, ref, {}); } ref = ref.object; } else { scene = getScene(this); ref = scene.getById(id); } elem[prop] = ref; } else { elem[prop] = value; } break; case "id": elem.id = value; break; case "class": case "className": elem.classList = value.split(" "); elem._flagClassName = true; break; case "x": case "y": ca = elem instanceof Gradient; cb = elem instanceof LinearGradient; cc = elem instanceof RadialGradient; if (ca || cb || cc) { break; } if (value.match("[a-z%]$") && !value.endsWith("px")) { error = new TwoError( "only pixel values are supported with the " + key + " attribute." ); console.warn(error.name, error.message); } elem.translation[key] = parseFloat(value); break; case "font-family": if (elem instanceof Text) { elem.family = value; } break; case "font-size": if (elem instanceof Text) { elem.size = value; } break; case "font-weight": if (elem instanceof Text) { elem.weight = value; } break; case "font-style": if (elem instanceof Text) { elem.style = value; } break; case "text-decoration": if (elem instanceof Text) { elem.decoration = value; } break; case "line-height": if (elem instanceof Text) { elem.leading = value; } break; } } if (Object.keys(node.dataset).length) elem.dataset = node.dataset; return styles; } function updateDefsCache(node, defsCache) { for (let i = 0, l = node.childNodes.length; i < l; i++) { const n = node.childNodes[i]; if (!n.id) continue; const tagName = getTagName(node.nodeName); if (tagName === "#text") continue; defsCache.add(n.id, n); } } function getScene(node) { while (node.parent) { node = node.parent; } return node.scene; } var read = { svg: function(node) { const defs = read.defs.current = new Registry(); const elements = node.getElementsByTagName("defs"); for (let i = 0; i < elements.length; i++) { updateDefsCache(elements[i], defs); } const svg2 = read.g.call(this, node); const viewBox = node.getAttribute("viewBox"); const x = node.getAttribute("x"); const y = node.getAttribute("y"); const width = node.getAttribute("width"); const height = node.getAttribute("height"); svg2.defs = defs; const viewBoxExists = viewBox !== null; const xExists = x !== null; const yExists = y !== null; const widthExists = width !== null; const heightExists = height !== null; if (xExists) { svg2.x = parseFloat(x.replace(regex2.unitSuffix, "")); } if (yExists) { svg2.y = parseFloat(y.replace(regex2.unitSuffix, "")); } if (widthExists) { svg2.width = parseFloat(width.replace(regex2.unitSuffix, "")); } if (heightExists) { svg2.height = parseFloat(height.replace(regex2.unitSuffix, "")); } if (viewBoxExists) { applySvgViewBox(svg2, viewBox); } delete read.defs.current; return svg2; }, defs: function(node) { return null; }, use: function(node, styles) { let error; const href = node.getAttribute("href") || node.getAttribute("xlink:href"); if (!href) { error = new TwoError("encountered with no href."); console.warn(error.name, error.message); return null; } const id = href.slice(1); if (!read.defs.current.contains(id)) { error = new TwoError( "unable to find element for reference " + href + "." ); console.warn(error.name, error.message); return null; } const template = read.defs.current.get(id); const fullNode = template.cloneNode(true); for (let i = 0; i < node.attributes.length; i++) { const attr = node.attributes[i]; const ca = overwriteAttrs.includes(attr.nodeName); const cb = !fullNode.hasAttribute(attr.nodeName); if (ca || cb) { fullNode.setAttribute(attr.nodeName, attr.value); } } const tagName = getTagName(fullNode.nodeName); return read[tagName].call(this, fullNode, styles); }, g: function(node, parentStyles) { const group = new Group(); applySvgAttributes.call(this, node, group, parentStyles); this.add(group); const styles = getSvgStyles.call(this, node); for (let i = 0, l = node.childNodes.length; i < l; i++) { const n = node.childNodes[i]; const tag = n.nodeName; if (!tag) return; const tagName = getTagName(tag); if (tagName in read) { const o = read[tagName].call(group, n, styles); if (!!o && !o.parent) { group.add(o); } } } return group; }, polygon: function(node, parentStyles) { let points; if (typeof node === "string") { points = node; } else { points = node.getAttribute("points"); } const verts = []; points.replace(/(-?[\d.eE-]+)[,|\s](-?[\d.eE-]+)/g, function(match, p1, p2) { verts.push(new Anchor(parseFloat(p1), parseFloat(p2))); }); const poly = new Path(verts, true).noStroke(); poly.fill = "black"; applySvgAttributes.call(this, node, poly, parentStyles); return poly; }, polyline: function(node, parentStyles) { const poly = read.polygon.call(this, node, parentStyles); poly.closed = false; return poly; }, path: function(node, parentStyles) { let path; if (typeof node === "string") { path = node; node = null; } else { path = node.getAttribute("d"); } let points = []; let closed2 = false, relative = false; if (path) { let coord = new Anchor(); let control, coords; let commands = path.match(/[a-df-z][^a-df-z]*/ig); const last = commands.length - 1; _.each(commands.slice(0), function(command, i) { const items = command.slice(1).trim().match(regex2.path); const type = command[0]; const lower = type.toLowerCase(); let bin, j, l, ct, times; const result = []; if (i === 0) { commands = []; } switch (lower) { case "h": case "v": if (items.length > 1) { bin = 1; } break; case "m": case "l": case "t": if (items.length > 2) { bin = 2; } break; case "s": case "q": if (items.length > 4) { bin = 4; } break; case "c": if (items.length > 6) { bin = 6; } break; case "a": if (items.length > 7) { bin = 7; } break; } if (bin) { for (j = 0, l = items.length, times = 0; j < l; j += bin) { ct = type; if (times > 0) { switch (type) { case "m": ct = "l"; break; case "M": ct = "L"; break; } } result.push(ct + items.slice(j, j + bin).join(" ")); times++; } commands = Array.prototype.concat.apply(commands, result); } else { commands.push(command); } }); _.each(commands, function(command, i) { let result, x, y; const type = command[0]; const lower = type.toLowerCase(); coords = command.slice(1).trim().match(regex2.path); relative = type === lower; let x1, y1, x2, y2, x3, y3, x4, y4, reflection; let a, b; let anchor2, rx, ry, xAxisRotation, largeArcFlag, sweepFlag; switch (lower) { case "z": if (i >= last) { closed2 = true; } else { x = coord.x; y = coord.y; result = new Anchor( x, y, void 0, void 0, void 0, void 0, Commands.close ); for (let j = points.length - 1; j >= 0; j--) { const point = points[j]; if (/m/i.test(point.command)) { coord = point; break; } } } break; case "m": case "l": control = void 0; x = parseFloat(coords[0]); y = parseFloat(coords[1]); result = new Anchor( x, y, void 0, void 0, void 0, void 0, /m/i.test(lower) ? Commands.move : Commands.line ); if (relative) { result.addSelf(coord); } coord = result; break; case "h": case "v": a = /h/i.test(lower) ? "x" : "y"; b = /x/i.test(a) ? "y" : "x"; result = new Anchor( void 0, void 0, void 0, void 0, void 0, void 0, Commands.line ); result[a] = parseFloat(coords[0]); result[b] = coord[b]; if (relative) { result[a] += coord[a]; } coord = result; break; case "c": case "s": x1 = coord.x; y1 = coord.y; if (!control) { control = new Vector(); } if (/c/i.test(lower)) { x2 = parseFloat(coords[0]); y2 = parseFloat(coords[1]); x3 = parseFloat(coords[2]); y3 = parseFloat(coords[3]); x4 = parseFloat(coords[4]); y4 = parseFloat(coords[5]); } else { reflection = getReflection(coord, control, relative); x2 = reflection.x; y2 = reflection.y; x3 = parseFloat(coords[0]); y3 = parseFloat(coords[1]); x4 = parseFloat(coords[2]); y4 = parseFloat(coords[3]); } if (relative) { x2 += x1; y2 += y1; x3 += x1; y3 += y1; x4 += x1; y4 += y1; } coord.controls.right.set(x2 - coord.x, y2 - coord.y); result = new Anchor( x4, y4, x3 - x4, y3 - y4, void 0, void 0, Commands.curve ); coord = result; control = result.controls.left; break; case "t": case "q": x1 = coord.x; y1 = coord.y; if (!control) { control = new Vector(); } if (/q/i.test(lower)) { x2 = parseFloat(coords[0]); y2 = parseFloat(coords[1]); x3 = parseFloat(coords[0]); y3 = parseFloat(coords[1]); x4 = parseFloat(coords[2]); y4 = parseFloat(coords[3]); } else { reflection = getReflection(coord, control, relative); x2 = reflection.x; y2 = reflection.y; x3 = reflection.x; y3 = reflection.y; x4 = parseFloat(coords[0]); y4 = parseFloat(coords[1]); } if (relative) { x2 += x1; y2 += y1; x3 += x1; y3 += y1; x4 += x1; y4 += y1; } coord.controls.right.set( (x2 - coord.x) * 0.33, (y2 - coord.y) * 0.33 ); result = new Anchor( x4, y4, x3 - x4, y3 - y4, void 0, void 0, Commands.curve ); coord = result; control = result.controls.left; break; case "a": x1 = coord.x; y1 = coord.y; rx = parseFloat(coords[0]); ry = parseFloat(coords[1]); xAxisRotation = parseFloat(coords[2]); largeArcFlag = parseFloat(coords[3]); sweepFlag = parseFloat(coords[4]); x4 = parseFloat(coords[5]); y4 = parseFloat(coords[6]); if (relative) { x4 += x1; y4 += y1; } anchor2 = new Anchor(x4, y4); anchor2.command = Commands.arc; anchor2.rx = rx; anchor2.ry = ry; anchor2.xAxisRotation = xAxisRotation; anchor2.largeArcFlag = largeArcFlag; anchor2.sweepFlag = sweepFlag; result = anchor2; coord = anchor2; control = void 0; break; } if (result) { if (Array.isArray(result)) { points = points.concat(result); } else { points.push(result); } } }); } path = new Path(points, closed2, void 0, true).noStroke(); path.fill = "black"; const rect = path.getBoundingClientRect(true); rect.centroid = { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 }; _.each(path.vertices, function(v) { v.subSelf(rect.centroid); }); applySvgAttributes.call(this, node, path, parentStyles); path.translation.addSelf(rect.centroid); return path; }, circle: function(node, parentStyles) { const x = parseFloat(node.getAttribute("cx")); const y = parseFloat(node.getAttribute("cy")); const r = parseFloat(node.getAttribute("r")); const circle = new Circle(0, 0, r).noStroke(); circle.fill = "black"; applySvgAttributes.call(this, node, circle, parentStyles); circle.translation.x = x; circle.translation.y = y; return circle; }, ellipse: function(node, parentStyles) { const x = parseFloat(node.getAttribute("cx")); const y = parseFloat(node.getAttribute("cy")); const width = parseFloat(node.getAttribute("rx")); const height = parseFloat(node.getAttribute("ry")); const ellipse = new Ellipse(0, 0, width, height).noStroke(); ellipse.fill = "black"; applySvgAttributes.call(this, node, ellipse, parentStyles); ellipse.translation.x = x; ellipse.translation.y = y; return ellipse; }, rect: function(node, parentStyles) { const rx = parseFloat(node.getAttribute("rx")); const ry = parseFloat(node.getAttribute("ry")); if (!_.isNaN(rx) || !_.isNaN(ry)) { return read["rounded-rect"](node); } const width = parseFloat(node.getAttribute("width")); const height = parseFloat(node.getAttribute("height")); const w2 = width / 2; const h2 = height / 2; const rect = new Rectangle(0, 0, width, height).noStroke(); rect.fill = "black"; applySvgAttributes.call(this, node, rect, parentStyles); rect.translation.x += w2; rect.translation.y += h2; return rect; }, "rounded-rect": function(node, parentStyles) { const rx = parseFloat(node.getAttribute("rx")) || 0; const ry = parseFloat(node.getAttribute("ry")) || 0; const width = parseFloat(node.getAttribute("width")); const height = parseFloat(node.getAttribute("height")); const w2 = width / 2; const h2 = height / 2; const radius = new Vector(rx, ry); const rect = new RoundedRectangle(0, 0, width, height, radius).noStroke(); rect.fill = "black"; applySvgAttributes.call(this, node, rect, parentStyles); rect.translation.x += w2; rect.translation.y += h2; return rect; }, line: function(node, parentStyles) { const x1 = parseFloat(node.getAttribute("x1")); const y1 = parseFloat(node.getAttribute("y1")); const x2 = parseFloat(node.getAttribute("x2")); const y2 = parseFloat(node.getAttribute("y2")); const line = new Line(x1, y1, x2, y2).noFill(); applySvgAttributes.call(this, node, line, parentStyles); return line; }, lineargradient: function(node, parentStyles) { let units = node.getAttribute("gradientUnits"); let spread = node.getAttribute("spreadMethod"); if (!units) { units = "objectBoundingBox"; } if (!spread) { spread = "pad"; } let x1 = parseFloat(node.getAttribute("x1") || 0); let y1 = parseFloat(node.getAttribute("y1") || 0); let x2 = parseFloat(node.getAttribute("x2") || 0); let y2 = parseFloat(node.getAttribute("y2") || 0); const ox = (x2 + x1) / 2; const oy = (y2 + y1) / 2; if (/userSpaceOnUse/i.test(units)) { x1 -= ox; y1 -= oy; x2 -= ox; y2 -= oy; } const stops = []; for (let i = 0; i < node.children.length; i++) { const child = node.children[i]; let offset = child.getAttribute("offset"); if (/%/ig.test(offset)) { offset = parseFloat(offset.replace(/%/ig, "")) / 100; } offset = parseFloat(offset); let color = child.getAttribute("stop-color"); let opacity = child.getAttribute("stop-opacity"); let style = child.getAttribute("style"); let matches; if (color === null) { matches = style ? style.match(/stop-color:\s?([#a-fA-F0-9]*)/) : false; color = matches && matches.length > 1 ? matches[1] : void 0; } if (opacity === null) { matches = style ? style.match(/stop-opacity:\s?([0-9.-]*)/) : false; opacity = matches && matches.length > 1 ? parseFloat(matches[1]) : 1; } else { opacity = parseFloat(opacity); } stops.push(new Stop(offset, color, opacity)); } const gradient = new LinearGradient(x1, y1, x2, y2, stops); gradient.spread = spread; gradient.units = units; applySvgAttributes.call(this, node, gradient, parentStyles); return gradient; }, radialgradient: function(node, parentStyles) { let units = node.getAttribute("gradientUnits"); let spread = node.getAttribute("spreadMethod"); if (!units) { units = "objectBoundingBox"; } if (!spread) { spread = "pad"; } let cx = parseFloat(node.getAttribute("cx")) || 0; let cy = parseFloat(node.getAttribute("cy")) || 0; let r = parseFloat(node.getAttribute("r")); let fx = parseFloat(node.getAttribute("fx")); let fy = parseFloat(node.getAttribute("fy")); if (_.isNaN(fx)) { fx = cx; } if (_.isNaN(fy)) { fy = cy; } const ox = Math.abs(cx + fx) / 2; const oy = Math.abs(cy + fy) / 2; if (/userSpaceOnUse/i.test(units)) { cx -= ox; cy -= oy; fx -= ox; fy -= oy; } const stops = []; for (let i = 0; i < node.children.length; i++) { const child = node.children[i]; let offset = child.getAttribute("offset"); if (/%/ig.test(offset)) { offset = parseFloat(offset.replace(/%/ig, "")) / 100; } offset = parseFloat(offset); let color = child.getAttribute("stop-color"); let opacity = child.getAttribute("stop-opacity"); let style = child.getAttribute("style"); let matches; if (color === null) { matches = style ? style.match(/stop-color:\s?([#a-fA-F0-9]*)/) : false; color = matches && matches.length > 1 ? matches[1] : void 0; } if (opacity === null) { matches = style ? style.match(/stop-opacity:\s?([0-9.-]*)/) : false; opacity = matches && matches.length > 1 ? parseFloat(matches[1]) : 1; } else { opacity = parseFloat(opacity); } stops.push(new Stop(offset, color, opacity)); } const gradient = new RadialGradient(cx, cy, r, stops, fx, fy); gradient.spread = spread; gradient.units = units; applySvgAttributes.call(this, node, gradient, parentStyles); return gradient; }, text: function(node, parentStyles) { const alignment = getAlignment(node.getAttribute("text-anchor")) || "left"; const baseline = getBaseline(node) || "baseline"; const message = node.textContent; const text = new Text(message); applySvgAttributes.call(this, node, text, parentStyles); text.alignment = alignment; text.baseline = baseline; return text; }, clippath: function(node, parentStyles) { if (read.defs.current && !read.defs.current.contains(node.id)) { read.defs.current.add(node.id, node); } return null; }, image: function(node, parentStyles) { let error; const href = node.getAttribute("href") || node.getAttribute("xlink:href"); if (!href) { error = new TwoError("encountered with no href."); console.warn(error.name, error.message); return null; } const x = parseFloat(node.getAttribute("x")) || 0; const y = parseFloat(node.getAttribute("y")) || 0; const width = parseFloat(node.getAttribute("width")); const height = parseFloat(node.getAttribute("height")); const sprite = new Sprite(href, x, y); if (!_.isNaN(width)) { sprite.width = width; } if (!_.isNaN(height)) { sprite.height = height; } applySvgAttributes.call(this, node, sprite, parentStyles); return sprite; } }; // src/utils/xhr.js function xhr(path, callback) { const xhr2 = new XMLHttpRequest(); xhr2.open("GET", path); xhr2.onreadystatechange = function() { if (xhr2.readyState === 4 && xhr2.status === 200) { callback(xhr2.responseText); } }; xhr2.send(); return xhr2; } // src/effects/image-sequence.js var _ImageSequence = class extends Rectangle { _flagTextures = false; _flagFrameRate = false; _flagIndex = false; _amount = 1; _duration = 0; _index = 0; _startTime = 0; _playing = false; _firstFrame = 0; _lastFrame = 0; _loop = true; _textures = null; _frameRate = 0; _origin = null; constructor(paths, ox, oy, frameRate) { super(ox, oy, 0, 0); for (let prop in proto19) { Object.defineProperty(this, prop, proto19[prop]); } this._renderer.flagTextures = FlagTextures.bind(this); this._renderer.bindTextures = BindTextures.bind(this); this._renderer.unbindTextures = UnbindTextures.bind(this); this.noStroke(); this.noFill(); if (Array.isArray(paths)) { this.textures = paths.map(GenerateTexture.bind(this)); } else { this.textures = [GenerateTexture(paths)]; } this.origin = new Vector(); this._update(); if (typeof frameRate === "number") { this.frameRate = frameRate; } else { this.frameRate = _ImageSequence.DefaultFrameRate; } this.index = 0; } play(firstFrame, lastFrame, onLastFrame) { this._playing = true; this._firstFrame = 0; this._lastFrame = this.amount - 1; this._startTime = _.performance.now(); if (typeof firstFrame === "number") { this._firstFrame = firstFrame; } if (typeof lastFrame === "number") { this._lastFrame = lastFrame; } if (typeof onLastFrame === "function") { this._onLastFrame = onLastFrame; } else { delete this._onLastFrame; } if (this._index !== this._firstFrame) { this._startTime -= 1e3 * Math.abs(this._index - this._firstFrame) / this._frameRate; } return this; } pause() { this._playing = false; return this; } stop() { this._playing = false; this._index = this._firstFrame; return this; } clone(parent) { const clone = new _ImageSequence( this.textures, this.translation.x, this.translation.y, this.frameRate ); clone._loop = this._loop; if (this._playing) { clone.play(); } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); object.textures = this.textures.map(function(texture) { return texture.toObject(); }); object.frameRate = this.frameRate; object.index = this.index; object._firstFrame = this._firstFrame; object._lastFrame = this._lastFrame; object._loop = this._loop; return object; } _update() { const effect = this._textures; let width, height, elapsed, amount, duration, texture; let index, frames; if (effect) { if (this._flagTextures) { this._amount = effect.length; } if (this._flagFrameRate) { this._duration = 1e3 * this._amount / this._frameRate; } if (this._playing && this._frameRate > 0) { amount = this._amount; if (_.isNaN(this._lastFrame)) { this._lastFrame = amount - 1; } elapsed = _.performance.now() - this._startTime; frames = this._lastFrame + 1; duration = 1e3 * (frames - this._firstFrame) / this._frameRate; if (this._loop) { elapsed = elapsed % duration; } else { elapsed = Math.min(elapsed, duration); } index = lerp(this._firstFrame, frames, elapsed / duration); index = Math.floor(index); if (index !== this._index) { this._index = index; texture = effect[this._index]; if (texture.loaded) { width = texture.image.width; height = texture.image.height; if (this.width !== width) { this.width = width; } if (this.height !== height) { this.height = height; } this.fill = texture; if (index >= this._lastFrame - 1 && this._onLastFrame) { this._onLastFrame(); } } } } else if (this._flagIndex || !(this.fill instanceof Texture)) { texture = effect[this._index]; if (texture.loaded) { width = texture.image.width; height = texture.image.height; if (this.width !== width) { this.width = width; } if (this.height !== height) { this.height = height; } } this.fill = texture; } } super._update.call(this); return this; } flagReset() { this._flagTextures = this._flagFrameRate = false; super.flagReset.call(this); return this; } }; var ImageSequence = _ImageSequence; __publicField(ImageSequence, "Properties", [ "textures", "frameRate", "index" ]); __publicField(ImageSequence, "DefaultFrameRate", 30); var proto19 = { frameRate: { enumerable: true, get: function() { return this._frameRate; }, set: function(v) { this._frameRate = v; this._flagFrameRate = true; } }, index: { enumerable: true, get: function() { return this._index; }, set: function(v) { this._index = v; this._flagIndex = true; } }, textures: { enumerable: true, get: function() { return this._textures; }, set: function(textures) { const bindTextures = this._renderer.bindTextures; const unbindTextures = this._renderer.unbindTextures; if (this._textures) { this._textures.unbind(Events.Types.insert, bindTextures).unbind(Events.Types.remove, unbindTextures); } this._textures = new Collection((textures || []).slice(0)); this._textures.bind(Events.Types.insert, bindTextures).bind(Events.Types.remove, unbindTextures); bindTextures(this._textures); } } }; function FlagTextures() { this._flagTextures = true; } function BindTextures(items) { let i = items.length; while (i--) { items[i].bind(Events.Types.change, this._renderer.flagTextures); } this._renderer.flagTextures(); } function UnbindTextures(items) { let i = items.length; while (i--) { items[i].unbind(Events.Types.change, this._renderer.flagTextures); } this._renderer.flagTextures(); } function GenerateTexture(obj) { if (obj instanceof Texture) { return obj; } else if (typeof obj === "string") { return new Texture(obj); } } // src/shapes/arc-segment.js var _ArcSegment = class extends Path { _flagStartAngle = false; _flagEndAngle = false; _flagInnerRadius = false; _flagOuterRadius = false; _startAngle = 0; _endAngle = TWO_PI; _innerRadius = 0; _outerRadius = 0; constructor(x, y, ir, or, sa, ea, res) { const amount = res || Constants.Resolution * 3; const points = []; for (let i = 0; i < amount; i++) { points.push(new Anchor()); } super(points, true, false, true); for (let prop in proto20) { Object.defineProperty(this, prop, proto20[prop]); } if (typeof ir === "number") { this.innerRadius = ir; } if (typeof or === "number") { this.outerRadius = or; } if (typeof sa === "number") { this.startAngle = sa; } if (typeof ea === "number") { this.endAngle = ea; } this._update(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } } _update() { if (this._flagVertices || this._flagStartAngle || this._flagEndAngle || this._flagInnerRadius || this._flagOuterRadius) { const sa = this._startAngle; const ea = this._endAngle; const ir = this._innerRadius; const or = this._outerRadius; const connected = mod(sa, TWO_PI) === mod(ea, TWO_PI); const punctured = ir > 0; const vertices = this.vertices; let length = punctured ? vertices.length / 2 : vertices.length; let command, id = 0; let i, last, pct, v, theta, step, x, y, amp; if (connected) { length--; } else if (!punctured) { length -= 2; } for (i = 0, last = length - 1; i < length; i++) { pct = i / last; v = vertices[id]; theta = pct * (ea - sa) + sa; step = (ea - sa) / length; x = or * Math.cos(theta); y = or * Math.sin(theta); switch (i) { case 0: command = Commands.move; break; default: command = Commands.curve; } v.command = command; v.x = x; v.y = y; v.controls.left.clear(); v.controls.right.clear(); if (v.command === Commands.curve) { amp = or * step / Math.PI; v.controls.left.x = amp * Math.cos(theta - HALF_PI); v.controls.left.y = amp * Math.sin(theta - HALF_PI); v.controls.right.x = amp * Math.cos(theta + HALF_PI); v.controls.right.y = amp * Math.sin(theta + HALF_PI); if (i === 1) { v.controls.left.multiplyScalar(2); } if (i === last) { v.controls.right.multiplyScalar(2); } } id++; } if (punctured) { if (connected) { vertices[id].command = Commands.close; id++; } else { length--; last = length - 1; } for (i = 0; i < length; i++) { pct = i / last; v = vertices[id]; theta = (1 - pct) * (ea - sa) + sa; step = (ea - sa) / length; x = ir * Math.cos(theta); y = ir * Math.sin(theta); command = Commands.curve; if (i <= 0) { command = connected ? Commands.move : Commands.line; } v.command = command; v.x = x; v.y = y; v.controls.left.clear(); v.controls.right.clear(); if (v.command === Commands.curve) { amp = ir * step / Math.PI; v.controls.left.x = amp * Math.cos(theta + HALF_PI); v.controls.left.y = amp * Math.sin(theta + HALF_PI); v.controls.right.x = amp * Math.cos(theta - HALF_PI); v.controls.right.y = amp * Math.sin(theta - HALF_PI); if (i === 1) { v.controls.left.multiplyScalar(2); } if (i === last) { v.controls.right.multiplyScalar(2); } } id++; } vertices[id].copy(vertices[0]); vertices[id].command = Commands.line; } else if (!connected) { vertices[id].command = Commands.line; vertices[id].x = 0; vertices[id].y = 0; id++; vertices[id].copy(vertices[0]); vertices[id].command = Commands.line; } } super._update.call(this); return this; } flagReset() { super.flagReset.call(this); this._flagStartAngle = this._flagEndAngle = this._flagInnerRadius = this._flagOuterRadius = false; return this; } clone(parent) { const ir = this.innerRadius; const or = this.outerRadius; const sa = this.startAngle; const ea = this.endAngle; const resolution = this.vertices.length; const clone = new _ArcSegment(0, 0, ir, or, sa, ea, resolution); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _ArcSegment.Properties.length; i++) { const k = _ArcSegment.Properties[i]; object[k] = this[k]; } return object; } }; var ArcSegment = _ArcSegment; __publicField(ArcSegment, "Properties", ["startAngle", "endAngle", "innerRadius", "outerRadius"]); var proto20 = { startAngle: { enumerable: true, get: function() { return this._startAngle; }, set: function(v) { this._startAngle = v; this._flagStartAngle = true; } }, endAngle: { enumerable: true, get: function() { return this._endAngle; }, set: function(v) { this._endAngle = v; this._flagEndAngle = true; } }, innerRadius: { enumerable: true, get: function() { return this._innerRadius; }, set: function(v) { this._innerRadius = v; this._flagInnerRadius = true; } }, outerRadius: { enumerable: true, get: function() { return this._outerRadius; }, set: function(v) { this._outerRadius = v; this._flagOuterRadius = true; } } }; // src/shapes/points.js var ceil2 = Math.ceil; var floor3 = Math.floor; var _Points = class extends Shape { _flagVertices = true; _flagLength = true; _flagFill = true; _flagStroke = true; _flagLinewidth = true; _flagOpacity = true; _flagVisible = true; _flagSize = true; _flagSizeAttenuation = true; _length = 0; _fill = "#fff"; _stroke = "#000"; _linewidth = 1; _opacity = 1; _visible = true; _size = 1; _sizeAttenuation = false; _beginning = 0; _ending = 1; _dashes = null; constructor(vertices) { super(); for (let prop in proto21) { Object.defineProperty(this, prop, proto21[prop]); } this._renderer.type = "points"; this._renderer.flagVertices = FlagVertices.bind(this); this._renderer.bindVertices = BindVertices.bind(this); this._renderer.unbindVertices = UnbindVertices.bind(this); this._renderer.flagFill = FlagFill.bind(this); this._renderer.flagStroke = FlagStroke.bind(this); this._renderer.vertices = null; this._renderer.collection = null; this.sizeAttenuation = false; this.beginning = 0; this.ending = 1; this.fill = "#fff"; this.stroke = "#000"; this.className = ""; this.visible = true; this.vertices = vertices; this.dashes = []; this.dashes.offset = 0; } clone(parent) { const clone = new _Points(); for (let j = 0; j < this.vertices.length; j++) { clone.vertices.push(this.vertices[j].clone()); } for (let i = 0; i < _Points.Properties.length; i++) { const k = _Points.Properties[i]; clone[k] = this[k]; } clone.className = this.className; clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } if (parent) { parent.add(clone); } return clone._update(); } toObject() { const result = { vertices: this.vertices.map(function(v) { return v.toObject(); }) }; _.each(_Points.Properties, function(k) { result[k] = this[k]; }, this); result.className = this.className; result.translation = this.translation.toObject(); result.rotation = this.rotation; result.scale = this.scale instanceof Vector ? this.scale.toObject() : this.scale; result.skewX = this.skewX; result.skewY = this.skewY; if (this.matrix.manual) { result.matrix = this.matrix.toObject(); } return result; } noFill = Path.prototype.noFill; noStroke = Path.prototype.noStroke; corner = Path.prototype.corner; center = Path.prototype.center; getBoundingClientRect = Path.prototype.getBoundingClientRect; subdivide(limit) { this._update(); let points = []; for (let i = 0; i < this.vertices.length; i++) { const a = this.vertices[i]; const b = this.vertices[i - 1]; if (!b) { continue; } const x1 = a.x; const y1 = a.y; const x2 = b.x; const y2 = b.y; const subdivisions = subdivide(x1, y1, x1, y1, x2, y2, x2, y2, limit); points = points.concat(subdivisions); } this.vertices = points; return this; } _updateLength = Path.prototype._updateLength; _update() { if (this._flagVertices) { if (this._flagLength) { this._updateLength(void 0, true); } const beginning = Math.min(this._beginning, this._ending); const ending = Math.max(this._beginning, this._ending); const bid = getIdByLength(this, beginning * this._length); const eid = getIdByLength(this, ending * this._length); const low = ceil2(bid); const high = floor3(eid); let j = 0, v; this._renderer.vertices = []; this._renderer.collection = []; for (let i = 0; i < this._collection.length; i++) { if (i >= low && i <= high) { v = this._collection[i]; this._renderer.collection.push(v); this._renderer.vertices[j * 2 + 0] = v.x; this._renderer.vertices[j * 2 + 1] = v.y; j++; } } } super._update.apply(this, arguments); return this; } flagReset() { this._flagVertices = this._flagLength = this._flagFill = this._flagStroke = this._flagLinewidth = this._flagOpacity = this._flagVisible = this._flagSize = this._flagSizeAttenuation = false; super.flagReset.call(this); return this; } }; var Points = _Points; __publicField(Points, "Properties", [ "fill", "stroke", "linewidth", "opacity", "visible", "size", "sizeAttenuation", "beginning", "ending" ]); var proto21 = { linewidth: { enumerable: true, get: function() { return this._linewidth; }, set: function(v) { this._linewidth = v; this._flagLinewidth = true; } }, opacity: { enumerable: true, get: function() { return this._opacity; }, set: function(v) { this._opacity = v; this._flagOpacity = true; } }, visible: { enumerable: true, get: function() { return this._visible; }, set: function(v) { this._visible = v; this._flagVisible = true; } }, size: { enumerable: true, get: function() { return this._size; }, set: function(v) { this._size = v; this._flagSize = true; } }, sizeAttenuation: { enumerable: true, get: function() { return this._sizeAttenuation; }, set: function(v) { this._sizeAttenuation = v; this._flagSizeAttenuation = true; } }, fill: { enumerable: true, get: function() { return this._fill; }, set: function(f) { if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.unbind(Events.Types.change, this._renderer.flagFill); } this._fill = f; this._flagFill = true; if (this._fill instanceof Gradient || this._fill instanceof LinearGradient || this._fill instanceof RadialGradient || this._fill instanceof Texture) { this._fill.bind(Events.Types.change, this._renderer.flagFill); } } }, stroke: { enumerable: true, get: function() { return this._stroke; }, set: function(f) { if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.unbind(Events.Types.change, this._renderer.flagStroke); } this._stroke = f; this._flagStroke = true; if (this._stroke instanceof Gradient || this._stroke instanceof LinearGradient || this._stroke instanceof RadialGradient || this._stroke instanceof Texture) { this._stroke.bind(Events.Types.change, this._renderer.flagStroke); } } }, length: { get: function() { if (this._flagLength) { this._updateLength(); } return this._length; } }, beginning: { enumerable: true, get: function() { return this._beginning; }, set: function(v) { this._beginning = v; this._flagVertices = true; } }, ending: { enumerable: true, get: function() { return this._ending; }, set: function(v) { this._ending = v; this._flagVertices = true; } }, vertices: { enumerable: true, get: function() { return this._collection; }, set: function(vertices) { const bindVertices = this._renderer.bindVertices; const unbindVertices = this._renderer.unbindVertices; if (this._collection) { this._collection.unbind(Events.Types.insert, bindVertices).unbind(Events.Types.remove, unbindVertices); } if (vertices instanceof Collection) { this._collection = vertices; } else { this._collection = new Collection(vertices || []); } this._collection.bind(Events.Types.insert, bindVertices).bind(Events.Types.remove, unbindVertices); bindVertices(this._collection); } }, dashes: { enumerable: true, get: function() { return this._dashes; }, set: function(v) { if (typeof v.offset !== "number") { v.offset = this.dashes && this._dashes.offset || 0; } this._dashes = v; } } }; // src/shapes/polygon.js var cos5 = Math.cos; var sin5 = Math.sin; var _Polygon = class extends Path { _flagWidth = false; _flagHeight = false; _flagSides = false; _radius = 0; _width = 0; _height = 0; _sides = 0; constructor(x, y, radius, sides) { sides = Math.max(sides || 0, 3); super(); for (let prop in proto22) { Object.defineProperty(this, prop, proto22[prop]); } this.closed = true; this.automatic = false; if (typeof radius === "number") { this.radius = radius; } if (typeof sides === "number") { this.sides = sides; } this._update(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } } _update() { if (this._flagVertices || this._flagWidth || this._flagHeight || this._flagSides) { const sides = this._sides; const amount = sides + 1; let length = this.vertices.length; if (length > sides) { this.vertices.splice(sides - 1, length - sides); length = sides; } for (let i = 0; i < amount; i++) { const pct = (i + 0.5) / sides; const theta = TWO_PI * pct + Math.PI / 2; const x = this._width * cos5(theta) / 2; const y = this._height * sin5(theta) / 2; if (i >= length) { this.vertices.push(new Anchor(x, y)); } else { this.vertices[i].set(x, y); } this.vertices[i].command = i === 0 ? Commands.move : Commands.line; } } super._update.call(this); return this; } flagReset() { this._flagWidth = this._flagHeight = this._flagSides = false; super.flagReset.call(this); return this; } clone(parent) { const clone = new _Polygon(0, 0, 0, this.sides); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; clone.width = this.width; clone.height = this.height; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _Polygon.Properties.length; i++) { const k = _Polygon.Properties[i]; object[k] = this[k]; } return object; } }; var Polygon = _Polygon; __publicField(Polygon, "Properties", ["width", "height", "sides"]); var proto22 = { radius: { enumerable: true, get: function() { return this._radius; }, set: function(v) { this._radius = v; this.width = v * 2; this.height = v * 2; } }, width: { enumerable: true, get: function() { return this._width; }, set: function(v) { this._width = v; this._flagWidth = true; this._radius = Math.max(this.width, this.height) / 2; } }, height: { enumerable: true, get: function() { return this._height; }, set: function(v) { this._height = v; this._flagHeight = true; this._radius = Math.max(this.width, this.height) / 2; } }, sides: { enumerable: true, get: function() { return this._sides; }, set: function(v) { this._sides = v; this._flagSides = true; } } }; // src/shapes/star.js var cos6 = Math.cos; var sin6 = Math.sin; var _Star = class extends Path { _flagInnerRadius = false; _flagOuterRadius = false; _flagSides = false; _innerRadius = 0; _outerRadius = 0; _sides = 0; constructor(x, y, innerRadius, outerRadius, sides) { if (arguments.length <= 3) { outerRadius = innerRadius; innerRadius = outerRadius / 2; } if (typeof sides !== "number" || sides <= 0) { sides = 5; } super(); for (let prop in proto23) { Object.defineProperty(this, prop, proto23[prop]); } this.closed = true; this.automatic = false; if (typeof innerRadius === "number") { this.innerRadius = innerRadius; } if (typeof outerRadius === "number") { this.outerRadius = outerRadius; } if (typeof sides === "number") { this.sides = sides; } this._update(); if (typeof x === "number") { this.translation.x = x; } if (typeof y === "number") { this.translation.y = y; } } _update() { if (this._flagVertices || this._flagInnerRadius || this._flagOuterRadius || this._flagSides) { const sides = this._sides * 2; const amount = sides + 1; let length = this.vertices.length; if (length > sides) { this.vertices.splice(sides - 1, length - sides); length = sides; } for (let i = 0; i < amount; i++) { const pct = (i + 0.5) / sides; const theta = TWO_PI * pct; const r = (!(i % 2) ? this._innerRadius : this._outerRadius) / 2; const x = r * cos6(theta); const y = r * sin6(theta); if (i >= length) { this.vertices.push(new Anchor(x, y)); } else { this.vertices[i].set(x, y); } this.vertices[i].command = i === 0 ? Commands.move : Commands.line; } } super._update.call(this); return this; } flagReset() { this._flagInnerRadius = this._flagOuterRadius = this._flagSides = false; super.flagReset.call(this); return this; } clone(parent) { const ir = this.innerRadius; const or = this.outerRadius; const sides = this.sides; const clone = new _Star(0, 0, ir, or, sides); clone.translation.copy(this.translation); clone.rotation = this.rotation; clone.scale = this.scale; clone.skewX = this.skewX; clone.skewY = this.skewY; if (this.matrix.manual) { clone.matrix.copy(this.matrix); } for (let i = 0; i < Path.Properties.length; i++) { const k = Path.Properties[i]; clone[k] = this[k]; } if (parent) { parent.add(clone); } return clone; } toObject() { const object = super.toObject.call(this); for (let i = 0; i < _Star.Properties.length; i++) { const k = _Star.Properties[i]; object[k] = this[k]; } return object; } }; var Star = _Star; __publicField(Star, "Properties", ["innerRadius", "outerRadius", "sides"]); var proto23 = { innerRadius: { enumerable: true, get: function() { return this._innerRadius; }, set: function(v) { this._innerRadius = v; this._flagInnerRadius = true; } }, outerRadius: { enumerable: true, get: function() { return this._outerRadius; }, set: function(v) { this._outerRadius = v; this._flagOuterRadius = true; } }, sides: { enumerable: true, get: function() { return this._sides; }, set: function(v) { this._sides = v; this._flagSides = true; } } }; // src/renderers/svg.js var svg = { version: 1.1, ns: "http://www.w3.org/2000/svg", xlink: "http://www.w3.org/1999/xlink", alignments: { left: "start", center: "middle", right: "end" }, createElement: function(name, attrs) { const tag = name; const elem = document.createElementNS(svg.ns, tag); if (tag === "svg") { attrs = _.defaults(attrs || {}, { version: svg.version }); } if (attrs && Object.keys(attrs).length > 0) { svg.setAttributes(elem, attrs); } return elem; }, setAttributes: function(elem, attrs) { const keys = Object.keys(attrs); for (let i = 0; i < keys.length; i++) { if (/href/.test(keys[i])) { elem.setAttributeNS(svg.xlink, keys[i], attrs[keys[i]]); } else { elem.setAttribute(keys[i], attrs[keys[i]]); } } return this; }, removeAttributes: function(elem, attrs) { for (let key in attrs) { elem.removeAttribute(key); } return this; }, toString: function(points, closed2) { let l = points.length, last = l - 1, d, string = ""; for (let i = 0; i < l; i++) { const b = points[i]; const prev = closed2 ? mod(i - 1, l) : Math.max(i - 1, 0); const a = points[prev]; let command, c; let vx, vy, ux, uy, ar, bl, br, cl; let rx, ry, xAxisRotation, largeArcFlag, sweepFlag; let x = toFixed(b.x); let y = toFixed(b.y); switch (b.command) { case Commands.close: command = Commands.close; break; case Commands.arc: rx = b.rx; ry = b.ry; xAxisRotation = b.xAxisRotation; largeArcFlag = b.largeArcFlag; sweepFlag = b.sweepFlag; command = Commands.arc + " " + rx + " " + ry + " " + xAxisRotation + " " + largeArcFlag + " " + sweepFlag + " " + x + " " + y; break; case Commands.curve: ar = a.controls && a.controls.right || Vector.zero; bl = b.controls && b.controls.left || Vector.zero; if (a.relative) { vx = toFixed(ar.x + a.x); vy = toFixed(ar.y + a.y); } else { vx = toFixed(ar.x); vy = toFixed(ar.y); } if (b.relative) { ux = toFixed(bl.x + b.x); uy = toFixed(bl.y + b.y); } else { ux = toFixed(bl.x); uy = toFixed(bl.y); } command = (i === 0 ? Commands.move : Commands.curve) + " " + vx + " " + vy + " " + ux + " " + uy + " " + x + " " + y; break; case Commands.move: d = b; command = Commands.move + " " + x + " " + y; break; default: command = b.command + " " + x + " " + y; } if (i >= last && closed2) { if (b.command === Commands.curve) { c = d; br = b.controls && b.controls.right || b; cl = c.controls && c.controls.left || c; if (b.relative) { vx = toFixed(br.x + b.x); vy = toFixed(br.y + b.y); } else { vx = toFixed(br.x); vy = toFixed(br.y); } if (c.relative) { ux = toFixed(cl.x + c.x); uy = toFixed(cl.y + c.y); } else { ux = toFixed(cl.x); uy = toFixed(cl.y); } x = toFixed(c.x); y = toFixed(c.y); command += " C " + vx + " " + vy + " " + ux + " " + uy + " " + x + " " + y; } if (b.command !== Commands.close) { command += " Z"; } } string += command + " "; } return string; }, pointsToString: function(points, size) { let string = ""; const r = size * 0.5; for (let i = 0; i < points.length; i++) { const x = points[i].x; const y = points[i].y - r; string += Commands.move + " " + x + " " + y + " "; string += "a " + r + " " + r + " 0 1 0 0.001 0 Z"; } return string; }, getClip: function(shape, domElement) { let clip = shape._renderer.clip; if (!clip) { clip = shape._renderer.clip = svg.createElement("clipPath", { "clip-rule": "nonzero" }); } if (clip.parentNode === null) { domElement.defs.appendChild(clip); } return clip; }, defs: { update: function(domElement) { const { defs } = domElement; if (defs._flagUpdate) { const children = Array.prototype.slice.call( defs.children, 0 ); for (let i = 0; i < children.length; i++) { const child = children[i]; const id = child.id; const selector = `[fill="url(#${id})"],[stroke="url(#${id})"],[clip-path="url(#${id})"]`; const exists = domElement.querySelector(selector); if (!exists) { defs.removeChild(child); } } defs._flagUpdate = false; } } }, group: { appendChild: function(object) { const elem = object._renderer.elem; if (!elem) { return; } const tag = elem.nodeName; if (!tag || /(radial|linear)gradient/i.test(tag) || object._clip) { return; } this.elem.appendChild(elem); }, removeChild: function(object) { const elem = object._renderer.elem; if (!elem || elem.parentNode != this.elem) { return; } const tag = elem.nodeName; if (!tag) { return; } if (object._clip) { return; } this.elem.removeChild(elem); }, orderChild: function(object) { this.elem.appendChild(object._renderer.elem); }, renderChild: function(child) { svg[child._renderer.type].render.call(child, this); }, render: function(domElement) { if (!this._visible && !this._flagVisible || this._opacity === 0 && !this._flagOpacity) { return this; } this._update(); if (!this._renderer.elem) { this._renderer.elem = svg.createElement("g", { id: this.id }); domElement.appendChild(this._renderer.elem); } const flagMatrix = this._matrix.manual || this._flagMatrix; const context = { domElement, elem: this._renderer.elem }; if (flagMatrix) { this._renderer.elem.setAttribute("transform", "matrix(" + this._matrix.toString() + ")"); } for (let i = 0; i < this.children.length; i++) { const child = this.children[i]; svg[child._renderer.type].render.call(child, domElement); } if (this._flagId) { this._renderer.elem.setAttribute("id", this._id); } if (this._flagOpacity) { this._renderer.elem.setAttribute("opacity", this._opacity); } if (this._flagVisible) { this._renderer.elem.setAttribute("display", this._visible ? "inline" : "none"); } if (this._flagClassName) { this._renderer.elem.setAttribute("class", this.classList.join(" ")); } if (this._flagAdditions) { this.additions.forEach(svg.group.appendChild, context); } if (this._flagSubtractions) { this.subtractions.forEach(svg.group.removeChild, context); } if (this._flagOrder) { this.children.forEach(svg.group.orderChild, context); } if (this._flagMask) { if (this._mask) { svg[this._mask._renderer.type].render.call(this._mask, domElement); this._renderer.elem.setAttribute("clip-path", "url(#" + this._mask.id + ")"); } else { this._renderer.elem.removeAttribute("clip-path"); } } if (this.dataset) { Object.assign(this._renderer.elem.dataset, this.dataset); } return this.flagReset(); } }, path: { render: function(domElement) { if (this._opacity === 0 && !this._flagOpacity) { return this; } this._update(); const changed = {}; const flagMatrix = this._matrix.manual || this._flagMatrix; if (flagMatrix) { changed.transform = "matrix(" + this._matrix.toString() + ")"; } if (this._flagId) { changed.id = this._id; } if (this._flagVertices) { const vertices = svg.toString(this._renderer.vertices, this._closed); changed.d = vertices; } if (this._fill && this._fill._renderer) { this._renderer.hasFillEffect = true; this._fill._update(); svg[this._fill._renderer.type].render.call(this._fill, domElement, true); } if (this._flagFill) { changed.fill = this._fill && this._fill.id ? "url(#" + this._fill.id + ")" : this._fill; if (this._renderer.hasFillEffect && typeof this._fill.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasFillEffect; } } if (this._stroke && this._stroke._renderer) { this._renderer.hasStrokeEffect = true; this._stroke._update(); svg[this._stroke._renderer.type].render.call(this._stroke, domElement, true); } if (this._flagStroke) { changed.stroke = this._stroke && this._stroke.id ? "url(#" + this._stroke.id + ")" : this._stroke; if (this._renderer.hasStrokeEffect && typeof this._stroke.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasStrokeEffect; } } if (this._flagLinewidth) { changed["stroke-width"] = this._linewidth; } if (this._flagOpacity) { changed["stroke-opacity"] = this._opacity; changed["fill-opacity"] = this._opacity; } if (this._flagClassName) { changed["class"] = this.classList.join(" "); } if (this._flagVisible) { changed.visibility = this._visible ? "visible" : "hidden"; } if (this._flagCap) { changed["stroke-linecap"] = this._cap; } if (this._flagJoin) { changed["stroke-linejoin"] = this._join; } if (this._flagMiter) { changed["stroke-miterlimit"] = this._miter; } if (this.dashes && this.dashes.length > 0) { changed["stroke-dasharray"] = this.dashes.join(" "); changed["stroke-dashoffset"] = this.dashes.offset || 0; } if (!this._renderer.elem) { changed.id = this._id; this._renderer.elem = svg.createElement("path", changed); domElement.appendChild(this._renderer.elem); } else { svg.setAttributes(this._renderer.elem, changed); } if (this._flagClip) { const clip = svg.getClip(this, domElement); const elem = this._renderer.elem; if (this._clip) { elem.removeAttribute("id"); clip.setAttribute("id", this.id); clip.appendChild(elem); } else { clip.removeAttribute("id"); elem.setAttribute("id", this.id); this.parent._renderer.elem.appendChild(elem); } } if (this._flagMask) { if (this._mask) { svg[this._mask._renderer.type].render.call(this._mask, domElement); this._renderer.elem.setAttribute("clip-path", "url(#" + this._mask.id + ")"); } else { this._renderer.elem.removeAttribute("clip-path"); } } return this.flagReset(); } }, points: { render: function(domElement) { if (this._opacity === 0 && !this._flagOpacity) { return this; } this._update(); const changed = {}; const flagMatrix = this._matrix.manual || this._flagMatrix; if (flagMatrix) { changed.transform = "matrix(" + this._matrix.toString() + ")"; } if (this._flagId) { changed.id = this._id; } if (this._flagVertices || this._flagSize || this._flagSizeAttenuation) { let size = this._size; if (!this._sizeAttenuation) { const me = this.worldMatrix.elements; const m = decomposeMatrix(me[0], me[3], me[1], me[4], me[2], me[5]); size /= Math.max(m.scaleX, m.scaleY); } const vertices = svg.pointsToString(this._renderer.collection, size); changed.d = vertices; } if (this._fill && this._fill._renderer) { this._renderer.hasFillEffect = true; this._fill._update(); svg[this._fill._renderer.type].render.call(this._fill, domElement, true); } if (this._flagFill) { changed.fill = this._fill && this._fill.id ? "url(#" + this._fill.id + ")" : this._fill; if (this._renderer.hasFillEffect && typeof this._fill.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasFillEffect; } } if (this._stroke && this._stroke._renderer) { this._renderer.hasStrokeEffect = true; this._stroke._update(); svg[this._stroke._renderer.type].render.call(this._stroke, domElement, true); } if (this._flagStroke) { changed.stroke = this._stroke && this._stroke.id ? "url(#" + this._stroke.id + ")" : this._stroke; if (this._renderer.hasStrokeEffect && typeof this._stroke.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasStrokeEffect; } } if (this._flagLinewidth) { changed["stroke-width"] = this._linewidth; } if (this._flagOpacity) { changed["stroke-opacity"] = this._opacity; changed["fill-opacity"] = this._opacity; } if (this._flagClassName) { changed["class"] = this.classList.join(" "); } if (this._flagVisible) { changed.visibility = this._visible ? "visible" : "hidden"; } if (this.dashes && this.dashes.length > 0) { changed["stroke-dasharray"] = this.dashes.join(" "); changed["stroke-dashoffset"] = this.dashes.offset || 0; } if (!this._renderer.elem) { changed.id = this._id; this._renderer.elem = svg.createElement("path", changed); domElement.appendChild(this._renderer.elem); } else { svg.setAttributes(this._renderer.elem, changed); } return this.flagReset(); } }, text: { render: function(domElement) { this._update(); const changed = {}; const flagMatrix = this._matrix.manual || this._flagMatrix; if (flagMatrix) { changed.transform = "matrix(" + this._matrix.toString() + ")"; } if (this._flagId) { changed.id = this._id; } if (this._flagFamily) { changed["font-family"] = this._family; } if (this._flagSize) { changed["font-size"] = this._size; } if (this._flagLeading) { changed["line-height"] = this._leading; } if (this._flagAlignment) { changed["text-anchor"] = svg.alignments[this._alignment] || this._alignment; } if (this._flagBaseline) { changed["alignment-baseline"] = changed["dominant-baseline"] = this._baseline; } if (this._flagStyle) { changed["font-style"] = this._style; } if (this._flagWeight) { changed["font-weight"] = this._weight; } if (this._flagDecoration) { changed["text-decoration"] = this._decoration; } if (this._fill && this._fill._renderer) { this._renderer.hasFillEffect = true; this._fill._update(); svg[this._fill._renderer.type].render.call(this._fill, domElement, true); } if (this._flagFill) { changed.fill = this._fill && this._fill.id ? "url(#" + this._fill.id + ")" : this._fill; if (this._renderer.hasFillEffect && typeof this._fill.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasFillEffect; } } if (this._stroke && this._stroke._renderer) { this._renderer.hasStrokeEffect = true; this._stroke._update(); svg[this._stroke._renderer.type].render.call(this._stroke, domElement, true); } if (this._flagStroke) { changed.stroke = this._stroke && this._stroke.id ? "url(#" + this._stroke.id + ")" : this._stroke; if (this._renderer.hasStrokeEffect && typeof this._stroke.id === "undefined") { domElement.defs._flagUpdate = true; delete this._renderer.hasStrokeEffect; } } if (this._flagLinewidth) { changed["stroke-width"] = this._linewidth; } if (this._flagOpacity) { changed.opacity = this._opacity; } if (this._flagClassName) { changed["class"] = this.classList.join(" "); } if (this._flagVisible) { changed.visibility = this._visible ? "visible" : "hidden"; } if (this.dashes && this.dashes.length > 0) { changed["stroke-dasharray"] = this.dashes.join(" "); changed["stroke-dashoffset"] = this.dashes.offset || 0; } if (!this._renderer.elem) { changed.id = this._id; this._renderer.elem = svg.createElement("text", changed); domElement.appendChild(this._renderer.elem); } else { svg.setAttributes(this._renderer.elem, changed); } if (this._flagClip) { const clip = svg.getClip(this, domElement); const elem = this._renderer.elem; if (this._clip) { elem.removeAttribute("id"); clip.setAttribute("id", this.id); clip.appendChild(elem); } else { clip.removeAttribute("id"); elem.setAttribute("id", this.id); this.parent._renderer.elem.appendChild(elem); } } if (this._flagMask) { if (this._mask) { svg[this._mask._renderer.type].render.call(this._mask, domElement); this._renderer.elem.setAttribute("clip-path", "url(#" + this._mask.id + ")"); } else { this._renderer.elem.removeAttribute("clip-path"); } } if (this._flagValue) { this._renderer.elem.textContent = this._value; } return this.flagReset(); } }, "linear-gradient": { render: function(domElement, silent) { if (!silent) { this._update(); } const changed = {}; if (this._flagId) { changed.id = this._id; } if (this._flagEndPoints) { changed.x1 = this.left._x; changed.y1 = this.left._y; changed.x2 = this.right._x; changed.y2 = this.right._y; } if (this._flagSpread) { changed.spreadMethod = this._spread; } if (this._flagUnits) { changed.gradientUnits = this._units; } if (!this._renderer.elem) { changed.id = this._id; this._renderer.elem = svg.createElement("linearGradient", changed); } else { svg.setAttributes(this._renderer.elem, changed); } if (this._renderer.elem.parentNode === null) { domElement.defs.appendChild(this._renderer.elem); } if (this._flagStops) { const lengthChanged = this._renderer.elem.childNodes.length !== this.stops.length; if (lengthChanged) { while (this._renderer.elem.lastChild) { this._renderer.elem.removeChild(this._renderer.elem.lastChild); } } for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; const attrs = {}; if (stop._flagOffset) { attrs.offset = 100 * stop._offset + "%"; } if (stop._flagColor) { attrs["stop-color"] = stop._color; } if (stop._flagOpacity) { attrs["stop-opacity"] = stop._opacity; } if (!stop._renderer.elem) { stop._renderer.elem = svg.createElement("stop", attrs); } else { svg.setAttributes(stop._renderer.elem, attrs); } if (lengthChanged) { this._renderer.elem.appendChild(stop._renderer.elem); } stop.flagReset(); } } return this.flagReset(); } }, "radial-gradient": { render: function(domElement, silent) { if (!silent) { this._update(); } const changed = {}; if (this._flagId) { changed.id = this._id; } if (this._flagCenter) { changed.cx = this.center._x; changed.cy = this.center._y; } if (this._flagFocal) { changed.fx = this.focal._x; changed.fy = this.focal._y; } if (this._flagRadius) { changed.r = this._radius; } if (this._flagSpread) { changed.spreadMethod = this._spread; } if (this._flagUnits) { changed.gradientUnits = this._units; } if (!this._renderer.elem) { changed.id = this._id; this._renderer.elem = svg.createElement("radialGradient", changed); } else { svg.setAttributes(this._renderer.elem, changed); } if (this._renderer.elem.parentNode === null) { domElement.defs.appendChild(this._renderer.elem); } if (this._flagStops) { const lengthChanged = this._renderer.elem.childNodes.length !== this.stops.length; if (lengthChanged) { while (this._renderer.elem.lastChild) { this._renderer.elem.removeChild(this._renderer.elem.lastChild); } } for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; const attrs = {}; if (stop._flagOffset) { attrs.offset = 100 * stop._offset + "%"; } if (stop._flagColor) { attrs["stop-color"] = stop._color; } if (stop._flagOpacity) { attrs["stop-opacity"] = stop._opacity; } if (!stop._renderer.elem) { stop._renderer.elem = svg.createElement("stop", attrs); } else { svg.setAttributes(stop._renderer.elem, attrs); } if (lengthChanged) { this._renderer.elem.appendChild(stop._renderer.elem); } stop.flagReset(); } } return this.flagReset(); } }, texture: { render: function(domElement, silent) { if (!silent) { this._update(); } const changed = {}; const styles = { x: 0, y: 0 }; const image = this.image; if (this._flagId) { changed.id = this._id; } if (this._flagLoaded && this.loaded) { switch (image.nodeName.toLowerCase()) { case "canvas": styles.href = styles["xlink:href"] = image.toDataURL("image/png"); break; case "img": case "image": styles.href = styles["xlink:href"] = this.src; break; } } if (this._flagOffset || this._flagLoaded || this._flagScale) { changed.x = this._offset.x; changed.y = this._offset.y; if (image) { changed.x -= image.width / 2; changed.y -= image.height / 2; if (this._scale instanceof Vector) { changed.x *= this._scale.x; changed.y *= this._scale.y; } else { changed.x *= this._scale; changed.y *= this._scale; } } if (changed.x > 0) { changed.x *= -1; } if (changed.y > 0) { changed.y *= -1; } } if (this._flagScale || this._flagLoaded || this._flagRepeat) { changed.width = 0; changed.height = 0; if (image) { styles.width = changed.width = image.width; styles.height = changed.height = image.height; switch (this._repeat) { case "no-repeat": changed.width += 1; changed.height += 1; break; } if (this._scale instanceof Vector) { changed.width *= this._scale.x; changed.height *= this._scale.y; } else { changed.width *= this._scale; changed.height *= this._scale; } } } if (this._flagScale || this._flagLoaded) { if (!this._renderer.image) { this._renderer.image = svg.createElement("image", styles); } else { svg.setAttributes(this._renderer.image, styles); } } if (!this._renderer.elem) { changed.id = this._id; changed.patternUnits = "userSpaceOnUse"; this._renderer.elem = svg.createElement("pattern", changed); } else if (Object.keys(changed).length !== 0) { svg.setAttributes(this._renderer.elem, changed); } if (this._renderer.elem.parentNode === null) { domElement.defs.appendChild(this._renderer.elem); } if (this._renderer.elem && this._renderer.image && !this._renderer.appended) { this._renderer.elem.appendChild(this._renderer.image); this._renderer.appended = true; } return this.flagReset(); } } }; var Renderer2 = class extends Events { constructor(params) { super(); this.domElement = params.domElement || svg.createElement("svg"); this.scene = new Group(); this.scene.parent = this; this.defs = svg.createElement("defs"); this.defs._flagUpdate = false; this.domElement.appendChild(this.defs); this.domElement.defs = this.defs; this.domElement.style.overflow = "hidden"; } setSize(width, height) { this.width = width; this.height = height; svg.setAttributes(this.domElement, { width, height }); return this.trigger(Events.Types.resize, width, height); } render() { svg.group.render.call(this.scene, this.domElement); svg.defs.update(this.domElement); return this; } }; __publicField(Renderer2, "Utils", svg); // src/utils/shaders.js var shaders = { create: function(gl, source, type) { const shader = gl.createShader(gl[type]); gl.shaderSource(shader, source); gl.compileShader(shader); const compiled = gl.getShaderParameter(shader, gl.COMPILE_STATUS); if (!compiled) { const error = gl.getShaderInfoLog(shader); gl.deleteShader(shader); throw new TwoError("unable to compile shader " + shader + ": " + error); } return shader; }, types: { vertex: "VERTEX_SHADER", fragment: "FRAGMENT_SHADER" }, path: { vertex: ` precision mediump float; attribute vec2 a_position; uniform mat3 u_matrix; uniform vec2 u_resolution; uniform vec4 u_rect; varying vec2 v_textureCoords; void main() { vec2 rectCoords = (a_position * (u_rect.zw - u_rect.xy)) + u_rect.xy; vec2 projected = (u_matrix * vec3(rectCoords, 1.0)).xy; vec2 normal = projected / u_resolution; vec2 clipspace = (normal * 2.0) - 1.0; gl_Position = vec4(clipspace * vec2(1.0, -1.0), 0.0, 1.0); v_textureCoords = a_position; } `, fragment: ` precision mediump float; uniform sampler2D u_image; varying vec2 v_textureCoords; void main() { vec4 texel = texture2D(u_image, v_textureCoords); if (texel.a == 0.0) { discard; } gl_FragColor = texel; } ` }, points: { vertex: ` precision mediump float; attribute vec2 a_position; uniform float u_size; uniform mat3 u_matrix; uniform vec2 u_resolution; varying vec2 v_textureCoords; void main() { vec2 projected = (u_matrix * vec3(a_position, 1.0)).xy; vec2 normal = projected / u_resolution; vec2 clipspace = (normal * 2.0) - 1.0; gl_PointSize = u_size; gl_Position = vec4(clipspace * vec2(1.0, -1.0), 0.0, 1.0); v_textureCoords = a_position; } `, fragment: ` precision mediump float; uniform sampler2D u_image; void main() { vec4 texel = texture2D(u_image, gl_PointCoord); if (texel.a == 0.0) { discard; } gl_FragColor = texel; } ` } }; // src/renderers/webgl.js var multiplyMatrix = Matrix2.Multiply; var identity = [1, 0, 0, 0, 1, 0, 0, 0, 1]; var transformation = new NumArray(9); var CanvasUtils = Renderer.Utils; var quad = new NumArray([ 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1 ]); var webgl = { precision: 0.9, isHidden: /(undefined|none|transparent)/i, canvas: root.document ? root.document.createElement("canvas") : { getContext: function() { } }, alignments: { left: "start", middle: "center", right: "end" }, matrix: new Matrix2(), group: { removeChild: function(child, gl) { if (child.children) { for (let i = 0; i < child.children.length; i++) { webgl.group.removeChild(child.children[i], gl); } } if (child._renderer.texture) { gl.deleteTexture(child._renderer.texture); delete child._renderer.texture; } if (child._renderer.positionBuffer) { gl.deleteBuffer(child._renderer.positionBuffer); delete child._renderer.positionBuffer; } }, render: function(gl, programs) { if (!this._visible) { return; } this._update(); const parent = this.parent; const flagParentMatrix = parent._matrix && parent._matrix.manual || parent._flagMatrix; const flagMatrix = this._matrix.manual || this._flagMatrix; if (flagParentMatrix || flagMatrix) { if (!this._renderer.matrix) { this._renderer.matrix = new NumArray(9); } this._matrix.toTransformArray(true, transformation); multiplyMatrix(transformation, parent._renderer.matrix, this._renderer.matrix); if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.x = this._scale.x; this._renderer.scale.y = this._scale.y; } else { this._renderer.scale.x = this._scale; this._renderer.scale.y = this._scale; } if (!/renderer/i.test(parent._renderer.type)) { this._renderer.scale.x *= parent._renderer.scale.x; this._renderer.scale.y *= parent._renderer.scale.y; } if (flagParentMatrix) { this._flagMatrix = true; } } if (this._mask) { gl.clear(gl.STENCIL_BUFFER_BIT); gl.enable(gl.STENCIL_TEST); gl.stencilFunc(gl.ALWAYS, 1, 0); gl.stencilOp(gl.KEEP, gl.KEEP, gl.REPLACE); gl.colorMask(false, false, false, false); webgl[this._mask._renderer.type].render.call(this._mask, gl, programs, this); gl.stencilFunc(gl.EQUAL, 1, 255); gl.stencilOp(gl.KEEP, gl.KEEP, gl.KEEP); gl.colorMask(true, true, true, true); } this._flagOpacity = parent._flagOpacity || this._flagOpacity; this._renderer.opacity = this._opacity * (parent && parent._renderer ? parent._renderer.opacity : 1); let i; if (this._flagSubtractions) { for (i = 0; i < this.subtractions.length; i++) { webgl.group.removeChild(this.subtractions[i], gl); } } for (i = 0; i < this.children.length; i++) { const child = this.children[i]; webgl[child._renderer.type].render.call(child, gl, programs); } if (this._mask) { gl.disable(gl.STENCIL_TEST); } return this.flagReset(); } }, path: { updateCanvas: function(elem) { let prev, a, c, ux, uy, vx, vy, ar, bl, br, cl, x, y; let isOffset; const commands = elem._renderer.vertices; const canvas3 = this.canvas; const ctx = this.ctx; const scale = elem._renderer.scale; const stroke = elem._stroke; const linewidth = elem._linewidth; const fill = elem._fill; const opacity = elem._renderer.opacity || elem._opacity; const cap = elem._cap; const join = elem._join; const miter = elem._miter; const closed2 = elem._closed; const dashes = elem.dashes; const length = commands.length; const last = length - 1; canvas3.width = Math.max(Math.ceil(elem._renderer.rect.width * scale.x), 1); canvas3.height = Math.max(Math.ceil(elem._renderer.rect.height * scale.y), 1); const centroid = elem._renderer.rect.centroid; const cx = centroid.x; const cy = centroid.y; ctx.clearRect(0, 0, canvas3.width, canvas3.height); if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { webgl[fill._renderer.type].render.call(fill, ctx, elem); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { webgl[stroke._renderer.type].render.call(stroke, ctx, elem); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth; } if (miter) { ctx.miterLimit = miter; } if (join) { ctx.lineJoin = join; } if (!closed2 && cap) { ctx.lineCap = cap; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } let d, rx, ry, xAxisRotation, largeArcFlag, sweepFlag, ax, ay; ctx.save(); ctx.scale(scale.x, scale.y); ctx.translate(cx, cy); ctx.beginPath(); for (let i = 0; i < commands.length; i++) { const b = commands[i]; x = b.x; y = b.y; switch (b.command) { case Commands.close: ctx.closePath(); break; case Commands.arc: rx = b.rx; ry = b.ry; xAxisRotation = b.xAxisRotation; largeArcFlag = b.largeArcFlag; sweepFlag = b.sweepFlag; prev = closed2 ? mod(i - 1, length) : Math.max(i - 1, 0); a = commands[prev]; ax = a.x; ay = a.y; CanvasUtils.renderSvgArcCommand(ctx, ax, ay, rx, ry, largeArcFlag, sweepFlag, xAxisRotation, x, y); break; case Commands.curve: prev = closed2 ? mod(i - 1, length) : Math.max(i - 1, 0); a = commands[prev]; ar = a.controls && a.controls.right || Vector.zero; bl = b.controls && b.controls.left || Vector.zero; if (a._relative) { vx = ar.x + a.x; vy = ar.y + a.y; } else { vx = ar.x; vy = ar.y; } if (b._relative) { ux = bl.x + b.x; uy = bl.y + b.y; } else { ux = bl.x; uy = bl.y; } ctx.bezierCurveTo(vx, vy, ux, uy, x, y); if (i >= last && closed2) { c = d; br = b.controls && b.controls.right || Vector.zero; cl = c.controls && c.controls.left || Vector.zero; if (b._relative) { vx = br.x + b.x; vy = br.y + b.y; } else { vx = br.x; vy = br.y; } if (c._relative) { ux = cl.x + c.x; uy = cl.y + c.y; } else { ux = cl.x; uy = cl.y; } x = c.x; y = c.y; ctx.bezierCurveTo(vx, vy, ux, uy, x, y); } break; case Commands.line: ctx.lineTo(x, y); break; case Commands.move: d = b; ctx.moveTo(x, y); break; } } if (closed2) { ctx.closePath(); } if (!webgl.isHidden.test(fill)) { isOffset = fill._renderer && fill._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(fill._renderer.scale.x, fill._renderer.scale.y); } ctx.fill(); if (isOffset) { ctx.restore(); } } if (!webgl.isHidden.test(stroke)) { isOffset = stroke._renderer && stroke._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(stroke._renderer.scale.x, stroke._renderer.scale.y); ctx.lineWidth = linewidth / stroke._renderer.scale.x; } ctx.stroke(); if (isOffset) { ctx.restore(); } } ctx.restore(); }, getBoundingClientRect: function(vertices, border, rect) { let left = Infinity, right = -Infinity, top = Infinity, bottom = -Infinity, width, height; vertices.forEach(function(v) { const x = v.x, y = v.y, controls = v.controls; let a, b, c, d, cl, cr; top = Math.min(y, top); left = Math.min(x, left); right = Math.max(x, right); bottom = Math.max(y, bottom); if (!v.controls) { return; } cl = controls.left; cr = controls.right; if (!cl || !cr) { return; } a = v._relative ? cl.x + x : cl.x; b = v._relative ? cl.y + y : cl.y; c = v._relative ? cr.x + x : cr.x; d = v._relative ? cr.y + y : cr.y; if (!a || !b || !c || !d) { return; } top = Math.min(b, d, top); left = Math.min(a, c, left); right = Math.max(a, c, right); bottom = Math.max(b, d, bottom); }); if (typeof border === "number") { top -= border; left -= border; right += border; bottom += border; } width = right - left; height = bottom - top; rect.top = top; rect.left = left; rect.right = right; rect.bottom = bottom; rect.width = width; rect.height = height; if (!rect.centroid) { rect.centroid = {}; } rect.centroid.x = -left; rect.centroid.y = -top; }, render: function(gl, programs, forcedParent) { if (!this._visible || !this._opacity) { return this; } this._update(); const parent = forcedParent || this.parent; const program = programs[this._renderer.type]; const flagParentMatrix = parent._matrix.manual || parent._flagMatrix; const flagMatrix = this._matrix.manual || this._flagMatrix; const parentChanged = this._renderer.parent !== parent; const flagTexture = this._flagVertices || this._flagFill || this._fill instanceof LinearGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagEndPoints) || this._fill instanceof RadialGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagRadius || this._fill._flagCenter || this._fill._flagFocal) || this._fill instanceof Texture && (this._fill._flagLoaded && this._fill.loaded || this._fill._flagImage || this._fill._flagVideo || this._fill._flagRepeat || this._fill._flagOffset || this._fill._flagScale) || this._stroke instanceof LinearGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagEndPoints) || this._stroke instanceof RadialGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagRadius || this._stroke._flagCenter || this._stroke._flagFocal) || this._stroke instanceof Texture && (this._stroke._flagLoaded && this._stroke.loaded || this._stroke._flagImage || this._stroke._flagVideo || this._stroke._flagRepeat || this._stroke._flagOffset || this._fill._flagScale) || this._flagStroke || this._flagLinewidth || this._flagOpacity || parent._flagOpacity || this._flagVisible || this._flagCap || this._flagJoin || this._flagMiter || this._flagScale || this.dashes && this.dashes.length > 0 || !this._renderer.texture; if (flagParentMatrix || flagMatrix || parentChanged) { if (!this._renderer.matrix) { this._renderer.matrix = new NumArray(9); } this._matrix.toTransformArray(true, transformation); multiplyMatrix(transformation, parent._renderer.matrix, this._renderer.matrix); if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.x = this._scale.x * parent._renderer.scale.x; this._renderer.scale.y = this._scale.y * parent._renderer.scale.y; } else { this._renderer.scale.x = this._scale * parent._renderer.scale.x; this._renderer.scale.y = this._scale * parent._renderer.scale.y; } if (parentChanged) { this._renderer.parent = parent; } } if (this._mask) { gl.clear(gl.STENCIL_BUFFER_BIT); gl.enable(gl.STENCIL_TEST); gl.stencilFunc(gl.ALWAYS, 1, 0); gl.stencilOp(gl.KEEP, gl.KEEP, gl.REPLACE); gl.colorMask(false, false, false, false); webgl[this._mask._renderer.type].render.call(this._mask, gl, programs, this); gl.stencilFunc(gl.EQUAL, 1, 255); gl.stencilOp(gl.KEEP, gl.KEEP, gl.KEEP); gl.colorMask(true, true, true, true); } if (flagTexture) { if (!this._renderer.rect) { this._renderer.rect = {}; } this._renderer.opacity = this._opacity * parent._renderer.opacity; webgl.path.getBoundingClientRect( this._renderer.vertices, this._linewidth, this._renderer.rect ); webgl.updateTexture.call(webgl, gl, this); } else { if (this._fill && this._fill._update) { this._fill._update(); } if (this._stroke && this._stroke._update) { this._stroke._update(); } } if (this._clip && !forcedParent || !this._renderer.texture) { return this; } if (programs.current !== program) { gl.useProgram(program); gl.bindBuffer(gl.ARRAY_BUFFER, programs.buffers.position); gl.vertexAttribPointer(program.position, 2, gl.FLOAT, false, 0, 0); gl.enableVertexAttribArray(program.position); gl.bufferData(gl.ARRAY_BUFFER, quad, gl.STATIC_DRAW); if (!programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } programs.current = program; } if (programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } gl.bindTexture(gl.TEXTURE_2D, this._renderer.texture); const rect = this._renderer.rect; gl.uniformMatrix3fv(program.matrix, false, this._renderer.matrix); gl.uniform4f(program.rect, rect.left, rect.top, rect.right, rect.bottom); gl.drawArrays(gl.TRIANGLES, 0, 6); if (this._mask) { gl.disable(gl.STENCIL_TEST); } return this.flagReset(); } }, points: { updateCanvas: function(elem) { let isOffset; const canvas3 = this.canvas; const ctx = this.ctx; const stroke = elem._stroke; const linewidth = elem._linewidth; const fill = elem._fill; const opacity = elem._renderer.opacity || elem._opacity; const dashes = elem.dashes; const size = elem._size; let dimension = size; if (!webgl.isHidden.test(stroke)) { dimension += linewidth; } canvas3.width = getPoT(dimension); canvas3.height = canvas3.width; const aspect = dimension / canvas3.width; const cx = canvas3.width / 2; const cy = canvas3.height / 2; ctx.clearRect(0, 0, canvas3.width, canvas3.height); if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { webgl[fill._renderer.type].render.call(fill, ctx, elem); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { webgl[stroke._renderer.type].render.call(stroke, ctx, elem); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth / aspect; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } ctx.save(); ctx.translate(cx, cy); ctx.scale(webgl.precision, webgl.precision); ctx.beginPath(); ctx.arc(0, 0, size / aspect * 0.5, 0, TWO_PI); ctx.restore(); if (closed) { ctx.closePath(); } if (!webgl.isHidden.test(fill)) { isOffset = fill._renderer && fill._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(fill._renderer.scale.x, fill._renderer.scale.y); } ctx.fill(); if (isOffset) { ctx.restore(); } } if (!webgl.isHidden.test(stroke)) { isOffset = stroke._renderer && stroke._renderer.offset; if (isOffset) { ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(stroke._renderer.scale.x, stroke._renderer.scale.y); ctx.lineWidth = linewidth / stroke._renderer.scale.x; } ctx.stroke(); if (isOffset) { ctx.restore(); } } }, render: function(gl, programs, forcedParent) { if (!this._visible || !this._opacity) { return this; } this._update(); let size = this._size; const parent = forcedParent || this.parent; const program = programs[this._renderer.type]; const sizeAttenuation = this._sizeAttenuation; const stroke = this._stroke; const linewidth = this._linewidth; const flagParentMatrix = parent._matrix.manual || parent._flagMatrix; const flagMatrix = this._matrix.manual || this._flagMatrix; const parentChanged = this._renderer.parent !== parent; const commands = this._renderer.vertices; const length = this._renderer.collection.length; const flagVertices = this._flagVertices; const flagTexture = this._flagFill || this._fill instanceof LinearGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagEndPoints) || this._fill instanceof RadialGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagRadius || this._fill._flagCenter || this._fill._flagFocal) || this._fill instanceof Texture && (this._fill._flagLoaded && this._fill.loaded || this._fill._flagImage || this._fill._flagVideo || this._fill._flagRepeat || this._fill._flagOffset || this._fill._flagScale) || this._stroke instanceof LinearGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagEndPoints) || this._stroke instanceof RadialGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagRadius || this._stroke._flagCenter || this._stroke._flagFocal) || this._stroke instanceof Texture && (this._stroke._flagLoaded && this._stroke.loaded || this._stroke._flagImage || this._stroke._flagVideo || this._stroke._flagRepeat || this._stroke._flagOffset || this._fill._flagScale) || this._flagStroke || this._flagLinewidth || this._flagOpacity || parent._flagOpacity || this._flagVisible || this._flagScale || this.dashes && this.dashes.length > 0 || !this._renderer.texture; if (flagParentMatrix || flagMatrix || parentChanged) { if (!this._renderer.matrix) { this._renderer.matrix = new NumArray(9); } this._matrix.toTransformArray(true, transformation); multiplyMatrix(transformation, parent._renderer.matrix, this._renderer.matrix); if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.x = this._scale.x * parent._renderer.scale.x; this._renderer.scale.y = this._scale.y * parent._renderer.scale.y; } else { this._renderer.scale.x = this._scale * parent._renderer.scale.x; this._renderer.scale.y = this._scale * parent._renderer.scale.y; } if (parentChanged) { this._renderer.parent = parent; } } if (flagVertices) { const positionBuffer = this._renderer.positionBuffer; if (positionBuffer) { gl.deleteBuffer(positionBuffer); } this._renderer.positionBuffer = gl.createBuffer(); gl.bindBuffer(gl.ARRAY_BUFFER, this._renderer.positionBuffer); gl.vertexAttribPointer(program.position, 2, gl.FLOAT, false, 0, 0); gl.enableVertexAttribArray(program.position); gl.bufferData(gl.ARRAY_BUFFER, commands, gl.STATIC_DRAW); } if (flagTexture) { this._renderer.opacity = this._opacity * parent._renderer.opacity; webgl.updateTexture.call(webgl, gl, this); } else { if (this._fill && this._fill._update) { this._fill._update(); } if (this._stroke && this._stroke._update) { this._stroke._update(); } } if (this._clip && !forcedParent || !this._renderer.texture) { return this; } if (!webgl.isHidden.test(stroke)) { size += linewidth; } size /= webgl.precision; if (sizeAttenuation) { size *= Math.max(this._renderer.scale.x, this._renderer.scale.y); } if (programs.current !== program) { gl.useProgram(program); if (!programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } programs.current = program; } if (programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } gl.bindTexture(gl.TEXTURE_2D, this._renderer.texture); gl.uniformMatrix3fv(program.matrix, false, this._renderer.matrix); gl.uniform1f(program.size, size * programs.resolution.ratio); gl.drawArrays(gl.POINTS, 0, length); return this.flagReset(); } }, text: { updateCanvas: function(elem) { const canvas3 = this.canvas; const ctx = this.ctx; const scale = elem._renderer.scale; const stroke = elem._stroke; const linewidth = elem._linewidth * scale; const fill = elem._fill; const opacity = elem._renderer.opacity || elem._opacity; const dashes = elem.dashes; const decoration = elem._decoration; canvas3.width = Math.max(Math.ceil(elem._renderer.rect.width * scale.x), 1); canvas3.height = Math.max(Math.ceil(elem._renderer.rect.height * scale.y), 1); const centroid = elem._renderer.rect.centroid; const cx = centroid.x; const cy = centroid.y; let a, b, c, d, e, sx, sy, x1, y1, x2, y2; const isOffset = fill._renderer && fill._renderer.offset && stroke._renderer && stroke._renderer.offset; ctx.clearRect(0, 0, canvas3.width, canvas3.height); if (!isOffset) { ctx.font = [elem._style, elem._weight, elem._size + "px/" + elem._leading + "px", elem._family].join(" "); } ctx.textAlign = "center"; ctx.textBaseline = "middle"; if (fill) { if (typeof fill === "string") { ctx.fillStyle = fill; } else { webgl[fill._renderer.type].render.call(fill, ctx, elem); ctx.fillStyle = fill._renderer.effect; } } if (stroke) { if (typeof stroke === "string") { ctx.strokeStyle = stroke; } else { webgl[stroke._renderer.type].render.call(stroke, ctx, elem); ctx.strokeStyle = stroke._renderer.effect; } if (linewidth) { ctx.lineWidth = linewidth; } } if (typeof opacity === "number") { ctx.globalAlpha = opacity; } if (dashes && dashes.length > 0) { ctx.lineDashOffset = dashes.offset || 0; ctx.setLineDash(dashes); } ctx.save(); ctx.scale(scale.x, scale.y); ctx.translate(cx, cy); if (!webgl.isHidden.test(fill)) { if (fill._renderer && fill._renderer.offset) { sx = fill._renderer.scale.x; sy = fill._renderer.scale.y; ctx.save(); ctx.translate( -fill._renderer.offset.x, -fill._renderer.offset.y ); ctx.scale(sx, sy); a = elem._size / fill._renderer.scale.y; b = elem._leading / fill._renderer.scale.y; ctx.font = [ elem._style, elem._weight, a + "px/", b + "px", elem._family ].join(" "); c = fill._renderer.offset.x / fill._renderer.scale.x; d = fill._renderer.offset.y / fill._renderer.scale.y; ctx.fillText(elem.value, c, d); ctx.restore(); } else { ctx.fillText(elem.value, 0, 0); } } if (!webgl.isHidden.test(stroke)) { if (stroke._renderer && stroke._renderer.offset) { sx = stroke._renderer.scale.x; sy = stroke._renderer.scale.y; ctx.save(); ctx.translate( -stroke._renderer.offset.x, -stroke._renderer.offset.y ); ctx.scale(sx, sy); a = elem._size / stroke._renderer.scale.y; b = elem._leading / stroke._renderer.scale.y; ctx.font = [ elem._style, elem._weight, a + "px/", b + "px", elem._family ].join(" "); c = stroke._renderer.offset.x / stroke._renderer.scale.x; d = stroke._renderer.offset.y / stroke._renderer.scale.y; e = linewidth / stroke._renderer.scale.x; ctx.lineWidth = e; ctx.strokeText(elem.value, c, d); ctx.restore(); } else { ctx.strokeText(elem.value, 0, 0); } } if (/(underline|strikethrough)/i.test(decoration)) { const metrics = ctx.measureText(elem.value); switch (decoration) { case "underline": y1 = metrics.actualBoundingBoxAscent; y2 = metrics.actualBoundingBoxAscent; break; case "strikethrough": y1 = 0; y2 = 0; break; } x1 = -metrics.width / 2; x2 = metrics.width / 2; ctx.lineWidth = Math.max(Math.floor(elem._size / 15), 1); ctx.strokeStyle = ctx.fillStyle; ctx.beginPath(); ctx.moveTo(x1, y1); ctx.lineTo(x2, y2); ctx.stroke(); } ctx.restore(); }, getBoundingClientRect: function(elem, rect) { const ctx = webgl.ctx; ctx.font = [elem._style, elem._weight, elem._size + "px/" + elem._leading + "px", elem._family].join(" "); ctx.textAlign = "center"; ctx.textBaseline = elem._baseline; let width = ctx.measureText(elem._value).width * 1.25; let height = Math.max(elem._size, elem._leading) * 1.25; if (this._linewidth && !webgl.isHidden.test(this._stroke)) { width += this._linewidth * 2; height += this._linewidth * 2; } const w = width / 2; const h = height / 2; switch (webgl.alignments[elem._alignment] || elem._alignment) { case webgl.alignments.left: rect.left = 0; rect.right = width; break; case webgl.alignments.right: rect.left = -width; rect.right = 0; break; default: rect.left = -w; rect.right = w; } switch (elem._baseline) { case "bottom": rect.top = -height; rect.bottom = 0; break; case "top": rect.top = 0; rect.bottom = height; break; default: rect.top = -h; rect.bottom = h; } rect.width = width; rect.height = height; if (!rect.centroid) { rect.centroid = {}; } rect.centroid.x = w; rect.centroid.y = h; }, render: function(gl, programs, forcedParent) { if (!this._visible || !this._opacity) { return this; } this._update(); const parent = forcedParent || this.parent; const program = programs[this._renderer.type]; const flagParentMatrix = parent._matrix.manual || parent._flagMatrix; const flagMatrix = this._matrix.manual || this._flagMatrix; const parentChanged = this._renderer.parent !== parent; const flagTexture = this._flagVertices || this._flagFill || this._fill instanceof LinearGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagEndPoints) || this._fill instanceof RadialGradient && (this._fill._flagSpread || this._fill._flagStops || this._fill._flagRadius || this._fill._flagCenter || this._fill._flagFocal) || this._fill instanceof Texture && (this._fill._flagLoaded && this._fill.loaded || this._fill._flagImage || this._fill._flagVideo || this._fill._flagRepeat || this._fill._flagOffset || this._fill._flagScale) || this._stroke instanceof LinearGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagEndPoints) || this._stroke instanceof RadialGradient && (this._stroke._flagSpread || this._stroke._flagStops || this._stroke._flagRadius || this._stroke._flagCenter || this._stroke._flagFocal) || this._stroke instanceof Texture && (this._stroke._flagLoaded && this._stroke.loaded || this._stroke._flagImage || this._stroke._flagVideo || this._stroke._flagRepeat || this._stroke._flagOffset || this._fill._flagScale) || this._flagStroke || this._flagLinewidth || this._flagOpacity || parent._flagOpacity || this._flagVisible || this._flagScale || this._flagValue || this._flagFamily || this._flagSize || this._flagLeading || this._flagAlignment || this._flagBaseline || this._flagStyle || this._flagWeight || this._flagDecoration || this.dashes && this.dashes.length > 0 || !this._renderer.texture; if (flagParentMatrix || flagMatrix || parentChanged) { if (!this._renderer.matrix) { this._renderer.matrix = new NumArray(9); } this._matrix.toTransformArray(true, transformation); multiplyMatrix(transformation, parent._renderer.matrix, this._renderer.matrix); if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.x = this._scale.x * parent._renderer.scale.x; this._renderer.scale.y = this._scale.y * parent._renderer.scale.y; } else { this._renderer.scale.x = this._scale * parent._renderer.scale.x; this._renderer.scale.y = this._scale * parent._renderer.scale.y; } if (parentChanged) { this._renderer.parent = parent; } } if (this._mask) { gl.clear(gl.STENCIL_BUFFER_BIT); gl.enable(gl.STENCIL_TEST); gl.stencilFunc(gl.ALWAYS, 1, 0); gl.stencilOp(gl.KEEP, gl.KEEP, gl.REPLACE); gl.colorMask(false, false, false, false); webgl[this._mask._renderer.type].render.call(this._mask, gl, programs, this); gl.stencilFunc(gl.EQUAL, 1, 255); gl.stencilOp(gl.KEEP, gl.KEEP, gl.KEEP); gl.colorMask(true, true, true, true); } if (flagTexture) { if (!this._renderer.rect) { this._renderer.rect = {}; } this._renderer.opacity = this._opacity * parent._renderer.opacity; webgl.text.getBoundingClientRect(this, this._renderer.rect); webgl.updateTexture.call(webgl, gl, this); } else { if (this._fill && this._fill._update) { this._fill._update(); } if (this._stroke && this._stroke._update) { this._stroke._update(); } } if (this._clip && !forcedParent || !this._renderer.texture) { return this; } if (programs.current !== program) { gl.useProgram(program); gl.bindBuffer(gl.ARRAY_BUFFER, programs.buffers.position); gl.vertexAttribPointer(program.position, 2, gl.FLOAT, false, 0, 0); gl.enableVertexAttribArray(program.position); gl.bufferData(gl.ARRAY_BUFFER, quad, gl.STATIC_DRAW); if (!programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } programs.current = program; } if (programs.resolution.flagged) { gl.uniform2f( gl.getUniformLocation(program, "u_resolution"), programs.resolution.width, programs.resolution.height ); } gl.bindTexture(gl.TEXTURE_2D, this._renderer.texture); const rect = this._renderer.rect; gl.uniformMatrix3fv(program.matrix, false, this._renderer.matrix); gl.uniform4f(program.rect, rect.left, rect.top, rect.right, rect.bottom); gl.drawArrays(gl.TRIANGLES, 0, 6); if (this._mask) { gl.disable(gl.STENCIL_TEST); } return this.flagReset(); } }, "linear-gradient": { render: function(ctx, parent) { if (!ctx.canvas.getContext("2d") || !parent) { return; } this._update(); if (!this._renderer.effect || this._flagEndPoints || this._flagStops || this._flagUnits) { let rect; let lx = this.left._x; let ly = this.left._y; let rx = this.right._x; let ry = this.right._y; if (/objectBoundingBox/i.test(this._units)) { rect = parent.getBoundingClientRect(true); lx = (lx - 0.5) * rect.width; ly = (ly - 0.5) * rect.height; rx = (rx - 0.5) * rect.width; ry = (ry - 0.5) * rect.height; } this._renderer.effect = ctx.createLinearGradient(lx, ly, rx, ry); for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; this._renderer.effect.addColorStop(stop._offset, stop._color); } } return this.flagReset(); } }, "radial-gradient": { render: function(ctx, parent) { if (!ctx.canvas.getContext("2d") || !parent) { return; } this._update(); if (!this._renderer.effect || this._flagCenter || this._flagFocal || this._flagRadius || this._flagStops || this._flagUnits) { let rect; let cx = this.center._x; let cy = this.center._y; let fx = this.focal._x; let fy = this.focal._y; let radius = this._radius; if (/objectBoundingBox/i.test(this._units)) { rect = parent.getBoundingClientRect(true); cx = cx * rect.width * 0.5; cy = cy * rect.height * 0.5; fx = fx * rect.width * 0.5; fy = fy * rect.height * 0.5; radius *= Math.min(rect.width, rect.height) * 0.5; } this._renderer.effect = ctx.createRadialGradient( cx, cy, 0, fx, fy, radius ); for (let i = 0; i < this.stops.length; i++) { const stop = this.stops[i]; this._renderer.effect.addColorStop(stop._offset, stop._color); } } return this.flagReset(); } }, texture: { render: function(ctx, elem) { if (!ctx.canvas.getContext("2d")) { return; } this._update(); const image = this.image; if ((this._flagLoaded || this._flagImage || this._flagVideo || this._flagRepeat) && this.loaded) { this._renderer.effect = ctx.createPattern(image, this._repeat); } else if (!this._renderer.effect) { return this.flagReset(); } if (this._flagOffset || this._flagLoaded || this._flagScale) { if (!(this._renderer.offset instanceof Vector)) { this._renderer.offset = new Vector(); } this._renderer.offset.x = -this._offset.x; this._renderer.offset.y = -this._offset.y; if (image) { this._renderer.offset.x += image.width / 2; this._renderer.offset.y += image.height / 2; if (this._scale instanceof Vector) { this._renderer.offset.x *= this._scale.x; this._renderer.offset.y *= this._scale.y; } else { this._renderer.offset.x *= this._scale; this._renderer.offset.y *= this._scale; } } } if (this._flagScale || this._flagLoaded) { if (!(this._renderer.scale instanceof Vector)) { this._renderer.scale = new Vector(); } if (this._scale instanceof Vector) { this._renderer.scale.copy(this._scale); } else { this._renderer.scale.set(this._scale, this._scale); } } return this.flagReset(); } }, updateTexture: function(gl, elem) { this[elem._renderer.type].updateCanvas.call(webgl, elem); if (this.canvas.width <= 0 || this.canvas.height <= 0) { if (elem._renderer.texture) { gl.deleteTexture(elem._renderer.texture); } delete elem._renderer.texture; return; } if (!elem._renderer.texture) { elem._renderer.texture = gl.createTexture(); } gl.bindTexture(gl.TEXTURE_2D, elem._renderer.texture); gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.LINEAR); gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, this.canvas); }, program: { create: function(gl, shaders2) { let program, linked, error; program = gl.createProgram(); _.each(shaders2, function(s) { gl.attachShader(program, s); }); gl.linkProgram(program); linked = gl.getProgramParameter(program, gl.LINK_STATUS); if (!linked) { error = gl.getProgramInfoLog(program); gl.deleteProgram(program); throw new TwoError("unable to link program: " + error); } return program; } }, TextureRegistry: new Registry() }; webgl.ctx = webgl.canvas.getContext("2d"); var Renderer3 = class extends Events { constructor(params) { super(); let gl, program, vs, fs; this.domElement = params.domElement || document.createElement("canvas"); if (typeof params.offscreenElement !== "undefined") { webgl.canvas = params.offscreenElement; webgl.ctx = webgl.canvas.getContext("2d"); } this.scene = new Group(); this.scene.parent = this; this._renderer = { type: "renderer", matrix: new NumArray(identity), scale: 1, opacity: 1 }; this._flagMatrix = true; params = _.defaults(params || {}, { antialias: false, alpha: true, premultipliedAlpha: true, stencil: true, preserveDrawingBuffer: true, overdraw: false }); this.overdraw = params.overdraw; gl = this.ctx = this.domElement.getContext("webgl", params) || this.domElement.getContext("experimental-webgl", params); if (!this.ctx) { throw new TwoError( "unable to create a webgl context. Try using another renderer." ); } vs = shaders.create(gl, shaders.path.vertex, shaders.types.vertex); fs = shaders.create(gl, shaders.path.fragment, shaders.types.fragment); this.programs = { current: null, buffers: { position: gl.createBuffer() }, resolution: { width: 0, height: 0, ratio: 1, flagged: false } }; program = this.programs.path = webgl.program.create(gl, [vs, fs]); this.programs.text = this.programs.path; program.position = gl.getAttribLocation(program, "a_position"); program.matrix = gl.getUniformLocation(program, "u_matrix"); program.rect = gl.getUniformLocation(program, "u_rect"); const positionBuffer = gl.createBuffer(); gl.bindBuffer(gl.ARRAY_BUFFER, positionBuffer); gl.vertexAttribPointer(program.position, 2, gl.FLOAT, false, 0, 0); gl.enableVertexAttribArray(program.position); gl.bufferData(gl.ARRAY_BUFFER, quad, gl.STATIC_DRAW); vs = shaders.create(gl, shaders.points.vertex, shaders.types.vertex); fs = shaders.create(gl, shaders.points.fragment, shaders.types.fragment); program = this.programs.points = webgl.program.create(gl, [vs, fs]); program.position = gl.getAttribLocation(program, "a_position"); program.matrix = gl.getUniformLocation(program, "u_matrix"); program.size = gl.getUniformLocation(program, "u_size"); gl.enable(gl.BLEND); gl.pixelStorei(gl.UNPACK_PREMULTIPLY_ALPHA_WEBGL, true); gl.blendEquation(gl.FUNC_ADD); gl.blendFunc(gl.ONE, gl.ONE_MINUS_SRC_ALPHA); } setSize(width, height, ratio) { let w, h; const ctx = this.ctx; this.width = width; this.height = height; this.ratio = typeof ratio === "undefined" ? getRatio(ctx) : ratio; this.domElement.width = width * this.ratio; this.domElement.height = height * this.ratio; if (_.isObject(this.domElement.style)) { _.extend(this.domElement.style, { width: width + "px", height: height + "px" }); } this._renderer.matrix[0] = this._renderer.matrix[4] = this._renderer.scale = this.ratio; this._flagMatrix = true; w = width * this.ratio; h = height * this.ratio; ctx.viewport(0, 0, w, h); this.programs.resolution.width = w; this.programs.resolution.height = h; this.programs.resolution.ratio = this.ratio; this.programs.resolution.flagged = true; return this.trigger(Events.Types.resize, width, height, ratio); } render() { const gl = this.ctx; if (!this.overdraw) { gl.clear(gl.COLOR_BUFFER_BIT); } webgl.group.render.call(this.scene, gl, this.programs); this._flagMatrix = false; this.programs.resolution.flagged = true; return this; } }; __publicField(Renderer3, "Utils", webgl); // src/two.js var Utils = _.extend({ Error: TwoError, getRatio, read, xhr }, _, CanvasShim, curves_exports, math_exports); var _Two = class { _events = new Events(); get _bound() { return this._events._bound; } set _bound(v) { this._events._bound = v; } addEventListener() { return this._events.addEventListener.apply(this, arguments); } on() { return this._events.addEventListener.apply(this, arguments); } bind() { return this._events.addEventListener.apply(this, arguments); } removeEventListener() { return this._events.removeEventListener.apply(this, arguments); } off() { return this._events.removeEventListener.apply(this, arguments); } unbind() { return this._events.removeEventListener.apply(this, arguments); } dispatchEvent() { return this._events.dispatchEvent.apply(this, arguments); } trigger() { return this._events.dispatchEvent.apply(this, arguments); } listen() { return this._events.listen.apply(this, arguments); } ignore() { return this._events.ignore.apply(this, arguments); } type = ""; renderer = null; scene = null; width = 0; height = 0; frameCount = 0; timeDelta = 0; playing = false; constructor(options) { const params = _.defaults(options || {}, { fullscreen: false, fitted: false, width: 640, height: 480, type: _Two.Types.svg, autostart: false }); _.each(params, function(v, k) { if (/fullscreen/i.test(k) || /autostart/i.test(k)) { return; } this[k] = v; }, this); if (_.isElement(params.domElement)) { const tagName = params.domElement.tagName.toLowerCase(); if (!/^(CanvasRenderer-canvas|WebGLRenderer-canvas|SVGRenderer-svg)$/.test(this.type + "-" + tagName)) { this.type = _Two.Types[tagName]; } } this.renderer = new _Two[this.type](this); this.setPlaying(params.autostart); this.frameCount = 0; if (params.fullscreen) { this.fit = fitToWindow.bind(this); this.fit.domElement = window; this.fit.attached = true; _.extend(document.body.style, { overflow: "hidden", margin: 0, padding: 0, top: 0, left: 0, right: 0, bottom: 0, position: "fixed" }); _.extend(this.renderer.domElement.style, { display: "block", top: 0, left: 0, right: 0, bottom: 0, position: "fixed" }); dom.bind(this.fit.domElement, "resize", this.fit); this.fit(); } else if (params.fitted) { this.fit = fitToParent.bind(this); _.extend(this.renderer.domElement.style, { display: "block" }); } else if (!_.isElement(params.domElement)) { this.renderer.setSize(params.width, params.height, this.ratio); this.width = params.width; this.height = params.height; } this.renderer.bind(Events.Types.resize, updateDimensions.bind(this)); this.scene = this.renderer.scene; _Two.Instances.push(this); if (params.autostart) { raf.init(); } } appendTo(elem) { elem.appendChild(this.renderer.domElement); if (this.fit) { if (this.fit.domElement !== window) { this.fit.domElement = elem; this.fit.attached = false; } this.update(); } return this; } play() { this.playing = true; raf.init(); return this.trigger(Events.Types.play); } pause() { this.playing = false; return this.trigger(Events.Types.pause); } setPlaying(p) { this.playing = p; } release(obj) { let i, v, child; if (!_.isObject(obj)) { return this.release(this.scene); } if (typeof obj.unbind === "function") { obj.unbind(); } if (obj.vertices) { if (typeof obj.vertices.unbind === "function") { obj.vertices.unbind(); } for (i = 0; i < obj.vertices.length; i++) { v = obj.vertices[i]; if (typeof v.unbind === "function") { v.unbind(); } if (v.controls) { if (v.controls.left && typeof v.controls.left.unbind === "function") { v.controls.left.unbind(); } if (v.controls.right && typeof v.controls.right.unbind === "function") { v.controls.right.unbind(); } } } } if (obj.children) { for (i = 0; i < obj.children.length; i++) { child = obj.children[i]; this.release(child); } if (typeof obj.children.unbind === "function") { obj.children.unbind(); } } return obj; } update() { const animated = !!this._lastFrame; const now = _.performance.now(); if (animated) { this.timeDelta = parseFloat((now - this._lastFrame).toFixed(3)); } this._lastFrame = now; if (this.fit && this.fit.domElement && !this.fit.attached) { dom.bind(this.fit.domElement, "resize", this.fit); this.fit.attached = true; this.fit(); } const width = this.width; const height = this.height; const renderer = this.renderer; if (width !== renderer.width || height !== renderer.height) { renderer.setSize(width, height, this.ratio); } this.trigger(Events.Types.update, this.frameCount, this.timeDelta); return this.render(); } render() { this.renderer.render(); return this.trigger(Events.Types.render, this.frameCount++); } add(objects) { if (!(objects instanceof Array)) { objects = Array.prototype.slice.call(arguments); } this.scene.add(objects); return this; } remove(objects) { if (!(objects instanceof Array)) { objects = Array.prototype.slice.call(arguments); } this.scene.remove(objects); return this; } clear() { this.scene.remove(this.scene.children); return this; } makeLine(x1, y1, x2, y2) { const line = new Line(x1, y1, x2, y2); this.scene.add(line); return line; } makeArrow(x1, y1, x2, y2, size) { const headlen = typeof size === "number" ? size : 10; const angle = Math.atan2(y2 - y1, x2 - x1); const vertices = [ new Anchor(x1, y1, void 0, void 0, void 0, void 0, Commands.move), new Anchor(x2, y2, void 0, void 0, void 0, void 0, Commands.line), new Anchor( x2 - headlen * Math.cos(angle - Math.PI / 4), y2 - headlen * Math.sin(angle - Math.PI / 4), void 0, void 0, void 0, void 0, Commands.line ), new Anchor(x2, y2, void 0, void 0, void 0, void 0, Commands.move), new Anchor( x2 - headlen * Math.cos(angle + Math.PI / 4), y2 - headlen * Math.sin(angle + Math.PI / 4), void 0, void 0, void 0, void 0, Commands.line ) ]; const path = new Path(vertices, false, false, true); path.noFill(); path.cap = "round"; path.join = "round"; this.scene.add(path); return path; } makeRectangle(x, y, width, height) { const rect = new Rectangle(x, y, width, height); this.scene.add(rect); return rect; } makeRoundedRectangle(x, y, width, height, sides) { const rect = new RoundedRectangle(x, y, width, height, sides); this.scene.add(rect); return rect; } makeCircle(x, y, radius, resolution) { const circle = new Circle(x, y, radius, resolution); this.scene.add(circle); return circle; } makeEllipse(x, y, rx, ry, resolution) { const ellipse = new Ellipse(x, y, rx, ry, resolution); this.scene.add(ellipse); return ellipse; } makeStar(x, y, outerRadius, innerRadius, sides) { const star = new Star(x, y, outerRadius, innerRadius, sides); this.scene.add(star); return star; } makeCurve(points) { const l = arguments.length; if (!Array.isArray(points)) { points = []; for (let i = 0; i < l; i += 2) { const x = arguments[i]; if (typeof x !== "number") { break; } const y = arguments[i + 1]; points.push(new Anchor(x, y)); } } const last = arguments[l - 1]; const curve = new Path(points, !(typeof last === "boolean" ? last : void 0), true); const rect = curve.getBoundingClientRect(); curve.center().translation.set(rect.left + rect.width / 2, rect.top + rect.height / 2); this.scene.add(curve); return curve; } makePolygon(x, y, radius, sides) { const poly = new Polygon(x, y, radius, sides); this.scene.add(poly); return poly; } makeArcSegment(x, y, innerRadius, outerRadius, startAngle, endAngle, resolution) { const arcSegment = new ArcSegment( x, y, innerRadius, outerRadius, startAngle, endAngle, resolution ); this.scene.add(arcSegment); return arcSegment; } makePoints(p) { const l = arguments.length; let vertices = p; if (!Array.isArray(p)) { vertices = []; for (let i = 0; i < l; i += 2) { const x = arguments[i]; if (typeof x !== "number") { break; } const y = arguments[i + 1]; vertices.push(new Vector(x, y)); } } const points = new Points(vertices); this.scene.add(points); return points; } makePath(p) { const l = arguments.length; let points = p; if (!Array.isArray(p)) { points = []; for (let i = 0; i < l; i += 2) { const x = arguments[i]; if (typeof x !== "number") { break; } const y = arguments[i + 1]; points.push(new Anchor(x, y)); } } const last = arguments[l - 1]; const path = new Path(points, !(typeof last === "boolean" ? last : void 0)); const rect = path.getBoundingClientRect(); if (typeof rect.top === "number" && typeof rect.left === "number" && typeof rect.right === "number" && typeof rect.bottom === "number") { path.center().translation.set(rect.left + rect.width / 2, rect.top + rect.height / 2); } this.scene.add(path); return path; } makeText(message, x, y, styles) { const text = new Text(message, x, y, styles); this.add(text); return text; } makeLinearGradient(x1, y1, x2, y2) { const stops = Array.prototype.slice.call(arguments, 4); const gradient = new LinearGradient(x1, y1, x2, y2, stops); this.add(gradient); return gradient; } makeRadialGradient(x1, y1, radius) { const stops = Array.prototype.slice.call(arguments, 3); const gradient = new RadialGradient(x1, y1, radius, stops); this.add(gradient); return gradient; } makeSprite(pathOrTexture, x, y, columns, rows, frameRate, autostart) { const sprite = new Sprite(pathOrTexture, x, y, columns, rows, frameRate); if (autostart) { sprite.play(); } this.add(sprite); return sprite; } makeImageSequence(pathsOrTextures, x, y, frameRate, autostart) { const imageSequence = new ImageSequence(pathsOrTextures, x, y, frameRate); if (autostart) { imageSequence.play(); } this.add(imageSequence); return imageSequence; } makeTexture(pathOrSource, callback) { const texture = new Texture(pathOrSource, callback); return texture; } makeGroup(objects) { if (!(objects instanceof Array)) { objects = Array.prototype.slice.call(arguments); } const group = new Group(); this.scene.add(group); group.add(objects); return group; } interpret(svg2, shallow, add) { const tag = svg2.tagName.toLowerCase(); add = typeof add !== "undefined" ? add : true; if (!(tag in read)) { return null; } const node = read[tag].call(this, svg2); if (add) { this.add(shallow && node instanceof Group ? node.children : node); } else if (node.parent) { node.remove(); } return node; } load(pathOrSVGContent, callback) { const group = new Group(); let elem, i, child; const attach = function(data) { dom.temp.innerHTML = data; for (i = 0; i < dom.temp.children.length; i++) { elem = dom.temp.children[i]; child = this.interpret(elem, false, false); if (child !== null) { group.add(child); } } if (typeof callback === "function") { const svg2 = dom.temp.children.length <= 1 ? dom.temp.children[0] : dom.temp.children; callback(group, svg2); } }.bind(this); if (/\.svg$/i.test(pathOrSVGContent)) { xhr(pathOrSVGContent, attach); return group; } attach(pathOrSVGContent); return group; } }; var Two = _Two; __publicField(Two, "nextFrameID", Constants.nextFrameID); __publicField(Two, "Types", Constants.Types); __publicField(Two, "Version", Constants.Version); __publicField(Two, "PublishDate", Constants.PublishDate); __publicField(Two, "Identifier", Constants.Identifier); __publicField(Two, "Resolution", Constants.Resolution); __publicField(Two, "AutoCalculateImportedMatrices", Constants.AutoCalculateImportedMatrices); __publicField(Two, "Instances", Constants.Instances); __publicField(Two, "uniqueId", Constants.uniqueId); __publicField(Two, "Anchor", Anchor); __publicField(Two, "Collection", Collection); __publicField(Two, "Events", Events); __publicField(Two, "Group", Group); __publicField(Two, "Matrix", Matrix2); __publicField(Two, "Path", Path); __publicField(Two, "Registry", Registry); __publicField(Two, "Shape", Shape); __publicField(Two, "Text", Text); __publicField(Two, "Vector", Vector); __publicField(Two, "Gradient", Gradient); __publicField(Two, "ImageSequence", ImageSequence); __publicField(Two, "LinearGradient", LinearGradient); __publicField(Two, "RadialGradient", RadialGradient); __publicField(Two, "Sprite", Sprite); __publicField(Two, "Stop", Stop); __publicField(Two, "Texture", Texture); __publicField(Two, "ArcSegment", ArcSegment); __publicField(Two, "Circle", Circle); __publicField(Two, "Ellipse", Ellipse); __publicField(Two, "Line", Line); __publicField(Two, "Points", Points); __publicField(Two, "Polygon", Polygon); __publicField(Two, "Rectangle", Rectangle); __publicField(Two, "RoundedRectangle", RoundedRectangle); __publicField(Two, "Star", Star); __publicField(Two, "CanvasRenderer", Renderer); __publicField(Two, "SVGRenderer", Renderer2); __publicField(Two, "WebGLRenderer", Renderer3); __publicField(Two, "Commands", Commands); __publicField(Two, "Utils", Utils); function fitToWindow() { const wr = document.body.getBoundingClientRect(); const width = this.width = wr.width; const height = this.height = wr.height; this.renderer.setSize(width, height, this.ratio); } function fitToParent() { const parent = this.renderer.domElement.parentElement; if (!parent) { console.warn("Two.js: Attempting to fit to parent, but no parent found."); return; } const wr = parent.getBoundingClientRect(); const width = this.width = wr.width; const height = this.height = wr.height; this.renderer.setSize(width, height, this.ratio); } function updateDimensions(width, height) { this.width = width; this.height = height; this.trigger(Events.Types.resize, width, height); } var raf = dom.getRequestAnimationFrame(); function loop() { for (let i = 0; i < Two.Instances.length; i++) { const t = Two.Instances[i]; if (t.playing) { t.update(); } } Two.nextFrameID = raf(loop); } raf.init = function() { loop(); raf.init = function() { }; }; return __toCommonJS(two_exports); })().default; (function(){if(typeof exports==='object'&&typeof module!=='undefined'){module.exports=Two}})() ================================================ FILE: models/layers/gilbert/gilbert2d.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2018 Jakub Červený def gilbert2d(width, height): """ Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized 2D rectangular grids. Generates discrete 2D coordinates to fill a rectangle of size (width x height). """ if width >= height: yield from generate2d(0, 0, width, 0, 0, height) else: yield from generate2d(0, 0, 0, height, width, 0) def gilbert2d_widthBigger(width, height): """ Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized 2D rectangular grids. Generates discrete 2D coordinates to fill a rectangle of size (width x height). """ yield from generate2d(0, 0, width, 0, 0, height) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) def generate2d(x, y, ax, ay, bx, by): w = abs(ax + ay) h = abs(bx + by) (dax, day) = (sgn(ax), sgn(ay)) # unit major direction (dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction if h == 1: # trivial row fill for i in range(0, w): yield(x, y) (x, y) = (x + dax, y + day) return if w == 1: # trivial column fill for i in range(0, h): yield(x, y) (x, y) = (x + dbx, y + dby) return (ax2, ay2) = (ax//2, ay//2) (bx2, by2) = (bx//2, by//2) w2 = abs(ax2 + ay2) h2 = abs(bx2 + by2) if 2*w > 3*h: if (w2 % 2) and (w > 2): # prefer even steps (ax2, ay2) = (ax2 + dax, ay2 + day) # long case: split in two parts only yield from generate2d(x, y, ax2, ay2, bx, by) yield from generate2d(x+ax2, y+ay2, ax-ax2, ay-ay2, bx, by) else: if (h2 % 2) and (h > 2): # prefer even steps (bx2, by2) = (bx2 + dbx, by2 + dby) # standard case: one step up, one long horizontal, one step down yield from generate2d(x, y, bx2, by2, ax2, ay2) yield from generate2d(x+bx2, y+by2, ax, ay, bx-bx2, by-by2) yield from generate2d(x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), -bx2, -by2, -(ax-ax2), -(ay-ay2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) args = parser.parse_args() for x, y in gilbert2d(args.width, args.height): print(x, y) ================================================ FILE: models/layers/gilbert/gilbert3d.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2018 Jakub Červený def gilbert3d(width, height, depth): """ Generalized Hilbert ('Gilbert') space-filling curve for arbitrary-sized 3D rectangular grids. Generates discrete 3D coordinates to fill a cuboid of size (width x height x depth). Even sizes are recommended in 3D. """ if width >= height and width >= depth: yield from generate3d(0, 0, 0, width, 0, 0, 0, height, 0, 0, 0, depth) elif height >= width and height >= depth: yield from generate3d(0, 0, 0, 0, height, 0, width, 0, 0, 0, 0, depth) else: # depth >= width and depth >= height yield from generate3d(0, 0, 0, 0, 0, depth, width, 0, 0, 0, height, 0) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) # T H W def generate3d(x, y, z, ax, ay, az, bx, by, bz, cx, cy, cz): w = abs(ax + ay + az) h = abs(bx + by + bz) d = abs(cx + cy + cz) (dax, day, daz) = (sgn(ax), sgn(ay), sgn(az)) # unit major direction ("right") (dbx, dby, dbz) = (sgn(bx), sgn(by), sgn(bz)) # unit ortho direction ("forward") (dcx, dcy, dcz) = (sgn(cx), sgn(cy), sgn(cz)) # unit ortho direction ("up") # trivial row/column fills if h == 1 and d == 1: for i in range(0, w): yield(x, y, z) (x, y, z) = (x + dax, y + day, z + daz) return if w == 1 and d == 1: for i in range(0, h): yield(x, y, z) (x, y, z) = (x + dbx, y + dby, z + dbz) return if w == 1 and h == 1: for i in range(0, d): yield(x, y, z) (x, y, z) = (x + dcx, y + dcy, z + dcz) return (ax2, ay2, az2) = (ax//2, ay//2, az//2) (bx2, by2, bz2) = (bx//2, by//2, bz//2) (cx2, cy2, cz2) = (cx//2, cy//2, cz//2) w2 = abs(ax2 + ay2 + az2) h2 = abs(bx2 + by2 + bz2) d2 = abs(cx2 + cy2 + cz2) # prefer even steps if (w2 % 2) and (w > 2): (ax2, ay2, az2) = (ax2 + dax, ay2 + day, az2 + daz) if (h2 % 2) and (h > 2): (bx2, by2, bz2) = (bx2 + dbx, by2 + dby, bz2 + dbz) if (d2 % 2) and (d > 2): (cx2, cy2, cz2) = (cx2 + dcx, cy2 + dcy, cz2 + dcz) # wide case, split in w only if (2*w > 3*h) and (2*w > 3*d): yield from generate3d(x, y, z, ax2, ay2, az2, bx, by, bz, cx, cy, cz) yield from generate3d(x+ax2, y+ay2, z+az2, ax-ax2, ay-ay2, az-az2, bx, by, bz, cx, cy, cz) # do not split in d elif 3*h > 4*d: yield from generate3d(x, y, z, bx2, by2, bz2, cx, cy, cz, ax2, ay2, az2) yield from generate3d(x+bx2, y+by2, z+bz2, ax, ay, az, bx-bx2, by-by2, bz-bz2, cx, cy, cz) yield from generate3d(x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx, cy, cz, -(ax-ax2), -(ay-ay2), -(az-az2)) # do not split in h elif 3*d > 4*h: yield from generate3d(x, y, z, cx2, cy2, cz2, ax2, ay2, az2, bx, by, bz) yield from generate3d(x+cx2, y+cy2, z+cz2, ax, ay, az, bx, by, bz, cx-cx2, cy-cy2, cz-cz2) yield from generate3d(x+(ax-dax)+(cx2-dcx), y+(ay-day)+(cy2-dcy), z+(az-daz)+(cz2-dcz), -cx2, -cy2, -cz2, -(ax-ax2), -(ay-ay2), -(az-az2), bx, by, bz) # regular case, split in all w/h/d else: yield from generate3d(x, y, z, bx2, by2, bz2, cx2, cy2, cz2, ax2, ay2, az2) yield from generate3d(x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2) yield from generate3d(x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)) yield from generate3d(x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2) yield from generate3d(x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx2, cy2, cz2, -(ax-ax2), -(ay-ay2), -(az-az2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) parser.add_argument('depth', type=int) args = parser.parse_args() for x, y, z in gilbert3d(args.width, args.height, args.depth): print(x, y, z) ================================================ FILE: models/layers/gilbert/gilbert_d2xy.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2024 abetusk def gilbert_d2xy(idx, w, h): """ Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized 2D rectangular grids. Takes a position along the gilbert curve and returns its 2D (x,y) coordinate. """ if w >= h: return gilbert_d2xy_r(idx,0, 0,0, w,0, 0,h) return gilbert_d2xy_r(idx,0, 0,0, 0,h, w,0) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) def gilbert_d2xy_r(dst_idx, cur_idx, x,y, ax,ay, bx,by): w = abs(ax + ay) h = abs(bx + by) (dax, day) = (sgn(ax), sgn(ay)) # unit major direction (dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction dx = dax + dbx dy = day + dby di = dst_idx - cur_idx if h == 1: return (x + dax*di, y + day*di) if w == 1: return (x + dbx*di, y + dby*di) (ax2, ay2) = (ax//2, ay//2) (bx2, by2) = (bx//2, by//2) w2 = abs(ax2 + ay2) h2 = abs(bx2 + by2) if 2*w > 3*h: if (w2 % 2) and (w > 2): # prefer even steps (ax2, ay2) = (ax2 + dax, ay2 + day) # long case: split in two parts only nxt_idx = cur_idx + abs((ax2 + ay2)*(bx + by)) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xy_r(dst_idx, cur_idx, x, y, ax2, ay2, bx, by) cur_idx = nxt_idx return gilbert_d2xy_r(dst_idx, cur_idx, x+ax2, y+ay2, ax-ax2, ay-ay2, bx, by) if (h2 % 2) and (h > 2): # prefer even steps (bx2, by2) = (bx2 + dbx, by2 + dby) # standard case: one step up, one long horizontal, one step down nxt_idx = cur_idx + abs((bx2 + by2)*(ax2 + ay2)) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xy_r(dst_idx, cur_idx, x,y, bx2,by2, ax2,ay2) cur_idx = nxt_idx nxt_idx = cur_idx + abs((ax + ay)*((bx - bx2) + (by - by2))) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xy_r(dst_idx, cur_idx, x+bx2, y+by2, ax,ay, bx-bx2,by-by2) cur_idx = nxt_idx return gilbert_d2xy_r(dst_idx, cur_idx, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), -bx2, -by2, -(ax-ax2), -(ay-ay2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) args = parser.parse_args() width = args.width height = args.height for idx in range(width*height): (x,y) = gilbert_d2xy(idx, width,height) print(x,y) ================================================ FILE: models/layers/gilbert/gilbert_d2xyz.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2024 abetusk def gilbert_d2xyz(idx, width, height, depth): """ Generalized Hilbert ('Gilbert') space-filling curve for arbitrary-sized 3D rectangular grids. Generates discrete 3D coordinates to fill a cuboid of size (width x height x depth). Even sizes are recommended in 3D. """ if width >= height and width >= depth: return gilbert_d2xyz_r(idx, 0, 0, 0, 0, width, 0, 0, 0, height, 0, 0, 0, depth) elif height >= width and height >= depth: return gilbert_d2xyz_r(idx, 0, 0, 0, 0, 0, height, 0, width, 0, 0, 0, 0, depth) else: # depth >= width and depth >= height return gilbert_d2xyz_r(idx, 0, 0, 0, 0, 0, 0, depth, width, 0, 0, 0, height, 0) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) def gilbert_d2xyz_r(dst_idx, cur_idx, x, y, z, ax, ay, az, bx, by, bz, cx, cy, cz): w = abs(ax + ay + az) h = abs(bx + by + bz) d = abs(cx + cy + cz) (dax, day, daz) = (sgn(ax), sgn(ay), sgn(az)) # unit major direction ("right") (dbx, dby, dbz) = (sgn(bx), sgn(by), sgn(bz)) # unit ortho direction ("forward") (dcx, dcy, dcz) = (sgn(cx), sgn(cy), sgn(cz)) # unit ortho direction ("up") _dx = dax + dbx + dcx _dy = day + dby + dcy _dz = daz + dbz + dcz _di = dst_idx - cur_idx # trivial row/column fills if h == 1 and d == 1: return (x + dax*_di, y + day*_di, z + daz*_di) if w == 1 and d == 1: return (x + dbx*_di, y + dby*_di, z + dbz*_di) if w == 1 and h == 1: return (x + dcx*_di, y + dcy*_di, z + dcz*_di) (ax2, ay2, az2) = (ax//2, ay//2, az//2) (bx2, by2, bz2) = (bx//2, by//2, bz//2) (cx2, cy2, cz2) = (cx//2, cy//2, cz//2) w2 = abs(ax2 + ay2 + az2) h2 = abs(bx2 + by2 + bz2) d2 = abs(cx2 + cy2 + cz2) # prefer even steps if (w2 % 2) and (w > 2): (ax2, ay2, az2) = (ax2 + dax, ay2 + day, az2 + daz) if (h2 % 2) and (h > 2): (bx2, by2, bz2) = (bx2 + dbx, by2 + dby, bz2 + dbz) if (d2 % 2) and (d > 2): (cx2, cy2, cz2) = (cx2 + dcx, cy2 + dcy, cz2 + dcz) # wide case, split in w only if (2*w > 3*h) and (2*w > 3*d): nxt_idx = cur_idx + abs( (ax2 + ay2 + az2)*(bx + by + bz)*(cx + cy + cz) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x, y, z, ax2, ay2, az2, bx, by, bz, cx, cy, cz) cur_idx = nxt_idx return gilbert_d2xyz_r(dst_idx,cur_idx, x+ax2, y+ay2, z+az2, ax-ax2, ay-ay2, az-az2, bx, by, bz, cx, cy, cz) # do not split in d elif 3*h > 4*d: nxt_idx = cur_idx + abs( (bx2 + by2 + bz2)*(cx + cy + cz)*(ax2 + ay2 + az2) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x, y, z, bx2, by2, bz2, cx, cy, cz, ax2, ay2, az2) cur_idx = nxt_idx nxt_idx = cur_idx + abs( (ax + ay + az)*((bx - bx2) + (by - by2) + (bz - bz2))*(cx + cy + cz) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x+bx2, y+by2, z+bz2, ax, ay, az, bx-bx2, by-by2, bz-bz2, cx, cy, cz) cur_idx = nxt_idx return gilbert_d2xyz_r(dst_idx,cur_idx, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx, cy, cz, -(ax-ax2), -(ay-ay2), -(az-az2)) # do not split in h elif 3*d > 4*h: nxt_idx = cur_idx + abs( (cx2 + cy2 + cz2)*(ax2 + ay2 + az2)*(bx + by + bz) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x, y, z, cx2, cy2, cz2, ax2, ay2, az2, bx, by, bz) cur_idx = nxt_idx nxt_idx = cur_idx + abs( (ax + ay + az)*(bx + by + bz)*((cx - cx2) + (cy - cy2) + (cz - cz2)) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x+cx2, y+cy2, z+cz2, ax, ay, az, bx, by, bz, cx-cx2, cy-cy2, cz-cz2) cur_idx = nxt_idx return gilbert_d2xyz_r(dst_idx,cur_idx, x+(ax-dax)+(cx2-dcx), y+(ay-day)+(cy2-dcy), z+(az-daz)+(cz2-dcz), -cx2, -cy2, -cz2, -(ax-ax2), -(ay-ay2), -(az-az2), bx, by, bz) # regular case, split in all w/h/d else: nxt_idx = cur_idx + abs( (bx2 + by2 + bz2)*(cx2 + cy2 + cz2)*(ax2 + ay2 + az2) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x, y, z, bx2, by2, bz2, cx2, cy2, cz2, ax2, ay2, az2) cur_idx = nxt_idx nxt_idx = cur_idx + abs( (cx + cy + cz)*(ax2 + ay2 + az2)*((bx - bx2) + (by - by2) + (bz - bz2)) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2) cur_idx = nxt_idx nxt_idx = cur_idx + abs( (ax + ay + az)*(-bx2 - by2 - bz2)*(-(cx - cx2) - (cy - cy2) - (cz - cz2)) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx, cur_idx, x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)) cur_idx = nxt_idx nxt_idx = cur_idx + abs( (-cx - cy - cz)*(-(ax - ax2) - (ay - ay2) - (az - az2))*((bx - bx2) + (by - by2) + (bz - bz2)) ) if (cur_idx <= dst_idx) and (dst_idx < nxt_idx): return gilbert_d2xyz_r(dst_idx,cur_idx, x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2) cur_idx = nxt_idx return gilbert_d2xyz_r(dst_idx,cur_idx, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx2, cy2, cz2, -(ax-ax2), -(ay-ay2), -(az-az2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) parser.add_argument('depth', type=int) args = parser.parse_args() w = args.width h = args.height d = args.depth n = w*h*d for idx in range(n): (x,y,z) = gilbert_d2xyz(idx,w,h,d) print(x,y,z) ================================================ FILE: models/layers/gilbert/gilbert_xy2d.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2024 abetusk def gilbert_xy2d(x, y, w, h): """ Generalized Hilbert ('gilbert') space-filling curve for arbitrary-sized 2D rectangular grids. Takes a discrete 2D coordinate and maps it to the index position on the gilbert curve. """ if w >= h: return gilbert_xy2d_r(0, x,y, 0,0, w,0, 0,h) return gilbert_xy2d_r(0, x,y, 0,0, 0,h, w,0) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) def in_bounds(x, y, x_s, y_s, ax, ay, bx, by): dx = ax + bx dy = ay + by if dx < 0: if (x > x_s) or (x <= (x_s + dx)): return False else: if (x < x_s) or (x >= (x_s + dx)): return False if dy < 0: if (y > y_s) or (y <= (y_s + dy)): return False else: if (y < y_s) or (y >= (y_s + dy)): return False return True def gilbert_xy2d_r(cur_idx, x_dst, y_dst, x, y, ax, ay, bx, by): w = abs(ax + ay) h = abs(bx + by) (dax, day) = (sgn(ax), sgn(ay)) # unit major direction (dbx, dby) = (sgn(bx), sgn(by)) # unit orthogonal direction dx = dax + dbx dy = day + dby if h == 1: if (dax==0): return cur_idx + (dy*(y_dst - y)) return cur_idx + (dx*(x_dst - x)) if w == 1: if (dbx==0): return cur_idx + (dy*(y_dst - y)) return cur_idx + (dx*(x_dst - x)) (ax2, ay2) = (ax//2, ay//2) (bx2, by2) = (bx//2, by//2) w2 = abs(ax2 + ay2) h2 = abs(bx2 + by2) if 2*w > 3*h: if (w2 % 2) and (w > 2): # prefer even steps (ax2, ay2) = (ax2 + dax, ay2 + day) if in_bounds(x_dst, y_dst, x, y, ax2, ay2, bx, by): return gilbert_xy2d_r(cur_idx, x_dst, y_dst, x, y, ax2, ay2, bx, by) cur_idx += abs((ax2 + ay2)*(bx + by)) return gilbert_xy2d_r(cur_idx, x_dst, y_dst, x+ax2, y+ay2, ax-ax2, ay-ay2, bx, by) else: if (h2 % 2) and (h > 2): # prefer even steps (bx2, by2) = (bx2 + dbx, by2 + dby) # standard case: one step up, one long horizontal, one step down if in_bounds(x_dst, y_dst, x, y, bx2, by2, ax2, ay2): return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x,y, bx2,by2, ax2,ay2) cur_idx += abs((bx2 + by2)*(ax2 + ay2)) if in_bounds(x_dst, y_dst, x+bx2, y+by2, ax, ay, bx-bx2, by-by2): return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x+bx2,y+by2, ax,ay, bx-bx2,by-by2) cur_idx += abs((ax + ay)*((bx - bx2) + (by - by2))) return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), -bx2, -by2, -(ax-ax2), -(ay-ay2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) args = parser.parse_args() width = args.width height = args.height for x in range(width): for y in range(height): idx = gilbert_xy2d(x,y, width,height) print(idx,x,y) ================================================ FILE: models/layers/gilbert/gilbert_xyz2d.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2024 abetusk def gilbert_xyz2d(x, y, z, width, height, depth): """ Generalized Hilbert ('Gilbert') space-filling curve for arbitrary-sized 3D rectangular grids. Generates discrete 3D coordinates to fill a cuboid of size (width x height x depth). Even sizes are recommended in 3D. """ if width >= height and width >= depth: return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, width, 0, 0, 0, height, 0, 0, 0, depth) elif height >= width and height >= depth: return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, 0, height, 0, width, 0, 0, 0, 0, depth) else: # depth >= width and depth >= height return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, 0, 0, depth, width, 0, 0, 0, height, 0) def sgn(x): return -1 if x < 0 else (1 if x > 0 else 0) def in_bounds(x, y, z, x_s, y_s, z_s, ax, ay, az, bx, by, bz, cx, cy, cz): dx = ax + bx + cx dy = ay + by + cy dz = az + bz + cz if dx < 0: if (x > x_s) or (x <= (x_s + dx)): return False else: if (x < x_s) or (x >= (x_s + dx)): return False if dy < 0: if (y > y_s) or (y <= (y_s + dy)): return False else: if (y < y_s) or (y >= (y_s + dy)): return False if dz <0: if (z > z_s) or (z <= (z_s + dz)): return False else: if (z < z_s) or (z >= (z_s + dz)): return False return True def gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, ax, ay, az, bx, by, bz, cx, cy, cz): w = abs(ax + ay + az) h = abs(bx + by + bz) d = abs(cx + cy + cz) (dax, day, daz) = (sgn(ax), sgn(ay), sgn(az)) # unit major direction ("right") (dbx, dby, dbz) = (sgn(bx), sgn(by), sgn(bz)) # unit ortho direction ("forward") (dcx, dcy, dcz) = (sgn(cx), sgn(cy), sgn(cz)) # unit ortho direction ("up") # trivial row/column fills if h == 1 and d == 1: return cur_idx + (dax*(x_dst - x)) + (day*(y_dst - y)) + (daz*(z_dst - z)) if w == 1 and d == 1: return cur_idx + (dbx*(x_dst - x)) + (dby*(y_dst - y)) + (dbz*(z_dst - z)) if w == 1 and h == 1: return cur_idx + (dcx*(x_dst - x)) + (dcy*(y_dst - y)) + (dcz*(z_dst - z)) (ax2, ay2, az2) = (ax//2, ay//2, az//2) (bx2, by2, bz2) = (bx//2, by//2, bz//2) (cx2, cy2, cz2) = (cx//2, cy//2, cz//2) w2 = abs(ax2 + ay2 + az2) h2 = abs(bx2 + by2 + bz2) d2 = abs(cx2 + cy2 + cz2) # prefer even steps if (w2 % 2) and (w > 2): (ax2, ay2, az2) = (ax2 + dax, ay2 + day, az2 + daz) if (h2 % 2) and (h > 2): (bx2, by2, bz2) = (bx2 + dbx, by2 + dby, bz2 + dbz) if (d2 % 2) and (d > 2): (cx2, cy2, cz2) = (cx2 + dcx, cy2 + dcy, cz2 + dcz) # wide case, split in w only if (2*w > 3*h) and (2*w > 3*d): if in_bounds(x_dst,y_dst,z_dst, x,y,z, ax2,ay2,az2, bx,by,bz, cx,cy,cz): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, ax2, ay2, az2, bx, by, bz, cx, cy, cz) cur_idx += abs( (ax2 + ay2 + az2)*(bx + by + bz)*(cx + cy + cz) ) return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+ax2, y+ay2, z+az2, ax-ax2, ay-ay2, az-az2, bx, by, bz, cx, cy, cz) # do not split in d elif 3*h > 4*d: if in_bounds(x_dst,y_dst,z_dst, x,y,z, bx2,by2,bz2, cx,cy,cz, ax2,ay2,az2): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, bx2, by2, bz2, cx, cy, cz, ax2, ay2, az2) cur_idx += abs( (bx2 + by2 + bz2)*(cx + cy + cz)*(ax2 + ay2 + az2) ) if in_bounds(x_dst,y_dst,z_dst, x+bx2,y+by2,z+bz2, ax,ay,az, bx-bx2,by-by2,bz-bz2, cx,cy,cz): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, ax, ay, az, bx-bx2, by-by2, bz-bz2, cx, cy, cz) cur_idx += abs( (ax + ay + az)*((bx - bx2) + (by - by2) + (bz - bz2))*(cx + cy + cz) ) return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx, cy, cz, -(ax-ax2), -(ay-ay2), -(az-az2)) # do not split in h elif 3*d > 4*h: if in_bounds(x_dst,y_dst,z_dst, x,y,z, cx2,cy2,cz2, ax2,ay2,az2, bx,by,bz): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, cx2, cy2, cz2, ax2, ay2, az2, bx, by, bz) cur_idx += abs( (cx2 + cy2 + cz2)*(ax2 + ay2 + az2)*(bx + by + bz) ) if in_bounds(x_dst,y_dst,z_dst, x+cx2,y+cy2,z+cz2, ax,ay,az, bx,by,bz, cx-cx2,cy-cy2,cz-cz2): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+cx2, y+cy2, z+cz2, ax, ay, az, bx, by, bz, cx-cx2, cy-cy2, cz-cz2) cur_idx += abs( (ax + ay + az)*(bx + by + bz)*((cx - cx2) + (cy - cy2) + (cz - cz2)) ) return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(cx2-dcx), y+(ay-day)+(cy2-dcy), z+(az-daz)+(cz2-dcz), -cx2, -cy2, -cz2, -(ax-ax2), -(ay-ay2), -(az-az2), bx, by, bz) # regular case, split in all w/h/d if in_bounds(x_dst,y_dst,z_dst, x,y,z, bx2,by2,bz2, cx2,cy2,cz2, ax2,ay2,az2): return gilbert_xyz2d_r(cur_idx,x_dst,y_dst,z_dst, x, y, z, bx2, by2, bz2, cx2, cy2, cz2, ax2, ay2, az2) cur_idx += abs( (bx2 + by2 + bz2)*(cx2 + cy2 + cz2)*(ax2 + ay2 + az2) ) if in_bounds(x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2) cur_idx += abs( (cx + cy + cz)*(ax2 + ay2 + az2)*((bx - bx2) + (by - by2) + (bz - bz2)) ) if in_bounds(x_dst,y_dst,z_dst, x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)) cur_idx += abs( (ax + ay + az)*(-bx2 - by2 - bz2)*(-(cx - cx2) - (cy - cy2) - (cz - cz2)) ) if in_bounds(x_dst,y_dst,z_dst, x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2): return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2) cur_idx += abs( (-cx - cy - cz)*(-(ax - ax2) - (ay - ay2) - (az - az2))*((bx - bx2) + (by - by2) + (bz - bz2)) ) return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx2, cy2, cz2, -(ax-ax2), -(ay-ay2), -(az-az2)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument('width', type=int) parser.add_argument('height', type=int) parser.add_argument('depth', type=int) args = parser.parse_args() w = args.width h = args.height d = args.depth n = w*h*d for x in range(w): for y in range(h): for z in range(d): idx = gilbert_xyz2d(x,y,z,w,h,d) print(idx,x,y,z) ================================================ FILE: models/layers/gilbert/plotpath.m ================================================ # Octave helper function to plot a 2D or 3D colored curve function h = plotpath(P) x = P(:,1)'; y = P(:,2)'; if (size(P,2) >= 3) z = P(:,3)'; else z = zeros(size(x)); endif col = 1:size(x,2); colormap jet; h = surface([x;x],[y;y],[z;z],[col;col],... 'facecolor','none',... 'edgecolor','interp',... 'linewidth',2); ================================================ FILE: models/layers/gilbert/ports/Makefile ================================================ CC := gcc CFLAGS := OPT := -O3 SRCFILES := gilbert.c all: gilbert gilbert: gilbert.c $(CC) gilbert.c -o gilbert $(CFLAGS) $(OPT) .PHONY: clean clean: rm -f gilbert ================================================ FILE: models/layers/gilbert/ports/gilbert.c ================================================ // SPDX-License-Identifier: BSD-2-Clause // Copyright (c) 2024 abetusk #include #include #include int gilbert_d2xy_r(int dst_idx, int cur_idx, int *xres, int *yres, int ax,int ay, int bx,int by ); int gilbert_xy2d_r(int cur_idx, int x_dst, int y_dst, int x, int y, int ax, int ay, int bx,int by ); int gilbert_xy2d(int x, int y, int w, int h) { if (w >= h) { return gilbert_xy2d_r(0, x,y, 0,0, w,0, 0,h); } return gilbert_xy2d_r(0, x,y, 0,0, 0,h, w,0); } int gilbert_d2xy(int *x, int *y, int idx,int w,int h) { *x = 0; *y = 0; if (w >= h) { return gilbert_d2xy_r(idx,0, x,y, w,0, 0,h); } return gilbert_d2xy_r(idx,0, x,y, 0,h, w,0); } int gilbert_d2xyz_r(int dst_idx, int cur_idx, int *x, int *y, int *z, int ax, int ay, int az, int bx, int by, int bz, int cx, int cy, int cz); int gilbert_xyz2d_r(int cur_idx, int x_dst, int y_dst, int z_dst, int x, int y, int z, int ax, int ay, int az, int bx, int by, int bz, int cx, int cy, int cz); int gilbert_xyz2d(int x, int y, int z, int width, int height, int depth) { if ((width >= height) && (width >= depth)) { return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, width, 0, 0, 0, height, 0, 0, 0, depth); } else if ((height >= width) && (height >= depth)) { return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, 0, height, 0, width, 0, 0, 0, 0, depth); } // depth >= width and depth >= height return gilbert_xyz2d_r(0,x,y,z, 0, 0, 0, 0, 0, depth, width, 0, 0, 0, height, 0); } int gilbert_d2xyz(int *x, int *y, int *z, int idx, int width, int height, int depth) { *x = 0; *y = 0; *z = 0; if ((width >= height) && (width >= depth)) { return gilbert_d2xyz_r(idx, 0, x,y,z, width, 0, 0, 0, height, 0, 0, 0, depth); } else if ((height >= width) && (height >= depth)) { return gilbert_d2xyz_r(idx, 0, x,y,z, 0, height, 0, width, 0, 0, 0, 0, depth); } // depth >= width and depth >= height return gilbert_d2xyz_r(idx, 0, x,y,z, 0, 0, depth, width, 0, 0, 0, height, 0); } static int sgn(int x) { if (x < 0) { return -1; } if (x > 0) { return 1; } return 0; } int in_bounds2(int x, int y, int x_s,int y_s, int ax, int ay, int bx, int by) { int dx, dy; dx = ax + bx; dy = ay + by; if (dx < 0) { if ((x > x_s) || (x <= (x_s + dx))) { return 0; } } else { if ((x < x_s) || (x >= (x_s + dx))) { return 0; } } if (dy < 0) { if ((y > y_s) || (y <= (y_s + dy))) { return 0; } } else { if ((y < y_s) || (y >= (y_s + dy))) { return 0; } } return 1; } int in_bounds3(int x, int y, int z, int x_s,int y_s,int z_s, int ax, int ay, int az, int bx, int by, int bz, int cx, int cy, int cz) { int dx, dy, dz; dx = ax + bx + cx; dy = ay + by + cy; dz = az + bz + cz; if (dx < 0) { if ((x > x_s) || (x <= (x_s + dx))) { return 0; } } else { if ((x < x_s) || (x >= (x_s + dx))) { return 0; } } if (dy < 0) { if ((y > y_s) || (y <= (y_s + dy))) { return 0; } } else { if ((y < y_s) || (y >= (y_s + dy))) { return 0; } } if (dz < 0) { if ((z > z_s) || (z <= (z_s + dz))) { return 0; } } else { if ((z < z_s) || (z >= (z_s + dz))) { return 0; } } return 1; } int gilbert_d2xy_r(int dst_idx, int cur_idx, int *xres, int *yres, int ax,int ay, int bx,int by ) { static int max_iter = 0; int nxt_idx; int w, h, x, y, dax, day, dbx, dby, di; int ax2, ay2, bx2, by2, w2, h2; if (max_iter > 100000) { return -1; } max_iter++; w = abs(ax + ay); h = abs(bx + by); x = *xres; y = *yres; // unit major direction dax = sgn(ax); day = sgn(ay); // unit orthogonal direction dbx = sgn(bx); dby = sgn(by); di = dst_idx - cur_idx; if (h == 1) { *xres = x + dax*di; *yres = y + day*di; return 0; } if (w == 1) { *xres = x + dbx*di; *yres = y + dby*di; return 0; } // floor function ax2 = (int)floor((double)ax/2.0); ay2 = (int)floor((double)ay/2.0); bx2 = (int)floor((double)bx/2.0); by2 = (int)floor((double)by/2.0); w2 = abs(ax2 + ay2); h2 = abs(bx2 + by2); if ((2*w) > (3*h)) { if ((w2 % 2) && (w > 2)) { // prefer even steps ax2 += dax; ay2 += day; } // long case: split in two parts only nxt_idx = cur_idx + abs((ax2 + ay2)*(bx + by)); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax2, ay2, bx, by); } cur_idx = nxt_idx; *xres = x + ax2; *yres = y + ay2; return gilbert_d2xy_r(dst_idx, cur_idx, xres, yres, ax-ax2, ay-ay2, bx, by); } if ((h2 % 2) && (h > 2)) { // prefer even steps bx2 += dbx; by2 += dby; } // standard case: one step up, one long horizontal, one step down nxt_idx = cur_idx + abs((bx2 + by2)*(ax2 + ay2)); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; return gilbert_d2xy_r(dst_idx, cur_idx, xres,yres, bx2,by2, ax2,ay2); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs((ax + ay)*((bx - bx2) + (by - by2))); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + bx2; *yres = y + by2; return gilbert_d2xy_r(dst_idx, cur_idx, xres,yres, ax,ay, bx-bx2,by-by2); } cur_idx = nxt_idx; *xres = x + (ax - dax) + (bx2 - dbx); *yres = y + (ay - day) + (by2 - dby); return gilbert_d2xy_r(dst_idx, cur_idx, xres,yres, -bx2, -by2, -(ax-ax2), -(ay-ay2)); } int gilbert_xy2d_r(int cur_idx, int x_dst, int y_dst, int x, int y, int ax, int ay, int bx,int by ) { int dax, day, dbx, dby, ax2, ay2, bx2, by2; int w, h, w2, h2; int dx, dy; w = abs(ax + ay); h = abs(bx + by); // unit major direction dax = sgn(ax); day = sgn(ay); // unit orthogonal direction dbx = sgn(bx); dby = sgn(by); dx = dax + dbx; dy = day + dby; if (h == 1) { if (dax == 0) { return cur_idx + (dy*(y_dst - y)); } return cur_idx + (dx*(x_dst - x)); } if (w == 1) { if (dbx == 0) { return cur_idx + (dy*(y_dst - y)); } return cur_idx + (dx*(x_dst - x)); } ax2 = (int)floor((double)ax/2.0); ay2 = (int)floor((double)ay/2.0); bx2 = (int)floor((double)bx/2.0); by2 = (int)floor((double)by/2.0); w2 = abs(ax2 + ay2); h2 = abs(bx2 + by2); if ((2*w) > (3*h)) { if ((w2 % 2) && (w > 2)) { // prefer even steps ax2 += dax; ay2 += day; } if (in_bounds2( x_dst, y_dst, x,y, ax2,ay2, bx,by )) { return gilbert_xy2d_r(cur_idx, x_dst, y_dst, x, y, ax2, ay2, bx, by); } cur_idx += abs((ax2 + ay2)*(bx + by)); return gilbert_xy2d_r(cur_idx, x_dst, y_dst, x+ax2, y+ay2, ax-ax2, ay-ay2, bx, by); } if ((h2 % 2) && (h > 2)) { // prefer even steps bx2 += dbx; by2 += dby; } // standard case: one step up, one long horizontal, one step down if (in_bounds2( x_dst,y_dst, x,y, bx2,by2, ax2,ay2 )) { return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x,y, bx2,by2, ax2,ay2); } cur_idx += abs((bx2 + by2)*(ax2 + ay2)); if (in_bounds2( x_dst,y_dst, x+bx2,y+by2, ax,ay, bx-bx2,by-by2)) { return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x+bx2,y+by2, ax,ay, bx-bx2,by-by2); } cur_idx += abs((ax + ay)*((bx - bx2) + (by - by2))); return gilbert_xy2d_r(cur_idx, x_dst,y_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), -bx2, -by2, -(ax-ax2), -(ay-ay2)); } int gilbert_d2xyz_r(int dst_idx, int cur_idx, int *xres, int *yres, int *zres, int ax, int ay, int az, int bx, int by, int bz, int cx, int cy, int cz) { int x, y, z; int _dx, _dy, _dz, _di; int nxt_idx; int w, h, d; int w2, h2, d2; int dax, day, daz, dbx, dby, dbz, dcx, dcy, dcz; int ax2, ay2, az2, bx2, by2, bz2, cx2, cy2, cz2; x = *xres; y = *yres; z = *zres; w = abs(ax + ay + az); h = abs(bx + by + bz); d = abs(cx + cy + cz); dax = sgn(ax); day = sgn(ay); daz = sgn(az); // unit major direction "right" dbx = sgn(bx); dby = sgn(by); dbz = sgn(bz); // unit ortho direction "forward" dcx = sgn(cx); dcy = sgn(cy); dcz = sgn(cz); // unit ortho direction "up" _dx = dax + dbx + dcx; _dy = day + dby + dcy; _dz = daz + dbz + dcz; _di = dst_idx - cur_idx; // trivial row/column fills if ((h == 1) && (d == 1)) { *xres = x + dax*_di; *yres = y + day*_di; *zres = z + daz*_di; return 0; } if ((w == 1) && (d == 1)) { *xres = x + dbx*_di; *yres = y + dby*_di; *zres = z + dbz*_di; return 0; } if ((w == 1) && (h == 1)) { *xres = x + dcx*_di; *yres = y + dcy*_di; *zres = z + dcz*_di; return 0; } ax2 = (int)floor((double)ax/2.0); ay2 = (int)floor((double)ay/2.0); az2 = (int)floor((double)az/2.0); bx2 = (int)floor((double)bx/2.0); by2 = (int)floor((double)by/2.0); bz2 = (int)floor((double)bz/2.0); cx2 = (int)floor((double)cx/2.0); cy2 = (int)floor((double)cy/2.0); cz2 = (int)floor((double)cz/2.0); w2 = abs(ax2 + ay2 + az2); h2 = abs(bx2 + by2 + bz2); d2 = abs(cx2 + cy2 + cz2); // prefer even steps if ((w2 % 2) && (w > 2)) { ax2 += dax; ay2 += day; az2 += daz; } if ((h2 % 2) && (h > 2)) { bx2 += dbx; by2 += dby; bz2 += dbz; } if ((d2 % 2) && (d > 2)) { cx2 += dcx; cy2 += dcy; cz2 += dcz; } // wide case, split in w only if (((2*w) > (3*h)) && ((2*w) > (3*d))) { nxt_idx = cur_idx + abs( (ax2 + ay2 + az2)*(bx + by + bz)*(cx + cy + cz) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; *zres = z; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, ax2, ay2, az2, bx, by, bz, cx, cy, cz); } cur_idx = nxt_idx; *xres = x + ax2; *yres = y + ay2; *zres = z + az2; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, ax-ax2, ay-ay2, az-az2, bx, by, bz, cx, cy, cz); } // do not split in d else if ((3*h) > (4*d)) { nxt_idx = cur_idx + abs( (bx2 + by2 + bz2)*(cx + cy + cz)*(ax2 + ay2 + az2) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; *zres = z; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, bx2, by2, bz2, cx, cy, cz, ax2, ay2, az2); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs( (ax + ay + az)*((bx - bx2) + (by - by2) + (bz - bz2))*(cx + cy + cz) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + bx2; *yres = y + by2; *zres = z + bz2; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, ax, ay, az, bx-bx2, by-by2, bz-bz2, cx, cy, cz); } cur_idx = nxt_idx; *xres = x + (ax - dax) + (bx2 - dbx); *yres = y + (ay - day) + (by2 - dby); *zres = z + (az - daz) + (bz2 - dbz); return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, -bx2, -by2, -bz2, cx, cy, cz, -(ax-ax2), -(ay-ay2), -(az-az2)); } // do not split in h else if ((3*d) > (4*h)) { nxt_idx = cur_idx + abs( (cx2 + cy2 + cz2)*(ax2 + ay2 + az2)*(bx + by + bz) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; *zres = z; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, cx2, cy2, cz2, ax2, ay2, az2, bx, by, bz); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs( (ax + ay + az)*(bx + by + bz)*((cx-cx2) + (cy-cy2) + (cz-cz2)) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + cx2; *yres = y + cy2; *zres = z + cz2; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, ax, ay, az, bx, by, bz, cx-cx2, cy-cy2, cz-cz2); } cur_idx = nxt_idx; *xres = x + (ax - dax) + (cx2 - dcx); *yres = y + (ay - day) + (cy2 - dcy); *zres = z + (az - daz) + (cz2 - dcz); return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, -cx2, -cy2, -cz2, -(ax-ax2), -(ay-ay2), -(az-az2), bx, by, bz); } // regular case, split in all w/h/d nxt_idx = cur_idx + abs( (bx2 + by2 + bz2)*(cx2 + cy2 + cz2)*(ax2 + ay2 + az2) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x; *yres = y; *zres = z; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, bx2, by2, bz2, cx2, cy2, cz2, ax2, ay2, az2); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs( (cx + cy + cz)*(ax2 + ay2 + az2)*((bx-bx2) + (by-by2) + (bz-bz2)) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + bx2; *yres = y + by2; *zres = z + bz2; return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs( (ax + ay + az)*( -bx2 - by2 - bz2)*( -(cx - cx2) - (cy - cy2) - (cz - cz2)) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + (bx2 - dbx) + (cx - dcx); *yres = y + (by2 - dby) + (cy - dcy); *zres = z + (bz2 - dbz) + (cz - dcz); return gilbert_d2xyz_r(dst_idx, cur_idx, xres,yres,zres, ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)); } cur_idx = nxt_idx; nxt_idx = cur_idx + abs( ( -cx - cy - cz)*( -(ax - ax2) - (ay - ay2) - (az - az2))*((bx - bx2) + (by - by2) + (bz - bz2)) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { *xres = x + (ax - dax) + bx2 + (cx - dcx); *yres = y + (ay - day) + by2 + (cy - dcy); *zres = z + (az - daz) + bz2 + (cz - dcz); return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2); } cur_idx = nxt_idx; *xres = x + (ax - dax) + (bx2 - dbx); *yres = y + (ay - day) + (by2 - dby); *zres = z + (az - daz) + (bz2 - dbz); return gilbert_d2xyz_r(dst_idx,cur_idx, xres,yres,zres, -bx2, -by2, -bz2, cx2, cy2, cz2, -(ax-ax2), -(ay-ay2), -(az-az2)); } int gilbert_xyz2d_r(int cur_idx, int x_dst, int y_dst, int z_dst, int x, int y, int z, int ax, int ay, int az, int bx, int by, int bz, int cx, int cy, int cz) { int w, h, d; int w2, h2, d2; int dax, day, daz, dbx, dby, dbz, dcx, dcy, dcz; int ax2, ay2, az2, bx2, by2, bz2, cx2, cy2, cz2; w = abs(ax + ay + az); h = abs(bx + by + bz); d = abs(cx + cy + cz); dax = sgn(ax); day = sgn(ay); daz = sgn(az); // unit major direction ("right") dbx = sgn(bx); dby = sgn(by); dbz = sgn(bz); // unit ortho direction ("forward") dcx = sgn(cx); dcy = sgn(cy); dcz = sgn(cz); // unit ortho direction ("up") // trivial row/column fills if ((h == 1) && (d == 1)) { return cur_idx + (dax*(x_dst - x)) + (day*(y_dst - y)) + (daz*(z_dst - z)); } if ((w == 1) && (d == 1)) { return cur_idx + (dbx*(x_dst - x)) + (dby*(y_dst - y)) + (dbz*(z_dst - z)); } if ((w == 1) && (h == 1)) { return cur_idx + (dcx*(x_dst - x)) + (dcy*(y_dst - y)) + (dcz*(z_dst - z)); } ax2 = (int)floor((double)ax/2.0); ay2 = (int)floor((double)ay/2.0); az2 = (int)floor((double)az/2.0); bx2 = (int)floor((double)bx/2.0); by2 = (int)floor((double)by/2.0); bz2 = (int)floor((double)bz/2.0); cx2 = (int)floor((double)cx/2.0); cy2 = (int)floor((double)cy/2.0); cz2 = (int)floor((double)cz/2.0); w2 = abs(ax2 + ay2 + az2); h2 = abs(bx2 + by2 + bz2); d2 = abs(cx2 + cy2 + cz2); // prefer even steps if ((w2 % 2) && (w > 2)) { ax2 += dax; ay2 += day; az2 += daz; } if ((h2 % 2) && (h > 2)) { bx2 += dbx; by2 += dby; bz2 += dbz; } if ((d2 % 2) && (d > 2)) { cx2 += dcx; cy2 += dcy; cz2 += dcz; } // wide case, split in w only if ((2*w > 3*h) && (2*w > 3*d)) { if (in_bounds3(x_dst,y_dst,z_dst, x,y,z, ax2,ay2,az2, bx,by,bz, cx,cy,cz)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, ax2, ay2, az2, bx, by, bz, cx, cy, cz); } cur_idx += abs( (ax2 + ay2 + az2)*(bx + by + bz)*(cx + cy + cz) ); return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+ax2, y+ay2, z+az2, ax-ax2, ay-ay2, az-az2, bx, by, bz, cx, cy, cz); } // do not split in d else if ((3*h) > (4*d)) { if (in_bounds3(x_dst,y_dst,z_dst, x,y,z, bx2,by2,bz2, cx,cy,cz, ax2,ay2,az2)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, bx2, by2, bz2, cx, cy, cz, ax2, ay2, az2); } cur_idx += abs( (bx2 + by2 + bz2)*(cx + cy + cz)*(ax2 + ay2 + az2) ); if (in_bounds3(x_dst,y_dst,z_dst, x+bx2,y+by2,z+bz2, ax,ay,az, bx-bx2,by-by2,bz-bz2, cx,cy,cz)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, ax, ay, az, bx-bx2, by-by2, bz-bz2, cx, cy, cz); } cur_idx += abs( (ax + ay + az)*((bx - bx2) + (by - by2) + (bz - bz2))*(cx + cy + cz) ); return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx, cy, cz, -(ax-ax2), -(ay-ay2), -(az-az2)); } // do not split in h else if ((3*d) > (4*h)) { if (in_bounds3(x_dst,y_dst,z_dst, x,y,z, cx2,cy2,cz2, ax2,ay2,az2, bx,by,bz)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x, y, z, cx2, cy2, cz2, ax2, ay2, az2, bx, by, bz); } cur_idx += abs( (cx2 + cy2 + cz2)*(ax2 + ay2 + az2)*(bx + by + bz) ); if (in_bounds3(x_dst,y_dst,z_dst, x+cx2,y+cy2,z+cz2, ax,ay,az, bx,by,bz, cx-cx2,cy-cy2,cz-cz2)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+cx2, y+cy2, z+cz2, ax, ay, az, bx, by, bz, cx-cx2, cy-cy2, cz-cz2); } cur_idx += abs( (ax + ay + az)*(bx + by + bz)*((cx - cx2) + (cy - cy2) + (cz - cz2)) ); return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(cx2-dcx), y+(ay-day)+(cy2-dcy), z+(az-daz)+(cz2-dcz), -cx2, -cy2, -cz2, -(ax-ax2), -(ay-ay2), -(az-az2), bx, by, bz); } // regular case, split in all w/h/d if (in_bounds3(x_dst,y_dst,z_dst, x,y,z, bx2,by2,bz2, cx2,cy2,cz2, ax2,ay2,az2)) { return gilbert_xyz2d_r(cur_idx,x_dst,y_dst,z_dst, x, y, z, bx2, by2, bz2, cx2, cy2, cz2, ax2, ay2, az2); } cur_idx += abs( (bx2 + by2 + bz2)*(cx2 + cy2 + cz2)*(ax2 + ay2 + az2) ); if (in_bounds3(x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+bx2, y+by2, z+bz2, cx, cy, cz, ax2, ay2, az2, bx-bx2, by-by2, bz-bz2); } cur_idx += abs( (cx + cy + cz)*(ax2 + ay2 + az2)*((bx - bx2) + (by - by2) + (bz - bz2)) ); if (in_bounds3(x_dst,y_dst,z_dst, x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2))) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(bx2-dbx)+(cx-dcx), y+(by2-dby)+(cy-dcy), z+(bz2-dbz)+(cz-dcz), ax, ay, az, -bx2, -by2, -bz2, -(cx-cx2), -(cy-cy2), -(cz-cz2)); } cur_idx += abs( (ax + ay + az)*(-bx2 - by2 - bz2)*(-(cx - cx2) - (cy - cy2) - (cz - cz2)) ); if (in_bounds3(x_dst,y_dst,z_dst, x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2)) { return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+bx2+(cx-dcx), y+(ay-day)+by2+(cy-dcy), z+(az-daz)+bz2+(cz-dcz), -cx, -cy, -cz, -(ax-ax2), -(ay-ay2), -(az-az2), bx-bx2, by-by2, bz-bz2); } cur_idx += abs( (-cx - cy - cz)*(-(ax - ax2) - (ay - ay2) - (az - az2))*((bx - bx2) + (by - by2) + (bz - bz2)) ); return gilbert_xyz2d_r(cur_idx, x_dst,y_dst,z_dst, x+(ax-dax)+(bx2-dbx), y+(ay-day)+(by2-dby), z+(az-daz)+(bz2-dbz), -bx2, -by2, -bz2, cx2, cy2, cz2, -(ax-ax2), -(ay-ay2), -(az-az2)); } #define GILBERT_MAIN #ifdef GILBERT_MAIN #include int main(int argc, char **argv) { int w, h, d; int x, y, z; int idx; char buf[1024]; w = 1; h = 1; d = 1; if (argc < 4) { printf("provide args\n"); printf("\n"); printf("usage:\n"); printf("\n"); printf(" gilbert [depth]\n"); printf("\n"); printf(" op - one of \"xy2d\",\"2dxy\",\"xyz2d\",\"d2xyz\"\n"); printf(" depth - default to 1 for 3D Gilbert with no depth specified\n"); printf("\n"); exit(-1); } strncpy(buf, argv[1], 1023); buf[1024]='\0'; w = atoi(argv[2]); h = atoi(argv[3]); if (argc > 4) { d = atoi(argv[4]); } if ((w <= 0) || (h <= 0) || (d <= 0)) { exit(-1); } if (strncmp("xy2d", buf, 1023) == 0) { for (x = 0; x < w; x++) { for (y = 0; y < h; y++) { idx = gilbert_xy2d( x, y, w, h ); printf("%i %i %i\n", idx, x, y); } } } else if (strncmp("d2xy", buf, 1023) == 0) { for (idx = 0; idx < (w*h); idx++) { gilbert_d2xy( &x, &y, idx, w, h ); printf("%i %i\n", x, y); } } else if (strncmp("xyz2d", buf, 1023) == 0) { for (x = 0; x < w; x++) { for (y = 0; y < h; y++) { for (z = 0; z < d; z++) { idx = gilbert_xyz2d( x,y,z, w,h,d ); printf("%i %i %i %i\n", idx, x, y, z); } } } } else if (strncmp("d2xyz", buf, 1023) == 0) { for (idx = 0; idx < (w*h*d); idx++) { gilbert_d2xyz( &x,&y,&z, idx, w,h,d ); printf("%i %i %i\n", x, y, z); } } exit(0); } #endif ================================================ FILE: models/layers/gilbert/ports/gilbert.js ================================================ // SPDX-License-Identifier: BSD-2-Clause // Copyright (c) 2024 abetusk "use strict"; var gilbert = { "xy2d": gilbert_xy2d, "d2xy": gilbert_d2xy, "xyz2d": gilbert_xyz2d, "d2xyz": gilbert_d2xyz, }; function sgn(x) { if (x < 0) { return -1; } if (x > 0) { return 1; } return 0; } function in_bounds2(p, s, a, b) { let d = { "x": a.x + b.x, "y": a.y + b.y }; if (d.x < 0) { if ((p.x > s.x) || (p.x <= (s.x + d.x))) { return false; } } else if ((p.x < s.x) || (p.x >= (s.x + d.x))) { return false; } if (d.y < 0) { if ((p.y > s.y) || (p.y <= (s.y + d.y))) { return false; } } else if ((p.y < s.y) || (p.y >= (s.y + d.y))) { return false; } return true; } function in_bounds3(p, s, a, b, c) { let d = { "x": a.x + b.x + c.x, "y": a.y + b.y + c.y, "z": a.z + b.z + c.z }; if (d.x < 0) { if ((p.x > s.x) || (p.x <= (s.x + d.x))) { return false; } } else if ((p.x < s.x) || (p.x >= (s.x + d.x))) { return false; } if (d.y < 0) { if ((p.y > s.y) || (p.y <= (s.y + d.y))) { return false; } } else if ((p.y < s.y) || (p.y >= (s.y + d.y))) { return false; } if (d.z < 0) { if ((p.z > s.z) || (p.z <= (s.z + d.z))) { return false; } } else if ((p.z < s.z) || (p.z >= (s.z + d.z))) { return false; } return true; } function gilbert_xy2d(x,y,w,h) { let _q = {"x": x, "y": y}; let _p = {"x": 0, "y": 0}; let _a = {"x": 0, "y": h}; let _b = {"x": w, "y": 0}; if (w >= h) { _a.x = w; _a.y = 0; _b.x = 0; _b.y = h; } return gilbert_xy2d_r(0, _q, _p, _a, _b); } function gilbert_d2xy(idx,w,h) { let _p = {"x": 0, "y": 0}; let _a = {"x": 0, "y": h}; let _b = {"x": w, "y": 0}; if (w >= h) { _a.x = w; _a.y = 0; _b.x = 0; _b.y = h; } return gilbert_d2xy_r(idx,0,_p,_a,_b); } function gilbert_xyz2d(x,y,z,w,h,d) { let _q = {"x": x, "y": y, "z": z}; let _p = {"x": 0, "y": 0, "z": 0}; let _a = {"x": w, "y": 0, "z": 0}; let _b = {"x": 0, "y": h, "z": 0}; let _c = {"x": 0, "y": 0, "z": d}; if ((w >= h) && (w >= d)) { return gilbert_xyz2d_r(0, _q, _p, _a, _b, _c); } else if ((h >= w) && (h >= d)) { return gilbert_xyz2d_r(0, _q, _p, _b, _a, _c); } return gilbert_xyz2d_r(0, _q, _p, _c, _a, _b); } function gilbert_d2xyz(idx,w,h,d) { let _p = {"x": 0, "y": 0, "z": 0}; let _a = {"x": w, "y": 0, "z": 0}; let _b = {"x": 0, "y": h, "z": 0}; let _c = {"x": 0, "y": 0, "z": d}; if ((w >= h) && (w >= d)) { return gilbert_d2xyz_r(idx, 0, _p, _a, _b, _c); } else if ((h >= w) && (h >= d)) { return gilbert_d2xyz_r(idx, 0, _p, _b, _a, _c); } return gilbert_d2xyz_r(idx, 0, _p, _c, _a, _b); } function gilbert_d2xy_r( dst_idx,cur_idx, p, a, b) { let _p = {}, _a = {}, _b = {}; let nxt_idx = -1; let w = Math.abs( a.x + a.y ); let h = Math.abs( b.x + b.y ); let da = { "x": sgn(a.x), "y": sgn(a.y) }; let db = { "x": sgn(b.x), "y": sgn(b.y) }; let d = { "x": da.x + db.x, "y": da.y + db.y, "i": dst_idx - cur_idx }; if (h == 1) { return { "x": p.x + da.x*d.i, "y": p.y + da.y*d.i }; } if (w == 1) { return {"x": p.x + db.x*d.i, "y": p.y + db.y*d.i }; } let a2 = { "x": Math.floor(a.x/2), "y": Math.floor(a.y/2) }; let b2 = { "x": Math.floor(b.x/2), "y": Math.floor(b.y/2) }; let w2 = Math.abs(a2.x + a2.y); let h2 = Math.abs(b2.x + b2.y); if ((2*w) > (3*h)) { // prefer even steps if ((w2 % 2) && (w > 2)) { a2.x += da.x; a2.y += da.y; } nxt_idx = cur_idx + Math.abs((a2.x + a2.y)*(b.x + b.y)); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xy_r(dst_idx,cur_idx, p, a2, b); } cur_idx = nxt_idx; _p = { "x": p.x + a2.x, "y": p.y + a2.y }; _a = { "x": a.x - a2.x, "y": a.y - a2.y }; return gilbert_d2xy_r(dst_idx,cur_idx, _p, _a, b); } // prefer event steps if ((h2 % 2) && (h > 2)) { b2.x += db.x; b2.y += db.y; } // standard case: one step up, on long horizontal, one step down nxt_idx = cur_idx + Math.abs((b2.x + b2.y)*(a2.x + a2.y)); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xy_r(dst_idx, cur_idx, p, b2, a2); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs((a.x + a.y)*((b.x - b2.x) + (b.y - b2.y))); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { _p = { "x": p.x + b2.x, "y": p.y + b2.y }; _b = { "x": b.x - b2.x, "y": b.y - b2.y }; return gilbert_d2xy_r(dst_idx, cur_idx, _p, a, _b); } cur_idx = nxt_idx; _p = { "x": p.x + (a.x - da.x) + (b2.x - db.x), "y": p.y + (a.y - da.y) + (b2.y - db.y) }; _a = { "x": -b2.x, "y": -b2.y }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y) }; return gilbert_d2xy_r(dst_idx, cur_idx, _p, _a, _b); } function gilbert_xy2d_r(idx, q, p, a, b) { let _p = {}, _a = {}, _b = {}; let w = Math.abs(a.x + a.y); let h = Math.abs(b.x + b.y); let da = { "x": sgn(a.x), "y": sgn(a.y) }; let db = { "x": sgn(b.x), "y": sgn(b.y) }; let d = {"x": da.x + db.x, "y": da.y + db.y }; if (h == 1) { return idx + (da.x*(q.x - p.x)) + (da.y*(q.y - p.y)); } if (w == 1) { return idx + (db.x*(q.x - p.x)) + (db.y*(q.y - p.y)); } let a2 = { "x": Math.floor(a.x/2), "y": Math.floor(a.y/2) }; let b2 = { "x": Math.floor(b.x/2), "y": Math.floor(b.y/2) }; let w2 = Math.abs(a2.x + a2.y); let h2 = Math.abs(b2.x + b2.y); if ((2*w) > (3*h)) { if ((w2 % 2) && (w > 2)) { a2.x += da.x; a2.y += da.y; } if (in_bounds2(q, p, a2, b)) { return gilbert_xy2d_r(idx, q, p, a2, b); } idx += Math.abs((a2.x + a2.y)*(b.x + b.y)); _p = { "x": p.x + a2.x, "y": p.y + a2.y }; _a = { "x": a.x - a2.x, "y": a.y - a2.y }; return gilbert_xy2d_r(idx, q, _p, _a, b); } if ((h2 % 2) && (h > 2)) { b2.x += db.x; b2.y += db.y; } if (in_bounds2(q, p, b2, a2)) { return gilbert_xy2d_r(idx, q, p, b2, a2); } idx += Math.abs((b2.x + b2.y)*(a2.x + a2.y)); _p = { "x": p.x + b2.x, "y": p.y + b2.y }; _b = { "x": b.x - b2.x, "y": b.y - b2.y }; if (in_bounds2(q, _p, a, _b)) { return gilbert_xy2d_r(idx, q, _p, a, _b); } idx += Math.abs((a.x + a.y)*((b.x - b2.x) + (b.y - b2.y))); _p = { "x" : p.x + (a.x - da.x) + (b2.x - db.x), "y" : p.y + (a.y - da.y) + (b2.y - db.y) }; _a = { "x": -b2.x, "y": -b2.y }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y) }; return gilbert_xy2d_r(idx, q, _p, _a, _b); } function gilbert_xyz2d_r(cur_idx, q, p, a, b, c) { let _p = {}, _a = {}, _b = {}, _c = {}; let w = Math.abs(a.x + a.y + a.z); let h = Math.abs(b.x + b.y + b.z); let d = Math.abs(c.x + c.y + c.z); let da = { "x": sgn(a.x), "y": sgn(a.y), "z": sgn(a.z) }; let db = { "x": sgn(b.x), "y": sgn(b.y), "z": sgn(b.z) }; let dc = { "x": sgn(c.x), "y": sgn(c.y), "z": sgn(c.z) }; // trivial row/column fills if ((h == 1) && (d == 1)) { return cur_idx + (da.x*(q.x - p.x)) + (da.y*(q.y - p.y)) + (da.z*(q.z - p.z)); } else if ((w == 1) && (d == 1)) { return cur_idx + (db.x*(q.x - p.x)) + (db.y*(q.y - p.y)) + (db.z*(q.z - p.z)); } else if ((w == 1) && (h == 1)) { return cur_idx + (dc.x*(q.x - p.x)) + (dc.y*(q.y - p.y)) + (dc.z*(q.z - p.z)); } let a2 = { "x": Math.floor(a.x/2), "y": Math.floor(a.y/2), "z": Math.floor(a.z/2) }; let b2 = { "x": Math.floor(b.x/2), "y": Math.floor(b.y/2), "z": Math.floor(b.z/2) }; let c2 = { "x": Math.floor(c.x/2), "y": Math.floor(c.y/2), "z": Math.floor(c.z/2) }; let w2 = Math.abs(a2.x + a2.y + a2.z); let h2 = Math.abs(b2.x + b2.y + b2.z); let d2 = Math.abs(c2.x + c2.y + c2.z); // prefer even steps if ((w2 % 2) && (w > 2)) { a2.x += da.x; a2.y += da.y; a2.z += da.z; } if ((h2 % 2) && (h > 2)) { b2.x += db.x; b2.y += db.y; b2.z += db.z; } if ((d2 % 2) && (d > 2)) { c2.x += dc.x; c2.y += dc.y; c2.z += dc.z; } // wide case, split in w only if ( ((2*w) > (3*h)) && ((2*w) > (3*d)) ) { if (in_bounds3(q, p, a2, b, c)) { return gilbert_xyz2d_r(cur_idx, q, p, a2, b, c); } cur_idx += Math.abs( (a2.x + a2.y + a2.z)*(b.x + b.y + b.z)*(c.x + c.y + c.z) ); _p = { "x": p.x + a2.x, "y": p.y + a2.y, "z": p.z + a2.z }; _a = { "x": a.x - a2.x, "y": a.y - a2.y, "z": a.z - a2.z }; return gilbert_xyz2d_r(cur_idx, q, _p, _a, b, c); } else if ((3*h) > (4*d)) { if (in_bounds3(q, p, b2, c, a2)) { return gilbert_xyz2d_r(cur_idx,q,p,b2,c,a2); } cur_idx += Math.abs( (b2.x + b2.y + b2.z)*(c.x + c.y + c.z)*(a2.x + a2.y + a2.z) ); _p = { "x": p.x + b2.x, "y": p.y + b2.y, "z": p.z + b2.z }; _b = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if (in_bounds3(q, _p, a, _b, c)) { return gilbert_xyz2d_r(cur_idx,q, _p, a, _b, c); } cur_idx += Math.abs( (a.x + a.y + a.z)*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z))*(c.x + c.y + c.z) ); _p = { "x": p.x + (a.x - da.x) + (b2.x - db.x), "y": p.y + (a.y - da.y) + (b2.y - db.y), "z": p.z + (a.z - da.z) + (b2.z - db.z) }; _a = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_xyz2d_r(cur_idx, q, _p, _a, c, _c); } else if ((3*d) > (4*h)) { if (in_bounds3(q, p, c2, a2, b)) { return gilbert_xyz2d_r(cur_idx,q,p,c2,a2,b); } cur_idx += Math.abs( (c2.x + c2.y + c2.z)*(a2.x + a2.y + a2.z)*(b.x + b.y + b.z) ); _p = { "x": p.x + c2.x, "y": p.y + c2.y, "z": p.z + c2.z }; _c = { "x": c.x - c2.x, "y": c.y - c2.y, "z": c.z - c2.z }; if (in_bounds3(q, _p, a, b, _c)) { return gilbert_xyz2d_r(cur_idx, q, _p, a, b, _c); } cur_idx += Math.abs( (a.x + a.y + a.z)*(b.x + b.y + b.z)*((c.x - c2.x) + (c.y - c2.y) + (c.z - c2.z)) ); _p = { "x": p.x + (a.x - da.x) + (c2.x - dc.x), "y": p.y + (a.y - da.y) + (c2.y - dc.y), "z": p.z + (a.z - da.z) + (c2.z - dc.z) } _a = { "x": -c2.x, "y": -c2.y, "z": -c2.z }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_xyz2d_r(cur_idx, q, _p, _a, _b, b); } // regular case, split in all w/h/d if (in_bounds3(q, p, b2, c2, a2)) { return gilbert_xyz2d_r(cur_idx,q,p,b2,c2,a2); } cur_idx += Math.abs( (b2.x + b2.y + b2.z)*(c2.x + c2.y + c2.z)*(a2.x + a2.y + a2.z) ); _p = { "x": p.x + b2.x, "y": p.y + b2.y, "z": p.z + b2.z }; _c = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if (in_bounds3(q, _p, c, a2, _c)) { return gilbert_xyz2d_r(cur_idx, q, _p, c, a2, _c); } cur_idx += Math.abs( (c.x + c.y + c.z)*(a2.x + a2.y + a2.z)*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z)) ); _p = { "x" : p.x + (b2.x - db.x) + (c.x - dc.x), "y" : p.y + (b2.y - db.y) + (c.y - dc.y), "z" : p.z + (b2.z - db.z) + (c.z - dc.z) }; _b = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(c.x - c2.x), "y": -(c.y - c2.y), "z": -(c.z - c2.z) }; if (in_bounds3(q, _p, a, _b, _c)) { return gilbert_xyz2d_r(cur_idx, q, _p, a, _b, _c); } cur_idx += Math.abs( (a.x + a.y + a.z)*( -b2.x - b2.y - b2.z)*( -(c.x - c2.x) - (c.y - c2.y) - (c.z - c2.z)) ); _p = { "x": p.x + (a.x - da.x) + b2.x + (c.x - dc.x), "y": p.y + (a.y - da.y) + b2.y + (c.y - dc.y), "z": p.z + (a.z - da.z) + b2.z + (c.z - dc.z) }; _a = { "x": -c.x, "y": -c.y, "z": -c.z }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; _c = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if (in_bounds3(q, _p, _a, _b, _c)) { return gilbert_xyz2d_r(cur_idx, q, _p, _a, _b, _c); } cur_idx += Math.abs( ( -c.x - c.y - c.z)*( -(a.x - a2.x) - (a.y - a2.y) - (a.z - a2.z))*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z)) ); _p = { "x": p.x + (a.x - da.x) + (b2.x - db.x), "y": p.y + (a.y - da.y) + (b2.y - db.y), "z": p.z + (a.z - da.z) + (b2.z - db.z) }; _a = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_xyz2d_r(cur_idx, q, _p, _a, c2, _c); } function gilbert_d2xyz_r(dst_idx, cur_idx, p, a, b, c) { let _p = {}, _a = {}, _b = {}, _c = {}; let nxt_idx = -1; let w = Math.abs(a.x + a.y + a.z); let h = Math.abs(b.x + b.y + b.z); let d = Math.abs(c.x + c.y + c.z); let da = { "x": sgn(a.x), "y": sgn(a.y), "z": sgn(a.z) }; let db = { "x": sgn(b.x), "y": sgn(b.y), "z": sgn(b.z) }; let dc = { "x": sgn(c.x), "y": sgn(c.y), "z": sgn(c.z) }; let di = dst_idx - cur_idx; // trivial row/column fills if ((h == 1) && (d == 1)) { return { "x": p.x + da.x*di, "y": p.y + da.y*di, "z": p.z + da.z*di }; } else if ((w == 1) && (d == 1)) { return { "x": p.x + db.x*di, "y": p.y + db.y*di, "z": p.z + db.z*di }; } else if ((w == 1) && (h == 1)) { return { "x": p.x + dc.x*di, "y": p.y + dc.y*di, "z": p.z + dc.z*di }; } let a2 = { "x": Math.floor(a.x/2), "y": Math.floor(a.y/2), "z": Math.floor(a.z/2) }; let b2 = { "x": Math.floor(b.x/2), "y": Math.floor(b.y/2), "z": Math.floor(b.z/2) }; let c2 = { "x": Math.floor(c.x/2), "y": Math.floor(c.y/2), "z": Math.floor(c.z/2) }; let w2 = Math.abs(a2.x + a2.y + a2.z); let h2 = Math.abs(b2.x + b2.y + b2.z); let d2 = Math.abs(c2.x + c2.y + c2.z); // prefer even steps if ((w2 % 2) && (w > 2)) { a2.x += da.x; a2.y += da.y; a2.z += da.z; } if ((h2 % 2) && (h > 2)) { b2.x += db.x; b2.y += db.y; b2.z += db.z; } if ((d2 % 2) && (d > 2)) { c2.x += dc.x; c2.y += dc.y; c2.z += dc.z; } // wide case, split in w only if ( ((2*w) > (3*h)) && ((2*w) > (3*d)) ) { nxt_idx = cur_idx + Math.abs( (a2.x + a2.y + a2.z)*(b.x + b.y + b.z)*(c.x + c.y + c.z) ); if ((cur_idx <= nxt_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx, cur_idx, p, a2, b, c); } cur_idx = nxt_idx; _p = { "x": p.x + a2.x, "y": p.y + a2.y, "z": p.z + a2.z }; _a = { "x": a.x - a2.x, "y": a.y - a2.y, "z": a.z - a2.z }; return gilbert_d2xyz_r(dst_idx, cur_idx, _p, _a, b, c); } else if ((3*h) > (4*d)) { nxt_idx = cur_idx + Math.abs( (b2.x + b2.y + b2.z)*(c.x + c.y + c.z)*(a2.x + a2.y + a2.z) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx,cur_idx,p,b2,c,a2); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs( (a.x + a.y + a.z)*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z))*(c.x + c.y + c.z) ); _p = { "x": p.x + b2.x, "y": p.y + b2.y, "z": p.z + b2.z }; _b = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx,cur_idx, _p, a, _b, c); } cur_idx = nxt_idx; _p = { "x": p.x + (a.x - da.x) + (b2.x - db.x), "y": p.y + (a.y - da.y) + (b2.y - db.y), "z": p.z + (a.z - da.z) + (b2.z - db.z) }; _a = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_d2xyz_r(dst_idx, cur_idx, _p, _a, c, _c); } else if ((3*d) > (4*h)) { nxt_idx = cur_idx + Math.abs( (c2.x + c2.y + c2.z)*(a2.x + a2.y + a2.z)*(b.x + b.y + b.z) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx,cur_idx,p,c2,a2,b); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs( (a.x + a.y + a.z)*(b.x + b.y + b.z)*((c.x - c2.x) + (c.y - c2.y) + (c.z - c2.z)) ); _p = { "x": p.x + c2.x, "y": p.y + c2.y, "z": p.z + c2.z }; _c = { "x": c.x - c2.x, "y": c.y - c2.y, "z": c.z - c2.z }; if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx, cur_idx, _p, a, b, _c); } cur_idx = nxt_idx; _p = { "x": p.x + (a.x - da.x) + (c2.x - dc.x), "y": p.y + (a.y - da.y) + (c2.y - dc.y), "z": p.z + (a.z - da.z) + (c2.z - dc.z) } _a = { "x": -c2.x, "y": -c2.y, "z": -c2.z }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_d2xyz_r(dst_idx, cur_idx, _p, _a, _b, b); } // regular case, split in all w/h/d nxt_idx = cur_idx + Math.abs( (b2.x + b2.y + b2.z)*(c2.x + c2.y + c2.z)*(a2.x + a2.y + a2.z) ); if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx,cur_idx,p,b2,c2,a2); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs( (c.x + c.y + c.z)*(a2.x + a2.y + a2.z)*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z)) ); _p = { "x": p.x + b2.x, "y": p.y + b2.y, "z": p.z + b2.z }; _c = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx, cur_idx, _p, c, a2, _c); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs( (a.x + a.y + a.z)*( -b2.x - b2.y - b2.z)*( -(c.x - c2.x) - (c.y - c2.y) - (c.z - c2.z)) ); _p = { "x" : p.x + (b2.x - db.x) + (c.x - dc.x), "y" : p.y + (b2.y - db.y) + (c.y - dc.y), "z" : p.z + (b2.z - db.z) + (c.z - dc.z) }; _b = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(c.x - c2.x), "y": -(c.y - c2.y), "z": -(c.z - c2.z) }; if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx, cur_idx, _p, a, _b, _c); } cur_idx = nxt_idx; nxt_idx = cur_idx + Math.abs( ( -c.x - c.y - c.z)*( -(a.x - a2.x) - (a.y - a2.y) - (a.z - a2.z))*((b.x - b2.x) + (b.y - b2.y) + (b.z - b2.z)) ); _p = { "x": p.x + (a.x - da.x) + b2.x + (c.x - dc.x), "y": p.y + (a.y - da.y) + b2.y + (c.y - dc.y), "z": p.z + (a.z - da.z) + b2.z + (c.z - dc.z) }; _a = { "x": -c.x, "y": -c.y, "z": -c.z }; _b = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; _c = { "x": b.x - b2.x, "y": b.y - b2.y, "z": b.z - b2.z }; if ((cur_idx <= dst_idx) && (dst_idx < nxt_idx)) { return gilbert_d2xyz_r(dst_idx, cur_idx, _p, _a, _b, _c); } cur_idx = nxt_idx; _p = { "x": p.x + (a.x - da.x) + (b2.x - db.x), "y": p.y + (a.y - da.y) + (b2.y - db.y), "z": p.z + (a.z - da.z) + (b2.z - db.z) }; _a = { "x": -b2.x, "y": -b2.y, "z": -b2.z }; _c = { "x": -(a.x - a2.x), "y": -(a.y - a2.y), "z": -(a.z - a2.z) }; return gilbert_d2xyz_r(dst_idx, cur_idx, _p, _a, c2, _c); } if (typeof module !== "undefined") { module.exports["d2xy"] = gilbert_xy2d; module.exports["xy2d"] = gilbert_d2xy; module.exports["d2xyz"] = gilbert_xyz2d; module.exports["xyz2d"] = gilbert_d2xyz; module.exports["main"] = _main; function _main(argv) { if (argv.length < 4) { console.log("provide args"); process.exit(-1); } let op = argv[1]; let w = parseInt(argv[2]); let h = parseInt(argv[3]); let d = 1; if (argv.length > 4) { d = parseInt(argv[4]); } if (op == "xy2d") { for (let x = 0; x < w; x++) { for (let y = 0; y < h; y++) { let idx = gilbert_xy2d(x,y,w,h); console.log(idx, x, y); } } } else if (op == "d2xy") { let n = w*h; for (let idx = 0; idx < n; idx++) { let xy = gilbert_d2xy(idx,w,h); console.log(xy.x,xy.y); } } else if (op == "xyz2d") { for (let x = 0; x < w; x++) { for (let y = 0; y < h; y++) { for (let z = 0; z < d; z++) { let idx = gilbert_xyz2d(x,y,z,w,h,d); console.log(idx,x,y,z); } } } } else if (op == "d2xyz") { let n = w*h*d; for (let idx = 0; idx < n; idx++) { let xyz = gilbert_d2xyz(idx,w,h,d); console.log(xyz.x,xyz.y,xyz.z); } } } //_main( process.argv.slice(1) ); } ================================================ FILE: models/layers/gilbert/test.py ================================================ ================================================ FILE: models/layers/gilbert/tests/runtests.sh ================================================ #!/bin/bash # # SPDX-License-Identifier: BSD-2-Clause # Copyright (c) 2018 abetusk ln -f -s ../gilbert2d.py . ln -f -s ../gilbert3d.py . ln -f -s ../gilbert_d2xy.py . ln -f -s ../gilbert_d2xyz.py . ln -f -s ../gilbert_xy2d.py . ln -f -s ../gilbert_xyz2d.py . ln -f -s ../ports/gilbert.js . ln -f -s ../ports/gilbert . pushd ../ports make popd gilbert_cmp2 () { local x=$1 local y=$2 echo -n "(python) xy2d[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( ./gilbert_xy2d.py $x $y | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(python) d2xy[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( ./gilbert_d2xy.py $x $y 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi ### echo -n "(js) xy2d[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( node -e 'require("./gilbert.js").main(["gilbert.js","xy2d",'$x','$y']);' | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(js) d2xy[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( node -e 'require("./gilbert.js").main(["gilbert.js","d2xy",'$x','$y']);' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi ### echo -n "(c) xy2d[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( ./gilbert xy2d $x $y | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(c) d2xy[$x,$y]: " diff \ <( ./gilbert2d.py $x $y 2> /dev/null ) \ <( ./gilbert d2xy $x $y 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi } gilbert_cmp3 () { local x=$1 local y=$2 local z=$3 echo -n "(python) xyz2d[$x,$y,$z]: " # diff \ # <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ # <( ./gilbert3d.py --op xyz2d $x $y $z | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( ./gilbert_xyz2d.py $x $y $z | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(python) d2xyz[$x,$y,$z]: " # diff \ # <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ # <( ./gilbert3d.py --op d2xyz $x $y $z 2> /dev/null ) > /dev/null diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( ./gilbert_d2xyz.py $x $y $z 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi ### echo -n "(js) xyz2d[$x,$y,$z]: " diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( node -e 'require("./gilbert.js").main(["gilbert.js","xyz2d",'$x','$y','$z']);' | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(js) d2xyz[$x,$y,$z]: " diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( node -e 'require("./gilbert.js").main(["gilbert.js","d2xyz",'$x','$y','$z']);' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi ### echo -n "(c) xyz2d[$x,$y,$z]: " diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( ./gilbert xyz2d $x $y $z | sort -n | cut -f2- -d' ' 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi echo -n "(c) d2xyz[$x,$y,$z]: " diff \ <( ./gilbert3d.py $x $y $z 2> /dev/null ) \ <( ./gilbert d2xyz $x $y $z 2> /dev/null ) > /dev/null if [[ $? != 0 ]] ; then echo "FAIL" ; else echo "pass" ; fi } x=10 ; y=2 gilbert_cmp2 $x $y x=2 ; y=10 gilbert_cmp2 $x $y x=10 ; y=2 ; z=1 gilbert_cmp3 $x $y $z x=2 ; y=10 ; z=1 gilbert_cmp3 $x $y $z x=100 ; y=63 gilbert_cmp2 $x $y x=63 ; y=100 gilbert_cmp2 $x $y x=8 ; y=6 ; z=4 gilbert_cmp3 $x $y $z x=40 ; y=30 gilbert_cmp2 $x $y x=30 ; y=40 gilbert_cmp2 $x $y x=40 ; y=30 ; z=20 gilbert_cmp3 $x $y $z x=20 ; y=12 ; z=2 gilbert_cmp3 $x $y $z x=15 ; y=12 gilbert_cmp2 $x $y x=12 ; y=15 gilbert_cmp2 $x $y x=7 ; y=6 ; z=4 gilbert_cmp3 $x $y $z ================================================ FILE: models/layers/matching.py ================================================ import torch import torch.nn.functional as F def dice_loss(inputs, targets, num_boxes): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * (inputs * targets).sum(1) denominator = inputs.sum(-1) + targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_boxes def ber_loss(inputs, targets, num_boxes): inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = (inputs * targets).sum(1) # tp denominator = targets.sum(-1) loss = 1 - (numerator + 1) / (denominator + 1) return loss.sum() / num_boxes def ce_mask_loss(inputs, targets, num_boxes, weight=None): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: b=n_sigma thw targets: b=n_sigma thw (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ # b hw ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") # weight: b hw if weight is not None: ce_loss = ce_loss * weight return ce_loss.mean(1).sum() / num_boxes def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). alpha: (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default = -1 (no weighting). gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. Returns: Loss tensor """ prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) loss = alpha_t * loss return loss.mean(1).sum() / num_boxes def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2): N, M = len(inputs), len(targets) inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1) # [N, M, THW] targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1) # [N, M, THW] prob = inputs.sigmoid() ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = prob * targets + (1 - prob) * (1 - targets) coef = ce_loss * ((1 - p_t) ** gamma) if alpha >= 0: alpha_t = alpha * targets + (1 - alpha) * (1 - targets) coef = alpha_t * coef return coef.mean(2) # [N, M] def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1) numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] loss = 1 - (numerator + 1) / (denominator + 1) return loss def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor,): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ hw = inputs.shape[1] pos = F.binary_cross_entropy_with_logits( inputs, torch.ones_like(inputs), reduction="none" ) neg = F.binary_cross_entropy_with_logits( inputs, torch.zeros_like(inputs), reduction="none" ) loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( "nc,mc->nm", neg, (1 - targets) ) return loss / hw def get_src_permutation_idx(indices): # permute predictions following indices batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) src_idx = torch.cat([src for (src, _) in indices]) return batch_idx, src_idx def get_tgt_permutation_idx(indices): # permute targets following indices batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx ================================================ FILE: models/layers/position_encoding.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Various positional encodings for the transformer. """ import math import torch from torch import nn from utils.misc import NestedTensor class PositionEmbeddingSine(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors mask = tensor_list.mask assert mask is not None not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) return pos class PositionEmbeddingLearned(nn.Module): """ Absolute pos embedding, learned. """ def __init__(self, num_pos_feats=256): super().__init__() self.row_embed = nn.Embedding(50, num_pos_feats) self.col_embed = nn.Embedding(50, num_pos_feats) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.row_embed.weight) nn.init.uniform_(self.col_embed.weight) def forward(self, tensor_list: NestedTensor): x = tensor_list.tensors h, w = x.shape[-2:] i = torch.arange(w, device=x.device) j = torch.arange(h, device=x.device) x_emb = self.col_embed(i) y_emb = self.row_embed(j) pos = torch.cat([ x_emb.unsqueeze(0).repeat(h, 1, 1), y_emb.unsqueeze(1).repeat(1, w, 1), ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) return pos # dimension == 1 class PositionEmbeddingSine1D(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, temperature=10000, normalize=True, scale=None): super().__init__() self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, mask, hidden_dim): device = mask.device num_pos_feats = hidden_dim assert mask is not None not_mask = ~mask x_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T] if self.normalize: eps = 1e-6 x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device) dim_t = self.temperature ** (2 * (dim_t // 2) / num_pos_feats) pos_x = x_embed[:, :, None] / dim_t # [B, T, C] # n,c,t pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) pos = pos_x.permute(0, 2, 1) # [B, C, T] return pos # dimension == 3 class PositionEmbeddingSine3D(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, x, mask=None): # b, t, c, h, w assert x.dim() == 5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead" if mask is None: mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool) not_mask = ~mask z_embed = not_mask.cumsum(1, dtype=torch.float32) y_embed = not_mask.cumsum(2, dtype=torch.float32) x_embed = not_mask.cumsum(3, dtype=torch.float32) if self.normalize: eps = 1e-6 z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device) dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2)) pos_x = x_embed[:, :, :, :, None] / dim_t pos_y = y_embed[:, :, :, :, None] / dim_t pos_z = z_embed[:, :, :, :, None] / dim_t_z pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3) # b, t, c, h, w return pos class PositionEmbeddingSine2D(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ def __init__(self, temperature=10000, normalize=True, scale=None): super().__init__() self.temperature = temperature self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") if scale is None: scale = 2 * math.pi self.scale = scale def forward(self, mask, hidden_dim: int): """ @param mask: a tensor of shape [B, H, W] @param hidden_dim: int @return: position embedding of the same shape [B, hidden_dim, H, W] """ num_pos_feats = hidden_dim // 2 not_mask = ~mask y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=mask.device) dim_t = self.temperature ** (2 * (dim_t // 2) / num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) pos = torch.cat((pos_y, pos_x), dim=3).permute(0,3,1,2) return pos # b c h w class PositionEmbeddingLearned1D(nn.Module): """ Absolute pos embedding, learned. """ def __init__(self, num_pos_feats=256): super().__init__() self.row_embed = nn.Embedding(50, num_pos_feats) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.row_embed.weight) def forward(self, tensor_list: NestedTensor): """ Input: - tensor_list: NT(b s d, b s) """ x = tensor_list.tensors sequence_length = x.shape[-2] i = torch.arange(sequence_length, device=x.device) x_emb = self.row_embed(i) # s d pos = x_emb.unsqueeze(0).repeat(x.shape[0], 1, 1) return pos def build_position_encoding(hidden_dim=None, position_embedding_name='2d'): if position_embedding_name == 'original_2d': # TODO find a better way of exposing other arguments N_steps = hidden_dim // 2 return PositionEmbeddingSine(N_steps, normalize=True) elif position_embedding_name == 'learned_2d': N_steps = hidden_dim // 2 return PositionEmbeddingLearned(N_steps) elif position_embedding_name == 'learned_1d': assert hidden_dim is not None return PositionEmbeddingLearned1D(hidden_dim) elif position_embedding_name == '1d': return PositionEmbeddingSine1D(normalize=True) elif position_embedding_name == '2d': return PositionEmbeddingSine2D(normalize=True) elif position_embedding_name == '3d': N_steps = hidden_dim // 2 return PositionEmbeddingSine3D(N_steps, normalize=True) else: raise ValueError(f"not supported {position_embedding_name}") ================================================ FILE: models/layers/utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import copy def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu if activation == None: return None raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def _get_activation_layer(activation): if activation == "relu": return nn.ReLU() if activation == "gelu": return nn.GELU() if activation == "glu": return nn.GLU() if activation == 'none': return nn.Identity() raise RuntimeError(F"activation should be relu/gelu, not {activation}.") def pad_1d_feats(feat_list): # list[ni c] -> b nmax c feat_len = [len(feat) for feat in feat_list] n_max = max(feat_len) batch_size = len(feat_list) pad_mask = torch.ones([batch_size, n_max], dtype=torch.bool, device=feat_list[0].device) for i in range(batch_size): feat_list[i] = F.pad(feat_list[i].clone(), pad=[0, 0, 0, n_max-feat_len[i]]) pad_mask[i, :feat_len[i]] = False feat_list = torch.stack(feat_list, dim=0) # b nmax c return feat_list, pad_mask ================================================ FILE: models/modality_input_mappers/__init__.py ================================================ from .hilbert_curve import ( HilbertCurve_FrameQuery ) ================================================ FILE: models/modality_input_mappers/hilbert_curve.py ================================================ from models.registry import MODELITY_INPUT_MAPPER_REGISTRY import logging import torch from models.layers.gilbert.gilbert2d import gilbert2d_widthBigger @MODELITY_INPUT_MAPPER_REGISTRY.register() class HilbertCurve_FrameQuery: def __init__(self, configs, ) -> None: self.frame_query_number = configs['frame_query_number'] def mapper(self, video): return { 'haosen': None, } def collate(self, list_of_haosen, batch_videos): batch_size, T = batch_videos.shape[:2] batch_size, T, _, H, W = batch_videos.shape hilbert_curve = list(gilbert2d_widthBigger(width=self.frame_query_number, height=T)) # list[(x(width), y(height))] hilbert_curve = torch.tensor(hilbert_curve).long() hilbert_curve = hilbert_curve[:, 1] * self.frame_query_number + hilbert_curve[:, 0] return { 'hilbert_curve': hilbert_curve, } ================================================ FILE: models/optimization/optimizer.py ================================================ from detectron2.solver.build import maybe_add_gradient_clipping from collections import OrderedDict from typing import Any, Dict, List, Set, Union, Iterable, Callable, Type, Optional import copy import itertools import torch from enum import Enum _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] _GradientClipper = Callable[[_GradientClipperInput], None] class GradientClipType(Enum): VALUE = "value" NORM = "norm" def maybe_add_full_model_gradient_clipping(optim, configs): # detectron2 doesn't have full model gradient clipping now clip_norm_val = configs['optim']['clip_gradients']['clip_value'] enable = ( configs['optim']['clip_gradients']['enabled'] and configs['optim']['clip_gradients']['clip_type'] == "full_model" and configs['optim']['clip_gradients']['clip_value'] > 0.0 ) class FullModelGradientClippingOptimizer(optim): def step(self, closure=None): all_params = itertools.chain(*[x["params"] for x in self.param_groups]) torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) super().step(closure=closure) return FullModelGradientClippingOptimizer if enable else optim def _create_gradient_clipper(cfg) -> _GradientClipper: """ Creates gradient clipping closure to clip by value or by norm, according to the provided config. """ cfg = copy.deepcopy(cfg) def clip_grad_norm(p: _GradientClipperInput): torch.nn.utils.clip_grad_norm_(p, cfg['clip_value'], cfg['norm_type']) def clip_grad_value(p: _GradientClipperInput): torch.nn.utils.clip_grad_value_(p, cfg['clip_value']) _GRADIENT_CLIP_TYPE_TO_CLIPPER = { GradientClipType.VALUE: clip_grad_value, GradientClipType.NORM: clip_grad_norm, } return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg['clip_type'])] def _generate_optimizer_class_with_gradient_clipping( optimizer: Type[torch.optim.Optimizer], *, per_param_clipper: Optional[_GradientClipper] = None, global_clipper: Optional[_GradientClipper] = None, ) -> Type[torch.optim.Optimizer]: """ Dynamically creates a new type that inherits the type of a given instance and overrides the `step` method to add gradient clipping """ assert ( per_param_clipper is None or global_clipper is None ), "Not allowed to use both per-parameter clipping and global clipping" def optimizer_wgc_step(self, closure=None): if per_param_clipper is not None: for group in self.param_groups: for p in group["params"]: per_param_clipper(p) else: # global clipper for future use with detr # (https://github.com/facebookresearch/detr/pull/287) all_params = itertools.chain(*[g["params"] for g in self.param_groups]) global_clipper(all_params) super(type(self), self).step(closure) OptimizerWithGradientClip = type( optimizer.__name__ + "WithGradientClip", (optimizer,), {"step": optimizer_wgc_step}, ) return OptimizerWithGradientClip def maybe_add_gradient_clipping( configs: dict, optimizer: torch.optim.Optimizer): """ If gradient clipping is enabled through config options, wraps the existing optimizer type to become a new dynamically created class OptimizerWithGradientClip that inherits the given optimizer and overrides the `step` method to include gradient clipping. Args: cfg: CfgNode, configuration options optimizer: type. A subclass of torch.optim.Optimizer Return: type: either the input `optimizer` (if gradient clipping is disabled), or a subclass of it with gradient clipping included in the `step` method. """ if not configs['optim']['clip_gradients']['enabled']: return optimizer if isinstance(optimizer, torch.optim.Optimizer): optimizer_type = type(optimizer) else: assert issubclass(optimizer, torch.optim.Optimizer), optimizer optimizer_type = optimizer grad_clipper = _create_gradient_clipper(configs['optim']['clip_gradients']) OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( optimizer_type, per_param_clipper=grad_clipper ) if isinstance(optimizer, torch.optim.Optimizer): optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended return optimizer else: return OptimizerWithGradientClip def get_optimizer(params, configs): optimizer_type = configs['optim']['name'] base_lr = configs['optim']['base_lr'] weight_decay = configs['optim']['weight_decay'] if 'weight_decay' in configs['optim'] else configs['optim']['base_wd'] if optimizer_type == "AdamW": optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW, configs)( params, base_lr, weight_decay=weight_decay, ) else: raise NotImplementedError(f"no optimizer type {optimizer_type}") if configs['optim']['clip_gradients']['clip_type'] != "full_model": optimizer = maybe_add_gradient_clipping(configs, optimizer) return optimizer ================================================ FILE: models/optimization/scheduler.py ================================================ import torch from functools import partial import logging import numpy as np def build_scheduler(configs, optimizer): name = configs['optim']['scheduler']['name'] scheduler_configs = configs['optim']['scheduler'] if name == 'multistep_lr': scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_configs['milestones'], gamma=scheduler_configs['gamma'], verbose=scheduler_configs['verbose'],) return scheduler else: raise ValueError() ================================================ FILE: models/registry.py ================================================ _model_entrypoints = {} def register_model(fn): model_name = fn.__name__ if model_name in _model_entrypoints: raise ValueError(f'model name {model_name} has been registered') _model_entrypoints[model_name] = fn return fn def model_entrypoint(model_name): try: return _model_entrypoints[model_name] except KeyError as e: print(f'Model Name {model_name} not found') from detectron2.utils.registry import Registry MODELITY_INPUT_MAPPER_REGISTRY = Registry("MODELITY_INPUT_MAPPER") ================================================ FILE: output/VIS/cvc/pvt.py ================================================ from copy import deepcopy as dcopy import numpy as np frame_sampler = { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [6], 'clip_distribute': 'local_global', 'clip_position': 'center',}, 'clip_global_targets_map_to_local_targets': True, # 把整个视频中这个clip没出现的物体消除 'augmentation': {'name': 'WeakPolyP_TrainAug'},} test_mapper_evaluator = { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'},}, 'evaluator':{'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}),},} attention_defaults = { 'attn': { 'dropout': 0.1, 'nheads': 8, 'dim_feedforward': 2048, 'activation': 'relu', 'normalize_before': False, 'enforce_input_proj': True, # try 每个module对input进行proj }, 'deform_attn':{ 'nheads': 8, 'dim_feedforward': 1024, 'activation': 'relu', 'dropout': 0., 'enc_n_points': 4 }, } d_model = 64 trainer_configs = { 'eval_seed': 2024, 'model_schedule_seed': 2024, 'stream_idx_seed': 2024, 'initckpt':{'path': '', 'load_schedule': False, 'load_model': True, 'load_optimizer': False, 'load_random': False, 'eval_init_ckpt': False,}, 'data':{ 'evaluate': { '300-tv': dcopy(test_mapper_evaluator), '612-test':dcopy(test_mapper_evaluator), '612-val':dcopy(test_mapper_evaluator),}, 'train': { # 3292个clip 'Kvasir-train_step[1]': { 'mapper': { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [1], 'clip_distribute': 'local_global', 'clip_position': 'center', }, 'clip_global_targets_map_to_local_targets': True, 'augmentation': {'name': 'WeakPolyP_TrainAug_RotateImageToClip', 'num_frames': 6}, }, }, 'Mayo-train_step[6]': {'mapper': dcopy(frame_sampler),}, '300-train_step[6]': {'mapper': dcopy(frame_sampler),}, '612-train_step[6]': {'mapper': dcopy(frame_sampler),}, 'polyp_train_step[6]': { 'mapper': { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [6], 'clip_distribute': 'local_global', 'clip_position': 'center', }, 'clip_global_targets_map_to_local_targets': True, # 把整个视频中这个clip没出现的物体消除 'augmentation': {'name': 'WeakPolyP_TrainAug'}, }, }, }, }, 'optim': { 'splits': [0, None], 'batch_sizes': [4], 'ckpted_iters': 1260, 'one_batch_two_epoch': 'just_use', 'scheduler': { 'name': 'multistep_lr', 'milestones':[1260*3, 1260*6, 1260*9, 1260*12], 'gamma': 0.5, 'verbose': False}, 'name': 'AdamW', 'base_lr': 1e-3, 'backbone_lr_multiplier': 0.1, 'weight_decay': 1e-4, 'weight_decay_embed': 0.0, 'weight_decay_norm': 0.0, 'clip_gradients': { 'clip_type': 'full_model', # NORM/VALUE # grad.data.clamp_ 'clip_value': 0.01, 'enabled': True, 'norm_type': 2.0 }, }, 'model': { 'name': 'backbone_encoder_decoder_withScaleConsistency', 'input_aux':{'video_auxes':[{'name': 'HilbertCurve_FrameQuery','frame_query_number': 20}], 'targets_auxes': [],}, 'test_clip_size': None, "video_backbone":{ 'name': 'Video2D_PVT_V2', 'freeze': False, }, 'fusion': { 'name': 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal', 'd_model': d_model, 'video_projs':{ 'name': 'VideoConv_MultiscaleProj', 'projs':{ 'res3': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res4': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res5': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, }, }, 'nlayers': 3, 'encoded_scales': ['res3', 'res4', 'res5'], 'fpn_norm': 'GN', 'deform_attn': dcopy(attention_defaults['deform_attn']), 'frame_nqueries': 20, 'add_local': True, 'local_configs': {'d_model': d_model, 'num_heads': 8, 'kernel_size': 5, 'dilation': 1, 'dropout': 0.0,'num_steps': 1}, 'add_global': True, 'global_configs': {'d_model': d_model, 'dim_feedforward': 2048, 'dropout': 0.0, 'scan_order': 'hilbert', 'd_state': 16, 'd_conv': 3, 'nlayers': 3, 'add_attn_mask': False} }, 'decoder':{ 'name': 'Video_MaskedAttn_MultiscaleMaskDecoder_v3', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), 'video_nqueries': 10, 'inputs_projs': None, 'nlayers': 3, 'memory_scales': ['res5','res4','res3'], 'mask_scale': 'res2', 'num_classes': 1, 'head_outputs': ['mask', 'class'], # polygon 'temporal_self_layer': { 'name': 'FrameQuery_SS2DLayer_v2', 'd_model': d_model, 'nlayers': 3, 'dropout': 0.0, 'd_state': 16, 'd_conv': 3, 'dim_feedforward': 2048, }, 'temporal_cross_layer': { 'name': 'TemporalQuery_CrossSelf', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), }, 'loss':{ 'losses': { 'point_mask_dice_ce': {'num_points': 12544, 'oversample_ratio':3.0, 'importance_sample_ratio': 0.75}, 'class_ce': {}, }, 'matching_metrics': { 'class_prob': {'prob': 2}, # 'mask_dice_ce': {'ce': 2, 'dice': 5}, 'point_mask_dice_ce': {'ce':2, 'dice':2, 'num_points':12544} }, 'aux_layer_weights': 1., # int/list 'background_cls_eos': 0.1, }, }, 'loss_weight': {'mask_dice': 5, 'mask_ce': 2, 'class_ce':2}, }, } ================================================ FILE: output/VIS/fibroid/pvt.py ================================================ from copy import deepcopy as dcopy import numpy as np attention_defaults = { 'attn': { 'dropout': 0.1, 'nheads': 8, 'dim_feedforward': 2048, 'activation': 'relu', 'normalize_before': False, 'enforce_input_proj': True, # try 每个module对input进行proj }, 'deform_attn':{ 'nheads': 8, 'dim_feedforward': 1024, 'activation': 'relu', 'dropout': 0., 'enc_n_points': 4 }, } d_model = 64 trainer_configs = { 'eval_seed': 2024, 'model_schedule_seed': 2024, 'stream_idx_seed': 2024, 'initckpt':{'path': '', 'load_schedule': False, 'load_model': True, 'load_optimizer': False, 'load_random': False, 'eval_init_ckpt': False,}, 'data':{ 'evaluate': { 'fibroid_validate_temp8': { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, 'evaluator': {'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}), }, }, }, 'train': { 'fibroid_train_temp8_step[6]': { 'mapper': { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [6], 'clip_distribute': 'local_global', 'clip_position': 'center', }, 'clip_global_targets_map_to_local_targets': True, # 把整个视频中这个clip没出现的物体消除 'augmentation': {'name': 'WeakPolyP_TrainAug'}, }, }, }, }, 'optim': { 'splits': [0, None], 'batch_sizes': [4], 'ckpted_iters': 186, 'one_batch_two_epoch': 'just_use', 'scheduler': { 'name': 'multistep_lr', 'milestones':[186*3, 186*6, 186*9, 186*12], 'gamma': 0.5, 'verbose': False}, 'name': 'AdamW', 'base_lr': 1e-3, 'backbone_lr_multiplier': 0.1, 'weight_decay': 1e-4, 'weight_decay_embed': 0.0, 'weight_decay_norm': 0.0, 'clip_gradients': { 'clip_type': 'full_model', # NORM/VALUE # grad.data.clamp_ 'clip_value': 0.01, 'enabled': True, 'norm_type': 2.0 }, }, 'model': { 'name': 'backbone_encoder_decoder_withScaleConsistency', 'input_aux':{'video_auxes':[{'name': 'HilbertCurve_FrameQuery','frame_query_number': 20}], 'targets_auxes': [],}, 'test_clip_size': None, "video_backbone":{ 'name': 'Video2D_PVT_V2', 'freeze': False, }, 'fusion': { 'name': 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal', 'd_model': d_model, 'video_projs':{ 'name': 'VideoConv_MultiscaleProj', 'projs':{ 'res3': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res4': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res5': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, }, }, 'nlayers': 3, 'encoded_scales': ['res3', 'res4', 'res5'], 'fpn_norm': 'GN', 'deform_attn': dcopy(attention_defaults['deform_attn']), 'frame_nqueries': 20, 'add_local': True, 'local_configs': {'d_model': d_model, 'num_heads': 8, 'kernel_size': 5, 'dilation': 2, 'dropout': 0.0,'num_steps': 1}, 'add_global': True, 'global_configs': {'d_model': d_model, 'dim_feedforward': 2048, 'dropout': 0.0, 'scan_order': 'hilbert', 'd_state': 16, 'd_conv': 3, 'nlayers': 3, 'add_attn_mask': True} }, 'decoder':{ 'name': 'Video_MaskedAttn_MultiscaleMaskDecoder_v3', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), 'video_nqueries': 10, 'inputs_projs': None, 'nlayers': 3, 'memory_scales': ['res5','res4','res3'], 'mask_scale': 'res2', 'num_classes': 1, 'head_outputs': ['mask', 'class'], # polygon 'temporal_self_layer': { 'name': 'FrameQuery_SS2DLayer_v2', 'd_model': d_model, 'nlayers': 3, 'dropout': 0.0, 'd_state': 16, 'd_conv': 3, 'dim_feedforward': 2048, }, 'temporal_cross_layer': { 'name': 'TemporalQuery_CrossSelf', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), }, 'loss':{ 'losses': { 'point_mask_dice_ce': {'num_points': 12544, 'oversample_ratio':3.0, 'importance_sample_ratio': 0.75}, 'class_ce': {}, }, 'matching_metrics': { 'class_prob': {'prob': 2}, # 'mask_dice_ce': {'ce': 2, 'dice': 5}, 'point_mask_dice_ce': {'ce':2, 'dice':2, 'num_points':12544} }, 'aux_layer_weights': 1., # int/list 'background_cls_eos': 0.1, }, }, 'loss_weight': {'mask_dice': 5, 'mask_ce': 2, 'class_ce':2}, }, } ================================================ FILE: output/VIS/sunseg/pvt/pvt.py ================================================ from copy import deepcopy as dcopy import numpy as np attention_defaults = { 'attn': { 'dropout': 0.1, 'nheads': 8, 'dim_feedforward': 2048, 'activation': 'relu', 'normalize_before': False, 'enforce_input_proj': True, }, 'deform_attn':{ 'nheads': 8, 'dim_feedforward': 1024, 'activation': 'relu', 'dropout': 0., 'enc_n_points': 4 }, } d_model = 64 trainer_configs = { 'eval_seed': 2024, 'model_schedule_seed': 2024, 'stream_idx_seed': 2024, 'initckpt':{'path': '', 'load_schedule': False, 'load_model': True, 'load_optimizer': False, 'load_random': False, 'eval_init_ckpt': False,}, 'data':{ 'evaluate': { 'polyp_hard_unseen_validate': { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, 'evaluator': {'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}), }, }, 'polyp_easy_unseen_validate': { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, 'evaluator': {'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}), }, }, # 'polyp_hard_validate': { # 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, # 'evaluator': {'name': 'VIS_Evaluator_FrameFast', # 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], # 'metrics_aggregator': ('polyp_metric_aggregator', {}), # }, # }, # 'polyp_easy_validate': { # 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, # 'evaluator': {'name': 'VIS_Evaluator_FrameFast', # 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], # 'metrics_aggregator': ('polyp_metric_aggregator', {}), # }, # }, }, 'train': { # 3292个clip 'polyp_train_step[6]': { 'mapper': { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [6], 'clip_distribute': 'local_global', 'clip_position': 'center', }, 'clip_global_targets_map_to_local_targets': True, # 把整个视频中这个clip没出现的物体消除 'augmentation': {'name': 'WeakPolyP_TrainAug'}, }, }, }, }, 'optim': { 'splits': [0, None], 'batch_sizes': [4], 'ckpted_iters': 823*8, 'one_batch_two_epoch': 'just_use', 'scheduler': { 'name': 'multistep_lr', 'milestones':[823*3, 823*6, 823*9, 823*12], 'gamma': 0.5, 'verbose': False}, 'name': 'AdamW', 'base_lr': 1e-3, 'backbone_lr_multiplier': 0.1, 'weight_decay': 1e-4, 'weight_decay_embed': 0.0, 'weight_decay_norm': 0.0, 'clip_gradients': { 'clip_type': 'full_model', 'clip_value': 0.01, 'enabled': True, 'norm_type': 2.0 }, }, 'model': { 'name': 'backbone_encoder_decoder_withScaleConsistency', 'input_aux':{'video_auxes':[{'name': 'HilbertCurve_FrameQuery','frame_query_number': 20}], 'targets_auxes': [],}, 'test_clip_size': None, "video_backbone":{ 'name': 'Video2D_PVT_V2', 'freeze': False, }, 'fusion': { 'name': 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal', 'd_model': d_model, 'video_projs':{ 'name': 'VideoConv_MultiscaleProj', 'projs':{ 'res3': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res4': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res5': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, }, }, 'nlayers': 3, 'encoded_scales': ['res3', 'res4', 'res5'], 'fpn_norm': 'GN', 'deform_attn': dcopy(attention_defaults['deform_attn']), 'frame_nqueries': 20, 'add_local': True, 'local_configs': {'d_model': d_model, 'num_heads': 8, 'kernel_size': 5, 'dilation': 1, 'dropout': 0.0,'num_steps': 1}, 'add_global': True, 'global_configs': {'d_model': d_model, 'dim_feedforward': 2048, 'dropout': 0.0, 'scan_order': 'hilbert', 'd_state': 16, 'd_conv': 3, 'nlayers': 3, 'add_attn_mask': False} }, 'decoder':{ 'name': 'Video_MaskedAttn_MultiscaleMaskDecoder_v3', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), 'video_nqueries': 10, 'inputs_projs': None, 'nlayers': 3, 'memory_scales': ['res5','res4','res3'], 'mask_scale': 'res2', 'num_classes': 1, 'head_outputs': ['mask', 'class'], # polygon 'temporal_self_layer': { 'name': 'FrameQuery_SS2DLayer_v2', 'd_model': d_model, 'nlayers': 3, 'dropout': 0.0, 'd_state': 16, 'd_conv': 3, 'dim_feedforward': 2048, }, 'temporal_cross_layer': { 'name': 'TemporalQuery_CrossSelf', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), }, 'loss':{ 'losses': { 'point_mask_dice_ce': {'num_points': 12544, 'oversample_ratio':3.0, 'importance_sample_ratio': 0.75}, 'class_ce': {}, }, 'matching_metrics': { 'class_prob': {'prob': 2}, # 'mask_dice_ce': {'ce': 2, 'dice': 5}, 'point_mask_dice_ce': {'ce':2, 'dice':2, 'num_points':12544} }, 'aux_layer_weights': 1., # int/list 'background_cls_eos': 0.1, }, }, 'loss_weight': {'mask_dice': 5, 'mask_ce': 5, 'class_ce':2}, }, } ================================================ FILE: output/VIS/sunseg/res/res.py ================================================ from copy import deepcopy as dcopy import numpy as np attention_defaults = { 'attn': { 'dropout': 0.1, 'nheads': 8, 'dim_feedforward': 2048, 'activation': 'relu', 'normalize_before': False, 'enforce_input_proj': True, # try 每个module对input进行proj }, 'deform_attn':{ 'nheads': 8, 'dim_feedforward': 1024, 'activation': 'relu', 'dropout': 0., 'enc_n_points': 4 }, } d_model = 128 trainer_configs = { 'eval_seed': 2024, 'model_schedule_seed': 2024, 'stream_idx_seed': 2024, 'initckpt':{'path': '', 'load_schedule': False, 'load_model': True, 'load_optimizer': False, 'load_random': False, 'eval_init_ckpt': False,}, 'data':{ 'evaluate': { 'polyp_hard_validate': { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, 'evaluator': {'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}), }, }, 'polyp_easy_validate': { 'mapper': {'name': 'VIS_Video_EvalMapper', 'augmentation': {'name': 'WeakPolyP_EvalAug'}, }, 'evaluator': {'name': 'VIS_Evaluator_FrameFast', 'frame_metrics': [('mask_dice_iou', {}), ('web', {})], 'video_metrics': [], 'metrics_aggregator': ('polyp_metric_aggregator', {}), }, }, }, 'train': { # 3292个clip 'polyp_train_step[6]': { 'mapper': { 'name': 'VIS_Video_or_Step_To_Clip_TrainMapper', 'frames_sampler': { 'name': 'Naive_ReferenceFrame_FrameSampler', 'clip_sizes': [6], 'clip_distribute': 'local_global', 'clip_position': 'center', }, 'clip_global_targets_map_to_local_targets': True, 'augmentation': {'name': 'WeakPolyP_TrainAug'}, }, }, }, }, 'optim': { 'splits': [0, None], 'batch_sizes': [4], 'ckpted_iters': 823, 'one_batch_two_epoch': 'just_use', 'scheduler': { 'name': 'multistep_lr', 'milestones':[823*3, 823*6, 823*9, 823*12], 'gamma': 0.5, 'verbose': False}, 'name': 'AdamW', 'base_lr': 1e-3, 'backbone_lr_multiplier': 0.1, 'weight_decay': 1e-4, 'weight_decay_embed': 0.0, 'weight_decay_norm': 0.0, 'clip_gradients': { 'clip_type': 'full_model', 'clip_value': 0.01, 'enabled': True, 'norm_type': 2.0 }, }, 'model': { 'test_clip_size': None, 'name': 'backbone_encoder_decoder_withScaleConsistency', 'input_aux':{'video_auxes':[{'name': 'HilbertCurve_FrameQuery','frame_query_number': 20}], 'targets_auxes': [],}, "video_backbone":{ 'name': 'Res2Net_50_EachFrame', 'freeze': False, }, 'fusion': { 'name': 'Video_Deform2D_DividedTemporal_MultiscaleEncoder_localGlobal', 'd_model': d_model, 'video_projs':{ 'name': 'VideoConv_MultiscaleProj', 'projs':{ 'res3': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res4': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, 'res5': {'kernel_size': 1, 'bias': False, 'norm': 'gn_32'}, }, }, 'nlayers': 3, 'encoded_scales': ['res3', 'res4', 'res5'], 'fpn_norm': 'GN', 'deform_attn': dcopy(attention_defaults['deform_attn']), 'frame_nqueries': 20, 'add_local': True, 'local_configs': {'d_model': d_model, 'num_heads': 8, 'kernel_size': 5, 'dilation': 1, 'dropout': 0.0,'num_steps': 1}, 'add_global': True, 'global_configs': {'d_model': d_model, 'dim_feedforward': 2048, 'dropout': 0.0, 'scan_order': 'hilbert', 'd_state': 16, 'd_conv': 3, 'nlayers': 3, 'add_attn_mask': False} }, 'decoder':{ 'name': 'Video_MaskedAttn_MultiscaleMaskDecoder_v3', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), 'video_nqueries': 10, 'inputs_projs': None, 'nlayers': 3, 'memory_scales': ['res5','res4','res3'], 'mask_scale': 'res2', 'num_classes': 1, 'head_outputs': ['mask', 'class'], # polygon 'temporal_self_layer': { 'name': 'FrameQuery_SS2DLayer_v2', 'd_model': d_model, 'nlayers': 3, 'dropout': 0.0, 'd_state': 16, 'd_conv': 3, 'dim_feedforward': 2048, }, 'temporal_cross_layer': { 'name': 'TemporalQuery_CrossSelf', 'd_model': d_model, 'attn': dcopy(attention_defaults['attn']), }, 'loss':{ 'losses': { 'point_mask_dice_ce': {'num_points': 12544, 'oversample_ratio':3.0, 'importance_sample_ratio': 0.75}, 'class_ce': {}, }, 'matching_metrics': { 'class_prob': {'prob': 2}, # 'mask_dice_ce': {'ce': 2, 'dice': 5}, 'point_mask_dice_ce': {'ce':2, 'dice':2, 'num_points':12544} }, 'aux_layer_weights': 1., # int/list 'background_cls_eos': 0.1, }, }, 'loss_weight': {'mask_dice': 5, 'mask_ce': 2, 'class_ce':2}, }, } ================================================ FILE: reorganize_sunseg.py ================================================ import os, shutil, glob from tqdm import tqdm SUN_root = f"{os.getenv('DATASET_PATH')}/SUN-SEG/SUN-Positive/" SUNSEG_root = f"{os.getenv('DATASET_PATH')}/SUN-SEG/SUN-SEG-Annotation/" SUN_split_dict = {} SUNSEG_split_dict = {} SUNSEG_dataset_dict = {} image_list = [] # SUN_list = glob.glob(SUN_root + '*/*.jpg') SUNSEG_test_list = glob.glob(SUNSEG_root + 'Test*/*/GT/*/*.png') SUNSEG_train_list = glob.glob(SUNSEG_root + 'TrainDataset/GT/*/*.png') SUNSEG_list = SUNSEG_test_list + SUNSEG_train_list SUN_list = [os.path.join(SUN_root, name.split('/')[-2].split('_')[0] if len(name.split('/')[-2].split('_')) > 1 else name.split('/')[-2], name.split('/')[-1].replace('.png', '')) for name in SUNSEG_list] for SUN_path, SUNSEG_path in zip(SUN_list, SUNSEG_list): """ @func: Get SUN and SUN-SEG case-to-image structure in a dictionary """ SUN_case_name, SUN_image_name = SUN_path.split('/')[-2], SUN_path.split('/')[-1] SUNSEG_dataset_name, SUNSEG_case_name, SUNSEG_image_name = SUNSEG_path.split('SUN-SEG-Annotation/')[1].split('/GT')[0], SUNSEG_path.split('/')[-2], SUNSEG_path.split('/')[-1].rstrip('.png') SUN_split_dict[SUN_image_name] = SUN_case_name SUNSEG_split_dict[SUNSEG_image_name] = SUNSEG_case_name SUNSEG_dataset_dict[SUNSEG_image_name] = SUNSEG_dataset_name image_list.append(SUN_image_name) for image in tqdm(image_list): """ @func: Change original SUN's structure """ SUN_case = SUN_split_dict[image] SUNSEG_case = SUNSEG_split_dict[image] dataset_split = SUNSEG_dataset_dict[image] os.makedirs(os.path.join(SUNSEG_root, dataset_split, 'Frame', SUNSEG_case), exist_ok=True) shutil.move(os.path.join(SUN_root, SUN_case, image + '.jpg'), os.path.join(SUNSEG_root, dataset_split, 'Frame', SUNSEG_case, image + '.jpg')) # combine Seen/Unseen to same directory os.makedirs(os.path.join(os.getenv('DATASET_PATH'), 'SUN-SEG/SUN-SEG-Annotation', 'TestEasyDataset', 'combine/Frame'), exist_ok=True) os.makedirs(os.path.join(os.getenv('DATASET_PATH'), 'SUN-SEG/SUN-SEG-Annotation', 'TestEasyDataset', 'combine/GT'), exist_ok=True) os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Seen/Frame/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/combine/Frame/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Unseen/Frame/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/combine/Frame/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Seen/GT/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/combine/GT/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/Unseen/GT/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestEasyDataset/combine/GT/') os.makedirs(os.path.join(os.getenv('DATASET_PATH'), 'SUN-SEG/SUN-SEG-Annotation', 'TestHardDataset', 'combine/Frame'), exist_ok=True) os.makedirs(os.path.join(os.getenv('DATASET_PATH'), 'SUN-SEG/SUN-SEG-Annotation', 'TestHardDataset', 'combine/GT'), exist_ok=True) os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/Seen/Frame/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/combine/Frame/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/Unseen/Frame/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/combine/Frame/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/Seen/GT/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/combine/GT/') os.system('mv $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/Unseen/GT/* $DATASET_PATH/SUN-SEG/SUN-SEG-Annotation/TestHardDataset/combine/GT/') ================================================ FILE: trainers/Trainer.py ================================================ import torch import numpy as np import random import math import logging import time import os from utils.misc import reduce_dict, to_device, is_dist_avail_and_initialized import gc from utils.misc import SmoothedValue, MetricLogger from torch.nn.parallel import DistributedDataParallel as DDP import detectron2.utils.comm as comm import datetime import torch.distributed as dist from models import model_entrypoint from utils.misc import to_device __all__ = ['Trainer'] class Trainer: def __init__(self, configs): torch.autograd.set_detect_anomaly(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True seed = configs['model_schedule_seed'] random.seed(seed) np.random.seed(seed) torch.random.manual_seed(seed) with torch.cuda.device(self.device): torch.cuda.manual_seed(seed) # model and data create_model_schedule = model_entrypoint(configs['model']['name']) self.model, self.optimizer, self.scheduler, \ self.train_samplers, self.train_loaders, self.log_lr_group_name_to_idx, \ self.eval_function = create_model_schedule(configs, device=self.device,) self.register_metric_logger([f'lr_group_{haosen}' for haosen in list(self.log_lr_group_name_to_idx.keys())]) logging.debug(f'total number of parameters:{sum(p.numel() for p in self.model.parameters())}') logging.debug(f'total number of trainable parameters:{sum(p.numel() for p in self.model.parameters() if p.requires_grad)}') logging.debug(configs) self.eval_seed = configs['eval_seed'] self.out_dir = configs['out_dir'] self.ckpted_iters = configs['optim']['ckpted_iters'] # list[int] self.num_iterations = 0 self.num_samples = 0 assert self.train_samplers[0].start_idx == self.num_samples if comm.get_world_size() > 1: self.model = DDP(self.model, device_ids=[comm.get_local_rank()], find_unused_parameters=False, broadcast_buffers = False) random.seed(seed + comm.get_rank()) np.random.seed(seed + comm.get_rank()) torch.random.manual_seed(seed + comm.get_rank()) with torch.cuda.device(self.device): torch.cuda.manual_seed(seed + comm.get_rank()) if configs['initckpt']['path'] != '': self.load_ckpt(configs['initckpt']['path'], load_random=configs['initckpt']['load_random'], load_model=configs['initckpt']['load_model'], load_schedule=configs['initckpt']['load_schedule'], load_optimize=configs['initckpt']['load_optimizer']) self.save_ckpt() if configs['initckpt']['eval_init_ckpt']: self.evaluate() self.load_ckpt(os.path.join(self.iteration_dir, 'ckpt.pth.tar'), load_random=True, load_schedule=False, load_model=False, load_optimize=False,) def train(self): manual_stop_train = False for loader in self.train_loaders: for idx, batch_dict in enumerate(loader): if manual_stop_train: self.save_ckpt() self.model.train() meta_idxs = batch_dict.pop('meta_idxs') visualize = batch_dict.pop('visualize') batch_dict = to_device(batch_dict, self.device) batch_dict['visualize_paths'] = self.visualize_path(meta_idxs=meta_idxs, visualize=visualize) iteration_time = time.time() loss_dict_unscaled, loss_weight = self.model(batch_dict) loss = sum([loss_dict_unscaled[k] * loss_weight[k] for k in loss_weight.keys()]) assert math.isfinite(loss.item()), f"Loss is {loss.item()}, stopping training" loss.backward() self.optimizer.step() iteration_time = time.time() - iteration_time self.optimizer.zero_grad(set_to_none=True) self.scheduler.step() sample_idxs = comm.all_gather(meta_idxs) sample_idxs = [taylor for cardib in sample_idxs for taylor in cardib] self.num_samples += len(sample_idxs) self.num_iterations += 1 loss_dict_unscaled_item = {key: torch.tensor(value.detach().item(), device=self.device) for key, value in loss_dict_unscaled.items()} del loss, loss_dict_unscaled self._log(loss_dict_unscaled=loss_dict_unscaled_item, loss_weight=loss_weight, sample_idxs=sample_idxs, iteration_time=iteration_time) def save_ckpt(self): rng_state_dict = {comm.get_rank(): { 'cpu_rng_state': torch.get_rng_state(), 'gpu_rng_state': torch.cuda.get_rng_state(self.device), 'numpy_rng_state': np.random.get_state(), 'py_rng_state': random.getstate() }} rng_state_dict_by_rank = comm.gather(rng_state_dict, dst=0) if comm.is_main_process(): rng_state_dict_by_rank = {key : value for rs in rng_state_dict_by_rank for key,value in rs.items()} model_without_ddp = self.model.module if isinstance(self.model, DDP) else self.model checkpoint_dict = { 'model': model_without_ddp.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict(), 'num_samples': self.num_samples, 'num_iterations': self.num_iterations, 'rng_state_dict_by_rank': rng_state_dict_by_rank, 'metrics': {}, } os.makedirs(self.iteration_dir, exist_ok=True) torch.save(checkpoint_dict, os.path.join(self.iteration_dir, 'ckpt.pth.tar')) del checkpoint_dict if is_dist_avail_and_initialized(): dist.barrier() del rng_state_dict_by_rank @torch.no_grad() def evaluate(self): random.seed(self.eval_seed) np.random.seed(self.eval_seed) torch.random.manual_seed(self.eval_seed) with torch.cuda.device(self.device): torch.cuda.manual_seed(self.eval_seed) self.model.eval() eval_model = self.model.module if isinstance(self.model, DDP) else self.model ckpt_file = os.path.join(self.iteration_dir, 'ckpt.pth.tar') assert os.path.exists(ckpt_file), 'must save ckpt before evaluate' evaluate_metrics = self.eval_function(model = eval_model, output_dir = self.iteration_dir) if is_dist_avail_and_initialized(): dist.barrier() if comm.is_main_process(): checkpoint_dict = torch.load(ckpt_file, map_location='cpu') ckpt_metrics = checkpoint_dict['metrics'] to_update_metrics = {} for metric_key in evaluate_metrics.keys(): metric_value = evaluate_metrics[metric_key] if metric_key in ckpt_metrics: saved_value = ckpt_metrics[metric_key] if (metric_value - saved_value) > 1e-2: logging.error(f'{metric_key} different saved value') to_update_metrics[metric_key] = metric_value else: to_update_metrics[metric_key] = metric_value checkpoint_dict['metrics'] = evaluate_metrics metric_string = ' '.join([f'{key} : {value:.6f}' for key, value in evaluate_metrics.items()]) logging.debug(metric_string) torch.save(checkpoint_dict, ckpt_file) del checkpoint_dict if is_dist_avail_and_initialized(): dist.barrier() def load_ckpt(self, ckpt_path=None, load_schedule=False, load_optimize=False, load_model=False, load_random=False, ): assert os.path.exists(ckpt_path) model_without_ddp = self.model.module if isinstance(self.model, DDP) else self.model checkpoint = torch.load(ckpt_path, map_location='cpu') if load_model: model_without_ddp.load_state_dict(checkpoint['model'], strict=True) if load_optimize: self.optimizer.load_state_dict(checkpoint['optimizer']) self.scheduler.load_state_dict(checkpoint['scheduler']) if load_schedule: self.num_samples = checkpoint['num_samples'] self.num_iterations = checkpoint['num_iterations'] sampler = self.train_samplers[0] while (sampler.end_idx != None) and (self.num_samples > (sampler.end_idx - 1)): self.train_samplers.pop(0) self.train_loaders.pop(0) sampler = self.train_samplers[0] self.train_samplers[0].set_iter_first_sample_idx(self.num_samples) if load_random: rng_state_dict_by_rank = checkpoint['rng_state_dict_by_rank'] torch.set_rng_state(rng_state_dict_by_rank[comm.get_rank()]['cpu_rng_state']) torch.cuda.set_rng_state(rng_state_dict_by_rank[comm.get_rank()]['gpu_rng_state'], device=self.device) np.random.set_state(rng_state_dict_by_rank[comm.get_rank()]['numpy_rng_state']) random.setstate(rng_state_dict_by_rank[comm.get_rank()]['py_rng_state']) del checkpoint def _log(self, loss_dict_unscaled, loss_weight, sample_idxs, iteration_time,): loss_dict_unscaled_reduced = reduce_dict(loss_dict_unscaled) loss_value = sum([loss_dict_unscaled_reduced[key] * loss_weight[key] for key in loss_weight.keys()]) if comm.is_main_process(): for idx, sp_idx in enumerate(sample_idxs): pass logger_updates = {} for log_lr_group_name, log_lr_group_idx in self.log_lr_group_name_to_idx.items(): if log_lr_group_idx is None: logger_updates[f'lr_group_{log_lr_group_name}'] = 0 else: logger_updates[f'lr_group_{log_lr_group_name}'] = self.optimizer.param_groups[log_lr_group_idx]["lr"] logger_updates.update(loss_dict_unscaled_reduced) logger_updates.update({ 'loss_value': loss_value, 'iteration_time': iteration_time, }) self.metric_logger.update(**logger_updates) log_string = self.log_header(iteration_time, sample_idxs) + f'\n{str(self.metric_logger)}' wandb_log = self.metric_logger.to_dict() logging.debug(log_string) if is_dist_avail_and_initialized(): dist.barrier() if type(self.ckpted_iters) == int: do_ckpt = (self.num_iterations % self.ckpted_iters) == 0 elif type(self.ckpted_iters) == list: do_ckpt = self.num_iterations in self.ckpted_iters else: raise ValueError() if (self.num_iterations % 2000 == 0) or do_ckpt: gc.collect() torch.cuda.empty_cache() if do_ckpt: try: self.save_ckpt() self.evaluate() self.load_ckpt(os.path.join(self.iteration_dir, 'ckpt.pth.tar'), load_random=True, load_schedule=False, load_model=False, load_optimize=False,) except: if comm.is_main_process(): logging.error(f'Iteration {self.num_iterations} evaluate error') if is_dist_avail_and_initialized(): dist.barrier() @property def device(self): return torch.device(comm.get_local_rank()) @property def iteration_dir(self): return os.path.join(self.out_dir, f'epc[{self.epoch[-1]}]_iter[{self.num_iterations}]_sap[{self.num_samples}]') @property def epoch(self): dataset_length = len(self.train_loaders[0].dataset) epoch = self.num_samples / dataset_length int_part, dec_part = f'{epoch:.2f}'.split('.') return epoch, f'{int_part}_{dec_part}' def log_header(self, iteration_time, sample_idxs): one_epoch_iterations = len(self.train_loaders[0].dataset) // len(sample_idxs) eta = datetime.timedelta(seconds=one_epoch_iterations * iteration_time) return f'Epoch_ETA: [{str(eta)}] Epoch:[{self.epoch[0]:.2f}] Iter: [{(self.num_iterations):06d}] Sample: [{self.num_samples:06d}]' def visualize_path(self, meta_idxs, visualize): return [os.path.join(self.iteration_dir, 'visualize_model', f'train_meta_{str(meta_idx)}') if vis else None for (meta_idx, vis) in zip(meta_idxs, visualize)] def register_metric_logger(self, log_keys): if comm.is_main_process(): if not hasattr(self, 'metric_logger'): self.metric_logger = MetricLogger(delimiter='\t') for haosen in log_keys: if 'lr_group' in haosen: self.metric_logger.add_meter(haosen, SmoothedValue(window_size=1,fmt='{value:.8f}', handler='value')) elif haosen == 'iteration_time': self.metric_logger.add_meter(haosen, SmoothedValue(window_size=1,fmt='{value:2f}',handler='value')) else: raise ValueError() ================================================ FILE: trainers/__init__.py ================================================ from .Trainer import Trainer task_to_trainer = { 'VIS': Trainer, } ================================================ FILE: utils/__init__.py ================================================ from . import misc ================================================ FILE: utils/misc.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. change SmoothedValue and MetricLogger, interpolate modify nested_tensor_from_tensor_list add nested_tensor_from_videos_list """ import os import subprocess import time from collections import defaultdict, deque import datetime import pickle from packaging import version from typing import Optional, List import torch import torch.distributed as dist from torch import Tensor # needed due to empty tensor bug in pytorch and torchvision 0.5 import torchvision if version.parse(torchvision.__version__) < version.parse('0.7'): from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=1, fmt='{value:.6f}', handler='value'): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt self.handler = handler def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] @property def wandb_log_property(self): if self.handler == 'value': return self.value elif self.handler == 'avg': return self.avg else: raise NotImplementedError('other log property not implemented!') def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) def all_gather(data): """ Run all_gather on arbitrary picklable data (not necessarily tensors) Args: data: any picklable object Returns: list[data]: list of data gathered from each rank """ world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()], device="cuda") size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) if local_size != max_size: padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def all_gather_cpu(data): world_size = get_world_size() if world_size == 1: return [data] # serialized to a Tensor buffer = pickle.dumps(data) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage) # obtain Tensor size of each rank local_size = torch.tensor([tensor.numel()]) size_list = [torch.tensor([0]) for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # receiving Tensor from all ranks # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes tensor_list = [] for _ in size_list: tensor_list.append(torch.empty((max_size,), dtype=torch.uint8)) if local_size != max_size: padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8) tensor = torch.cat((tensor, padding), dim=0) dist.all_gather(tensor_list, tensor) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def reduce_dict(input_dict, average=True): """ Args: input_dict (dict): all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.all_reduce(values) if average: values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict def reduce_scalar(input, average=True): """ Args: input: all the values will be reduced average (bool): whether to do average or sum Reduce the values in the dictionary from all processes so that all processes have the averaged results. Returns a dict with the same fields as input, after reduction. """ world_size = get_world_size() if world_size < 2: return input.item() with torch.no_grad(): dist.all_reduce(input) if average: input /= world_size return input.item() class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def to_dict(self): log_dict = {} for name in self.meters.keys(): log_dict[name] = self.meters[name].wandb_log_property return log_dict def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' if torch.cuda.is_available(): log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}', 'max mem: {memory:.0f}' ]) else: log_msg = self.delimiter.join([ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ]) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def get_sha(): cwd = os.path.dirname(os.path.abspath(__file__)) def _run(command): return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() sha = 'N/A' diff = "clean" branch = 'N/A' try: sha = _run(['git', 'rev-parse', 'HEAD']) subprocess.check_output(['git', 'diff'], cwd=cwd) diff = _run(['git', 'diff-index', 'HEAD']) diff = "has uncommited changes" if diff else "clean" branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) except Exception: pass message = f"sha: {sha}, status: {diff}, branch: {branch}" return message def collate_fn(batch): batch = list(zip(*batch)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) # region change this function # def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # # TODO make this more general # if tensor_list[0].ndim == 3: # if torchvision._is_tracing(): # # nested_tensor_from_tensor_list() does not export well to ONNX # # call _onnx_nested_tensor_from_tensor_list() instead # return _onnx_nested_tensor_from_tensor_list(tensor_list) # # TODO make it support different-sized images # max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) # batch_shape = [len(tensor_list)] + max_size # b, c, h, w = batch_shape # dtype = tensor_list[0].dtype # device = tensor_list[0].device # tensor = torch.zeros(batch_shape, dtype=dtype, device=device) # mask = torch.ones((b, h, w), dtype=torch.bool, device=device) # for img, pad_img, m in zip(tensor_list, tensor, mask): # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # else: # raise ValueError('not supported') # return NestedTensor(tensor, mask) # endregion def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): """ This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their padding masks (true for padding areas, false otherwise). """ max_size = _max_by_axis([list(img.shape) for img in tensor_list]) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False return NestedTensor(tensor, mask) def nested_tensor_from_tensor_list_visiblility(tensor_list: List[Tensor], size_divisibility=1, split=True): """ This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their padding masks (true for padding areas, false otherwise). """ # TODO make this more general # if image tensor is stacked as [T*3, H, W], then use split if split: tensor_list = [tensor.split(3,dim=0) for tensor in tensor_list] tensor_list = [item for sublist in tensor_list for item in sublist] # list[tensor], length = batch_size x time if tensor_list[0].ndim == 3: # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) if size_divisibility > 1: # so that the mask dowmsample can be matched stride = size_divisibility # the last two dims are [H, W], both subject to divisibility requirement max_size[-2] = (max_size[-2] + (stride - 1)) // stride * stride max_size[-1] = (max_size[-1] + (stride - 1)) // stride * stride # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False # valid locations else: raise ValueError('not supported') return NestedTensor(tensor, mask) def nested_tensor_from_tensor_list_with_stride(tensor_list: List[Tensor], max_stride=16): """ This function receives a list of image tensors and returns a NestedTensor of the padded images, along with their padding masks (true for padding areas, false otherwise). """ max_size = _max_by_axis([list(img.shape) for img in tensor_list]) *_, h,w = max_size if w % max_stride != 0: w += max_stride - (w % max_stride) if h % max_stride != 0: h += max_stride - (h % max_stride) max_size[-1] = w max_size[-2] = h batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], :img.shape[2]] = False return NestedTensor(tensor, mask) def _get_nearest_scale_number(num, scale): res = num % scale if res > 0: return num + (scale - res) else: return num def nested_tensor_from_videos_list(videos_list: List[Tensor]): """ This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded videos (shape [T, B, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape [T, B, PH, PW]. """ max_size = _max_by_axis([list(img.shape) for img in videos_list]) # max_size[2] = _get_nearest_scale_number(max_size[2], mask_decoder_max_stride) # max_size[3] = _get_nearest_scale_number(max_size[3], mask_decoder_max_stride) padded_batch_shape = [len(videos_list)] + max_size b, t, c, h, w = padded_batch_shape dtype = videos_list[0].dtype device = videos_list[0].device padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False # transpose the temporal and batch dims and create a NestedTensor: return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_master or force: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}'.format( args.rank, args.dist_url), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) @torch.no_grad() def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" if target.numel() == 0: return [torch.zeros([], device=output.device)] maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty batch sizes. This will eventually be supported natively by PyTorch, and this class can go away. """ if float(torchvision.__version__.split(".")[1]) < 7.0: if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) return _new_empty_tensor(input, output_shape) else: return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) def inverse_sigmoid(x, eps=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) return torch.log(x1/x2) def to_device(sample, device): if isinstance(sample, torch.Tensor): sample = sample.to(device) elif isinstance(sample, tuple) or isinstance(sample, list): sample = [to_device(s, device) for s in sample] elif isinstance(sample, dict): sample = {k: to_device(v, device) for k, v in sample.items()} return sample def nested_tensor_from_videos_list_with_stride(videos_list, max_stride): """ This function receives a list of videos (each of shape [T, C, H, W]) and returns a NestedTensor of the padded videos (shape [B, T, C, PH, PW], along with their padding masks (true for padding areas, false otherwise, of shape [B, T, PH, PW]. """ temporal_max_stride, spatial_max_stride = max_stride max_size = _max_by_axis([list(video.shape) for video in videos_list]) # list[t 3 h w] -> t b 3 h w t, *_, h,w = max_size if t % temporal_max_stride != 0: t += temporal_max_stride - (t % temporal_max_stride) if w % spatial_max_stride != 0: w += spatial_max_stride - (w % spatial_max_stride) if h % spatial_max_stride != 0: h += spatial_max_stride - (h % spatial_max_stride) max_size[0] = t max_size[2] = h max_size[3] = w padded_batch_shape = [len(videos_list)] + max_size # b t 3 hp wp b, t, c, h, w = padded_batch_shape dtype = videos_list[0].dtype device = videos_list[0].device padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) # b t c hp wp videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False # b t c hp wp return NestedTensor(padded_videos, videos_pad_masks)